Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make value_and_pushforward_function consistent with other APIs #120

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 23 additions & 23 deletions src/AbstractDifferentiation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ end
AD.value_and_gradient(ab::AD.AbstractBackend, f, xs...)

Return the tuple `(v, gs)` of the function value `v = f(xs...)` and the gradients `gs = AD.gradient(ab, f, xs...)`.

See also [`AbstractDifferentiation.gradient`](@ref).
"""
function value_and_gradient(ab::AbstractBackend, f, xs...)
Expand All @@ -130,7 +130,7 @@ end
AD.value_and_jacobian(ab::AD.AbstractBackend, f, xs...)

Return the tuple `(v, Js)` of the function value `v = f(xs...)` and the Jacobians `Js = AD.jacobian(ab, f, xs...)`.

See also [`AbstractDifferentiation.jacobian`](@ref).
"""
function value_and_jacobian(ab::AbstractBackend, f, xs...)
Expand All @@ -144,7 +144,7 @@ end

Return the tuple `(v, H)` of the function value `v = f(x)` and the Hessian `H = AD.hessian(ab, f, x)`.

See also [`AbstractDifferentiation.hessian`](@ref).
See also [`AbstractDifferentiation.hessian`](@ref).
"""
function value_and_hessian(ab::AbstractBackend, f, x)
if x isa Tuple
Expand Down Expand Up @@ -178,7 +178,7 @@ end

"""
AD.value_gradient_and_hessian(ab::AD.AbstractBackend, f, x)

Return the tuple `(v, g, H)` of the function value `v = f(x)`, the gradient `g = AD.gradient(ab, f, x)`, and the Hessian `H = AD.hessian(ab, f, x)`.

See also [`AbstractDifferentiation.gradient`](@ref) and [`AbstractDifferentiation.hessian`](@ref).
Expand Down Expand Up @@ -219,9 +219,9 @@ end

"""
AD.pushforward_function(ab::AD.AbstractBackend, f, xs...)
Return the pushforward function `pf` of the function `f` at the inputs `xs` using backend `ab`.

Return the pushforward function `pf` of the function `f` at the inputs `xs` using backend `ab`.

The pushfoward function `pf` accepts as input a `Tuple` of tangents, one for each element in `xs`.
If `xs` consists of a single element, `pf` can also accept a single tangent instead of a 1-tuple.
"""
Expand All @@ -246,7 +246,7 @@ end

"""
AD.value_and_pushforward_function(ab::AD.AbstractBackend, f, xs...)

Return a function that, given tangents `ts`, computes the tuple `(v, p)` of the function value `v = f(xs...)` and the output `p` of the pushforward function `AD.pushforward_function(ab, f, xs...)` applied to `ts`.

See also [`AbstractDifferentiation.pushforward_function`](@ref).
Expand All @@ -256,13 +256,13 @@ function value_and_pushforward_function(ab::AbstractBackend, f, xs...)
value = f(xs...)
pf_function = pushforward_function(lowest(ab), f, xs...)

return ds -> begin
return value, ds -> begin
if !(ds isa Tuple)
ds = (ds,)
end
@assert length(ds) == n
pf = pf_function(ds)
return value, pf
return pf
end
end

Expand All @@ -285,8 +285,8 @@ end
"""
AD.pullback_function(ab::AD.AbstractBackend, f, xs...)

Return the pullback function `pb` of the function `f` at the inputs `xs` using backend `ab`.
Return the pullback function `pb` of the function `f` at the inputs `xs` using backend `ab`.

The pullback function `pb` accepts as input a `Tuple` of cotangents, one for each output of `f`.
If `f` has a single output, `pb` can also accept a single input instead of a 1-tuple.
"""
Expand Down Expand Up @@ -511,9 +511,9 @@ end

"""
AD.lazy_derivative(ab::AbstractBackend, f, xs::Number...)

Return an operator `ld` for multiplying by the derivative of `f` at `xs`.

You can apply the operator by multiplication e.g. `ld * y` where `y` is a number if `f` has a single input, a tuple of the same length as `xs` if `f` has multiple inputs, or an array of numbers/tuples.
"""
function lazy_derivative(ab::AbstractBackend, f, xs::Number...)
Expand All @@ -522,9 +522,9 @@ end

"""
AD.lazy_gradient(ab::AbstractBackend, f, xs...)

Return an operator `lg` for multiplying by the gradient of `f` at `xs`.

You can apply the operator by multiplication e.g. `lg * y` where `y` is a number if `f` has a single input or a tuple of the same length as `xs` if `f` has multiple inputs.
"""
function lazy_gradient(ab::AbstractBackend, f, xs...)
Expand All @@ -533,9 +533,9 @@ end

"""
AD.lazy_hessian(ab::AbstractBackend, f, x)

Return an operator `lh` for multiplying by the Hessian of the scalar-valued function `f` at `x`.

You can apply the operator by multiplication e.g. `lh * y` or `y' * lh` where `y` is a number or a vector of the appropriate length.
"""
function lazy_hessian(ab::AbstractBackend, f, xs...)
Expand All @@ -544,10 +544,10 @@ end

"""
AD.lazy_jacobian(ab::AbstractBackend, f, xs...)

Return an operator `lj` for multiplying by the Jacobian of `f` at `xs`.
You can apply the operator by multiplication e.g. `lj * y` or `y' * lj` where `y` is a number, vector or tuple of numbers and/or vectors.

You can apply the operator by multiplication e.g. `lj * y` or `y' * lj` where `y` is a number, vector or tuple of numbers and/or vectors.
If `f` has multiple inputs, `y` in `lj * y` should be a tuple.
If `f` has multiple outputs, `y` in `y' * lj` should be a tuple.
Otherwise, it should be a scalar or a vector of the appropriate length.
Expand Down Expand Up @@ -614,7 +614,7 @@ function define_pushforward_function_and_friends(fdef)
elseif eltype(identity_like) <: AbstractMatrix
# needed for the computation of the Hessian and Jacobian
ret = hcat.(mapslices(identity_like[1]; dims=1) do cols
# cols loop over basis states
# cols loop over basis states
pf = pff((cols,))
if typeof(pf) <: AbstractVector
# to make the hcat. work / get correct matrix-like, non-flat output dimension
Expand Down Expand Up @@ -650,7 +650,7 @@ function define_value_and_pullback_function_and_friends(fdef)
elseif eltype(identity_like) <: AbstractMatrix
# needed for Hessian computation:
# value is a (grad,). Then, identity_like is a (matrix,).
# cols loops over columns of the matrix
# cols loops over columns of the matrix
return vcat.(mapslices(identity_like[1]; dims=1) do cols
adjoint.(pbf((cols,)))
end...)
Expand Down
20 changes: 11 additions & 9 deletions test/test_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@

function test_hessians(backend; multiple_inputs=false, test_types=true)
if multiple_inputs
# ... but
# ... but
error("multiple_inputs=true is not supported.")
else
# explicit test that AbstractDifferentiation throws an error
Expand Down Expand Up @@ -207,11 +207,15 @@

pf1 = map(v -> AD.pushforward_function(backend, fjac, xvec, yvec)(v), vaug)
((valvec1, pf2x), (valvec2, pf2y)) = map(
v -> AD.value_and_pushforward_function(backend, fjac, xvec, yvec)(v), vaug
v -> begin
vl, jvf_1 = AD.value_and_pushforward_function(backend, fjac, xvec, yvec)
vl, jvf_1(v)
end, vaug

Check warning on line 213 in test/test_utils.jl

View workflow job for this annotation

GitHub Actions / format

[JuliaFormatter] reported by reviewdog 🐶 Raw Output: test/test_utils.jl:213:- end, vaug test/test_utils.jl:213:+ end, test/test_utils.jl:214:+ vaug,
)
else
pf1 = AD.pushforward_function(backend, fjac, xvec, yvec)(v)
valvec, pf2 = AD.value_and_pushforward_function(backend, fjac, xvec, yvec)(v)
valvec, jvf_2 = AD.value_and_pushforward_function(backend, fjac, xvec, yvec)
pf2 = jvf_2(v)
((valvec1, pf2x), (valvec2, pf2y)) = (valvec, pf2[1]), (valvec, pf2[2])
end

Expand All @@ -234,12 +238,10 @@
@test yvec == yvec2
end

valvec1, pf1 = AD.value_and_pushforward_function(backend, x -> fjac(x, yvec), xvec)(
v[1]
)
valvec2, pf2 = AD.value_and_pushforward_function(backend, y -> fjac(xvec, y), yvec)(
v[2]
)
valvec1, jvf1 = AD.value_and_pushforward_function(backend, x -> fjac(x, yvec), xvec)
pf1 = jvf1(v[1])
valvec2, jvf2 = AD.value_and_pushforward_function(backend, y -> fjac(xvec, y), yvec)
pf2 = jvf2(v[2])

if test_types
@test valvec1 isa Vector{Float64}
Expand Down
Loading