From 1356f9f06ea45e6124cb1669c30ab29bed812fd4 Mon Sep 17 00:00:00 2001 From: "Keaton J. Burns" Date: Tue, 27 May 2025 14:47:38 +0100 Subject: [PATCH 01/50] Add array namespace option for field buffers --- dedalus/core/distributor.py | 6 +++++- dedalus/core/field.py | 23 ++++++++++++++--------- 2 files changed, 19 insertions(+), 10 deletions(-) diff --git a/dedalus/core/distributor.py b/dedalus/core/distributor.py index c4cc766f..6b2aedbd 100644 --- a/dedalus/core/distributor.py +++ b/dedalus/core/distributor.py @@ -10,6 +10,7 @@ from math import prod import numbers from weakref import WeakSet +import array_api_compat from .coords import CoordinateSystem, DirectProduct from ..tools.array import reshape_vector @@ -74,7 +75,7 @@ class Distributor: states) and the paths between them (D transforms and R transposes). """ - def __init__(self, coordsystems, comm=None, mesh=None, dtype=None): + def __init__(self, coordsystems, comm=None, mesh=None, dtype=None, array_namespace=np): # Accept single coordsys in place of tuple/list if not isinstance(coordsystems, (tuple, list)): coordsystems = (coordsystems,) @@ -115,6 +116,9 @@ def __init__(self, coordsystems, comm=None, mesh=None, dtype=None): self._build_layouts() # Keep set of weak field references self.fields = WeakSet() + # Array module + x = array_namespace.zeros(0) + self.array_namespace = array_api_compat.array_namespace(x) @CachedAttribute def cs_by_axis(self): diff --git a/dedalus/core/field.py b/dedalus/core/field.py index 44bd1a94..5ff5d8c5 100644 --- a/dedalus/core/field.py +++ b/dedalus/core/field.py @@ -483,16 +483,19 @@ def evaluate(self): def reinitialize(self, **kw): return self - @staticmethod - def _create_buffer(buffer_size): + def _create_buffer(self, buffer_size): """Create buffer for Field data.""" - if buffer_size == 0: - # FFTW doesn't like allocating size-0 arrays - return np.zeros((0,), dtype=np.float64) + xp = self.array_namespace + if xp == np: + if buffer_size == 0: + # FFTW doesn't like allocating size-0 arrays + return np.zeros((0,), dtype=np.float64) + else: + # Use FFTW SIMD aligned allocation + alloc_doubles = buffer_size // 8 + return fftw.create_buffer(alloc_doubles) else: - # Use FFTW SIMD aligned allocation - alloc_doubles = buffer_size // 8 - return fftw.create_buffer(alloc_doubles) + return xp.zeros(buffer_size) @CachedAttribute def _dealias_buffer_size(self): @@ -526,12 +529,13 @@ def preset_scales(self, scales): def preset_layout(self, layout): """Interpret buffer as data in specified layout.""" + xp = self.array_namespace layout = self.dist.get_layout_object(layout) self.layout = layout tens_shape = [vs.dim for vs in self.tensorsig] local_shape = layout.local_shape(self.domain, self.scales) total_shape = tuple(tens_shape) + tuple(local_shape) - self.data = np.ndarray(shape=total_shape, + self.data = xp.ndarray(shape=total_shape, dtype=self.dtype, buffer=self.buffer) #self.global_start = layout.start(self.domain, self.scales) @@ -571,6 +575,7 @@ def __init__(self, dist, bases=None, name=None, tensorsig=None, dtype=None): dtype = dist.dtype from .domain import Domain self.dist = dist + self.array_namespace = dist.array_namespace self.name = name self.tensorsig = tensorsig self.dtype = dtype From f7fb6c9d34e666171b494eae301bf8ebc0adc57b Mon Sep 17 00:00:00 2001 From: "Keaton J. Burns" Date: Tue, 27 May 2025 14:49:15 +0100 Subject: [PATCH 02/50] Add array-api-compat to setup.py --- setup.py | 1 + 1 file changed, 1 insertion(+) diff --git a/setup.py b/setup.py index 7291ba0e..36641924 100644 --- a/setup.py +++ b/setup.py @@ -181,6 +181,7 @@ def read(rel_path): # Runtime requirements install_requires = [ + "array-api-compat", "docopt", "h5py >= 3.0.0", "matplotlib >= 3.7.0", From 63648b82b1e80a56750ef36b0c41beedba6a7361 Mon Sep 17 00:00:00 2001 From: "Keaton J. Burns" Date: Tue, 27 May 2025 15:15:45 +0100 Subject: [PATCH 03/50] Allow specifying array namespace by string --- dedalus/core/distributor.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/dedalus/core/distributor.py b/dedalus/core/distributor.py index 6b2aedbd..9b2a3ff7 100644 --- a/dedalus/core/distributor.py +++ b/dedalus/core/distributor.py @@ -117,8 +117,10 @@ def __init__(self, coordsystems, comm=None, mesh=None, dtype=None, array_namespa # Keep set of weak field references self.fields = WeakSet() # Array module - x = array_namespace.zeros(0) - self.array_namespace = array_api_compat.array_namespace(x) + if isinstance(array_namespace, str): + self.array_namespace = getattr(array_api_compat, array_namespace) + else: + self.array_namespace = array_api_compat.array_namespace(array_namespace.zeros(0)) @CachedAttribute def cs_by_axis(self): From ad50b2031dabfb40c912ddb251f52051f3904a9b Mon Sep 17 00:00:00 2001 From: "Keaton J. Burns" Date: Tue, 27 May 2025 15:16:18 +0100 Subject: [PATCH 04/50] Try fixing cupy allocation from buffer --- dedalus/core/field.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/dedalus/core/field.py b/dedalus/core/field.py index 5ff5d8c5..8f9b26c3 100644 --- a/dedalus/core/field.py +++ b/dedalus/core/field.py @@ -535,9 +535,11 @@ def preset_layout(self, layout): tens_shape = [vs.dim for vs in self.tensorsig] local_shape = layout.local_shape(self.domain, self.scales) total_shape = tuple(tens_shape) + tuple(local_shape) - self.data = xp.ndarray(shape=total_shape, - dtype=self.dtype, - buffer=self.buffer) + # Handle cupy allocation + if xp.__name__ == "cupy": + self.data = xp.ndarray(shape=total_shape, dtype=self.dtype, memptr=self.buffer.data) + else: + self.data = xp.ndarray(shape=total_shape, dtype=self.dtype, buffer=self.buffer) #self.global_start = layout.start(self.domain, self.scales) From 3f8692da29e6f4d57d4150523675234390c6d275 Mon Sep 17 00:00:00 2001 From: "Keaton J. Burns" Date: Tue, 27 May 2025 15:31:50 +0100 Subject: [PATCH 05/50] Fix cupy check --- dedalus/core/field.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/dedalus/core/field.py b/dedalus/core/field.py index 8f9b26c3..f30fb556 100644 --- a/dedalus/core/field.py +++ b/dedalus/core/field.py @@ -7,6 +7,7 @@ from functools import partial, reduce from collections import defaultdict import numpy as np +import array_api_compat from mpi4py import MPI from scipy import sparse from scipy.sparse import linalg as splinalg @@ -535,8 +536,8 @@ def preset_layout(self, layout): tens_shape = [vs.dim for vs in self.tensorsig] local_shape = layout.local_shape(self.domain, self.scales) total_shape = tuple(tens_shape) + tuple(local_shape) - # Handle cupy allocation - if xp.__name__ == "cupy": + # Create view into buffer + if array_api_compat.is_cupy_namespace(xp): self.data = xp.ndarray(shape=total_shape, dtype=self.dtype, memptr=self.buffer.data) else: self.data = xp.ndarray(shape=total_shape, dtype=self.dtype, buffer=self.buffer) From 8b4295725fbc79cac189e425e70e5367b38dfe94 Mon Sep 17 00:00:00 2001 From: "Keaton J. Burns" Date: Tue, 27 May 2025 15:48:34 +0100 Subject: [PATCH 06/50] Add cupy-based complex fourier MMT --- dedalus/core/basis.py | 6 ++++-- dedalus/core/transforms.py | 36 +++++++++++++++++++++--------------- 2 files changed, 25 insertions(+), 17 deletions(-) diff --git a/dedalus/core/basis.py b/dedalus/core/basis.py index b25e5061..d2f41c54 100644 --- a/dedalus/core/basis.py +++ b/dedalus/core/basis.py @@ -1072,11 +1072,13 @@ def _native_grid(self, scale): @CachedMethod def transform_plan(self, dist, grid_size): """Build transform plan.""" + xp = dist.array_namespace + xp_name = xp.__name__.split('.')[-1] # Shortcut trivial transforms if grid_size == 1 or self.size == 1: - return self.transforms['matrix'](grid_size, self.size) + return self.transforms[f"matrix-{xp_name}"](grid_size, self.size) else: - return self.transforms[self.library](grid_size, self.size) + return self.transforms[f"{self.library}-{xp_name}"](grid_size, self.size) def forward_transform(self, field, axis, gdata, cdata): # Transform diff --git a/dedalus/core/transforms.py b/dedalus/core/transforms.py index 00758fb2..db6f2595 100644 --- a/dedalus/core/transforms.py +++ b/dedalus/core/transforms.py @@ -191,50 +191,56 @@ class ComplexFourierTransform(SeparableTransform): If M is even, the ordering is [0, 1, 2, ..., KM, -KM, -KM+1, ..., -1]. """ - def __init__(self, grid_size, coeff_size): + def __init__(self, grid_size, coeff_size, array_namespace, dtype): self.N = grid_size self.M = coeff_size self.KN = (self.N - 1) // 2 self.KM = (self.M - 1) // 2 self.Kmax = min(self.KN, self.KM) + self.array_namespace = array_namespace + self.dtype = dtype @property def wavenumbers(self): """One-dimensional global wavenumber array.""" + xp = self.array_namespace M = self.M KM = self.KM - k = np.arange(M) + k = xp.arange(M) # Wrap around Nyquist mode return (k + KM) % M - KM -@register_transform(basis.ComplexFourier, 'matrix') +@register_transform(basis.ComplexFourier, 'matrix-numpy') +@register_transform(basis.ComplexFourier, 'matrix-cupy') class ComplexFourierMMT(ComplexFourierTransform, SeparableMatrixTransform): """Complex-to-complex Fourier MMT.""" @CachedAttribute def forward_matrix(self): """Build forward transform matrix.""" + xp = self.array_namespace K = self.wavenumbers[:, None] - X = np.arange(self.N)[None, :] + X = xp.arange(self.N)[None, :] dX = self.N / 2 / np.pi - quadrature = np.exp(-1j*K*X/dX) / self.N + quadrature = xp.exp(-1j*K*X/dX) / self.N # Zero Nyquist and higher modes for transforms with grid_size <= coeff_size - quadrature *= np.abs(K) <= self.Kmax - # Ensure C ordering for fast dot products - return np.asarray(quadrature, order='C') + quadrature *= xp.abs(K) <= self.Kmax + # Ensure C ordering for fast dot products, cast to specified dtype + return xp.asarray(quadrature, order='C', dtype=self.dtype) @CachedAttribute def backward_matrix(self): """Build backward transform matrix.""" + xp = self.array_namespace K = self.wavenumbers[None, :] - X = np.arange(self.N)[:, None] + X = xp.arange(self.N)[:, None] dX = self.N / 2 / np.pi - functions = np.exp(1j*K*X/dX) + functions = xp.exp(1j*K*X/dX) # Zero Nyquist and higher modes for transforms with grid_size <= coeff_size - functions *= np.abs(K) <= self.Kmax - # Ensure C ordering for fast dot products - return np.asarray(functions, order='C') + functions *= xp.abs(K) <= self.Kmax + # Ensure C ordering for fast dot products, cast to specified dtype + return xp.asarray(functions, order='C', dtype=self.dtype) class ComplexFFT(ComplexFourierTransform): @@ -267,7 +273,7 @@ def resize_coeffs(self, data_in, data_out, axis, rescale): np.multiply(data_in[negfreq], rescale, data_out[negfreq]) -@register_transform(basis.ComplexFourier, 'scipy') +@register_transform(basis.ComplexFourier, 'scipy-numpy') class ScipyComplexFFT(ComplexFFT): """Complex-to-complex FFT using scipy.fft.""" @@ -299,7 +305,7 @@ def __init__(self, *args, rigor=None, **kw): super().__init__(*args, **kw) -@register_transform(basis.ComplexFourier, 'fftw') +@register_transform(basis.ComplexFourier, 'fftw-numpy') class FFTWComplexFFT(FFTWBase, ComplexFFT): """Complex-to-complex FFT using FFTW.""" From bebf8dd7ef6337c19f75a41443406d0fef650823 Mon Sep 17 00:00:00 2001 From: "Keaton J. Burns" Date: Tue, 27 May 2025 15:52:28 +0100 Subject: [PATCH 07/50] Fix transform lookup --- dedalus/core/basis.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/dedalus/core/basis.py b/dedalus/core/basis.py index d2f41c54..b89b7b91 100644 --- a/dedalus/core/basis.py +++ b/dedalus/core/basis.py @@ -1076,9 +1076,9 @@ def transform_plan(self, dist, grid_size): xp_name = xp.__name__.split('.')[-1] # Shortcut trivial transforms if grid_size == 1 or self.size == 1: - return self.transforms[f"matrix-{xp_name}"](grid_size, self.size) + return self.transforms[f"matrix-{xp_name}"](grid_size, self.size, dist.array_namespace, dist.dtype) else: - return self.transforms[f"{self.library}-{xp_name}"](grid_size, self.size) + return self.transforms[f"{self.library}-{xp_name}"](grid_size, self.size, dist.array_namespace, dist.dtype) def forward_transform(self, field, axis, gdata, cdata): # Transform From f04bdf6ef6e7e34a8de9e7cb5f9b7fd1b522ae29 Mon Sep 17 00:00:00 2001 From: "Keaton J. Burns" Date: Tue, 27 May 2025 16:16:32 +0100 Subject: [PATCH 08/50] Make fill_random array and dtype compatible --- dedalus/core/field.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/dedalus/core/field.py b/dedalus/core/field.py index f30fb556..0e4fadd0 100644 --- a/dedalus/core/field.py +++ b/dedalus/core/field.py @@ -935,6 +935,7 @@ def fill_random(self, layout=None, scales=None, seed=None, chunk_size=2**20, dis **kw : dict Other keywords passed to the distribution method. """ + xp = self.dist.array_namespace init_layout = self.layout # Set scales if requested if scales is not None: @@ -954,11 +955,10 @@ def fill_random(self, layout=None, scales=None, seed=None, chunk_size=2**20, dis spatial_slices = self.layout.slices(self.domain, self.scales) local_slices = component_slices + spatial_slices local_data = global_data[local_slices] - if self.is_real: - self.data[:] = local_data - else: - self.data.real[:] = local_data[..., 0] - self.data.imag[:] = local_data[..., 1] + if self.is_complex: + local_data = local_data[..., 0] + 1j * local_data[..., 1] + # Copy to field data + self.data[:] = xp.asarray(local_data, dtype=self.dtype) def low_pass_filter(self, shape=None, scales=None): """ From 64b91c4251ebfa415b51ea618b681e90fc289bc6 Mon Sep 17 00:00:00 2001 From: "Keaton J. Burns" Date: Tue, 27 May 2025 16:16:55 +0100 Subject: [PATCH 09/50] Work on cupy real fourier MMTs --- dedalus/core/transforms.py | 40 ++++++++++++++++++++++---------------- 1 file changed, 23 insertions(+), 17 deletions(-) diff --git a/dedalus/core/transforms.py b/dedalus/core/transforms.py index db6f2595..066b4bb4 100644 --- a/dedalus/core/transforms.py +++ b/dedalus/core/transforms.py @@ -374,7 +374,7 @@ class RealFourierTransform(SeparableTransform): where the k = 0 minus-sine mode is zeroed in both directions. """ - def __init__(self, grid_size, coeff_size): + def __init__(self, grid_size, coeff_size, array_namespace, dtype): if coeff_size % 2 != 0: pass#raise ValueError("coeff_size must be even.") self.N = grid_size @@ -382,55 +382,61 @@ def __init__(self, grid_size, coeff_size): self.KN = (self.N - 1) // 2 self.KM = (self.M - 1) // 2 self.Kmax = min(self.KN, self.KM) + self.array_namespace = array_namespace + self.dtype = dtype @property def wavenumbers(self): """One-dimensional global wavenumber array.""" + xp = self.array_namespace # Repeat k's for cos and msin parts - return np.repeat(np.arange(self.KM+1), 2) + return xp.repeat(xp.arange(self.KM+1), 2) -@register_transform(basis.RealFourier, 'matrix') +@register_transform(basis.RealFourier, 'matrix-numpy') +@register_transform(basis.RealFourier, 'matrix-cupy') class RealFourierMMT(RealFourierTransform, SeparableMatrixTransform): """Real-to-real Fourier MMT.""" @CachedAttribute def forward_matrix(self): """Build forward transform matrix.""" + xp = self.array_namespace N = self.N M = max(2, self.M) # Account for sin and cos parts of m=0 Kmax = self.Kmax K = self.wavenumbers[::2, None] - X = np.arange(N)[None, :] + X = xp.arange(N)[None, :] dX = N / 2 / np.pi - quadrature = np.zeros((M, N)) - quadrature[0::2] = (2 / N) * np.cos(K*X/dX) - quadrature[1::2] = -(2 / N) * np.sin(K*X/dX) + quadrature = xp.zeros((M, N)) + quadrature[0::2] = (2 / N) * xp.cos(K*X/dX) + quadrature[1::2] = -(2 / N) * xp.sin(K*X/dX) quadrature[0] = 1 / N # Zero Nyquist and higher modes for transforms with grid_size <= coeff_size quadrature *= self.wavenumbers[:,None] <= self.Kmax # Ensure C ordering for fast dot products - return np.asarray(quadrature, order='C') + return xp.asarray(quadrature, order='C', dtype=self.dtype) @CachedAttribute def backward_matrix(self): """Build backward transform matrix.""" + xp = self.array_namespace N = self.N M = max(2, self.M) # Account for sin and cos parts of m=0 Kmax = self.Kmax K = self.wavenumbers[None, ::2] - X = np.arange(N)[:, None] + X = xp.arange(N)[:, None] dX = N / 2 / np.pi - functions = np.zeros((N, M)) - functions[:, 0::2] = np.cos(K*X/dX) - functions[:, 1::2] = -np.sin(K*X/dX) + functions = xp.zeros((N, M)) + functions[:, 0::2] = xp.cos(K*X/dX) + functions[:, 1::2] = -xp.sin(K*X/dX) # Zero Nyquist and higher modes for transforms with grid_size <= coeff_size functions *= self.wavenumbers[None, :] <= self.Kmax # Ensure C ordering for fast dot products - return np.asarray(functions, order='C') + return xp.asarray(functions, order='C', dtype=self.dtype) -@register_transform(basis.RealFourier, 'fftpack') +@register_transform(basis.RealFourier, 'fftpack-numpy') class FFTPACKRealFFT(RealFourierTransform): """Real-to-real FFT using scipy.fftpack.""" @@ -515,7 +521,7 @@ def repack_rescale(self, cdata, temp, axis, rescale): temp[axslice(axis, Kmax+1, None)] = 0 -@register_transform(basis.RealFourier, 'scipy') +@register_transform(basis.RealFourier, 'scipy-numpy') class ScipyRealFFT(RealFFT): """Real-to-real FFT using scipy.fft.""" @@ -540,7 +546,7 @@ def backward(self, cdata, gdata, axis): np.copyto(gdata, temp) -@register_transform(basis.RealFourier, 'fftw') +@register_transform(basis.RealFourier, 'fftw-numpy') class FFTWRealFFT(FFTWBase, RealFFT): """Real-to-real FFT using FFTW.""" @@ -571,7 +577,7 @@ def backward(self, cdata, gdata, axis): plan.backward(temp, gdata) -@register_transform(basis.RealFourier, 'fftw_hc') +@register_transform(basis.RealFourier, 'fftw_hc-numpy') class FFTWHalfComplexFFT(FFTWBase, RealFourierTransform): """Real-to-real FFT using FFTW half-complex DFT.""" From 295158ad41449ec2b5200d8eff4d36ea8f99811c Mon Sep 17 00:00:00 2001 From: "Keaton J. Burns" Date: Tue, 27 May 2025 16:26:02 +0100 Subject: [PATCH 10/50] Generalize Fourier basis for more dtypes --- dedalus/core/basis.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/dedalus/core/basis.py b/dedalus/core/basis.py index b89b7b91..f7f63db6 100644 --- a/dedalus/core/basis.py +++ b/dedalus/core/basis.py @@ -14,7 +14,7 @@ from ..tools import clenshaw from ..tools.array import reshape_vector, axindex, axslice, interleave_matrices from ..tools.dispatch import MultiClass, SkipDispatchException -from ..tools.general import unify, DeferredTuple +from ..tools.general import unify, DeferredTuple, is_real_dtype, is_complex_dtype from .coords import Coordinate, CartesianCoordinates, S2Coordinates, SphericalCoordinates, PolarCoordinates, AzimuthalCoordinate, DirectProduct from .domain import Domain from .field import Operand, LockedField @@ -1099,9 +1099,9 @@ def Fourier(*args, dtype=None, **kw): """Factory function dispatching to RealFourier and ComplexFourier based on provided dtype.""" if dtype is None: raise ValueError("dtype must be specified") - elif dtype == np.float64: + elif is_real_dtype(dtype): return RealFourier(*args, **kw) - elif dtype == np.complex128: + elif is_complex_dtype(dtype): return ComplexFourier(*args, **kw) else: raise ValueError(f"Unrecognized dtype: {dtype}") From b59d1ec6abda61934fd67cca75354c4947f27ad3 Mon Sep 17 00:00:00 2001 From: "Keaton J. Burns" Date: Wed, 28 May 2025 11:13:42 +0100 Subject: [PATCH 11/50] Add cupy complex FFT --- dedalus/core/transforms.py | 40 ++++++++++++++++++++++++++++++++------ 1 file changed, 34 insertions(+), 6 deletions(-) diff --git a/dedalus/core/transforms.py b/dedalus/core/transforms.py index 066b4bb4..4e2aa839 100644 --- a/dedalus/core/transforms.py +++ b/dedalus/core/transforms.py @@ -248,29 +248,30 @@ class ComplexFFT(ComplexFourierTransform): def resize_coeffs(self, data_in, data_out, axis, rescale): """Resize and rescale coefficients in standard FFT format by intermediate padding/truncation.""" + xp = self.array_namespace M = self.M Kmax = self.Kmax if Kmax == 0: posfreq = axslice(axis, 0, 1) badfreq = axslice(axis, 1, None) if rescale is None: - np.copyto(data_out[posfreq], data_in[posfreq]) + xp.copyto(data_out[posfreq], data_in[posfreq]) data_out[badfreq] = 0 else: - np.multiply(data_in[posfreq], rescale, data_out[posfreq]) + xp.multiply(data_in[posfreq], rescale, data_out[posfreq]) data_out[badfreq] = 0 else: posfreq = axslice(axis, 0, Kmax+1) badfreq = axslice(axis, Kmax+1, -Kmax) negfreq = axslice(axis, -Kmax, None) if rescale is None: - np.copyto(data_out[posfreq], data_in[posfreq]) + xp.copyto(data_out[posfreq], data_in[posfreq]) data_out[badfreq] = 0 - np.copyto(data_out[negfreq], data_in[negfreq]) + xp.copyto(data_out[negfreq], data_in[negfreq]) else: - np.multiply(data_in[posfreq], rescale, data_out[posfreq]) + xp.multiply(data_in[posfreq], rescale, data_out[posfreq]) data_out[badfreq] = 0 - np.multiply(data_in[negfreq], rescale, data_out[negfreq]) + xp.multiply(data_in[negfreq], rescale, data_out[negfreq]) @register_transform(basis.ComplexFourier, 'scipy-numpy') @@ -295,6 +296,33 @@ def backward(self, cdata, gdata, axis): np.copyto(gdata, temp) +@register_transform(basis.ComplexFourier, 'scipy-cupy') +class CupyComplexFFT(ComplexFFT): + """Complex-to-complex FFT using scipy.fft.""" + + def __init__(self, *args, **kw): + import cupyx.scipy.fft as cufft + self.cufft = cufft + super().__init__(*args, **kw) + + def forward(self, gdata, cdata, axis): + """Apply forward transform along specified axis.""" + # Call FFT + temp = self.cufft.fft(gdata, axis=axis) # Creates temporary + # Resize and rescale for unit-amplitude normalization + self.resize_coeffs(temp, cdata, axis, rescale=1/self.N) + + def backward(self, cdata, gdata, axis): + """Apply backward transform along specified axis.""" + # Resize and rescale for unit-amplitude normalization + # Need temporary to avoid overwriting problems + temp = np.empty_like(gdata) # Creates temporary + self.resize_coeffs(cdata, temp, axis, rescale=self.N) + # Call FFT + temp = self.cufft.ifft(temp, axis=axis, overwrite_x=True) # Creates temporary + np.copyto(gdata, temp) + + class FFTWBase: """Abstract base class for FFTW transforms.""" From 00bd2664d3cb348ab765ac8c8808f6a92e31e0b0 Mon Sep 17 00:00:00 2001 From: "Keaton J. Burns" Date: Wed, 28 May 2025 14:35:10 +0100 Subject: [PATCH 12/50] Add cupy real fft --- dedalus/core/transforms.py | 64 +++++++++++++++++++++++++++++++------- dedalus/tools/general.py | 12 +++++++ 2 files changed, 64 insertions(+), 12 deletions(-) diff --git a/dedalus/core/transforms.py b/dedalus/core/transforms.py index 4e2aa839..5616c1c7 100644 --- a/dedalus/core/transforms.py +++ b/dedalus/core/transforms.py @@ -15,6 +15,7 @@ from ..tools.array import apply_matrix, apply_dense, axslice, solve_upper_sparse, apply_sparse from ..tools.cache import CachedAttribute from ..tools.cache import CachedMethod +from ..tools.general import float_to_complex import logging logger = logging.getLogger(__name__.split('.')[-1]) @@ -314,13 +315,14 @@ def forward(self, gdata, cdata, axis): def backward(self, cdata, gdata, axis): """Apply backward transform along specified axis.""" + xp = self.array_namespace # Resize and rescale for unit-amplitude normalization # Need temporary to avoid overwriting problems - temp = np.empty_like(gdata) # Creates temporary + temp = xp.empty_like(gdata) # Creates temporary self.resize_coeffs(cdata, temp, axis, rescale=self.N) # Call FFT temp = self.cufft.ifft(temp, axis=axis, overwrite_x=True) # Creates temporary - np.copyto(gdata, temp) + xp.copyto(gdata, temp) class FFTWBase: @@ -511,40 +513,42 @@ class RealFFT(RealFourierTransform): def unpack_rescale(self, temp, cdata, axis, rescale): """Unpack complex coefficients and rescale for unit-amplitude normalization.""" + xp = self.array_namespace Kmax = self.Kmax # Scale k = 0 cos data meancos = axslice(axis, 0, 1) - np.multiply(temp[meancos].real, rescale, cdata[meancos]) + xp.multiply(temp[meancos].real, rescale, cdata[meancos]) # Zero k = 0 msin data cdata[axslice(axis, 1, 2)] = 0 # Unpack and scale 1 < k <= Kmax data temp_posfreq = temp[axslice(axis, 1, Kmax+1)] cdata_posfreq_cos = cdata[axslice(axis, 2, 2*(Kmax+1), 2)] cdata_posfreq_msin = cdata[axslice(axis, 3, 2*(Kmax+1), 2)] - np.multiply(temp_posfreq.real, 2*rescale, cdata_posfreq_cos) - np.multiply(temp_posfreq.imag, 2*rescale, cdata_posfreq_msin) + xp.multiply(temp_posfreq.real, 2*rescale, cdata_posfreq_cos) + xp.multiply(temp_posfreq.imag, 2*rescale, cdata_posfreq_msin) # Zero k > Kmax data cdata[axslice(axis, 2*(Kmax+1), None)] = 0 def repack_rescale(self, cdata, temp, axis, rescale): """Repack into complex coefficients and rescale for unit-amplitude normalization.""" + xp = self.array_namespace Kmax = self.Kmax # Scale k = 0 data meancos = axslice(axis, 0, 1) if rescale is None: - np.copyto(temp[meancos], cdata[meancos]) + xp.copyto(temp[meancos], cdata[meancos]) else: - np.multiply(cdata[meancos], rescale, temp[meancos]) + xp.multiply(cdata[meancos], rescale, temp[meancos]) # Repack and scale 1 < k <= Kmax data temp_posfreq = temp[axslice(axis, 1, Kmax+1)] cdata_posfreq_cos = cdata[axslice(axis, 2, 2*(Kmax+1), 2)] cdata_posfreq_msin = cdata[axslice(axis, 3, 2*(Kmax+1), 2)] if rescale is None: - np.multiply(cdata_posfreq_cos, (1 / 2), temp_posfreq.real) - np.multiply(cdata_posfreq_msin, (1 / 2), temp_posfreq.imag) + xp.multiply(cdata_posfreq_cos, (1 / 2), temp_posfreq.real) + xp.multiply(cdata_posfreq_msin, (1 / 2), temp_posfreq.imag) else: - np.multiply(cdata_posfreq_cos, (rescale / 2), temp_posfreq.real) - np.multiply(cdata_posfreq_msin, (rescale / 2), temp_posfreq.imag) + xp.multiply(cdata_posfreq_cos, (rescale / 2), temp_posfreq.real) + xp.multiply(cdata_posfreq_msin, (rescale / 2), temp_posfreq.imag) # Zero k > Kmax data temp[axslice(axis, Kmax+1, None)] = 0 @@ -553,6 +557,10 @@ def repack_rescale(self, cdata, temp, axis, rescale): class ScipyRealFFT(RealFFT): """Real-to-real FFT using scipy.fft.""" + def __init__(self, *args, **kw): + super().__init__(*args, **kw) + self.complex_dtype = float_to_complex(self.dtype) + def forward(self, gdata, cdata, axis): """Apply forward transform along specified axis.""" # Call RFFT @@ -566,7 +574,7 @@ def backward(self, cdata, gdata, axis): # Rescale all modes and combine into complex form shape = list(gdata.shape) shape[axis] = N // 2 + 1 - temp = np.empty(shape=shape, dtype=np.complex128) # Creates temporary + temp = np.empty(shape=shape, dtype=self.complex_dtype) # Creates temporary # Repack into complex form and rescale self.repack_rescale(cdata, temp, axis, rescale=N) # Call IRFFT @@ -574,6 +582,38 @@ def backward(self, cdata, gdata, axis): np.copyto(gdata, temp) +@register_transform(basis.RealFourier, 'scipy-cupy') +class CupyRealFFT(RealFFT): + """Real-to-real FFT using scipy.fft.""" + + def __init__(self, *args, **kw): + import cupyx.scipy.fft as cufft + self.cufft = cufft + super().__init__(*args, **kw) + self.complex_dtype = float_to_complex(self.dtype) + + def forward(self, gdata, cdata, axis): + """Apply forward transform along specified axis.""" + # Call RFFT + temp = self.cufft.rfft(gdata, axis=axis) # Creates temporary + # Unpack from complex form and rescale + self.unpack_rescale(temp, cdata, axis, rescale=1/self.N) + + def backward(self, cdata, gdata, axis): + """Apply backward transform along specified axis.""" + xp = self.array_namespace + N = self.N + # Rescale all modes and combine into complex form + shape = list(gdata.shape) + shape[axis] = N // 2 + 1 + temp = xp.empty(shape=shape, dtype=self.complex_dtype) # Creates temporary + # Repack into complex form and rescale + self.repack_rescale(cdata, temp, axis, rescale=N) + # Call IRFFT + temp = self.cufft.irfft(temp, axis=axis, n=N, overwrite_x=True) # Creates temporary + xp.copyto(gdata, temp) + + @register_transform(basis.RealFourier, 'fftw-numpy') class FFTWRealFFT(FFTWBase, RealFFT): """Real-to-real FFT using FFTW.""" diff --git a/dedalus/tools/general.py b/dedalus/tools/general.py index 18eb5ee4..5e22f9b2 100644 --- a/dedalus/tools/general.py +++ b/dedalus/tools/general.py @@ -124,3 +124,15 @@ def is_complex_dtype(dtype): dtype = dtype.type return np.iscomplexobj(dtype()) + +def float_to_complex(dtype): + itemsize = np.dtype(dtype).itemsize + complex_dtype = np.dtype(f'complex{itemsize*2}') + return complex_dtype.type + + +def complex_to_float(dtype): + itemsize = np.dtype(dtype).itemsize + float_dtype = np.dtype(f'float{itemsize//2}') + return float_dtype.type + From 60369836ffb6b4a93917ad70a3387961136afffc Mon Sep 17 00:00:00 2001 From: "Keaton J. Burns" Date: Wed, 28 May 2025 15:07:20 +0100 Subject: [PATCH 13/50] Fix dtype conversion --- dedalus/tools/general.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/dedalus/tools/general.py b/dedalus/tools/general.py index 5e22f9b2..9b8b5746 100644 --- a/dedalus/tools/general.py +++ b/dedalus/tools/general.py @@ -127,12 +127,12 @@ def is_complex_dtype(dtype): def float_to_complex(dtype): itemsize = np.dtype(dtype).itemsize - complex_dtype = np.dtype(f'complex{itemsize*2}') + complex_dtype = np.dtype(f'complex{16*itemsize}') return complex_dtype.type def complex_to_float(dtype): itemsize = np.dtype(dtype).itemsize - float_dtype = np.dtype(f'float{itemsize//2}') + float_dtype = np.dtype(f'float{4*itemsize}') return float_dtype.type From 19e235077d5ff38ea39cee4eb4d8116c85b2e6c2 Mon Sep 17 00:00:00 2001 From: "Keaton J. Burns" Date: Wed, 28 May 2025 15:20:25 +0100 Subject: [PATCH 14/50] Add array compat for basic arithmetic --- dedalus/core/arithmetic.py | 9 ++++++--- dedalus/core/future.py | 1 + 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/dedalus/core/arithmetic.py b/dedalus/core/arithmetic.py index 64daa530..b92b7ce0 100644 --- a/dedalus/core/arithmetic.py +++ b/dedalus/core/arithmetic.py @@ -245,10 +245,11 @@ def choose_layout(self): def operate(self, out): """Perform operation.""" + xp = self.array_namespace arg0, arg1 = self.args # Set output layout out.preset_layout(arg0.layout) - np.add(arg0.data, arg1.data, out=out.data) + xp.add(arg0.data, arg1.data, out=out.data) # used for einsum string manipulation @@ -854,6 +855,7 @@ def __init__(self, arg0, arg1, out=None, **kw): def operate(self, out): """Perform operation.""" + xp = self.array_namespace arg0, arg1 = self.args # Set output layout out.preset_layout(arg0.layout) @@ -863,7 +865,7 @@ def operate(self, out): # Reshape arg data to broadcast properly for output tensorsig arg0_exp_data = arg0_data.reshape(self.arg0_exp_tshape + arg0_data.shape[len(arg0.tensorsig):]) arg1_exp_data = arg1_data.reshape(self.arg1_exp_tshape + arg1_data.shape[len(arg1.tensorsig):]) - np.multiply(arg0_exp_data, arg1_exp_data, out=out.data) + xp.multiply(arg0_exp_data, arg1_exp_data, out=out.data) class GhostBroadcaster: @@ -939,11 +941,12 @@ def enforce_conditions(self): def operate(self, out): """Perform operation.""" + xp = self.array_namespace arg0, arg1 = self.args # Set output layout out.preset_layout(arg1.layout) # Multiply argument data - np.multiply(arg0, arg1.data, out=out.data) + xp.multiply(arg0, arg1.data, out=out.data) def matrix_dependence(self, *vars): return self.args[1].matrix_dependence(*vars) diff --git a/dedalus/core/future.py b/dedalus/core/future.py index d3d9bb15..ecc27d96 100644 --- a/dedalus/core/future.py +++ b/dedalus/core/future.py @@ -51,6 +51,7 @@ def __init__(self, *args, out=None): self.original_args = tuple(args) self.out = out self.dist = unify_attributes(args, 'dist', require=False) + self.array_namespace = self.dist.array_namespace #self.domain = Domain(self.dist, self.bases) self._grid_layout = self.dist.grid_layout self._coeff_layout = self.dist.coeff_layout From a3bda50034ee9198766c39f09fff5f7577aa7495 Mon Sep 17 00:00:00 2001 From: "Keaton J. Burns" Date: Wed, 28 May 2025 15:30:15 +0100 Subject: [PATCH 15/50] Beginning adding array_compat to operators --- dedalus/core/operators.py | 37 +++++++++++++++++++++++++------------ 1 file changed, 25 insertions(+), 12 deletions(-) diff --git a/dedalus/core/operators.py b/dedalus/core/operators.py index db750d73..662d2e1d 100644 --- a/dedalus/core/operators.py +++ b/dedalus/core/operators.py @@ -378,11 +378,12 @@ def enforce_conditions(self): arg0.require_grid_space() def operate(self, out): + xp = self.array_namespace arg0, arg1 = self.args # Multiply in grid layout out.preset_layout(arg0.layout) if out.data.size: - np.power(arg0.data, arg1, out.data) + xp.power(arg0.data, arg1, out.data) def new_operands(self, arg0, arg1, **kw): return Power(arg0, arg1) @@ -498,8 +499,9 @@ def enforce_conditions(self): self.args[i].change_layout(self.layout) def operate(self, out): + xp = self.array_namespace out.preset_layout(self.layout) - np.copyto(out.data, self.func(*self.args, **self.kw)) + xp.copyto(out.data, self.func(*self.args, **self.kw)) class UnaryGridFunction(NonlinearOperator, FutureField): @@ -829,10 +831,11 @@ def enforce_conditions(self): def operate(self, out): """Perform operation.""" + xp = self.array_namespace arg0 = self.args[0] out.preset_layout(arg0.layout) out.lock_to_layouts(self.layouts) - np.copyto(out.data, arg0.data) + xp.copyto(out.data, arg0.data) def new_operand(self, operand, **kw): return Lock(operand, *self.layouts, **kw) @@ -1539,9 +1542,10 @@ def subproblem_matrix(self, subproblem): def operate(self, out): """Perform operation.""" + xp = self.array_namespace arg = self.args[0] out.preset_layout(arg.layout) - np.copyto(out.data, arg.data) + xp.copyto(out.data, arg.data) class Convert(SpectralOperator, metaclass=MultiClass): @@ -1646,12 +1650,13 @@ def subspace_matrix(self, layout): def operate(self, out): """Perform operation.""" + xp = self.array_namespace arg = self.args[0] layout = arg.layout # Copy for grid space if layout.grid_space[self.last_axis]: out.preset_layout(layout) - np.copyto(out.data, arg.data) + xp.copyto(out.data, arg.data) # Revert to matrix application for coeff space else: super().operate(out) @@ -1794,9 +1799,10 @@ def base(self): def operate(self, out): """Perform operation.""" + xp = self.array_namespace arg = self.args[0] out.preset_layout(arg.layout) - np.einsum('ii...', arg.data, out=out.data) + xp.einsum('ii...', arg.data, out=out.data) class SphericalTrace(Trace): @@ -1993,6 +1999,7 @@ def subproblem_matrix(self, subproblem): def operate(self, out): """Perform operation.""" + xp = self.array_namespace operand = self.args[0] # Set output layout out.preset_layout(operand.layout) @@ -3507,10 +3514,11 @@ def subproblem_matrix(self, subproblem): def operate(self, out): """Perform operation.""" # OPTIMIZE: this has an extra copy + xp = self.array_namespace arg0 = self.args[0] # Set output layout out.preset_layout(arg0.layout) - np.copyto(out.data, arg0.data) + xp.copyto(out.data, arg0.data) class DirectProductDivergence(Divergence): @@ -3556,10 +3564,11 @@ def subproblem_matrix(self, subproblem): def operate(self, out): """Perform operation.""" # OPTIMIZE: this has an extra copy + xp = self.array_namespace arg0 = self.args[0] # Set output layout out.preset_layout(arg0.layout) - np.copyto(out.data, arg0.data) + xp.copyto(out.data, arg0.data) class SphericalDivergence(Divergence, SphericalEllOperator): @@ -3761,10 +3770,11 @@ def enforce_conditions(self): def operate(self, out): """Perform operation.""" # OPTIMIZE: this has an extra copy + xp = self.array_namespace arg0 = self.args[0] # Set output layout out.preset_layout(arg0.layout) - np.copyto(out.data, arg0.data) + xp.copyto(out.data, arg0.data) class DirectProductCurl(Curl): @@ -3848,10 +3858,11 @@ def enforce_conditions(self): def operate(self, out): """Perform operation.""" # OPTIMIZE: this has an extra copy + xp = self.array_namespace arg0 = self.args[0] # Set output layout out.preset_layout(arg0.layout) - np.copyto(out.data, arg0.data) + xp.copyto(out.data, arg0.data) class SphericalCurl(Curl, SphericalEllOperator): @@ -4074,10 +4085,11 @@ def enforce_conditions(self): def operate(self, out): """Perform operation.""" # OPTIMIZE: this has an extra copy + xp = self.array_namespace arg0 = self.args[0] # Set output layout out.preset_layout(arg0.layout) - np.copyto(out.data, arg0.data) + xp.copyto(out.data, arg0.data) class DirectProductLaplacian(Laplacian): @@ -4119,10 +4131,11 @@ def enforce_conditions(self): def operate(self, out): """Perform operation.""" # OPTIMIZE: this has an extra copy + xp = self.array_namespace arg0 = self.args[0] # Set output layout out.preset_layout(arg0.layout) - np.copyto(out.data, arg0.data) + xp.copyto(out.data, arg0.data) class SphericalLaplacian(Laplacian, SphericalEllOperator): From 28f76e7cfca7d5b8ce1647d208f549df5d94debb Mon Sep 17 00:00:00 2001 From: "Keaton J. Burns" Date: Tue, 22 Jul 2025 11:58:36 -0400 Subject: [PATCH 16/50] Quick implementation of apply_sparse for cupy --- dedalus/tools/array.py | 40 +++++++++++------- dedalus/tools/linalg_gpu.py | 84 +++++++++++++++++++++++++++++++++++++ 2 files changed, 109 insertions(+), 15 deletions(-) create mode 100644 dedalus/tools/linalg_gpu.py diff --git a/dedalus/tools/array.py b/dedalus/tools/array.py index ab9caf88..9749f66c 100644 --- a/dedalus/tools/array.py +++ b/dedalus/tools/array.py @@ -6,6 +6,8 @@ from scipy.sparse import _sparsetools from scipy.sparse import linalg as spla from math import prod +from ..tools import linalg_gpu +import array_api_compat from .config import config from . import linalg as cython_linalg @@ -173,14 +175,12 @@ def apply_sparse(matrix, array, axis, out=None, check_shapes=False, num_threads= Apply sparse matrix along any axis of an array. Must be out of place if ouptut is specified. """ - # Check matrix - if not isinstance(matrix, sparse.csr_matrix): - raise ValueError("Matrix must be in CSR format.") + xp = array_api_compat.array_namespace(array) # Check output if out is None: out_shape = list(array.shape) out_shape[axis] = matrix.shape[0] - out = np.empty(out_shape, dtype=array.dtype) + out = xp.empty(out_shape, dtype=array.dtype) elif out is array: raise ValueError("Cannot apply in place") # Check shapes @@ -189,17 +189,27 @@ def apply_sparse(matrix, array, axis, out=None, check_shapes=False, num_threads= raise ValueError("Axis out of bounds.") if matrix.shape[1] != array.shape[axis] or matrix.shape[0] != out.shape[axis]: raise ValueError("Matrix shape mismatch.") - # Old way if requested - if OLD_CSR_MATVECS and array.ndim == 2 and axis == 0: - out.fill(0) - return csr_matvecs(matrix, array, out) - # Promote datatypes - # TODO: find way to optimize this with fused types - matrix_data = matrix.data - if matrix_data.dtype != out.dtype: - matrix_data = matrix_data.astype(out.dtype) - # Call cython routine - cython_linalg.apply_csr(matrix.indptr, matrix.indices, matrix_data, array, out, axis, num_threads) + # Dispatch on array type + if array_api_compat.is_numpy_namespace(xp): + # Check matrix + if not isinstance(matrix, sparse.csr_matrix): + raise ValueError("Matrix must be in CSR format.") + # Old way if requested + if OLD_CSR_MATVECS and array.ndim == 2 and axis == 0: + out.fill(0) + return csr_matvecs(matrix, array, out) + # Promote datatypes + # TODO: find way to optimize this with fused types + matrix_data = matrix.data + if matrix_data.dtype != out.dtype: + matrix_data = matrix_data.astype(out.dtype) + # Call cython routine + cython_linalg.apply_csr(matrix.indptr, matrix.indices, matrix_data, array, out, axis, num_threads) + elif array_api_compat.is_cupy_namespace(xp): + # TODO: check matrix format here without import cupy + linalg_gpu.cupy_apply_csr(matrix, array, axis, out) + else: + raise ValueError("Unsupported array type") return out diff --git a/dedalus/tools/linalg_gpu.py b/dedalus/tools/linalg_gpu.py new file mode 100644 index 00000000..095eabaf --- /dev/null +++ b/dedalus/tools/linalg_gpu.py @@ -0,0 +1,84 @@ + +import numpy as np +try: + import cupy as cp + import cupyx.scipy.sparse as csp + HAVE_CUPY = True +except ImportError: + HAVE_CUPY = False + + +def cupy_apply_csr(matrix, array, axis, out): + """Apply CSR matrix to arbitrary axis of array.""" + if not HAVE_CUPY: + raise ImportError("cupy must be installed to use GPU linear algebra") + # Check matrix format + if not isinstance(matrix, csp.csr_matrix): + raise ValueError("Matrix must be in CSR format.") + # Switch by dimension + ndim = array.ndim + if ndim == 1: + if axis == 0: + out[:] = cupy_apply_csr_vec(matrix, array) + else: + raise ValueError("axis must be 0 for 1D arrays") + elif ndim == 2: + if axis == 0: + if array.shape[1] == 1: + out[:,0] = cupy_apply_csr_vec(matrix, array[:,0]) + else: + out[:] = cupy_apply_csr_first(matrix, array) + elif axis == 1: + if array.shape[0] == 1: + out[0,:] = cupy_apply_csr_vec(matrix, array[0,:]) + else: + out[:] = cupy_apply_csr_last(matrix, array) + else: + raise ValueError("axis must be 0 or 1 for 2D arrays") + else: + # Treat as 3D array with specified axis in the middle + # Compute equivalent shape (N1, N2, N3) + if ndim == 3 and axis == 1: + N1 = array.shape[0] + N2 = array.shape[1] + N3 = array.shape[2] + else: + N1 = int(np.prod(array.shape[:axis])) + N2 = array.shape[axis] + N3 = int(np.prod(array.shape[axis+1:])) + # Dispatch to cupy routines + if N1 == 1: + if N3 == 1: + # (1, N2, 1) -> (N2,) + x1 = array.reshape((N2,)) + temp = cupy_apply_csr_vec(matrix, x1) + out[:] = temp.reshape(out.shape) + else: + # (1, N2, N3) -> (N2, N3) + x2 = array.reshape((N2, N3)) + temp = cupy_apply_csr_first(matrix, x2) + out[:] = temp.reshape(out.shape) + else: + if N3 == 1: + # (N1, N2, 1) -> (N1, N2) + x2 = array.reshape((N1, N2)) + temp = cupy_apply_csr_last(matrix, x2) + out[:] = temp.reshape(out.shape) + else: + # (N1, N2, N3) + x3 = array.reshape((N1, N2, N3)) + y3 = out.reshape(((N1, matrix.shape[0], N3))) + for n1 in range(N1): + y3[n1] = cupy_apply_csr_first(matrix, x3[n1]) + + +def cupy_apply_csr_vec(matrix, vec): + return matrix.dot(vec) + +def cupy_apply_csr_first(matrix, array): + return matrix.dot(array) + +def cupy_apply_csr_last(matrix, array): + return matrix.dot(array.T).T + + From adecfa9e650c86959831ae2cab57a46fd94b0d0d Mon Sep 17 00:00:00 2001 From: "Keaton J. Burns" Date: Tue, 22 Jul 2025 15:09:41 -0400 Subject: [PATCH 17/50] Make einsum in dot compatible with cupy --- dedalus/core/arithmetic.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/dedalus/core/arithmetic.py b/dedalus/core/arithmetic.py index b92b7ce0..3b8398c7 100644 --- a/dedalus/core/arithmetic.py +++ b/dedalus/core/arithmetic.py @@ -13,6 +13,7 @@ import numexpr as ne from collections import defaultdict from math import prod +import array_api_compat from .domain import Domain from .field import Operand, Field @@ -665,6 +666,7 @@ def GammaCoord(self, A_tensorsig, B_tensorsig, C_tensorsig): return G def operate(self, out): + xp = self.array_namespace arg0, arg1 = self.args out.preset_layout(arg0.layout) # Broadcast @@ -672,7 +674,11 @@ def operate(self, out): arg1_data = self.arg1_ghost_broadcaster.cast(arg1) # Call einsum if out.data.size: - np.einsum(self.einsum_str, arg0_data, arg1_data, out=out.data, optimize=True) + if array_api_compat.is_cupy_namespace(xp): + # Cupy does not support output keyword + out.data[:] = xp.einsum(self.einsum_str, arg0_data, arg1_data, optimize=True) + else: + xp.einsum(self.einsum_str, arg0_data, arg1_data, out=out.data, optimize=True) @alias("cross") From 6b9ff232af4805c8f308e1d3ca5c36c704f43fd3 Mon Sep 17 00:00:00 2001 From: "Keaton J. Burns" Date: Tue, 22 Jul 2025 15:10:06 -0400 Subject: [PATCH 18/50] Add custom kernel for cupy csr middle dot product --- dedalus/tools/linalg_gpu.py | 75 ++++++++++++++++++++++++++++--------- 1 file changed, 58 insertions(+), 17 deletions(-) diff --git a/dedalus/tools/linalg_gpu.py b/dedalus/tools/linalg_gpu.py index 095eabaf..ad0b5098 100644 --- a/dedalus/tools/linalg_gpu.py +++ b/dedalus/tools/linalg_gpu.py @@ -1,5 +1,7 @@ +"""Linear algebra routines using cupy.""" import numpy as np +import math try: import cupy as cp import cupyx.scipy.sparse as csp @@ -14,25 +16,27 @@ def cupy_apply_csr(matrix, array, axis, out): raise ImportError("cupy must be installed to use GPU linear algebra") # Check matrix format if not isinstance(matrix, csp.csr_matrix): - raise ValueError("Matrix must be in CSR format.") + # TODO: avoid this explicit conversion + matrix = csp.csr_matrix(matrix) + #raise ValueError("Matrix must be in CSR format.") # Switch by dimension ndim = array.ndim if ndim == 1: if axis == 0: - out[:] = cupy_apply_csr_vec(matrix, array) + out[:] = matrix.dot(array) else: raise ValueError("axis must be 0 for 1D arrays") elif ndim == 2: if axis == 0: if array.shape[1] == 1: - out[:,0] = cupy_apply_csr_vec(matrix, array[:,0]) + out[:,0] = matrix.dot(array[:,0]) else: - out[:] = cupy_apply_csr_first(matrix, array) + out[:] = matrix.dot(array) elif axis == 1: if array.shape[0] == 1: - out[0,:] = cupy_apply_csr_vec(matrix, array[0,:]) + out[0,:] = matrix.dot(array[0,:]) else: - out[:] = cupy_apply_csr_last(matrix, array) + out[:] = matrix.dot(array.T).T else: raise ValueError("axis must be 0 or 1 for 2D arrays") else: @@ -51,34 +55,71 @@ def cupy_apply_csr(matrix, array, axis, out): if N3 == 1: # (1, N2, 1) -> (N2,) x1 = array.reshape((N2,)) - temp = cupy_apply_csr_vec(matrix, x1) + temp = matrix.dot(x1) out[:] = temp.reshape(out.shape) else: # (1, N2, N3) -> (N2, N3) x2 = array.reshape((N2, N3)) - temp = cupy_apply_csr_first(matrix, x2) + temp = matrix.dot(x2) out[:] = temp.reshape(out.shape) else: if N3 == 1: # (N1, N2, 1) -> (N1, N2) x2 = array.reshape((N1, N2)) - temp = cupy_apply_csr_last(matrix, x2) + temp = matrix.dot(x2.T).T out[:] = temp.reshape(out.shape) else: # (N1, N2, N3) x3 = array.reshape((N1, N2, N3)) y3 = out.reshape(((N1, matrix.shape[0], N3))) - for n1 in range(N1): - y3[n1] = cupy_apply_csr_first(matrix, x3[n1]) + cupy_apply_csr_mid(matrix, x3, y3) -def cupy_apply_csr_vec(matrix, vec): - return matrix.dot(vec) +# Kernel for applying CSR matrix with parallelization over n1 and n3 +apply_csr_mid_kernel = cp.RawKernel( + r''' + extern "C" __global__ void apply_csr_mid_kernel( + const float* data, // CSR data of shape (nnz,) + const int* indices, // CSR column indices (nnz,) + const int* indptr, // CSR row pointers (N2o + 1,) + const float* input, // shape (N1, N2i, N3) + float* output, // shape (N1, N2o, N3) + int N1, int N2i, int N2o, int N3) + { + int n1 = blockIdx.x * blockDim.x + threadIdx.x ; // batch index + int n3 = blockIdx.y * blockDim.y + threadIdx.y; // output column index -def cupy_apply_csr_first(matrix, array): - return matrix.dot(array) + if (n1 >= N1 || n3 >= N3) return; -def cupy_apply_csr_last(matrix, array): - return matrix.dot(array.T).T + // Loop over output rows = CSR matrix rows + for (int i = 0; i < N2o; ++i) { + float acc = 0.0f; + int start = indptr[i]; + int end = indptr[i + 1]; + for (int k = start; k < end; ++k) { + int j = indices[k]; // input column + float val = data[k]; + acc += val * input[n1 * N2i * N3 + j * N3 + n3]; + } + + output[n1 * N2o * N3 + i * N3 + n3] = acc; + } + } + ''', + 'apply_csr_mid_kernel') + + +def cupy_apply_csr_mid(matrix, array, out): + N1, N2i, N3 = array.shape + N2o = matrix.shape[0] + # Choose thread/block config + threads_y = min(1024, N3) # maximize concurrency along n3 + threads_x = 1024 // threads_y # make block have 1024 threads + blockdim = (threads_x, threads_y) + blocks_x = (N1 + threads_x - 1) // threads_x + blocks_y = (N3 + threads_y - 1) // threads_y + griddim = (blocks_x, blocks_y) + # Launch kernel + apply_csr_mid_kernel(griddim, blockdim, (matrix.data, matrix.indices, matrix.indptr, array, out, N1, N2i, N2o, N3)) From 8ee70d35f2bc7f5c3fa1870397a85a3b0ca69e26 Mon Sep 17 00:00:00 2001 From: "Keaton J. Burns" Date: Tue, 22 Jul 2025 15:47:23 -0400 Subject: [PATCH 19/50] Convert local grids/modes to device arrays --- dedalus/core/distributor.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/dedalus/core/distributor.py b/dedalus/core/distributor.py index 9b2a3ff7..99ada2e9 100644 --- a/dedalus/core/distributor.py +++ b/dedalus/core/distributor.py @@ -261,11 +261,12 @@ def IdentityTensor(self, coordsys_in, coordsys_out=None, bases=None, dtype=None) return I def local_grid(self, basis, scale=None): + xp = self.array_namespace # TODO: remove from bases and do it all here? if scale is None: scale = 1 if basis.dim == 1: - return basis.local_grid(self, scale=scale) + return xp.asarray(basis.local_grid(self, scale=scale)) else: raise ValueError("Use `local_grids` for multidimensional bases.") @@ -298,16 +299,18 @@ def local_grid(self, basis, scale=None): # return tuple(grids) def local_grids(self, *bases, scales=None): + xp = self.array_namespace scales = self.remedy_scales(scales) grids = [] for basis in bases: basis_scales = scales[self.first_axis(basis):self.last_axis(basis)+1] - grids.extend(basis.local_grids(self, scales=basis_scales)) + grids.extend(xp.asarray(basis.local_grids(self, scales=basis_scales))) return grids def local_modes(self, basis): # TODO: remove from bases and do it all here? - return basis.local_modes(self) + xp = self.array_namespace + return xp.asarray(basis.local_modes(self)) @CachedAttribute def default_nonconst_groups(self): From 689e41fa23a331ba943125096a6e4dc2294e9f52 Mon Sep 17 00:00:00 2001 From: "Keaton J. Burns" Date: Tue, 22 Jul 2025 15:55:21 -0400 Subject: [PATCH 20/50] Explicitly cast data norms to float --- dedalus/core/field.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dedalus/core/field.py b/dedalus/core/field.py index 0e4fadd0..e9cb3ef0 100644 --- a/dedalus/core/field.py +++ b/dedalus/core/field.py @@ -854,7 +854,7 @@ def allreduce_data_norm(self, layout=None, order=2): if self.dist.comm.size > 1: norm = self.dist.comm.allreduce(norm, op=MPI.SUM) norm = norm ** (1 / order) - return norm + return float(norm) def allreduce_data_max(self, layout=None): return self.allreduce_data_norm(layout=layout, order=np.inf) From c226735e2cdca3a2a55d6b24dd43e64a226ad5f6 Mon Sep 17 00:00:00 2001 From: "Keaton J. Burns" Date: Tue, 22 Jul 2025 15:56:10 -0400 Subject: [PATCH 21/50] Cast grid spacing to device array in cartesian cfl --- dedalus/core/basis.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/dedalus/core/basis.py b/dedalus/core/basis.py index f7f63db6..b2d80870 100644 --- a/dedalus/core/basis.py +++ b/dedalus/core/basis.py @@ -6240,6 +6240,7 @@ class CartesianAdvectiveCFL(operators.AdvectiveCFL): @CachedMethod def cfl_spacing(self): + xp = self.array_namespace velocity = self.operand coordsys = velocity.tensorsig[0] spacing = [] @@ -6262,7 +6263,7 @@ def cfl_spacing(self): axis_spacing[:] = dealias * native_spacing * basis.COV.stretch elif basis is None: axis_spacing = np.inf - spacing.append(axis_spacing) + spacing.append(xp.asarray(axis_spacing)) return spacing def compute_cfl_frequency(self, velocity, out): From ad30363a1d26f0f495782dd895c28dc50e7e2822 Mon Sep 17 00:00:00 2001 From: "Keaton J. Burns" Date: Tue, 22 Jul 2025 16:02:30 -0400 Subject: [PATCH 22/50] Convert field data gathers to numpy on gpu --- dedalus/core/field.py | 20 ++++++++++++++++---- 1 file changed, 16 insertions(+), 4 deletions(-) diff --git a/dedalus/core/field.py b/dedalus/core/field.py index e9cb3ef0..682d246b 100644 --- a/dedalus/core/field.py +++ b/dedalus/core/field.py @@ -802,9 +802,15 @@ def allgather_data(self, layout=None): # Change layout if layout is not None: self.change_layout(layout) + # Convert to numpy if on GPU + xp = self.dist.array_namespace + if array_api_compat.is_cupy_namespace(xp): + data = xp.asnumpy(self.data) + else: + data = self.data.copy() # Shortcut for serial execution if self.dist.comm.size == 1: - return self.data.copy() + return data # Build global buffers tensor_shape = tuple(cs.dim for cs in self.tensorsig) global_shape = tensor_shape + self.layout.global_shape(self.domain, self.scales) @@ -813,7 +819,7 @@ def allgather_data(self, layout=None): recv_buff = np.empty_like(send_buff) # Combine data via allreduce -- easy but not communication-optimal # Should be optimized using Allgatherv if this is used past startup - send_buff[local_slices] = self.data + send_buff[local_slices] = data self.dist.comm.Allreduce(send_buff, recv_buff, op=MPI.SUM) return recv_buff @@ -821,13 +827,19 @@ def gather_data(self, root=0, layout=None): # Change layout if layout is not None: self.change_layout(layout) + # Convert to numpy if on GPU + xp = self.dist.array_namespace + if array_api_compat.is_cupy_namespace(xp): + data = xp.asnumpy(self.data) + else: + data = self.data.copy() # Shortcut for serial execution if self.dist.comm.size == 1: - return self.data.copy() + return data # TODO: Shortcut this for constant fields # Gather data # Should be optimized via Gatherv eventually - pieces = self.dist.comm.gather(self.data, root=root) + pieces = self.dist.comm.gather(data, root=root) # Assemble on root node if self.dist.comm.rank == root: ext_mesh = self.layout.ext_mesh From b7a7188f0ed57731538863f671aa745c3267ec96 Mon Sep 17 00:00:00 2001 From: "Keaton J. Burns" Date: Tue, 22 Jul 2025 16:24:55 -0400 Subject: [PATCH 23/50] Fix subsystem gather/scatter to copy to/from gpu --- dedalus/core/subsystems.py | 10 +++++----- dedalus/tools/array.py | 14 ++++++++++++++ 2 files changed, 19 insertions(+), 5 deletions(-) diff --git a/dedalus/core/subsystems.py b/dedalus/core/subsystems.py index e684c005..9de08cf6 100644 --- a/dedalus/core/subsystems.py +++ b/dedalus/core/subsystems.py @@ -13,7 +13,7 @@ from math import prod from .domain import Domain -from ..tools.array import zeros_with_pattern, expand_pattern, sparse_block_diag, copyto, perm_matrix, drop_empty_rows, apply_sparse, assert_sparse_pinv +from ..tools.array import zeros_with_pattern, expand_pattern, sparse_block_diag, copyto, perm_matrix, drop_empty_rows, apply_sparse, assert_sparse_pinv, copy_to_device, copy_from_device from ..tools.cache import CachedAttribute, CachedMethod from ..tools.general import replace, OrderedSet from ..tools.progress import log_progress @@ -342,7 +342,7 @@ def gather_inputs(self, fields, out=None): # Gather from fields views = self._input_field_views(tuple(fields)) for buffer_view, field_view in views: - np.copyto(buffer_view, field_view) + copy_from_device(buffer_view, field_view) # Apply right preconditioner inverse to compress inputs if out is None: out = self._compressed_buffer @@ -354,7 +354,7 @@ def gather_outputs(self, fields, out=None): # Gather from fields views = self._output_field_views(tuple(fields)) for buffer_view, field_view in views: - np.copyto(buffer_view, field_view) + copy_from_device(buffer_view, field_view) # Apply left preconditioner to compress outputs if out is None: out = self._compressed_buffer @@ -368,7 +368,7 @@ def scatter_inputs(self, data, fields): # Scatter to fields views = self._input_field_views(tuple(fields)) for buffer_view, field_view in views: - np.copyto(field_view, buffer_view) + copy_to_device(field_view, buffer_view) def scatter_outputs(self, data, fields): """Precondition and scatter subproblem data out to output-like field list.""" @@ -377,7 +377,7 @@ def scatter_outputs(self, data, fields): # Scatter to fields views = self._output_field_views(tuple(fields)) for buffer_view, field_view in views: - np.copyto(field_view, buffer_view) + copy_to_device(field_view, buffer_view) def inclusion_matrices(self, bases): """List of inclusion matrices.""" diff --git a/dedalus/tools/array.py b/dedalus/tools/array.py index 9749f66c..399b4950 100644 --- a/dedalus/tools/array.py +++ b/dedalus/tools/array.py @@ -363,6 +363,20 @@ def copyto(dest, src): dest[:] = src +def copy_to_device(dest, src): + if array_api_compat.is_cupy_array(dest): + dest.set(src) + else: + dest[:] = src + + +def copy_from_device(dest, src): + if array_api_compat.is_cupy_array(src): + src.get(out=dest) + else: + dest[:] = src + + def perm_matrix(perm, M=None, source_index=False, sparse=True): """ Build sparse permutation matrix from permutation vector. From ffe0719b6f7043ab5a4710c109ec8ed37a679620 Mon Sep 17 00:00:00 2001 From: "Keaton J. Burns" Date: Tue, 22 Jul 2025 17:17:45 -0400 Subject: [PATCH 24/50] Allow for non-contiguous device copy --- dedalus/tools/array.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/dedalus/tools/array.py b/dedalus/tools/array.py index 399b4950..bfef6a95 100644 --- a/dedalus/tools/array.py +++ b/dedalus/tools/array.py @@ -364,8 +364,10 @@ def copyto(dest, src): def copy_to_device(dest, src): - if array_api_compat.is_cupy_array(dest): - dest.set(src) + xp = array_api_compat.array_namespace(dest) + if array_api_compat.is_cupy_namespace(xp): + src = xp.asarray(src) + dest[:] = src else: dest[:] = src From 07f43fd587b6b4ed68a70b42583c0104dbebe422 Mon Sep 17 00:00:00 2001 From: "Keaton J. Burns" Date: Tue, 22 Jul 2025 17:18:20 -0400 Subject: [PATCH 25/50] Fix cupy csr kernel for double instead of float --- dedalus/tools/linalg_gpu.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/dedalus/tools/linalg_gpu.py b/dedalus/tools/linalg_gpu.py index ad0b5098..c4e1cde9 100644 --- a/dedalus/tools/linalg_gpu.py +++ b/dedalus/tools/linalg_gpu.py @@ -79,11 +79,11 @@ def cupy_apply_csr(matrix, array, axis, out): apply_csr_mid_kernel = cp.RawKernel( r''' extern "C" __global__ void apply_csr_mid_kernel( - const float* data, // CSR data of shape (nnz,) + const double* data, // CSR data of shape (nnz,) const int* indices, // CSR column indices (nnz,) const int* indptr, // CSR row pointers (N2o + 1,) - const float* input, // shape (N1, N2i, N3) - float* output, // shape (N1, N2o, N3) + const double* input, // shape (N1, N2i, N3) + double* output, // shape (N1, N2o, N3) int N1, int N2i, int N2o, int N3) { int n1 = blockIdx.x * blockDim.x + threadIdx.x ; // batch index @@ -93,13 +93,13 @@ def cupy_apply_csr(matrix, array, axis, out): // Loop over output rows = CSR matrix rows for (int i = 0; i < N2o; ++i) { - float acc = 0.0f; + double acc = 0; int start = indptr[i]; int end = indptr[i + 1]; for (int k = start; k < end; ++k) { int j = indices[k]; // input column - float val = data[k]; + double val = data[k]; acc += val * input[n1 * N2i * N3 + j * N3 + n3]; } From 799064ae063fe2d5fe55f81ae7cbf582436d9a31 Mon Sep 17 00:00:00 2001 From: "Keaton J. Burns" Date: Fri, 25 Jul 2025 12:00:46 -0400 Subject: [PATCH 26/50] Move subsystems, coeff systems, and matrices to GPU --- dedalus/core/subsystems.py | 67 ++++++++++++++++++++++++--------- dedalus/core/system.py | 31 ++++++++------- dedalus/core/timesteppers.py | 36 +++++++++--------- dedalus/libraries/matsolvers.py | 21 ++++++++++- dedalus/tools/array.py | 10 +++++ dedalus/tools/linalg_gpu.py | 6 +-- 6 files changed, 117 insertions(+), 54 deletions(-) diff --git a/dedalus/core/subsystems.py b/dedalus/core/subsystems.py index 9de08cf6..b8d962b3 100644 --- a/dedalus/core/subsystems.py +++ b/dedalus/core/subsystems.py @@ -11,6 +11,7 @@ from mpi4py import MPI import uuid from math import prod +import array_api_compat from .domain import Domain from ..tools.array import zeros_with_pattern, expand_pattern, sparse_block_diag, copyto, perm_matrix, drop_empty_rows, apply_sparse, assert_sparse_pinv, copy_to_device, copy_from_device @@ -18,6 +19,12 @@ from ..tools.general import replace, OrderedSet from ..tools.progress import log_progress +try: + import cupy as cp + import cupyx.scipy.sparse as csp +except ImportError: + pass + import logging logger = logging.getLogger(__name__.split('.')[-1]) @@ -118,6 +125,7 @@ def __init__(self, solver, group): self.solver = solver self.problem = problem = solver.problem self.dist = solver.dist + self.array_namespace = solver.dist.array_namespace self.dtype = problem.dtype self.group = group # Determine matrix group using solver matrix dependence @@ -191,11 +199,12 @@ def field_size(self, field): @CachedMethod def _gather_scatter_setup(self, fields): + xp = self.array_namespace # Allocate vector fsizes = tuple(self.field_size(f) for f in fields) fslices = tuple(self.field_slices(f) for f in fields) fshapes = tuple(self.field_shape(f) for f in fields) - data = np.empty(sum(fsizes), dtype=self.dtype) + data = xp.empty(sum(fsizes), dtype=self.dtype) # Make views into data fviews = [] i0 = 0 @@ -248,6 +257,7 @@ def __init__(self, solver, subsystems, group): self.subsystems = subsystems self.group = group self.dist = problem.dist + self.array_namespace = self.dist.array_namespace self.domain = problem.variables[0].domain # HACK self.dtype = problem.dtype # Cross reference from subsystems @@ -279,7 +289,8 @@ def size(self): @CachedAttribute def _compressed_buffer(self): - return np.zeros(self.shape, dtype=self.dtype) + xp = self.array_namespace + return xp.zeros(self.shape, dtype=self.dtype) def coeff_slices(self, domain): return self.subsystems[0].coeff_slices(domain) @@ -300,9 +311,10 @@ def field_size(self, field): return self.subsystems[0].field_size(field) def _build_buffer_views(self, fields): + xp = self.array_namespace # Allocate buffer fsizes = tuple(self.field_size(f) for f in fields) - buffer = np.zeros((sum(fsizes), len(self.subsystems)), dtype=self.dtype) + buffer = xp.zeros((sum(fsizes), len(self.subsystems)), dtype=self.dtype) # Make views into buffer views = [] i0 = 0 @@ -342,7 +354,7 @@ def gather_inputs(self, fields, out=None): # Gather from fields views = self._input_field_views(tuple(fields)) for buffer_view, field_view in views: - copy_from_device(buffer_view, field_view) + copyto(buffer_view, field_view) # Apply right preconditioner inverse to compress inputs if out is None: out = self._compressed_buffer @@ -354,7 +366,7 @@ def gather_outputs(self, fields, out=None): # Gather from fields views = self._output_field_views(tuple(fields)) for buffer_view, field_view in views: - copy_from_device(buffer_view, field_view) + copyto(buffer_view, field_view) # Apply left preconditioner to compress outputs if out is None: out = self._compressed_buffer @@ -368,7 +380,7 @@ def scatter_inputs(self, data, fields): # Scatter to fields views = self._input_field_views(tuple(fields)) for buffer_view, field_view in views: - copy_to_device(field_view, buffer_view) + copyto(field_view, buffer_view) def scatter_outputs(self, data, fields): """Precondition and scatter subproblem data out to output-like field list.""" @@ -377,7 +389,7 @@ def scatter_outputs(self, data, fields): # Scatter to fields views = self._output_field_views(tuple(fields)) for buffer_view, field_view in views: - copy_to_device(field_view, buffer_view) + copyto(field_view, buffer_view) def inclusion_matrices(self, bases): """List of inclusion matrices.""" @@ -555,24 +567,45 @@ def build_matrices(self, names): left_perm = left_permutation(self, eqns, bc_top=solver.bc_top, interleave_components=solver.interleave_components).tocsr() right_perm = right_permutation(self, vars, tau_left=solver.tau_left, interleave_components=solver.interleave_components).tocsr() - # Preconditioners + # Preconditioners on CPU # TODO: remove astype casting, requires dealing with used types in apply_sparse - self.pre_left = drop_empty_rows(left_perm @ valid_eqn).tocsr().astype(dtype) - self.pre_left_pinv = self.pre_left.T.tocsr().astype(dtype) - self.pre_right_pinv = drop_empty_rows(right_perm @ valid_var).tocsr().astype(dtype) - self.pre_right = self.pre_right_pinv.T.tocsr().astype(dtype) + pre_left = drop_empty_rows(left_perm @ valid_eqn).tocsr().astype(dtype) + pre_left_pinv = pre_left.T.tocsr().astype(dtype) + pre_right_pinv = drop_empty_rows(right_perm @ valid_var).tocsr().astype(dtype) + pre_right = pre_right_pinv.T.tocsr().astype(dtype) # Check preconditioner pseudoinverses - assert_sparse_pinv(self.pre_left, self.pre_left_pinv) - assert_sparse_pinv(self.pre_right, self.pre_right_pinv) + assert_sparse_pinv(pre_left, pre_left_pinv) + assert_sparse_pinv(pre_right, pre_right_pinv) # Precondition matrices for name in matrices: - matrices[name] = self.pre_left @ matrices[name] @ self.pre_right + matrices[name] = pre_left @ matrices[name] @ pre_right - # Store minimal CSR matrices for fast dot products + # Store minimal CSR matrices on CPU for name, matrix in matrices.items(): - setattr(self, '{:}_min'.format(name), matrix.tocsr()) + setattr(self, f'{name}_min', matrix.tocsr()) + + # Store device copies for fast dot products + xp = solver.dist.array_namespace + if array_api_compat.is_numpy_namespace(xp): + self.pre_left = pre_left + self.pre_left_pinv = pre_left_pinv + self.pre_right_pinv = pre_right_pinv + self.pre_right = pre_right + # Reference current CPU matrices + for name, matrix in matrices.items(): + setattr(self, f'{name}_min_device', getattr(self, f'{name}_min')) + elif array_api_compat.is_cupy_namespace(xp): + # Copy to device + self.pre_left = csp.csr_matrix(pre_left) + self.pre_left_pinv = csp.csr_matrix(pre_left_pinv) + self.pre_right_pinv = csp.csr_matrix(pre_right_pinv) + self.pre_right = csp.csr_matrix(pre_right) + for name, matrix in matrices.items(): + setattr(self, f'{name}_min_device', csp.csr_matrix(matrix)) + else: + raise ValueError("Unsupported array namespace: {}".format(xp)) # Store expanded CSR matrices for fast recombination if len(matrices) > 1: diff --git a/dedalus/core/system.py b/dedalus/core/system.py index 23cbb86b..f28cb206 100644 --- a/dedalus/core/system.py +++ b/dedalus/core/system.py @@ -12,45 +12,44 @@ class CoeffSystem: """ - Representation of a collection of fields that don't need to be transformed, - and are therefore stored as a contigous set of coefficient data for - efficient pencil and group manipulation. + Contiguous buffer for data from all subproblems. Parameters ---------- - nfields : int - Number of fields to represent - domain : domain object - Problem domain + subproblems : list of Subproblem objects + Subproblems to represent + dtype : dtype + Data type + array_namespace : array namespace + Array namespace Attributes ---------- data : ndarray - Contiguous buffer for field coefficients - - """ - - """ - var buffer - + Contiguous buffer for data from all subproblems + views : dict + Nested dictionary of views for each subproblem and subsystem """ - def __init__(self, subproblems, dtype): + def __init__(self, subproblems, dtype, array_namespace): + xp = array_namespace # Build buffer total_size = sum(sp.LHS.shape[1]*len(sp.subsystems) for sp in subproblems) - self.data = np.zeros(total_size, dtype=dtype) + self.data = xp.zeros(total_size, dtype=dtype) # Build views i0 = i1 = 0 self.views = views = {} for sp in subproblems: views[sp] = views_sp = {} + # View for each individual subsystem i00 = i0 for ss in sp.subsystems: i1 += sp.LHS.shape[1] views_sp[ss] = self.data[i0:i1] i0 = i1 i11 = i1 + # View combining all subsystems as rows in a matrix if i11 - i00 > 0: views_sp[None] = self.data[i00:i11].reshape((sp.LHS.shape[1], -1)) else: diff --git a/dedalus/core/timesteppers.py b/dedalus/core/timesteppers.py index 81da4c10..162a2d32 100644 --- a/dedalus/core/timesteppers.py +++ b/dedalus/core/timesteppers.py @@ -2,10 +2,9 @@ from collections import deque, OrderedDict import numpy as np -from scipy.linalg import blas from .system import CoeffSystem -from ..tools.array import apply_sparse +from ..tools.array import apply_sparse, get_axpy # Public interface @@ -71,7 +70,8 @@ class MultistepIMEX: def __init__(self, solver): self.solver = solver - self.RHS = CoeffSystem(solver.subproblems, dtype=solver.dtype) + xp = solver.dist.array_namespace + self.RHS = CoeffSystem(solver.subproblems, dtype=solver.dtype, array_namespace=xp) # Create deque for storing recent timesteps self.dt = deque([0.] * self.steps) @@ -81,16 +81,16 @@ def __init__(self, solver): self.LX = LX = deque() self.F = F = deque() for j in range(self.amax): - MX.append(CoeffSystem(solver.subproblems, dtype=solver.dtype)) + MX.append(CoeffSystem(solver.subproblems, dtype=solver.dtype, array_namespace=xp)) for j in range(self.bmax): - LX.append(CoeffSystem(solver.subproblems, dtype=solver.dtype)) + LX.append(CoeffSystem(solver.subproblems, dtype=solver.dtype, array_namespace=xp)) for j in range(self.cmax): - F.append(CoeffSystem(solver.subproblems, dtype=solver.dtype)) + F.append(CoeffSystem(solver.subproblems, dtype=solver.dtype, array_namespace=xp)) # Attributes self._iteration = 0 self._LHS_params = None - self.axpy = blas.get_blas_funcs('axpy', dtype=solver.dtype) + self.axpy = get_axpy(xp, solver.dtype) def step(self, dt, wall_time): """Advance solver by one timestep.""" @@ -143,8 +143,8 @@ def step(self, dt, wall_time): evaluator.require_coeff_space(state_fields) for sp in subproblems: spX = sp.gather_inputs(state_fields) - apply_sparse(sp.M_min, spX, axis=0, out=MX0.get_subdata(sp)) - apply_sparse(sp.L_min, spX, axis=0, out=LX0.get_subdata(sp)) + apply_sparse(sp.M_min_device, spX, axis=0, out=MX0.get_subdata(sp)) + apply_sparse(sp.L_min_device, spX, axis=0, out=LX0.get_subdata(sp)) # Evaluate F(X0) evaluator.evaluate_scheduled(iteration=iteration, wall_time=wall_time, sim_time=sim_time, timestep=dt) @@ -539,15 +539,16 @@ class RungeKuttaIMEX: def __init__(self, solver): self.solver = solver - self.RHS = CoeffSystem(solver.subproblems, dtype=solver.dtype) + xp = solver.dist.array_namespace + self.RHS = CoeffSystem(solver.subproblems, dtype=solver.dtype, array_namespace=xp) # Create coefficient systems for multistep history - self.MX0 = CoeffSystem(solver.subproblems, dtype=solver.dtype) - self.LX = [CoeffSystem(solver.subproblems, dtype=solver.dtype) for i in range(self.stages)] - self.F = [CoeffSystem(solver.subproblems, dtype=solver.dtype) for i in range(self.stages)] + self.MX0 = CoeffSystem(solver.subproblems, dtype=solver.dtype, array_namespace=xp) + self.LX = [CoeffSystem(solver.subproblems, dtype=solver.dtype, array_namespace=xp) for i in range(self.stages)] + self.F = [CoeffSystem(solver.subproblems, dtype=solver.dtype, array_namespace=xp) for i in range(self.stages)] self._LHS_params = None - self.axpy = blas.get_blas_funcs('axpy', dtype=solver.dtype) + self.axpy = get_axpy(xp, solver.dtype) def step(self, dt, wall_time): """Advance solver by one timestep.""" @@ -584,11 +585,12 @@ def step(self, dt, wall_time): # Compute M.X(n,0) and L.X(n,0) # Ensure coeff space before subsystem gathers + # TODO: add option to evaluate this matrix-free (e.g for high-bandwidth NCCs when using fast transforms) evaluator.require_coeff_space(state_fields) for sp in subproblems: spX = sp.gather_inputs(state_fields) - apply_sparse(sp.M_min, spX, axis=0, out=MX0.get_subdata(sp)) - apply_sparse(sp.L_min, spX, axis=0, out=LX0.get_subdata(sp)) + apply_sparse(sp.M_min_device, spX, axis=0, out=MX0.get_subdata(sp)) + apply_sparse(sp.L_min_device, spX, axis=0, out=LX0.get_subdata(sp)) # Compute stages # (M + k Hii L).X(n,i) = M.X(n,0) + k Aij F(n,j) - k Hij L.X(n,j) @@ -601,7 +603,7 @@ def step(self, dt, wall_time): evaluator.require_coeff_space(state_fields) for sp in subproblems: spX = sp.gather_inputs(state_fields) - apply_sparse(sp.L_min, spX, axis=0, out=LXi.get_subdata(sp)) + apply_sparse(sp.L_min_device, spX, axis=0, out=LXi.get_subdata(sp)) # Compute F(n,i-1), only doing output on first evaluation if i == 1: diff --git a/dedalus/libraries/matsolvers.py b/dedalus/libraries/matsolvers.py index f301d4f2..2544f8e9 100644 --- a/dedalus/libraries/matsolvers.py +++ b/dedalus/libraries/matsolvers.py @@ -5,7 +5,12 @@ import scipy.sparse as sp import scipy.sparse.linalg as spla from functools import partial - +import array_api_compat +try: + import cupyx.scipy.sparse.linalg as cupy_spla + cupy_available = True +except ImportError: + cupy_available = False matsolvers = {} def add_solver(solver): @@ -144,6 +149,17 @@ def __init__(self, matrix, solver=None): relax=self.relax, panel_size=self.panel_size, options=self.options) + # Cupy conversion + if array_api_compat.is_cupy_namespace(solver.dist.array_namespace): + # Avoid cupy splu which requires GPU matrices but transfers them to factorize on CPU + # Run same typecheck as cupy splu + if matrix.dtype.char not in 'fdFD': + raise TypeError('Invalid dtype (actual: {})'.format(self.LU.dtype)) + # Build cupy factorization from scipy factorization of CPU matrices + self.LU = cupy_spla.SuperLU(self.LU) + sp.save_npz("block1024.npz", matrix) + print(self.LU.shape) + print(self.LU.nnz) def solve(self, vector): return self.LU.solve(vector, trans=self.trans) @@ -225,6 +241,9 @@ class SparseInverse(SparseSolver): def __init__(self, matrix, solver=None): self.matrix_inverse = spla.inv(matrix.tocsc()) + # Cupy conversion + if array_api_compat.is_cupy_namespace(solver.dist.array_namespace): + self.matrix_inverse = cupy_spla.inv(matrix.tocsc()) def solve(self, vector): return self.matrix_inverse @ vector diff --git a/dedalus/tools/array.py b/dedalus/tools/array.py index bfef6a95..2d6a69e6 100644 --- a/dedalus/tools/array.py +++ b/dedalus/tools/array.py @@ -5,6 +5,7 @@ import scipy.sparse as sp from scipy.sparse import _sparsetools from scipy.sparse import linalg as spla +from scipy.linalg import blas from math import prod from ..tools import linalg_gpu import array_api_compat @@ -500,3 +501,12 @@ def assert_sparse_pinv(A, B): if not sparse_allclose((B @ A).conj().T, B @ A): raise AssertionError("Not a pseudoinverse") + +def get_axpy(array_namespace, dtype): + if array_api_compat.is_numpy_namespace(array_namespace): + return blas.get_blas_funcs('axpy', dtype=dtype) + elif array_api_compat.is_cupy_namespace(array_namespace): + from cupy.cublas import axpy as cublas_axpy + return cublas_axpy + else: + raise ValueError("Unsupported array namespace") diff --git a/dedalus/tools/linalg_gpu.py b/dedalus/tools/linalg_gpu.py index c4e1cde9..95ccfe52 100644 --- a/dedalus/tools/linalg_gpu.py +++ b/dedalus/tools/linalg_gpu.py @@ -5,14 +5,14 @@ try: import cupy as cp import cupyx.scipy.sparse as csp - HAVE_CUPY = True + cupy_available = True except ImportError: - HAVE_CUPY = False + cupy_available = False def cupy_apply_csr(matrix, array, axis, out): """Apply CSR matrix to arbitrary axis of array.""" - if not HAVE_CUPY: + if not cupy_available: raise ImportError("cupy must be installed to use GPU linear algebra") # Check matrix format if not isinstance(matrix, csp.csr_matrix): From d99b5b51762021301b7df7023c4d665ca3a30183 Mon Sep 17 00:00:00 2001 From: "Keaton J. Burns" Date: Fri, 25 Jul 2025 13:49:34 -0400 Subject: [PATCH 27/50] Build custom cupy superlu wrapper to reuse spsm descriptors --- dedalus/libraries/matsolvers.py | 10 +- dedalus/tools/linalg_gpu.py | 244 ++++++++++++++++++++++++++++++++ 2 files changed, 251 insertions(+), 3 deletions(-) diff --git a/dedalus/libraries/matsolvers.py b/dedalus/libraries/matsolvers.py index 2544f8e9..ede93a1b 100644 --- a/dedalus/libraries/matsolvers.py +++ b/dedalus/libraries/matsolvers.py @@ -157,9 +157,13 @@ def __init__(self, matrix, solver=None): raise TypeError('Invalid dtype (actual: {})'.format(self.LU.dtype)) # Build cupy factorization from scipy factorization of CPU matrices self.LU = cupy_spla.SuperLU(self.LU) - sp.save_npz("block1024.npz", matrix) - print(self.LU.shape) - print(self.LU.nnz) + self.LU.spsm_L_descr = None + self.LU.spsm_U_descr = None + self.solve = self.cupy_solve + + def cupy_solve(self, vector): + from dedalus.tools.linalg_gpu import custom_SuperLU_solve + return custom_SuperLU_solve(self.LU, vector, trans=self.trans) def solve(self, vector): return self.LU.solve(vector, trans=self.trans) diff --git a/dedalus/tools/linalg_gpu.py b/dedalus/tools/linalg_gpu.py index 95ccfe52..220d3615 100644 --- a/dedalus/tools/linalg_gpu.py +++ b/dedalus/tools/linalg_gpu.py @@ -123,3 +123,247 @@ def cupy_apply_csr_mid(matrix, array, out): # Launch kernel apply_csr_mid_kernel(griddim, blockdim, (matrix.data, matrix.indices, matrix.indptr, array, out, N1, N2i, N2o, N3)) + +def custom_spsm(a, b, alpha=1.0, lower=True, unit_diag=False, transa=False, spsm_descr=None): + """Custom spsm wrapper to save spsm_descr, since spsm_analysis takes lots of time.""" + """Solves a sparse triangular linear system op(a) * x = alpha * op(b). + + Args: + a (cupyx.scipy.sparse.csr_matrix or cupyx.scipy.sparse.coo_matrix): + Sparse matrix with dimension ``(M, M)``. + b (cupy.ndarray): Dense matrix with dimension ``(M, K)``. + alpha (float or complex): Coefficient. + lower (bool): + True: ``a`` is lower triangle matrix. + False: ``a`` is upper triangle matrix. + unit_diag (bool): + True: diagonal part of ``a`` has unit elements. + False: diagonal part of ``a`` has non-unit elements. + transa (bool or str): True, False, 'N', 'T' or 'H'. + 'N' or False: op(a) == ``a``. + 'T' or True: op(a) == ``a.T``. + 'H': op(a) == ``a.conj().T``. + """ + import cupyx + from cupyx import cusparse + import cupy as _cupy + import numpy as _numpy + from cupy._core import _dtype + from cupy_backends.cuda.libs import cusparse as _cusparse + from cupy.cuda import device as _device + from cupyx.cusparse import SpMatDescriptor, DnMatDescriptor + if not cusparse.check_availability('spsm'): + raise RuntimeError('spsm is not available.') + + # Canonicalise transa + if transa is False: + transa = 'N' + elif transa is True: + transa = 'T' + elif transa not in 'NTH': + raise ValueError(f'Unknown transa (actual: {transa})') + + # Check A's type and sparse format + if cupyx.scipy.sparse.isspmatrix_csr(a): + pass + elif cupyx.scipy.sparse.isspmatrix_csc(a): + if transa == 'N': + a = a.T + transa = 'T' + elif transa == 'T': + a = a.T + transa = 'N' + elif transa == 'H': + a = a.conj().T + transa = 'N' + lower = not lower + elif cupyx.scipy.sparse.isspmatrix_coo(a): + pass + else: + raise ValueError('a must be CSR, CSC or COO sparse matrix') + assert a.has_canonical_format + + # Check B's ndim + if b.ndim == 1: + is_b_vector = True + b = b.reshape(-1, 1) + elif b.ndim == 2: + is_b_vector = False + else: + raise ValueError('b.ndim must be 1 or 2') + + # Check shapes + if not (a.shape[0] == a.shape[1] == b.shape[0]): + raise ValueError('mismatched shape') + + # Check dtypes + dtype = a.dtype + if dtype.char not in 'fdFD': + raise TypeError('Invalid dtype (actual: {})'.format(dtype)) + if dtype != b.dtype: + raise TypeError('dtype mismatch') + + # Prepare fill mode + if lower is True: + fill_mode = _cusparse.CUSPARSE_FILL_MODE_LOWER + elif lower is False: + fill_mode = _cusparse.CUSPARSE_FILL_MODE_UPPER + else: + raise ValueError('Unknown lower (actual: {})'.format(lower)) + + # Prepare diag type + if unit_diag is False: + diag_type = _cusparse.CUSPARSE_DIAG_TYPE_NON_UNIT + elif unit_diag is True: + diag_type = _cusparse.CUSPARSE_DIAG_TYPE_UNIT + else: + raise ValueError('Unknown unit_diag (actual: {})'.format(unit_diag)) + + # Prepare op_a + if transa == 'N': + op_a = _cusparse.CUSPARSE_OPERATION_NON_TRANSPOSE + elif transa == 'T': + op_a = _cusparse.CUSPARSE_OPERATION_TRANSPOSE + else: # transa == 'H' + if dtype.char in 'fd': + op_a = _cusparse.CUSPARSE_OPERATION_TRANSPOSE + else: + op_a = _cusparse.CUSPARSE_OPERATION_CONJUGATE_TRANSPOSE + + # Prepare op_b + if b._f_contiguous: + op_b = _cusparse.CUSPARSE_OPERATION_NON_TRANSPOSE + elif b._c_contiguous: + if _cusparse.get_build_version() < 11701: # earlier than CUDA 11.6 + raise ValueError('b must be F-contiguous.') + b = b.T + op_b = _cusparse.CUSPARSE_OPERATION_TRANSPOSE + else: + raise ValueError('b must be F-contiguous or C-contiguous.') + + # Allocate space for matrix C. Note that it is known cusparseSpSM requires + # the output matrix zero initialized. + m, _ = a.shape + if op_b == _cusparse.CUSPARSE_OPERATION_NON_TRANSPOSE: + _, n = b.shape + else: + n, _ = b.shape + c_shape = m, n + c = _cupy.zeros(c_shape, dtype=a.dtype, order='f') + + # Prepare descriptors and other parameters + handle = _device.get_cusparse_handle() + mat_a = SpMatDescriptor.create(a) + mat_b = DnMatDescriptor.create(b) + mat_c = DnMatDescriptor.create(c) + if spsm_descr is None: + spsm_descr = _cusparse.spSM_createDescr() + new_spsm_descr = True + else: + spsm_descr, buff = spsm_descr + new_spsm_descr = False + alpha = _numpy.array(alpha, dtype=c.dtype).ctypes + cuda_dtype = _dtype.to_cuda_dtype(c.dtype) + algo = _cusparse.CUSPARSE_SPSM_ALG_DEFAULT + + try: + # Specify Lower|Upper fill mode + mat_a.set_attribute(_cusparse.CUSPARSE_SPMAT_FILL_MODE, fill_mode) + + # Specify Unit|Non-Unit diagonal type + mat_a.set_attribute(_cusparse.CUSPARSE_SPMAT_DIAG_TYPE, diag_type) + + # Allocate the workspace needed by the succeeding phases + if new_spsm_descr: + buff_size = _cusparse.spSM_bufferSize( + handle, op_a, op_b, alpha.data, mat_a.desc, mat_b.desc, + mat_c.desc, cuda_dtype, algo, spsm_descr) + buff = _cupy.empty(buff_size, dtype=_cupy.int8) + + # Perform the analysis phase + if new_spsm_descr: + _cusparse.spSM_analysis( + handle, op_a, op_b, alpha.data, mat_a.desc, mat_b.desc, + mat_c.desc, cuda_dtype, algo, spsm_descr, buff.data.ptr) + + # Executes the solve phase + _cusparse.spSM_solve( + handle, op_a, op_b, alpha.data, mat_a.desc, mat_b.desc, + mat_c.desc, cuda_dtype, algo, spsm_descr, buff.data.ptr) + + # Reshape back if B was a vector + if is_b_vector: + c = c.reshape(-1) + + return c, (spsm_descr, buff) + + finally: + # Destroy matrix/vector descriptors + #_cusparse.spSM_destroyDescr(spsm_descr) + pass + + +def custom_SuperLU_solve(self, rhs, trans='N', spsm_descr=None): + """Custom SuperLU solve wrapper to save spsm_descr, since spsm_analysis takes lots of time.""" + """Solves linear system of equations with one or several right-hand sides. + + Args: + rhs (cupy.ndarray): Right-hand side(s) of equation with dimension + ``(M)`` or ``(M, K)``. + trans (str): 'N', 'T' or 'H'. + 'N': Solves ``A * x = rhs``. + 'T': Solves ``A.T * x = rhs``. + 'H': Solves ``A.conj().T * x = rhs``. + + Returns: + cupy.ndarray: + Solution vector(s) + """ # NOQA + from cupyx import cusparse + import cupy + from cupyx.scipy.sparse.linalg._solve import _should_use_spsm + + if not isinstance(rhs, cupy.ndarray): + raise TypeError('ojb must be cupy.ndarray') + if rhs.ndim not in (1, 2): + raise ValueError('rhs.ndim must be 1 or 2 (actual: {})'. + format(rhs.ndim)) + if rhs.shape[0] != self.shape[0]: + raise ValueError('shape mismatch (self.shape: {}, rhs.shape: {})' + .format(self.shape, rhs.shape)) + if trans not in ('N', 'T', 'H'): + raise ValueError('trans must be \'N\', \'T\', or \'H\'') + + if cusparse.check_availability('spsm') and _should_use_spsm(rhs): + def spsm(A, B, lower, transa, spsm_descr): + return custom_spsm(A, B, lower=lower, transa=transa, spsm_descr=spsm_descr) + sm = spsm + else: + raise NotImplementedError + + x = rhs.astype(self.L.dtype) + if trans == 'N': + if self.perm_r is not None: + if x.ndim == 2 and x._f_contiguous: + x = x.T[:, self._perm_r_rev].T # want to keep f-order + else: + x = x[self._perm_r_rev] + x, self.spsm_L_descr = sm(self.L, x, lower=True, transa=trans, spsm_descr=self.spsm_L_descr) + x, self.spsm_U_descr = sm(self.U, x, lower=False, transa=trans, spsm_descr=self.spsm_U_descr) + if self.perm_c is not None: + x = x[self.perm_c] + else: + if self.perm_c is not None: + if x.ndim == 2 and x._f_contiguous: + x = x.T[:, self._perm_c_rev].T # want to keep f-order + else: + x = x[self._perm_c_rev] + x, self.spsm_U_descr = sm(self.U, x, lower=False, transa=trans, spsm_descr=self.spsm_U_descr) + x, self.spsm_L_descr = sm(self.L, x, lower=True, transa=trans, spsm_descr=self.spsm_L_descr) + if self.perm_r is not None: + x = x[self.perm_r] + + if not x._f_contiguous: + # For compatibility with SciPy + x = x.copy(order='F') + return x From ef66c9910f66ccf8c5d66df7e564d631cc88a1b1 Mon Sep 17 00:00:00 2001 From: "Keaton J. Burns" Date: Fri, 25 Jul 2025 16:46:49 -0400 Subject: [PATCH 28/50] Move all operator matrices to device. Add Chebyshev transforms --- dedalus/core/basis.py | 8 +- dedalus/core/operators.py | 17 +++- dedalus/core/transforms.py | 74 +++++++++++++-- dedalus/tools/array.py | 72 ++++++++++----- dedalus/tools/linalg_gpu.py | 176 ++++++++++++++++++++++++++++++++++++ 5 files changed, 311 insertions(+), 36 deletions(-) diff --git a/dedalus/core/basis.py b/dedalus/core/basis.py index b2d80870..f323a014 100644 --- a/dedalus/core/basis.py +++ b/dedalus/core/basis.py @@ -663,7 +663,13 @@ def _native_grid(self, scale): @CachedMethod def transform_plan(self, dist, grid_size): """Build transform plan.""" - return self.transforms[self.library](grid_size, self.size, self.a, self.b, self.a0, self.b0) + xp = dist.array_namespace + xp_name = xp.__name__.split('.')[-1] + # Shortcut trivial transforms + if grid_size == 1 or self.size == 1: + return self.transforms[f"matrix-{xp_name}"](grid_size, self.size, self.a, self.b, self.a0, self.b0, dist.array_namespace, dist.dtype) + else: + return self.transforms[f"{self.library}-{xp_name}"](grid_size, self.size, self.a, self.b, self.a0, self.b0, dist.array_namespace, dist.dtype) # def weights(self, scales): # """Gauss-Jacobi weights.""" diff --git a/dedalus/core/operators.py b/dedalus/core/operators.py index 662d2e1d..c0d3416e 100644 --- a/dedalus/core/operators.py +++ b/dedalus/core/operators.py @@ -15,6 +15,7 @@ from math import prod from ..libraries import dedalus_sphere import logging +import array_api_compat logger = logging.getLogger(__name__.split('.')[-1]) from .domain import Domain @@ -967,6 +968,20 @@ def subspace_matrix(self, layout): # Caching layer to allow insertion of other arguments return self._subspace_matrix(layout, self.input_basis, self.output_basis, self.first_axis) + @CachedMethod + def subspace_matrix_device(self, layout): + """Build matrix operating on local subspace data on device.""" + # Caching layer to allow insertion of other arguments + matrix = self._subspace_matrix(layout, self.input_basis, self.output_basis, self.first_axis) + if array_api_compat.is_cupy_namespace(self.array_namespace): + import cupy as cp + import cupyx.scipy.sparse as csp + if sparse.issparse(matrix): + matrix = csp.csr_matrix(matrix) + else: + matrix = cp.array(matrix) + return matrix + def group_matrix(self, group): return self._group_matrix(group, self.input_basis, self.output_basis) @@ -1007,7 +1022,7 @@ def operate(self, out): # Apply matrix if arg.data.size and out.data.size: data_axis = self.last_axis + len(arg.tensorsig) - apply_matrix(self.subspace_matrix(layout), arg.data, data_axis, out=out.data) + apply_matrix(self.subspace_matrix_device(layout), arg.data, data_axis, out=out.data) else: out.data.fill(0) diff --git a/dedalus/core/transforms.py b/dedalus/core/transforms.py index 5616c1c7..0433d802 100644 --- a/dedalus/core/transforms.py +++ b/dedalus/core/transforms.py @@ -8,14 +8,16 @@ import scipy.fftpack from ..libraries import dedalus_sphere from math import prod +import array_api_compat from . import basis from ..libraries.fftw import fftw_wrappers as fftw from ..tools import jacobi -from ..tools.array import apply_matrix, apply_dense, axslice, solve_upper_sparse, apply_sparse +from ..tools.array import apply_matrix, apply_dense, axslice, solve_upper_sparse, apply_sparse, copyto from ..tools.cache import CachedAttribute from ..tools.cache import CachedMethod from ..tools.general import float_to_complex +from ..tools.linalg_gpu import cupy_solve_upper_csr, CustomCupyUpperTriangularSolver import logging logger = logging.getLogger(__name__.split('.')[-1]) @@ -94,31 +96,39 @@ class JacobiTransform(SeparableTransform): Jacobi "a" parameter for the quadrature grid. b0 : int Jacobi "b" parameter for the quadrature grid. + array_namespace : array namespace + Array namespace for the transform. + dtype : dtype + Data type for the transform. Notes ----- TODO: We need to define the normalization we use here. """ - def __init__(self, grid_size, coeff_size, a, b, a0, b0, dealias_before_converting=None): + def __init__(self, grid_size, coeff_size, a, b, a0, b0, array_namespace, dtype, dealias_before_converting=None): self.N = grid_size self.M = coeff_size self.a = a self.b = b self.a0 = a0 self.b0 = b0 + self.array_namespace = array_namespace + self.dtype = dtype if dealias_before_converting is None: dealias_before_converting = GET_DEALIAS_BEFORE_CONVERTING() self.dealias_before_converting = dealias_before_converting -@register_transform(basis.Jacobi, 'matrix') +@register_transform(basis.Jacobi, 'matrix-numpy') +@register_transform(basis.Jacobi, 'matrix-cupy') class JacobiMMT(JacobiTransform, SeparableMatrixTransform): """Jacobi polynomial MMTs.""" @CachedAttribute def forward_matrix(self): """Build forward transform matrix.""" + xp = self.array_namespace N, M = self.N, self.M a, a0 = self.a, self.a0 b, b0 = self.b, self.b0 @@ -142,11 +152,12 @@ def forward_matrix(self): # Truncate to specified coeff_size forward_matrix = forward_matrix[:M, :] # Ensure C ordering for fast dot products - return np.asarray(forward_matrix, order='C') + return xp.asarray(forward_matrix, order='C', dtype=self.dtype) @CachedAttribute def backward_matrix(self): """Build backward transform matrix.""" + xp = self.array_namespace N, M = self.N, self.M a, a0 = self.a, self.a0 b, b0 = self.b, self.b0 @@ -156,7 +167,7 @@ def backward_matrix(self): # Zero higher polynomials for transforms with grid_size < coeff_size polynomials[N:, :] = 0 # Transpose and ensure C ordering for fast dot products - return np.asarray(polynomials.T, order='C') + return xp.asarray(polynomials.T, order='C', dtype=self.dtype) class ComplexFourierTransform(SeparableTransform): @@ -848,6 +859,33 @@ def backward(self, cdata, gdata, axis): np.copyto(gdata, temp) +class CupyDCT(FastCosineTransform): + """Fast cosine transform using cupy fft.""" + + def __init__(self, *args, **kw): + import cupyx.scipy.fft as cufft + self.cufft = cufft + super().__init__(*args, **kw) + + def forward(self, gdata, cdata, axis): + """Apply forward transform along specified axis.""" + # Call DCT + temp = self.cufft.dct(gdata, type=2, axis=axis) # Creates temporary + # Resize and rescale for unit-ampltidue normalization + self.resize_rescale_forward(temp, cdata, axis, self.Kmax) + + def backward(self, cdata, gdata, axis): + """Apply backward transform along specified axis.""" + xp = self.array_namespace + # Resize and rescale for unit-amplitude normalization + # Need temporary to avoid overwriting problems + temp = xp.empty_like(gdata) # Creates temporary + self.resize_rescale_backward(cdata, temp, axis, self.Kmax) + # Call IDCT + temp = self.cufft.dct(temp, type=3, axis=axis, overwrite_x=True) # Creates temporary + copyto(gdata, temp) + + #@register_transform(basis.Cosine, 'fftw') class FFTWDCT(FFTWBase, FastCosineTransform): """Fast cosine transform using FFTW.""" @@ -884,11 +922,11 @@ class FastChebyshevTransform(JacobiTransform): Subclasses should inherit from this class, then a FastCosineTransform subclass. """ - def __init__(self, grid_size, coeff_size, a, b, a0, b0, **kw): + def __init__(self, grid_size, coeff_size, a, b, a0, b0, array_namespace, dtype, **kw): if not a0 == b0 == -1/2: raise ValueError("Fast Chebshev transform requires a0 == b0 == -1/2.") # Jacobi initialization - super().__init__(grid_size, coeff_size, a, b, a0, b0, **kw) + super().__init__(grid_size, coeff_size, a, b, a0, b0, array_namespace, dtype, **kw) # DCT initialization to set scaling factors if a != a0 or b != b0: # Modify coeff_size to avoid truncation before conversion @@ -920,6 +958,13 @@ def __init__(self, grid_size, coeff_size, a, b, a0, b0, **kw): self.backward_conversion.sum_duplicates() # for faster solve_upper self.resize_rescale_forward = self._resize_rescale_forward_convert self.resize_rescale_backward = self._resize_rescale_backward_convert + if array_api_compat.is_cupy_namespace(self.array_namespace): + import cupyx.scipy.sparse as csp + self.forward_conversion = csp.csr_matrix(self.forward_conversion) + self.backward_conversion = csp.csr_matrix(self.backward_conversion) + self.forward_conversion.sum_duplicates() + self.backward_conversion.sum_duplicates() + self.backward_conversion_LU = CustomCupyUpperTriangularSolver(self.backward_conversion) def _resize_rescale_forward(self, data_in, data_out, axis, Kmax): """Resize by padding/trunction and rescale to unit amplitude.""" @@ -961,7 +1006,10 @@ def _resize_rescale_backward_convert(self, data_in, data_out, axis, Kmax_DCT): # Truncate input before conversion data_in[badfreq] = 0 # Ultraspherical conversion - solve_upper_sparse(self.backward_conversion, data_in, axis, out=data_in) + if array_api_compat.is_cupy_namespace(self.array_namespace): + cupy_solve_upper_csr(self.backward_conversion_LU, data_in, axis, out=data_in) + else: + solve_upper_sparse(self.backward_conversion, data_in, axis, out=data_in) # Change sign of odd modes if Kmax_orig > 0: posfreq_odd = axslice(axis, 1, Kmax_orig+1, 2) @@ -970,18 +1018,24 @@ def _resize_rescale_backward_convert(self, data_in, data_out, axis, Kmax_DCT): super().resize_rescale_backward(data_in, data_out, axis, Kmax_orig) -@register_transform(basis.Jacobi, 'scipy_dct') +@register_transform(basis.Jacobi, 'scipy_dct-numpy') class ScipyFastChebyshevTransform(FastChebyshevTransform, ScipyDCT): """Fast ultraspherical transform using scipy.fft and spectral conversion.""" pass # Implementation is complete via inheritance -@register_transform(basis.Jacobi, 'fftw_dct') +@register_transform(basis.Jacobi, 'fftw_dct-numpy') class FFTWFastChebyshevTransform(FastChebyshevTransform, FFTWDCT): """Fast ultraspherical transform using scipy.fft and spectral conversion.""" pass # Implementation is complete via inheritance +@register_transform(basis.Jacobi, 'scipy_dct-cupy') +class CupyFastChebyshevTransform(FastChebyshevTransform, CupyDCT): + """Fast ultraspherical transform using cupy fft and spectral conversion.""" + pass # Implementation is complete via inheritance + + # class ScipyDST(PolynomialTransform): # def forward_reduced(self): diff --git a/dedalus/tools/array.py b/dedalus/tools/array.py index 2d6a69e6..e137f75d 100644 --- a/dedalus/tools/array.py +++ b/dedalus/tools/array.py @@ -79,10 +79,20 @@ def expand_pattern(input, pattern): def apply_matrix(matrix, array, axis, **kw): """Apply matrix along any axis of an array.""" - if sparse.isspmatrix(matrix): - return apply_sparse(matrix, array, axis, **kw) + xp = array_api_compat.array_namespace(array) + if array_api_compat.is_numpy_namespace(xp): + if sparse.issparse(matrix): + return apply_sparse(matrix, array, axis, **kw) + else: + return apply_dense(matrix, array, axis, **kw) + elif array_api_compat.is_cupy_namespace(xp): + import cupyx.scipy.sparse as csp + if csp.issparse(matrix): + return apply_sparse(matrix, array, axis, **kw) + else: + return apply_dense(matrix, array, axis, **kw) else: - return apply_dense(matrix, array, axis, **kw) + raise ValueError("Unsupported array type") def apply_dense_einsum(matrix, array, axis, optimize=True, **kw): @@ -177,6 +187,8 @@ def apply_sparse(matrix, array, axis, out=None, check_shapes=False, num_threads= Must be out of place if ouptut is specified. """ xp = array_api_compat.array_namespace(array) + matrix.sum_duplicates() + matrix.has_canonical_format = True # Check output if out is None: out_shape = list(array.shape) @@ -219,28 +231,40 @@ def solve_upper_sparse(matrix, rhs, axis, out=None, check_shapes=False, num_thre Solve upper triangular sparse matrix along any axis of an array. Matrix assumed to be nonzero on the diagonals. """ - # Check matrix - if not isinstance(matrix, sparse.csr_matrix): - raise ValueError("Matrix must be in CSR format.") - if not matrix._has_canonical_format: # avoid property hook (without underscore) - matrix.sum_duplicates() - # Setup output = rhs + xp = array_api_compat.array_namespace(rhs) + matrix.sum_duplicates() + matrix.has_canonical_format = True + # Check output if out is None: - out = np.copy(rhs) - elif out is not rhs: - np.copyto(out, rhs) - # Promote datatypes - matrix_data = matrix.data - if matrix_data.dtype != rhs.dtype: - matrix_data = matrix_data.astype(rhs.dtype) - # Check shapes - if check_shapes: - if not (0 <= axis < rhs.ndim): - raise ValueError("Axis out of bounds.") - if not (matrix.shape[0] == matrix.shape[1] == rhs.shape[axis]): - raise ValueError("Matrix shape mismatch.") - # Call cython routine - cython_linalg.solve_upper_csr(matrix.indptr, matrix.indices, matrix_data, out, axis, num_threads) + out = xp.empty_like(rhs) + # Dispatch on array type + if array_api_compat.is_numpy_namespace(xp): + # Check matrix + if not isinstance(matrix, sparse.csr_matrix): + raise ValueError("Matrix must be in CSR format.") + if not matrix._has_canonical_format: # avoid property hook (without underscore) + matrix.sum_duplicates() + # Setup output = rhs + copyto(out, rhs) + # Promote datatypes + matrix_data = matrix.data + if matrix_data.dtype != rhs.dtype: + matrix_data = matrix_data.astype(rhs.dtype) + # Check shapes + if check_shapes: + if not (0 <= axis < rhs.ndim): + raise ValueError("Axis out of bounds.") + if not (matrix.shape[0] == matrix.shape[1] == rhs.shape[axis]): + raise ValueError("Matrix shape mismatch.") + # Call cython routine + cython_linalg.solve_upper_csr(matrix.indptr, matrix.indices, matrix_data, out, axis, num_threads) + elif array_api_compat.is_cupy_namespace(xp): + if not matrix._has_canonical_format: # avoid property hook (without underscore) + matrix.sum_duplicates() + linalg_gpu.cupy_solve_upper_csr(matrix, rhs, axis, out) + else: + raise ValueError("Unsupported array type") + return out def csr_matvec(A_csr, x_vec, out_vec): diff --git a/dedalus/tools/linalg_gpu.py b/dedalus/tools/linalg_gpu.py index 220d3615..a64f51f1 100644 --- a/dedalus/tools/linalg_gpu.py +++ b/dedalus/tools/linalg_gpu.py @@ -5,6 +5,7 @@ try: import cupy as cp import cupyx.scipy.sparse as csp + import cupyx.scipy.sparse.linalg as cupy_spla cupy_available = True except ImportError: cupy_available = False @@ -17,6 +18,7 @@ def cupy_apply_csr(matrix, array, axis, out): # Check matrix format if not isinstance(matrix, csp.csr_matrix): # TODO: avoid this explicit conversion + print('WARNING: converting matrix to CSR format') matrix = csp.csr_matrix(matrix) #raise ValueError("Matrix must be in CSR format.") # Switch by dimension @@ -367,3 +369,177 @@ def spsm(A, B, lower, transa, spsm_descr): # For compatibility with SciPy x = x.copy(order='F') return x + + +class CustomCupyUpperTriangularSolver: + """Hacky class to save spsm_descr for reuse in spsm for triangular solves.""" + + def __init__(self, matrix): + # Check matrix format + if not isinstance(matrix, csp.csr_matrix): + # TODO: avoid this explicit conversion + matrix = csp.csr_matrix(matrix) + print('WARNING: converting matrix to CSR format') + #raise ValueError("Matrix must be in CSR format.") + self.matrix = matrix + self.spsm_descr = None + + def solve(self, b, lower=True, overwrite_A=False, overwrite_b=False, + unit_diagonal=False): + """Solves a sparse triangular system ``A x = b``. + + Args: + A (cupyx.scipy.sparse.spmatrix): + Sparse matrix with dimension ``(M, M)``. + b (cupy.ndarray): + Dense vector or matrix with dimension ``(M)`` or ``(M, K)``. + lower (bool): + Whether ``A`` is a lower or upper triangular matrix. + If True, it is lower triangular, otherwise, upper triangular. + overwrite_A (bool): + (not supported) + overwrite_b (bool): + Allows overwriting data in ``b``. + unit_diagonal (bool): + If True, diagonal elements of ``A`` are assumed to be 1 and will + not be referenced. + + Returns: + cupy.ndarray: + Solution to the system ``A x = b``. The shape is the same as ``b``. + """ + from cupyx import cusparse + from cupyx.scipy import sparse + import cupy + from cupyx.scipy.sparse.linalg._solve import _should_use_spsm + + A = self.matrix + + if not (cusparse.check_availability('spsm') or + cusparse.check_availability('csrsm2')): + raise NotImplementedError + + if not sparse.isspmatrix(A): + raise TypeError('A must be cupyx.scipy.sparse.spmatrix') + if not isinstance(b, cupy.ndarray): + raise TypeError('b must be cupy.ndarray') + if A.shape[0] != A.shape[1]: + raise ValueError(f'A must be a square matrix (A.shape: {A.shape})') + if b.ndim not in [1, 2]: + raise ValueError(f'b must be 1D or 2D array (b.shape: {b.shape})') + if A.shape[0] != b.shape[0]: + raise ValueError('The size of dimensions of A must be equal to the ' + 'size of the first dimension of b ' + f'(A.shape: {A.shape}, b.shape: {b.shape})') + if A.dtype.char not in 'fdFD': + raise TypeError(f'unsupported dtype (actual: {A.dtype})') + + if cusparse.check_availability('spsm') and _should_use_spsm(b): + if not (sparse.isspmatrix_csr(A) or + sparse.isspmatrix_csc(A) or + sparse.isspmatrix_coo(A)): + warnings.warn('CSR, CSC or COO format is required. Converting to ' + 'CSR format.', sparse.SparseEfficiencyWarning) + A = A.tocsr() + A.sum_duplicates() + x, self.spsm_descr = custom_spsm(A, b, lower=lower, unit_diag=unit_diagonal, spsm_descr=self.spsm_descr) + elif cusparse.check_availability('csrsm2'): + if not (sparse.isspmatrix_csr(A) or sparse.isspmatrix_csc(A)): + warnings.warn('CSR or CSC format is required. Converting to CSR ' + 'format.', sparse.SparseEfficiencyWarning) + A = A.tocsr() + A.sum_duplicates() + + if (overwrite_b and A.dtype == b.dtype and + (b._c_contiguous or b._f_contiguous)): + x = b + else: + x = b.astype(A.dtype, copy=True) + + cusparse.csrsm2(A, x, lower=lower, unit_diag=unit_diagonal) + else: + assert False + + if x.dtype.char in 'fF': + # Note: This is for compatibility with SciPy. + dtype = numpy.promote_types(x.dtype, 'float64') + x = x.astype(dtype) + return x + + +def cupy_solve_upper_csr(matrix, array, axis, out): + """Solve upper triangular CSR matrix along specified axis of an array.""" + # Switch by dimension + ndim = array.ndim + if ndim == 1: + if axis == 0: + cupy_solve_upper_csr_vec(matrix, array, out) + else: + raise ValueError("axis must be 0 for 1D arrays") + elif ndim == 2: + if axis == 0: + if array.shape[1] == 1: + cupy_solve_upper_csr_vec(matrix, array[:,0], out[:,0]) + else: + cupy_solve_upper_csr_first(matrix, array, out) + elif axis == 1: + if array.shape[0] == 1: + cupy_solve_upper_csr_vec(matrix, array[0,:], out[0,:]) + else: + cupy_solve_upper_csr_last(matrix, array, out) + else: + raise ValueError("axis must be 0 or 1 for 2D arrays") + else: + # Treat as 3D array with specified axis in the middle + # Compute equivalent shape (N1, N2, N3) + if ndim == 3 and axis == 1: + N1 = shape[0] + N2 = shape[1] + N3 = shape[2] + else: + N1 = int(np.prod(array.shape[:axis])) + N2 = array.shape[axis] + N3 = int(np.prod(array.shape[axis+1:])) + # Dispatch to cupy routines + if N1 == 1: + if N3 == 1: + # (1, N2, 1) -> (N2,) + x1 = array.reshape((N2,)) + y1 = out.reshape((N2,)) + cupy_solve_upper_csr_vec(matrix, x1, y1) + else: + # (1, N2, N3) -> (N2, N3) + x2 = array.reshape((N2, N3)) + y2 = out.reshape((N2, N3)) + cupy_solve_upper_csr_first(matrix, x2, y2) + else: + if N3 == 1: + # (N1, N2, 1) -> (N1, N2) + x2 = array.reshape((N1, N2)) + y2 = out.reshape((N1, N2)) + cupy_solve_upper_csr_last(matrix, x2, y2) + else: + # (N1, N2, N3) + x3 = array.reshape((N1, N2, N3)) + y3 = out.reshape((N1, N2, N3)) + cupy_solve_upper_csr_mid(matrix, x3, y3) + + +def cupy_solve_upper_csr_vec(matrix, vec, out): + """Solve upper triangular CSR matrix along a vector.""" + out[:] = matrix.solve(vec, lower=False) + + +def cupy_solve_upper_csr_first(matrix, array, out): + """Solve upper triangular CSR matrix along first axis of 2D array.""" + out[:] = matrix.solve(array, lower=False) + + +def cupy_solve_upper_csr_last(matrix, array, out): + """Solve upper triangular CSR matrix along last axis of 2D array.""" + out.T[:] = matrix.solve(array.T, lower=False) + + +def cupy_solve_upper_csr_mid(matrix, array, out): + """Solve upper triangular CSR matrix along middle axis of 3D array.""" + raise NotImplementedError From 746a2a6c687ee436a502b393d0b422954a3d4aee Mon Sep 17 00:00:00 2001 From: "Calum S. Skene" Date: Tue, 16 Dec 2025 11:51:26 +0000 Subject: [PATCH 29/50] Make einsum in trace compatible with cupy --- dedalus/core/operators.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/dedalus/core/operators.py b/dedalus/core/operators.py index c0d3416e..5540e86a 100644 --- a/dedalus/core/operators.py +++ b/dedalus/core/operators.py @@ -1817,8 +1817,10 @@ def operate(self, out): xp = self.array_namespace arg = self.args[0] out.preset_layout(arg.layout) - xp.einsum('ii...', arg.data, out=out.data) - + if array_api_compat.is_cupy_namespace(xp): + out.data[:] = xp.einsum('ii...', arg.data) + else: + xp.einsum('ii...', arg.data, out=out.data) class SphericalTrace(Trace): From 18e713a56208b7c00dd2d9c5ba078da7fe727a59 Mon Sep 17 00:00:00 2001 From: "Calum S. Skene" Date: Tue, 10 Mar 2026 15:26:12 +0000 Subject: [PATCH 30/50] Fix dtype for MultiplyNumberField --- dedalus/core/arithmetic.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dedalus/core/arithmetic.py b/dedalus/core/arithmetic.py index 3b8398c7..570f9505 100644 --- a/dedalus/core/arithmetic.py +++ b/dedalus/core/arithmetic.py @@ -927,7 +927,7 @@ def __init__(self, arg0, arg1, out=None,**kw): super().__init__(arg0, arg1, out=out) self.domain = arg1.domain self.tensorsig = arg1.tensorsig - self.dtype = np.result_type(type(arg0), arg1.dtype) + self.dtype = np.result_type(arg0, arg1.dtype) @classmethod def _check_args(cls, *args, **kw): From 9ed5d0f75b5c9cd6c6f522ea389a0e5430bf1276 Mon Sep 17 00:00:00 2001 From: "Calum S. Skene" Date: Tue, 28 Apr 2026 16:04:34 +0100 Subject: [PATCH 31/50] Specify dtype for the CFL reducer --- dedalus/extras/flow_tools.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dedalus/extras/flow_tools.py b/dedalus/extras/flow_tools.py index bc6798ba..18358867 100644 --- a/dedalus/extras/flow_tools.py +++ b/dedalus/extras/flow_tools.py @@ -181,7 +181,7 @@ def __init__(self, solver, initial_dt, cadence=1, safety=1., max_dt=np.inf, self.min_change = min_change self.threshold = threshold - self.reducer = GlobalArrayReducer(self.solver.dist.comm_cart) + self.reducer = GlobalArrayReducer(self.solver.dist.comm_cart, solver.dtype) self.frequencies = self.solver.evaluator.add_dictionary_handler(iter=cadence) def compute_dt(self): From 9fbba8d89680557ee49b8dc04797d48591364654 Mon Sep 17 00:00:00 2001 From: "Calum S. Skene" Date: Tue, 28 Apr 2026 16:39:18 +0100 Subject: [PATCH 32/50] Ensure timestepping coefficients are the correct dtype --- dedalus/core/timesteppers.py | 83 +++++++++++++++++++----------------- 1 file changed, 44 insertions(+), 39 deletions(-) diff --git a/dedalus/core/timesteppers.py b/dedalus/core/timesteppers.py index 162a2d32..1f87bd4c 100644 --- a/dedalus/core/timesteppers.py +++ b/dedalus/core/timesteppers.py @@ -117,7 +117,7 @@ def step(self, dt, wall_time): self.dt[0] = dt # Compute IMEX coefficients - a, b, c = self.compute_coefficients(self.dt, self._iteration) + a, b, c = self.compute_coefficients(self.dt, self._iteration, self.solver.dtype) self._iteration += 1 # Update RHS components and LHS matrices @@ -203,11 +203,11 @@ class CNAB1(MultistepIMEX): steps = 1 @classmethod - def compute_coefficients(self, timesteps, iteration): + def compute_coefficients(self, timesteps, iteration, dtype): - a = np.zeros(self.amax+1) - b = np.zeros(self.bmax+1) - c = np.zeros(self.cmax+1) + a = np.zeros(self.amax+1, dtype=dtype) + b = np.zeros(self.bmax+1, dtype=dtype) + c = np.zeros(self.cmax+1, dtype=dtype) k0, *rest = timesteps @@ -236,11 +236,11 @@ class SBDF1(MultistepIMEX): steps = 1 @classmethod - def compute_coefficients(self, timesteps, iteration): + def compute_coefficients(self, timesteps, iteration, dtype): - a = np.zeros(self.amax+1) - b = np.zeros(self.bmax+1) - c = np.zeros(self.cmax+1) + a = np.zeros(self.amax+1, dtype=dtype) + b = np.zeros(self.bmax+1, dtype=dtype) + c = np.zeros(self.cmax+1, dtype=dtype) k0, *rest = timesteps @@ -268,14 +268,14 @@ class CNAB2(MultistepIMEX): steps = 2 @classmethod - def compute_coefficients(self, timesteps, iteration): + def compute_coefficients(self, timesteps, iteration, dtype): if iteration < 1: - return CNAB1.compute_coefficients(timesteps, iteration) + return CNAB1.compute_coefficients(timesteps, iteration, dtype) - a = np.zeros(self.amax+1) - b = np.zeros(self.bmax+1) - c = np.zeros(self.cmax+1) + a = np.zeros(self.amax+1, dtype=dtype) + b = np.zeros(self.bmax+1, dtype=dtype) + c = np.zeros(self.cmax+1, dtype=dtype) k1, k0, *rest = timesteps w1 = k1 / k0 @@ -306,14 +306,14 @@ class MCNAB2(MultistepIMEX): steps = 2 @classmethod - def compute_coefficients(self, timesteps, iteration): + def compute_coefficients(self, timesteps, iteration, dtype): if iteration < 1: - return CNAB1.compute_coefficients(timesteps, iteration) + return CNAB1.compute_coefficients(timesteps, iteration, dtype) - a = np.zeros(self.amax+1) - b = np.zeros(self.bmax+1) - c = np.zeros(self.cmax+1) + a = np.zeros(self.amax+1, dtype=dtype) + b = np.zeros(self.bmax+1, dtype=dtype) + c = np.zeros(self.cmax+1, dtype=dtype) k1, k0, *rest = timesteps w1 = k1 / k0 @@ -345,14 +345,14 @@ class SBDF2(MultistepIMEX): steps = 2 @classmethod - def compute_coefficients(self, timesteps, iteration): + def compute_coefficients(self, timesteps, iteration, dtype): if iteration < 1: - return SBDF1.compute_coefficients(timesteps, iteration) + return SBDF1.compute_coefficients(timesteps, iteration, dtype=dtype) - a = np.zeros(self.amax+1) - b = np.zeros(self.bmax+1) - c = np.zeros(self.cmax+1) + a = np.zeros(self.amax+1, dtype=dtype) + b = np.zeros(self.bmax+1, dtype=dtype) + c = np.zeros(self.cmax+1, dtype=dtype) k1, k0, *rest = timesteps w1 = k1 / k0 @@ -383,14 +383,14 @@ class CNLF2(MultistepIMEX): steps = 2 @classmethod - def compute_coefficients(self, timesteps, iteration): + def compute_coefficients(self, timesteps, iteration, dtype): if iteration < 1: - return CNAB1.compute_coefficients(timesteps, iteration) + return CNAB1.compute_coefficients(timesteps, iteration, dtype) - a = np.zeros(self.amax+1) - b = np.zeros(self.bmax+1) - c = np.zeros(self.cmax+1) + a = np.zeros(self.amax+1, dtype=dtype) + b = np.zeros(self.bmax+1, dtype=dtype) + c = np.zeros(self.cmax+1, dtype=dtype) k1, k0, *rest = timesteps w1 = k1 / k0 @@ -422,14 +422,14 @@ class SBDF3(MultistepIMEX): steps = 3 @classmethod - def compute_coefficients(self, timesteps, iteration): + def compute_coefficients(self, timesteps, iteration, dtype): if iteration < 2: - return SBDF2.compute_coefficients(timesteps, iteration) + return SBDF2.compute_coefficients(timesteps, iteration, dtype) - a = np.zeros(self.amax+1) - b = np.zeros(self.bmax+1) - c = np.zeros(self.cmax+1) + a = np.zeros(self.amax+1, dtype=dtype) + b = np.zeros(self.bmax+1, dtype=dtype) + c = np.zeros(self.cmax+1, dtype=dtype) k2, k1, k0, *rest = timesteps w2 = k2 / k1 @@ -463,14 +463,14 @@ class SBDF4(MultistepIMEX): steps = 4 @classmethod - def compute_coefficients(self, timesteps, iteration): + def compute_coefficients(self, timesteps, iteration, dtype): if iteration < 3: - return SBDF3.compute_coefficients(timesteps, iteration) + return SBDF3.compute_coefficients(timesteps, iteration, dtype) - a = np.zeros(self.amax+1) - b = np.zeros(self.bmax+1) - c = np.zeros(self.cmax+1) + a = np.zeros(self.amax+1, dtype=dtype) + b = np.zeros(self.bmax+1, dtype=dtype) + c = np.zeros(self.cmax+1, dtype=dtype) k3, k2, k1, k0, *rest = timesteps w3 = k3 / k2 @@ -550,6 +550,11 @@ def __init__(self, solver): self._LHS_params = None self.axpy = get_axpy(xp, solver.dtype) + # Cast scheme coefficients + self.A = self.A.astype(self.solver.dtype) + self.H = self.H.astype(self.solver.dtype) + self.c = self.c.astype(self.solver.dtype) + def step(self, dt, wall_time): """Advance solver by one timestep.""" From ce4a1e40fa05c6def8b5c2debc81d0ac82edd88c Mon Sep 17 00:00:00 2001 From: "Calum S. Skene" Date: Tue, 28 Apr 2026 17:09:20 +0100 Subject: [PATCH 33/50] Convert Jacobi conversion matrices to specified dtype --- dedalus/core/transforms.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/dedalus/core/transforms.py b/dedalus/core/transforms.py index 0433d802..6b196c12 100644 --- a/dedalus/core/transforms.py +++ b/dedalus/core/transforms.py @@ -949,12 +949,12 @@ def __init__(self, grid_size, coeff_size, a, b, a0, b0, array_namespace, dtype, else: # Conversion matrices if self.dealias_before_converting and (self.M_orig < self.N): # truncate prior to conversion matrix - self.forward_conversion = jacobi.conversion_matrix(self.M_orig, a0, b0, a, b).tocsr() + self.forward_conversion = jacobi.conversion_matrix(self.M_orig, a0, b0, a, b).tocsr().astype(dtype) else: # input to conversion matrix not truncated self.forward_conversion = jacobi.conversion_matrix(self.N, a0, b0, a, b) self.forward_conversion.resize(self.M_orig, self.N) - self.forward_conversion = self.forward_conversion.tocsr() - self.backward_conversion = jacobi.conversion_matrix(self.M_orig, a0, b0, a, b).tocsr() + self.forward_conversion = self.forward_conversion.tocsr().astype(dtype) + self.backward_conversion = jacobi.conversion_matrix(self.M_orig, a0, b0, a, b).tocsr().astype(dtype) self.backward_conversion.sum_duplicates() # for faster solve_upper self.resize_rescale_forward = self._resize_rescale_forward_convert self.resize_rescale_backward = self._resize_rescale_backward_convert From 2d2a0becbf812839fae469d599edd3a04f1c8406 Mon Sep 17 00:00:00 2001 From: "Calum S. Skene" Date: Wed, 29 Apr 2026 09:09:41 +0100 Subject: [PATCH 34/50] Specify dtype for GlobalFlowProperty reducer --- dedalus/extras/flow_tools.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dedalus/extras/flow_tools.py b/dedalus/extras/flow_tools.py index 18358867..b99a5069 100644 --- a/dedalus/extras/flow_tools.py +++ b/dedalus/extras/flow_tools.py @@ -86,7 +86,7 @@ def __init__(self, solver, cadence=1): self.solver = solver self.cadence = cadence - self.reducer = GlobalArrayReducer(solver.dist.comm_cart) + self.reducer = GlobalArrayReducer(solver.dist.comm_cart, solver.dtype) self.properties = solver.evaluator.add_dictionary_handler(iter=cadence) def add_property(self, property, name, precompute_integral=False): From 0bf770dde26cf214d73a3ca630d292267a1ad2d8 Mon Sep 17 00:00:00 2001 From: "Calum S. Skene" Date: Wed, 29 Apr 2026 09:42:10 +0100 Subject: [PATCH 35/50] Check if buff size grows even for same spsm descriptor. If it does make a new buff and recompute the analysis. --- dedalus/tools/linalg_gpu.py | 20 +++++++++++++++----- 1 file changed, 15 insertions(+), 5 deletions(-) diff --git a/dedalus/tools/linalg_gpu.py b/dedalus/tools/linalg_gpu.py index a64f51f1..e2de9009 100644 --- a/dedalus/tools/linalg_gpu.py +++ b/dedalus/tools/linalg_gpu.py @@ -276,14 +276,24 @@ def custom_spsm(a, b, alpha=1.0, lower=True, unit_diag=False, transa=False, spsm mat_a.set_attribute(_cusparse.CUSPARSE_SPMAT_DIAG_TYPE, diag_type) # Allocate the workspace needed by the succeeding phases - if new_spsm_descr: - buff_size = _cusparse.spSM_bufferSize( - handle, op_a, op_b, alpha.data, mat_a.desc, mat_b.desc, - mat_c.desc, cuda_dtype, algo, spsm_descr) + # Always calculate workspace (buff_size can change even for same spsm + # descriptor) + buff_size = _cusparse.spSM_bufferSize( + handle, op_a, op_b, alpha.data, mat_a.desc, mat_b.desc, + mat_c.desc, cuda_dtype, algo, spsm_descr) + + need_analysis = new_spsm_descr + if new_spsm_descr: buff = _cupy.empty(buff_size, dtype=_cupy.int8) + else: + # Check if buff size grew from that in the cache + if buff is None or buff.size < buff_size: + buff = _cupy.empty(buff_size, dtype=_cupy.int8) + # buff changed so need the analysis phase + need_analysis = True # Perform the analysis phase - if new_spsm_descr: + if need_analysis: _cusparse.spSM_analysis( handle, op_a, op_b, alpha.data, mat_a.desc, mat_b.desc, mat_c.desc, cuda_dtype, algo, spsm_descr, buff.data.ptr) From 4cc47acd551d5db234fdc26b4c58cbeb3181c0c3 Mon Sep 17 00:00:00 2001 From: "Calum S. Skene" Date: Wed, 29 Apr 2026 09:55:36 +0100 Subject: [PATCH 36/50] Add custom kernel for apply_csr_mid for dtype float32 --- dedalus/tools/linalg_gpu.py | 55 +++++++++++++++++++++++++++++++------ 1 file changed, 47 insertions(+), 8 deletions(-) diff --git a/dedalus/tools/linalg_gpu.py b/dedalus/tools/linalg_gpu.py index e2de9009..ca80c07a 100644 --- a/dedalus/tools/linalg_gpu.py +++ b/dedalus/tools/linalg_gpu.py @@ -78,9 +78,9 @@ def cupy_apply_csr(matrix, array, axis, out): # Kernel for applying CSR matrix with parallelization over n1 and n3 -apply_csr_mid_kernel = cp.RawKernel( +apply_csr_mid_kernel_f64 = cp.RawKernel( r''' - extern "C" __global__ void apply_csr_mid_kernel( + extern "C" __global__ void apply_csr_mid_kernel_f64( const double* data, // CSR data of shape (nnz,) const int* indices, // CSR column indices (nnz,) const int* indptr, // CSR row pointers (N2o + 1,) @@ -109,8 +109,40 @@ def cupy_apply_csr(matrix, array, axis, out): } } ''', - 'apply_csr_mid_kernel') + 'apply_csr_mid_kernel_f64') +apply_csr_mid_kernel_f32 = cp.RawKernel( + r''' + extern "C" __global__ void apply_csr_mid_kernel_f32( + const float* data, // CSR data of shape (nnz,) + const int* indices, // CSR column indices (nnz,) + const int* indptr, // CSR row pointers (N2o + 1,) + const float* input, // shape (N1, N2i, N3) + float* output, // shape (N1, N2o, N3) + int N1, int N2i, int N2o, int N3) + { + int n1 = blockIdx.x * blockDim.x + threadIdx.x ; // batch index + int n3 = blockIdx.y * blockDim.y + threadIdx.y; // output column index + + if (n1 >= N1 || n3 >= N3) return; + + // Loop over output rows = CSR matrix rows + for (int i = 0; i < N2o; ++i) { + float acc = 0; + int start = indptr[i]; + int end = indptr[i + 1]; + + for (int k = start; k < end; ++k) { + int j = indices[k]; // input column + float val = data[k]; + acc += val * input[n1 * N2i * N3 + j * N3 + n3]; + } + + output[n1 * N2o * N3 + i * N3 + n3] = acc; + } + } + ''', + 'apply_csr_mid_kernel_f32') def cupy_apply_csr_mid(matrix, array, out): N1, N2i, N3 = array.shape @@ -123,7 +155,13 @@ def cupy_apply_csr_mid(matrix, array, out): blocks_y = (N3 + threads_y - 1) // threads_y griddim = (blocks_x, blocks_y) # Launch kernel - apply_csr_mid_kernel(griddim, blockdim, (matrix.data, matrix.indices, matrix.indptr, array, out, N1, N2i, N2o, N3)) + if matrix.dtype == cp.float64: + apply_csr_mid_kernel_f64(griddim, blockdim, (matrix.data, matrix.indices, matrix.indptr, array, out, N1, N2i, N2o, N3)) + elif matrix.dtype == cp.float32: + apply_csr_mid_kernel_f32(griddim, blockdim, (matrix.data, matrix.indices, matrix.indptr, array, out, N1, N2i, N2o, N3)) + else: + raise NotImplementedError(f'No apply_csr_mid_kernel for dtype {matrix.dtype}') + def custom_spsm(a, b, alpha=1.0, lower=True, unit_diag=False, transa=False, spsm_descr=None): @@ -470,10 +508,11 @@ def solve(self, b, lower=True, overwrite_A=False, overwrite_b=False, else: assert False - if x.dtype.char in 'fF': - # Note: This is for compatibility with SciPy. - dtype = numpy.promote_types(x.dtype, 'float64') - x = x.astype(dtype) + # TODO: Check if need this (breaks things for float32?) + # if x.dtype.char in 'fF': + # # Note: This is for compatibility with SciPy. + # dtype = numpy.promote_types(x.dtype, 'float64') + # x = x.astype(dtype) return x From b85f2703df9f61cbd659960e02d1adff59cd9933 Mon Sep 17 00:00:00 2001 From: "Calum S. Skene" Date: Wed, 29 Apr 2026 10:03:11 +0100 Subject: [PATCH 37/50] Convert subspace matrices to specified dtype --- dedalus/core/operators.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/dedalus/core/operators.py b/dedalus/core/operators.py index 5540e86a..717f553d 100644 --- a/dedalus/core/operators.py +++ b/dedalus/core/operators.py @@ -989,7 +989,7 @@ def group_matrix(self, group): @CachedMethod def _subspace_matrix(cls, layout, input_basis, output_basis, axis, *args): if cls.subaxis_coupling[0]: - return cls._full_matrix(input_basis, output_basis, *args) + return cls._full_matrix(input_basis, output_basis, *args).astype(layout.dist.dtype) else: input_domain = Domain(layout.dist, bases=[input_basis]) output_domain = Domain(layout.dist, bases=[output_basis]) @@ -1003,7 +1003,7 @@ def _subspace_matrix(cls, layout, input_basis, output_basis, axis, *args): group_blocks = [cls._group_matrix(group, input_basis, output_basis, *args) for group in groups] arg_size = layout.local_shape(input_domain, scales=1)[axis] out_size = layout.local_shape(output_domain, scales=1)[axis] - return sparse_block_diag(group_blocks, shape=(out_size, arg_size)) + return sparse_block_diag(group_blocks, shape=(out_size, arg_size)).astype(layout.dist.dtype) @staticmethod def _full_matrix(input_basis, output_basis, *args): From 5aa0f1f26846ae648c6c7d5a956e4888e6e545d3 Mon Sep 17 00:00:00 2001 From: "Calum S. Skene" Date: Tue, 12 May 2026 15:20:26 +0100 Subject: [PATCH 38/50] Use jit.rawkernel for apply_csr_mid_kernel --- dedalus/tools/linalg_gpu.py | 93 ++++++++----------------------------- 1 file changed, 20 insertions(+), 73 deletions(-) diff --git a/dedalus/tools/linalg_gpu.py b/dedalus/tools/linalg_gpu.py index ca80c07a..f1de3d90 100644 --- a/dedalus/tools/linalg_gpu.py +++ b/dedalus/tools/linalg_gpu.py @@ -6,6 +6,7 @@ import cupy as cp import cupyx.scipy.sparse as csp import cupyx.scipy.sparse.linalg as cupy_spla + from cupyx import jit cupy_available = True except ImportError: cupy_available = False @@ -77,76 +78,28 @@ def cupy_apply_csr(matrix, array, axis, out): cupy_apply_csr_mid(matrix, x3, y3) -# Kernel for applying CSR matrix with parallelization over n1 and n3 -apply_csr_mid_kernel_f64 = cp.RawKernel( - r''' - extern "C" __global__ void apply_csr_mid_kernel_f64( - const double* data, // CSR data of shape (nnz,) - const int* indices, // CSR column indices (nnz,) - const int* indptr, // CSR row pointers (N2o + 1,) - const double* input, // shape (N1, N2i, N3) - double* output, // shape (N1, N2o, N3) - int N1, int N2i, int N2o, int N3) - { - int n1 = blockIdx.x * blockDim.x + threadIdx.x ; // batch index - int n3 = blockIdx.y * blockDim.y + threadIdx.y; // output column index - - if (n1 >= N1 || n3 >= N3) return; - - // Loop over output rows = CSR matrix rows - for (int i = 0; i < N2o; ++i) { - double acc = 0; - int start = indptr[i]; - int end = indptr[i + 1]; - - for (int k = start; k < end; ++k) { - int j = indices[k]; // input column - double val = data[k]; - acc += val * input[n1 * N2i * N3 + j * N3 + n3]; - } - - output[n1 * N2o * N3 + i * N3 + n3] = acc; - } - } - ''', - 'apply_csr_mid_kernel_f64') - -apply_csr_mid_kernel_f32 = cp.RawKernel( - r''' - extern "C" __global__ void apply_csr_mid_kernel_f32( - const float* data, // CSR data of shape (nnz,) - const int* indices, // CSR column indices (nnz,) - const int* indptr, // CSR row pointers (N2o + 1,) - const float* input, // shape (N1, N2i, N3) - float* output, // shape (N1, N2o, N3) - int N1, int N2i, int N2o, int N3) - { - int n1 = blockIdx.x * blockDim.x + threadIdx.x ; // batch index - int n3 = blockIdx.y * blockDim.y + threadIdx.y; // output column index - - if (n1 >= N1 || n3 >= N3) return; - - // Loop over output rows = CSR matrix rows - for (int i = 0; i < N2o; ++i) { - float acc = 0; - int start = indptr[i]; - int end = indptr[i + 1]; - - for (int k = start; k < end; ++k) { - int j = indices[k]; // input column - float val = data[k]; - acc += val * input[n1 * N2i * N3 + j * N3 + n3]; - } - - output[n1 * N2o * N3 + i * N3 + n3] = acc; - } - } - ''', - 'apply_csr_mid_kernel_f32') +@jit.rawkernel() +def apply_csr_mid_kernel(data, indices, indptr, x3, y3, N1, N2i, N2o, N3): + n1 = jit.blockIdx.x * jit.blockDim.x + jit.threadIdx.x # batch index + n3 = jit.blockIdx.y * jit.blockDim.y + jit.threadIdx.y # output column index + if n1 >= N1 or n3 >= N3: + return + # Loop over output rows = CSR matrix rows + for i in range(N2o): + y3[n1, i, n3] = 0 + start = indptr[i] + end = indptr[i + 1] + for k in range(start, end): + j = indices[k] + y3[n1, i, n3] += data[k] * x3[n1, j, n3] def cupy_apply_csr_mid(matrix, array, out): N1, N2i, N3 = array.shape N2o = matrix.shape[0] + N1 = cp.uint32(N1) + N2i = cp.uint32(N2i) + N3 = cp.uint32(N3) + N2o = cp.uint32(N2o) # Choose thread/block config threads_y = min(1024, N3) # maximize concurrency along n3 threads_x = 1024 // threads_y # make block have 1024 threads @@ -155,13 +108,7 @@ def cupy_apply_csr_mid(matrix, array, out): blocks_y = (N3 + threads_y - 1) // threads_y griddim = (blocks_x, blocks_y) # Launch kernel - if matrix.dtype == cp.float64: - apply_csr_mid_kernel_f64(griddim, blockdim, (matrix.data, matrix.indices, matrix.indptr, array, out, N1, N2i, N2o, N3)) - elif matrix.dtype == cp.float32: - apply_csr_mid_kernel_f32(griddim, blockdim, (matrix.data, matrix.indices, matrix.indptr, array, out, N1, N2i, N2o, N3)) - else: - raise NotImplementedError(f'No apply_csr_mid_kernel for dtype {matrix.dtype}') - + apply_csr_mid_kernel(griddim, blockdim, (matrix.data, matrix.indices, matrix.indptr, array, out, N1, N2i, N2o, N3)) def custom_spsm(a, b, alpha=1.0, lower=True, unit_diag=False, transa=False, spsm_descr=None): From a124d8efdcbf76e0680db1db333fdd47dd5c932c Mon Sep 17 00:00:00 2001 From: "Calum S. Skene" Date: Tue, 12 May 2026 15:30:55 +0100 Subject: [PATCH 39/50] Accumulate into acc in apply_csr_mid_kernel --- dedalus/tools/linalg_gpu.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/dedalus/tools/linalg_gpu.py b/dedalus/tools/linalg_gpu.py index f1de3d90..395557b8 100644 --- a/dedalus/tools/linalg_gpu.py +++ b/dedalus/tools/linalg_gpu.py @@ -86,12 +86,13 @@ def apply_csr_mid_kernel(data, indices, indptr, x3, y3, N1, N2i, N2o, N3): return # Loop over output rows = CSR matrix rows for i in range(N2o): - y3[n1, i, n3] = 0 + acc = 0 * y3[n1, i, n3] # get right type start = indptr[i] end = indptr[i + 1] for k in range(start, end): j = indices[k] - y3[n1, i, n3] += data[k] * x3[n1, j, n3] + acc += data[k] * x3[n1, j, n3] + y3[n1, i, n3] = acc def cupy_apply_csr_mid(matrix, array, out): N1, N2i, N3 = array.shape From 99655afe01fe8369452389f8db1d4278c4ee48bd Mon Sep 17 00:00:00 2001 From: "Keaton J. Burns" Date: Tue, 12 May 2026 10:34:48 -0400 Subject: [PATCH 40/50] Update distributor docstring --- dedalus/core/distributor.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/dedalus/core/distributor.py b/dedalus/core/distributor.py index 99ada2e9..cfb2a11c 100644 --- a/dedalus/core/distributor.py +++ b/dedalus/core/distributor.py @@ -40,12 +40,16 @@ class Distributor: Parameters ---------- - dim : int - Dimension + coordsystems : CoordinateSystem or tuple of CoordinateSystems + Problem coordinate systems comm : MPI communicator, optional MPI communicator (default: comm world) mesh : tuple of ints, optional Process mesh for parallelization (default: 1-D mesh of available processes) + dtype : data type, optional + Default data type for fields (default: None) + array_namespace : array namespace or string, optional + Array namespace for field data (e.g. numpy or cupy, default: numpy) Attributes ---------- From b6a88139b03266e3af1d847e956638b48d967c81 Mon Sep 17 00:00:00 2001 From: "Keaton J. Burns" Date: Tue, 12 May 2026 11:05:13 -0400 Subject: [PATCH 41/50] Start simplifying transform library defaults (broken for curvilinear) --- dedalus/core/basis.py | 66 +++++++++++++++++++++----------------- dedalus/core/transforms.py | 31 ++++++++---------- 2 files changed, 51 insertions(+), 46 deletions(-) diff --git a/dedalus/core/basis.py b/dedalus/core/basis.py index f323a014..c475a81e 100644 --- a/dedalus/core/basis.py +++ b/dedalus/core/basis.py @@ -5,6 +5,7 @@ from functools import reduce import inspect from math import prod +import array_api_compat from . import operators from ..libraries import spin_recombination @@ -595,8 +596,10 @@ class Jacobi(IntervalBasis, metaclass=CachedClass): group_shape = (1,) native_bounds = (-1, 1) transforms = {} - default_dct = "fftw_dct" - default_library = "matrix" + default_cpu_library = "matrix" + default_gpu_library = "matrix" + default_cpu_dct = "fftw" + default_gpu_dct = "cupy" @classmethod def _preprocess_cache_args(cls, coord, size, bounds, a, b, a0, b0, dealias, library): @@ -631,12 +634,6 @@ def _preprocess_cache_args(cls, coord, size, bounds, a, b, a0, b0, dealias, libr dealias = tuple(dealias) if len(dealias) != 1: raise ValueError("Jacobi dealias must have length 1.") - # library: pick default based on (a0, b0) - if library is None: - if a0 == b0 == -1/2: - library = cls.default_dct - else: - library = cls.default_library return (coord, size, bounds, a, b, a0, b0, dealias, library) def __init__(self, coord, size, bounds, a, b, a0=None, b0=None, dealias=(1,), library=None): @@ -660,16 +657,30 @@ def _native_grid(self, scale): N, = self.grid_shape((scale,)) return jacobi.build_grid(N, a=self.a0, b=self.b0) + def get_library(self, dist): + """Get library for transforms.""" + if self.library is None: + if self.a0 == self.b0 == -1/2: + if array_api_compat.is_cupy_namespace(dist.array_namespace): + return self.default_gpu_dct + else: + return self.default_cpu_dct + else: + if array_api_compat.is_cupy_namespace(dist.array_namespace): + return self.default_gpu_library + else: + return self.default_cpu_library + else: + return self.library + @CachedMethod def transform_plan(self, dist, grid_size): """Build transform plan.""" - xp = dist.array_namespace - xp_name = xp.__name__.split('.')[-1] # Shortcut trivial transforms if grid_size == 1 or self.size == 1: - return self.transforms[f"matrix-{xp_name}"](grid_size, self.size, self.a, self.b, self.a0, self.b0, dist.array_namespace, dist.dtype) + return self.transforms["matrix"](grid_size, self.size, self.a, self.b, self.a0, self.b0, dist.array_namespace, dist.dtype) else: - return self.transforms[f"{self.library}-{xp_name}"](grid_size, self.size, self.a, self.b, self.a0, self.b0, dist.array_namespace, dist.dtype) + return self.transforms[self.get_library(dist)](grid_size, self.size, self.a, self.b, self.a0, self.b0, dist.array_namespace, dist.dtype) # def weights(self, scales): # """Gauss-Jacobi weights.""" @@ -981,7 +992,8 @@ class FourierBase(IntervalBasis): """Base class for RealFourier and ComplexFourier.""" native_bounds = (0, 2*np.pi) - default_library = "fftw" + default_gpu_library = "cupy" + default_cpu_library = "fftw" @classmethod def _preprocess_cache_args(cls, coord, size, bounds, dealias, library): @@ -1004,9 +1016,6 @@ def _preprocess_cache_args(cls, coord, size, bounds, dealias, library): dealias = tuple(dealias) if len(dealias) != 1: raise ValueError("Fourier dealias must have length 1.") - # library: pick default based on (a0, b0) - if library is None: - library = cls.default_library return (coord, size, bounds, dealias, library) def __init__(self, coord, size, bounds, dealias=(1,), library=None): @@ -1075,16 +1084,24 @@ def _native_grid(self, scale): N, = self.grid_shape((scale,)) return (2 * np.pi / N) * np.arange(N) + def get_library(self, dist): + """Get library for transforms.""" + if self.library is None: + if array_api_compat.is_cupy_namespace(dist.array_namespace): + return self.default_gpu_library + else: + return self.default_cpu_library + else: + return self.library + @CachedMethod def transform_plan(self, dist, grid_size): """Build transform plan.""" - xp = dist.array_namespace - xp_name = xp.__name__.split('.')[-1] # Shortcut trivial transforms if grid_size == 1 or self.size == 1: - return self.transforms[f"matrix-{xp_name}"](grid_size, self.size, dist.array_namespace, dist.dtype) + return self.transforms["matrix"](grid_size, self.size, dist.array_namespace, dist.dtype) else: - return self.transforms[f"{self.library}-{xp_name}"](grid_size, self.size, dist.array_namespace, dist.dtype) + return self.transforms[self.get_library(dist)](grid_size, self.size, dist.array_namespace, dist.dtype) def forward_transform(self, field, axis, gdata, cdata): # Transform @@ -2212,15 +2229,6 @@ def _preprocess_cache_args(cls, coordsys, shape, dtype, radii, k, alpha, dealias dealias = tuple(dealias) if len(dealias) != 2: raise ValueError("Annulus dealias must have length 2.") - # azimuth_library: pick default - if azimuth_library is None: - azimuth_library = RealFourier.default_library - # radius_library: pick default based on alpha - if radius_library is None: - if alpha[0] == alpha[1] == -1/2: - radius_library = Jacobi.default_dct - else: - radius_library = Jacobi.default_library return (coordsys, shape, dtype, radii, k, alpha, dealias, azimuth_library, radius_library) def __init__(self, coordsys, shape, dtype, radii=(1,2), k=0, alpha=(-0.5,-0.5), dealias=(1,1), azimuth_library=None, radius_library=None): diff --git a/dedalus/core/transforms.py b/dedalus/core/transforms.py index 6b196c12..4198eeb7 100644 --- a/dedalus/core/transforms.py +++ b/dedalus/core/transforms.py @@ -120,8 +120,7 @@ def __init__(self, grid_size, coeff_size, a, b, a0, b0, array_namespace, dtype, self.dealias_before_converting = dealias_before_converting -@register_transform(basis.Jacobi, 'matrix-numpy') -@register_transform(basis.Jacobi, 'matrix-cupy') +@register_transform(basis.Jacobi, 'matrix') class JacobiMMT(JacobiTransform, SeparableMatrixTransform): """Jacobi polynomial MMTs.""" @@ -223,8 +222,7 @@ def wavenumbers(self): return (k + KM) % M - KM -@register_transform(basis.ComplexFourier, 'matrix-numpy') -@register_transform(basis.ComplexFourier, 'matrix-cupy') +@register_transform(basis.ComplexFourier, 'matrix') class ComplexFourierMMT(ComplexFourierTransform, SeparableMatrixTransform): """Complex-to-complex Fourier MMT.""" @@ -286,7 +284,7 @@ def resize_coeffs(self, data_in, data_out, axis, rescale): xp.multiply(data_in[negfreq], rescale, data_out[negfreq]) -@register_transform(basis.ComplexFourier, 'scipy-numpy') +@register_transform(basis.ComplexFourier, 'scipy') class ScipyComplexFFT(ComplexFFT): """Complex-to-complex FFT using scipy.fft.""" @@ -308,7 +306,7 @@ def backward(self, cdata, gdata, axis): np.copyto(gdata, temp) -@register_transform(basis.ComplexFourier, 'scipy-cupy') +@register_transform(basis.ComplexFourier, 'cupy') class CupyComplexFFT(ComplexFFT): """Complex-to-complex FFT using scipy.fft.""" @@ -346,7 +344,7 @@ def __init__(self, *args, rigor=None, **kw): super().__init__(*args, **kw) -@register_transform(basis.ComplexFourier, 'fftw-numpy') +@register_transform(basis.ComplexFourier, 'fftw') class FFTWComplexFFT(FFTWBase, ComplexFFT): """Complex-to-complex FFT using FFTW.""" @@ -434,8 +432,7 @@ def wavenumbers(self): return xp.repeat(xp.arange(self.KM+1), 2) -@register_transform(basis.RealFourier, 'matrix-numpy') -@register_transform(basis.RealFourier, 'matrix-cupy') +@register_transform(basis.RealFourier, 'matrix') class RealFourierMMT(RealFourierTransform, SeparableMatrixTransform): """Real-to-real Fourier MMT.""" @@ -477,7 +474,7 @@ def backward_matrix(self): return xp.asarray(functions, order='C', dtype=self.dtype) -@register_transform(basis.RealFourier, 'fftpack-numpy') +@register_transform(basis.RealFourier, 'fftpack') class FFTPACKRealFFT(RealFourierTransform): """Real-to-real FFT using scipy.fftpack.""" @@ -564,7 +561,7 @@ def repack_rescale(self, cdata, temp, axis, rescale): temp[axslice(axis, Kmax+1, None)] = 0 -@register_transform(basis.RealFourier, 'scipy-numpy') +@register_transform(basis.RealFourier, 'scipy') class ScipyRealFFT(RealFFT): """Real-to-real FFT using scipy.fft.""" @@ -593,7 +590,7 @@ def backward(self, cdata, gdata, axis): np.copyto(gdata, temp) -@register_transform(basis.RealFourier, 'scipy-cupy') +@register_transform(basis.RealFourier, 'cupy') class CupyRealFFT(RealFFT): """Real-to-real FFT using scipy.fft.""" @@ -625,7 +622,7 @@ def backward(self, cdata, gdata, axis): xp.copyto(gdata, temp) -@register_transform(basis.RealFourier, 'fftw-numpy') +@register_transform(basis.RealFourier, 'fftw') class FFTWRealFFT(FFTWBase, RealFFT): """Real-to-real FFT using FFTW.""" @@ -656,7 +653,7 @@ def backward(self, cdata, gdata, axis): plan.backward(temp, gdata) -@register_transform(basis.RealFourier, 'fftw_hc-numpy') +@register_transform(basis.RealFourier, 'fftw_hc') class FFTWHalfComplexFFT(FFTWBase, RealFourierTransform): """Real-to-real FFT using FFTW half-complex DFT.""" @@ -1018,19 +1015,19 @@ def _resize_rescale_backward_convert(self, data_in, data_out, axis, Kmax_DCT): super().resize_rescale_backward(data_in, data_out, axis, Kmax_orig) -@register_transform(basis.Jacobi, 'scipy_dct-numpy') +@register_transform(basis.Jacobi, 'scipy') class ScipyFastChebyshevTransform(FastChebyshevTransform, ScipyDCT): """Fast ultraspherical transform using scipy.fft and spectral conversion.""" pass # Implementation is complete via inheritance -@register_transform(basis.Jacobi, 'fftw_dct-numpy') +@register_transform(basis.Jacobi, 'fftw') class FFTWFastChebyshevTransform(FastChebyshevTransform, FFTWDCT): """Fast ultraspherical transform using scipy.fft and spectral conversion.""" pass # Implementation is complete via inheritance -@register_transform(basis.Jacobi, 'scipy_dct-cupy') +@register_transform(basis.Jacobi, 'cupy') class CupyFastChebyshevTransform(FastChebyshevTransform, CupyDCT): """Fast ultraspherical transform using cupy fft and spectral conversion.""" pass # Implementation is complete via inheritance From be949c668d506dc5a38ab0e998c86b39df1e715f Mon Sep 17 00:00:00 2001 From: "Keaton J. Burns" Date: Tue, 12 May 2026 11:26:42 -0400 Subject: [PATCH 42/50] Suppress docstring warnings --- dedalus/core/transforms.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/dedalus/core/transforms.py b/dedalus/core/transforms.py index 4198eeb7..90b11a1d 100644 --- a/dedalus/core/transforms.py +++ b/dedalus/core/transforms.py @@ -170,7 +170,7 @@ def backward_matrix(self): class ComplexFourierTransform(SeparableTransform): - """ + r""" Abstract base class for complex-to-complex Fourier transforms. Parameters @@ -376,7 +376,7 @@ def backward(self, cdata, gdata, axis): class RealFourierTransform(SeparableTransform): - """ + r""" Abstract base class for real-to-real Fourier transforms. Parameters @@ -718,7 +718,7 @@ def backward(self, cdata, gdata, axis): class CosineTransform(SeparableTransform): - """ + r""" Abstract base class for cosine transforms. Parameters From 092fe7f47e2af070f24a0f09f70862a4d06479ba Mon Sep 17 00:00:00 2001 From: "Keaton J. Burns" Date: Tue, 12 May 2026 11:27:00 -0400 Subject: [PATCH 43/50] Add mock jit so linalg_gpu is still importable without cupy --- dedalus/tools/linalg_gpu.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/dedalus/tools/linalg_gpu.py b/dedalus/tools/linalg_gpu.py index 395557b8..4e1c3a33 100644 --- a/dedalus/tools/linalg_gpu.py +++ b/dedalus/tools/linalg_gpu.py @@ -9,6 +9,13 @@ from cupyx import jit cupy_available = True except ImportError: + # Mock jit so module can still be imported without cupy + class jit: + @staticmethod + def rawkernel(): + def decorator(func): + return func + return decorator cupy_available = False @@ -269,7 +276,7 @@ def custom_spsm(a, b, alpha=1.0, lower=True, unit_diag=False, transa=False, spsm mat_c.desc, cuda_dtype, algo, spsm_descr) need_analysis = new_spsm_descr - if new_spsm_descr: + if new_spsm_descr: buff = _cupy.empty(buff_size, dtype=_cupy.int8) else: # Check if buff size grew from that in the cache From 9bcf92a8e5c735fdeaeeb829e93de1f849baceba Mon Sep 17 00:00:00 2001 From: "Keaton J. Burns" Date: Tue, 12 May 2026 11:36:50 -0400 Subject: [PATCH 44/50] Add config option for default gpu subproblem coupling --- dedalus/core/basis.py | 6 +++--- dedalus/core/distributor.py | 2 ++ dedalus/core/solvers.py | 13 +++++++++---- dedalus/dedalus.cfg | 6 +++--- 4 files changed, 17 insertions(+), 10 deletions(-) diff --git a/dedalus/core/basis.py b/dedalus/core/basis.py index c475a81e..61fdf85b 100644 --- a/dedalus/core/basis.py +++ b/dedalus/core/basis.py @@ -661,12 +661,12 @@ def get_library(self, dist): """Get library for transforms.""" if self.library is None: if self.a0 == self.b0 == -1/2: - if array_api_compat.is_cupy_namespace(dist.array_namespace): + if dist.is_cupy_namespace: return self.default_gpu_dct else: return self.default_cpu_dct else: - if array_api_compat.is_cupy_namespace(dist.array_namespace): + if dist.is_cupy_namespace: return self.default_gpu_library else: return self.default_cpu_library @@ -1087,7 +1087,7 @@ def _native_grid(self, scale): def get_library(self, dist): """Get library for transforms.""" if self.library is None: - if array_api_compat.is_cupy_namespace(dist.array_namespace): + if dist.is_cupy_namespace: return self.default_gpu_library else: return self.default_cpu_library diff --git a/dedalus/core/distributor.py b/dedalus/core/distributor.py index cfb2a11c..6dcadfb4 100644 --- a/dedalus/core/distributor.py +++ b/dedalus/core/distributor.py @@ -125,6 +125,8 @@ def __init__(self, coordsystems, comm=None, mesh=None, dtype=None, array_namespa self.array_namespace = getattr(array_api_compat, array_namespace) else: self.array_namespace = array_api_compat.array_namespace(array_namespace.zeros(0)) + self.is_numpy_namespace = array_api_compat.is_numpy_namespace(self.array_namespace) + self.is_cupy_namespace = array_api_compat.is_cupy_namespace(self.array_namespace) @CachedAttribute def cs_by_axis(self): diff --git a/dedalus/core/solvers.py b/dedalus/core/solvers.py index c63a5fbb..98776d3b 100644 --- a/dedalus/core/solvers.py +++ b/dedalus/core/solvers.py @@ -67,11 +67,16 @@ def __init__(self, problem, ncc_cutoff=1e-6, max_ncc_terms=None, entry_cutoff=1e self.ncc_cutoff = ncc_cutoff self.max_ncc_terms = max_ncc_terms self.entry_cutoff = entry_cutoff + # Determing matrix coupling if matrix_coupling is None: - matrix_coupling = np.array(problem.matrix_coupling) - # Couple fully separable problems along last axis by default for efficiency - if not np.any(matrix_coupling): - matrix_coupling[-1] = True + # Override with full coupling according to config option + if self.dist.is_cupy_namespace and config['matrix construction'].getboolean('COUPLE_GPU_SUBPROBLEMS'): + matrix_coupling = np.ones_like(problem.matrix_coupling, dtype=bool) + else: + matrix_coupling = np.array(problem.matrix_coupling) + # Couple fully separable problems along last axis by default for efficiency + if not np.any(matrix_coupling): + matrix_coupling[-1] = True else: # Check specified coupling for compatibility problem_coupling = np.array(problem.matrix_coupling) diff --git a/dedalus/dedalus.cfg b/dedalus/dedalus.cfg index edd0595c..04861e6b 100644 --- a/dedalus/dedalus.cfg +++ b/dedalus/dedalus.cfg @@ -31,9 +31,6 @@ [transforms] - # Default transform library (scipy, fftw) - DEFAULT_LIBRARY = fftw - # Transform multiple fields together when possible GROUP_TRANSFORMS = False @@ -71,6 +68,9 @@ [matrix construction] + # Fully couple GPU subproblems + COUPLE_GPU_SUBPROBLEMS = True + # Put BC rows at the top of the matrix BC_TOP = False From 69039c85e12b8ccb3cbe8acb20825ba658ee9342 Mon Sep 17 00:00:00 2001 From: "Calum S. Skene" Date: Fri, 15 May 2026 10:57:51 +0100 Subject: [PATCH 45/50] Replace np with xp in timesteppers --- dedalus/core/timesteppers.py | 32 ++++++++++++++++---------------- 1 file changed, 16 insertions(+), 16 deletions(-) diff --git a/dedalus/core/timesteppers.py b/dedalus/core/timesteppers.py index 1f87bd4c..0a8c20a9 100644 --- a/dedalus/core/timesteppers.py +++ b/dedalus/core/timesteppers.py @@ -70,8 +70,8 @@ class MultistepIMEX: def __init__(self, solver): self.solver = solver - xp = solver.dist.array_namespace - self.RHS = CoeffSystem(solver.subproblems, dtype=solver.dtype, array_namespace=xp) + self.xp = solver.dist.array_namespace + self.RHS = CoeffSystem(solver.subproblems, dtype=solver.dtype, array_namespace=self.xp) # Create deque for storing recent timesteps self.dt = deque([0.] * self.steps) @@ -81,16 +81,16 @@ def __init__(self, solver): self.LX = LX = deque() self.F = F = deque() for j in range(self.amax): - MX.append(CoeffSystem(solver.subproblems, dtype=solver.dtype, array_namespace=xp)) + MX.append(CoeffSystem(solver.subproblems, dtype=solver.dtype, array_namespace=self.xp)) for j in range(self.bmax): - LX.append(CoeffSystem(solver.subproblems, dtype=solver.dtype, array_namespace=xp)) + LX.append(CoeffSystem(solver.subproblems, dtype=solver.dtype, array_namespace=self.xp)) for j in range(self.cmax): - F.append(CoeffSystem(solver.subproblems, dtype=solver.dtype, array_namespace=xp)) + F.append(CoeffSystem(solver.subproblems, dtype=solver.dtype, array_namespace=self.xp)) # Attributes self._iteration = 0 self._LHS_params = None - self.axpy = get_axpy(xp, solver.dtype) + self.axpy = get_axpy(self.xp, solver.dtype) def step(self, dt, wall_time): """Advance solver by one timestep.""" @@ -154,7 +154,7 @@ def step(self, dt, wall_time): # Build RHS if RHS.data.size: - np.multiply(c[1], F0.data, out=RHS.data) + self.xp.multiply(c[1], F0.data, out=RHS.data) for j in range(2, len(c)): # RHS.data += c[j] * F[j-1].data axpy(a=c[j], x=F[j-1].data, y=RHS.data) @@ -173,7 +173,7 @@ def step(self, dt, wall_time): if update_LHS: if STORE_EXPANDED_MATRICES: # sp.LHS.data[:] = a0*sp.M_exp.data + b0*sp.L_exp.data - np.multiply(a0, sp.M_exp.data, out=sp.LHS.data) + self.xp.multiply(a0, sp.M_exp.data, out=sp.LHS.data) axpy(a=b0, x=sp.L_exp.data, y=sp.LHS.data) else: sp.LHS = (a0*sp.M_min + b0*sp.L_min) # CREATES TEMPORARY @@ -539,16 +539,16 @@ class RungeKuttaIMEX: def __init__(self, solver): self.solver = solver - xp = solver.dist.array_namespace - self.RHS = CoeffSystem(solver.subproblems, dtype=solver.dtype, array_namespace=xp) + self.xp = solver.dist.array_namespace + self.RHS = CoeffSystem(solver.subproblems, dtype=solver.dtype, array_namespace=self.xp) # Create coefficient systems for multistep history - self.MX0 = CoeffSystem(solver.subproblems, dtype=solver.dtype, array_namespace=xp) - self.LX = [CoeffSystem(solver.subproblems, dtype=solver.dtype, array_namespace=xp) for i in range(self.stages)] - self.F = [CoeffSystem(solver.subproblems, dtype=solver.dtype, array_namespace=xp) for i in range(self.stages)] + self.MX0 = CoeffSystem(solver.subproblems, dtype=solver.dtype, array_namespace=self.xp) + self.LX = [CoeffSystem(solver.subproblems, dtype=solver.dtype, array_namespace=self.xp) for i in range(self.stages)] + self.F = [CoeffSystem(solver.subproblems, dtype=solver.dtype, array_namespace=self.xp) for i in range(self.stages)] self._LHS_params = None - self.axpy = get_axpy(xp, solver.dtype) + self.axpy = get_axpy(self.xp, solver.dtype) # Cast scheme coefficients self.A = self.A.astype(self.solver.dtype) @@ -622,7 +622,7 @@ def step(self, dt, wall_time): # Construct RHS(n,i) if RHS.data.size: - np.copyto(RHS.data, MX0.data) + self.xp.copyto(RHS.data, MX0.data) for j in range(0, i): # RHS.data += (k * A[i,j]) * F[j].data axpy(a=(k*A[i,j]), x=F[j].data, y=RHS.data) @@ -639,7 +639,7 @@ def step(self, dt, wall_time): if update_LHS: if STORE_EXPANDED_MATRICES: # sp.LHS.data[:] = sp.M_exp.data + k_Hii*sp.L_exp.data - np.copyto(sp.LHS.data, sp.M_exp.data) + self.xp.copyto(sp.LHS.data, sp.M_exp.data) axpy(a=k_Hii, x=sp.L_exp.data, y=sp.LHS.data) else: sp.LHS = (sp.M_min + k_Hii*sp.L_min) # CREATES TEMPORARY From 9b1bddddd67a5ddb35257eabcf372e32990e9978 Mon Sep 17 00:00:00 2001 From: "Keaton J. Burns" Date: Thu, 4 Jun 2026 12:07:18 +0200 Subject: [PATCH 46/50] Set matrix transform as default for Chebyshev --- dedalus/core/basis.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dedalus/core/basis.py b/dedalus/core/basis.py index 61fdf85b..66a0ae57 100644 --- a/dedalus/core/basis.py +++ b/dedalus/core/basis.py @@ -599,7 +599,7 @@ class Jacobi(IntervalBasis, metaclass=CachedClass): default_cpu_library = "matrix" default_gpu_library = "matrix" default_cpu_dct = "fftw" - default_gpu_dct = "cupy" + default_gpu_dct = "matrix" @classmethod def _preprocess_cache_args(cls, coord, size, bounds, a, b, a0, b0, dealias, library): From 75e43d2d07f785e2646631e2bdb928f1cd568f9e Mon Sep 17 00:00:00 2001 From: "Keaton J. Burns" Date: Fri, 19 Jun 2026 10:54:34 -0400 Subject: [PATCH 47/50] Separate axpy for cpu and device in timesteppers --- dedalus/core/timesteppers.py | 28 +++++++++++++++------------- 1 file changed, 15 insertions(+), 13 deletions(-) diff --git a/dedalus/core/timesteppers.py b/dedalus/core/timesteppers.py index 0a8c20a9..2eaa6cd8 100644 --- a/dedalus/core/timesteppers.py +++ b/dedalus/core/timesteppers.py @@ -90,7 +90,8 @@ def __init__(self, solver): # Attributes self._iteration = 0 self._LHS_params = None - self.axpy = get_axpy(self.xp, solver.dtype) + self.axpy_xp = get_axpy(self.xp, solver.dtype) + self.axpy_np = get_axpy(np, solver.dtype) def step(self, dt, wall_time): """Advance solver by one timestep.""" @@ -110,7 +111,6 @@ def step(self, dt, wall_time): LX = self.LX F = self.F RHS = self.RHS - axpy = self.axpy # Cycle and compute timesteps self.dt.rotate() @@ -157,13 +157,13 @@ def step(self, dt, wall_time): self.xp.multiply(c[1], F0.data, out=RHS.data) for j in range(2, len(c)): # RHS.data += c[j] * F[j-1].data - axpy(a=c[j], x=F[j-1].data, y=RHS.data) + self.axpy_xp(a=c[j], x=F[j-1].data, y=RHS.data) for j in range(1, len(a)): # RHS.data -= a[j] * MX[j-1].data - axpy(a=-a[j], x=MX[j-1].data, y=RHS.data) + self.axpy_xp(a=-a[j], x=MX[j-1].data, y=RHS.data) for j in range(1, len(b)): # RHS.data -= b[j] * LX[j-1].data - axpy(a=-b[j], x=LX[j-1].data, y=RHS.data) + self.axpy_xp(a=-b[j], x=LX[j-1].data, y=RHS.data) # Solve # Ensure coeff space before subsystem scatters @@ -171,10 +171,11 @@ def step(self, dt, wall_time): field.preset_layout('c') for sp in subproblems: if update_LHS: + # Form updated LHS matrix on CPU for factorization if STORE_EXPANDED_MATRICES: # sp.LHS.data[:] = a0*sp.M_exp.data + b0*sp.L_exp.data - self.xp.multiply(a0, sp.M_exp.data, out=sp.LHS.data) - axpy(a=b0, x=sp.L_exp.data, y=sp.LHS.data) + np.multiply(a0, sp.M_exp.data, out=sp.LHS.data) + self.axpy_np(a=b0, x=sp.L_exp.data, y=sp.LHS.data) else: sp.LHS = (a0*sp.M_min + b0*sp.L_min) # CREATES TEMPORARY sp.LHS_solver = solver.matsolver(sp.LHS, solver) @@ -548,7 +549,8 @@ def __init__(self, solver): self.F = [CoeffSystem(solver.subproblems, dtype=solver.dtype, array_namespace=self.xp) for i in range(self.stages)] self._LHS_params = None - self.axpy = get_axpy(self.xp, solver.dtype) + self.axpy_xp = get_axpy(self.xp, solver.dtype) + self.axpy_np = get_axpy(np, solver.dtype) # Cast scheme coefficients self.A = self.A.astype(self.solver.dtype) @@ -578,7 +580,6 @@ def step(self, dt, wall_time): H = self.H c = self.c k = dt - axpy = self.axpy # Check on updating LHS update_LHS = (k != self._LHS_params) @@ -625,9 +626,9 @@ def step(self, dt, wall_time): self.xp.copyto(RHS.data, MX0.data) for j in range(0, i): # RHS.data += (k * A[i,j]) * F[j].data - axpy(a=(k*A[i,j]), x=F[j].data, y=RHS.data) + self.axpy_xp(a=(k*A[i,j]), x=F[j].data, y=RHS.data) # RHS.data -= (k * H[i,j]) * LX[j].data - axpy(a=-(k*H[i,j]), x=LX[j].data, y=RHS.data) + self.axpy_xp(a=-(k*H[i,j]), x=LX[j].data, y=RHS.data) # Solve for stage k_Hii = k * H[i,i] @@ -637,10 +638,11 @@ def step(self, dt, wall_time): for sp in subproblems: # Construct LHS(n,i) if update_LHS: + # Form updated LHS matrix on CPU for factorization if STORE_EXPANDED_MATRICES: # sp.LHS.data[:] = sp.M_exp.data + k_Hii*sp.L_exp.data - self.xp.copyto(sp.LHS.data, sp.M_exp.data) - axpy(a=k_Hii, x=sp.L_exp.data, y=sp.LHS.data) + np.copyto(sp.LHS.data, sp.M_exp.data) + self.axpy_np(a=k_Hii, x=sp.L_exp.data, y=sp.LHS.data) else: sp.LHS = (sp.M_min + k_Hii*sp.L_min) # CREATES TEMPORARY sp.LHS_solvers[i] = solver.matsolver(sp.LHS, solver) From 392b0e1ed51d8d405f7b7fe0391ad94f3222687a Mon Sep 17 00:00:00 2001 From: "Keaton J. Burns" Date: Fri, 19 Jun 2026 11:19:37 -0400 Subject: [PATCH 48/50] Try caching einsum_path for dot product --- dedalus/core/arithmetic.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/dedalus/core/arithmetic.py b/dedalus/core/arithmetic.py index 570f9505..8d4f160b 100644 --- a/dedalus/core/arithmetic.py +++ b/dedalus/core/arithmetic.py @@ -618,6 +618,7 @@ def __init__(self, arg0, arg1, indices=(-1,0), out=None, **kw): arg2_str = arg2_str.replace(arg2_str[indices[1]], 'z') out_str = (arg1_str + arg2_str).replace('z', '') self.einsum_str = arg1_str + '...,' + arg2_str + '...->' + out_str + '...' + self.einsum_path = None def _check_indices(self, arg0, arg1, indices): if (not isinstance(arg0, Operand)) or (not isinstance(arg1, Operand)): @@ -675,10 +676,17 @@ def operate(self, out): # Call einsum if out.data.size: if array_api_compat.is_cupy_namespace(xp): + if self.einsum_path is None: + self.einsum_path = self.get_einsum_path(xp.asnumpy(arg0_data), xp.asnumpy(arg1_data)) # Cupy does not support output keyword - out.data[:] = xp.einsum(self.einsum_str, arg0_data, arg1_data, optimize=True) + out.data[:] = xp.einsum(self.einsum_str, arg0_data, arg1_data, optimize=self.einsum_path) else: - xp.einsum(self.einsum_str, arg0_data, arg1_data, out=out.data, optimize=True) + if self.einsum_path is None: + self.einsum_path = self.get_einsum_path(arg0_data, arg1_data) + xp.einsum(self.einsum_str, arg0_data, arg1_data, out=out.data, optimize=self.einsum_path) + + def get_einsum_path(self, arg0_data, arg1_data): + return np.einsum_path(self.einsum_str, arg0_data, arg1_data, optimize="optimal")[0] @alias("cross") From 721fdfd74e463ff21e740791940607a0a9473e79 Mon Sep 17 00:00:00 2001 From: "Keaton J. Burns" Date: Wed, 24 Jun 2026 09:40:07 -0400 Subject: [PATCH 49/50] Add warning for non-cartesian problems --- dedalus/core/distributor.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/dedalus/core/distributor.py b/dedalus/core/distributor.py index 6dcadfb4..e7fb7619 100644 --- a/dedalus/core/distributor.py +++ b/dedalus/core/distributor.py @@ -11,6 +11,7 @@ import numbers from weakref import WeakSet import array_api_compat +import warnings from .coords import CoordinateSystem, DirectProduct from ..tools.array import reshape_vector @@ -127,6 +128,9 @@ def __init__(self, coordsystems, comm=None, mesh=None, dtype=None, array_namespa self.array_namespace = array_api_compat.array_namespace(array_namespace.zeros(0)) self.is_numpy_namespace = array_api_compat.is_numpy_namespace(self.array_namespace) self.is_cupy_namespace = array_api_compat.is_cupy_namespace(self.array_namespace) + # Warnings for non-Cartesian problems + if self.is_cupy_namespace and any(cs.curvilinear for cs in self.coordsystems): + warnings.warn("Non-Cartesian coordinate systems not yet supported on GPU.") @CachedAttribute def cs_by_axis(self): From 40fd3b4977019f1d1756d950860233294528f752 Mon Sep 17 00:00:00 2001 From: "Keaton J. Burns" Date: Wed, 24 Jun 2026 09:56:08 -0400 Subject: [PATCH 50/50] Add GPU doc page --- docs/pages/gpu.rst | 39 +++++++++++++++++++++++++++++++++++++++ docs/pages/user_guide.rst | 1 + 2 files changed, 40 insertions(+) create mode 100644 docs/pages/gpu.rst diff --git a/docs/pages/gpu.rst b/docs/pages/gpu.rst new file mode 100644 index 00000000..92cac910 --- /dev/null +++ b/docs/pages/gpu.rst @@ -0,0 +1,39 @@ +GPU Support +*********** + +.. note:: + GPU support is currently experimental and may not be fully functional for all use cases. + Cartesian problems should generally work, but curvilinear problems are not yet supported. + +Dedalus supports GPU acceleration on NVIDIA and AMD GPUs through the use of the CuPy array library. +Specifically, Dedalus utilizes the "array-api-compat" library to provide a unified interface for Numpy and CuPy array types and operations. + +Installation +------------ + +In addition to the regular Dedalus dependencies, you will need CuPy installed to enable GPU support. CuPy itself can be installed using pip, but requires a compatible CUDA or ROCm environment to be installed on your system. + +Selecting GPUs +-------------- + +To utilize GPU acceleration, your Dedalus script needs to import CuPy and pass it as the "array_namespace" keyword to the Distributor object. +Fields built with that Distributor will then use CuPy arrays for their data and computations. +No other explicit changes are required to most scripts, but some care may need to be taken for setting field initial conditions, etc., and converting from other Numpy arrays in your scripts to CuPy arrays. + +.. code-block:: python + import dedalus.public as d3 + import cupy as cp + ... + dist = d3.Distributor(coords, dtype=np.float64, array_namespace=cp) + + +Guidelines +---------- + +GPU support is preliminary and many standard Dedalus features have not yet been fully implemented or optimized for GPUs. +Please keep in mind the following guidelines for best performance under the current capabilities: +- Curvilinear problems are not yet supported. +- Distributed GPUs (combining MPI will multiple GPUs) is not yet supported. +- Single and double precision, real and complex, are supported. +- By default, GPU problems are treated collectively as a single subproblem to speed up linear algebra on the GPU (looping explicitly over subproblems is slow). The trade off is that this can make matrix factorizations quite slow, so we strongly recommend using constant timesteps when possible to avoid refactorizations. We hope to address this issue shortly. + diff --git a/docs/pages/user_guide.rst b/docs/pages/user_guide.rst index d8ac3fed..942ad97b 100644 --- a/docs/pages/user_guide.rst +++ b/docs/pages/user_guide.rst @@ -8,6 +8,7 @@ General user guide: problem_formulations performance_tips + gpu configuration troubleshooting changes_from_v2