Skip to content

Commit

Permalink
Simplify inverses of planar and radial layer (#126)
Browse files Browse the repository at this point in the history
Co-authored-by: Tor Erlend Fjelde <[email protected]>
  • Loading branch information
devmotion and torfjelde authored Aug 18, 2020
1 parent 5fd780c commit 2552cdf
Show file tree
Hide file tree
Showing 9 changed files with 113 additions and 63 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/ForwardDiff_Tracker.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ jobs:
strategy:
matrix:
version:
- '1.0'
- '1.3'
- '1'
os:
- ubuntu-latest
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/ReverseDiff.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ jobs:
strategy:
matrix:
version:
- '1.0'
- '1.3'
- '1'
os:
- ubuntu-latest
Expand Down
1 change: 1 addition & 0 deletions .github/workflows/Zygote.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ jobs:
strategy:
matrix:
version:
- '1.3'
- '1'
os:
- ubuntu-latest
Expand Down
2 changes: 1 addition & 1 deletion .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ os:
- linux
- osx
julia:
- 1.0
- 1.3
- 1
- nightly
matrix:
Expand Down
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,14 @@ StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c"
[compat]
ArgCheck = "1, 2"
Compat = "3"
Distributions = "0.23.3"
Distributions = "=0.23.3, =0.23.4, =0.23.5, =0.23.6, =0.23.7, =0.23.8"
MappedArrays = "0.2.2"
NNlib = "0.6, 0.7"
Reexport = "0.2"
Requires = "0.5, 1"
Roots = "0.8.4, 1"
StatsFuns = "0.8, 0.9.3"
julia = "1"
julia = "1.3"

[extras]
Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa"
Expand Down
83 changes: 52 additions & 31 deletions src/bijectors/planar_layer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -65,44 +65,65 @@ function forward(flow::PlanarLayer, z::AbstractVecOrMat)
return (rv = transformed, logabsdetjac = log_det_jacobian)
end

function (ib::Inverse{<: PlanarLayer})(y::AbstractVector{<:Real})
function (ib::Inverse{<:PlanarLayer})(y::AbstractVector{<:Real})
flow = ib.orig
u_hat = get_u_hat(flow.u, flow.w)
T = promote_type(eltype(flow.u), eltype(flow.w), eltype(flow.b), eltype(y))
TV = vectorof(T)
# Define the objective functional; implemented with reference from A.1
f(y) = alpha -> (flow.w' * y) - alpha - (flow.w' * u_hat) * tanh(alpha + first(flow.b))
# Run solver
alpha::T = find_zero(f(y), zero(T), Order16())
z_para::TV = (flow.w ./ norm(flow.w, 2)) .* alpha
return (y .- u_hat .* tanh.(flow.w' * z_para .+ first(flow.b)))::TV
w = flow.w
b = first(flow.b)
u_hat = get_u_hat(flow.u, w)

# Find the scalar ``alpha`` from A.1.
wt_y = dot(w, y)
wt_u_hat = dot(w, u_hat)
alpha = find_alpha(y, wt_y, wt_u_hat, b)

return y .- u_hat .* tanh(alpha * norm(w, 2) + b)
end
function (ib::Inverse{<: PlanarLayer})(y::AbstractMatrix{<:Real})

function (ib::Inverse{<:PlanarLayer})(y::AbstractMatrix{<:Real})
flow = ib.orig
w = flow.w
b = first(flow.b)
u_hat = get_u_hat(flow.u, flow.w)
T = promote_type(eltype(flow.u), eltype(flow.w), eltype(flow.b), eltype(y))
TM = matrixof(T)
# Define the objective functional; implemented with reference from A.1
f(y) = alpha -> (flow.w' * y) - alpha - (flow.w' * u_hat) * tanh(alpha + first(flow.b))
# Run solver
alpha = mapvcat(eachcol(y)) do c
find_zero(f(c), zero(T), Order16())

# Find the scalar ``alpha`` from A.1 for each column.
wt_u_hat = dot(w, u_hat)
alphas = mapvcat(eachcol(y)) do c
find_alpha(c, dot(w, c), wt_u_hat, b)
end
z_para::TM = (flow.w ./ norm(flow.w, 2)) .* alpha'
return (y .- u_hat .* tanh.(flow.w' * z_para .+ first(flow.b)))::TM
end

function matrixof(::Type{Vector{T}}) where {T <: Real}
return Matrix{T}
end
function matrixof(::Type{T}) where {T <: Real}
return Matrix{T}
return y .- u_hat .* tanh.(alphas' .* norm(w, 2) .+ b)
end
function vectorof(::Type{Matrix{T}}) where {T <: Real}
return Vector{T}
end
function vectorof(::Type{T}) where {T <: Real}
return Vector{T}

"""
find_alpha(y::AbstractVector{<:Real}, wt_y, wt_u_hat, b)
Compute an (approximate) real-valued solution ``α`` to the equation
```math
wt_y = α + wt_u_hat tanh(α + b)
```
The uniqueness of the solution is guaranteed since ``wt_u_hat ≥ -1``.
For details see appendix A.1 of the reference.
# References
D. Rezende, S. Mohamed (2015): Variational Inference with Normalizing Flows.
arXiv:1505.05770
"""
function find_alpha(y::AbstractVector{<:Real}, wt_y, wt_u_hat, b)
# Compute the initial bracket ((-Inf, 0) or (0, Inf))
f0 = wt_u_hat * tanh(b) - wt_y
zero_f0 = zero(f0)
if f0 < zero_f0
initial_bracket = (zero_f0, oftype(f0, Inf))
else
initial_bracket = (oftype(f0, -Inf), zero_f0)
end
alpha = find_zero(initial_bracket) do x
x + wt_u_hat * tanh(x + b) - wt_y
end

return alpha
end

logabsdetjac(flow::PlanarLayer, x) = forward(flow, x).logabsdetjac
Expand Down
62 changes: 45 additions & 17 deletions src/bijectors/radial_layer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -70,29 +70,57 @@ end

function (ib::Inverse{<:RadialLayer})(y::AbstractVector{<:Real})
flow = ib.orig
T = promote_type(eltype(flow.α_), eltype(flow.β), eltype(flow.z_0), eltype(y))
TV = vectorof(T)
z0 = flow.z_0
α = softplus(first(flow.α_)) # from A.2
β_hat = - α + softplus(first(flow.β)) # from A.2
# Define the objective functional
f(y) = r -> norm(y .- flow.z_0) - r * (1 + β_hat /+ r)) # from eq(26)
# Run solver
rs::T = find_zero(f(y), zero(T), Order16())
return (flow.z_0 .+ (y .- flow.z_0) ./ (1 .+ β_hat ./.+ rs)))::TV
α_plus_β_hat = softplus(first(flow.β)) # from A.2

# Compute the norm ``r`` from A.2.
y_minus_z0 = y .- z0
r = compute_r(y_minus_z0, α, α_plus_β_hat)

return z0 .+ ((α + r) / (α_plus_β_hat + r)) .* y_minus_z0
end

function (ib::Inverse{<:RadialLayer})(y::AbstractMatrix{<:Real})
flow = ib.orig
T = promote_type(eltype(flow.α_), eltype(flow.β), eltype(flow.z_0), eltype(y))
TM = matrixof(T)
z0 = flow.z_0
α = softplus(first(flow.α_)) # from A.2
β_hat = - α + softplus(first(flow.β)) # from A.2
# Define the objective functional
f(y) = r -> norm(y .- flow.z_0) - r * (1 + β_hat /+ r)) # from eq(26)
# Run solver
rs = mapvcat(eachcol(y)) do c
find_zero(f(c), zero(T), Order16())
α_plus_β_hat = softplus(first(flow.β)) # from A.2

# Compute the norm ``r`` from A.2 for each column.
y_minus_z0 = y .- z0
rs = mapvcat(eachcol(y_minus_z0)) do c
return compute_r(c, α, α_plus_β_hat)
end
return (flow.z_0 .+ (y .- flow.z_0) ./ (1 .+ β_hat ./.+ rs')))::TM

return z0 .+ ((α .+ rs) ./ (α_plus_β_hat .+ rs))' .* y_minus_z0
end

"""
compute_r(y_minus_z0::AbstractVector{<:Real}, α, α_plus_β_hat)
Compute the unique solution ``r`` to the equation
```math
\\|y_minus_z0\\|_2 = r \\left(1 + \\frac{α_plus_β_hat - α}{α + r}\\right)
```
subject to ``r ≥ 0`` and ``r ≠ α``.
Since ``α > 0`` and ``α_plus_β_hat > 0``, the solution is unique and given by
```math
r = (\\sqrt{(α_plus_β_hat - γ)^2 + 4 α γ} - (α_plus_β_hat - γ)) / 2,
```
where ``γ = \\|y_minus_z0\\|_2``. For details see appendix A.2 of the reference.
# References
D. Rezende, S. Mohamed (2015): Variational Inference with Normalizing Flows.
arXiv:1505.05770
"""
function compute_r(y_minus_z0::AbstractVector{<:Real}, α, α_plus_β_hat)
γ = norm(y_minus_z0)
a = α_plus_β_hat - γ
r = (sqrt(a^2 + 4 * α * γ) - a) / 2
return r
end

logabsdetjac(flow::RadialLayer, x::AbstractVecOrMat) = forward(flow, x).logabsdetjac
Expand Down
16 changes: 8 additions & 8 deletions src/compat/tracker.jl
Original file line number Diff line number Diff line change
Expand Up @@ -400,17 +400,17 @@ eachcolnorm(X::TrackedMatrix) = track(eachcolnorm, X)
end
end

function matrixof(::Type{TrackedArray{T, 1, Vector{T}}}) where {T <: Real}
return TrackedArray{T, 2, Matrix{T}}
function matrixof(::Type{<:TrackedArray{T,1,Vector{T}}}) where {T<:Real}
return TrackedArray{T,2,Matrix{T}}
end
function matrixof(::Type{TrackedReal{T}}) where {T <: Real}
return TrackedArray{T, 2, Matrix{T}}
function matrixof(::Type{TrackedReal{T}}) where {T<:Real}
return TrackedArray{T,2,Matrix{T}}
end
function vectorof(::Type{TrackedArray{T, 2, Matrix{T}}}) where {T <: Real}
return TrackedArray{T, 1, Vector{T}}
function vectorof(::Type{<:TrackedArray{T,2,Matrix{T}}}) where {T<:Real}
return TrackedArray{T,1,Vector{T}}
end
function vectorof(::Type{TrackedReal{T}}) where {T <: Real}
return TrackedArray{T, 1, Vector{T}}
function vectorof(::Type{TrackedReal{T}}) where {T<:Real}
return TrackedArray{T,1,Vector{T}}
end

(b::Exp{0})(x::TrackedVector) = exp.(x)::vectorof(float(eltype(x)))
Expand Down
4 changes: 2 additions & 2 deletions test/norm_flows.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@ end
our_method = sum(forward(flow, z).logabsdetjac)

@test our_method forward_diff
@test inv(flow)(flow(z)) z rtol=0.2
@test (inv(flow) flow)(z) z rtol=0.2
@test inv(flow)(flow(z)) z rtol=0.25
@test (inv(flow) flow)(z) z rtol=0.25
end

w = ones(10)
Expand Down

0 comments on commit 2552cdf

Please sign in to comment.