Skip to content

Commit 07e4a54

Browse files
authored
Fix inner preparation behavior for Mooncake (#948)
1 parent 45fad0e commit 07e4a54

2 files changed

Lines changed: 17 additions & 0 deletions

File tree

DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/DifferentiationInterfaceMooncakeExt.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ using Mooncake:
3434
const AnyAutoMooncake{C} = Union{AutoMooncake{C}, AutoMooncakeForward{C}}
3535

3636
DI.check_available(::AnyAutoMooncake{C}) where {C} = true
37+
DI.inner_preparation_behavior(::AutoMooncakeForward) = DI.PrepareInnerSimple()
3738

3839
include("utils.jl")
3940
include("onearg.jl")

DifferentiationInterface/test/Back/Mooncake/test.jl

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,22 @@ test_differentiation(
2828
logging = LOGGING,
2929
);
3030

31+
EXCLUDED = @static if VERSION v"1.11-" && VERSION v"1.12-"
32+
# testing only :hessian on 1.11 due to an opaque closure bug.
33+
# this is potentially the same issue as discussed in
34+
# https://github.com/chalk-lab/MistyClosures.jl/pull/12#issue-3278662295
35+
[FIRST_ORDER..., :hvp, :second_derivative]
36+
else
37+
[FIRST_ORDER...]
38+
end
39+
40+
# Test second-order differentiation (forward-over-reverse)
41+
test_differentiation(
42+
[SecondOrder(AutoMooncakeForward(; config = nothing), AutoMooncake(; config = nothing))],
43+
excluded = EXCLUDED,
44+
logging = true,
45+
)
46+
3147
@testset "NamedTuples" begin
3248
ps = (; A = rand(5), B = rand(5))
3349
myfun(ps) = sum(ps.A .* ps.B)

0 commit comments

Comments
 (0)