flip.nvflare.controllers.scatter_and_gather_ldm

Classes

ScatterAndGatherLDM

The controller for FederatedAveraging Workflow.

Module Contents

class flip.nvflare.controllers.scatter_and_gather_ldm.ScatterAndGatherLDM(model_id: str = '', min_clients: int = 1, num_rounds_ae: int = 5, num_rounds_dm: int = 5, start_round: int = 0, model_locator_id='', wait_time_after_min_received: int = 10, aggregator_id=AppConstants.DEFAULT_AGGREGATOR_ID, persistor_id=AppConstants.DEFAULT_PERSISTOR_ID, shareable_generator_id=AppConstants.DEFAULT_SHAREABLE_GENERATOR_ID, train_task_name=AppConstants.TASK_TRAIN, train_timeout: int = 0, ignore_result_error: bool = True, fatal_error_delay: int = 5, task_check_period: float = 0.5, persist_every_n_rounds: int = 1)[source]

Bases: nvflare.apis.impl.controller.Controller

The controller for FederatedAveraging Workflow.

The ScatterAndGather workflow defines Federated training on all clients. The model persistor (persistor_id) is used to load the initial global model which is sent to all clients. Each clients sends it’s updated weights after local training which is aggregated (aggregator_id). The shareable generator is used to convert the aggregated weights to shareable and shareable back to weights. The model_persistor also saves the model after training.

Parameters:
  • model_id (str, required) – ID of the model that the training is being performed under.

  • min_clients (int, optional) – Min number of clients in training. Defaults to 1.

  • num_rounds_ae (int, optional) – The total number of training rounds for autoencoder. Defaults to 5.

  • num_rounds_dm (int, optional) – The total number of training rounds for diffusion model. Defaults to 5.

  • start_round (int, optional) – Start round for training. Defaults to 0.

  • model_locator_id (str, optional) – ID of the model locator component. Defaults to “”.

  • wait_time_after_min_received (int, optional) – Time to wait before beginning aggregation after contributions received. Defaults to 10.

  • aggregator_id (str, optional) – ID of the aggregator component. Defaults to “aggregator”.

  • persistor_id (str, optional) – ID of the persistor component. Defaults to “persistor”.

  • shareable_generator_id (str, optional) – ID of the shareable generator. Defaults to “shareable_generator”.

  • train_task_name (str, optional) – Name of the train task. Defaults to “train”.

  • train_timeout (int, optional) – Time to wait for clients to do local training.

  • ignore_result_error (bool, optional) – whether this controller can proceed if result has errors. Defaults to False.

  • fatal_error_delay (int, optional) – Time in seconds to delay before calling ‘system_panic’ if a task returns an error result and ignore_result_error is set to false

  • task_check_period (float, optional) – interval for checking status of tasks. Defaults to 0.5.

  • persist_every_n_rounds (int, optional) – persist the global model every n rounds. Defaults to 0. If n is 0 then no persist.

Raises:
  • TypeError – when any of input arguments does not have correct type

  • ValueError – when any of input arguments is out of range or are in an incorrect format

flip[source]
model_id = ''[source]
aggregator_id[source]
persistor_id[source]
model_locator_id = ''[source]
shareable_generator_id[source]
train_task_name[source]
aggregator = None[source]
persistor = None[source]
shareable_gen = None[source]
ignore_result_error = True[source]
start_controller(fl_ctx: nvflare.apis.fl_context.FLContext) None[source]
locate_server_models(fl_ctx: nvflare.apis.fl_context.FLContext) bool[source]

Locate server models for the current task.

Parameters:

fl_ctx (FLContext) – _description_

Returns:

bool – _description_

control_flow(abort_signal: nvflare.apis.signal.Signal, fl_ctx: nvflare.apis.fl_context.FLContext) None[source]
stop_controller(fl_ctx: nvflare.apis.fl_context.FLContext) None[source]
handle_event(event_type: str, fl_ctx: nvflare.apis.fl_context.FLContext)[source]
process_result_of_unknown_task(client: nvflare.apis.client.Client, task_name, client_task_id, result: nvflare.apis.shareable.Shareable, fl_ctx: nvflare.apis.fl_context.FLContext) None[source]
get_persist_state(fl_ctx: nvflare.apis.fl_context.FLContext) dict[source]
restore(state_data: dict, fl_ctx: nvflare.apis.fl_context.FLContext)[source]