@@ -372,6 +372,35 @@ function _jacobian_aux(
372372 end
373373end
374374
375+ function _jacobian_aux (
376+ f_or_f!y:: FY ,
377+ prep:: PushforwardJacobianPrep{SIG, <:BatchSizeSettings{1, false, true}} ,
378+ backend:: AbstractADType ,
379+ x,
380+ contexts:: Vararg{Context, C} ,
381+ ) where {FY, SIG, C}
382+ (; batched_seeds, seed_example, pushforward_prep) = prep
383+
384+ pushforward_prep_same = prepare_pushforward_same_point (
385+ f_or_f!y... , pushforward_prep, backend, x, seed_example, contexts...
386+ )
387+
388+ jac = stack (eachindex (batched_seeds); dims = 2 ) do a
389+ dy = only (
390+ pushforward (
391+ f_or_f!y... ,
392+ pushforward_prep_same,
393+ backend,
394+ x,
395+ batched_seeds[a],
396+ contexts... ,
397+ )
398+ )
399+ return vec (dy)
400+ end
401+ return jac
402+ end
403+
375404function _jacobian_aux (
376405 f_or_f!y:: FY ,
377406 prep:: PushforwardJacobianPrep{SIG, <:BatchSizeSettings{B, false, aligned}} ,
@@ -428,6 +457,34 @@ function _jacobian_aux(
428457 end
429458end
430459
460+ function _jacobian_aux (
461+ f_or_f!y:: FY ,
462+ prep:: PullbackJacobianPrep{SIG, <:BatchSizeSettings{1, false, true}} ,
463+ backend:: AbstractADType ,
464+ x,
465+ contexts:: Vararg{Context, C} ,
466+ ) where {FY, SIG, C}
467+ (; batched_seeds, seed_example, pullback_prep) = prep
468+
469+ pullback_prep_same = prepare_pullback_same_point (
470+ f_or_f!y... , pullback_prep, backend, x, seed_example, contexts...
471+ )
472+
473+ jac = stack (eachindex (batched_seeds); dims = 1 ) do a
474+ dx = only (
475+ pullback (
476+ f_or_f!y... , pullback_prep_same, backend, x, batched_seeds[a], contexts...
477+ )
478+ )
479+ if eltype (x) <: Complex
480+ return map (conj, vec (dx))
481+ else
482+ return vec (dx)
483+ end
484+ end
485+ return jac
486+ end
487+
431488function _jacobian_aux (
432489 f_or_f!y:: FY ,
433490 prep:: PullbackJacobianPrep{SIG, <:BatchSizeSettings{B, false, aligned}} ,
0 commit comments