Skip to content

Implement row-scaled NVFP4 fprop recipe#2931

Open
zianglih wants to merge 43 commits intoNVIDIA:mainfrom
zianglih:fp4-per-token
Open

Implement row-scaled NVFP4 fprop recipe#2931
zianglih wants to merge 43 commits intoNVIDIA:mainfrom
zianglih:fp4-per-token

Conversation

@zianglih
Copy link
Copy Markdown
Contributor

@zianglih zianglih commented Apr 27, 2026

Description

@HumansAnd

Implement per-token row-scaled NVFP4 recipe with fprop only.
Currently, the row-scaled scaling is handled by separate pytorch code.
Quantization kernels are bitwise exact with existing TE reference implementation.

The following tests passed on B200:

python3 -m pytest --tb=auto tests/pytorch/nvfp4/test_nvfp4_quantize_exact.py
python3 -m pytest --tb=auto tests/pytorch/nvfp4/test_nvfp4_gemm_exact.py
python3 -m pytest --tb=auto tests/pytorch/test_backward_override.py
python3 -m pytest --tb=auto tests/pytorch/test_sanity.py
python3 -m pytest --tb=auto tests/pytorch/test_recipe.py
python3 -m pytest --tb=auto tests/pytorch/test_torch_compile.py
python3 -m pytest --tb=auto tests/pytorch/test_cpu_offloading.py
PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 NVTE_FUSED_ATTN=0 python3 -m pytest --tb=auto tests/pytorch/test_cuda_graphs.py
NVTE_TEST_NVINSPECT_ENABLED=1 NVTE_TEST_NVINSPECT_CONFIG_FILE=tests/pytorch/debug/test_configs/dummy_feature.yaml NVTE_TEST_NVINSPECT_FEATURE_DIRS=transformer_engine/debug/features PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 python3 -m pytest --tb=auto tests/pytorch/test_sanity.py

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Add a row_scaled_activation field in nvfp4 recipe, can be turned on by NVTE_NVFP4_ROW_SCALED_ACTIVATION
  • New per-token nvfp4 quantize kernels in transformer_engine/common/cast/nvfp4/quantize_pertoken_nvfp4.cuh, bitwise exact with existing TE pytorch reference implementation and per-tesor nvfp4 emulated implmentation. New quantization kernels folded into existing nvfp4 quantization kernels.
  • Expand dequant kernel transformer_engine/common/cast/nvfp4/dequantize_nvfp4.cuh to correctly handle this row-scaled nvfp4
  • In TransformerEngine/transformer_engine/pytorch/cpp_extensions/gemm.py, if row-scaled nvfp4 is enabled, it conducts separate per-token scaling using pytorch code, after cublas gemm
  • Broad test coverage by expanding 7 python and 2 cpp test files
  • Modify 1d quant reference implementation in tests/cpp/operator/test_cast_nvfp4_transpose.cu to align with pytorch reference numerics

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@zianglih zianglih marked this pull request as draft April 27, 2026 06:24
@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented Apr 27, 2026

Greptile Summary

This PR adds a per-token (row-scaled) NVFP4 forward quantization recipe for TransformerEngine. When enabled via NVFP4BlockScaling(row_scaled_activation=True) or the NVTE_NVFP4_ROW_SCALED_ACTIVATION env var, forward activation quantizers store one FP32 amax per tensor row instead of a single global scalar; the row-scaled GEMM path then applies those per-token scales in FP32 after the cuBLAS NVFP4 GEMM.

  • New compute_rowwise_amax_kernel computes per-row absolute maxima on the GPU; the existing quantize_transpose kernel gains a ROW_SCALED_NVFP4 template parameter that reads these row amaxes in place of the global amax during quantization.
  • general_gemm and general_grouped_gemm detect row-scaled B tensors, neutralise the global amax before passing data to cuBLAS, and multiply the FP32 output by the per-row scales afterward; the feature is guarded to fprop-only with an explicit RuntimeError.
  • All storage classes (NVFP4TensorStorage, GroupedTensorStorage, NVFP4Quantizer, C++ TensorWrapper) are extended with a row_scaled_nvfp4 flag that propagates through allocation, cloning, serialisation, and C++/Python bindings.

Confidence Score: 4/5

Safe to merge for the fprop-only use case that was tested; the unsupported backward path is blocked by a RuntimeError. The remaining bare assert guards can be silently disabled under -O, but they protect corners not reachable in the tested training configurations.

The row-scaled quantization logic — two-step kernel flow, per-row amax bookkeeping, neutralising global amax before cuBLAS, and post-GEMM FP32 rescaling — is mathematically correct and verified by bitwise-exact tests on B200. The main unresolved structural issue is that several assert checks survived the conversion of the grad guard to RuntimeError, meaning optimised Python builds would silently enter undefined states for currently unsupported combinations. The recipe class also still lacks a constructor guard that would catch incompatible backward_override settings at object creation time.

transformer_engine/pytorch/cpp_extensions/gemm.py — the cluster of assert statements in the row-scaled branch; transformer_engine/common/recipe/__init__.py — missing __post_init__ validation for the backward_override requirement.

Important Files Changed

Filename Overview
transformer_engine/pytorch/cpp_extensions/gemm.py Adds the row-scaled NVFP4 path to both general_gemm and general_grouped_gemm. The grad guard correctly uses RuntimeError, but eight other guards (GELU, accumulate, CommOverlap, etc.) still use bare assert and can be silently disabled.
transformer_engine/common/cast/nvfp4/quantize_transpose_nvfp4.cuh Adds compute_rowwise_amax_kernel and the ROW_SCALED_NVFP4 template branch inside quantize_transpose_nvfp4_kernel. The two-step flow (row-amax first, quantize second on the same stream) is correct; bounds checks and noop handling are consistent with existing patterns.
transformer_engine/common/recipe/init.py Adds row_scaled_activation field and env-var gate. No __post_init__ validation that row_scaled_activation=True requires a non-default backward_override; users who omit the override will hit an AssertionError deep in the first backward pass (flagged in a previous review thread).
transformer_engine/pytorch/csrc/quantizer.cpp Propagates row_scaled_nvfp4 through create_tensor, create_grouped_tensor, convert_and_update_tensor, and quantize_impl. Amax buffer sizing and shape registration are updated consistently; set_amax now uses dynamic shape rather than hard-coded {1}.
transformer_engine/pytorch/tensor/storage/nvfp4_tensor_storage.py Cleanly adds _row_scaled_nvfp4 field, propagates it through __new__, copy_from_storage, get_metadata, and make_like. Mismatch check in copy_from_storage prevents silent data corruption.
transformer_engine/common/cast/nvfp4/dequantize_nvfp4.cuh Extends dequantize_fp4_kernel with a row_scaled_nvfp4 flag that indexes tensor_amax[y] instead of tensor_amax[0]. The pre-existing if (y >= N) return; guard prevents out-of-bounds reads; a pre-launch NVTE_CHECK validates amax size.
transformer_engine/pytorch/quantization.py Forwards row_scaled_nvfp4 to forward quantizers via idx % 3 != 1 positional heuristic; backward quantizers explicitly receive False. The positional heuristic is fragile (noted in a prior review thread); backward path is safe.
transformer_engine/pytorch/tensor/nvfp4_tensor.py Adds row_scaled_nvfp4 to NVFP4Quantizer and wires it into tensor allocation and cloning. is_quantizable returns False for row-scaled quantizers without an explanatory comment, which may confuse future maintainers.
transformer_engine/pytorch/custom_recipes/quantization_nvfp4.py Updates the reference quantizer to support row-scaled amaxes, including correct per-row amax computation, zero-amax replacement, and gemm_ref output-scale reshaping. Logic is consistent with the kernel-level implementation.

Sequence Diagram

sequenceDiagram
    participant Mod as Module Forward
    participant Quant as NVFP4Quantizer (row_scaled=True)
    participant RowAmax as compute_rowwise_amax kernel
    participant QKernel as quantize_transpose kernel (ROW_SCALED_NVFP4)
    participant PyGEMM as general_gemm (Python)
    participant cuBLAS as cuBLAS NVFP4 GEMM

    Mod->>Quant: quantize(activation)
    Quant->>RowAmax: compute per-row abs-max to amax[0..M-1]
    RowAmax-->>Quant: amax array written to tensor
    Quant->>QKernel: quantize using amax[row_idx] per block
    QKernel-->>Quant: NVFP4 tensor (row_scaled_nvfp4=True)
    Quant-->>Mod: NVFP4Tensor with per-row amax metadata

    Mod->>PyGEMM: general_gemm(A=weight, B=activation_fp4)
    Note over PyGEMM: Detect _is_nvfp4_row_scaled_tensor(B)
    PyGEMM->>PyGEMM: Replace amax to 1.0 in A and B metadata, capture rowwise_global_scales
    PyGEMM->>cuBLAS: GEMM with amax=1.0 to FP32 output
    cuBLAS-->>PyGEMM: out_fp32
    PyGEMM->>PyGEMM: out_fp32 *= rowwise_global_scales per row, add bias in FP32
    PyGEMM->>PyGEMM: cast to requested_out_dtype
    PyGEMM-->>Mod: scaled output
Loading

Reviews (9): Last reviewed commit: "Minor" | Re-trigger Greptile

Comment thread transformer_engine/pytorch/cpp_extensions/gemm.py Outdated
// Compute "correct" per-block encoding scaling factor
const float S_enc_b_fp8 = S_dec_b_fp32 == 0.f ? 0.f : S_enc / S_dec_b_fp32;
const float S_enc_b_fp8 = S_dec_b_fp32 == 0.f ? 0.f :
fminf(1.0f / (S_dec_b_fp32 * (1.0f / S_enc)), Numeric_Traits<float>::maxNorm);
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have to change here to stay aligned with pytorch reference.

@zianglih zianglih marked this pull request as ready for review April 27, 2026 09:14
@zianglih zianglih marked this pull request as draft May 2, 2026 18:22
zianglih and others added 14 commits May 2, 2026 11:27
Signed-off-by: Ziang Li <[email protected]>
Signed-off-by: Ziang Li <[email protected]>
Signed-off-by: Ziang Li <[email protected]>
Signed-off-by: Ziang Li <[email protected]>
Signed-off-by: Ziang Li <[email protected]>
Signed-off-by: Ziang Li <[email protected]>
Signed-off-by: Ziang Li <[email protected]>
Signed-off-by: Ziang Li <[email protected]>
Signed-off-by: Ziang Li <[email protected]>
Signed-off-by: Ziang Li <[email protected]>
@ziang-and ziang-and force-pushed the fp4-per-token branch 2 times, most recently from 6998f64 to 5b2f606 Compare May 2, 2026 19:10
zianglih added 5 commits May 2, 2026 16:33
Signed-off-by: Ziang Li <[email protected]>
Signed-off-by: Ziang Li <[email protected]>
Signed-off-by: Ziang Li <[email protected]>
Signed-off-by: Ziang Li <[email protected]>
@zianglih
Copy link
Copy Markdown
Contributor Author

zianglih commented May 2, 2026

The following extended tests all passed:

python3 -m pytest --tb=auto tests/pytorch/nvfp4/test_nvfp4_quantize_exact.py
python3 -m pytest --tb=auto tests/pytorch/nvfp4/test_nvfp4_gemm_exact.py
python3 -m pytest --tb=auto tests/pytorch/test_backward_override.py
python3 -m pytest --tb=auto tests/pytorch/test_sanity.py
python3 -m pytest --tb=auto tests/pytorch/test_recipe.py
python3 -m pytest --tb=auto tests/pytorch/test_torch_compile.py
python3 -m pytest --tb=auto tests/pytorch/test_cpu_offloading.py
PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 NVTE_FUSED_ATTN=0 python3 -m pytest --tb=auto tests/pytorch/test_cuda_graphs.py
NVTE_TEST_NVINSPECT_ENABLED=1 NVTE_TEST_NVINSPECT_CONFIG_FILE=tests/pytorch/debug/test_configs/dummy_feature.yaml NVTE_TEST_NVINSPECT_FEATURE_DIRS=transformer_engine/debug/features PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 python3 -m pytest --tb=auto tests/pytorch/test_sanity.py

cd /root/TransformerEngine/tests/cpp
cmake --build build -j200
TEST_BIN="$(find build -type f -name test_operator -perm -u+x | head -n 1)"
"$TEST_BIN" --gtest_filter='*FusedCastTransposeNVFP4*:*DequantizeNVFP4*'
EOF

zianglih added 6 commits May 4, 2026 23:42
Signed-off-by: Ziang Li <[email protected]>
Signed-off-by: Ziang Li <[email protected]>
Signed-off-by: Ziang Li <[email protected]>
Signed-off-by: Ziang Li <[email protected]>
Signed-off-by: Ziang Li <[email protected]>
@zianglih zianglih marked this pull request as ready for review May 5, 2026 08:12
@zianglih
Copy link
Copy Markdown
Contributor Author

zianglih commented May 5, 2026

Hi @ptrendx , I have removed the standalone implementation and extended existing kernel to support this per-token nvfp4 recipe.

with_post_rht_amax=qparams.random_hadamard_transform,
with_2d_quantization=qparams.fp4_2d_quantization,
stochastic_rounding=qparams.stochastic_rounding,
per_token_activation=self.recipe.per_token_activation and idx % 3 != 1,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Hardcoded idx % 3 != 1 pattern silently misassigns per-token to wrong quantizers

The expression self.recipe.per_token_activation and idx % 3 != 1 assumes that every third quantizer (index 1, 4, 7, …) is always a weight quantizer and should be skipped for per-token scaling. This works for standard Linear / LayerNormLinear layers but is not documented and will silently produce wrong results if the quantizer ordering changes (e.g., MoE layers, attention layers with additional quantizers, or future refactors). The intent should either be codified as a named constant or enforced by tagging each quantizer by its semantic role (activation vs. weight) rather than position.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should be able to get rid of this hack once #2620 is merged.

Signed-off-by: Ziang Li <[email protected]>
@zianglih
Copy link
Copy Markdown
Contributor Author

zianglih commented May 5, 2026

The functionality has been verified by nvfp4 rl experiment.

Copy link
Copy Markdown
Collaborator

@timmoon10 timmoon10 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall this is a nice feature, but we should make some changes to the core design. My biggest suggestions:

  • Row-scaling should be part of the tensor data and not just hidden in the quantizer. You need to be aware of it for both quantization and dequantization, so it can't be hidden in the quantizer.
  • We should enable row-scaling based on a bool flag rather than the amax tensor shape. We should also make sure it is clearly documented.
  • We should consider a better name like "1D scaling" or "row scaling" since I don't see any reason this is specific to tokens or activations.

Comment thread transformer_engine/common/cast/nvfp4/quantize_transpose_nvfp4.cuh
Comment thread transformer_engine/common/include/transformer_engine/transformer_engine.h Outdated
Comment on lines +373 to +374
/*! Whether to enable per-token (per-row) NVFP4 quantization */
kNVTEQuantizationConfigNVFP4PerTokenActivation = 8,
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should configure this in NVTETensor rather than in NVTEQuantizationConfig. The quantization config is used for quantization, and is not available to any downstream consumers (dequant or GEMM). However, consumers need to be aware of tensor-scaling vs row-scaling. The buffer sizes are different, and getting it wrong means incorrect values or segfaults.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Dequantization handles row-scaled NVFP4, but we need to make sure that other quantized tensor consumers also handle it. Erroring out is fine for now. Currently the only other consumers we need to handle are GEMM and attention, although we should keep this mind if we add more features in the future.

Comment thread transformer_engine/common/cast/nvfp4/dequantize_nvfp4.cuh Outdated
Comment thread transformer_engine/pytorch/cpp_extensions/gemm.py Outdated
Comment thread transformer_engine/pytorch/csrc/extensions/cast.cpp Outdated
Comment thread transformer_engine/pytorch/csrc/extensions/cast.cpp Outdated
Comment thread tests/pytorch/test_recipe.py Outdated


@pytest.mark.skipif(not fp4_available, reason=reason_for_no_fp4)
def test_nvfp4_per_token_quantizer_roles():
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test will need to be updated once #2620 merges.

Comment thread transformer_engine/pytorch/tensor/nvfp4_tensor.py
@zianglih zianglih marked this pull request as draft May 5, 2026 23:07
zianglih added 13 commits May 5, 2026 17:10
Signed-off-by: Ziang Li <[email protected]>
Signed-off-by: Ziang Li <[email protected]>
Signed-off-by: Ziang Li <[email protected]>
Signed-off-by: Ziang Li <[email protected]>
Signed-off-by: Ziang Li <[email protected]>
Signed-off-by: Ziang Li <[email protected]>
Signed-off-by: Ziang Li <[email protected]>
Signed-off-by: Ziang Li <[email protected]>
Signed-off-by: Ziang Li <[email protected]>
@zianglih zianglih marked this pull request as ready for review May 6, 2026 06:08
@zianglih
Copy link
Copy Markdown
Contributor Author

zianglih commented May 6, 2026

Hi @timmoon10 , thank you for your comments! I have refactored accordingly and rename this to row_scaled_nvfp4, no per-token. In recipe we still have row_scaled_activation, which indicates 1d2d gemm where only the activation is modified to 1d. The control flow is now based on bool flags instead of implicitly inferred by shapes.

@zianglih zianglih changed the title Implement per-token NVFP4 fprop recipe Implement row scaled NVFP4 fprop recipe May 6, 2026
@zianglih zianglih changed the title Implement row scaled NVFP4 fprop recipe Implement row-scaled NVFP4 fprop recipe May 6, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

community-contribution PRs from external contributor outside the core maintainers, representing community-driven work.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants