Skip to content

Commit 8d6dc63

Browse files
committed
fixups and use ADTypes HyperHessians
1 parent d516e7d commit 8d6dc63

6 files changed

Lines changed: 19 additions & 38 deletions

File tree

DifferentiationInterface/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,12 +67,12 @@ Diffractor = "=0.2.6"
6767
Enzyme = "0.13.39"
6868
EnzymeCore = "0.8.8"
6969
FastDifferentiation = "0.4.3"
70-
HyperHessians = "0.2"
7170
FiniteDiff = "2.27.0"
7271
FiniteDifferences = "0.12.31"
7372
ForwardDiff = "0.10.36,1"
7473
GPUArraysCore = "0.2"
7574
GTPSA = "1.4.0"
75+
HyperHessians = "0.2"
7676
LinearAlgebra = "1"
7777
Mooncake = "0.5.1 - 0.5.24"
7878
PolyesterForwardDiff = "0.1.2"

DifferentiationInterface/docs/src/explanation/backends.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ We support the following dense backend choices from [ADTypes.jl](https://github.
1212
- [`AutoFiniteDifferences`](@extref ADTypes.AutoFiniteDifferences)
1313
- [`AutoForwardDiff`](@extref ADTypes.AutoForwardDiff)
1414
- [`AutoGTPSA`](@extref ADTypes.AutoGTPSA)
15-
- [`AutoHyperHessians`](https://github.com/KristofferC/HyperHessians.jl)
15+
- [`AutoHyperHessians`](@extref ADTypes.AutoHyperHessians)
1616
- [`AutoMooncake`](@extref ADTypes.AutoMooncake) and [`AutoMooncakeForward`](@extref ADTypes.AutoMooncake) (the latter is experimental)
1717
- [`AutoPolyesterForwardDiff`](@extref ADTypes.AutoPolyesterForwardDiff)
1818
- [`AutoReverseDiff`](@extref ADTypes.AutoReverseDiff)
@@ -71,7 +71,7 @@ Moreover, each context type is supported by a specific subset of backends:
7171
| `AutoFiniteDifferences` |||
7272
| `AutoForwardDiff` |||
7373
| `AutoGTPSA` |||
74-
| `AutoHyperHessians` || |
74+
| `AutoHyperHessians` || |
7575
| `AutoMooncake` |||
7676
| `AutoMooncakeForward` |||
7777
| `AutoPolyesterForwardDiff` |||

DifferentiationInterface/ext/DifferentiationInterfaceHyperHessiansExt/DifferentiationInterfaceHyperHessiansExt.jl

Lines changed: 13 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -2,33 +2,29 @@ module DifferentiationInterfaceHyperHessiansExt
22

33
import DifferentiationInterface as DI
44
import .DI: AutoHyperHessians
5-
using ADTypes: ForwardMode
65
using HyperHessians:
76
HVPConfig,
87
HessianConfig,
98
Chunk,
109
chunksize,
11-
pickchunksize,
1210
HyperDual,
1311
hessian,
1412
hessian!,
1513
hessian_gradient_value,
1614
hessian_gradient_value!,
17-
hessian,
1815
hvp,
1916
hvp!,
2017
hvp_gradient_value,
2118
hvp_gradient_value!
2219

2320
## Traits
2421
DI.check_available(::DI.AutoHyperHessians) = true
25-
DI.inplace_support(::DI.AutoHyperHessians) = DI.InPlaceSupported()
26-
DI.mode(::DI.AutoHyperHessians) = ForwardMode()
22+
DI.inplace_support(::DI.AutoHyperHessians) = DI.InPlaceNotSupported()
2723

28-
chunk_from_backend(backend::DI.AutoHyperHessians, x) =
29-
isnothing(backend.chunksize) ? Chunk(x) : Chunk{backend.chunksize}()
30-
chunk_from_backend(backend::DI.AutoHyperHessians, N::Integer, ::Type{T}) where {T} =
31-
isnothing(backend.chunksize) ? Chunk(pickchunksize(N, T), T) : Chunk{backend.chunksize}()
24+
chunk_from_backend(::DI.AutoHyperHessians{nothing}, x) = Chunk(x)
25+
chunk_from_backend(::DI.AutoHyperHessians{CS}, x) where {CS} = Chunk{CS}()
26+
chunk_from_backend(::DI.AutoHyperHessians{nothing}, N::Integer, ::Type{T}) where {T} = Chunk(N, T)
27+
chunk_from_backend(::DI.AutoHyperHessians{CS}, ::Integer, ::Type) where {CS} = Chunk{CS}()
3228

3329
function DI.pick_batchsize(backend::DI.AutoHyperHessians, x::AbstractArray)
3430
B = chunksize(chunk_from_backend(backend, x))
@@ -40,10 +36,11 @@ function DI.pick_batchsize(backend::DI.AutoHyperHessians, N::Integer)
4036
return DI.BatchSizeSettings{B}(N)
4137
end
4238

43-
function DI.threshold_batchsize(backend::DI.AutoHyperHessians, chunksize2::Integer)
44-
chunksize1 = backend.chunksize
45-
chunksize = isnothing(chunksize1) ? nothing : min(chunksize1, chunksize2)
46-
return DI.AutoHyperHessians(; chunksize)
39+
function DI.threshold_batchsize(::DI.AutoHyperHessians{nothing}, ::Integer)
40+
return DI.AutoHyperHessians()
41+
end
42+
function DI.threshold_batchsize(::DI.AutoHyperHessians{CS}, chunksize2::Integer) where {CS}
43+
return DI.AutoHyperHessians(; chunksize = min(CS, chunksize2))
4744
end
4845

4946
function _translate_toprep(::Type{T}, c::Union{DI.GeneralizedConstant, DI.ConstantOrCache}) where {T}
@@ -111,8 +108,8 @@ function DI.second_derivative!(
111108
contexts::Vararg{DI.Context, C},
112109
) where {C}
113110
DI.check_prep(f, prep, backend, x, contexts...)
114-
copyto!(der2, DI.second_derivative(f, prep, backend, x, contexts...))
115-
return der2
111+
new_der2 = DI.second_derivative(f, prep, backend, x, contexts...)
112+
return copyto!(der2, new_der2)
116113
end
117114

118115
function DI.value_derivative_and_second_derivative(
@@ -253,7 +250,7 @@ end
253250
function DI.hvp(
254251
f,
255252
prep::HyperHessiansHVPPrep,
256-
backend::AutoHyperHessians,
253+
backend::DI.AutoHyperHessians,
257254
x,
258255
tx::NTuple,
259256
contexts::Vararg{DI.Context, C},

DifferentiationInterface/src/DifferentiationInterface.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,8 @@ using ADTypes:
3333
AutoReverseDiff,
3434
AutoSymbolics,
3535
AutoTracker,
36-
AutoZygote
36+
AutoZygote,
37+
AutoHyperHessians
3738
using LinearAlgebra: dot
3839

3940
include("compat.jl")
@@ -64,7 +65,6 @@ include("second_order/hessian.jl")
6465

6566
include("misc/differentiate_with.jl")
6667
include("misc/from_primitive.jl")
67-
include("misc/autohyperhessians.jl")
6868
include("misc/sparsity_detector.jl")
6969
include("misc/simple_finite_diff.jl")
7070
include("misc/zero_backends.jl")

DifferentiationInterface/src/misc/autohyperhessians.jl

Lines changed: 0 additions & 16 deletions
This file was deleted.

DifferentiationInterface/test/Back/HyperHessians/test.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ backends = [
1616

1717
for backend in backends
1818
@test DI.check_available(backend)
19-
@test DI.check_inplace(backend)
19+
@test !DI.check_inplace(backend)
2020
end
2121

2222
scenarios = default_scenarios(; include_constantified = true, include_cachified = true)

0 commit comments

Comments
 (0)