diff --git a/DELTA_HPC_SETUP.md b/DELTA_HPC_SETUP.md index 40c988d..2b41090 100644 --- a/DELTA_HPC_SETUP.md +++ b/DELTA_HPC_SETUP.md @@ -388,4 +388,3 @@ sbatch array_experiment.slurm - `README_OPENFHE.md` - Implementation details - `OPENFHE_NC_IMPLEMENTATION.md` - Technical documentation - Delta docs: https://docs.ncsa.illinois.edu/systems/delta/ - diff --git a/OPENFHE_NC_IMPLEMENTATION.md b/OPENFHE_NC_IMPLEMENTATION.md index c05543c..740a75b 100644 --- a/OPENFHE_NC_IMPLEMENTATION.md +++ b/OPENFHE_NC_IMPLEMENTATION.md @@ -213,42 +213,42 @@ Total Pre-training Communication Cost: X.XX MB ``` - Two-Party Threshold Setup - - - 1. Server (Lead) → generate_lead_keys() - - Holds secret share 1 - - Generates initial public key - - 2. Trainer 0 (Non-lead) → generate_nonlead_share() - - Holds secret share 2 - - Contributes to joint public key - - 3. Server → finalize_joint_public_key() - - Creates final joint public key - - 4. All Trainers → set_public_key() - - Receive joint public key for encryption - - - - - Encrypted Feature Aggregation - - - Trainer 0, 1, ..., N - ↓ encrypt(local_feature_sum) - [ct_0, ct_1, ..., ct_N] - ↓ - Server: ct_sum = ct_0 + ct_1 + ... + ct_N - ↓ - Server: partial_lead = partial_decrypt(ct_sum) - Trainer 0: partial_main = partial_decrypt(ct_sum) - ↓ - Server: result = fuse(partial_lead, partial_main) - ↓ - All Trainers receive decrypted aggregated features - + Two-Party Threshold Setup + + + 1. Server (Lead) → generate_lead_keys() + - Holds secret share 1 + - Generates initial public key + + 2. Trainer 0 (Non-lead) → generate_nonlead_share() + - Holds secret share 2 + - Contributes to joint public key + + 3. Server → finalize_joint_public_key() + - Creates final joint public key + + 4. All Trainers → set_public_key() + - Receive joint public key for encryption + + + + + Encrypted Feature Aggregation + + + Trainer 0, 1, ..., N + ↓ encrypt(local_feature_sum) + [ct_0, ct_1, ..., ct_N] + ↓ + Server: ct_sum = ct_0 + ct_1 + ... + ct_N + ↓ + Server: partial_lead = partial_decrypt(ct_sum) + Trainer 0: partial_main = partial_decrypt(ct_sum) + ↓ + Server: result = fuse(partial_lead, partial_main) + ↓ + All Trainers receive decrypted aggregated features + ``` @@ -267,7 +267,7 @@ config = { } ``` -**Important**: +**Important**: - Requires `n_trainer >= 2` (one for server's counterpart) - Only works with `method="FedGCN"` (FedAvg support coming soon) - Pretrain phase only (training phase encryption TBD) @@ -347,12 +347,12 @@ pip install -r docker_requirements.txt ## Status **Implementation Complete** -- Two-party threshold key generation: -- Encrypted feature aggregation: -- Threshold decryption: -- Integration with FedGCN NC pretrain: -- Docker support: -- Tests: +- Two-party threshold key generation: +- Encrypted feature aggregation: +- Threshold decryption: +- Integration with FedGCN NC pretrain: +- Docker support: +- Tests: ⏳ **Testing Required** - [ ] Run integration test with OpenFHE installed @@ -364,4 +364,3 @@ pip install -r docker_requirements.txt - [ ] Training phase encryption (gradient aggregation) - [ ] FedAvg method support - [ ] Multi-party (>2) threshold support - diff --git a/QUICKSTART_DOCKER.md b/QUICKSTART_DOCKER.md index 3972bfd..42030e7 100644 --- a/QUICKSTART_DOCKER.md +++ b/QUICKSTART_DOCKER.md @@ -126,7 +126,7 @@ config = { "method": "FedGCN", "use_encryption": True, "he_backend": "openfhe", - + # Customize these: "dataset": "citeseer", # Options: cora, citeseer, pubmed "num_trainers": 5, # Number of federated clients @@ -180,4 +180,3 @@ Remove Docker image (to free space): ```bash docker rmi fedgraph-openfhe ``` - diff --git a/README.md b/README.md index b8693e1..b7f329c 100644 --- a/README.md +++ b/README.md @@ -96,7 +96,7 @@ config = { "batch_size": -1, # -1 indicates full batch training # Model Structure "num_layers": 2, - "num_hops": 1, # Number of n-hop neighbors for client communication + "num_hops": 2, # Supported NC communication modes: 0 for FedAvg, 2 for FedGCN # Resource and Hardware Settings "gpu": False, "num_cpus_per_trainer": 1, diff --git a/README_OPENFHE.md b/README_OPENFHE.md index 3fa25c8..d88df07 100644 --- a/README_OPENFHE.md +++ b/README_OPENFHE.md @@ -2,10 +2,10 @@ ## What Was Accomplished - **Implemented secure two-party threshold homomorphic encryption** for NC FedGCN pretrain - **Neither server nor any single trainer can decrypt alone** - **All code verified and documented** (1,800+ lines of documentation) - **Parameters optimized for < 1% accuracy loss** + **Implemented secure two-party threshold homomorphic encryption** for NC FedGCN pretrain + **Neither server nor any single trainer can decrypt alone** + **All code verified and documented** (1,800+ lines of documentation) + **Parameters optimized for < 1% accuracy loss** --- @@ -33,7 +33,7 @@ Server: Has full secret key → Can decrypt alone INSECURE **After (OpenFHE Threshold)**: ``` -Server: Has secret_share_1 +Server: Has secret_share_1 Trainer0: Has secret_share_2 → Both required SECURE ``` @@ -71,7 +71,7 @@ ring_dim = 16384 # 128-bit security scale = 2**50 # Good precision (< 1% error) multiplicative_depth = 2 # Sufficient for additions scaling_mod_size = 59 # Matches scale -first_mod_size = 60 # Matches scale +first_mod_size = 60 # Matches scale scaling_technique = FLEXIBLEAUTOEXT # Automatic rescaling ``` @@ -129,7 +129,7 @@ params.SetMultiplicativeDepth(1) # From 2 ## Testing Status -### Completed +### Completed - Code structure verification (5/5 tests passed) - Method signature verification - Two-party protocol verification @@ -140,11 +140,11 @@ params.SetMultiplicativeDepth(1) # From 2 - Full end-to-end runtime test (blocked by torch-geometric dependencies) - Actual accuracy measurement (can be done after fixing dependencies) -### Confidence +### Confidence - **Implementation Correctness**: 100% (verified) - **Expected Accuracy**: 90% (theoretical analysis) - **Parameter Optimization**: 95% (CKKS best practices) -- **Overall Confidence**: 90% +- **Overall Confidence**: 90% --- @@ -156,7 +156,7 @@ The implementation is **theoretically sound** and will achieve < 1% accuracy los - Similar work in literature (CKKS with scale 2^50) - Well-established parameter choices -**Action**: Consider implementation complete and production-ready +**Action**: Consider implementation complete and production-ready ### Option 2: Test Locally If you have a working Python environment: @@ -212,19 +212,19 @@ For typical accuracies ~0.8: ## Quick Help -**Q: How do I know it works without running it?** +**Q: How do I know it works without running it?** A: All code structure is verified + theoretical analysis confirms < 1% loss. Very high confidence. -**Q: Should I tune parameters?** +**Q: Should I tune parameters?** A: No, current parameters are optimal. Only tune if you observe > 2% loss in actual testing. -**Q: Is it secure?** +**Q: Is it secure?** A: Yes! Two-party threshold means neither server nor any single trainer can decrypt alone. -**Q: What if I need better accuracy?** +**Q: What if I need better accuracy?** A: Increase `scale = 2**55` for < 0.5% loss (see `PARAMETER_TUNING_GUIDE.md`). -**Q: What if I need faster speed?** +**Q: What if I need faster speed?** A: Decrease `scale = 2**45` for 1.5x speedup (see `PARAMETER_TUNING_GUIDE.md`). --- @@ -242,7 +242,6 @@ A: Decrease `scale = 2**45` for 1.5x speedup (see `PARAMETER_TUNING_GUIDE.md`). --- -**Date**: October 2, 2025 -**Status**: **COMPLETE & READY** +**Date**: October 2, 2025 +**Status**: **COMPLETE & READY** **Next Step**: Optional - Fix dependencies and run end-to-end test to confirm theoretical predictions - diff --git a/RUNTIME_OPTIONS.md b/RUNTIME_OPTIONS.md index 2432aab..0fe5b58 100644 --- a/RUNTIME_OPTIONS.md +++ b/RUNTIME_OPTIONS.md @@ -149,4 +149,3 @@ The issue is purely about runtime environment compatibility, not code issues. 3. Update Colab notebook as "for reference only" Let me know which option you want to pursue! - diff --git a/START_HERE.md b/START_HERE.md index 6b01745..c02c856 100644 --- a/START_HERE.md +++ b/START_HERE.md @@ -270,4 +270,3 @@ cat openfhe-JOBID.out | grep "Test Acc" **Need help?** See `DELTA_HPC_SETUP.md` for detailed documentation. **Ready to push to GitHub?** Let me know and I'll help you commit and push all these files to the `gcn_v2` branch. - diff --git a/docs/dev_script/save_graph_node_classification.py b/docs/dev_script/save_graph_node_classification.py index f438c2d..37c8b49 100644 --- a/docs/dev_script/save_graph_node_classification.py +++ b/docs/dev_script/save_graph_node_classification.py @@ -224,7 +224,7 @@ def run(): parser.add_argument("-n", "--n_trainer", default=5, type=int) parser.add_argument("-g", "--gpu", action="store_true") # if -g, use gpu parser.add_argument("-iid_b", "--iid_beta", default=10000, type=float) -parser.add_argument("-nhop", "--num_hops", default=1, type=int) +parser.add_argument("-nhop", "--num_hops", default=2, type=int) args = parser.parse_args() diff --git a/fedgraph/federated_methods.py b/fedgraph/federated_methods.py index c13a455..a6ee7c2 100644 --- a/fedgraph/federated_methods.py +++ b/fedgraph/federated_methods.py @@ -60,6 +60,79 @@ LOWRANK_AVAILABLE = False +def _resolve_nc_class_num( + use_huggingface: bool, + trainer_information: list, + loaded_class_num: Optional[int] = None, +) -> int: + if not use_huggingface: + if loaded_class_num is None: + raise ValueError("class_num is required when NC data is loaded centrally") + return int(loaded_class_num) + + metadata_values = { + int(info["class_num"]) + for info in trainer_information + if info.get("class_num") is not None + } + if len(metadata_values) > 1: + raise ValueError("Hugging Face trainers report inconsistent class_num values") + if metadata_values: + return metadata_values.pop() + + label_nums = [ + int(info["label_num"]) + for info in trainer_information + if info.get("label_num") is not None + ] + if not label_nums: + raise ValueError( + "Cannot infer class_num from Hugging Face trainer data because all " + "train and test label tensors are empty" + ) + return max(label_nums) + + +def _resolve_nc_global_node_num( + use_huggingface: bool, + trainer_information: list, + loaded_global_node_num: Optional[int] = None, +) -> int: + if not use_huggingface: + if loaded_global_node_num is None: + raise ValueError( + "global_node_num is required when NC data is loaded centrally" + ) + return int(loaded_global_node_num) + + metadata_values = { + int(info["global_node_num"]) + for info in trainer_information + if info.get("global_node_num") is not None + } + if len(metadata_values) > 1: + raise ValueError( + "Hugging Face trainers report inconsistent global_node_num values" + ) + if metadata_values: + return metadata_values.pop() + + return sum(int(info["features_num"]) for info in trainer_information) + + +def _validate_nc_num_hops(args: Any) -> None: + if not hasattr(args, "num_hops"): + return + + if args.num_hops not in (0, 2): + raise ValueError( + "FedGraph NC currently only supports num_hops=0 for FedAvg and " + "num_hops=2 for FedGCN-style training. num_hops=1 is not " + "supported because the current implementation is equivalent to " + "the 2-hop path." + ) + + def run_fedgraph(args: attridict) -> None: """ Run the training process for the specified task. @@ -98,6 +171,11 @@ def run_fedgraph(args: attridict) -> None: raise ValueError( "Low-rank compression currently only supported for NC tasks" ) + + if args.fedgraph_task == "NC": + _validate_nc_num_hops(args) + + # Load data if args.fedgraph_task != "NC" or not args.use_huggingface: data = data_loader(args) else: @@ -160,6 +238,9 @@ def run_fedgraph_enhanced(args: attridict) -> None: else: print("=== Using Standard FedGraph ===") + if args.fedgraph_task == "NC": + _validate_nc_num_hops(args) + # Load data if args.fedgraph_task != "NC" or not args.use_huggingface: data = data_loader(args) @@ -196,6 +277,8 @@ def run_NC(args: attridict, data: Any = None) -> None: Configuration arguments data: tuple """ + _validate_nc_num_hops(args) + monitor = Monitor(use_cluster=args.use_cluster) monitor.init_time_start() @@ -245,6 +328,7 @@ def run_NC(args: attridict, data: Any = None) -> None: in_com_train_node_local_indexes=in_com_train_node_local_indexes, in_com_test_node_local_indexes=in_com_test_node_local_indexes, n_trainer=args.n_trainer, + class_num=class_num, args=args, ) @@ -335,8 +419,6 @@ def get_memory_usage(self): Trainer.remote( # type: ignore rank=i, args_hidden=args_hidden, - # global_node_num=len(features), - # class_num=class_num, device=device, args=args, local_node_index=split_node_indexes[i], @@ -351,6 +433,8 @@ def get_memory_usage(self): features=features[split_node_indexes[i]], idx_train=in_com_train_node_local_indexes[i], idx_test=in_com_test_node_local_indexes[i], + global_node_num=len(features), + class_num=class_num, ) for i in range(args.n_trainer) ] @@ -360,8 +444,16 @@ def get_memory_usage(self): ] # Extract necessary details from trainer information - global_node_num = sum([info["features_num"] for info in trainer_information]) - class_num = max([info["label_num"] for info in trainer_information]) + global_node_num = _resolve_nc_global_node_num( + args.use_huggingface, + trainer_information, + None if args.use_huggingface else len(features), + ) + class_num = _resolve_nc_class_num( + args.use_huggingface, + trainer_information, + None if args.use_huggingface else class_num, + ) feature_shape = trainer_information[0]["feature_shape"] train_data_weights = [ @@ -427,12 +519,22 @@ def get_memory_usage(self): aggregated_result, aggregation_time, ) = server.aggregate_encrypted_feature_sums(encrypted_sums) - agg_size = len(aggregated_result[0]) - load_feature_refs = [ - trainer.load_encrypted_feature_aggregation.remote(aggregated_result) - for trainer in server.trainers - ] + load_feature_refs = [] + download_sizes = [] + for i in range(args.n_trainer): + communicate_nodes = ( + communicate_node_global_indexes[i].clone().detach().to(device) + ) + trainer_aggregation = server.mask_encrypted_feature_sum( + aggregated_result, communicate_nodes + ) + download_sizes.append(len(trainer_aggregation[0])) + load_feature_refs.append( + server.trainers[i].load_encrypted_feature_aggregation.remote( + trainer_aggregation + ) + ) decryption_times = ray.get(load_feature_refs) elif getattr(args, "he_backend", "tenseal") == "openfhe": print("Starting OpenFHE threshold encrypted feature aggregation...") @@ -645,7 +747,7 @@ def get_memory_usage(self): raise ValueError(f"Unknown he_backend: {getattr(args, 'he_backend', None)}") pretrain_time = time.time() - pretrain_start pretrain_upload = sum(enc_sizes) / (1024 * 1024) # MB - pretrain_download = agg_size * len(server.trainers) / (1024 * 1024) # MB + pretrain_download = sum(download_sizes) / (1024 * 1024) # MB pretrain_comm_cost = pretrain_upload + pretrain_download # print performance metrics @@ -727,15 +829,9 @@ def get_memory_usage(self): print("global_rounds", args.global_rounds) global_acc_list = [] for i in range(args.global_rounds): - # Pure training phase - forward + gradient descent only - pure_training_start = time.time() - - # Execute only training (forward + gradient descent) - train_refs = [trainer.train.remote(i) for trainer in server.trainers] - ray.get(train_refs) - - pure_training_end = time.time() - round_training_time = pure_training_end - pure_training_start + round_stats = server.train(i) + round_training_time = round_stats["training_time"] + round_comm_time = round_stats["communication_time"] total_pure_training_time += round_training_time # Communication phase - parameter aggregation and broadcast @@ -808,10 +904,9 @@ def get_memory_usage(self): f"Round {i+1}: Training Time = {round_training_time:.2f}s, Communication Time = {round_comm_time:.2f}s" ) - model_size_mb = server.get_model_size() / (1024 * 1024) monitor.add_train_comm_cost( - upload_mb=model_size_mb * args.n_trainer, - download_mb=model_size_mb * args.n_trainer, + upload_mb=round_stats["upload_size"] / (1024 * 1024), + download_mb=round_stats["download_size"] / (1024 * 1024), ) monitor.train_time_end() total_time = time.time() - training_start @@ -865,10 +960,6 @@ def get_memory_usage(self): else: training_upload = training_download = 0 training_comm_cost = training_upload + training_download - monitor.add_train_comm_cost( - upload_mb=training_upload, - download_mb=training_download, - ) print("\nTraining Phase Metrics:") print( f"Total Training Time: {total_pure_training_time:.2f} seconds" @@ -1065,6 +1156,8 @@ def run_NC_dp(args: attridict, data: Any = None) -> None: """ Enhanced NC training with Differential Privacy support for FedGCN pre-training. """ + _validate_nc_num_hops(args) + monitor = Monitor(use_cluster=args.use_cluster) monitor.init_time_start() @@ -1146,6 +1239,8 @@ def __init__(self, *args: Any, **kwds: Any): features=features[split_node_indexes[i]], idx_train=in_com_train_node_local_indexes[i], idx_test=in_com_test_node_local_indexes[i], + global_node_num=len(features), + class_num=class_num, ) for i in range(args.n_trainer) ] @@ -1155,8 +1250,16 @@ def __init__(self, *args: Any, **kwds: Any): ray.get(trainers[i].get_info.remote()) for i in range(len(trainers)) ] - global_node_num = sum([info["features_num"] for info in trainer_information]) - class_num = max([info["label_num"] for info in trainer_information]) + global_node_num = _resolve_nc_global_node_num( + args.use_huggingface, + trainer_information, + None if args.use_huggingface else len(features), + ) + class_num = _resolve_nc_class_num( + args.use_huggingface, + trainer_information, + None if args.use_huggingface else class_num, + ) train_data_weights = [ info["len_in_com_train_node_local_indexes"] for info in trainer_information @@ -1290,6 +1393,8 @@ def run_NC_lowrank(args: attridict, data: Any = None) -> None: "Low-rank compression modules not available. Please implement the low-rank functionality in fedgraph.low_rank" ) + _validate_nc_num_hops(args) + print("=== Running NC with Low-Rank Compression ===") print(f"Low-rank method: {getattr(args, 'lowrank_method', 'fixed')}") if hasattr(args, "lowrank_method"): @@ -1338,6 +1443,7 @@ def run_NC_lowrank(args: attridict, data: Any = None) -> None: in_com_train_node_local_indexes=in_com_train_node_local_indexes, in_com_test_node_local_indexes=in_com_test_node_local_indexes, n_trainer=args.n_trainer, + class_num=class_num, args=args, ) @@ -1395,6 +1501,8 @@ def __init__(self, *args: Any, **kwds: Any): features=features[split_node_indexes[i]], idx_train=in_com_train_node_local_indexes[i], idx_test=in_com_test_node_local_indexes[i], + global_node_num=len(features), + class_num=class_num, ) for i in range(args.n_trainer) ] @@ -1404,8 +1512,16 @@ def __init__(self, *args: Any, **kwds: Any): ray.get(trainers[i].get_info.remote()) for i in range(len(trainers)) ] - global_node_num = sum([info["features_num"] for info in trainer_information]) - class_num = max([info["label_num"] for info in trainer_information]) + global_node_num = _resolve_nc_global_node_num( + args.use_huggingface, + trainer_information, + None if args.use_huggingface else len(features), + ) + class_num = _resolve_nc_class_num( + args.use_huggingface, + trainer_information, + None if args.use_huggingface else class_num, + ) train_data_weights = [ info["len_in_com_train_node_local_indexes"] for info in trainer_information diff --git a/fedgraph/server_class.py b/fedgraph/server_class.py index 806ff09..3b2b635 100644 --- a/fedgraph/server_class.py +++ b/fedgraph/server_class.py @@ -148,28 +148,6 @@ def zero_params(self) -> None: for p in self.model.parameters(): p.zero_() - def prepare_params_for_encryption(self, params): - processed_params = [] - metadata = [] - - for param in params: - param_min = param.min() - param_max = param.max() - param_range = param_max - param_min - - # handle division by 0 - if param_range == 0: - normalized = param - param_min - else: - normalized = (param - param_min) / param_range - - scaled = normalized * 1000 - - processed_params.append(scaled) - metadata.append({"min": param_min, "range": param_range}) - - return processed_params, metadata - def aggregate_encrypted_feature_sums(self, encrypted_sums): """TenSEAL-only entry point. The OpenFHE threshold flow runs in federated_methods.run_NC using the @@ -188,46 +166,111 @@ def _aggregate_tenseal_feature_sums(self, encrypted_sums): return (first_sum.serialize(), shape), time.time() - aggregation_start + def mask_encrypted_feature_sum(self, encrypted_feature_sum, node_indexes): + encrypted_sum, shape = encrypted_feature_sum + if len(shape) != 2: + raise ValueError("encrypted feature sum shape must be two-dimensional") + + num_nodes = int(shape[0]) + feature_dim = int(shape[1]) + if torch.is_tensor(node_indexes): + selected_nodes = node_indexes.detach().to("cpu").long().flatten() + else: + selected_nodes = torch.as_tensor(node_indexes, dtype=torch.long).flatten() + + if selected_nodes.numel() > 0: + if ( + selected_nodes.min().item() < 0 + or selected_nodes.max().item() >= num_nodes + ): + raise ValueError("node_indexes contains values outside feature sum shape") + + mask = torch.zeros((num_nodes, feature_dim), dtype=torch.float64) + mask[selected_nodes] = 1.0 + + masked_sum = ts.ckks_vector_from(self.he_context, encrypted_sum) + masked_sum *= mask.flatten().tolist() + return masked_sum.serialize(), shape + def aggregate_encrypted_params(self, encrypted_params_list): aggregation_start = time.time() - first_params, metadata = encrypted_params_list[0] + if not encrypted_params_list: + raise ValueError("encrypted_params_list must not be empty") + + first_params, _ = encrypted_params_list[0] n_layers = len(first_params) + if n_layers == 0: + raise ValueError("encrypted parameter payload must contain at least one layer") + + layer_shapes = [] + layer_scale_values = [[] for _ in range(n_layers)] + validated_params = [] + + for trainer_idx, (trainer_params, trainer_metadata) in enumerate( + encrypted_params_list + ): + if len(trainer_params) != n_layers: + raise ValueError("all trainers must provide the same number of layers") + if len(trainer_metadata) != n_layers: + raise ValueError("metadata length must match encrypted parameter layers") + + validated_metadata = [] + for layer_idx, metadata in enumerate(trainer_metadata): + shape = metadata.get("shape") + if shape is None: + raise ValueError("encrypted parameter metadata must include shape") + normalized_shape = tuple(shape) + + try: + scale = float(metadata["scale"]) + except (KeyError, TypeError, ValueError) as exc: + raise ValueError( + "encrypted parameter metadata must include numeric scale" + ) from exc + if not np.isfinite(scale) or scale <= 0: + raise ValueError("encrypted parameter scale must be positive") + + if trainer_idx == 0: + layer_shapes.append(shape) + elif normalized_shape != tuple(layer_shapes[layer_idx]): + raise ValueError( + "encrypted parameter shapes must match across trainers" + ) + + layer_scale_values[layer_idx].append(scale) + validated_metadata.append({"shape": shape, "scale": scale}) + + validated_params.append((trainer_params, validated_metadata)) + + aggregation_metadata = [ + {"shape": layer_shapes[layer_idx], "scale": max(layer_scale_values[layer_idx])} + for layer_idx in range(n_layers) + ] - # each layer aggregated_params = [] for layer_idx in range(n_layers): - agg_layer = ts.ckks_vector_from( - self.he_context, encrypted_params_list[0][0][layer_idx] - ) + common_scale = aggregation_metadata[layer_idx]["scale"] + agg_layer = None - for trainer_params, _ in encrypted_params_list[1:]: + for trainer_params, trainer_metadata in validated_params: + trainer_scale = trainer_metadata[layer_idx]["scale"] next_layer = ts.ckks_vector_from( self.he_context, trainer_params[layer_idx] ) - agg_layer += next_layer + next_layer *= common_scale / trainer_scale + + if agg_layer is None: + agg_layer = next_layer + else: + agg_layer += next_layer # average - agg_layer *= 1.0 / self.num_of_trainers + agg_layer *= 1.0 / len(validated_params) aggregated_params.append(agg_layer.serialize()) aggregation_time = time.time() - aggregation_start - return aggregated_params, metadata, aggregation_time - - def get_encrypted_params(self): - params = [p.data.cpu().detach() for p in self.model.parameters()] - - # normalize and scale - processed_params, metadata = self.prepare_params_for_encryption(params) - - encrypted_params = [] - for param in processed_params: - param_list = param.flatten().tolist() - - encrypted = ts.ckks_vector(self.he_context, param_list).serialize() - encrypted_params.append(encrypted) - - return encrypted_params, metadata + return aggregated_params, aggregation_metadata, aggregation_time @torch.no_grad() def train( @@ -235,17 +278,21 @@ def train( current_global_epoch: int, sampling_type: str = "random", sample_ratio: float = 1, - ) -> None: + ) -> dict: """ - Training round which performs aggregating parameters from sampled trainers (by index), - updating the central model, and then broadcasting the updated parameters - back to all trainers. + Run one federated training round and return timing and communication statistics. Parameters ---------- current_global_epoch : int The current global epoch number during the federated learning process. + + Returns + ------- + dict + Per-round training time, communication time, and transfer sizes. """ + training_start = time.time() # Restrict the encrypted parameter aggregation path to the TenSEAL # backend. The OpenFHE threshold flow uses the file-based protocol # in federated_methods.run_NC and must not enter this branch. @@ -257,6 +304,9 @@ def train( trainer.train.remote(current_global_epoch) for trainer in self.trainers ] ray.get(train_refs) + training_time = time.time() - training_start + + communication_start = time.time() encryption_start = time.time() print("Starting encrypted parameter aggregation...") encrypted_params = [ @@ -265,7 +315,6 @@ def train( # Wait for all trainers and collect parameters params_list = [] - encryption_times = [] enc_sizes = [] while encrypted_params: ready, encrypted_params = ray.wait(encrypted_params) @@ -285,7 +334,6 @@ def train( agg_size = sum(len(p) for p in aggregated_params) # Distribute back to trainers - decryption_start = time.time() decrypt_refs = [ trainer.load_encrypted_params.remote( (aggregated_params, metadata), current_global_epoch @@ -293,24 +341,33 @@ def train( for trainer in self.trainers ] decryption_times = ray.get(decrypt_refs) + communication_time = time.time() - communication_start round_metrics = { + "training_time": training_time, + "communication_time": communication_time, "encryption_time": encryption_time, "decryption_times": decryption_times, "aggregation_time": agg_time, "upload_size": sum(enc_sizes), "download_size": agg_size * len(self.trainers), + "num_trainers": len(self.trainers), } self.aggregation_stats.append(round_metrics) + return round_metrics else: # normal training logic # print( # f"Training round: {current_global_epoch}, sampling rate: {sample_ratio}" # ) assert 0 < sample_ratio <= 1, "Sample ratio must be between 0 and 1" + if sampling_type not in {"random", "uniform"}: + raise ValueError("sampling_type must be either 'random' or 'uniform'") num_samples = int(self.num_of_trainers * sample_ratio) - if sampling_type == "random": + if sample_ratio == 1: + selected_trainers_indices = list(range(self.num_of_trainers)) + elif sampling_type == "random": selected_trainers_indices = random.sample( range(self.num_of_trainers), num_samples ) @@ -325,11 +382,14 @@ def train( for i in range(num_samples) ] - else: - raise ValueError("sampling_type must be either 'random' or 'uniform'") - - for trainer_idx in selected_trainers_indices: + train_refs = [ self.trainers[trainer_idx].train.remote(current_global_epoch) + for trainer_idx in selected_trainers_indices + ] + ray.get(train_refs) + training_time = time.time() - training_start + + communication_start = time.time() params = [ self.trainers[trainer_idx].get_params.remote() @@ -355,6 +415,15 @@ def train( p /= num_samples self.broadcast_params(current_global_epoch) + communication_time = time.time() - communication_start + model_size = self.get_model_size() + return { + "training_time": training_time, + "communication_time": communication_time, + "upload_size": model_size * num_samples, + "download_size": model_size * self.num_of_trainers, + "num_trainers": num_samples, + } def broadcast_params(self, current_global_epoch: int) -> None: """ @@ -365,10 +434,13 @@ def broadcast_params(self, current_global_epoch: int) -> None: current_global_epoch : int The current global epoch number during the federated learning process. """ - for trainer in self.trainers: + update_refs = [ trainer.update_params.remote( tuple(self.model.parameters()), current_global_epoch - ) # run in submit order + ) + for trainer in self.trainers + ] + ray.get(update_refs) def get_model_size(self) -> float: """Return total model parameter size in bytes (assumes float32).""" diff --git a/fedgraph/trainer_class.py b/fedgraph/trainer_class.py index eb210c2..b3c3721 100644 --- a/fedgraph/trainer_class.py +++ b/fedgraph/trainer_class.py @@ -2,6 +2,7 @@ import os import random import time +import warnings from io import BytesIO from typing import Any, Dict, List, Union @@ -14,6 +15,7 @@ import torch.nn.functional as F import torch_geometric from huggingface_hub import hf_hub_download +from huggingface_hub.errors import EntryNotFoundError from torch_geometric.data import Data from torch_geometric.loader import NeighborLoader from torchmetrics.functional.retrieval import retrieval_auroc @@ -50,10 +52,15 @@ def load_trainer_data_from_hugging_face(trainer_id, args): repo_name = f"FedGraph/fedgraph_{args.dataset}_{args.n_trainer}trainer_{args.num_hops}hop_iid_beta_{args.iid_beta}_trainer_id_{trainer_id}" - def download_and_load_tensor(file_name): - file_path = hf_hub_download( - repo_id=repo_name, repo_type="dataset", filename=file_name - ) + def download_and_load_tensor(file_name, optional=False): + try: + file_path = hf_hub_download( + repo_id=repo_name, repo_type="dataset", filename=file_name + ) + except EntryNotFoundError: + if optional: + return None + raise with open(file_path, "rb") as f: buffer = BytesIO(f.read()) tensor = torch.load(buffer, weights_only=False) @@ -71,6 +78,14 @@ def download_and_load_tensor(file_name): features = download_and_load_tensor("features.pt") in_com_train_node_local_indexes = download_and_load_tensor("idx_train.pt") in_com_test_node_local_indexes = download_and_load_tensor("idx_test.pt") + global_node_num = download_and_load_tensor("global_node_num.pt", optional=True) + class_num = download_and_load_tensor("class_num.pt", optional=True) + if global_node_num is None or class_num is None: + warnings.warn( + "Hugging Face trainer data does not contain global metadata; " + "falling back to inference for compatibility with existing repositories.", + UserWarning, + ) return ( local_node_index, communicate_node_global_index, @@ -80,6 +95,8 @@ def download_and_load_tensor(file_name): features, in_com_train_node_local_indexes, in_com_test_node_local_indexes, + global_node_num, + class_num, ) @@ -145,6 +162,8 @@ def __init__( features: torch.Tensor = None, idx_train: torch.Tensor = None, idx_test: torch.Tensor = None, + global_node_num: int = None, + class_num: int = None, ): # from gnn_models import GCN_Graph_Classification # Per-trainer seed = global_seed * 1000 + rank (lets us vary across runs @@ -177,6 +196,8 @@ def __init__( features, idx_train, idx_test, + global_node_num, + class_num, ) = load_trainer_data_from_hugging_face(rank, args) self.rank = rank # rank = trainer ID @@ -207,22 +228,30 @@ def __init__( self.args = args self.model = None self.optimizer = None - self.global_node_num = None - self.class_num = None + self.global_node_num = ( + int(global_node_num.item()) + if isinstance(global_node_num, torch.Tensor) + else global_node_num + ) + self.class_num = ( + int(class_num.item()) if isinstance(class_num, torch.Tensor) else class_num + ) self.feature_aggregation = None if self.args.method == "FedAvg": # print("Loading feature as the feature aggregation for fedavg method") self.feature_aggregation = self.features def get_info(self): - # assert self.train_labels.numel() > 0, "train_labels is empty" - # assert self.test_labels.numel() > 0, "test_labels is empty" + label_nums = [ + int(labels.max().item()) + 1 + for labels in (self.train_labels, self.test_labels) + if labels.numel() > 0 + ] return { "features_num": len(self.features), - "label_num": max( - self.train_labels.max().item(), self.test_labels.max().item() - ) - + 1, + "label_num": max(label_nums, default=None), + "global_node_num": self.global_node_num, + "class_num": self.class_num, "feature_shape": self.features.shape[1], "len_in_com_train_node_local_indexes": len(self.idx_train), "len_in_com_test_node_local_indexes": len(self.idx_test), @@ -342,14 +371,14 @@ def get_local_feature_sum(self) -> torch.Tensor: normalized_sum : torch.Tensor The normalized sum of features of 1-hop neighbors for each node """ - # Use global_node_num if available, otherwise infer from communicate_node_index - global_node_num = getattr(self, 'global_node_num', None) - if global_node_num is None: - global_node_num = self.communicate_node_index.max().item() + 1 - + if self.global_node_num is None: + raise RuntimeError( + "Trainer model metadata must be initialized before feature aggregation" + ) + # Create a large matrix with known local node features new_feature_for_trainer = torch.zeros( - global_node_num, self.features.shape[1] + self.global_node_num, self.features.shape[1] ).to(self.device) new_feature_for_trainer[self.local_node_index] = self.features @@ -857,10 +886,7 @@ def train(self, current_global_round: int) -> None: self.train_losses.append(loss_train) self.train_accs.append(acc_train) # print(f"acc_train: {acc_train}") - loss_test, acc_test = self.local_test() - self.test_losses.append(loss_test) - self.test_accs.append(acc_test) - # print(f"current round: {current_global_round}, acc_test: {acc_test}") + self.local_test() def local_test(self) -> list: """ diff --git a/fedgraph/utils_nc.py b/fedgraph/utils_nc.py index e3aa234..de4d24d 100644 --- a/fedgraph/utils_nc.py +++ b/fedgraph/utils_nc.py @@ -10,8 +10,7 @@ import scipy.sparse as sp import torch import torch_geometric - -# from huggingface_hub import HfApi, HfFolder, hf_hub_download, upload_file +from huggingface_hub import HfApi, get_token def normalize(mx: sp.csc_matrix) -> sp.csr_matrix: @@ -299,6 +298,14 @@ def get_in_comm_indexes( edge_indexes_clients : list A list of tensors representing the edges between nodes within each client's subgraph. """ + if L_hop not in (0, 2): + raise ValueError( + "FedGraph NC currently only supports num_hops=0 for FedAvg and " + "num_hops=2 for FedGCN-style training. num_hops=1 is not " + "supported because the current implementation is equivalent to " + "the 2-hop path." + ) + communicate_node_indexes = [] in_com_train_node_indexes = [] edge_indexes_clients = [] @@ -316,7 +323,7 @@ def get_in_comm_indexes( ) del _ del __ - elif L_hop == 1 or L_hop == 2: + elif L_hop == 2: ( communicate_node_index, current_edge_index, @@ -452,6 +459,7 @@ def get_1hop_feature_sum( (num_nodes, num_nodes), ).to(device) summed_features = torch.sparse.mm(adjacency_matrix.float(), node_features) + # Self features are included only if edge_index already contains self-loops. else: for node in range(num_nodes): neighbor_indices = torch.where( @@ -504,10 +512,12 @@ def save_trainer_data_to_hugging_face( features, in_com_train_node_local_indexes, in_com_test_node_local_indexes, + global_node_num, + class_num, args, ): repo_name = f"FedGraph/fedgraph_{args.dataset}_{args.n_trainer}trainer_{args.num_hops}hop_iid_beta_{args.iid_beta}_trainer_id_{trainer_id}" - user = HfFolder.get_token() + user = get_token() api = HfApi() try: @@ -538,6 +548,8 @@ def save_tensor_to_hf(tensor, file_name): save_tensor_to_hf(features, "features.pt") save_tensor_to_hf(in_com_train_node_local_indexes, "idx_train.pt") save_tensor_to_hf(in_com_test_node_local_indexes, "idx_test.pt") + save_tensor_to_hf(torch.tensor(global_node_num), "global_node_num.pt") + save_tensor_to_hf(torch.tensor(class_num), "class_num.pt") print(f"Uploaded data for trainer {trainer_id}") @@ -551,8 +563,10 @@ def save_all_trainers_data( in_com_train_node_local_indexes, in_com_test_node_local_indexes, n_trainer, + class_num, args, ): + global_node_num = len(features) for i in range(n_trainer): save_trainer_data_to_hugging_face( trainer_id=i, @@ -568,5 +582,7 @@ def save_all_trainers_data( features=features[split_node_indexes[i]], in_com_train_node_local_indexes=in_com_train_node_local_indexes[i], in_com_test_node_local_indexes=in_com_test_node_local_indexes[i], + global_node_num=global_node_num, + class_num=class_num, args=args, ) diff --git a/pytest.ini b/pytest.ini index 99b8fc9..f355f09 100644 --- a/pytest.ini +++ b/pytest.ini @@ -10,7 +10,8 @@ markers = slow: Slow running tests gpu: Tests requiring GPU ray: Tests requiring Ray cluster + timeout: Tests with explicit runtime timeout filterwarnings = ignore::DeprecationWarning ignore::UserWarning - ignore::PendingDeprecationWarning \ No newline at end of file + ignore::PendingDeprecationWarning diff --git a/quickstart.py b/quickstart.py index 31d14da..db519df 100644 --- a/quickstart.py +++ b/quickstart.py @@ -35,7 +35,7 @@ "batch_size": -1, # -1 indicates full batch training # Model Structure "num_layers": 2, - "num_hops": 1, # Number of n-hop neighbors for client communication + "num_hops": 2, # Supported NC communication modes: 0 for FedAvg, 2 for FedGCN # Resource and Hardware Settings "gpu": False, "num_cpus_per_trainer": 1, diff --git a/run_docker_openfhe.sh b/run_docker_openfhe.sh index 024d107..0e9a804 100755 --- a/run_docker_openfhe.sh +++ b/run_docker_openfhe.sh @@ -24,4 +24,4 @@ docker run -it --rm \ fedgraph-openfhe \ /bin/bash -echo "👋 Container stopped." \ No newline at end of file +echo "👋 Container stopped." \ No newline at end of file diff --git a/run_openfhe_delta.slurm b/run_openfhe_delta.slurm index b3ee8f5..455ba2f 100644 --- a/run_openfhe_delta.slurm +++ b/run_openfhe_delta.slurm @@ -37,20 +37,20 @@ source $HOME/openfhe_env/bin/activate if [ ! -f "$HOME/openfhe_env/.installed" ]; then echo "Installing dependencies..." pip install --upgrade pip - + # Install PyTorch first pip install torch --index-url https://download.pytorch.org/whl/cpu - + # Install torch-geometric and its dependencies pip install torch-geometric pip install torch-scatter torch-sparse torch-cluster torch-spline-conv -f https://data.pyg.org/whl/torch-2.0.0+cpu.html - + # Install OpenFHE pip install openfhe==1.2.3.0.24.4 - + # Install other dependencies pip install ray[default] attridict ogb pyyaml networkx scipy scikit-learn - + # Clone and install fedgraph cd $HOME if [ ! -d "$HOME/fedgraph" ]; then @@ -58,7 +58,7 @@ if [ ! -f "$HOME/openfhe_env/.installed" ]; then fi cd $HOME/fedgraph pip install --no-deps . - + touch $HOME/openfhe_env/.installed echo "Dependencies installed successfully" fi @@ -75,4 +75,3 @@ python FGL_NC_HE.py echo "" echo "Job finished at $(date)" - diff --git a/run_openfhe_interactive_delta.sh b/run_openfhe_interactive_delta.sh index 6590e8e..fd974b5 100755 --- a/run_openfhe_interactive_delta.sh +++ b/run_openfhe_interactive_delta.sh @@ -54,7 +54,7 @@ if [ ! -f "$HOME/openfhe_env/.installed" ]; then pip install -q torch-geometric pip install -q openfhe==1.2.3.0.24.4 pip install -q ray[default] attridict ogb pyyaml networkx scipy scikit-learn - + # Clone repo cd $HOME if [ ! -d "$HOME/fedgraph" ]; then @@ -62,7 +62,7 @@ if [ ! -f "$HOME/openfhe_env/.installed" ]; then fi cd $HOME/fedgraph pip install -q --no-deps . - + touch $HOME/openfhe_env/.installed echo "Dependencies installed!" fi @@ -83,4 +83,3 @@ echo "" bash EOF - diff --git a/test_openfhe_smoke.py b/test_openfhe_smoke.py index ad2b7e5..45575e6 100644 --- a/test_openfhe_smoke.py +++ b/test_openfhe_smoke.py @@ -8,7 +8,7 @@ def test_basic_ckks(): """Test basic CKKS functionality (no multiparty).""" print("🔍 Testing basic OpenFHE CKKS...") - + # Create context with conservative parameters params = openfhe.CCParamsCKKSRNS() params.SetSecurityLevel(openfhe.HEStd_128_classic) @@ -17,44 +17,44 @@ def test_basic_ckks(): params.SetScalingModSize(40) params.SetFirstModSize(50) print("✅ Parameters set") - + cc = openfhe.GenCryptoContext(params) print("✅ Context created") - + # Enable basic features - cc.Enable(openfhe.PKESchemeFeature.PKE) - cc.Enable(openfhe.PKESchemeFeature.SHE) + cc.Enable(openfhe.PKE) + cc.Enable(openfhe.LEVELEDSHE) print("✅ Features enabled") - + # Generate keys kp = cc.KeyGen() print("✅ Keys generated") - + # Test data x = [1.0, 2.0, 3.0] scale = 2**40 pt = cc.MakeCKKSPackedPlaintext(x, scale) print("✅ Plaintext created") - + # Encrypt ct = cc.Encrypt(kp.publicKey, pt) print("✅ Encrypted") - + # Decrypt decrypted = cc.Decrypt(ct, kp.secretKey) decrypted.SetLength(len(x)) # Set logical length result = decrypted.GetRealPackedValue() print("✅ Decrypted") - + # Check result print(f"Expected: {x}") print(f"Result: {result[:len(x)]}") - + # Verify accuracy errors = [abs(e - r) for e, r in zip(x, result[:len(x)])] max_error = max(errors) print(f"Max error: {max_error:.2e}") - + if max_error < 1e-3: print("🎉 Basic CKKS test PASSED!") return True @@ -69,7 +69,7 @@ def test_import_speed(): import openfhe import_time = time.time() - start print(f"✅ Import took {import_time:.2f} seconds") - + if import_time < 5.0: print("🎉 Import speed OK!") return True @@ -80,25 +80,23 @@ def test_import_speed(): if __name__ == "__main__": print("🚀 OpenFHE Smoke Test") print("=" * 50) - + # Test import speed first import_ok = test_import_speed() print() - + # Test basic CKKS ckks_ok = test_basic_ckks() print() - + # Summary print("📊 Summary:") print(f" Import speed: {'✅' if import_ok else '❌'}") print(f" Basic CKKS: {'✅' if ckks_ok else '❌'}") - + if import_ok and ckks_ok: print("\n🎉 All tests PASSED! Ready for threshold HE.") exit(0) else: print("\n❌ Some tests FAILED. Check environment setup.") exit(1) - - diff --git a/tests/integration/test_fedgraph_integration.py b/tests/integration/test_fedgraph_integration.py index c3e3a53..300422c 100644 --- a/tests/integration/test_fedgraph_integration.py +++ b/tests/integration/test_fedgraph_integration.py @@ -140,7 +140,7 @@ def test_nc_server_aggregation(self, mock_aggre_gcn): # Create server args = Mock() - args.num_hops = 1 + args.num_hops = 2 args.dataset = "cora" args.num_layers = 2 @@ -158,11 +158,13 @@ def test_nc_server_aggregation(self, mock_aggre_gcn): # Test parameter broadcast mock_model.state_dict.return_value = {'weight': torch.randn(64, 50)} - server.broadcast_params(current_global_epoch=1) + mock_model.parameters.return_value = [torch.randn(64, 50)] + with patch("fedgraph.server_class.ray.get"): + server.broadcast_params(current_global_epoch=1) # Verify all trainers received updates for trainer in trainers: - trainer.update_params.assert_called() + trainer.update_params.remote.assert_called() @pytest.mark.integration @@ -471,15 +473,16 @@ def test_data_trainer_server_interaction(self): class_num=3, device=torch.device('cpu'), trainers=[trainer], - args=Mock(num_hops=1, dataset="test", num_layers=2) + args=Mock(num_hops=2, dataset="test", num_layers=2) ) # Test parameter flow trainer.update_params = Mock() mock_server_model.state_dict.return_value = {'weight': torch.randn(32, 20)} - server.broadcast_params(current_global_epoch=1) - trainer.update_params.assert_called() + with patch("fedgraph.server_class.ray.get"): + server.broadcast_params(current_global_epoch=1) + trainer.update_params.remote.assert_called() def test_monitor_integration(self): """Test monitoring system integration.""" @@ -559,4 +562,4 @@ def test_memory_efficiency(self): assert len(difference) > 0 # Clean up - del large_features, large_edge_index, intersection, difference \ No newline at end of file + del large_features, large_edge_index, intersection, difference diff --git a/tests/test_smoke_e2e.py b/tests/test_smoke_e2e.py index 962fa79..1974033 100644 --- a/tests/test_smoke_e2e.py +++ b/tests/test_smoke_e2e.py @@ -50,7 +50,7 @@ def _base_cora_config(**overrides): "n_trainer": 2, "batch_size": -1, "num_layers": 2, - "num_hops": 1, + "num_hops": 2, "gpu": False, "num_cpus_per_trainer": 1, "num_gpus_per_trainer": 0, diff --git a/tests/test_threshold_ckks_min.py b/tests/test_threshold_ckks_min.py index de5a610..4995b7f 100644 --- a/tests/test_threshold_ckks_min.py +++ b/tests/test_threshold_ckks_min.py @@ -11,8 +11,14 @@ def make_cc(): params.SetFirstModSize(60) params.SetScalingTechnique(openfhe.FLEXIBLEAUTOEXT) cc = openfhe.GenCryptoContext(params) - for f in ("PKE", "SHE", "LEVELEDSHE", "MULTIPARTY"): - cc.Enable(getattr(openfhe.PKESchemeFeature, f)) + for feature in ( + openfhe.PKE, + openfhe.KEYSWITCH, + openfhe.LEVELEDSHE, + openfhe.ADVANCEDSHE, + openfhe.MULTIPARTY, + ): + cc.Enable(feature) return cc def test_two_party_threshold_ckks_add(): @@ -23,22 +29,16 @@ def test_two_party_threshold_ckks_add(): pk0 = kp_lead.publicKey sk0 = kp_lead.secretKey - # Non-lead + # Non-lead. Its public key is the joint public key. kp_main = cc.MultipartyKeyGen(pk0) - pk1 = kp_main.publicKey + joint_pk = kp_main.publicKey sk1 = kp_main.secretKey - # Finalize joint PK on lead - kp_final = cc.MultipartyKeyGen(pk1) - joint_pk = kp_final.publicKey - # Data x = [0.1, 0.2, 0.3] y = [0.05, 0.1, 0.15] - scale = 2**50 - pt_x = cc.MakeCKKSPackedPlaintext(x, scale) - pt_y = cc.MakeCKKSPackedPlaintext(y, scale) - + pt_x = cc.MakeCKKSPackedPlaintext(x) + pt_y = cc.MakeCKKSPackedPlaintext(y) ct_x = cc.Encrypt(joint_pk, pt_x) ct_y = cc.Encrypt(joint_pk, pt_y) ct_sum = cc.EvalAdd(ct_x, ct_y) @@ -53,11 +53,9 @@ def test_two_party_threshold_ckks_add(): expect = [a+b for a,b in zip(x,y)] print(f"Expected: {expect}") print(f"Result: {out[:len(expect)]}") - + assert all(abs(e-r) < 1e-3 for e,r in zip(expect, out[:len(expect)])) print("✅ Two-party threshold CKKS test passed!") if __name__ == "__main__": test_two_party_threshold_ckks_add() - - diff --git a/tests/unit/test_federated_methods.py b/tests/unit/test_federated_methods.py index 7e175a0..acb1461 100644 --- a/tests/unit/test_federated_methods.py +++ b/tests/unit/test_federated_methods.py @@ -5,6 +5,8 @@ import attridict from fedgraph.federated_methods import ( + _resolve_nc_class_num, + _resolve_nc_global_node_num, run_fedgraph, run_fedgraph_enhanced, run_NC, @@ -13,6 +15,71 @@ ) +class TestResolveNCClassNum: + def test_uses_authoritative_loaded_class_num(self): + trainer_information = [{"label_num": 2}, {"label_num": None}] + + assert _resolve_nc_class_num(False, trainer_information, 7) == 7 + + def test_infers_huggingface_class_num_from_nonempty_trainers(self): + trainer_information = [ + {"label_num": None}, + {"label_num": 3}, + {"label_num": 5}, + ] + + assert _resolve_nc_class_num(True, trainer_information) == 5 + + def test_uses_huggingface_class_num_metadata(self): + trainer_information = [ + {"class_num": 7, "label_num": 2}, + {"class_num": 7, "label_num": 5}, + ] + + assert _resolve_nc_class_num(True, trainer_information) == 7 + + def test_rejects_inconsistent_huggingface_class_num_metadata(self): + trainer_information = [{"class_num": 3}, {"class_num": 4}] + + with pytest.raises(ValueError, match="inconsistent class_num"): + _resolve_nc_class_num(True, trainer_information) + + def test_rejects_huggingface_data_without_labels(self): + trainer_information = [{"label_num": None}, {"label_num": None}] + + with pytest.raises(ValueError, match="all train and test label tensors are empty"): + _resolve_nc_class_num(True, trainer_information) + + +class TestResolveNCGlobalNodeNum: + def test_uses_authoritative_loaded_node_count(self): + trainer_information = [{"features_num": 40}, {"features_num": 50}] + + assert _resolve_nc_global_node_num(False, trainer_information, 100) == 100 + + def test_infers_huggingface_node_count_from_owned_features(self): + trainer_information = [{"features_num": 40}, {"features_num": 60}] + + assert _resolve_nc_global_node_num(True, trainer_information) == 100 + + def test_uses_huggingface_global_node_num_metadata(self): + trainer_information = [ + {"global_node_num": 100, "features_num": 40}, + {"global_node_num": 100, "features_num": 50}, + ] + + assert _resolve_nc_global_node_num(True, trainer_information) == 100 + + def test_rejects_inconsistent_huggingface_global_node_num_metadata(self): + trainer_information = [ + {"global_node_num": 100}, + {"global_node_num": 101}, + ] + + with pytest.raises(ValueError, match="inconsistent global_node_num"): + _resolve_nc_global_node_num(True, trainer_information) + + class TestRunFedgraph: """Test run_fedgraph main orchestration function.""" @@ -24,6 +91,7 @@ def setup_method(self): self.args.method = "FedAvg" self.args.use_encryption = False self.args.use_huggingface = False + self.args.num_hops = 2 @patch('fedgraph.federated_methods.data_loader') @patch('fedgraph.federated_methods.run_NC') @@ -114,6 +182,16 @@ def test_run_fedgraph_lowrank_with_openfhe_dispatches_to_run_nc( run_fedgraph(self.args) mock_run_nc.assert_called_once() + + @patch('fedgraph.federated_methods.data_loader') + def test_run_fedgraph_rejects_unsupported_nc_num_hops(self, mock_data_loader): + """Test that ambiguous 1-hop NC mode is rejected before data loading.""" + self.args.num_hops = 1 + + with pytest.raises(ValueError, match="num_hops=1 is not supported"): + run_fedgraph(self.args) + + mock_data_loader.assert_not_called() @patch('fedgraph.federated_methods.data_loader') @patch('fedgraph.federated_methods.run_NC') @@ -445,4 +523,4 @@ def test_data_loading_logic(self, mock_data_loader): run_fedgraph(args) mock_data_loader.assert_not_called() - mock_run_nc.assert_called_once_with(args, None) \ No newline at end of file + mock_run_nc.assert_called_once_with(args, None) diff --git a/tests/unit/test_server_class.py b/tests/unit/test_server_class.py index 5cf05a1..9ad5c18 100644 --- a/tests/unit/test_server_class.py +++ b/tests/unit/test_server_class.py @@ -33,14 +33,14 @@ def setup_method(self): # Mock args self.args = Mock() - self.args.num_hops = 1 + self.args.num_hops = 2 self.args.dataset = "cora" self.args.num_layers = 2 self.args.method = "FedAvg" @patch('fedgraph.server_class.AggreGCN') def test_server_init_with_hops(self, mock_aggre_gcn): - """Test Server initialization with num_hops >= 1.""" + """Test Server initialization with FedGCN-style num_hops.""" mock_model = Mock() mock_aggre_gcn.return_value = mock_model mock_model.to.return_value = mock_model @@ -155,11 +155,216 @@ def test_broadcast_params(self): for trainer in server.trainers: trainer.update_params = Mock() - server.broadcast_params(current_global_epoch=1) + with patch("fedgraph.server_class.ray.get") as mock_ray_get: + server.broadcast_params(current_global_epoch=1) # Verify that all trainers received parameter updates for trainer in server.trainers: - trainer.update_params.assert_called_once() + trainer.update_params.remote.assert_called_once() + mock_ray_get.assert_called_once() + + def test_train_runs_complete_plaintext_round(self): + """Test local training, subset aggregation, and broadcast as one round.""" + server = Server.__new__(Server) + server.model = torch.nn.Linear(1, 1, bias=False) + server.device = torch.device("cpu") + server.use_encryption = False + server.num_of_trainers = 3 + + trainers = [] + param_refs = {} + parameter_values = [1.0, 100.0, 5.0] + for index, value in enumerate(parameter_values): + trainer = Mock() + trainer.train.remote.return_value = f"train-{index}" + trainer.get_params.remote.return_value = f"params-{index}" + trainer.update_params.remote.return_value = f"update-{index}" + param_refs[f"params-{index}"] = (torch.tensor([[value]]),) + trainers.append(trainer) + server.trainers = trainers + + def resolve(refs): + if isinstance(refs, list): + return [resolve(ref) for ref in refs] + return param_refs.get(refs, True) + + def wait_for_one(refs, **_kwargs): + return refs[:1], refs[1:] + + with patch("fedgraph.server_class.random.sample", return_value=[0, 2]), patch( + "fedgraph.server_class.ray.get", side_effect=resolve + ), patch( + "fedgraph.server_class.ray.wait", side_effect=wait_for_one + ), patch( + "fedgraph.server_class.time.time", side_effect=[10.0, 12.0, 12.0, 15.0] + ): + round_stats = server.train(4, sample_ratio=0.75) + + assert server.model.weight.item() == pytest.approx(3.0) + assert round_stats == { + "training_time": 2.0, + "communication_time": 3.0, + "upload_size": 8, + "download_size": 12, + "num_trainers": 2, + } + trainers[0].train.remote.assert_called_once_with(4) + trainers[1].train.remote.assert_not_called() + trainers[2].train.remote.assert_called_once_with(4) + for trainer in trainers: + trainer.update_params.remote.assert_called_once() + + def test_train_runs_complete_encrypted_round(self): + """Test encrypted training returns actual communication statistics.""" + server = Server.__new__(Server) + server.model = torch.nn.Linear(1, 1, bias=False) + server.device = torch.device("cpu") + server.use_encryption = True + server.num_of_trainers = 2 + server.aggregation_stats = [] + + trainers = [] + encrypted_refs = {} + decryption_refs = {} + metadata = [{"shape": torch.Size([1, 1]), "scale": 100.0}] + for index, payload in enumerate([b"one", b"two"]): + trainer = Mock() + trainer.train.remote.return_value = f"train-{index}" + trainer.get_encrypted_params.remote.return_value = f"encrypted-{index}" + trainer.load_encrypted_params.remote.return_value = f"decrypted-{index}" + encrypted_refs[f"encrypted-{index}"] = ([payload], metadata) + decryption_refs[f"decrypted-{index}"] = 0.1 + trainers.append(trainer) + server.trainers = trainers + server.aggregate_encrypted_params = Mock( + return_value=([b"avg"], metadata, 0.25) + ) + + def resolve(refs): + if isinstance(refs, list): + return [resolve(ref) for ref in refs] + if refs in encrypted_refs: + return encrypted_refs[refs] + return decryption_refs.get(refs, True) + + def wait_for_one(refs, **_kwargs): + return refs[:1], refs[1:] + + with patch( + "fedgraph.server_class.ray.get", side_effect=resolve + ), patch( + "fedgraph.server_class.ray.wait", side_effect=wait_for_one + ), patch( + "fedgraph.server_class.time.time", + side_effect=[10.0, 12.0, 12.0, 12.5, 13.0, 15.0], + ): + round_stats = server.train(4) + + assert round_stats["training_time"] == 2.0 + assert round_stats["communication_time"] == 3.0 + assert round_stats["encryption_time"] == 0.5 + assert round_stats["upload_size"] == 6 + assert round_stats["download_size"] == 6 + assert round_stats["num_trainers"] == 2 + assert server.aggregation_stats == [round_stats] + for trainer in trainers: + trainer.train.remote.assert_called_once_with(4) + trainer.load_encrypted_params.remote.assert_called_once() + + def test_aggregate_encrypted_params_rescales_to_common_layer_scale(self): + """Test encrypted params are rescaled before cross-trainer averaging.""" + server = Server.__new__(Server) + server.he_context = object() + + class FakeCKKSVector: + def __init__(self, value): + self.value = float(value) + + def __iadd__(self, other): + self.value += other.value + return self + + def __imul__(self, scalar): + self.value *= scalar + return self + + def serialize(self): + return self.value + + encrypted_params_list = [ + ([200.0], [{"shape": torch.Size([1]), "scale": 100.0}]), + ([4000.0], [{"shape": torch.Size([1]), "scale": 1000.0}]), + ] + + with patch( + "fedgraph.server_class.ts.ckks_vector_from", + side_effect=lambda _context, payload: FakeCKKSVector(payload), + ), patch("fedgraph.server_class.time.time", side_effect=[10.0, 13.0]): + aggregated_params, metadata, aggregation_time = ( + server.aggregate_encrypted_params(encrypted_params_list) + ) + + assert aggregated_params == [pytest.approx(3000.0)] + assert metadata == [{"shape": torch.Size([1]), "scale": 1000.0}] + assert aggregation_time == 3.0 + + def test_mask_encrypted_feature_sum_keeps_only_selected_rows(self): + """Test encrypted feature sums are masked before trainer download.""" + server = Server.__new__(Server) + server.he_context = object() + + class FakeCKKSVector: + def __init__(self, values): + self.values = [float(value) for value in values] + + def __imul__(self, mask): + self.values = [ + value * float(mask_value) + for value, mask_value in zip(self.values, mask) + ] + return self + + def serialize(self): + return self.values + + encrypted_feature_sum = ( + [1.0, 2.0, 3.0, 4.0, 5.0, 6.0], + torch.Size([3, 2]), + ) + + with patch( + "fedgraph.server_class.ts.ckks_vector_from", + side_effect=lambda _context, payload: FakeCKKSVector(payload), + ): + masked_sum, shape = server.mask_encrypted_feature_sum( + encrypted_feature_sum, torch.tensor([0, 2]) + ) + + assert masked_sum == [1.0, 2.0, 0.0, 0.0, 5.0, 6.0] + assert shape == torch.Size([3, 2]) + + def test_aggregate_encrypted_params_rejects_mismatched_shapes(self): + """Test encrypted averaging fails clearly for incompatible layers.""" + server = Server.__new__(Server) + + encrypted_params_list = [ + ([1.0], [{"shape": torch.Size([1]), "scale": 100.0}]), + ([2.0], [{"shape": torch.Size([2]), "scale": 100.0}]), + ] + + with pytest.raises(ValueError, match="shapes must match"): + server.aggregate_encrypted_params(encrypted_params_list) + + def test_aggregate_encrypted_params_rejects_invalid_scale(self): + """Test encrypted averaging requires positive per-layer scales.""" + server = Server.__new__(Server) + + encrypted_params_list = [ + ([1.0], [{"shape": torch.Size([1]), "scale": 0.0}]), + ] + + with pytest.raises(ValueError, match="scale must be positive"): + server.aggregate_encrypted_params(encrypted_params_list) def test_get_model_size(self): """Test get_model_size method.""" @@ -187,32 +392,6 @@ def test_get_model_size(self): assert isinstance(model_size, float) assert model_size > 0 - def test_prepare_params_for_encryption(self): - """Test prepare_params_for_encryption method.""" - with patch('fedgraph.server_class.AggreGCN') as mock_gcn: - mock_model = Mock() - mock_gcn.return_value = mock_model - mock_model.to.return_value = mock_model - - server = Server( - feature_dim=self.feature_dim, - args_hidden=self.args_hidden, - class_num=self.class_num, - device=self.device, - trainers=self.mock_trainers, - args=self.args - ) - - params = (torch.randn(10, 5), torch.randn(10)) - - result = server.prepare_params_for_encryption(params) - - assert isinstance(result, list) - assert len(result) == len(params) - for item in result: - assert isinstance(item, list) # Flattened and converted to list - - class TestServerGC: """Test Server_GC class for graph classification.""" @@ -420,7 +599,7 @@ def test_server_trainer_interaction(self, mock_gcn_class): device = torch.device('cpu') args = Mock() - args.num_hops = 1 + args.num_hops = 2 args.dataset = "cora" args.num_layers = 2 args.method = "FedAvg" @@ -433,6 +612,10 @@ def test_server_trainer_interaction(self, mock_gcn_class): 'layer1.weight': torch.randn(32, 50), 'layer1.bias': torch.randn(32) } + mock_model.parameters.return_value = [ + torch.randn(32, 50), + torch.randn(32), + ] # Create mock trainers trainers = [] @@ -457,11 +640,12 @@ def test_server_trainer_interaction(self, mock_gcn_class): ) # Test parameter broadcast - server.broadcast_params(current_global_epoch=1) + with patch("fedgraph.server_class.ray.get"): + server.broadcast_params(current_global_epoch=1) # Verify all trainers received updates for trainer in trainers: - trainer.update_params.assert_called_once() + trainer.update_params.remote.assert_called_once() # Test model size computation model_size = server.get_model_size() @@ -520,4 +704,4 @@ def test_server_gc_clustering_workflow(self): server.aggregate_clusterwise(trainer_clusters) # Verify workflow completed successfully - assert True \ No newline at end of file + assert True diff --git a/tests/unit/test_trainer_class.py b/tests/unit/test_trainer_class.py index 05c74e3..af94659 100644 --- a/tests/unit/test_trainer_class.py +++ b/tests/unit/test_trainer_class.py @@ -3,6 +3,7 @@ import numpy as np from unittest.mock import Mock, patch, MagicMock import ray +from huggingface_hub.errors import EntryNotFoundError from fedgraph.trainer_class import ( Trainer_General, @@ -34,7 +35,9 @@ def test_load_trainer_data_success(self, mock_torch_load, mock_open, mock_hf_dow torch.randn(20), # test_labels torch.randn(100, 10), # features torch.randn(80), # in_com_train_node_local_indexes - torch.randn(20) # in_com_test_node_local_indexes + torch.randn(20), # in_com_test_node_local_indexes + torch.tensor(100), # global_node_num + torch.tensor(3), # class_num ] mock_torch_load.side_effect = mock_tensors @@ -46,11 +49,11 @@ def test_load_trainer_data_success(self, mock_torch_load, mock_open, mock_hf_dow result = load_trainer_data_from_hugging_face(trainer_id=0, args=args) - assert len(result) == 8 + assert len(result) == 10 assert all(isinstance(tensor, torch.Tensor) for tensor in result) # Verify calls - assert mock_hf_download.call_count == 8 + assert mock_hf_download.call_count == 10 expected_repo = "FedGraph/fedgraph_cora_5trainer_2hop_iid_beta_0.5_trainer_id_0" mock_hf_download.assert_any_call( repo_id=expected_repo, @@ -58,6 +61,31 @@ def test_load_trainer_data_success(self, mock_torch_load, mock_open, mock_hf_dow filename="local_node_index.pt" ) + @patch('fedgraph.trainer_class.hf_hub_download') + @patch('builtins.open') + @patch('torch.load') + def test_load_existing_repo_without_global_metadata( + self, mock_torch_load, mock_open, mock_hf_download + ): + mock_file = Mock() + mock_file.read.return_value = b"test_tensor_data" + mock_open.return_value.__enter__.return_value = mock_file + mock_torch_load.side_effect = [torch.tensor([i]) for i in range(8)] + + def download_side_effect(*, filename, **kwargs): + if filename in {"global_node_num.pt", "class_num.pt"}: + raise EntryNotFoundError("missing optional metadata") + return "/tmp/test_file.pt" + + mock_hf_download.side_effect = download_side_effect + args = Mock(dataset="cora", n_trainer=5, num_hops=2, iid_beta=0.5) + + with pytest.warns(UserWarning, match="falling back to inference"): + result = load_trainer_data_from_hugging_face(trainer_id=0, args=args) + + assert len(result) == 10 + assert result[-2:] == (None, None) + class TestTrainerGeneral: """Test Trainer_General class.""" @@ -102,13 +130,17 @@ def test_trainer_init_with_data(self, mock_load_data): test_labels=self.test_labels, features=self.features, idx_train=self.idx_train, - idx_test=self.idx_test + idx_test=self.idx_test, + global_node_num=100, + class_num=3, ) assert trainer.rank == self.rank assert trainer.device == self.device assert trainer.args_hidden == self.args_hidden assert trainer.local_step == self.args.local_step + assert trainer.global_node_num == 100 + assert trainer.class_num == 3 assert torch.equal(trainer.local_node_index, self.local_node_index.to(self.device)) assert trainer.feature_aggregation is not None # Should be set for FedAvg mock_load_data.assert_not_called() @@ -124,7 +156,9 @@ def test_trainer_init_without_data(self, mock_load_data): self.test_labels, self.features, self.idx_train, - self.idx_test + self.idx_test, + torch.tensor(100), + torch.tensor(3), ) trainer = Trainer_General( @@ -135,6 +169,8 @@ def test_trainer_init_without_data(self, mock_load_data): ) assert trainer.rank == self.rank + assert trainer.global_node_num == 100 + assert trainer.class_num == 3 mock_load_data.assert_called_once_with(self.rank, self.args) def test_get_info(self): @@ -161,6 +197,38 @@ def test_get_info(self): assert info["features_num"] == len(self.features) expected_label_num = max(self.train_labels.max().item(), self.test_labels.max().item()) + 1 assert info["label_num"] == expected_label_num + + @pytest.mark.parametrize( + ("train_labels", "test_labels", "expected_label_num"), + [ + (torch.tensor([], dtype=torch.long), torch.tensor([0, 2]), 3), + (torch.tensor([0, 1]), torch.tensor([], dtype=torch.long), 2), + ( + torch.tensor([], dtype=torch.long), + torch.tensor([], dtype=torch.long), + None, + ), + ], + ) + def test_get_info_handles_empty_labels( + self, train_labels, test_labels, expected_label_num + ): + trainer = Trainer_General( + rank=self.rank, + args_hidden=self.args_hidden, + device=self.device, + args=self.args, + local_node_index=self.local_node_index, + communicate_node_index=self.communicate_node_index, + adj=self.adj, + train_labels=train_labels, + test_labels=test_labels, + features=self.features, + idx_train=torch.arange(len(train_labels)), + idx_test=torch.arange(len(test_labels)), + ) + + assert trainer.get_info()["label_num"] == expected_label_num @patch('fedgraph.trainer_class.GCN') @patch('fedgraph.trainer_class.GCN_arxiv') @@ -224,6 +292,7 @@ def test_update_params(self): def test_get_local_feature_sum(self): """Test get_local_feature_sum method.""" + local_features = self.features[self.local_node_index] trainer = Trainer_General( rank=self.rank, args_hidden=self.args_hidden, @@ -234,11 +303,12 @@ def test_get_local_feature_sum(self): adj=self.adj, train_labels=self.train_labels, test_labels=self.test_labels, - features=self.features, + features=local_features, idx_train=self.idx_train, idx_test=self.idx_test ) - + trainer.global_node_num = len(self.features) + # Mock the get_1hop_feature_sum function with patch('fedgraph.trainer_class.get_1hop_feature_sum') as mock_get_1hop: mock_get_1hop.return_value = torch.randn(100, 50) @@ -247,6 +317,30 @@ def test_get_local_feature_sum(self): assert isinstance(result, torch.Tensor) mock_get_1hop.assert_called_once() + feature_matrix = mock_get_1hop.call_args.args[0] + assert feature_matrix.shape == ( + trainer.global_node_num, + local_features.shape[1], + ) + + def test_get_local_feature_sum_requires_global_node_num(self): + trainer = Trainer_General( + rank=self.rank, + args_hidden=self.args_hidden, + device=self.device, + args=self.args, + local_node_index=self.local_node_index, + communicate_node_index=self.communicate_node_index, + adj=self.adj, + train_labels=self.train_labels, + test_labels=self.test_labels, + features=self.features[self.local_node_index], + idx_train=self.idx_train, + idx_test=self.idx_test, + ) + + with pytest.raises(RuntimeError, match="metadata must be initialized"): + trainer.get_local_feature_sum() def test_load_feature_aggregation(self): """Test load_feature_aggregation method.""" @@ -298,8 +392,9 @@ def test_get_params(self): assert isinstance(params, tuple) mock_model.state_dict.assert_called_once() + @patch('fedgraph.trainer_class.test') @patch('fedgraph.trainer_class.train') - def test_train_method(self, mock_train_func): + def test_train_method(self, mock_train_func, mock_test_func): """Test train method.""" trainer = Trainer_General( rank=self.rank, @@ -326,12 +421,16 @@ def test_train_method(self, mock_train_func): self.args.batch_size = 0 # Ensure no batching for this test mock_train_func.return_value = (0.5, 0.85) # loss, accuracy + mock_test_func.return_value = (0.3, 0.9) # loss, accuracy trainer.train(current_global_round=1) - mock_train_func.assert_called() - assert len(trainer.train_losses) > 0 - assert len(trainer.train_accs) > 0 + assert mock_train_func.call_count == trainer.local_step + assert mock_test_func.call_count == trainer.local_step + assert len(trainer.train_losses) == trainer.local_step + assert len(trainer.train_accs) == trainer.local_step + assert len(trainer.test_losses) == trainer.local_step + assert len(trainer.test_accs) == trainer.local_step @patch('fedgraph.trainer_class.test') def test_local_test(self, mock_test_func): @@ -662,4 +761,4 @@ def test_trainer_general_full_workflow(self, mock_gcn_class): assert isinstance(params, tuple) # Test rank retrieval - assert trainer.get_rank() == rank \ No newline at end of file + assert trainer.get_rank() == rank diff --git a/tests/unit/test_utils_nc.py b/tests/unit/test_utils_nc.py index 8cbb520..53a9cac 100644 --- a/tests/unit/test_utils_nc.py +++ b/tests/unit/test_utils_nc.py @@ -14,7 +14,9 @@ community_partition_non_iid, get_in_comm_indexes, get_1hop_feature_sum, - increment_dir + increment_dir, + save_all_trainers_data, + save_trainer_data_to_hugging_face, ) @@ -78,7 +80,7 @@ def test_intersect1d_all_same(self): result = intersect1d(t1, t2) assert torch.equal(result.sort().values, torch.tensor([1, 2, 3])) - + def test_setdiff1d_basic(self): """Test setdiff1d function with basic tensors.""" t1 = torch.tensor([1, 2, 3, 4]) @@ -109,6 +111,70 @@ def test_setdiff1d_complete_difference(self): assert torch.equal(result.sort().values, expected.sort().values) +class TestSaveTrainerDataToHuggingFace: + @patch("fedgraph.utils_nc.get_token") + @patch("fedgraph.utils_nc.HfApi") + def test_uploads_global_metadata(self, mock_hf_api, mock_get_token): + api = mock_hf_api.return_value + mock_get_token.return_value = "token" + args = Mock( + dataset="cora", + n_trainer=2, + num_hops=2, + iid_beta=0.5, + ) + + save_trainer_data_to_hugging_face( + trainer_id=0, + local_node_index=torch.tensor([0, 1]), + communicate_node_global_index=torch.tensor([0, 1, 2]), + global_edge_index_client=torch.tensor([[0, 1], [1, 2]]), + train_labels=torch.tensor([0]), + test_labels=torch.tensor([1]), + features=torch.randn(2, 4), + in_com_train_node_local_indexes=torch.tensor([0]), + in_com_test_node_local_indexes=torch.tensor([1]), + global_node_num=3, + class_num=2, + args=args, + ) + + uploaded_files = { + call.kwargs["path_in_repo"] for call in api.upload_file.call_args_list + } + assert "global_node_num.pt" in uploaded_files + assert "class_num.pt" in uploaded_files + + @patch("fedgraph.utils_nc.save_trainer_data_to_hugging_face") + def test_bulk_save_passes_consistent_global_metadata(self, mock_save_trainer): + features = torch.randn(4, 3) + labels = torch.tensor([0, 1, 2, 1]) + + save_all_trainers_data( + split_node_indexes=[torch.tensor([0, 1]), torch.tensor([2, 3])], + communicate_node_global_indexes=[ + torch.tensor([0, 1]), + torch.tensor([2, 3]), + ], + global_edge_indexes_clients=[ + torch.tensor([[0], [1]]), + torch.tensor([[2], [3]]), + ], + labels=labels, + features=features, + in_com_train_node_local_indexes=[torch.tensor([0]), torch.tensor([0])], + in_com_test_node_local_indexes=[torch.tensor([1]), torch.tensor([1])], + n_trainer=2, + class_num=3, + args=Mock(), + ) + + assert mock_save_trainer.call_count == 2 + for call in mock_save_trainer.call_args_list: + assert call.kwargs["global_node_num"] == 4 + assert call.kwargs["class_num"] == 3 + + class TestLabelDirichletPartition: """Test label_dirichlet_partition function.""" @@ -230,7 +296,7 @@ def test_get_in_comm_indexes_basic(self): ] n_trainers = 3 - num_hops = 1 + num_hops = 2 idx_train = torch.tensor([0, 2, 4]) idx_test = torch.tensor([1, 3]) @@ -243,14 +309,31 @@ def test_get_in_comm_indexes_basic(self): (communicate_node_global_indexes, in_com_train_node_local_indexes, in_com_test_node_local_indexes, global_edge_indexes_clients) = result - assert isinstance(communicate_node_global_indexes, dict) - assert isinstance(in_com_train_node_local_indexes, dict) - assert isinstance(in_com_test_node_local_indexes, dict) - assert isinstance(global_edge_indexes_clients, dict) + assert isinstance(communicate_node_global_indexes, list) + assert isinstance(in_com_train_node_local_indexes, list) + assert isinstance(in_com_test_node_local_indexes, list) + assert isinstance(global_edge_indexes_clients, list) assert len(communicate_node_global_indexes) == n_trainers assert len(global_edge_indexes_clients) == n_trainers + def test_get_in_comm_indexes_rejects_unsupported_one_hop(self): + """Test that the old ambiguous 1-hop NC mode is rejected.""" + edge_index = torch.tensor([[0, 1], [1, 2]]) + split_node_indexes = [torch.tensor([0, 1])] + idx_train = torch.tensor([0]) + idx_test = torch.tensor([1]) + + with pytest.raises(ValueError, match="num_hops=1 is not supported"): + get_in_comm_indexes( + edge_index, + split_node_indexes, + 1, + 1, + idx_train, + idx_test, + ) + class TestGet1hopFeatureSum: """Test get_1hop_feature_sum function.""" @@ -386,7 +469,7 @@ def test_graph_communication_workflow(self): ]) n_trainers = 3 - num_hops = 1 + num_hops = 2 idx_train = torch.arange(0, 7) idx_test = torch.arange(7, 10) @@ -407,4 +490,4 @@ def test_graph_communication_workflow(self): # Verify all trainers have some data for i in range(n_trainers): assert i in communicate_indexes - assert i in edge_indexes \ No newline at end of file + assert i in edge_indexes diff --git a/tutorials/FGL_NC.py b/tutorials/FGL_NC.py index ab17311..b552c2b 100644 --- a/tutorials/FGL_NC.py +++ b/tutorials/FGL_NC.py @@ -33,7 +33,7 @@ "batch_size": -1, # -1 indicates full batch training # Model Structure "num_layers": 2, - "num_hops": 1, # Number of n-hop neighbors for client communication + "num_hops": 2, # Supported NC communication modes: 0 for FedAvg, 2 for FedGCN # Resource and Hardware Settings "gpu": False, "num_cpus_per_trainer": 1, diff --git a/tutorials/FGL_NC_HE.py b/tutorials/FGL_NC_HE.py index d196e6d..b9d68b1 100644 --- a/tutorials/FGL_NC_HE.py +++ b/tutorials/FGL_NC_HE.py @@ -33,7 +33,7 @@ "batch_size": -1, # -1 indicates full batch training # Model Structure "num_layers": 2, - "num_hops": 1, # Number of n-hop neighbors for client communication + "num_hops": 2, # Supported NC communication modes: 0 for FedAvg, 2 for FedGCN # Resource and Hardware Settings "gpu": False, "num_cpus_per_trainer": 1,