diff --git a/.github/workflows/ForwardDiff_Tracker.yml b/.github/workflows/ForwardDiff_Tracker.yml index b185c06d..cf30e2b5 100644 --- a/.github/workflows/ForwardDiff_Tracker.yml +++ b/.github/workflows/ForwardDiff_Tracker.yml @@ -13,7 +13,7 @@ jobs: strategy: matrix: version: - - '1.0' + - '1.3' - '1' os: - ubuntu-latest diff --git a/.github/workflows/ReverseDiff.yml b/.github/workflows/ReverseDiff.yml index 30748f36..6c5abb98 100644 --- a/.github/workflows/ReverseDiff.yml +++ b/.github/workflows/ReverseDiff.yml @@ -13,7 +13,7 @@ jobs: strategy: matrix: version: - - '1.0' + - '1.3' - '1' os: - ubuntu-latest diff --git a/.github/workflows/Zygote.yml b/.github/workflows/Zygote.yml index 53ecf7ca..f03c17ab 100644 --- a/.github/workflows/Zygote.yml +++ b/.github/workflows/Zygote.yml @@ -13,6 +13,7 @@ jobs: strategy: matrix: version: + - '1.3' - '1' os: - ubuntu-latest diff --git a/.travis.yml b/.travis.yml index 876d6c25..08d163f5 100644 --- a/.travis.yml +++ b/.travis.yml @@ -7,7 +7,7 @@ os: - linux - osx julia: - - 1.0 + - 1.3 - 1 - nightly matrix: diff --git a/Project.toml b/Project.toml index 40dfcc4d..9ca5e86b 100644 --- a/Project.toml +++ b/Project.toml @@ -20,14 +20,14 @@ StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c" [compat] ArgCheck = "1, 2" Compat = "3" -Distributions = "0.23.3" +Distributions = "=0.23.3, =0.23.4, =0.23.5, =0.23.6, =0.23.7, =0.23.8" MappedArrays = "0.2.2" NNlib = "0.6, 0.7" Reexport = "0.2" Requires = "0.5, 1" Roots = "0.8.4, 1" StatsFuns = "0.8, 0.9.3" -julia = "1" +julia = "1.3" [extras] Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa" diff --git a/src/bijectors/planar_layer.jl b/src/bijectors/planar_layer.jl index b40166c7..fbb24f41 100644 --- a/src/bijectors/planar_layer.jl +++ b/src/bijectors/planar_layer.jl @@ -65,44 +65,65 @@ function forward(flow::PlanarLayer, z::AbstractVecOrMat) return (rv = transformed, logabsdetjac = log_det_jacobian) end -function (ib::Inverse{<: PlanarLayer})(y::AbstractVector{<:Real}) +function (ib::Inverse{<:PlanarLayer})(y::AbstractVector{<:Real}) flow = ib.orig - u_hat = get_u_hat(flow.u, flow.w) - T = promote_type(eltype(flow.u), eltype(flow.w), eltype(flow.b), eltype(y)) - TV = vectorof(T) - # Define the objective functional; implemented with reference from A.1 - f(y) = alpha -> (flow.w' * y) - alpha - (flow.w' * u_hat) * tanh(alpha + first(flow.b)) - # Run solver - alpha::T = find_zero(f(y), zero(T), Order16()) - z_para::TV = (flow.w ./ norm(flow.w, 2)) .* alpha - return (y .- u_hat .* tanh.(flow.w' * z_para .+ first(flow.b)))::TV + w = flow.w + b = first(flow.b) + u_hat = get_u_hat(flow.u, w) + + # Find the scalar ``alpha`` from A.1. + wt_y = dot(w, y) + wt_u_hat = dot(w, u_hat) + alpha = find_alpha(y, wt_y, wt_u_hat, b) + + return y .- u_hat .* tanh(alpha * norm(w, 2) + b) end -function (ib::Inverse{<: PlanarLayer})(y::AbstractMatrix{<:Real}) + +function (ib::Inverse{<:PlanarLayer})(y::AbstractMatrix{<:Real}) flow = ib.orig + w = flow.w + b = first(flow.b) u_hat = get_u_hat(flow.u, flow.w) - T = promote_type(eltype(flow.u), eltype(flow.w), eltype(flow.b), eltype(y)) - TM = matrixof(T) - # Define the objective functional; implemented with reference from A.1 - f(y) = alpha -> (flow.w' * y) - alpha - (flow.w' * u_hat) * tanh(alpha + first(flow.b)) - # Run solver - alpha = mapvcat(eachcol(y)) do c - find_zero(f(c), zero(T), Order16()) + + # Find the scalar ``alpha`` from A.1 for each column. + wt_u_hat = dot(w, u_hat) + alphas = mapvcat(eachcol(y)) do c + find_alpha(c, dot(w, c), wt_u_hat, b) end - z_para::TM = (flow.w ./ norm(flow.w, 2)) .* alpha' - return (y .- u_hat .* tanh.(flow.w' * z_para .+ first(flow.b)))::TM -end -function matrixof(::Type{Vector{T}}) where {T <: Real} - return Matrix{T} -end -function matrixof(::Type{T}) where {T <: Real} - return Matrix{T} + return y .- u_hat .* tanh.(alphas' .* norm(w, 2) .+ b) end -function vectorof(::Type{Matrix{T}}) where {T <: Real} - return Vector{T} -end -function vectorof(::Type{T}) where {T <: Real} - return Vector{T} + +""" + find_alpha(y::AbstractVector{<:Real}, wt_y, wt_u_hat, b) + +Compute an (approximate) real-valued solution ``α`` to the equation +```math +wt_y = α + wt_u_hat tanh(α + b) +``` + +The uniqueness of the solution is guaranteed since ``wt_u_hat ≥ -1``. +For details see appendix A.1 of the reference. + +# References + +D. Rezende, S. Mohamed (2015): Variational Inference with Normalizing Flows. +arXiv:1505.05770 +""" +function find_alpha(y::AbstractVector{<:Real}, wt_y, wt_u_hat, b) + # Compute the initial bracket ((-Inf, 0) or (0, Inf)) + f0 = wt_u_hat * tanh(b) - wt_y + zero_f0 = zero(f0) + if f0 < zero_f0 + initial_bracket = (zero_f0, oftype(f0, Inf)) + else + initial_bracket = (oftype(f0, -Inf), zero_f0) + end + alpha = find_zero(initial_bracket) do x + x + wt_u_hat * tanh(x + b) - wt_y + end + + return alpha end logabsdetjac(flow::PlanarLayer, x) = forward(flow, x).logabsdetjac diff --git a/src/bijectors/radial_layer.jl b/src/bijectors/radial_layer.jl index 43e7c26f..b29ac8d2 100644 --- a/src/bijectors/radial_layer.jl +++ b/src/bijectors/radial_layer.jl @@ -70,29 +70,57 @@ end function (ib::Inverse{<:RadialLayer})(y::AbstractVector{<:Real}) flow = ib.orig - T = promote_type(eltype(flow.α_), eltype(flow.β), eltype(flow.z_0), eltype(y)) - TV = vectorof(T) + z0 = flow.z_0 α = softplus(first(flow.α_)) # from A.2 - β_hat = - α + softplus(first(flow.β)) # from A.2 - # Define the objective functional - f(y) = r -> norm(y .- flow.z_0) - r * (1 + β_hat / (α + r)) # from eq(26) - # Run solver - rs::T = find_zero(f(y), zero(T), Order16()) - return (flow.z_0 .+ (y .- flow.z_0) ./ (1 .+ β_hat ./ (α .+ rs)))::TV + α_plus_β_hat = softplus(first(flow.β)) # from A.2 + + # Compute the norm ``r`` from A.2. + y_minus_z0 = y .- z0 + r = compute_r(y_minus_z0, α, α_plus_β_hat) + + return z0 .+ ((α + r) / (α_plus_β_hat + r)) .* y_minus_z0 end + function (ib::Inverse{<:RadialLayer})(y::AbstractMatrix{<:Real}) flow = ib.orig - T = promote_type(eltype(flow.α_), eltype(flow.β), eltype(flow.z_0), eltype(y)) - TM = matrixof(T) + z0 = flow.z_0 α = softplus(first(flow.α_)) # from A.2 - β_hat = - α + softplus(first(flow.β)) # from A.2 - # Define the objective functional - f(y) = r -> norm(y .- flow.z_0) - r * (1 + β_hat / (α + r)) # from eq(26) - # Run solver - rs = mapvcat(eachcol(y)) do c - find_zero(f(c), zero(T), Order16()) + α_plus_β_hat = softplus(first(flow.β)) # from A.2 + + # Compute the norm ``r`` from A.2 for each column. + y_minus_z0 = y .- z0 + rs = mapvcat(eachcol(y_minus_z0)) do c + return compute_r(c, α, α_plus_β_hat) end - return (flow.z_0 .+ (y .- flow.z_0) ./ (1 .+ β_hat ./ (α .+ rs')))::TM + + return z0 .+ ((α .+ rs) ./ (α_plus_β_hat .+ rs))' .* y_minus_z0 +end + +""" + compute_r(y_minus_z0::AbstractVector{<:Real}, α, α_plus_β_hat) + +Compute the unique solution ``r`` to the equation +```math +\\|y_minus_z0\\|_2 = r \\left(1 + \\frac{α_plus_β_hat - α}{α + r}\\right) +``` +subject to ``r ≥ 0`` and ``r ≠ α``. + +Since ``α > 0`` and ``α_plus_β_hat > 0``, the solution is unique and given by +```math +r = (\\sqrt{(α_plus_β_hat - γ)^2 + 4 α γ} - (α_plus_β_hat - γ)) / 2, +``` +where ``γ = \\|y_minus_z0\\|_2``. For details see appendix A.2 of the reference. + +# References + +D. Rezende, S. Mohamed (2015): Variational Inference with Normalizing Flows. +arXiv:1505.05770 +""" +function compute_r(y_minus_z0::AbstractVector{<:Real}, α, α_plus_β_hat) + γ = norm(y_minus_z0) + a = α_plus_β_hat - γ + r = (sqrt(a^2 + 4 * α * γ) - a) / 2 + return r end logabsdetjac(flow::RadialLayer, x::AbstractVecOrMat) = forward(flow, x).logabsdetjac diff --git a/src/compat/tracker.jl b/src/compat/tracker.jl index 20696441..f9c11c2b 100644 --- a/src/compat/tracker.jl +++ b/src/compat/tracker.jl @@ -400,17 +400,17 @@ eachcolnorm(X::TrackedMatrix) = track(eachcolnorm, X) end end -function matrixof(::Type{TrackedArray{T, 1, Vector{T}}}) where {T <: Real} - return TrackedArray{T, 2, Matrix{T}} +function matrixof(::Type{<:TrackedArray{T,1,Vector{T}}}) where {T<:Real} + return TrackedArray{T,2,Matrix{T}} end -function matrixof(::Type{TrackedReal{T}}) where {T <: Real} - return TrackedArray{T, 2, Matrix{T}} +function matrixof(::Type{TrackedReal{T}}) where {T<:Real} + return TrackedArray{T,2,Matrix{T}} end -function vectorof(::Type{TrackedArray{T, 2, Matrix{T}}}) where {T <: Real} - return TrackedArray{T, 1, Vector{T}} +function vectorof(::Type{<:TrackedArray{T,2,Matrix{T}}}) where {T<:Real} + return TrackedArray{T,1,Vector{T}} end -function vectorof(::Type{TrackedReal{T}}) where {T <: Real} - return TrackedArray{T, 1, Vector{T}} +function vectorof(::Type{TrackedReal{T}}) where {T<:Real} + return TrackedArray{T,1,Vector{T}} end (b::Exp{0})(x::TrackedVector) = exp.(x)::vectorof(float(eltype(x))) diff --git a/test/norm_flows.jl b/test/norm_flows.jl index 4ee46774..6504bad5 100644 --- a/test/norm_flows.jl +++ b/test/norm_flows.jl @@ -27,8 +27,8 @@ end our_method = sum(forward(flow, z).logabsdetjac) @test our_method ≈ forward_diff - @test inv(flow)(flow(z)) ≈ z rtol=0.2 - @test (inv(flow) ∘ flow)(z) ≈ z rtol=0.2 + @test inv(flow)(flow(z)) ≈ z rtol=0.25 + @test (inv(flow) ∘ flow)(z) ≈ z rtol=0.25 end w = ones(10)