Source code for fedn.network.combiner.aggregators.fedavg

import traceback

from fedn.common.log_config import logger
from fedn.network.combiner.aggregators.aggregatorbase import AggregatorBase


[docs] class Aggregator(AggregatorBase): """Local SGD / Federated Averaging (FedAvg) aggregator. Computes a weighted mean of parameter updates. :param control: A handle to the :class: `fedn.network.combiner.updatehandler.UpdateHandler` :type control: class: `fedn.network.combiner.updatehandler.UpdateHandler` """ def __init__(self, update_handler): """Constructor method""" super().__init__(update_handler) self.name = "fedavg"
[docs] def combine_models(self, helper=None, delete_models=True, parameters=None): """Aggregate all model updates in the queue by computing an incremental weighted average of model parameters. :param helper: An instance of :class: `fedn.utils.helpers.helpers.HelperBase`, ML framework specific helper, defaults to None :type helper: class: `fedn.utils.helpers.helpers.HelperBase`, optional :param time_window: The time window for model aggregation, defaults to 180 :type time_window: int, optional :param max_nr_models: The maximum number of updates aggregated, defaults to 100 :type max_nr_models: int, optional :param delete_models: Delete models from storage after aggregation, defaults to True :type delete_models: bool, optional :return: The global model and metadata :rtype: tuple """ data = {} data["time_model_load"] = 0.0 data["time_model_aggregation"] = 0.0 model = None nr_aggregated_models = 0 total_examples = 0 logger.info("AGGREGATOR({}): Aggregating model updates... ".format(self.name)) while not self.update_handler.model_updates.empty(): try: logger.info("AGGREGATOR({}): Getting next model update from queue.".format(self.name)) model_update = self.update_handler.next_model_update() # Load model parameters and metadata logger.info("AGGREGATOR({}): Loading model metadata {}.".format(self.name, model_update.model_update_id)) model_next, metadata = self.update_handler.load_model_update(model_update, helper) logger.info("AGGREGATOR({}): Processing model update {}, metadata: {} ".format(self.name, model_update.model_update_id, metadata)) # Increment total number of examples total_examples += metadata["num_examples"] if nr_aggregated_models == 0: model = model_next else: model = helper.increment_average(model, model_next, metadata["num_examples"], total_examples) nr_aggregated_models += 1 # Delete model from storage if delete_models: self.update_handler.delete_model(model_update) except Exception as e: tb = traceback.format_exc() logger.error(f"AGGREGATOR({self.name}): Error encoutered while processing model update: {e}") logger.error(tb) data["nr_aggregated_models"] = nr_aggregated_models logger.info("AGGREGATOR({}): Aggregation completed, aggregated {} models.".format(self.name, nr_aggregated_models)) return model, data