Skip to content

Bug: cuEquivariance-jax not support the B200 / B300 | cuda12 #209

Description

@guyujun

The triangle_multiplicative_update triton kernel can not work with haiku & jax in B200/B300 NVIDIA-GPU (but H200,H100,H800,5090,4090,PRO6000 works!)

my test code:

import jax
from cuequivariance_jax import triangle_multiplicative_update
import haiku as hk
import jax.random as jrandom

class TritonNetwork(hk.Module):
    def __init__(self, name=None):
        super().__init__(name=name)
    
    def triangle_multiplicative_update_block(self, pair_act):
        assert len(pair_act.shape) == 4
        key = jax.random.key(0)
        
        pair_act = triangle_multiplicative_update(
            x=pair_act, 
            direction='outgoing',   # 'outgoing' or 'incoming'
            key=key
        )
        return pair_act

    def __call__(self, batch):
        act = self.triangle_multiplicative_update_block(batch['feat'])
        return act


def forward_triton(batch):
    network = TritonNetwork()
    return network(batch)

forward = hk.transform(forward_triton)
L = 256
feat_shape = (1, L, L, 128)

key = jrandom.PRNGKey(0)
feat_batch = {'feat': jrandom.normal(key, feat_shape)}
print('step1:')
params = forward.init(key, feat_batch)

print('step2:')
for i in range(10):
    key, apply_key = jax.random.split(key, 2)
    output = forward.apply(params, apply_key, feat_batch)
    # print(output)
    print(output.shape)

the code will die and stuck randomly.

However using fallback mode to Jax-based triangle_multiplicative_update works.

by modifying the fallback options in cuequivariance_jax/triangle/triangle_multiplicative_update.py /triangle_multiplicative_update

# Gated dual gemm
    ab = sigmoid_gated_dual_gemm(
        x,
        g_in_weight,
        p_in_weight,
        b1=g_in_bias,
        b2=p_in_bias,
        mask=mask,
        transpose_out=True,
        precision=precision,
        fallback=**False**,   #<-  this line to False to use jax based code.
    )
a, b = jnp.split(ab, 2, axis=0)

so I think there are something in the sigmoid_gated_dual_gemm kernel not work with B200/B300 Nvidia-GPU

but I can't giving more details, because I haven't the source code of the kernel.


my envs:

python3.11
cuequivariance-jax                      0.7.0rc2        pypi_0          pypi
cuequivariance-ops-jax-cu12     0.7.0            pypi_0           pypi
jax                                                 0.6.0            pypi_0           pypi
jax-cuda12-pjrt                            0.6.0            pypi_0           pypi
jax-cuda12-plugin                       0.6.0            pypi_0           pypi
jax-triton                                      0.3.0            pypi_0           pypi
jaxlib                                             0.6.0            pypi_0           pypi
jaxtyping                                      0.2.34           pypi_0          pypi
nvidia-cublas-cu12                    12.9.1.4         pypi_0           pypi
nvidia-cuda-cupti-cu12            12.9.79          pypi_0           pypi
nvidia-cuda-nvcc-cu12            12.9.86          pypi_0           pypi
nvidia-cuda-nvrtc-cu12           12.9.86          pypi_0           pypi
nvidia-cuda-runtime-cu12       12.9.79          pypi_0           pypi
nvidia-cudnn-cu12                   9.15.0.57       pypi_0           pypi
nvidia-cufft-cu12                     11.4.1.4         pypi_0           pypi
nvidia-cusolver-cu12               11.7.5.82        pypi_0           pypi
nvidia-cusparse-cu12             12.5.10.65      pypi_0           pypi
nvidia-ml-py                             13.580.82        pypi_0           pypi
nvidia-nccl-cu12                      2.28.7           pypi_0           pypi
nvidia-nvjitlink-cu12                12.9.86          pypi_0           pypi
nvidia-nvshmem-cu12             3.4.5            pypi_0           pypi
haiku=0.0.15

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions