Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Simplify inverses of planar and radial layer #126

Merged
merged 12 commits into from
Aug 18, 2020
2 changes: 1 addition & 1 deletion .github/workflows/ForwardDiff_Tracker.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ jobs:
strategy:
matrix:
version:
- '1.0'
- '1.3'
- '1'
os:
- ubuntu-latest
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/ReverseDiff.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ jobs:
strategy:
matrix:
version:
- '1.0'
- '1.3'
- '1'
os:
- ubuntu-latest
Expand Down
1 change: 1 addition & 0 deletions .github/workflows/Zygote.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ jobs:
strategy:
matrix:
version:
- '1.3'
- '1'
os:
- ubuntu-latest
Expand Down
2 changes: 1 addition & 1 deletion .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ os:
- linux
- osx
julia:
- 1.0
- 1.3
- 1
- nightly
matrix:
Expand Down
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,14 @@ StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c"
[compat]
ArgCheck = "1, 2"
Compat = "3"
Distributions = "0.23.3"
Distributions = "=0.23.3, =0.23.4, =0.23.5, =0.23.6, =0.23.7, =0.23.8"
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Avoids pulling in FillArrays 0.9. The proper fix is rather to adjust Zygote's compatibilities, but I'm not sure how long it takes for the Zygote PR to be released. In the meantime this should avoid the current test failures.

Would be good to fix this before the next release to avoid some weird compatibilities ending up in the registry (which might lead to outdated versions of Bijectors).

MappedArrays = "0.2.2"
NNlib = "0.6, 0.7"
Reexport = "0.2"
Requires = "0.5, 1"
Roots = "0.8.4, 1"
StatsFuns = "0.8, 0.9.3"
julia = "1"
julia = "1.3"

[extras]
Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa"
Expand Down
83 changes: 52 additions & 31 deletions src/bijectors/planar_layer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -65,44 +65,65 @@ function forward(flow::PlanarLayer, z::AbstractVecOrMat)
return (rv = transformed, logabsdetjac = log_det_jacobian)
end

function (ib::Inverse{<: PlanarLayer})(y::AbstractVector{<:Real})
function (ib::Inverse{<:PlanarLayer})(y::AbstractVector{<:Real})
flow = ib.orig
u_hat = get_u_hat(flow.u, flow.w)
T = promote_type(eltype(flow.u), eltype(flow.w), eltype(flow.b), eltype(y))
TV = vectorof(T)
# Define the objective functional; implemented with reference from A.1
f(y) = alpha -> (flow.w' * y) - alpha - (flow.w' * u_hat) * tanh(alpha + first(flow.b))
# Run solver
alpha::T = find_zero(f(y), zero(T), Order16())
z_para::TV = (flow.w ./ norm(flow.w, 2)) .* alpha
return (y .- u_hat .* tanh.(flow.w' * z_para .+ first(flow.b)))::TV
w = flow.w
b = first(flow.b)
u_hat = get_u_hat(flow.u, w)

# Find the scalar ``alpha`` from A.1.
wt_y = dot(w, y)
wt_u_hat = dot(w, u_hat)
alpha = find_alpha(y, wt_y, wt_u_hat, b)

return y .- u_hat .* tanh(alpha * norm(w, 2) + b)
end
function (ib::Inverse{<: PlanarLayer})(y::AbstractMatrix{<:Real})

function (ib::Inverse{<:PlanarLayer})(y::AbstractMatrix{<:Real})
flow = ib.orig
w = flow.w
b = first(flow.b)
u_hat = get_u_hat(flow.u, flow.w)
T = promote_type(eltype(flow.u), eltype(flow.w), eltype(flow.b), eltype(y))
TM = matrixof(T)
# Define the objective functional; implemented with reference from A.1
f(y) = alpha -> (flow.w' * y) - alpha - (flow.w' * u_hat) * tanh(alpha + first(flow.b))
# Run solver
alpha = mapvcat(eachcol(y)) do c
find_zero(f(c), zero(T), Order16())

# Find the scalar ``alpha`` from A.1 for each column.
wt_u_hat = dot(w, u_hat)
alphas = mapvcat(eachcol(y)) do c
find_alpha(c, dot(w, c), wt_u_hat, b)
end
z_para::TM = (flow.w ./ norm(flow.w, 2)) .* alpha'
return (y .- u_hat .* tanh.(flow.w' * z_para .+ first(flow.b)))::TM
end

function matrixof(::Type{Vector{T}}) where {T <: Real}
return Matrix{T}
end
function matrixof(::Type{T}) where {T <: Real}
return Matrix{T}
return y .- u_hat .* tanh.(alphas' .* norm(w, 2) .+ b)
end
function vectorof(::Type{Matrix{T}}) where {T <: Real}
return Vector{T}
end
function vectorof(::Type{T}) where {T <: Real}
return Vector{T}

"""
find_alpha(y::AbstractVector{<:Real}, wt_y, wt_u_hat, b)

Compute an (approximate) real-valued solution ``α`` to the equation
```math
wt_y = α + wt_u_hat tanh(α + b)
```

The uniqueness of the solution is guaranteed since ``wt_u_hat ≥ -1``.
For details see appendix A.1 of the reference.

# References

D. Rezende, S. Mohamed (2015): Variational Inference with Normalizing Flows.
arXiv:1505.05770
"""
function find_alpha(y::AbstractVector{<:Real}, wt_y, wt_u_hat, b)
# Compute the initial bracket ((-Inf, 0) or (0, Inf))
f0 = wt_u_hat * tanh(b) - wt_y
zero_f0 = zero(f0)
if f0 < zero_f0
initial_bracket = (zero_f0, oftype(f0, Inf))
else
initial_bracket = (oftype(f0, -Inf), zero_f0)
end
alpha = find_zero(initial_bracket) do x
x + wt_u_hat * tanh(x + b) - wt_y
end

return alpha
end

logabsdetjac(flow::PlanarLayer, x) = forward(flow, x).logabsdetjac
Expand Down
62 changes: 45 additions & 17 deletions src/bijectors/radial_layer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -70,29 +70,57 @@ end

function (ib::Inverse{<:RadialLayer})(y::AbstractVector{<:Real})
flow = ib.orig
T = promote_type(eltype(flow.α_), eltype(flow.β), eltype(flow.z_0), eltype(y))
TV = vectorof(T)
z0 = flow.z_0
α = softplus(first(flow.α_)) # from A.2
β_hat = - α + softplus(first(flow.β)) # from A.2
# Define the objective functional
f(y) = r -> norm(y .- flow.z_0) - r * (1 + β_hat / (α + r)) # from eq(26)
# Run solver
rs::T = find_zero(f(y), zero(T), Order16())
return (flow.z_0 .+ (y .- flow.z_0) ./ (1 .+ β_hat ./ (α .+ rs)))::TV
α_plus_β_hat = softplus(first(flow.β)) # from A.2

# Compute the norm ``r`` from A.2.
y_minus_z0 = y .- z0
r = compute_r(y_minus_z0, α, α_plus_β_hat)

return z0 .+ ((α + r) / (α_plus_β_hat + r)) .* y_minus_z0
end

function (ib::Inverse{<:RadialLayer})(y::AbstractMatrix{<:Real})
flow = ib.orig
T = promote_type(eltype(flow.α_), eltype(flow.β), eltype(flow.z_0), eltype(y))
TM = matrixof(T)
z0 = flow.z_0
α = softplus(first(flow.α_)) # from A.2
β_hat = - α + softplus(first(flow.β)) # from A.2
# Define the objective functional
f(y) = r -> norm(y .- flow.z_0) - r * (1 + β_hat / (α + r)) # from eq(26)
# Run solver
rs = mapvcat(eachcol(y)) do c
find_zero(f(c), zero(T), Order16())
α_plus_β_hat = softplus(first(flow.β)) # from A.2

# Compute the norm ``r`` from A.2 for each column.
y_minus_z0 = y .- z0
rs = mapvcat(eachcol(y_minus_z0)) do c
return compute_r(c, α, α_plus_β_hat)
end
return (flow.z_0 .+ (y .- flow.z_0) ./ (1 .+ β_hat ./ (α .+ rs')))::TM

return z0 .+ ((α .+ rs) ./ (α_plus_β_hat .+ rs))' .* y_minus_z0
end

"""
compute_r(y_minus_z0::AbstractVector{<:Real}, α, α_plus_β_hat)

Compute the unique solution ``r`` to the equation
```math
\\|y_minus_z0\\|_2 = r \\left(1 + \\frac{α_plus_β_hat - α}{α + r}\\right)
```
subject to ``r ≥ 0`` and ``r ≠ α``.

Since ``α > 0`` and ``α_plus_β_hat > 0``, the solution is unique and given by
```math
r = (\\sqrt{(α_plus_β_hat - γ)^2 + 4 α γ} - (α_plus_β_hat - γ)) / 2,
```
where ``γ = \\|y_minus_z0\\|_2``. For details see appendix A.2 of the reference.

# References

D. Rezende, S. Mohamed (2015): Variational Inference with Normalizing Flows.
arXiv:1505.05770
"""
function compute_r(y_minus_z0::AbstractVector{<:Real}, α, α_plus_β_hat)
γ = norm(y_minus_z0)
a = α_plus_β_hat - γ
r = (sqrt(a^2 + 4 * α * γ) - a) / 2
return r
end

logabsdetjac(flow::RadialLayer, x::AbstractVecOrMat) = forward(flow, x).logabsdetjac
Expand Down
16 changes: 8 additions & 8 deletions src/compat/tracker.jl
Original file line number Diff line number Diff line change
Expand Up @@ -400,17 +400,17 @@ eachcolnorm(X::TrackedMatrix) = track(eachcolnorm, X)
end
end

function matrixof(::Type{TrackedArray{T, 1, Vector{T}}}) where {T <: Real}
return TrackedArray{T, 2, Matrix{T}}
function matrixof(::Type{<:TrackedArray{T,1,Vector{T}}}) where {T<:Real}
return TrackedArray{T,2,Matrix{T}}
end
function matrixof(::Type{TrackedReal{T}}) where {T <: Real}
return TrackedArray{T, 2, Matrix{T}}
function matrixof(::Type{TrackedReal{T}}) where {T<:Real}
return TrackedArray{T,2,Matrix{T}}
end
function vectorof(::Type{TrackedArray{T, 2, Matrix{T}}}) where {T <: Real}
return TrackedArray{T, 1, Vector{T}}
function vectorof(::Type{<:TrackedArray{T,2,Matrix{T}}}) where {T<:Real}
return TrackedArray{T,1,Vector{T}}
end
function vectorof(::Type{TrackedReal{T}}) where {T <: Real}
return TrackedArray{T, 1, Vector{T}}
function vectorof(::Type{TrackedReal{T}}) where {T<:Real}
return TrackedArray{T,1,Vector{T}}
end

(b::Exp{0})(x::TrackedVector) = exp.(x)::vectorof(float(eltype(x)))
Expand Down
4 changes: 2 additions & 2 deletions test/norm_flows.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@ end
our_method = sum(forward(flow, z).logabsdetjac)

@test our_method ≈ forward_diff
@test inv(flow)(flow(z)) ≈ z rtol=0.2
@test (inv(flow) ∘ flow)(z) ≈ z rtol=0.2
@test inv(flow)(flow(z)) ≈ z rtol=0.25
@test (inv(flow) ∘ flow)(z) ≈ z rtol=0.25
end

w = ones(10)
Expand Down