Skip to content

Commit

Permalink
Improve planar layer (#158)
Browse files Browse the repository at this point in the history
  • Loading branch information
devmotion authored Jan 9, 2021
1 parent 06fa669 commit ca4069a
Show file tree
Hide file tree
Showing 6 changed files with 96 additions and 151 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "Bijectors"
uuid = "76274a88-744f-5084-9051-94815aaf08c4"
version = "0.8.11"
version = "0.8.12"

[deps]
ArgCheck = "dce04be8-c92d-5529-be00-80e4d2c0e197"
Expand Down
1 change: 1 addition & 0 deletions src/Bijectors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,7 @@ function getlogp(d::InverseWishart, Xcf, X)
return -0.5 * ((d.df + dim(d) + 1) * logdet(Xcf) + tr(Xcf \ Ψ)) + d.logc0
end

include("utils.jl")
include("interface.jl")

# Broadcasting here breaks Tracker for some reason
Expand Down
134 changes: 84 additions & 50 deletions src/bijectors/planar_layer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ using NNlib: softplus

# TODO: add docstring

mutable struct PlanarLayer{T1<:AbstractVector{<:Real}, T2<:Union{Real, AbstractVector{<:Real}}} <: Bijector{1}
struct PlanarLayer{T1<:AbstractVector{<:Real}, T2<:Union{Real, AbstractVector{<:Real}}} <: Bijector{1}
w::T1
u::T1
b::T2
Expand All @@ -23,74 +23,103 @@ function Base.:(==)(b1::PlanarLayer, b2::PlanarLayer)
return b1.w == b2.w && b1.u == b2.u && b1.b == b2.b
end

function get_u_hat(u, w)
# To preserve invertibility
x = w' * u
return u .+ (planar_flow_m(x) - x) .* w ./ sum(abs2, w) # from A.1
end

function PlanarLayer(dims::Int, wrapper=identity)
w = wrapper(randn(dims))
u = wrapper(randn(dims))
b = wrapper(randn(1))
return PlanarLayer(w, u, b)
end

planar_flow_m(x) = -1 + softplus(x) # for planar flow from A.1
ψ(z, w, b) = (1 .- tanh.(w' * z .+ b).^2) .* w # for planar flow from eq(11)
"""
get_u_hat(u::AbstractVector{<:Real}, w::AbstractVector{<:Real})
# An internal version of transform that returns intermediate variables
function _transform(flow::PlanarLayer, z::AbstractVecOrMat)
return _planar_transform(flow.u, flow.w, first(flow.b), z)
end
function _planar_transform(u, w, b, z)
u_hat = get_u_hat(u, w)
transformed = z .+ u_hat .* tanh.(w' * z .+ b) # from eq(10)
return (transformed = transformed, u_hat = u_hat)
end
Return a tuple of vector ``û`` that guarantees invertibility of the planar layer, and
scalar ``wᵀ û``.
(b::PlanarLayer)(z) = _transform(b, z).transformed
# Mathematical background
function forward(flow::PlanarLayer, z::AbstractVecOrMat)
transformed, u_hat = _transform(flow, z)
# Compute log_det_jacobian
psi = ψ(z, flow.w, first(flow.b)) .+ zero(eltype(u_hat))
if psi isa AbstractVector
T = eltype(psi)
else
T = typeof(vec(psi))
end
log_det_jacobian::T = log.(abs.(1 .+ psi' * u_hat)) # from eq(12)
return (rv = transformed, logabsdetjac = log_det_jacobian)
According to appendix A.1, vector ``û`` defined by
```math
û(w, u) = u + (\\log(1 + \\exp{(wᵀu)}) - 1 - wᵀu) \\frac{w}{\\|w\\|²}
```
guarantees that the planar layer ``f(z) = z + û tanh(wᵀz + b)`` is invertible for all ``w, u ∈ ℝᵈ`` and ``b ∈ ℝ``.
We can rewrite ``û`` as
```math
û = u + (\\log(1 + \\exp{(-wᵀu)}) - 1) \\frac{w}{\\|w\\|²}.
```
Additionally, we obtain
```math
wᵀû = wᵀu + \\log(1 + \\exp{(-wᵀu)}) - 1 = \\log(1 + \\exp{(wᵀu)}) - 1.
```
# References
D. Rezende, S. Mohamed (2015): Variational Inference with Normalizing Flows.
arXiv:1505.05770
"""
function get_u_hat(u::AbstractVector{<:Real}, w::AbstractVector{<:Real})
wT_u = dot(w, u)
= u .+ ((softplus(-wT_u) - 1) / sum(abs2, w)) .* w
wT_û = softplus(wT_u) - 1
return û, wT_û
end

function (ib::Inverse{<:PlanarLayer})(y::AbstractVector{<:Real})
flow = ib.orig
# An internal version of the transform in eq. (10) that returns intermediate variables
function _transform(flow::PlanarLayer, z::AbstractVecOrMat{<:Real})
w = flow.w
b = first(flow.b)
u_hat = get_u_hat(flow.u, w)
û, wT_û = get_u_hat(flow.u, w)
wT_z = aT_b(w, z)
transformed = z .+.* tanh.(wT_z .+ b)
return (transformed = transformed, wT_û = wT_û, wT_z = wT_z)
end

(b::PlanarLayer)(z) = _transform(b, z).transformed

#=
Log-determinant of the Jacobian of the planar layer
The log-determinant of the Jacobian of the planar layer ``f(z) = z + û tanh(wᵀz + b)``
is given by
```math
\\log |det ∂f(z)/∂z| = \\log |1 + ûᵀsech²(wᵀz + b)w| = \\log |1 + sech²(wᵀz + b) wᵀû|.
```
Since ``0 < sech²(x) ≤ 1`` and
```math
wᵀû = wᵀu + \\log(1 + \\exp{(-wᵀu)}) - 1 = \\log(1 + \\exp{(wᵀu)}) - 1 > -1,
```
we get
```math
\\log |det ∂f(z)/∂z| = \\log(1 + sech²(wᵀz + b) wᵀû).
```
=#
function forward(flow::PlanarLayer, z::AbstractVecOrMat{<:Real})
transformed, wT_û, wT_z = _transform(flow, z)

# Find the scalar ``alpha`` from A.1.
wt_y = dot(w, y)
wt_u_hat = dot(w, u_hat)
alpha = find_alpha(wt_y, wt_u_hat, b)
# Compute ``\\log |det ∂f(z)/∂z|`` (see above).
b = first(flow.b)
log_det_jacobian = log1p.(wT_û .* abs2.(sech.(_vec(wT_z) .+ b)))

return y .- u_hat .* tanh(alpha + b)
return (rv = transformed, logabsdetjac = log_det_jacobian)
end

function (ib::Inverse{<:PlanarLayer})(y::AbstractMatrix{<:Real})
function (ib::Inverse{<:PlanarLayer})(y::AbstractVecOrMat{<:Real})
flow = ib.orig
w = flow.w
b = first(flow.b)
u_hat = get_u_hat(flow.u, flow.w)
û, wT_û = get_u_hat(flow.u, w)

# Find the scalar ``alpha`` from A.1 for each column.
wt_u_hat = dot(w, u_hat)
alphas = mapvcat(eachcol(y)) do c
find_alpha(dot(w, c), wt_u_hat, b)
end
# Find the scalar ``α`` by solving ``wᵀy = α + wᵀû tanh(α + b)``
# (eq. (23) from appendix A.1).
wT_y = aT_b(w, y)
α = find_alpha.(wT_y, wT_û, b)

# Compute ``z = y - û tanh(α + b)``.
z = y .-.* tanh.(α .+ b)

return y .- u_hat .* tanh.(reshape(alphas, 1, :) .+ b)
return z
end

"""
Expand Down Expand Up @@ -121,16 +150,21 @@ which implies ``α̂ ∈ [wt_y - |wt_u_hat|, wt_y + |wt_u_hat|]``.
D. Rezende, S. Mohamed (2015): Variational Inference with Normalizing Flows.
arXiv:1505.05770
"""
function find_alpha(wt_y, wt_u_hat, b)
# Compute the initial bracket.
function find_alpha(wt_y::Real, wt_u_hat::Real, b::Real)
# Compute the initial bracket
_wt_y, _wt_u_hat, _b = promote(wt_y, wt_u_hat, b)
initial_bracket = (_wt_y - abs(_wt_u_hat), _wt_y + abs(_wt_u_hat))

# Try to solve the root-finding problem, i.e., compute a final bracket
prob = NonlinearSolve.NonlinearProblem{false}(initial_bracket) do α, _
α + _wt_u_hat * tanh+ _b) - _wt_y
end
alpha = NonlinearSolve.solve(prob, NonlinearSolve.Falsi()).left
return alpha
sol = NonlinearSolve.solve(prob, NonlinearSolve.Falsi())
if sol.retcode === NonlinearSolve.MAXITERS_EXCEED
@warn "Planar layer: root finding algorithm did not converge" sol
end

return sol.left
end

logabsdetjac(flow::PlanarLayer, x) = forward(flow, x).logabsdetjac
Expand Down
96 changes: 0 additions & 96 deletions src/compat/tracker.jl
Original file line number Diff line number Diff line change
Expand Up @@ -225,102 +225,6 @@ end
end
end

for header in [
(:(u::TrackedArray), :w),
(:u, :(w::TrackedArray)),
(:(u::TrackedArray), :(w::TrackedArray)),
]
@eval begin
function get_u_hat($(header...))
if u isa TrackedArray
T = typeof(u)
else
T = typeof(w)
end
x = w' * u
return (u .+ (planar_flow_m(x) - x) .* w ./ sum(abs2, w))::T
end
end
end

for header in [
(:(z::TrackedArray), :w, :b),
(:z, :(w::TrackedArray), :b),
(:z, :w, :(b::TrackedReal)),
(:(z::TrackedArray), :(w::TrackedArray), :b),
(:(z::TrackedArray), :w, :(b::TrackedReal)),
(:z, :(w::TrackedArray), :(b::TrackedReal)),
(:(z::TrackedArray), :(w::TrackedArray), :(b::TrackedReal)),
]
@eval begin
function ψ($(header...))
if z isa AbstractMatrix
if z isa TrackedMatrix
T = typeof(z)
elseif w isa TrackedVector
T = matrixof(typeof(w))
else
T = matrixof(typeof(b))
end
else
if z isa TrackedVector
T = typeof(z)
elseif w isa TrackedVector
T = typeof(w)
else
T = vectorof(typeof(b))
end
end
return ((1 .- tanh.(w' * z .+ b).^2) .* w)::T # for planar flow from eq(11)
end
end
end

for header in [
(:(u::TrackedArray), :w, :b, :(z::AbstractVecOrMat)),
(:u, :(w::TrackedArray), :b, :(z::AbstractVecOrMat)),
(:u, :w, :(b::TrackedReal), :(z::AbstractVecOrMat)),
(:u, :w, :b, :(z::TrackedVecOrMat)),
(:(u::TrackedArray), :(w::TrackedArray), :b, :(z::AbstractVecOrMat)),
(:(u::TrackedArray), :w, :(b::TrackedReal), :(z::AbstractVecOrMat)),
(:(u::TrackedArray), :w, :b, :(z::TrackedVecOrMat)),
(:u, :(w::TrackedArray), :(b::TrackedReal), :(z::AbstractVecOrMat)),
(:u, :(w::TrackedArray), :b, :(z::TrackedVecOrMat)),
(:u, :w, :(b::TrackedArray), :(z::TrackedVecOrMat)),
(:(u::TrackedArray), :(w::TrackedArray), :(b::TrackedReal), :(z::AbstractVecOrMat)),
(:(u::TrackedArray), :(w::TrackedArray), :b, :(z::TrackedVecOrMat)),
(:(u::TrackedArray), :w, :(b::TrackedReal), :(z::TrackedVecOrMat)),
(:u, :(w::TrackedArray), :(b::TrackedReal), :(z::TrackedVecOrMat)),
(:(u::TrackedArray), :(w::TrackedArray), :(b::TrackedReal), :(z::TrackedVecOrMat)),
]
@eval begin
function _planar_transform($(header...))
u_hat = get_u_hat(u, w)
if z isa AbstractVector
temp = w' * z + b + zero(eltype(u_hat))
if z isa TrackedVector
T = typeof(z)
elseif u_hat isa TrackedVector
T = typeof(u_hat)
else
T = vectorof(typeof(temp))
end
else
temp = w' * z .+ (b + zero(eltype(u_hat)))
if z isa TrackedMatrix
T = typeof(z)
elseif u_hat isa TrackedVector
T = matrixof(typeof(u_hat))
else
T = matrixof(typeof(temp'))
end
end
transformed::T = z .+ u_hat .* tanh.(temp) # from eq(10)
return (transformed = transformed, u_hat = u_hat)
end
end
end

for header in [
(:(α_::TrackedReal), , :z_0, :(z::AbstractVector)),
(:α_, :(β::TrackedReal), :z_0, :(z::AbstractVector)),
Expand Down
8 changes: 8 additions & 0 deletions src/utils.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
# `permutedims` seems to work better with AD (cf. KernelFunctions.jl)
aT_b(a::AbstractVector{<:Real}, b::AbstractMatrix{<:Real}) = permutedims(a) * b
# `permutedims` can't be used here since scalar output is desired
aT_b(a::AbstractVector{<:Real}, b::AbstractVector{<:Real}) = dot(a, b)

# flatten arrays with fallback for scalars
_vec(x::AbstractArray{<:Real}) = vec(x)
_vec(x::Real) = x
6 changes: 2 additions & 4 deletions test/interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -255,12 +255,10 @@ end
# FIXME: `SimplexBijector` results in ∞ gradient if not in the domain
if !contains(t -> t isa SimplexBijector, b)
b_logjac_ad = [logabsdet(ForwardDiff.jacobian(b, xs[:, i]))[1] for i = 1:size(xs, 2)]
tol = isclosedform(b) ? 1e-9 : 1e-1
@test logabsdetjac(b, xs) b_logjac_ad rtol=tol atol=tol
@test logabsdetjac(b, xs) b_logjac_ad atol=1e-9

ib_logjac_ad = [logabsdet(ForwardDiff.jacobian(ib, ys[:, i]))[1] for i = 1:size(ys, 2)]
tol = isclosedform(ib) ? 1e-9 : 1e-1
@test logabsdetjac(ib, ys) ib_logjac_ad rtol=tol atol=tol
@test logabsdetjac(ib, ys) ib_logjac_ad atol=1e-9
end
else
error("tests not implemented yet")
Expand Down

2 comments on commit ca4069a

@devmotion
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/27635

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.8.12 -m "<description of version>" ca4069a6bbc04646df350f5ad604350203fb815f
git push origin v0.8.12

Please sign in to comment.