Skip to content

Out of shared memory on blackwell architectures with embedding size 128 #244

Description

@psturmfels

Describe the bug
The documentation (https://docs.nvidia.com/cuda/cuequivariance/api/generated/cuequivariance_torch.triangle_attention.html) states that:

Triangle attention kernel supports: all hidden_dim<=32 and divisible by 4 for tf32/fp32, and for all hidden_dim<=128 and divisible by 8 for bf16/fp16

I'm running cuequivariance_torch==0.8.1, cuequivariance-ops-cu13==0.8.1 and torch==2.10.0+cu130 with CUDA 13.0, and running on an RTX 6000 PRO Blackwell series GPU. When attempting to run the kernel on an input with a hidden_dim=128, I get a CUDA error (see below for full trace).

To Reproduce

import math

import torch
from cuequivariance_torch import triangle_attention


def main():
    device = torch.device("cuda")
    # Set up dimensions
    batch_size, seq_len, num_heads, hidden_dim = 2, 512, 4, 128
    # Create input tensors on GPU with float16 precision
    q = torch.randn(
        batch_size, seq_len, num_heads, seq_len, hidden_dim, device=device, dtype=torch.float16, requires_grad=True
    )
    k = torch.randn(
        batch_size, seq_len, num_heads, seq_len, hidden_dim, device=device, dtype=torch.float16, requires_grad=True
    )
    v = torch.randn(
        batch_size, seq_len, num_heads, seq_len, hidden_dim, device=device, dtype=torch.float16, requires_grad=True
    )
    bias = torch.randn(
        batch_size, 1, num_heads, seq_len, seq_len, device=device, dtype=torch.float16, requires_grad=True
    )
    # Create optional mask
    mask = torch.rand(batch_size, seq_len, 1, 1, seq_len, device=device) < 0.5
    # Calculate scale
    scale = 1 / math.sqrt(hidden_dim)
    # Forward pass
    output, lse, max_val = triangle_attention(q=q, k=k, v=v, bias=bias, mask=mask, scale=scale, return_aux=True)
    print(output.shape)
    # Create gradient tensor and perform backward pass
    grad_out = torch.randn_like(output)
    output.backward(grad_out)
    # Access gradients
    print(q.grad.shape)
    print(k.grad.shape)
    print(v.grad.shape)
    print(bias.grad.shape)


if __name__ == "__main__":
    main()

Expected behavior
I believe, according to the documentation, that this input should have run properly without errors.

Screenshots
The full stack trace is below.

/net/home/pascal/sandbox-pascal/.pixi/envs/default/lib/python3.12/site-packages/cuequivariance_ops_torch/triangle_attention.py:165: UserWarning: Non-SM100f kernel expects bias to be float32 so it's going to be cast to torch.float32. Check if you can change your code for maximum performance.
  warnings.warn(
torch.Size([2, 512, 4, 512, 96])
Traceback (most recent call last):
  File "/net/home/pascal/sandbox-pascal/scripts/misc/test_kernel.py", line 42, in <module>
    main()
  File "/net/home/pascal/sandbox-pascal/scripts/misc/test_kernel.py", line 33, in main
    output.backward(grad_out)
  File "/net/home/pascal/sandbox-pascal/.pixi/envs/default/lib/python3.12/site-packages/torch/_tensor.py", line 630, in backward
    torch.autograd.backward(
  File "/net/home/pascal/sandbox-pascal/.pixi/envs/default/lib/python3.12/site-packages/torch/autograd/__init__.py", line 364, in backward
    _engine_run_backward(
  File "/net/home/pascal/sandbox-pascal/.pixi/envs/default/lib/python3.12/site-packages/torch/autograd/graph.py", line 865, in _engine_run_backward
    return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/net/home/pascal/sandbox-pascal/.pixi/envs/default/lib/python3.12/site-packages/torch/autograd/function.py", line 317, in apply
    return user_fn(self, *args)
           ^^^^^^^^^^^^^^^^^^^^
  File "/net/home/pascal/sandbox-pascal/.pixi/envs/default/lib/python3.12/site-packages/torch/_library/autograd.py", line 78, in backward
    result = info._backward_fn(ctx, *grads)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/net/home/pascal/sandbox-pascal/.pixi/envs/default/lib/python3.12/site-packages/cuequivariance_ops_torch/triangle_attention.py", line 432, in _backward
    d_q, d_k, d_v, dbias = torch.ops.cuequivariance.triangle_attention_bwd(
                           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/net/home/pascal/sandbox-pascal/.pixi/envs/default/lib/python3.12/site-packages/torch/_ops.py", line 1209, in __call__
    return self._op(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/net/home/pascal/sandbox-pascal/.pixi/envs/default/lib/python3.12/site-packages/torch/_library/autograd.py", line 112, in autograd_impl
    result = forward_no_grad(*args, Metadata(keyset, keyword_only_args))
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/net/home/pascal/sandbox-pascal/.pixi/envs/default/lib/python3.12/site-packages/torch/_library/autograd.py", line 41, in forward_no_grad
    result = op.redispatch(keyset & _C._after_autograd_keyset, *args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/net/home/pascal/sandbox-pascal/.pixi/envs/default/lib/python3.12/site-packages/torch/_ops.py", line 826, in redispatch
    return self._handle.redispatch_boxed(keyset, *args, **kwargs)  # type: ignore[return-value]
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/net/home/pascal/sandbox-pascal/.pixi/envs/default/lib/python3.12/site-packages/torch/_library/custom_ops.py", line 347, in backend_impl
    result = self._backend_fns[device_type](*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/net/home/pascal/sandbox-pascal/.pixi/envs/default/lib/python3.12/site-packages/torch/_compile.py", line 54, in inner
    return disable_fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/net/home/pascal/sandbox-pascal/.pixi/envs/default/lib/python3.12/site-packages/torch/_dynamo/eval_frame.py", line 1181, in _fn
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/net/home/pascal/sandbox-pascal/.pixi/envs/default/lib/python3.12/site-packages/torch/_library/custom_ops.py", line 382, in wrapped_fn
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/net/home/pascal/sandbox-pascal/.pixi/envs/default/lib/python3.12/site-packages/cuequivariance_ops_torch/triangle_attention.py", line 292, in _
    ops.triangle_attention_bwd(
RuntimeError: CUDA error: "invalid argument" at bwd_fmha.cu:193
Failed call: cudaFuncSetAttribute( (const void*)*kernel_ptr, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)

GPU HW/SW(please complete the following information):

  • CUDA toolkit versions:
  • cuda-bindings==13.0.3 ; sys_platform == 'linux'
  • nvidia-cuda-nvrtc==13.0.88 ; sys_platform == 'linux'
  • nvidia-cuda-runtime==13.0.96 ; sys_platform == 'linux'
  • nvidia-cuda-cupti==13.0.85 ; sys_platform == 'linux'
  • torch==2.10.0+cu130
  • Driver version: 13.0
  • full name of GPU: NVIDIA RTX PRO 6000 Blackwell Server Edition

Any advice would be appreciated. Let me know if I can add more context.

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions