Source code for flip.utils.model_weights_handling

# Copyright (c) 2026 Guy's and St Thomas' NHS Foundation Trust & King's College London
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#     http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

from __future__ import annotations

from collections import OrderedDict
from typing import TYPE_CHECKING

from nvflare.apis.dxo import DXO, DataKind, MetaKey

if TYPE_CHECKING:
    pass


[docs] def get_model_weights_diff(original_weights: OrderedDict, new_weights: OrderedDict, iterations: int) -> DXO: """Compute the weights differences to send a weight update. Args: original_weights (OrderedDict): weights coming from the server (before training) new_weights (OrderedDict): weights coming out of the server (post training) iterations (int): number of iterations Returns: DXO: DXO containing the weight updates for the server. """ import torch first_key = next(iter(new_weights)) first_val = new_weights[first_key] if isinstance(first_val, torch.Tensor): new_weights_dict = {k: v.cpu().numpy() for k, v in new_weights.items()} else: new_weights_dict = dict(new_weights) weight_diff = {} for k in new_weights_dict.keys(): weight_diff[k] = new_weights_dict[k] - original_weights[k] outgoing_dxo = DXO( data_kind=DataKind.WEIGHT_DIFF, data=weight_diff, meta={MetaKey.NUM_STEPS_CURRENT_ROUND: iterations}, ) return outgoing_dxo