1
- function SciMLBase. solve (
2
- prob:: NonlinearProblem {<: Union{Number, <:AbstractArray} , iip,
3
- <: Union{<:Dual{T, V, P}, <:AbstractArray{<:Dual{T, V, P}}} },
4
- alg:: AbstractSimpleNonlinearSolveAlgorithm ,
5
- args... ;
6
- kwargs... ) where {T, V, P, iip}
7
- sol, partials = __nlsolve_ad (prob, alg, args... ; kwargs... )
8
- dual_soln = __nlsolve_dual_soln (sol. u, partials, prob. p)
9
- return SciMLBase. build_solution (
10
- prob, alg, dual_soln, sol. resid; sol. retcode, sol. stats, sol. original)
11
- end
12
-
13
- function SciMLBase. solve (
14
- prob:: NonlinearLeastSquaresProblem {
15
- <: AbstractArray , iip, <: Union{<:AbstractArray{<:Dual{T, V, P}}} },
16
- alg:: AbstractSimpleNonlinearSolveAlgorithm ,
17
- args... ;
18
- kwargs... ) where {T, V, P, iip}
19
- sol, partials = __nlsolve_ad (prob, alg, args... ; kwargs... )
20
- dual_soln = __nlsolve_dual_soln (sol. u, partials, prob. p)
21
- return SciMLBase. build_solution (
22
- prob, alg, dual_soln, sol. resid; sol. retcode, sol. stats, sol. original)
1
+ for pType in (NonlinearProblem, NonlinearLeastSquaresProblem)
2
+ @eval function SciMLBase. solve (
3
+ prob:: $ (pType){<: Union{Number, <:AbstractArray} , iip,
4
+ <: Union{<:Dual{T, V, P}, <:AbstractArray{<:Dual{T, V, P}}} },
5
+ alg:: AbstractSimpleNonlinearSolveAlgorithm ,
6
+ args... ;
7
+ kwargs... ) where {T, V, P, iip}
8
+ sol, partials = __nlsolve_ad (prob, alg, args... ; kwargs... )
9
+ dual_soln = __nlsolve_dual_soln (sol. u, partials, prob. p)
10
+ return SciMLBase. build_solution (
11
+ prob, alg, dual_soln, sol. resid; sol. retcode, sol. stats, sol. original)
12
+ end
23
13
end
24
14
25
15
for algType in (Bisection, Brent, Alefeld, Falsi, ITP, Ridder)
@@ -47,8 +37,7 @@ function __nlsolve_ad(
47
37
tspan = value .(prob. tspan)
48
38
newprob = IntervalNonlinearProblem (prob. f, tspan, p; prob. kwargs... )
49
39
else
50
- u0 = value (prob. u0)
51
- newprob = NonlinearProblem (prob. f, u0, p; prob. kwargs... )
40
+ newprob = remake (prob; p, u0 = value (prob. u0))
52
41
end
53
42
54
43
sol = solve (newprob, alg, args... ; kwargs... )
@@ -73,20 +62,16 @@ function __nlsolve_ad(
73
62
end
74
63
75
64
function __nlsolve_ad (prob:: NonlinearLeastSquaresProblem , alg, args... ; kwargs... )
76
- p = value (prob. p)
77
- u0 = value (prob. u0)
78
- newprob = NonlinearLeastSquaresProblem (prob. f, u0, p; prob. kwargs... )
79
-
65
+ newprob = remake (prob; p = value (prob. p), u0 = value (prob. u0))
80
66
sol = solve (newprob, alg, args... ; kwargs... )
81
-
82
67
uu = sol. u
83
68
84
69
# First check for custom `vjp` then custom `Jacobian` and if nothing is provided use
85
70
# nested autodiff as the last resort
86
71
if SciMLBase. has_vjp (prob. f)
87
72
if isinplace (prob)
88
73
_F = @closure (du, u, p) -> begin
89
- resid = similar (du, length (sol. resid))
74
+ resid = __similar (du, length (sol. resid))
90
75
prob. f (resid, u, p)
91
76
prob. f. vjp (du, resid, u, p)
92
77
du .*= 2
@@ -101,9 +86,9 @@ function __nlsolve_ad(prob::NonlinearLeastSquaresProblem, alg, args...; kwargs..
101
86
elseif SciMLBase. has_jac (prob. f)
102
87
if isinplace (prob)
103
88
_F = @closure (du, u, p) -> begin
104
- J = similar (du, length (sol. resid), length (u))
89
+ J = __similar (du, length (sol. resid), length (u))
105
90
prob. f. jac (J, u, p)
106
- resid = similar (du, length (sol. resid))
91
+ resid = __similar (du, length (sol. resid))
107
92
prob. f (resid, u, p)
108
93
mul! (reshape (du, 1 , :), vec (resid)' , J, 2 , false )
109
94
return nothing
@@ -116,43 +101,40 @@ function __nlsolve_ad(prob::NonlinearLeastSquaresProblem, alg, args...; kwargs..
116
101
else
117
102
if isinplace (prob)
118
103
_F = @closure (du, u, p) -> begin
119
- resid = similar (du, length (sol. resid))
120
- res = DiffResults. DiffResult (
121
- resid, similar (du, length (sol. resid), length (u)))
122
104
_f = @closure (du, u) -> prob. f (du, u, p)
123
- ForwardDiff . jacobian! (res, _f, resid, u )
124
- mul! ( reshape (du, 1 , :), vec (DiffResults . value (res)) ' ,
125
- DiffResults . jacobian (res) , 2 , false )
105
+ resid = __similar (du, length (sol . resid) )
106
+ v, J = DI . value_and_jacobian (_f, resid, AutoForwardDiff (), u)
107
+ mul! ( reshape (du, 1 , :), vec (v) ' , J , 2 , false )
126
108
return nothing
127
109
end
128
110
else
129
111
# For small problems, nesting ForwardDiff is actually quite fast
130
112
if __is_extension_loaded (Val (:Zygote )) && (length (uu) + length (sol. resid) ≥ 50 )
131
- _F = @closure (u, p) -> __zygote_compute_nlls_vjp (prob. f, u, p)
113
+ # TODO : Remove once DI has the value_and_pullback_split defined
114
+ _F = @closure (u, p) -> begin
115
+ _f = Base. Fix2 (prob. f, p)
116
+ return __zygote_compute_nlls_vjp (_f, u, p)
117
+ end
132
118
else
133
119
_F = @closure (u, p) -> begin
134
- T = promote_type (eltype (u), eltype (p))
135
- res = DiffResults. DiffResult (similar (u, T, size (sol. resid)),
136
- similar (u, T, length (sol. resid), length (u)))
137
- ForwardDiff. jacobian! (res, Base. Fix2 (prob. f, p), u)
138
- return reshape (
139
- 2 .* vec (DiffResults. value (res))' * DiffResults. jacobian (res),
140
- size (u))
120
+ _f = Base. Fix2 (prob. f, p)
121
+ v, J = DI. value_and_jacobian (_f, AutoForwardDiff (), u)
122
+ return reshape (2 .* vec (v)' * J, size (u))
141
123
end
142
124
end
143
125
end
144
126
end
145
127
146
- f_p = __nlsolve_∂f_∂p (prob, _F, uu, p)
147
- f_x = __nlsolve_∂f_∂u (prob, _F, uu, p)
128
+ f_p = __nlsolve_∂f_∂p (prob, _F, uu, newprob . p)
129
+ f_x = __nlsolve_∂f_∂u (prob, _F, uu, newprob . p)
148
130
149
131
z_arr = - f_x \ f_p
150
132
151
133
pp = prob. p
152
134
sumfun = ((z, p),) -> map (zᵢ -> zᵢ * ForwardDiff. partials (p), z)
153
135
if uu isa Number
154
136
partials = sum (sumfun, zip (z_arr, pp))
155
- elseif p isa Number
137
+ elseif pp isa Number
156
138
partials = sumfun ((z_arr, pp))
157
139
else
158
140
partials = sum (sumfun, zip (eachcol (z_arr), pp))
164
146
@inline function __nlsolve_∂f_∂p (prob, f:: F , u, p) where {F}
165
147
if isinplace (prob)
166
148
__f = p -> begin
167
- du = similar (u, promote_type (eltype (u), eltype (p)))
149
+ du = __similar (u, promote_type (eltype (u), eltype (p)))
168
150
f (du, u, p)
169
151
return du
170
152
end
@@ -182,16 +164,12 @@ end
182
164
183
165
@inline function __nlsolve_∂f_∂u (prob, f:: F , u, p) where {F}
184
166
if isinplace (prob)
185
- du = similar (u)
186
- __f = (du, u) -> f (du, u, p)
187
- ForwardDiff. jacobian (__f, du, u)
167
+ __f = @closure (du, u) -> f (du, u, p)
168
+ return ForwardDiff. jacobian (__f, __similar (u), u)
188
169
else
189
170
__f = Base. Fix2 (f, p)
190
- if u isa Number
191
- return ForwardDiff. derivative (__f, u)
192
- else
193
- return ForwardDiff. jacobian (__f, u)
194
- end
171
+ u isa Number && return ForwardDiff. derivative (__f, u)
172
+ return ForwardDiff. jacobian (__f, u)
195
173
end
196
174
end
197
175
0 commit comments