Skip to content

DanielWay17/Domain-Generalization

Repository files navigation

ACIRL: Adversarial Causal Intervention for Representation Learning

Python PyTorch MLflow License: MIT


Abstract

Domain Generalization (DG) aims to learn models that transfer to unseen domains without any target data. Prior works (FACT, CIRL) apply passive Fourier augmentation — randomly mixing amplitude components between domains with no explicit adversarial pressure.

This work proposes Adversarial Fourier Augmentation (AFA), which actively searches for the hardest amplitude perturbations via Gradient Ascent, forcing the encoder to learn robust shape-biased representations rather than shortcut style cues. Combined with a GAN regularizer on the causal feature space, ACIRL is evaluated on PACS (4 domains, 7 classes) and OfficeHome (4 domains, 65 classes).


Key Contributions

# Contribution
1 Adversarial Fourier Augmentation (AFA) — active causal intervention via Gradient Ascent on the amplitude spectrum; improves shape-bias over FACT/CIRL's passive mixing
2 GAN-regularized causal space — pluggable Vanilla / WGAN / WGAN-GP module aligns features to a Gaussian prior via the lid loss
3 OfficeHome support — full configs (ResNet18, ResNet50), data prep, and evaluation
4 MLflow + DagHub tracking — all hyperparameters, metrics, and checkpoints logged per run

Method Overview

flowchart TD
    subgraph INPUT["Input"]
        direction LR
        IMG["Source Domain Images"]
        FA["Fourier Augmentation"]
    end

    subgraph BACKBONE["Feature Extraction"]
        ENC["Encoder\n— ResNet-18 / ResNet-50 —"]
    end

    subgraph DISENTANGLE["Causal Disentanglement"]
        MSK["Masker\nGumbel-Softmax  ·  top-k selection"]
        FSUP["f_sup\nCausal Features"]
        FINF["f_inf\nSpurious Features"]
    end

    subgraph CLASSIFY["Classification"]
        CSUP["Classifier  C_sup"]
        CINF["Classifier  C_inf"]
    end

    subgraph GAN_MOD["Generative Regularization"]
        GEN["Generator  G(z)"]
        DISC["Discriminator  D"]
    end

    subgraph LOSSES["Objective"]
        L1["L_cls_sup"]
        L2["L_cls_inf"]
        L3["L_Fac"]
        L4["L_lid  =  L_GAN  +  β · L_Normal"]
        LT["L_total  =  L_cls_sup  +  L_cls_inf  +  λ · L_Fac  +  L_lid"]
    end

    IMG --> ENC
    FA  --> ENC
    ENC -->|"f"| MSK
    MSK -->|"f_sup"| FSUP
    MSK -->|"f_inf"| FINF
    FSUP --> CSUP --> L1
    FINF --> CINF --> L2
    ENC -->|"f_ori vs f_aug"| L3
    ENC -->|"f"| GEN
    GEN --> DISC --> L4
    L1 --> LT
    L2 --> LT
    L3 --> LT
    L4 --> LT

    classDef inputNode   fill:#e8f4fd,stroke:#2c82c9,color:#1a4a7a,font-weight:bold
    classDef backboneNode fill:#eafaf1,stroke:#27ae60,color:#1a5c38,font-weight:bold
    classDef maskNode    fill:#fef9e7,stroke:#d4ac0d,color:#7d6608,font-weight:bold
    classDef featNode    fill:#f5f5f5,stroke:#95a5a6,color:#2c3e50
    classDef classNode   fill:#eaf4fb,stroke:#2980b9,color:#1a4a7a
    classDef ganNode     fill:#f4ecf7,stroke:#8e44ad,color:#5b2c6f,font-weight:bold
    classDef lossNode    fill:#fdedec,stroke:#e74c3c,color:#922b21
    classDef totalNode   fill:#2c3e50,stroke:#2c3e50,color:#ffffff,font-weight:bold

    class IMG,FA inputNode
    class ENC backboneNode
    class MSK maskNode
    class FSUP,FINF featNode
    class CSUP,CINF classNode
    class GEN,DISC ganNode
    class L1,L2,L3,L4 lossNode
    class LT totalNode
Loading

Setup

Requirements: Python 3.10+, CUDA-compatible GPU.

git clone https://github.com/DanielWay17/Domain-Generalization
cd Domain-Generalization
python -m venv .venv && source .venv/bin/activate
pip install -r requirements.txt

Tech stack: PyTorch 2.9 · torchvision 0.24 · MLflow 2.9 · DagHub 0.3 · Flake8 / Black


Datasets & Data Preparation

Dataset Domains Classes Download
PACS Art Painting, Cartoon, Photo, Sketch 7 PACS
OfficeHome Art, Clipart, Product, Real World 65 OfficeHome

Datalist format (one image per line):

/absolute/path/to/image.jpg <label_index>

Generate datalists:

# Kaggle (auto-scans input dir, 80/20 split)
python prepare_data_kaggle.py

# Local OfficeHome
python process_officehome.py --root /path/to/OfficeHomeDataset_10072016 --output data/datalists

# Repath existing lists after moving data
python add_prefix_path.py --prefix /new/base/path --datalist_dir data/datalists

Training & Evaluation

# Leave-one-out training (recommended)
python shell_train.py --domain <target> --gpu 0 --author <name>

# Manual training
python train.py --source art_painting cartoon sketch \
    --target photo --input_dir data/datalists \
    --output_dir outputs --config PACS/ResNet50 --author <name>

# Evaluation
python shell_test.py --domain <target> --gpu 0

Valid domain names:

Dataset Domains
PACS art_painting cartoon photo sketch
OfficeHome Art Clipart Product RealWorld

Configuration

Config files live in config/<Dataset>/<Backbone>.py.

Parameter Description PACS/R50 OfficeHome/R50
batch_size Batch size 16 32
epoch Total epochs 50 50
T Gumbel-Softmax temperature 5.0 5.5
k Top-k causal dimensions 308 1228
num_classes Output classes 7 65
lam_const Factorization loss weight 5.0 5.0
GAN_TYPE vanilla / wgan / wgan_gp wgan wgan
gan.critic_steps Discriminator steps per G step 5 5

Experiment Tracking

All runs log to MLflow (connect to DagHub for remote tracking)

MLflow Tracking


Results

PACS — ResNet-18

Method Art Painting Cartoon Photo Sketch Avg.
CIRL (original) 86.10 81.12 95.99 84.27 86.87
ACIRL (ours) 86.00 81.05 95.75 85.90 87.18

OfficeHome — ResNet-18

Method Art Clipart Product Real World Avg.
CIRL (original) 61.48 56.28 75.06 76.64 67.36
ACIRL (ours) 62.15 57.40 74.55 75.80 67.48

About

No description, website, or topics provided.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors