[Performance] Recompute backward + narrow canonicalization for RNN backends#3752
Open
vmoens wants to merge 2 commits into
Open
[Performance] Recompute backward + narrow canonicalization for RNN backends#3752vmoens wants to merge 2 commits into
vmoens wants to merge 2 commits into
Conversation
…ckends Adds a memory-saving recompute mode and a small set of helpers driven by the Isaac RNN PPO follow-ups (TODO: isaac_rnn_scan_triton_memory_followups). - New ``recurrent_recompute: Literal["none", "full"]`` kwarg on ``LSTMModule`` / ``GRUModule``, orthogonal to ``recurrent_backend``. - Triton ``_LSTMFn`` / ``_GRUFn`` accept ``recompute``: drops the per-step gate buffers (``save_i/f/g/o/save_tanhc`` for LSTM, ``save_r/z/n/gh_n`` for GRU) from ``save_for_backward`` and replays the fwd kernel into scratch during backward. Saves ~5 ``[B, T, H_pad]`` activation tensors per layer. - Scan ``_lstm_scan_with_resets`` / ``_gru_scan_with_resets`` swap the ``torch._higher_order_ops.scan`` HOP for a python time-loop wrapped in ``torch.utils.checkpoint.checkpoint(use_reentrant=False)`` when ``recurrent_recompute='full'``. Parameter ``.clone()`` calls are dropped along that path; gradients match the cuDNN ``"pad"`` backend to float precision. - ``LSTMModule.canonicalize`` / ``GRUModule.canonicalize`` plus ``torchrl.modules.canonicalize_rnn_subset`` apply ``contiguous(canonical=True)`` only to the module's in/out keys, so callers can avoid materializing transient full-batch copies of unrelated TensorDict leaves. - ``torchrl.cuda_memory_stats``, ``reset_cuda_peak_stats`` and the ``cuda_memory_profile`` context manager replace experiment-side reinventions of ``torch.cuda.memory_*`` and pair with ``timeit`` for per-phase peak accounting. - Tests in ``test/modules/test_rnn.py`` cover scan-recompute parity vs pad, triton recompute parity, the canonicalize subset method and free function, the recompute-rejected-for-pad guard and invalid-value rejection (CUDA-only tests are skipped where appropriate). ``test/test_utils.py`` covers the new memory helpers. - New benchmark ``benchmarks/bench_rnn_recompute_memory.py`` reports peak allocated GB for ``recompute=none`` vs ``recompute='full'`` per (rnn, backend). Co-Authored-By: Claude Opus 4.7 <[email protected]>
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/rl/3752
Note: Links to docs will display an error until the docs builds have been completed. ❗ 1 Active SEVsThere are 1 currently active SEVs. If your PR is affected, please view them below: ✅ No FailuresAs of commit d69a009 with merge base 0a01ee8 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
…'t read them The Triton LSTM/GRU forward was unconditionally allocating save_i/save_f/save_g/save_o/save_tanhc (LSTM) and save_r/save_z/save_n/save_gh_n (GRU) and writing to them from the kernel, even under recompute=True (where ctx.save_for_backward drops them) and under torch.no_grad / no-tracked-input calls (where backward never runs). Under CudaGraphModule / CUDA graph capture, these otherwise-dead allocations get attributed to the graph's private memory pool and inflate persistent memory_allocated() -- defeating the recompute mode's memory win. - Add ``SAVE_GATES: tl.constexpr = True`` to ``_lstm_fwd_kernel`` and ``_gru_fwd_kernel``. Wrap the ``tl.store(save_*)`` lines in ``if SAVE_GATES:`` so the stores are dead-code-eliminated when False. - ``_LSTMFn`` / ``_GRUFn`` now take ``save_gates`` instead of ``recompute``. When False, allocate a single 4-byte placeholder for the kernel's gate-save pointer args (instead of 5 / 4 buffers of ``[B, T, H_pad]``) and save the minimal set in ``save_for_backward``. ``ctx.recompute`` is derived from ``not save_gates`` so the existing backward replay path is unchanged. - ``lstm_triton`` / ``gru_triton`` compute ``save_gates`` via the new ``_resolve_save_gates`` helper outside ``Function.apply`` (where ``torch.is_grad_enabled()`` is meaningful): ``save_gates = not recompute and is_grad_enabled() and any(t.requires_grad)``. Public wrapper signatures are unchanged. - Backward kernel-replay sites pass ``SAVE_GATES=True`` explicitly so the recompute path stores into its fresh scratch buffers as before. - Test ``test_resolve_save_gates`` covers the four truth-table cases (recompute, no_grad, no-grad-input, optional ``None`` inputs). Co-Authored-By: Claude Opus 4.7 <[email protected]>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Lands the four follow-ups from
TODO/isaac_rnn_scan_triton_memory_followups.mdin a single change. The Isaac RNN PPO runs were showing scan/triton reserving 3-4x the cuDNN baseline's max-allocated memory; the gap was dominated by per-step backward activations and learner-boundary whole-TD canonicalization._LSTMFn/_GRUFn. Newrecompute: boolarg drops the 5 LSTM gate buffers (save_i/f/g/o/save_tanhc) / 4 GRU buffers (save_r/z/n/gh_n) fromsave_for_backwardand replays the forward kernel into scratch during backward. ~5[B, T, H_pad]activation tensors saved per layer._lstm_scan_with_resets/_gru_scan_with_resets. Whenrecurrent_recompute='full', swapstorch._higher_order_ops.scanfor a python time-loop wrapped intorch.utils.checkpoint.checkpoint(use_reentrant=False), drops parameter.clone()calls along that path. Side-effect: the python-loop+checkpoint path matches the cuDNNpadbackend to float precision (_scanHOP has pre-existing gradient drift ~0.17 vs cuDNN; tests pin scan-recompute againstpad, not against scan-no-recompute).LSTMModule.canonical_keys+LSTMModule.canonicalize(data), twin methods onGRUModule, and atorchrl.modules.canonicalize_rnn_subset(data, modules)free function. Appliescontiguous(canonical=True)only to the module's in/out keys so rewards/advantages/log-probs/value targets keep their layout.torchrl/_utils.py:cuda_memory_stats(device),reset_cuda_peak_stats(device), andcuda_memory_profile(label, *, device, log, reset_peaks)context manager. CPU/MPS-safe (returns zeros / no-ops). Re-exported fromtorchrl.Public API
New kwarg on both modules, orthogonal to
recurrent_backend:backend=\"pad\"rejects non-\"none\"values (cuDNN manages its own backward workspace).Test plan
pytest test/modules/test_rnn.py::TestLSTMModule test/modules/test_rnn.py::TestGRUModule(CPU) -- 140 passed, 33 CUDA/triton-only skipped.pytest test/test_utils.py::TestCudaMemoryHelpers-- 5 passed, 2 CUDA-only skipped.padbackend toatol=1e-5fornum_layers in {1, 2}on both LSTM and GRU.test_lstm_triton_recompute_parity,test_gru_triton_recompute_parity(forward + grads vs full-save triton path).benchmarks/bench_rnn_recompute_memory.py --rnn lstm --backend triton --batch 4096 --seq-len 32 --hidden 256on a CUDA box to confirm max_allocated_gb drops withrecompute='full'.Notes
recurrent_recompute='chunk'with chunk-size knob, reduced-precision (fp16/bf16) gate storage, andis_contiguous()guards on internal_pad_last(...).contiguous()helpers.Plan:
/Users/vmoens/.claude/plans/sharded-gliding-tower.md🤖 Generated with Claude Code