import inspect
import os
import re
from datetime import datetime
from typing import Callable, Dict, Optional
import uuid
import requests
from scaleoututil.auth.token_manager import TokenManager
from scaleoututil.auth.token_cache import TokenCache
from scaleoututil.logging import ScaleoutLogger
from scaleoututil.serverfunctions.serverfunctionsbase import ServerFunctionsBase
from scaleoututil.utils.url import parse_url
[docs]
class Scaleout:
"""Python client for interacting with the statestore and controller.
:param host: The host of the api server.
:type host: str
:param port: The port of the api server.
:type port: int
:param secure: Whether to use https.
:type secure: bool
:param verify: Whether to verify the server certificate.
:type verify: bool
"""
[docs]
def __init__(
self,
host: str,
port: int = None,
secure: Optional[bool] = None,
verify: bool = True,
token: str = None,
auth_scheme: str = None,
token_endpoint: str = None,
access_token_provider: Optional[Callable[[], Optional[str]]] = None,
):
"""Initialize the Scaleout client.
:param host: The host of the api server.
:type host: str
:param port: The port of the api server.
:type port: int
:param secure: Whether to use https.
:type secure: bool
:param verify: Whether to verify the server SSL certificate (default: True). Set to False for development with self-signed certificates.
:type verify: bool
:param token: Refresh token for automatic token management. Can also be set via SCALEOUT_AUTH_TOKEN env var.
:type token: str
:param auth_scheme: Authorization scheme (e.g., 'Bearer'). Defaults to 'Bearer'.
:type auth_scheme: str
:param token_endpoint: Token refresh endpoint URL. If not provided, will be constructed from host.
:type token_endpoint: str
"""
_protocol, _host, _port, _ = parse_url(host)
if _protocol is not None and secure is not None:
ScaleoutLogger().warning("Both protocol in host and secure parameter provided. Using protocol from host.")
if _protocol is not None:
if _protocol == "https":
secure = True
else:
secure = False
host = _host
if _port is not None and port is not None and _port != port:
ScaleoutLogger().warning(f"Both port in host URL (:{_port}) and port parameter ({port}) provided. Using port parameter: {port}")
port = port if port else _port
if secure is None:
if port == 443:
secure = True
ScaleoutLogger().debug("Port 443 detected, automatically using HTTPS. To use HTTP explicitly, set secure=False.")
else:
secure = False
if port is None:
port = 443 if secure else 80
ScaleoutLogger().debug(f"No port specified, using default port {port} for {'HTTPS' if secure else 'HTTP'}.")
self.host = host
self.port = port
self.secure = secure
self.verify = verify
self.access_token_provider = access_token_provider
self.auth_scheme = auth_scheme or os.environ.get("SCALEOUT_AUTH_SCHEME", "Bearer")
self.token_manager: Optional[TokenManager] = None
self.headers = {}
# Get token from args or env
if not token:
token = os.environ.get("SCALEOUT_AUTH_TOKEN", None)
safe_host = re.sub(r"[^\w.\-]", "_", host)
token_cache_id = f"api-client-{safe_host}-{port}"
token_cache = TokenCache(cache_id=token_cache_id, cache_dir=os.environ.get("SCALEOUT_TOKEN_CACHE_DIR", None))
if not token and token_cache.exists():
cached_data = token_cache.load()
if cached_data and cached_data.get("refresh_token"):
token = cached_data["refresh_token"]
ScaleoutLogger().info(f"Loaded refresh token from cache: {token_cache.cache_file}")
# Clean token if it has scheme prefix
if token and " " in token:
token = token.split()[1]
# Initialize TokenManager with refresh token
if token:
# Construct token endpoint if not provided
if not token_endpoint:
token_endpoint = self._get_url("api/auth/refresh")
# Create token refresh callback to save tokens to cache
def on_token_refresh(access_token: str, refresh_token: str, expires_at: datetime) -> None:
"""Callback to save tokens when they are refreshed."""
try:
token_cache.save(access_token, refresh_token, expires_at)
ScaleoutLogger().debug(f"API client tokens updated in cache: {token_cache.cache_file}")
except Exception as e:
ScaleoutLogger().warning(f"Failed to save API client tokens to cache: {e}")
# TokenManager will automatically fetch the first access token using the refresh token
try:
self.token_manager = TokenManager(
refresh_token=token,
token_endpoint=token_endpoint,
verify_ssl=verify,
role="admin",
on_token_refresh=on_token_refresh,
)
except (RuntimeError, requests.exceptions.RequestException) as e:
ScaleoutLogger().warning(
f"Token authentication failed ({e}). Continuing without authentication — this is expected if the target environment has no auth system."
)
self.token_manager = None
def _get_url(self, endpoint):
if self.secure:
protocol = "https"
else:
protocol = "http"
if self.port:
return f"{protocol}://{self.host}:{self.port}/{endpoint}"
return f"{protocol}://{self.host}/{endpoint}"
def _get_url_api_v1(self, endpoint):
return self._get_url(f"api/v1/{endpoint}")
def _get_headers(self, additional_headers: Optional[Dict[str, str]] = None) -> Dict[str, str]:
"""Get headers with current access token.
If TokenManager is configured, this will automatically refresh the token if needed.
:param additional_headers: Optional additional headers to merge.
:type additional_headers: dict
:return: Headers dictionary with Authorization header.
:rtype: dict
"""
if self.access_token_provider:
tok = self.access_token_provider()
headers = {}
if tok:
headers["Authorization"] = f"{self.auth_scheme} {tok}"
elif self.token_manager:
# Get fresh token from TokenManager (will auto-refresh if needed)
headers = self.token_manager.get_auth_header()
else:
# Use static headers
headers = self.headers.copy()
# Merge additional headers if provided
if additional_headers:
headers.update(additional_headers)
return headers
def _perform_chunked_upload(self, file_path: str) -> str:
"""Uploads a file in chunks and returns a file_token."""
if not os.path.exists(file_path):
raise FileNotFoundError(f"Upload failed: File not found ({file_path})")
file_name = os.path.basename(file_path)
file_size = os.path.getsize(file_path)
current_chunk_size = 900 * 1024 # Default chunk size to start from
min_chunk_size = 256 * 1024 # Minimum chunk size when to give up
while True:
try:
return self._do_chunked_upload(file_path, file_name, file_size, current_chunk_size)
except requests.exceptions.HTTPError as e:
if getattr(e, "response", None) is not None and e.response.status_code == 413:
next_size = current_chunk_size // 2
if next_size < min_chunk_size:
raise RuntimeError(f"Upload failed: proxy rejects all supported chunk sizes (minimum {min_chunk_size // 1024} KB).")
ScaleoutLogger().warning(f"Proxy rejected chunk size (413). Retrying with {next_size // 1024} KB chunks...")
current_chunk_size = next_size
continue
raise
def _do_chunked_upload(self, file_path: str, file_name: str, file_size: int, chunk_size: int) -> str:
headers = self._get_headers()
# 1. Initialize Upload
init_url = self._get_url_api_v1("file-upload/init")
init_response = requests.post(
init_url,
json={"file_name": file_name, "file_size": file_size, "chunk_size": chunk_size},
headers=headers,
verify=self.verify,
)
init_response.raise_for_status()
upload_id = init_response.json().get("upload_id")
if not upload_id:
raise RuntimeError("Failed to receive an upload_id from the backend")
# 2. Upload Chunks
chunk_url = self._get_url_api_v1(f"file-upload/{upload_id}/chunk")
with open(file_path, "rb") as f:
chunk_index = 0
while True:
chunk_data = f.read(chunk_size)
if not chunk_data:
break
chunk_headers = headers.copy()
chunk_headers["X-Chunk-Index"] = str(chunk_index)
try:
chunk_resp = requests.post(chunk_url, data=chunk_data, headers=chunk_headers, verify=self.verify)
chunk_resp.raise_for_status()
except requests.exceptions.HTTPError as e:
if getattr(e, "response", None) is not None and e.response.status_code == 413:
abort_url = self._get_url_api_v1(f"file-upload/{upload_id}/abort")
requests.post(abort_url, headers=headers, verify=self.verify)
raise
chunk_index += 1
# 3. Complete Upload
complete_url = self._get_url_api_v1(f"file-upload/{upload_id}/complete")
complete_resp = requests.post(complete_url, headers=headers, verify=self.verify)
complete_resp.raise_for_status()
json_resp = complete_resp.json()
file_token = json_resp.get("file_token")
if not file_token:
raise RuntimeError("Failed to fetch file_token. Backend completed upload but emitted no token.")
return file_token
# --- Clients --- #
[docs]
def get_client(self, id: str):
"""Get a client from the statestore.
:param id: The client id to get.
:type id: str
:return: Client.
:rtype: dict
"""
response = requests.get(self._get_url_api_v1(f"clients/{id}"), verify=self.verify, headers=self._get_headers())
_json = response.json()
return _json
[docs]
def get_clients(self, n_max: int = None):
"""Get clients from the statestore.
:param n_max: The maximum number of clients to get (If none all will be fetched).
:type n_max: int
:return: Clients.
:rtype: dict
"""
additional_headers = {}
if n_max:
additional_headers["X-Limit"] = str(n_max)
response = requests.get(self._get_url_api_v1("clients/"), verify=self.verify, headers=self._get_headers(additional_headers))
_json = response.json()
return _json
[docs]
def get_clients_count(self, only_active: bool = False):
"""Get the number of clients in the statestore.
:return: The number of clients.
:rtype: dict
"""
_params = {}
if only_active:
_params["status"] = "online"
response = requests.get(self._get_url_api_v1("clients/count"), params=_params, verify=self.verify, headers=self._get_headers())
_json = response.json()
return _json
[docs]
def get_client_config(self, checksum=True):
"""Get client config from controller. Optionally include the checksum.
The config is used for clients to connect to the controller and ask for combiner assignment.
:param checksum: Whether to include the checksum of the package.
:type checksum: bool
:return: The client configuration.
:rtype: dict
"""
_params = {"checksum": "true" if checksum else "false"}
response = requests.get(self._get_url_api_v1("clients/config"), params=_params, verify=self.verify, headers=self._get_headers())
_json = response.json()
return _json
[docs]
def get_active_clients(self, combiner_id: str = None, n_max: int = None):
"""Get active clients from the statestore.
:param combiner_id: The combiner id to get active clients for.
:type combiner_id: str
:param n_max: The maximum number of clients to get (If none all will be fetched).
:type n_max: int
:return: Active clients.
:rtype: dict
"""
_params = {"status": "online"}
if combiner_id:
_params["combiner"] = combiner_id
additional_headers = {}
if n_max:
additional_headers["X-Limit"] = str(n_max)
response = requests.get(self._get_url_api_v1("clients/"), params=_params, verify=self.verify, headers=self._get_headers(additional_headers))
_json = response.json()
return _json
# --- Combiners --- #
[docs]
def get_combiner(self, id: str):
"""Get a combiner from the statestore.
:param id: The combiner id to get.
:type id: str
:return: Combiner.
:rtype: dict
"""
response = requests.get(self._get_url_api_v1(f"combiners/{id}"), verify=self.verify, headers=self._get_headers())
_json = response.json()
return _json
[docs]
def get_combiners(self, n_max: int = None):
"""Get combiners in the network.
:param n_max: The maximum number of combiners to get (If none all will be fetched).
:type n_max: int
:return: Combiners.
:rtype: dict
"""
additional_headers = {}
if n_max:
additional_headers["X-Limit"] = str(n_max)
response = requests.get(self._get_url_api_v1("combiners/"), verify=self.verify, headers=self._get_headers(additional_headers))
_json = response.json()
return _json
[docs]
def get_combiners_count(self):
"""Get the number of combiners in the statestore.
:return: The number of combiners.
:rtype: dict
"""
response = requests.get(self._get_url_api_v1("combiners/count"), verify=self.verify, headers=self._get_headers())
_json = response.json()
return _json
# --- Controllers --- #
[docs]
def get_controller_status(self):
"""Get the status of the controller.
:return: The status of the controller.
:rtype: dict
"""
response = requests.get(self._get_url("get_controller_status"), verify=self.verify, headers=self._get_headers())
_json = response.json()
return _json
# --- Models --- #
[docs]
def get_model(self, id: str):
"""Get a model from the statestore.
:param id: The id (or model property) of the model to get.
:type id: str
:return: Model.
:rtype: dict
"""
response = requests.get(self._get_url_api_v1(f"models/{id}"), verify=self.verify, headers=self._get_headers())
_json = response.json()
return _json
[docs]
def get_models(self, session_id: str = None, n_max: int = None):
"""Get models from the statestore.
:param session_id: The session id to get models for. (optional)
:type session_id: str
:param n_max: The maximum number of models to get (If none all will be fetched).
:type n_max: int
:return: Models.
:rtype: dict
"""
_params = {}
if session_id:
_params["session_id"] = session_id
additional_headers = {}
if n_max:
additional_headers["X-Limit"] = str(n_max)
response = requests.get(self._get_url_api_v1("models/"), params=_params, verify=self.verify, headers=self._get_headers(additional_headers))
_json = response.json()
return _json
[docs]
def get_models_count(self):
"""Get the number of models in the statestore.
:return: The number of models.
:rtype: dict
"""
response = requests.get(self._get_url_api_v1("models/count"), verify=self.verify, headers=self._get_headers())
_json = response.json()
return _json
[docs]
def get_active_model(self):
"""Get the latest model from the statestore.
:return: The latest model.
:rtype: dict
"""
additional_headers = {"X-Limit": "1", "X-Sort-Key": "committed_at", "X-Sort-Order": "desc"}
response = requests.get(self._get_url_api_v1("models/"), verify=self.verify, headers=self._get_headers(additional_headers))
_json = response.json()
if "result" in _json and len(_json["result"]) > 0:
return _json["result"][0]
return _json
[docs]
def get_model_trail(self, id: str = None, include_self: bool = True, reverse: bool = True, n_max: int = None):
"""Get the model trail.
:param id: The id (or model property) of the model to start the trail from. (optional)
:type id: str
:param n_max: The maximum number of models to get (If none all will be fetched).
:type n_max: int
:return: Models.
:rtype: dict
"""
if not id:
model = self.get_active_model()
if "model_id" in model:
id = model["model_id"]
else:
return model
_count: int = n_max if n_max else self.get_models_count()
additional_headers = {"X-Limit": str(_count), "X-Reverse": "true" if reverse else "false"}
_include_self_str: str = "true" if include_self else "false"
response = requests.get(
self._get_url_api_v1(f"models/{id}/ancestors?include_self={_include_self_str}"), verify=self.verify, headers=self._get_headers(additional_headers)
)
_json = response.json()
return _json
[docs]
def download_model(self, id: str, path: str):
"""Download the model with id id.
:param id: The id (or model property) of the model to download.
:type id: str
:param path: The path to download the model to.
:type path: str
:return: Message with success or failure.
:rtype: dict
"""
response = requests.get(self._get_url_api_v1(f"models/{id}/download"), verify=self.verify, headers=self._get_headers())
if response.status_code == 200:
with open(path, "wb") as file:
file.write(response.content)
return {"success": True, "message": "Model downloaded successfully."}
else:
return {"success": False, "message": "Failed to download model."}
[docs]
def set_active_model(self, path):
"""Set the initial model in the statestore and upload to model repository.
:param path: The file path of the initial model to set.
:type path: str
:return: A dict with success or failure message.
:rtype: dict
"""
if path.endswith(".npz"):
helper = "numpyhelper"
elif path.endswith(".bin"):
helper = "binaryhelper"
if helper:
response = requests.put(self._get_url_api_v1("helpers/active"), json={"helper": helper}, verify=self.verify, headers=self._get_headers())
file_token = self._perform_chunked_upload(path)
response = requests.post(
self._get_url_api_v1("models/"), data={"helper": helper, "file_token": file_token}, verify=self.verify, headers=self._get_headers()
)
return response.json()
# --- Packages --- #
[docs]
def get_package(self, id: str):
"""Get a compute package from the statestore.
:param id: The id of the compute package to get.
:type id: str
:return: Package.
:rtype: dict
"""
response = requests.get(self._get_url_api_v1(f"packages/{id}"), verify=self.verify, headers=self._get_headers())
_json = response.json()
return _json
[docs]
def get_packages(self, n_max: int = None):
"""Get compute packages from the statestore.
:param n_max: The maximum number of packages to get (If none all will be fetched).
:type n_max: int
:return: Packages.
:rtype: dict
"""
additional_headers = {}
if n_max:
additional_headers["X-Limit"] = str(n_max)
response = requests.get(self._get_url_api_v1("packages/"), verify=self.verify, headers=self._get_headers(additional_headers))
_json = response.json()
return _json
[docs]
def get_packages_count(self):
"""Get the number of compute packages in the statestore.
:return: The number of packages.
:rtype: dict
"""
response = requests.get(self._get_url_api_v1("packages/count"), verify=self.verify, headers=self._get_headers())
_json = response.json()
return _json
[docs]
def get_active_package(self):
"""Get the (active) compute package from the statestore.
:return: Package.
:rtype: dict
"""
response = requests.get(self._get_url_api_v1("packages/active"), verify=self.verify, headers=self._get_headers())
_json = response.json()
return _json
[docs]
def get_package_checksum(self):
"""Get the checksum of the compute package.
:return: The checksum.
:rtype: dict
"""
response = requests.get(self._get_url_api_v1("packages/checksum"), verify=self.verify, headers=self._get_headers())
_json = response.json()
return _json
[docs]
def download_package(self, path: str):
"""Download the compute package.
:param path: The path to download the compute package to.
:type path: str
:return: Message with success or failure.
:rtype: dict
"""
response = requests.get(self._get_url_api_v1("packages/download"), verify=self.verify, headers=self._get_headers())
if response.status_code == 200:
with open(path, "wb") as file:
file.write(response.content)
return {"success": True, "message": "Package downloaded successfully."}
else:
return {"success": False, "message": "Failed to download package."}
[docs]
def set_active_package(self, path: str, helper: str, name: str, description: str = ""):
"""Set the compute package in the statestore.
:param path: The file path of the compute package to set.
:type path: str
:param helper: The helper type to use.
:type helper: str
:return: A dict with success or failure message.
:rtype: dict
"""
file_token = self._perform_chunked_upload(path)
response = requests.post(
self._get_url_api_v1("packages/"),
data={"helper": helper, "name": name, "description": description, "file_token": file_token, "file_name": os.path.basename(path)},
verify=self.verify,
headers=self._get_headers(),
)
_json = response.json()
return _json
# --- Rounds --- #
[docs]
def get_round(self, id: str):
"""Get a round from the statestore.
:param round_id: The round id to get.
:type round_id: str
:return: Round (config and metrics).
:rtype: dict
"""
response = requests.get(self._get_url_api_v1(f"rounds/{id}"), verify=self.verify, headers=self._get_headers())
_json = response.json()
return _json
[docs]
def get_rounds(self, n_max: int = None, filter: Dict = None):
"""Get all rounds from the statestore.
:param n_max: The maximum number of rounds to get (If none all will be fetched).
:type n_max: int
:return: Rounds.
:rtype: dict
"""
additional_headers = {}
params = {}
if n_max:
additional_headers["X-Limit"] = str(n_max)
if filter:
params.update({k: str(v) for k, v in filter.items()})
response = requests.get(self._get_url_api_v1("rounds/"), verify=self.verify, headers=self._get_headers(additional_headers), params=params)
_json = response.json()
return _json
[docs]
def get_rounds_count(self):
"""Get the number of rounds in the statestore.
:return: The number of rounds.
:rtype: dict
"""
response = requests.get(self._get_url_api_v1("rounds/count"), verify=self.verify, headers=self._get_headers())
_json = response.json()
return _json
# --- Sessions --- #
[docs]
def get_session(self, id: str = None, name: str = None):
"""Get a session from the statestore.
:param id: The session id to get.
:type id: str
:param name: The session name to get.
:type name: str
:return: Session.
:rtype: dict
"""
if name:
response = requests.get(self._get_url_api_v1(f"sessions?name={name}"), verify=self.verify, headers=self._get_headers())
_json = response.json()
if "result" in _json and len(_json["result"]) > 0:
_json = _json["result"][0]
else:
_json = {"message": "Session not found."}
elif id:
response = requests.get(self._get_url_api_v1(f"sessions/{id}"), verify=self.verify, headers=self._get_headers())
_json = response.json()
else:
_json = {"message": "No id or name provided."}
return _json
[docs]
def get_sessions(self, n_max: int = None, name: str = None):
"""Get sessions from the statestore.
:param n_max: The maximum number of sessions to get (If none all will be fetched).
:type n_max: int
:param name: The session name to get.
:type name: str
:return: Sessions.
:rtype: dict
"""
additional_headers = {}
if n_max:
additional_headers["X-Limit"] = str(n_max)
url = self._get_url_api_v1("sessions/")
if name:
url += f"?name={name}"
response = requests.get(url, verify=self.verify, headers=self._get_headers(additional_headers))
_json = response.json()
return _json
[docs]
def get_sessions_count(self):
"""Get the number of sessions in the statestore.
:return: The number of sessions.
:rtype: dict
"""
response = requests.get(self._get_url_api_v1("sessions/count"), verify=self.verify, headers=self._get_headers())
_json = response.json()
return _json
[docs]
def get_session_status(self, id: str):
"""Get the status of a session.
:param id: The id of the session to get.
:type id: str
:return: The status of the session.
:rtype: str
"""
session = self.get_session(id)
if session and "status" in session:
return session["status"]
return "Could not retrieve session status."
[docs]
def session_is_finished(self, id: str):
"""Check if a session with id has finished.
:param id: The id of the session to get.
:type id: str
:return: True if session is finished, otherwise false.
:rtype: bool
"""
status = self.get_session_status(id)
return status and status.lower() == "finished"
[docs]
def run_custom_command(self, command_type: str, blocking: bool = False, timeout: int = None, client_ids: list = None, parameters: dict = None):
if not command_type.startswith("Custom_"):
command_type = "Custom_" + command_type
url = self._get_url_api_v1("control/run_command")
data = {"command_type": command_type, "blocking": blocking, "timeout": timeout, "client_ids": client_ids, **(parameters or {})}
response = requests.post(url, json=data, verify=self.verify, headers=self._get_headers())
if response.status_code == 200:
return response.json()
else:
return {"success": False, "message": f"Failed to run command. Status code: {response.status_code}, Response: {response.text}"}
[docs]
def start_session(
self,
name: str = None,
aggregator: str = "fedavg",
aggregator_kwargs: dict = None,
model_id: str = None,
round_timeout: int = 180,
rounds: int = 5,
round_buffer_size: int = -1,
delete_models: bool = True,
validate: bool = True,
helper: str = None,
min_clients: int = 1,
requested_clients: int = 8,
server_functions: ServerFunctionsBase = None,
):
"""Start a new session.
:param name: The name of the session
:type name: str
:param aggregator: The aggregator plugin to use.
:type aggregator: str
:param model_id: The id of the initial model.
:type model_id: str
:param round_timeout: The round timeout to use in seconds.
:type round_timeout: int
:param rounds: The number of rounds to perform.
:type rounds: int
:param round_buffer_size: The round buffer size to use.
:type round_buffer_size: int
:param delete_models: Whether to delete models after each round at combiner (save storage).
:type delete_models: bool
:param validate: Whether to validate the model after each round.
:type validate: bool
:param helper: The helper type to use.
:type helper: str
:param min_clients: The minimum number of clients required.
:type min_clients: int
:param requested_clients: The requested number of clients.
:type requested_clients: int
:return: A dict with success or failure message and session config.
:rtype: dict
"""
if model_id is None:
additional_headers = {"X-Limit": "1", "X-Sort-Key": "committed_at", "X-Sort-Order": "desc"}
response = requests.get(self._get_url_api_v1("models/"), verify=self.verify, headers=self._get_headers(additional_headers))
if response.status_code == 200:
json = response.json()
if "result" in json and len(json["result"]) > 0:
model_id = json["result"][0]["model_id"]
else:
return {"message": "No models found in the repository"}
else:
return {"message": "No models found in the repository"}
response = requests.post(
self._get_url_api_v1("sessions/"),
json={
"name": name,
"seed_model_id": model_id,
"session_config": {
"aggregator": aggregator,
"aggregator_kwargs": aggregator_kwargs,
"rounds": rounds,
"round_timeout": round_timeout,
"buffer_size": round_buffer_size,
"model_id": model_id,
"delete_models_storage": delete_models,
"clients_required": min_clients,
"requested_clients": requested_clients,
"validate": validate,
"helper_type": helper,
"server_functions": None if server_functions is None else inspect.getsource(server_functions),
},
},
verify=self.verify,
headers=self._get_headers(),
)
if response.status_code == 201:
session_id = response.json()["session_id"]
response = requests.post(
self._get_url_api_v1("sessions/start"),
json={
"session_id": session_id,
"rounds": rounds,
"round_timeout": round_timeout,
},
verify=self.verify,
headers=self._get_headers(),
)
# Try to parse JSON, but handle the case where it fails
try:
response_json = response.json()
response_json["session_id"] = session_id
return response_json
except requests.exceptions.JSONDecodeError:
# Handle invalid JSON response
return {"success": response.status_code < 400, "session_id": session_id, "message": f"Session started with status code {response.status_code}"}
_json = response.json()
return _json
[docs]
def continue_session(self, session_id: str, rounds: int = 5, round_timeout: int = 180):
"""Continue a session.
:param session_id: The id of the session to continue.
:type session_id: str
:param rounds: The number of rounds to perform.
:type rounds: int
:param round_timeout: The round timeout to use in seconds.
:type round_timeout: int
:return: A dict with success or failure message and session config.
:rtype: dict
"""
if not session_id:
return {"message": "No session id provided."}
if rounds is None or rounds <= 0:
return {"message": "Invalid number of rounds provided. Must be greater than 0."}
if round_timeout is None or round_timeout <= 0:
return {"message": "Invalid round timeout provided. Must be greater than 0."}
# Check if session exists
session = self.get_session(session_id)
if not session or "session_id" not in session:
return {"message": "Session not found."}
# Check if session is finished
if not self.session_is_finished(session_id):
return {"message": "Session is already running."}
response = requests.post(
self._get_url_api_v1("sessions/start"),
json={
"session_id": session_id,
"rounds": rounds,
"round_timeout": round_timeout,
},
verify=self.verify,
headers=self._get_headers(),
)
_json = response.json()
return _json
# --- Statuses --- #
[docs]
def get_status(self, id: str):
"""Get a status object (event) from the statestore.
:param id: The id of the status to get.
:type id: str
:return: Status.
:rtype: dict
"""
response = requests.get(self._get_url_api_v1(f"statuses/{id}"), verify=self.verify, headers=self._get_headers())
_json = response.json()
return _json
[docs]
def get_statuses(self, session_id: str = None, event_type: str = None, sender_name: str = None, sender_role: str = None, n_max: int = None):
"""Get statuses from the statestore. Filter by input parameters
:param session_id: The session id to get statuses for.
:type session_id: str
:param event_type: The event type to get.
:type event_type: str
:param sender_name: The sender name to get.
:type sender_name: str
:param sender_role: The sender role to get.
:type sender_role: str
:param n_max: The maximum number of statuses to get (If none all will be fetched).
:type n_max: int
:return: Statuses
"""
_params = {}
if session_id:
_params["session_id"] = session_id
if event_type:
_params["type"] = event_type
if sender_name:
_params["sender.name"] = sender_name
if sender_role:
_params["sender.role"] = sender_role
additional_headers = {}
if n_max:
additional_headers["X-Limit"] = str(n_max)
response = requests.get(self._get_url_api_v1("statuses/"), params=_params, verify=self.verify, headers=self._get_headers(additional_headers))
_json = response.json()
return _json
[docs]
def get_statuses_count(self):
"""Get the number of statuses in the statestore.
:return: The number of statuses.
:rtype: dict
"""
response = requests.get(self._get_url_api_v1("statuses/count"), verify=self.verify, headers=self._get_headers())
_json = response.json()
return _json
# --- Validations --- #
[docs]
def get_validation(self, id: str):
"""Get a validation from the statestore.
:param id: The id of the validation to get.
:type id: str
:return: Validation.
:rtype: dict
"""
response = requests.get(self._get_url_api_v1(f"validations/{id}"), verify=self.verify, headers=self._get_headers())
_json = response.json()
return _json
[docs]
def get_validations(
self,
session_id: str = None,
model_id: str = None,
correlation_id: str = None,
sender_name: str = None,
sender_role: str = None,
sender_client_id: str = None,
n_max: int = None,
):
"""Get validations from the statestore. Filter by input parameters.
:param session_id: The session id to get validations for.
:type session_id: str
:param model_id: The model id to get validations for.
:type model_id: str
:param correlation_id: The correlation id to get validations for.
:type correlation_id: str
:param sender_name: The sender name to get validations for.
:type sender_name: str
:param sender_role: The sender role to get validations for.
:type sender_role: str
:param sender_client_id: The sender client id to get validations for.
:type sender_client_id: str
:param n_max: The maximum number of validations to get (If none all will be fetched).
:type n_max: int
:return: Validations.
:rtype: dict
"""
_params = {}
if session_id:
_params["session_id"] = session_id
if model_id:
_params["model_id"] = model_id
if correlation_id:
_params["correlation_id"] = correlation_id
if sender_name:
_params["sender.name"] = sender_name
if sender_role:
_params["sender.role"] = sender_role
if sender_client_id:
_params["sender.client_id"] = sender_client_id
additional_headers = {}
if n_max:
additional_headers["X-Limit"] = str(n_max)
response = requests.get(self._get_url_api_v1("validations/"), params=_params, verify=self.verify, headers=self._get_headers(additional_headers))
_json = response.json()
return _json
[docs]
def get_validations_count(self, session_id: str = None):
"""Get the number of validations in the statestore.
:return: The number of validations.
:rtype: dict
"""
_params = {}
if session_id:
_params["session_id"] = session_id
response = requests.get(self._get_url_api_v1("validations/count"), params=_params, verify=self.verify, headers=self._get_headers())
_json = response.json()
return _json
response = requests.get(self._get_url_api_v1("validations/count"), verify=self.verify, headers=self._get_headers())
_json = response.json()
return _json
# --- Predictions --- #
[docs]
def get_predictions(
self,
model_id: str = None,
correlation_id: str = None,
client_id: str = None,
n_max: int = None,
):
"""Get predictions from the statestore. Filter by input parameters.
:param model_id: The model id to get predictions for.
:type model_id: str
:param correlation_id: The correlation id to get predictions for.
:type correlation_id: str
:param client_id: The client id to get predictions for.
:type client_id: str
:param n_max: The maximum number of predictions to get (If none all will be fetched).
:type n_max: int
:return: Predictions.
:rtype: dict
"""
_params = {}
if model_id:
_params["model_id"] = model_id
if correlation_id:
_params["correlation_id"] = correlation_id
if client_id:
_params["client_id"] = client_id
additional_headers = {}
if n_max:
additional_headers["X-Limit"] = str(n_max)
response = requests.get(self._get_url_api_v1("predictions/"), params=_params, verify=self.verify, headers=self._get_headers(additional_headers))
_json = response.json()
return _json
[docs]
def start_predictions(self, prediction_id: str = None, model_id: str = None):
"""Start predictions for a model.
:param model_id: The model id to start predictions for.
:type model_id: str
:param data: The data to predict.
:type data: dict
:return: A dict with success or failure message.
:rtype: dict
"""
if not prediction_id:
prediction_id = str(uuid.uuid4())
response = requests.post(
self._get_url_api_v1("predictions/start"),
json={"prediction_id": prediction_id, "model_id": model_id},
verify=self.verify,
headers=self._get_headers(),
)
_json = response.json()
return _json
# --- Attributes --- #
[docs]
def get_current_attributes(self, node_list):
"""Get the current attributes of the node(s).
:param node_list: The list of nodes to get the attributes for or a single node_id (client_id or combiner_id)
:type node_list: list|str
:raises ValueError: If node_list is not a list or empty.
:return: The current attributes of the nodes.
:rtype: dict
"""
if not isinstance(node_list, list):
if isinstance(node_list, str):
node_list = [node_list]
else:
raise ValueError("node_list must be a list or string")
if len(node_list) == 0:
raise ValueError("node_list must not be empty")
json = {"node_ids": node_list}
response = requests.post(self._get_url_api_v1("attributes/current"), json=json, verify=self.verify, headers=self._get_headers())
_json = response.json()
return _json
[docs]
def add_attributes(self, attribute: dict) -> dict:
"""Add or update node attributes via the controller API.
:param attribute: A dict matching AttributeDTO.schema, e.g.:
.. code-block:: text
{
"key": "charging",
"value": "true",
"sender": {
"name": "",
"role": "",
"client_id": "abc123" # or "combiner_id": "abc123"
}
}
:return: Parsed JSON response from the server.
:rtype: dict
"""
url = self._get_url_api_v1("attributes/")
response = requests.post(url, json=attribute, headers=self._get_headers(), verify=self.verify)
response.raise_for_status()
return response.json()
[docs]
def add_status(self, status: dict) -> dict:
"""Submit a status entry to the controller API.
:param status: A dict matching StatusDTO.schema, e.g.:
.. code-block:: json
{
"type": "MODEL_UPDATE",
"log_level": "INFO",
"status": "Training complete",
"sender": {
"client_id": "abc-123"
}
}
:return: Parsed JSON response from the server.
:rtype: dict
"""
url = self._get_url_api_v1("statuses/")
response = requests.post(url, json=status, headers=self._get_headers(), verify=self.verify)
response.raise_for_status()
return response.json()
[docs]
def get_attribute_trail(self, node_id: str):
"""Get the full attribute history for a given node.
:param node_id: The node_id (client or combiner) to fetch attributes for
:type node_id: str
:return: All matching attributes for the node
:rtype: dict
"""
if not isinstance(node_id, str) or not node_id:
raise ValueError("node_id must be a non-empty string")
# Reuse existing /attributes/list endpoint with a filter on node_id
url = self._get_url_api_v1("attributes/list")
response = requests.post(
url,
json={"node_id": node_id},
headers=self._get_headers(),
verify=self.verify,
)
response.raise_for_status()
return response.json()
### Control Functions ###
[docs]
def step_current_session(self):
"""Continue a session control.
:param session_id: The id of the session to continue.
:type session_id: str
:return: A dict with success or failure message.
:rtype: dict
"""
response = requests.post(
self._get_url_api_v1("control/continue"),
verify=self.verify,
headers=self._get_headers(),
)
_json = response.json()
return _json
[docs]
def stop_current_session(self):
"""Stop a session control.
:param session_id: The id of the session to stop.
:type session_id: str
:return: A dict with success or failure message.
:rtype: dict
"""
response = requests.post(
self._get_url_api_v1("control/stop"),
verify=self.verify,
headers=self._get_headers(),
)
_json = response.json()
return _json