|
| 1 | + |
| 2 | +module QuadGKEnzymeExt |
| 3 | + |
| 4 | +using QuadGK, Enzyme, LinearAlgebra |
| 5 | + |
| 6 | +function Enzyme.EnzymeRules.augmented_primal(config, ofunc::Const{typeof(quadgk)}, ::Type{RT}, f, segs::Annotation{T}...; kws...) where {RT, T} |
| 7 | + prims = map(x->x.val, segs) |
| 8 | + |
| 9 | + retres, segbuf = if f isa Const |
| 10 | + if EnzymeRules.needs_primal(config) |
| 11 | + quadgk(f.val, prims...; kws...), nothing |
| 12 | + else |
| 13 | + nothing |
| 14 | + end |
| 15 | + else |
| 16 | + I, E, segbuf = quadgk_segbuf(f.val, prims...; kws...) |
| 17 | + if EnzymeRules.needs_primal(config) |
| 18 | + (I, E), segbuf |
| 19 | + else |
| 20 | + nothing, segbuf |
| 21 | + end |
| 22 | + end |
| 23 | + |
| 24 | + dres = if !Enzyme.EnzymeRules.needs_shadow(config) |
| 25 | + nothing |
| 26 | + elseif EnzymeRules.width(config) == 1 |
| 27 | + zero.(res...) |
| 28 | + else |
| 29 | + ntuple(Val(EnzymeRules.width(config))) do i |
| 30 | + Base.@_inline_meta |
| 31 | + zero.(res...) |
| 32 | + end |
| 33 | + end |
| 34 | + |
| 35 | + cache = if RT <: Duplicated || RT <: DuplicatedNoNeed || RT <: BatchDuplicated || RT <: BatchDuplicatedNoNeed |
| 36 | + dres |
| 37 | + else |
| 38 | + nothing |
| 39 | + end |
| 40 | + cache2 = segbuf, cache |
| 41 | + |
| 42 | + return Enzyme.EnzymeRules.AugmentedReturn{ |
| 43 | + Enzyme.EnzymeRules.needs_primal(config) ? eltype(RT) : Nothing, |
| 44 | + Enzyme.EnzymeRules.needs_shadow(config) ? (Enzyme.EnzymeRules.width(config) == 1 ? eltype(RT) : NTuple{Enzyme.EnzymeRules.width(config), eltype(RT)}) : Nothing, |
| 45 | + typeof(cache2) |
| 46 | + }(retres, dres, cache2) |
| 47 | +end |
| 48 | + |
| 49 | +function call(f, x) |
| 50 | + f(x) |
| 51 | +end |
| 52 | + |
| 53 | +# Wrapper around a function f that allows it to act as a vector space, and hence be usable as |
| 54 | +# an integrand, where the vector operations act on the closed-over parameters of f that are |
| 55 | +# begin differentiated with respect to. In particular, if we have a closure f = x -> g(x, p), and we want |
| 56 | +# to differentiate with respect to p, then our reverse (vJp) rule needs an integrand given by the |
| 57 | +# Jacobian-vector product (pullback) vᵀ∂g/∂p. But Enzyme wraps this in a closure so that it is the |
| 58 | +# same "shape" as f, whereas to integrate it we need to be able to treat it as a vector space. |
| 59 | +# ClosureVector calls Enzyme.Compiler.recursive_add, which is an internal function that "unwraps" |
| 60 | +# the closure to access the internal state, which can then be added/subtracted/scaled. |
| 61 | +struct ClosureVector{F} |
| 62 | + f::F |
| 63 | +end |
| 64 | + |
| 65 | +@inline function guaranteed_nonactive(::Type{T}) where T |
| 66 | + rt = Enzyme.Compiler.active_reg_inner(T, (), nothing) |
| 67 | + return rt == Enzyme.Compiler.AnyState || rt == Enzyme.Compiler.DupState |
| 68 | +end |
| 69 | + |
| 70 | +function Base.:+(a::CV, b::CV) where {CV <: ClosureVector} |
| 71 | + Enzyme.Compiler.recursive_add(a, b, identity, guaranteed_nonactive)::CV |
| 72 | +end |
| 73 | + |
| 74 | +function Base.:-(a::CV, b::CV) where {CV <: ClosureVector} |
| 75 | + Enzyme.Compiler.recursive_add(a, b, x->-x, guaranteed_nonactive)::CV |
| 76 | +end |
| 77 | + |
| 78 | +function Base.:*(a::Number, b::CV) where {CV <: ClosureVector} |
| 79 | + # b + (a-1) * b = a * b |
| 80 | + Enzyme.Compiler.recursive_add(b, b, x->(a-1)*x, guaranteed_nonactive)::CV |
| 81 | +end |
| 82 | + |
| 83 | +function Base.:*(a::ClosureVector, b::Number) |
| 84 | + return b*a |
| 85 | +end |
| 86 | + |
| 87 | +function Enzyme.EnzymeRules.reverse(config, ofunc::Const{typeof(quadgk)}, dres::Active, cache, f::Union{Const, Active}, segs::Annotation{T}...; kws...) where {T} |
| 88 | + df = if f isa Const |
| 89 | + nothing |
| 90 | + else |
| 91 | + segbuf = cache[1] |
| 92 | + fwd, rev = Enzyme.autodiff_thunk(ReverseSplitNoPrimal, Const{typeof(call)}, Active, typeof(f), Const{T}) |
| 93 | + _df, _ = quadgk(map(x->x.val, segs)...; kws..., eval_segbuf=segbuf, maxevals=0, norm=f->0) do x |
| 94 | + tape, prim, shad = fwd(Const(call), f, Const(x)) |
| 95 | + drev = rev(Const(call), f, Const(x), dres.val[1], tape) |
| 96 | + return ClosureVector(drev[1][1]) |
| 97 | + end |
| 98 | + _df.f |
| 99 | + end |
| 100 | + dsegs1 = segs[1] isa Const ? nothing : -LinearAlgebra.dot(f.val(segs[1].val), dres.val[1]) |
| 101 | + dsegsn = segs[end] isa Const ? nothing : LinearAlgebra.dot(f.val(segs[end].val), dres.val[1]) |
| 102 | + return (df, # f |
| 103 | + dsegs1, |
| 104 | + ntuple(i -> nothing, Val(length(segs)-2))..., |
| 105 | + dsegsn) |
| 106 | +end |
| 107 | + |
| 108 | +function Enzyme.EnzymeRules.reverse(config, ofunc::Const{typeof(quadgk)}, dres::Type{<:Union{Duplicated, BatchDuplicated}}, cache, f::Union{Const, Active}, segs::Annotation{T}...; kws...) where {T} |
| 109 | + dres = cache[2] |
| 110 | + df = if f isa Const |
| 111 | + nothing |
| 112 | + else |
| 113 | + segbuf = cache[1] |
| 114 | + fwd, rev = Enzyme.autodiff_thunk(ReverseSplitNoPrimal, Const{typeof(call)}, Active, typeof(f), Const{T}) |
| 115 | + _df, _ = quadgk(map(x->x.val, segs)...; kws..., eval_segbuf=segbuf, maxevals=0, norm=f->0) do x |
| 116 | + tape, prim, shad = fwd(Const(call), f, Const(x)) |
| 117 | + shad .= dres |
| 118 | + drev = rev(Const(call), f, Const(x), tape) |
| 119 | + return ClosureVector(drev[1][1]) |
| 120 | + end |
| 121 | + _df.f |
| 122 | + end |
| 123 | + dsegs1 = segs[1] isa Const ? nothing : -LinearAlgebra.dot(f.val(segs[1].val), dres) |
| 124 | + dsegsn = segs[end] isa Const ? nothing : LinearAlgebra.dot(f.val(segs[end].val), dres) |
| 125 | + Enzyme.make_zero!(dres) |
| 126 | + return (df, # f |
| 127 | + dsegs1, |
| 128 | + ntuple(i -> nothing, Val(length(segs)-2))..., |
| 129 | + dsegsn) |
| 130 | +end |
| 131 | + |
| 132 | +end # module |
0 commit comments