Source code for flip.nvflare.components.pt_model_locator

# Copyright (c) 2026 Guy's and St Thomas' NHS Foundation Trust & King's College London
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#     http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

import json
import os
from typing import List, Union

import torch
import torch.cuda
from nvflare.apis.dxo import DXO, DataKind
from nvflare.apis.fl_context import FLContext
from nvflare.app_common.abstract.model import model_learnable_to_dxo
from nvflare.app_common.abstract.model_locator import ModelLocator
from nvflare.app_opt.pt import PTModelPersistenceFormatManager

from flip.constants import FlipConstants, PTConstants


[docs] class PTModelLocator(ModelLocator): def __init__(self, exclude_vars=None, model=None): super(PTModelLocator, self).__init__() if model is None: from models import get_model model = get_model()
[docs] self.model = model
[docs] self.exclude_vars = exclude_vars
[docs] def get_model_names(self, fl_ctx: FLContext) -> List[str]: return [PTConstants.PTServerName]
[docs] def locate_model(self, model_name, fl_ctx: FLContext) -> Union[DXO, None]: if model_name == PTConstants.PTServerName: try: server_run_dir = fl_ctx.get_engine().get_workspace().get_app_dir(fl_ctx.get_job_id()) # Log server_run_dir self.log_info(fl_ctx, f"Server run directory: {server_run_dir}") if FlipConstants.LOCAL_DEV: model_path = os.path.join(server_run_dir, PTConstants.PTFileModelName) else: model_path = os.path.join(server_run_dir, "model", PTConstants.PTFileModelName) if not os.path.exists(model_path): self.log_error(fl_ctx, f"Model file not found at {model_path}", fire_event=False) return None # Load the torch model device = "cuda" if torch.cuda.is_available() else "cpu" data = torch.load(model_path, map_location=device) # Setup the persistence manager. if self.model: default_train_conf = {"train": {"model": type(self.model).__name__}} else: default_train_conf = None # Use persistence manager to get learnable persistence_manager = PTModelPersistenceFormatManager(data, default_train_conf=default_train_conf) ml = persistence_manager.to_model_learnable(exclude_vars=None) # Create dxo and return return model_learnable_to_dxo(ml) except Exception as e: self.log_error(fl_ctx, f"Error in retrieving {model_name}: {e}", fire_event=False) return None else: self.log_error(fl_ctx, f"PTModelLocator doesn't recognize name: {model_name}", fire_event=False) return None
[docs] class InitialPTModelLocator(ModelLocator): def __init__(self, exclude_vars=None, model=None): super(InitialPTModelLocator, self).__init__() if model is None: from models import get_model model = get_model()
[docs] self.model = model
[docs] self.exclude_vars = exclude_vars
[docs] def get_model_names(self, fl_ctx: FLContext) -> List[str]: return [PTConstants.PTServerName]
[docs] def locate_model(self, model_name, fl_ctx: FLContext) -> Union[DXO, None]: # We look for existing models self.log_info(fl_ctx, f"Trying to locate the model {model_name}") if model_name == PTConstants.PTServerName: try: server_run_dir = fl_ctx.get_engine().get_workspace().get_app_dir(fl_ctx.get_job_id()) model_path = os.path.join(server_run_dir, PTConstants.PTFileModelName) self.log_info(fl_ctx, model_path) if not os.path.exists(model_path): self.log_info(fl_ctx, f"Model does not exist at {model_path}") # Safe house: constant safehouse should be defined. Here we are just getting it directly. model_path = os.path.join("/safehouse", fl_ctx.get_job_id(), PTConstants.PTFileModelName) if not os.path.exists(model_path): self.log_info(fl_ctx, f"Model does not exist at safehouse ({model_path})") return None # Load the torch model device = "cuda" if torch.cuda.is_available() else "cpu" data = torch.load( model_path, map_location=device, weights_only=True, ) # Setup the persistence manager. if self.model: default_train_conf = {"train": {"model": type(self.model).__name__}} self.log_info(fl_ctx, f"Default train conf: {default_train_conf}") else: default_train_conf = None # Use persistence manager to get learnable try: persistence_manager = PTModelPersistenceFormatManager(data, default_train_conf=default_train_conf) ml = persistence_manager.to_model_learnable(exclude_vars=None) except RuntimeError: self.log_info(fl_ctx, f"Could not load the weights from {model_path} into the model. ") return None # Create dxo and return return ml except Exception as e: self.log_error(fl_ctx, f"Error in retrieving {model_name}: {e}", fire_event=False) return None else: self.log_error( fl_ctx, f"PTModelLocator doesn't recognize name: {model_name}", fire_event=False, ) return None
[docs] class EvaluationPTModelLocator(ModelLocator): def __init__(self, exclude_vars=None): super(EvaluationPTModelLocator, self).__init__()
[docs] self.models = None
[docs] self.exclude_vars = exclude_vars
[docs] def locate_model(self, fl_ctx: FLContext) -> Union[DXO, None]: if self.models is None: # Load config from workspace app_dir = fl_ctx.get_engine().get_workspace().get_app_dir(fl_ctx.get_job_id()) config_path = os.path.join(app_dir, "custom", "config.json") with open(config_path, "r") as file: self.config = json.load(file) if "models" not in self.config.keys(): self.log_error( fl_ctx, "In this pipeline, there must be a models key-element object in the config.json file, " "pointing to the getter function as well the architecture.", fire_event=True, ) else: models_config = self.config["models"] self.model_names = models_config.keys() self.models = {} from models import model_paths for name in self.model_names: model_checkpoint = models_config[name]["checkpoint"] checkpoint_path = os.path.join(app_dir, "custom", model_checkpoint) if not os.path.isfile(checkpoint_path): self.log_error( fl_ctx, f"Model checkpoint for model {name} not found at {checkpoint_path}", fire_event=True, ) net = model_paths[models_config[name]["path"]] self.models[name] = torch.load( checkpoint_path, weights_only=True, map_location="cuda" if torch.cuda.is_available() else "cpu", ) try: net.load_state_dict(self.models[name], strict=True) except Exception as e: self.log_error( fl_ctx, f"The weights for network {name} could not be loaded into the object: {e}", fire_event=True, ) all_model_dxo = {} for model_name, weight in self.models.items(): # We convert this into a DXO persistence_manager = PTModelPersistenceFormatManager(weight, default_train_conf=None) # Model learnable to DXO: ml = persistence_manager.to_model_learnable(exclude_vars=None) # We convert this into a DXO: all_model_dxo[model_name] = model_learnable_to_dxo(ml) # Create dxo and return return DXO(data_kind=DataKind.COLLECTION, data=all_model_dxo)