Skip to content

Draft:Extended Tensor Parallelism#2960

Draft
jiemingz wants to merge 43 commits intoNVIDIA:mainfrom
jiemingz:cuda_graph_final
Draft

Draft:Extended Tensor Parallelism#2960
jiemingz wants to merge 43 commits intoNVIDIA:mainfrom
jiemingz:cuda_graph_final

Conversation

@jiemingz
Copy link
Copy Markdown

@jiemingz jiemingz commented May 5, 2026

Description

Please include a brief summary of the changes, relevant motivation and context.

Fixes # (issue)

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Change A
  • Change B

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

jiemingz and others added 30 commits March 5, 2026 14:45
Signed-off-by: Jieming Zhang <[email protected]>
Signed-off-by: Jieming Zhang <[email protected]>
Signed-off-by: Jieming Zhang <[email protected]>
Signed-off-by: Jieming Zhang <[email protected]>
… avoid allocating fresh memory each iteration.
…G allocations

  1. Release NCCL Work C++ tensor refs promptly in ETPShardHandle.wait()
     (self.handle = None) so wgrad buffers are freed when RS is waited,
     not held until optimizer.step().

  2. Use cache buffers for sync all-gather path (not just async). The old
     code passed out_buffers=None for sync gathers, allocating ~22 GB/iter
     of fresh tensors. Sync gathers now reuse the same ETPWeightCache
     buffers as the async prefetch path.

  3. Add standalone wgrad input buffer pool (_wgrad_buf_pool) for expert
     chain. get_wgrad_tensor() draws from the pool; buffers are returned
     after RS is waited via _wgrad_input_bufs stash in _wait_reduce_scatter.
     Reduces expert wgrad peak from ~4 GB (held until optimizer) to ~640 MB
     (16 buffers reused across all MoE layers).

  4. Stash _wgrad_input_bufs for all chains (not just expert) so ungraphed
     dense weights (output layer) also drop Python refs at _wait_reduce_scatter
     instead of surviving until calc_params_l2_norm.

  5. Fix tensor comparison crash in cache.release(): use identity check
     (any(b is slot.buf ...)) instead of tensor == which returns a
     multi-element bool tensor.
…backward

  For nvfp4, batched_all_gather_and_prefetch_bwd() returns NVFP4TensorStorage
  objects with internal sub-tensors (columnwise data/scale_inv). The local
  `weights` variable kept them alive until function return, wasting memory
  through the wgrad phase. Delete `weights` immediately after the dgrad GEMM
  (the last consumer), saving weight_sizes for the fuse_wgrad_accumulation=False
  fallback path.
…TP AG on IB in expert backward

  ag_stream.wait_stream(main_stream) before batched_all_gather_and_prefetch_bwd
  in grouped_linear.py backward. With --overlap-grad-reduce + CG, DDP backward
  hooks fire a reduce-scatter (IB, main_stream) that races with the EETP
  all-gather (IB, ag_stream), causing NCCL deadlock at 64+ GPU IB scale
…ze_wgrad DDP hook trigger

  Re-add register_grad_accum_hook() to store DDP backward hook on ETP params.
  _finalize_wgrad now calls the hook after RS wait + main_grad.add_(), firing
  DDP register_grad_ready at the correct serialization point. This replaces
  the previous approach of skipping DDP hooks entirely for ETP params.

  param.grad = dummy_grad is a Python attr set (does NOT trigger autograd's
  grad accumulator); the explicit _grad_accum_hook() call is required.
    to drop Python refs to wgrad input buffers immediately. The async RS
    still holds C++ refs via NCCL Work until _wait_reduce_scatter. Reduces
    peak memory during graph capture warmup (~320 MB per MoE layer).
  Three related fixes exposed by adding embedding / output_layer to the
  UNGRAPHED prefetch chain:

  1. iter-2 NaN (consumer-side race on AG output buffer). The async AG
     prefetch was issued from main_stream; NCCL's caller-stream preEvent
     queued behind pending CG/compute and NCCL started late, leaving the
     consumer GEMM reading a partially-written buffer.
     Fix: wrap the async issue in
     in both all_gather_and_prefetch (fwd) and all_gather_and_prefetch_bwd
     (bwd). Added a state guard in _get_prefetched_weight that asserts an
     AG was issued for this consume cycle — catches silent stale-cache
     reads from misconfigured _need_weight_prefetch flags.

  2. Unbounded _wgrad_buf_pool growth. _wait_reduce_scatter pushed the
     wgrad input buffer into the pool unconditionally, but callers that
     don't acquire via _wgrad_pool_get (Megatron layers.py wgrad GEMM,
     aten F.embedding backward) never popped — every iter leaked N fresh
     buffers into the pool.
     Fix: tag pool-owned buffers at _wgrad_pool_get; _wgrad_pool_put
     no-ops on foreign buffers, letting the caching allocator recycle.
     Side effect: throughput 80 → 580 TFLOPs/GPU (pool thrash eliminated).

  3. ag/rs streams partitioned by (chain_id, NCCL group). UNGRAPHED chain
     can span multiple communicators (ETP vs EETP); sharing a single
     user-level stream forced cross-group NCCL ops to serialize. Stream
     dicts are now keyed on (chain_id, id(group)); adds
     get_{ag,rs}_streams_for_chain() helpers.
Signed-off-by: Jieming Zhang <[email protected]>
…empool —

   but only for params on nd MoE paths whose scope is not captured) run eagerly and
   don't the GRAPHED chain. UNGRAPHED-chain params (embedding, output_layer, need their
   quantized storage in the CG mempool.

   (2) _ETP_PARAMS already contains every individual expert (appended per weight_name in
   wrap_module_params_etp), so iterate it directly — no weight_list unroll needed.
fanshiqing and others added 13 commits April 23, 2026 04:09
  Replaces the per-expert (zero_amax + amax + D2D amax replicate) chain in the
  ETP coalesced-amax path with a pair of multi-tensor kernel launches. The
  compute kernel writes rowwise and columnwise amax directly (atomicMaxFloat),
  eliminating the per-expert D2D copy.
Changes:
    - Lazy-cache ag_stream / rs_stream on self (resolved once from
      chain_id + group; prior path hit a dict lookup every call).
    - Cache quantizers / dtypes / etp_group on the anchor weight
      (rebuilt via list comprehensions on every _all_gather_weight call).
    - Consolidate _multi_amax_quantizer_list into _cached_quantizers
      (single cache shared between Tier-2 amax and _all_gather_weight).
    - Gate the duplicate-output-buffer assertion in batched AG behind
      ETP_CONFIG.check_param_states (was running O(N) per call).
    - Drop a dead `out_buffers is not None` check (always a list)
…c86fd

The `with torch.cuda.stream(target.ag_stream):` wrapper re-routed NCCL's                                                        preEvent onto an idle stream, so the AG raced the caller-stream writer
 (quantize / sharded-weight update).

Issue now on caller's stream; _wait_param_gather keeps ag_stream.

Verified: 5000+ steps clean on TP2ETP2_EP2EETP2 nvfp4, 1/4 Ultra,
32xGB200.
  Move gradient accumulation from caller stream to rs_stream inside
  _wait_reduce_scatter(finalize_grad=True). The add_ starts right
  after NCCL RS (concurrent with Phase 1 AG drain) instead of after
  it, avoiding SM-saturation that blocks cross-graph overlap.
Signed-off-by: Jieming Zhang <[email protected]>
@jiemingz jiemingz changed the title Extended Tensor Parallelism Draft:Extended Tensor Parallelism May 5, 2026
@jiemingz jiemingz marked this pull request as draft May 5, 2026 22:58
@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented May 5, 2026

Greptile Summary

This PR introduces Extended Tensor Parallelism (ETP), a new weight-sharding strategy that all-gathers weights on demand during the forward and backward passes instead of replicating them across ranks. It ships a new 1724-line extended_tensor_parallelism.py module, a new nvte_multi_compute_amax CUDA kernel for fused per-expert amax computation, and a split-phase NVFP4 quantize API (compute_amax_nvfp4 / quantize_cast_only_nvfp4) that coalesces N per-expert amax allreduces into a single NCCL call.

  • ETP core (extended_tensor_parallelism.py): ETPShardedParam subclasses nn.Parameter with prefetch linked lists, async AG/RS streams, a ticket-based buffer cache, and CUDA-graph chain classification (GRAPHED vs. UNGRAPHED). Two runtime bugs are present: tensor is undefined in wrap_module_params_etp's else branch (triggered when pad_for_alignment=0), and ETPShardedParam._finalize_wgrad is called inside wait_async_comms(finalize_after_drain=True) but never defined.
  • Distributed changes (distributed.py): _all_gather_nvfp4/_all_gather_mxfp8 gain output_tensor and grouped parameters to support external coalescing managers; a new grouped_gather_along_first_dim batches N expert weight gathers into a single NCCL coalescing region. Several leftover # TODO / #TODO: jiemingz debug comments remain.
  • Module integration (grouped_linear.py, layernorm_linear.py, linear.py): each gains an etp_group constructor parameter wiring ETP weight gather/scatter into their existing autograd functions.

Confidence Score: 2/5

Not safe to merge as-is: one code path crashes on init when pad_for_alignment=0, and a second code path in wait_async_comms(finalize_after_drain=True) calls a method that does not exist.

Two separate runtime failures exist in the new ETP module. The wrap_module_params_etp else branch uses tensor before it is ever assigned, causing immediate NameError during model initialization. Separately, wait_async_comms(finalize_after_drain=True) has a fallback path that calls ETPShardedParam._finalize_wgrad(...), a static method referenced in docstrings but with no implementation anywhere in the class, causing AttributeError whenever a param reaches that branch without an in-flight RS handle.

extended_tensor_parallelism.py requires the most scrutiny: the missing _finalize_wgrad implementation and the tensor NameError in wrap_module_params_etp both live there. distributed.py has several leftover debug TODO comments that should be cleaned up before merge.

Important Files Changed

Filename Overview
transformer_engine/pytorch/module/extended_tensor_parallelism.py Core ETP implementation: 1724-line new file with weight sharding, async AG/RS, coalesced amax, CUDA graph chain classification — contains two runtime bugs: tensor NameError in wrap_module_params_etp when pad_for_alignment=0, and missing _finalize_wgrad method called in wait_async_comms.
transformer_engine/pytorch/distributed.py Extended reduce_scatter_along_first_dim with pre-allocated output, refactored _all_gather_nvfp4/_all_gather_mxfp8 for grouped/coalesced gather, added grouped_gather_along_first_dim — logic correct but several leftover TODO/debug comments remain.
transformer_engine/pytorch/module/grouped_linear.py Adds etp_group parameter and ETP forward/backward paths; saves sharded weights instead of gathered in context; adds batched gather/scatter wrappers — unused import traceback present.
transformer_engine/pytorch/module/layernorm_linear.py Adds etp_group parameter and ETP fwd/bwd paths; fixes out_features hardcoding with out.shape[-1] for variable-width outputs; stores sharded weight in context instead of gathered.
transformer_engine/common/recipe/multi_amax.cu New CUDA kernel nvte_multi_compute_amax: fuses N per-expert (zero_amax + amax + D2D replicate) into two multi-tensor launches; uses 2D grid (y=tensor index, x=work chunks); validates dtype homogeneity and amax buffer presence.
transformer_engine/pytorch/csrc/extensions/cast.cpp Adds compute_amax_nvfp4, quantize_cast_only_nvfp4, and compute_multi_amax_nvfp4 — split-phase quantize API for coalesced amax allreduce; NVFP4-only and with_rht=false restriction enforced.
transformer_engine/pytorch/csrc/quantizer.cpp Adds compute_amax_only and quantize_cast_only to NVFP4Quantizer; adds skip_amax_reduction flag to quantize_impl — clean split-phase separation of amax compute from allreduce from cast.
transformer_engine/common/include/transformer_engine/recipe.h Adds nvte_multi_compute_amax declaration with clear documentation of semantics and constraints (shared dtype, amax buffer required, internal chunking).
transformer_engine/pytorch/module/linear.py Minor ETP integration: passes etp_size through to the autograd function and adjusts wgrad/dgrad paths conditionally.
tests/pytorch/distributed/test_etp.py 1411-line test suite covering ETP forward/backward correctness for NVFP4/MXFP8/BF16, coalesced amax path, prefetch chain, and grouped-expert scenarios.

Sequence Diagram

sequenceDiagram
    participant Caller as Training Loop
    participant ETP as ETPShardedParam
    participant Cache as ETPWeightCache
    participant NCCL as NCCL (AG Stream)
    participant Quant as NVFP4Quantizer
    participant GEMM

    Note over Caller,GEMM: Forward Pass (ETP weight gather + GEMM)
    Caller->>ETP: all_gather_and_prefetch(fwd=True)
    ETP->>Quant: compute_amax_only (per-expert, fused)
    ETP->>NCCL: all_reduce_coalesced(amax tensors)
    ETP->>Quant: quantize_cast_only (per-expert)
    ETP->>Cache: get(ag_ticket_fwd)
    ETP->>NCCL: grouped_gather_along_first_dim (async)
    ETP-->>Caller: gathered_weight plus ETPShardHandle
    Caller->>ETP: prefetch next_w async on ag_stream
    Caller->>GEMM: forward GEMM with input and gathered_weight

    Note over Caller,GEMM: Backward Pass (dgrad + wgrad + RS)
    Caller->>ETP: all_gather_and_prefetch_bwd()
    ETP->>Cache: get(ag_ticket_bwd)
    ETP->>NCCL: gather with skip_weight_cast=True async
    Caller->>GEMM: dgrad GEMM with grad_output and gathered_weight
    Caller->>GEMM: wgrad GEMM with input and grad_output
    Caller->>ETP: wgrad_reduce_scatter(wgrad)
    ETP->>NCCL: reduce_scatter async on rs_stream
    ETP->>ETP: main_grad.add_ rs_result deferred
Loading

Reviews (1): Last reviewed commit: "ETP: pad full tensor before sharding ins..." | Re-trigger Greptile

Comment on lines +478 to +481
shard_size = tensor.shape[0] // etp_group.size()
shard = tensor[etp_rank * shard_size: (etp_rank + 1) * shard_size]
etp_shard = ETPShardedParam(shard.clone())

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 tensor is undefined in the else branch

tensor = param.data is only assigned inside the if ETP_CONFIG.pad_for_alignment > 0: block (line 467), but the else branch at line 478 references tensor without initializing it. When ETP_CONFIG.pad_for_alignment == 0 (e.g., after update_config(pad_for_alignment=0)), this raises NameError: name 'tensor' is not defined and ETP initialization crashes entirely.

Comment on lines +1638 to +1645
if not skip_rs:
param._wait_reduce_scatter(finalize_grad=finalize_after_drain)
if finalize_after_drain and not getattr(param, '_already_finalized', False):
cache = get_global_ETP_cache()
param.rs_event.wait()
for w in param._weights:
ETPShardedParam._finalize_wgrad(w, cache.get(w._rs_ticket))
cache.release(w._rs_ticket)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 ETPShardedParam._finalize_wgrad doesn't exist

wait_async_comms(finalize_after_drain=True) reaches the fallback branch on line 1644 when param._wgrad_rs_handle is None (i.e., the param had no in-flight async RS). In that case _wait_reduce_scatter never sets _already_finalized = True, so the if finalize_after_drain and not getattr(param, '_already_finalized', False): block executes and calls ETPShardedParam._finalize_wgrad(...), which is not defined anywhere in the class — causing AttributeError at runtime. The docstring says "Falls back to caller-stream _finalize_wgrad", confirming this is a real planned path, not dead code.

Comment on lines +1256 to +1265
# TODO
# # Fix the interleaved transposed data from gathering along first dim.
# out._columnwise_scale_inv = _swap_first_dims(columnwise_scale_inv_interleaved, world_size)
# out._columnwise_data = _swap_first_dims(columnwise_data_interleaved, world_size)
out._columnwise_scale_inv.copy_(_swap_first_dims(columnwise_scale_inv_interleaved, world_size))
out._columnwise_data.copy_(_swap_first_dims(columnwise_data_interleaved, world_size))

# Optionally pad the scaling inverse if needed.
out._columnwise_scale_inv = pad_columnwise_scale_inv(out._columnwise_scale_inv)
# # Optionally pad the scaling inverse if needed.
# out._columnwise_scale_inv = pad_columnwise_scale_inv(out._columnwise_scale_inv)
out._columnwise_scale_inv.copy_(pad_columnwise_scale_inv(out._columnwise_scale_inv))
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Left-in debug comments obscure the intent of the change. The original assignment pattern was replaced with copy_() to avoid rebinding the tensor reference (needed for CUDA graph capture), but the commented-out code and # TODO are leftover noise.

Suggested change
# TODO
# # Fix the interleaved transposed data from gathering along first dim.
# out._columnwise_scale_inv = _swap_first_dims(columnwise_scale_inv_interleaved, world_size)
# out._columnwise_data = _swap_first_dims(columnwise_data_interleaved, world_size)
out._columnwise_scale_inv.copy_(_swap_first_dims(columnwise_scale_inv_interleaved, world_size))
out._columnwise_data.copy_(_swap_first_dims(columnwise_data_interleaved, world_size))
# Optionally pad the scaling inverse if needed.
out._columnwise_scale_inv = pad_columnwise_scale_inv(out._columnwise_scale_inv)
# # Optionally pad the scaling inverse if needed.
# out._columnwise_scale_inv = pad_columnwise_scale_inv(out._columnwise_scale_inv)
out._columnwise_scale_inv.copy_(pad_columnwise_scale_inv(out._columnwise_scale_inv))
# Fix the interleaved transposed data from gathering along first dim.
out._columnwise_scale_inv.copy_(_swap_first_dims(columnwise_scale_inv_interleaved, world_size))
out._columnwise_data.copy_(_swap_first_dims(columnwise_data_interleaved, world_size))
# Optionally pad the scaling inverse if needed.
out._columnwise_scale_inv.copy_(pad_columnwise_scale_inv(out._columnwise_scale_inv))

Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!

Comment on lines 1415 to +1418
# Transfer amax to output.
out._amax_rowwise = inp._amax_rowwise
#TODO: jiemingz
# out._amax_rowwise = inp._amax_rowwise
out._amax_rowwise.copy_(inp._amax_rowwise)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 The #TODO: jiemingz comment and the commented-out code are left-in debug markers. The assignment was intentionally switched to copy_() to keep the output buffer's identity stable for CUDA graph capture; the comment should explain that intent instead.

Suggested change
# Transfer amax to output.
out._amax_rowwise = inp._amax_rowwise
#TODO: jiemingz
# out._amax_rowwise = inp._amax_rowwise
out._amax_rowwise.copy_(inp._amax_rowwise)
# Transfer amax to output. Use copy_() to preserve buffer identity
# (required for CUDA graph capture).
out._amax_rowwise.copy_(inp._amax_rowwise)

Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!

Comment on lines 6 to 9
from typing import Union, Optional, Callable, Tuple, List
from itertools import chain
import traceback
import warnings
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 traceback is imported but not used anywhere in this file.

Suggested change
from typing import Union, Optional, Callable, Tuple, List
from itertools import chain
import traceback
import warnings
from typing import Union, Optional, Callable, Tuple, List
from itertools import chain
import warnings

Comment on lines +1433 to +1468
@staticmethod
def _buf_bytes(shape, dtype) -> int:
"""Estimate buffer size in bytes."""
numel = 1
for d in shape:
numel *= d
bpe = ETPWeightCache._BYTES_PER_ELEMENT.get(dtype, None)
return numel * bpe

def _allocate_buffer(self, param: 'ETPShardedParam', dtype, reduce_scatter, fwd) -> torch.Tensor:
if reduce_scatter:
out_shape = param._sharded_padded_shape
else:
out_shape = param._unsharded_shape_padded

if not isinstance(dtype, torch.dtype):
quantizer = param._quantizer
assert quantizer is not None
param._quantizer.set_usage(rowwise=fwd, columnwise=not fwd)

buf = param._quantizer.make_empty(
out_shape,
dtype=torch.bfloat16,
device=torch.cuda.current_device(),
)
else:
buf = torch.empty(
out_shape, dtype=dtype, device=param.device, memory_format=torch.contiguous_format
)

buf_bytes = self._buf_bytes(out_shape, dtype)
self._total_bytes += buf_bytes
print_rank_0(
f"[ETP Cache] +{buf_bytes / 1024**2:.1f} MB (shape={out_shape}, dtype={dtype}) "
f"total={self._total_bytes / 1024**2:.1f} MB id: {id(buf)} fwd: {fwd}"
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 _buf_bytes silently returns None for unrecognized dtypes, crashing on arithmetic

_BYTES_PER_ELEMENT only maps torch.bfloat16, torch.float16, torch.float32, tex.DType.kFloat4E2M1, and tex.DType.kFloat8E4M3. If a weight uses tex.DType.kFloat8E5M2 (or any other unregistered type), _buf_bytes returns None, and self._total_bytes += None on line 1464 raises TypeError, crashing every buffer allocation for that ETP param.

Comment on lines +1463 to +1469
buf_bytes = self._buf_bytes(out_shape, dtype)
self._total_bytes += buf_bytes
print_rank_0(
f"[ETP Cache] +{buf_bytes / 1024**2:.1f} MB (shape={out_shape}, dtype={dtype}) "
f"total={self._total_bytes / 1024**2:.1f} MB id: {id(buf)} fwd: {fwd}"
)
return buf
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Unconditional rank-0 logging on every buffer allocation

_allocate_buffer always calls print_rank_0(...) regardless of any debug flag. In a production run with many ETP params this produces a burst of unfiltered output to stdout on every model init, including during CUDA graph capture warm-up. Consider gating this behind ETP_CONFIG.debug_numerics > 0 or a dedicated verbose flag, consistent with the rest of the ETP debug paths.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants