flip.nvflare.controllers ======================== .. py:module:: flip.nvflare.controllers .. autoapi-nested-parse:: FLIP Controllers module containing NVFLARE workflow controllers. Controllers orchestrate federated learning workflows. Exports: - InitTraining: Initialization controller for training setup - ScatterAndGather: Main training loop controller with FedAvg aggregation - ScatterAndGatherLDM: Dual-phase training controller for LDM (autoencoder + diffusion model) - CrossSiteModelEval: Cross-site model evaluation controller - InitEvaluation: Initialization controller for evaluation setup - ModelEval: Main evaluation loop controller Submodules ---------- .. toctree:: :maxdepth: 1 /reference/api/flip/nvflare/controllers/cross_site_model_eval/index /reference/api/flip/nvflare/controllers/fed_evaluation/index /reference/api/flip/nvflare/controllers/init_evaluation/index /reference/api/flip/nvflare/controllers/init_training/index /reference/api/flip/nvflare/controllers/scatter_and_gather/index /reference/api/flip/nvflare/controllers/scatter_and_gather_ldm/index Classes ------- .. autoapisummary:: flip.nvflare.controllers.CrossSiteModelEval flip.nvflare.controllers.ModelEval flip.nvflare.controllers.InitEvaluation flip.nvflare.controllers.InitTraining flip.nvflare.controllers.ScatterAndGather flip.nvflare.controllers.ScatterAndGatherLDM Package Contents ---------------- .. py:class:: CrossSiteModelEval(task_check_period=0.5, cross_val_dir=AppConstants.CROSS_VAL_DIR, submit_model_timeout=600, validation_timeout: int = 6000, model_locator_id='', formatter_id='', submit_model_task_name=AppConstants.TASK_SUBMIT_MODEL, validation_task_name=AppConstants.TASK_VALIDATION, cleanup_models=False, participating_clients=None, wait_for_clients_timeout=300, cleanup_timeout=600, fatal_error_delay=5, model_id='') Bases: :py:obj:`nvflare.apis.impl.controller.Controller` Cross Site Model Validation workflow. :param task_check_period: How often to check for new tasks or tasks being finished. Defaults to 0.5. :type task_check_period: float, optional :param cross_val_dir: Path to cross site validation directory relative to run directory. Defaults to "cross_site_val". :type cross_val_dir: str, optional :param submit_model_timeout: Timeout of submit_model_task. Defaults to 600 secs. :type submit_model_timeout: int, optional :param validation_timeout: Timeout for validate_model task. Defaults to 6000 secs. :type validation_timeout: int, optional :param model_locator_id: ID for model_locator component. Defaults to "". :type model_locator_id: str, optional :param formatter_id: ID for formatter component. Defaults to "". :type formatter_id: str, optional :param submit_model_task_name: Name of submit_model task. Defaults to "". :type submit_model_task_name: str, optional :param validation_task_name: Name of validate_model task. Defaults to "validate". :type validation_task_name: str, optional :param cleanup_models: Whether models should be deleted after run. Defaults to False. :type cleanup_models: bool, optional :param participating_clients: List of participating client names. If not provided, defaults to all clients connected at start of controller. :type participating_clients: list, optional :param wait_for_clients_timeout: Timeout for clients to appear. Defaults to 300 secs :type wait_for_clients_timeout: int, optional :param fatal_error_delay: Time in seconds to delay before calling 'system_panic' if a task returns an error result and ignore_result_error is set to false :type fatal_error_delay: int, optional :param model_id: ID of the model that the training is being performed under. :type model_id: str, required .. py:attribute:: flip .. py:method:: start_controller(fl_ctx: nvflare.apis.fl_context.FLContext) .. py:method:: control_flow(abort_signal: nvflare.apis.signal.Signal, fl_ctx: nvflare.apis.fl_context.FLContext) .. py:method:: stop_controller(fl_ctx: nvflare.apis.fl_context.FLContext) .. py:method:: handle_event(event_type: str, fl_ctx: nvflare.apis.fl_context.FLContext) .. py:method:: process_result_of_unknown_task(client: nvflare.apis.client.Client, task_name: str, client_task_id: str, result: nvflare.apis.shareable.Shareable, fl_ctx: nvflare.apis.fl_context.FLContext) .. py:class:: ModelEval(task_check_period=0.5, submit_model_timeout=600, validation_timeout: int = 6000, model_locator_id='', formatter_id='', submit_model_task_name=AppConstants.TASK_SUBMIT_MODEL, evaluation_task_name=PTConstants.EvalTaskName, cleanup_models=False, participating_clients=None, wait_for_clients_timeout=300, cleanup_timeout=600, fatal_error_delay=5, model_id='') Bases: :py:obj:`nvflare.apis.impl.controller.Controller` Model Evaluation workflow. :param task_check_period: How often to check for new tasks or tasks being finished. Defaults to 0.5. :type task_check_period: float, optional :param submit_model_timeout: Timeout of submit_model_task. Defaults to 600 secs. :type submit_model_timeout: int, optional :param validation_timeout: Timeout for validate_model task. Defaults to 6000 secs. :type validation_timeout: int, optional :param model_locator_id: ID for model_locator component. Defaults to "". :type model_locator_id: str, optional :param formatter_id: ID for formatter component. Defaults to "". :type formatter_id: str, optional :param submit_model_task_name: Name of submit_model task. Defaults to "". :type submit_model_task_name: str, optional :param validation_task_name: Name of validate_model task. Defaults to "validate". :type validation_task_name: str, optional :param cleanup_models: Whether models should be deleted after run. Defaults to False. :type cleanup_models: bool, optional :param participating_clients: List of participating client names. If not provided, defaults to all clients connected at start of controller. :type participating_clients: list, optional :param wait_for_clients_timeout: Timeout for clients to appear. Defaults to 300 secs :type wait_for_clients_timeout: int, optional :param fatal_error_delay: Time in seconds to delay before calling 'system_panic' if a task returns an error result and ignore_result_error is set to false :type fatal_error_delay: int, optional :param model_id: ID of the model that the training is being performed under. :type model_id: str, required .. py:attribute:: flip .. py:method:: start_controller(fl_ctx: nvflare.apis.fl_context.FLContext) .. py:method:: control_flow(abort_signal: nvflare.apis.signal.Signal, fl_ctx: nvflare.apis.fl_context.FLContext) .. py:method:: stop_controller(fl_ctx: nvflare.apis.fl_context.FLContext) .. py:method:: handle_event(event_type: str, fl_ctx: nvflare.apis.fl_context.FLContext) .. py:method:: process_result_of_unknown_task(client: nvflare.apis.client.Client, task_name: str, client_task_id: str, result: nvflare.apis.shareable.Shareable, fl_ctx: nvflare.apis.fl_context.FLContext) .. py:class:: InitEvaluation(model_id: str, min_clients: int = FlipConstants.MIN_CLIENTS, flip: flip.FLIP = FLIP(), cleanup_timeout: int = 600) Bases: :py:obj:`nvflare.apis.impl.controller.Controller` The controller that is executed pre-training and is a part of the FLIP training model The InitTraining workflow sends a request to the Central Hub, stating that training has initiated and executes the client cleanup task. :param model_id: ID of the model that the training is being performed under. :type model_id: str :param min_clients: Minimum number of clients. Defaults to 1 for the aggregation to take place with successful results. :type min_clients: int, optional :param cleanup_timeout: Timeout for image cleanup, defaults to 600 seconds (10 minutes) :type cleanup_timeout: int, optional :raises ValueError: - when the model ID is not a valid UUID. - when the minimum number of clients specified is less than 1 - when cleanup_timeout is less the 0 .. py:attribute:: flip .. py:method:: start_controller(fl_ctx: nvflare.apis.fl_context.FLContext) .. py:method:: control_flow(abort_signal: nvflare.apis.signal.Signal, fl_ctx: nvflare.apis.fl_context.FLContext) .. py:method:: stop_controller(fl_ctx: nvflare.apis.fl_context.FLContext) -> None .. py:method:: process_result_of_unknown_task(client: nvflare.apis.client.Client, task_name, client_task_id, result: nvflare.apis.shareable.Shareable, fl_ctx: nvflare.apis.fl_context.FLContext) -> None .. py:class:: InitTraining(model_id: str, min_clients: int = FlipConstants.MIN_CLIENTS, flip: flip.FLIP = FLIP(), cleanup_timeout: int = 600) Bases: :py:obj:`nvflare.apis.impl.controller.Controller` The controller that is executed pre-training and is a part of the FLIP training model The InitTraining workflow sends a request to the Central Hub, stating that training has initiated and executes the client cleanup task. :param model_id: ID of the model that the training is being performed under. :type model_id: str :param min_clients: Minimum number of clients. Defaults to 1 for the aggregation to take place with successful results. :type min_clients: int, optional :param cleanup_timeout: Timeout for image cleanup, defaults to 600 seconds (10 minutes) :type cleanup_timeout: int, optional :raises ValueError: - when the model ID is not a valid UUID. - when the minimum number of clients specified is less than 1 - when cleanup_timeout is less the 0 .. py:attribute:: flip .. py:method:: start_controller(fl_ctx: nvflare.apis.fl_context.FLContext) .. py:method:: control_flow(abort_signal: nvflare.apis.signal.Signal, fl_ctx: nvflare.apis.fl_context.FLContext) .. py:method:: stop_controller(fl_ctx: nvflare.apis.fl_context.FLContext) -> None .. py:method:: process_result_of_unknown_task(client: nvflare.apis.client.Client, task_name, client_task_id, result: nvflare.apis.shareable.Shareable, fl_ctx: nvflare.apis.fl_context.FLContext) -> None .. py:class:: ScatterAndGather(model_id: str = '', min_clients: int = 1, num_rounds: int = 5, start_round: int = 0, wait_time_after_min_received: int = 10, aggregator_id=AppConstants.DEFAULT_AGGREGATOR_ID, persistor_id=AppConstants.DEFAULT_PERSISTOR_ID, shareable_generator_id=AppConstants.DEFAULT_SHAREABLE_GENERATOR_ID, train_task_name=AppConstants.TASK_TRAIN, train_timeout: int = 0, ignore_result_error: bool = False, fatal_error_delay: int = 5, task_check_period: float = 0.5, persist_every_n_rounds: int = 1) Bases: :py:obj:`nvflare.apis.impl.controller.Controller` The controller for FederatedAveraging Workflow. The ScatterAndGather workflow defines Federated training on all clients. The model persistor (persistor_id) is used to load the initial global model which is sent to all clients. Each clients sends it's updated weights after local training which is aggregated (aggregator_id). The shareable generator is used to convert the aggregated weights to shareable and shareable back to weights. The model_persistor also saves the model after training. :param model_id: ID of the model that the training is being performed under. :type model_id: str, required :param min_clients: Min number of clients in training. Defaults to 1. :type min_clients: int, optional :param num_rounds: The total number of training rounds. Defaults to 5. :type num_rounds: int, optional :param start_round: Start round for training. Defaults to 0. :type start_round: int, optional :param wait_time_after_min_received: Time to wait before beginning aggregation after contributions received. Defaults to 10. :type wait_time_after_min_received: int, optional :param aggregator_id: ID of the aggregator component. Defaults to "aggregator". :type aggregator_id: str, optional :param persistor_id: ID of the persistor component. Defaults to "persistor". :type persistor_id: str, optional :param shareable_generator_id: ID of the shareable generator. Defaults to "shareable_generator". :type shareable_generator_id: str, optional :param train_task_name: Name of the train task. Defaults to "train". :type train_task_name: str, optional :param train_timeout: Time to wait for clients to do local training. :type train_timeout: int, optional :param ignore_result_error: whether this controller can proceed if result has errors. Defaults to False. :type ignore_result_error: bool, optional :param fatal_error_delay: Time in seconds to delay before calling 'system_panic' if a task returns an error result and ignore_result_error is set to false :type fatal_error_delay: int, optional :param task_check_period: interval for checking status of tasks. Defaults to 0.5. :type task_check_period: float, optional :param persist_every_n_rounds: persist the global model every n rounds. Defaults to 0. If n is 0 then no persist. :type persist_every_n_rounds: int, optional :raises TypeError: when any of input arguments does not have correct type :raises ValueError: when any of input arguments is out of range or are in an incorrect format .. py:attribute:: flip .. py:attribute:: model_id :value: '' .. py:attribute:: aggregator_id .. py:attribute:: persistor_id .. py:attribute:: shareable_generator_id .. py:attribute:: train_task_name .. py:attribute:: aggregator :value: None .. py:attribute:: persistor :value: None .. py:attribute:: shareable_gen :value: None .. py:method:: start_controller(fl_ctx: nvflare.apis.fl_context.FLContext) -> None .. py:method:: control_flow(abort_signal: nvflare.apis.signal.Signal, fl_ctx: nvflare.apis.fl_context.FLContext) -> None .. py:method:: stop_controller(fl_ctx: nvflare.apis.fl_context.FLContext) -> None .. py:method:: handle_event(event_type: str, fl_ctx: nvflare.apis.fl_context.FLContext) .. py:method:: process_result_of_unknown_task(client: nvflare.apis.client.Client, task_name, client_task_id, result: nvflare.apis.shareable.Shareable, fl_ctx: nvflare.apis.fl_context.FLContext) -> None .. py:method:: get_persist_state(fl_ctx: nvflare.apis.fl_context.FLContext) -> dict .. py:method:: restore(state_data: dict, fl_ctx: nvflare.apis.fl_context.FLContext) .. py:class:: ScatterAndGatherLDM(model_id: str = '', min_clients: int = 1, num_rounds_ae: int = 5, num_rounds_dm: int = 5, start_round: int = 0, model_locator_id='', wait_time_after_min_received: int = 10, aggregator_id=AppConstants.DEFAULT_AGGREGATOR_ID, persistor_id=AppConstants.DEFAULT_PERSISTOR_ID, shareable_generator_id=AppConstants.DEFAULT_SHAREABLE_GENERATOR_ID, train_task_name=AppConstants.TASK_TRAIN, train_timeout: int = 0, ignore_result_error: bool = True, fatal_error_delay: int = 5, task_check_period: float = 0.5, persist_every_n_rounds: int = 1) Bases: :py:obj:`nvflare.apis.impl.controller.Controller` The controller for FederatedAveraging Workflow. The ScatterAndGather workflow defines Federated training on all clients. The model persistor (persistor_id) is used to load the initial global model which is sent to all clients. Each clients sends it's updated weights after local training which is aggregated (aggregator_id). The shareable generator is used to convert the aggregated weights to shareable and shareable back to weights. The model_persistor also saves the model after training. :param model_id: ID of the model that the training is being performed under. :type model_id: str, required :param min_clients: Min number of clients in training. Defaults to 1. :type min_clients: int, optional :param num_rounds_ae: The total number of training rounds for autoencoder. Defaults to 5. :type num_rounds_ae: int, optional :param num_rounds_dm: The total number of training rounds for diffusion model. Defaults to 5. :type num_rounds_dm: int, optional :param start_round: Start round for training. Defaults to 0. :type start_round: int, optional :param model_locator_id: ID of the model locator component. Defaults to "". :type model_locator_id: str, optional :param wait_time_after_min_received: Time to wait before beginning aggregation after contributions received. Defaults to 10. :type wait_time_after_min_received: int, optional :param aggregator_id: ID of the aggregator component. Defaults to "aggregator". :type aggregator_id: str, optional :param persistor_id: ID of the persistor component. Defaults to "persistor". :type persistor_id: str, optional :param shareable_generator_id: ID of the shareable generator. Defaults to "shareable_generator". :type shareable_generator_id: str, optional :param train_task_name: Name of the train task. Defaults to "train". :type train_task_name: str, optional :param train_timeout: Time to wait for clients to do local training. :type train_timeout: int, optional :param ignore_result_error: whether this controller can proceed if result has errors. Defaults to False. :type ignore_result_error: bool, optional :param fatal_error_delay: Time in seconds to delay before calling 'system_panic' if a task returns an error result and ignore_result_error is set to false :type fatal_error_delay: int, optional :param task_check_period: interval for checking status of tasks. Defaults to 0.5. :type task_check_period: float, optional :param persist_every_n_rounds: persist the global model every n rounds. Defaults to 0. If n is 0 then no persist. :type persist_every_n_rounds: int, optional :raises TypeError: when any of input arguments does not have correct type :raises ValueError: when any of input arguments is out of range or are in an incorrect format .. py:attribute:: flip .. py:attribute:: model_id :value: '' .. py:attribute:: aggregator_id .. py:attribute:: persistor_id .. py:attribute:: model_locator_id :value: '' .. py:attribute:: shareable_generator_id .. py:attribute:: train_task_name .. py:attribute:: aggregator :value: None .. py:attribute:: persistor :value: None .. py:attribute:: shareable_gen :value: None .. py:attribute:: ignore_result_error :value: True .. py:method:: start_controller(fl_ctx: nvflare.apis.fl_context.FLContext) -> None .. py:method:: locate_server_models(fl_ctx: nvflare.apis.fl_context.FLContext) -> bool Locate server models for the current task. :param fl_ctx: _description_ :type fl_ctx: FLContext :returns: *bool* -- _description_ .. py:method:: control_flow(abort_signal: nvflare.apis.signal.Signal, fl_ctx: nvflare.apis.fl_context.FLContext) -> None .. py:method:: stop_controller(fl_ctx: nvflare.apis.fl_context.FLContext) -> None .. py:method:: handle_event(event_type: str, fl_ctx: nvflare.apis.fl_context.FLContext) .. py:method:: process_result_of_unknown_task(client: nvflare.apis.client.Client, task_name, client_task_id, result: nvflare.apis.shareable.Shareable, fl_ctx: nvflare.apis.fl_context.FLContext) -> None .. py:method:: get_persist_state(fl_ctx: nvflare.apis.fl_context.FLContext) -> dict .. py:method:: restore(state_data: dict, fl_ctx: nvflare.apis.fl_context.FLContext)