-
Notifications
You must be signed in to change notification settings - Fork 72
Open
Description
jaxopt.ScipyMinimize unilaterally assumes a form for callable of callable(xk)
, which is not the correct specification for trust-constr
, according to scipy documentation. This results in an error which was documented 2 years ago #428.
I initially (incorrectly) assumed this was an error with scipy, and posted there (scipy/scipy#23570).
Code to reproduce the issue.
# Show how callbacks behave across SciPy methods via jaxopt.ScipyMinimize.
# Logs loss per iteration; prints PASS/FAIL for each method.
import jax
import jax.numpy as jnp
from jaxopt import ScipyMinimize
# Tiny quadratic: f(x) = sum((x - 3)^2)
def fun(x):
return jnp.sum((x - 3.0) ** 2)
METHODS = [
"CG", "BFGS", "Newton-CG", "L-BFGS-B",
"Nelder-Mead", "Powell",
"TNC", "SLSQP", "COBYLA", "trust-constr",
# "dogleg", "trust-ncg", "trust-krylov", "trust-exact", # need hess or hessp
]
def run_one(method: str):
iter_losses = []
# Simple SciPy-style callback: callback(xk)
def callback(xk):
val = fun(xk)
iter_losses.append(float(val))
# IMPORTANT: pass the callback to the CONSTRUCTOR; disable jit to allow Python callback
solver = ScipyMinimize(
fun=fun,
method=method,
callback=callback,
jit=False,
maxiter=5, # keep it fast
tol=1e-6,
)
x0 = jnp.array([0.0, 0.0])
try:
solver.run(x0)
except Exception as e:
print(f" [FAIL] {method} -> {type(e).__name__}: {e}")
def main():
for m in METHODS:
run_one(m)
if __name__ == "__main__":
main()
Returns:
[FAIL] COBYLA -> AttributeError: nit
[FAIL] trust-constr -> TypeError: ScipyMinimize._run.<locals>.scipy_callback() takes 1 positional argument but 2 were given```
Metadata
Metadata
Assignees
Labels
No labels