Source code for fedn.network.combiner.round

import queue
import random
import sys
import time
import uuid

from fedn.common.log_config import logger
from fedn.network.combiner.aggregators.aggregatorbase import get_aggregator
from fedn.utils.helpers import get_helper


[docs] class ModelUpdateError(Exception): pass
[docs] class RoundController: """ Round controller. The round controller recieves round configurations from the global controller and coordinates model updates and aggregation, and model validations. :param aggregator_name: The name of the aggregator plugin module. :type aggregator_name: str :param storage: Model repository for :class: `fedn.network.combiner.Combiner` :type storage: class: `fedn.common.storage.s3.s3repo.S3ModelRepository` :param server: A handle to the Combiner class :class: `fedn.network.combiner.Combiner` :type server: class: `fedn.network.combiner.Combiner` :param modelservice: A handle to the model service :class: `fedn.network.combiner.modelservice.ModelService` :type modelservice: class: `fedn.network.combiner.modelservice.ModelService` """ def __init__(self, aggregator_name, storage, server, modelservice): """ Initialize the RoundController.""" self.round_configs = queue.Queue() self.storage = storage self.server = server self.modelservice = modelservice self.aggregator = get_aggregator(aggregator_name, self.storage, self.server, self.modelservice, self)
[docs] def push_round_config(self, round_config): """Add a round_config (job description) to the inbox. :param round_config: A dict containing the round configuration (from global controller). :type round_config: dict :return: A job id (universally unique identifier) for the round. :rtype: str """ try: round_config['_job_id'] = str(uuid.uuid4()) self.round_configs.put(round_config) except Exception: logger.warning( "ROUNDCONTROL: Failed to push round config.") raise return round_config['_job_id']
[docs] def load_model_update(self, helper, model_id): """Load model update in its native format. :param helper: An instance of :class: `fedn.utils.helpers.HelperBase`, ML framework specific helper, defaults to None :type helper: class: `fedn.utils.helpers.HelperBase` :param model_id: The ID of the model update, UUID in str format :type model_id: str """ model_str = self.load_model_update_str(model_id) if model_str: try: model = self.modelservice.load_model_from_BytesIO(model_str.getbuffer(), helper) except IOError: logger.warning( "AGGREGATOR({}): Failed to load model!".format(self.name)) else: raise ModelUpdateError("Failed to load model.") return model
[docs] def load_model_update_str(self, model_id, retry=3): """Load model update object and return it as BytesIO. :param model_id: The ID of the model :type model_id: str :param retry: number of times retrying load model update, defaults to 3 :type retry: int, optional :return: Updated model :rtype: class: `io.BytesIO` """ # Try reading model update from local disk/combiner memory model_str = self.modelservice.models.get(model_id) # And if we cannot access that, try downloading from the server if model_str is None: model_str = self.modelservice.get_model(model_id) # TODO: use retrying library tries = 0 while tries < retry: tries += 1 if not model_str or sys.getsizeof(model_str) == 80: logger.warning( "ROUNDCONTROL: Model download failed. retrying") time.sleep(1) model_str = self.modelservice.get_model(model_id) return model_str
[docs] def waitforit(self, config, buffer_size=100, polling_interval=0.1): """ Defines the policy for how long the server should wait before starting to aggregate models. The policy is as follows: 1. Wait a maximum of time_window time until the round times out. 2. Terminate if a preset number of model updates (buffer_size) are in the queue. :param config: The round config object :type config: dict :param buffer_size: The number of model updates to wait for before starting aggregation, defaults to 100 :type buffer_size: int, optional :param polling_interval: The polling interval, defaults to 0.1 :type polling_interval: float, optional """ time_window = float(config['round_timeout']) tt = 0.0 while tt < time_window: if self.aggregator.model_updates.qsize() >= buffer_size: break time.sleep(polling_interval) tt += polling_interval
def _training_round(self, config, clients): """Send model update requests to clients and aggregate results. :param config: The round config object (passed to the client). :type config: dict :param clients: clients to participate in the training round :type clients: list :return: an aggregated model and associated metadata :rtype: model, dict """ logger.info( "ROUNDCONTROL: Initiating training round, participating clients: {}".format(clients)) meta = {} meta['nr_expected_updates'] = len(clients) meta['nr_required_updates'] = int(config['clients_required']) meta['timeout'] = float(config['round_timeout']) # Request model updates from all active clients. self.server.request_model_update(config, clients=clients) # If buffer_size is -1 (default), the round terminates when/if all clients have completed. if int(config['buffer_size']) == -1: buffer_size = len(clients) else: buffer_size = int(config['buffer_size']) # Wait / block until the round termination policy has been met. self.waitforit(config, buffer_size=buffer_size) tic = time.time() model = None data = None try: helper = get_helper(config['helper_type']) logger.info("ROUNDCONTROL: Config delete_models_storage: {}".format(config['delete_models_storage'])) if config['delete_models_storage'] == 'True': delete_models = True else: delete_models = False model, data = self.aggregator.combine_models(helper=helper, delete_models=delete_models) except Exception as e: logger.warning("AGGREGATION FAILED AT COMBINER! {}".format(e)) meta['time_combination'] = time.time() - tic meta['aggregation_time'] = data return model, meta def _validation_round(self, config, clients, model_id): """Send model validation requests to clients. :param config: The round config object (passed to the client). :type config: dict :param clients: clients to send validation requests to :type clients: list :param model_id: The ID of the model to validate :type model_id: str """ self.server.request_model_validation(model_id, config, clients)
[docs] def stage_model(self, model_id, timeout_retry=3, retry=2): """Download a model from persistent storage and set in modelservice. :param model_id: ID of the model update object to stage. :type model_id: str :param timeout_retry: Sleep before retrying download again(sec), defaults to 3 :type timeout_retry: int, optional :param retry: Number of retries, defaults to 2 :type retry: int, optional """ # If the model is already in memory at the server we do not need to do anything. if self.modelservice.models.exist(model_id): logger.info("ROUNDCONTROL: Model already exists in memory, skipping model staging.") return logger.info("ROUNDCONTROL: Model Staging, fetching model from storage...") # If not, download it and stage it in memory at the combiner. tries = 0 while True: try: model = self.storage.get_model_stream(model_id) if model: break except Exception: logger.info("ROUNDCONTROL: Could not fetch model from storage backend, retrying.") time.sleep(timeout_retry) tries += 1 if tries > retry: logger.info( "ROUNDCONTROL: Failed to stage model {} from storage backend!".format(model_id)) raise self.modelservice.set_model(model, model_id)
def _assign_round_clients(self, n, type="trainers"): """ Obtain a list of clients(trainers or validators) to ask for updates in this round. :param n: Size of a random set taken from active trainers(clients), if n > "active trainers" all is used :type n: int :param type: type of clients, either "trainers" or "validators", defaults to "trainers" :type type: str, optional :return: Set of clients :rtype: list """ if type == "validators": clients = self.server.get_active_validators() elif type == "trainers": clients = self.server.get_active_trainers() else: logger.info( "ROUNDCONTROL(ERROR): {} is not a supported type of client".format(type)) raise # If the number of requested trainers exceeds the number of available, use all available. if n > len(clients): n = len(clients) # If not, we pick a random subsample of all available clients. clients = random.sample(clients, n) return clients def _check_nr_round_clients(self, config, timeout=0.0): """Check that the minimal number of clients required to start a round are available. :param config: The round config object. :type config: dict :param timeout: Timeout in seconds, defaults to 0.0 :type timeout: float, optional :return: True if the required number of clients are available, False otherwise. :rtype: bool """ ready = False t = 0.0 while not ready: active = self.server.nr_active_trainers() if active >= int(config['clients_requested']): return True else: logger.info("waiting for {} clients to get started, currently: {}".format( int(config['clients_requested']) - active, active)) if t >= timeout: if active >= int(config['clients_required']): return True else: return False time.sleep(1.0) t += 1.0 return ready
[docs] def execute_validation_round(self, round_config): """ Coordinate validation rounds as specified in config. :param round_config: The round config object. :type round_config: dict """ model_id = round_config['model_id'] logger.info( "COMBINER orchestrating validation of model {}".format(model_id)) self.stage_model(model_id) validators = self._assign_round_clients( self.server.max_clients, type="validators") self._validation_round(round_config, validators, model_id)
[docs] def execute_training_round(self, config): """ Coordinates clients to execute training tasks. :param config: The round config object. :type config: dict :return: metadata about the training round. :rtype: dict """ logger.info( "ROUNDCONTROL: Processing training round, job_id {}".format(config['_job_id'])) data = {} data['config'] = config data['round_id'] = config['round_id'] # Make sure the model to update is available on this combiner. self.stage_model(config['model_id']) clients = self._assign_round_clients(self.server.max_clients) model, meta = self._training_round(config, clients) data['data'] = meta if model is None: logger.warning( "\t Failed to update global model in round {0}!".format(config['round_id'])) if model is not None: helper = get_helper(config['helper_type']) a = self.modelservice.serialize_model_to_BytesIO(model, helper) # Send aggregated model to server model_id = str(uuid.uuid4()) self.modelservice.set_model(a, model_id) a.close() data['model_id'] = model_id logger.info( "ROUNDCONTROL: TRAINING ROUND COMPLETED. Aggregated model id: {}, Job id: {}".format(model_id, config['_job_id'])) return data
[docs] def run(self, polling_interval=1.0): """ Main control loop. Execute rounds based on round config on the queue. :param polling_interval: The polling interval in seconds for checking if a new job/config is available. :type polling_interval: float """ try: while True: try: round_config = self.round_configs.get(block=False) # Check that the minimum allowed number of clients are connected ready = self._check_nr_round_clients(round_config) round_meta = {} if ready: if round_config['task'] == 'training': tic = time.time() round_meta = self.execute_training_round(round_config) round_meta['time_exec_training'] = time.time() - \ tic round_meta['status'] = "Success" round_meta['name'] = self.server.id self.server.statestore.set_round_combiner_data(round_meta) elif round_config['task'] == 'validation' or round_config['task'] == 'inference': self.execute_validation_round(round_config) else: logger.warning( "ROUNDCONTROL: Round config contains unkown task type.") else: round_meta = {} round_meta['status'] = "Failed" round_meta['reason'] = "Failed to meet client allocation requirements for this round config." logger.warning( "ROUNDCONTROL: {0}".format(round_meta['reason'])) self.round_configs.task_done() except queue.Empty: time.sleep(polling_interval) except (KeyboardInterrupt, SystemExit): pass