Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
50 commits
Select commit Hold shift + click to select a range
1356f9f
Add array namespace option for field buffers
kburns May 27, 2025
f7fb6c9
Add array-api-compat to setup.py
kburns May 27, 2025
63648b8
Allow specifying array namespace by string
kburns May 27, 2025
ad50b20
Try fixing cupy allocation from buffer
kburns May 27, 2025
3f8692d
Fix cupy check
kburns May 27, 2025
8b42957
Add cupy-based complex fourier MMT
kburns May 27, 2025
bebf8dd
Fix transform lookup
kburns May 27, 2025
f04bdf6
Make fill_random array and dtype compatible
kburns May 27, 2025
64b91c4
Work on cupy real fourier MMTs
kburns May 27, 2025
295158a
Generalize Fourier basis for more dtypes
kburns May 27, 2025
b59d1ec
Add cupy complex FFT
kburns May 28, 2025
00bd266
Add cupy real fft
kburns May 28, 2025
6036983
Fix dtype conversion
kburns May 28, 2025
19e2350
Add array compat for basic arithmetic
kburns May 28, 2025
a3bda50
Beginning adding array_compat to operators
kburns May 28, 2025
28f76e7
Quick implementation of apply_sparse for cupy
kburns Jul 22, 2025
adecfa9
Make einsum in dot compatible with cupy
kburns Jul 22, 2025
6b9ff23
Add custom kernel for cupy csr middle dot product
kburns Jul 22, 2025
8ee70d3
Convert local grids/modes to device arrays
kburns Jul 22, 2025
689e41f
Explicitly cast data norms to float
kburns Jul 22, 2025
c226735
Cast grid spacing to device array in cartesian cfl
kburns Jul 22, 2025
ad30363
Convert field data gathers to numpy on gpu
kburns Jul 22, 2025
b7a7188
Fix subsystem gather/scatter to copy to/from gpu
kburns Jul 22, 2025
ffe0719
Allow for non-contiguous device copy
kburns Jul 22, 2025
07f43fd
Fix cupy csr kernel for double instead of float
kburns Jul 22, 2025
799064a
Move subsystems, coeff systems, and matrices to GPU
kburns Jul 25, 2025
d99b5b5
Build custom cupy superlu wrapper to reuse spsm descriptors
kburns Jul 25, 2025
ef66c99
Move all operator matrices to device. Add Chebyshev transforms
kburns Jul 25, 2025
746a2a6
Make einsum in trace compatible with cupy
csskene Dec 16, 2025
18e713a
Fix dtype for MultiplyNumberField
csskene Mar 10, 2026
9ed5d0f
Specify dtype for the CFL reducer
csskene Apr 28, 2026
9fbba8d
Ensure timestepping coefficients are the correct dtype
csskene Apr 28, 2026
ce4a1e4
Convert Jacobi conversion matrices to specified dtype
csskene Apr 28, 2026
2d2a0be
Specify dtype for GlobalFlowProperty reducer
csskene Apr 29, 2026
0bf770d
Check if buff size grows even for same spsm descriptor. If it does ma…
csskene Apr 29, 2026
4cc47ac
Add custom kernel for apply_csr_mid for dtype float32
csskene Apr 29, 2026
b85f270
Convert subspace matrices to specified dtype
csskene Apr 29, 2026
5aa0f1f
Use jit.rawkernel for apply_csr_mid_kernel
csskene May 12, 2026
a124d8e
Accumulate into acc in apply_csr_mid_kernel
csskene May 12, 2026
99655af
Update distributor docstring
kburns May 12, 2026
b6a8813
Start simplifying transform library defaults (broken for curvilinear)
kburns May 12, 2026
be949c6
Suppress docstring warnings
kburns May 12, 2026
092fe7f
Add mock jit so linalg_gpu is still importable without cupy
kburns May 12, 2026
9bcf92a
Add config option for default gpu subproblem coupling
kburns May 12, 2026
69039c8
Replace np with xp in timesteppers
csskene May 15, 2026
9b1bddd
Set matrix transform as default for Chebyshev
kburns Jun 4, 2026
75e43d2
Separate axpy for cpu and device in timesteppers
kburns Jun 19, 2026
392b0e1
Try caching einsum_path for dot product
kburns Jun 19, 2026
721fdfd
Add warning for non-cartesian problems
kburns Jun 24, 2026
40fd3b4
Add GPU doc page
kburns Jun 24, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 22 additions & 5 deletions dedalus/core/arithmetic.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import numexpr as ne
from collections import defaultdict
from math import prod
import array_api_compat

from .domain import Domain
from .field import Operand, Field
Expand Down Expand Up @@ -245,10 +246,11 @@ def choose_layout(self):

def operate(self, out):
"""Perform operation."""
xp = self.array_namespace
arg0, arg1 = self.args
# Set output layout
out.preset_layout(arg0.layout)
np.add(arg0.data, arg1.data, out=out.data)
xp.add(arg0.data, arg1.data, out=out.data)


# used for einsum string manipulation
Expand Down Expand Up @@ -616,6 +618,7 @@ def __init__(self, arg0, arg1, indices=(-1,0), out=None, **kw):
arg2_str = arg2_str.replace(arg2_str[indices[1]], 'z')
out_str = (arg1_str + arg2_str).replace('z', '')
self.einsum_str = arg1_str + '...,' + arg2_str + '...->' + out_str + '...'
self.einsum_path = None

def _check_indices(self, arg0, arg1, indices):
if (not isinstance(arg0, Operand)) or (not isinstance(arg1, Operand)):
Expand Down Expand Up @@ -664,14 +667,26 @@ def GammaCoord(self, A_tensorsig, B_tensorsig, C_tensorsig):
return G

def operate(self, out):
xp = self.array_namespace
arg0, arg1 = self.args
out.preset_layout(arg0.layout)
# Broadcast
arg0_data = self.arg0_ghost_broadcaster.cast(arg0)
arg1_data = self.arg1_ghost_broadcaster.cast(arg1)
# Call einsum
if out.data.size:
np.einsum(self.einsum_str, arg0_data, arg1_data, out=out.data, optimize=True)
if array_api_compat.is_cupy_namespace(xp):
if self.einsum_path is None:
self.einsum_path = self.get_einsum_path(xp.asnumpy(arg0_data), xp.asnumpy(arg1_data))
# Cupy does not support output keyword
out.data[:] = xp.einsum(self.einsum_str, arg0_data, arg1_data, optimize=self.einsum_path)
else:
if self.einsum_path is None:
self.einsum_path = self.get_einsum_path(arg0_data, arg1_data)
xp.einsum(self.einsum_str, arg0_data, arg1_data, out=out.data, optimize=self.einsum_path)

def get_einsum_path(self, arg0_data, arg1_data):
return np.einsum_path(self.einsum_str, arg0_data, arg1_data, optimize="optimal")[0]


@alias("cross")
Expand Down Expand Up @@ -854,6 +869,7 @@ def __init__(self, arg0, arg1, out=None, **kw):

def operate(self, out):
"""Perform operation."""
xp = self.array_namespace
arg0, arg1 = self.args
# Set output layout
out.preset_layout(arg0.layout)
Expand All @@ -863,7 +879,7 @@ def operate(self, out):
# Reshape arg data to broadcast properly for output tensorsig
arg0_exp_data = arg0_data.reshape(self.arg0_exp_tshape + arg0_data.shape[len(arg0.tensorsig):])
arg1_exp_data = arg1_data.reshape(self.arg1_exp_tshape + arg1_data.shape[len(arg1.tensorsig):])
np.multiply(arg0_exp_data, arg1_exp_data, out=out.data)
xp.multiply(arg0_exp_data, arg1_exp_data, out=out.data)


class GhostBroadcaster:
Expand Down Expand Up @@ -919,7 +935,7 @@ def __init__(self, arg0, arg1, out=None,**kw):
super().__init__(arg0, arg1, out=out)
self.domain = arg1.domain
self.tensorsig = arg1.tensorsig
self.dtype = np.result_type(type(arg0), arg1.dtype)
self.dtype = np.result_type(arg0, arg1.dtype)

@classmethod
def _check_args(cls, *args, **kw):
Expand All @@ -939,11 +955,12 @@ def enforce_conditions(self):

def operate(self, out):
"""Perform operation."""
xp = self.array_namespace
arg0, arg1 = self.args
# Set output layout
out.preset_layout(arg1.layout)
# Multiply argument data
np.multiply(arg0, arg1.data, out=out.data)
xp.multiply(arg0, arg1.data, out=out.data)

def matrix_dependence(self, *vars):
return self.args[1].matrix_dependence(*vars)
Expand Down
73 changes: 45 additions & 28 deletions dedalus/core/basis.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from functools import reduce
import inspect
from math import prod
import array_api_compat

from . import operators
from ..libraries import spin_recombination
Expand All @@ -14,7 +15,7 @@
from ..tools import clenshaw
from ..tools.array import reshape_vector, axindex, axslice, interleave_matrices
from ..tools.dispatch import MultiClass, SkipDispatchException
from ..tools.general import unify, DeferredTuple
from ..tools.general import unify, DeferredTuple, is_real_dtype, is_complex_dtype
from .coords import Coordinate, CartesianCoordinates, S2Coordinates, SphericalCoordinates, PolarCoordinates, AzimuthalCoordinate, DirectProduct
from .domain import Domain
from .field import Operand, LockedField
Expand Down Expand Up @@ -595,8 +596,10 @@ class Jacobi(IntervalBasis, metaclass=CachedClass):
group_shape = (1,)
native_bounds = (-1, 1)
transforms = {}
default_dct = "fftw_dct"
default_library = "matrix"
default_cpu_library = "matrix"
default_gpu_library = "matrix"
default_cpu_dct = "fftw"
default_gpu_dct = "matrix"

@classmethod
def _preprocess_cache_args(cls, coord, size, bounds, a, b, a0, b0, dealias, library):
Expand Down Expand Up @@ -631,12 +634,6 @@ def _preprocess_cache_args(cls, coord, size, bounds, a, b, a0, b0, dealias, libr
dealias = tuple(dealias)
if len(dealias) != 1:
raise ValueError("Jacobi dealias must have length 1.")
# library: pick default based on (a0, b0)
if library is None:
if a0 == b0 == -1/2:
library = cls.default_dct
else:
library = cls.default_library
return (coord, size, bounds, a, b, a0, b0, dealias, library)

def __init__(self, coord, size, bounds, a, b, a0=None, b0=None, dealias=(1,), library=None):
Expand All @@ -660,10 +657,30 @@ def _native_grid(self, scale):
N, = self.grid_shape((scale,))
return jacobi.build_grid(N, a=self.a0, b=self.b0)

def get_library(self, dist):
"""Get library for transforms."""
if self.library is None:
if self.a0 == self.b0 == -1/2:
if dist.is_cupy_namespace:
return self.default_gpu_dct
else:
return self.default_cpu_dct
else:
if dist.is_cupy_namespace:
return self.default_gpu_library
else:
return self.default_cpu_library
else:
return self.library

@CachedMethod
def transform_plan(self, dist, grid_size):
"""Build transform plan."""
return self.transforms[self.library](grid_size, self.size, self.a, self.b, self.a0, self.b0)
# Shortcut trivial transforms
if grid_size == 1 or self.size == 1:
return self.transforms["matrix"](grid_size, self.size, self.a, self.b, self.a0, self.b0, dist.array_namespace, dist.dtype)
else:
return self.transforms[self.get_library(dist)](grid_size, self.size, self.a, self.b, self.a0, self.b0, dist.array_namespace, dist.dtype)

# def weights(self, scales):
# """Gauss-Jacobi weights."""
Expand Down Expand Up @@ -975,7 +992,8 @@ class FourierBase(IntervalBasis):
"""Base class for RealFourier and ComplexFourier."""

native_bounds = (0, 2*np.pi)
default_library = "fftw"
default_gpu_library = "cupy"
default_cpu_library = "fftw"

@classmethod
def _preprocess_cache_args(cls, coord, size, bounds, dealias, library):
Expand All @@ -998,9 +1016,6 @@ def _preprocess_cache_args(cls, coord, size, bounds, dealias, library):
dealias = tuple(dealias)
if len(dealias) != 1:
raise ValueError("Fourier dealias must have length 1.")
# library: pick default based on (a0, b0)
if library is None:
library = cls.default_library
return (coord, size, bounds, dealias, library)

def __init__(self, coord, size, bounds, dealias=(1,), library=None):
Expand Down Expand Up @@ -1069,14 +1084,24 @@ def _native_grid(self, scale):
N, = self.grid_shape((scale,))
return (2 * np.pi / N) * np.arange(N)

def get_library(self, dist):
"""Get library for transforms."""
if self.library is None:
if dist.is_cupy_namespace:
return self.default_gpu_library
else:
return self.default_cpu_library
else:
return self.library

@CachedMethod
def transform_plan(self, dist, grid_size):
"""Build transform plan."""
# Shortcut trivial transforms
if grid_size == 1 or self.size == 1:
return self.transforms['matrix'](grid_size, self.size)
return self.transforms["matrix"](grid_size, self.size, dist.array_namespace, dist.dtype)
else:
return self.transforms[self.library](grid_size, self.size)
return self.transforms[self.get_library(dist)](grid_size, self.size, dist.array_namespace, dist.dtype)

def forward_transform(self, field, axis, gdata, cdata):
# Transform
Expand All @@ -1097,9 +1122,9 @@ def Fourier(*args, dtype=None, **kw):
"""Factory function dispatching to RealFourier and ComplexFourier based on provided dtype."""
if dtype is None:
raise ValueError("dtype must be specified")
elif dtype == np.float64:
elif is_real_dtype(dtype):
return RealFourier(*args, **kw)
elif dtype == np.complex128:
elif is_complex_dtype(dtype):
return ComplexFourier(*args, **kw)
else:
raise ValueError(f"Unrecognized dtype: {dtype}")
Expand Down Expand Up @@ -2204,15 +2229,6 @@ def _preprocess_cache_args(cls, coordsys, shape, dtype, radii, k, alpha, dealias
dealias = tuple(dealias)
if len(dealias) != 2:
raise ValueError("Annulus dealias must have length 2.")
# azimuth_library: pick default
if azimuth_library is None:
azimuth_library = RealFourier.default_library
# radius_library: pick default based on alpha
if radius_library is None:
if alpha[0] == alpha[1] == -1/2:
radius_library = Jacobi.default_dct
else:
radius_library = Jacobi.default_library
return (coordsys, shape, dtype, radii, k, alpha, dealias, azimuth_library, radius_library)

def __init__(self, coordsys, shape, dtype, radii=(1,2), k=0, alpha=(-0.5,-0.5), dealias=(1,1), azimuth_library=None, radius_library=None):
Expand Down Expand Up @@ -6238,6 +6254,7 @@ class CartesianAdvectiveCFL(operators.AdvectiveCFL):

@CachedMethod
def cfl_spacing(self):
xp = self.array_namespace
velocity = self.operand
coordsys = velocity.tensorsig[0]
spacing = []
Expand All @@ -6260,7 +6277,7 @@ def cfl_spacing(self):
axis_spacing[:] = dealias * native_spacing * basis.COV.stretch
elif basis is None:
axis_spacing = np.inf
spacing.append(axis_spacing)
spacing.append(xp.asarray(axis_spacing))
return spacing

def compute_cfl_frequency(self, velocity, out):
Expand Down
31 changes: 25 additions & 6 deletions dedalus/core/distributor.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
from math import prod
import numbers
from weakref import WeakSet
import array_api_compat
import warnings

from .coords import CoordinateSystem, DirectProduct
from ..tools.array import reshape_vector
Expand Down Expand Up @@ -39,12 +41,16 @@ class Distributor:

Parameters
----------
dim : int
Dimension
coordsystems : CoordinateSystem or tuple of CoordinateSystems
Problem coordinate systems
comm : MPI communicator, optional
MPI communicator (default: comm world)
mesh : tuple of ints, optional
Process mesh for parallelization (default: 1-D mesh of available processes)
dtype : data type, optional
Default data type for fields (default: None)
array_namespace : array namespace or string, optional
Array namespace for field data (e.g. numpy or cupy, default: numpy)

Attributes
----------
Expand Down Expand Up @@ -74,7 +80,7 @@ class Distributor:
states) and the paths between them (D transforms and R transposes).
"""

def __init__(self, coordsystems, comm=None, mesh=None, dtype=None):
def __init__(self, coordsystems, comm=None, mesh=None, dtype=None, array_namespace=np):
# Accept single coordsys in place of tuple/list
if not isinstance(coordsystems, (tuple, list)):
coordsystems = (coordsystems,)
Expand Down Expand Up @@ -115,6 +121,16 @@ def __init__(self, coordsystems, comm=None, mesh=None, dtype=None):
self._build_layouts()
# Keep set of weak field references
self.fields = WeakSet()
# Array module
if isinstance(array_namespace, str):
self.array_namespace = getattr(array_api_compat, array_namespace)
else:
self.array_namespace = array_api_compat.array_namespace(array_namespace.zeros(0))
self.is_numpy_namespace = array_api_compat.is_numpy_namespace(self.array_namespace)
self.is_cupy_namespace = array_api_compat.is_cupy_namespace(self.array_namespace)
# Warnings for non-Cartesian problems
if self.is_cupy_namespace and any(cs.curvilinear for cs in self.coordsystems):
warnings.warn("Non-Cartesian coordinate systems not yet supported on GPU.")

@CachedAttribute
def cs_by_axis(self):
Expand Down Expand Up @@ -255,11 +271,12 @@ def IdentityTensor(self, coordsys_in, coordsys_out=None, bases=None, dtype=None)
return I

def local_grid(self, basis, scale=None):
xp = self.array_namespace
# TODO: remove from bases and do it all here?
if scale is None:
scale = 1
if basis.dim == 1:
return basis.local_grid(self, scale=scale)
return xp.asarray(basis.local_grid(self, scale=scale))
else:
raise ValueError("Use `local_grids` for multidimensional bases.")

Expand Down Expand Up @@ -292,16 +309,18 @@ def local_grid(self, basis, scale=None):
# return tuple(grids)

def local_grids(self, *bases, scales=None):
xp = self.array_namespace
scales = self.remedy_scales(scales)
grids = []
for basis in bases:
basis_scales = scales[self.first_axis(basis):self.last_axis(basis)+1]
grids.extend(basis.local_grids(self, scales=basis_scales))
grids.extend(xp.asarray(basis.local_grids(self, scales=basis_scales)))
return grids

def local_modes(self, basis):
# TODO: remove from bases and do it all here?
return basis.local_modes(self)
xp = self.array_namespace
return xp.asarray(basis.local_modes(self))

@CachedAttribute
def default_nonconst_groups(self):
Expand Down
Loading