Source code for fedn.network.api.client

import inspect
import os

import requests

from fedn.network.combiner.hooks.serverfunctionsbase import ServerFunctionsBase

__all__ = ["APIClient"]


[docs] class APIClient: """An API client for interacting with the statestore and controller. :param host: The host of the api server. :type host: str :param port: The port of the api server. :type port: int :param secure: Whether to use https. :type secure: bool :param verify: Whether to verify the server certificate. :type verify: bool """ def __init__(self, host, port=None, secure=False, verify=False, token=None, auth_scheme=None): self.host = host self.port = port self.secure = secure self.verify = verify self.headers = {} # Auth scheme passed as argument overrides environment variable. # "Token" is the default auth scheme. if not auth_scheme: auth_scheme = os.environ.get("FEDN_AUTH_SCHEME", "Bearer") # Override potential env variable if token is passed as argument. if not token: token = os.environ.get("FEDN_AUTH_TOKEN", False) if token: self.headers = {"Authorization": f"{auth_scheme} {token}"} def _get_url(self, endpoint): if self.secure: protocol = "https" else: protocol = "http" if self.port: return f"{protocol}://{self.host}:{self.port}/{endpoint}" return f"{protocol}://{self.host}/{endpoint}" def _get_url_api_v1(self, endpoint): return self._get_url(f"api/v1/{endpoint}") # --- Clients --- #
[docs] def get_client(self, id: str): """Get a client from the statestore. :param id: The client id to get. :type id: str :return: Client. :rtype: dict """ response = requests.get(self._get_url_api_v1(f"clients/{id}"), verify=self.verify, headers=self.headers) _json = response.json() return _json
[docs] def get_clients(self, n_max: int = None): """Get clients from the statestore. :param n_max: The maximum number of clients to get (If none all will be fetched). :type n_max: int return: Clients. rtype: dict """ _headers = self.headers.copy() if n_max: _headers["X-Limit"] = str(n_max) response = requests.get(self._get_url_api_v1("clients/"), verify=self.verify, headers=_headers) _json = response.json() return _json
[docs] def get_clients_count(self): """Get the number of clients in the statestore. :return: The number of clients. :rtype: dict """ response = requests.get(self._get_url_api_v1("clients/count"), verify=self.verify, headers=self.headers) _json = response.json() return _json
[docs] def get_client_config(self, checksum=True): """Get client config from controller. Optionally include the checksum. The config is used for clients to connect to the controller and ask for combiner assignment. :param checksum: Whether to include the checksum of the package. :type checksum: bool :return: The client configuration. :rtype: dict """ _params = {"checksum": "true" if checksum else "false"} response = requests.get(self._get_url("get_client_config"), params=_params, verify=self.verify, headers=self.headers) _json = response.json() return _json
[docs] def get_active_clients(self, combiner_id: str = None, n_max: int = None): """Get active clients from the statestore. :param combiner_id: The combiner id to get active clients for. :type combiner_id: str :param n_max: The maximum number of clients to get (If none all will be fetched). :type n_max: int :return: Active clients. :rtype: dict """ _params = {"status": "online"} if combiner_id: _params["combiner"] = combiner_id _headers = self.headers.copy() if n_max: _headers["X-Limit"] = str(n_max) response = requests.get(self._get_url_api_v1("clients/"), params=_params, verify=self.verify, headers=_headers) _json = response.json() return _json
# --- Combiners --- #
[docs] def get_combiner(self, id: str): """Get a combiner from the statestore. :param id: The combiner id to get. :type id: str :return: Combiner. :rtype: dict """ response = requests.get(self._get_url_api_v1(f"combiners/{id}"), verify=self.verify, headers=self.headers) _json = response.json() return _json
[docs] def get_combiners(self, n_max: int = None): """Get combiners in the network. :param n_max: The maximum number of combiners to get (If none all will be fetched). :type n_max: int :return: Combiners. :rtype: dict """ _headers = self.headers.copy() if n_max: _headers["X-Limit"] = str(n_max) response = requests.get(self._get_url_api_v1("combiners/"), verify=self.verify, headers=_headers) _json = response.json() return _json
[docs] def get_combiners_count(self): """Get the number of combiners in the statestore. :return: The number of combiners. :rtype: dict """ response = requests.get(self._get_url_api_v1("combiners/count"), verify=self.verify, headers=self.headers) _json = response.json() return _json
# --- Controllers --- #
[docs] def get_controller_status(self): """Get the status of the controller. :return: The status of the controller. :rtype: dict """ response = requests.get(self._get_url("get_controller_status"), verify=self.verify, headers=self.headers) _json = response.json() return _json
# --- Models --- #
[docs] def get_model(self, id: str): """Get a model from the statestore. :param id: The id (or model property) of the model to get. :type id: str :return: Model. :rtype: dict """ response = requests.get(self._get_url_api_v1(f"models/{id}"), verify=self.verify, headers=self.headers) _json = response.json() return _json
[docs] def get_models(self, session_id: str = None, n_max: int = None): """Get models from the statestore. :param session_id: The session id to get models for. (optional) :type session_id: str :param n_max: The maximum number of models to get (If none all will be fetched). :type n_max: int :return: Models. :rtype: dict """ _params = {} if session_id: _params["session_id"] = session_id _headers = self.headers.copy() if n_max: _headers["X-Limit"] = str(n_max) response = requests.get(self._get_url_api_v1("models/"), params=_params, verify=self.verify, headers=_headers) _json = response.json() return _json
[docs] def get_models_count(self): """Get the number of models in the statestore. :return: The number of models. :rtype: dict """ response = requests.get(self._get_url_api_v1("models/count"), verify=self.verify, headers=self.headers) _json = response.json() return _json
[docs] def get_active_model(self): """Get the latest model from the statestore. :return: The latest model. :rtype: dict """ _headers = self.headers.copy() _headers["X-Limit"] = "1" response = requests.get(self._get_url_api_v1("models/"), verify=self.verify, headers=_headers) _json = response.json() if "result" in _json and len(_json["result"]) > 0: return _json["result"][0] return _json
[docs] def get_model_trail(self, id: str = None, include_self: bool = True, reverse: bool = True, n_max: int = None): """Get the model trail. :param id: The id (or model property) of the model to start the trail from. (optional) :type id: str :param n_max: The maximum number of models to get (If none all will be fetched). :type n_max: int :return: Models. :rtype: dict """ if not id: model = self.get_active_model() if "id" in model: id = model["id"] else: return model _headers = self.headers.copy() _count: int = n_max if n_max else self.get_models_count() _headers["X-Limit"] = str(_count) _headers["X-Reverse"] = "true" if reverse else "false" _include_self_str: str = "true" if include_self else "false" response = requests.get(self._get_url_api_v1(f"models/{id}/ancestors?include_self={_include_self_str}"), verify=self.verify, headers=_headers) _json = response.json() return _json
[docs] def download_model(self, id: str, path: str): """Download the model with id id. :param id: The id (or model property) of the model to download. :type id: str :param path: The path to download the model to. :type path: str :return: Message with success or failure. :rtype: dict """ response = requests.get(self._get_url_api_v1(f"models/{id}/download"), verify=self.verify, headers=self.headers) if response.status_code == 200: with open(path, "wb") as file: file.write(response.content) return {"success": True, "message": "Model downloaded successfully."} else: return {"success": False, "message": "Failed to download model."}
[docs] def set_active_model(self, path): """Set the initial model in the statestore and upload to model repository. :param path: The file path of the initial model to set. :type path: str :return: A dict with success or failure message. :rtype: dict """ if path.endswith(".npz"): helper = "numpyhelper" elif path.endswith(".bin"): helper = "binaryhelper" if helper: response = requests.put(self._get_url_api_v1("helpers/active"), json={"helper": helper}, verify=self.verify, headers=self.headers) with open(path, "rb") as file: response = requests.post( self._get_url("set_initial_model"), files={"file": file}, data={"helper": helper}, verify=self.verify, headers=self.headers ) return response.json()
# --- Packages --- #
[docs] def get_package(self, id: str): """Get a compute package from the statestore. :param id: The id of the compute package to get. :type id: str :return: Package. :rtype: dict """ response = requests.get(self._get_url_api_v1(f"packages/{id}"), verify=self.verify, headers=self.headers) _json = response.json() return _json
[docs] def get_packages(self, n_max: int = None): """Get compute packages from the statestore. :param n_max: The maximum number of packages to get (If none all will be fetched). :type n_max: int :return: Packages. :rtype: dict """ _headers = self.headers.copy() if n_max: _headers["X-Limit"] = str(n_max) response = requests.get(self._get_url_api_v1("packages/"), verify=self.verify, headers=_headers) _json = response.json() return _json
[docs] def get_packages_count(self): """Get the number of compute packages in the statestore. :return: The number of packages. :rtype: dict """ response = requests.get(self._get_url_api_v1("packages/count"), verify=self.verify, headers=self.headers) _json = response.json() return _json
[docs] def get_active_package(self): """Get the (active) compute package from the statestore. :return: Package. :rtype: dict """ response = requests.get(self._get_url_api_v1("packages/active"), verify=self.verify, headers=self.headers) _json = response.json() return _json
[docs] def get_package_checksum(self): """Get the checksum of the compute package. :return: The checksum. :rtype: dict """ response = requests.get(self._get_url("get_package_checksum"), verify=self.verify, headers=self.headers) _json = response.json() return _json
[docs] def download_package(self, path: str): """Download the compute package. :param path: The path to download the compute package to. :type path: str :return: Message with success or failure. :rtype: dict """ response = requests.get(self._get_url("download_package"), verify=self.verify, headers=self.headers) if response.status_code == 200: with open(path, "wb") as file: file.write(response.content) return {"success": True, "message": "Package downloaded successfully."} else: return {"success": False, "message": "Failed to download package."}
[docs] def set_active_package(self, path: str, helper: str, name: str = None, description: str = None): """Set the compute package in the statestore. :param path: The file path of the compute package to set. :type path: str :param helper: The helper type to use. :type helper: str :return: A dict with success or failure message. :rtype: dict """ with open(path, "rb") as file: response = requests.post( self._get_url("set_package"), files={"file": file}, data={"helper": helper, "name": name, "description": description}, verify=self.verify, headers=self.headers, ) _json = response.json() return _json
# --- Rounds --- #
[docs] def get_round(self, id: str): """Get a round from the statestore. :param round_id: The round id to get. :type round_id: str :return: Round (config and metrics). :rtype: dict """ response = requests.get(self._get_url_api_v1(f"rounds/{id}"), verify=self.verify, headers=self.headers) _json = response.json() return _json
[docs] def get_rounds(self, n_max: int = None): """Get all rounds from the statestore. :param n_max: The maximum number of rounds to get (If none all will be fetched). :type n_max: int :return: Rounds. :rtype: dict """ _headers = self.headers.copy() if n_max: _headers["X-Limit"] = str(n_max) response = requests.get(self._get_url_api_v1("rounds/"), verify=self.verify, headers=_headers) _json = response.json() return _json
[docs] def get_rounds_count(self): """Get the number of rounds in the statestore. :return: The number of rounds. :rtype: dict """ response = requests.get(self._get_url_api_v1("rounds/count"), verify=self.verify, headers=self.headers) _json = response.json() return _json
# --- Sessions --- #
[docs] def get_session(self, id: str): """Get a session from the statestore. :param id: The session id to get. :type id: str :return: Session. :rtype: dict """ response = requests.get(self._get_url_api_v1(f"sessions/{id}"), verify=self.verify, headers=self.headers) _json = response.json() return _json
[docs] def get_sessions(self, n_max: int = None): """Get sessions from the statestore. :param n_max: The maximum number of sessions to get (If none all will be fetched). :type n_max: int :return: Sessions. :rtype: dict """ _headers = self.headers.copy() if n_max: _headers["X-Limit"] = str(n_max) response = requests.get(self._get_url_api_v1("sessions/"), verify=self.verify, headers=_headers) _json = response.json() return _json
[docs] def get_sessions_count(self): """Get the number of sessions in the statestore. :return: The number of sessions. :rtype: dict """ response = requests.get(self._get_url_api_v1("sessions/count"), verify=self.verify, headers=self.headers) _json = response.json() return _json
[docs] def get_session_status(self, id: str): """Get the status of a session. :param id: The id of the session to get. :type id: str :return: The status of the session. :rtype: str """ session = self.get_session(id) if session and "status" in session: return session["status"] return "Could not retrieve session status."
[docs] def session_is_finished(self, id: str): """Check if a session with id has finished. :param id: The id of the session to get. :type id: str :return: True if session is finished, otherwise false. :rtype: bool """ status = self.get_session_status(id) return status and status.lower() == "finished"
[docs] def start_session( self, id: str = None, aggregator: str = "fedavg", aggregator_kwargs: dict = None, model_id: str = None, round_timeout: int = 180, rounds: int = 5, round_buffer_size: int = -1, delete_models: bool = True, validate: bool = True, helper: str = "", min_clients: int = 1, requested_clients: int = 8, server_functions: ServerFunctionsBase = None, ): """Start a new session. :param id: The session id to start. :type id: str :param aggregator: The aggregator plugin to use. :type aggregator: str :param model_id: The id of the initial model. :type model_id: str :param round_timeout: The round timeout to use in seconds. :type round_timeout: int :param rounds: The number of rounds to perform. :type rounds: int :param round_buffer_size: The round buffer size to use. :type round_buffer_size: int :param delete_models: Whether to delete models after each round at combiner (save storage). :type delete_models: bool :param validate: Whether to validate the model after each round. :type validate: bool :param helper: The helper type to use. :type helper: str :param min_clients: The minimum number of clients required. :type min_clients: int :param requested_clients: The requested number of clients. :type requested_clients: int :return: A dict with success or failure message and session config. :rtype: dict """ if model_id is None: response = requests.get(self._get_url_api_v1("models/active"), verify=self.verify, headers=self.headers) if response.status_code == 200: model_id = response.json() else: return response.json() response = requests.post( self._get_url_api_v1("sessions/"), json={ "session_id": id, "session_config": { "aggregator": aggregator, "aggregator_kwargs": aggregator_kwargs, "round_timeout": round_timeout, "buffer_size": round_buffer_size, "model_id": model_id, "delete_models_storage": delete_models, "clients_required": min_clients, "requested_clients": requested_clients, "validate": validate, "helper_type": helper, "server_functions": None if server_functions is None else inspect.getsource(server_functions), }, }, verify=self.verify, headers=self.headers, ) if response.status_code == 201: if id is None: id = response.json()["session_id"] response = requests.post( self._get_url_api_v1("sessions/start"), json={ "session_id": id, "rounds": rounds, "round_timeout": round_timeout, }, verify=self.verify, headers=self.headers, ) _json = response.json() return _json
# --- Statuses --- #
[docs] def get_status(self, id: str): """Get a status object (event) from the statestore. :param id: The id of the status to get. :type id: str :return: Status. :rtype: dict """ response = requests.get(self._get_url_api_v1(f"statuses/{id}"), verify=self.verify, headers=self.headers) _json = response.json() return _json
[docs] def get_statuses(self, session_id: str = None, event_type: str = None, sender_name: str = None, sender_role: str = None, n_max: int = None): """Get statuses from the statestore. Filter by input parameters :param session_id: The session id to get statuses for. :type session_id: str :param event_type: The event type to get. :type event_type: str :param sender_name: The sender name to get. :type sender_name: str :param sender_role: The sender role to get. :type sender_role: str :param n_max: The maximum number of statuses to get (If none all will be fetched). :type n_max: int :return: Statuses """ _params = {} if session_id: _params["session_id"] = session_id if event_type: _params["type"] = event_type if sender_name: _params["sender.name"] = sender_name if sender_role: _params["sender.role"] = sender_role _headers = self.headers.copy() if n_max: _headers["X-Limit"] = str(n_max) response = requests.get(self._get_url_api_v1("statuses/"), params=_params, verify=self.verify, headers=_headers) _json = response.json() return _json
[docs] def get_statuses_count(self): """Get the number of statuses in the statestore. :return: The number of statuses. :rtype: dict """ response = requests.get(self._get_url_api_v1("statuses/count"), verify=self.verify, headers=self.headers) _json = response.json() return _json
# --- Validations --- #
[docs] def get_validation(self, id: str): """Get a validation from the statestore. :param id: The id of the validation to get. :type id: str :return: Validation. :rtype: dict """ response = requests.get(self._get_url_api_v1(f"validations/{id}"), verify=self.verify, headers=self.headers) _json = response.json() return _json
[docs] def get_validations( self, session_id: str = None, model_id: str = None, correlation_id: str = None, sender_name: str = None, sender_role: str = None, receiver_name: str = None, receiver_role: str = None, n_max: int = None, ): """Get validations from the statestore. Filter by input parameters. :param session_id: The session id to get validations for. :type session_id: str :param model_id: The model id to get validations for. :type model_id: str :param correlation_id: The correlation id to get validations for. :type correlation_id: str :param sender_name: The sender name to get validations for. :type sender_name: str :param sender_role: The sender role to get validations for. :type sender_role: str :param receiver_name: The receiver name to get validations for. :type receiver_name: str :param receiver_role: The receiver role to get validations for. :type receiver_role: str :param n_max: The maximum number of validations to get (If none all will be fetched). :type n_max: int :return: Validations. :rtype: dict """ _params = {} if session_id: _params["sessionId"] = session_id if model_id: _params["modelId"] = model_id if correlation_id: _params["correlationId"] = correlation_id if sender_name: _params["sender.name"] = sender_name if sender_role: _params["sender.role"] = sender_role if receiver_name: _params["receiver.name"] = receiver_name if receiver_role: _params["receiver.role"] = receiver_role _headers = self.headers.copy() if n_max: _headers["X-Limit"] = str(n_max) response = requests.get(self._get_url_api_v1("validations/"), params=_params, verify=self.verify, headers=_headers) _json = response.json() return _json
[docs] def get_validations_count(self): """Get the number of validations in the statestore. :return: The number of validations. :rtype: dict """ response = requests.get(self._get_url_api_v1("validations/count"), verify=self.verify, headers=self.headers) _json = response.json() return _json