Aggregator API Specs
The Aggregator API (i.e. fl_aggregator
) is the central component in the federated learning process with SyftBox. This API orchestrates the federated learning workflow by
initiating the computation, and managing updates across distributed clients. During each training iteration, it collects local model updates, applies an aggregation strategy (e.g. federated averaging), and creates an improved global model to redistribute to all clients.
Federated Averaging (FedAvg
) is the most common method for combining client updates in federated learning. At its core, it computes a weighted average of the model updates received from clients, where weights typically reflect the size of each client's training dataset. This simple yet effective approach allows the global model to benefit from all clients' contributions while accounting for differences in their data distributions.
Source Code
Find the complete fl_aggregator
API code and some sample data to start with in our GitHub repository.
Project Structure
The fl_aggregator
project has the following structure:
fl_aggregator/
├── dashboard/ # HTML+asset files of the dashboard to monitor the FL workflow
├── main.py # Main application logic
├── requirements.txt # Python package dependencies
├── run.sh # Setup and run script
├── samples/
│ ├── launch_config/ # Configuration files for launching FL flow
│ │ ├── fl_config.json
│ │ ├── global_model_weight.pt # Global model parameters
│ │ └── model.py # ML model architecture (shared with all clients)
│ └── test_data/ # Test data for global model evaluation
│ └── mnist_test_dataset.pt
└── utils.py # Collection of library/utility functions used in main.py
Required Dependencies
- Python (version 3.12)
- SyftBox
- PyTorch
The local virtual environment is automatically created as part of the running script (More details here)
Running the API locally in Development Mode
It is also possible to run the API in development mode, using SyftBox local development server. For further information and instructions please check this guide.
API Workflow
The fl_aggregator
API operates in three main phases:
-
Initialization
- Establish shared directories for model and data exchange
- Configure client access permissions and security settings
-
Setup and Distribution
- Share federated learning configuration with clients
- Distribute model architecture specifications
- Deploy initial model weights
- Launch the first training round
-
Training Coordination
- Collect and aggregate local model updates from clients
- Apply the chosen aggregation strategy (e.g.
FedAvg
) - Update the global model parameters
- Evaluate model performance on private test data
- Trigger next training round by distributing updated global model
Implementation Details
Let's dive into the core implementation details of each phase. While we'll highlight the key code segments here, you can find the complete source code in our GitHub repository.
1. Initializing the API directories
The init_aggregator
function sets up the required directory structure for the Aggregator:
- Creates three subdirectories within the
fl_aggregator
api_data
directory:launch
: For initial setup filesrunning
: For active training datadone
: For completed training data
- Creates a private directory to store test data used for model evaluation.
def init_aggregator(client: Client) -> None:
fl_aggregator = client.api_data("fl_aggregator")
for folder in ["launch", "running", "done"]:
fl_aggregator_folder = fl_aggregator / folder
fl_aggregator_folder.mkdir(parents=True, exist_ok=True)
# Create the private data directory for the app
# This is where the private test data will be stored
app_pvt_dir = get_app_private_data(client, "fl_aggregator")
app_pvt_dir.mkdir(parents=True, exist_ok=True)
2. Starting the Federated Learning workflow
The following code snippet shows how the aggregator initializes the project directories.
def init_project_directory(client: Client, fl_config_json_path: Path) -> None:
# Read the fl_config.json file
# [...] omissis
# Create the project folder
fl_clients_folder = proj_folder / "fl_clients"
agg_weights_folder = proj_folder / "agg_weights"
# ...
# create the folders for the participants
for participant in participants:
participant_folder = fl_clients_folder / participant
participant_folder.mkdir(parents=True, exist_ok=True)
# Give participant write access to the project folder
add_public_write_permission(client, participant_folder)
# Move the config file to the project's running folder
shutil.move(fl_config_json_path, proj_folder)
# move the model architecture to the project's running folder
model_arch_src = fl_aggregator / "launch" / fl_config["model_arch"]
shutil.move(model_arch_src, proj_folder)
# copy the global model weights to the project's agg_weights folder as `agg_model_round_0.pt`
# and move the global model weights to the project's running folder
model_weights_src = fl_aggregator / "launch" / fl_config["model_weight"]
shutil.copy(model_weights_src, agg_weights_folder / "agg_model_round_0.pt")
shutil.move(model_weights_src, proj_folder)
# Initialise the metrics dashboard for the FL project
create_metrics_dashboard(client, fl_config, participants, proj_name)
3. Model Aggregation and Evaluation
The following code snippet captures the logic for aggregating the model updates and evaluating the global model. The rest of the implementation details can be found in the full code linked below.
def aggregate_and_evaluate(client: Client, proj_folder: Path):
"""
1. Wait for the trained model from the clients
3. Aggregate the trained model and place it in the `agg_weights` folder
4. Send the aggregated model to all the clients
5. Repeat until all the rounds are complete
"""
# [...] omissis
# Aggregate the trained model (using FedAvg)
agg_model_output_path = aggregate_model(
fl_config, proj_folder, trained_model_paths, current_round
)
# Test dataset for model evaluation
test_dataset_dir = get_app_private_data(client, "fl_aggregator")
test_dataset_path = test_dataset_dir / fl_config["test_dataset"]
# Load updated Global model
model_class = load_model_class(
proj_folder / fl_config["model_arch"], fl_config["model_class_name"]
)
model: nn.Module = model_class()
model.load_state_dict(torch.load(str(agg_model_output_path), weights_only=True))
# Evaluate Global model
accuracy = evaluate_agg_model(model, test_dataset_path)
print(f"Accuracy of the aggregated model for round {current_round}: {accuracy}")
# Save the model accuracy metrics
save_model_accuracy_metrics(client, proj_folder, current_round, accuracy)
# Send the aggregated model to all the clients
share_agg_model_to_peers(client, proj_folder, agg_model_output_path, participants)
Full implementation details, with reference to specific inner functions:
aggregate_and_evaluate(client, proj_folder)
aggregate_model
: implements the federated aggregation strategyevaluate_agg_model
: runs model evaluation on test datasave_model_accuracy_metrics
: saves model performance metricsshare_agg_model_to_peers
: shares model updates across the distributed datasite clients.
Privacy & Security Considerations
The aggregator's access to raw model updates creates potential privacy risks through model inversion attacks, where private training data could be inferred from updates. This can be mitigated using techniques like