Skip to content

Commit 5996df5

Browse files
ErikQQYgdalle
andauthored
fix: overloaded_input_type for one-element vector input (#954)
* fix: overloaded_input_type for one-element vector input * Fix bugs in overloaded input type --------- Co-authored-by: Guillaume Dalle <[email protected]>
1 parent d6b05e4 commit 5996df5

4 files changed

Lines changed: 17 additions & 9 deletions

File tree

DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/misc.jl

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,11 +22,15 @@ end
2222
function DI.overloaded_input_type(prep::ForwardDiffOneArgDerivativePrep)
2323
return DI.overloaded_input_type(prep.pushforward_prep)
2424
end
25-
DI.overloaded_input_type(prep::ForwardDiffTwoArgDerivativePrep) = typeof(prep.config.duals)
25+
function DI.overloaded_input_type(
26+
prep::ForwardDiffTwoArgDerivativePrep{SIG, X, <:DerivativeConfig{T}}
27+
) where {SIG, X, T}
28+
return typeof(Dual{T}(one(X), one(X)))
29+
end
2630

2731
## Gradient
2832
DI.overloaded_input_type(prep::ForwardDiffGradientPrep) = typeof(prep.config.duals)
2933

3034
## Jacobian
31-
DI.overloaded_input_type(prep::ForwardDiffOneArgJacobianPrep) = typeof(prep.config.duals[2])
35+
DI.overloaded_input_type(prep::ForwardDiffOneArgJacobianPrep) = typeof(prep.config.duals)
3236
DI.overloaded_input_type(prep::ForwardDiffTwoArgJacobianPrep) = typeof(prep.config.duals[2])

DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/twoarg.jl

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -182,20 +182,21 @@ end
182182

183183
### Prepared
184184

185-
struct ForwardDiffTwoArgDerivativePrep{SIG, C, CD} <: DI.DerivativePrep{SIG}
185+
struct ForwardDiffTwoArgDerivativePrep{SIG, X, C, CD} <: DI.DerivativePrep{SIG}
186186
_sig::Val{SIG}
187+
x::X
187188
config::C
188189
contexts_dual::CD
189190
end
190191

191192
function DI.prepare_derivative_nokwarg(
192-
strict::Val, f!::F, y, backend::AutoForwardDiff, x, contexts::Vararg{DI.Context, C}
193-
) where {F, C}
193+
strict::Val, f!::F, y, backend::AutoForwardDiff, x::X, contexts::Vararg{DI.Context, C}
194+
) where {F, C, X}
194195
_sig = DI.signature(f!, y, backend, x, contexts...; strict)
195196
tag = get_tag(f!, backend, x)
196197
config = DerivativeConfig(nothing, y, x, tag)
197198
contexts_dual = translate_toprep(dual_type(config), contexts)
198-
return ForwardDiffTwoArgDerivativePrep(_sig, config, contexts_dual)
199+
return ForwardDiffTwoArgDerivativePrep(_sig, copy(x), config, contexts_dual)
199200
end
200201

201202
function DI.prepare!_derivative(

DifferentiationInterface/src/misc/overloading.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
"""
22
overloaded_input_type(prep)
33
4-
If it exists, return the overloaded input type which will be passed to the differentiated function when preparation result `prep` is reused.
4+
If it exists, return the overloaded input type (for the differentiated argument `x`) which will be passed to the differentiated function when preparation result `prep` is reused.
55
66
!!! danger
77

DifferentiationInterface/test/Back/ForwardDiff/test.jl

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ end
104104
@test DI.overloaded_input_type(prepare_derivative(copy, backend, x)) ==
105105
ForwardDiff.Dual{ForwardDiff.Tag{typeof(copy), Float64}, Float64, 1}
106106
@test DI.overloaded_input_type(prepare_derivative(copyto!, y, backend, x)) ==
107-
Vector{ForwardDiff.Dual{ForwardDiff.Tag{typeof(copyto!), Float64}, Float64, 1}}
107+
ForwardDiff.Dual{ForwardDiff.Tag{typeof(copyto!), Float64}, Float64, 1}
108108

109109
# Gradient
110110
x = [1.0, 1.0]
@@ -114,12 +114,15 @@ end
114114
# Jacobian
115115
x = [1.0, 0.0, 0.0]
116116
@test DI.overloaded_input_type(prepare_jacobian(copy, backend, x)) ==
117-
ForwardDiff.Dual{ForwardDiff.Tag{typeof(copy), Float64}, Float64, 3}
117+
Vector{ForwardDiff.Dual{ForwardDiff.Tag{typeof(copy), Float64}, Float64, 3}}
118118
@test DI.overloaded_input_type(prepare_jacobian(copyto!, similar(x), backend, x)) ==
119119
Vector{ForwardDiff.Dual{ForwardDiff.Tag{typeof(copyto!), Float64}, Float64, 3}}
120120
@test DI.overloaded_input_type(
121121
prepare_jacobian(copyto!, similar(x), sparse_backend, x)
122122
) == Vector{ForwardDiff.Dual{ForwardDiff.Tag{typeof(copyto!), Float64}, Float64, 1}}
123+
# Jacobian with one-element input
124+
@test DI.overloaded_input_type(prepare_jacobian(copy, backend, [1.0])) ==
125+
Vector{ForwardDiff.Dual{ForwardDiff.Tag{typeof(copy), Float64}, Float64, 1}}
123126
end;
124127

125128
include("benchmark.jl")

0 commit comments

Comments
 (0)