import torch
from cuequivariance_torch.operations.spherical_harmonics import (
SphericalHarmonics as LibrarySphericalHarmonics,
)
device = torch.device("cuda:0")
sph_mod = LibrarySphericalHarmonics(
[0],
normalize=True,
device=device,
math_dtype=torch.float32,
)
raw = torch.randn(16, 3, device=device)
x = (raw / raw.norm(dim=-1, keepdim=True)).detach()
def f(x_in):
return sph_mod(x_in).pow(2).sum().pow(2)
def via_vjp(x_in):
out, vjp_fn = torch.func.vjp(f, x_in)
return vjp_fn(torch.ones_like(out))[0].sum()
torch.func.grad(via_vjp)(x)
File "/home/ryan/src/dir/repro_spharm_double_backward.py", line 27, in <module>
torch.func.grad(via_vjp)(x)
File "/home/ryan/anaconda3/envs/environment/lib/python3.12/site-packages/torch/_functorch/apis.py", line 398, in wrapper
return eager_transforms.grad_impl(func, argnums, has_aux, args, kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ryan/anaconda3/envs/environment/lib/python3.12/site-packages/torch/_functorch/eager_transforms.py", line 1407, in grad_impl
results = grad_and_value_impl(func, argnums, has_aux, args, kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ryan/anaconda3/envs/environment/lib/python3.12/site-packages/torch/_functorch/vmap.py", line 48, in fn
return f(*args, **kwargs)
^^^^^^^^^^^^^^^^^^
File "/home/ryan/anaconda3/envs/environment/lib/python3.12/site-packages/torch/_functorch/eager_transforms.py", line 1391, in grad_and_value_impl
flat_grad_input = _autograd_grad(
^^^^^^^^^^^^^^^
File "/home/ryan/anaconda3/envs/environment/lib/python3.12/site-packages/torch/_functorch/eager_transforms.py", line 142, in _autograd_grad
grad_inputs = torch.autograd.grad(
^^^^^^^^^^^^^^^^^^^^
File "/home/ryan/anaconda3/envs/environment/lib/python3.12/site-packages/torch/autograd/__init__.py", line 503, in grad
result = _engine_run_backward(
^^^^^^^^^^^^^^^^^^^^^
File "/home/ryan/anaconda3/envs/environment/lib/python3.12/site-packages/torch/autograd/graph.py", line 829, in _engine_run_backward
return Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ryan/anaconda3/envs/environment/lib/python3.12/site-packages/torch/autograd/function.py", line 311, in apply
return user_fn(self, *args)
^^^^^^^^^^^^^^^^^^^^
File "/home/ryan/anaconda3/envs/environment/lib/python3.12/site-packages/torch/_library/autograd.py", line 198, in new_backward
raise RuntimeError(
RuntimeError: Expected the return from backward to be of the same structure as the inputs. Got: TreeSpec(tuple, None, [*,
*,
*,
*,
*,
*,
*,
*,
*,
*,
*,
*,
*,
*,
*,
*,
*,
*,
*,
*,
TreeSpec(list, None, [*,
*]),
*]) (return from backward), TreeSpec(tuple, None, [*,
*,
*,
*,
*,
*,
*,
*,
*,
*,
*,
*,
TreeSpec(list, None, []),
TreeSpec(list, None, []),
TreeSpec(list, None, []),
TreeSpec(list, None, []),
TreeSpec(list, None, []),
*,
*,
*,
TreeSpec(list, None, [*,
*]),
*]) (inputs)
As a result, we will be required to document in our paper's architectural compatibility matrix that cuequivariance does not currently compose with functional transformation layers for higher-order automatic differentiation out-of-the-box. We are providing this for full visibility to the maintainers and to establish a clear tracking reference for the open-source community as these advanced scenarios evolve.
Describe the bug
When computing second order derivatives for l=0 spherical harmonics in eager mode, the call fails
This code results in
Expected behavior
This should compute spherical harmonics
GPU HW/SW(please complete the following information):
-For performance issues, please also add memory, power limit, clock: 33% 55C P0 55W / 180W | 7060MiB / 16311MiB
Real-World Motivation:
This defect was isolated during an end-to-end runtime evaluation for an upcoming machine learning systems publication comparing user-space hardware orchestration paradigms with optimized library backends .
We are highly motivated to feature cuequivariance as a baseline for our higher-order derivative execution sweeps. However, this and previously filed
torch.functransformation failures are currently blocking us from gathering completed head-to-head performance datasets for these workloads.As a result, we will be required to document in our paper's architectural compatibility matrix that cuequivariance does not currently compose with functional transformation layers for higher-order automatic differentiation out-of-the-box. We are providing this for full visibility to the maintainers and to establish a clear tracking reference for the open-source community as these advanced scenarios evolve.