Skip to content

Commit b7adfb6

Browse files
authored
fix: speed up Mooncake in forward mode by preallocating tangents (#915)
* fix: speed up Mooncake in forward mode by preallocating tangents * Coverage
1 parent 6b6ca3f commit b7adfb6

6 files changed

Lines changed: 110 additions & 13 deletions

File tree

DifferentiationInterface/docs/make.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ makedocs(;
3636
"api.md",
3737
"Development" => [
3838
"dev/internals.md",
39+
"dev/math.md",
3940
"dev/contributing.md",
4041
],
4142
],
Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
# Mathematical model
2+
3+
This page recaps the mathematical model of automatic differentiation used by DI, which justifies how preparation results are constructed.
4+
It is inspired by
5+
6+
- the [documentation](https://chalk-lab.github.io/Mooncake.jl/stable/understanding_mooncake/rule_system/) of [Mooncake.jl](https://github.com/chalk-lab/Mooncake.jl)
7+
- [this Discourse answer](https://discourse.julialang.org/t/do-i-understand-enzyme-properly/97760) about [Enzyme.jl](https://github.com/EnzymeAD/Enzyme.jl)
8+
9+
## Setting and hypotheses
10+
11+
Consider a mathematical function $f(x, c, s) = y$ where
12+
13+
- $x \in \mathcal{X}$ is the active argument (the one being differentiated)
14+
- $c \in \mathcal{C}$ is a constant argument (corresponds to [`Constant`](@ref) contexts)
15+
- $s \in \mathcal{S}$ is a scratch argument (corresponds to [`Cache`](@ref) contexts)
16+
- $y \in \mathcal{Y}$ is the output
17+
18+
In Julia code, some of the input arguments might be mutated, while the output may be written to as well.
19+
Therefore, the proper model is a function $\phi(x_0, c_0, s_0, y_0) = (x_1, c_1, s_1, y_1)$ where $a_0$ is the state of argument $a$ before $f$ is run, while $a_1$ is its state after $a$ is run.
20+
21+
DI makes the following hypotheses on the implementation of $f$ (aka the behavior of $\phi$):
22+
23+
1. The active argument $x$ is not mutated, so $x_1 = x_0$
24+
2. The constant argument $c$ is not mutated, so $c_1 = c_0$
25+
3. The initial value of the scratch argument $s_0$ does not matter
26+
4. The initial value of the output $y_0$ does not matter
27+
28+
## Forward mode
29+
30+
We want to compute a Jacobian-Vector Product (JVP) $\dot{y} = \left(\frac{\partial f}{\partial x}\right) \dot{x}$ where $\dot{x} \in \mathcal{X}$ is an input tangent.
31+
32+
To do that, we run our AD backend on $\phi$ with input tangents $(\dot{x}_0, \dot{c}_0, \dot{s}_0, \dot{y}_0)$ and obtain $(\dot{x}_1, \dot{c}_1, \dot{s}_1, \dot{y}_1)$.
33+
The interesting value is
34+
$$\dot{y}_1 = \frac{\partial y_1}{\partial x_0} \dot{x}_0 + \frac{\partial y_1}{\partial c_0} \dot{c}_0 + \frac{\partial y_1}{\partial s_0} \dot{s}_0 + \frac{\partial y_1}{\partial y_0} \dot{y}_0$$
35+
36+
Thanks to our hypotheses 3 and 4 on the function's implementation, $\frac{\partial y_1}{\partial s_0} = 0$ and $\frac{\partial y_1}{\partial y_0} = 0$, so we are left with:
37+
$$\dot{y}_1 = \frac{\partial y_1}{\partial x_0} \dot{x_0} + \frac{\partial y_1}{\partial c_0} \dot{c_0}$$
38+
39+
Thus, as long as $\dot{c}_0 = 0$, the output tangent $\dot{y}_1$ contains the correct JVP.
40+
Let us now look at $\dot{s}_1$ with the help of hypothesis 2:
41+
$$\dot{c}_1 = \frac{\partial c_1}{\partial x_0} \dot{x}_0 + \frac{\partial c_1}{\partial c_0} \dot{c}_0 + \frac{\partial c_1}{\partial s_0} \dot{s}_0 + \frac{\partial c_1}{\partial y_0} \dot{y}_0 = \dot{c}_0$$
42+
43+
The tangent of $c$ will always be preserved by differentiation.
44+
45+
## Reverse mode
46+
47+
We want to compute a Vector-Jacobian Product (VJP) $\bar{x} = \left(\frac{\partial f}{\partial x}\right)^* \bar{y}$ where $\bar{y} \in \mathcal{Y}$ is an output sensivity.
48+
49+
To do that, we run our AD backend on $\phi$ with output sensitivities $(\bar{x}_1, \bar{c}_1, \bar{s}_1, \bar{y}_1)$ and obtain $(\bar{x}_0, \bar{c}_0, \bar{s}_0, \bar{y}_0)$.
50+
The interesting value is
51+
$$\bar{x}_0 = \left(\frac{\partial x_1}{\partial x_0}\right)^* \bar{x}_1 + \left(\frac{\partial c_1}{\partial x_0}\right)^* \bar{c}_1 + \left(\frac{\partial s_1}{\partial x_0}\right)^* \bar{s}_1 + \left(\frac{\partial y_1}{\partial x_0}\right)^* \bar{y}_1$$
52+
53+
Thanks to our hypotheses 1 and 2 on the function's implementation, $\frac{\partial x_1}{\partial x_0} = I$ and $\frac{\partial c_1}{\partial x_0} = 0$, so we are left with:
54+
$$\bar{x}_0 = \bar{x}_1 + \left(\frac{\partial s_1}{\partial x_0}\right)^* \bar{s}_1 + \left(\frac{\partial y_1}{\partial x_0}\right)^* \bar{y}_1$$
55+
56+
Thus, as long as $\bar{x}_1 = 0$ and $\bar{s}_1 = 0$, the input sensitivity $\bar{x}_0$ contains the correct VJP.
57+
Let us now look at $\bar{s}_0$ with the help of hypothesis 3:
58+
59+
$$\bar{s}_0 = \left(\frac{\partial x_1}{\partial s_0}\right)^* \bar{x}_1 + \left(\frac{\partial c_1}{\partial s_0}\right)^* \bar{c}_1 + \left(\frac{\partial s_1}{\partial s_0}\right)^* \bar{s}_1 + \left(\frac{\partial y_1}{\partial s_0}\right)^* \bar{y}_1 = 0$$
60+
61+
The sensitivity of $s$ will always be set to $0$ by differentiation.
62+
63+
## Implementation
64+
65+
DI's preparation mechanism allows pre-allocating the memory for tangents and sensitivities, inside a `prep` object.
66+
This object is then reused across several AD calls.
67+
68+
For mutable objects, each AD call performs the following transformations on the provided shadow/dual storage (`Duplicated` for Enzyme, `Dual` / `CoDual` for Mooncake):
69+
70+
- In forward mode, $\dot{a}$ is updated from $\dot{a}_0$ to $\dot{a}_1$
71+
- In reverse mode, $\bar{a}$ is updated from $\bar{a}_1$ to $\bar{a}_0$
72+
73+
### At initialization
74+
75+
How to initialize shadow/dual memory inside `prep`?
76+
77+
- In forward mode, make sure that $\dot{c} = 0$.
78+
- In reverse mode, make sure that $\bar{x} = 0$ and $\bar{s} = 0$.
79+
80+
### At every call
81+
82+
Should the shadow/dual memory inside `prep` be reset before every AD call?
83+
84+
- In forward mode, no need ($\dot{c}$ will remain $0$ if it is initialized to $0$)
85+
- In reverse mode, just set $\bar{x} = 0$ ($\bar{s}$ will be reset to $0$ at every AD call)

DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/DifferentiationInterfaceMooncakeExt.jl

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,9 +35,7 @@ const AnyAutoMooncake{C} = Union{AutoMooncake{C}, AutoMooncakeForward{C}}
3535

3636
DI.check_available(::AnyAutoMooncake{C}) where {C} = true
3737

38-
get_config(::AnyAutoMooncake{Nothing}) = Config()
39-
get_config(backend::AnyAutoMooncake{<:Config}) = backend.config
40-
38+
include("utils.jl")
4139
include("onearg.jl")
4240
include("twoarg.jl")
4341
include("forward_onearg.jl")

DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/forward_onearg.jl

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
## Pushforward
22

3-
struct MooncakeOneArgPushforwardPrep{SIG, Tcache, DX} <: DI.PushforwardPrep{SIG}
3+
struct MooncakeOneArgPushforwardPrep{SIG, Tcache, DX, FT, CT} <: DI.PushforwardPrep{SIG}
44
_sig::Val{SIG}
55
cache::Tcache
66
dx_righttype::DX
7+
df::FT
8+
context_tangents::CT
79
end
810

911
function DI.prepare_pushforward_nokwarg(
@@ -20,7 +22,9 @@ function DI.prepare_pushforward_nokwarg(
2022
f, x, map(DI.unwrap, contexts)...; config.debug_mode, config.silence_debug_messages
2123
)
2224
dx_righttype = zero_tangent(x)
23-
prep = MooncakeOneArgPushforwardPrep(_sig, cache, dx_righttype)
25+
df = zero_tangent(f)
26+
context_tangents = map(zero_tangent_unwrap, contexts)
27+
prep = MooncakeOneArgPushforwardPrep(_sig, cache, dx_righttype, df, context_tangents)
2428
return prep
2529
end
2630

@@ -38,9 +42,9 @@ function DI.value_and_pushforward(
3842
dx isa tangent_type(X) ? dx : _copy_to_output!!(prep.dx_righttype, dx)
3943
y_dual = value_and_derivative!!(
4044
prep.cache,
41-
zero_dual(f),
45+
Dual(f, prep.df),
4246
Dual(x, dx_righttype),
43-
map(zero_dual DI.unwrap, contexts)...,
47+
map(Dual_unwrap, contexts, prep.context_tangents)...,
4448
)
4549
y = primal(y_dual)
4650
dy = _copy_output(tangent(y_dual))

DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/forward_twoarg.jl

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
## Pushforward
22

3-
struct MooncakeTwoArgPushforwardPrep{SIG, Tcache, DX, DY} <: DI.PushforwardPrep{SIG}
3+
struct MooncakeTwoArgPushforwardPrep{SIG, Tcache, DX, DY, FT, CT} <: DI.PushforwardPrep{SIG}
44
_sig::Val{SIG}
55
cache::Tcache
66
dx_righttype::DX
77
dy_righttype::DY
8+
df!::FT
9+
context_tangents::CT
810
end
911

1012
function DI.prepare_pushforward_nokwarg(
@@ -28,7 +30,9 @@ function DI.prepare_pushforward_nokwarg(
2830
)
2931
dx_righttype = zero_tangent(x)
3032
dy_righttype = zero_tangent(y)
31-
prep = MooncakeTwoArgPushforwardPrep(_sig, cache, dx_righttype, dy_righttype)
33+
df! = zero_tangent(f!)
34+
context_tangents = map(zero_tangent_unwrap, contexts)
35+
prep = MooncakeTwoArgPushforwardPrep(_sig, cache, dx_righttype, dy_righttype, df!, context_tangents)
3236
return prep
3337
end
3438

@@ -48,10 +52,10 @@ function DI.value_and_pushforward(
4852
y_dual = zero_dual(y)
4953
value_and_derivative!!(
5054
prep.cache,
51-
zero_dual(f!),
55+
Dual(f!, prep.df!),
5256
y_dual,
5357
Dual(x, dx_righttype),
54-
map(zero_dual DI.unwrap, contexts)...,
58+
map(Dual_unwrap, contexts, prep.context_tangents)...,
5559
)
5660
dy = _copy_output(tangent(y_dual))
5761
return dy
@@ -90,10 +94,10 @@ function DI.value_and_pushforward!(
9094
dy isa tangent_type(Y) ? dy : _copy_to_output!!(prep.dy_righttype, dy)
9195
value_and_derivative!!(
9296
prep.cache,
93-
zero_dual(f!),
97+
Dual(f!, prep.df!),
9498
Dual(y, dy_righttype),
9599
Dual(x, dx_righttype),
96-
map(zero_dual DI.unwrap, contexts)...,
100+
map(Dual_unwrap, contexts, prep.context_tangents)...,
97101
)
98102
dy === dy_righttype || copyto!(dy, dy_righttype)
99103
end
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
get_config(::AnyAutoMooncake{Nothing}) = Config()
2+
get_config(backend::AnyAutoMooncake{<:Config}) = backend.config
3+
4+
@inline zero_tangent_unwrap(c::DI.Context) = zero_tangent(DI.unwrap(c))
5+
@inline Dual_unwrap(c, dc) = Dual(DI.unwrap(c), dc)

0 commit comments

Comments
 (0)