@@ -2,33 +2,29 @@ module DifferentiationInterfaceHyperHessiansExt
22
33import DifferentiationInterface as DI
44import . DI: AutoHyperHessians
5- using ADTypes: ForwardMode
65using HyperHessians:
76 HVPConfig,
87 HessianConfig,
98 Chunk,
109 chunksize,
11- pickchunksize,
1210 HyperDual,
1311 hessian,
1412 hessian!,
1513 hessian_gradient_value,
1614 hessian_gradient_value!,
17- hessian,
1815 hvp,
1916 hvp!,
2017 hvp_gradient_value,
2118 hvp_gradient_value!
2219
2320# # Traits
2421DI. check_available (:: DI.AutoHyperHessians ) = true
25- DI. inplace_support (:: DI.AutoHyperHessians ) = DI. InPlaceSupported ()
26- DI. mode (:: DI.AutoHyperHessians ) = ForwardMode ()
22+ DI. inplace_support (:: DI.AutoHyperHessians ) = DI. InPlaceNotSupported ()
2723
28- chunk_from_backend (backend :: DI.AutoHyperHessians , x) =
29- isnothing (backend . chunksize) ? Chunk ( x) : Chunk {backend.chunksize } ()
30- chunk_from_backend (backend :: DI.AutoHyperHessians , N:: Integer , :: Type{T} ) where {T} =
31- isnothing (backend . chunksize) ? Chunk ( pickchunksize (N, T), T) : Chunk {backend.chunksize } ()
24+ chunk_from_backend (:: DI.AutoHyperHessians{nothing} , x) = Chunk (x)
25+ chunk_from_backend ( :: DI.AutoHyperHessians{CS} , x) where {CS} = Chunk {CS } ()
26+ chunk_from_backend (:: DI.AutoHyperHessians{nothing} , N:: Integer , :: Type{T} ) where {T} = Chunk (N, T)
27+ chunk_from_backend ( :: DI.AutoHyperHessians{CS} , :: Integer , :: Type ) where {CS} = Chunk {CS } ()
3228
3329function DI. pick_batchsize (backend:: DI.AutoHyperHessians , x:: AbstractArray )
3430 B = chunksize (chunk_from_backend (backend, x))
@@ -40,10 +36,11 @@ function DI.pick_batchsize(backend::DI.AutoHyperHessians, N::Integer)
4036 return DI. BatchSizeSettings {B} (N)
4137end
4238
43- function DI. threshold_batchsize (backend:: DI.AutoHyperHessians , chunksize2:: Integer )
44- chunksize1 = backend. chunksize
45- chunksize = isnothing (chunksize1) ? nothing : min (chunksize1, chunksize2)
46- return DI. AutoHyperHessians (; chunksize)
39+ function DI. threshold_batchsize (:: DI.AutoHyperHessians{nothing} , :: Integer )
40+ return DI. AutoHyperHessians ()
41+ end
42+ function DI. threshold_batchsize (:: DI.AutoHyperHessians{CS} , chunksize2:: Integer ) where {CS}
43+ return DI. AutoHyperHessians (; chunksize = min (CS, chunksize2))
4744end
4845
4946function _translate_toprep (:: Type{T} , c:: Union{DI.GeneralizedConstant, DI.ConstantOrCache} ) where {T}
@@ -111,8 +108,8 @@ function DI.second_derivative!(
111108 contexts:: Vararg{DI.Context, C} ,
112109 ) where {C}
113110 DI. check_prep (f, prep, backend, x, contexts... )
114- copyto! (der2, DI. second_derivative (f, prep, backend, x, contexts... ) )
115- return der2
111+ new_der2 = DI. second_derivative (f, prep, backend, x, contexts... )
112+ return copyto! ( der2, new_der2)
116113end
117114
118115function DI. value_derivative_and_second_derivative (
253250function DI. hvp (
254251 f,
255252 prep:: HyperHessiansHVPPrep ,
256- backend:: AutoHyperHessians ,
253+ backend:: DI. AutoHyperHessians ,
257254 x,
258255 tx:: NTuple ,
259256 contexts:: Vararg{DI.Context, C} ,
0 commit comments