Source code for flox.clients.MainClient

import os
import platform
from datetime import datetime
from timeit import default_timer as timer
from typing import Dict, List, Optional, Tuple, Union

import numpy as np
import psutil

from flox.common import NDArrays
from flox.logic import FloxClientLogic


[docs]class MainClient(FloxClientLogic):
[docs] def on_model_receive(self): pass
[docs] def retrieve_local_data(self, config: dict) -> tuple: """ Retrieves x_train and y_train from given paths of .npy files Overwrite this function if you need to retrieve your locally stored data differently Parameters ---------- config: dict dictionary with values indicating the path directory and file names for x_ and y_train data Returns ------- raw_training_data: Tuple(NDArrays, NDArrays) x_train and y_train returned as a tuple """ x_train_path_file = os.sep.join( [config["path_dir"], config["x_train_filename"]] ) y_train_path_file = os.sep.join( [config["path_dir"], config["y_train_filename"]] ) # load the files with open(x_train_path_file, "rb") as f: x_train = np.load(f) with open(y_train_path_file, "rb") as f: y_train = np.load(f) return (x_train, y_train)
def retrieve_framework_data(self): raise NotImplementedError("Method not implemented")
[docs] def on_data_retrieve(self, config): if config["data_source"] == "local": raw_training_data = self.retrieve_local_data(config) elif config["data_source"] == "framework": raw_training_data = self.retrieve_framework_data(config) else: # possibly allow for custom data_source and processing function? # otherwise throw an error that the data_source can be one of the two options raise KeyError("Please choose one of data sources: ['local', 'framework']") return raw_training_data
def on_data_process(self, data, config): return data
[docs] def on_model_send( self, fit_results, training_data=None, config=None, task_runtime=None, task_start_timestamp=None, task_finish_timestamp=None, data_processing_runtime=None, training_runtime=None, endpoint_physical_cores=None, endpoint_logical_cores=None, endpoint_physical_memory=None, platform_name=None, ): return { "model_weights": fit_results, "samples_count": self.get_number_of_samples(training_data, config), "task_runtime": task_runtime, "task_start_timestamp": task_start_timestamp, "task_finish_timestamp": task_finish_timestamp, "data_processing_runtime": data_processing_runtime, "training_runtime": training_runtime, "endpoint_physical_cores": endpoint_physical_cores, "endpoint_logical_cores": endpoint_logical_cores, "endpoint_physical_memory": endpoint_physical_memory, "endpoint_platform_name": platform_name, }
def get_number_of_samples(self, training_data, config): return config.get("num_samples", None)
[docs] def run_round(self, config, model_trainer): import platform from datetime import datetime from timeit import default_timer as timer import psutil task_start_time = timer() task_start_timestamp = datetime.utcnow() data_processing_start = timer() raw_training_data = self.on_data_retrieve(config) processed_training_data = self.on_data_process(raw_training_data, config) data_processing_runtime = timer() - data_processing_start training_start = timer() fit_results = self.on_model_fit(model_trainer, config, processed_training_data) training_runtime = timer() - training_start task_runtime = timer() - task_start_time task_finish_timestamp = datetime.utcnow() physical_cores = psutil.cpu_count(logical=False) logical_cores = psutil.cpu_count(logical=True) physical_memory = psutil.virtual_memory().total platform_name = platform.node() task_results = self.on_model_send( fit_results, training_data=processed_training_data, config=config, task_runtime=task_runtime, task_start_timestamp=task_start_timestamp, task_finish_timestamp=task_finish_timestamp, data_processing_runtime=data_processing_runtime, training_runtime=training_runtime, endpoint_physical_cores=physical_cores, endpoint_logical_cores=logical_cores, endpoint_physical_memory=physical_memory, platform_name=platform_name, ) return task_results