Client API Specs
The Client API(i.e. fl_client
) is responsible for training a machine learning model on local data while ensuring data privacy.
While the model updates are shared with the aggregator, the training data never leaves the datasite.
In this tutorial, we will consider an example federated learning experiment to train a neural network classifier on subsets of the MNIST dataset.
Source Code
Find the complete fl_client
API code and some sample data to start with in our GitHub repository.
Project Structure
fl_client/
├── mnist_samples/ # Private dataset (MNIST subsets)
│ └── mnist_label_*.pt
├── .gitignore # Git ignore rules
├── .python-version # Python version specification
├── main.py # Main application logic
├── requirements.txt # Python package dependencies
├── run.sh # Setup and run script
└── utils.py # Collection of library/utility functions used in main.py
Required Dependencies
- Python (version 3.12)
- SyftBox
- PyTorch
The local virtual environment used by the API is automatically created as part of the running script (More details here)
API Workflow
The execution of the fl_client
API follows these main steps:
- Initialization
- Set up shared directories
- Configure access permissions
- Model Training
- Load the global model
- Train on private local data
- Update Sharing
- Save model updates for aggregation (automatically shared with the Aggregator via SyftBox)
- Await next round's of training
Implementation Details
In this section we'll explore the key code components of each step. While we'll focus on the most important snippets here, you can find the complete implementation in the GitHub repository.
1. Directory Initialization
First, we set up the shared directories that enable communication between client and aggregator. Here's the core initialization code:
def init_shared_dirs(client: Client, proj_folder: Path) -> None:
round_weights_folder = proj_folder / "round_weights"
agg_weights_folder = proj_folder / "agg_weights"
round_weights_folder.mkdir(parents=True, exist_ok=True)
agg_weights_folder.mkdir(parents=True, exist_ok=True)
add_public_write_permission(client, agg_weights_folder)
create_project_state(client, proj_folder)
2. Local Model Training
Each client Datasite trains the model for one federated learning round using a common machine learning algorithm. This shared algorithm is provided by the aggregator, and ensures consistency across all participating clients. After training completes, the model weights are stored privately in the client's datasite.
def train_model(proj_folder: Path, round_num: int, dataset_files: list[Path]) -> None:
# [...] omissis
# Load the ML model from the configured Python model (as instructed by the Aggregator)
model_class = load_model_class(proj_folder / fl_config["model_arch"])
model: nn.Module = model_class()
# Load the aggregated weights from the previous round
agg_weights_file = agg_weights_folder / f"agg_model_round_{round_num - 1}.pt"
model.load_state_dict(torch.load(agg_weights_file, weights_only=True))
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=fl_config["learning_rate"])
# [...] Training loop here, loaded from input `dataset_files` path
torch.save(model.state_dict(), proj_folder / f"trained_model_round_{round_num}.pt")
3. Sharing Model Updates
After model training, the client transfers model updates to the aggregator, serialised into the appropriate paths within the datasite (i.e. <aggregator_datasite>/api_data/fl_aggregator/running/<project_name>/fl_clients/<client_datasite>
).
Each client's updates are stored separately, allowing the aggregator to track contributions from individual participants.
def share_model_to_aggregator(client: Client, aggregator_email: str, proj_folder: Path, model_file: Path):
"""Shares the trained model to the aggregator."""
fl_aggregator_app_path = (
client.datasites / f"{aggregator_email}/api_data/fl_aggregator"
)
fl_aggregator_running_folder = fl_aggregator_app_path / "running" / proj_folder.name
fl_aggregator_client_path = (
fl_aggregator_running_folder / "fl_clients" / client.email
)
# Copy the trained model to the aggregator's client folder
shutil.copy(model_file, fl_aggregator_client_path)
Privacy Considerations
The workflow's design prioritizes data privacy by sharing only model updates while keeping raw data secure on the each Datasite. However, sharing model weights alone does not guarantee complete privacy, as they can be vulnerable to attacks by malicious users (e.g. 📝 Model Inversion Attacks).
To achieve stronger privacy, additional safeguards can be implemented. For example, incorporating differential privacy during training helps protect against model inversion attacks by adding calibrated noise to the learning process.