From 10f6a1c513f1800e8455312c6823f4dd71a1e1e4 Mon Sep 17 00:00:00 2001 From: Hongyu Chen Date: Fri, 12 Jun 2026 19:19:26 -0400 Subject: [PATCH 1/5] fix: make NC metadata initialization robust to corner cases --- fedgraph/federated_methods.py | 106 ++++++++++++++++++++++++-- fedgraph/trainer_class.py | 65 +++++++++++----- fedgraph/utils_nc.py | 13 +++- tests/unit/test_federated_methods.py | 69 ++++++++++++++++- tests/unit/test_trainer_class.py | 110 +++++++++++++++++++++++++-- tests/unit/test_utils_nc.py | 72 +++++++++++++++++- 6 files changed, 394 insertions(+), 41 deletions(-) diff --git a/fedgraph/federated_methods.py b/fedgraph/federated_methods.py index cb6e25b..98532b1 100644 --- a/fedgraph/federated_methods.py +++ b/fedgraph/federated_methods.py @@ -49,6 +49,66 @@ LOWRANK_AVAILABLE = False +def _resolve_nc_class_num( + use_huggingface: bool, + trainer_information: list, + loaded_class_num: Optional[int] = None, +) -> int: + if not use_huggingface: + if loaded_class_num is None: + raise ValueError("class_num is required when NC data is loaded centrally") + return int(loaded_class_num) + + metadata_values = { + int(info["class_num"]) + for info in trainer_information + if info.get("class_num") is not None + } + if len(metadata_values) > 1: + raise ValueError("Hugging Face trainers report inconsistent class_num values") + if metadata_values: + return metadata_values.pop() + + label_nums = [ + int(info["label_num"]) + for info in trainer_information + if info.get("label_num") is not None + ] + if not label_nums: + raise ValueError( + "Cannot infer class_num from Hugging Face trainer data because all " + "train and test label tensors are empty" + ) + return max(label_nums) + + +def _resolve_nc_global_node_num( + use_huggingface: bool, + trainer_information: list, + loaded_global_node_num: Optional[int] = None, +) -> int: + if not use_huggingface: + if loaded_global_node_num is None: + raise ValueError( + "global_node_num is required when NC data is loaded centrally" + ) + return int(loaded_global_node_num) + + metadata_values = { + int(info["global_node_num"]) + for info in trainer_information + if info.get("global_node_num") is not None + } + if len(metadata_values) > 1: + raise ValueError( + "Hugging Face trainers report inconsistent global_node_num values" + ) + if metadata_values: + return metadata_values.pop() + + return sum(int(info["features_num"]) for info in trainer_information) + + def run_fedgraph(args: attridict) -> None: """ Run the training process for the specified task. @@ -206,6 +266,7 @@ def run_NC(args: attridict, data: Any = None) -> None: in_com_train_node_local_indexes=in_com_train_node_local_indexes, in_com_test_node_local_indexes=in_com_test_node_local_indexes, n_trainer=args.n_trainer, + class_num=class_num, args=args, ) @@ -289,8 +350,6 @@ def get_memory_usage(self): Trainer.remote( # type: ignore rank=i, args_hidden=args_hidden, - # global_node_num=len(features), - # class_num=class_num, device=device, args=args, local_node_index=split_node_indexes[i], @@ -305,6 +364,8 @@ def get_memory_usage(self): features=features[split_node_indexes[i]], idx_train=in_com_train_node_local_indexes[i], idx_test=in_com_test_node_local_indexes[i], + global_node_num=len(features), + class_num=class_num, ) for i in range(args.n_trainer) ] @@ -314,8 +375,16 @@ def get_memory_usage(self): ] # Extract necessary details from trainer information - global_node_num = sum([info["features_num"] for info in trainer_information]) - class_num = max([info["label_num"] for info in trainer_information]) + global_node_num = _resolve_nc_global_node_num( + args.use_huggingface, + trainer_information, + None if args.use_huggingface else len(features), + ) + class_num = _resolve_nc_class_num( + args.use_huggingface, + trainer_information, + None if args.use_huggingface else class_num, + ) feature_shape = trainer_information[0]["feature_shape"] train_data_weights = [ @@ -882,6 +951,8 @@ def __init__(self, *args: Any, **kwds: Any): features=features[split_node_indexes[i]], idx_train=in_com_train_node_local_indexes[i], idx_test=in_com_test_node_local_indexes[i], + global_node_num=len(features), + class_num=class_num, ) for i in range(args.n_trainer) ] @@ -891,8 +962,16 @@ def __init__(self, *args: Any, **kwds: Any): ray.get(trainers[i].get_info.remote()) for i in range(len(trainers)) ] - global_node_num = sum([info["features_num"] for info in trainer_information]) - class_num = max([info["label_num"] for info in trainer_information]) + global_node_num = _resolve_nc_global_node_num( + args.use_huggingface, + trainer_information, + None if args.use_huggingface else len(features), + ) + class_num = _resolve_nc_class_num( + args.use_huggingface, + trainer_information, + None if args.use_huggingface else class_num, + ) train_data_weights = [ info["len_in_com_train_node_local_indexes"] for info in trainer_information @@ -1074,6 +1153,7 @@ def run_NC_lowrank(args: attridict, data: Any = None) -> None: in_com_train_node_local_indexes=in_com_train_node_local_indexes, in_com_test_node_local_indexes=in_com_test_node_local_indexes, n_trainer=args.n_trainer, + class_num=class_num, args=args, ) @@ -1131,6 +1211,8 @@ def __init__(self, *args: Any, **kwds: Any): features=features[split_node_indexes[i]], idx_train=in_com_train_node_local_indexes[i], idx_test=in_com_test_node_local_indexes[i], + global_node_num=len(features), + class_num=class_num, ) for i in range(args.n_trainer) ] @@ -1140,8 +1222,16 @@ def __init__(self, *args: Any, **kwds: Any): ray.get(trainers[i].get_info.remote()) for i in range(len(trainers)) ] - global_node_num = sum([info["features_num"] for info in trainer_information]) - class_num = max([info["label_num"] for info in trainer_information]) + global_node_num = _resolve_nc_global_node_num( + args.use_huggingface, + trainer_information, + None if args.use_huggingface else len(features), + ) + class_num = _resolve_nc_class_num( + args.use_huggingface, + trainer_information, + None if args.use_huggingface else class_num, + ) train_data_weights = [ info["len_in_com_train_node_local_indexes"] for info in trainer_information diff --git a/fedgraph/trainer_class.py b/fedgraph/trainer_class.py index 069eb21..d1ce18c 100644 --- a/fedgraph/trainer_class.py +++ b/fedgraph/trainer_class.py @@ -1,6 +1,7 @@ import logging import random import time +import warnings from io import BytesIO from typing import Any, Dict, List, Union @@ -13,6 +14,7 @@ import torch.nn.functional as F import torch_geometric from huggingface_hub import hf_hub_download +from huggingface_hub.errors import EntryNotFoundError from torch_geometric.data import Data from torch_geometric.loader import NeighborLoader from torchmetrics.functional.retrieval import retrieval_auroc @@ -40,10 +42,15 @@ def load_trainer_data_from_hugging_face(trainer_id, args): repo_name = f"FedGraph/fedgraph_{args.dataset}_{args.n_trainer}trainer_{args.num_hops}hop_iid_beta_{args.iid_beta}_trainer_id_{trainer_id}" - def download_and_load_tensor(file_name): - file_path = hf_hub_download( - repo_id=repo_name, repo_type="dataset", filename=file_name - ) + def download_and_load_tensor(file_name, optional=False): + try: + file_path = hf_hub_download( + repo_id=repo_name, repo_type="dataset", filename=file_name + ) + except EntryNotFoundError: + if optional: + return None + raise with open(file_path, "rb") as f: buffer = BytesIO(f.read()) tensor = torch.load(buffer, weights_only=False) @@ -61,6 +68,14 @@ def download_and_load_tensor(file_name): features = download_and_load_tensor("features.pt") in_com_train_node_local_indexes = download_and_load_tensor("idx_train.pt") in_com_test_node_local_indexes = download_and_load_tensor("idx_test.pt") + global_node_num = download_and_load_tensor("global_node_num.pt", optional=True) + class_num = download_and_load_tensor("class_num.pt", optional=True) + if global_node_num is None or class_num is None: + warnings.warn( + "Hugging Face trainer data does not contain global metadata; " + "falling back to inference for compatibility with existing repositories.", + UserWarning, + ) return ( local_node_index, communicate_node_global_index, @@ -70,6 +85,8 @@ def download_and_load_tensor(file_name): features, in_com_train_node_local_indexes, in_com_test_node_local_indexes, + global_node_num, + class_num, ) @@ -135,6 +152,8 @@ def __init__( features: torch.Tensor = None, idx_train: torch.Tensor = None, idx_test: torch.Tensor = None, + global_node_num: int = None, + class_num: int = None, ): # from gnn_models import GCN_Graph_Classification torch.manual_seed(rank) @@ -157,6 +176,8 @@ def __init__( features, idx_train, idx_test, + global_node_num, + class_num, ) = load_trainer_data_from_hugging_face(rank, args) self.rank = rank # rank = trainer ID @@ -187,22 +208,30 @@ def __init__( self.args = args self.model = None self.optimizer = None - self.global_node_num = None - self.class_num = None + self.global_node_num = ( + int(global_node_num.item()) + if isinstance(global_node_num, torch.Tensor) + else global_node_num + ) + self.class_num = ( + int(class_num.item()) if isinstance(class_num, torch.Tensor) else class_num + ) self.feature_aggregation = None if self.args.method == "FedAvg": # print("Loading feature as the feature aggregation for fedavg method") self.feature_aggregation = self.features def get_info(self): - # assert self.train_labels.numel() > 0, "train_labels is empty" - # assert self.test_labels.numel() > 0, "test_labels is empty" + label_nums = [ + int(labels.max().item()) + 1 + for labels in (self.train_labels, self.test_labels) + if labels.numel() > 0 + ] return { "features_num": len(self.features), - "label_num": max( - self.train_labels.max().item(), self.test_labels.max().item() - ) - + 1, + "label_num": max(label_nums, default=None), + "global_node_num": self.global_node_num, + "class_num": self.class_num, "feature_shape": self.features.shape[1], "len_in_com_train_node_local_indexes": len(self.idx_train), "len_in_com_test_node_local_indexes": len(self.idx_test), @@ -320,14 +349,14 @@ def get_local_feature_sum(self) -> torch.Tensor: normalized_sum : torch.Tensor The normalized sum of features of 1-hop neighbors for each node """ - # Use global_node_num if available, otherwise infer from communicate_node_index - global_node_num = getattr(self, 'global_node_num', None) - if global_node_num is None: - global_node_num = self.communicate_node_index.max().item() + 1 - + if self.global_node_num is None: + raise RuntimeError( + "Trainer model metadata must be initialized before feature aggregation" + ) + # Create a large matrix with known local node features new_feature_for_trainer = torch.zeros( - global_node_num, self.features.shape[1] + self.global_node_num, self.features.shape[1] ).to(self.device) new_feature_for_trainer[self.local_node_index] = self.features diff --git a/fedgraph/utils_nc.py b/fedgraph/utils_nc.py index ee50409..1acf2f7 100644 --- a/fedgraph/utils_nc.py +++ b/fedgraph/utils_nc.py @@ -10,8 +10,7 @@ import scipy.sparse as sp import torch import torch_geometric - -# from huggingface_hub import HfApi, HfFolder, hf_hub_download, upload_file +from huggingface_hub import HfApi, get_token def normalize(mx: sp.csc_matrix) -> sp.csr_matrix: @@ -467,10 +466,12 @@ def save_trainer_data_to_hugging_face( features, in_com_train_node_local_indexes, in_com_test_node_local_indexes, + global_node_num, + class_num, args, ): repo_name = f"FedGraph/fedgraph_{args.dataset}_{args.n_trainer}trainer_{args.num_hops}hop_iid_beta_{args.iid_beta}_trainer_id_{trainer_id}" - user = HfFolder.get_token() + user = get_token() api = HfApi() try: @@ -501,6 +502,8 @@ def save_tensor_to_hf(tensor, file_name): save_tensor_to_hf(features, "features.pt") save_tensor_to_hf(in_com_train_node_local_indexes, "idx_train.pt") save_tensor_to_hf(in_com_test_node_local_indexes, "idx_test.pt") + save_tensor_to_hf(torch.tensor(global_node_num), "global_node_num.pt") + save_tensor_to_hf(torch.tensor(class_num), "class_num.pt") print(f"Uploaded data for trainer {trainer_id}") @@ -514,8 +517,10 @@ def save_all_trainers_data( in_com_train_node_local_indexes, in_com_test_node_local_indexes, n_trainer, + class_num, args, ): + global_node_num = len(features) for i in range(n_trainer): save_trainer_data_to_hugging_face( trainer_id=i, @@ -531,5 +536,7 @@ def save_all_trainers_data( features=features[split_node_indexes[i]], in_com_train_node_local_indexes=in_com_train_node_local_indexes[i], in_com_test_node_local_indexes=in_com_test_node_local_indexes[i], + global_node_num=global_node_num, + class_num=class_num, args=args, ) diff --git a/tests/unit/test_federated_methods.py b/tests/unit/test_federated_methods.py index c7fdd0b..251cb20 100644 --- a/tests/unit/test_federated_methods.py +++ b/tests/unit/test_federated_methods.py @@ -5,6 +5,8 @@ import attridict from fedgraph.federated_methods import ( + _resolve_nc_class_num, + _resolve_nc_global_node_num, run_fedgraph, run_fedgraph_enhanced, run_NC, @@ -13,6 +15,71 @@ ) +class TestResolveNCClassNum: + def test_uses_authoritative_loaded_class_num(self): + trainer_information = [{"label_num": 2}, {"label_num": None}] + + assert _resolve_nc_class_num(False, trainer_information, 7) == 7 + + def test_infers_huggingface_class_num_from_nonempty_trainers(self): + trainer_information = [ + {"label_num": None}, + {"label_num": 3}, + {"label_num": 5}, + ] + + assert _resolve_nc_class_num(True, trainer_information) == 5 + + def test_uses_huggingface_class_num_metadata(self): + trainer_information = [ + {"class_num": 7, "label_num": 2}, + {"class_num": 7, "label_num": 5}, + ] + + assert _resolve_nc_class_num(True, trainer_information) == 7 + + def test_rejects_inconsistent_huggingface_class_num_metadata(self): + trainer_information = [{"class_num": 3}, {"class_num": 4}] + + with pytest.raises(ValueError, match="inconsistent class_num"): + _resolve_nc_class_num(True, trainer_information) + + def test_rejects_huggingface_data_without_labels(self): + trainer_information = [{"label_num": None}, {"label_num": None}] + + with pytest.raises(ValueError, match="all train and test label tensors are empty"): + _resolve_nc_class_num(True, trainer_information) + + +class TestResolveNCGlobalNodeNum: + def test_uses_authoritative_loaded_node_count(self): + trainer_information = [{"features_num": 40}, {"features_num": 50}] + + assert _resolve_nc_global_node_num(False, trainer_information, 100) == 100 + + def test_infers_huggingface_node_count_from_owned_features(self): + trainer_information = [{"features_num": 40}, {"features_num": 60}] + + assert _resolve_nc_global_node_num(True, trainer_information) == 100 + + def test_uses_huggingface_global_node_num_metadata(self): + trainer_information = [ + {"global_node_num": 100, "features_num": 40}, + {"global_node_num": 100, "features_num": 50}, + ] + + assert _resolve_nc_global_node_num(True, trainer_information) == 100 + + def test_rejects_inconsistent_huggingface_global_node_num_metadata(self): + trainer_information = [ + {"global_node_num": 100}, + {"global_node_num": 101}, + ] + + with pytest.raises(ValueError, match="inconsistent global_node_num"): + _resolve_nc_global_node_num(True, trainer_information) + + class TestRunFedgraph: """Test run_fedgraph main orchestration function.""" @@ -426,4 +493,4 @@ def test_data_loading_logic(self, mock_data_loader): run_fedgraph(args) mock_data_loader.assert_not_called() - mock_run_nc.assert_called_once_with(args, None) \ No newline at end of file + mock_run_nc.assert_called_once_with(args, None) diff --git a/tests/unit/test_trainer_class.py b/tests/unit/test_trainer_class.py index 05c74e3..88e81d9 100644 --- a/tests/unit/test_trainer_class.py +++ b/tests/unit/test_trainer_class.py @@ -3,6 +3,7 @@ import numpy as np from unittest.mock import Mock, patch, MagicMock import ray +from huggingface_hub.errors import EntryNotFoundError from fedgraph.trainer_class import ( Trainer_General, @@ -34,7 +35,9 @@ def test_load_trainer_data_success(self, mock_torch_load, mock_open, mock_hf_dow torch.randn(20), # test_labels torch.randn(100, 10), # features torch.randn(80), # in_com_train_node_local_indexes - torch.randn(20) # in_com_test_node_local_indexes + torch.randn(20), # in_com_test_node_local_indexes + torch.tensor(100), # global_node_num + torch.tensor(3), # class_num ] mock_torch_load.side_effect = mock_tensors @@ -46,11 +49,11 @@ def test_load_trainer_data_success(self, mock_torch_load, mock_open, mock_hf_dow result = load_trainer_data_from_hugging_face(trainer_id=0, args=args) - assert len(result) == 8 + assert len(result) == 10 assert all(isinstance(tensor, torch.Tensor) for tensor in result) # Verify calls - assert mock_hf_download.call_count == 8 + assert mock_hf_download.call_count == 10 expected_repo = "FedGraph/fedgraph_cora_5trainer_2hop_iid_beta_0.5_trainer_id_0" mock_hf_download.assert_any_call( repo_id=expected_repo, @@ -58,6 +61,31 @@ def test_load_trainer_data_success(self, mock_torch_load, mock_open, mock_hf_dow filename="local_node_index.pt" ) + @patch('fedgraph.trainer_class.hf_hub_download') + @patch('builtins.open') + @patch('torch.load') + def test_load_existing_repo_without_global_metadata( + self, mock_torch_load, mock_open, mock_hf_download + ): + mock_file = Mock() + mock_file.read.return_value = b"test_tensor_data" + mock_open.return_value.__enter__.return_value = mock_file + mock_torch_load.side_effect = [torch.tensor([i]) for i in range(8)] + + def download_side_effect(*, filename, **kwargs): + if filename in {"global_node_num.pt", "class_num.pt"}: + raise EntryNotFoundError("missing optional metadata") + return "/tmp/test_file.pt" + + mock_hf_download.side_effect = download_side_effect + args = Mock(dataset="cora", n_trainer=5, num_hops=2, iid_beta=0.5) + + with pytest.warns(UserWarning, match="falling back to inference"): + result = load_trainer_data_from_hugging_face(trainer_id=0, args=args) + + assert len(result) == 10 + assert result[-2:] == (None, None) + class TestTrainerGeneral: """Test Trainer_General class.""" @@ -102,13 +130,17 @@ def test_trainer_init_with_data(self, mock_load_data): test_labels=self.test_labels, features=self.features, idx_train=self.idx_train, - idx_test=self.idx_test + idx_test=self.idx_test, + global_node_num=100, + class_num=3, ) assert trainer.rank == self.rank assert trainer.device == self.device assert trainer.args_hidden == self.args_hidden assert trainer.local_step == self.args.local_step + assert trainer.global_node_num == 100 + assert trainer.class_num == 3 assert torch.equal(trainer.local_node_index, self.local_node_index.to(self.device)) assert trainer.feature_aggregation is not None # Should be set for FedAvg mock_load_data.assert_not_called() @@ -124,7 +156,9 @@ def test_trainer_init_without_data(self, mock_load_data): self.test_labels, self.features, self.idx_train, - self.idx_test + self.idx_test, + torch.tensor(100), + torch.tensor(3), ) trainer = Trainer_General( @@ -135,6 +169,8 @@ def test_trainer_init_without_data(self, mock_load_data): ) assert trainer.rank == self.rank + assert trainer.global_node_num == 100 + assert trainer.class_num == 3 mock_load_data.assert_called_once_with(self.rank, self.args) def test_get_info(self): @@ -161,6 +197,38 @@ def test_get_info(self): assert info["features_num"] == len(self.features) expected_label_num = max(self.train_labels.max().item(), self.test_labels.max().item()) + 1 assert info["label_num"] == expected_label_num + + @pytest.mark.parametrize( + ("train_labels", "test_labels", "expected_label_num"), + [ + (torch.tensor([], dtype=torch.long), torch.tensor([0, 2]), 3), + (torch.tensor([0, 1]), torch.tensor([], dtype=torch.long), 2), + ( + torch.tensor([], dtype=torch.long), + torch.tensor([], dtype=torch.long), + None, + ), + ], + ) + def test_get_info_handles_empty_labels( + self, train_labels, test_labels, expected_label_num + ): + trainer = Trainer_General( + rank=self.rank, + args_hidden=self.args_hidden, + device=self.device, + args=self.args, + local_node_index=self.local_node_index, + communicate_node_index=self.communicate_node_index, + adj=self.adj, + train_labels=train_labels, + test_labels=test_labels, + features=self.features, + idx_train=torch.arange(len(train_labels)), + idx_test=torch.arange(len(test_labels)), + ) + + assert trainer.get_info()["label_num"] == expected_label_num @patch('fedgraph.trainer_class.GCN') @patch('fedgraph.trainer_class.GCN_arxiv') @@ -224,6 +292,7 @@ def test_update_params(self): def test_get_local_feature_sum(self): """Test get_local_feature_sum method.""" + local_features = self.features[self.local_node_index] trainer = Trainer_General( rank=self.rank, args_hidden=self.args_hidden, @@ -234,11 +303,12 @@ def test_get_local_feature_sum(self): adj=self.adj, train_labels=self.train_labels, test_labels=self.test_labels, - features=self.features, + features=local_features, idx_train=self.idx_train, idx_test=self.idx_test ) - + trainer.global_node_num = len(self.features) + # Mock the get_1hop_feature_sum function with patch('fedgraph.trainer_class.get_1hop_feature_sum') as mock_get_1hop: mock_get_1hop.return_value = torch.randn(100, 50) @@ -247,6 +317,30 @@ def test_get_local_feature_sum(self): assert isinstance(result, torch.Tensor) mock_get_1hop.assert_called_once() + feature_matrix = mock_get_1hop.call_args.args[0] + assert feature_matrix.shape == ( + trainer.global_node_num, + local_features.shape[1], + ) + + def test_get_local_feature_sum_requires_global_node_num(self): + trainer = Trainer_General( + rank=self.rank, + args_hidden=self.args_hidden, + device=self.device, + args=self.args, + local_node_index=self.local_node_index, + communicate_node_index=self.communicate_node_index, + adj=self.adj, + train_labels=self.train_labels, + test_labels=self.test_labels, + features=self.features[self.local_node_index], + idx_train=self.idx_train, + idx_test=self.idx_test, + ) + + with pytest.raises(RuntimeError, match="metadata must be initialized"): + trainer.get_local_feature_sum() def test_load_feature_aggregation(self): """Test load_feature_aggregation method.""" @@ -662,4 +756,4 @@ def test_trainer_general_full_workflow(self, mock_gcn_class): assert isinstance(params, tuple) # Test rank retrieval - assert trainer.get_rank() == rank \ No newline at end of file + assert trainer.get_rank() == rank diff --git a/tests/unit/test_utils_nc.py b/tests/unit/test_utils_nc.py index 8cbb520..11d8b46 100644 --- a/tests/unit/test_utils_nc.py +++ b/tests/unit/test_utils_nc.py @@ -14,7 +14,9 @@ community_partition_non_iid, get_in_comm_indexes, get_1hop_feature_sum, - increment_dir + increment_dir, + save_all_trainers_data, + save_trainer_data_to_hugging_face, ) @@ -78,7 +80,7 @@ def test_intersect1d_all_same(self): result = intersect1d(t1, t2) assert torch.equal(result.sort().values, torch.tensor([1, 2, 3])) - + def test_setdiff1d_basic(self): """Test setdiff1d function with basic tensors.""" t1 = torch.tensor([1, 2, 3, 4]) @@ -109,6 +111,70 @@ def test_setdiff1d_complete_difference(self): assert torch.equal(result.sort().values, expected.sort().values) +class TestSaveTrainerDataToHuggingFace: + @patch("fedgraph.utils_nc.get_token") + @patch("fedgraph.utils_nc.HfApi") + def test_uploads_global_metadata(self, mock_hf_api, mock_get_token): + api = mock_hf_api.return_value + mock_get_token.return_value = "token" + args = Mock( + dataset="cora", + n_trainer=2, + num_hops=1, + iid_beta=0.5, + ) + + save_trainer_data_to_hugging_face( + trainer_id=0, + local_node_index=torch.tensor([0, 1]), + communicate_node_global_index=torch.tensor([0, 1, 2]), + global_edge_index_client=torch.tensor([[0, 1], [1, 2]]), + train_labels=torch.tensor([0]), + test_labels=torch.tensor([1]), + features=torch.randn(2, 4), + in_com_train_node_local_indexes=torch.tensor([0]), + in_com_test_node_local_indexes=torch.tensor([1]), + global_node_num=3, + class_num=2, + args=args, + ) + + uploaded_files = { + call.kwargs["path_in_repo"] for call in api.upload_file.call_args_list + } + assert "global_node_num.pt" in uploaded_files + assert "class_num.pt" in uploaded_files + + @patch("fedgraph.utils_nc.save_trainer_data_to_hugging_face") + def test_bulk_save_passes_consistent_global_metadata(self, mock_save_trainer): + features = torch.randn(4, 3) + labels = torch.tensor([0, 1, 2, 1]) + + save_all_trainers_data( + split_node_indexes=[torch.tensor([0, 1]), torch.tensor([2, 3])], + communicate_node_global_indexes=[ + torch.tensor([0, 1]), + torch.tensor([2, 3]), + ], + global_edge_indexes_clients=[ + torch.tensor([[0], [1]]), + torch.tensor([[2], [3]]), + ], + labels=labels, + features=features, + in_com_train_node_local_indexes=[torch.tensor([0]), torch.tensor([0])], + in_com_test_node_local_indexes=[torch.tensor([1]), torch.tensor([1])], + n_trainer=2, + class_num=3, + args=Mock(), + ) + + assert mock_save_trainer.call_count == 2 + for call in mock_save_trainer.call_args_list: + assert call.kwargs["global_node_num"] == 4 + assert call.kwargs["class_num"] == 3 + + class TestLabelDirichletPartition: """Test label_dirichlet_partition function.""" @@ -407,4 +473,4 @@ def test_graph_communication_workflow(self): # Verify all trainers have some data for i in range(n_trainers): assert i in communicate_indexes - assert i in edge_indexes \ No newline at end of file + assert i in edge_indexes From 5d073055cc3d88aea7b3945821dd091f9bcdd3c0 Mon Sep 17 00:00:00 2001 From: Hongyu Chen Date: Wed, 17 Jun 2026 15:35:52 -0400 Subject: [PATCH 2/5] Fix NC server training and encrypted aggregation --- fedgraph/federated_methods.py | 73 +----- fedgraph/server_class.py | 165 +++++++++----- .../integration/test_fedgraph_integration.py | 13 +- tests/unit/test_server_class.py | 211 +++++++++++++++--- 4 files changed, 299 insertions(+), 163 deletions(-) diff --git a/fedgraph/federated_methods.py b/fedgraph/federated_methods.py index 98532b1..cdca011 100644 --- a/fedgraph/federated_methods.py +++ b/fedgraph/federated_methods.py @@ -536,68 +536,10 @@ def get_memory_usage(self): print("global_rounds", args.global_rounds) global_acc_list = [] for i in range(args.global_rounds): - # Pure training phase - forward + gradient descent only - pure_training_start = time.time() - - # Execute only training (forward + gradient descent) - train_refs = [trainer.train.remote(i) for trainer in server.trainers] - ray.get(train_refs) - - pure_training_end = time.time() - round_training_time = pure_training_end - pure_training_start + round_stats = server.train(i) + round_training_time = round_stats["training_time"] + round_comm_time = round_stats["communication_time"] total_pure_training_time += round_training_time - - # Communication phase - parameter aggregation and broadcast - comm_start = time.time() - - if args.use_encryption: - # Encrypted parameter aggregation - encrypted_params = [ - trainer.get_encrypted_params.remote() for trainer in server.trainers - ] - params_list = ray.get(encrypted_params) - - # Server-side aggregation - aggregated_params, metadata, _ = server.aggregate_encrypted_params( - params_list - ) - - # Distribute aggregated parameters - decrypt_refs = [ - trainer.load_encrypted_params.remote((aggregated_params, metadata), i) - for trainer in server.trainers - ] - ray.get(decrypt_refs) - else: - # Regular parameter aggregation - # Get parameters from all trainers - params_refs = [trainer.get_params.remote() for trainer in server.trainers] - param_results = ray.get(params_refs) - - # Aggregate parameters on server - avoid in-place operations - server.zero_params() - - # Move model to CPU for aggregation - server.model = server.model.to("cpu") - - # Aggregate parameters safely - for param_result in param_results: - for p, mp in zip(param_result, server.model.parameters()): - mp.data = mp.data + p.cpu() - - # Move back to device and average - server.model = server.model.to(server.device) - - # Average the parameters - with torch.no_grad(): - for p in server.model.parameters(): - p.data = p.data / len(server.trainers) - - # Broadcast updated parameters to all trainers - server.broadcast_params(i) - - comm_end = time.time() - round_comm_time = comm_end - comm_start total_communication_time += round_comm_time # Testing phase (not counted in training or communication time) @@ -613,10 +555,9 @@ def get_memory_usage(self): f"Round {i+1}: Training Time = {round_training_time:.2f}s, Communication Time = {round_comm_time:.2f}s" ) - model_size_mb = server.get_model_size() / (1024 * 1024) monitor.add_train_comm_cost( - upload_mb=model_size_mb * args.n_trainer, - download_mb=model_size_mb * args.n_trainer, + upload_mb=round_stats["upload_size"] / (1024 * 1024), + download_mb=round_stats["download_size"] / (1024 * 1024), ) monitor.train_time_end() total_time = time.time() - training_start @@ -670,10 +611,6 @@ def get_memory_usage(self): else: training_upload = training_download = 0 training_comm_cost = training_upload + training_download - monitor.add_train_comm_cost( - upload_mb=training_upload, - download_mb=training_download, - ) print("\nTraining Phase Metrics:") print( f"Total Training Time: {total_pure_training_time:.2f} seconds" diff --git a/fedgraph/server_class.py b/fedgraph/server_class.py index b3fdf9f..ae639dc 100644 --- a/fedgraph/server_class.py +++ b/fedgraph/server_class.py @@ -130,28 +130,6 @@ def zero_params(self) -> None: for p in self.model.parameters(): p.zero_() - def prepare_params_for_encryption(self, params): - processed_params = [] - metadata = [] - - for param in params: - param_min = param.min() - param_max = param.max() - param_range = param_max - param_min - - # handle division by 0 - if param_range == 0: - normalized = param - param_min - else: - normalized = (param - param_min) / param_range - - scaled = normalized * 1000 - - processed_params.append(scaled) - metadata.append({"min": param_min, "range": param_range}) - - return processed_params, metadata - def aggregate_encrypted_feature_sums(self, encrypted_sums): aggregation_start = time.time() @@ -167,43 +145,82 @@ def aggregate_encrypted_feature_sums(self, encrypted_sums): def aggregate_encrypted_params(self, encrypted_params_list): aggregation_start = time.time() - first_params, metadata = encrypted_params_list[0] + if not encrypted_params_list: + raise ValueError("encrypted_params_list must not be empty") + + first_params, _ = encrypted_params_list[0] n_layers = len(first_params) + if n_layers == 0: + raise ValueError("encrypted parameter payload must contain at least one layer") + + layer_shapes = [] + layer_scale_values = [[] for _ in range(n_layers)] + validated_params = [] + + for trainer_idx, (trainer_params, trainer_metadata) in enumerate( + encrypted_params_list + ): + if len(trainer_params) != n_layers: + raise ValueError("all trainers must provide the same number of layers") + if len(trainer_metadata) != n_layers: + raise ValueError("metadata length must match encrypted parameter layers") + + validated_metadata = [] + for layer_idx, metadata in enumerate(trainer_metadata): + shape = metadata.get("shape") + if shape is None: + raise ValueError("encrypted parameter metadata must include shape") + normalized_shape = tuple(shape) + + try: + scale = float(metadata["scale"]) + except (KeyError, TypeError, ValueError) as exc: + raise ValueError( + "encrypted parameter metadata must include numeric scale" + ) from exc + if not np.isfinite(scale) or scale <= 0: + raise ValueError("encrypted parameter scale must be positive") + + if trainer_idx == 0: + layer_shapes.append(shape) + elif normalized_shape != tuple(layer_shapes[layer_idx]): + raise ValueError( + "encrypted parameter shapes must match across trainers" + ) + + layer_scale_values[layer_idx].append(scale) + validated_metadata.append({"shape": shape, "scale": scale}) + + validated_params.append((trainer_params, validated_metadata)) + + aggregation_metadata = [ + {"shape": layer_shapes[layer_idx], "scale": max(layer_scale_values[layer_idx])} + for layer_idx in range(n_layers) + ] - # each layer aggregated_params = [] for layer_idx in range(n_layers): - agg_layer = ts.ckks_vector_from( - self.he_context, encrypted_params_list[0][0][layer_idx] - ) + common_scale = aggregation_metadata[layer_idx]["scale"] + agg_layer = None - for trainer_params, _ in encrypted_params_list[1:]: + for trainer_params, trainer_metadata in validated_params: + trainer_scale = trainer_metadata[layer_idx]["scale"] next_layer = ts.ckks_vector_from( self.he_context, trainer_params[layer_idx] ) - agg_layer += next_layer + next_layer *= common_scale / trainer_scale + + if agg_layer is None: + agg_layer = next_layer + else: + agg_layer += next_layer # average - agg_layer *= 1.0 / self.num_of_trainers + agg_layer *= 1.0 / len(validated_params) aggregated_params.append(agg_layer.serialize()) aggregation_time = time.time() - aggregation_start - return aggregated_params, metadata, aggregation_time - - def get_encrypted_params(self): - params = [p.data.cpu().detach() for p in self.model.parameters()] - - # normalize and scale - processed_params, metadata = self.prepare_params_for_encryption(params) - - encrypted_params = [] - for param in processed_params: - param_list = param.flatten().tolist() - - encrypted = ts.ckks_vector(self.he_context, param_list).serialize() - encrypted_params.append(encrypted) - - return encrypted_params, metadata + return aggregated_params, aggregation_metadata, aggregation_time @torch.no_grad() def train( @@ -211,17 +228,22 @@ def train( current_global_epoch: int, sampling_type: str = "random", sample_ratio: float = 1, - ) -> None: + ) -> dict: """ - Training round which performs aggregating parameters from sampled trainers (by index), - updating the central model, and then broadcasting the updated parameters - back to all trainers. + Run one federated training round and return timing and communication statistics. Parameters ---------- current_global_epoch : int The current global epoch number during the federated learning process. + + Returns + ------- + dict + Per-round training time, communication time, and transfer sizes. """ + training_start = time.time() + if self.use_encryption: if not hasattr(self, "aggregation_stats"): self.aggregation_stats = [] @@ -230,6 +252,9 @@ def train( trainer.train.remote(current_global_epoch) for trainer in self.trainers ] ray.get(train_refs) + training_time = time.time() - training_start + + communication_start = time.time() encryption_start = time.time() print("Starting encrypted parameter aggregation...") encrypted_params = [ @@ -238,7 +263,6 @@ def train( # Wait for all trainers and collect parameters params_list = [] - encryption_times = [] enc_sizes = [] while encrypted_params: ready, encrypted_params = ray.wait(encrypted_params) @@ -258,7 +282,6 @@ def train( agg_size = sum(len(p) for p in aggregated_params) # Distribute back to trainers - decryption_start = time.time() decrypt_refs = [ trainer.load_encrypted_params.remote( (aggregated_params, metadata), current_global_epoch @@ -266,24 +289,33 @@ def train( for trainer in self.trainers ] decryption_times = ray.get(decrypt_refs) + communication_time = time.time() - communication_start round_metrics = { + "training_time": training_time, + "communication_time": communication_time, "encryption_time": encryption_time, "decryption_times": decryption_times, "aggregation_time": agg_time, "upload_size": sum(enc_sizes), "download_size": agg_size * len(self.trainers), + "num_trainers": len(self.trainers), } self.aggregation_stats.append(round_metrics) + return round_metrics else: # normal training logic # print( # f"Training round: {current_global_epoch}, sampling rate: {sample_ratio}" # ) assert 0 < sample_ratio <= 1, "Sample ratio must be between 0 and 1" + if sampling_type not in {"random", "uniform"}: + raise ValueError("sampling_type must be either 'random' or 'uniform'") num_samples = int(self.num_of_trainers * sample_ratio) - if sampling_type == "random": + if sample_ratio == 1: + selected_trainers_indices = list(range(self.num_of_trainers)) + elif sampling_type == "random": selected_trainers_indices = random.sample( range(self.num_of_trainers), num_samples ) @@ -298,11 +330,14 @@ def train( for i in range(num_samples) ] - else: - raise ValueError("sampling_type must be either 'random' or 'uniform'") - - for trainer_idx in selected_trainers_indices: + train_refs = [ self.trainers[trainer_idx].train.remote(current_global_epoch) + for trainer_idx in selected_trainers_indices + ] + ray.get(train_refs) + training_time = time.time() - training_start + + communication_start = time.time() params = [ self.trainers[trainer_idx].get_params.remote() @@ -328,6 +363,15 @@ def train( p /= num_samples self.broadcast_params(current_global_epoch) + communication_time = time.time() - communication_start + model_size = self.get_model_size() + return { + "training_time": training_time, + "communication_time": communication_time, + "upload_size": model_size * num_samples, + "download_size": model_size * self.num_of_trainers, + "num_trainers": num_samples, + } def broadcast_params(self, current_global_epoch: int) -> None: """ @@ -338,10 +382,13 @@ def broadcast_params(self, current_global_epoch: int) -> None: current_global_epoch : int The current global epoch number during the federated learning process. """ - for trainer in self.trainers: + update_refs = [ trainer.update_params.remote( tuple(self.model.parameters()), current_global_epoch - ) # run in submit order + ) + for trainer in self.trainers + ] + ray.get(update_refs) def get_model_size(self) -> float: """Return total model parameter size in bytes (assumes float32).""" diff --git a/tests/integration/test_fedgraph_integration.py b/tests/integration/test_fedgraph_integration.py index bc09fb4..a3fc5f3 100644 --- a/tests/integration/test_fedgraph_integration.py +++ b/tests/integration/test_fedgraph_integration.py @@ -158,11 +158,13 @@ def test_nc_server_aggregation(self, mock_aggre_gcn): # Test parameter broadcast mock_model.state_dict.return_value = {'weight': torch.randn(64, 50)} - server.broadcast_params(current_global_epoch=1) + mock_model.parameters.return_value = [torch.randn(64, 50)] + with patch("fedgraph.server_class.ray.get"): + server.broadcast_params(current_global_epoch=1) # Verify all trainers received updates for trainer in trainers: - trainer.update_params.assert_called() + trainer.update_params.remote.assert_called() @pytest.mark.integration @@ -483,8 +485,9 @@ def test_data_trainer_server_interaction(self): trainer.update_params = Mock() mock_server_model.state_dict.return_value = {'weight': torch.randn(32, 20)} - server.broadcast_params(current_global_epoch=1) - trainer.update_params.assert_called() + with patch("fedgraph.server_class.ray.get"): + server.broadcast_params(current_global_epoch=1) + trainer.update_params.remote.assert_called() def test_monitor_integration(self): """Test monitoring system integration.""" @@ -564,4 +567,4 @@ def test_memory_efficiency(self): assert len(difference) > 0 # Clean up - del large_features, large_edge_index, intersection, difference \ No newline at end of file + del large_features, large_edge_index, intersection, difference diff --git a/tests/unit/test_server_class.py b/tests/unit/test_server_class.py index 5cf05a1..85bb29b 100644 --- a/tests/unit/test_server_class.py +++ b/tests/unit/test_server_class.py @@ -155,11 +155,181 @@ def test_broadcast_params(self): for trainer in server.trainers: trainer.update_params = Mock() - server.broadcast_params(current_global_epoch=1) + with patch("fedgraph.server_class.ray.get") as mock_ray_get: + server.broadcast_params(current_global_epoch=1) # Verify that all trainers received parameter updates for trainer in server.trainers: - trainer.update_params.assert_called_once() + trainer.update_params.remote.assert_called_once() + mock_ray_get.assert_called_once() + + def test_train_runs_complete_plaintext_round(self): + """Test local training, subset aggregation, and broadcast as one round.""" + server = Server.__new__(Server) + server.model = torch.nn.Linear(1, 1, bias=False) + server.device = torch.device("cpu") + server.use_encryption = False + server.num_of_trainers = 3 + + trainers = [] + param_refs = {} + parameter_values = [1.0, 100.0, 5.0] + for index, value in enumerate(parameter_values): + trainer = Mock() + trainer.train.remote.return_value = f"train-{index}" + trainer.get_params.remote.return_value = f"params-{index}" + trainer.update_params.remote.return_value = f"update-{index}" + param_refs[f"params-{index}"] = (torch.tensor([[value]]),) + trainers.append(trainer) + server.trainers = trainers + + def resolve(refs): + if isinstance(refs, list): + return [resolve(ref) for ref in refs] + return param_refs.get(refs, True) + + def wait_for_one(refs, **_kwargs): + return refs[:1], refs[1:] + + with patch("fedgraph.server_class.random.sample", return_value=[0, 2]), patch( + "fedgraph.server_class.ray.get", side_effect=resolve + ), patch( + "fedgraph.server_class.ray.wait", side_effect=wait_for_one + ), patch( + "fedgraph.server_class.time.time", side_effect=[10.0, 12.0, 12.0, 15.0] + ): + round_stats = server.train(4, sample_ratio=0.75) + + assert server.model.weight.item() == pytest.approx(3.0) + assert round_stats == { + "training_time": 2.0, + "communication_time": 3.0, + "upload_size": 8, + "download_size": 12, + "num_trainers": 2, + } + trainers[0].train.remote.assert_called_once_with(4) + trainers[1].train.remote.assert_not_called() + trainers[2].train.remote.assert_called_once_with(4) + for trainer in trainers: + trainer.update_params.remote.assert_called_once() + + def test_train_runs_complete_encrypted_round(self): + """Test encrypted training returns actual communication statistics.""" + server = Server.__new__(Server) + server.model = torch.nn.Linear(1, 1, bias=False) + server.device = torch.device("cpu") + server.use_encryption = True + server.num_of_trainers = 2 + server.aggregation_stats = [] + + trainers = [] + encrypted_refs = {} + decryption_refs = {} + metadata = [{"shape": torch.Size([1, 1]), "scale": 100.0}] + for index, payload in enumerate([b"one", b"two"]): + trainer = Mock() + trainer.train.remote.return_value = f"train-{index}" + trainer.get_encrypted_params.remote.return_value = f"encrypted-{index}" + trainer.load_encrypted_params.remote.return_value = f"decrypted-{index}" + encrypted_refs[f"encrypted-{index}"] = ([payload], metadata) + decryption_refs[f"decrypted-{index}"] = 0.1 + trainers.append(trainer) + server.trainers = trainers + server.aggregate_encrypted_params = Mock( + return_value=([b"avg"], metadata, 0.25) + ) + + def resolve(refs): + if isinstance(refs, list): + return [resolve(ref) for ref in refs] + if refs in encrypted_refs: + return encrypted_refs[refs] + return decryption_refs.get(refs, True) + + def wait_for_one(refs, **_kwargs): + return refs[:1], refs[1:] + + with patch( + "fedgraph.server_class.ray.get", side_effect=resolve + ), patch( + "fedgraph.server_class.ray.wait", side_effect=wait_for_one + ), patch( + "fedgraph.server_class.time.time", + side_effect=[10.0, 12.0, 12.0, 12.5, 13.0, 15.0], + ): + round_stats = server.train(4) + + assert round_stats["training_time"] == 2.0 + assert round_stats["communication_time"] == 3.0 + assert round_stats["encryption_time"] == 0.5 + assert round_stats["upload_size"] == 6 + assert round_stats["download_size"] == 6 + assert round_stats["num_trainers"] == 2 + assert server.aggregation_stats == [round_stats] + for trainer in trainers: + trainer.train.remote.assert_called_once_with(4) + trainer.load_encrypted_params.remote.assert_called_once() + + def test_aggregate_encrypted_params_rescales_to_common_layer_scale(self): + """Test encrypted params are rescaled before cross-trainer averaging.""" + server = Server.__new__(Server) + server.he_context = object() + + class FakeCKKSVector: + def __init__(self, value): + self.value = float(value) + + def __iadd__(self, other): + self.value += other.value + return self + + def __imul__(self, scalar): + self.value *= scalar + return self + + def serialize(self): + return self.value + + encrypted_params_list = [ + ([200.0], [{"shape": torch.Size([1]), "scale": 100.0}]), + ([4000.0], [{"shape": torch.Size([1]), "scale": 1000.0}]), + ] + + with patch( + "fedgraph.server_class.ts.ckks_vector_from", + side_effect=lambda _context, payload: FakeCKKSVector(payload), + ), patch("fedgraph.server_class.time.time", side_effect=[10.0, 13.0]): + aggregated_params, metadata, aggregation_time = ( + server.aggregate_encrypted_params(encrypted_params_list) + ) + + assert aggregated_params == [pytest.approx(3000.0)] + assert metadata == [{"shape": torch.Size([1]), "scale": 1000.0}] + assert aggregation_time == 3.0 + + def test_aggregate_encrypted_params_rejects_mismatched_shapes(self): + """Test encrypted averaging fails clearly for incompatible layers.""" + server = Server.__new__(Server) + + encrypted_params_list = [ + ([1.0], [{"shape": torch.Size([1]), "scale": 100.0}]), + ([2.0], [{"shape": torch.Size([2]), "scale": 100.0}]), + ] + + with pytest.raises(ValueError, match="shapes must match"): + server.aggregate_encrypted_params(encrypted_params_list) + + def test_aggregate_encrypted_params_rejects_invalid_scale(self): + """Test encrypted averaging requires positive per-layer scales.""" + server = Server.__new__(Server) + + encrypted_params_list = [ + ([1.0], [{"shape": torch.Size([1]), "scale": 0.0}]), + ] + + with pytest.raises(ValueError, match="scale must be positive"): + server.aggregate_encrypted_params(encrypted_params_list) def test_get_model_size(self): """Test get_model_size method.""" @@ -187,32 +357,6 @@ def test_get_model_size(self): assert isinstance(model_size, float) assert model_size > 0 - def test_prepare_params_for_encryption(self): - """Test prepare_params_for_encryption method.""" - with patch('fedgraph.server_class.AggreGCN') as mock_gcn: - mock_model = Mock() - mock_gcn.return_value = mock_model - mock_model.to.return_value = mock_model - - server = Server( - feature_dim=self.feature_dim, - args_hidden=self.args_hidden, - class_num=self.class_num, - device=self.device, - trainers=self.mock_trainers, - args=self.args - ) - - params = (torch.randn(10, 5), torch.randn(10)) - - result = server.prepare_params_for_encryption(params) - - assert isinstance(result, list) - assert len(result) == len(params) - for item in result: - assert isinstance(item, list) # Flattened and converted to list - - class TestServerGC: """Test Server_GC class for graph classification.""" @@ -433,6 +577,10 @@ def test_server_trainer_interaction(self, mock_gcn_class): 'layer1.weight': torch.randn(32, 50), 'layer1.bias': torch.randn(32) } + mock_model.parameters.return_value = [ + torch.randn(32, 50), + torch.randn(32), + ] # Create mock trainers trainers = [] @@ -457,11 +605,12 @@ def test_server_trainer_interaction(self, mock_gcn_class): ) # Test parameter broadcast - server.broadcast_params(current_global_epoch=1) + with patch("fedgraph.server_class.ray.get"): + server.broadcast_params(current_global_epoch=1) # Verify all trainers received updates for trainer in trainers: - trainer.update_params.assert_called_once() + trainer.update_params.remote.assert_called_once() # Test model size computation model_size = server.get_model_size() @@ -520,4 +669,4 @@ def test_server_gc_clustering_workflow(self): server.aggregate_clusterwise(trainer_clusters) # Verify workflow completed successfully - assert True \ No newline at end of file + assert True From 6213a859d589653bf7f36395f4440297118eacc6 Mon Sep 17 00:00:00 2001 From: Hongyu Chen Date: Wed, 17 Jun 2026 16:20:05 -0400 Subject: [PATCH 3/5] Fix duplicate trainer test metric appends --- fedgraph/trainer_class.py | 5 +---- tests/unit/test_trainer_class.py | 13 +++++++++---- 2 files changed, 10 insertions(+), 8 deletions(-) diff --git a/fedgraph/trainer_class.py b/fedgraph/trainer_class.py index d1ce18c..5c335fd 100644 --- a/fedgraph/trainer_class.py +++ b/fedgraph/trainer_class.py @@ -650,10 +650,7 @@ def train(self, current_global_round: int) -> None: self.train_losses.append(loss_train) self.train_accs.append(acc_train) # print(f"acc_train: {acc_train}") - loss_test, acc_test = self.local_test() - self.test_losses.append(loss_test) - self.test_accs.append(acc_test) - # print(f"current round: {current_global_round}, acc_test: {acc_test}") + self.local_test() def local_test(self) -> list: """ diff --git a/tests/unit/test_trainer_class.py b/tests/unit/test_trainer_class.py index 88e81d9..af94659 100644 --- a/tests/unit/test_trainer_class.py +++ b/tests/unit/test_trainer_class.py @@ -392,8 +392,9 @@ def test_get_params(self): assert isinstance(params, tuple) mock_model.state_dict.assert_called_once() + @patch('fedgraph.trainer_class.test') @patch('fedgraph.trainer_class.train') - def test_train_method(self, mock_train_func): + def test_train_method(self, mock_train_func, mock_test_func): """Test train method.""" trainer = Trainer_General( rank=self.rank, @@ -420,12 +421,16 @@ def test_train_method(self, mock_train_func): self.args.batch_size = 0 # Ensure no batching for this test mock_train_func.return_value = (0.5, 0.85) # loss, accuracy + mock_test_func.return_value = (0.3, 0.9) # loss, accuracy trainer.train(current_global_round=1) - mock_train_func.assert_called() - assert len(trainer.train_losses) > 0 - assert len(trainer.train_accs) > 0 + assert mock_train_func.call_count == trainer.local_step + assert mock_test_func.call_count == trainer.local_step + assert len(trainer.train_losses) == trainer.local_step + assert len(trainer.train_accs) == trainer.local_step + assert len(trainer.test_losses) == trainer.local_step + assert len(trainer.test_accs) == trainer.local_step @patch('fedgraph.trainer_class.test') def test_local_test(self, mock_test_func): From 045ac47ed1b4d4553b5398cd26ecf87862924863 Mon Sep 17 00:00:00 2001 From: Hongyu Chen Date: Fri, 19 Jun 2026 18:05:57 -0400 Subject: [PATCH 4/5] Fix NC hop validation and encrypted feature sharing --- README.md | 2 +- .../save_graph_node_classification.py | 2 +- fedgraph/federated_methods.py | 47 ++++++++++++++++--- fedgraph/server_class.py | 26 ++++++++++ fedgraph/trainer_class.py | 37 --------------- fedgraph/utils_nc.py | 11 ++++- quickstart.py | 2 +- .../integration/test_fedgraph_integration.py | 4 +- tests/unit/test_federated_methods.py | 11 +++++ tests/unit/test_server_class.py | 41 ++++++++++++++-- tests/unit/test_utils_nc.py | 31 +++++++++--- tutorials/FGL_NC.py | 2 +- tutorials/FGL_NC_HE.py | 2 +- 13 files changed, 157 insertions(+), 61 deletions(-) diff --git a/README.md b/README.md index 2ad3798..bdcf0ea 100644 --- a/README.md +++ b/README.md @@ -60,7 +60,7 @@ config = { "batch_size": -1, # -1 indicates full batch training # Model Structure "num_layers": 2, - "num_hops": 1, # Number of n-hop neighbors for client communication + "num_hops": 2, # Supported NC communication modes: 0 for FedAvg, 2 for FedGCN # Resource and Hardware Settings "gpu": False, "num_cpus_per_trainer": 1, diff --git a/docs/dev_script/save_graph_node_classification.py b/docs/dev_script/save_graph_node_classification.py index f438c2d..37c8b49 100644 --- a/docs/dev_script/save_graph_node_classification.py +++ b/docs/dev_script/save_graph_node_classification.py @@ -224,7 +224,7 @@ def run(): parser.add_argument("-n", "--n_trainer", default=5, type=int) parser.add_argument("-g", "--gpu", action="store_true") # if -g, use gpu parser.add_argument("-iid_b", "--iid_beta", default=10000, type=float) -parser.add_argument("-nhop", "--num_hops", default=1, type=int) +parser.add_argument("-nhop", "--num_hops", default=2, type=int) args = parser.parse_args() diff --git a/fedgraph/federated_methods.py b/fedgraph/federated_methods.py index cdca011..a56095c 100644 --- a/fedgraph/federated_methods.py +++ b/fedgraph/federated_methods.py @@ -109,6 +109,19 @@ def _resolve_nc_global_node_num( return sum(int(info["features_num"]) for info in trainer_information) +def _validate_nc_num_hops(args: Any) -> None: + if not hasattr(args, "num_hops"): + return + + if args.num_hops not in (0, 2): + raise ValueError( + "FedGraph NC currently only supports num_hops=0 for FedAvg and " + "num_hops=2 for FedGCN-style training. num_hops=1 is not " + "supported because the current implementation is equivalent to " + "the 2-hop path." + ) + + def run_fedgraph(args: attridict) -> None: """ Run the training process for the specified task. @@ -140,6 +153,9 @@ def run_fedgraph(args: attridict) -> None: "Cannot use both encryption and low-rank compression simultaneously" ) + if args.fedgraph_task == "NC": + _validate_nc_num_hops(args) + # Load data if args.fedgraph_task != "NC" or not args.use_huggingface: data = data_loader(args) @@ -195,6 +211,9 @@ def run_fedgraph_enhanced(args: attridict) -> None: else: print("=== Using Standard FedGraph ===") + if args.fedgraph_task == "NC": + _validate_nc_num_hops(args) + # Load data if args.fedgraph_task != "NC" or not args.use_huggingface: data = data_loader(args) @@ -231,6 +250,8 @@ def run_NC(args: attridict, data: Any = None) -> None: Configuration arguments data: tuple """ + _validate_nc_num_hops(args) + monitor = Monitor(use_cluster=args.use_cluster) monitor.init_time_start() @@ -446,16 +467,26 @@ def get_memory_usage(self): aggregated_result, aggregation_time, ) = server.aggregate_encrypted_feature_sums(encrypted_sums) - agg_size = len(aggregated_result[0]) - load_feature_refs = [ - trainer.load_encrypted_feature_aggregation.remote(aggregated_result) - for trainer in server.trainers - ] + load_feature_refs = [] + download_sizes = [] + for i in range(args.n_trainer): + communicate_nodes = ( + communicate_node_global_indexes[i].clone().detach().to(device) + ) + trainer_aggregation = server.mask_encrypted_feature_sum( + aggregated_result, communicate_nodes + ) + download_sizes.append(len(trainer_aggregation[0])) + load_feature_refs.append( + server.trainers[i].load_encrypted_feature_aggregation.remote( + trainer_aggregation + ) + ) decryption_times = ray.get(load_feature_refs) pretrain_time = time.time() - pretrain_start pretrain_upload = sum(enc_sizes) / (1024 * 1024) # MB - pretrain_download = agg_size * len(server.trainers) / (1024 * 1024) # MB + pretrain_download = sum(download_sizes) / (1024 * 1024) # MB pretrain_comm_cost = pretrain_upload + pretrain_download # print performance metrics @@ -807,6 +838,8 @@ def run_NC_dp(args: attridict, data: Any = None) -> None: """ Enhanced NC training with Differential Privacy support for FedGCN pre-training. """ + _validate_nc_num_hops(args) + monitor = Monitor(use_cluster=args.use_cluster) monitor.init_time_start() @@ -1042,6 +1075,8 @@ def run_NC_lowrank(args: attridict, data: Any = None) -> None: "Low-rank compression modules not available. Please implement the low-rank functionality in fedgraph.low_rank" ) + _validate_nc_num_hops(args) + print("=== Running NC with Low-Rank Compression ===") print(f"Low-rank method: {getattr(args, 'lowrank_method', 'fixed')}") if hasattr(args, "lowrank_method"): diff --git a/fedgraph/server_class.py b/fedgraph/server_class.py index ae639dc..b6d49d4 100644 --- a/fedgraph/server_class.py +++ b/fedgraph/server_class.py @@ -142,6 +142,32 @@ def aggregate_encrypted_feature_sums(self, encrypted_sums): return (first_sum.serialize(), shape), time.time() - aggregation_start + def mask_encrypted_feature_sum(self, encrypted_feature_sum, node_indexes): + encrypted_sum, shape = encrypted_feature_sum + if len(shape) != 2: + raise ValueError("encrypted feature sum shape must be two-dimensional") + + num_nodes = int(shape[0]) + feature_dim = int(shape[1]) + if torch.is_tensor(node_indexes): + selected_nodes = node_indexes.detach().to("cpu").long().flatten() + else: + selected_nodes = torch.as_tensor(node_indexes, dtype=torch.long).flatten() + + if selected_nodes.numel() > 0: + if ( + selected_nodes.min().item() < 0 + or selected_nodes.max().item() >= num_nodes + ): + raise ValueError("node_indexes contains values outside feature sum shape") + + mask = torch.zeros((num_nodes, feature_dim), dtype=torch.float64) + mask[selected_nodes] = 1.0 + + masked_sum = ts.ckks_vector_from(self.he_context, encrypted_sum) + masked_sum *= mask.flatten().tolist() + return masked_sum.serialize(), shape + def aggregate_encrypted_params(self, encrypted_params_list): aggregation_start = time.time() diff --git a/fedgraph/trainer_class.py b/fedgraph/trainer_class.py index 5c335fd..78ac134 100644 --- a/fedgraph/trainer_class.py +++ b/fedgraph/trainer_class.py @@ -372,43 +372,6 @@ def get_local_feature_sum(self) -> torch.Tensor: return one_hop_neighbor_feature_sum - def get_local_feature_sum_og(self) -> torch.Tensor: - """ - Computes the sum of features of all 1-hop neighbors for each node, used for plain text version. - - Returns - ------- - one_hop_neighbor_feature_sum : torch.Tensor - The sum of features of 1-hop neighbors for each node - """ - - computation_start = time.time() - new_feature_for_trainer = torch.zeros( - self.global_node_num, self.features.shape[1] - ).to(self.device) - new_feature_for_trainer[self.local_node_index] = self.features - one_hop_neighbor_feature_sum = get_1hop_feature_sum( - new_feature_for_trainer, self.adj, self.device - ) - computation_time = time.time() - computation_start - - data_size = ( - one_hop_neighbor_feature_sum.element_size() - * one_hop_neighbor_feature_sum.nelement() - ) - - print(f"Trainer {self.rank} - Computation time: {computation_time:.4f} seconds") - print(f"Trainer {self.rank} - Data size: {data_size / 1024:.2f} KB") - print(f"Trainer {self.rank} - Feature sum statistics:") - print(f"Shape: {one_hop_neighbor_feature_sum.shape}") - print(f"Mean: {one_hop_neighbor_feature_sum.mean().item():.6f}") - print(f"Std: {one_hop_neighbor_feature_sum.std().item():.6f}") - print(f"Min: {one_hop_neighbor_feature_sum.min().item():.6f}") - print(f"Max: {one_hop_neighbor_feature_sum.max().item():.6f}") - print(f"Non-zeros: {(one_hop_neighbor_feature_sum != 0).sum().item()}") - - return one_hop_neighbor_feature_sum, computation_time, data_size - def load_feature_aggregation(self, feature_aggregation: torch.Tensor) -> None: """ Loads the aggregated features into the trainer. Used for plain text version diff --git a/fedgraph/utils_nc.py b/fedgraph/utils_nc.py index 1acf2f7..d529045 100644 --- a/fedgraph/utils_nc.py +++ b/fedgraph/utils_nc.py @@ -298,6 +298,14 @@ def get_in_comm_indexes( edge_indexes_clients : list A list of tensors representing the edges between nodes within each client's subgraph. """ + if L_hop not in (0, 2): + raise ValueError( + "FedGraph NC currently only supports num_hops=0 for FedAvg and " + "num_hops=2 for FedGCN-style training. num_hops=1 is not " + "supported because the current implementation is equivalent to " + "the 2-hop path." + ) + communicate_node_indexes = [] in_com_train_node_indexes = [] edge_indexes_clients = [] @@ -315,7 +323,7 @@ def get_in_comm_indexes( ) del _ del __ - elif L_hop == 1 or L_hop == 2: + elif L_hop == 2: ( communicate_node_index, current_edge_index, @@ -414,6 +422,7 @@ def get_1hop_feature_sum( (num_nodes, num_nodes), ).to(device) summed_features = torch.sparse.mm(adjacency_matrix.float(), node_features) + # Self features are included only if edge_index already contains self-loops. else: for node in range(num_nodes): neighbor_indices = torch.where( diff --git a/quickstart.py b/quickstart.py index 31d14da..db519df 100644 --- a/quickstart.py +++ b/quickstart.py @@ -35,7 +35,7 @@ "batch_size": -1, # -1 indicates full batch training # Model Structure "num_layers": 2, - "num_hops": 1, # Number of n-hop neighbors for client communication + "num_hops": 2, # Supported NC communication modes: 0 for FedAvg, 2 for FedGCN # Resource and Hardware Settings "gpu": False, "num_cpus_per_trainer": 1, diff --git a/tests/integration/test_fedgraph_integration.py b/tests/integration/test_fedgraph_integration.py index a3fc5f3..73fef28 100644 --- a/tests/integration/test_fedgraph_integration.py +++ b/tests/integration/test_fedgraph_integration.py @@ -140,7 +140,7 @@ def test_nc_server_aggregation(self, mock_aggre_gcn): # Create server args = Mock() - args.num_hops = 1 + args.num_hops = 2 args.dataset = "cora" args.num_layers = 2 @@ -478,7 +478,7 @@ def test_data_trainer_server_interaction(self): class_num=3, device=torch.device('cpu'), trainers=[trainer], - args=Mock(num_hops=1, dataset="test", num_layers=2) + args=Mock(num_hops=2, dataset="test", num_layers=2) ) # Test parameter flow diff --git a/tests/unit/test_federated_methods.py b/tests/unit/test_federated_methods.py index 251cb20..b3b9cc8 100644 --- a/tests/unit/test_federated_methods.py +++ b/tests/unit/test_federated_methods.py @@ -91,6 +91,7 @@ def setup_method(self): self.args.method = "FedAvg" self.args.use_encryption = False self.args.use_huggingface = False + self.args.num_hops = 2 @patch('fedgraph.federated_methods.data_loader') @patch('fedgraph.federated_methods.run_NC') @@ -162,6 +163,16 @@ def test_run_fedgraph_lowrank_encryption_conflict(self): with pytest.raises(ValueError, match="Cannot use both encryption and low-rank compression simultaneously"): run_fedgraph(self.args) + + @patch('fedgraph.federated_methods.data_loader') + def test_run_fedgraph_rejects_unsupported_nc_num_hops(self, mock_data_loader): + """Test that ambiguous 1-hop NC mode is rejected before data loading.""" + self.args.num_hops = 1 + + with pytest.raises(ValueError, match="num_hops=1 is not supported"): + run_fedgraph(self.args) + + mock_data_loader.assert_not_called() @patch('fedgraph.federated_methods.data_loader') @patch('fedgraph.federated_methods.run_NC') diff --git a/tests/unit/test_server_class.py b/tests/unit/test_server_class.py index 85bb29b..9ad5c18 100644 --- a/tests/unit/test_server_class.py +++ b/tests/unit/test_server_class.py @@ -33,14 +33,14 @@ def setup_method(self): # Mock args self.args = Mock() - self.args.num_hops = 1 + self.args.num_hops = 2 self.args.dataset = "cora" self.args.num_layers = 2 self.args.method = "FedAvg" @patch('fedgraph.server_class.AggreGCN') def test_server_init_with_hops(self, mock_aggre_gcn): - """Test Server initialization with num_hops >= 1.""" + """Test Server initialization with FedGCN-style num_hops.""" mock_model = Mock() mock_aggre_gcn.return_value = mock_model mock_model.to.return_value = mock_model @@ -308,6 +308,41 @@ def serialize(self): assert metadata == [{"shape": torch.Size([1]), "scale": 1000.0}] assert aggregation_time == 3.0 + def test_mask_encrypted_feature_sum_keeps_only_selected_rows(self): + """Test encrypted feature sums are masked before trainer download.""" + server = Server.__new__(Server) + server.he_context = object() + + class FakeCKKSVector: + def __init__(self, values): + self.values = [float(value) for value in values] + + def __imul__(self, mask): + self.values = [ + value * float(mask_value) + for value, mask_value in zip(self.values, mask) + ] + return self + + def serialize(self): + return self.values + + encrypted_feature_sum = ( + [1.0, 2.0, 3.0, 4.0, 5.0, 6.0], + torch.Size([3, 2]), + ) + + with patch( + "fedgraph.server_class.ts.ckks_vector_from", + side_effect=lambda _context, payload: FakeCKKSVector(payload), + ): + masked_sum, shape = server.mask_encrypted_feature_sum( + encrypted_feature_sum, torch.tensor([0, 2]) + ) + + assert masked_sum == [1.0, 2.0, 0.0, 0.0, 5.0, 6.0] + assert shape == torch.Size([3, 2]) + def test_aggregate_encrypted_params_rejects_mismatched_shapes(self): """Test encrypted averaging fails clearly for incompatible layers.""" server = Server.__new__(Server) @@ -564,7 +599,7 @@ def test_server_trainer_interaction(self, mock_gcn_class): device = torch.device('cpu') args = Mock() - args.num_hops = 1 + args.num_hops = 2 args.dataset = "cora" args.num_layers = 2 args.method = "FedAvg" diff --git a/tests/unit/test_utils_nc.py b/tests/unit/test_utils_nc.py index 11d8b46..53a9cac 100644 --- a/tests/unit/test_utils_nc.py +++ b/tests/unit/test_utils_nc.py @@ -120,7 +120,7 @@ def test_uploads_global_metadata(self, mock_hf_api, mock_get_token): args = Mock( dataset="cora", n_trainer=2, - num_hops=1, + num_hops=2, iid_beta=0.5, ) @@ -296,7 +296,7 @@ def test_get_in_comm_indexes_basic(self): ] n_trainers = 3 - num_hops = 1 + num_hops = 2 idx_train = torch.tensor([0, 2, 4]) idx_test = torch.tensor([1, 3]) @@ -309,14 +309,31 @@ def test_get_in_comm_indexes_basic(self): (communicate_node_global_indexes, in_com_train_node_local_indexes, in_com_test_node_local_indexes, global_edge_indexes_clients) = result - assert isinstance(communicate_node_global_indexes, dict) - assert isinstance(in_com_train_node_local_indexes, dict) - assert isinstance(in_com_test_node_local_indexes, dict) - assert isinstance(global_edge_indexes_clients, dict) + assert isinstance(communicate_node_global_indexes, list) + assert isinstance(in_com_train_node_local_indexes, list) + assert isinstance(in_com_test_node_local_indexes, list) + assert isinstance(global_edge_indexes_clients, list) assert len(communicate_node_global_indexes) == n_trainers assert len(global_edge_indexes_clients) == n_trainers + def test_get_in_comm_indexes_rejects_unsupported_one_hop(self): + """Test that the old ambiguous 1-hop NC mode is rejected.""" + edge_index = torch.tensor([[0, 1], [1, 2]]) + split_node_indexes = [torch.tensor([0, 1])] + idx_train = torch.tensor([0]) + idx_test = torch.tensor([1]) + + with pytest.raises(ValueError, match="num_hops=1 is not supported"): + get_in_comm_indexes( + edge_index, + split_node_indexes, + 1, + 1, + idx_train, + idx_test, + ) + class TestGet1hopFeatureSum: """Test get_1hop_feature_sum function.""" @@ -452,7 +469,7 @@ def test_graph_communication_workflow(self): ]) n_trainers = 3 - num_hops = 1 + num_hops = 2 idx_train = torch.arange(0, 7) idx_test = torch.arange(7, 10) diff --git a/tutorials/FGL_NC.py b/tutorials/FGL_NC.py index ab17311..b552c2b 100644 --- a/tutorials/FGL_NC.py +++ b/tutorials/FGL_NC.py @@ -33,7 +33,7 @@ "batch_size": -1, # -1 indicates full batch training # Model Structure "num_layers": 2, - "num_hops": 1, # Number of n-hop neighbors for client communication + "num_hops": 2, # Supported NC communication modes: 0 for FedAvg, 2 for FedGCN # Resource and Hardware Settings "gpu": False, "num_cpus_per_trainer": 1, diff --git a/tutorials/FGL_NC_HE.py b/tutorials/FGL_NC_HE.py index 686a9c4..132e7a0 100644 --- a/tutorials/FGL_NC_HE.py +++ b/tutorials/FGL_NC_HE.py @@ -33,7 +33,7 @@ "batch_size": -1, # -1 indicates full batch training # Model Structure "num_layers": 2, - "num_hops": 1, # Number of n-hop neighbors for client communication + "num_hops": 2, # Supported NC communication modes: 0 for FedAvg, 2 for FedGCN # Resource and Hardware Settings "gpu": False, "num_cpus_per_trainer": 1, From 2891fafe951ff0ecedadd209a0a362c85989f215 Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Tue, 23 Jun 2026 22:42:47 +0000 Subject: [PATCH 5/5] Fix OpenFHE smoke test compatibility --- pytest.ini | 3 ++- test_openfhe_smoke.py | 4 ++-- tests/test_smoke_e2e.py | 2 +- tests/test_threshold_ckks_min.py | 24 ++++++++++++------------ 4 files changed, 17 insertions(+), 16 deletions(-) diff --git a/pytest.ini b/pytest.ini index 99b8fc9..f355f09 100644 --- a/pytest.ini +++ b/pytest.ini @@ -10,7 +10,8 @@ markers = slow: Slow running tests gpu: Tests requiring GPU ray: Tests requiring Ray cluster + timeout: Tests with explicit runtime timeout filterwarnings = ignore::DeprecationWarning ignore::UserWarning - ignore::PendingDeprecationWarning \ No newline at end of file + ignore::PendingDeprecationWarning diff --git a/test_openfhe_smoke.py b/test_openfhe_smoke.py index 3403baf..45575e6 100644 --- a/test_openfhe_smoke.py +++ b/test_openfhe_smoke.py @@ -22,8 +22,8 @@ def test_basic_ckks(): print("✅ Context created") # Enable basic features - cc.Enable(openfhe.PKESchemeFeature.PKE) - cc.Enable(openfhe.PKESchemeFeature.SHE) + cc.Enable(openfhe.PKE) + cc.Enable(openfhe.LEVELEDSHE) print("✅ Features enabled") # Generate keys diff --git a/tests/test_smoke_e2e.py b/tests/test_smoke_e2e.py index 962fa79..1974033 100644 --- a/tests/test_smoke_e2e.py +++ b/tests/test_smoke_e2e.py @@ -50,7 +50,7 @@ def _base_cora_config(**overrides): "n_trainer": 2, "batch_size": -1, "num_layers": 2, - "num_hops": 1, + "num_hops": 2, "gpu": False, "num_cpus_per_trainer": 1, "num_gpus_per_trainer": 0, diff --git a/tests/test_threshold_ckks_min.py b/tests/test_threshold_ckks_min.py index 7ee6bf0..4995b7f 100644 --- a/tests/test_threshold_ckks_min.py +++ b/tests/test_threshold_ckks_min.py @@ -11,8 +11,14 @@ def make_cc(): params.SetFirstModSize(60) params.SetScalingTechnique(openfhe.FLEXIBLEAUTOEXT) cc = openfhe.GenCryptoContext(params) - for f in ("PKE", "SHE", "LEVELEDSHE", "MULTIPARTY"): - cc.Enable(getattr(openfhe.PKESchemeFeature, f)) + for feature in ( + openfhe.PKE, + openfhe.KEYSWITCH, + openfhe.LEVELEDSHE, + openfhe.ADVANCEDSHE, + openfhe.MULTIPARTY, + ): + cc.Enable(feature) return cc def test_two_party_threshold_ckks_add(): @@ -23,22 +29,16 @@ def test_two_party_threshold_ckks_add(): pk0 = kp_lead.publicKey sk0 = kp_lead.secretKey - # Non-lead + # Non-lead. Its public key is the joint public key. kp_main = cc.MultipartyKeyGen(pk0) - pk1 = kp_main.publicKey + joint_pk = kp_main.publicKey sk1 = kp_main.secretKey - # Finalize joint PK on lead - kp_final = cc.MultipartyKeyGen(pk1) - joint_pk = kp_final.publicKey - # Data x = [0.1, 0.2, 0.3] y = [0.05, 0.1, 0.15] - scale = 2**50 - pt_x = cc.MakeCKKSPackedPlaintext(x, scale) - pt_y = cc.MakeCKKSPackedPlaintext(y, scale) - + pt_x = cc.MakeCKKSPackedPlaintext(x) + pt_y = cc.MakeCKKSPackedPlaintext(y) ct_x = cc.Encrypt(joint_pk, pt_x) ct_y = cc.Encrypt(joint_pk, pt_y) ct_sum = cc.EvalAdd(ct_x, ct_y)