Tuning for SCAFFOLD Server + Client Learning Rates $\eta_g$ and $\eta_i$?
Hello, for my project on training a binary classifier on a residual CNN - data dimension (12, 4096), I am simulating federated training using FedAvg, FedProx and SCAFFOLD as the aggregation strategies.
In my current academic project, your benchmarking client.py and server.py have been more than useful in setting up the groundwork, getting familiar with non-IID FL and to do detailed federated simulations. Thank you very much!
This can be highly subjective to the data in question, and for the learning task, but I can describe more in detail about the training process at the moment. In centralized training, it achieved 60% Precision Recall AUC in 5 conducted epochs, and with Federated Averaging for 1 epoch, the metric performances match, and for larger local epoch values (5), achieved Precision Recall AUC reaches around 0.75-0.8. I am trying for the same dataset, but different IID levels, by keeping the size same, but having the spread of the features imbalanced between clients.
When trying FedProx, it can reach the same amount of performance, with reduced client training loss volatility on the non-IID scenarios I try. However, when trying SCAFFOLD, the performance increase is too small, and all client training losses stay at the same level without changing for 10+ communication rounds. Federated Averaging and FedProx seem to reach the elbow point of the training losses in around 5-6 communication rounds.
Have you all also encountered this issue in SCAFFOLD? For SCAFFOLD, I found that the server step size was fixed to 1.0, and client step size can be defined as configuration hyperparameters.
- Is there a way to know how to tune the server or the client step sizes, and whether to increase/decrease based on what signs in training?
- Is there a server learning rate or client learning rate that you had to change compared to classic Federated Averaging that you had to modify for SCAFFOLD in particular?
Environment
- Python 3.12.3, Flower 1.26.1
- The files below are in a folder known as
pytorchexample and having a Ray session already running first, run python -m pytorchexample.run_scaffold. Data Loaders are valid torch.utils.data.DataLoader objects.
Optimizer specific to SCAFFOLD:
class ScaffoldOptimizer(SGD):
"""Implements SGD optimizer step function as defined in the SCAFFOLD paper."""
def __init__(self, grads, step_size, momentum, weight_decay):
super().__init__(
grads, lr=step_size, momentum=momentum, weight_decay=weight_decay
)
def step_custom(self, server_cv, client_cv):
"""Implement the custom step function fo SCAFFOLD."""
# y_i = y_i - \eta * (g_i + c - c_i) -->
# y_i = y_i - \eta*(g_i + \mu*b_{t}) - \eta*(c - c_i)
self.step()
for group in self.param_groups:
for par, s_cv, c_cv in zip(group["params"], server_cv, client_cv):
print("parameter vs. s_cv - c_cv", par.grad.norm(), (s_cv.to(par.device) - c_cv.to(par.device)).norm())
par.data.add_(s_cv.to(par.device) - c_cv.to(par.device), alpha=-group["lr"])
Server Logic:
class ScaffoldStrategy(FedAvg):
"""Implement custom strategy for SCAFFOLD based on FedAvg class."""
def aggregate_fit(
self,
server_round: int,
results: List[Tuple[ClientProxy, FitRes]],
failures: List[Union[Tuple[ClientProxy, FitRes], BaseException]],
) -> Tuple[Optional[Parameters], Dict[str, Scalar]]:
"""Aggregate fit results using weighted average."""
if not results:
return None, {}
# Do not aggregate if there are failures and failures are not accepted
if not self.accept_failures and failures:
return None, {}
combined_parameters_all_updates = [
parameters_to_ndarrays(fit_res.parameters) for _, fit_res in results
]
len_combined_parameter = len(combined_parameters_all_updates[0])
num_examples_all_updates = [fit_res.num_examples for _, fit_res in results]
# Zip parameters and num_examples
weights_results = [
(update[: len_combined_parameter // 2], num_examples)
for update, num_examples in zip(
combined_parameters_all_updates, num_examples_all_updates
)
]
# Aggregate parameters
"""Weighted Update
parameters_aggregated = aggregate(weights_results)
"""
param_updates = [update[0] for update in weights_results]
parameters_aggregated = [np.mean(layer, axis=0) for layer in zip(*param_updates)]
# Zip client_cv_updates and num_examples
client_cv_updates_and_num_examples = [
(update[len_combined_parameter // 2 :], num_examples)
for update, num_examples in zip(
combined_parameters_all_updates, num_examples_all_updates
)
]
"""Weighted Update
aggregated_cv_update = aggregate(client_cv_updates_and_num_examples)
"""
cv_updates = [update[0] for update in client_cv_updates_and_num_examples]
aggregated_cv_update = [np.mean(layer, axis=0) for layer in zip(*cv_updates)]
# Aggregate custom metrics if aggregation fn was provided
metrics_aggregated = {}
if self.fit_metrics_aggregation_fn:
fit_metrics = [(res.num_examples, res.metrics) for _, res in results]
metrics_aggregated = self.fit_metrics_aggregation_fn(fit_metrics)
#elif server_round == 1: # Only log this warning once
# log(WARNING, "No fit_metrics_aggregation_fn provided")
return (
ndarrays_to_parameters(parameters_aggregated + aggregated_cv_update),
metrics_aggregated,
)
# pylint: disable=too-many-locals
def fit_round(
self,
server_round: int,
timeout: Optional[float],
) -> Optional[
Tuple[Optional[Parameters], Dict[str, Scalar], FitResultsAndFailures]
]:
"""Perform a single round of federated averaging."""
# Get clients and their respective instructions from strateg
client_instructions = self.strategy.configure_fit(
server_round=server_round,
parameters=update_parameters_with_cv(self.parameters, self.server_cv),
client_manager=self._client_manager,
)
if not client_instructions:
log(INFO, "fit_round %s: no clients selected, cancel", server_round)
return None
log(
DEBUG,
"fit_round %s: strategy sampled %s clients (out of %s)",
server_round,
len(client_instructions),
self._client_manager.num_available(),
)
# Collect `fit` results from all clients participating in this round
results, failures = fit_clients(
client_instructions=client_instructions,
max_workers=self.max_workers,
timeout=timeout,
group_id=server_round,
)
log(
DEBUG,
"fit_round %s received %s results and %s failures",
server_round,
len(results),
len(failures),
)
# Aggregate training results
aggregated_result: Tuple[Optional[Parameters], Dict[str, Scalar]] = (
self.strategy.aggregate_fit(server_round, results, failures)
)
#aggregated_result_arrays_combined = []
if aggregated_result[0] is None:
return None
aggregated_result_arrays_combined = parameters_to_ndarrays(
aggregated_result[0]
)
aggregated_parameters = aggregated_result_arrays_combined[
: len(aggregated_result_arrays_combined) // 2
]
aggregated_cv_update = aggregated_result_arrays_combined[
len(aggregated_result_arrays_combined) // 2 :
]
# convert server cv into ndarrays
server_cv_np = [cv.numpy() for cv in self.server_cv]
# update server cv
total_clients = len(self._client_manager.all())
cv_multiplier = len(results) / total_clients
self.server_cv = [
torch.from_numpy(cv + cv_multiplier * aggregated_cv_update[i])
for i, cv in enumerate(server_cv_np)
]
# update parameters x = x + global_lr* aggregated_update
curr_params = parameters_to_ndarrays(self.parameters)
updated_params = [
x + (self.global_lr*aggregated_parameters[i]) for i, x in enumerate(curr_params)
]
parameters_updated = ndarrays_to_parameters(updated_params)
# metrics
metrics_aggregated = aggregated_result[1]
return parameters_updated, metrics_aggregated, (results, failures)
def update_parameters_with_cv(
parameters: Parameters, s_cv: List[torch.Tensor]
) -> Parameters:
"""Extend the list of parameters with the server control variate."""
# extend the list of parameters arrays with the cv arrays
cv_np = [cv.numpy() for cv in s_cv]
parameters_np = parameters_to_ndarrays(parameters)
parameters_np.extend(cv_np)
return ndarrays_to_parameters(parameters_np)
Client Logic:
class FlowerClientScaffold(fl.client.NumPyClient):
"""Flower client implementing scaffold."""
# pylint: disable=too-many-arguments
def __init__(
self,
cid: int,
net: torch.nn.Module,
num_partitions: int,
batch_size: int,
val,
device: torch.device,
num_epochs: int,
learning_rate: float,
momentum: float,
weight_decay: float,
save_dir: str = "",
) -> None:
self.cid = cid
self.net = net
self.num_partitions = num_partitions
self.batch_size = batch_size
self.val = val
self.device = device
self.num_epochs = num_epochs
self.learning_rate = learning_rate
self.momentum = momentum
self.weight_decay = weight_decay
# initialize client control variate with 0 and shape of the network parameters
self.client_cv = []
# load_datasets
self.trainloader = None
self.valloader = None
for param in self.net.parameters():
self.client_cv.append(torch.zeros(param.shape))
# save cv to directory
if save_dir == "":
save_dir = "client_cvs"
self.dir = save_dir
if not os.path.exists(self.dir):
os.makedirs(self.dir)
def get_parameters(self, config: Dict[str, Scalar]):
"""Return the current local model parameters."""
return [val.detach().cpu().numpy() for val in self.net.parameters()]
def set_parameters(self, parameters):
"""Set the local model parameters using given ones."""
#params_dict = zip(list(self.net.state_dict().keys()), parameters)
#state_dict = OrderedDict({k: torch.Tensor(v) for k, v in params_dict})
params_list = list(self.net.parameters())
if len(parameters) != len(params_list):
raise ValueError(f"Expected {len(params_list)} but got {len(parameters)}")
with torch.no_grad():
for i, param in enumerate(params_list):
param.copy_(torch.as_tensor(parameters[i]).to(param.device))
#self.net.load_state_dict(state_dict, strict=True)
def fit(self, parameters, config: Dict[str, Scalar]):
"""Implement distributed fit function for a given client for SCAFFOLD."""
# the first half are model parameters and the second are the server_cv
server_cv = parameters[len(parameters) // 2 :]
parameters = parameters[: len(parameters) // 2]
self.set_parameters(parameters)
self.client_cv = []
#for param in self.net.parameters():
# self.client_cv.append(param.clone().detach())
# load client control variate
if os.path.exists(f"{self.dir}/client_cv_{self.cid}.pt"):
self.client_cv = torch.load(f"{self.dir}/client_cv_{self.cid}.pt")
else:
self.client_cv = [torch.zeros_like(param) for param in self.net.parameters()]
# convert the server control variate to a list of tensors
server_cv = [torch.Tensor(cv) for cv in server_cv]
#print(self.cid)
trainloader, valloader = load_datasets(
self.cid,
self.num_partitions,
self.batch_size,
partitioning="dirichlet",
val=self.val, device=self.device
)
print(f"[client {self.cid}] starting train_scaffold with batches {len(trainloader)}")
start_time = time.time()
train_loss, net, count = train_scaffold({
"net": self.net,
"partition_id": self.cid,
"trainloader": trainloader,
"valloader": valloader,
"epochs": self.num_epochs,
"lr": self.learning_rate,
"batch_size": self.batch_size
},
server_cv,
self.client_cv,
)
end_time = time.time()
training_time = end_time - start_time
print("training done!")
x = parameters
y_i = self.get_parameters(config={})
c_i_n = []
server_update_x = []
server_update_c = []
# update client control variate c_i_1 = c_i - c + 1/eta*K (x - y_i)
for c_i_j, c_j, x_j, y_i_j in zip(self.client_cv, server_cv, x, y_i):
#print(c_i_j.device, c_j.device)
c_i_n.append(
c_i_j.cpu()
- c_j.cpu()
+ (1.0 / (self.learning_rate * self.num_epochs * len(trainloader)))
* (x_j - y_i_j)
)
# y_i - x, c_i_n - c_i for the server
server_update_x.append((y_i_j - x_j))
server_update_c.append((c_i_n[-1] - c_i_j.cpu()).cpu().numpy())
self.client_cv = c_i_n
torch.save(self.client_cv, f"{self.dir}/client_cv_{self.cid}.pt")
combined_updates = server_update_x + server_update_c
metrics = {
"train_loss": train_loss,
"num-examples": len(trainloader.dataset),
"training_time": training_time,
"local-epochs": self.num_epochs,
"partition_id": self.cid
}
#metric_record = MetricRecord(metrics)
#content = RecordDict({"arrays": model_record, "metrics": metric_record})
with open(f'{self.dir}/clients{self.num_partitions}-partitioningdirichlet{self.val}-loceps{self.num_epochs}.txt', "a") as logger:
logger.write(f"{str(metrics)}\n")
return (
combined_updates,
len(trainloader.dataset),
{},
)
def evaluate(self, parameters, config: Dict[str, Scalar]):
"""Evaluate using given parameters."""
self.set_parameters(parameters)
trainloader, valloader = load_datasets(
self.cid,
self.num_partitions,
self.batch_size,
partitioning="dirichlet",
val=self.val, device=self.device
)
self.trainloader = trainloader
self.valloader = valloader
loss, acc = test(self.net, self.valloader, self.device)
metrics = {
"eval_loss": float(loss),
"eval_acc": float(acc),
"num-examples": len(self.valloader.dataset)
}
with open(f'{self.dir}/clients{self.num_partitions}-partitioningdirichlet{self.val}-loceps{self.num_epochs}.txt', "a") as logger:
logger.write(f"{str(metrics)}\n")
return float(loss), len(self.valloader.dataset), {"accuracy": float(acc)}
Tuning for SCAFFOLD Server + Client Learning Rates$\eta_g$ and $\eta_i$ ?
Hello, for my project on training a binary classifier on a residual CNN - data dimension (12, 4096), I am simulating federated training using FedAvg, FedProx and SCAFFOLD as the aggregation strategies.
In my current academic project, your benchmarking client.py and server.py have been more than useful in setting up the groundwork, getting familiar with non-IID FL and to do detailed federated simulations. Thank you very much!
This can be highly subjective to the data in question, and for the learning task, but I can describe more in detail about the training process at the moment. In centralized training, it achieved 60% Precision Recall AUC in 5 conducted epochs, and with Federated Averaging for 1 epoch, the metric performances match, and for larger local epoch values (5), achieved Precision Recall AUC reaches around 0.75-0.8. I am trying for the same dataset, but different IID levels, by keeping the size same, but having the spread of the features imbalanced between clients.
When trying FedProx, it can reach the same amount of performance, with reduced client training loss volatility on the non-IID scenarios I try. However, when trying SCAFFOLD, the performance increase is too small, and all client training losses stay at the same level without changing for 10+ communication rounds. Federated Averaging and FedProx seem to reach the elbow point of the training losses in around 5-6 communication rounds.
Have you all also encountered this issue in SCAFFOLD? For SCAFFOLD, I found that the server step size was fixed to 1.0, and client step size can be defined as configuration hyperparameters.
Environment
pytorchexampleand having a Ray session already running first, runpython -m pytorchexample.run_scaffold. Data Loaders are validtorch.utils.data.DataLoaderobjects.Optimizer specific to SCAFFOLD:
Server Logic:
Client Logic: