Skip to content

Optimised training, inference and memory for metalearners in multitreatment settings#896

Open
Ic3fr0g wants to merge 5 commits intouber:masterfrom
Ic3fr0g:fix/853-tlearner-model-c-train-once
Open

Optimised training, inference and memory for metalearners in multitreatment settings#896
Ic3fr0g wants to merge 5 commits intouber:masterfrom
Ic3fr0g:fix/853-tlearner-model-c-train-once

Conversation

@Ic3fr0g
Copy link
Copy Markdown

@Ic3fr0g Ic3fr0g commented May 9, 2026

Proposed changes

Optimised training, inference and memory for most metalearners for multi-treatment settings.
Fixes #853

Here's a summary of every change made across all learners:

Learner Change Time saved Space saved
T models_c dict -> model_c single model (N-1) fits on control data + (N-1) forward passes on full X (N-1) × model_c copies in RAM; (N-1) × model_c in pickle
T yhat_cs dict -> scalar yhat_c; loop uses direct boolean index Loop overhead only (N-1) arrays of shape (n,)
X models_mu_c dict -> model_mu_c single model (N-1) fits on control data (N-1) × model_mu_c copies in RAM; (N-1) × in pickle
X y_control_pred cached after fit; var_c computed once before loop (N-1) forward passes on X_control during fit 0 (scalar reuse)
X yhat_c_verbose hoisted before loop in both predict methods (N-1) forward passes on X_control during predict (verbose path) 0 (scalar reuse)
DR yhat_c hoisted outside group loop (N-1) × 3 forward passes on full X per predict() call (N-1) arrays of shape (n,) in yhat_cs dict
S X_new_c, X_new_t hoisted outside group loop (N-1) × 2 hstack ops on (n × (p+1)) (N-1) × 2 peak arrays of shape (n, p+1) allocated and dropped each iteration

Where N = number of treatment groups. For the common binary case (N=1), there is no saving; savings are realized in multi-treatment settings, where these learners are most expensive.

Types of changes

What types of changes does your code introduce to CausalML?
Put an x in the boxes that apply

  • Bugfix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Documentation Update (if none of the other choices apply)

Checklist

Put an x in the boxes that apply. You can also fill these out after creating the PR. If you're unsure about any of them, don't hesitate to ask. We're here to help! This is simply a reminder of what we are going to look for before merging your code.

  • I have read the CONTRIBUTING doc
  • I have signed the CLA
  • Lint and unit tests pass locally with my changes
  • I have added tests that prove my fix is effective or that my feature works
  • I have added necessary documentation (if appropriate)
  • Any dependent changes have been merged and published in downstream modules

Further comments

I started off with the T-learner basis the issue and my personal experience using this project at work. Then I gradually expanded the scope because I saw the same ineffiencies everywhere. Lot's of memory bloat, training and inference time is required for most metalearners, this aims to solve some of those issues.

Pickle / joblib shared-reference behavior

Pickle uses a memo dict keyed by id(obj). When the same object appears multiple times in a structure, it is serialized once, and subsequent occurrences become tiny backreference opcodes (about 2 bytes each).

So:

  • single model: 137 KB (1.00×)
  • shared-ref dict N=10: 137 KB (1.00×) ← pickle memo kicks in
  • deepcopy dict N=10: 1,359 KB (9.9×) ← 10 fully serialized copies

The overhead of a shared-ref dict vs a single object is just the dict structure + N backreference opcodes, which is negligible for real models.
But the old code ({group: deepcopy(self.model_c) for group in self.t_groups}) created N separate Python objects, so pickle saw N different id() values and serialized each one fully. The new code (self.model_c) removes the dict entirely.

Ic3fr0g added 4 commits May 9, 2026 16:50
Previously, BaseTLearner.fit() deep-copied and trained model_c once per
treatment group even though all groups share the same control data. Likewise,
BaseTLearner.predict() (and BaseTClassifier.predict()) called model_c.predict()
once per group despite the model being identical.

- fit(): compute control_mask once, deep-copy model_c once, fit it once, then
  share the single fitted instance in self.models_c across all groups.
- predict() / BaseTClassifier.predict(): call model_c.predict[_proba]() once
  before the loop and reuse the result for each group.

Adds test_BaseTLearner_model_c_trained_and_predicted_once to verify that both
fit() and predict() invoke model_c exactly once regardless of treatment group count.
Now that model_c is trained once before the loop, the per-group mask that
combined treatment and control rows is no longer needed for model_t either.
Replace the mask + treatment_filt + w == 1 chain with a direct boolean index
on the treatment group alone, which is both simpler and faster.
TLearner
- fit(): self.models_c dict replaced by a single self.model_c trained once
  on control data; per-group loop now only trains treatment models with a
  direct boolean index (no mask+w chain).
- predict() / BaseTClassifier.predict(): yhat_cs dict removed; yhat_c
  computed once from self.model_c before the loop.
- fit_predict() / estimate_ate(): bootstrap save/restore updated from
  models_c dict to scalar model_c.

XLearner
- fit() / BaseXClassifier.fit(): models_mu_c dict replaced by a single
  self.model_mu_c trained once; per-group loop skips redundant re-training
  and uses direct treatment_mask indexing.
- predict() / BaseXClassifier.predict(): verbose block updated to use
  self.model_mu_c directly.
- fit_predict() / estimate_ate(): bootstrap save/restore updated.

DRLearner
- predict() / BaseDRClassifier.predict(): models_mu_c is fold-specific but
  not group-specific; yhat_c hoisted outside the group loop to avoid
  running the 3-fold ensemble prediction once per group.
- estimate_ate(): unpacking updated from yhat_cs dict to scalar yhat_c.

SLearner
- predict() / BaseSClassifier.predict(): X_new_c and X_new_t (hstack of
  treatment indicator + X) are identical for every group; construction
  hoisted outside the group loop.

RLearner: no change needed — model_mu already fitted once via
cross_val_predict, not per group.
fit() (BaseXLearner and BaseXClassifier):
- Cache y_control_pred = model_mu_c.predict[_proba](X[control_mask]) once
  immediately after fitting model_mu_c.
- Derive var_c from the cached array before the group loop; assign the
  scalar to self.vars_c[group] inside the loop without re-predicting.

predict() (BaseXLearner and BaseXClassifier):
- Pre-compute yhat_c_verbose = model_mu_c.predict[_proba](X[control_mask])
  once before the group loop, guarded by the verbose condition.
- Replace the per-group self.model_mu_c.predict call in the verbose block
  with a direct reference to the cached array.
@Ic3fr0g Ic3fr0g changed the title Optimised training time and memory for metalearners Optimised training, inference and memory for metalearners in multitreatment settings May 9, 2026
@Ic3fr0g Ic3fr0g marked this pull request as ready for review May 9, 2026 13:12
@Ic3fr0g
Copy link
Copy Markdown
Author

Ic3fr0g commented May 9, 2026

@jeongyoonlee would greatly appreciate it if you could take a look at this PR!

@jeongyoonlee jeongyoonlee requested a review from Copilot May 9, 2026 16:23
@jeongyoonlee jeongyoonlee added the enhancement New feature or request label May 9, 2026
Copy link
Copy Markdown

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR optimizes multi-treatment meta-learners by removing redundant per-treatment control-model fits/predictions and hoisting shared computations to reduce training/inference time and memory footprint (addresses #853).

Changes:

  • X-learner: fit the control outcome model once, reuse control variance/predictions across treatment groups, and reduce repeated verbose-path predictions.
  • T-learner: fit the control model once and reuse a single control prediction vector across treatment groups.
  • DR-learner and S-learner: hoist group-invariant predictions/feature construction out of per-group loops to avoid repeated work and allocations.

Reviewed changes

Copilot reviewed 4 out of 4 changed files in this pull request and generated 2 comments.

File Description
causalml/inference/meta/xlearner.py Fits model_mu_c once on control data and reuses control-only artifacts across groups; reduces repeated predictions in verbose paths.
causalml/inference/meta/tlearner.py Fits model_c once and reuses a single yhat_c across groups; adjusts return_components outputs accordingly.
causalml/inference/meta/slearner.py Hoists np.hstack-built augmented design matrices outside the per-group loop to reduce repeated allocations.
causalml/inference/meta/drlearner.py Computes control predictions once (fold-averaged) and reuses them across treatment groups; adjusts return_components outputs accordingly.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread causalml/inference/meta/drlearner.py Outdated
Comment on lines +625 to +628
if not return_components:
return te
else:
return te, yhat_cs, yhat_ts
return te, yhat_c, yhat_ts
Comment thread causalml/inference/meta/tlearner.py Outdated
return te
else:
return te, yhat_cs, yhat_ts
return te, yhat_c, yhat_ts
@jeongyoonlee
Copy link
Copy Markdown
Collaborator

Thanks @Ic3fr0g — the math is sound and the savings are real. A few things to address before merge:

1. Silent API break in predict(return_components=True)

T-learner and DR-learner (regressor + classifier) return type changes from (te, dict, dict) to (te, ndarray, dict). CI is green only because test_BaseDRClassifier (test_meta_learners.py:1184) remaps treatment to binary 0/1, so yhat_cs[1] becomes ndarray scalar indexing that happens to return a probability in [0,1] — the assertion passes but no longer checks what it claims. Either keep the dict shape externally (build {group: yhat_c for group in self.t_groups} — pickle memo keeps it cheap), or call this out as a breaking change and update the test + docstrings + changelog.

2. Public state attributes deleted

BaseTLearner.models_c and BaseXLearner.models_mu_c (+ classifier) disappear. Anyone introspecting these or loading old pickles with new code gets AttributeError. Same fix as above — restore as shared-reference dicts.

3. self.model_c = deepcopy(self.model_c) overwrites the template

After fit(), self.model_c is the fitted model, not the unfitted template assigned in __init__. Re-fitting now deepcopies the fitted state then resets it via .fit() — wasted work, and behavior diverges for learners with warm_start=True. Same issue with self.model_mu_c in X-learner. Fix: stash the template (e.g. self._model_c_template) and always deepcopy from it.

4. No multi-treatment test for the central claim

The whole point of this PR is multi-treatment correctness, but no test exercises 3+ treatment groups against a reference. Please add one that asserts predicted te matches the pre-refactor numbers (or a known-equivalent computation) for T, X, S, and DR with N=3 treatments — green CI today does not cover this path.

Non-blocking

  • Returns docstrings still say yhat_cs (dict) — update to yhat_c (ndarray).
  • self.vars_c in X-learner is now a dict of identical scalars; consider collapsing to a single self.var_c.
  • A before/after timing on a realistic N=3 multi-treatment workload would help readers calibrate the win.

@Ic3fr0g
Copy link
Copy Markdown
Author

Ic3fr0g commented May 9, 2026

Quite a few egregious errors on my end. Should've thought about maintaining backward compatibility properly. I think I'll proceed with this as a new feature and not break existing functionality/docstrings/tests/etc.

Let me know if you would also like me to benchmark the results. I'll make these changes in the morning and send it back to you!

Happy weekend!

…plots

- Adds test for meta-learner consistency and key attributes
@Ic3fr0g
Copy link
Copy Markdown
Author

Ic3fr0g commented May 10, 2026

Incorporated all review comments!

Benchmarking

Code

Click to expand!
"""
Benchmark meta-learners: timing fit() and predict() with fixed synthetic data.

Fixed setup: n=300 samples (DR 3-fold cross-fit needs enough rows per fold per arm),
p=3 features, four treatment arms (1–4) plus control (0). Treatment is balanced
(equal counts per arm) and shuffled.

Produces box plots (wall time, CPU time, peak traced memory) across repeated runs.

Usage:
    python benchmarks/benchmark_metalearners.py --reps 15
"""

from __future__ import annotations

import argparse
import gc
import time
import tracemalloc
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from sklearn.linear_model import LinearRegression

_SCRIPT_DIR = Path(__file__).resolve().parent
_REPO_ROOT = _SCRIPT_DIR.parent

# Fixed benchmark geometry (do not expose via CLI)
N_SAMPLES = 300
N_FEATURES = 3
N_TREATMENT_ARMS = 4  # treatment labels 1..N plus control 0
_SEED_BASE = 42

_LEARNER_ORDER = ["T", "X", "S", "R", "DR"]
_LEARNER_COLORS = {
    "T": "#E69F00",
    "X": "#56B4E9",
    "S": "#009E73",
    "R": "#CC79A7",
    "DR": "#0072B2",
}


class _Timer:
    """Wall time, CPU time, and peak traced memory for one block."""

    def __enter__(self):
        gc.collect()
        tracemalloc.start()
        self._wall = time.perf_counter()
        self._cpu = time.process_time()
        return self

    def __exit__(self, *_):
        self.wall_s = time.perf_counter() - self._wall
        self.cpu_s = time.process_time() - self._cpu
        _, self.peak_kb = tracemalloc.get_traced_memory()
        tracemalloc.stop()
        self.peak_kb //= 1024


def make_multi_treatment_data(n: int, p: int, n_groups: int, seed: int):
    """Balanced assignment to control (0) and arms 1..n_groups; requires ``n % (n_groups + 1) == 0``."""
    n_levels = n_groups + 1
    if n % n_levels != 0:
        raise ValueError(f"n ({n}) must be divisible by number of treatment levels ({n_levels})")

    rng = np.random.default_rng(seed)
    X = rng.standard_normal((n, p))
    treatment = np.tile(np.arange(n_levels, dtype=int), n // n_levels)
    rng.shuffle(treatment)
    tau = np.where(treatment == 0, 0.0, treatment.astype(float))
    y = X[:, 0] + tau + 0.1 * rng.standard_normal(n)
    return X, treatment, y


def benchmark_one_pass(seed: int) -> list[dict]:
    """One repetition: timed fit then timed predict using the same fitted model."""
    from causalml.inference.meta import (
        BaseDRRegressor,
        BaseRRegressor,
        BaseSRegressor,
        BaseTRegressor,
        BaseXRegressor,
    )

    X, treatment, y = make_multi_treatment_data(
        n=N_SAMPLES,
        p=N_FEATURES,
        n_groups=N_TREATMENT_ARMS,
        seed=seed,
    )
    p_score = {
        g: np.full(N_SAMPLES, 1.0 / (N_TREATMENT_ARMS + 1))
        for g in range(1, N_TREATMENT_ARMS + 1)
    }
    rows = []

    def add(learner: str, phase: str, wall: float, cpu: float, peak_kb: float):
        rows.append(
            {
                "learner": learner,
                "phase": phase,
                "wall_s": wall,
                "cpu_s": cpu,
                "peak_kb": peak_kb,
            }
        )

    # T — reuse first fit for predict timing (no second fit)
    with _Timer() as t:
        m = BaseTRegressor(learner=LinearRegression())
        m.fit(X=X, treatment=treatment, y=y)
    add("T", "fit", t.wall_s, t.cpu_s, t.peak_kb)
    with _Timer() as t:
        m.predict(X=X)
    add("T", "predict", t.wall_s, t.cpu_s, t.peak_kb)

    # X (propensity fixed)
    with _Timer() as t:
        mx = BaseXRegressor(learner=LinearRegression())
        mx.fit(X=X, treatment=treatment, y=y, p=p_score)
    add("X", "fit", t.wall_s, t.cpu_s, t.peak_kb)
    with _Timer() as t:
        mx.predict(X=X, p=p_score)
    add("X", "predict", t.wall_s, t.cpu_s, t.peak_kb)

    # S
    with _Timer() as t:
        ms = BaseSRegressor(learner=LinearRegression())
        ms.fit(X=X, treatment=treatment, y=y)
    add("S", "fit", t.wall_s, t.cpu_s, t.peak_kb)
    with _Timer() as t:
        ms.predict(X=X)
    add("S", "predict", t.wall_s, t.cpu_s, t.peak_kb)

    # R
    with _Timer() as t:
        mr = BaseRRegressor(learner=LinearRegression(), random_state=42, cv_n_jobs=1)
        mr.fit(X=X, treatment=treatment, y=y, p=p_score, verbose=False)
    add("R", "fit", t.wall_s, t.cpu_s, t.peak_kb)
    with _Timer() as t:
        mr.predict(X=X)
    add("R", "predict", t.wall_s, t.cpu_s, t.peak_kb)

    # DR — fixed p (no propensity CV). Pass seed so KFold splits are tied to this run;
    # with large balanced n, each fold almost surely has every arm for outcome regression.
    with _Timer() as t:
        md = BaseDRRegressor(learner=LinearRegression(), control_name=0)
        md.fit(X=X, treatment=treatment, y=y, p=p_score, seed=seed)
    add("DR", "fit", t.wall_s, t.cpu_s, t.peak_kb)
    with _Timer() as t:
        md.predict(X=X)
    add("DR", "predict", t.wall_s, t.cpu_s, t.peak_kb)

    return rows


def run_benchmarks(reps: int, seed0: int = _SEED_BASE) -> pd.DataFrame:
    all_rows = []
    for r in range(reps):
        for row in benchmark_one_pass(seed=seed0 + r):
            row["rep"] = r
            all_rows.append(row)
    return pd.DataFrame(all_rows)


def plot_benchmarks(df: pd.DataFrame, out_path: Path, title_suffix: str = "") -> None:
    """2×3 grid: rows = fit / predict, cols = wall, CPU, peak memory."""
    df = df.copy()
    df["peak_mb"] = df["peak_kb"] / 1024.0
    df["learner"] = pd.Categorical(df["learner"], categories=_LEARNER_ORDER, ordered=True)
    palette = [_LEARNER_COLORS[c] for c in _LEARNER_ORDER]

    phases = ["fit", "predict"]
    metrics = [
        ("wall_s", "Wall time (s)"),
        ("cpu_s", "CPU time (s)"),
        ("peak_mb", "Peak traced memory (MB)"),
    ]

    fig, axes = plt.subplots(2, 3, figsize=(14, 7), constrained_layout=True)
    st = fig.suptitle(
        f"Meta-learner benchmarks{title_suffix}",
        fontsize=13,
        fontweight="bold",
    )

    for i, phase in enumerate(phases):
        sub = df[df["phase"] == phase]
        for j, (col, ylab) in enumerate(metrics):
            ax = axes[i, j]
            sns.boxplot(
                data=sub,
                x="learner",
                y=col,
                order=_LEARNER_ORDER,
                hue="learner",
                hue_order=_LEARNER_ORDER,
                palette=palette,
                dodge=False,
                legend=False,
                ax=ax,
                linewidth=1,
                fliersize=3,
            )
            ax.set_title(f"{phase.capitalize()} — {ylab.split()[0]}")
            ax.set_xlabel("Learner")
            ax.set_ylabel(ylab)

    plt.savefig(out_path, dpi=150, bbox_inches="tight", bbox_extra_artists=(st,))
    plt.close()


def main():
    parser = argparse.ArgumentParser(description="Benchmark meta-learners (fixed n, p, treatment arms)")
    parser.add_argument(
        "--reps",
        type=int,
        default=15,
        help="Number of repeated benchmark passes (default: 15)",
    )
    args = parser.parse_args()

    output_dir = _REPO_ROOT / "benchmarks" / "output"
    output_dir.mkdir(parents=True, exist_ok=True)

    df = run_benchmarks(reps=args.reps)
    csv_path = output_dir / "metalearner_benchmark_runs.csv"
    png_path = output_dir / "metalearner_benchmark_boxplots.png"
    df.to_csv(csv_path, index=False)

    plot_benchmarks(
        df,
        png_path,
        title_suffix=(
            f"\nn={N_SAMPLES}, p={N_FEATURES}, treatment arms={N_TREATMENT_ARMS}, reps={args.reps}"
        ),
    )

    print(f"Wrote:\n  {csv_path}\n  {png_path}")


if __name__ == "__main__":
    main()

Results

Learner Phase Master (Wall Time) Feature Branch (Wall Time) Delta (%) Peak Memory Change
T Fit ~0.0057s ~0.0037s -35% 46KB → 32KB (-30%)
T Predict ~0.0008s ~0.0006s -25% 34KB → 26KB (-23%)
X Fit ~0.0125s ~0.0097s -22% 62KB → 53KB (-14%)
X Predict ~0.0010s ~0.0010s 0% No change
S Fit ~0.0031s ~0.0032s +3% (Negligible) No change
R Fit ~0.0123s ~0.0122s -1% (Stable) No change
DR Predict ~0.0033s ~0.0026s -21% 57KB → 49KB (-14%)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

enhancement New feature or request

Projects

None yet

Development

Successfully merging this pull request may close these issues.

BaseTLearner fit model_c for every treatment on the same data and predict same model on same data

3 participants