Skip to content

Commit b8a65b4

Browse files
wsmosesgiordanostevengj
authored
Add Enzyme reverse rules (#110)
* Add Enzyme reverse rules * fix * fixup * Add test project file * gate per extension package * Update test/runtests.jl Co-authored-by: Mosè Giordano <[email protected]> * Update test/runtests.jl Co-authored-by: Mosè Giordano <[email protected]> * Update test/Project.toml Co-authored-by: Mosè Giordano <[email protected]> * Update Project.toml Co-authored-by: Mosè Giordano <[email protected]> * Add actual file * Update QuadGKEnzymeExt.jl * Update ext/QuadGKEnzymeExt.jl Co-authored-by: Steven G. Johnson <[email protected]> * fixup * fixup * Bump minimum to 1.9 * Update QuadGKEnzymeExt.jl * Update runtests.jl --------- Co-authored-by: Mosè Giordano <[email protected]> Co-authored-by: Steven G. Johnson <[email protected]>
1 parent 9b1acdb commit b8a65b4

File tree

6 files changed

+186
-7
lines changed

6 files changed

+186
-7
lines changed

.github/workflows/CI.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ jobs:
2222
fail-fast: false
2323
matrix:
2424
version:
25-
- '1.2'
25+
- '1.9'
2626
- '1'
2727
# - 'nightly'
2828
os:

Project.toml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,12 @@ version = "2.10.1"
66
DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
77
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
88

9+
[weakdeps]
10+
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
11+
12+
[extensions]
13+
QuadGKEnzymeExt = "Enzyme"
14+
915
[compat]
1016
DataStructures = "0.11, 0.12, 0.13, 0.14, 0.15, 0.16, 0.17, 0.18, 0.19"
1117
julia = "1.2"

ext/QuadGKEnzymeExt.jl

Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,132 @@
1+
2+
module QuadGKEnzymeExt
3+
4+
using QuadGK, Enzyme, LinearAlgebra
5+
6+
function Enzyme.EnzymeRules.augmented_primal(config, ofunc::Const{typeof(quadgk)}, ::Type{RT}, f, segs::Annotation{T}...; kws...) where {RT, T}
7+
prims = map(x->x.val, segs)
8+
9+
retres, segbuf = if f isa Const
10+
if EnzymeRules.needs_primal(config)
11+
quadgk(f.val, prims...; kws...), nothing
12+
else
13+
nothing
14+
end
15+
else
16+
I, E, segbuf = quadgk_segbuf(f.val, prims...; kws...)
17+
if EnzymeRules.needs_primal(config)
18+
(I, E), segbuf
19+
else
20+
nothing, segbuf
21+
end
22+
end
23+
24+
dres = if !Enzyme.EnzymeRules.needs_shadow(config)
25+
nothing
26+
elseif EnzymeRules.width(config) == 1
27+
zero.(res...)
28+
else
29+
ntuple(Val(EnzymeRules.width(config))) do i
30+
Base.@_inline_meta
31+
zero.(res...)
32+
end
33+
end
34+
35+
cache = if RT <: Duplicated || RT <: DuplicatedNoNeed || RT <: BatchDuplicated || RT <: BatchDuplicatedNoNeed
36+
dres
37+
else
38+
nothing
39+
end
40+
cache2 = segbuf, cache
41+
42+
return Enzyme.EnzymeRules.AugmentedReturn{
43+
Enzyme.EnzymeRules.needs_primal(config) ? eltype(RT) : Nothing,
44+
Enzyme.EnzymeRules.needs_shadow(config) ? (Enzyme.EnzymeRules.width(config) == 1 ? eltype(RT) : NTuple{Enzyme.EnzymeRules.width(config), eltype(RT)}) : Nothing,
45+
typeof(cache2)
46+
}(retres, dres, cache2)
47+
end
48+
49+
function call(f, x)
50+
f(x)
51+
end
52+
53+
# Wrapper around a function f that allows it to act as a vector space, and hence be usable as
54+
# an integrand, where the vector operations act on the closed-over parameters of f that are
55+
# begin differentiated with respect to. In particular, if we have a closure f = x -> g(x, p), and we want
56+
# to differentiate with respect to p, then our reverse (vJp) rule needs an integrand given by the
57+
# Jacobian-vector product (pullback) vᵀ∂g/∂p. But Enzyme wraps this in a closure so that it is the
58+
# same "shape" as f, whereas to integrate it we need to be able to treat it as a vector space.
59+
# ClosureVector calls Enzyme.Compiler.recursive_add, which is an internal function that "unwraps"
60+
# the closure to access the internal state, which can then be added/subtracted/scaled.
61+
struct ClosureVector{F}
62+
f::F
63+
end
64+
65+
@inline function guaranteed_nonactive(::Type{T}) where T
66+
rt = Enzyme.Compiler.active_reg_inner(T, (), nothing)
67+
return rt == Enzyme.Compiler.AnyState || rt == Enzyme.Compiler.DupState
68+
end
69+
70+
function Base.:+(a::CV, b::CV) where {CV <: ClosureVector}
71+
Enzyme.Compiler.recursive_add(a, b, identity, guaranteed_nonactive)::CV
72+
end
73+
74+
function Base.:-(a::CV, b::CV) where {CV <: ClosureVector}
75+
Enzyme.Compiler.recursive_add(a, b, x->-x, guaranteed_nonactive)::CV
76+
end
77+
78+
function Base.:*(a::Number, b::CV) where {CV <: ClosureVector}
79+
# b + (a-1) * b = a * b
80+
Enzyme.Compiler.recursive_add(b, b, x->(a-1)*x, guaranteed_nonactive)::CV
81+
end
82+
83+
function Base.:*(a::ClosureVector, b::Number)
84+
return b*a
85+
end
86+
87+
function Enzyme.EnzymeRules.reverse(config, ofunc::Const{typeof(quadgk)}, dres::Active, cache, f::Union{Const, Active}, segs::Annotation{T}...; kws...) where {T}
88+
df = if f isa Const
89+
nothing
90+
else
91+
segbuf = cache[1]
92+
fwd, rev = Enzyme.autodiff_thunk(ReverseSplitNoPrimal, Const{typeof(call)}, Active, typeof(f), Const{T})
93+
_df, _ = quadgk(map(x->x.val, segs)...; kws..., eval_segbuf=segbuf, maxevals=0, norm=f->0) do x
94+
tape, prim, shad = fwd(Const(call), f, Const(x))
95+
drev = rev(Const(call), f, Const(x), dres.val[1], tape)
96+
return ClosureVector(drev[1][1])
97+
end
98+
_df.f
99+
end
100+
dsegs1 = segs[1] isa Const ? nothing : -LinearAlgebra.dot(f.val(segs[1].val), dres.val[1])
101+
dsegsn = segs[end] isa Const ? nothing : LinearAlgebra.dot(f.val(segs[end].val), dres.val[1])
102+
return (df, # f
103+
dsegs1,
104+
ntuple(i -> nothing, Val(length(segs)-2))...,
105+
dsegsn)
106+
end
107+
108+
function Enzyme.EnzymeRules.reverse(config, ofunc::Const{typeof(quadgk)}, dres::Type{<:Union{Duplicated, BatchDuplicated}}, cache, f::Union{Const, Active}, segs::Annotation{T}...; kws...) where {T}
109+
dres = cache[2]
110+
df = if f isa Const
111+
nothing
112+
else
113+
segbuf = cache[1]
114+
fwd, rev = Enzyme.autodiff_thunk(ReverseSplitNoPrimal, Const{typeof(call)}, Active, typeof(f), Const{T})
115+
_df, _ = quadgk(map(x->x.val, segs)...; kws..., eval_segbuf=segbuf, maxevals=0, norm=f->0) do x
116+
tape, prim, shad = fwd(Const(call), f, Const(x))
117+
shad .= dres
118+
drev = rev(Const(call), f, Const(x), tape)
119+
return ClosureVector(drev[1][1])
120+
end
121+
_df.f
122+
end
123+
dsegs1 = segs[1] isa Const ? nothing : -LinearAlgebra.dot(f.val(segs[1].val), dres)
124+
dsegsn = segs[end] isa Const ? nothing : LinearAlgebra.dot(f.val(segs[end].val), dres)
125+
Enzyme.make_zero!(dres)
126+
return (df, # f
127+
dsegs1,
128+
ntuple(i -> nothing, Val(length(segs)-2))...,
129+
dsegsn)
130+
end
131+
132+
end # module

src/api.jl

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,15 @@ function quadgk!(f!, result, a::T,b::T,c::T...; atol=nothing, rtol=nothing, maxe
132132
return quadgk(f, a, b, c...; atol=atol, rtol=rtol, maxevals=maxevals, order=order, norm=norm, segbuf=segbuf, eval_segbuf=eval_segbuf)
133133
end
134134

135+
struct Counter{F}
136+
f::F
137+
count::Base.RefValue{Int}
138+
end
139+
function (c::Counter{F})(args...) where F
140+
c.count[] += 1
141+
c.f(args...)
142+
end
143+
135144
"""
136145
quadgk_count(f, args...; kws...)
137146
@@ -146,12 +155,9 @@ it may be possible to mathematically transform the problem in some way
146155
to improve the convergence rate.
147156
"""
148157
function quadgk_count(f, args...; kws...)
149-
count = 0
150-
i = quadgk(args...; kws...) do x
151-
count += 1
152-
f(x)
153-
end
154-
return (i..., count)
158+
counter = Counter(f, Ref(0))
159+
i = quadgk(counter, args...; kws...)
160+
return (i..., counter.count[])
155161
end
156162

157163
"""

test/Project.toml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
[deps]
2+
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
3+
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
4+
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"

test/runtests.jl

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -443,3 +443,34 @@ quadgk_segbuf_printnull(args...; kws...) = quadgk_segbuf_print(devnull, args...;
443443
@inferred QuadGK.to_segbuf([0,1])
444444
@inferred QuadGK.to_segbuf([(0,1+3im)])
445445
end
446+
447+
# Extension package only supported in 1.9+
448+
@static if VERSION >= v"1.9"
449+
using Enzyme
450+
f1(x) = quadgk(cos, 0., x)[1]
451+
f2(x) = quadgk(cos, x, 1)[1]
452+
f3(x) = quadgk(y->cos(x * y), 0., 1.)[1]
453+
454+
f1_count(x) = quadgk_count(cos, 0., x)[1]
455+
f2_count(x) = quadgk_count(cos, x, 1)[1]
456+
f3_count(x) = quadgk_count(y->cos(x * y), 0., 1.)[1]
457+
458+
f_vec(x) = sum(quadgk(y->[cos(x[1] * y), cos(x[2] * y)], 0., 1.)[1])
459+
460+
@testset "Enzyme" begin
461+
@test cos(0.3) Enzyme.autodiff(Reverse, f1, Active(0.3))[1][1]
462+
@test -cos(0.3) Enzyme.autodiff(Reverse, f2, Active(0.3))[1][1]
463+
@test (0.3 * cos(0.3) - sin(0.3))/(0.3*0.3) Enzyme.autodiff(Reverse, f3, Active(0.3))[1][1]
464+
465+
@test cos(0.3) Enzyme.autodiff(Reverse, f1_count, Active(0.3))[1][1]
466+
@test -cos(0.3) Enzyme.autodiff(Reverse, f2_count, Active(0.3))[1][1]
467+
@test (0.3 * cos(0.3) - sin(0.3))/(0.3*0.3) Enzyme.autodiff(Reverse, f3_count, Active(0.3))[1][1]
468+
469+
x = [0.3, 0.7]
470+
dx = [0.0, 0.0]
471+
f_vec(x)
472+
# TODO custom rule with mixed vector returns not yet supported x/ref https://github.com/EnzymeAD/Enzyme.jl/issues/1692
473+
@test_throws Enzyme.Compiler.EnzymeRuntimeException autodiff(Reverse, f_vec, Duplicated(x, dx))
474+
# @test dx ≈ [(0.3 * cos(0.3) - sin(0.3))/(0.3*0.3), (0.7 * cos(0.7) - sin(0.7))/(0.7*0.7)]
475+
end
476+
end

0 commit comments

Comments
 (0)