Implement row-scaled NVFP4 fprop recipe#2931
Conversation
Greptile SummaryThis PR adds a per-token (row-scaled) NVFP4 forward quantization recipe for TransformerEngine. When enabled via
Confidence Score: 4/5Safe to merge for the fprop-only use case that was tested; the unsupported backward path is blocked by a The row-scaled quantization logic — two-step kernel flow, per-row amax bookkeeping, neutralising global amax before cuBLAS, and post-GEMM FP32 rescaling — is mathematically correct and verified by bitwise-exact tests on B200. The main unresolved structural issue is that several
Important Files Changed
Sequence DiagramsequenceDiagram
participant Mod as Module Forward
participant Quant as NVFP4Quantizer (row_scaled=True)
participant RowAmax as compute_rowwise_amax kernel
participant QKernel as quantize_transpose kernel (ROW_SCALED_NVFP4)
participant PyGEMM as general_gemm (Python)
participant cuBLAS as cuBLAS NVFP4 GEMM
Mod->>Quant: quantize(activation)
Quant->>RowAmax: compute per-row abs-max to amax[0..M-1]
RowAmax-->>Quant: amax array written to tensor
Quant->>QKernel: quantize using amax[row_idx] per block
QKernel-->>Quant: NVFP4 tensor (row_scaled_nvfp4=True)
Quant-->>Mod: NVFP4Tensor with per-row amax metadata
Mod->>PyGEMM: general_gemm(A=weight, B=activation_fp4)
Note over PyGEMM: Detect _is_nvfp4_row_scaled_tensor(B)
PyGEMM->>PyGEMM: Replace amax to 1.0 in A and B metadata, capture rowwise_global_scales
PyGEMM->>cuBLAS: GEMM with amax=1.0 to FP32 output
cuBLAS-->>PyGEMM: out_fp32
PyGEMM->>PyGEMM: out_fp32 *= rowwise_global_scales per row, add bias in FP32
PyGEMM->>PyGEMM: cast to requested_out_dtype
PyGEMM-->>Mod: scaled output
Reviews (9): Last reviewed commit: "Minor" | Re-trigger Greptile |
| // Compute "correct" per-block encoding scaling factor | ||
| const float S_enc_b_fp8 = S_dec_b_fp32 == 0.f ? 0.f : S_enc / S_dec_b_fp32; | ||
| const float S_enc_b_fp8 = S_dec_b_fp32 == 0.f ? 0.f : | ||
| fminf(1.0f / (S_dec_b_fp32 * (1.0f / S_enc)), Numeric_Traits<float>::maxNorm); |
There was a problem hiding this comment.
We have to change here to stay aligned with pytorch reference.
Signed-off-by: Ziang Li <[email protected]> Co-authored-by: Yigong Qin <[email protected]>
Signed-off-by: Ziang Li <[email protected]>
Signed-off-by: Ziang Li <[email protected]>
Signed-off-by: Ziang Li <[email protected]>
Signed-off-by: Ziang Li <[email protected]>
Signed-off-by: Ziang Li <[email protected]>
Signed-off-by: Ziang Li <[email protected]>
Signed-off-by: Ziang Li <[email protected]>
Signed-off-by: Ziang Li <[email protected]>
Signed-off-by: Ziang Li <[email protected]>
Signed-off-by: Ziang Li <[email protected]>
Signed-off-by: Ziang Li <[email protected]>
Signed-off-by: Ziang Li <[email protected]>
Signed-off-by: Ziang Li <[email protected]>
6998f64 to
5b2f606
Compare
Signed-off-by: Ziang Li <[email protected]>
Signed-off-by: Ziang Li <[email protected]>
Signed-off-by: Ziang Li <[email protected]>
Signed-off-by: Ziang Li <[email protected]>
Signed-off-by: Ziang Li <[email protected]>
|
The following extended tests all passed: |
Signed-off-by: Ziang Li <[email protected]>
Signed-off-by: Ziang Li <[email protected]>
Signed-off-by: Ziang Li <[email protected]>
Signed-off-by: Ziang Li <[email protected]>
Signed-off-by: Ziang Li <[email protected]>
Signed-off-by: Ziang Li <[email protected]>
|
Hi @ptrendx , I have removed the standalone implementation and extended existing kernel to support this per-token nvfp4 recipe. |
| with_post_rht_amax=qparams.random_hadamard_transform, | ||
| with_2d_quantization=qparams.fp4_2d_quantization, | ||
| stochastic_rounding=qparams.stochastic_rounding, | ||
| per_token_activation=self.recipe.per_token_activation and idx % 3 != 1, |
There was a problem hiding this comment.
Hardcoded
idx % 3 != 1 pattern silently misassigns per-token to wrong quantizers
The expression self.recipe.per_token_activation and idx % 3 != 1 assumes that every third quantizer (index 1, 4, 7, …) is always a weight quantizer and should be skipped for per-token scaling. This works for standard Linear / LayerNormLinear layers but is not documented and will silently produce wrong results if the quantizer ordering changes (e.g., MoE layers, attention layers with additional quantizers, or future refactors). The intent should either be codified as a named constant or enforced by tagging each quantizer by its semantic role (activation vs. weight) rather than position.
There was a problem hiding this comment.
We should be able to get rid of this hack once #2620 is merged.
Signed-off-by: Ziang Li <[email protected]>
|
The functionality has been verified by nvfp4 rl experiment. |
timmoon10
left a comment
There was a problem hiding this comment.
Overall this is a nice feature, but we should make some changes to the core design. My biggest suggestions:
- Row-scaling should be part of the tensor data and not just hidden in the quantizer. You need to be aware of it for both quantization and dequantization, so it can't be hidden in the quantizer.
- We should enable row-scaling based on a bool flag rather than the amax tensor shape. We should also make sure it is clearly documented.
- We should consider a better name like "1D scaling" or "row scaling" since I don't see any reason this is specific to tokens or activations.
| /*! Whether to enable per-token (per-row) NVFP4 quantization */ | ||
| kNVTEQuantizationConfigNVFP4PerTokenActivation = 8, |
There was a problem hiding this comment.
We should configure this in NVTETensor rather than in NVTEQuantizationConfig. The quantization config is used for quantization, and is not available to any downstream consumers (dequant or GEMM). However, consumers need to be aware of tensor-scaling vs row-scaling. The buffer sizes are different, and getting it wrong means incorrect values or segfaults.
There was a problem hiding this comment.
Dequantization handles row-scaled NVFP4, but we need to make sure that other quantized tensor consumers also handle it. Erroring out is fine for now. Currently the only other consumers we need to handle are GEMM and attention, although we should keep this mind if we add more features in the future.
|
|
||
|
|
||
| @pytest.mark.skipif(not fp4_available, reason=reason_for_no_fp4) | ||
| def test_nvfp4_per_token_quantizer_roles(): |
There was a problem hiding this comment.
This test will need to be updated once #2620 merges.
Signed-off-by: Ziang Li <[email protected]>
Signed-off-by: Ziang Li <[email protected]>
Signed-off-by: Ziang Li <[email protected]>
Signed-off-by: Ziang Li <[email protected]>
Signed-off-by: Ziang Li <[email protected]>
Signed-off-by: Ziang Li <[email protected]>
Signed-off-by: Ziang Li <[email protected]>
Signed-off-by: Ziang Li <[email protected]>
Signed-off-by: Ziang Li <[email protected]>
Signed-off-by: Ziang Li <[email protected]>
Signed-off-by: Ziang Li <[email protected]>
Signed-off-by: Ziang Li <[email protected]>
Signed-off-by: Ziang Li <[email protected]>
|
Hi @timmoon10 , thank you for your comments! I have refactored accordingly and rename this to |
Description
@HumansAnd
Implement
per-tokenrow-scaled NVFP4 recipe with fprop only.Currently, the row-scaled scaling is handled by separate pytorch code.
Quantization kernels are bitwise exact with existing TE reference implementation.
The following tests passed on B200:
Type of change
Changes
Please list the changes introduced in this PR:
row_scaled_activationfield in nvfp4 recipe, can be turned on byNVTE_NVFP4_ROW_SCALED_ACTIVATIONNew per-token nvfp4 quantize kernels inNew quantization kernels folded into existing nvfp4 quantization kernels.transformer_engine/common/cast/nvfp4/quantize_pertoken_nvfp4.cuh, bitwise exact with existing TE pytorch reference implementation and per-tesor nvfp4 emulated implmentation.transformer_engine/common/cast/nvfp4/dequantize_nvfp4.cuhto correctly handle this row-scaled nvfp4TransformerEngine/transformer_engine/pytorch/cpp_extensions/gemm.py, if row-scaled nvfp4 is enabled, it conducts separate per-token scaling using pytorch code, after cublas gemmtests/cpp/operator/test_cast_nvfp4_transpose.cuto align with pytorch reference numericsChecklist: