import tempfile
import threading
from typing import BinaryIO, Iterable
import scaleoututil.grpc.scaleout_pb2 as scaleout_msg
from scaleoututil.utils.checksum import compute_checksum_from_stream
from scaleoututil.helpers.plugins.numpyhelper import Helper
CHUNK_SIZE = 1 * 1024 * 1024 # 8 KB chunk size for reading/writing files
SPOOLED_MAX_SIZE = 10 * 1024 * 1024 # 10 MB max size for spooled temporary files
[docs]
class ScaleoutModel:
"""The ScaleoutModel class is the primary model representation in the Scaleout framework.
A ScaleoutModel object contains a data object (tempfile.SpooledTemporaryFile) that holds the model parameters.
The model parameters dict can be extracted from the data object or be used to create a model object.
Unpacking of the model parameters is done by the helper which needs to be provided either to the the class or
to the method
"""
[docs]
def __init__(self):
"""Initializes a ScaleoutModel object."""
# Using SpooledTemporaryFile to handle large model data efficiently
# It will automatically store on disk if the data exceeds the specified size
self._data = tempfile.SpooledTemporaryFile(SPOOLED_MAX_SIZE)
self._data_lock = threading.RLock()
self.model_id = None
self.helper = None
self._checksum = None
@property
def checksum(self) -> str:
"""Returns the checksum of the model data."""
if self._checksum is None:
self._checksum = compute_checksum_from_stream(self.get_stream())
return self._checksum
[docs]
def verify_checksum(self, checksum: str) -> bool:
"""Verifies the checksum of the model data.
If no checksum is provided, it returns True.
"""
return checksum is None or self.checksum == checksum
[docs]
def get_stream(self):
"""Returns a stream of the model data.
To avoid concurrency issues, a new stream is created each time this method is called.
"""
with self._data_lock:
self._data.seek(0)
new_stream = tempfile.SpooledTemporaryFile(SPOOLED_MAX_SIZE)
while chunk := self._data.read(CHUNK_SIZE):
new_stream.write(chunk)
new_stream.seek(0)
self._data.seek(0)
return new_stream
[docs]
def get_stream_unsafe(self):
"""Returns the internal stream of the model data.
This method is not thread-safe and should be used with caution.
"""
self._data.seek(0)
return self._data
[docs]
def get_model_params(self, helper=None):
"""Returns the model parameters as a dictionary."""
stream = self.get_stream()
self.helper = helper or self.helper
if self.helper is None:
raise ValueError("No helper provided to unpack model parameters.")
return self.helper.load(stream)
[docs]
def save_to_file(self, file_path: str):
"""Saves the model data to a file."""
with open(file_path, "wb") as file:
stream = self.get_stream()
while chunk := stream.read(CHUNK_SIZE):
file.write(chunk)
[docs]
def get_filechunk_stream(self, chunk_size=CHUNK_SIZE):
"""Returns a generator that yields chunks of the model data."""
stream = self.get_stream()
while chunk := stream.read(chunk_size):
yield scaleout_msg.FileChunk(data=chunk)
[docs]
@staticmethod
def from_model_params(model_params: dict, helper=None) -> "ScaleoutModel":
"""Creates a ScaleoutModel from model parameters."""
model_reference = ScaleoutModel()
model_reference.helper = helper
if helper is None:
# No helper provided, using numpy helper as default
helper = Helper()
helper.save(model_params, model_reference._data)
model_reference._data.seek(0)
return model_reference
[docs]
@staticmethod
def from_file(file_path: str) -> "ScaleoutModel":
"""Creates a ScaleoutModel from a file."""
with open(file_path, "rb") as file:
return ScaleoutModel.from_stream(file)
[docs]
@staticmethod
def from_stream(stream: BinaryIO) -> "ScaleoutModel":
"""Creates a ScaleoutModel from a stream."""
model_reference = ScaleoutModel()
while chunk := stream.read(CHUNK_SIZE):
model_reference._data.write(chunk)
model_reference._data.seek(0)
return model_reference
[docs]
@staticmethod
def from_filechunk_stream(filechunk_stream: Iterable[scaleout_msg.FileChunk]) -> "ScaleoutModel":
"""Creates a ScaleoutModel from a filechunk stream."""
model_reference = ScaleoutModel()
for chunk in filechunk_stream:
if chunk.data:
model_reference._data.write(chunk.data)
model_reference._data.seek(0)
return model_reference