Skip to content

Spherical Harmonics Call Fails Internally For 2nd Order Derivatives #286

Description

@rwkeane

Describe the bug
When computing second order derivatives for l=0 spherical harmonics in eager mode, the call fails

import torch
from cuequivariance_torch.operations.spherical_harmonics import (
    SphericalHarmonics as LibrarySphericalHarmonics,
)

device = torch.device("cuda:0")
sph_mod = LibrarySphericalHarmonics(
    [0],
    normalize=True,
    device=device,
    math_dtype=torch.float32,
)

raw = torch.randn(16, 3, device=device)
x = (raw / raw.norm(dim=-1, keepdim=True)).detach()

def f(x_in):
    return sph_mod(x_in).pow(2).sum().pow(2)

def via_vjp(x_in):
    out, vjp_fn = torch.func.vjp(f, x_in)
    return vjp_fn(torch.ones_like(out))[0].sum()

torch.func.grad(via_vjp)(x)

This code results in

  File "/home/ryan/src/dir/repro_spharm_double_backward.py", line 27, in <module>
    torch.func.grad(via_vjp)(x)
  File "/home/ryan/anaconda3/envs/environment/lib/python3.12/site-packages/torch/_functorch/apis.py", line 398, in wrapper
    return eager_transforms.grad_impl(func, argnums, has_aux, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ryan/anaconda3/envs/environment/lib/python3.12/site-packages/torch/_functorch/eager_transforms.py", line 1407, in grad_impl
    results = grad_and_value_impl(func, argnums, has_aux, args, kwargs)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ryan/anaconda3/envs/environment/lib/python3.12/site-packages/torch/_functorch/vmap.py", line 48, in fn
    return f(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^
  File "/home/ryan/anaconda3/envs/environment/lib/python3.12/site-packages/torch/_functorch/eager_transforms.py", line 1391, in grad_and_value_impl
    flat_grad_input = _autograd_grad(
                      ^^^^^^^^^^^^^^^
  File "/home/ryan/anaconda3/envs/environment/lib/python3.12/site-packages/torch/_functorch/eager_transforms.py", line 142, in _autograd_grad
    grad_inputs = torch.autograd.grad(
                  ^^^^^^^^^^^^^^^^^^^^
  File "/home/ryan/anaconda3/envs/environment/lib/python3.12/site-packages/torch/autograd/__init__.py", line 503, in grad
    result = _engine_run_backward(
             ^^^^^^^^^^^^^^^^^^^^^
  File "/home/ryan/anaconda3/envs/environment/lib/python3.12/site-packages/torch/autograd/graph.py", line 829, in _engine_run_backward
    return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ryan/anaconda3/envs/environment/lib/python3.12/site-packages/torch/autograd/function.py", line 311, in apply
    return user_fn(self, *args)
           ^^^^^^^^^^^^^^^^^^^^
  File "/home/ryan/anaconda3/envs/environment/lib/python3.12/site-packages/torch/_library/autograd.py", line 198, in new_backward
    raise RuntimeError(
RuntimeError: Expected the return from backward to be of the same structure as the inputs. Got: TreeSpec(tuple, None, [*,
  *,
  *,
  *,
  *,
  *,
  *,
  *,
  *,
  *,
  *,
  *,
  *,
  *,
  *,
  *,
  *,
  *,
  *,
  *,
  TreeSpec(list, None, [*,
    *]),
  *]) (return from backward), TreeSpec(tuple, None, [*,
  *,
  *,
  *,
  *,
  *,
  *,
  *,
  *,
  *,
  *,
  *,
  TreeSpec(list, None, []),
  TreeSpec(list, None, []),
  TreeSpec(list, None, []),
  TreeSpec(list, None, []),
  TreeSpec(list, None, []),
  *,
  *,
  *,
  TreeSpec(list, None, [*,
    *]),
  *]) (inputs)

Expected behavior
This should compute spherical harmonics

GPU HW/SW(please complete the following information):

  • CUDA Version: 13.0
  • torch or ngc docker version: PyTorch 2.8
  • Driver Version: 580.126.09
  • full name of GPU: NVIDIA GeForce RTX 5060 Ti
    -For performance issues, please also add memory, power limit, clock: 33% 55C P0 55W / 180W | 7060MiB / 16311MiB

⚠️ Academic Benchmarking Impact & Context

Real-World Motivation:
This defect was isolated during an end-to-end runtime evaluation for an upcoming machine learning systems publication comparing user-space hardware orchestration paradigms with optimized library backends .

We are highly motivated to feature cuequivariance as a baseline for our higher-order derivative execution sweeps. However, this and previously filed torch.func transformation failures are currently blocking us from gathering completed head-to-head performance datasets for these workloads.

As a result, we will be required to document in our paper's architectural compatibility matrix that cuequivariance does not currently compose with functional transformation layers for higher-order automatic differentiation out-of-the-box. We are providing this for full visibility to the maintainers and to establish a clear tracking reference for the open-source community as these advanced scenarios evolve.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions