Domain Generalization (DG) aims to learn models that transfer to unseen domains without any target data. Prior works (FACT, CIRL) apply passive Fourier augmentation — randomly mixing amplitude components between domains with no explicit adversarial pressure.
This work proposes Adversarial Fourier Augmentation (AFA), which actively searches for the hardest amplitude perturbations via Gradient Ascent, forcing the encoder to learn robust shape-biased representations rather than shortcut style cues. Combined with a GAN regularizer on the causal feature space, ACIRL is evaluated on PACS (4 domains, 7 classes) and OfficeHome (4 domains, 65 classes).
| # | Contribution |
|---|---|
| 1 | Adversarial Fourier Augmentation (AFA) — active causal intervention via Gradient Ascent on the amplitude spectrum; improves shape-bias over FACT/CIRL's passive mixing |
| 2 | GAN-regularized causal space — pluggable Vanilla / WGAN / WGAN-GP module aligns features to a Gaussian prior via the lid loss |
| 3 | OfficeHome support — full configs (ResNet18, ResNet50), data prep, and evaluation |
| 4 | MLflow + DagHub tracking — all hyperparameters, metrics, and checkpoints logged per run |
flowchart TD
subgraph INPUT["Input"]
direction LR
IMG["Source Domain Images"]
FA["Fourier Augmentation"]
end
subgraph BACKBONE["Feature Extraction"]
ENC["Encoder\n— ResNet-18 / ResNet-50 —"]
end
subgraph DISENTANGLE["Causal Disentanglement"]
MSK["Masker\nGumbel-Softmax · top-k selection"]
FSUP["f_sup\nCausal Features"]
FINF["f_inf\nSpurious Features"]
end
subgraph CLASSIFY["Classification"]
CSUP["Classifier C_sup"]
CINF["Classifier C_inf"]
end
subgraph GAN_MOD["Generative Regularization"]
GEN["Generator G(z)"]
DISC["Discriminator D"]
end
subgraph LOSSES["Objective"]
L1["L_cls_sup"]
L2["L_cls_inf"]
L3["L_Fac"]
L4["L_lid = L_GAN + β · L_Normal"]
LT["L_total = L_cls_sup + L_cls_inf + λ · L_Fac + L_lid"]
end
IMG --> ENC
FA --> ENC
ENC -->|"f"| MSK
MSK -->|"f_sup"| FSUP
MSK -->|"f_inf"| FINF
FSUP --> CSUP --> L1
FINF --> CINF --> L2
ENC -->|"f_ori vs f_aug"| L3
ENC -->|"f"| GEN
GEN --> DISC --> L4
L1 --> LT
L2 --> LT
L3 --> LT
L4 --> LT
classDef inputNode fill:#e8f4fd,stroke:#2c82c9,color:#1a4a7a,font-weight:bold
classDef backboneNode fill:#eafaf1,stroke:#27ae60,color:#1a5c38,font-weight:bold
classDef maskNode fill:#fef9e7,stroke:#d4ac0d,color:#7d6608,font-weight:bold
classDef featNode fill:#f5f5f5,stroke:#95a5a6,color:#2c3e50
classDef classNode fill:#eaf4fb,stroke:#2980b9,color:#1a4a7a
classDef ganNode fill:#f4ecf7,stroke:#8e44ad,color:#5b2c6f,font-weight:bold
classDef lossNode fill:#fdedec,stroke:#e74c3c,color:#922b21
classDef totalNode fill:#2c3e50,stroke:#2c3e50,color:#ffffff,font-weight:bold
class IMG,FA inputNode
class ENC backboneNode
class MSK maskNode
class FSUP,FINF featNode
class CSUP,CINF classNode
class GEN,DISC ganNode
class L1,L2,L3,L4 lossNode
class LT totalNode
Requirements: Python 3.10+, CUDA-compatible GPU.
git clone https://github.com/DanielWay17/Domain-Generalization
cd Domain-Generalization
python -m venv .venv && source .venv/bin/activate
pip install -r requirements.txtTech stack: PyTorch 2.9 · torchvision 0.24 · MLflow 2.9 · DagHub 0.3 · Flake8 / Black
| Dataset | Domains | Classes | Download |
|---|---|---|---|
| PACS | Art Painting, Cartoon, Photo, Sketch | 7 | PACS |
| OfficeHome | Art, Clipart, Product, Real World | 65 | OfficeHome |
Datalist format (one image per line):
/absolute/path/to/image.jpg <label_index>
Generate datalists:
# Kaggle (auto-scans input dir, 80/20 split)
python prepare_data_kaggle.py
# Local OfficeHome
python process_officehome.py --root /path/to/OfficeHomeDataset_10072016 --output data/datalists
# Repath existing lists after moving data
python add_prefix_path.py --prefix /new/base/path --datalist_dir data/datalists# Leave-one-out training (recommended)
python shell_train.py --domain <target> --gpu 0 --author <name>
# Manual training
python train.py --source art_painting cartoon sketch \
--target photo --input_dir data/datalists \
--output_dir outputs --config PACS/ResNet50 --author <name>
# Evaluation
python shell_test.py --domain <target> --gpu 0Valid domain names:
| Dataset | Domains |
|---|---|
| PACS | art_painting cartoon photo sketch |
| OfficeHome | Art Clipart Product RealWorld |
Config files live in config/<Dataset>/<Backbone>.py.
| Parameter | Description | PACS/R50 | OfficeHome/R50 |
|---|---|---|---|
batch_size |
Batch size | 16 | 32 |
epoch |
Total epochs | 50 | 50 |
T |
Gumbel-Softmax temperature | 5.0 | 5.5 |
k |
Top-k causal dimensions | 308 | 1228 |
num_classes |
Output classes | 7 | 65 |
lam_const |
Factorization loss weight | 5.0 | 5.0 |
GAN_TYPE |
vanilla / wgan / wgan_gp |
wgan |
wgan |
gan.critic_steps |
Discriminator steps per G step | 5 | 5 |
All runs log to MLflow (connect to DagHub for remote tracking)
| Method | Art Painting | Cartoon | Photo | Sketch | Avg. |
|---|---|---|---|---|---|
| CIRL (original) | 86.10 | 81.12 | 95.99 | 84.27 | 86.87 |
| ACIRL (ours) | 86.00 | 81.05 | 95.75 | 85.90 | 87.18 |
| Method | Art | Clipart | Product | Real World | Avg. |
|---|---|---|---|---|---|
| CIRL (original) | 61.48 | 56.28 | 75.06 | 76.64 | 67.36 |
| ACIRL (ours) | 62.15 | 57.40 | 74.55 | 75.80 | 67.48 |