From 8ea80fd58a1923fd053c2a31d5b7b1dcc44a45c2 Mon Sep 17 00:00:00 2001 From: Dhairya Gandhi Date: Mon, 23 Aug 2021 20:30:54 +0530 Subject: [PATCH 1/7] make parallel vararg --- src/layers/basic.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/layers/basic.jl b/src/layers/basic.jl index 0a53b70415..c2ed9f0202 100644 --- a/src/layers/basic.jl +++ b/src/layers/basic.jl @@ -416,7 +416,7 @@ Parallel(connection, layers...) = Parallel(connection, layers) @functor Parallel (m::Parallel)(x::AbstractArray) = mapreduce(f -> f(x), m.connection, m.layers) -(m::Parallel)(xs::Vararg{<:AbstractArray}) = mapreduce((f, x) -> f(x), m.connection, m.layers, xs) +(m::Parallel)(xs::AbstractArray...) = m.connection(map((f, x) -> f(x), m.layers, xs)...) (m::Parallel)(xs::Tuple) = m(xs...) Base.getindex(m::Parallel, i::Integer) = m.layers[i] From f1dab10808cdcf96075e135f23752e057f7cbd77 Mon Sep 17 00:00:00 2001 From: Dhairya Gandhi Date: Mon, 23 Aug 2021 20:48:33 +0530 Subject: [PATCH 2/7] store layers as tuple --- src/layers/basic.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/layers/basic.jl b/src/layers/basic.jl index 07bd2c1fcc..2e46c509e5 100644 --- a/src/layers/basic.jl +++ b/src/layers/basic.jl @@ -426,7 +426,7 @@ julia> model2[:β] == model2[2] true ``` """ -struct Parallel{F, T} +struct Parallel{F, T<:Union{NamedTuple,Tuple}} connection::F layers::T end From cc08fc5085d5b8c63c26c93acc78815b5c59b6e8 Mon Sep 17 00:00:00 2001 From: Dhairya Gandhi Date: Tue, 24 Aug 2021 11:12:05 +0530 Subject: [PATCH 3/7] rm tuple method --- src/layers/basic.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/layers/basic.jl b/src/layers/basic.jl index 2e46c509e5..0c05d042d7 100644 --- a/src/layers/basic.jl +++ b/src/layers/basic.jl @@ -446,7 +446,7 @@ end (m::Parallel)(x) = m.connection(map(f -> f(x), Tuple(m.layers))...) (m::Parallel)(xs...) = m.connection(map((f, x) -> f(x), Tuple(m.layers), xs)...) -(m::Parallel)(xs::Tuple) = m(xs...) +# (m::Parallel)(xs::Tuple) = m(xs...) Base.getindex(m::Parallel, i) = m.layers[i] Base.getindex(m::Parallel, i::AbstractVector) = Parallel(m.connection, m.layers[i]...) From 37dd17e4ad7042f1d8c86c9a38967563734d241d Mon Sep 17 00:00:00 2001 From: Dhairya Gandhi Date: Tue, 24 Aug 2021 11:17:40 +0530 Subject: [PATCH 4/7] test vararg input differently from tuples --- test/layers/basic.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/layers/basic.jl b/test/layers/basic.jl index 0476895498..f9a675bcef 100644 --- a/test/layers/basic.jl +++ b/test/layers/basic.jl @@ -200,8 +200,8 @@ import Flux: activations @testset "vararg input" begin inputs = randn(10), randn(5), randn(4) - @test size(Parallel(+, Dense(10, 2), Dense(5, 2), Dense(4, 2))(inputs)) == (2,) - @test size(Parallel(+; a = Dense(10, 2), b = Dense(5, 2), c = Dense(4, 2))(inputs)) == (2,) + @test size(Parallel(+, Dense(10, 2), Dense(5, 2), Dense(4, 2))(inputs...)) == (2,) + @test size(Parallel(+; a = Dense(10, 2), b = Dense(5, 2), c = Dense(4, 2))(inputs...)) == (2,) end @testset "named access" begin From 182631580c544e1ef5911cafdbe1b410102cb13b Mon Sep 17 00:00:00 2001 From: Dhairya Gandhi Date: Tue, 24 Aug 2021 11:40:02 +0530 Subject: [PATCH 5/7] mimo test --- test/layers/basic.jl | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/test/layers/basic.jl b/test/layers/basic.jl index f9a675bcef..875341f5fd 100644 --- a/test/layers/basic.jl +++ b/test/layers/basic.jl @@ -242,6 +242,27 @@ import Flux: activations @test gs[2].x ≈ gs_reg[2].x @test gs[3].x ≈ gs_reg[3].x end + + @testset "Multiple Inputs/ Multiple Outputs" begin + + struct MIMO{T} + W::T + end + + (m::MIMO)(x, y) = m.W * x, m.W * y, x * y + x = (rand(3,3), rand(3,3)) + + p = Parallel((x...) -> identity.(x), + MIMO(rand(3,3)), + MIMO(rand(3,3))) + + (m::MIMO)(x::Tuple) = m(x...) + mimo_output1 = p(x, x) + mimo_ouput2 = p(x) # to check for N layers 1 input case + @test mimo_output1 ≈ mimo_output2 + @test length(mimo_output1) == 2 + @test all(x -> length(x) == 3, mimo_output1) + end end @testset "Embedding" begin From bf2e87684ef6507a8afb06a3cc05fafc81ba0a47 Mon Sep 17 00:00:00 2001 From: Dhairya Gandhi Date: Tue, 24 Aug 2021 11:56:40 +0530 Subject: [PATCH 6/7] typo --- test/layers/basic.jl | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/test/layers/basic.jl b/test/layers/basic.jl index 875341f5fd..cad3f1eed6 100644 --- a/test/layers/basic.jl +++ b/test/layers/basic.jl @@ -258,8 +258,7 @@ import Flux: activations (m::MIMO)(x::Tuple) = m(x...) mimo_output1 = p(x, x) - mimo_ouput2 = p(x) # to check for N layers 1 input case - @test mimo_output1 ≈ mimo_output2 + mimo_output2 = p(x) # to check for N layers 1 input case @test length(mimo_output1) == 2 @test all(x -> length(x) == 3, mimo_output1) end From 0c4cf948cb23bede116ca3fbf7c1cbfbc2f0bde1 Mon Sep 17 00:00:00 2001 From: Dhairya Gandhi Date: Wed, 25 Aug 2021 03:29:51 +0530 Subject: [PATCH 7/7] broken parallel test --- src/layers/basic.jl | 2 +- test/layers/basic.jl | 2 +- test/outputsize.jl | 4 ++-- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/layers/basic.jl b/src/layers/basic.jl index 0c05d042d7..2e46c509e5 100644 --- a/src/layers/basic.jl +++ b/src/layers/basic.jl @@ -446,7 +446,7 @@ end (m::Parallel)(x) = m.connection(map(f -> f(x), Tuple(m.layers))...) (m::Parallel)(xs...) = m.connection(map((f, x) -> f(x), Tuple(m.layers), xs)...) -# (m::Parallel)(xs::Tuple) = m(xs...) +(m::Parallel)(xs::Tuple) = m(xs...) Base.getindex(m::Parallel, i) = m.layers[i] Base.getindex(m::Parallel, i::AbstractVector) = Parallel(m.connection, m.layers[i]...) diff --git a/test/layers/basic.jl b/test/layers/basic.jl index cad3f1eed6..777d48e56e 100644 --- a/test/layers/basic.jl +++ b/test/layers/basic.jl @@ -258,7 +258,7 @@ import Flux: activations (m::MIMO)(x::Tuple) = m(x...) mimo_output1 = p(x, x) - mimo_output2 = p(x) # to check for N layers 1 input case + @test_broken all(p(((x,),)) .== mimo_output1) # to check for N layers 1 input case @test length(mimo_output1) == 2 @test all(x -> length(x) == 3, mimo_output1) end diff --git a/test/outputsize.jl b/test/outputsize.jl index 2c90811dcb..43d12afb85 100644 --- a/test/outputsize.jl +++ b/test/outputsize.jl @@ -45,13 +45,13 @@ end m = Parallel(vcat, Dense(2, 4, relu), Dense(3, 6, relu)) @test outputsize(m, (2,), (3,)) == (10,) @test outputsize(m, ((2,), (3,))) == (10,) - @test outputsize(m, (2,), (3,); padbatch=true) == (10, 1) + @test outputsize(m, (2,), (3,); padbatch = true) == (10, 1) @test outputsize(m, (2,7), (3,7)) == (10, 7) m = Chain(m, Dense(10, 13, tanh), softmax) @test outputsize(m, (2,), (3,)) == (13,) @test outputsize(m, ((2,), (3,))) == (13,) - @test outputsize(m, (2,), (3,); padbatch=true) == (13, 1) + @test outputsize(m, (2,), (3,); padbatch = true) == (13, 1) @test outputsize(m, (2,7), (3,7)) == (13, 7) end