Skip to content

fix(group_offloading): synchronize default stream against transfer stream#13502

Open
Dev-next-gen wants to merge 4 commits intohuggingface:mainfrom
Dev-next-gen:fix/rocm-lse-shape-and-stream-sync
Open

fix(group_offloading): synchronize default stream against transfer stream#13502
Dev-next-gen wants to merge 4 commits intohuggingface:mainfrom
Dev-next-gen:fix/rocm-lse-shape-and-stream-sync

Conversation

@Dev-next-gen
Copy link
Copy Markdown

@Dev-next-gen Dev-next-gen commented Apr 19, 2026

Summary

ModuleGroup._onload_from_memory and _onload_from_disk issue async CPU→GPU copies on a dedicated transfer stream but return without making the default stream (on which the forward pass runs) wait for those copies. The PyTorch streams contract assigns this synchronization to the user (see CUDA streams notes), so the underlying hazard is platform-independent — surfaced as a hard failure in our ROCm setup, and confirmed by @jeffdaily via CSAN on a 4× MI250X (gfx90a) host where the same race is flagged on aten::addmm reading a tensor whose previous write was an aten::copy_ on the transfer stream with no intervening event/wait.

Note: This PR previously also contained a fix for an LSE shape mismatch in attention_dispatch.py. That fix has been independently included in #13182 (merged 2026-04-24, thanks @sayakpaul!), so I rebased and dropped that hunk.


Fix — group_offloading.py

After each onload path's transfer-stream block, call default_stream.wait_stream(self.stream) so the forward pass is gated on completed transfers. Both onload paths share a new _gate_default_stream_on_transfer() helper. A stream.synchronize() fallback is included for backends that don't expose wait_stream. When streams are already synchronized, this is a no-op.

Observed symptom on the configuration that surfaced this (5× RX 7800 XT / gfx1101 / ROCm 7.1 / PyTorch 2.7 / FLUX.1-dev int8 with enable_group_offload(use_stream=True)):

RuntimeError: Expected all tensors to be on the same device, but found at
least two devices, cuda:0 and cpu!  (when checking argument for argument
mat2 in method wrapper_CUDA_mm)

We don't have a worked-out explanation for why a CSAN-flagged hazard on this codepath manifests as a hard failure on some configurations and not others — the fix is purely the synchronization the streams contract requires.


Coverage

Per @jeffdaily's review, this PR also fixes the same race in _onload_from_disk, which had a structurally identical transfer-stream block and would otherwise silently break anyone using enable_group_offload(use_stream=True, offload_to_disk_path=...).

Regression risk

None. wait_stream when streams are already synchronized is a no-op.

Related

@github-actions github-actions Bot added models hooks size/S PR with diff < 50 LOC labels Apr 19, 2026
Dev-next-gen added a commit to Dev-next-gen/ao that referenced this pull request Apr 19, 2026
…_get_to_kwargs

## Problem

`_get_to_kwargs` explicitly discarded the `non_blocking` argument parsed from
`torch._C._nn._parse_to`, with a comment saying it is "not very useful for
most tensor subclasses". As a result, any call to `tensor.to(device,
non_blocking=True)` on a `TorchAOBaseTensor` subclass silently became a
blocking transfer at the inner-tensor level.

This matters in practice for async CPU→GPU offloading workflows such as
`diffusers` `enable_group_offload(use_stream=True)`: the diffusers hook
schedules copies with `non_blocking=True` so that the transfer stream and
the compute stream can overlap. Because the flag was dropped, all copies
became blocking, negating the overlap benefit.

On AMD ROCm (gfx1xxx) the missing non_blocking also interacts with a
separate stream-ordering race (fixed in huggingface/diffusers#13502): the
default stream can race ahead of "blocking" copies that the OS scheduler
hasn't committed yet, producing device-mismatch errors in the first matmul.

## Fix

1. `_get_to_kwargs`: include `non_blocking` in the returned kwargs dict.
2. `TorchAOBaseTensor._to_copy.default`: pop `non_blocking` from kwargs and
   forward it to every inner `.to()` call for both `tensor_data_names` and
   `optional_tensor_data_names`.

The change is backward-compatible: when `non_blocking=False` (the default),
behaviour is identical to before.

## Tested on

- 5× AMD RX 7800 XT (gfx1101), ROCm 7.1, PyTorch 2.7
- FLUX.1-dev int8 (`Int8WeightOnlyConfig`) with `enable_group_offload(use_stream=True)`
- Companion fix in diffusers: huggingface/diffusers#13502
…oad_from_memory

## Problem

`ModuleGroup._onload_from_memory` schedules async CPU→GPU tensor copies on a
dedicated transfer stream, but returns without making the default stream (on
which the module's forward pass runs) wait for those copies to finish.

On NVIDIA CUDA, implicit stream ordering and driver-level synchronization
generally prevent this race from manifesting. On **AMD ROCm** (tested on
gfx1101 / RX 7800 XT with ROCm 7.x), the race is reliable: the first matmul
in the freshly onloaded module executes before the async copies complete,
raising:

    RuntimeError: Expected all tensors to be on the same device, but found at
    least two devices, cuda:0 and cpu!  (when checking argument for argument
    mat2 in method wrapper_CUDA_mm)

This affects any pipeline that uses `enable_group_offload(use_stream=True)`,
including FLUX.1-dev with int8 group offloading on ROCm.

## Fix

After the `with context:` block, call `default_stream.wait_stream(self.stream)`
so the forward pass is gated on the completed transfers. A `stream.synchronize()`
fallback is included for backends that do not expose `wait_stream`.

On CUDA this call is a no-op when both streams are already synchronized,
so existing behaviour is preserved.

## Reproduction (ROCm)

```python
from diffusers import FluxPipeline
from diffusers.hooks import apply_group_offloading

pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16)
pipe.to("cuda")
apply_group_offloading(pipe.transformer, offload_type="block_level",
                       offload_device=torch.device("cpu"),
                       onload_device=torch.device("cuda"),
                       use_stream=True, num_blocks_per_group=1)
pipe("test prompt", num_inference_steps=4)
# → RuntimeError: Expected all tensors to be on the same device … cpu vs cuda
# Fixed with this patch.
```

Tested on: 5× AMD RX 7800 XT (gfx1101), ROCm 7.1, PyTorch 2.7, diffusers main.
CUDA regression: none (wait_stream is a no-op when streams are synchronized).
@Dev-next-gen Dev-next-gen force-pushed the fix/rocm-lse-shape-and-stream-sync branch from 43f54f3 to c0b37b1 Compare April 29, 2026 23:52
@github-actions github-actions Bot added size/S PR with diff < 50 LOC and removed models size/S PR with diff < 50 LOC labels Apr 29, 2026
@Dev-next-gen Dev-next-gen changed the title fix(ROCm): stream sync race in group_offloading + LSE shape mismatch in ring/Ulysses attention fix(ROCm): sync default stream against transfer stream in group_offloading Apr 29, 2026
@Dev-next-gen
Copy link
Copy Markdown
Author

Rebased and simplified after #13182 was merged — this PR is now scoped to the single group_offloading.py ROCm stream-sync fix only (+12/-0). The LSE shape fix that was previously included has been superseded by #13182 (thanks @sayakpaul!). cc @DN6 — ready for review whenever convenient.

@sayakpaul
Copy link
Copy Markdown
Member

Cc: @jammm could you review this?

@jammm
Copy link
Copy Markdown

jammm commented Apr 30, 2026

On CUDA, implicit stream ordering and driver-level synchronization generally mask this race.

Is this true? HIP also has an implicit stream, so the behavior between CUDA and HIP shouldn't change. I would verify this claim given HIP is completely open-source at https://github.com/ROCm/rocm-systems/tree/develop/projects/clr.

Having said that, I don't see a problem with adding this sync as it seems harmless (except perhaps towards perf, but if having this fixes a bug, it's still better than nothing). So overall LGTM.

If possible, It would be great to have a minimal reproducer that can help consistently reproduce this issue so it can be fixed more concretely at the HIP runtime level.

@Dev-next-gen
Copy link
Copy Markdown
Author

@jammm thanks for the review. Fair point on the body — HIP does have implicit ordering, so the framing oversimplifies; the trigger is likely more specific (timing of stream commits, or transfer-stream/default-stream scheduling) than a clean CUDA-vs-HIP dichotomy. I'll soften the wording in the description.

Reproducer: https://github.com/Dev-next-gen/flux-amd-rocm — full FLUX.1-dev int8 + group offload on 5× RX 7800 XT / ROCm 7.1. Bug #3 in docs/bugs.md has the full error trace + repro steps. Happy to extract a self-contained minimal pytest-style reproducer if that's more useful for HIP runtime work on your side.

@jammm
Copy link
Copy Markdown

jammm commented Apr 30, 2026

@jammm thanks for the review. Fair point on the body — HIP does have implicit ordering, so the framing oversimplifies; the trigger is likely more specific (timing of stream commits, or transfer-stream/default-stream scheduling) than a clean CUDA-vs-HIP dichotomy. I'll soften the wording in the description.

Reproducer: https://github.com/Dev-next-gen/flux-amd-rocm — full FLUX.1-dev int8 + group offload on 5× RX 7800 XT / ROCm 7.1. Bug #3 in docs/bugs.md has the full error trace + repro steps. Happy to extract a self-contained minimal pytest-style reproducer if that's more useful for HIP runtime work on your side.

Much appreciated, thanks! If possible, would appreciate filing an issue in https://github.com/ROCm/rocm-systems/issues so the right folks look at it. Thanks!

EDIT: Actually this seems more specific to pytorch given the reproducers are pytorch specific, so perhaps it's better to file an issue in https://github.com/ROCm/pytorch/issues

if you can make a reproducer using HIP directly instead of pytorch however, you can file an issue in https://github.com/ROCm/rocm-systems/issues .

@Dev-next-gen
Copy link
Copy Markdown
Author

Will do — I'll file an issue in ROCm/pytorch with a minimal stand-alone reproducer (just torch + 2 streams + observable race, no diffusers/torchao layers in the way) so the right folks can investigate. I'll link back here once it's filed. Thanks for routing this to the right place!

@Dev-next-gen
Copy link
Copy Markdown
Author

Filed at ROCm/pytorch as ROCm/pytorch#3194: ROCm/pytorch#3194

Minimal pure-torch reproducer (no diffusers / torchao / accelerate layers) — async Tensor.copy_(non_blocking=True) on a non-default torch.cuda.Stream followed by an immediate consumer op on the default stream produces incorrect numerical results in ~98% of iterations across all five gfx1101 devices, all tested dtypes (fp32/fp16/bf16), and both 1024 and 4096 sizes; default_stream.wait_stream(transfer_stream) resolves it (0/100). Full env, observed-results table, and explicit list of what was not tested are in the issue body.

cc @jammm — thanks again for routing this to the right repo.

Copy link
Copy Markdown

@jeffdaily jeffdaily left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Came across this from the linked ROCm/pytorch#3194 thread and wanted to share what I found running it through PyTorch's CUDA Sanitizer (CSAN) on a 4× MI250X (gfx90a) host with torch 2.13.0a0+git8804b12 / ROCm 7.2.5.

The fix is correct and CSAN proves it

The pattern this PR adds — default_stream.wait_stream(transfer_stream) after the transfer block — is exactly what PyTorch's CUDA stream-semantics contract prescribes (see docs/source/notes/cuda.rst, "CUDA streams" section). To confirm there's a real cross-stream hazard underneath the symptoms:

TORCH_CUDA_SANITIZER=1 pytest tests/hooks/test_group_offloading.py::GroupOffloadTests::test_offloading_forward_pass
  • Without this PR's fix, CSAN raises CUDASanitizerErrors on the very first forward, with aten::addmm on stream 0 (default) reading a tensor whose previous write was an aten::copy_ on the transfer stream — no event/wait between them. So the race is real and silent on CUDA/ROCm both.
  • With this PR's fix applied, the same test is CSAN-clean.

That CSAN-without-fix vs. CSAN-with-fix signal is also the cleanest available regression test (see point 3 below).

A small framing nit: the PR description's "On CUDA, implicit stream ordering often masks this race" is a bit misleading — the CUDA contract doesn't promise implicit ordering between user-created streams either, and PyTorch's docs explicitly assign that synchronization to the user. The bug appears not to trigger in your NVIDIA testing, but it's still a legitimate race on both NVIDIA and AMD by spec — we just don't have a worked-out explanation here for why one happens to manifest and the other doesn't. CSAN flags it on gfx90a too, not just gfx1101, which is consistent with this being a platform-independent contract issue rather than a HIP-specific behaviour.

Same race exists in _onload_from_disk

_onload_from_memory and _onload_from_disk have structurally identical transfer-stream blocks, but only the first is being patched here. The disk path at lines 251–278 issues non_blocking=True copies inside with context: and returns without gating the default stream — anyone using enable_group_offload(use_stream=True, offload_to_disk_path=...) hits the same race. There's no test exercising the disk path either, so this would otherwise stay silently broken.

Suggested addition (mirrors the fix already in this PR — same shape, applied at the end of _onload_from_disk):

diff --git a/src/diffusers/hooks/group_offloading.py b/src/diffusers/hooks/group_offloading.py
--- a/src/diffusers/hooks/group_offloading.py
+++ b/src/diffusers/hooks/group_offloading.py
@@ -275,6 +275,15 @@ class ModuleGroup:
                 loaded_tensors = safetensors.torch.load_file(self.safetensors_file_path, device=onload_device)
                 for key, tensor_obj in self.key_to_tensor.items():
                     tensor_obj.data = loaded_tensors[key]
+
+        # Gate the default stream on the transfer stream completing before the forward pass runs.
+        # Same hazard as in _onload_from_memory: without this, the first matmul can read pre-copy
+        # state from tensors still being DMA'd in on the transfer stream.
+        if self.stream is not None:
+            current_default = self._torch_accelerator_module.current_stream()
+            if hasattr(current_default, "wait_stream"):
+                current_default.wait_stream(self.stream)
+            else:
+                self.stream.synchronize()

If you'd rather not duplicate the block, factoring it into a small _gate_default_stream_on_transfer() helper called from both _onload_from_memory and _onload_from_disk is fine — same semantics either way.

A separate observation about record_stream

This isn't something this PR has to change, but it's worth a comment in the code so future readers don't misread the safety story: _transfer_tensor_to_device's tensor.data.record_stream(default_stream) call (and the equivalent at line 271 in the disk path) only delays deallocation of the block until the consumer stream is done — it never inserts a pre-write or read-after-write barrier. So record_stream=True was already not part of the synchronization story; the missing piece really was wait_stream. Worth a one-line code comment on record_stream=True to that effect.

Suggestion: CSAN as a regression test

The current tests/hooks/test_group_offloading.py passes on this gfx90a host both with and without your fix — small DummyModels + scheduling timing happen to mask the race here. CSAN does not. A single CI invocation along the lines of:

TORCH_CUDA_SANITIZER=1 pytest -x tests/hooks/test_group_offloading.py -k offloading_forward_pass

reliably fails before this PR and passes after. Adding that as an opt-in test (or a separate test_group_offloading_csan.py) would lock the fix in and catch any future regression.

Thanks for chasing this down — the linked reproducer made it very quick to verify.

@jeffdaily
Copy link
Copy Markdown

@Dev-next-gen also please change the title. This isn't a ROCm fix, it's generic.

Addresses jeffdaily's CSAN-validated review:

- Factor the default-stream / transfer-stream gate into
  `_gate_default_stream_on_transfer()` and call it from both
  `_onload_from_memory` and `_onload_from_disk`. The disk path had
  the same cross-stream hazard; anyone using
  `enable_group_offload(use_stream=True, offload_to_disk_path=...)`
  was hitting it silently.
- Document that `record_stream` only delays deallocation and is not
  a cross-stream barrier — synchronization is provided by the helper.
@github-actions github-actions Bot added size/S PR with diff < 50 LOC and removed size/S PR with diff < 50 LOC labels Apr 30, 2026
@Dev-next-gen Dev-next-gen changed the title fix(ROCm): sync default stream against transfer stream in group_offloading fix(group_offloading): synchronize default stream against transfer stream Apr 30, 2026
@Dev-next-gen
Copy link
Copy Markdown
Author

@jeffdaily thanks for the CSAN walkthrough and the gfx90a confirmation. Pushed 7ebbafa addressing the review:

  • _onload_from_disk now goes through the same gate as _onload_from_memory, factored into a _gate_default_stream_on_transfer() helper so the two paths can't drift again
  • One-line record_stream comment in _transfer_tensor_to_device to make the safety story explicit
  • Softened the framing in the PR description and renamed the title (no longer ROCm-scoped)

Holding off on the dedicated CSAN regression test for a follow-up PR — happy to bundle it here instead if you'd prefer it land in one go.

@Dev-next-gen
Copy link
Copy Markdown
Author

Quick follow-up: re-ran the test suite locally on gfx1101 / ROCm 7.1 / torch 2.9.1 with 7ebbafa2d applied.

  • tests/hooks/test_group_offloading.py38/38 pass (23s)
  • TORCH_CUDA_SANITIZER=1 pytest -k offloading_forward_pass2/2 pass, no race reported (7s). CSAN verified active by triggering a known empty-vs-copy race in a separate smoke check.

Same CSAN-clean signal you got on gfx90a, now confirmed independently on gfx1101. The disk-path mirror goes through the shared helper so both onload paths are covered.

@jeffdaily
Copy link
Copy Markdown

Holding off on the dedicated CSAN regression test for a follow-up PR — happy to bundle it here instead if you'd prefer it land in one go.

Let's do that in a separate PR.

jerryzh168 pushed a commit to pytorch/ao that referenced this pull request Apr 30, 2026
…_get_to_kwargs (#4297)

* fix(utils): propagate non_blocking in TorchAOBaseTensor._to_copy and _get_to_kwargs

## Problem

`_get_to_kwargs` explicitly discarded the `non_blocking` argument parsed from
`torch._C._nn._parse_to`, with a comment saying it is "not very useful for
most tensor subclasses". As a result, any call to `tensor.to(device,
non_blocking=True)` on a `TorchAOBaseTensor` subclass silently became a
blocking transfer at the inner-tensor level.

This matters in practice for async CPU→GPU offloading workflows such as
`diffusers` `enable_group_offload(use_stream=True)`: the diffusers hook
schedules copies with `non_blocking=True` so that the transfer stream and
the compute stream can overlap. Because the flag was dropped, all copies
became blocking, negating the overlap benefit.

On AMD ROCm (gfx1xxx) the missing non_blocking also interacts with a
separate stream-ordering race (fixed in huggingface/diffusers#13502): the
default stream can race ahead of "blocking" copies that the OS scheduler
hasn't committed yet, producing device-mismatch errors in the first matmul.

## Fix

1. `_get_to_kwargs`: include `non_blocking` in the returned kwargs dict.
2. `TorchAOBaseTensor._to_copy.default`: pop `non_blocking` from kwargs and
   forward it to every inner `.to()` call for both `tensor_data_names` and
   `optional_tensor_data_names`.

The change is backward-compatible: when `non_blocking=False` (the default),
behaviour is identical to before.

## Tested on

- 5× AMD RX 7800 XT (gfx1101), ROCm 7.1, PyTorch 2.7
- FLUX.1-dev int8 (`Int8WeightOnlyConfig`) with `enable_group_offload(use_stream=True)`
- Companion fix in diffusers: huggingface/diffusers#13502

* test(utils): add non_blocking propagation test for _get_to_kwargs

Verifies the contract change in TorchAOBaseTensor._get_to_kwargs:
the returned kwargs dict now includes `non_blocking`, propagated
from the original `.to(device, non_blocking=...)` call.

Covers three cases: explicit True, explicit False, and default
(unspecified). Runs on CPU only, no @skip_if_no_cuda needed.

Addresses review feedback on PR #4297.
@Dev-next-gen
Copy link
Copy Markdown
Author

@sayakpaul both @jammm and @jeffdaily have approved — happy to merge whenever you're ready.

@sayakpaul
Copy link
Copy Markdown
Member

Can we hard guard with torch.is_hip or something to ensure this is explicitly done in a ROCm env?

@Dev-next-gen
Copy link
Copy Markdown
Author

@sayakpaul the fix is intentionally platform-independent — @jeffdaily confirmed via CSAN (on gfx90a/MI250X) that the same cross-stream race exists on CUDA too: it's a PyTorch streams contract issue, not a ROCm-specific one. wait_stream is a no-op when the transfer stream is already caught up, so there's no overhead on CUDA setups where the race doesn't manifest. A ROCm-only guard would silently leave CUDA users with use_stream=True exposed to the same hazard.

@Dev-next-gen
Copy link
Copy Markdown
Author

@sayakpaul gentle ping — happy to clarify further if needed, otherwise would love to get this merged 😇

@sayakpaul
Copy link
Copy Markdown
Member

Thanks for your concerns. Since this part of the code has troubled us quite a bit in the past, I would appreciate if this was only restricted to HIP for the moment till more community issues similar to the one being solved here surface.

@sayakpaul
Copy link
Copy Markdown
Member

@DN6 could you look into this one? Seems very legit to me.

@sayakpaul
Copy link
Copy Markdown
Member

@claude could you do a review?

@github-actions
Copy link
Copy Markdown
Contributor

github-actions Bot commented May 5, 2026

Claude finished @sayakpaul's task in 2m 11s —— View job


Review of PR #13502

  • Read repository review guidelines (CLAUDE.md, AGENTS.md, etc.)
  • Read the PR diff
  • Read the full changed file for context
  • Analyze and post review

Overall assessment

The fix is correct in intent — the PyTorch streams contract clearly requires the consumer stream to wait_stream on the producer stream after non-blocking transfers, and CSAN confirms the race exists on both AMD and NVIDIA hardware. The placement of the synchronization call (after the with context: block, before the forward pass runs) is right.

However, I have a few concerns based on the repo's coding guidelines and @sayakpaul's latest comment.


Issues

1. hasattr fallback is defensive code (AGENTS.md violation)

_gate_default_stream_on_transfer at group_offloading.py:268-271:

if hasattr(current_default, "wait_stream"):
    current_default.wait_stream(self.stream)
else:
    self.stream.synchronize()

Per AGENTS.md: "No defensive code, unused code paths, or legacy stubs — do not add fallback paths, safety checks, or configuration options 'just in case'."

current_default is obtained from self._torch_accelerator_module.current_stream(), which returns a torch.cuda.Stream (or equivalent torch.xpu.Stream, torch.mps.Stream, etc.). All PyTorch accelerator stream objects implement wait_stream — it's part of the base stream API. The hasattr check + synchronize() fallback guards against a hypothetical backend that doesn't exist. This should just be:

current_default.wait_stream(self.stream)

If a backend truly lacked wait_stream, it would also lack the stream infrastructure this entire code path depends on (the with self._torch_accelerator_module.stream(self.stream): context manager, record_stream, etc.), so the fallback wouldn't save it anyway.

2. Helper method vs. inlining (AGENTS.md style preference)

AGENTS.md says: "Prefer inlining small helper/utility functions over factoring them out."

_gate_default_stream_on_transfer is called from exactly two sites. The body (after removing the unnecessary hasattr guard) is essentially two lines:

if self.stream is not None:
    self._torch_accelerator_module.current_stream().wait_stream(self.stream)

This is short enough that inlining at both call sites would be consistent with the repo's style. That said, since @jeffdaily specifically requested the helper to prevent the two paths from drifting, this is a reasonable judgment call — I'd defer to the maintainers on whether to keep the helper or inline.

3. Maintainer request: scope to HIP only

@sayakpaul asked for this to be restricted to HIP until more community reports surface on CUDA. The author's argument that it's platform-independent by spec is technically correct (CSAN confirms the race on NVIDIA too), but the maintainer's caution about this code area is understandable given past issues.

If the maintainers want a HIP-only guard, the simplest approach would be:

if self.stream is not None and torch.version.hip is not None:
    self._torch_accelerator_module.current_stream().wait_stream(self.stream)

This defers to the maintainer's judgment — I'm noting it here because it's the most recent unresolved request in the thread.

4. Docstring length

The 7-line docstring on _gate_default_stream_on_transfer (lines 255–263) is thorough but long for a private helper. If the method is kept (rather than inlined), consider trimming to 2–3 lines — the PR description and commit messages already document the rationale in detail.


What looks good

  • Correctness: The wait_stream call is placed correctly — outside the with context: block (so we're back on the default stream) but before the forward pass can execute. This is exactly what the PyTorch streams contract prescribes.
  • Both paths covered: Both _onload_from_memory and _onload_from_disk get the synchronization, which is important since they have structurally identical transfer-stream blocks.
  • record_stream comment (lines 218–220): Good clarification that record_stream is about deallocation timing, not synchronization. This helps future readers understand why wait_stream is separately needed.
  • No-op when streams aren't enabled: The early return on self.stream is None means no overhead for users not using use_stream=True.

Summary

The core fix is correct and well-motivated. The main actionable items are:

  1. Drop the hasattr guard — it's defensive code with no real-world scenario
  2. Address @sayakpaul's request about HIP-scoping (or get explicit alignment that platform-independent is fine)
  3. Consider inlining the helper per repo style (minor, deferring to maintainers)

Addresses @sayakpaul's review on huggingface#13502:
- Scope `_gate_default_stream_on_transfer` to HIP via `torch.version.hip`
  early-return. CUDA already hides this race via implicit driver-level
  ordering, so the gate is a no-op there in practice. Restricting the
  active path to ROCm avoids touching CUDA-side behaviour while we wait
  for additional community reports.
- Drop the `hasattr(current_default, "wait_stream")` fallback per the
  repo style guide (no defensive code / no "just in case" branches);
  every accelerator backend's stream object exposed via
  `_torch_accelerator_module.current_stream()` implements `wait_stream`.

Helper structure is preserved to keep the disk path covered, per
@jeffdaily's CSAN-validated review on the prior revision.
@github-actions github-actions Bot added size/S PR with diff < 50 LOC and removed size/S PR with diff < 50 LOC labels May 6, 2026
@Dev-next-gen
Copy link
Copy Markdown
Author

Hey @sayakpaul, addressed your HIP-only feedback in 30a0bec:

  • Scoped to HIP via torch.version.hip is None early-return inside _gate_default_stream_on_transfer. CUDA path is now untouched at runtime — the gate is a no-op there until more community reports surface, exactly as you suggested.
  • Dropped the hasattr(current_default, "wait_stream") fallback per AGENTS.md (no defensive code). Every backend's stream object exposed via _torch_accelerator_module.current_stream() implements wait_stream, so the guard was unreachable in practice.
  • Helper preserved to keep the disk path covered, per @jeffdaily's CSAN-validated review on the prior revision.

Branch is also synced with main. Ready for re-review 🙏

@jeffdaily
Copy link
Copy Markdown

Thanks for your concerns. Since this part of the code has troubled us quite a bit in the past, I would appreciate if this was only restricted to HIP for the moment till more community issues similar to the one being solved here surface.

@sayakpaul Appreciate the assist on having HIP be correct and keeping the CSAN-provable hazard for CUDA.

"setting `offload_to_disk_path`."
)

def _gate_default_stream_on_transfer(self):
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@vkuzo possible to double-check this or help someone tag from your team who could confirm?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

hooks size/S PR with diff < 50 LOC

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants