diff --git a/dedalus/core/arithmetic.py b/dedalus/core/arithmetic.py index 64daa530..8d4f160b 100644 --- a/dedalus/core/arithmetic.py +++ b/dedalus/core/arithmetic.py @@ -13,6 +13,7 @@ import numexpr as ne from collections import defaultdict from math import prod +import array_api_compat from .domain import Domain from .field import Operand, Field @@ -245,10 +246,11 @@ def choose_layout(self): def operate(self, out): """Perform operation.""" + xp = self.array_namespace arg0, arg1 = self.args # Set output layout out.preset_layout(arg0.layout) - np.add(arg0.data, arg1.data, out=out.data) + xp.add(arg0.data, arg1.data, out=out.data) # used for einsum string manipulation @@ -616,6 +618,7 @@ def __init__(self, arg0, arg1, indices=(-1,0), out=None, **kw): arg2_str = arg2_str.replace(arg2_str[indices[1]], 'z') out_str = (arg1_str + arg2_str).replace('z', '') self.einsum_str = arg1_str + '...,' + arg2_str + '...->' + out_str + '...' + self.einsum_path = None def _check_indices(self, arg0, arg1, indices): if (not isinstance(arg0, Operand)) or (not isinstance(arg1, Operand)): @@ -664,6 +667,7 @@ def GammaCoord(self, A_tensorsig, B_tensorsig, C_tensorsig): return G def operate(self, out): + xp = self.array_namespace arg0, arg1 = self.args out.preset_layout(arg0.layout) # Broadcast @@ -671,7 +675,18 @@ def operate(self, out): arg1_data = self.arg1_ghost_broadcaster.cast(arg1) # Call einsum if out.data.size: - np.einsum(self.einsum_str, arg0_data, arg1_data, out=out.data, optimize=True) + if array_api_compat.is_cupy_namespace(xp): + if self.einsum_path is None: + self.einsum_path = self.get_einsum_path(xp.asnumpy(arg0_data), xp.asnumpy(arg1_data)) + # Cupy does not support output keyword + out.data[:] = xp.einsum(self.einsum_str, arg0_data, arg1_data, optimize=self.einsum_path) + else: + if self.einsum_path is None: + self.einsum_path = self.get_einsum_path(arg0_data, arg1_data) + xp.einsum(self.einsum_str, arg0_data, arg1_data, out=out.data, optimize=self.einsum_path) + + def get_einsum_path(self, arg0_data, arg1_data): + return np.einsum_path(self.einsum_str, arg0_data, arg1_data, optimize="optimal")[0] @alias("cross") @@ -854,6 +869,7 @@ def __init__(self, arg0, arg1, out=None, **kw): def operate(self, out): """Perform operation.""" + xp = self.array_namespace arg0, arg1 = self.args # Set output layout out.preset_layout(arg0.layout) @@ -863,7 +879,7 @@ def operate(self, out): # Reshape arg data to broadcast properly for output tensorsig arg0_exp_data = arg0_data.reshape(self.arg0_exp_tshape + arg0_data.shape[len(arg0.tensorsig):]) arg1_exp_data = arg1_data.reshape(self.arg1_exp_tshape + arg1_data.shape[len(arg1.tensorsig):]) - np.multiply(arg0_exp_data, arg1_exp_data, out=out.data) + xp.multiply(arg0_exp_data, arg1_exp_data, out=out.data) class GhostBroadcaster: @@ -919,7 +935,7 @@ def __init__(self, arg0, arg1, out=None,**kw): super().__init__(arg0, arg1, out=out) self.domain = arg1.domain self.tensorsig = arg1.tensorsig - self.dtype = np.result_type(type(arg0), arg1.dtype) + self.dtype = np.result_type(arg0, arg1.dtype) @classmethod def _check_args(cls, *args, **kw): @@ -939,11 +955,12 @@ def enforce_conditions(self): def operate(self, out): """Perform operation.""" + xp = self.array_namespace arg0, arg1 = self.args # Set output layout out.preset_layout(arg1.layout) # Multiply argument data - np.multiply(arg0, arg1.data, out=out.data) + xp.multiply(arg0, arg1.data, out=out.data) def matrix_dependence(self, *vars): return self.args[1].matrix_dependence(*vars) diff --git a/dedalus/core/basis.py b/dedalus/core/basis.py index b25e5061..66a0ae57 100644 --- a/dedalus/core/basis.py +++ b/dedalus/core/basis.py @@ -5,6 +5,7 @@ from functools import reduce import inspect from math import prod +import array_api_compat from . import operators from ..libraries import spin_recombination @@ -14,7 +15,7 @@ from ..tools import clenshaw from ..tools.array import reshape_vector, axindex, axslice, interleave_matrices from ..tools.dispatch import MultiClass, SkipDispatchException -from ..tools.general import unify, DeferredTuple +from ..tools.general import unify, DeferredTuple, is_real_dtype, is_complex_dtype from .coords import Coordinate, CartesianCoordinates, S2Coordinates, SphericalCoordinates, PolarCoordinates, AzimuthalCoordinate, DirectProduct from .domain import Domain from .field import Operand, LockedField @@ -595,8 +596,10 @@ class Jacobi(IntervalBasis, metaclass=CachedClass): group_shape = (1,) native_bounds = (-1, 1) transforms = {} - default_dct = "fftw_dct" - default_library = "matrix" + default_cpu_library = "matrix" + default_gpu_library = "matrix" + default_cpu_dct = "fftw" + default_gpu_dct = "matrix" @classmethod def _preprocess_cache_args(cls, coord, size, bounds, a, b, a0, b0, dealias, library): @@ -631,12 +634,6 @@ def _preprocess_cache_args(cls, coord, size, bounds, a, b, a0, b0, dealias, libr dealias = tuple(dealias) if len(dealias) != 1: raise ValueError("Jacobi dealias must have length 1.") - # library: pick default based on (a0, b0) - if library is None: - if a0 == b0 == -1/2: - library = cls.default_dct - else: - library = cls.default_library return (coord, size, bounds, a, b, a0, b0, dealias, library) def __init__(self, coord, size, bounds, a, b, a0=None, b0=None, dealias=(1,), library=None): @@ -660,10 +657,30 @@ def _native_grid(self, scale): N, = self.grid_shape((scale,)) return jacobi.build_grid(N, a=self.a0, b=self.b0) + def get_library(self, dist): + """Get library for transforms.""" + if self.library is None: + if self.a0 == self.b0 == -1/2: + if dist.is_cupy_namespace: + return self.default_gpu_dct + else: + return self.default_cpu_dct + else: + if dist.is_cupy_namespace: + return self.default_gpu_library + else: + return self.default_cpu_library + else: + return self.library + @CachedMethod def transform_plan(self, dist, grid_size): """Build transform plan.""" - return self.transforms[self.library](grid_size, self.size, self.a, self.b, self.a0, self.b0) + # Shortcut trivial transforms + if grid_size == 1 or self.size == 1: + return self.transforms["matrix"](grid_size, self.size, self.a, self.b, self.a0, self.b0, dist.array_namespace, dist.dtype) + else: + return self.transforms[self.get_library(dist)](grid_size, self.size, self.a, self.b, self.a0, self.b0, dist.array_namespace, dist.dtype) # def weights(self, scales): # """Gauss-Jacobi weights.""" @@ -975,7 +992,8 @@ class FourierBase(IntervalBasis): """Base class for RealFourier and ComplexFourier.""" native_bounds = (0, 2*np.pi) - default_library = "fftw" + default_gpu_library = "cupy" + default_cpu_library = "fftw" @classmethod def _preprocess_cache_args(cls, coord, size, bounds, dealias, library): @@ -998,9 +1016,6 @@ def _preprocess_cache_args(cls, coord, size, bounds, dealias, library): dealias = tuple(dealias) if len(dealias) != 1: raise ValueError("Fourier dealias must have length 1.") - # library: pick default based on (a0, b0) - if library is None: - library = cls.default_library return (coord, size, bounds, dealias, library) def __init__(self, coord, size, bounds, dealias=(1,), library=None): @@ -1069,14 +1084,24 @@ def _native_grid(self, scale): N, = self.grid_shape((scale,)) return (2 * np.pi / N) * np.arange(N) + def get_library(self, dist): + """Get library for transforms.""" + if self.library is None: + if dist.is_cupy_namespace: + return self.default_gpu_library + else: + return self.default_cpu_library + else: + return self.library + @CachedMethod def transform_plan(self, dist, grid_size): """Build transform plan.""" # Shortcut trivial transforms if grid_size == 1 or self.size == 1: - return self.transforms['matrix'](grid_size, self.size) + return self.transforms["matrix"](grid_size, self.size, dist.array_namespace, dist.dtype) else: - return self.transforms[self.library](grid_size, self.size) + return self.transforms[self.get_library(dist)](grid_size, self.size, dist.array_namespace, dist.dtype) def forward_transform(self, field, axis, gdata, cdata): # Transform @@ -1097,9 +1122,9 @@ def Fourier(*args, dtype=None, **kw): """Factory function dispatching to RealFourier and ComplexFourier based on provided dtype.""" if dtype is None: raise ValueError("dtype must be specified") - elif dtype == np.float64: + elif is_real_dtype(dtype): return RealFourier(*args, **kw) - elif dtype == np.complex128: + elif is_complex_dtype(dtype): return ComplexFourier(*args, **kw) else: raise ValueError(f"Unrecognized dtype: {dtype}") @@ -2204,15 +2229,6 @@ def _preprocess_cache_args(cls, coordsys, shape, dtype, radii, k, alpha, dealias dealias = tuple(dealias) if len(dealias) != 2: raise ValueError("Annulus dealias must have length 2.") - # azimuth_library: pick default - if azimuth_library is None: - azimuth_library = RealFourier.default_library - # radius_library: pick default based on alpha - if radius_library is None: - if alpha[0] == alpha[1] == -1/2: - radius_library = Jacobi.default_dct - else: - radius_library = Jacobi.default_library return (coordsys, shape, dtype, radii, k, alpha, dealias, azimuth_library, radius_library) def __init__(self, coordsys, shape, dtype, radii=(1,2), k=0, alpha=(-0.5,-0.5), dealias=(1,1), azimuth_library=None, radius_library=None): @@ -6238,6 +6254,7 @@ class CartesianAdvectiveCFL(operators.AdvectiveCFL): @CachedMethod def cfl_spacing(self): + xp = self.array_namespace velocity = self.operand coordsys = velocity.tensorsig[0] spacing = [] @@ -6260,7 +6277,7 @@ def cfl_spacing(self): axis_spacing[:] = dealias * native_spacing * basis.COV.stretch elif basis is None: axis_spacing = np.inf - spacing.append(axis_spacing) + spacing.append(xp.asarray(axis_spacing)) return spacing def compute_cfl_frequency(self, velocity, out): diff --git a/dedalus/core/distributor.py b/dedalus/core/distributor.py index c4cc766f..e7fb7619 100644 --- a/dedalus/core/distributor.py +++ b/dedalus/core/distributor.py @@ -10,6 +10,8 @@ from math import prod import numbers from weakref import WeakSet +import array_api_compat +import warnings from .coords import CoordinateSystem, DirectProduct from ..tools.array import reshape_vector @@ -39,12 +41,16 @@ class Distributor: Parameters ---------- - dim : int - Dimension + coordsystems : CoordinateSystem or tuple of CoordinateSystems + Problem coordinate systems comm : MPI communicator, optional MPI communicator (default: comm world) mesh : tuple of ints, optional Process mesh for parallelization (default: 1-D mesh of available processes) + dtype : data type, optional + Default data type for fields (default: None) + array_namespace : array namespace or string, optional + Array namespace for field data (e.g. numpy or cupy, default: numpy) Attributes ---------- @@ -74,7 +80,7 @@ class Distributor: states) and the paths between them (D transforms and R transposes). """ - def __init__(self, coordsystems, comm=None, mesh=None, dtype=None): + def __init__(self, coordsystems, comm=None, mesh=None, dtype=None, array_namespace=np): # Accept single coordsys in place of tuple/list if not isinstance(coordsystems, (tuple, list)): coordsystems = (coordsystems,) @@ -115,6 +121,16 @@ def __init__(self, coordsystems, comm=None, mesh=None, dtype=None): self._build_layouts() # Keep set of weak field references self.fields = WeakSet() + # Array module + if isinstance(array_namespace, str): + self.array_namespace = getattr(array_api_compat, array_namespace) + else: + self.array_namespace = array_api_compat.array_namespace(array_namespace.zeros(0)) + self.is_numpy_namespace = array_api_compat.is_numpy_namespace(self.array_namespace) + self.is_cupy_namespace = array_api_compat.is_cupy_namespace(self.array_namespace) + # Warnings for non-Cartesian problems + if self.is_cupy_namespace and any(cs.curvilinear for cs in self.coordsystems): + warnings.warn("Non-Cartesian coordinate systems not yet supported on GPU.") @CachedAttribute def cs_by_axis(self): @@ -255,11 +271,12 @@ def IdentityTensor(self, coordsys_in, coordsys_out=None, bases=None, dtype=None) return I def local_grid(self, basis, scale=None): + xp = self.array_namespace # TODO: remove from bases and do it all here? if scale is None: scale = 1 if basis.dim == 1: - return basis.local_grid(self, scale=scale) + return xp.asarray(basis.local_grid(self, scale=scale)) else: raise ValueError("Use `local_grids` for multidimensional bases.") @@ -292,16 +309,18 @@ def local_grid(self, basis, scale=None): # return tuple(grids) def local_grids(self, *bases, scales=None): + xp = self.array_namespace scales = self.remedy_scales(scales) grids = [] for basis in bases: basis_scales = scales[self.first_axis(basis):self.last_axis(basis)+1] - grids.extend(basis.local_grids(self, scales=basis_scales)) + grids.extend(xp.asarray(basis.local_grids(self, scales=basis_scales))) return grids def local_modes(self, basis): # TODO: remove from bases and do it all here? - return basis.local_modes(self) + xp = self.array_namespace + return xp.asarray(basis.local_modes(self)) @CachedAttribute def default_nonconst_groups(self): diff --git a/dedalus/core/field.py b/dedalus/core/field.py index 44bd1a94..682d246b 100644 --- a/dedalus/core/field.py +++ b/dedalus/core/field.py @@ -7,6 +7,7 @@ from functools import partial, reduce from collections import defaultdict import numpy as np +import array_api_compat from mpi4py import MPI from scipy import sparse from scipy.sparse import linalg as splinalg @@ -483,16 +484,19 @@ def evaluate(self): def reinitialize(self, **kw): return self - @staticmethod - def _create_buffer(buffer_size): + def _create_buffer(self, buffer_size): """Create buffer for Field data.""" - if buffer_size == 0: - # FFTW doesn't like allocating size-0 arrays - return np.zeros((0,), dtype=np.float64) + xp = self.array_namespace + if xp == np: + if buffer_size == 0: + # FFTW doesn't like allocating size-0 arrays + return np.zeros((0,), dtype=np.float64) + else: + # Use FFTW SIMD aligned allocation + alloc_doubles = buffer_size // 8 + return fftw.create_buffer(alloc_doubles) else: - # Use FFTW SIMD aligned allocation - alloc_doubles = buffer_size // 8 - return fftw.create_buffer(alloc_doubles) + return xp.zeros(buffer_size) @CachedAttribute def _dealias_buffer_size(self): @@ -526,14 +530,17 @@ def preset_scales(self, scales): def preset_layout(self, layout): """Interpret buffer as data in specified layout.""" + xp = self.array_namespace layout = self.dist.get_layout_object(layout) self.layout = layout tens_shape = [vs.dim for vs in self.tensorsig] local_shape = layout.local_shape(self.domain, self.scales) total_shape = tuple(tens_shape) + tuple(local_shape) - self.data = np.ndarray(shape=total_shape, - dtype=self.dtype, - buffer=self.buffer) + # Create view into buffer + if array_api_compat.is_cupy_namespace(xp): + self.data = xp.ndarray(shape=total_shape, dtype=self.dtype, memptr=self.buffer.data) + else: + self.data = xp.ndarray(shape=total_shape, dtype=self.dtype, buffer=self.buffer) #self.global_start = layout.start(self.domain, self.scales) @@ -571,6 +578,7 @@ def __init__(self, dist, bases=None, name=None, tensorsig=None, dtype=None): dtype = dist.dtype from .domain import Domain self.dist = dist + self.array_namespace = dist.array_namespace self.name = name self.tensorsig = tensorsig self.dtype = dtype @@ -794,9 +802,15 @@ def allgather_data(self, layout=None): # Change layout if layout is not None: self.change_layout(layout) + # Convert to numpy if on GPU + xp = self.dist.array_namespace + if array_api_compat.is_cupy_namespace(xp): + data = xp.asnumpy(self.data) + else: + data = self.data.copy() # Shortcut for serial execution if self.dist.comm.size == 1: - return self.data.copy() + return data # Build global buffers tensor_shape = tuple(cs.dim for cs in self.tensorsig) global_shape = tensor_shape + self.layout.global_shape(self.domain, self.scales) @@ -805,7 +819,7 @@ def allgather_data(self, layout=None): recv_buff = np.empty_like(send_buff) # Combine data via allreduce -- easy but not communication-optimal # Should be optimized using Allgatherv if this is used past startup - send_buff[local_slices] = self.data + send_buff[local_slices] = data self.dist.comm.Allreduce(send_buff, recv_buff, op=MPI.SUM) return recv_buff @@ -813,13 +827,19 @@ def gather_data(self, root=0, layout=None): # Change layout if layout is not None: self.change_layout(layout) + # Convert to numpy if on GPU + xp = self.dist.array_namespace + if array_api_compat.is_cupy_namespace(xp): + data = xp.asnumpy(self.data) + else: + data = self.data.copy() # Shortcut for serial execution if self.dist.comm.size == 1: - return self.data.copy() + return data # TODO: Shortcut this for constant fields # Gather data # Should be optimized via Gatherv eventually - pieces = self.dist.comm.gather(self.data, root=root) + pieces = self.dist.comm.gather(data, root=root) # Assemble on root node if self.dist.comm.rank == root: ext_mesh = self.layout.ext_mesh @@ -846,7 +866,7 @@ def allreduce_data_norm(self, layout=None, order=2): if self.dist.comm.size > 1: norm = self.dist.comm.allreduce(norm, op=MPI.SUM) norm = norm ** (1 / order) - return norm + return float(norm) def allreduce_data_max(self, layout=None): return self.allreduce_data_norm(layout=layout, order=np.inf) @@ -927,6 +947,7 @@ def fill_random(self, layout=None, scales=None, seed=None, chunk_size=2**20, dis **kw : dict Other keywords passed to the distribution method. """ + xp = self.dist.array_namespace init_layout = self.layout # Set scales if requested if scales is not None: @@ -946,11 +967,10 @@ def fill_random(self, layout=None, scales=None, seed=None, chunk_size=2**20, dis spatial_slices = self.layout.slices(self.domain, self.scales) local_slices = component_slices + spatial_slices local_data = global_data[local_slices] - if self.is_real: - self.data[:] = local_data - else: - self.data.real[:] = local_data[..., 0] - self.data.imag[:] = local_data[..., 1] + if self.is_complex: + local_data = local_data[..., 0] + 1j * local_data[..., 1] + # Copy to field data + self.data[:] = xp.asarray(local_data, dtype=self.dtype) def low_pass_filter(self, shape=None, scales=None): """ diff --git a/dedalus/core/future.py b/dedalus/core/future.py index d3d9bb15..ecc27d96 100644 --- a/dedalus/core/future.py +++ b/dedalus/core/future.py @@ -51,6 +51,7 @@ def __init__(self, *args, out=None): self.original_args = tuple(args) self.out = out self.dist = unify_attributes(args, 'dist', require=False) + self.array_namespace = self.dist.array_namespace #self.domain = Domain(self.dist, self.bases) self._grid_layout = self.dist.grid_layout self._coeff_layout = self.dist.coeff_layout diff --git a/dedalus/core/operators.py b/dedalus/core/operators.py index db750d73..717f553d 100644 --- a/dedalus/core/operators.py +++ b/dedalus/core/operators.py @@ -15,6 +15,7 @@ from math import prod from ..libraries import dedalus_sphere import logging +import array_api_compat logger = logging.getLogger(__name__.split('.')[-1]) from .domain import Domain @@ -378,11 +379,12 @@ def enforce_conditions(self): arg0.require_grid_space() def operate(self, out): + xp = self.array_namespace arg0, arg1 = self.args # Multiply in grid layout out.preset_layout(arg0.layout) if out.data.size: - np.power(arg0.data, arg1, out.data) + xp.power(arg0.data, arg1, out.data) def new_operands(self, arg0, arg1, **kw): return Power(arg0, arg1) @@ -498,8 +500,9 @@ def enforce_conditions(self): self.args[i].change_layout(self.layout) def operate(self, out): + xp = self.array_namespace out.preset_layout(self.layout) - np.copyto(out.data, self.func(*self.args, **self.kw)) + xp.copyto(out.data, self.func(*self.args, **self.kw)) class UnaryGridFunction(NonlinearOperator, FutureField): @@ -829,10 +832,11 @@ def enforce_conditions(self): def operate(self, out): """Perform operation.""" + xp = self.array_namespace arg0 = self.args[0] out.preset_layout(arg0.layout) out.lock_to_layouts(self.layouts) - np.copyto(out.data, arg0.data) + xp.copyto(out.data, arg0.data) def new_operand(self, operand, **kw): return Lock(operand, *self.layouts, **kw) @@ -964,6 +968,20 @@ def subspace_matrix(self, layout): # Caching layer to allow insertion of other arguments return self._subspace_matrix(layout, self.input_basis, self.output_basis, self.first_axis) + @CachedMethod + def subspace_matrix_device(self, layout): + """Build matrix operating on local subspace data on device.""" + # Caching layer to allow insertion of other arguments + matrix = self._subspace_matrix(layout, self.input_basis, self.output_basis, self.first_axis) + if array_api_compat.is_cupy_namespace(self.array_namespace): + import cupy as cp + import cupyx.scipy.sparse as csp + if sparse.issparse(matrix): + matrix = csp.csr_matrix(matrix) + else: + matrix = cp.array(matrix) + return matrix + def group_matrix(self, group): return self._group_matrix(group, self.input_basis, self.output_basis) @@ -971,7 +989,7 @@ def group_matrix(self, group): @CachedMethod def _subspace_matrix(cls, layout, input_basis, output_basis, axis, *args): if cls.subaxis_coupling[0]: - return cls._full_matrix(input_basis, output_basis, *args) + return cls._full_matrix(input_basis, output_basis, *args).astype(layout.dist.dtype) else: input_domain = Domain(layout.dist, bases=[input_basis]) output_domain = Domain(layout.dist, bases=[output_basis]) @@ -985,7 +1003,7 @@ def _subspace_matrix(cls, layout, input_basis, output_basis, axis, *args): group_blocks = [cls._group_matrix(group, input_basis, output_basis, *args) for group in groups] arg_size = layout.local_shape(input_domain, scales=1)[axis] out_size = layout.local_shape(output_domain, scales=1)[axis] - return sparse_block_diag(group_blocks, shape=(out_size, arg_size)) + return sparse_block_diag(group_blocks, shape=(out_size, arg_size)).astype(layout.dist.dtype) @staticmethod def _full_matrix(input_basis, output_basis, *args): @@ -1004,7 +1022,7 @@ def operate(self, out): # Apply matrix if arg.data.size and out.data.size: data_axis = self.last_axis + len(arg.tensorsig) - apply_matrix(self.subspace_matrix(layout), arg.data, data_axis, out=out.data) + apply_matrix(self.subspace_matrix_device(layout), arg.data, data_axis, out=out.data) else: out.data.fill(0) @@ -1539,9 +1557,10 @@ def subproblem_matrix(self, subproblem): def operate(self, out): """Perform operation.""" + xp = self.array_namespace arg = self.args[0] out.preset_layout(arg.layout) - np.copyto(out.data, arg.data) + xp.copyto(out.data, arg.data) class Convert(SpectralOperator, metaclass=MultiClass): @@ -1646,12 +1665,13 @@ def subspace_matrix(self, layout): def operate(self, out): """Perform operation.""" + xp = self.array_namespace arg = self.args[0] layout = arg.layout # Copy for grid space if layout.grid_space[self.last_axis]: out.preset_layout(layout) - np.copyto(out.data, arg.data) + xp.copyto(out.data, arg.data) # Revert to matrix application for coeff space else: super().operate(out) @@ -1794,10 +1814,13 @@ def base(self): def operate(self, out): """Perform operation.""" + xp = self.array_namespace arg = self.args[0] out.preset_layout(arg.layout) - np.einsum('ii...', arg.data, out=out.data) - + if array_api_compat.is_cupy_namespace(xp): + out.data[:] = xp.einsum('ii...', arg.data) + else: + xp.einsum('ii...', arg.data, out=out.data) class SphericalTrace(Trace): @@ -1993,6 +2016,7 @@ def subproblem_matrix(self, subproblem): def operate(self, out): """Perform operation.""" + xp = self.array_namespace operand = self.args[0] # Set output layout out.preset_layout(operand.layout) @@ -3507,10 +3531,11 @@ def subproblem_matrix(self, subproblem): def operate(self, out): """Perform operation.""" # OPTIMIZE: this has an extra copy + xp = self.array_namespace arg0 = self.args[0] # Set output layout out.preset_layout(arg0.layout) - np.copyto(out.data, arg0.data) + xp.copyto(out.data, arg0.data) class DirectProductDivergence(Divergence): @@ -3556,10 +3581,11 @@ def subproblem_matrix(self, subproblem): def operate(self, out): """Perform operation.""" # OPTIMIZE: this has an extra copy + xp = self.array_namespace arg0 = self.args[0] # Set output layout out.preset_layout(arg0.layout) - np.copyto(out.data, arg0.data) + xp.copyto(out.data, arg0.data) class SphericalDivergence(Divergence, SphericalEllOperator): @@ -3761,10 +3787,11 @@ def enforce_conditions(self): def operate(self, out): """Perform operation.""" # OPTIMIZE: this has an extra copy + xp = self.array_namespace arg0 = self.args[0] # Set output layout out.preset_layout(arg0.layout) - np.copyto(out.data, arg0.data) + xp.copyto(out.data, arg0.data) class DirectProductCurl(Curl): @@ -3848,10 +3875,11 @@ def enforce_conditions(self): def operate(self, out): """Perform operation.""" # OPTIMIZE: this has an extra copy + xp = self.array_namespace arg0 = self.args[0] # Set output layout out.preset_layout(arg0.layout) - np.copyto(out.data, arg0.data) + xp.copyto(out.data, arg0.data) class SphericalCurl(Curl, SphericalEllOperator): @@ -4074,10 +4102,11 @@ def enforce_conditions(self): def operate(self, out): """Perform operation.""" # OPTIMIZE: this has an extra copy + xp = self.array_namespace arg0 = self.args[0] # Set output layout out.preset_layout(arg0.layout) - np.copyto(out.data, arg0.data) + xp.copyto(out.data, arg0.data) class DirectProductLaplacian(Laplacian): @@ -4119,10 +4148,11 @@ def enforce_conditions(self): def operate(self, out): """Perform operation.""" # OPTIMIZE: this has an extra copy + xp = self.array_namespace arg0 = self.args[0] # Set output layout out.preset_layout(arg0.layout) - np.copyto(out.data, arg0.data) + xp.copyto(out.data, arg0.data) class SphericalLaplacian(Laplacian, SphericalEllOperator): diff --git a/dedalus/core/solvers.py b/dedalus/core/solvers.py index c63a5fbb..98776d3b 100644 --- a/dedalus/core/solvers.py +++ b/dedalus/core/solvers.py @@ -67,11 +67,16 @@ def __init__(self, problem, ncc_cutoff=1e-6, max_ncc_terms=None, entry_cutoff=1e self.ncc_cutoff = ncc_cutoff self.max_ncc_terms = max_ncc_terms self.entry_cutoff = entry_cutoff + # Determing matrix coupling if matrix_coupling is None: - matrix_coupling = np.array(problem.matrix_coupling) - # Couple fully separable problems along last axis by default for efficiency - if not np.any(matrix_coupling): - matrix_coupling[-1] = True + # Override with full coupling according to config option + if self.dist.is_cupy_namespace and config['matrix construction'].getboolean('COUPLE_GPU_SUBPROBLEMS'): + matrix_coupling = np.ones_like(problem.matrix_coupling, dtype=bool) + else: + matrix_coupling = np.array(problem.matrix_coupling) + # Couple fully separable problems along last axis by default for efficiency + if not np.any(matrix_coupling): + matrix_coupling[-1] = True else: # Check specified coupling for compatibility problem_coupling = np.array(problem.matrix_coupling) diff --git a/dedalus/core/subsystems.py b/dedalus/core/subsystems.py index e684c005..b8d962b3 100644 --- a/dedalus/core/subsystems.py +++ b/dedalus/core/subsystems.py @@ -11,13 +11,20 @@ from mpi4py import MPI import uuid from math import prod +import array_api_compat from .domain import Domain -from ..tools.array import zeros_with_pattern, expand_pattern, sparse_block_diag, copyto, perm_matrix, drop_empty_rows, apply_sparse, assert_sparse_pinv +from ..tools.array import zeros_with_pattern, expand_pattern, sparse_block_diag, copyto, perm_matrix, drop_empty_rows, apply_sparse, assert_sparse_pinv, copy_to_device, copy_from_device from ..tools.cache import CachedAttribute, CachedMethod from ..tools.general import replace, OrderedSet from ..tools.progress import log_progress +try: + import cupy as cp + import cupyx.scipy.sparse as csp +except ImportError: + pass + import logging logger = logging.getLogger(__name__.split('.')[-1]) @@ -118,6 +125,7 @@ def __init__(self, solver, group): self.solver = solver self.problem = problem = solver.problem self.dist = solver.dist + self.array_namespace = solver.dist.array_namespace self.dtype = problem.dtype self.group = group # Determine matrix group using solver matrix dependence @@ -191,11 +199,12 @@ def field_size(self, field): @CachedMethod def _gather_scatter_setup(self, fields): + xp = self.array_namespace # Allocate vector fsizes = tuple(self.field_size(f) for f in fields) fslices = tuple(self.field_slices(f) for f in fields) fshapes = tuple(self.field_shape(f) for f in fields) - data = np.empty(sum(fsizes), dtype=self.dtype) + data = xp.empty(sum(fsizes), dtype=self.dtype) # Make views into data fviews = [] i0 = 0 @@ -248,6 +257,7 @@ def __init__(self, solver, subsystems, group): self.subsystems = subsystems self.group = group self.dist = problem.dist + self.array_namespace = self.dist.array_namespace self.domain = problem.variables[0].domain # HACK self.dtype = problem.dtype # Cross reference from subsystems @@ -279,7 +289,8 @@ def size(self): @CachedAttribute def _compressed_buffer(self): - return np.zeros(self.shape, dtype=self.dtype) + xp = self.array_namespace + return xp.zeros(self.shape, dtype=self.dtype) def coeff_slices(self, domain): return self.subsystems[0].coeff_slices(domain) @@ -300,9 +311,10 @@ def field_size(self, field): return self.subsystems[0].field_size(field) def _build_buffer_views(self, fields): + xp = self.array_namespace # Allocate buffer fsizes = tuple(self.field_size(f) for f in fields) - buffer = np.zeros((sum(fsizes), len(self.subsystems)), dtype=self.dtype) + buffer = xp.zeros((sum(fsizes), len(self.subsystems)), dtype=self.dtype) # Make views into buffer views = [] i0 = 0 @@ -342,7 +354,7 @@ def gather_inputs(self, fields, out=None): # Gather from fields views = self._input_field_views(tuple(fields)) for buffer_view, field_view in views: - np.copyto(buffer_view, field_view) + copyto(buffer_view, field_view) # Apply right preconditioner inverse to compress inputs if out is None: out = self._compressed_buffer @@ -354,7 +366,7 @@ def gather_outputs(self, fields, out=None): # Gather from fields views = self._output_field_views(tuple(fields)) for buffer_view, field_view in views: - np.copyto(buffer_view, field_view) + copyto(buffer_view, field_view) # Apply left preconditioner to compress outputs if out is None: out = self._compressed_buffer @@ -368,7 +380,7 @@ def scatter_inputs(self, data, fields): # Scatter to fields views = self._input_field_views(tuple(fields)) for buffer_view, field_view in views: - np.copyto(field_view, buffer_view) + copyto(field_view, buffer_view) def scatter_outputs(self, data, fields): """Precondition and scatter subproblem data out to output-like field list.""" @@ -377,7 +389,7 @@ def scatter_outputs(self, data, fields): # Scatter to fields views = self._output_field_views(tuple(fields)) for buffer_view, field_view in views: - np.copyto(field_view, buffer_view) + copyto(field_view, buffer_view) def inclusion_matrices(self, bases): """List of inclusion matrices.""" @@ -555,24 +567,45 @@ def build_matrices(self, names): left_perm = left_permutation(self, eqns, bc_top=solver.bc_top, interleave_components=solver.interleave_components).tocsr() right_perm = right_permutation(self, vars, tau_left=solver.tau_left, interleave_components=solver.interleave_components).tocsr() - # Preconditioners + # Preconditioners on CPU # TODO: remove astype casting, requires dealing with used types in apply_sparse - self.pre_left = drop_empty_rows(left_perm @ valid_eqn).tocsr().astype(dtype) - self.pre_left_pinv = self.pre_left.T.tocsr().astype(dtype) - self.pre_right_pinv = drop_empty_rows(right_perm @ valid_var).tocsr().astype(dtype) - self.pre_right = self.pre_right_pinv.T.tocsr().astype(dtype) + pre_left = drop_empty_rows(left_perm @ valid_eqn).tocsr().astype(dtype) + pre_left_pinv = pre_left.T.tocsr().astype(dtype) + pre_right_pinv = drop_empty_rows(right_perm @ valid_var).tocsr().astype(dtype) + pre_right = pre_right_pinv.T.tocsr().astype(dtype) # Check preconditioner pseudoinverses - assert_sparse_pinv(self.pre_left, self.pre_left_pinv) - assert_sparse_pinv(self.pre_right, self.pre_right_pinv) + assert_sparse_pinv(pre_left, pre_left_pinv) + assert_sparse_pinv(pre_right, pre_right_pinv) # Precondition matrices for name in matrices: - matrices[name] = self.pre_left @ matrices[name] @ self.pre_right + matrices[name] = pre_left @ matrices[name] @ pre_right - # Store minimal CSR matrices for fast dot products + # Store minimal CSR matrices on CPU for name, matrix in matrices.items(): - setattr(self, '{:}_min'.format(name), matrix.tocsr()) + setattr(self, f'{name}_min', matrix.tocsr()) + + # Store device copies for fast dot products + xp = solver.dist.array_namespace + if array_api_compat.is_numpy_namespace(xp): + self.pre_left = pre_left + self.pre_left_pinv = pre_left_pinv + self.pre_right_pinv = pre_right_pinv + self.pre_right = pre_right + # Reference current CPU matrices + for name, matrix in matrices.items(): + setattr(self, f'{name}_min_device', getattr(self, f'{name}_min')) + elif array_api_compat.is_cupy_namespace(xp): + # Copy to device + self.pre_left = csp.csr_matrix(pre_left) + self.pre_left_pinv = csp.csr_matrix(pre_left_pinv) + self.pre_right_pinv = csp.csr_matrix(pre_right_pinv) + self.pre_right = csp.csr_matrix(pre_right) + for name, matrix in matrices.items(): + setattr(self, f'{name}_min_device', csp.csr_matrix(matrix)) + else: + raise ValueError("Unsupported array namespace: {}".format(xp)) # Store expanded CSR matrices for fast recombination if len(matrices) > 1: diff --git a/dedalus/core/system.py b/dedalus/core/system.py index 23cbb86b..f28cb206 100644 --- a/dedalus/core/system.py +++ b/dedalus/core/system.py @@ -12,45 +12,44 @@ class CoeffSystem: """ - Representation of a collection of fields that don't need to be transformed, - and are therefore stored as a contigous set of coefficient data for - efficient pencil and group manipulation. + Contiguous buffer for data from all subproblems. Parameters ---------- - nfields : int - Number of fields to represent - domain : domain object - Problem domain + subproblems : list of Subproblem objects + Subproblems to represent + dtype : dtype + Data type + array_namespace : array namespace + Array namespace Attributes ---------- data : ndarray - Contiguous buffer for field coefficients - - """ - - """ - var buffer - + Contiguous buffer for data from all subproblems + views : dict + Nested dictionary of views for each subproblem and subsystem """ - def __init__(self, subproblems, dtype): + def __init__(self, subproblems, dtype, array_namespace): + xp = array_namespace # Build buffer total_size = sum(sp.LHS.shape[1]*len(sp.subsystems) for sp in subproblems) - self.data = np.zeros(total_size, dtype=dtype) + self.data = xp.zeros(total_size, dtype=dtype) # Build views i0 = i1 = 0 self.views = views = {} for sp in subproblems: views[sp] = views_sp = {} + # View for each individual subsystem i00 = i0 for ss in sp.subsystems: i1 += sp.LHS.shape[1] views_sp[ss] = self.data[i0:i1] i0 = i1 i11 = i1 + # View combining all subsystems as rows in a matrix if i11 - i00 > 0: views_sp[None] = self.data[i00:i11].reshape((sp.LHS.shape[1], -1)) else: diff --git a/dedalus/core/timesteppers.py b/dedalus/core/timesteppers.py index 81da4c10..2eaa6cd8 100644 --- a/dedalus/core/timesteppers.py +++ b/dedalus/core/timesteppers.py @@ -2,10 +2,9 @@ from collections import deque, OrderedDict import numpy as np -from scipy.linalg import blas from .system import CoeffSystem -from ..tools.array import apply_sparse +from ..tools.array import apply_sparse, get_axpy # Public interface @@ -71,7 +70,8 @@ class MultistepIMEX: def __init__(self, solver): self.solver = solver - self.RHS = CoeffSystem(solver.subproblems, dtype=solver.dtype) + self.xp = solver.dist.array_namespace + self.RHS = CoeffSystem(solver.subproblems, dtype=solver.dtype, array_namespace=self.xp) # Create deque for storing recent timesteps self.dt = deque([0.] * self.steps) @@ -81,16 +81,17 @@ def __init__(self, solver): self.LX = LX = deque() self.F = F = deque() for j in range(self.amax): - MX.append(CoeffSystem(solver.subproblems, dtype=solver.dtype)) + MX.append(CoeffSystem(solver.subproblems, dtype=solver.dtype, array_namespace=self.xp)) for j in range(self.bmax): - LX.append(CoeffSystem(solver.subproblems, dtype=solver.dtype)) + LX.append(CoeffSystem(solver.subproblems, dtype=solver.dtype, array_namespace=self.xp)) for j in range(self.cmax): - F.append(CoeffSystem(solver.subproblems, dtype=solver.dtype)) + F.append(CoeffSystem(solver.subproblems, dtype=solver.dtype, array_namespace=self.xp)) # Attributes self._iteration = 0 self._LHS_params = None - self.axpy = blas.get_blas_funcs('axpy', dtype=solver.dtype) + self.axpy_xp = get_axpy(self.xp, solver.dtype) + self.axpy_np = get_axpy(np, solver.dtype) def step(self, dt, wall_time): """Advance solver by one timestep.""" @@ -110,14 +111,13 @@ def step(self, dt, wall_time): LX = self.LX F = self.F RHS = self.RHS - axpy = self.axpy # Cycle and compute timesteps self.dt.rotate() self.dt[0] = dt # Compute IMEX coefficients - a, b, c = self.compute_coefficients(self.dt, self._iteration) + a, b, c = self.compute_coefficients(self.dt, self._iteration, self.solver.dtype) self._iteration += 1 # Update RHS components and LHS matrices @@ -143,8 +143,8 @@ def step(self, dt, wall_time): evaluator.require_coeff_space(state_fields) for sp in subproblems: spX = sp.gather_inputs(state_fields) - apply_sparse(sp.M_min, spX, axis=0, out=MX0.get_subdata(sp)) - apply_sparse(sp.L_min, spX, axis=0, out=LX0.get_subdata(sp)) + apply_sparse(sp.M_min_device, spX, axis=0, out=MX0.get_subdata(sp)) + apply_sparse(sp.L_min_device, spX, axis=0, out=LX0.get_subdata(sp)) # Evaluate F(X0) evaluator.evaluate_scheduled(iteration=iteration, wall_time=wall_time, sim_time=sim_time, timestep=dt) @@ -154,16 +154,16 @@ def step(self, dt, wall_time): # Build RHS if RHS.data.size: - np.multiply(c[1], F0.data, out=RHS.data) + self.xp.multiply(c[1], F0.data, out=RHS.data) for j in range(2, len(c)): # RHS.data += c[j] * F[j-1].data - axpy(a=c[j], x=F[j-1].data, y=RHS.data) + self.axpy_xp(a=c[j], x=F[j-1].data, y=RHS.data) for j in range(1, len(a)): # RHS.data -= a[j] * MX[j-1].data - axpy(a=-a[j], x=MX[j-1].data, y=RHS.data) + self.axpy_xp(a=-a[j], x=MX[j-1].data, y=RHS.data) for j in range(1, len(b)): # RHS.data -= b[j] * LX[j-1].data - axpy(a=-b[j], x=LX[j-1].data, y=RHS.data) + self.axpy_xp(a=-b[j], x=LX[j-1].data, y=RHS.data) # Solve # Ensure coeff space before subsystem scatters @@ -171,10 +171,11 @@ def step(self, dt, wall_time): field.preset_layout('c') for sp in subproblems: if update_LHS: + # Form updated LHS matrix on CPU for factorization if STORE_EXPANDED_MATRICES: # sp.LHS.data[:] = a0*sp.M_exp.data + b0*sp.L_exp.data np.multiply(a0, sp.M_exp.data, out=sp.LHS.data) - axpy(a=b0, x=sp.L_exp.data, y=sp.LHS.data) + self.axpy_np(a=b0, x=sp.L_exp.data, y=sp.LHS.data) else: sp.LHS = (a0*sp.M_min + b0*sp.L_min) # CREATES TEMPORARY sp.LHS_solver = solver.matsolver(sp.LHS, solver) @@ -203,11 +204,11 @@ class CNAB1(MultistepIMEX): steps = 1 @classmethod - def compute_coefficients(self, timesteps, iteration): + def compute_coefficients(self, timesteps, iteration, dtype): - a = np.zeros(self.amax+1) - b = np.zeros(self.bmax+1) - c = np.zeros(self.cmax+1) + a = np.zeros(self.amax+1, dtype=dtype) + b = np.zeros(self.bmax+1, dtype=dtype) + c = np.zeros(self.cmax+1, dtype=dtype) k0, *rest = timesteps @@ -236,11 +237,11 @@ class SBDF1(MultistepIMEX): steps = 1 @classmethod - def compute_coefficients(self, timesteps, iteration): + def compute_coefficients(self, timesteps, iteration, dtype): - a = np.zeros(self.amax+1) - b = np.zeros(self.bmax+1) - c = np.zeros(self.cmax+1) + a = np.zeros(self.amax+1, dtype=dtype) + b = np.zeros(self.bmax+1, dtype=dtype) + c = np.zeros(self.cmax+1, dtype=dtype) k0, *rest = timesteps @@ -268,14 +269,14 @@ class CNAB2(MultistepIMEX): steps = 2 @classmethod - def compute_coefficients(self, timesteps, iteration): + def compute_coefficients(self, timesteps, iteration, dtype): if iteration < 1: - return CNAB1.compute_coefficients(timesteps, iteration) + return CNAB1.compute_coefficients(timesteps, iteration, dtype) - a = np.zeros(self.amax+1) - b = np.zeros(self.bmax+1) - c = np.zeros(self.cmax+1) + a = np.zeros(self.amax+1, dtype=dtype) + b = np.zeros(self.bmax+1, dtype=dtype) + c = np.zeros(self.cmax+1, dtype=dtype) k1, k0, *rest = timesteps w1 = k1 / k0 @@ -306,14 +307,14 @@ class MCNAB2(MultistepIMEX): steps = 2 @classmethod - def compute_coefficients(self, timesteps, iteration): + def compute_coefficients(self, timesteps, iteration, dtype): if iteration < 1: - return CNAB1.compute_coefficients(timesteps, iteration) + return CNAB1.compute_coefficients(timesteps, iteration, dtype) - a = np.zeros(self.amax+1) - b = np.zeros(self.bmax+1) - c = np.zeros(self.cmax+1) + a = np.zeros(self.amax+1, dtype=dtype) + b = np.zeros(self.bmax+1, dtype=dtype) + c = np.zeros(self.cmax+1, dtype=dtype) k1, k0, *rest = timesteps w1 = k1 / k0 @@ -345,14 +346,14 @@ class SBDF2(MultistepIMEX): steps = 2 @classmethod - def compute_coefficients(self, timesteps, iteration): + def compute_coefficients(self, timesteps, iteration, dtype): if iteration < 1: - return SBDF1.compute_coefficients(timesteps, iteration) + return SBDF1.compute_coefficients(timesteps, iteration, dtype=dtype) - a = np.zeros(self.amax+1) - b = np.zeros(self.bmax+1) - c = np.zeros(self.cmax+1) + a = np.zeros(self.amax+1, dtype=dtype) + b = np.zeros(self.bmax+1, dtype=dtype) + c = np.zeros(self.cmax+1, dtype=dtype) k1, k0, *rest = timesteps w1 = k1 / k0 @@ -383,14 +384,14 @@ class CNLF2(MultistepIMEX): steps = 2 @classmethod - def compute_coefficients(self, timesteps, iteration): + def compute_coefficients(self, timesteps, iteration, dtype): if iteration < 1: - return CNAB1.compute_coefficients(timesteps, iteration) + return CNAB1.compute_coefficients(timesteps, iteration, dtype) - a = np.zeros(self.amax+1) - b = np.zeros(self.bmax+1) - c = np.zeros(self.cmax+1) + a = np.zeros(self.amax+1, dtype=dtype) + b = np.zeros(self.bmax+1, dtype=dtype) + c = np.zeros(self.cmax+1, dtype=dtype) k1, k0, *rest = timesteps w1 = k1 / k0 @@ -422,14 +423,14 @@ class SBDF3(MultistepIMEX): steps = 3 @classmethod - def compute_coefficients(self, timesteps, iteration): + def compute_coefficients(self, timesteps, iteration, dtype): if iteration < 2: - return SBDF2.compute_coefficients(timesteps, iteration) + return SBDF2.compute_coefficients(timesteps, iteration, dtype) - a = np.zeros(self.amax+1) - b = np.zeros(self.bmax+1) - c = np.zeros(self.cmax+1) + a = np.zeros(self.amax+1, dtype=dtype) + b = np.zeros(self.bmax+1, dtype=dtype) + c = np.zeros(self.cmax+1, dtype=dtype) k2, k1, k0, *rest = timesteps w2 = k2 / k1 @@ -463,14 +464,14 @@ class SBDF4(MultistepIMEX): steps = 4 @classmethod - def compute_coefficients(self, timesteps, iteration): + def compute_coefficients(self, timesteps, iteration, dtype): if iteration < 3: - return SBDF3.compute_coefficients(timesteps, iteration) + return SBDF3.compute_coefficients(timesteps, iteration, dtype) - a = np.zeros(self.amax+1) - b = np.zeros(self.bmax+1) - c = np.zeros(self.cmax+1) + a = np.zeros(self.amax+1, dtype=dtype) + b = np.zeros(self.bmax+1, dtype=dtype) + c = np.zeros(self.cmax+1, dtype=dtype) k3, k2, k1, k0, *rest = timesteps w3 = k3 / k2 @@ -539,15 +540,22 @@ class RungeKuttaIMEX: def __init__(self, solver): self.solver = solver - self.RHS = CoeffSystem(solver.subproblems, dtype=solver.dtype) + self.xp = solver.dist.array_namespace + self.RHS = CoeffSystem(solver.subproblems, dtype=solver.dtype, array_namespace=self.xp) # Create coefficient systems for multistep history - self.MX0 = CoeffSystem(solver.subproblems, dtype=solver.dtype) - self.LX = [CoeffSystem(solver.subproblems, dtype=solver.dtype) for i in range(self.stages)] - self.F = [CoeffSystem(solver.subproblems, dtype=solver.dtype) for i in range(self.stages)] + self.MX0 = CoeffSystem(solver.subproblems, dtype=solver.dtype, array_namespace=self.xp) + self.LX = [CoeffSystem(solver.subproblems, dtype=solver.dtype, array_namespace=self.xp) for i in range(self.stages)] + self.F = [CoeffSystem(solver.subproblems, dtype=solver.dtype, array_namespace=self.xp) for i in range(self.stages)] self._LHS_params = None - self.axpy = blas.get_blas_funcs('axpy', dtype=solver.dtype) + self.axpy_xp = get_axpy(self.xp, solver.dtype) + self.axpy_np = get_axpy(np, solver.dtype) + + # Cast scheme coefficients + self.A = self.A.astype(self.solver.dtype) + self.H = self.H.astype(self.solver.dtype) + self.c = self.c.astype(self.solver.dtype) def step(self, dt, wall_time): """Advance solver by one timestep.""" @@ -572,7 +580,6 @@ def step(self, dt, wall_time): H = self.H c = self.c k = dt - axpy = self.axpy # Check on updating LHS update_LHS = (k != self._LHS_params) @@ -584,11 +591,12 @@ def step(self, dt, wall_time): # Compute M.X(n,0) and L.X(n,0) # Ensure coeff space before subsystem gathers + # TODO: add option to evaluate this matrix-free (e.g for high-bandwidth NCCs when using fast transforms) evaluator.require_coeff_space(state_fields) for sp in subproblems: spX = sp.gather_inputs(state_fields) - apply_sparse(sp.M_min, spX, axis=0, out=MX0.get_subdata(sp)) - apply_sparse(sp.L_min, spX, axis=0, out=LX0.get_subdata(sp)) + apply_sparse(sp.M_min_device, spX, axis=0, out=MX0.get_subdata(sp)) + apply_sparse(sp.L_min_device, spX, axis=0, out=LX0.get_subdata(sp)) # Compute stages # (M + k Hii L).X(n,i) = M.X(n,0) + k Aij F(n,j) - k Hij L.X(n,j) @@ -601,7 +609,7 @@ def step(self, dt, wall_time): evaluator.require_coeff_space(state_fields) for sp in subproblems: spX = sp.gather_inputs(state_fields) - apply_sparse(sp.L_min, spX, axis=0, out=LXi.get_subdata(sp)) + apply_sparse(sp.L_min_device, spX, axis=0, out=LXi.get_subdata(sp)) # Compute F(n,i-1), only doing output on first evaluation if i == 1: @@ -615,12 +623,12 @@ def step(self, dt, wall_time): # Construct RHS(n,i) if RHS.data.size: - np.copyto(RHS.data, MX0.data) + self.xp.copyto(RHS.data, MX0.data) for j in range(0, i): # RHS.data += (k * A[i,j]) * F[j].data - axpy(a=(k*A[i,j]), x=F[j].data, y=RHS.data) + self.axpy_xp(a=(k*A[i,j]), x=F[j].data, y=RHS.data) # RHS.data -= (k * H[i,j]) * LX[j].data - axpy(a=-(k*H[i,j]), x=LX[j].data, y=RHS.data) + self.axpy_xp(a=-(k*H[i,j]), x=LX[j].data, y=RHS.data) # Solve for stage k_Hii = k * H[i,i] @@ -630,10 +638,11 @@ def step(self, dt, wall_time): for sp in subproblems: # Construct LHS(n,i) if update_LHS: + # Form updated LHS matrix on CPU for factorization if STORE_EXPANDED_MATRICES: # sp.LHS.data[:] = sp.M_exp.data + k_Hii*sp.L_exp.data np.copyto(sp.LHS.data, sp.M_exp.data) - axpy(a=k_Hii, x=sp.L_exp.data, y=sp.LHS.data) + self.axpy_np(a=k_Hii, x=sp.L_exp.data, y=sp.LHS.data) else: sp.LHS = (sp.M_min + k_Hii*sp.L_min) # CREATES TEMPORARY sp.LHS_solvers[i] = solver.matsolver(sp.LHS, solver) diff --git a/dedalus/core/transforms.py b/dedalus/core/transforms.py index 00758fb2..90b11a1d 100644 --- a/dedalus/core/transforms.py +++ b/dedalus/core/transforms.py @@ -8,13 +8,16 @@ import scipy.fftpack from ..libraries import dedalus_sphere from math import prod +import array_api_compat from . import basis from ..libraries.fftw import fftw_wrappers as fftw from ..tools import jacobi -from ..tools.array import apply_matrix, apply_dense, axslice, solve_upper_sparse, apply_sparse +from ..tools.array import apply_matrix, apply_dense, axslice, solve_upper_sparse, apply_sparse, copyto from ..tools.cache import CachedAttribute from ..tools.cache import CachedMethod +from ..tools.general import float_to_complex +from ..tools.linalg_gpu import cupy_solve_upper_csr, CustomCupyUpperTriangularSolver import logging logger = logging.getLogger(__name__.split('.')[-1]) @@ -93,19 +96,25 @@ class JacobiTransform(SeparableTransform): Jacobi "a" parameter for the quadrature grid. b0 : int Jacobi "b" parameter for the quadrature grid. + array_namespace : array namespace + Array namespace for the transform. + dtype : dtype + Data type for the transform. Notes ----- TODO: We need to define the normalization we use here. """ - def __init__(self, grid_size, coeff_size, a, b, a0, b0, dealias_before_converting=None): + def __init__(self, grid_size, coeff_size, a, b, a0, b0, array_namespace, dtype, dealias_before_converting=None): self.N = grid_size self.M = coeff_size self.a = a self.b = b self.a0 = a0 self.b0 = b0 + self.array_namespace = array_namespace + self.dtype = dtype if dealias_before_converting is None: dealias_before_converting = GET_DEALIAS_BEFORE_CONVERTING() self.dealias_before_converting = dealias_before_converting @@ -118,6 +127,7 @@ class JacobiMMT(JacobiTransform, SeparableMatrixTransform): @CachedAttribute def forward_matrix(self): """Build forward transform matrix.""" + xp = self.array_namespace N, M = self.N, self.M a, a0 = self.a, self.a0 b, b0 = self.b, self.b0 @@ -141,11 +151,12 @@ def forward_matrix(self): # Truncate to specified coeff_size forward_matrix = forward_matrix[:M, :] # Ensure C ordering for fast dot products - return np.asarray(forward_matrix, order='C') + return xp.asarray(forward_matrix, order='C', dtype=self.dtype) @CachedAttribute def backward_matrix(self): """Build backward transform matrix.""" + xp = self.array_namespace N, M = self.N, self.M a, a0 = self.a, self.a0 b, b0 = self.b, self.b0 @@ -155,11 +166,11 @@ def backward_matrix(self): # Zero higher polynomials for transforms with grid_size < coeff_size polynomials[N:, :] = 0 # Transpose and ensure C ordering for fast dot products - return np.asarray(polynomials.T, order='C') + return xp.asarray(polynomials.T, order='C', dtype=self.dtype) class ComplexFourierTransform(SeparableTransform): - """ + r""" Abstract base class for complex-to-complex Fourier transforms. Parameters @@ -191,19 +202,22 @@ class ComplexFourierTransform(SeparableTransform): If M is even, the ordering is [0, 1, 2, ..., KM, -KM, -KM+1, ..., -1]. """ - def __init__(self, grid_size, coeff_size): + def __init__(self, grid_size, coeff_size, array_namespace, dtype): self.N = grid_size self.M = coeff_size self.KN = (self.N - 1) // 2 self.KM = (self.M - 1) // 2 self.Kmax = min(self.KN, self.KM) + self.array_namespace = array_namespace + self.dtype = dtype @property def wavenumbers(self): """One-dimensional global wavenumber array.""" + xp = self.array_namespace M = self.M KM = self.KM - k = np.arange(M) + k = xp.arange(M) # Wrap around Nyquist mode return (k + KM) % M - KM @@ -215,26 +229,28 @@ class ComplexFourierMMT(ComplexFourierTransform, SeparableMatrixTransform): @CachedAttribute def forward_matrix(self): """Build forward transform matrix.""" + xp = self.array_namespace K = self.wavenumbers[:, None] - X = np.arange(self.N)[None, :] + X = xp.arange(self.N)[None, :] dX = self.N / 2 / np.pi - quadrature = np.exp(-1j*K*X/dX) / self.N + quadrature = xp.exp(-1j*K*X/dX) / self.N # Zero Nyquist and higher modes for transforms with grid_size <= coeff_size - quadrature *= np.abs(K) <= self.Kmax - # Ensure C ordering for fast dot products - return np.asarray(quadrature, order='C') + quadrature *= xp.abs(K) <= self.Kmax + # Ensure C ordering for fast dot products, cast to specified dtype + return xp.asarray(quadrature, order='C', dtype=self.dtype) @CachedAttribute def backward_matrix(self): """Build backward transform matrix.""" + xp = self.array_namespace K = self.wavenumbers[None, :] - X = np.arange(self.N)[:, None] + X = xp.arange(self.N)[:, None] dX = self.N / 2 / np.pi - functions = np.exp(1j*K*X/dX) + functions = xp.exp(1j*K*X/dX) # Zero Nyquist and higher modes for transforms with grid_size <= coeff_size - functions *= np.abs(K) <= self.Kmax - # Ensure C ordering for fast dot products - return np.asarray(functions, order='C') + functions *= xp.abs(K) <= self.Kmax + # Ensure C ordering for fast dot products, cast to specified dtype + return xp.asarray(functions, order='C', dtype=self.dtype) class ComplexFFT(ComplexFourierTransform): @@ -242,29 +258,30 @@ class ComplexFFT(ComplexFourierTransform): def resize_coeffs(self, data_in, data_out, axis, rescale): """Resize and rescale coefficients in standard FFT format by intermediate padding/truncation.""" + xp = self.array_namespace M = self.M Kmax = self.Kmax if Kmax == 0: posfreq = axslice(axis, 0, 1) badfreq = axslice(axis, 1, None) if rescale is None: - np.copyto(data_out[posfreq], data_in[posfreq]) + xp.copyto(data_out[posfreq], data_in[posfreq]) data_out[badfreq] = 0 else: - np.multiply(data_in[posfreq], rescale, data_out[posfreq]) + xp.multiply(data_in[posfreq], rescale, data_out[posfreq]) data_out[badfreq] = 0 else: posfreq = axslice(axis, 0, Kmax+1) badfreq = axslice(axis, Kmax+1, -Kmax) negfreq = axslice(axis, -Kmax, None) if rescale is None: - np.copyto(data_out[posfreq], data_in[posfreq]) + xp.copyto(data_out[posfreq], data_in[posfreq]) data_out[badfreq] = 0 - np.copyto(data_out[negfreq], data_in[negfreq]) + xp.copyto(data_out[negfreq], data_in[negfreq]) else: - np.multiply(data_in[posfreq], rescale, data_out[posfreq]) + xp.multiply(data_in[posfreq], rescale, data_out[posfreq]) data_out[badfreq] = 0 - np.multiply(data_in[negfreq], rescale, data_out[negfreq]) + xp.multiply(data_in[negfreq], rescale, data_out[negfreq]) @register_transform(basis.ComplexFourier, 'scipy') @@ -289,6 +306,34 @@ def backward(self, cdata, gdata, axis): np.copyto(gdata, temp) +@register_transform(basis.ComplexFourier, 'cupy') +class CupyComplexFFT(ComplexFFT): + """Complex-to-complex FFT using scipy.fft.""" + + def __init__(self, *args, **kw): + import cupyx.scipy.fft as cufft + self.cufft = cufft + super().__init__(*args, **kw) + + def forward(self, gdata, cdata, axis): + """Apply forward transform along specified axis.""" + # Call FFT + temp = self.cufft.fft(gdata, axis=axis) # Creates temporary + # Resize and rescale for unit-amplitude normalization + self.resize_coeffs(temp, cdata, axis, rescale=1/self.N) + + def backward(self, cdata, gdata, axis): + """Apply backward transform along specified axis.""" + xp = self.array_namespace + # Resize and rescale for unit-amplitude normalization + # Need temporary to avoid overwriting problems + temp = xp.empty_like(gdata) # Creates temporary + self.resize_coeffs(cdata, temp, axis, rescale=self.N) + # Call FFT + temp = self.cufft.ifft(temp, axis=axis, overwrite_x=True) # Creates temporary + xp.copyto(gdata, temp) + + class FFTWBase: """Abstract base class for FFTW transforms.""" @@ -331,7 +376,7 @@ def backward(self, cdata, gdata, axis): class RealFourierTransform(SeparableTransform): - """ + r""" Abstract base class for real-to-real Fourier transforms. Parameters @@ -368,7 +413,7 @@ class RealFourierTransform(SeparableTransform): where the k = 0 minus-sine mode is zeroed in both directions. """ - def __init__(self, grid_size, coeff_size): + def __init__(self, grid_size, coeff_size, array_namespace, dtype): if coeff_size % 2 != 0: pass#raise ValueError("coeff_size must be even.") self.N = grid_size @@ -376,12 +421,15 @@ def __init__(self, grid_size, coeff_size): self.KN = (self.N - 1) // 2 self.KM = (self.M - 1) // 2 self.Kmax = min(self.KN, self.KM) + self.array_namespace = array_namespace + self.dtype = dtype @property def wavenumbers(self): """One-dimensional global wavenumber array.""" + xp = self.array_namespace # Repeat k's for cos and msin parts - return np.repeat(np.arange(self.KM+1), 2) + return xp.repeat(xp.arange(self.KM+1), 2) @register_transform(basis.RealFourier, 'matrix') @@ -391,37 +439,39 @@ class RealFourierMMT(RealFourierTransform, SeparableMatrixTransform): @CachedAttribute def forward_matrix(self): """Build forward transform matrix.""" + xp = self.array_namespace N = self.N M = max(2, self.M) # Account for sin and cos parts of m=0 Kmax = self.Kmax K = self.wavenumbers[::2, None] - X = np.arange(N)[None, :] + X = xp.arange(N)[None, :] dX = N / 2 / np.pi - quadrature = np.zeros((M, N)) - quadrature[0::2] = (2 / N) * np.cos(K*X/dX) - quadrature[1::2] = -(2 / N) * np.sin(K*X/dX) + quadrature = xp.zeros((M, N)) + quadrature[0::2] = (2 / N) * xp.cos(K*X/dX) + quadrature[1::2] = -(2 / N) * xp.sin(K*X/dX) quadrature[0] = 1 / N # Zero Nyquist and higher modes for transforms with grid_size <= coeff_size quadrature *= self.wavenumbers[:,None] <= self.Kmax # Ensure C ordering for fast dot products - return np.asarray(quadrature, order='C') + return xp.asarray(quadrature, order='C', dtype=self.dtype) @CachedAttribute def backward_matrix(self): """Build backward transform matrix.""" + xp = self.array_namespace N = self.N M = max(2, self.M) # Account for sin and cos parts of m=0 Kmax = self.Kmax K = self.wavenumbers[None, ::2] - X = np.arange(N)[:, None] + X = xp.arange(N)[:, None] dX = N / 2 / np.pi - functions = np.zeros((N, M)) - functions[:, 0::2] = np.cos(K*X/dX) - functions[:, 1::2] = -np.sin(K*X/dX) + functions = xp.zeros((N, M)) + functions[:, 0::2] = xp.cos(K*X/dX) + functions[:, 1::2] = -xp.sin(K*X/dX) # Zero Nyquist and higher modes for transforms with grid_size <= coeff_size functions *= self.wavenumbers[None, :] <= self.Kmax # Ensure C ordering for fast dot products - return np.asarray(functions, order='C') + return xp.asarray(functions, order='C', dtype=self.dtype) @register_transform(basis.RealFourier, 'fftpack') @@ -471,40 +521,42 @@ class RealFFT(RealFourierTransform): def unpack_rescale(self, temp, cdata, axis, rescale): """Unpack complex coefficients and rescale for unit-amplitude normalization.""" + xp = self.array_namespace Kmax = self.Kmax # Scale k = 0 cos data meancos = axslice(axis, 0, 1) - np.multiply(temp[meancos].real, rescale, cdata[meancos]) + xp.multiply(temp[meancos].real, rescale, cdata[meancos]) # Zero k = 0 msin data cdata[axslice(axis, 1, 2)] = 0 # Unpack and scale 1 < k <= Kmax data temp_posfreq = temp[axslice(axis, 1, Kmax+1)] cdata_posfreq_cos = cdata[axslice(axis, 2, 2*(Kmax+1), 2)] cdata_posfreq_msin = cdata[axslice(axis, 3, 2*(Kmax+1), 2)] - np.multiply(temp_posfreq.real, 2*rescale, cdata_posfreq_cos) - np.multiply(temp_posfreq.imag, 2*rescale, cdata_posfreq_msin) + xp.multiply(temp_posfreq.real, 2*rescale, cdata_posfreq_cos) + xp.multiply(temp_posfreq.imag, 2*rescale, cdata_posfreq_msin) # Zero k > Kmax data cdata[axslice(axis, 2*(Kmax+1), None)] = 0 def repack_rescale(self, cdata, temp, axis, rescale): """Repack into complex coefficients and rescale for unit-amplitude normalization.""" + xp = self.array_namespace Kmax = self.Kmax # Scale k = 0 data meancos = axslice(axis, 0, 1) if rescale is None: - np.copyto(temp[meancos], cdata[meancos]) + xp.copyto(temp[meancos], cdata[meancos]) else: - np.multiply(cdata[meancos], rescale, temp[meancos]) + xp.multiply(cdata[meancos], rescale, temp[meancos]) # Repack and scale 1 < k <= Kmax data temp_posfreq = temp[axslice(axis, 1, Kmax+1)] cdata_posfreq_cos = cdata[axslice(axis, 2, 2*(Kmax+1), 2)] cdata_posfreq_msin = cdata[axslice(axis, 3, 2*(Kmax+1), 2)] if rescale is None: - np.multiply(cdata_posfreq_cos, (1 / 2), temp_posfreq.real) - np.multiply(cdata_posfreq_msin, (1 / 2), temp_posfreq.imag) + xp.multiply(cdata_posfreq_cos, (1 / 2), temp_posfreq.real) + xp.multiply(cdata_posfreq_msin, (1 / 2), temp_posfreq.imag) else: - np.multiply(cdata_posfreq_cos, (rescale / 2), temp_posfreq.real) - np.multiply(cdata_posfreq_msin, (rescale / 2), temp_posfreq.imag) + xp.multiply(cdata_posfreq_cos, (rescale / 2), temp_posfreq.real) + xp.multiply(cdata_posfreq_msin, (rescale / 2), temp_posfreq.imag) # Zero k > Kmax data temp[axslice(axis, Kmax+1, None)] = 0 @@ -513,6 +565,10 @@ def repack_rescale(self, cdata, temp, axis, rescale): class ScipyRealFFT(RealFFT): """Real-to-real FFT using scipy.fft.""" + def __init__(self, *args, **kw): + super().__init__(*args, **kw) + self.complex_dtype = float_to_complex(self.dtype) + def forward(self, gdata, cdata, axis): """Apply forward transform along specified axis.""" # Call RFFT @@ -526,7 +582,7 @@ def backward(self, cdata, gdata, axis): # Rescale all modes and combine into complex form shape = list(gdata.shape) shape[axis] = N // 2 + 1 - temp = np.empty(shape=shape, dtype=np.complex128) # Creates temporary + temp = np.empty(shape=shape, dtype=self.complex_dtype) # Creates temporary # Repack into complex form and rescale self.repack_rescale(cdata, temp, axis, rescale=N) # Call IRFFT @@ -534,6 +590,38 @@ def backward(self, cdata, gdata, axis): np.copyto(gdata, temp) +@register_transform(basis.RealFourier, 'cupy') +class CupyRealFFT(RealFFT): + """Real-to-real FFT using scipy.fft.""" + + def __init__(self, *args, **kw): + import cupyx.scipy.fft as cufft + self.cufft = cufft + super().__init__(*args, **kw) + self.complex_dtype = float_to_complex(self.dtype) + + def forward(self, gdata, cdata, axis): + """Apply forward transform along specified axis.""" + # Call RFFT + temp = self.cufft.rfft(gdata, axis=axis) # Creates temporary + # Unpack from complex form and rescale + self.unpack_rescale(temp, cdata, axis, rescale=1/self.N) + + def backward(self, cdata, gdata, axis): + """Apply backward transform along specified axis.""" + xp = self.array_namespace + N = self.N + # Rescale all modes and combine into complex form + shape = list(gdata.shape) + shape[axis] = N // 2 + 1 + temp = xp.empty(shape=shape, dtype=self.complex_dtype) # Creates temporary + # Repack into complex form and rescale + self.repack_rescale(cdata, temp, axis, rescale=N) + # Call IRFFT + temp = self.cufft.irfft(temp, axis=axis, n=N, overwrite_x=True) # Creates temporary + xp.copyto(gdata, temp) + + @register_transform(basis.RealFourier, 'fftw') class FFTWRealFFT(FFTWBase, RealFFT): """Real-to-real FFT using FFTW.""" @@ -630,7 +718,7 @@ def backward(self, cdata, gdata, axis): class CosineTransform(SeparableTransform): - """ + r""" Abstract base class for cosine transforms. Parameters @@ -768,6 +856,33 @@ def backward(self, cdata, gdata, axis): np.copyto(gdata, temp) +class CupyDCT(FastCosineTransform): + """Fast cosine transform using cupy fft.""" + + def __init__(self, *args, **kw): + import cupyx.scipy.fft as cufft + self.cufft = cufft + super().__init__(*args, **kw) + + def forward(self, gdata, cdata, axis): + """Apply forward transform along specified axis.""" + # Call DCT + temp = self.cufft.dct(gdata, type=2, axis=axis) # Creates temporary + # Resize and rescale for unit-ampltidue normalization + self.resize_rescale_forward(temp, cdata, axis, self.Kmax) + + def backward(self, cdata, gdata, axis): + """Apply backward transform along specified axis.""" + xp = self.array_namespace + # Resize and rescale for unit-amplitude normalization + # Need temporary to avoid overwriting problems + temp = xp.empty_like(gdata) # Creates temporary + self.resize_rescale_backward(cdata, temp, axis, self.Kmax) + # Call IDCT + temp = self.cufft.dct(temp, type=3, axis=axis, overwrite_x=True) # Creates temporary + copyto(gdata, temp) + + #@register_transform(basis.Cosine, 'fftw') class FFTWDCT(FFTWBase, FastCosineTransform): """Fast cosine transform using FFTW.""" @@ -804,11 +919,11 @@ class FastChebyshevTransform(JacobiTransform): Subclasses should inherit from this class, then a FastCosineTransform subclass. """ - def __init__(self, grid_size, coeff_size, a, b, a0, b0, **kw): + def __init__(self, grid_size, coeff_size, a, b, a0, b0, array_namespace, dtype, **kw): if not a0 == b0 == -1/2: raise ValueError("Fast Chebshev transform requires a0 == b0 == -1/2.") # Jacobi initialization - super().__init__(grid_size, coeff_size, a, b, a0, b0, **kw) + super().__init__(grid_size, coeff_size, a, b, a0, b0, array_namespace, dtype, **kw) # DCT initialization to set scaling factors if a != a0 or b != b0: # Modify coeff_size to avoid truncation before conversion @@ -831,15 +946,22 @@ def __init__(self, grid_size, coeff_size, a, b, a0, b0, **kw): else: # Conversion matrices if self.dealias_before_converting and (self.M_orig < self.N): # truncate prior to conversion matrix - self.forward_conversion = jacobi.conversion_matrix(self.M_orig, a0, b0, a, b).tocsr() + self.forward_conversion = jacobi.conversion_matrix(self.M_orig, a0, b0, a, b).tocsr().astype(dtype) else: # input to conversion matrix not truncated self.forward_conversion = jacobi.conversion_matrix(self.N, a0, b0, a, b) self.forward_conversion.resize(self.M_orig, self.N) - self.forward_conversion = self.forward_conversion.tocsr() - self.backward_conversion = jacobi.conversion_matrix(self.M_orig, a0, b0, a, b).tocsr() + self.forward_conversion = self.forward_conversion.tocsr().astype(dtype) + self.backward_conversion = jacobi.conversion_matrix(self.M_orig, a0, b0, a, b).tocsr().astype(dtype) self.backward_conversion.sum_duplicates() # for faster solve_upper self.resize_rescale_forward = self._resize_rescale_forward_convert self.resize_rescale_backward = self._resize_rescale_backward_convert + if array_api_compat.is_cupy_namespace(self.array_namespace): + import cupyx.scipy.sparse as csp + self.forward_conversion = csp.csr_matrix(self.forward_conversion) + self.backward_conversion = csp.csr_matrix(self.backward_conversion) + self.forward_conversion.sum_duplicates() + self.backward_conversion.sum_duplicates() + self.backward_conversion_LU = CustomCupyUpperTriangularSolver(self.backward_conversion) def _resize_rescale_forward(self, data_in, data_out, axis, Kmax): """Resize by padding/trunction and rescale to unit amplitude.""" @@ -881,7 +1003,10 @@ def _resize_rescale_backward_convert(self, data_in, data_out, axis, Kmax_DCT): # Truncate input before conversion data_in[badfreq] = 0 # Ultraspherical conversion - solve_upper_sparse(self.backward_conversion, data_in, axis, out=data_in) + if array_api_compat.is_cupy_namespace(self.array_namespace): + cupy_solve_upper_csr(self.backward_conversion_LU, data_in, axis, out=data_in) + else: + solve_upper_sparse(self.backward_conversion, data_in, axis, out=data_in) # Change sign of odd modes if Kmax_orig > 0: posfreq_odd = axslice(axis, 1, Kmax_orig+1, 2) @@ -890,18 +1015,24 @@ def _resize_rescale_backward_convert(self, data_in, data_out, axis, Kmax_DCT): super().resize_rescale_backward(data_in, data_out, axis, Kmax_orig) -@register_transform(basis.Jacobi, 'scipy_dct') +@register_transform(basis.Jacobi, 'scipy') class ScipyFastChebyshevTransform(FastChebyshevTransform, ScipyDCT): """Fast ultraspherical transform using scipy.fft and spectral conversion.""" pass # Implementation is complete via inheritance -@register_transform(basis.Jacobi, 'fftw_dct') +@register_transform(basis.Jacobi, 'fftw') class FFTWFastChebyshevTransform(FastChebyshevTransform, FFTWDCT): """Fast ultraspherical transform using scipy.fft and spectral conversion.""" pass # Implementation is complete via inheritance +@register_transform(basis.Jacobi, 'cupy') +class CupyFastChebyshevTransform(FastChebyshevTransform, CupyDCT): + """Fast ultraspherical transform using cupy fft and spectral conversion.""" + pass # Implementation is complete via inheritance + + # class ScipyDST(PolynomialTransform): # def forward_reduced(self): diff --git a/dedalus/dedalus.cfg b/dedalus/dedalus.cfg index edd0595c..04861e6b 100644 --- a/dedalus/dedalus.cfg +++ b/dedalus/dedalus.cfg @@ -31,9 +31,6 @@ [transforms] - # Default transform library (scipy, fftw) - DEFAULT_LIBRARY = fftw - # Transform multiple fields together when possible GROUP_TRANSFORMS = False @@ -71,6 +68,9 @@ [matrix construction] + # Fully couple GPU subproblems + COUPLE_GPU_SUBPROBLEMS = True + # Put BC rows at the top of the matrix BC_TOP = False diff --git a/dedalus/extras/flow_tools.py b/dedalus/extras/flow_tools.py index bc6798ba..b99a5069 100644 --- a/dedalus/extras/flow_tools.py +++ b/dedalus/extras/flow_tools.py @@ -86,7 +86,7 @@ def __init__(self, solver, cadence=1): self.solver = solver self.cadence = cadence - self.reducer = GlobalArrayReducer(solver.dist.comm_cart) + self.reducer = GlobalArrayReducer(solver.dist.comm_cart, solver.dtype) self.properties = solver.evaluator.add_dictionary_handler(iter=cadence) def add_property(self, property, name, precompute_integral=False): @@ -181,7 +181,7 @@ def __init__(self, solver, initial_dt, cadence=1, safety=1., max_dt=np.inf, self.min_change = min_change self.threshold = threshold - self.reducer = GlobalArrayReducer(self.solver.dist.comm_cart) + self.reducer = GlobalArrayReducer(self.solver.dist.comm_cart, solver.dtype) self.frequencies = self.solver.evaluator.add_dictionary_handler(iter=cadence) def compute_dt(self): diff --git a/dedalus/libraries/matsolvers.py b/dedalus/libraries/matsolvers.py index f301d4f2..ede93a1b 100644 --- a/dedalus/libraries/matsolvers.py +++ b/dedalus/libraries/matsolvers.py @@ -5,7 +5,12 @@ import scipy.sparse as sp import scipy.sparse.linalg as spla from functools import partial - +import array_api_compat +try: + import cupyx.scipy.sparse.linalg as cupy_spla + cupy_available = True +except ImportError: + cupy_available = False matsolvers = {} def add_solver(solver): @@ -144,6 +149,21 @@ def __init__(self, matrix, solver=None): relax=self.relax, panel_size=self.panel_size, options=self.options) + # Cupy conversion + if array_api_compat.is_cupy_namespace(solver.dist.array_namespace): + # Avoid cupy splu which requires GPU matrices but transfers them to factorize on CPU + # Run same typecheck as cupy splu + if matrix.dtype.char not in 'fdFD': + raise TypeError('Invalid dtype (actual: {})'.format(self.LU.dtype)) + # Build cupy factorization from scipy factorization of CPU matrices + self.LU = cupy_spla.SuperLU(self.LU) + self.LU.spsm_L_descr = None + self.LU.spsm_U_descr = None + self.solve = self.cupy_solve + + def cupy_solve(self, vector): + from dedalus.tools.linalg_gpu import custom_SuperLU_solve + return custom_SuperLU_solve(self.LU, vector, trans=self.trans) def solve(self, vector): return self.LU.solve(vector, trans=self.trans) @@ -225,6 +245,9 @@ class SparseInverse(SparseSolver): def __init__(self, matrix, solver=None): self.matrix_inverse = spla.inv(matrix.tocsc()) + # Cupy conversion + if array_api_compat.is_cupy_namespace(solver.dist.array_namespace): + self.matrix_inverse = cupy_spla.inv(matrix.tocsc()) def solve(self, vector): return self.matrix_inverse @ vector diff --git a/dedalus/tools/array.py b/dedalus/tools/array.py index ab9caf88..e137f75d 100644 --- a/dedalus/tools/array.py +++ b/dedalus/tools/array.py @@ -5,7 +5,10 @@ import scipy.sparse as sp from scipy.sparse import _sparsetools from scipy.sparse import linalg as spla +from scipy.linalg import blas from math import prod +from ..tools import linalg_gpu +import array_api_compat from .config import config from . import linalg as cython_linalg @@ -76,10 +79,20 @@ def expand_pattern(input, pattern): def apply_matrix(matrix, array, axis, **kw): """Apply matrix along any axis of an array.""" - if sparse.isspmatrix(matrix): - return apply_sparse(matrix, array, axis, **kw) + xp = array_api_compat.array_namespace(array) + if array_api_compat.is_numpy_namespace(xp): + if sparse.issparse(matrix): + return apply_sparse(matrix, array, axis, **kw) + else: + return apply_dense(matrix, array, axis, **kw) + elif array_api_compat.is_cupy_namespace(xp): + import cupyx.scipy.sparse as csp + if csp.issparse(matrix): + return apply_sparse(matrix, array, axis, **kw) + else: + return apply_dense(matrix, array, axis, **kw) else: - return apply_dense(matrix, array, axis, **kw) + raise ValueError("Unsupported array type") def apply_dense_einsum(matrix, array, axis, optimize=True, **kw): @@ -173,14 +186,14 @@ def apply_sparse(matrix, array, axis, out=None, check_shapes=False, num_threads= Apply sparse matrix along any axis of an array. Must be out of place if ouptut is specified. """ - # Check matrix - if not isinstance(matrix, sparse.csr_matrix): - raise ValueError("Matrix must be in CSR format.") + xp = array_api_compat.array_namespace(array) + matrix.sum_duplicates() + matrix.has_canonical_format = True # Check output if out is None: out_shape = list(array.shape) out_shape[axis] = matrix.shape[0] - out = np.empty(out_shape, dtype=array.dtype) + out = xp.empty(out_shape, dtype=array.dtype) elif out is array: raise ValueError("Cannot apply in place") # Check shapes @@ -189,17 +202,27 @@ def apply_sparse(matrix, array, axis, out=None, check_shapes=False, num_threads= raise ValueError("Axis out of bounds.") if matrix.shape[1] != array.shape[axis] or matrix.shape[0] != out.shape[axis]: raise ValueError("Matrix shape mismatch.") - # Old way if requested - if OLD_CSR_MATVECS and array.ndim == 2 and axis == 0: - out.fill(0) - return csr_matvecs(matrix, array, out) - # Promote datatypes - # TODO: find way to optimize this with fused types - matrix_data = matrix.data - if matrix_data.dtype != out.dtype: - matrix_data = matrix_data.astype(out.dtype) - # Call cython routine - cython_linalg.apply_csr(matrix.indptr, matrix.indices, matrix_data, array, out, axis, num_threads) + # Dispatch on array type + if array_api_compat.is_numpy_namespace(xp): + # Check matrix + if not isinstance(matrix, sparse.csr_matrix): + raise ValueError("Matrix must be in CSR format.") + # Old way if requested + if OLD_CSR_MATVECS and array.ndim == 2 and axis == 0: + out.fill(0) + return csr_matvecs(matrix, array, out) + # Promote datatypes + # TODO: find way to optimize this with fused types + matrix_data = matrix.data + if matrix_data.dtype != out.dtype: + matrix_data = matrix_data.astype(out.dtype) + # Call cython routine + cython_linalg.apply_csr(matrix.indptr, matrix.indices, matrix_data, array, out, axis, num_threads) + elif array_api_compat.is_cupy_namespace(xp): + # TODO: check matrix format here without import cupy + linalg_gpu.cupy_apply_csr(matrix, array, axis, out) + else: + raise ValueError("Unsupported array type") return out @@ -208,28 +231,40 @@ def solve_upper_sparse(matrix, rhs, axis, out=None, check_shapes=False, num_thre Solve upper triangular sparse matrix along any axis of an array. Matrix assumed to be nonzero on the diagonals. """ - # Check matrix - if not isinstance(matrix, sparse.csr_matrix): - raise ValueError("Matrix must be in CSR format.") - if not matrix._has_canonical_format: # avoid property hook (without underscore) - matrix.sum_duplicates() - # Setup output = rhs + xp = array_api_compat.array_namespace(rhs) + matrix.sum_duplicates() + matrix.has_canonical_format = True + # Check output if out is None: - out = np.copy(rhs) - elif out is not rhs: - np.copyto(out, rhs) - # Promote datatypes - matrix_data = matrix.data - if matrix_data.dtype != rhs.dtype: - matrix_data = matrix_data.astype(rhs.dtype) - # Check shapes - if check_shapes: - if not (0 <= axis < rhs.ndim): - raise ValueError("Axis out of bounds.") - if not (matrix.shape[0] == matrix.shape[1] == rhs.shape[axis]): - raise ValueError("Matrix shape mismatch.") - # Call cython routine - cython_linalg.solve_upper_csr(matrix.indptr, matrix.indices, matrix_data, out, axis, num_threads) + out = xp.empty_like(rhs) + # Dispatch on array type + if array_api_compat.is_numpy_namespace(xp): + # Check matrix + if not isinstance(matrix, sparse.csr_matrix): + raise ValueError("Matrix must be in CSR format.") + if not matrix._has_canonical_format: # avoid property hook (without underscore) + matrix.sum_duplicates() + # Setup output = rhs + copyto(out, rhs) + # Promote datatypes + matrix_data = matrix.data + if matrix_data.dtype != rhs.dtype: + matrix_data = matrix_data.astype(rhs.dtype) + # Check shapes + if check_shapes: + if not (0 <= axis < rhs.ndim): + raise ValueError("Axis out of bounds.") + if not (matrix.shape[0] == matrix.shape[1] == rhs.shape[axis]): + raise ValueError("Matrix shape mismatch.") + # Call cython routine + cython_linalg.solve_upper_csr(matrix.indptr, matrix.indices, matrix_data, out, axis, num_threads) + elif array_api_compat.is_cupy_namespace(xp): + if not matrix._has_canonical_format: # avoid property hook (without underscore) + matrix.sum_duplicates() + linalg_gpu.cupy_solve_upper_csr(matrix, rhs, axis, out) + else: + raise ValueError("Unsupported array type") + return out def csr_matvec(A_csr, x_vec, out_vec): @@ -353,6 +388,22 @@ def copyto(dest, src): dest[:] = src +def copy_to_device(dest, src): + xp = array_api_compat.array_namespace(dest) + if array_api_compat.is_cupy_namespace(xp): + src = xp.asarray(src) + dest[:] = src + else: + dest[:] = src + + +def copy_from_device(dest, src): + if array_api_compat.is_cupy_array(src): + src.get(out=dest) + else: + dest[:] = src + + def perm_matrix(perm, M=None, source_index=False, sparse=True): """ Build sparse permutation matrix from permutation vector. @@ -474,3 +525,12 @@ def assert_sparse_pinv(A, B): if not sparse_allclose((B @ A).conj().T, B @ A): raise AssertionError("Not a pseudoinverse") + +def get_axpy(array_namespace, dtype): + if array_api_compat.is_numpy_namespace(array_namespace): + return blas.get_blas_funcs('axpy', dtype=dtype) + elif array_api_compat.is_cupy_namespace(array_namespace): + from cupy.cublas import axpy as cublas_axpy + return cublas_axpy + else: + raise ValueError("Unsupported array namespace") diff --git a/dedalus/tools/general.py b/dedalus/tools/general.py index 18eb5ee4..9b8b5746 100644 --- a/dedalus/tools/general.py +++ b/dedalus/tools/general.py @@ -124,3 +124,15 @@ def is_complex_dtype(dtype): dtype = dtype.type return np.iscomplexobj(dtype()) + +def float_to_complex(dtype): + itemsize = np.dtype(dtype).itemsize + complex_dtype = np.dtype(f'complex{16*itemsize}') + return complex_dtype.type + + +def complex_to_float(dtype): + itemsize = np.dtype(dtype).itemsize + float_dtype = np.dtype(f'float{4*itemsize}') + return float_dtype.type + diff --git a/dedalus/tools/linalg_gpu.py b/dedalus/tools/linalg_gpu.py new file mode 100644 index 00000000..4e1c3a33 --- /dev/null +++ b/dedalus/tools/linalg_gpu.py @@ -0,0 +1,549 @@ +"""Linear algebra routines using cupy.""" + +import numpy as np +import math +try: + import cupy as cp + import cupyx.scipy.sparse as csp + import cupyx.scipy.sparse.linalg as cupy_spla + from cupyx import jit + cupy_available = True +except ImportError: + # Mock jit so module can still be imported without cupy + class jit: + @staticmethod + def rawkernel(): + def decorator(func): + return func + return decorator + cupy_available = False + + +def cupy_apply_csr(matrix, array, axis, out): + """Apply CSR matrix to arbitrary axis of array.""" + if not cupy_available: + raise ImportError("cupy must be installed to use GPU linear algebra") + # Check matrix format + if not isinstance(matrix, csp.csr_matrix): + # TODO: avoid this explicit conversion + print('WARNING: converting matrix to CSR format') + matrix = csp.csr_matrix(matrix) + #raise ValueError("Matrix must be in CSR format.") + # Switch by dimension + ndim = array.ndim + if ndim == 1: + if axis == 0: + out[:] = matrix.dot(array) + else: + raise ValueError("axis must be 0 for 1D arrays") + elif ndim == 2: + if axis == 0: + if array.shape[1] == 1: + out[:,0] = matrix.dot(array[:,0]) + else: + out[:] = matrix.dot(array) + elif axis == 1: + if array.shape[0] == 1: + out[0,:] = matrix.dot(array[0,:]) + else: + out[:] = matrix.dot(array.T).T + else: + raise ValueError("axis must be 0 or 1 for 2D arrays") + else: + # Treat as 3D array with specified axis in the middle + # Compute equivalent shape (N1, N2, N3) + if ndim == 3 and axis == 1: + N1 = array.shape[0] + N2 = array.shape[1] + N3 = array.shape[2] + else: + N1 = int(np.prod(array.shape[:axis])) + N2 = array.shape[axis] + N3 = int(np.prod(array.shape[axis+1:])) + # Dispatch to cupy routines + if N1 == 1: + if N3 == 1: + # (1, N2, 1) -> (N2,) + x1 = array.reshape((N2,)) + temp = matrix.dot(x1) + out[:] = temp.reshape(out.shape) + else: + # (1, N2, N3) -> (N2, N3) + x2 = array.reshape((N2, N3)) + temp = matrix.dot(x2) + out[:] = temp.reshape(out.shape) + else: + if N3 == 1: + # (N1, N2, 1) -> (N1, N2) + x2 = array.reshape((N1, N2)) + temp = matrix.dot(x2.T).T + out[:] = temp.reshape(out.shape) + else: + # (N1, N2, N3) + x3 = array.reshape((N1, N2, N3)) + y3 = out.reshape(((N1, matrix.shape[0], N3))) + cupy_apply_csr_mid(matrix, x3, y3) + + +@jit.rawkernel() +def apply_csr_mid_kernel(data, indices, indptr, x3, y3, N1, N2i, N2o, N3): + n1 = jit.blockIdx.x * jit.blockDim.x + jit.threadIdx.x # batch index + n3 = jit.blockIdx.y * jit.blockDim.y + jit.threadIdx.y # output column index + if n1 >= N1 or n3 >= N3: + return + # Loop over output rows = CSR matrix rows + for i in range(N2o): + acc = 0 * y3[n1, i, n3] # get right type + start = indptr[i] + end = indptr[i + 1] + for k in range(start, end): + j = indices[k] + acc += data[k] * x3[n1, j, n3] + y3[n1, i, n3] = acc + +def cupy_apply_csr_mid(matrix, array, out): + N1, N2i, N3 = array.shape + N2o = matrix.shape[0] + N1 = cp.uint32(N1) + N2i = cp.uint32(N2i) + N3 = cp.uint32(N3) + N2o = cp.uint32(N2o) + # Choose thread/block config + threads_y = min(1024, N3) # maximize concurrency along n3 + threads_x = 1024 // threads_y # make block have 1024 threads + blockdim = (threads_x, threads_y) + blocks_x = (N1 + threads_x - 1) // threads_x + blocks_y = (N3 + threads_y - 1) // threads_y + griddim = (blocks_x, blocks_y) + # Launch kernel + apply_csr_mid_kernel(griddim, blockdim, (matrix.data, matrix.indices, matrix.indptr, array, out, N1, N2i, N2o, N3)) + + +def custom_spsm(a, b, alpha=1.0, lower=True, unit_diag=False, transa=False, spsm_descr=None): + """Custom spsm wrapper to save spsm_descr, since spsm_analysis takes lots of time.""" + """Solves a sparse triangular linear system op(a) * x = alpha * op(b). + + Args: + a (cupyx.scipy.sparse.csr_matrix or cupyx.scipy.sparse.coo_matrix): + Sparse matrix with dimension ``(M, M)``. + b (cupy.ndarray): Dense matrix with dimension ``(M, K)``. + alpha (float or complex): Coefficient. + lower (bool): + True: ``a`` is lower triangle matrix. + False: ``a`` is upper triangle matrix. + unit_diag (bool): + True: diagonal part of ``a`` has unit elements. + False: diagonal part of ``a`` has non-unit elements. + transa (bool or str): True, False, 'N', 'T' or 'H'. + 'N' or False: op(a) == ``a``. + 'T' or True: op(a) == ``a.T``. + 'H': op(a) == ``a.conj().T``. + """ + import cupyx + from cupyx import cusparse + import cupy as _cupy + import numpy as _numpy + from cupy._core import _dtype + from cupy_backends.cuda.libs import cusparse as _cusparse + from cupy.cuda import device as _device + from cupyx.cusparse import SpMatDescriptor, DnMatDescriptor + if not cusparse.check_availability('spsm'): + raise RuntimeError('spsm is not available.') + + # Canonicalise transa + if transa is False: + transa = 'N' + elif transa is True: + transa = 'T' + elif transa not in 'NTH': + raise ValueError(f'Unknown transa (actual: {transa})') + + # Check A's type and sparse format + if cupyx.scipy.sparse.isspmatrix_csr(a): + pass + elif cupyx.scipy.sparse.isspmatrix_csc(a): + if transa == 'N': + a = a.T + transa = 'T' + elif transa == 'T': + a = a.T + transa = 'N' + elif transa == 'H': + a = a.conj().T + transa = 'N' + lower = not lower + elif cupyx.scipy.sparse.isspmatrix_coo(a): + pass + else: + raise ValueError('a must be CSR, CSC or COO sparse matrix') + assert a.has_canonical_format + + # Check B's ndim + if b.ndim == 1: + is_b_vector = True + b = b.reshape(-1, 1) + elif b.ndim == 2: + is_b_vector = False + else: + raise ValueError('b.ndim must be 1 or 2') + + # Check shapes + if not (a.shape[0] == a.shape[1] == b.shape[0]): + raise ValueError('mismatched shape') + + # Check dtypes + dtype = a.dtype + if dtype.char not in 'fdFD': + raise TypeError('Invalid dtype (actual: {})'.format(dtype)) + if dtype != b.dtype: + raise TypeError('dtype mismatch') + + # Prepare fill mode + if lower is True: + fill_mode = _cusparse.CUSPARSE_FILL_MODE_LOWER + elif lower is False: + fill_mode = _cusparse.CUSPARSE_FILL_MODE_UPPER + else: + raise ValueError('Unknown lower (actual: {})'.format(lower)) + + # Prepare diag type + if unit_diag is False: + diag_type = _cusparse.CUSPARSE_DIAG_TYPE_NON_UNIT + elif unit_diag is True: + diag_type = _cusparse.CUSPARSE_DIAG_TYPE_UNIT + else: + raise ValueError('Unknown unit_diag (actual: {})'.format(unit_diag)) + + # Prepare op_a + if transa == 'N': + op_a = _cusparse.CUSPARSE_OPERATION_NON_TRANSPOSE + elif transa == 'T': + op_a = _cusparse.CUSPARSE_OPERATION_TRANSPOSE + else: # transa == 'H' + if dtype.char in 'fd': + op_a = _cusparse.CUSPARSE_OPERATION_TRANSPOSE + else: + op_a = _cusparse.CUSPARSE_OPERATION_CONJUGATE_TRANSPOSE + + # Prepare op_b + if b._f_contiguous: + op_b = _cusparse.CUSPARSE_OPERATION_NON_TRANSPOSE + elif b._c_contiguous: + if _cusparse.get_build_version() < 11701: # earlier than CUDA 11.6 + raise ValueError('b must be F-contiguous.') + b = b.T + op_b = _cusparse.CUSPARSE_OPERATION_TRANSPOSE + else: + raise ValueError('b must be F-contiguous or C-contiguous.') + + # Allocate space for matrix C. Note that it is known cusparseSpSM requires + # the output matrix zero initialized. + m, _ = a.shape + if op_b == _cusparse.CUSPARSE_OPERATION_NON_TRANSPOSE: + _, n = b.shape + else: + n, _ = b.shape + c_shape = m, n + c = _cupy.zeros(c_shape, dtype=a.dtype, order='f') + + # Prepare descriptors and other parameters + handle = _device.get_cusparse_handle() + mat_a = SpMatDescriptor.create(a) + mat_b = DnMatDescriptor.create(b) + mat_c = DnMatDescriptor.create(c) + if spsm_descr is None: + spsm_descr = _cusparse.spSM_createDescr() + new_spsm_descr = True + else: + spsm_descr, buff = spsm_descr + new_spsm_descr = False + alpha = _numpy.array(alpha, dtype=c.dtype).ctypes + cuda_dtype = _dtype.to_cuda_dtype(c.dtype) + algo = _cusparse.CUSPARSE_SPSM_ALG_DEFAULT + + try: + # Specify Lower|Upper fill mode + mat_a.set_attribute(_cusparse.CUSPARSE_SPMAT_FILL_MODE, fill_mode) + + # Specify Unit|Non-Unit diagonal type + mat_a.set_attribute(_cusparse.CUSPARSE_SPMAT_DIAG_TYPE, diag_type) + + # Allocate the workspace needed by the succeeding phases + # Always calculate workspace (buff_size can change even for same spsm + # descriptor) + buff_size = _cusparse.spSM_bufferSize( + handle, op_a, op_b, alpha.data, mat_a.desc, mat_b.desc, + mat_c.desc, cuda_dtype, algo, spsm_descr) + + need_analysis = new_spsm_descr + if new_spsm_descr: + buff = _cupy.empty(buff_size, dtype=_cupy.int8) + else: + # Check if buff size grew from that in the cache + if buff is None or buff.size < buff_size: + buff = _cupy.empty(buff_size, dtype=_cupy.int8) + # buff changed so need the analysis phase + need_analysis = True + + # Perform the analysis phase + if need_analysis: + _cusparse.spSM_analysis( + handle, op_a, op_b, alpha.data, mat_a.desc, mat_b.desc, + mat_c.desc, cuda_dtype, algo, spsm_descr, buff.data.ptr) + + # Executes the solve phase + _cusparse.spSM_solve( + handle, op_a, op_b, alpha.data, mat_a.desc, mat_b.desc, + mat_c.desc, cuda_dtype, algo, spsm_descr, buff.data.ptr) + + # Reshape back if B was a vector + if is_b_vector: + c = c.reshape(-1) + + return c, (spsm_descr, buff) + + finally: + # Destroy matrix/vector descriptors + #_cusparse.spSM_destroyDescr(spsm_descr) + pass + + +def custom_SuperLU_solve(self, rhs, trans='N', spsm_descr=None): + """Custom SuperLU solve wrapper to save spsm_descr, since spsm_analysis takes lots of time.""" + """Solves linear system of equations with one or several right-hand sides. + + Args: + rhs (cupy.ndarray): Right-hand side(s) of equation with dimension + ``(M)`` or ``(M, K)``. + trans (str): 'N', 'T' or 'H'. + 'N': Solves ``A * x = rhs``. + 'T': Solves ``A.T * x = rhs``. + 'H': Solves ``A.conj().T * x = rhs``. + + Returns: + cupy.ndarray: + Solution vector(s) + """ # NOQA + from cupyx import cusparse + import cupy + from cupyx.scipy.sparse.linalg._solve import _should_use_spsm + + if not isinstance(rhs, cupy.ndarray): + raise TypeError('ojb must be cupy.ndarray') + if rhs.ndim not in (1, 2): + raise ValueError('rhs.ndim must be 1 or 2 (actual: {})'. + format(rhs.ndim)) + if rhs.shape[0] != self.shape[0]: + raise ValueError('shape mismatch (self.shape: {}, rhs.shape: {})' + .format(self.shape, rhs.shape)) + if trans not in ('N', 'T', 'H'): + raise ValueError('trans must be \'N\', \'T\', or \'H\'') + + if cusparse.check_availability('spsm') and _should_use_spsm(rhs): + def spsm(A, B, lower, transa, spsm_descr): + return custom_spsm(A, B, lower=lower, transa=transa, spsm_descr=spsm_descr) + sm = spsm + else: + raise NotImplementedError + + x = rhs.astype(self.L.dtype) + if trans == 'N': + if self.perm_r is not None: + if x.ndim == 2 and x._f_contiguous: + x = x.T[:, self._perm_r_rev].T # want to keep f-order + else: + x = x[self._perm_r_rev] + x, self.spsm_L_descr = sm(self.L, x, lower=True, transa=trans, spsm_descr=self.spsm_L_descr) + x, self.spsm_U_descr = sm(self.U, x, lower=False, transa=trans, spsm_descr=self.spsm_U_descr) + if self.perm_c is not None: + x = x[self.perm_c] + else: + if self.perm_c is not None: + if x.ndim == 2 and x._f_contiguous: + x = x.T[:, self._perm_c_rev].T # want to keep f-order + else: + x = x[self._perm_c_rev] + x, self.spsm_U_descr = sm(self.U, x, lower=False, transa=trans, spsm_descr=self.spsm_U_descr) + x, self.spsm_L_descr = sm(self.L, x, lower=True, transa=trans, spsm_descr=self.spsm_L_descr) + if self.perm_r is not None: + x = x[self.perm_r] + + if not x._f_contiguous: + # For compatibility with SciPy + x = x.copy(order='F') + return x + + +class CustomCupyUpperTriangularSolver: + """Hacky class to save spsm_descr for reuse in spsm for triangular solves.""" + + def __init__(self, matrix): + # Check matrix format + if not isinstance(matrix, csp.csr_matrix): + # TODO: avoid this explicit conversion + matrix = csp.csr_matrix(matrix) + print('WARNING: converting matrix to CSR format') + #raise ValueError("Matrix must be in CSR format.") + self.matrix = matrix + self.spsm_descr = None + + def solve(self, b, lower=True, overwrite_A=False, overwrite_b=False, + unit_diagonal=False): + """Solves a sparse triangular system ``A x = b``. + + Args: + A (cupyx.scipy.sparse.spmatrix): + Sparse matrix with dimension ``(M, M)``. + b (cupy.ndarray): + Dense vector or matrix with dimension ``(M)`` or ``(M, K)``. + lower (bool): + Whether ``A`` is a lower or upper triangular matrix. + If True, it is lower triangular, otherwise, upper triangular. + overwrite_A (bool): + (not supported) + overwrite_b (bool): + Allows overwriting data in ``b``. + unit_diagonal (bool): + If True, diagonal elements of ``A`` are assumed to be 1 and will + not be referenced. + + Returns: + cupy.ndarray: + Solution to the system ``A x = b``. The shape is the same as ``b``. + """ + from cupyx import cusparse + from cupyx.scipy import sparse + import cupy + from cupyx.scipy.sparse.linalg._solve import _should_use_spsm + + A = self.matrix + + if not (cusparse.check_availability('spsm') or + cusparse.check_availability('csrsm2')): + raise NotImplementedError + + if not sparse.isspmatrix(A): + raise TypeError('A must be cupyx.scipy.sparse.spmatrix') + if not isinstance(b, cupy.ndarray): + raise TypeError('b must be cupy.ndarray') + if A.shape[0] != A.shape[1]: + raise ValueError(f'A must be a square matrix (A.shape: {A.shape})') + if b.ndim not in [1, 2]: + raise ValueError(f'b must be 1D or 2D array (b.shape: {b.shape})') + if A.shape[0] != b.shape[0]: + raise ValueError('The size of dimensions of A must be equal to the ' + 'size of the first dimension of b ' + f'(A.shape: {A.shape}, b.shape: {b.shape})') + if A.dtype.char not in 'fdFD': + raise TypeError(f'unsupported dtype (actual: {A.dtype})') + + if cusparse.check_availability('spsm') and _should_use_spsm(b): + if not (sparse.isspmatrix_csr(A) or + sparse.isspmatrix_csc(A) or + sparse.isspmatrix_coo(A)): + warnings.warn('CSR, CSC or COO format is required. Converting to ' + 'CSR format.', sparse.SparseEfficiencyWarning) + A = A.tocsr() + A.sum_duplicates() + x, self.spsm_descr = custom_spsm(A, b, lower=lower, unit_diag=unit_diagonal, spsm_descr=self.spsm_descr) + elif cusparse.check_availability('csrsm2'): + if not (sparse.isspmatrix_csr(A) or sparse.isspmatrix_csc(A)): + warnings.warn('CSR or CSC format is required. Converting to CSR ' + 'format.', sparse.SparseEfficiencyWarning) + A = A.tocsr() + A.sum_duplicates() + + if (overwrite_b and A.dtype == b.dtype and + (b._c_contiguous or b._f_contiguous)): + x = b + else: + x = b.astype(A.dtype, copy=True) + + cusparse.csrsm2(A, x, lower=lower, unit_diag=unit_diagonal) + else: + assert False + + # TODO: Check if need this (breaks things for float32?) + # if x.dtype.char in 'fF': + # # Note: This is for compatibility with SciPy. + # dtype = numpy.promote_types(x.dtype, 'float64') + # x = x.astype(dtype) + return x + + +def cupy_solve_upper_csr(matrix, array, axis, out): + """Solve upper triangular CSR matrix along specified axis of an array.""" + # Switch by dimension + ndim = array.ndim + if ndim == 1: + if axis == 0: + cupy_solve_upper_csr_vec(matrix, array, out) + else: + raise ValueError("axis must be 0 for 1D arrays") + elif ndim == 2: + if axis == 0: + if array.shape[1] == 1: + cupy_solve_upper_csr_vec(matrix, array[:,0], out[:,0]) + else: + cupy_solve_upper_csr_first(matrix, array, out) + elif axis == 1: + if array.shape[0] == 1: + cupy_solve_upper_csr_vec(matrix, array[0,:], out[0,:]) + else: + cupy_solve_upper_csr_last(matrix, array, out) + else: + raise ValueError("axis must be 0 or 1 for 2D arrays") + else: + # Treat as 3D array with specified axis in the middle + # Compute equivalent shape (N1, N2, N3) + if ndim == 3 and axis == 1: + N1 = shape[0] + N2 = shape[1] + N3 = shape[2] + else: + N1 = int(np.prod(array.shape[:axis])) + N2 = array.shape[axis] + N3 = int(np.prod(array.shape[axis+1:])) + # Dispatch to cupy routines + if N1 == 1: + if N3 == 1: + # (1, N2, 1) -> (N2,) + x1 = array.reshape((N2,)) + y1 = out.reshape((N2,)) + cupy_solve_upper_csr_vec(matrix, x1, y1) + else: + # (1, N2, N3) -> (N2, N3) + x2 = array.reshape((N2, N3)) + y2 = out.reshape((N2, N3)) + cupy_solve_upper_csr_first(matrix, x2, y2) + else: + if N3 == 1: + # (N1, N2, 1) -> (N1, N2) + x2 = array.reshape((N1, N2)) + y2 = out.reshape((N1, N2)) + cupy_solve_upper_csr_last(matrix, x2, y2) + else: + # (N1, N2, N3) + x3 = array.reshape((N1, N2, N3)) + y3 = out.reshape((N1, N2, N3)) + cupy_solve_upper_csr_mid(matrix, x3, y3) + + +def cupy_solve_upper_csr_vec(matrix, vec, out): + """Solve upper triangular CSR matrix along a vector.""" + out[:] = matrix.solve(vec, lower=False) + + +def cupy_solve_upper_csr_first(matrix, array, out): + """Solve upper triangular CSR matrix along first axis of 2D array.""" + out[:] = matrix.solve(array, lower=False) + + +def cupy_solve_upper_csr_last(matrix, array, out): + """Solve upper triangular CSR matrix along last axis of 2D array.""" + out.T[:] = matrix.solve(array.T, lower=False) + + +def cupy_solve_upper_csr_mid(matrix, array, out): + """Solve upper triangular CSR matrix along middle axis of 3D array.""" + raise NotImplementedError diff --git a/docs/pages/gpu.rst b/docs/pages/gpu.rst new file mode 100644 index 00000000..92cac910 --- /dev/null +++ b/docs/pages/gpu.rst @@ -0,0 +1,39 @@ +GPU Support +*********** + +.. note:: + GPU support is currently experimental and may not be fully functional for all use cases. + Cartesian problems should generally work, but curvilinear problems are not yet supported. + +Dedalus supports GPU acceleration on NVIDIA and AMD GPUs through the use of the CuPy array library. +Specifically, Dedalus utilizes the "array-api-compat" library to provide a unified interface for Numpy and CuPy array types and operations. + +Installation +------------ + +In addition to the regular Dedalus dependencies, you will need CuPy installed to enable GPU support. CuPy itself can be installed using pip, but requires a compatible CUDA or ROCm environment to be installed on your system. + +Selecting GPUs +-------------- + +To utilize GPU acceleration, your Dedalus script needs to import CuPy and pass it as the "array_namespace" keyword to the Distributor object. +Fields built with that Distributor will then use CuPy arrays for their data and computations. +No other explicit changes are required to most scripts, but some care may need to be taken for setting field initial conditions, etc., and converting from other Numpy arrays in your scripts to CuPy arrays. + +.. code-block:: python + import dedalus.public as d3 + import cupy as cp + ... + dist = d3.Distributor(coords, dtype=np.float64, array_namespace=cp) + + +Guidelines +---------- + +GPU support is preliminary and many standard Dedalus features have not yet been fully implemented or optimized for GPUs. +Please keep in mind the following guidelines for best performance under the current capabilities: +- Curvilinear problems are not yet supported. +- Distributed GPUs (combining MPI will multiple GPUs) is not yet supported. +- Single and double precision, real and complex, are supported. +- By default, GPU problems are treated collectively as a single subproblem to speed up linear algebra on the GPU (looping explicitly over subproblems is slow). The trade off is that this can make matrix factorizations quite slow, so we strongly recommend using constant timesteps when possible to avoid refactorizations. We hope to address this issue shortly. + diff --git a/docs/pages/user_guide.rst b/docs/pages/user_guide.rst index d8ac3fed..942ad97b 100644 --- a/docs/pages/user_guide.rst +++ b/docs/pages/user_guide.rst @@ -8,6 +8,7 @@ General user guide: problem_formulations performance_tips + gpu configuration troubleshooting changes_from_v2 diff --git a/setup.py b/setup.py index 7291ba0e..36641924 100644 --- a/setup.py +++ b/setup.py @@ -181,6 +181,7 @@ def read(rel_path): # Runtime requirements install_requires = [ + "array-api-compat", "docopt", "h5py >= 3.0.0", "matplotlib >= 3.7.0",