Skip to content
This repository was archived by the owner on May 15, 2025. It is now read-only.

Add ForwardDiff Inplace Overloads #114

Merged
merged 2 commits into from
Dec 27, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
Add ForwardDiff Inplace Overloads
  • Loading branch information
avik-pal committed Dec 26, 2023
commit fadbffac0ebf05dd322b7ed9cda71da2f2775a72
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "SimpleNonlinearSolve"
uuid = "727e6d20-b764-4bd8-a329-72de5adea6c7"
authors = ["SciML"]
version = "1.1.0"
version = "1.2.0"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand Down
95 changes: 53 additions & 42 deletions src/ad.jl
Original file line number Diff line number Diff line change
@@ -1,21 +1,24 @@
function scalar_nlsolve_ad(prob, alg, args...; kwargs...)
f = prob.f
function SciMLBase.solve(prob::NonlinearProblem{<:Union{Number, <:AbstractArray},
iip, <:Union{<:Dual{T, V, P}, <:AbstractArray{<:Dual{T, V, P}}}},
alg::AbstractSimpleNonlinearSolveAlgorithm, args...; kwargs...) where {T, V, P, iip}
sol, partials = __nlsolve_ad(prob, alg, args...; kwargs...)
dual_soln = __nlsolve_dual_soln(sol.u, partials, prob.p)
return SciMLBase.build_solution(prob, alg, dual_soln, sol.resid; sol.retcode, sol.stats,
sol.original)
end

function __nlsolve_ad(prob::NonlinearProblem{uType, iip}, alg, args...;
kwargs...) where {uType, iip}
p = value(prob.p)
if prob isa IntervalNonlinearProblem
tspan = value.(prob.tspan)
newprob = IntervalNonlinearProblem(f, tspan, p; prob.kwargs...)
else
u0 = value(prob.u0)
newprob = NonlinearProblem(f, u0, p; prob.kwargs...)
end
newprob = NonlinearProblem(prob.f, value(prob.u0), p; prob.kwargs...)

sol = solve(newprob, alg, args...; kwargs...)

uu = sol.u
f_p = scalar_nlsolve_∂f_∂p(f, uu, p)
f_x = scalar_nlsolve_∂f_∂u(f, uu, p)
f_p = __nlsolve_∂f_∂p(prob, prob.f, uu, p)
f_x = __nlsolve_∂f_∂u(prob, prob.f, uu, p)

z_arr = -inv(f_x) * f_p
z_arr = -f_x \ f_p

pp = prob.p
sumfun = ((z, p),) -> map(zᵢ -> zᵢ * ForwardDiff.partials(p), z)
Expand All @@ -30,58 +33,66 @@ function scalar_nlsolve_ad(prob, alg, args...; kwargs...)
return sol, partials
end

function SciMLBase.solve(prob::NonlinearProblem{<:Union{Number, SVector, <:AbstractArray},
false, <:Dual{T, V, P}}, alg::AbstractSimpleNonlinearSolveAlgorithm, args...;
kwargs...) where {T, V, P}
sol, partials = scalar_nlsolve_ad(prob, alg, args...; kwargs...)
dual_soln = scalar_nlsolve_dual_soln(sol.u, partials, prob.p)
return SciMLBase.build_solution(prob, alg, dual_soln, sol.resid; sol.retcode)
end

function SciMLBase.solve(prob::NonlinearProblem{<:Union{Number, SVector, <:AbstractArray},
false, <:AbstractArray{<:Dual{T, V, P}}},
alg::AbstractSimpleNonlinearSolveAlgorithm, args...; kwargs...) where {T, V, P}
sol, partials = scalar_nlsolve_ad(prob, alg, args...; kwargs...)
dual_soln = scalar_nlsolve_dual_soln(sol.u, partials, prob.p)
return SciMLBase.build_solution(prob, alg, dual_soln, sol.resid; sol.retcode)
end

function scalar_nlsolve_∂f_∂p(f, u, p)
ff = p isa Number ? ForwardDiff.derivative :
(u isa Number ? ForwardDiff.gradient : ForwardDiff.jacobian)
return ff(Base.Fix1(f, u), p)
@inline function __nlsolve_∂f_∂p(prob, f::F, u, p) where {F}
if isinplace(prob)
__f = p -> begin
du = similar(u, promote_type(eltype(u), eltype(p)))
f(du, u, p)
return du
end
else
__f = Base.Fix1(f, u)
end
if p isa Number
return __reshape(ForwardDiff.derivative(__f, p), :, 1)
elseif u isa Number
return __reshape(ForwardDiff.gradient(__f, p), 1, :)
else
return ForwardDiff.jacobian(__f, p)
end
end

function scalar_nlsolve_∂f_∂u(f, u, p)
ff = u isa Number ? ForwardDiff.derivative : ForwardDiff.jacobian
return ff(Base.Fix2(f, p), u)
@inline function __nlsolve_∂f_∂u(prob, f::F, u, p) where {F}
if isinplace(prob)
du = similar(u)
__f = (du, u) -> f(du, u, p)
ForwardDiff.jacobian(__f, du, u)
else
__f = Base.Fix2(f, p)
if u isa Number
return ForwardDiff.derivative(__f, u)
else
return ForwardDiff.jacobian(__f, u)
end
end
end

function scalar_nlsolve_dual_soln(u::Number, partials,
@inline function __nlsolve_dual_soln(u::Number, partials,
::Union{<:AbstractArray{<:Dual{T, V, P}}, Dual{T, V, P}}) where {T, V, P}
return Dual{T, V, P}(u, partials)
end

function scalar_nlsolve_dual_soln(u::AbstractArray, partials,
@inline function __nlsolve_dual_soln(u::AbstractArray, partials,
::Union{<:AbstractArray{<:Dual{T, V, P}}, Dual{T, V, P}}) where {T, V, P}
return map(((uᵢ, pᵢ),) -> Dual{T, V, P}(uᵢ, pᵢ), zip(u, partials))
_partials = _restructure(u, partials)
return map(((uᵢ, pᵢ),) -> Dual{T, V, P}(uᵢ, pᵢ), zip(u, _partials))
end

# avoid ambiguities
for Alg in [Bisection]
@eval function SciMLBase.solve(prob::IntervalNonlinearProblem{uType, iip,
<:Dual{T, V, P}}, alg::$Alg, args...; kwargs...) where {uType, iip, T, V, P}
sol, partials = scalar_nlsolve_ad(prob, alg, args...; kwargs...)
dual_soln = scalar_nlsolve_dual_soln(sol.u, partials, prob.p)
sol, partials = __nlsolve_ad(prob, alg, args...; kwargs...)
dual_soln = __nlsolve_dual_soln(sol.u, partials, prob.p)
return SciMLBase.build_solution(prob, alg, dual_soln, sol.resid; sol.retcode,
left = Dual{T, V, P}(sol.left, partials),
right = Dual{T, V, P}(sol.right, partials))
end
@eval function SciMLBase.solve(prob::IntervalNonlinearProblem{uType, iip,
<:AbstractArray{<:Dual{T, V, P}}}, alg::$Alg, args...;
kwargs...) where {uType, iip, T, V, P}
sol, partials = scalar_nlsolve_ad(prob, alg, args...; kwargs...)
dual_soln = scalar_nlsolve_dual_soln(sol.u, partials, prob.p)
sol, partials = __nlsolve_ad(prob, alg, args...; kwargs...)
dual_soln = __nlsolve_dual_soln(sol.u, partials, prob.p)
return SciMLBase.build_solution(prob, alg, dual_soln, sol.resid; sol.retcode,
left = Dual{T, V, P}(sol.left, partials),
right = Dual{T, V, P}(sol.right, partials))
Expand Down
9 changes: 8 additions & 1 deletion src/nlsolve/halley.jl
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,14 @@ function SciMLBase.__solve(prob::NonlinearProblem, alg::SimpleHalley, args...;
setindex_trait(x) === CannotSetindex() && (A = dfx)

# Factorize Once and Reuse
dfx_fact = factorize(dfx)
dfx_fact = if dfx isa Number
dfx
else
fact = lu(dfx; check = false)
!issuccess(fact) && return build_solution(prob, alg, x, fx;
retcode = ReturnCode.Unstable)
fact
end

aᵢ = dfx_fact \ _vec(fx)
A_ = _vec(A)
Expand Down
3 changes: 3 additions & 0 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -381,3 +381,6 @@ end
return AutoFiniteDiff()
end
end

@inline __reshape(x::Number, args...) = x
@inline __reshape(x::AbstractArray, args...) = reshape(x, args...)
98 changes: 0 additions & 98 deletions test/basictests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -64,36 +64,6 @@ const TERMINATION_CONDITIONS = [
autodiff = AutoForwardDiff())) == 0
end

@testset "[OOP] Immutable AD" begin
for p in [1.0, 100.0]
@test begin
res = benchmark_nlsolve_oop(quadratic_f, @SVector[1.0, 1.0], p)
res_true = sqrt(p)
all(res.u .≈ res_true)
end
@test ForwardDiff.derivative(p -> benchmark_nlsolve_oop(quadratic_f,
@SVector[1.0, 1.0], p).u[end], p) ≈ 1 / (2 * sqrt(p))
end
end

@testset "[OOP] Scalar AD" begin
for p in 1.0:0.1:100.0
@test begin
res = benchmark_nlsolve_oop(quadratic_f, 1.0, p)
res_true = sqrt(p)
res.u ≈ res_true
end
@test ForwardDiff.derivative(p -> benchmark_nlsolve_oop(quadratic_f, 1.0, p).u,
p) ≈ 1 / (2 * sqrt(p))
end
end

t = (p) -> [sqrt(p[2] / p[1])]
p = [0.9, 50.0]
@test benchmark_nlsolve_oop(quadratic_f2, 0.5, p).u ≈ sqrt(p[2] / p[1])
@test ForwardDiff.jacobian(p -> [benchmark_nlsolve_oop(quadratic_f2, 0.5, p).u],
p) ≈ ForwardDiff.jacobian(t, p)

@testset "Termination condition: $(termination_condition) u0: $(_nameof(u0))" for termination_condition in TERMINATION_CONDITIONS,
u0 in (1.0, [1.0, 1.0], @SVector[1.0, 1.0])

Expand Down Expand Up @@ -124,36 +94,6 @@ end
autodiff = AutoForwardDiff())) == 0
end

@testset "[OOP] Immutable AD" begin
for p in [1.0, 100.0]
@test begin
res = benchmark_nlsolve_oop(quadratic_f, @SVector[1.0, 1.0], p)
res_true = sqrt(p)
all(res.u .≈ res_true)
end
@test ForwardDiff.derivative(p -> benchmark_nlsolve_oop(quadratic_f,
@SVector[1.0, 1.0], p).u[end], p) ≈ 1 / (2 * sqrt(p))
end
end

@testset "[OOP] Scalar AD" begin
for p in 1.0:0.1:100.0
@test begin
res = benchmark_nlsolve_oop(quadratic_f, 1.0, p)
res_true = sqrt(p)
res.u ≈ res_true
end
@test ForwardDiff.derivative(p -> benchmark_nlsolve_oop(quadratic_f, 1.0, p).u,
p) ≈ 1 / (2 * sqrt(p))
end
end

t = (p) -> [sqrt(p[2] / p[1])]
p = [0.9, 50.0]
@test benchmark_nlsolve_oop(quadratic_f2, 0.5, p).u ≈ sqrt(p[2] / p[1])
@test ForwardDiff.jacobian(p -> [benchmark_nlsolve_oop(quadratic_f2, 0.5, p).u],
p) ≈ ForwardDiff.jacobian(t, p)

@testset "Termination condition: $(termination_condition) u0: $(_nameof(u0))" for termination_condition in TERMINATION_CONDITIONS,
u0 in (1.0, [1.0, 1.0], @SVector[1.0, 1.0])

Expand Down Expand Up @@ -195,44 +135,6 @@ end
@test (@ballocated $(benchmark_nlsolve_oop)($quadratic_f, 1.0, 2.0)) == allocs
end

@testset "[OOP] Immutable AD" begin
for p in [1.0, 100.0]
res = benchmark_nlsolve_oop(quadratic_f, @SVector[1.0, 1.0], p)

if any(x -> isnan(x) || x <= 1e-5 || x >= 1e5, res)
@test_broken all(abs.(res) .≈ sqrt(p))
@test_broken abs.(ForwardDiff.derivative(p -> benchmark_nlsolve_oop(quadratic_f,
@SVector[1.0, 1.0], p).u[end], p)) ≈ 1 / (2 * sqrt(p))
else
@test all(abs.(res) .≈ sqrt(p))
@test isapprox(abs.(ForwardDiff.derivative(p -> benchmark_nlsolve_oop(quadratic_f,
@SVector[1.0, 1.0], p).u[end], p)), 1 / (2 * sqrt(p)))
end
end
end

@testset "[OOP] Scalar AD" begin
for p in 1.0:0.1:100.0
res = benchmark_nlsolve_oop(quadratic_f, 1.0, p)

if any(x -> isnan(x), res)
@test_broken abs(res.u) ≈ sqrt(p)
@test_broken abs.(ForwardDiff.derivative(p -> benchmark_nlsolve_oop(quadratic_f,
1.0, p).u, p)) ≈ 1 / (2 * sqrt(p))
else
@test abs(res.u) ≈ sqrt(p)
@test isapprox(abs.(ForwardDiff.derivative(p -> benchmark_nlsolve_oop(quadratic_f,
1.0, p).u, p)), 1 / (2 * sqrt(p)))
end
end
end

t = (p) -> [sqrt(p[2] / p[1])]
p = [0.9, 50.0]
@test benchmark_nlsolve_oop(quadratic_f2, 0.5, p).u ≈ sqrt(p[2] / p[1])
@test ForwardDiff.jacobian(p -> [benchmark_nlsolve_oop(quadratic_f2, 0.5, p).u],
p) ≈ ForwardDiff.jacobian(t, p)

@testset "Termination condition: $(termination_condition) u0: $(_nameof(u0))" for termination_condition in TERMINATION_CONDITIONS,
u0 in (1.0, [1.0, 1.0], @SVector[1.0, 1.0])

Expand Down
93 changes: 93 additions & 0 deletions test/forward_ad.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
using ForwardDiff, SimpleNonlinearSolve, StaticArrays, Test, LinearAlgebra

test_f!(du, u, p) = (@. du = u^2 - p)
test_f(u, p) = (@. u^2 - p)

jacobian_f(::Number, p) = 1 / (2 * √p)
jacobian_f(::Number, p::Number) = 1 / (2 * √p)
jacobian_f(u, p::Number) = one.(u) .* (1 / (2 * √p))
jacobian_f(u, p::AbstractArray) = diagm(vec(@. 1 / (2 * √p)))

function solve_with(::Val{mode}, u, alg) where {mode}
f = if mode === :iip
solve_iip(p) = solve(NonlinearProblem(test_f!, u, p), alg).u
elseif mode === :oop
solve_oop(p) = solve(NonlinearProblem(test_f, u, p), alg).u
end
return f
end

__compatible(::Any, ::Val{:oop}) = true
__compatible(::Number, ::Val{:iip}) = false
__compatible(::AbstractArray, ::Val{:iip}) = true
__compatible(::StaticArray, ::Val{:iip}) = false

__compatible(::Any, ::Number) = true
__compatible(::Number, ::AbstractArray) = false
__compatible(u::AbstractArray, p::AbstractArray) = size(u) == size(p)

__compatible(u::Number, ::SimpleNonlinearSolve.AbstractSimpleNonlinearSolveAlgorithm) = true
function __compatible(u::AbstractArray,
::SimpleNonlinearSolve.AbstractSimpleNonlinearSolveAlgorithm)
true
end
function __compatible(u::StaticArray,
::SimpleNonlinearSolve.AbstractSimpleNonlinearSolveAlgorithm)
true
end

function __compatible(::SimpleNonlinearSolve.AbstractSimpleNonlinearSolveAlgorithm,
::Val{:iip})
true
end
function __compatible(::SimpleNonlinearSolve.AbstractSimpleNonlinearSolveAlgorithm,
::Val{:oop})
true
end
__compatible(::SimpleHalley, ::Val{:iip}) = false

@testset "ForwardDiff.jl Integration: $(alg)" for alg in (SimpleNewtonRaphson(),
SimpleTrustRegion(), SimpleHalley(), SimpleBroyden(), SimpleKlement(), SimpleDFSane())
us = (2.0, @SVector[1.0, 1.0], [1.0, 1.0], ones(2, 2), @SArray ones(2, 2))

@testset "Scalar AD" begin
for p in 1.0:0.1:100.0, u0 in us, mode in (:iip, :oop)
__compatible(u0, alg) || continue
__compatible(u0, Val(mode)) || continue
__compatible(alg, Val(mode)) || continue

sol = solve(NonlinearProblem(test_f, u0, p), alg)
if SciMLBase.successful_retcode(sol)
gs = abs.(ForwardDiff.derivative(solve_with(Val{mode}(), u0, alg), p))
gs_true = abs.(jacobian_f(u0, p))
if !(isapprox(gs, gs_true, atol = 1e-5))
@show sol.retcode, sol.u
@error "ForwardDiff Failed for u0=$(u0) and p=$(p) with $(alg)" forwardiff_gradient=gs true_gradient=gs_true
else
@test abs.(gs)≈abs.(gs_true) atol=1e-5
end
end
end
end

@testset "Jacobian" begin
for u0 in us, p in ([2.0, 1.0], [2.0 1.0; 3.0 4.0]), mode in (:iip, :oop)
__compatible(u0, p) || continue
__compatible(u0, alg) || continue
__compatible(u0, Val(mode)) || continue
__compatible(alg, Val(mode)) || continue

sol = solve(NonlinearProblem(test_f, u0, p), alg)
if SciMLBase.successful_retcode(sol)
gs = abs.(ForwardDiff.jacobian(solve_with(Val{mode}(), u0, alg), p))
gs_true = abs.(jacobian_f(u0, p))
if !(isapprox(gs, gs_true, atol = 1e-5))
@show sol.retcode, sol.u
@error "ForwardDiff Failed for u0=$(u0) and p=$(p) with $(alg)" forwardiff_jacobian=gs true_jacobian=gs_true
else
@test abs.(gs)≈abs.(gs_true) atol=1e-5
end
end
end
end
end
3 changes: 2 additions & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@ const GROUP = get(ENV, "GROUP", "All")

@time @testset "SimpleNonlinearSolve.jl" begin
if GROUP == "All" || GROUP == "Core"
@time @safetestset "Basic Tests + Some AD" include("basictests.jl")
@time @safetestset "Basic Tests" include("basictests.jl")
@time @safetestset "Forward AD" include("forward_ad.jl")
@time @safetestset "Matrix Resizing Tests" include("matrix_resizing_tests.jl")
@time @safetestset "Least Squares Tests" include("least_squares.jl")
@time @safetestset "23 Test Problems" include("23_test_problems.jl")
Expand Down