Skip to content
This repository was archived by the owner on May 15, 2025. It is now read-only.

Commit 06239a6

Browse files
committed
Forward Mode overloads for Least Squares Problem
1 parent 8995a23 commit 06239a6

File tree

4 files changed

+28
-2
lines changed

4 files changed

+28
-2
lines changed

Project.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "SimpleNonlinearSolve"
22
uuid = "727e6d20-b764-4bd8-a329-72de5adea6c7"
33
authors = ["SciML"]
4-
version = "1.4.3"
4+
version = "1.5.0"
55

66
[deps]
77
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
@@ -22,11 +22,13 @@ StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c"
2222
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
2323
PolyesterForwardDiff = "98d1487c-24ca-40b6-b7ab-df2af84e126b"
2424
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
25+
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
2526

2627
[extensions]
2728
SimpleNonlinearSolveChainRulesCoreExt = "ChainRulesCore"
2829
SimpleNonlinearSolvePolyesterForwardDiffExt = "PolyesterForwardDiff"
2930
SimpleNonlinearSolveStaticArraysExt = "StaticArrays"
31+
SimpleNonlinearSolveZygoteExt = "Zygote"
3032

3133
[compat]
3234
ADTypes = "0.2.6"

ext/SimpleNonlinearSolveZygoteExt.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
module SimpleNonlinearSolveZygoteExt
2+
3+
import SimpleNonlinearSolve
4+
5+
SimpleNonlinearSolve.__is_extension_loaded(::Val{:Zygote}) = true
6+
7+
end

src/ad.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,8 @@ for algType in (Bisection, Brent, Alefeld, Falsi, ITP, Ridder)
2424
end
2525
end
2626

27-
function __nlsolve_ad(prob, alg, args...; kwargs...)
27+
function __nlsolve_ad(
28+
prob::Union{IntervalNonlinearProblem, NonlinearProblem}, alg, args...; kwargs...)
2829
p = value(prob.p)
2930
if prob isa IntervalNonlinearProblem
3031
tspan = value.(prob.tspan)

test/core/least_squares_tests.jl

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,12 @@
1212
return.- y_target
1313
end
1414

15+
function loss_function!(resid, θ, p)
16+
= true_function(p, θ)
17+
@. resid =- y_target
18+
return
19+
end
20+
1521
θ_init = θ_true .+ 0.1
1622
prob_oop = NonlinearLeastSquaresProblem{false}(loss_function, θ_init, x)
1723

@@ -21,4 +27,14 @@
2127
sol = solve(prob_oop, solver)
2228
@test norm(sol.resid, Inf) < 1e-12
2329
end
30+
31+
prob_iip = NonlinearLeastSquaresProblem(
32+
NonlinearFunction{true}(loss_function!, resid_prototype = zeros(length(y_target))), θ_init, x)
33+
34+
@testset "Solver: $(nameof(typeof(solver)))" for solver in [
35+
SimpleNewtonRaphson(AutoForwardDiff()), SimpleGaussNewton(AutoForwardDiff()),
36+
SimpleNewtonRaphson(AutoFiniteDiff()), SimpleGaussNewton(AutoFiniteDiff())]
37+
sol = solve(prob_iip, solver)
38+
@test norm(sol.resid, Inf) < 1e-12
39+
end
2440
end

0 commit comments

Comments
 (0)