Source code for scaleout.client.edge_client

"""EdgeClient class for interacting with the Scaleout network."""

import enum
import json
import signal
import threading
import time
import traceback
import uuid
from contextlib import contextmanager
from typing import Any, Callable, Dict, Optional, Tuple
from datetime import datetime

from scaleout.utils.dist import VERSION
from scaleoututil.utils.model import ScaleoutModel
from scaleoututil.utils.url import assemble_endpoint_url
import psutil
import requests

from scaleoututil.auth.token_manager import TokenManager

import scaleoututil.grpc.scaleout_pb2 as scaleout_msg
from scaleoututil.config import (
    SCALEOUT_AUTH_SCHEME,
    SCALEOUT_CONNECT_API_SECURE,
    SCALEOUT_CLIENT_STATUS_REPORTING,
    SCALEOUT_CLIENT_SEND_TELEMETRY,
    SCALEOUT_GRACEFUL_CLIENT_CONNECTION,
    SCALEOUT_CHECK_COMPATIBILITY,
    SCALEOUT_CLIENT_TASK_POLLING_INTERVAL,
)
from scaleoututil.logging import ScaleoutLogger
from scaleout.client.grpc_handler import GrpcConnectionOptions, GrpcHandler, RetryException
from scaleoututil.utils.http_status_codes import (
    HTTP_STATUS_BAD_REQUEST,
    HTTP_STATUS_NOT_ACCEPTABLE,
    HTTP_STATUS_NOT_FOUND,
    HTTP_STATUS_OK,
    HTTP_STATUS_PACKAGE_MISSING,
    HTTP_STATUS_UNAUTHORIZED,
    HTTP_STATUS_SERVER_ERROR,
)
from scaleout.client.logging_context import LoggingContext
from scaleout.client.task_receiver import StoppedException, TaskReceiver, UnknownTaskType
from scaleoututil.grpc.tasktype import TaskType

# Default timeout for requests
REQUEST_TIMEOUT = 10  # seconds


class ConnectToApiResult(enum.Enum):
    """Enum for representing the result of connecting to the Scaleout API."""

    Assigned = 0
    ComputePackageMissing = 1
    UnAuthorized = 2
    UnMatchedConfig = 3
    IncorrectUrl = 4
    UnknownError = 5


class GracefulExitException(Exception):
    pass


[docs] class EdgeClient: """Client for interacting with the Scaleout network."""
[docs] def __init__( self, train_callback: Optional[Callable[[ScaleoutModel, Dict], Tuple[Optional[ScaleoutModel], Dict]]] = None, validate_callback: Optional[Callable[[ScaleoutModel], Dict]] = None, predict_callback: Optional[Callable[[ScaleoutModel], Dict]] = None, ) -> None: """Initialize the EdgeClient.""" self.name: str = None self.client_id: str = None self.train_callback = train_callback self.validate_callback = validate_callback self.predict_callback = predict_callback self.grpc_handler: Optional[GrpcHandler] = None self.package_path: str = "." self._current_logging_context = threading.local() self.task_receiver = TaskReceiver(self, self._run_task_callback, polling_interval=SCALEOUT_CLIENT_TASK_POLLING_INTERVAL) self.registered_callbacks: Dict[str, Callable[[scaleout_msg.TaskRequest], Dict]] = {} self.token_manager: Optional[TokenManager] = None
@property def current_logging_context(self) -> Optional[LoggingContext]: """Get the current logging context for the running thread.""" return getattr(self._current_logging_context, "value", None) @current_logging_context.setter def current_logging_context(self, context: LoggingContext) -> None: """Set the current logging context for the running thread.""" self._current_logging_context.value = context
[docs] def set_train_callback(self, callback: callable) -> None: """Set the train callback.""" self.train_callback = callback
[docs] def set_validate_callback(self, callback: callable) -> None: """Set the validate callback.""" self.validate_callback = callback
[docs] def set_predict_callback(self, callback: callable) -> None: """Set the predict callback.""" self.predict_callback = callback
[docs] def set_custom_callback(self, callback_name: str, callback: Callable[[scaleout_msg.TaskRequest], Dict]) -> None: """Set a custom task callback.""" if not callback_name.startswith("Custom_"): callback_name = "Custom_" + callback_name self.registered_callbacks[callback_name] = callback ScaleoutLogger().info(f"Registered custom callback: {callback_name}")
[docs] def remove_custom_callback(self, callback_name: str) -> None: """Remove a custom task callback.""" if not callback_name.startswith("Custom_"): callback_name = "Custom_" + callback_name if callback_name in self.registered_callbacks: del self.registered_callbacks[callback_name] ScaleoutLogger().info(f"Removed custom callback: {callback_name}") else: ScaleoutLogger().warning(f"Custom callback {callback_name} not found")
def _get_current_token(self) -> Optional[str]: """Get the current access token, refreshing if needed.""" if self.token_manager: return self.token_manager.get_access_token() return None def _init_token_manager(self, token: str, url: str, token_refresh_callback: Optional[Callable[[str, str, datetime], None]] = None) -> None: """Initialize the token manager with the provided token.""" if self.token_manager is None: token_endpoint = assemble_endpoint_url(url, "api/auth", "refresh") self.token_manager = TokenManager(refresh_token=token, token_endpoint=token_endpoint, on_token_refresh=token_refresh_callback)
[docs] def connect_to_api( self, url: str, json: dict = None, token: Optional[str] = None, token_refresh_callback: Optional[Callable[[str, str, datetime], None]] = None ) -> Tuple[ConnectToApiResult, Any]: """Connect to the Scaleout API. Accepts a refresh token, instantiates TokenManager, and uses access token.""" if token: self._init_token_manager(token, url, token_refresh_callback) current_token = self._get_current_token() url_endpoint = assemble_endpoint_url(url, "api/v1/clients/add") ScaleoutLogger().info(f"Connecting to API endpoint: {url_endpoint}") if SCALEOUT_CHECK_COMPATIBILITY: json["client_version"] = VERSION try: response = requests.post( url=url_endpoint, json=json, allow_redirects=True, headers={"Authorization": f"{SCALEOUT_AUTH_SCHEME} {current_token}"}, timeout=REQUEST_TIMEOUT, verify=SCALEOUT_CONNECT_API_SECURE, ) if response.status_code == HTTP_STATUS_OK: ScaleoutLogger().info("Connect to Scaleout API - Client assigned to controller") json_response = response.json() self.set_client_id(json_response["client_id"]) self.set_name(json.get("name", json_response["client_id"])) combiner_config = GrpcConnectionOptions.from_dict(json_response) return ConnectToApiResult.Assigned, combiner_config if response.status_code == HTTP_STATUS_PACKAGE_MISSING: json_response = response.json() ScaleoutLogger().info("Connect to Scaleout API - Remote compute package missing.") return ConnectToApiResult.ComputePackageMissing, json_response if response.status_code == HTTP_STATUS_UNAUTHORIZED: ScaleoutLogger().error("Connect to Scaleout API - Unauthorized") return ConnectToApiResult.UnAuthorized, "Unauthorized" if response.status_code == HTTP_STATUS_BAD_REQUEST: try: json_response = response.json() except Exception: json_response = {} msg = json_response.get("message", "Unknown error") ScaleoutLogger().error(f"Connect to Scaleout API - {msg}") return ConnectToApiResult.UnMatchedConfig, msg if response.status_code == HTTP_STATUS_NOT_ACCEPTABLE: try: json_response = response.json() except Exception: json_response = {} msg = json_response.get("message", "Unknown error") ScaleoutLogger().error(f"Connect to Scaleout API - {msg}") return ConnectToApiResult.UnMatchedConfig, msg if response.status_code == HTTP_STATUS_NOT_FOUND: ScaleoutLogger().error("Connect to Scaleout API - Incorrect URL") return ConnectToApiResult.IncorrectUrl, "Incorrect URL" if response.status_code == HTTP_STATUS_SERVER_ERROR: response_json = response.json() msg = response_json.get("message", "Unknown server error") ScaleoutLogger().error(f"Connect to Scaleout API - Server error: {msg}") return ConnectToApiResult.UnknownError, f"Server error: {msg}" except Exception as e: ScaleoutLogger().error(f"Connect to Scaleout API - Error occurred: {str(e)}") return ConnectToApiResult.UnknownError, str(e)
[docs] def init_grpchandler( self, config: GrpcConnectionOptions, token: Optional[str] = None, url: Optional[str] = None, token_refresh_callback: Optional[Callable[[str, str, datetime], None]] = None, ) -> bool: """Initialize the GRPC handler. Accepts a refresh token, instantiates TokenManager, and uses access token.""" if token and url: self._init_token_manager(token, url, token_refresh_callback) try: self.grpc_handler = GrpcHandler(self, host=config.host, port=config.port) if SCALEOUT_CHECK_COMPATIBILITY: success, server_version, msg = self.grpc_handler.check_version_compatibility() if not success: ScaleoutLogger().error(f"Client version: {VERSION} compatibility check failed with Server version: {server_version}. {msg}") return False ScaleoutLogger().info("Successfully initialized GRPC connection") return True except Exception as e: ScaleoutLogger().error(f"Could not initialize GRPC connection: {e}") return False
def _send_heartbeats(self, client_name: str, client_id: str, update_frequency: float = 2.0) -> None: """Send heartbeats to the server.""" self.grpc_handler.send_heartbeats(client_name=client_name, client_id=client_id, update_frequency=update_frequency) def _listen_to_task_stream(self, client_id: str) -> None: """Listen to the task stream.""" self.grpc_handler.listen_to_task_stream(client_id=client_id, callback=self._task_stream_callback)
[docs] def default_telemetry_loop(self, update_frequency: float = 5.0) -> None: """Send default telemetry data.""" send_telemetry = True while send_telemetry: memory_usage = psutil.virtual_memory().percent cpu_usage = psutil.cpu_percent() try: success = self.log_telemetry(telemetry={"memory_usage": memory_usage, "cpu_usage": cpu_usage}) except RetryException as e: ScaleoutLogger().error(f"Sending telemetry failed: {e}") success = False if not success: ScaleoutLogger().error("Telemetry failed.") send_telemetry = False time.sleep(update_frequency)
[docs] @contextmanager def logging_context(self, context: LoggingContext): """Set the logging context.""" prev_context = self.current_logging_context self.current_logging_context = context try: yield finally: self.current_logging_context = prev_context
def _task_stream_callback(self, request: scaleout_msg.TaskRequest) -> None: """Handle task stream callbacks.""" if request.type == TaskType.ModelUpdate.value: self.update_local_model(request) elif request.type == TaskType.Validation.value: self.validate_global_model(request) elif request.type == TaskType.Prediction.value: self.predict_global_model(request) return {} def _run_task_callback(self, request: scaleout_msg.TaskRequest) -> Dict: if request.type in (t.value for t in TaskType): return self._task_stream_callback(request) elif TaskType.is_custom_task(request.type): return self._handle_custom_task(request) else: ScaleoutLogger().error(f"Invalid task type: {request.type}") raise Exception(f"Invalid task type: {request.type}") def _handle_custom_task(self, request: scaleout_msg.TaskRequest) -> Dict: if request.type in self.registered_callbacks: with self.logging_context(LoggingContext(request=request)): params = json.loads(request.data) if request.data else {} try: result = self.registered_callbacks[request.type](params) except Exception as e: ScaleoutLogger().error(f"Custom task callback failed with exception: {e}") traceback.print_exc() return None return result else: ScaleoutLogger().warning(f"Unknown task type: {request.type}") raise UnknownTaskType(f"Unknown task type: {request.type}")
[docs] def update_local_model(self, request: scaleout_msg.TaskRequest) -> None: """Update the local model.""" with self.logging_context(LoggingContext(request=request)): model_id = request.model_id model_update_id = str(uuid.uuid4()) tic = time.time() in_model = self.get_model_from_combiner(model_id=model_id) if in_model is None: ScaleoutLogger().error("Could not retrieve model from combiner. Aborting training request.") return fetch_model_time = time.time() - tic ScaleoutLogger().info(f"FETCH_MODEL: {fetch_model_time}") if not self.train_callback: ScaleoutLogger().error("No train callback set") return if SCALEOUT_CLIENT_STATUS_REPORTING: self.send_status( f"\t Starting processing of training request for model_id {model_id}", log_level=scaleout_msg.LogLevel.INFO, type="MODEL_UPDATE", ) ScaleoutLogger().info(f"Running train callback with model ID: {model_id}") client_settings = json.loads(request.data).get("client_settings", {}) tic = time.time() try: out_model, meta = self.train_callback(in_model, client_settings) except StoppedException: return except Exception as e: ScaleoutLogger().error(f"Train callback failed with exception: {e}") traceback.print_exc() return if out_model is None: ScaleoutLogger().error("Train callback returned None model. Aborting training request.") return meta["processing_time"] = time.time() - tic tic = time.time() out_model.model_id = model_update_id self.send_model_to_combiner(model=out_model) meta["upload_model"] = time.time() - tic ScaleoutLogger().info("UPLOAD_MODEL: {0}".format(meta["upload_model"])) meta["fetch_model"] = fetch_model_time meta["config"] = request.data self.grpc_handler.send_model_update( model_id=model_id, model_update_id=model_update_id, meta=meta, correlation_id=request.correlation_id, round_id=request.round_id, session_id=request.session_id, ) if SCALEOUT_CLIENT_STATUS_REPORTING: self.send_status( "Model update completed.", log_level=scaleout_msg.LogLevel.AUDIT, type="MODEL_UPDATE", )
[docs] def validate_global_model(self, request: scaleout_msg.TaskRequest) -> None: """Validate the global model.""" with self.logging_context(LoggingContext(request=request)): model_id = request.model_id if SCALEOUT_CLIENT_STATUS_REPORTING: self.send_status( f"Processing validate request for model_id {model_id}", log_level=scaleout_msg.LogLevel.INFO, type="MODEL_VALIDATION", ) in_model = self.get_model_from_combiner(model_id=model_id) if in_model is None: ScaleoutLogger().error("Could not retrieve model from combiner. Aborting validation request.") return if not self.validate_callback: ScaleoutLogger().error("No validate callback set") return ScaleoutLogger().debug(f"Running validate callback with model ID: {model_id}") try: metrics = self.validate_callback(in_model) except StoppedException: return except Exception as e: ScaleoutLogger().error(f"Validation callback failed with exception: {e}") traceback.print_exc() return if metrics is not None: # Send validation result: bool = self.grpc_handler.send_model_validation( model_id=request.model_id, metrics=json.dumps(metrics), correlation_id=request.correlation_id, session_id=request.session_id, ) if result and SCALEOUT_CLIENT_STATUS_REPORTING: self.send_status( "Model validation completed.", log_level=scaleout_msg.LogLevel.AUDIT, type="MODEL_VALIDATION", ) elif SCALEOUT_CLIENT_STATUS_REPORTING: self.send_status( f"Client {self.client_id} failed to complete model validation.", log_level=scaleout_msg.LogLevel.WARNING, type="MODEL_VALIDATION", )
[docs] def predict_global_model(self, request: scaleout_msg.TaskRequest) -> None: """Predict using the global model.""" with self.logging_context(LoggingContext(request=request)): model_id = request.model_id model = self.get_model_from_combiner(model_id=model_id) if model is None: ScaleoutLogger().error("Could not retrieve model from combiner. Aborting prediction request.") return if not self.predict_callback: ScaleoutLogger().error("No predict callback set") return ScaleoutLogger().info(f"Running predict callback with model ID: {model_id}") try: prediction = self.predict_callback(model) except Exception as e: ScaleoutLogger().error(f"Predict callback failed with exception: {e}") traceback.print_exc() return self.grpc_handler.send_model_prediction( model_id=request.model_id, prediction_output=json.dumps(prediction), correlation_id=request.correlation_id, session_id=request.session_id )
[docs] def log_metric(self, metrics: dict, step: int = None, commit: bool = True, check_task_abort=True, context: LoggingContext = None) -> bool: """Log the metrics to the server. Args: metrics (dict): The metrics to log. step (int, optional): The step number. If provided the context step will be set to this value. If not provided, the step from the context will be used. commit (bool, optional): Whether or not to increment the step. Defaults to True. check_task_abort (bool, optional): Whether or not to check for task abort. Defaults to True. context (LoggingContext, optional): The logging context to use. Defaults to None, which uses the current context. Returns: bool: True if the metrics were logged successfully, False otherwise. """ context = context or self.current_logging_context if context is None: ScaleoutLogger().error("Missing context for logging metric.") return False if step is None: step = context.step else: context.step = step if commit: context.step += 1 message = self.grpc_handler.create_metric_message( metrics=metrics, model_id=context.model_id, step=step, round_id=context.round_id, session_id=context.session_id, ) success = self.grpc_handler.send_model_metric(message) if check_task_abort: self.task_receiver.check_abort() return success
[docs] def log_attributes(self, attributes: dict, check_task_abort: bool = True) -> bool: """Log the attributes to the server. Args: attributes (dict): The attributes to log. check_task_abort (bool, optional): Whether or not to check for task abort. Defaults to True. Returns: bool: True if the attributes were logged successfully, False otherwise. """ message = scaleout_msg.AttributeMessage() message.client_id = self.client_id message.timestamp.GetCurrentTime() for key, value in attributes.items(): message.attributes.add(key=key, value=value) success = self.grpc_handler.send_attributes(message) if check_task_abort: self.task_receiver.check_abort() return success
[docs] def log_telemetry( self, telemetry: dict, check_task_abort: bool = True, ) -> bool: """Log the telemetry data to the server. Args: telemetry (dict): The telemetry data to log. check_task_abort (bool, optional): Whether or not to check for task abort. Defaults to True. Returns: bool: True if the telemetry data was logged successfully, False otherwise. """ message = scaleout_msg.TelemetryMessage() message.client_id = self.client_id message.timestamp.GetCurrentTime() for key, value in telemetry.items(): message.telemetries.add(key=key, value=value) success = self.grpc_handler.send_telemetry(message) if check_task_abort: self.task_receiver.check_abort() return success
[docs] def check_task_abort(self) -> None: """Check if the ongoing task has been aborted. This function should be called periodically from the task callback to ensure that the task can be interrupted if needed. If called from a thread that do not run the task, this function is a no-op. Raises: StoppedException: If the task was aborted. """ self.task_receiver.check_abort()
[docs] def set_name(self, name: str) -> None: """Set the client name.""" ScaleoutLogger().info(f"Setting client name to: {name}") self.name = name
[docs] def set_client_id(self, client_id: str) -> None: """Set the client ID.""" ScaleoutLogger().info(f"Setting client ID to: {client_id}") self.client_id = client_id
[docs] def run(self, with_heartbeat=False, with_polling=True) -> None: """Run the client.""" # Handle SIGTERM for graceful shutdown def _handle_sigterm(signum, frame): raise GracefulExitException() signal.signal(signal.SIGTERM, _handle_sigterm) if with_heartbeat: threading.Thread(target=self._send_heartbeats, args=(self.name, self.client_id), daemon=True).start() if SCALEOUT_CLIENT_SEND_TELEMETRY: threading.Thread(target=self.default_telemetry_loop, daemon=True).start() try: if with_polling: self._run_polling_client() else: self._listen_to_task_stream(client_id=self.client_id) except KeyboardInterrupt: ScaleoutLogger().info("Client stopped by user.") except GracefulExitException: ScaleoutLogger().info("Client stopping gracefully.")
def _run_polling_client(self) -> None: self.task_receiver.start() ScaleoutLogger().info("Task receiver started.") if SCALEOUT_GRACEFUL_CLIENT_CONNECTION: self.grpc_handler.connect() while True: try: ScaleoutLogger().info("Client is running. Press Ctrl+C to stop.") self.task_receiver.wait_on_manager_thread() ScaleoutLogger().info("Task manager thread has exited. Stopping client.") break except GracefulExitException: ScaleoutLogger().info("SIGTERM received, shutting down gracefully...") if not self.task_receiver.has_current_task(): ScaleoutLogger().info("No ongoing task to abort. Exiting...") break self.task_receiver.abort_current_task() break except KeyboardInterrupt: ScaleoutLogger().info("KeyboardInterrupt received, aborting current task...") if not self.task_receiver.has_current_task(): ScaleoutLogger().info("No ongoing task to abort. Exiting client.") break self.task_receiver.abort_current_task() ScaleoutLogger().info("To completely stop the client, press Ctrl+C again within 5 seconds...") try: time.sleep(5) except KeyboardInterrupt: ScaleoutLogger().info("Second KeyboardInterrupt received, stopping client immediately...") break if SCALEOUT_GRACEFUL_CLIENT_CONNECTION: self.grpc_handler.disconnect()
[docs] def get_model_from_combiner(self, model_id: str) -> ScaleoutModel: """Get the model from the combiner.""" return self.grpc_handler.get_model_from_combiner(model_id=model_id)
[docs] def send_model_to_combiner(self, model: ScaleoutModel) -> scaleout_msg.ModelResponse: """Send the model to the combiner.""" return self.grpc_handler.send_model_to_combiner(model=model)
[docs] def send_status( self, msg: str, log_level: scaleout_msg.LogLevel = scaleout_msg.LogLevel.INFO, type: Optional[str] = None, ) -> None: """Send the status.""" self.grpc_handler.send_status(msg, log_level, type)