Skip to content

Commit

Permalink
Generalize Base._cat to non-Val, typed Base._cat_t and implemen…
Browse files Browse the repository at this point in the history
…t `typed_hcat`, `typed_vcat`, `typed_hvcat`, `typed_hvncat` (#163)

* Remove `Val` constraint on `Base._cat` signature

* Remove `Val` constraint on `maybe_expand_dims`

* fix: update src/TracedRArray.jl

* Generalize `Base._cat` implementation on `TracedRArray` to typed `Base._cat_t`

* Update src/TracedRArray.jl

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

* Fix collection type passed to `stablehlo.concatenate`

* Test `cat` methods

* Test result eltype on `*cat` methods

* Fix conversion of integer arrays to `ConcreteRArray`s

* Format code

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

* Fix `_typed_cat`, `_typed_hcat`, `typed_hvcat` dispatches

* Fix `hvcat`

* Convert to target eltype before cat

* Fix `typed_hcat` tests

* Test `typed_hvncat` on vectors

* Refactor tests

* Add more test cases

* Refactor tests

* Fix typo

---------

Co-authored-by: Avik Pal <[email protected]>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
  • Loading branch information
3 people authored Oct 6, 2024
1 parent 9904590 commit f2c0e8a
Show file tree
Hide file tree
Showing 3 changed files with 137 additions and 29 deletions.
83 changes: 69 additions & 14 deletions src/TracedRArray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -761,32 +761,87 @@ function _copyto!(dest::TracedRArray, bc::Broadcasted)
return dest
end

function Base._cat(dims::Val{D}, A::TracedRArray{T,N}, Bs::TracedRArray...) where {T,N,D}
@assert D isa Integer "Support for non-integer dimensions is not implemented yet."
dispatch_val(x) = x
dispatch_val(::Val{D}) where {D} = D

# MLIR expects the dimension `D` to be ≤ the rank of the input tensors
A = maybe_expand_dims(A, dims)
Bs = maybe_expand_dims.(Bs, (dims,))
@inline function Base._typed_vcat(
::Type{T}, X::Base.AbstractVecOrTuple{<:TracedRArray}
) where {T}
return Base._cat_t(Val(1), T, X...)
end
@inline function Base._typed_hcat(
::Type{T}, X::Base.AbstractVecOrTuple{<:TracedRArray}
) where {T}
return Base._cat_t(Val(2), T, X...)
end

# `Base.typed_hvcat` is overloaded for `AbstractVecOrMat` using `setindex!` that breaks Reactant
# generic implementation uses `typed_hcat` and `typed_vcat` which is alright
@inline function Base.typed_hvcat(
::Type{T}, rows::Tuple{Vararg{Int}}, as::TracedRArray...
) where {T}
return invoke(
Base.typed_hvcat, Tuple{Type{T},Tuple{Vararg{Int}},Vararg{Any}}, T, rows, as...
)
end

function Base._typed_hvncat(
T::Type, dims::NTuple{N,Int}, row_first::Bool, as::TracedRArray...
) where {N}
As = if row_first
perm = [2, 1, 3:N...]
dims = [dims[2], dims[1], dims[3:end]...]
permutedims(reshape(collect(as), dims...), perm)
else
reshape(collect(as), dims)
end

for d in 1:N
Bs = Array{Any,N - d}(undef, size(As)[2:end]...)

for (i, col) in
zip(eachindex(Bs), eachslice(As; dims=Tuple(2:ndims(As)), drop=true))
# TODO row_first affects the flattening?
Bs[i] = Base._cat_t(d, T, col...)
end

As = Bs
end

return only(As)
end

function Base._cat_t(dims, ::Type{T}, X::TracedRArray...) where {T}
dims = dispatch_val(dims)
@assert dims isa Integer "Support for non-integer dimensions is not implemented yet."

# MLIR expects the dimension `dims` to be ≤ the rank of the input tensors
X = maybe_expand_dims.(X, (dims,))

catdims = Base.dims2cat(dims)
shape = Base.cat_size_shape(catdims, A, Bs...)
RT = Base.promote_eltype(A, Bs...)
Res = TracedRArray{RT,length(shape)}(
shape = Base.cat_size_shape(catdims, X...)
RT = Base.promote_eltype(T, X...)

# convert to the target eltype
X = map(Base.Fix1(promote_to, TracedRArray{RT,length(shape)}), X)

return TracedRArray{RT,length(shape)}(
(),
MLIR.IR.result(
# TODO maybe we should do some conversion?
MLIR.Dialects.stablehlo.concatenate(
[A.mlir_data, [B.mlir_data for B in Bs]...];
collect(get_mlir_data.(X));
result_0=MLIR.IR.TensorType(shape, MLIR.IR.Type(RT)),
dimension=D - 1, # stablehlo expects this to be zero-indexed
dimension=dims - 1, # stablehlo expects this to be zero-indexed
),
1,
),
shape,
)
return Res
end

function maybe_expand_dims(x::AbstractArray{T,N}, ::Val{D}) where {T,N,D}
D N && return x
return reshape(x, ntuple(i -> i N ? size(x, i) : 1, Val(D)))
function maybe_expand_dims(x::AbstractArray{T,N}, dims) where {T,N}
dims = dispatch_val(dims)
dims N && return x
return reshape(x, ntuple(i -> i N ? size(x, i) : 1, dims))
end
2 changes: 1 addition & 1 deletion src/Tracing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -380,7 +380,7 @@ function make_tracer(
if haskey(seen, prev)
return seen[prev]
end
if mode == ArrayToConcrete && eltype(RT) <: AbstractFloat
if mode == ArrayToConcrete && eltype(RT) <: Union{AbstractFloat,Integer}
return seen[prev] = ConcreteRArray(prev)
end
TT = traced_type(eltype(RT), (), Val(mode))
Expand Down
81 changes: 67 additions & 14 deletions test/basic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -210,20 +210,73 @@ end
end

@testset "concatenation" begin
x = ones(2, 4, 3)
x_concrete = Reactant.to_rarray(x)

cat1(x) = vcat(x, x, x)
cat2(x) = hcat(x, x, x)
cat3(x) = cat(x, x, x; dims=Val(3))

cat1_compiled = @compile cat1(x_concrete)
cat2_compiled = @compile cat2(x_concrete)
cat3_compiled = @compile cat3(x_concrete)

@test cat1(x) cat1_compiled(x_concrete)
@test cat2(x) cat2_compiled(x_concrete)
@test cat3(x) cat3_compiled(x_concrete)
@testset "$(ndims(x))-dim" for x in [
fill(true),
[true, false],
[true false],
[true true; true false],
[
true true true true; true true true false;;;
true true false true; true true false false;;;
true false true true; true false true false
],
]
x_concrete = Reactant.to_rarray(x)

# NOTE [,,,] is a call to `vect`, not `*cat`
# f = Reactant.compile((x_concrete,)) do x
# return [x, x, x]
# end
# @test f(x_concrete) ≈ ones(3)

# vcat
test_vcat(x) = [x; x; x]
f = @compile test_vcat(x_concrete)
@test f(x_concrete) == test_vcat(x)
@test eltype(f(x_concrete)) === Bool

# hcat
test_hcat(x) = [x x x]
f = @compile test_hcat(x_concrete)
@test f(x_concrete) == test_hcat(x)
@test eltype(f(x_concrete)) === Bool

# hvcat
test_hvcat(x) = [x x x; x x x]
f = @compile test_hvcat(x_concrete)
@test f(x_concrete) == test_hvcat(x)
@test eltype(f(x_concrete)) === Bool

# hvncat
test_hvncat(x) = [x x x; x x x;;; x x x; x x x]
f = @compile test_hvncat(x_concrete)
@test f(x_concrete) == test_hvncat(x)
@test eltype(f(x_concrete)) === Bool

# typed_vcat
test_typed_vcat(x) = Int[x; x; x]
f = @compile test_typed_vcat(x_concrete)
@test f(x_concrete) == test_typed_vcat(x)
@test eltype(f(x_concrete)) === Int

# typed_hcat
test_typed_hcat(x) = Int[x x x]
f = @compile test_typed_hcat(x_concrete)
@test f(x_concrete) == test_typed_hcat(x)
@test eltype(f(x_concrete)) === Int

# typed_hvcat
test_typed_hvcat(x) = Int[x x x; x x x]
f = @compile test_typed_hvcat(x_concrete)
@test f(x_concrete) == test_typed_hvcat(x)
@test eltype(f(x_concrete)) === Int

# typed_hvncat
test_typed_hvncat(x) = Int[x x x; x x x;;; x x x; x x x]
f = @compile test_typed_hvncat(x_concrete)
@test f(x_concrete) == test_typed_hvncat(x)
@test eltype(f(x_concrete)) === Int
end
end

function update_on_copy(x)
Expand Down

1 comment on commit f2c0e8a

@github-actions
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reactant.jl Benchmarks

Benchmark suite Current: f2c0e8a Previous: 9904590 Ratio
ViT base (256 x 256 x 3 x 32)/forward/CUDA/Reactant 1315729546 ns 1349734013 ns 0.97
ViT base (256 x 256 x 3 x 32)/forward/CUDA/Lux 212083499 ns 206256598 ns 1.03
ViT base (256 x 256 x 3 x 32)/forward/CPU/Reactant 5286469750 ns 5640616179 ns 0.94
ViT base (256 x 256 x 3 x 32)/forward/CPU/Lux 23583347555 ns 21088566845 ns 1.12
ViT small (256 x 256 x 3 x 4)/forward/CUDA/Reactant 1254858296 ns 1271785039.5 ns 0.99
ViT small (256 x 256 x 3 x 4)/forward/CUDA/Lux 8478570 ns 8622303 ns 0.98
ViT small (256 x 256 x 3 x 4)/forward/CPU/Reactant 1636237670 ns 1620712407 ns 1.01
ViT small (256 x 256 x 3 x 4)/forward/CPU/Lux 2376437823 ns 2515728657 ns 0.94
ViT tiny (256 x 256 x 3 x 32)/forward/CUDA/Reactant 1266018905 ns 1309552971.5 ns 0.97
ViT tiny (256 x 256 x 3 x 32)/forward/CUDA/Lux 84820407 ns 88993975 ns 0.95
ViT tiny (256 x 256 x 3 x 32)/forward/CPU/Reactant 2170879105 ns 2241028351 ns 0.97
ViT tiny (256 x 256 x 3 x 32)/forward/CPU/Lux 4675094299 ns 5778081364 ns 0.81
ViT tiny (256 x 256 x 3 x 4)/forward/CUDA/Reactant 1263496480 ns 1267184417.5 ns 1.00
ViT tiny (256 x 256 x 3 x 4)/forward/CUDA/Lux 7782824 ns 7556327.5 ns 1.03
ViT tiny (256 x 256 x 3 x 4)/forward/CPU/Reactant 1467043032.5 ns 1485636377.5 ns 0.99
ViT tiny (256 x 256 x 3 x 4)/forward/CPU/Lux 1685775445 ns 1618719051 ns 1.04
ViT tiny (256 x 256 x 3 x 16)/forward/CUDA/Reactant 1306815930 ns 1277921381 ns 1.02
ViT tiny (256 x 256 x 3 x 16)/forward/CUDA/Lux 11611908 ns 11579756 ns 1.00
ViT tiny (256 x 256 x 3 x 16)/forward/CPU/Reactant 1752808523 ns 1771132012 ns 0.99
ViT tiny (256 x 256 x 3 x 16)/forward/CPU/Lux 2463987825.5 ns 2573621234.5 ns 0.96
ViT small (256 x 256 x 3 x 16)/forward/CUDA/Reactant 1325877558.5 ns 1308609555 ns 1.01
ViT small (256 x 256 x 3 x 16)/forward/CUDA/Lux 90330187 ns 86522099 ns 1.04
ViT small (256 x 256 x 3 x 16)/forward/CPU/Reactant 2213119086 ns 2232948717 ns 0.99
ViT small (256 x 256 x 3 x 16)/forward/CPU/Lux 4023816395 ns 4522528342 ns 0.89
ViT small (256 x 256 x 3 x 32)/forward/CUDA/Reactant 1270812264 ns 1299910120.5 ns 0.98
ViT small (256 x 256 x 3 x 32)/forward/CUDA/Lux 113097539 ns 107942560 ns 1.05
ViT small (256 x 256 x 3 x 32)/forward/CPU/Reactant 3042643080 ns 3089121407 ns 0.98
ViT small (256 x 256 x 3 x 32)/forward/CPU/Lux 8210106924.5 ns 12146634131 ns 0.68
ViT base (256 x 256 x 3 x 16)/forward/CUDA/Reactant 1324054039 ns 1317150051 ns 1.01
ViT base (256 x 256 x 3 x 16)/forward/CUDA/Lux 127669686.5 ns 121564537 ns 1.05
ViT base (256 x 256 x 3 x 16)/forward/CPU/Reactant 3203794253 ns 3268410602 ns 0.98
ViT base (256 x 256 x 3 x 16)/forward/CPU/Lux 11004907984 ns 10888829737 ns 1.01
ViT base (256 x 256 x 3 x 4)/forward/CUDA/Reactant 1299288245 ns 1345840157 ns 0.97
ViT base (256 x 256 x 3 x 4)/forward/CUDA/Lux 96277750 ns 78626182 ns 1.22
ViT base (256 x 256 x 3 x 4)/forward/CPU/Reactant 2155333265.5 ns 2037797510.5 ns 1.06
ViT base (256 x 256 x 3 x 4)/forward/CPU/Lux 2863535293.5 ns 2617867593 ns 1.09

This comment was automatically generated by workflow using github-action-benchmark.

Please sign in to comment.