flip.nvflare.controllers.scatter_and_gather_ldm =============================================== .. py:module:: flip.nvflare.controllers.scatter_and_gather_ldm Classes ------- .. autoapisummary:: flip.nvflare.controllers.scatter_and_gather_ldm.ScatterAndGatherLDM Module Contents --------------- .. py:class:: ScatterAndGatherLDM(model_id: str = '', min_clients: int = 1, num_rounds_ae: int = 5, num_rounds_dm: int = 5, start_round: int = 0, model_locator_id='', wait_time_after_min_received: int = 10, aggregator_id=AppConstants.DEFAULT_AGGREGATOR_ID, persistor_id=AppConstants.DEFAULT_PERSISTOR_ID, shareable_generator_id=AppConstants.DEFAULT_SHAREABLE_GENERATOR_ID, train_task_name=AppConstants.TASK_TRAIN, train_timeout: int = 0, ignore_result_error: bool = True, fatal_error_delay: int = 5, task_check_period: float = 0.5, persist_every_n_rounds: int = 1) Bases: :py:obj:`nvflare.apis.impl.controller.Controller` The controller for FederatedAveraging Workflow. The ScatterAndGather workflow defines Federated training on all clients. The model persistor (persistor_id) is used to load the initial global model which is sent to all clients. Each clients sends it's updated weights after local training which is aggregated (aggregator_id). The shareable generator is used to convert the aggregated weights to shareable and shareable back to weights. The model_persistor also saves the model after training. :param model_id: ID of the model that the training is being performed under. :type model_id: str, required :param min_clients: Min number of clients in training. Defaults to 1. :type min_clients: int, optional :param num_rounds_ae: The total number of training rounds for autoencoder. Defaults to 5. :type num_rounds_ae: int, optional :param num_rounds_dm: The total number of training rounds for diffusion model. Defaults to 5. :type num_rounds_dm: int, optional :param start_round: Start round for training. Defaults to 0. :type start_round: int, optional :param model_locator_id: ID of the model locator component. Defaults to "". :type model_locator_id: str, optional :param wait_time_after_min_received: Time to wait before beginning aggregation after contributions received. Defaults to 10. :type wait_time_after_min_received: int, optional :param aggregator_id: ID of the aggregator component. Defaults to "aggregator". :type aggregator_id: str, optional :param persistor_id: ID of the persistor component. Defaults to "persistor". :type persistor_id: str, optional :param shareable_generator_id: ID of the shareable generator. Defaults to "shareable_generator". :type shareable_generator_id: str, optional :param train_task_name: Name of the train task. Defaults to "train". :type train_task_name: str, optional :param train_timeout: Time to wait for clients to do local training. :type train_timeout: int, optional :param ignore_result_error: whether this controller can proceed if result has errors. Defaults to False. :type ignore_result_error: bool, optional :param fatal_error_delay: Time in seconds to delay before calling 'system_panic' if a task returns an error result and ignore_result_error is set to false :type fatal_error_delay: int, optional :param task_check_period: interval for checking status of tasks. Defaults to 0.5. :type task_check_period: float, optional :param persist_every_n_rounds: persist the global model every n rounds. Defaults to 0. If n is 0 then no persist. :type persist_every_n_rounds: int, optional :raises TypeError: when any of input arguments does not have correct type :raises ValueError: when any of input arguments is out of range or are in an incorrect format .. py:attribute:: flip .. py:attribute:: model_id :value: '' .. py:attribute:: aggregator_id .. py:attribute:: persistor_id .. py:attribute:: model_locator_id :value: '' .. py:attribute:: shareable_generator_id .. py:attribute:: train_task_name .. py:attribute:: aggregator :value: None .. py:attribute:: persistor :value: None .. py:attribute:: shareable_gen :value: None .. py:attribute:: ignore_result_error :value: True .. py:method:: start_controller(fl_ctx: nvflare.apis.fl_context.FLContext) -> None .. py:method:: locate_server_models(fl_ctx: nvflare.apis.fl_context.FLContext) -> bool Locate server models for the current task. :param fl_ctx: _description_ :type fl_ctx: FLContext :returns: *bool* -- _description_ .. py:method:: control_flow(abort_signal: nvflare.apis.signal.Signal, fl_ctx: nvflare.apis.fl_context.FLContext) -> None .. py:method:: stop_controller(fl_ctx: nvflare.apis.fl_context.FLContext) -> None .. py:method:: handle_event(event_type: str, fl_ctx: nvflare.apis.fl_context.FLContext) .. py:method:: process_result_of_unknown_task(client: nvflare.apis.client.Client, task_name, client_task_id, result: nvflare.apis.shareable.Shareable, fl_ctx: nvflare.apis.fl_context.FLContext) -> None .. py:method:: get_persist_state(fl_ctx: nvflare.apis.fl_context.FLContext) -> dict .. py:method:: restore(state_data: dict, fl_ctx: nvflare.apis.fl_context.FLContext)