Skip to content
This repository was archived by the owner on May 15, 2025. It is now read-only.

Commit 6b682af

Browse files
authored
Merge pull request #148 from SciML/ap/di
Use DifferentiationInterface
2 parents 0d300d3 + 29ae939 commit 6b682af

13 files changed

+184
-329
lines changed

Project.toml

+9-9
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,15 @@
11
name = "SimpleNonlinearSolve"
22
uuid = "727e6d20-b764-4bd8-a329-72de5adea6c7"
33
authors = ["SciML"]
4-
version = "1.8.1"
4+
version = "1.9.0"
55

66
[deps]
77
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
88
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
99
ConcreteStructs = "2569d6c7-a4a2-43d3-a901-331e8e4be471"
1010
DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"
1111
DiffResults = "163ba53b-c6d8-5494-b064-1a9d43ac40c5"
12+
DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
1213
FastClosures = "9aa1b823-49e4-5ca5-8b0f-3971ec8bab6a"
1314
FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41"
1415
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
@@ -17,21 +18,18 @@ MaybeInplace = "bb5d69b7-63fc-4a16-80bd-7e42200c7bdb"
1718
PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
1819
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
1920
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
21+
Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46"
2022
StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c"
2123

2224
[weakdeps]
2325
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
24-
PolyesterForwardDiff = "98d1487c-24ca-40b6-b7ab-df2af84e126b"
2526
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
26-
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
2727
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
2828
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
2929

3030
[extensions]
3131
SimpleNonlinearSolveChainRulesCoreExt = "ChainRulesCore"
32-
SimpleNonlinearSolvePolyesterForwardDiffExt = "PolyesterForwardDiff"
3332
SimpleNonlinearSolveReverseDiffExt = "ReverseDiff"
34-
SimpleNonlinearSolveStaticArraysExt = "StaticArrays"
3533
SimpleNonlinearSolveTrackerExt = "Tracker"
3634
SimpleNonlinearSolveZygoteExt = "Zygote"
3735

@@ -41,13 +39,14 @@ AllocCheck = "0.1.1"
4139
Aqua = "0.8"
4240
ArrayInterface = "7.9"
4341
CUDA = "5.2"
44-
ChainRulesCore = "1.22"
42+
ChainRulesCore = "1.23"
4543
ConcreteStructs = "0.2.3"
4644
DiffEqBase = "6.149"
4745
DiffResults = "1.1"
46+
DifferentiationInterface = "0.4"
4847
ExplicitImports = "1.5.0"
4948
FastClosures = "0.3.2"
50-
FiniteDiff = "2.22"
49+
FiniteDiff = "2.23.1"
5150
ForwardDiff = "0.10.36"
5251
LinearAlgebra = "1.10"
5352
LinearSolve = "2.30"
@@ -59,13 +58,14 @@ PrecompileTools = "1.2"
5958
Random = "1.10"
6059
ReTestItems = "1.23"
6160
Reexport = "1.2"
62-
ReverseDiff = "1.15"
61+
ReverseDiff = "1.15.3"
6362
SciMLBase = "2.37.0"
6463
SciMLSensitivity = "7.58"
64+
Setfield = "1.1.1"
6565
StaticArrays = "1.9"
6666
StaticArraysCore = "1.4.2"
6767
Test = "1.10"
68-
Tracker = "0.2.32"
68+
Tracker = "0.2.33"
6969
Zygote = "0.6.69"
7070
julia = "1.10"
7171

ext/SimpleNonlinearSolvePolyesterForwardDiffExt.jl

-20
This file was deleted.

ext/SimpleNonlinearSolveStaticArraysExt.jl

-7
This file was deleted.

src/SimpleNonlinearSolve.jl

+10-5
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,15 @@ module SimpleNonlinearSolve
33
using PrecompileTools: @compile_workload, @setup_workload, @recompile_invalidations
44

55
@recompile_invalidations begin
6-
using ADTypes: ADTypes, AutoFiniteDiff, AutoForwardDiff, AutoPolyesterForwardDiff
6+
using ADTypes: ADTypes, AbstractADType, AutoFiniteDiff, AutoForwardDiff,
7+
AutoPolyesterForwardDiff
78
using ArrayInterface: ArrayInterface
89
using ConcreteStructs: @concrete
910
using DiffEqBase: DiffEqBase, AbstractNonlinearTerminationMode,
1011
AbstractSafeNonlinearTerminationMode,
1112
AbstractSafeBestNonlinearTerminationMode, AbsNormTerminationMode,
1213
NONLINEARSOLVE_DEFAULT_NORM
14+
using DifferentiationInterface: DifferentiationInterface
1315
using DiffResults: DiffResults
1416
using FastClosures: @closure
1517
using FiniteDiff: FiniteDiff
@@ -18,13 +20,16 @@ using PrecompileTools: @compile_workload, @setup_workload, @recompile_invalidati
1820
mul!, norm, transpose
1921
using MaybeInplace: @bb, setindex_trait, CanSetindex, CannotSetindex
2022
using Reexport: @reexport
21-
using SciMLBase: SciMLBase, IntervalNonlinearProblem, NonlinearFunction,
22-
NonlinearLeastSquaresProblem, NonlinearProblem, ReturnCode, init,
23-
remake, solve, AbstractNonlinearAlgorithm, build_solution, isinplace,
24-
_unwrap_val
23+
using SciMLBase: SciMLBase, AbstractNonlinearProblem, IntervalNonlinearProblem,
24+
NonlinearFunction, NonlinearLeastSquaresProblem, NonlinearProblem,
25+
ReturnCode, init, remake, solve, AbstractNonlinearAlgorithm,
26+
build_solution, isinplace, _unwrap_val
27+
using Setfield: @set!
2528
using StaticArraysCore: StaticArray, SVector, SMatrix, SArray, MArray, Size
2629
end
2730

31+
const DI = DifferentiationInterface
32+
2833
@reexport using SciMLBase
2934

3035
abstract type AbstractSimpleNonlinearSolveAlgorithm <: AbstractNonlinearAlgorithm end

src/ad.jl

+36-58
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,15 @@
1-
function SciMLBase.solve(
2-
prob::NonlinearProblem{<:Union{Number, <:AbstractArray}, iip,
3-
<:Union{<:Dual{T, V, P}, <:AbstractArray{<:Dual{T, V, P}}}},
4-
alg::AbstractSimpleNonlinearSolveAlgorithm,
5-
args...;
6-
kwargs...) where {T, V, P, iip}
7-
sol, partials = __nlsolve_ad(prob, alg, args...; kwargs...)
8-
dual_soln = __nlsolve_dual_soln(sol.u, partials, prob.p)
9-
return SciMLBase.build_solution(
10-
prob, alg, dual_soln, sol.resid; sol.retcode, sol.stats, sol.original)
11-
end
12-
13-
function SciMLBase.solve(
14-
prob::NonlinearLeastSquaresProblem{
15-
<:AbstractArray, iip, <:Union{<:AbstractArray{<:Dual{T, V, P}}}},
16-
alg::AbstractSimpleNonlinearSolveAlgorithm,
17-
args...;
18-
kwargs...) where {T, V, P, iip}
19-
sol, partials = __nlsolve_ad(prob, alg, args...; kwargs...)
20-
dual_soln = __nlsolve_dual_soln(sol.u, partials, prob.p)
21-
return SciMLBase.build_solution(
22-
prob, alg, dual_soln, sol.resid; sol.retcode, sol.stats, sol.original)
1+
for pType in (NonlinearProblem, NonlinearLeastSquaresProblem)
2+
@eval function SciMLBase.solve(
3+
prob::$(pType){<:Union{Number, <:AbstractArray}, iip,
4+
<:Union{<:Dual{T, V, P}, <:AbstractArray{<:Dual{T, V, P}}}},
5+
alg::AbstractSimpleNonlinearSolveAlgorithm,
6+
args...;
7+
kwargs...) where {T, V, P, iip}
8+
sol, partials = __nlsolve_ad(prob, alg, args...; kwargs...)
9+
dual_soln = __nlsolve_dual_soln(sol.u, partials, prob.p)
10+
return SciMLBase.build_solution(
11+
prob, alg, dual_soln, sol.resid; sol.retcode, sol.stats, sol.original)
12+
end
2313
end
2414

2515
for algType in (Bisection, Brent, Alefeld, Falsi, ITP, Ridder)
@@ -47,8 +37,7 @@ function __nlsolve_ad(
4737
tspan = value.(prob.tspan)
4838
newprob = IntervalNonlinearProblem(prob.f, tspan, p; prob.kwargs...)
4939
else
50-
u0 = value(prob.u0)
51-
newprob = NonlinearProblem(prob.f, u0, p; prob.kwargs...)
40+
newprob = remake(prob; p, u0 = value(prob.u0))
5241
end
5342

5443
sol = solve(newprob, alg, args...; kwargs...)
@@ -73,20 +62,16 @@ function __nlsolve_ad(
7362
end
7463

7564
function __nlsolve_ad(prob::NonlinearLeastSquaresProblem, alg, args...; kwargs...)
76-
p = value(prob.p)
77-
u0 = value(prob.u0)
78-
newprob = NonlinearLeastSquaresProblem(prob.f, u0, p; prob.kwargs...)
79-
65+
newprob = remake(prob; p = value(prob.p), u0 = value(prob.u0))
8066
sol = solve(newprob, alg, args...; kwargs...)
81-
8267
uu = sol.u
8368

8469
# First check for custom `vjp` then custom `Jacobian` and if nothing is provided use
8570
# nested autodiff as the last resort
8671
if SciMLBase.has_vjp(prob.f)
8772
if isinplace(prob)
8873
_F = @closure (du, u, p) -> begin
89-
resid = similar(du, length(sol.resid))
74+
resid = __similar(du, length(sol.resid))
9075
prob.f(resid, u, p)
9176
prob.f.vjp(du, resid, u, p)
9277
du .*= 2
@@ -101,9 +86,9 @@ function __nlsolve_ad(prob::NonlinearLeastSquaresProblem, alg, args...; kwargs..
10186
elseif SciMLBase.has_jac(prob.f)
10287
if isinplace(prob)
10388
_F = @closure (du, u, p) -> begin
104-
J = similar(du, length(sol.resid), length(u))
89+
J = __similar(du, length(sol.resid), length(u))
10590
prob.f.jac(J, u, p)
106-
resid = similar(du, length(sol.resid))
91+
resid = __similar(du, length(sol.resid))
10792
prob.f(resid, u, p)
10893
mul!(reshape(du, 1, :), vec(resid)', J, 2, false)
10994
return nothing
@@ -116,43 +101,40 @@ function __nlsolve_ad(prob::NonlinearLeastSquaresProblem, alg, args...; kwargs..
116101
else
117102
if isinplace(prob)
118103
_F = @closure (du, u, p) -> begin
119-
resid = similar(du, length(sol.resid))
120-
res = DiffResults.DiffResult(
121-
resid, similar(du, length(sol.resid), length(u)))
122104
_f = @closure (du, u) -> prob.f(du, u, p)
123-
ForwardDiff.jacobian!(res, _f, resid, u)
124-
mul!(reshape(du, 1, :), vec(DiffResults.value(res))',
125-
DiffResults.jacobian(res), 2, false)
105+
resid = __similar(du, length(sol.resid))
106+
v, J = DI.value_and_jacobian(_f, resid, AutoForwardDiff(), u)
107+
mul!(reshape(du, 1, :), vec(v)', J, 2, false)
126108
return nothing
127109
end
128110
else
129111
# For small problems, nesting ForwardDiff is actually quite fast
130112
if __is_extension_loaded(Val(:Zygote)) && (length(uu) + length(sol.resid) 50)
131-
_F = @closure (u, p) -> __zygote_compute_nlls_vjp(prob.f, u, p)
113+
# TODO: Remove once DI has the value_and_pullback_split defined
114+
_F = @closure (u, p) -> begin
115+
_f = Base.Fix2(prob.f, p)
116+
return __zygote_compute_nlls_vjp(_f, u, p)
117+
end
132118
else
133119
_F = @closure (u, p) -> begin
134-
T = promote_type(eltype(u), eltype(p))
135-
res = DiffResults.DiffResult(similar(u, T, size(sol.resid)),
136-
similar(u, T, length(sol.resid), length(u)))
137-
ForwardDiff.jacobian!(res, Base.Fix2(prob.f, p), u)
138-
return reshape(
139-
2 .* vec(DiffResults.value(res))' * DiffResults.jacobian(res),
140-
size(u))
120+
_f = Base.Fix2(prob.f, p)
121+
v, J = DI.value_and_jacobian(_f, AutoForwardDiff(), u)
122+
return reshape(2 .* vec(v)' * J, size(u))
141123
end
142124
end
143125
end
144126
end
145127

146-
f_p = __nlsolve_∂f_∂p(prob, _F, uu, p)
147-
f_x = __nlsolve_∂f_∂u(prob, _F, uu, p)
128+
f_p = __nlsolve_∂f_∂p(prob, _F, uu, newprob.p)
129+
f_x = __nlsolve_∂f_∂u(prob, _F, uu, newprob.p)
148130

149131
z_arr = -f_x \ f_p
150132

151133
pp = prob.p
152134
sumfun = ((z, p),) -> map(zᵢ -> zᵢ * ForwardDiff.partials(p), z)
153135
if uu isa Number
154136
partials = sum(sumfun, zip(z_arr, pp))
155-
elseif p isa Number
137+
elseif pp isa Number
156138
partials = sumfun((z_arr, pp))
157139
else
158140
partials = sum(sumfun, zip(eachcol(z_arr), pp))
@@ -164,7 +146,7 @@ end
164146
@inline function __nlsolve_∂f_∂p(prob, f::F, u, p) where {F}
165147
if isinplace(prob)
166148
__f = p -> begin
167-
du = similar(u, promote_type(eltype(u), eltype(p)))
149+
du = __similar(u, promote_type(eltype(u), eltype(p)))
168150
f(du, u, p)
169151
return du
170152
end
@@ -182,16 +164,12 @@ end
182164

183165
@inline function __nlsolve_∂f_∂u(prob, f::F, u, p) where {F}
184166
if isinplace(prob)
185-
du = similar(u)
186-
__f = (du, u) -> f(du, u, p)
187-
ForwardDiff.jacobian(__f, du, u)
167+
__f = @closure (du, u) -> f(du, u, p)
168+
return ForwardDiff.jacobian(__f, __similar(u), u)
188169
else
189170
__f = Base.Fix2(f, p)
190-
if u isa Number
191-
return ForwardDiff.derivative(__f, u)
192-
else
193-
return ForwardDiff.jacobian(__f, u)
194-
end
171+
u isa Number && return ForwardDiff.derivative(__f, u)
172+
return ForwardDiff.jacobian(__f, u)
195173
end
196174
end
197175

src/nlsolve/dfsane.jl

+9-6
Original file line numberDiff line numberDiff line change
@@ -77,12 +77,8 @@ function SciMLBase.__solve(prob::NonlinearProblem, alg::SimpleDFSane{M}, args...
7777
α_1 = one(T)
7878
f_1 = fx_norm
7979

80-
history_f_k = if x isa SArray ||
81-
(x isa Number && __is_extension_loaded(Val(:StaticArrays)))
82-
ones(SVector{M, T}) * fx_norm
83-
else
84-
fill(fx_norm, M)
85-
end
80+
history_f_k = x isa SArray ? ones(SVector{M, T}) * fx_norm :
81+
__history_vec(fx_norm, Val(M))
8682

8783
# Generate the cache
8884
@bb x_cache = similar(x)
@@ -150,6 +146,8 @@ function SciMLBase.__solve(prob::NonlinearProblem, alg::SimpleDFSane{M}, args...
150146
# Store function value
151147
if history_f_k isa SVector
152148
history_f_k = Base.setindex(history_f_k, fx_norm_new, mod1(k, M))
149+
elseif history_f_k isa NTuple
150+
@set! history_f_k[mod1(k, M)] = fx_norm_new
153151
else
154152
history_f_k[mod1(k, M)] = fx_norm_new
155153
end
@@ -158,3 +156,8 @@ function SciMLBase.__solve(prob::NonlinearProblem, alg::SimpleDFSane{M}, args...
158156

159157
return build_solution(prob, alg, x, fx; retcode = ReturnCode.MaxIters)
160158
end
159+
160+
@inline @generated function __history_vec(fx_norm, ::Val{M}) where {M}
161+
M 11 && return :(fill(fx_norm, M)) # Julia can't specialize here
162+
return :(ntuple(Returns(fx_norm), $(M)))
163+
end

0 commit comments

Comments
 (0)