-
Notifications
You must be signed in to change notification settings - Fork 44
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: value_and_grad
with support for auxiliary data
#1890
Conversation
7662be5
to
4aa8499
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good, just wondering if we should consider an implementation that doesn't depend on the autograd internals? here's something I cooked up where the value is temporarily stored in aux_data
and then passed to the return.
import autograd as ag
VAL_KEY = "VALUE"
def f(x):
return x**2
def f_aux(x):
return f(x), dict(x=x)
def value_and_grad(f, has_aux: bool=False, **kwargs):
def f_with_aux(*args, **kwargs):
f_out = f(*args, **kwargs)
if has_aux:
val, aux = f_out
else:
val = f_out
aux = {}
aux[VAL_KEY] = val
return val, aux
grad_f_fn = ag.grad_and_aux(f_with_aux)
def _value_and_grad(*args, **kwargs):
grad_f, aux = grad_f_fn(*args, **kwargs)
value = aux.pop(VAL_KEY)
if has_aux:
return (value, aux), grad_f
else:
return value, grad_f
return _value_and_grad
x0 = 1.0
f(x0)
value_and_grad(f)(x0)
# (<autograd.numpy.numpy_boxes.ArrayBox at 0x105dadfc0>, 2.0)
value_and_grad(f_aux, has_aux=True)(x0)
# ((<autograd.numpy.numpy_boxes.ArrayBox at 0x105e20200>,
# {'x': <autograd.numpy.numpy_boxes.ArrayBox at 0x105e201c0>}),
# 2.0)
not sure if it's totally correct due to the presence of ArrayBox/
Toughts?
Are these really autograd internals? I don't think I'm using any hidden functions 😄 |
ah ok, that makes sense. then this implementation is good. I'd just also say we should probably make this importable through |
4aa8499
to
996fa36
Compare
Yeah good idea, done. |
This will be "needed" for the metasurface notebook, but it's useful to have in general.