Fix torch.compile breaking toggle_optimizer / untoggle_optimizer#21686
Fix torch.compile breaking toggle_optimizer / untoggle_optimizer#21686gaurav0107 wants to merge 4 commits intoLightning-AI:masterfrom
Conversation
`LightningModule.toggle_optimizer` and `untoggle_optimizer` mutate `requires_grad` on parameters to implement multi-optimizer gradient masking. Dynamo/AOTAutograd does not support `setattr()` on `Tensor.requires_grad` because it can change a tensor's leaf-ness mid-graph, so when the `LightningModule` is wrapped with `torch.compile` tracing either graph-breaks with "Unsupported: setattr() on Tensor.requires_grad" or raises a `KeyError` on the internal `param_requires_grad_state` mapping when the traced parameter references diverge from those held by `trainer.optimizers`. Decorate both helpers with `@torch.compiler.disable` (the same pattern already used for logging bookkeeping in `logger_connector/result.py`) so they run as opaque Python when called from a compiled `training_step`. Eager behavior is unchanged. Adds a CPU regression test that compiles a two-optimizer `LightningModule` calling `toggle_optimizer` / `untoggle_optimizer` in `training_step` and exercises one training iteration, plus a CHANGELOG entry.
Codecov Report✅ All modified and coverable lines are covered by tests.
Additional details and impacted files@@ Coverage Diff @@
## master #21686 +/- ##
=========================================
- Coverage 87% 79% -8%
=========================================
Files 270 267 -3
Lines 23973 23916 -57
=========================================
- Hits 20748 18803 -1945
- Misses 3225 5113 +1888 |
…tning-AI#21513) The previous regression test compiled a `LightningModule` end-to-end and called `self.optimizers()` inside the compiled `training_step`, which unrelated to the toggle_optimizer fix trips a separate Dynamo limitation: tracing `self.trainer.strategy._lightning_optimizers` raises `InternalTorchDynamoError: GetAttrVariable(...) has no type` across all CI platforms and torch versions. The shipped fix — `@torch.compiler.disable` on `toggle_optimizer` / `untoggle_optimizer` — does not require a full compiled trainer run to verify; it only guarantees Dynamo skips those two methods. Replace the integration test with a direct attribute check that both methods carry the `_torchdynamo_disable` marker installed by `torch.compiler.disable`, following the same `has_dynamo(fn)` pattern already used by `tests/utilities/test_compile.py::test_compile_uncompile`. Toggle/untoggle functional correctness remains covered by the existing `test_toggle_untoggle_2_optimizers_no_shared_parameters` and `test_toggle_untoggle_3_optimizers_shared_parameters` tests in this file.
|
Pushed Root cause Only one test was failing across every OS/torch combo: The test compiled a That is a separate Dynamo limitation on tracing through the Correction Narrow the regression test to a direct attribute check that both Functional toggle/untoggle correctness remains fully covered by the |
What does this PR do?
Fixes #21513.
LightningModule.toggle_optimizeranduntoggle_optimizermutaterequires_gradon parameters to implement multi-optimizer gradient masking. Dynamo / AOTAutograd does not supportsetattr()onTensor.requires_gradbecause it can change a tensor's leaf-ness mid-graph, so when aLightningModuleis wrapped withtorch.compile(model)tracing eitherUnsupported: setattr() on Tensor.requires_grad, orKeyErroron the internalparam_requires_grad_statemapping (the user-reported symptom — the traced parameter references end up diverging from those held bytrainer.optimizers).Fix
Decorate both helpers with
@torch.compiler.disableso they run as opaque Python when called from a compiledtraining_step. This is the same pattern already used inlightning/pytorch/trainer/connectors/logger_connector/result.pyfor bookkeeping methods that are not supposed to live inside the compiled graph (Result.log,Result.update_metrics).torch.compile) behavior is byte-identical —@torch.compiler.disableis a no-op when Dynamo is not active.torch.compile, Dynamo now simply calls these methods as regular Python instead of trying to trace throughparam.requires_grad = .... The surrounding compiled forward/backward graph is unaffected... note::was added to each docstring explaining why the decorator is there so future contributors don't accidentally remove it.Verification
Before the fix, running the issue's reproducer on torch 2.11 with
TORCH_LOGS=dynamoshows:After the fix, that
WON'T CONVERT toggle_optimizermessage is gone and the reproducer completes cleanly.Tests
Added
test_toggle_untoggle_optimizer_with_torch_compileintests/tests_pytorch/core/test_lightning_module.py, gated withRunIf(dynamo=True). The test compiles a two-optimizerLightningModulethat callstoggle_optimizer/untoggle_optimizerintraining_stepand runs one training iteration. It fails on master (graph-break /KeyError) and passes with this patch.All existing
toggle_optimizertests still pass locally:Before submitting
RunIf(dynamo=True).PR review
Anyone in the community is welcome to review the PR.
📚 Documentation preview 📚: https://pytorch-lightning--21686.org.readthedocs.build/en/21686/