Skip to content

[Performance] Recompute backward + narrow canonicalization for RNN backends#3752

Open
vmoens wants to merge 2 commits into
mainfrom
vmoens/rnn-recompute-memory
Open

[Performance] Recompute backward + narrow canonicalization for RNN backends#3752
vmoens wants to merge 2 commits into
mainfrom
vmoens/rnn-recompute-memory

Conversation

@vmoens
Copy link
Copy Markdown
Collaborator

@vmoens vmoens commented May 14, 2026

Summary

Lands the four follow-ups from TODO/isaac_rnn_scan_triton_memory_followups.md in a single change. The Isaac RNN PPO runs were showing scan/triton reserving 3-4x the cuDNN baseline's max-allocated memory; the gap was dominated by per-step backward activations and learner-boundary whole-TD canonicalization.

  • Triton recompute backward for _LSTMFn / _GRUFn. New recompute: bool arg drops the 5 LSTM gate buffers (save_i/f/g/o/save_tanhc) / 4 GRU buffers (save_r/z/n/gh_n) from save_for_backward and replays the forward kernel into scratch during backward. ~5 [B, T, H_pad] activation tensors saved per layer.
  • Scan checkpointing for _lstm_scan_with_resets / _gru_scan_with_resets. When recurrent_recompute='full', swaps torch._higher_order_ops.scan for a python time-loop wrapped in torch.utils.checkpoint.checkpoint(use_reentrant=False), drops parameter .clone() calls along that path. Side-effect: the python-loop+checkpoint path matches the cuDNN pad backend to float precision (_scan HOP has pre-existing gradient drift ~0.17 vs cuDNN; tests pin scan-recompute against pad, not against scan-no-recompute).
  • Narrow canonicalization: LSTMModule.canonical_keys + LSTMModule.canonicalize(data), twin methods on GRUModule, and a torchrl.modules.canonicalize_rnn_subset(data, modules) free function. Applies contiguous(canonical=True) only to the module's in/out keys so rewards/advantages/log-probs/value targets keep their layout.
  • Public memory helpers in torchrl/_utils.py: cuda_memory_stats(device), reset_cuda_peak_stats(device), and cuda_memory_profile(label, *, device, log, reset_peaks) context manager. CPU/MPS-safe (returns zeros / no-ops). Re-exported from torchrl.

Public API

New kwarg on both modules, orthogonal to recurrent_backend:

recurrent_recompute: Literal[\"none\", \"full\"] = \"none\"

backend=\"pad\" rejects non-\"none\" values (cuDNN manages its own backward workspace).

Test plan

  • pytest test/modules/test_rnn.py::TestLSTMModule test/modules/test_rnn.py::TestGRUModule (CPU) -- 140 passed, 33 CUDA/triton-only skipped.
  • pytest test/test_utils.py::TestCudaMemoryHelpers -- 5 passed, 2 CUDA-only skipped.
  • Scan-recompute output + grads match the cuDNN pad backend to atol=1e-5 for num_layers in {1, 2} on both LSTM and GRU.
  • CUDA-required tests need a GPU runner: test_lstm_triton_recompute_parity, test_gru_triton_recompute_parity (forward + grads vs full-save triton path).
  • Benchmark benchmarks/bench_rnn_recompute_memory.py --rnn lstm --backend triton --batch 4096 --seq-len 32 --hidden 256 on a CUDA box to confirm max_allocated_gb drops with recompute='full'.

Notes

  • Deferred to follow-ups: rl-matteo experiment-script adoption of the new helpers, recurrent_recompute='chunk' with chunk-size knob, reduced-precision (fp16/bf16) gate storage, and is_contiguous() guards on internal _pad_last(...).contiguous() helpers.

Plan: /Users/vmoens/.claude/plans/sharded-gliding-tower.md

🤖 Generated with Claude Code

…ckends

Adds a memory-saving recompute mode and a small set of helpers driven by
the Isaac RNN PPO follow-ups (TODO: isaac_rnn_scan_triton_memory_followups).

- New ``recurrent_recompute: Literal["none", "full"]`` kwarg on
  ``LSTMModule`` / ``GRUModule``, orthogonal to ``recurrent_backend``.
- Triton ``_LSTMFn`` / ``_GRUFn`` accept ``recompute``: drops the per-step
  gate buffers (``save_i/f/g/o/save_tanhc`` for LSTM, ``save_r/z/n/gh_n``
  for GRU) from ``save_for_backward`` and replays the fwd kernel into
  scratch during backward. Saves ~5 ``[B, T, H_pad]`` activation tensors
  per layer.
- Scan ``_lstm_scan_with_resets`` / ``_gru_scan_with_resets`` swap the
  ``torch._higher_order_ops.scan`` HOP for a python time-loop wrapped in
  ``torch.utils.checkpoint.checkpoint(use_reentrant=False)`` when
  ``recurrent_recompute='full'``. Parameter ``.clone()`` calls are dropped
  along that path; gradients match the cuDNN ``"pad"`` backend to float
  precision.
- ``LSTMModule.canonicalize`` / ``GRUModule.canonicalize`` plus
  ``torchrl.modules.canonicalize_rnn_subset`` apply
  ``contiguous(canonical=True)`` only to the module's in/out keys, so
  callers can avoid materializing transient full-batch copies of
  unrelated TensorDict leaves.
- ``torchrl.cuda_memory_stats``, ``reset_cuda_peak_stats`` and the
  ``cuda_memory_profile`` context manager replace experiment-side
  reinventions of ``torch.cuda.memory_*`` and pair with ``timeit`` for
  per-phase peak accounting.
- Tests in ``test/modules/test_rnn.py`` cover scan-recompute parity vs
  pad, triton recompute parity, the canonicalize subset method and free
  function, the recompute-rejected-for-pad guard and invalid-value
  rejection (CUDA-only tests are skipped where appropriate).
  ``test/test_utils.py`` covers the new memory helpers.
- New benchmark ``benchmarks/bench_rnn_recompute_memory.py`` reports
  peak allocated GB for ``recompute=none`` vs ``recompute='full'`` per
  (rnn, backend).

Co-Authored-By: Claude Opus 4.7 <[email protected]>
@pytorch-bot
Copy link
Copy Markdown

pytorch-bot Bot commented May 14, 2026

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/rl/3752

Note: Links to docs will display an error until the docs builds have been completed.

❗ 1 Active SEVs

There are 1 currently active SEVs. If your PR is affected, please view them below:

✅ No Failures

As of commit d69a009 with merge base 0a01ee8 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@meta-cla meta-cla Bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label May 14, 2026
@github-actions github-actions Bot added Performance Performance issue or suggestion for improvement Documentation Improvements or additions to documentation Benchmarks rl/benchmark changes Modules Integrations/torch_geometric Integrations labels May 14, 2026
…'t read them

The Triton LSTM/GRU forward was unconditionally allocating
save_i/save_f/save_g/save_o/save_tanhc (LSTM) and
save_r/save_z/save_n/save_gh_n (GRU) and writing to them from the
kernel, even under recompute=True (where ctx.save_for_backward drops
them) and under torch.no_grad / no-tracked-input calls (where backward
never runs). Under CudaGraphModule / CUDA graph capture, these
otherwise-dead allocations get attributed to the graph's private memory
pool and inflate persistent memory_allocated() -- defeating the
recompute mode's memory win.

- Add ``SAVE_GATES: tl.constexpr = True`` to ``_lstm_fwd_kernel`` and
  ``_gru_fwd_kernel``. Wrap the ``tl.store(save_*)`` lines in
  ``if SAVE_GATES:`` so the stores are dead-code-eliminated when False.
- ``_LSTMFn`` / ``_GRUFn`` now take ``save_gates`` instead of
  ``recompute``. When False, allocate a single 4-byte placeholder for
  the kernel's gate-save pointer args (instead of 5 / 4 buffers of
  ``[B, T, H_pad]``) and save the minimal set in ``save_for_backward``.
  ``ctx.recompute`` is derived from ``not save_gates`` so the existing
  backward replay path is unchanged.
- ``lstm_triton`` / ``gru_triton`` compute ``save_gates`` via the new
  ``_resolve_save_gates`` helper outside ``Function.apply`` (where
  ``torch.is_grad_enabled()`` is meaningful):
  ``save_gates = not recompute and is_grad_enabled() and any(t.requires_grad)``.
  Public wrapper signatures are unchanged.
- Backward kernel-replay sites pass ``SAVE_GATES=True`` explicitly so
  the recompute path stores into its fresh scratch buffers as before.
- Test ``test_resolve_save_gates`` covers the four truth-table cases
  (recompute, no_grad, no-grad-input, optional ``None`` inputs).

Co-Authored-By: Claude Opus 4.7 <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Benchmarks rl/benchmark changes CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. Documentation Improvements or additions to documentation Integrations/torch_geometric Integrations Modules Performance Performance issue or suggestion for improvement

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant