Skip to content

Commit ee88aa9

Browse files
committed
updates
1 parent e440536 commit ee88aa9

2 files changed

Lines changed: 9 additions & 5 deletions

File tree

DifferentiationInterface/ext/DifferentiationInterfaceHyperHessiansExt/DifferentiationInterfaceHyperHessiansExt.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,12 @@ function DI.pick_batchsize(backend::DI.AutoHyperHessians, N::Integer)
4040
return DI.BatchSizeSettings{B}(N)
4141
end
4242

43+
function DI.threshold_batchsize(backend::DI.AutoHyperHessians, chunksize2::Integer)
44+
chunksize1 = backend.chunksize
45+
chunksize = isnothing(chunksize1) ? nothing : min(chunksize1, chunksize2)
46+
return DI.AutoHyperHessians(; chunksize)
47+
end
48+
4349
function _translate_toprep(::Type{T}, c::Union{DI.GeneralizedConstant, DI.ConstantOrCache}) where {T}
4450
return nothing
4551
end

DifferentiationInterface/test/Back/HyperHessians/test.jl

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,19 +19,17 @@ for backend in backends
1919
@test DI.check_inplace(backend)
2020
end
2121

22-
excluded_ops = [:pushforward, :pullback, :jacobian, :derivative, :gradient]
23-
24-
scenarios = default_scenarios(; include_constantified = true)
22+
scenarios = default_scenarios(; include_constantified = true, include_cachified = true)
2523

2624
test_differentiation(
2725
backends, scenarios;
28-
excluded = excluded_ops, logging = LOGGING,
26+
excluded = FIRST_ORDER, logging = LOGGING,
2927
)
3028

3129
test_differentiation(
3230
DI.AutoHyperHessians(), scenarios;
3331
correctness = false,
3432
type_stability = safetypestab(:prepared),
35-
excluded = excluded_ops,
33+
excluded = FIRST_ORDER,
3634
logging = LOGGING,
3735
)

0 commit comments

Comments
 (0)