Source code for fedn.network.clients.client

import base64
import io
import json
import os
import queue
import re
import ssl
import sys
import tempfile
import threading
import time
import uuid
from datetime import datetime
from distutils.dir_util import copy_tree
from io import BytesIO

import grpc
from google.protobuf.json_format import MessageToJson

import fedn.common.net.grpc.fedn_pb2 as fedn
import fedn.common.net.grpc.fedn_pb2_grpc as rpc
from fedn.network.clients.connect import ConnectorClient, Status
from fedn.network.clients.package import PackageRuntime
from fedn.network.clients.state import ClientState, ClientStateToString
from fedn.utils.dispatcher import Dispatcher
from fedn.utils.helpers import get_helper
from fedn.utils.logger import Logger

CHUNK_SIZE = 1024 * 1024
VALID_NAME_REGEX = '^[a-zA-Z0-9_-]*$'


[docs] class GrpcAuth(grpc.AuthMetadataPlugin): def __init__(self, key): self._key = key def __call__(self, context, callback): callback((('authorization', f'Token {self._key}'),), None)
[docs] class Client: """FEDn Client. Service running on client/datanodes in a federation, recieving and handling model update and model validation requests. :param config: A configuration dictionary containing connection information for the discovery service (controller) and settings governing e.g. client-combiner assignment behavior. :type config: dict """ def __init__(self, config): """Initialize the client.""" self.state = None self.error_state = False self._attached = False self._missed_heartbeat = 0 self.config = config self.connector = ConnectorClient(host=config['discover_host'], port=config['discover_port'], token=config['token'], name=config['name'], remote_package=config['remote_compute_context'], force_ssl=config['force_ssl'], verify=config['verify'], combiner=config['preferred_combiner'], id=config['client_id']) # Validate client name match = re.search(VALID_NAME_REGEX, config['name']) if not match: raise ValueError('Unallowed character in client name. Allowed characters: a-z, A-Z, 0-9, _, -.') self.name = config['name'] dirname = time.strftime("%Y%m%d-%H%M%S") self.run_path = os.path.join(os.getcwd(), dirname) os.mkdir(self.run_path) self.logger = Logger( to_file=config['logfile'], file_path=self.run_path) self.started_at = datetime.now() self.logs = [] self.inbox = queue.Queue() # Attach to the FEDn network (get combiner) client_config = self._attach() self._initialize_dispatcher(config) self._initialize_helper(client_config) if not self.helper: print("Failed to retrive helper class settings! {}".format( client_config), flush=True) self._subscribe_to_combiner(config) self.state = ClientState.idle def _assign(self): """Contacts the controller and asks for combiner assignment. :return: A configuration dictionary containing connection information for combiner. :rtype: dict """ print("Asking for assignment!", flush=True) while True: status, response = self.connector.assign() if status == Status.TryAgain: print(response, flush=True) time.sleep(5) continue if status == Status.Assigned: client_config = response break if status == Status.UnAuthorized: print(response, flush=True) sys.exit("Exiting: Unauthorized") if status == Status.UnMatchedConfig: print(response, flush=True) sys.exit("Exiting: UnMatchedConfig") time.sleep(5) print(".", end=' ', flush=True) print("Got assigned!", flush=True) print("Received combiner config: {}".format(client_config), flush=True) return client_config def _add_grpc_metadata(self, key, value): """Add metadata for gRPC calls. :param key: The key of the metadata. :type key: str :param value: The value of the metadata. :type value: str """ # Check if metadata exists and add if not if not hasattr(self, 'metadata'): self.metadata = () # Check if metadata key already exists and replace value if so for i, (k, v) in enumerate(self.metadata): if k == key: # Replace value self.metadata = self.metadata[:i] + ((key, value),) + self.metadata[i + 1:] return # Set metadata using tuple concatenation self.metadata += ((key, value),) def _connect(self, client_config): """Connect to assigned combiner. :param client_config: A configuration dictionary containing connection information for the combiner. :type client_config: dict """ # TODO use the client_config['certificate'] for setting up secure comms' host = client_config['host'] # Add host to gRPC metadata self._add_grpc_metadata('grpc-server', host) print("CLIENT: Using metadata: {}".format(self.metadata), flush=True) port = client_config['port'] secure = False if client_config['fqdn'] is not None: host = client_config['fqdn'] # assuming https if fqdn is used port = 443 print(f"CLIENT: Connecting to combiner host: {host}:{port}", flush=True) if client_config['certificate']: print("CLIENT: using certificate from Reducer for GRPC channel") secure = True cert = base64.b64decode( client_config['certificate']) # .decode('utf-8') credentials = grpc.ssl_channel_credentials(root_certificates=cert) channel = grpc.secure_channel("{}:{}".format(host, str(port)), credentials) elif os.getenv("FEDN_GRPC_ROOT_CERT_PATH"): secure = True print("CLIENT: using root certificate from environment variable for GRPC channel") with open(os.environ["FEDN_GRPC_ROOT_CERT_PATH"], 'rb') as f: credentials = grpc.ssl_channel_credentials(f.read()) channel = grpc.secure_channel("{}:{}".format(host, str(port)), credentials) elif self.config['secure']: secure = True print("CLIENT: using CA certificate for GRPC channel") cert = ssl.get_server_certificate((host, port)) credentials = grpc.ssl_channel_credentials(cert.encode('utf-8')) if self.config['token']: token = self.config['token'] auth_creds = grpc.metadata_call_credentials(GrpcAuth(token)) channel = grpc.secure_channel("{}:{}".format(host, str(port)), grpc.composite_channel_credentials(credentials, auth_creds)) else: channel = grpc.secure_channel("{}:{}".format(host, str(port)), credentials) else: print("CLIENT: using insecure GRPC channel") if port == 443: port = 80 channel = grpc.insecure_channel("{}:{}".format( host, str(port))) self.channel = channel self.connectorStub = rpc.ConnectorStub(channel) self.combinerStub = rpc.CombinerStub(channel) self.modelStub = rpc.ModelServiceStub(channel) print("Client: {} connected {} to {}:{}".format(self.name, "SECURED" if secure else "INSECURE", host, port), flush=True) print("Client: Using {} compute package.".format( client_config["package"])) def _disconnect(self): """Disconnect from the combiner.""" self.channel.close() def _detach(self): """Detach from the FEDn network (disconnect from combiner)""" # Setting _attached to False will make all processing threads return if not self._attached: print("Client is not attached.", flush=True) self._attached = False # Close gRPC connection to combiner self._disconnect() def _attach(self): """Attach to the FEDn network (connect to combiner)""" # Ask controller for a combiner and connect to that combiner. if self._attached: print("Client is already attached. ", flush=True) return None client_config = self._assign() self._connect(client_config) if client_config: self._attached = True return client_config def _initialize_helper(self, client_config): """Initialize the helper class for the client. :param client_config: A configuration dictionary containing connection information for | the discovery service (controller) and settings governing e.g. | client-combiner assignment behavior. :type client_config: dict :return: """ if 'helper_type' in client_config.keys(): self.helper = get_helper(client_config['helper_type']) def _subscribe_to_combiner(self, config): """Listen to combiner message stream and start all processing threads. :param config: A configuration dictionary containing connection information for | the discovery service (controller) and settings governing e.g. | client-combiner assignment behavior. """ # Start sending heartbeats to the combiner. threading.Thread(target=self._send_heartbeat, kwargs={ 'update_frequency': config['heartbeat_interval']}, daemon=True).start() # Start listening for combiner training and validation messages if config['trainer']: threading.Thread( target=self._listen_to_model_update_request_stream, daemon=True).start() if config['validator']: threading.Thread( target=self._listen_to_model_validation_request_stream, daemon=True).start() self._attached = True # Start processing the client message inbox threading.Thread(target=self.process_request, daemon=True).start() def _initialize_dispatcher(self, config): """ Initialize the dispatcher for the client. :param config: A configuration dictionary containing connection information for | the discovery service (controller) and settings governing e.g. | client-combiner assignment behavior. :type config: dict :return: """ if config['remote_compute_context']: pr = PackageRuntime(os.getcwd(), os.getcwd()) retval = None tries = 10 while tries > 0: retval = pr.download( host=config['discover_host'], port=config['discover_port'], token=config['token'], force_ssl=config['force_ssl'], secure=config['secure'] ) if retval: break time.sleep(60) print("No compute package available... retrying in 60s Trying {} more times.".format( tries), flush=True) tries -= 1 if retval: if 'checksum' not in config: print( "\nWARNING: Skipping security validation of local package!, make sure you trust the package source.\n", flush=True) else: checks_out = pr.validate(config['checksum']) if not checks_out: print("Validation was enforced and invalid, client closing!") self.error_state = True return if retval: pr.unpack() self.dispatcher = pr.dispatcher(self.run_path) try: print("Running Dispatcher for entrypoint: startup", flush=True) self.dispatcher.run_cmd("startup") except KeyError: pass else: # TODO: Deprecate dispatch_config = {'entry_points': {'predict': {'command': 'python3 predict.py'}, 'train': {'command': 'python3 train.py'}, 'validate': {'command': 'python3 validate.py'}}} from_path = os.path.join(os.getcwd(), 'client') copy_tree(from_path, self.run_path) self.dispatcher = Dispatcher(dispatch_config, self.run_path)
[docs] def get_model(self, id): """Fetch a model from the assigned combiner. Downloads the model update object via a gRPC streaming channel. :param id: The id of the model update object. :type id: str :return: The model update object. :rtype: BytesIO """ data = BytesIO() for part in self.modelStub.Download(fedn.ModelRequest(id=id), metadata=self.metadata): if part.status == fedn.ModelStatus.IN_PROGRESS: data.write(part.data) if part.status == fedn.ModelStatus.OK: return data if part.status == fedn.ModelStatus.FAILED: return None return data
[docs] def set_model(self, model, id): """Send a model update to the assigned combiner. Uploads the model updated object via a gRPC streaming channel, Upload. :param model: The model update object. :type model: BytesIO :param id: The id of the model update object. :type id: str :return: The model update object. :rtype: BytesIO """ if not isinstance(model, BytesIO): bt = BytesIO() for d in model.stream(32 * 1024): bt.write(d) else: bt = model bt.seek(0, 0) def upload_request_generator(mdl): """Generator function for model upload requests. :param mdl: The model update object. :type mdl: BytesIO :return: A model update request. :rtype: fedn.ModelRequest """ while True: b = mdl.read(CHUNK_SIZE) if b: result = fedn.ModelRequest( data=b, id=id, status=fedn.ModelStatus.IN_PROGRESS) else: result = fedn.ModelRequest( id=id, status=fedn.ModelStatus.OK) yield result if not b: break result = self.modelStub.Upload(upload_request_generator(bt), metadata=self.metadata) return result
def _listen_to_model_update_request_stream(self): """Subscribe to the model update request stream. :return: None :rtype: None """ r = fedn.ClientAvailableMessage() r.sender.name = self.name r.sender.role = fedn.WORKER # Add client to metadata self._add_grpc_metadata('client', self.name) while True: try: for request in self.combinerStub.ModelUpdateRequestStream(r, metadata=self.metadata): if request.sender.role == fedn.COMBINER: # Process training request self._send_status("Received model update request.", log_level=fedn.Status.AUDIT, type=fedn.StatusType.MODEL_UPDATE_REQUEST, request=request) self.inbox.put(('train', request)) if not self._attached: return except grpc.RpcError as e: _ = e.code() except grpc.RpcError: # TODO: make configurable timeout = 5 time.sleep(timeout) except Exception: raise if not self._attached: return def _listen_to_model_validation_request_stream(self): """Subscribe to the model validation request stream. :return: None :rtype: None """ r = fedn.ClientAvailableMessage() r.sender.name = self.name r.sender.role = fedn.WORKER while True: try: for request in self.combinerStub.ModelValidationRequestStream(r, metadata=self.metadata): # Process validation request _ = request.model_id self._send_status("Recieved model validation request.", log_level=fedn.Status.AUDIT, type=fedn.StatusType.MODEL_VALIDATION_REQUEST, request=request) self.inbox.put(('validate', request)) except grpc.RpcError: # TODO: make configurable timeout = 5 time.sleep(timeout) except Exception: raise if not self._attached: return def _process_training_request(self, model_id): """Process a training (model update) request. :param model_id: The model id of the model to be updated. :type model_id: str :return: The model id of the updated model, or None if the update failed. And a dict with metadata. :rtype: tuple """ self._send_status( "\t Starting processing of training request for model_id {}".format(model_id)) self.state = ClientState.training try: meta = {} tic = time.time() mdl = self.get_model(str(model_id)) meta['fetch_model'] = time.time() - tic inpath = self.helper.get_tmp_path() with open(inpath, 'wb') as fh: fh.write(mdl.getbuffer()) outpath = self.helper.get_tmp_path() tic = time.time() # TODO: Check return status, fail gracefully self.dispatcher.run_cmd("train {} {}".format(inpath, outpath)) meta['exec_training'] = time.time() - tic tic = time.time() out_model = None with open(outpath, "rb") as fr: out_model = io.BytesIO(fr.read()) # Push model update to combiner server updated_model_id = uuid.uuid4() self.set_model(out_model, str(updated_model_id)) meta['upload_model'] = time.time() - tic # Read the metadata file with open(outpath+'-metadata', 'r') as fh: training_metadata = json.loads(fh.read()) meta['training_metadata'] = training_metadata os.unlink(inpath) os.unlink(outpath) os.unlink(outpath+'-metadata') except Exception as e: print("ERROR could not process training request due to error: {}".format( e), flush=True) updated_model_id = None meta = {'status': 'failed', 'error': str(e)} self.state = ClientState.idle return updated_model_id, meta def _process_validation_request(self, model_id, is_inference): """Process a validation request. :param model_id: The model id of the model to be validated. :type model_id: str :param is_inference: True if the validation is an inference request, False if it is a validation request. :type is_inference: bool :return: The validation metrics, or None if validation failed. :rtype: dict """ # Figure out cmd if is_inference: cmd = 'infer' else: cmd = 'validate' self._send_status( f"Processing {cmd} request for model_id {model_id}") self.state = ClientState.validating try: model = self.get_model(str(model_id)) inpath = self.helper.get_tmp_path() with open(inpath, "wb") as fh: fh.write(model.getbuffer()) _, outpath = tempfile.mkstemp() self.dispatcher.run_cmd(f"{cmd} {inpath} {outpath}") with open(outpath, "r") as fh: validation = json.loads(fh.read()) os.unlink(inpath) os.unlink(outpath) except Exception as e: print("Validation failed with exception {}".format(e), flush=True) raise self.state = ClientState.idle return None self.state = ClientState.idle return validation
[docs] def process_request(self): """Process training and validation tasks. """ while True: if not self._attached: return try: (task_type, request) = self.inbox.get(timeout=1.0) if task_type == 'train': tic = time.time() self.state = ClientState.training model_id, meta = self._process_training_request( request.model_id) processing_time = time.time()-tic meta['processing_time'] = processing_time meta['config'] = request.data if model_id is not None: # Send model update to combiner update = fedn.ModelUpdate() update.sender.name = self.name update.sender.role = fedn.WORKER update.receiver.name = request.sender.name update.receiver.role = request.sender.role update.model_id = request.model_id update.model_update_id = str(model_id) update.timestamp = str(datetime.now()) update.correlation_id = request.correlation_id update.meta = json.dumps(meta) # TODO: Check responses _ = self.combinerStub.SendModelUpdate(update, metadata=self.metadata) self._send_status("Model update completed.", log_level=fedn.Status.AUDIT, type=fedn.StatusType.MODEL_UPDATE, request=update) else: self._send_status("Client {} failed to complete model update.", log_level=fedn.Status.WARNING, request=request) self.state = ClientState.idle self.inbox.task_done() elif task_type == 'validate': self.state = ClientState.validating metrics = self._process_validation_request( request.model_id, request.is_inference) if metrics is not None: # Send validation validation = fedn.ModelValidation() validation.sender.name = self.name validation.sender.role = fedn.WORKER validation.receiver.name = request.sender.name validation.receiver.role = request.sender.role validation.model_id = str(request.model_id) validation.data = json.dumps(metrics) self.str = str(datetime.now()) validation.timestamp = self.str validation.correlation_id = request.correlation_id _ = self.combinerStub.SendModelValidation( validation, metadata=self.metadata) # Set status type if request.is_inference: status_type = fedn.StatusType.INFERENCE else: status_type = fedn.StatusType.MODEL_VALIDATION self._send_status("Model validation completed.", log_level=fedn.Status.AUDIT, type=status_type, request=validation) else: self._send_status("Client {} failed to complete model validation.".format(self.name), log_level=fedn.Status.WARNING, request=request) self.state = ClientState.idle self.inbox.task_done() except queue.Empty: pass
def _handle_combiner_failure(self): """ Register failed combiner connection.""" self._missed_heartbeat += 1 if self._missed_heartbeat > self.config['reconnect_after_missed_heartbeat']: self._detach() def _send_heartbeat(self, update_frequency=2.0): """Send a heartbeat to the combiner. :param update_frequency: The frequency of the heartbeat in seconds. :type update_frequency: float :return: None if the client is detached. :rtype: None """ while True: heartbeat = fedn.Heartbeat(sender=fedn.Client( name=self.name, role=fedn.WORKER)) try: self.connectorStub.SendHeartbeat(heartbeat, metadata=self.metadata) self._missed_heartbeat = 0 except grpc.RpcError as e: status_code = e.code() print("CLIENT heartbeat: GRPC ERROR {} retrying..".format( status_code.name), flush=True) self._handle_combiner_failure() time.sleep(update_frequency) if not self._attached: return def _send_status(self, msg, log_level=fedn.Status.INFO, type=None, request=None): """Send status message. :param msg: The message to send. :type msg: str :param log_level: The log level of the message. :type log_level: fedn.Status.INFO, fedn.Status.WARNING, fedn.Status.ERROR :param type: The type of the message. :type type: str :param request: The request message. :type request: fedn.Request """ status = fedn.Status() status.timestamp = str(datetime.now()) status.sender.name = self.name status.sender.role = fedn.WORKER status.log_level = log_level status.status = str(msg) if type is not None: status.type = type if request is not None: status.data = MessageToJson(request) self.logs.append( "{} {} LOG LEVEL {} MESSAGE {}".format(str(datetime.now()), status.sender.name, status.log_level, status.status)) _ = self.connectorStub.SendStatus(status, metadata=self.metadata)
[docs] def run(self): """ Run the client. """ try: cnt = 0 old_state = self.state while True: time.sleep(1) cnt += 1 if self.state != old_state: print("{}:CLIENT in {} state".format(datetime.now().strftime( '%Y-%m-%d %H:%M:%S'), ClientStateToString(self.state)), flush=True) if cnt > 5: print("{}:CLIENT active".format( datetime.now().strftime('%Y-%m-%d %H:%M:%S')), flush=True) cnt = 0 if not self._attached: print("Detatched from combiner.", flush=True) # TODO: Implement a check/condition to ulitmately close down if too many reattachment attepts have failed. s self._attach() self._subscribe_to_combiner(self.config) if self.error_state: return except KeyboardInterrupt: print("Ok, exiting..")