Skip to content

Commit

Permalink
fix half of FluxML#125
Browse files Browse the repository at this point in the history
  • Loading branch information
mcabbott committed Feb 25, 2022
1 parent 67c24cc commit 18fdd64
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 2 deletions.
4 changes: 3 additions & 1 deletion src/lib/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,9 @@ end

for i = 0:2, c = combinations([:AbstractArray, :TrackedArray, :Number], i), f = [:hcat, :vcat]
cnames = map(_ -> gensym(), c)
@eval Base.$f($([:($x::$c) for (x, c) in zip(cnames, c)]...), x::Union{TrackedArray,TrackedReal}, xs::Union{AbstractArray,Number}...) =
@eval Base.$f($([:($x::$c) for (x, c) in zip(cnames, c)]...), x::Union{TrackedArray,TrackedReal}, xs::AbstractArray...) =
track($f, $(cnames...), x, xs...)
@eval Base.$f($([:($x::$c) for (x, c) in zip(cnames, c)]...), x::Union{TrackedArray,TrackedReal}, xs::Number...) =
track($f, $(cnames...), x, xs...)
end

Expand Down
4 changes: 3 additions & 1 deletion test/tracker.jl
Original file line number Diff line number Diff line change
Expand Up @@ -129,9 +129,11 @@ end

@testset "scalars" begin
@test vcat(param([1, 2, 3]), 1) isa TrackedArray
@test vcat(param(1), 2) isa TrackedArray

# These two are ambiguity errors on Julia 1.8
@test vcat(1, param([1, 2, 3])) isa TrackedArray
@test hcat(1, param([1 2 3;])) isa TrackedArray
@test vcat(param(1), 2) isa TrackedArray
end

end
Expand Down

0 comments on commit 18fdd64

Please sign in to comment.