import copy
import uuid
from datetime import datetime
import pymongo
from google.protobuf.json_format import MessageToDict
from fedn.common.log_config import logger
from fedn.network.state import ReducerStateToString, StringToReducerState
[docs]
class MongoStateStore:
"""Statestore implementation using MongoDB.
:param network_id: The network id.
:type network_id: str
:param config: The statestore configuration.
:type config: dict
:param defaults: The default configuration. Given by config/settings-reducer.yaml.template
:type defaults: dict
"""
def __init__(self, network_id, config):
"""Constructor."""
self.__inited = False
try:
self.config = config
self.network_id = network_id
self.mdb = self.connect()
# FEDn network
self.network = self.mdb["network"]
self.reducer = self.network["reducer"]
self.combiners = self.network["combiners"]
self.clients = self.network["clients"]
self.storage = self.network["storage"]
# Control
self.control = self.mdb["control"]
self.package = self.control["package"]
self.state = self.control["state"]
self.model = self.control["model"]
self.sessions = self.control["sessions"]
self.rounds = self.control["rounds"]
self.validations = self.control["validations"]
# Logging
self.status = self.control["status"]
self.__inited = True
except Exception as e:
logger.error("FAILED TO CONNECT TO MONGODB, {}".format(e))
self.state = None
self.model = None
self.control = None
self.network = None
self.combiners = None
self.clients = None
raise
self.init_index()
[docs]
def connect(self):
"""Establish client connection to MongoDB.
:param config: Dictionary containing connection strings and security credentials.
:type config: dict
:param network_id: Unique identifier for the FEDn network, used as db name
:type network_id: str
:return: MongoDB client pointing to the db corresponding to network_id
"""
try:
mc = pymongo.MongoClient(**self.config)
# This is so that we check that the connection is live
mc.server_info()
mdb = mc[self.network_id]
return mdb
except Exception:
raise
[docs]
def init_index(self):
self.package.create_index([("id", pymongo.DESCENDING)])
self.clients.create_index([("client_id", pymongo.DESCENDING)])
[docs]
def is_inited(self):
"""Check if the statestore is intialized.
:return: True if initialized, else False.
:rtype: bool
"""
return self.__inited
[docs]
def get_config(self):
"""Retrive the statestore config.
:return: The statestore config.
:rtype: dict
"""
data = {
"type": "MongoDB",
"mongo_config": self.config,
"network_id": self.network_id,
}
return data
[docs]
def state(self):
"""Get the current state.
:return: The current state.
:rtype: str
"""
return StringToReducerState(self.state.find_one()["current_state"])
[docs]
def transition(self, state):
"""Transition to a new state.
:param state: The new state.
:type state: str
:return:
"""
old_state = self.state.find_one({"state": "current_state"})
if old_state != state:
return self.state.update_one(
{"state": "current_state"},
{"$set": {"state": ReducerStateToString(state)}},
True,
)
else:
logger.info("Not updating state, already in {}".format(ReducerStateToString(state)))
[docs]
def get_sessions(self, limit=None, skip=None, sort_key="_id", sort_order=pymongo.DESCENDING):
"""Get all sessions.
:param limit: The maximum number of sessions to return.
:type limit: int
:param skip: The number of sessions to skip.
:type skip: int
:param sort_key: The key to sort by.
:type sort_key: str
:param sort_order: The sort order.
:type sort_order: pymongo.ASCENDING or pymongo.DESCENDING
:return: Dictionary of sessions in result (array of session objects) and count.
"""
result = None
if limit is not None and skip is not None:
limit = int(limit)
skip = int(skip)
result = self.sessions.find().limit(limit).skip(skip).sort(sort_key, sort_order)
else:
result = self.sessions.find().sort(sort_key, sort_order)
count = self.sessions.count_documents({})
return {
"result": result,
"count": count,
}
[docs]
def get_session(self, session_id):
"""Get session with id.
:param session_id: The session id.
:type session_id: str
:return: The session.
:rtype: ObjectID
"""
return self.sessions.find_one({"session_id": session_id})
[docs]
def get_session_status(self, session_id):
"""Get the session status.
:param session_id: The session id.
:type session_id: str
:return: The session status.
:rtype: str
"""
session = self.sessions.find_one({"session_id": session_id})
return session["status"]
[docs]
def set_latest_model(self, model_id, session_id=None):
"""Set the latest model id.
:param model_id: The model id.
:type model_id: str
:return:
"""
committed_at = datetime.now()
current_model = self.model.find_one({"key": "current_model"})
parent_model = None
# if session_id is set the it means the model is generated from a session
# and we need to set the parent model
# if not the model is uploaded by the user and we don't need to set the parent model
if session_id is not None:
parent_model = current_model["model"] if current_model and "model" in current_model else None
self.model.insert_one(
{
"key": "models",
"model": model_id,
"parent_model": parent_model,
"session_id": session_id,
"committed_at": committed_at,
}
)
self.model.update_one({"key": "current_model"}, {"$set": {"model": model_id}}, True)
self.model.update_one(
{"key": "model_trail"},
{
"$push": {
"model": model_id,
"committed_at": str(committed_at),
}
},
True,
)
[docs]
def get_initial_model(self):
"""Return model_id for the initial model in the model trail
:return: The initial model id. None if no model is found.
:rtype: str
"""
result = self.model.find_one({"key": "model_trail"}, sort=[("committed_at", pymongo.ASCENDING)])
if result is None:
return None
try:
model_id = result["model"]
if model_id == "" or model_id == " ":
return None
return model_id[0]
except (KeyError, IndexError):
return None
[docs]
def get_latest_model(self):
"""Return model_id for the latest model in the model_trail
:return: The latest model id. None if no model is found.
:rtype: str
"""
result = self.model.find_one({"key": "current_model"})
if result is None:
return None
try:
model_id = result["model"]
if model_id == "" or model_id == " ":
return None
return model_id
except (KeyError, IndexError):
return None
[docs]
def set_current_model(self, model_id: str):
"""Set the current model in statestore.
:param model_id: The model id.
:type model_id: str
:return:
"""
try:
committed_at = datetime.now()
existing_model = self.model.find_one({"key": "models", "model": model_id})
if existing_model is not None:
self.model.update_one({"key": "current_model"}, {"$set": {"model": model_id, "committed_at": committed_at, "session_id": None}}, True)
return True
except Exception as e:
logger.error("ERROR: {}".format(e))
return False
[docs]
def get_latest_round(self):
"""Get the id of the most recent round.
:return: The id of the most recent round.
:rtype: ObjectId
"""
return self.rounds.find_one(sort=[("_id", pymongo.DESCENDING)])
[docs]
def get_round(self, id):
"""Get round with id.
:param id: id of round to get
:type id: int
:return: round with id, reducer and combiners
:rtype: ObjectId
"""
return self.rounds.find_one({"round_id": str(id)})
[docs]
def get_rounds(self):
"""Get all rounds.
:return: All rounds.
:rtype: ObjectId
"""
return self.rounds.find()
[docs]
def get_validations(self, **kwargs):
"""Get validations from the database.
:param kwargs: query to filter validations
:type kwargs: dict
:return: validations matching query
:rtype: ObjectId
"""
result = self.control.validations.find(kwargs)
return result
[docs]
def set_active_compute_package(self, id: str):
"""Set the active compute package in statestore.
:param id: The id of the compute package (not document _id).
:type id: str
:return: True if successful.
:rtype: bool
"""
try:
find = {"id": id}
projection = {"_id": False, "key": False}
doc = self.control.package.find_one(find, projection)
if doc is None:
return False
doc["key"] = "active"
self.control.package.replace_one({"key": "active"}, doc)
except Exception as e:
logger.error("ERROR: {}".format(e))
return False
return True
[docs]
def set_compute_package(self, file_name: str, storage_file_name: str, helper_type: str, name: str = None, description: str = None):
"""Set the active compute package in statestore.
:param file_name: The file_name of the compute package.
:type file_name: str
:return: True if successful.
:rtype: bool
"""
obj = {
"file_name": file_name,
"storage_file_name": storage_file_name,
"helper": helper_type,
"committed_at": datetime.now(),
"name": name,
"description": description,
"id": str(uuid.uuid4()),
}
self.control.package.update_one(
{"key": "active"},
{"$set": obj},
True,
)
trail_obj = {**{"key": "package_trail"}, **obj}
self.control.package.insert_one(trail_obj)
return True
[docs]
def get_compute_package(self):
"""Get the active compute package.
:return: The active compute package.
:rtype: ObjectID
"""
try:
find = {"key": "active"}
projection = {"key": False, "_id": False}
ret = self.control.package.find_one(find, projection)
return ret
except Exception as e:
logger.error("ERROR: {}".format(e))
return None
[docs]
def list_compute_packages(self, limit: int = None, skip: int = None, sort_key="committed_at", sort_order=pymongo.DESCENDING):
"""List compute packages in the statestore (paginated).
:param limit: The maximum number of compute packages to return.
:type limit: int
:param skip: The number of compute packages to skip.
:type skip: int
:param sort_key: The key to sort by.
:type sort_key: str
:param sort_order: The sort order.
:type sort_order: pymongo.ASCENDING or pymongo.DESCENDING
:return: Dictionary of compute packages in result and count.
:rtype: dict
"""
result = None
count = None
find_option = {"key": "package_trail"}
projection = {"key": False, "_id": False}
try:
if limit is not None and skip is not None:
result = self.control.package.find(find_option, projection).limit(limit).skip(skip).sort(sort_key, sort_order)
else:
result = self.control.package.find(find_option, projection).sort(sort_key, sort_order)
count = self.control.package.count_documents(find_option)
except Exception as e:
logger.error("ERROR: {}".format(e))
return None
return {
"result": result or [],
"count": count or 0,
}
[docs]
def set_helper(self, helper):
"""Set the active helper package in statestore.
:param helper: The name of the helper package. See helper.py for available helpers.
:type helper: str
:return:
"""
self.control.package.update_one({"key": "active"}, {"$set": {"helper": helper}}, True)
[docs]
def get_helper(self):
"""Get the active helper package.
:return: The active helper set for the package.
:rtype: str
"""
ret = self.control.package.find_one({"key": "active"})
# if local compute package used, then 'package' is None
# if not ret:
# get framework from round_config instead
# ret = self.control.config.find_one({'key': 'round_config'})
try:
retcheck = ret["helper"]
if retcheck == "" or retcheck == " ": # ugly check for empty string
return None
return retcheck
except (KeyError, IndexError):
return None
[docs]
def list_models(
self,
session_id=None,
limit=None,
skip=None,
sort_key="committed_at",
sort_order=pymongo.DESCENDING,
):
"""List all models in the statestore.
:param session_id: The session id.
:type session_id: str
:param limit: The maximum number of models to return.
:type limit: int
:param skip: The number of models to skip.
:type skip: int
:return: List of models.
:rtype: list
"""
result = None
find_option = {"key": "models"} if session_id is None else {"key": "models", "session_id": session_id}
projection = {"_id": False, "key": False}
if limit is not None and skip is not None:
limit = int(limit)
skip = int(skip)
result = self.model.find(find_option, projection).limit(limit).skip(skip).sort(sort_key, sort_order)
else:
result = self.model.find(find_option, projection).sort(sort_key, sort_order)
count = self.model.count_documents(find_option)
return {
"result": result,
"count": count,
}
[docs]
def get_model_trail(self):
"""Get the model trail.
:return: dictionary of model_id: committed_at
:rtype: dict
"""
# TODO Make it so that model order from db is preserved.
result = self.model.find_one({"key": "model_trail"})
try:
if result is not None:
committed_at = result["committed_at"]
model = result["model"]
model_dictionary = dict(zip(model, committed_at))
return model_dictionary
else:
return None
except (KeyError, IndexError):
return None
[docs]
def get_model_ancestors(self, model_id: str, limit: int):
"""Get the model ancestors.
:param model_id: The model id.
:type model_id: str
:param limit: The maximum number of ancestors to return.
:type limit: int
:return: List of model ancestors.
:rtype: list
"""
model = self.model.find_one({"key": "models", "model": model_id})
current_model_id = model["parent_model"] if model is not None else None
result = []
for _ in range(limit):
if current_model_id is None:
break
model = self.model.find_one({"key": "models", "model": current_model_id})
if model is not None:
result.append(model)
current_model_id = model["parent_model"]
return result
[docs]
def get_model_descendants(self, model_id: str, limit: int):
"""Get the model descendants.
:param model_id: The model id.
:type model_id: str
:param limit: The maximum number of descendants to return.
:type limit: int
:return: List of model descendants.
:rtype: list
"""
model: object = self.model.find_one({"key": "models", "model": model_id})
current_model_id: str = model["model"] if model is not None else None
result: list = []
for _ in range(limit):
if current_model_id is None:
break
model: str = self.model.find_one({"key": "models", "parent_model": current_model_id})
if model is not None:
result.append(model)
current_model_id = model["model"]
result.reverse()
return result
[docs]
def get_model(self, model_id):
"""Get model with id.
:param model_id: id of model to get
:type model_id: str
:return: model with id
:rtype: ObjectId
"""
return self.model.find_one({"key": "models", "model": model_id})
[docs]
def get_events(self, **kwargs):
"""Get events from the database.
:param kwargs: query to filter events
:type kwargs: dict
:return: events matching query
:rtype: ObjectId
"""
# check if kwargs is empty
result = None
count = None
projection = {"_id": False}
if not kwargs:
result = self.control.status.find({}, projection).sort("timestamp", pymongo.DESCENDING)
count = self.control.status.count_documents({})
else:
limit = kwargs.pop("limit", None)
skip = kwargs.pop("skip", None)
if limit is not None and skip is not None:
limit = int(limit)
skip = int(skip)
result = self.control.status.find(kwargs, projection).sort("timestamp", pymongo.DESCENDING).limit(limit).skip(skip)
else:
result = self.control.status.find(kwargs, projection).sort("timestamp", pymongo.DESCENDING)
count = self.control.status.count_documents(kwargs)
return {
"result": result,
"count": count,
}
[docs]
def get_storage_backend(self):
"""Get the storage backend.
:return: The storage backend.
:rtype: ObjectID
"""
try:
ret = self.storage.find({"status": "enabled"}, projection={"_id": False})
return ret[0]
except (KeyError, IndexError):
return None
[docs]
def set_storage_backend(self, config):
"""Set the storage backend.
:param config: The storage backend configuration.
:type config: dict
:return:
"""
config = copy.deepcopy(config)
config["updated_at"] = str(datetime.now())
config["status"] = "enabled"
self.storage.update_one({"storage_type": config["storage_type"]}, {"$set": config}, True)
[docs]
def set_reducer(self, reducer_data):
"""Set the reducer in the statestore.
:param reducer_data: dictionary of reducer config.
:type reducer_data: dict
:return:
"""
reducer_data["updated_at"] = str(datetime.now())
self.reducer.update_one({"name": reducer_data["name"]}, {"$set": reducer_data}, True)
[docs]
def get_reducer(self):
"""Get reducer.config.
return: reducer config.
rtype: ObjectId
"""
try:
ret = self.reducer.find_one()
return ret
except Exception:
return None
[docs]
def get_combiner(self, name):
"""Get combiner by name.
:param name: name of combiner to get.
:type name: str
:return: The combiner.
:rtype: ObjectId
"""
try:
ret = self.combiners.find_one({"name": name})
return ret
except Exception:
return None
[docs]
def get_combiners(self, limit=None, skip=None, sort_key="updated_at", sort_order=pymongo.DESCENDING, projection={}):
"""Get all combiners.
:param limit: The maximum number of combiners to return.
:type limit: int
:param skip: The number of combiners to skip.
:type skip: int
:param sort_key: The key to sort by.
:type sort_key: str
:param sort_order: The sort order.
:type sort_order: pymongo.ASCENDING or pymongo.DESCENDING
:param projection: The projection.
:type projection: dict
:return: Dictionary of combiners in result and count.
:rtype: dict
"""
result = None
count = None
try:
if limit is not None and skip is not None:
limit = int(limit)
skip = int(skip)
result = self.combiners.find({}, projection).limit(limit).skip(skip).sort(sort_key, sort_order)
else:
result = self.combiners.find({}, projection).sort(sort_key, sort_order)
count = self.combiners.count_documents({})
except Exception:
return None
return {
"result": result,
"count": count,
}
[docs]
def set_combiner(self, combiner_data):
"""Set combiner in statestore.
:param combiner_data: dictionary of combiner config
:type combiner_data: dict
:return:
"""
combiner_data["updated_at"] = str(datetime.now())
self.combiners.update_one({"name": combiner_data["name"]}, {"$set": combiner_data}, True)
[docs]
def delete_combiner(self, combiner):
"""Delete a combiner from statestore.
:param combiner: name of combiner to delete.
:type combiner: str
:return:
"""
try:
self.combiners.delete_one({"name": combiner})
except Exception:
logger.error(
"Failed to delete combiner: {}".format(combiner),
)
[docs]
def set_client(self, client_data):
"""Set client in statestore.
:param client_data: dictionary of client config.
:type client_data: dict
:return:
"""
client_data["updated_at"] = str(datetime.now())
try:
# self.clients.update_one({"client_id": client_data["client_id"]}, {"$set": client_data}, True)
self.clients.update_one({"client_id": client_data["client_id"]}, {"$set": {k: v for k, v in client_data.items() if v is not None}}, upsert=True)
except KeyError:
# If client_id is not present, use name as identifier, for backwards compatibility
id = str(uuid.uuid4())
client_data["client_id"] = id
# self.clients.update_one({"name": client_data["name"]}, {"$set": client_data}, True)
self.clients.update_one({"client_id": client_data["client_id"]}, {"$set": {k: v for k, v in client_data.items() if v is not None}}, upsert=True)
[docs]
def get_client(self, client_id):
"""Get client by client_id.
:param client_id: client_id of client to get.
:type client_id: str
:return: The client. None if not found.
:rtype: ObjectId
"""
try:
ret = self.clients.find({"key": client_id})
if list(ret) == []:
return None
else:
return ret
except Exception:
return None
[docs]
def list_clients(self, limit=None, skip=None, status=None, sort_key="last_seen", sort_order=pymongo.DESCENDING):
"""List all clients registered on the network.
:param limit: The maximum number of clients to return.
:type limit: int
:param skip: The number of clients to skip.
:type skip: int
:param status: online | offline
:type status: str
:param sort_key: The key to sort by.
"""
result = None
count = None
try:
find = {} if status is None else {"status": status}
projection = {"_id": False, "updated_at": False}
if limit is not None and skip is not None:
limit = int(limit)
skip = int(skip)
result = self.clients.find(find, projection).limit(limit).skip(skip).sort(sort_key, sort_order)
else:
result = self.clients.find(find, projection).sort(sort_key, sort_order)
count = self.clients.count_documents(find)
except Exception as e:
logger.error("{}".format(e))
return {
"result": result,
"count": count,
}
[docs]
def list_combiners_data(self, combiners, sort_key="count", sort_order=pymongo.DESCENDING):
"""List all combiner data.
:param combiners: list of combiners to get data for.
:type combiners: list
:param sort_key: The key to sort by.
:type sort_key: str
:param sort_order: The sort order.
:type sort_order: pymongo.ASCENDING or pymongo.DESCENDING
:return: list of combiner data.
:rtype: list(ObjectId)
"""
result = None
try:
pipeline = (
[
{"$match": {"combiner": {"$in": combiners}, "status": "online"}},
{"$group": {"_id": "$combiner", "count": {"$sum": 1}}},
{"$sort": {sort_key: sort_order, "_id": pymongo.ASCENDING}},
]
if combiners is not None
else [{"$group": {"_id": "$combiner", "count": {"$sum": 1}}}, {"$sort": {sort_key: sort_order, "_id": pymongo.ASCENDING}}]
)
result = self.clients.aggregate(pipeline)
except Exception as e:
logger.error(e)
return result
[docs]
def report_status(self, msg):
"""Write status message to the database.
:param msg: The status message.
:type msg: str
"""
data = MessageToDict(msg)
if self.status is not None:
self.status.insert_one(data)
[docs]
def report_validation(self, validation):
"""Write model validation to database.
:param validation: The model validation.
:type validation: dict
"""
data = MessageToDict(validation)
if self.validations is not None:
self.validations.insert_one(data)
[docs]
def drop_status(self):
"""Drop the status collection."""
if self.status:
self.status.drop()
[docs]
def create_session(self, id=None):
"""Create a new session object.
:param id: The ID of the created session.
:type id: uuid, str
"""
if not id:
id = uuid.uuid4()
data = {"session_id": str(id)}
self.sessions.insert_one(data)
[docs]
def create_round(self, round_data):
"""Create a new round.
:param round_data: Dictionary with round data.
:type round_data: dict
"""
# TODO: Add check if round_id already exists
self.rounds.insert_one(round_data)
[docs]
def set_session_config(self, id: str, config) -> None:
"""Set the session configuration.
:param id: The session id
:type id: str
:param config: Session configuration
:type config: dict
"""
self.sessions.update_one({"session_id": str(id)}, {"$push": {"session_config": config}}, True)
# Added to accomodate new session config structure
[docs]
def set_session_config_v2(self, id: str, config) -> None:
"""Set the session configuration.
:param id: The session id
:type id: str
:param config: Session configuration
:type config: dict
"""
self.sessions.update_one({"session_id": str(id)}, {"$set": {"session_config": config}}, True)
[docs]
def set_session_status(self, id, status):
"""Set session status.
:param round_id: The round unique identifier
:type round_id: str
:param round_status: The status of the session.
"""
self.sessions.update_one({"session_id": str(id)}, {"$set": {"status": status}}, True)
[docs]
def set_round_combiner_data(self, data):
"""Set combiner round controller data.
:param data: The combiner data
:type data: dict
"""
self.rounds.update_one({"round_id": str(data["round_id"])}, {"$push": {"combiners": data}}, True)
[docs]
def set_round_config(self, round_id, round_config):
"""Set round configuration.
:param round_id: The round unique identifier
:type round_id: str
:param round_config: The round configuration
:type round_config: dict
"""
self.rounds.update_one({"round_id": round_id}, {"$set": {"round_config": round_config}}, True)
[docs]
def set_round_status(self, round_id, round_status):
"""Set round status.
:param round_id: The round unique identifier
:type round_id: str
:param round_status: The status of the round.
"""
self.rounds.update_one({"round_id": round_id}, {"$set": {"status": round_status}}, True)
[docs]
def set_round_data(self, round_id, round_data):
"""Update round metadata
:param round_id: The round unique identifier
:type round_id: str
:param round_data: The round metadata
:type round_data: dict
"""
self.rounds.update_one({"round_id": round_id}, {"$set": {"round_data": round_data}}, True)
[docs]
def update_client_status(self, clients, status):
"""Update client status in statestore.
:param client_name: The client name
:type client_name: str
:param status: The client status
:type status: str
:return: None
"""
datetime_now = datetime.now()
filter_query = {"client_id": {"$in": clients}}
update_query = {"$set": {"last_seen": datetime_now, "status": status}}
self.clients.update_many(filter_query, update_query)