Optimised training, inference and memory for metalearners in multitreatment settings#896
Optimised training, inference and memory for metalearners in multitreatment settings#896Ic3fr0g wants to merge 5 commits intouber:masterfrom
Conversation
Previously, BaseTLearner.fit() deep-copied and trained model_c once per treatment group even though all groups share the same control data. Likewise, BaseTLearner.predict() (and BaseTClassifier.predict()) called model_c.predict() once per group despite the model being identical. - fit(): compute control_mask once, deep-copy model_c once, fit it once, then share the single fitted instance in self.models_c across all groups. - predict() / BaseTClassifier.predict(): call model_c.predict[_proba]() once before the loop and reuse the result for each group. Adds test_BaseTLearner_model_c_trained_and_predicted_once to verify that both fit() and predict() invoke model_c exactly once regardless of treatment group count.
Now that model_c is trained once before the loop, the per-group mask that combined treatment and control rows is no longer needed for model_t either. Replace the mask + treatment_filt + w == 1 chain with a direct boolean index on the treatment group alone, which is both simpler and faster.
TLearner - fit(): self.models_c dict replaced by a single self.model_c trained once on control data; per-group loop now only trains treatment models with a direct boolean index (no mask+w chain). - predict() / BaseTClassifier.predict(): yhat_cs dict removed; yhat_c computed once from self.model_c before the loop. - fit_predict() / estimate_ate(): bootstrap save/restore updated from models_c dict to scalar model_c. XLearner - fit() / BaseXClassifier.fit(): models_mu_c dict replaced by a single self.model_mu_c trained once; per-group loop skips redundant re-training and uses direct treatment_mask indexing. - predict() / BaseXClassifier.predict(): verbose block updated to use self.model_mu_c directly. - fit_predict() / estimate_ate(): bootstrap save/restore updated. DRLearner - predict() / BaseDRClassifier.predict(): models_mu_c is fold-specific but not group-specific; yhat_c hoisted outside the group loop to avoid running the 3-fold ensemble prediction once per group. - estimate_ate(): unpacking updated from yhat_cs dict to scalar yhat_c. SLearner - predict() / BaseSClassifier.predict(): X_new_c and X_new_t (hstack of treatment indicator + X) are identical for every group; construction hoisted outside the group loop. RLearner: no change needed — model_mu already fitted once via cross_val_predict, not per group.
fit() (BaseXLearner and BaseXClassifier): - Cache y_control_pred = model_mu_c.predict[_proba](X[control_mask]) once immediately after fitting model_mu_c. - Derive var_c from the cached array before the group loop; assign the scalar to self.vars_c[group] inside the loop without re-predicting. predict() (BaseXLearner and BaseXClassifier): - Pre-compute yhat_c_verbose = model_mu_c.predict[_proba](X[control_mask]) once before the group loop, guarded by the verbose condition. - Replace the per-group self.model_mu_c.predict call in the verbose block with a direct reference to the cached array.
|
@jeongyoonlee would greatly appreciate it if you could take a look at this PR! |
There was a problem hiding this comment.
Pull request overview
This PR optimizes multi-treatment meta-learners by removing redundant per-treatment control-model fits/predictions and hoisting shared computations to reduce training/inference time and memory footprint (addresses #853).
Changes:
- X-learner: fit the control outcome model once, reuse control variance/predictions across treatment groups, and reduce repeated verbose-path predictions.
- T-learner: fit the control model once and reuse a single control prediction vector across treatment groups.
- DR-learner and S-learner: hoist group-invariant predictions/feature construction out of per-group loops to avoid repeated work and allocations.
Reviewed changes
Copilot reviewed 4 out of 4 changed files in this pull request and generated 2 comments.
| File | Description |
|---|---|
causalml/inference/meta/xlearner.py |
Fits model_mu_c once on control data and reuses control-only artifacts across groups; reduces repeated predictions in verbose paths. |
causalml/inference/meta/tlearner.py |
Fits model_c once and reuses a single yhat_c across groups; adjusts return_components outputs accordingly. |
causalml/inference/meta/slearner.py |
Hoists np.hstack-built augmented design matrices outside the per-group loop to reduce repeated allocations. |
causalml/inference/meta/drlearner.py |
Computes control predictions once (fold-averaged) and reuses them across treatment groups; adjusts return_components outputs accordingly. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| if not return_components: | ||
| return te | ||
| else: | ||
| return te, yhat_cs, yhat_ts | ||
| return te, yhat_c, yhat_ts |
| return te | ||
| else: | ||
| return te, yhat_cs, yhat_ts | ||
| return te, yhat_c, yhat_ts |
|
Thanks @Ic3fr0g — the math is sound and the savings are real. A few things to address before merge: 1. Silent API break in
|
|
Quite a few egregious errors on my end. Should've thought about maintaining backward compatibility properly. I think I'll proceed with this as a new feature and not break existing functionality/docstrings/tests/etc. Let me know if you would also like me to benchmark the results. I'll make these changes in the morning and send it back to you! Happy weekend! |
…plots - Adds test for meta-learner consistency and key attributes
|
Incorporated all review comments! BenchmarkingCodeClick to expand!Results
|
Proposed changes
Optimised training, inference and memory for most metalearners for multi-treatment settings.
Fixes #853
Here's a summary of every change made across all learners:
models_cdict ->model_csingle model(N-1)fits on control data +(N-1)forward passes on fullX(N-1) × model_ccopies in RAM;(N-1) × model_cin pickleyhat_csdict -> scalaryhat_c; loop uses direct boolean index(N-1)arrays of shape(n,)models_mu_cdict ->model_mu_csingle model(N-1)fits on control data(N-1) × model_mu_ccopies in RAM;(N-1) ×in pickley_control_predcached after fit;var_ccomputed once before loop(N-1)forward passes onX_controlduring fit0(scalar reuse)yhat_c_verbosehoisted before loop in bothpredictmethods(N-1)forward passes onX_controlduring predict (verbose path)0(scalar reuse)yhat_choisted outside group loop(N-1) × 3forward passes on fullXperpredict()call(N-1)arrays of shape(n,)inyhat_csdictX_new_c,X_new_thoisted outside group loop(N-1) × 2hstackops on(n × (p+1))(N-1) × 2peak arrays of shape(n, p+1)allocated and dropped each iterationWhere
N= number of treatment groups. For the common binary case (N=1), there is no saving; savings are realized in multi-treatment settings, where these learners are most expensive.Types of changes
What types of changes does your code introduce to CausalML?
Put an
xin the boxes that applyChecklist
Put an
xin the boxes that apply. You can also fill these out after creating the PR. If you're unsure about any of them, don't hesitate to ask. We're here to help! This is simply a reminder of what we are going to look for before merging your code.Further comments
I started off with the T-learner basis the issue and my personal experience using this project at work. Then I gradually expanded the scope because I saw the same ineffiencies everywhere. Lot's of memory bloat, training and inference time is required for most metalearners, this aims to solve some of those issues.
Pickle / joblib shared-reference behavior
Pickle uses a memo dict keyed by
id(obj). When the same object appears multiple times in a structure, it is serialized once, and subsequent occurrences become tiny backreference opcodes (about 2 bytes each).So:
The overhead of a shared-ref dict vs a single object is just the dict structure + N backreference opcodes, which is negligible for real models.
But the old code
({group: deepcopy(self.model_c) for group in self.t_groups})created N separate Python objects, so pickle saw N different id() values and serialized each one fully. The new code(self.model_c)removes the dict entirely.