# Copyright (c) 2026 Guy's and St Thomas' NHS Foundation Trust & King's College London
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import os
import shutil
import time
from typing import Union
from nvflare.apis.client import Client
from nvflare.apis.controller_spec import ClientTask, Task
from nvflare.apis.dxo import DXO, from_bytes
from nvflare.apis.fl_constant import ReturnCode
from nvflare.apis.fl_context import FLContext
from nvflare.apis.impl.controller import Controller
from nvflare.apis.shareable import Shareable
from nvflare.apis.signal import Signal
from nvflare.apis.workspace import Workspace
from nvflare.app_common.abstract.formatter import Formatter
from nvflare.app_common.abstract.model_locator import ModelLocator
from nvflare.app_common.app_constant import AppConstants
from nvflare.app_common.app_event_type import AppEventType
from nvflare.security.logging import secure_format_exception
from nvflare.widgets.info_collector import GroupInfoCollector, InfoCollector
from flip import FLIP
from flip.constants import FlipTasks, PTConstants
from flip.utils import Utils
[docs]
class ModelEval(Controller):
def __init__(
self,
task_check_period=0.5,
submit_model_timeout=600,
validation_timeout: int = 6000,
model_locator_id="",
formatter_id="",
submit_model_task_name=AppConstants.TASK_SUBMIT_MODEL,
evaluation_task_name=PTConstants.EvalTaskName,
cleanup_models=False,
participating_clients=None,
wait_for_clients_timeout=300,
cleanup_timeout=600,
fatal_error_delay=5,
model_id="",
):
"""Model Evaluation workflow.
Args:
task_check_period (float, optional): How often to check for new tasks or tasks being finished.
Defaults to 0.5.
submit_model_timeout (int, optional): Timeout of submit_model_task. Defaults to 600 secs.
validation_timeout (int, optional): Timeout for validate_model task. Defaults to 6000 secs.
model_locator_id (str, optional): ID for model_locator component. Defaults to "".
formatter_id (str, optional): ID for formatter component. Defaults to "".
submit_model_task_name (str, optional): Name of submit_model task. Defaults to "".
validation_task_name (str, optional): Name of validate_model task. Defaults to "validate".
cleanup_models (bool, optional): Whether models should be deleted after run. Defaults to False.
participating_clients (list, optional): List of participating client names. If not provided, defaults
to all clients connected at start of controller.
wait_for_clients_timeout (int, optional): Timeout for clients to appear. Defaults to 300 secs
fatal_error_delay (int, optional): Time in seconds to delay before calling 'system_panic'
if a task returns an error result and ignore_result_error is set to false
model_id (str, required): ID of the model that the training is being performed under.
"""
super(ModelEval, self).__init__(task_check_period=task_check_period)
# flip
if not isinstance(task_check_period, float):
raise TypeError("task_check_period must be float but got {}".format(type(task_check_period)))
if not isinstance(submit_model_timeout, int):
raise TypeError("submit_model_timeout must be int but got {}".format(type(submit_model_timeout)))
if not isinstance(validation_timeout, int):
raise TypeError("validation_timeout must be int but got {}".format(type(validation_timeout)))
if not isinstance(model_locator_id, str):
raise TypeError("model_locator_id must be a string but got {}".format(type(model_locator_id)))
if not isinstance(formatter_id, str):
raise TypeError("formatter_id must be a string but got {}".format(type(formatter_id)))
if not isinstance(submit_model_task_name, str):
raise TypeError("submit_model_task_name must be a string but got {}".format(type(submit_model_task_name)))
if not isinstance(evaluation_task_name, str):
raise TypeError("evaluation_task_name must be a string but got {}".format(type(evaluation_task_name)))
if not isinstance(cleanup_models, bool):
raise TypeError("cleanup_models must be bool but got {}".format(type(cleanup_models)))
if not Utils.is_valid_uuid(model_id):
raise ValueError(f"The model ID: {model_id} is not a valid UUID")
if participating_clients:
if not isinstance(participating_clients, list):
raise TypeError("participating_clients must be a list but got {}".format(type(participating_clients)))
if not all(isinstance(x, str) for x in participating_clients):
raise TypeError("participating_clients must be strings")
if submit_model_timeout < 0:
raise ValueError("submit_model_timeout must be greater than or equal to 0.")
if validation_timeout < 0:
raise ValueError("model_validate_timeout must be greater than or equal to 0.")
if wait_for_clients_timeout < 0:
raise ValueError("wait_for_clients_timeout must be greater than or equal to 0.")
if cleanup_timeout < 0:
raise ValueError("cleanup_timeout must be greater than or equal to 0.")
self._eval_dir = PTConstants.EvalDir
self._model_locator_id = model_locator_id
self._formatter_id = formatter_id
self._submit_model_task_name = submit_model_task_name
# This stays VALIDATION for NVFLARE compatibility.
self._evaluation_task_name = evaluation_task_name
self._submit_model_timeout = submit_model_timeout
self._validation_timeout = validation_timeout
self._wait_for_clients_timeout = wait_for_clients_timeout
self._cleanup_models = cleanup_models
self._participating_clients = participating_clients
self._eval_results = {}
self._models = {}
self._all_models_dxo = None
self._client_models = {}
self._formatter = None
self._eval_results_dir = PTConstants.EvalDir
self._model_locator = None
self._cleanup_timeout = cleanup_timeout
self._fatal_error_delay = fatal_error_delay
self._model_id = model_id
[docs]
def start_controller(self, fl_ctx: FLContext):
engine = fl_ctx.get_engine()
if not engine:
self.system_panic("Engine not found. Workflow exiting.", fl_ctx)
return
# If the list of participating clients is not provided, include all clients currently available.
if not self._participating_clients:
clients = engine.get_clients()
self._participating_clients = [c.name for c in clients]
# Create shareable dirs for models and results
workspace: Workspace = engine.get_workspace()
run_dir = workspace.get_run_dir(fl_ctx.get_job_id())
eval_path = os.path.join(run_dir, self._eval_results_dir)
self._eval_results_dir = eval_path
# Initialise the model locator
if self._model_locator_id:
self._model_locator = engine.get_component(self._model_locator_id)
if not isinstance(self._model_locator, ModelLocator):
self.system_panic(
reason="bad model locator {}: expect ModelLocator but got {}".format(
self._model_locator_id, type(self._model_locator)
),
fl_ctx=fl_ctx,
)
return
# Fire the init event.
fl_ctx.set_prop(AppConstants.CROSS_VAL_RESULTS_PATH, self._eval_results_dir)
self.fire_event(AppEventType.CROSS_VAL_INIT, fl_ctx)
# Cleanup/create the cross val models and results directories
if os.path.exists(self._eval_results_dir):
shutil.rmtree(self._eval_results_dir)
# Recreate new directories.
os.makedirs(self._eval_results_dir)
# Get components
if self._model_locator_id:
self._model_locator = engine.get_component(self._model_locator_id)
if not isinstance(self._model_locator, ModelLocator):
self.system_panic(
reason="bad model locator {}: expect ModelLocator but got {}".format(
self._model_locator_id, type(self._model_locator)
),
fl_ctx=fl_ctx,
)
return
if self._formatter_id:
self._formatter = engine.get_component(self._formatter_id)
if not isinstance(self._formatter, Formatter):
self.system_panic(
reason=f"formatter {self._formatter_id} is not an instance of Formatter.", fl_ctx=fl_ctx
)
return
if not self._formatter:
self.log_info(fl_ctx, "Formatter not found. Stats will not be printed.")
for c_name in self._participating_clients:
self._client_models[c_name] = None
self._eval_results[c_name] = {}
[docs]
def control_flow(self, abort_signal: Signal, fl_ctx: FLContext):
try:
# wait until there are some clients
engine = fl_ctx.get_engine()
start_time = time.time()
while not self._participating_clients:
self._participating_clients = engine.get_clients()
if time.time() - start_time > self._wait_for_clients_timeout:
self.log_info(fl_ctx, "No clients available - quit model validation.")
return
self.log_info(fl_ctx, "No clients available - waiting ...")
time.sleep(2.0)
if abort_signal.triggered:
self.log_info(fl_ctx, "Abort signal triggered. Finishing model validation.")
return
self.log_info(fl_ctx, f"Beginning model validation with clients: {self._participating_clients}.")
# This loads models in the evaluation directory, and saves them in self._models.
if self._model_locator:
success = self._locate_server_models(fl_ctx)
if not success:
return
# Do the validation task
task = Task(
name=self._evaluation_task_name,
data=Shareable(),
before_task_sent_cb=self._before_send_validate_task_cb,
after_task_sent_cb=self._after_send_validate_task_cb,
result_received_cb=self._receive_val_result_cb,
timeout=self._validation_timeout,
)
self.broadcast(
task=task,
fl_ctx=fl_ctx,
targets=self._participating_clients,
min_responses=len(self._participating_clients),
wait_time_after_min_received=0,
)
if abort_signal.triggered:
self.log_info(fl_ctx, "Abort signal triggered. Finishing model evaluation.")
return
while self.get_num_standing_tasks():
if abort_signal.triggered:
self.log_info(fl_ctx, "Abort signal triggered. Finishing cross site validation.")
return
self.log_debug(fl_ctx, "Checking standing tasks to see if cross site validation finished.")
time.sleep(self._task_check_period)
self.log_info(fl_ctx, "Beginning post validation cleanup task...")
cleanup_task = Task(name=FlipTasks.POST_TASK.value, data=Shareable(), timeout=self._cleanup_timeout)
self.broadcast_and_wait(
task=cleanup_task,
min_responses=len(self._participating_clients),
wait_time_after_min_received=0,
fl_ctx=fl_ctx,
abort_signal=abort_signal,
)
self.log_info(fl_ctx, "Post validation cleanup completed")
except BaseException as e:
error_msg = f"Exception in cross site validator control_flow: {secure_format_exception(e)}"
self.log_exception(fl_ctx, error_msg)
self.system_panic(error_msg, fl_ctx)
[docs]
def stop_controller(self, fl_ctx: FLContext):
self.cancel_all_tasks(fl_ctx=fl_ctx)
if self._cleanup_models:
self.log_info(fl_ctx, "Removing local models kept for validation.")
for model_name, model_path in self._models.items():
if model_path and os.path.isfile(model_path):
os.remove(model_path)
self.log_debug(fl_ctx, f"Removing server model {model_name} at {model_path}.")
for model_name, model_path in self._client_models.items():
if model_path and os.path.isfile(model_path):
os.remove(model_path)
self.log_debug(fl_ctx, f"Removing client {model_name}'s model at {model_path}.")
def _receive_local_model_cb(self, client_task: ClientTask, fl_ctx: FLContext):
client_name = client_task.client.name
result: Shareable = client_task.result
self._accept_local_model(client_name=client_name, result=result, fl_ctx=fl_ctx)
# Cleanup task result
client_task.result = None
def _after_send_validate_task_cb(self, client_task: ClientTask, fl_ctx: FLContext):
# Once task is sent clear data to restore memory
client_task.task.data = None
def _receive_val_result_cb(self, client_task: ClientTask, fl_ctx: FLContext):
# Find name of the client sending this
result = client_task.result
client_name = client_task.client.name
self.log_info(fl_ctx, f"Receiving validation result from client: {client_name}: {result}")
self._accept_val_result(client_name=client_name, result=result, fl_ctx=fl_ctx)
client_task.result = None
def _locate_server_models(self, fl_ctx: FLContext) -> bool:
# Load models from model_locator
self.log_info(fl_ctx, "Locating server models.")
# Evaluation allows us to load a collection of models to test:
all_models = self._model_locator.locate_model(fl_ctx)
for name in self._model_locator.model_names:
model_dxo = all_models.data.get(name)
if not isinstance(model_dxo, DXO):
self.system_panic(f"ModelLocator gave a collection of models for which {name} is not a dxo.", fl_ctx)
return False
self._eval_results[name] = {}
if self._model_locator.model_names:
self.log_info(fl_ctx, f"Server models loaded: {self._model_locator.model_names}.")
else:
self.log_info(fl_ctx, "no server models to validate!")
self._all_models_dxo = all_models
return True
def _before_send_validate_task_cb(self, client_task: ClientTask, fl_ctx: FLContext):
"""Before the evaluation task, we load all the models DXO into a Shareable that is sent to the clients."""
if not self._all_models_dxo:
self.system_panic(
"Model contents empty.",
fl_ctx=fl_ctx,
)
return
model_shareable = self._all_models_dxo.to_shareable()
client_task.task.data = model_shareable
fl_ctx.set_prop(AppConstants.DATA_CLIENT, client_task.client, private=True, sticky=False)
fl_ctx.set_prop(AppConstants.MODEL_TO_VALIDATE, model_shareable, private=True, sticky=False)
fl_ctx.set_prop(AppConstants.PARTICIPATING_CLIENTS, self._participating_clients, private=True, sticky=False)
self.fire_event(AppEventType.SEND_MODEL_FOR_VALIDATION, fl_ctx)
def _accept_local_model(self, client_name: str, result: Shareable, fl_ctx: FLContext):
fl_ctx.set_prop(AppConstants.RECEIVED_MODEL, result, private=False, sticky=False)
fl_ctx.set_prop(AppConstants.CROSS_VAL_DIR, self._cross_val_dir, private=False, sticky=False)
self.fire_event(AppEventType.RECEIVE_BEST_MODEL, fl_ctx)
# get return code
rc = result.get_return_code()
if rc and rc != ReturnCode.OK:
time.sleep(self._fatal_error_delay)
# Raise errors if bad peer context or execution exception.
if rc in [ReturnCode.MISSING_PEER_CONTEXT, ReturnCode.BAD_PEER_CONTEXT]:
self.log_error(fl_ctx, "Peer context is bad or missing. No model submitted for this client.")
elif rc in [ReturnCode.EXECUTION_EXCEPTION, ReturnCode.TASK_UNKNOWN]:
formatted_exception = result.get_header("exception")
if formatted_exception is not None:
self.log_error(fl_ctx, formatted_exception)
self.flip.send_handled_exception(
formatted_exception=formatted_exception, client_name=client_name, model_id=self._model_id
)
self.log_error(
fl_ctx, "Execution Exception on client during model submission. No model submitted for this client."
)
# Ignore contribution if result invalid.
elif rc in [
ReturnCode.EXECUTION_RESULT_ERROR,
ReturnCode.TASK_DATA_FILTER_ERROR,
ReturnCode.TASK_RESULT_FILTER_ERROR,
ReturnCode.TASK_UNKNOWN,
]:
self.log_error(fl_ctx, "Execution result is not a shareable. Model submission will be ignored.")
else:
self.log_error(fl_ctx, "Return code set. Model submission from client will be ignored.")
else:
# Save shareable in models directory.
# try:
# self.log_debug(fl_ctx, "Extracting DXO from shareable.")
# dxo = from_shareable(result)
# save_path = self._save_validation_content(client_name, self._cross_val_models_dir, dxo, fl_ctx)
# except ValueError as v_e:
# self.log_error(
# fl_ctx, f"Unable to save shareable contents of {client_name}'s model. Exception: {str(v_e)}"
# )
# self.log_warning(fl_ctx, f"Ignoring client {client_name}'s model.")
# return
self.log_info(fl_ctx, f"Received local model from client {client_name}.")
# self._client_models[client_name] = save_path
# Send a model to this client to validate
self._send_validation_task(client_name, fl_ctx)
def _accept_val_result(self, client_name: str, result: Shareable, fl_ctx: FLContext):
# Fire event. This needs to be a new local context per each client
fl_ctx.set_prop(AppConstants.DATA_CLIENT, client_name, private=True, sticky=False)
fl_ctx.set_prop(AppConstants.VALIDATION_RESULT, result, private=True, sticky=False)
self.fire_event(AppEventType.VALIDATION_RESULT_RECEIVED, fl_ctx)
rc = result.get_return_code()
if rc and rc != ReturnCode.OK:
# Raise errors if bad peer context or execution exception.
if rc in [ReturnCode.MISSING_PEER_CONTEXT, ReturnCode.BAD_PEER_CONTEXT]:
self.log_error(fl_ctx, "Peer context is bad or missing.")
elif rc in [ReturnCode.EXECUTION_EXCEPTION, ReturnCode.TASK_UNKNOWN]:
formatted_exception = result.get_header("exception")
if formatted_exception is not None:
self.log_error(fl_ctx, formatted_exception)
self.flip.send_handled_exception(
formatted_exception=formatted_exception, client_name=client_name, model_id=self._model_id
)
self.log_error(fl_ctx, "Execution Exception in model validation.")
elif rc in [
ReturnCode.EXECUTION_RESULT_ERROR,
ReturnCode.TASK_DATA_FILTER_ERROR,
ReturnCode.TASK_RESULT_FILTER_ERROR,
]:
self.log_error(fl_ctx, "Execution result is not a shareable. Validation results will be ignored.")
else:
self.log_error(
fl_ctx,
f"Client {client_name} sent results with return code set. Logging empty results.",
)
self._eval_results[client_name] = {}
else:
save_file_name = client_name
try:
self._eval_results[client_name] = os.path.join(self._eval_results_dir, save_file_name)
self.log_info(fl_ctx, f"Client {client_name} sent results for validating model.")
except ValueError as v_e:
reason = f"Unable to save validation result from {client_name}. Exception: {str(v_e)}"
self.log_exception(fl_ctx, reason)
def _load_validation_content(self, name: str, load_dir: str, fl_ctx: FLContext) -> Union[DXO, None]:
# Load shareable from disk
shareable_filename = os.path.join(load_dir, name)
dxo: DXO = None
# load shareable
try:
with open(shareable_filename, "rb") as f:
data = f.read()
dxo: DXO = from_bytes(data)
self.log_debug(fl_ctx, f"Loading cross validation shareable content with name: {name}.")
except Exception as e:
raise ValueError(f"Exception in loading shareable content for {name}: {secure_format_exception(e)}")
return dxo
[docs]
def handle_event(self, event_type: str, fl_ctx: FLContext):
super().handle_event(event_type=event_type, fl_ctx=fl_ctx)
if event_type == InfoCollector.EVENT_TYPE_GET_STATS:
if self._formatter:
collector = fl_ctx.get_prop(InfoCollector.CTX_KEY_STATS_COLLECTOR, None)
if collector:
if not isinstance(collector, GroupInfoCollector):
raise TypeError("collector must be GroupInfoCollector but got {}".format(type(collector)))
fl_ctx.set_prop(AppConstants.VALIDATION_RESULT, self._eval_results, private=True, sticky=False)
val_info = self._formatter.format(fl_ctx)
collector.add_info(
group_name=self._name,
info={"val_results": val_info},
)
else:
self.log_warning(fl_ctx, "No formatter provided. Validation results can't be printed.")
[docs]
def process_result_of_unknown_task(
self, client: Client, task_name: str, client_task_id: str, result: Shareable, fl_ctx: FLContext
):
if task_name == self._submit_model_task_name:
self._accept_local_model(client_name=client.name, result=result, fl_ctx=fl_ctx)
elif task_name == self._validation_task_name:
self._accept_val_result(client_name=client.name, result=result, fl_ctx=fl_ctx)
else:
self.log_error(fl_ctx, "Ignoring result from unknown task.")