Skip to content

Finding a good server_stepsize and client_stepsize for SCAFFOLD #30

@adityak714

Description

@adityak714

Tuning for SCAFFOLD Server + Client Learning Rates $\eta_g$ and $\eta_i$?

Hello, for my project on training a binary classifier on a residual CNN - data dimension (12, 4096), I am simulating federated training using FedAvg, FedProx and SCAFFOLD as the aggregation strategies.

In my current academic project, your benchmarking client.py and server.py have been more than useful in setting up the groundwork, getting familiar with non-IID FL and to do detailed federated simulations. Thank you very much!

This can be highly subjective to the data in question, and for the learning task, but I can describe more in detail about the training process at the moment. In centralized training, it achieved 60% Precision Recall AUC in 5 conducted epochs, and with Federated Averaging for 1 epoch, the metric performances match, and for larger local epoch values (5), achieved Precision Recall AUC reaches around 0.75-0.8. I am trying for the same dataset, but different IID levels, by keeping the size same, but having the spread of the features imbalanced between clients.

When trying FedProx, it can reach the same amount of performance, with reduced client training loss volatility on the non-IID scenarios I try. However, when trying SCAFFOLD, the performance increase is too small, and all client training losses stay at the same level without changing for 10+ communication rounds. Federated Averaging and FedProx seem to reach the elbow point of the training losses in around 5-6 communication rounds.

Have you all also encountered this issue in SCAFFOLD? For SCAFFOLD, I found that the server step size was fixed to 1.0, and client step size can be defined as configuration hyperparameters.

  • Is there a way to know how to tune the server or the client step sizes, and whether to increase/decrease based on what signs in training?
  • Is there a server learning rate or client learning rate that you had to change compared to classic Federated Averaging that you had to modify for SCAFFOLD in particular?

Environment

  • Python 3.12.3, Flower 1.26.1
  • The files below are in a folder known as pytorchexample and having a Ray session already running first, run python -m pytorchexample.run_scaffold. Data Loaders are valid torch.utils.data.DataLoader objects.

Optimizer specific to SCAFFOLD:

class ScaffoldOptimizer(SGD):
    """Implements SGD optimizer step function as defined in the SCAFFOLD paper."""

    def __init__(self, grads, step_size, momentum, weight_decay):
        super().__init__(
            grads, lr=step_size, momentum=momentum, weight_decay=weight_decay
        )

    def step_custom(self, server_cv, client_cv):
        """Implement the custom step function fo SCAFFOLD."""
        # y_i = y_i - \eta * (g_i + c - c_i)  -->
        # y_i = y_i - \eta*(g_i + \mu*b_{t}) - \eta*(c - c_i)
        self.step()
        for group in self.param_groups:
            for par, s_cv, c_cv in zip(group["params"], server_cv, client_cv):
                print("parameter vs. s_cv - c_cv", par.grad.norm(), (s_cv.to(par.device) - c_cv.to(par.device)).norm())
                par.data.add_(s_cv.to(par.device) - c_cv.to(par.device), alpha=-group["lr"])

Server Logic:

class ScaffoldStrategy(FedAvg):
    """Implement custom strategy for SCAFFOLD based on FedAvg class."""

    def aggregate_fit(
        self,
        server_round: int,
        results: List[Tuple[ClientProxy, FitRes]],
        failures: List[Union[Tuple[ClientProxy, FitRes], BaseException]],
    ) -> Tuple[Optional[Parameters], Dict[str, Scalar]]:
        """Aggregate fit results using weighted average."""
        if not results:
            return None, {}
        # Do not aggregate if there are failures and failures are not accepted
        if not self.accept_failures and failures:
            return None, {}
        
        combined_parameters_all_updates = [
            parameters_to_ndarrays(fit_res.parameters) for _, fit_res in results
        ]
        len_combined_parameter = len(combined_parameters_all_updates[0])
        num_examples_all_updates = [fit_res.num_examples for _, fit_res in results]
        # Zip parameters and num_examples
        weights_results = [
            (update[: len_combined_parameter // 2], num_examples)
            for update, num_examples in zip(
                combined_parameters_all_updates, num_examples_all_updates
            )
        ]
        # Aggregate parameters
        """Weighted Update
        parameters_aggregated = aggregate(weights_results)
        """
        param_updates = [update[0] for update in weights_results]
        parameters_aggregated = [np.mean(layer, axis=0) for layer in zip(*param_updates)]


        # Zip client_cv_updates and num_examples
        client_cv_updates_and_num_examples = [
            (update[len_combined_parameter // 2 :], num_examples)
            for update, num_examples in zip(
                combined_parameters_all_updates, num_examples_all_updates
            )
        ]
        """Weighted Update
        aggregated_cv_update = aggregate(client_cv_updates_and_num_examples)
        """
        cv_updates = [update[0] for update in client_cv_updates_and_num_examples]
        aggregated_cv_update = [np.mean(layer, axis=0) for layer in zip(*cv_updates)]

        # Aggregate custom metrics if aggregation fn was provided
        metrics_aggregated = {}
        if self.fit_metrics_aggregation_fn:
            fit_metrics = [(res.num_examples, res.metrics) for _, res in results]
            metrics_aggregated = self.fit_metrics_aggregation_fn(fit_metrics)
        #elif server_round == 1:  # Only log this warning once
        #    log(WARNING, "No fit_metrics_aggregation_fn provided")

        return (
            ndarrays_to_parameters(parameters_aggregated + aggregated_cv_update),
            metrics_aggregated,
        )

    # pylint: disable=too-many-locals
    def fit_round(
        self,
        server_round: int,
        timeout: Optional[float],
    ) -> Optional[
        Tuple[Optional[Parameters], Dict[str, Scalar], FitResultsAndFailures]
    ]:
        """Perform a single round of federated averaging."""
        # Get clients and their respective instructions from strateg
        client_instructions = self.strategy.configure_fit(
            server_round=server_round,
            parameters=update_parameters_with_cv(self.parameters, self.server_cv),
            client_manager=self._client_manager,
        )

        if not client_instructions:
            log(INFO, "fit_round %s: no clients selected, cancel", server_round)
            return None
        log(
            DEBUG,
            "fit_round %s: strategy sampled %s clients (out of %s)",
            server_round,
            len(client_instructions),
            self._client_manager.num_available(),
        )

        # Collect `fit` results from all clients participating in this round
        results, failures = fit_clients(
            client_instructions=client_instructions,
            max_workers=self.max_workers,
            timeout=timeout,
            group_id=server_round,
        )
        log(
            DEBUG,
            "fit_round %s received %s results and %s failures",
            server_round,
            len(results),
            len(failures),
        )

        # Aggregate training results
        aggregated_result: Tuple[Optional[Parameters], Dict[str, Scalar]] = (
            self.strategy.aggregate_fit(server_round, results, failures)
        )

        #aggregated_result_arrays_combined = []
        if aggregated_result[0] is None:
            return None
        
        aggregated_result_arrays_combined = parameters_to_ndarrays(
            aggregated_result[0]
        )
        aggregated_parameters = aggregated_result_arrays_combined[
            : len(aggregated_result_arrays_combined) // 2
        ]
        aggregated_cv_update = aggregated_result_arrays_combined[
            len(aggregated_result_arrays_combined) // 2 :
        ]

        # convert server cv into ndarrays
        server_cv_np = [cv.numpy() for cv in self.server_cv]
        # update server cv
        total_clients = len(self._client_manager.all())
        cv_multiplier = len(results) / total_clients
        self.server_cv = [
            torch.from_numpy(cv + cv_multiplier * aggregated_cv_update[i])
            for i, cv in enumerate(server_cv_np)
        ]

        # update parameters x = x + global_lr* aggregated_update
        curr_params = parameters_to_ndarrays(self.parameters)
        updated_params = [
            x + (self.global_lr*aggregated_parameters[i]) for i, x in enumerate(curr_params)
        ]
        parameters_updated = ndarrays_to_parameters(updated_params)

        # metrics
        metrics_aggregated = aggregated_result[1]
        return parameters_updated, metrics_aggregated, (results, failures)


def update_parameters_with_cv(
    parameters: Parameters, s_cv: List[torch.Tensor]
) -> Parameters:
    """Extend the list of parameters with the server control variate."""
    # extend the list of parameters arrays with the cv arrays
    cv_np = [cv.numpy() for cv in s_cv]
    parameters_np = parameters_to_ndarrays(parameters)
    parameters_np.extend(cv_np)
    return ndarrays_to_parameters(parameters_np)

Client Logic:

class FlowerClientScaffold(fl.client.NumPyClient):
    """Flower client implementing scaffold."""

    # pylint: disable=too-many-arguments
    def __init__(
        self,
        cid: int,
        net: torch.nn.Module,
        num_partitions: int,
        batch_size: int,
        val,
        device: torch.device,
        num_epochs: int,
        learning_rate: float,
        momentum: float,
        weight_decay: float,
        save_dir: str = "",
    ) -> None:
        self.cid = cid
        self.net = net
        self.num_partitions = num_partitions
        self.batch_size = batch_size
        self.val = val
        self.device = device
        self.num_epochs = num_epochs
        self.learning_rate = learning_rate
        self.momentum = momentum
        self.weight_decay = weight_decay
        
        # initialize client control variate with 0 and shape of the network parameters
        self.client_cv = []
        
        # load_datasets
        self.trainloader = None
        self.valloader = None
        
        for param in self.net.parameters():
            self.client_cv.append(torch.zeros(param.shape))
        # save cv to directory
        if save_dir == "":
            save_dir = "client_cvs"
        self.dir = save_dir
        if not os.path.exists(self.dir):
            os.makedirs(self.dir)
    
    def get_parameters(self, config: Dict[str, Scalar]):
        """Return the current local model parameters."""
        return [val.detach().cpu().numpy() for val in self.net.parameters()]

    def set_parameters(self, parameters):
        """Set the local model parameters using given ones."""
        #params_dict = zip(list(self.net.state_dict().keys()), parameters)
        #state_dict = OrderedDict({k: torch.Tensor(v) for k, v in params_dict})
        params_list = list(self.net.parameters())
        if len(parameters) != len(params_list):
            raise ValueError(f"Expected {len(params_list)} but got {len(parameters)}")
        
        with torch.no_grad():
            for i, param in enumerate(params_list):
                param.copy_(torch.as_tensor(parameters[i]).to(param.device))
        #self.net.load_state_dict(state_dict, strict=True)

    def fit(self, parameters, config: Dict[str, Scalar]):
        """Implement distributed fit function for a given client for SCAFFOLD."""
        # the first half are model parameters and the second are the server_cv
        server_cv = parameters[len(parameters) // 2 :]
        parameters = parameters[: len(parameters) // 2]
        self.set_parameters(parameters)
        self.client_cv = []
        #for param in self.net.parameters():
        #    self.client_cv.append(param.clone().detach())
        # load client control variate
        if os.path.exists(f"{self.dir}/client_cv_{self.cid}.pt"):
            self.client_cv = torch.load(f"{self.dir}/client_cv_{self.cid}.pt")
        else:
            self.client_cv = [torch.zeros_like(param) for param in self.net.parameters()]
        # convert the server control variate to a list of tensors
        server_cv = [torch.Tensor(cv) for cv in server_cv]

        #print(self.cid)
        trainloader, valloader = load_datasets(
            self.cid, 
            self.num_partitions, 
            self.batch_size,
            partitioning="dirichlet",
            val=self.val, device=self.device
        )

        print(f"[client {self.cid}] starting train_scaffold with batches {len(trainloader)}")
        start_time = time.time()
        train_loss, net, count = train_scaffold({
                "net": self.net,
                "partition_id": self.cid,
                "trainloader": trainloader,
                "valloader": valloader,
                "epochs": self.num_epochs, 
                "lr": self.learning_rate, 
                "batch_size": self.batch_size
            },
            server_cv,
            self.client_cv,
        )
        end_time = time.time()
        training_time = end_time - start_time
        print("training done!")
        
        x = parameters
        y_i = self.get_parameters(config={})
        c_i_n = []
        server_update_x = []
        server_update_c = []

        # update client control variate c_i_1 = c_i - c + 1/eta*K (x - y_i)
        for c_i_j, c_j, x_j, y_i_j in zip(self.client_cv, server_cv, x, y_i):
            #print(c_i_j.device, c_j.device)
            c_i_n.append(
                c_i_j.cpu()
                - c_j.cpu()
                + (1.0 / (self.learning_rate * self.num_epochs * len(trainloader)))
                * (x_j - y_i_j)
            )
            # y_i - x, c_i_n - c_i for the server
            server_update_x.append((y_i_j - x_j))
            server_update_c.append((c_i_n[-1] - c_i_j.cpu()).cpu().numpy())
        self.client_cv = c_i_n
        torch.save(self.client_cv, f"{self.dir}/client_cv_{self.cid}.pt")

        combined_updates = server_update_x + server_update_c

        metrics = {
            "train_loss": train_loss,
            "num-examples": len(trainloader.dataset),
            "training_time": training_time,
            "local-epochs": self.num_epochs,
            "partition_id": self.cid
        }
        #metric_record = MetricRecord(metrics)
        #content = RecordDict({"arrays": model_record, "metrics": metric_record})

        with open(f'{self.dir}/clients{self.num_partitions}-partitioningdirichlet{self.val}-loceps{self.num_epochs}.txt', "a") as logger:
            logger.write(f"{str(metrics)}\n")
        
        return (
            combined_updates,
            len(trainloader.dataset),
            {},
        )

    def evaluate(self, parameters, config: Dict[str, Scalar]):
        """Evaluate using given parameters."""
        self.set_parameters(parameters)
        
        trainloader, valloader = load_datasets(
            self.cid, 
            self.num_partitions, 
            self.batch_size,
            partitioning="dirichlet",
            val=self.val, device=self.device
        )
        self.trainloader = trainloader
        self.valloader = valloader

        loss, acc = test(self.net, self.valloader, self.device)
        metrics = {
            "eval_loss": float(loss),
            "eval_acc": float(acc),
            "num-examples": len(self.valloader.dataset)
        }
        with open(f'{self.dir}/clients{self.num_partitions}-partitioningdirichlet{self.val}-loceps{self.num_epochs}.txt', "a") as logger:
            logger.write(f"{str(metrics)}\n")
        return float(loss), len(self.valloader.dataset), {"accuracy": float(acc)}

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions