Skip to main content

Client API Specs

The Client API(i.e. fl_client) is responsible for training a machine learning model on local data while ensuring data privacy. While the model updates are shared with the aggregator, the training data never leaves the datasite.

In this tutorial, we will consider an example federated learning experiment to train a neural network classifier on subsets of the MNIST dataset.

Source Code

Find the complete fl_client API code and some sample data to start with in our GitHub repository.

Project Structure

fl_client/
├── mnist_samples/ # Private dataset (MNIST subsets)
│ └── mnist_label_*.pt
├── .gitignore # Git ignore rules
├── .python-version # Python version specification
├── main.py # Main application logic
├── requirements.txt # Python package dependencies
├── run.sh # Setup and run script
└── utils.py # Collection of library/utility functions used in main.py

Required Dependencies

  • Python (version 3.12)
  • SyftBox
  • PyTorch

The local virtual environment used by the API is automatically created as part of the running script (More details here)

API Workflow

The execution of the fl_client API follows these main steps:

  1. Initialization
  • Set up shared directories
  • Configure access permissions
  1. Model Training
  • Load the global model
  • Train on private local data
  1. Update Sharing
  • Save model updates for aggregation (automatically shared with the Aggregator via SyftBox)
  • Await next round's of training

Implementation Details

In this section we'll explore the key code components of each step. While we'll focus on the most important snippets here, you can find the complete implementation in the GitHub repository.

1. Directory Initialization

First, we set up the shared directories that enable communication between client and aggregator. Here's the core initialization code:

def init_shared_dirs(client: Client, proj_folder: Path) -> None:
round_weights_folder = proj_folder / "round_weights"
agg_weights_folder = proj_folder / "agg_weights"

round_weights_folder.mkdir(parents=True, exist_ok=True)
agg_weights_folder.mkdir(parents=True, exist_ok=True)

add_public_write_permission(client, agg_weights_folder)

create_project_state(client, proj_folder)

View full implementation

2. Local Model Training

Each client Datasite trains the model for one federated learning round using a common machine learning algorithm. This shared algorithm is provided by the aggregator, and ensures consistency across all participating clients. After training completes, the model weights are stored privately in the client's datasite.

def train_model(proj_folder: Path, round_num: int, dataset_files: list[Path]) -> None:
# [...] omissis

# Load the ML model from the configured Python model (as instructed by the Aggregator)
model_class = load_model_class(proj_folder / fl_config["model_arch"])
model: nn.Module = model_class()

# Load the aggregated weights from the previous round
agg_weights_file = agg_weights_folder / f"agg_model_round_{round_num - 1}.pt"
model.load_state_dict(torch.load(agg_weights_file, weights_only=True))

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=fl_config["learning_rate"])

# [...] Training loop here, loaded from input `dataset_files` path

torch.save(model.state_dict(), proj_folder / f"trained_model_round_{round_num}.pt")

View full implementation

3. Sharing Model Updates

After model training, the client transfers model updates to the aggregator, serialised into the appropriate paths within the datasite (i.e. <aggregator_datasite>/api_data/fl_aggregator/running/<project_name>/fl_clients/<client_datasite>). Each client's updates are stored separately, allowing the aggregator to track contributions from individual participants.

def share_model_to_aggregator(client: Client, aggregator_email: str, proj_folder: Path, model_file: Path):
"""Shares the trained model to the aggregator."""
fl_aggregator_app_path = (
client.datasites / f"{aggregator_email}/api_data/fl_aggregator"
)
fl_aggregator_running_folder = fl_aggregator_app_path / "running" / proj_folder.name
fl_aggregator_client_path = (
fl_aggregator_running_folder / "fl_clients" / client.email
)

# Copy the trained model to the aggregator's client folder
shutil.copy(model_file, fl_aggregator_client_path)

Privacy Considerations

The workflow's design prioritizes data privacy by sharing only model updates while keeping raw data secure on the each Datasite. However, sharing model weights alone does not guarantee complete privacy, as they can be vulnerable to attacks by malicious users (e.g. 📝 Model Inversion Attacks).

To achieve stronger privacy, additional safeguards can be implemented. For example, incorporating differential privacy during training helps protect against model inversion attacks by adding calibrated noise to the learning process.