diff --git a/src/layers/basic.jl b/src/layers/basic.jl index 1c87ceb204..2e46c509e5 100644 --- a/src/layers/basic.jl +++ b/src/layers/basic.jl @@ -426,7 +426,7 @@ julia> model2[:β] == model2[2] true ``` """ -struct Parallel{F, T} +struct Parallel{F, T<:Union{NamedTuple,Tuple}} connection::F layers::T end @@ -444,8 +444,8 @@ end @functor Parallel -(m::Parallel)(x) = mapreduce(f -> f(x), m.connection, Tuple(m.layers)) -(m::Parallel)(xs...) = mapreduce((f, x) -> f(x), m.connection, Tuple(m.layers), xs) +(m::Parallel)(x) = m.connection(map(f -> f(x), Tuple(m.layers))...) +(m::Parallel)(xs...) = m.connection(map((f, x) -> f(x), Tuple(m.layers), xs)...) (m::Parallel)(xs::Tuple) = m(xs...) Base.getindex(m::Parallel, i) = m.layers[i] diff --git a/test/layers/basic.jl b/test/layers/basic.jl index 0476895498..777d48e56e 100644 --- a/test/layers/basic.jl +++ b/test/layers/basic.jl @@ -200,8 +200,8 @@ import Flux: activations @testset "vararg input" begin inputs = randn(10), randn(5), randn(4) - @test size(Parallel(+, Dense(10, 2), Dense(5, 2), Dense(4, 2))(inputs)) == (2,) - @test size(Parallel(+; a = Dense(10, 2), b = Dense(5, 2), c = Dense(4, 2))(inputs)) == (2,) + @test size(Parallel(+, Dense(10, 2), Dense(5, 2), Dense(4, 2))(inputs...)) == (2,) + @test size(Parallel(+; a = Dense(10, 2), b = Dense(5, 2), c = Dense(4, 2))(inputs...)) == (2,) end @testset "named access" begin @@ -242,6 +242,26 @@ import Flux: activations @test gs[2].x ≈ gs_reg[2].x @test gs[3].x ≈ gs_reg[3].x end + + @testset "Multiple Inputs/ Multiple Outputs" begin + + struct MIMO{T} + W::T + end + + (m::MIMO)(x, y) = m.W * x, m.W * y, x * y + x = (rand(3,3), rand(3,3)) + + p = Parallel((x...) -> identity.(x), + MIMO(rand(3,3)), + MIMO(rand(3,3))) + + (m::MIMO)(x::Tuple) = m(x...) + mimo_output1 = p(x, x) + @test_broken all(p(((x,),)) .== mimo_output1) # to check for N layers 1 input case + @test length(mimo_output1) == 2 + @test all(x -> length(x) == 3, mimo_output1) + end end @testset "Embedding" begin diff --git a/test/outputsize.jl b/test/outputsize.jl index 2c90811dcb..43d12afb85 100644 --- a/test/outputsize.jl +++ b/test/outputsize.jl @@ -45,13 +45,13 @@ end m = Parallel(vcat, Dense(2, 4, relu), Dense(3, 6, relu)) @test outputsize(m, (2,), (3,)) == (10,) @test outputsize(m, ((2,), (3,))) == (10,) - @test outputsize(m, (2,), (3,); padbatch=true) == (10, 1) + @test outputsize(m, (2,), (3,); padbatch = true) == (10, 1) @test outputsize(m, (2,7), (3,7)) == (10, 7) m = Chain(m, Dense(10, 13, tanh), softmax) @test outputsize(m, (2,), (3,)) == (13,) @test outputsize(m, ((2,), (3,))) == (13,) - @test outputsize(m, (2,), (3,); padbatch=true) == (13, 1) + @test outputsize(m, (2,), (3,); padbatch = true) == (13, 1) @test outputsize(m, (2,7), (3,7)) == (13, 7) end