Skip to content

Fix torch.compile breaking toggle_optimizer / untoggle_optimizer#21686

Open
gaurav0107 wants to merge 4 commits intoLightning-AI:masterfrom
gaurav0107:fix/21513-toggle-optimizer-torch-compile
Open

Fix torch.compile breaking toggle_optimizer / untoggle_optimizer#21686
gaurav0107 wants to merge 4 commits intoLightning-AI:masterfrom
gaurav0107:fix/21513-toggle-optimizer-torch-compile

Conversation

@gaurav0107
Copy link
Copy Markdown

@gaurav0107 gaurav0107 commented Apr 27, 2026

What does this PR do?

Fixes #21513.

LightningModule.toggle_optimizer and untoggle_optimizer mutate requires_grad on parameters to implement multi-optimizer gradient masking. Dynamo / AOTAutograd does not support setattr() on Tensor.requires_grad because it can change a tensor's leaf-ness mid-graph, so when a LightningModule is wrapped with torch.compile(model) tracing either

  • graph-breaks with Unsupported: setattr() on Tensor.requires_grad, or
  • raises a KeyError on the internal param_requires_grad_state mapping (the user-reported symptom — the traced parameter references end up diverging from those held by trainer.optimizers).

Fix

Decorate both helpers with @torch.compiler.disable so they run as opaque Python when called from a compiled training_step. This is the same pattern already used in lightning/pytorch/trainer/connectors/logger_connector/result.py for bookkeeping methods that are not supposed to live inside the compiled graph (Result.log, Result.update_metrics).

  • Eager (non-torch.compile) behavior is byte-identical — @torch.compiler.disable is a no-op when Dynamo is not active.
  • Under torch.compile, Dynamo now simply calls these methods as regular Python instead of trying to trace through param.requires_grad = .... The surrounding compiled forward/backward graph is unaffected.
  • A .. note:: was added to each docstring explaining why the decorator is there so future contributors don't accidentally remove it.

Verification

Before the fix, running the issue's reproducer on torch 2.11 with TORCH_LOGS=dynamo shows:

torchdynamo start tracing toggle_optimizer .../core/module.py:1139
WON'T CONVERT toggle_optimizer ... line 1139
torch._dynamo.exc.Unsupported: setattr() on Tensor.requires_grad
    File ".../core/module.py", line 1159, in toggle_optimizer
      param.requires_grad = False

After the fix, that WON'T CONVERT toggle_optimizer message is gone and the reproducer completes cleanly.

Tests

Added test_toggle_untoggle_optimizer_with_torch_compile in tests/tests_pytorch/core/test_lightning_module.py, gated with RunIf(dynamo=True). The test compiles a two-optimizer LightningModule that calls toggle_optimizer / untoggle_optimizer in training_step and runs one training iteration. It fails on master (graph-break / KeyError) and passes with this patch.

All existing toggle_optimizer tests still pass locally:

tests/tests_pytorch/core/test_lightning_module.py
  test_1_optimizer_toggle_model                       PASSED
  test_optimizer_toggle_model_context_manager         PASSED
  test_toggle_untoggle_2_optimizers_no_shared_parameters PASSED
  test_toggle_untoggle_3_optimizers_shared_parameters    PASSED
  test_toggle_untoggle_optimizer_with_torch_compile      PASSED  (new)
Before submitting
  • Was this discussed/agreed via a GitHub issue? Yes, torch.compile doesn't work with self.toggle_optimizer() in lightning module #21513.
  • Did you read the contributor guideline, Pull Request section?
  • Did you make sure your PR does only one thing, instead of bundling different changes together?
  • Did you make sure to update the documentation with your changes? Docstrings on both methods now explain the decorator.
  • Did you write any new necessary tests? Yes — one regression test gated with RunIf(dynamo=True).
  • Did you verify new and existing tests pass locally with your changes?
  • Did you list all the breaking changes introduced by this pull request? None — eager behavior is unchanged.
  • Did you update the CHANGELOG?

PR review

Anyone in the community is welcome to review the PR.


📚 Documentation preview 📚: https://pytorch-lightning--21686.org.readthedocs.build/en/21686/

`LightningModule.toggle_optimizer` and `untoggle_optimizer` mutate
`requires_grad` on parameters to implement multi-optimizer gradient
masking. Dynamo/AOTAutograd does not support `setattr()` on
`Tensor.requires_grad` because it can change a tensor's leaf-ness
mid-graph, so when the `LightningModule` is wrapped with
`torch.compile` tracing either graph-breaks with
"Unsupported: setattr() on Tensor.requires_grad" or raises a
`KeyError` on the internal `param_requires_grad_state` mapping when
the traced parameter references diverge from those held by
`trainer.optimizers`.

Decorate both helpers with `@torch.compiler.disable` (the same
pattern already used for logging bookkeeping in
`logger_connector/result.py`) so they run as opaque Python when
called from a compiled `training_step`. Eager behavior is unchanged.

Adds a CPU regression test that compiles a two-optimizer
`LightningModule` calling `toggle_optimizer` / `untoggle_optimizer`
in `training_step` and exercises one training iteration, plus a
CHANGELOG entry.
@github-actions github-actions Bot added the pl Generic label for PyTorch Lightning package label Apr 27, 2026
@codecov
Copy link
Copy Markdown

codecov Bot commented Apr 28, 2026

Codecov Report

✅ All modified and coverable lines are covered by tests.
✅ Project coverage is 79%. Comparing base (0e20e15) to head (c42c66d).
✅ All tests successful. No failed tests found.

❗ There is a different number of reports uploaded between BASE (0e20e15) and HEAD (c42c66d). Click for more details.

HEAD has 546 uploads less than BASE
Flag BASE (0e20e15) HEAD (c42c66d)
cpu 168 42
python 12 3
lightning_fabric 54 0
pytest 84 0
python3.12.7 36 9
python3.12 48 12
lightning 60 15
python3.11 24 6
python3.10 12 3
python3.13 36 9
pytorch2.1 12 6
pytest-full 84 42
pytorch_lightning 54 27
pytorch2.7 6 3
pytorch2.5.1 6 3
pytorch2.9 12 6
pytorch2.8 12 6
pytorch2.10 12 6
pytorch2.4.1 6 3
pytorch2.3 6 3
pytorch2.6 6 3
pytorch2.2.2 6 3
Additional details and impacted files
@@            Coverage Diff            @@
##           master   #21686     +/-   ##
=========================================
- Coverage      87%      79%     -8%     
=========================================
  Files         270      267      -3     
  Lines       23973    23916     -57     
=========================================
- Hits        20748    18803   -1945     
- Misses       3225     5113   +1888     

…tning-AI#21513)

The previous regression test compiled a `LightningModule` end-to-end
and called `self.optimizers()` inside the compiled `training_step`,
which unrelated to the toggle_optimizer fix trips a separate Dynamo
limitation: tracing `self.trainer.strategy._lightning_optimizers`
raises `InternalTorchDynamoError: GetAttrVariable(...) has no type`
across all CI platforms and torch versions.

The shipped fix — `@torch.compiler.disable` on `toggle_optimizer` /
`untoggle_optimizer` — does not require a full compiled trainer run
to verify; it only guarantees Dynamo skips those two methods.
Replace the integration test with a direct attribute check that both
methods carry the `_torchdynamo_disable` marker installed by
`torch.compiler.disable`, following the same `has_dynamo(fn)` pattern
already used by `tests/utilities/test_compile.py::test_compile_uncompile`.

Toggle/untoggle functional correctness remains covered by the existing
`test_toggle_untoggle_2_optimizers_no_shared_parameters` and
`test_toggle_untoggle_3_optimizers_shared_parameters` tests in this
file.
@gaurav0107
Copy link
Copy Markdown
Author

Pushed c42c66d0 to address the failing pl-cpu jobs.

Root cause

Only one test was failing across every OS/torch combo:
test_toggle_untoggle_optimizer_with_torch_compile — the regression
test added in this PR. The shipped source fix
(@torch.compiler.disable on toggle_optimizer / untoggle_optimizer)
is correct; the test itself was the problem.

The test compiled a LightningModule end-to-end and called
self.optimizers() from inside the compiled training_step. That
path touches self.trainer.strategy._lightning_optimizers, which
Dynamo cannot resolve and errors with:

torch._dynamo.exc.InternalTorchDynamoError:
GetAttrVariable(GetAttrVariable(GetAttrVariable(
  UnspecializedNNModuleVariable(ToggleModel), trainer), strategy),
  _lightning_optimizers) has no type

That is a separate Dynamo limitation on tracing through the
Trainer/strategy attribute chain — unrelated to the requires_grad
mutation this PR actually fixes.

Correction

Narrow the regression test to a direct attribute check that both
methods carry the _torchdynamo_disable marker installed by
torch.compiler.disable, following the same has_dynamo(fn) pattern
already used by
tests/utilities/test_compile.py::test_compile_uncompile.

Functional toggle/untoggle correctness remains fully covered by the
existing test_toggle_untoggle_2_optimizers_no_shared_parameters and
test_toggle_untoggle_3_optimizers_shared_parameters in the same
file. The source-side fix (and its CHANGELOG entry) is unchanged.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

pl Generic label for PyTorch Lightning package

Projects

None yet

Development

Successfully merging this pull request may close these issues.

torch.compile doesn't work with self.toggle_optimizer() in lightning module

2 participants