Skip to content
This repository was archived by the owner on May 15, 2025. It is now read-only.

Commit fadbffa

Browse files
committed
Add ForwardDiff Inplace Overloads
1 parent da36df6 commit fadbffa

File tree

7 files changed

+160
-143
lines changed

7 files changed

+160
-143
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "SimpleNonlinearSolve"
22
uuid = "727e6d20-b764-4bd8-a329-72de5adea6c7"
33
authors = ["SciML"]
4-
version = "1.1.0"
4+
version = "1.2.0"
55

66
[deps]
77
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"

src/ad.jl

Lines changed: 53 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,24 @@
1-
function scalar_nlsolve_ad(prob, alg, args...; kwargs...)
2-
f = prob.f
1+
function SciMLBase.solve(prob::NonlinearProblem{<:Union{Number, <:AbstractArray},
2+
iip, <:Union{<:Dual{T, V, P}, <:AbstractArray{<:Dual{T, V, P}}}},
3+
alg::AbstractSimpleNonlinearSolveAlgorithm, args...; kwargs...) where {T, V, P, iip}
4+
sol, partials = __nlsolve_ad(prob, alg, args...; kwargs...)
5+
dual_soln = __nlsolve_dual_soln(sol.u, partials, prob.p)
6+
return SciMLBase.build_solution(prob, alg, dual_soln, sol.resid; sol.retcode, sol.stats,
7+
sol.original)
8+
end
9+
10+
function __nlsolve_ad(prob::NonlinearProblem{uType, iip}, alg, args...;
11+
kwargs...) where {uType, iip}
312
p = value(prob.p)
4-
if prob isa IntervalNonlinearProblem
5-
tspan = value.(prob.tspan)
6-
newprob = IntervalNonlinearProblem(f, tspan, p; prob.kwargs...)
7-
else
8-
u0 = value(prob.u0)
9-
newprob = NonlinearProblem(f, u0, p; prob.kwargs...)
10-
end
13+
newprob = NonlinearProblem(prob.f, value(prob.u0), p; prob.kwargs...)
1114

1215
sol = solve(newprob, alg, args...; kwargs...)
1316

1417
uu = sol.u
15-
f_p = scalar_nlsolve_∂f_∂p(f, uu, p)
16-
f_x = scalar_nlsolve_∂f_∂u(f, uu, p)
18+
f_p = __nlsolve_∂f_∂p(prob, prob.f, uu, p)
19+
f_x = __nlsolve_∂f_∂u(prob, prob.f, uu, p)
1720

18-
z_arr = -inv(f_x) * f_p
21+
z_arr = -f_x \ f_p
1922

2023
pp = prob.p
2124
sumfun = ((z, p),) -> map(zᵢ -> zᵢ * ForwardDiff.partials(p), z)
@@ -30,58 +33,66 @@ function scalar_nlsolve_ad(prob, alg, args...; kwargs...)
3033
return sol, partials
3134
end
3235

33-
function SciMLBase.solve(prob::NonlinearProblem{<:Union{Number, SVector, <:AbstractArray},
34-
false, <:Dual{T, V, P}}, alg::AbstractSimpleNonlinearSolveAlgorithm, args...;
35-
kwargs...) where {T, V, P}
36-
sol, partials = scalar_nlsolve_ad(prob, alg, args...; kwargs...)
37-
dual_soln = scalar_nlsolve_dual_soln(sol.u, partials, prob.p)
38-
return SciMLBase.build_solution(prob, alg, dual_soln, sol.resid; sol.retcode)
39-
end
40-
41-
function SciMLBase.solve(prob::NonlinearProblem{<:Union{Number, SVector, <:AbstractArray},
42-
false, <:AbstractArray{<:Dual{T, V, P}}},
43-
alg::AbstractSimpleNonlinearSolveAlgorithm, args...; kwargs...) where {T, V, P}
44-
sol, partials = scalar_nlsolve_ad(prob, alg, args...; kwargs...)
45-
dual_soln = scalar_nlsolve_dual_soln(sol.u, partials, prob.p)
46-
return SciMLBase.build_solution(prob, alg, dual_soln, sol.resid; sol.retcode)
47-
end
48-
49-
function scalar_nlsolve_∂f_∂p(f, u, p)
50-
ff = p isa Number ? ForwardDiff.derivative :
51-
(u isa Number ? ForwardDiff.gradient : ForwardDiff.jacobian)
52-
return ff(Base.Fix1(f, u), p)
36+
@inline function __nlsolve_∂f_∂p(prob, f::F, u, p) where {F}
37+
if isinplace(prob)
38+
__f = p -> begin
39+
du = similar(u, promote_type(eltype(u), eltype(p)))
40+
f(du, u, p)
41+
return du
42+
end
43+
else
44+
__f = Base.Fix1(f, u)
45+
end
46+
if p isa Number
47+
return __reshape(ForwardDiff.derivative(__f, p), :, 1)
48+
elseif u isa Number
49+
return __reshape(ForwardDiff.gradient(__f, p), 1, :)
50+
else
51+
return ForwardDiff.jacobian(__f, p)
52+
end
5353
end
5454

55-
function scalar_nlsolve_∂f_∂u(f, u, p)
56-
ff = u isa Number ? ForwardDiff.derivative : ForwardDiff.jacobian
57-
return ff(Base.Fix2(f, p), u)
55+
@inline function __nlsolve_∂f_∂u(prob, f::F, u, p) where {F}
56+
if isinplace(prob)
57+
du = similar(u)
58+
__f = (du, u) -> f(du, u, p)
59+
ForwardDiff.jacobian(__f, du, u)
60+
else
61+
__f = Base.Fix2(f, p)
62+
if u isa Number
63+
return ForwardDiff.derivative(__f, u)
64+
else
65+
return ForwardDiff.jacobian(__f, u)
66+
end
67+
end
5868
end
5969

60-
function scalar_nlsolve_dual_soln(u::Number, partials,
70+
@inline function __nlsolve_dual_soln(u::Number, partials,
6171
::Union{<:AbstractArray{<:Dual{T, V, P}}, Dual{T, V, P}}) where {T, V, P}
6272
return Dual{T, V, P}(u, partials)
6373
end
6474

65-
function scalar_nlsolve_dual_soln(u::AbstractArray, partials,
75+
@inline function __nlsolve_dual_soln(u::AbstractArray, partials,
6676
::Union{<:AbstractArray{<:Dual{T, V, P}}, Dual{T, V, P}}) where {T, V, P}
67-
return map(((uᵢ, pᵢ),) -> Dual{T, V, P}(uᵢ, pᵢ), zip(u, partials))
77+
_partials = _restructure(u, partials)
78+
return map(((uᵢ, pᵢ),) -> Dual{T, V, P}(uᵢ, pᵢ), zip(u, _partials))
6879
end
6980

7081
# avoid ambiguities
7182
for Alg in [Bisection]
7283
@eval function SciMLBase.solve(prob::IntervalNonlinearProblem{uType, iip,
7384
<:Dual{T, V, P}}, alg::$Alg, args...; kwargs...) where {uType, iip, T, V, P}
74-
sol, partials = scalar_nlsolve_ad(prob, alg, args...; kwargs...)
75-
dual_soln = scalar_nlsolve_dual_soln(sol.u, partials, prob.p)
85+
sol, partials = __nlsolve_ad(prob, alg, args...; kwargs...)
86+
dual_soln = __nlsolve_dual_soln(sol.u, partials, prob.p)
7687
return SciMLBase.build_solution(prob, alg, dual_soln, sol.resid; sol.retcode,
7788
left = Dual{T, V, P}(sol.left, partials),
7889
right = Dual{T, V, P}(sol.right, partials))
7990
end
8091
@eval function SciMLBase.solve(prob::IntervalNonlinearProblem{uType, iip,
8192
<:AbstractArray{<:Dual{T, V, P}}}, alg::$Alg, args...;
8293
kwargs...) where {uType, iip, T, V, P}
83-
sol, partials = scalar_nlsolve_ad(prob, alg, args...; kwargs...)
84-
dual_soln = scalar_nlsolve_dual_soln(sol.u, partials, prob.p)
94+
sol, partials = __nlsolve_ad(prob, alg, args...; kwargs...)
95+
dual_soln = __nlsolve_dual_soln(sol.u, partials, prob.p)
8596
return SciMLBase.build_solution(prob, alg, dual_soln, sol.resid; sol.retcode,
8697
left = Dual{T, V, P}(sol.left, partials),
8798
right = Dual{T, V, P}(sol.right, partials))

src/nlsolve/halley.jl

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,14 @@ function SciMLBase.__solve(prob::NonlinearProblem, alg::SimpleHalley, args...;
5555
setindex_trait(x) === CannotSetindex() && (A = dfx)
5656

5757
# Factorize Once and Reuse
58-
dfx_fact = factorize(dfx)
58+
dfx_fact = if dfx isa Number
59+
dfx
60+
else
61+
fact = lu(dfx; check = false)
62+
!issuccess(fact) && return build_solution(prob, alg, x, fx;
63+
retcode = ReturnCode.Unstable)
64+
fact
65+
end
5966

6067
aᵢ = dfx_fact \ _vec(fx)
6168
A_ = _vec(A)

src/utils.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -381,3 +381,6 @@ end
381381
return AutoFiniteDiff()
382382
end
383383
end
384+
385+
@inline __reshape(x::Number, args...) = x
386+
@inline __reshape(x::AbstractArray, args...) = reshape(x, args...)

test/basictests.jl

Lines changed: 0 additions & 98 deletions
Original file line numberDiff line numberDiff line change
@@ -64,36 +64,6 @@ const TERMINATION_CONDITIONS = [
6464
autodiff = AutoForwardDiff())) == 0
6565
end
6666

67-
@testset "[OOP] Immutable AD" begin
68-
for p in [1.0, 100.0]
69-
@test begin
70-
res = benchmark_nlsolve_oop(quadratic_f, @SVector[1.0, 1.0], p)
71-
res_true = sqrt(p)
72-
all(res.u .≈ res_true)
73-
end
74-
@test ForwardDiff.derivative(p -> benchmark_nlsolve_oop(quadratic_f,
75-
@SVector[1.0, 1.0], p).u[end], p) 1 / (2 * sqrt(p))
76-
end
77-
end
78-
79-
@testset "[OOP] Scalar AD" begin
80-
for p in 1.0:0.1:100.0
81-
@test begin
82-
res = benchmark_nlsolve_oop(quadratic_f, 1.0, p)
83-
res_true = sqrt(p)
84-
res.u res_true
85-
end
86-
@test ForwardDiff.derivative(p -> benchmark_nlsolve_oop(quadratic_f, 1.0, p).u,
87-
p) 1 / (2 * sqrt(p))
88-
end
89-
end
90-
91-
t = (p) -> [sqrt(p[2] / p[1])]
92-
p = [0.9, 50.0]
93-
@test benchmark_nlsolve_oop(quadratic_f2, 0.5, p).u sqrt(p[2] / p[1])
94-
@test ForwardDiff.jacobian(p -> [benchmark_nlsolve_oop(quadratic_f2, 0.5, p).u],
95-
p) ForwardDiff.jacobian(t, p)
96-
9767
@testset "Termination condition: $(termination_condition) u0: $(_nameof(u0))" for termination_condition in TERMINATION_CONDITIONS,
9868
u0 in (1.0, [1.0, 1.0], @SVector[1.0, 1.0])
9969

@@ -124,36 +94,6 @@ end
12494
autodiff = AutoForwardDiff())) == 0
12595
end
12696

127-
@testset "[OOP] Immutable AD" begin
128-
for p in [1.0, 100.0]
129-
@test begin
130-
res = benchmark_nlsolve_oop(quadratic_f, @SVector[1.0, 1.0], p)
131-
res_true = sqrt(p)
132-
all(res.u .≈ res_true)
133-
end
134-
@test ForwardDiff.derivative(p -> benchmark_nlsolve_oop(quadratic_f,
135-
@SVector[1.0, 1.0], p).u[end], p) 1 / (2 * sqrt(p))
136-
end
137-
end
138-
139-
@testset "[OOP] Scalar AD" begin
140-
for p in 1.0:0.1:100.0
141-
@test begin
142-
res = benchmark_nlsolve_oop(quadratic_f, 1.0, p)
143-
res_true = sqrt(p)
144-
res.u res_true
145-
end
146-
@test ForwardDiff.derivative(p -> benchmark_nlsolve_oop(quadratic_f, 1.0, p).u,
147-
p) 1 / (2 * sqrt(p))
148-
end
149-
end
150-
151-
t = (p) -> [sqrt(p[2] / p[1])]
152-
p = [0.9, 50.0]
153-
@test benchmark_nlsolve_oop(quadratic_f2, 0.5, p).u sqrt(p[2] / p[1])
154-
@test ForwardDiff.jacobian(p -> [benchmark_nlsolve_oop(quadratic_f2, 0.5, p).u],
155-
p) ForwardDiff.jacobian(t, p)
156-
15797
@testset "Termination condition: $(termination_condition) u0: $(_nameof(u0))" for termination_condition in TERMINATION_CONDITIONS,
15898
u0 in (1.0, [1.0, 1.0], @SVector[1.0, 1.0])
15999

@@ -195,44 +135,6 @@ end
195135
@test (@ballocated $(benchmark_nlsolve_oop)($quadratic_f, 1.0, 2.0)) == allocs
196136
end
197137

198-
@testset "[OOP] Immutable AD" begin
199-
for p in [1.0, 100.0]
200-
res = benchmark_nlsolve_oop(quadratic_f, @SVector[1.0, 1.0], p)
201-
202-
if any(x -> isnan(x) || x <= 1e-5 || x >= 1e5, res)
203-
@test_broken all(abs.(res) .≈ sqrt(p))
204-
@test_broken abs.(ForwardDiff.derivative(p -> benchmark_nlsolve_oop(quadratic_f,
205-
@SVector[1.0, 1.0], p).u[end], p)) 1 / (2 * sqrt(p))
206-
else
207-
@test all(abs.(res) .≈ sqrt(p))
208-
@test isapprox(abs.(ForwardDiff.derivative(p -> benchmark_nlsolve_oop(quadratic_f,
209-
@SVector[1.0, 1.0], p).u[end], p)), 1 / (2 * sqrt(p)))
210-
end
211-
end
212-
end
213-
214-
@testset "[OOP] Scalar AD" begin
215-
for p in 1.0:0.1:100.0
216-
res = benchmark_nlsolve_oop(quadratic_f, 1.0, p)
217-
218-
if any(x -> isnan(x), res)
219-
@test_broken abs(res.u) sqrt(p)
220-
@test_broken abs.(ForwardDiff.derivative(p -> benchmark_nlsolve_oop(quadratic_f,
221-
1.0, p).u, p)) 1 / (2 * sqrt(p))
222-
else
223-
@test abs(res.u) sqrt(p)
224-
@test isapprox(abs.(ForwardDiff.derivative(p -> benchmark_nlsolve_oop(quadratic_f,
225-
1.0, p).u, p)), 1 / (2 * sqrt(p)))
226-
end
227-
end
228-
end
229-
230-
t = (p) -> [sqrt(p[2] / p[1])]
231-
p = [0.9, 50.0]
232-
@test benchmark_nlsolve_oop(quadratic_f2, 0.5, p).u sqrt(p[2] / p[1])
233-
@test ForwardDiff.jacobian(p -> [benchmark_nlsolve_oop(quadratic_f2, 0.5, p).u],
234-
p) ForwardDiff.jacobian(t, p)
235-
236138
@testset "Termination condition: $(termination_condition) u0: $(_nameof(u0))" for termination_condition in TERMINATION_CONDITIONS,
237139
u0 in (1.0, [1.0, 1.0], @SVector[1.0, 1.0])
238140

test/forward_ad.jl

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
using ForwardDiff, SimpleNonlinearSolve, StaticArrays, Test, LinearAlgebra
2+
3+
test_f!(du, u, p) = (@. du = u^2 - p)
4+
test_f(u, p) = (@. u^2 - p)
5+
6+
jacobian_f(::Number, p) = 1 / (2 * p)
7+
jacobian_f(::Number, p::Number) = 1 / (2 * p)
8+
jacobian_f(u, p::Number) = one.(u) .* (1 / (2 * p))
9+
jacobian_f(u, p::AbstractArray) = diagm(vec(@. 1 / (2 * p)))
10+
11+
function solve_with(::Val{mode}, u, alg) where {mode}
12+
f = if mode === :iip
13+
solve_iip(p) = solve(NonlinearProblem(test_f!, u, p), alg).u
14+
elseif mode === :oop
15+
solve_oop(p) = solve(NonlinearProblem(test_f, u, p), alg).u
16+
end
17+
return f
18+
end
19+
20+
__compatible(::Any, ::Val{:oop}) = true
21+
__compatible(::Number, ::Val{:iip}) = false
22+
__compatible(::AbstractArray, ::Val{:iip}) = true
23+
__compatible(::StaticArray, ::Val{:iip}) = false
24+
25+
__compatible(::Any, ::Number) = true
26+
__compatible(::Number, ::AbstractArray) = false
27+
__compatible(u::AbstractArray, p::AbstractArray) = size(u) == size(p)
28+
29+
__compatible(u::Number, ::SimpleNonlinearSolve.AbstractSimpleNonlinearSolveAlgorithm) = true
30+
function __compatible(u::AbstractArray,
31+
::SimpleNonlinearSolve.AbstractSimpleNonlinearSolveAlgorithm)
32+
true
33+
end
34+
function __compatible(u::StaticArray,
35+
::SimpleNonlinearSolve.AbstractSimpleNonlinearSolveAlgorithm)
36+
true
37+
end
38+
39+
function __compatible(::SimpleNonlinearSolve.AbstractSimpleNonlinearSolveAlgorithm,
40+
::Val{:iip})
41+
true
42+
end
43+
function __compatible(::SimpleNonlinearSolve.AbstractSimpleNonlinearSolveAlgorithm,
44+
::Val{:oop})
45+
true
46+
end
47+
__compatible(::SimpleHalley, ::Val{:iip}) = false
48+
49+
@testset "ForwardDiff.jl Integration: $(alg)" for alg in (SimpleNewtonRaphson(),
50+
SimpleTrustRegion(), SimpleHalley(), SimpleBroyden(), SimpleKlement(), SimpleDFSane())
51+
us = (2.0, @SVector[1.0, 1.0], [1.0, 1.0], ones(2, 2), @SArray ones(2, 2))
52+
53+
@testset "Scalar AD" begin
54+
for p in 1.0:0.1:100.0, u0 in us, mode in (:iip, :oop)
55+
__compatible(u0, alg) || continue
56+
__compatible(u0, Val(mode)) || continue
57+
__compatible(alg, Val(mode)) || continue
58+
59+
sol = solve(NonlinearProblem(test_f, u0, p), alg)
60+
if SciMLBase.successful_retcode(sol)
61+
gs = abs.(ForwardDiff.derivative(solve_with(Val{mode}(), u0, alg), p))
62+
gs_true = abs.(jacobian_f(u0, p))
63+
if !(isapprox(gs, gs_true, atol = 1e-5))
64+
@show sol.retcode, sol.u
65+
@error "ForwardDiff Failed for u0=$(u0) and p=$(p) with $(alg)" forwardiff_gradient=gs true_gradient=gs_true
66+
else
67+
@test abs.(gs)abs.(gs_true) atol=1e-5
68+
end
69+
end
70+
end
71+
end
72+
73+
@testset "Jacobian" begin
74+
for u0 in us, p in ([2.0, 1.0], [2.0 1.0; 3.0 4.0]), mode in (:iip, :oop)
75+
__compatible(u0, p) || continue
76+
__compatible(u0, alg) || continue
77+
__compatible(u0, Val(mode)) || continue
78+
__compatible(alg, Val(mode)) || continue
79+
80+
sol = solve(NonlinearProblem(test_f, u0, p), alg)
81+
if SciMLBase.successful_retcode(sol)
82+
gs = abs.(ForwardDiff.jacobian(solve_with(Val{mode}(), u0, alg), p))
83+
gs_true = abs.(jacobian_f(u0, p))
84+
if !(isapprox(gs, gs_true, atol = 1e-5))
85+
@show sol.retcode, sol.u
86+
@error "ForwardDiff Failed for u0=$(u0) and p=$(p) with $(alg)" forwardiff_jacobian=gs true_jacobian=gs_true
87+
else
88+
@test abs.(gs)abs.(gs_true) atol=1e-5
89+
end
90+
end
91+
end
92+
end
93+
end

test/runtests.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@ const GROUP = get(ENV, "GROUP", "All")
44

55
@time @testset "SimpleNonlinearSolve.jl" begin
66
if GROUP == "All" || GROUP == "Core"
7-
@time @safetestset "Basic Tests + Some AD" include("basictests.jl")
7+
@time @safetestset "Basic Tests" include("basictests.jl")
8+
@time @safetestset "Forward AD" include("forward_ad.jl")
89
@time @safetestset "Matrix Resizing Tests" include("matrix_resizing_tests.jl")
910
@time @safetestset "Least Squares Tests" include("least_squares.jl")
1011
@time @safetestset "23 Test Problems" include("23_test_problems.jl")

0 commit comments

Comments
 (0)