From 5c9a0276a41fc0801145a56e71b471a1c46cde42 Mon Sep 17 00:00:00 2001 From: mloubout Date: Fri, 1 May 2026 13:33:32 -0400 Subject: [PATCH] api: add option to inject without increment --- devito/operations/interpolators.py | 17 ++++++++------- .../self_adjoint/sa_03_iso_correctness.ipynb | 6 +++--- examples/seismic/utils.py | 2 +- tests/test_interpolation.py | 21 +++++++++++++++++++ 4 files changed, 35 insertions(+), 11 deletions(-) diff --git a/devito/operations/interpolators.py b/devito/operations/interpolators.py index 160b33fdee..4c597a2492 100644 --- a/devito/operations/interpolators.py +++ b/devito/operations/interpolators.py @@ -166,18 +166,20 @@ class Injection(UnevaluatedSparseOperation): __rargs__ = ('field', 'expr', 'implicit_dims') + UnevaluatedSparseOperation.__rargs__ - def __new__(cls, field, expr, implicit_dims, interpolator): + def __new__(cls, field, expr, increment, implicit_dims, interpolator): obj = super().__new__(cls, interpolator) # TODO: unused now, but will be necessary to compute the adjoint obj.field = field obj.expr = expr + obj.increment = increment obj.implicit_dims = implicit_dims return obj def operation(self, **kwargs): return self.interpolator._inject(expr=self.expr, field=self.field, + increment=self.increment, implicit_dims=self.implicit_dims) def __repr__(self): @@ -372,7 +374,7 @@ def interpolate(self, expr, increment=False, self_subs=None, implicit_dims=None) @check_radius @check_coords - def inject(self, field, expr, implicit_dims=None): + def inject(self, field, expr, increment=True, implicit_dims=None): """ Generate equations injecting an arbitrary expression into a field. @@ -387,7 +389,7 @@ def inject(self, field, expr, implicit_dims=None): injection expression, but that should be honored when constructing the operator. """ - return Injection(field, expr, implicit_dims, self) + return Injection(field, expr, increment, implicit_dims, self) def _interpolate(self, expr, increment=False, self_subs=None, implicit_dims=None): """ @@ -439,7 +441,7 @@ def _interpolate(self, expr, increment=False, self_subs=None, implicit_dims=None return temps + summands + last - def _inject(self, field, expr, implicit_dims=None): + def _inject(self, field, expr, increment=True, implicit_dims=None): """ Generate equations injecting an arbitrary expression into a field. @@ -489,9 +491,10 @@ def _inject(self, field, expr, implicit_dims=None): pos_only=variables, subdomain=subdomain) # Substitute coordinate base symbols into the interpolation coefficients - eqns = [Inc(_field.xreplace(idx_subs), - (self._weights(subdomain=subdomain) * _expr).xreplace(idx_subs), - implicit_dims=implicit_dims) + ecls = Inc if increment else Eq + eqns = [ecls(_field.xreplace(idx_subs), + (self._weights(subdomain=subdomain) * _expr).xreplace(idx_subs), + implicit_dims=implicit_dims) for (_field, _expr) in zip(fields, _exprs, strict=True)] return temps + eqns diff --git a/examples/seismic/self_adjoint/sa_03_iso_correctness.ipynb b/examples/seismic/self_adjoint/sa_03_iso_correctness.ipynb index c4933e85ee..feba7f7824 100644 --- a/examples/seismic/self_adjoint/sa_03_iso_correctness.ipynb +++ b/examples/seismic/self_adjoint/sa_03_iso_correctness.ipynb @@ -578,9 +578,9 @@ "output_type": "stream", "text": [ "Operator `IsoFwdOperator` ran in 0.04 s\n", - "No source type defined, returning uninitiallized (zero) source\n", + "No source type defined, returning uninitialized (zero) source\n", "Operator `IsoAdjOperator` ran in 0.03 s\n", - "No source type defined, returning uninitiallized (zero) source\n", + "No source type defined, returning uninitialized (zero) source\n", "Operator `IsoAdjOperator` ran in 0.03 s\n" ] }, @@ -639,7 +639,7 @@ "output_type": "stream", "text": [ "Operator `IsoFwdOperator` ran in 0.03 s\n", - "No source type defined, returning uninitiallized (zero) source\n", + "No source type defined, returning uninitialized (zero) source\n", "Operator `IsoAdjOperator` ran in 0.03 s\n" ] }, diff --git a/examples/seismic/utils.py b/examples/seismic/utils.py index 468e44ddd6..b4eb98a5dc 100644 --- a/examples/seismic/utils.py +++ b/examples/seismic/utils.py @@ -194,7 +194,7 @@ def src(self): def new_src(self, name='src', src_type='self', coordinates=None): coords = coordinates or self.src_positions if self.src_type is None or src_type is None: - warning("No source type defined, returning uninitiallized (zero) source") + warning("No source type defined, returning uninitialized (zero) source") src = PointSource(name=name, grid=self.grid, time_range=self.time_axis, npoint=self.nsrc, coordinates=coords, diff --git a/tests/test_interpolation.py b/tests/test_interpolation.py index 7c605a599b..1ba6bf095f 100644 --- a/tests/test_interpolation.py +++ b/tests/test_interpolation.py @@ -440,6 +440,27 @@ def test_inject(shape, coords, result, npoints=19): assert np.allclose(a.data[indices], result, rtol=1.e-5) +@pytest.mark.parametrize('shape, coords', [ + ((11, 11), [(.1, .9), (.4, .4)]), + ((11, 11, 11), [(.1, .9), (.4, .4), (.4, .4)]) +]) +def test_inject_no_incr(shape, coords, npoints=9): + a = unit_box(shape=shape) + a.data[:] = 2. + p = points(a.grid, coords, npoints=npoints) + + p.data[:] = 3. + expr = p.inject(a, p, increment=False) + op = Operator(expr, subs=a.grid.spacing_map) + + op(a=a) + + indices = [slice(4, 5, 1) for _ in coords] + indices[0] = slice(1, -1, 1) + # Should be 3 at the points + assert np.allclose(a.data[indices], 3, rtol=1.e-5) + + @pytest.mark.parametrize('shape, coords, nexpr, result', [ ((11, 11), [(.05, .95), (.45, .45)], 1, 1.), ((11, 11), [(.05, .95), (.45, .45)], 2, 1.),