Federated Learning
Federated Learning Basics
Federated learning enables multiple parties to collaboratively train a model without sharing their private data. Each participant trains locally and shares only model updates.
Federated Averaging (FedAvg)
import numpy as np
from typing import List, Dict
import torch
import torch.nn as nn
class FederatedClient:
def __init__(self, model, data, labels):
self.model = model
self.data = data
self.labels = labels
self.optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
self.criterion = nn.CrossEntropyLoss()
def local_train(self, global_weights, epochs=5):
self.model.load_state_dict(global_weights)
for epoch in range(epochs):
self.optimizer.zero_grad()
outputs = self.model(self.data)
loss = self.criterion(outputs, self.labels)
loss.backward()
self.optimizer.step()
return self.model.state_dict()
class FederatedServer:
def __init__(self, global_model):
self.global_model = global_model
self.clients: List[FederatedClient] = []
def add_client(self, client: FederatedClient):
self.clients.append(client)
def aggregate(self, client_weights: List[Dict]) -> Dict:
avg_weights = {}
for key in client_weights[0].keys():
avg_weights[key] = torch.stack([
weights[key].float() for weights in client_weights
]).mean(dim=0)
return avg_weights
def federated_round(self, rounds=10):
for round_num in range(rounds):
client_weights = []
for client in self.clients:
weights = client.local_train(self.global_model.state_dict())
client_weights.append(weights)
new_weights = self.aggregate(client_weights)
self.global_model.load_state_dict(new_weights)
print(f"Round {round_num + 1} completed")
return self.global_model
server = FederatedServer(global_model)
for data, labels in client_datasets:
server.add_client(FederatedClient(model, data, labels))
final_model = server.federated_round(rounds=5)
Differential Privacy
import torch
class DPFederatedClient:
def __init__(self, model, data, labels, epsilon=1.0, delta=1e-5):
self.model = model
self.data = data
self.labels = labels
self.epsilon = epsilon
self.delta = delta
def compute_private_gradients(self, global_weights, max_grad_norm=1.0):
self.model.load_state_dict(global_weights)
outputs = self.model(self.data)
loss = nn.CrossEntropyLoss()(outputs, self.labels)
loss.backward()
total_norm = 0
for p in self.model.parameters():
if p.grad is not None:
total_norm += p.grad.data.norm(2).item() ** 2
total_norm = total_norm ** 0.5
clip_factor = max_grad_norm / (total_norm + 1e-6)
for p in self.model.parameters():
if p.grad is not None:
p.grad.data.mul_(min(clip_factor, 1.0))
noise = torch.randn_like(p.grad) * max_grad_norm / self.epsilon
p.grad.data.add_(noise)
return self.model.state_dict()
Secure Aggregation
import hashlib
from typing import Tuple
class SecureAggregator:
def __init__(self, num_clients: int):
self.num_clients = num_clients
self.masks = {}
def generate_masks(self) -> Dict[int, torch.Tensor]:
for i in range(self.num_clients):
for j in range(self.num_clients):
if i != j:
seed = hashlib.sha256(f"{i}-{j}".encode()).digest()
generator = torch.Generator()
generator.manual_seed(int.from_bytes(seed[:4], 'big'))
mask = torch.randn(1000, generator=generator)
self.masks[(i, j)] = mask
return self.masks
def apply_mask(self, client_id: int, update: torch.Tensor) -> torch.Tensor:
masked_update = update.clone()
for j in range(self.num_clients):
if client_id != j:
if (client_id, j) in self.masks:
masked_update += self.masks[(client_id, j)]
if (j, client_id) in self.masks:
masked_update -= self.masks[(j, client_id)]
return masked_update
def aggregate_masked(self, masked_updates: List[torch.Tensor]) -> torch.Tensor:
return torch.stack(masked_updates).mean(dim=0)
aggregator = SecureAggregator(num_clients=5)
masked_updates = [aggregator.apply_mask(i, update) for i, update in enumerate(updates)]
aggregated = aggregator.aggregate_masked(masked_updates)
Communication Efficiency
class CompressedCommunicator:
def __init__(self, compression_rate=0.1):
self.compression_rate = compression_rate
def top_k_compress(self, tensor: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
flat = tensor.flatten()
k = max(1, int(len(flat) * self.compression_rate))
_, indices = torch.topk(flat.abs(), k)
values = flat[indices]
return values, indices
def decompress(self, values: torch.Tensor, indices: torch.Tensor, shape) -> torch.Tensor:
result = torch.zeros(shape).flatten()
result[indices] = values
return result.reshape(shape)
def sparsify_update(self, update: Dict[str, torch.Tensor]) -> Dict:
compressed = {}
for name, tensor in update.items():
values, indices = self.top_k_compress(tensor)
compressed[name] = {"values": values, "indices": indices, "shape": tensor.shape}
return compressed
communicator = CompressedCommunicator(compression_rate=0.05)
compressed_update = communicator.sparsify_update(model_update)
Best Practices
- Use secure aggregation to protect individual updates
- Implement differential privacy for formal guarantees
- Compress updates to reduce communication costs
- Handle non-IID data distributions appropriately
- Monitor convergence and adjust hyperparameters
- Implement Byzantine fault tolerance