Code example for stop_gradient()
#580
-
Could someone please help me with a code example how to use |
Beta Was this translation helpful? Give feedback.
Replies: 2 comments 5 replies
-
def fun(x):
return mx.exp(x) + mx.stop_gradient(mx.exp(x))
print(mx.grad(fun)(mx.array(1.0))) Gives So there you would only get the gradient through the first Compare to: def fun(x):
return mx.exp(x) + mx.exp(x)
print(mx.grad(fun)(mx.array(1.0))) Which gives you twice the result of the first (grad through both paths). Gives |
Beta Was this translation helpful? Give feedback.
-
@awni It may be helpful to have a no_grad decorator to simplify the following: import functools
import mlx.core as mx
def no_grad(fn):
@functools.wraps(fn)
def wrapper(*args, **kwargs):
out = fn(*args, **kwargs)
if isinstance(out, mx.array):
return mx.stop_gradient(out)
elif isinstance(out, (tuple, list)):
return type(out)(
mx.stop_gradient(x) if isinstance(x, mx.array) else x for x in out
)
else:
return out
return wrapper
@no_grad
def fun(x):
return mx.exp(x) + mx.exp(x)
print(mx.grad(fun)(mx.array(1.0)))
# array(0, dtype=float32) |
Beta Was this translation helpful? Give feedback.
Gives
array(2.71828, dtype=float32)
.So there you would only get the gradient through the first
mx.exp
.Compare to:
Which gives you twice the result of the first (grad through both paths). Gives
array(5.43656, dtype=float32)
.