Skip to content

Commit

Permalink
Rework some inplace stuff. (#72)
Browse files Browse the repository at this point in the history
  • Loading branch information
pkofod authored Jun 11, 2018
1 parent e5c314f commit d714ee0
Show file tree
Hide file tree
Showing 5 changed files with 130 additions and 52 deletions.
1 change: 1 addition & 0 deletions src/NLSolversBase.jl
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ function x_of_nans(x)
x_out
end

include("objective_types/inplace_factory.jl")
include("objective_types/abstract.jl")
include("objective_types/nondifferentiable.jl")
include("objective_types/oncedifferentiable.jl")
Expand Down
70 changes: 70 additions & 0 deletions src/objective_types/inplace_factory.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
"""
f!_from_f(f, F::Abstractarray)
Return an inplace version of f
"""
function f!_from_f(f, F::AbstractArray, inplace)
if inplace
return f
else
return function ff!(F, x)
copyto!(F, f(x))
F
end
end
end
function df!_from_df(g, F::Real, inplace)
if inplace
return g
else
return function gg!(G, x)
gx = g(x)
copyto!(G, gx)
G
end
end
end
function df!_from_df(j, F::AbstractArray, inplace)
if inplace
return j
else
return function jj!(J, x)
jx = j(x)
copyto!(J, jx)
J
end
end
end
function fdf!_from_fdf(fg, F::Real, inplace)
if inplace
return fg
else
return function ffgg!(G, x)
f, g = fg(x)
copyto!(G, g)
f
end
end
end
function fdf!_from_fdf(fj, F::AbstractArray, inplace)
if inplace
return fj
else
return function ffjj!(F, J, x)
f, j = fj(x)
copyto!(J, j)
copyto!(F, f)
end
end
end
function h!_from_h(h, F::Real, inplace)
if inplace
return h
else
return function hh!(H, x)
h = h(x)
copyto!(H, h)
H
end
end
end
54 changes: 10 additions & 44 deletions src/objective_types/oncedifferentiable.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,41 +11,6 @@ mutable struct OnceDifferentiable{TF, TDF, TX} <: AbstractObjective
df_calls::Vector{Int}
end

# This should be refactored to be reused in incomplete.jl
function f!_from_f(f, x, F::AbstractArray)
return function ff!(F, x)
copyto!(F, f(x))
F
end
end
function df!_from_df(g, x, F::Real)
return function gg!(G, x)
gx = g(x)
copyto!(G, gx)
G
end
end
function df!_from_df(j, x, F::AbstractArray)
return function jj!(J, x)
jx = j(x)
copyto!(J, jx)
J
end
end
function fdf!_from_fdf(fg, x, F::Real)
return function ffgg!(G, x)
f, g = fg(x)
copyto!(G, g)
f
end
end
function fdf!_from_fdf(fj, x, F::AbstractArray)
return function ffjj!(F, J, x)
f, j = fj(x)
copyto!(J, j)
copyto!(F, f)
end
end
### Only the objective
# Ambiguity
OnceDifferentiable(f, x::AbstractArray,
Expand All @@ -57,7 +22,7 @@ OnceDifferentiable(f, x::AbstractArray,
function OnceDifferentiable(f, x::AbstractArray,
F::AbstractArray, DF::AbstractArray = alloc_DF(x, F);
inplace = true, autodiff = :finite)
f! = inplace ? f : f!_from_f(f, x, F)
f! = f!_from_f(f, F, inplace)

OnceDifferentiable(f!, x::AbstractArray, F::AbstractArray, DF, autodiff)
end
Expand Down Expand Up @@ -163,7 +128,8 @@ function OnceDifferentiable(f, df,
inplace = true)


df! = inplace ? df : df!_from_df(df, x, F)
df! = df!_from_df(df, F, inplace)

fdf! = make_fdf(x, F, f, df!)

OnceDifferentiable(f, df!, fdf!, x, F, DF)
Expand All @@ -175,8 +141,8 @@ function OnceDifferentiable(f, j,
J::AbstractArray = alloc_DF(x, F);
inplace = true)

f! = inplace ? f : f!_from_f(f, x, F)
j! = inplace ? j : df!_from_df(j, x, F)
f! = f!_from_f(f, F, inplace)
j! = df!_from_df(j, F, inplace)
fj! = make_fdf(x, F, f!, j!)

OnceDifferentiable(f!, j!, fj!, x, F, J)
Expand All @@ -191,8 +157,8 @@ function OnceDifferentiable(f, df, fdf,
inplace = true)

# f is never "inplace" since F is scalar
df! = inplace ? df : df!_from_df(df, x, F)
fdf! = inplace ? fdf : fdf!_from_fdf(fdf, x, F)
df! = df!_from_df(df, F, inplace)
fdf! = fdf!_from_fdf(fdf, F, inplace)

x_f, x_df = x_of_nans(x), x_of_nans(x)

Expand All @@ -208,9 +174,9 @@ function OnceDifferentiable(f, df, fdf,
DF::AbstractArray = alloc_DF(x, F);
inplace = true)

f = inplace ? f : f!_from_f(f, x, F)
df! = inplace ? df : df!_from_df(df, x, F)
fdf! = inplace ? fdf : fdf!_from_fdf(fdf, x, F)
f = f!_from_f(f, F, inplace)
df! = df!_from_df(df, F, inplace)
fdf! = fdf!_from_fdf(fdf, F, inplace)

x_f, x_df = x_of_nans(x), x_of_nans(x)

Expand Down
33 changes: 25 additions & 8 deletions src/objective_types/twicedifferentiable.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,27 +15,44 @@ mutable struct TwiceDifferentiable{T,TDF,TH,TX} <: AbstractObjective
h_calls::Vector{Int}
end
# compatibility with old constructor
function TwiceDifferentiable(f, g!, fg!, h!, x::TX, F::T = real(zero(eltype(x))), G::TG = similar(x), H::TH = alloc_H(x)) where {T, TG, TH, TX}
function TwiceDifferentiable(f, g, fg, h, x::TX, F::T = real(zero(eltype(x))), G::TG = similar(x), H::TH = alloc_H(x); inplace = true) where {T, TG, TH, TX}
x_f, x_df, x_h = x_of_nans(x), x_of_nans(x), x_of_nans(x)

g! = df!_from_df(g, F, inplace)
fg! = fdf!_from_fdf(fg, F, inplace)
h! = h!_from_h(h, F, inplace)

TwiceDifferentiable{T,TG,TH,TX}(f, g!, fg!, h!,
copy(F), similar(G), copy(H),
x_f, x_df, x_h,
[0,], [0,], [0,])
end

function TwiceDifferentiable(f, g!, h!, x::AbstractVector{TX}, F::Real = real(zero(eltype(x))), G = similar(x), H = alloc_H(x)) where {TX}
function TwiceDifferentiable(f, g, h,
x::AbstractVector{TX},
F::Real = real(zero(eltype(x))),
G = similar(x),
H = alloc_H(x); inplace = true) where {TX}

g! = df!_from_df(g, F, inplace)
h! = h!_from_h(h, F, inplace)

fg! = make_fdf(x, F, f, g!)

return TwiceDifferentiable(f, g!, fg!, h!, x, F, G, H)
end



function TwiceDifferentiable(f, g!, x_seed::AbstractVector{T}, F::Real = real(zero(T)); autodiff = :finite) where T
function TwiceDifferentiable(f, g,
x_seed::AbstractVector{T},
F::Real = real(zero(T)); autodiff = :finite, inplace = true) where T
n_x = length(x_seed)
function fg!(storage, x)
g!(storage, x)
return f(x)
end

g! = df!_from_df(g, F, inplace)

fg! = make_fdf(x_seed, F, f, g!)

if autodiff == :finite
# TODO: Create / request Hessian functionality in DiffEqDiffTools?
# (Or is it better to use the finite difference Jacobian of a gradient?)
Expand Down Expand Up @@ -78,7 +95,7 @@ function TwiceDifferentiable(d::OnceDifferentiable, x_seed::AbstractVector{T} =
end

function TwiceDifferentiable(f, x::AbstractVector, F::Real = real(zero(eltype(x)));
autodiff = :finite)
autodiff = :finite, inplace = true)
if autodiff == :finite
# TODO: Allow user to specify Val{:central}, Val{:forward}, Val{:complex}
gcache = DiffEqDiffTools.GradientCache(x, x, Val{:central})
Expand Down
24 changes: 24 additions & 0 deletions test/kwargs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,30 @@
vg4 = value_gradient!(fi3, xr)
@test vg1[1] vg2[1] vg3[1] vg4[1]
@test vg1[2] vg2[2] vg3[2] vg4[2]

ft1 = TwiceDifferentiable(exponential, rand(2), 0.0)
ftia1 = TwiceDifferentiable(exponential, rand(2); inplace = false)
fti1 = TwiceDifferentiable(exponential, exponential_gradient, rand(2); inplace = false)
fti2 = TwiceDifferentiable(exponential, exponential_gradient, rand(2); inplace = false)
fti3 = TwiceDifferentiable(exponential, exponential_gradient, exponential_hessian,
rand(2); inplace = false)
fti4 = TwiceDifferentiable(exponential, exponential_gradient, exponential_value_gradient,
exponential_hessian, rand(2); inplace = false)

@test value!(ft1, xr) value!(ftia1, xr) value!(fti1, xr) value!(fti2, xr) value!(fti3, xr) value!(fti4, xr)
@test gradient!(ft1, xr) gradient!(ftia1, xr)
@test gradient!(ftia1, xr) gradient!(fti1, xr) gradient!(fti2, xr) gradient!(fti3, xr) gradient!(fti4, xr)
vg1 = value_gradient!(ftia1, xr)
vg2 = value_gradient!(fti1, xr)
vg3 = value_gradient!(fti2, xr)
vg4 = value_gradient!(fti3, xr)
vg5 = value_gradient!(fti4, xr)
@test vg1[1] vg2[1] vg3[1] vg4[1] vg5[1]
@test vg1[2] vg2[2] vg3[2] vg4[2] vg5[2]
@test hessian!(ft1, xr) hessian!(ftia1, xr)
@test hessian!(fti1, xr) hessian!(fti2, xr)
@test hessian!(fti3, xr) hessian!(fti4, xr)

# R^N → R^N
f1 = OnceDifferentiable(exponential_gradient!, rand(2), rand(2))
fia1 = OnceDifferentiable(exponential_gradient, rand(2), rand(2); inplace = false)
Expand Down

0 comments on commit d714ee0

Please sign in to comment.