Draft:Extended Tensor Parallelism#2960
Conversation
Signed-off-by: Jieming Zhang <[email protected]>
Signed-off-by: Jieming Zhang <[email protected]>
Signed-off-by: Jieming Zhang <[email protected]>
Signed-off-by: Jieming Zhang <[email protected]>
Signed-off-by: Jieming Zhang <[email protected]>
… avoid allocating fresh memory each iteration.
…G allocations
1. Release NCCL Work C++ tensor refs promptly in ETPShardHandle.wait()
(self.handle = None) so wgrad buffers are freed when RS is waited,
not held until optimizer.step().
2. Use cache buffers for sync all-gather path (not just async). The old
code passed out_buffers=None for sync gathers, allocating ~22 GB/iter
of fresh tensors. Sync gathers now reuse the same ETPWeightCache
buffers as the async prefetch path.
3. Add standalone wgrad input buffer pool (_wgrad_buf_pool) for expert
chain. get_wgrad_tensor() draws from the pool; buffers are returned
after RS is waited via _wgrad_input_bufs stash in _wait_reduce_scatter.
Reduces expert wgrad peak from ~4 GB (held until optimizer) to ~640 MB
(16 buffers reused across all MoE layers).
4. Stash _wgrad_input_bufs for all chains (not just expert) so ungraphed
dense weights (output layer) also drop Python refs at _wait_reduce_scatter
instead of surviving until calc_params_l2_norm.
5. Fix tensor comparison crash in cache.release(): use identity check
(any(b is slot.buf ...)) instead of tensor == which returns a
multi-element bool tensor.
…backward For nvfp4, batched_all_gather_and_prefetch_bwd() returns NVFP4TensorStorage objects with internal sub-tensors (columnwise data/scale_inv). The local `weights` variable kept them alive until function return, wasting memory through the wgrad phase. Delete `weights` immediately after the dgrad GEMM (the last consumer), saving weight_sizes for the fuse_wgrad_accumulation=False fallback path.
…TP AG on IB in expert backward ag_stream.wait_stream(main_stream) before batched_all_gather_and_prefetch_bwd in grouped_linear.py backward. With --overlap-grad-reduce + CG, DDP backward hooks fire a reduce-scatter (IB, main_stream) that races with the EETP all-gather (IB, ag_stream), causing NCCL deadlock at 64+ GPU IB scale
…ze_wgrad DDP hook trigger Re-add register_grad_accum_hook() to store DDP backward hook on ETP params. _finalize_wgrad now calls the hook after RS wait + main_grad.add_(), firing DDP register_grad_ready at the correct serialization point. This replaces the previous approach of skipping DDP hooks entirely for ETP params. param.grad = dummy_grad is a Python attr set (does NOT trigger autograd's grad accumulator); the explicit _grad_accum_hook() call is required.
to drop Python refs to wgrad input buffers immediately. The async RS
still holds C++ refs via NCCL Work until _wait_reduce_scatter. Reduces
peak memory during graph capture warmup (~320 MB per MoE layer).
This reverts commit e09983c.
Three related fixes exposed by adding embedding / output_layer to the
UNGRAPHED prefetch chain:
1. iter-2 NaN (consumer-side race on AG output buffer). The async AG
prefetch was issued from main_stream; NCCL's caller-stream preEvent
queued behind pending CG/compute and NCCL started late, leaving the
consumer GEMM reading a partially-written buffer.
Fix: wrap the async issue in
in both all_gather_and_prefetch (fwd) and all_gather_and_prefetch_bwd
(bwd). Added a state guard in _get_prefetched_weight that asserts an
AG was issued for this consume cycle — catches silent stale-cache
reads from misconfigured _need_weight_prefetch flags.
2. Unbounded _wgrad_buf_pool growth. _wait_reduce_scatter pushed the
wgrad input buffer into the pool unconditionally, but callers that
don't acquire via _wgrad_pool_get (Megatron layers.py wgrad GEMM,
aten F.embedding backward) never popped — every iter leaked N fresh
buffers into the pool.
Fix: tag pool-owned buffers at _wgrad_pool_get; _wgrad_pool_put
no-ops on foreign buffers, letting the caching allocator recycle.
Side effect: throughput 80 → 580 TFLOPs/GPU (pool thrash eliminated).
3. ag/rs streams partitioned by (chain_id, NCCL group). UNGRAPHED chain
can span multiple communicators (ETP vs EETP); sharing a single
user-level stream forced cross-group NCCL ops to serialize. Stream
dicts are now keyed on (chain_id, id(group)); adds
get_{ag,rs}_streams_for_chain() helpers.
Signed-off-by: Jieming Zhang <[email protected]>
…empool — but only for params on nd MoE paths whose scope is not captured) run eagerly and don't the GRAPHED chain. UNGRAPHED-chain params (embedding, output_layer, need their quantized storage in the CG mempool. (2) _ETP_PARAMS already contains every individual expert (appended per weight_name in wrap_module_params_etp), so iterate it directly — no weight_list unroll needed.
Replaces the per-expert (zero_amax + amax + D2D amax replicate) chain in the ETP coalesced-amax path with a pair of multi-tensor kernel launches. The compute kernel writes rowwise and columnwise amax directly (atomicMaxFloat), eliminating the per-expert D2D copy.
Changes:
- Lazy-cache ag_stream / rs_stream on self (resolved once from
chain_id + group; prior path hit a dict lookup every call).
- Cache quantizers / dtypes / etp_group on the anchor weight
(rebuilt via list comprehensions on every _all_gather_weight call).
- Consolidate _multi_amax_quantizer_list into _cached_quantizers
(single cache shared between Tier-2 amax and _all_gather_weight).
- Gate the duplicate-output-buffer assertion in batched AG behind
ETP_CONFIG.check_param_states (was running O(N) per call).
- Drop a dead `out_buffers is not None` check (always a list)
…c86fd The `with torch.cuda.stream(target.ag_stream):` wrapper re-routed NCCL's preEvent onto an idle stream, so the AG raced the caller-stream writer (quantize / sharded-weight update). Issue now on caller's stream; _wait_param_gather keeps ag_stream. Verified: 5000+ steps clean on TP2ETP2_EP2EETP2 nvfp4, 1/4 Ultra, 32xGB200.
…treams with explicit producer event
Move gradient accumulation from caller stream to rs_stream inside _wait_reduce_scatter(finalize_grad=True). The add_ starts right after NCCL RS (concurrent with Phase 1 AG drain) instead of after it, avoiding SM-saturation that blocks cross-graph overlap.
Signed-off-by: Jieming Zhang <[email protected]>
Greptile SummaryThis PR introduces Extended Tensor Parallelism (ETP), a new weight-sharding strategy that all-gathers weights on demand during the forward and backward passes instead of replicating them across ranks. It ships a new 1724-line
Confidence Score: 2/5Not safe to merge as-is: one code path crashes on init when Two separate runtime failures exist in the new ETP module. The
Important Files Changed
Sequence DiagramsequenceDiagram
participant Caller as Training Loop
participant ETP as ETPShardedParam
participant Cache as ETPWeightCache
participant NCCL as NCCL (AG Stream)
participant Quant as NVFP4Quantizer
participant GEMM
Note over Caller,GEMM: Forward Pass (ETP weight gather + GEMM)
Caller->>ETP: all_gather_and_prefetch(fwd=True)
ETP->>Quant: compute_amax_only (per-expert, fused)
ETP->>NCCL: all_reduce_coalesced(amax tensors)
ETP->>Quant: quantize_cast_only (per-expert)
ETP->>Cache: get(ag_ticket_fwd)
ETP->>NCCL: grouped_gather_along_first_dim (async)
ETP-->>Caller: gathered_weight plus ETPShardHandle
Caller->>ETP: prefetch next_w async on ag_stream
Caller->>GEMM: forward GEMM with input and gathered_weight
Note over Caller,GEMM: Backward Pass (dgrad + wgrad + RS)
Caller->>ETP: all_gather_and_prefetch_bwd()
ETP->>Cache: get(ag_ticket_bwd)
ETP->>NCCL: gather with skip_weight_cast=True async
Caller->>GEMM: dgrad GEMM with grad_output and gathered_weight
Caller->>GEMM: wgrad GEMM with input and grad_output
Caller->>ETP: wgrad_reduce_scatter(wgrad)
ETP->>NCCL: reduce_scatter async on rs_stream
ETP->>ETP: main_grad.add_ rs_result deferred
Reviews (1): Last reviewed commit: "ETP: pad full tensor before sharding ins..." | Re-trigger Greptile |
| shard_size = tensor.shape[0] // etp_group.size() | ||
| shard = tensor[etp_rank * shard_size: (etp_rank + 1) * shard_size] | ||
| etp_shard = ETPShardedParam(shard.clone()) | ||
|
|
There was a problem hiding this comment.
tensor is undefined in the else branch
tensor = param.data is only assigned inside the if ETP_CONFIG.pad_for_alignment > 0: block (line 467), but the else branch at line 478 references tensor without initializing it. When ETP_CONFIG.pad_for_alignment == 0 (e.g., after update_config(pad_for_alignment=0)), this raises NameError: name 'tensor' is not defined and ETP initialization crashes entirely.
| if not skip_rs: | ||
| param._wait_reduce_scatter(finalize_grad=finalize_after_drain) | ||
| if finalize_after_drain and not getattr(param, '_already_finalized', False): | ||
| cache = get_global_ETP_cache() | ||
| param.rs_event.wait() | ||
| for w in param._weights: | ||
| ETPShardedParam._finalize_wgrad(w, cache.get(w._rs_ticket)) | ||
| cache.release(w._rs_ticket) |
There was a problem hiding this comment.
ETPShardedParam._finalize_wgrad doesn't exist
wait_async_comms(finalize_after_drain=True) reaches the fallback branch on line 1644 when param._wgrad_rs_handle is None (i.e., the param had no in-flight async RS). In that case _wait_reduce_scatter never sets _already_finalized = True, so the if finalize_after_drain and not getattr(param, '_already_finalized', False): block executes and calls ETPShardedParam._finalize_wgrad(...), which is not defined anywhere in the class — causing AttributeError at runtime. The docstring says "Falls back to caller-stream _finalize_wgrad", confirming this is a real planned path, not dead code.
| # TODO | ||
| # # Fix the interleaved transposed data from gathering along first dim. | ||
| # out._columnwise_scale_inv = _swap_first_dims(columnwise_scale_inv_interleaved, world_size) | ||
| # out._columnwise_data = _swap_first_dims(columnwise_data_interleaved, world_size) | ||
| out._columnwise_scale_inv.copy_(_swap_first_dims(columnwise_scale_inv_interleaved, world_size)) | ||
| out._columnwise_data.copy_(_swap_first_dims(columnwise_data_interleaved, world_size)) | ||
|
|
||
| # Optionally pad the scaling inverse if needed. | ||
| out._columnwise_scale_inv = pad_columnwise_scale_inv(out._columnwise_scale_inv) | ||
| # # Optionally pad the scaling inverse if needed. | ||
| # out._columnwise_scale_inv = pad_columnwise_scale_inv(out._columnwise_scale_inv) | ||
| out._columnwise_scale_inv.copy_(pad_columnwise_scale_inv(out._columnwise_scale_inv)) |
There was a problem hiding this comment.
Left-in debug comments obscure the intent of the change. The original assignment pattern was replaced with
copy_() to avoid rebinding the tensor reference (needed for CUDA graph capture), but the commented-out code and # TODO are leftover noise.
| # TODO | |
| # # Fix the interleaved transposed data from gathering along first dim. | |
| # out._columnwise_scale_inv = _swap_first_dims(columnwise_scale_inv_interleaved, world_size) | |
| # out._columnwise_data = _swap_first_dims(columnwise_data_interleaved, world_size) | |
| out._columnwise_scale_inv.copy_(_swap_first_dims(columnwise_scale_inv_interleaved, world_size)) | |
| out._columnwise_data.copy_(_swap_first_dims(columnwise_data_interleaved, world_size)) | |
| # Optionally pad the scaling inverse if needed. | |
| out._columnwise_scale_inv = pad_columnwise_scale_inv(out._columnwise_scale_inv) | |
| # # Optionally pad the scaling inverse if needed. | |
| # out._columnwise_scale_inv = pad_columnwise_scale_inv(out._columnwise_scale_inv) | |
| out._columnwise_scale_inv.copy_(pad_columnwise_scale_inv(out._columnwise_scale_inv)) | |
| # Fix the interleaved transposed data from gathering along first dim. | |
| out._columnwise_scale_inv.copy_(_swap_first_dims(columnwise_scale_inv_interleaved, world_size)) | |
| out._columnwise_data.copy_(_swap_first_dims(columnwise_data_interleaved, world_size)) | |
| # Optionally pad the scaling inverse if needed. | |
| out._columnwise_scale_inv.copy_(pad_columnwise_scale_inv(out._columnwise_scale_inv)) |
Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!
| # Transfer amax to output. | ||
| out._amax_rowwise = inp._amax_rowwise | ||
| #TODO: jiemingz | ||
| # out._amax_rowwise = inp._amax_rowwise | ||
| out._amax_rowwise.copy_(inp._amax_rowwise) |
There was a problem hiding this comment.
The
#TODO: jiemingz comment and the commented-out code are left-in debug markers. The assignment was intentionally switched to copy_() to keep the output buffer's identity stable for CUDA graph capture; the comment should explain that intent instead.
| # Transfer amax to output. | |
| out._amax_rowwise = inp._amax_rowwise | |
| #TODO: jiemingz | |
| # out._amax_rowwise = inp._amax_rowwise | |
| out._amax_rowwise.copy_(inp._amax_rowwise) | |
| # Transfer amax to output. Use copy_() to preserve buffer identity | |
| # (required for CUDA graph capture). | |
| out._amax_rowwise.copy_(inp._amax_rowwise) |
Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!
| from typing import Union, Optional, Callable, Tuple, List | ||
| from itertools import chain | ||
| import traceback | ||
| import warnings |
There was a problem hiding this comment.
| @staticmethod | ||
| def _buf_bytes(shape, dtype) -> int: | ||
| """Estimate buffer size in bytes.""" | ||
| numel = 1 | ||
| for d in shape: | ||
| numel *= d | ||
| bpe = ETPWeightCache._BYTES_PER_ELEMENT.get(dtype, None) | ||
| return numel * bpe | ||
|
|
||
| def _allocate_buffer(self, param: 'ETPShardedParam', dtype, reduce_scatter, fwd) -> torch.Tensor: | ||
| if reduce_scatter: | ||
| out_shape = param._sharded_padded_shape | ||
| else: | ||
| out_shape = param._unsharded_shape_padded | ||
|
|
||
| if not isinstance(dtype, torch.dtype): | ||
| quantizer = param._quantizer | ||
| assert quantizer is not None | ||
| param._quantizer.set_usage(rowwise=fwd, columnwise=not fwd) | ||
|
|
||
| buf = param._quantizer.make_empty( | ||
| out_shape, | ||
| dtype=torch.bfloat16, | ||
| device=torch.cuda.current_device(), | ||
| ) | ||
| else: | ||
| buf = torch.empty( | ||
| out_shape, dtype=dtype, device=param.device, memory_format=torch.contiguous_format | ||
| ) | ||
|
|
||
| buf_bytes = self._buf_bytes(out_shape, dtype) | ||
| self._total_bytes += buf_bytes | ||
| print_rank_0( | ||
| f"[ETP Cache] +{buf_bytes / 1024**2:.1f} MB (shape={out_shape}, dtype={dtype}) " | ||
| f"total={self._total_bytes / 1024**2:.1f} MB id: {id(buf)} fwd: {fwd}" | ||
| ) |
There was a problem hiding this comment.
_buf_bytes silently returns None for unrecognized dtypes, crashing on arithmetic
_BYTES_PER_ELEMENT only maps torch.bfloat16, torch.float16, torch.float32, tex.DType.kFloat4E2M1, and tex.DType.kFloat8E4M3. If a weight uses tex.DType.kFloat8E5M2 (or any other unregistered type), _buf_bytes returns None, and self._total_bytes += None on line 1464 raises TypeError, crashing every buffer allocation for that ETP param.
| buf_bytes = self._buf_bytes(out_shape, dtype) | ||
| self._total_bytes += buf_bytes | ||
| print_rank_0( | ||
| f"[ETP Cache] +{buf_bytes / 1024**2:.1f} MB (shape={out_shape}, dtype={dtype}) " | ||
| f"total={self._total_bytes / 1024**2:.1f} MB id: {id(buf)} fwd: {fwd}" | ||
| ) | ||
| return buf |
There was a problem hiding this comment.
Unconditional rank-0 logging on every buffer allocation
_allocate_buffer always calls print_rank_0(...) regardless of any debug flag. In a production run with many ETP params this produces a burst of unfiltered output to stdout on every model init, including during CUDA graph capture warm-up. Consider gating this behind ETP_CONFIG.debug_numerics > 0 or a dedicated verbose flag, consistent with the rest of the ETP debug paths.
Description
Please include a brief summary of the changes, relevant motivation and context.
Fixes # (issue)
Type of change
Changes
Please list the changes introduced in this PR:
Checklist: