Source code for

import ast
import queue
import random
import sys
import time
import uuid
from typing import TypedDict

from fedn.common.log_config import logger
from import get_aggregator
from import load_model_from_BytesIO, serialize_model_to_BytesIO
from fedn.utils.helpers.helpers import get_helper
from fedn.utils.parameters import Parameters

[docs] class RoundConfig(TypedDict): """Round configuration. :param _job_id: A universally unique identifier for the round. Set by Combiner. :type _job_id: str :param committed_at: The time the round was committed. Set by Controller. :type committed_at: str :param task: The task to perform in the round. Set by Controller. Supported tasks are "training", "validation", and "inference". :type task: str :param round_id: The round identifier as str(int) :type round_id: str :param round_timeout: The round timeout in seconds. Set by user interfaces or Controller. :type round_timeout: str :param rounds: The number of rounds. Set by user interfaces. :param model_id: The model identifier. Set by user interfaces or Controller (get_latest_model). :type model_id: str :param model_version: The model version. Currently not used. :type model_version: str :param model_type: The model type. Currently not used. :type model_type: str :param model_size: The size of the model. Currently not used. :type model_size: int :param model_parameters: The model parameters. Currently not used. :type model_parameters: dict :param model_metadata: The model metadata. Currently not used. :type model_metadata: dict :param session_id: The session identifier. Set by (Controller?). :type session_id: str :param helper_type: The helper type. :type helper_type: str :param aggregator: The aggregator type. :type aggregator: str """ _job_id: str committed_at: str task: str round_id: str round_timeout: str rounds: int model_id: str model_version: str model_type: str model_size: int model_parameters: dict model_metadata: dict session_id: str helper_type: str aggregator: str
[docs] class ModelUpdateError(Exception): pass
[docs] class RoundHandler: """Round handler. The round handler processes requests from the global controller to produce model updates and perform model validations. :param aggregator_name: The name of the aggregator plugin module. :type aggregator_name: str :param storage: Model repository for :class: `` :type storage: class: `` :param server: A handle to the Combiner class :class: `` :type server: class: `` :param modelservice: A handle to the model service :class: `` :type modelservice: class: `` """ def __init__(self, storage, server, modelservice): """Initialize the RoundHandler.""" self.round_configs = queue.Queue() = storage self.server = server self.modelservice = modelservice
[docs] def set_aggregator(self, aggregator): self.aggregator = get_aggregator(aggregator,, self.server, self.modelservice, self)
[docs] def push_round_config(self, round_config: RoundConfig) -> str: """Add a round_config (job description) to the inbox. :param round_config: A dict containing the round configuration (from global controller). :type round_config: dict :return: A job id (universally unique identifier) for the round. :rtype: str """ try: round_config["_job_id"] = str(uuid.uuid4()) self.round_configs.put(round_config) except Exception: logger.error("Failed to push round config.") raise return round_config["_job_id"]
[docs] def load_model_update(self, helper, model_id): """Load model update with id model_id into its memory representation. :param helper: An instance of :class: `fedn.utils.helpers.helpers.HelperBase` :type helper: class: `fedn.utils.helpers.helpers.HelperBase` :param model_id: The ID of the model update, UUID in str format :type model_id: str """ model_str = self.load_model_update_str(model_id) if model_str: try: model = load_model_from_BytesIO(model_str.getbuffer(), helper) except IOError: logger.warning("AGGREGATOR({}): Failed to load model!".format( else: raise ModelUpdateError("Failed to load model.") return model
[docs] def load_model_update_str(self, model_id, retry=3): """Load model update object and return it as BytesIO. :param model_id: The ID of the model :type model_id: str :param retry: number of times retrying load model update, defaults to 3 :type retry: int, optional :return: Updated model :rtype: class: `io.BytesIO` """ # Try reading model update from local disk/combiner memory model_str = self.modelservice.temp_model_storage.get(model_id) # And if we cannot access that, try downloading from the server if model_str is None: model_str = self.modelservice.get_model(model_id) # TODO: use retrying library tries = 0 while tries < retry: tries += 1 if not model_str or sys.getsizeof(model_str) == 80: logger.warning("Model download failed. retrying") time.sleep(1) model_str = self.modelservice.get_model(model_id) return model_str
[docs] def waitforit(self, config, buffer_size=100, polling_interval=0.1): """Defines the policy for how long the server should wait before starting to aggregate models. The policy is as follows: 1. Wait a maximum of time_window time until the round times out. 2. Terminate if a preset number of model updates (buffer_size) are in the queue. :param config: The round config object :type config: dict :param buffer_size: The number of model updates to wait for before starting aggregation, defaults to 100 :type buffer_size: int, optional :param polling_interval: The polling interval, defaults to 0.1 :type polling_interval: float, optional """ time_window = float(config["round_timeout"]) tt = 0.0 while tt < time_window: if self.aggregator.model_updates.qsize() >= buffer_size: break time.sleep(polling_interval) tt += polling_interval
def _training_round(self, config, clients): """Send model update requests to clients and aggregate results. :param config: The round config object (passed to the client). :type config: dict :param clients: clients to participate in the training round :type clients: list :return: an aggregated model and associated metadata :rtype: model, dict """"ROUNDHANDLER: Initiating training round, participating clients: {}".format(clients)) meta = {} meta["nr_expected_updates"] = len(clients) meta["nr_required_updates"] = int(config["clients_required"]) meta["timeout"] = float(config["round_timeout"]) session_id = config["session_id"] model_id = config["model_id"] # Request model updates from all active clients. self.server.request_model_update(session_id=session_id, model_id=model_id, config=config, clients=clients) # If buffer_size is -1 (default), the round terminates when/if all clients have completed. if int(config["buffer_size"]) == -1: buffer_size = len(clients) else: buffer_size = int(config["buffer_size"]) # Wait / block until the round termination policy has been met. self.waitforit(config, buffer_size=buffer_size) tic = time.time() model = None data = None try: helper = get_helper(config["helper_type"])"Config delete_models_storage: {}".format(config["delete_models_storage"])) if config["delete_models_storage"] == "True": delete_models = True else: delete_models = False if "aggregator_kwargs" in config.keys(): dict_parameters = ast.literal_eval(config["aggregator_kwargs"]) parameters = Parameters(dict_parameters) else: parameters = None model, data = self.aggregator.combine_models(helper=helper, delete_models=delete_models, parameters=parameters) except Exception as e: logger.warning("AGGREGATION FAILED AT COMBINER! {}".format(e)) meta["time_combination"] = time.time() - tic meta["aggregation_time"] = data return model, meta def _validation_round(self, session_id, model_id, clients): """Send model validation requests to clients. :param config: The round config object (passed to the client). :type config: dict :param clients: clients to send validation requests to :type clients: list :param model_id: The ID of the model to validate :type model_id: str """ self.server.request_model_validation(session_id, model_id, clients=clients) def _inference_round(self, session_id: str, model_id: str, clients: list): """Send model inference requests to clients. :param config: The round config object (passed to the client). :type config: dict :param clients: clients to send inference requests to :type clients: list :param model_id: The ID of the model to use for inference :type model_id: str """ self.server.request_model_inference(session_id, model_id, clients=clients)
[docs] def stage_model(self, model_id, timeout_retry=3, retry=2): """Download a model from persistent storage and set in modelservice. :param model_id: ID of the model update object to stage. :type model_id: str :param timeout_retry: Sleep before retrying download again(sec), defaults to 3 :type timeout_retry: int, optional :param retry: Number of retries, defaults to 2 :type retry: int, optional """ # If the model is already in memory at the server we do not need to do anything. if self.modelservice.temp_model_storage.exist(model_id):"Model already exists in memory, skipping model staging.") return"Model Staging, fetching model from storage...") # If not, download it and stage it in memory at the combiner. tries = 0 while True: try: model = if model: break except Exception: logger.warning("Could not fetch model from storage backend, retrying.") time.sleep(timeout_retry) tries += 1 if tries > retry: logger.error("Failed to stage model {} from storage backend!".format(model_id)) raise self.modelservice.set_model(model, model_id)
def _assign_round_clients(self, n, type="trainers"): """Obtain a list of clients(trainers or validators) to ask for updates in this round. :param n: Size of a random set taken from active trainers(clients), if n > "active trainers" all is used :type n: int :param type: type of clients, either "trainers" or "validators", defaults to "trainers" :type type: str, optional :return: Set of clients :rtype: list """ if type == "validators": clients = self.server.get_active_validators() elif type == "trainers": clients = self.server.get_active_trainers() else: logger.error("(ERROR): {} is not a supported type of client".format(type)) # If the number of requested trainers exceeds the number of available, use all available. n = min(n, len(clients)) # If not, we pick a random subsample of all available clients. clients = random.sample(clients, n) return clients def _check_nr_round_clients(self, config): """Check that the minimal number of clients required to start a round are available. :param config: The round config object. :type config: dict :param timeout: Timeout in seconds, defaults to 0.0 :type timeout: float, optional :return: True if the required number of clients are available, False otherwise. :rtype: bool """ active = self.server.nr_active_trainers() if active >= int(config["clients_required"]):"Number of clients required ({0}) to start round met {1}.".format(config["clients_required"], active)) return True else:"Too few clients to start round.") return False
[docs] def execute_validation_round(self, session_id, model_id): """Coordinate validation rounds as specified in config. :param round_config: The round config object. :type round_config: dict """"COMBINER orchestrating validation of model {}".format(model_id)) self.stage_model(model_id) validators = self._assign_round_clients(self.server.max_clients, type="validators") self._validation_round(session_id, model_id, validators)
[docs] def execute_inference_round(self, session_id: str, model_id: str) -> None: """Coordinate inference rounds as specified in config. :param round_config: The round config object. :type round_config: dict """"COMBINER orchestrating inference using model {}".format(model_id)) self.stage_model(model_id) # TODO: Implement inference client type clients = self._assign_round_clients(self.server.max_clients, type="validators") self._inference_round(session_id, model_id, clients)
[docs] def execute_training_round(self, config): """Coordinates clients to execute training tasks. :param config: The round config object. :type config: dict :return: metadata about the training round. :rtype: dict """"Processing training round, job_id {}".format(config["_job_id"])) data = {} data["config"] = config data["round_id"] = config["round_id"] # Download model to update and set in temp storage. self.stage_model(config["model_id"]) clients = self._assign_round_clients(self.server.max_clients) model, meta = self._training_round(config, clients) data["data"] = meta if model is None: logger.warning("\t Failed to update global model in round {0}!".format(config["round_id"])) if model is not None: helper = get_helper(config["helper_type"]) a = serialize_model_to_BytesIO(model, helper) model_id =, is_file=False) a.close() data["model_id"] = model_id"TRAINING ROUND COMPLETED. Aggregated model id: {}, Job id: {}".format(model_id, config["_job_id"])) # Delete temp model self.modelservice.temp_model_storage.delete(config["model_id"]) return data
[docs] def run(self, polling_interval=1.0): """Main control loop. Execute rounds based on round config on the queue. :param polling_interval: The polling interval in seconds for checking if a new job/config is available. :type polling_interval: float """ try: while True: try: round_config = self.round_configs.get(block=False) session_id = round_config["session_id"] model_id = round_config["model_id"] # Check that the minimum allowed number of clients are connected ready = self._check_nr_round_clients(round_config) round_meta = {} if ready: if round_config["task"] == "training": tic = time.time() round_meta = self.execute_training_round(round_config) round_meta["time_exec_training"] = time.time() - tic round_meta["status"] = "Success" round_meta["name"] = self.server.statestore.set_round_combiner_data(round_meta) elif round_config["task"] == "validation": self.execute_validation_round(session_id, model_id) elif round_config["task"] == "inference": self.execute_inference_round(session_id, model_id) else: logger.warning("config contains unkown task type.") else: round_meta = {} round_meta["status"] = "Failed" round_meta["reason"] = "Failed to meet client allocation requirements for this round config." logger.warning("{0}".format(round_meta["reason"])) self.round_configs.task_done() except queue.Empty: time.sleep(polling_interval) except (KeyboardInterrupt, SystemExit): pass