Source code for fedn.network.controller.controlbase

import os
import uuid
from abc import ABC, abstractmethod
from time import sleep

import fedn.utils.helpers.helpers
from fedn.common.log_config import logger
from fedn.network.api.network import Network
from fedn.network.combiner.interfaces import CombinerUnavailableError
from fedn.network.combiner.roundhandler import RoundConfig
from fedn.network.state import ReducerState
from fedn.network.storage.s3.repository import Repository

# Maximum number of tries to connect to statestore and retrieve storage configuration
MAX_TRIES_BACKEND = os.getenv("MAX_TRIES_BACKEND", 10)


[docs] class UnsupportedStorageBackend(Exception): pass
[docs] class MisconfiguredStorageBackend(Exception): pass
[docs] class MisconfiguredHelper(Exception): pass
[docs] class ControlBase(ABC): """Base class and interface for a global controller. Override this class to implement a global training strategy (control). :param statestore: The statestore object. :type statestore: :class:`fedn.network.statestore.statestorebase.StateStoreBase` """ @abstractmethod def __init__(self, statestore): """Constructor.""" self._state = ReducerState.setup self.statestore = statestore if self.statestore.is_inited(): self.network = Network(self, statestore) try: not_ready = True tries = 0 while not_ready: storage_config = self.statestore.get_storage_backend() if storage_config: not_ready = False else: logger.warning("Storage backend not configured, waiting...") sleep(5) tries += 1 if tries > MAX_TRIES_BACKEND: raise Exception except Exception: logger.error("Failed to retrive storage configuration, exiting.") raise MisconfiguredStorageBackend() if storage_config["storage_type"] == "S3": self.model_repository = Repository(storage_config["storage_config"]) else: logger.error("Unsupported storage backend, exiting.") raise UnsupportedStorageBackend() if self.statestore.is_inited(): self._state = ReducerState.idle
[docs] @abstractmethod def session(self, config): pass
[docs] @abstractmethod def round(self, config, round_number): pass
[docs] @abstractmethod def reduce(self, combiners): pass
[docs] def get_helper(self): """Get a helper instance from global config. :return: Helper instance. :rtype: :class:`fedn.utils.plugins.helperbase.HelperBase` """ helper_type = self.statestore.get_helper() helper = fedn.utils.helpers.helpers.get_helper(helper_type) if not helper: raise MisconfiguredHelper("Unsupported helper type {}, please configure compute_package.helper !".format(helper_type)) return helper
[docs] def get_state(self): """Get the current state of the controller. :return: The current state. :rtype: :class:`fedn.network.state.ReducerState` """ return self._state
[docs] def idle(self): """Check if the controller is idle. :return: True if idle, False otherwise. :rtype: bool """ if self._state == ReducerState.idle: return True else: return False
[docs] def get_model_info(self): """:return:""" return self.statestore.get_model_trail()
# TODO: remove use statestore.get_events() instead
[docs] def get_events(self): """:return:""" return self.statestore.get_events()
[docs] def get_latest_round_id(self): last_round = self.statestore.get_latest_round() if not last_round: return 0 else: return last_round["round_id"]
[docs] def get_latest_round(self): round = self.statestore.get_latest_round() return round
[docs] def get_compute_package_name(self): """:return:""" definition = self.statestore.get_compute_package() if definition: try: package_name = definition["storage_file_name"] return package_name except (IndexError, KeyError): logger.error("No context filename set for compute context definition") return None else: return None
[docs] def set_compute_package(self, filename, path): """Persist the configuration for the compute package.""" self.model_repository.set_compute_package(filename, path)
[docs] def get_compute_package(self, compute_package=""): """:param compute_package: :return: """ if compute_package == "": compute_package = self.get_compute_package_name() if compute_package: return self.model_repository.get_compute_package(compute_package) else: return None
[docs] def create_session(self, config: RoundConfig, status: str = "Initialized") -> None: """Initialize a new session in backend db.""" if "session_id" not in config.keys(): session_id = uuid.uuid4() config["session_id"] = str(session_id) else: session_id = config["session_id"] self.statestore.create_session(id=session_id) self.statestore.set_session_config(session_id, config) self.statestore.set_session_status(session_id, status)
[docs] def set_session_status(self, session_id, status): """Set the round round stats. :param round_id: The round unique identifier :type round_id: str :param status: The status :type status: str """ self.statestore.set_session_status(session_id, status)
[docs] def get_session_status(self, session_id): """Get the status of a session. :param session_id: The session unique identifier :type session_id: str :return: The status :rtype: str """ return self.statestore.get_session_status(session_id)
[docs] def set_session_config(self, session_id: str, config: dict): """Set the model id for a session. :param session_id: The session unique identifier :type session_id: str :param config: The session config :type config: dict """ self.statestore.set_session_config_v2(session_id, config)
[docs] def create_round(self, round_data): """Initialize a new round in backend db.""" self.statestore.create_round(round_data)
[docs] def set_round_data(self, round_id, round_data): """Set round data. :param round_id: The round unique identifier :type round_id: str :param round_data: The status :type status: dict """ self.statestore.set_round_data(round_id, round_data)
[docs] def set_round_status(self, round_id, status): """Set the round round stats. :param round_id: The round unique identifier :type round_id: str :param status: The status :type status: str """ self.statestore.set_round_status(round_id, status)
[docs] def set_round_config(self, round_id, round_config: RoundConfig): """Upate round in backend db. :param round_id: The round unique identifier :type round_id: str :param round_config: The round configuration :type round_config: dict """ self.statestore.set_round_config(round_id, round_config)
[docs] def request_model_updates(self, combiners): """Ask Combiner server to produce a model update. :param combiners: A list of combiners :type combiners: tuple (combiner, combiner_round_config) """ cl = [] for combiner, combiner_round_config in combiners: response = combiner.submit(combiner_round_config) cl.append((combiner, response)) return cl
[docs] def commit(self, model_id, model=None, session_id=None): """Commit a model to the global model trail. The model commited becomes the lastest consensus model. :param model_id: Unique identifier for the model to commit. :type model_id: str (uuid) :param model: The model object to commit :type model: BytesIO :param session_id: Unique identifier for the session :type session_id: str """ helper = self.get_helper() if model is not None: logger.info("Saving model file temporarily to disk...") outfile_name = helper.save(model) logger.info("CONTROL: Uploading model to Minio...") model_id = self.model_repository.set_model(outfile_name, is_file=True) logger.info("CONTROL: Deleting temporary model file...") os.unlink(outfile_name) logger.info("Committing model {} to global model trail in statestore...".format(model_id)) self.statestore.set_latest_model(model_id, session_id)
[docs] def get_combiner(self, name): for combiner in self.network.get_combiners(): if combiner.name == name: return combiner return None
[docs] def get_participating_combiners(self, combiner_round_config): """Assemble a list of combiners able to participate in a round as descibed by combiner_round_config. """ combiners = [] for combiner in self.network.get_combiners(): try: # Current gRPC endpoint only returns active clients (both trainers and validators) nr_active_clients = len(combiner.list_active_clients()) except CombinerUnavailableError: self._handle_unavailable_combiner(combiner) continue clients_required = int(combiner_round_config["clients_required"]) is_participating = self.evaluate_round_participation_policy(clients_required, nr_active_clients) if is_participating: combiners.append((combiner, combiner_round_config)) return combiners
[docs] def evaluate_round_participation_policy(self, clients_required: int, nr_active_clients: int) -> bool: """Evaluate policy for combiner round-participation. A combiner participates if it is responsive and reports enough active clients to participate in the round. """ if clients_required <= nr_active_clients: return True else: return False
[docs] def evaluate_round_start_policy(self, combiners: list): """Check if the policy to start a round is met. :param combiners: A list of combiners :type combiners: list :return: True if the round policy is mer, otherwise False :rtype: bool """ if len(combiners) > 0: return True else: return False
[docs] def evaluate_round_validity_policy(self, round): """Check if the round is valid. At the end of the round, before committing a model to the global model trail, we check if the round validity policy has been met. This can involve e.g. asserting that a certain number of combiners have reported in an updated model, or that criteria on model performance have been met. :param round: The round object :rtype round: dict :return: True if the policy is met, otherwise False :rtype: bool """ model_ids = [] for combiner in round["combiners"]: try: model_ids.append(combiner["model_id"]) except KeyError: pass if len(model_ids) == 0: return False return True
[docs] def state(self): """Get the current state of the controller. :return: The state :rype: str """ return self._state