FLoX Client-Controller Logic¶
FLoX is designed to be a highly modular and customizable framework for
serverless, FL processes. It is built on top of a general 10-step
abstraction for FL processes grouped into controller-side and
client-side logical steps. FLoX considers two abstract classes,
FloxControllerLogic
and FloxClientLogic
, for
wrapping/implementing the logic needed for the controller-side and
client-side, respectively. Each step is run on either the Controller (S)
or the clients (C). All 10 steps are listed below with the respective
function that corresponds with them:
(S) Model Initialization:
on_model_init
(S) Model Sharing:
on_model_broadcast
(C) Receiving Model on Client:
on_model_receive
(C) Data Fetching:
on_data_retrieve
(C) Local Model Training:
on_model_fit
(C) Model Parameter Submission:
on_model_send
(S) Receiving Model on Controller:
on_model_receive
(S) Model Aggregation:
on_model_aggregate
(S) Model Updating:
on_model_update
(S) Model Evaluation:
on_model_evaluate
Below is a (mermaid.js
) figure showing the sequence of the logic
steps. Then, we further describe each of the logic steps.
flowchart
id1[[on_model_init]]
id2[/on_model_broadcast\]
id3[\on_model_receive/]
id4[(on_data_retrieve)]
id5{on_model_fit}
id6[on_model_send]
id7[on_model_receive]
id8[on_model_aggregate]
id9[on_model_update]
id10[on_model_evaluate]
subgraph controller
id1-->id2
id7-->id8
id8-->id9
id9-->id10
id10-->id2
end
subgraph client
direction TB
id3-->id4
id4-->id5
id5-->id6
end
id2-->|model params|id3
id6-->|model params|id7
1. Model Initialization¶
Controller.on_model_init()
is where one would provide initial
setup scripts that need to be run only once rather than needing to run
every FL round. In flox.controllers.TensorflowController
we use
.on_model_init()
to set up variables that were not provided by the
user but will be reused in all of the FL rounds going forward.
2. Model Broadcasting¶
Model broadcasting happens in Controller.on_model_broadcast
where all variables
are assembled into a config file, data is encrypted if
necessary, and the tasks are deployed to the clients using an Executor,
such as FuncXExecutor for remote execution or ThreadPoolExecutor for local one.
The method should return a list of futures/tasks
that can later be parsed out by Controller.on_model_receive()
once
clients return the results.
3. Receiving Model on Client¶
Once clients receive the model and the config with necessary
parameters, Client.on_model_receive()
is responsible for the initial
actions such as decrypting the data if it’s encrypted.
4. Data Retrieval¶
Client.on_data_retrieve()
is where clients retrieve and prepare
their data for training.
5. Local Model Training¶
Client.on_data_fit()
is where the training process is defined and
executed.
6. Model Parameter Submission¶
When clients have finished retrieving local data and making updates to the global
model, the new model weights are returned to the Controller and
Client.on_data_send()
is for things like encryption of data before
it is sent back.
7. Receiving Model on Controller¶
Once the Controller receives the results back from the clients,
Controller.on_model_receive()
parses the results and decrypts if
necessary.
8. Model Aggregation¶
Controller.on_model_aggregate()
takes the parsed results from
Controller.on_model_receive()
and aggregates weights from the
endpoints.
9. Model Updating¶
Controller.on_model_update()
simply takes the new weights from
Controller.on_model_aggregate()
and assigns them to the global
model.
10. Model Evaluation¶
Finally, Controller.on_model_evaluate()
evaluates the model using
a user-provided testing dataset, reports the results, and then the
entire loop from Step 2 to Step 10 is repeated for as many rounds as was
specified by the user.
More on Controllers and Clients¶
Each Controller has a .run_federated_learning()
method which iteratively calls
each controller method to facilitate the Federated Learning rounds. Each Client has
a .run_round()
method which calls its client methods to facilitate a single round of
FL and return the updated model weights. This .run_round()
method is the function that
gets submitted to the Executor and should return the updated model weights.
We also make use of Model Trainers to facilitate Machine Learning-related computations.
Each Model Trainer should implement four methods: .fit(), .evaluate(), .set_weights(),
and .get_weights()
which are called by both the Controller and the Client to fit the model, get the weights,
evaluate the model, and set the new weights.
We implemented the abstract base classes in
flox.logic.base_client.py
and flox.logic.base_controller.py
.
We implemented a base class for Machine Learning Model Trainers in flox.logic.base_model_trainer.py
.
To facilitate most of FL Controller-side computations, we implemented the
MainController
under flox/controllers
.
Initially, we had full implementations of controllers for each ML framework (Tensorflow, PyTorch).
However, there was a lot of code duplication, and the differences between requirements of different
ML frameworks were small. Thus, we put the majority of shared functionality under MainController
and left just a few methods that need to be extended for specific ML frameworks.
For example, since ML framework-specific clients require different variables for training
(e.g., our implementation of the Tensorflow training loop requires input_shape
while PyTorch doesn’t),
we created the create_config()
method that should return a dictionary of variables
that the ML framework-specific Client needs for training. The specific ML Model Trainers might also
differ in what parameters their methods (like .evaluate()
and .set_weights()
) accept,
thus users can override MainController
’s default implementation of on_model_evaluate()
and on_model_update()
, which call those Model Trainer methods, to provide different parameters.
For a concrete example, first look at MainController
and then see how PyTorchController
and
TensorflowController
extend and override MainController
differently to be compatible with its
corresponding clients and Model Trainers.
We are providing practical examples on top of these classes to illustrate how all of these steps come together:
flox.examples.quickstart_pytorch
makes
use of PyTorchController
, PyTorchClient
, and PyTorchTrainer
to run a Federated Learning workflow on PyTorch.
flox.examples.quickstart_tensorflow
makes use of
TensorflowController
, TensorflowClient
, and
TensorflowTrainer
to run a Federated Learning workflow on
Tensorflow.
Issues & Points of Improvement¶
You can see how MainController
reduces code duplication on the Controller side.
However, it’s not the same for the Client side. For example, let’s take a look at the
PyTorchClient
and TensorflowClient
under flox/clients
. All of their core methods
differ in implementation and the parameters they accept. Their .run_round()
methods thus
also differ since the methods needs to accept different parameters. Now, if we wanted to
start timing how long it takes for each function to run using time.time()
,
we would need to add that piece of code to both the PyTorchClient
and TensorflowClient
,
thus complicating code maintenance. It would be nice to have a MainClient
that would
provide more structure to implementations and maintenance of Clients, but it’s not clear
how to do so since the existing Clients share very little in their implementation.
Another point of concern is the coupling between the Clients, Model Trainers, and Controllers based on the ML framework, since they require different parameters at times. This creates a lot of coupling and complicates the management of the system for the user since they need to implement/extend three classes to run an FL experiment on a new ML framework. I was wondering if we need to have Model Trainer as a class at all, and if using functions would make it less complicated. However, having it as a class also makes it easier to track variables and keep all ML-related functions and variables together.