Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Compiling Recurrent Models with Reactant #1025

Closed
avik-pal opened this issue Nov 4, 2024 · 4 comments · Fixed by #1026
Closed

Compiling Recurrent Models with Reactant #1025

avik-pal opened this issue Nov 4, 2024 · 4 comments · Fixed by #1026

Comments

@avik-pal
Copy link
Member

avik-pal commented Nov 4, 2024

using Lux, Reactant, Random

model = Recurrence(RNNCell(4 => 4))
ps, st = Lux.setup(Xoshiro(123), model) |> Reactant.to_rarray
x = rand(Float32, 4, 16, 12) |> Reactant.ConcreteRArray

@code_hlo model(x, ps, st)

The issue originates from the following function moving the data to a ConcreteRArray instead of a TracedRArray when run inside a compilation context. I could use @reactant_override to define custom dispatches when used inside Reactant.compile, but I am not sure that should be the recommended usage

function init_rnn_hidden_state(rng::AbstractRNG, rnn, x::AbstractMatrix)
    # TODO: Once we support moving `rng` to the device, we can directly initialize on the
    #       device
    return rnn.init_state(rng, rnn.out_dims, Base.size(x, 2)) |> get_device(x)
end
@avik-pal
Copy link
Member Author

avik-pal commented Nov 4, 2024

Error for full context:

1-element ExceptionStack:
LoadError: MethodError: no method matching Float32(::Reactant.TracedRNumber{Float32})
The type `Float32` exists, but no method is defined for this combination of argument types when trying to construct it.

Closest candidates are:
  (::Type{T})(::T) where T<:Number
   @ Core boot.jl:900
  Float32(::IrrationalConstants.Log2π)
   @ IrrationalConstants ~/.julia/packages/IrrationalConstants/vp5v4/src/macro.jl:113
  Float32(::IrrationalConstants.Halfπ)
   @ IrrationalConstants ~/.julia/packages/IrrationalConstants/vp5v4/src/macro.jl:113
  ...

Stacktrace:
  [1] convert(::Type{Float32}, x::Reactant.TracedRNumber{Float32})
    @ Base ./number.jl:7
  [2] unsafe_store!(p::Ptr{Float32}, x::Reactant.TracedRNumber{Float32}, i::Int64)
    @ Base ./pointer.jl:180
  [3] setindex!(::ConcreteRArray{Float32, 2}, ::Reactant.TracedRNumber{Float32}, ::Int64, ::Int64)
    @ Reactant /mnt/software/lux/Reactant.jl/src/ConcreteRArray.jl:233
  [4] macro expansion
    @ ~/.julia/juliaup/julia-1.11.1+0.x64.linux.gnu/share/julia/stdlib/v1.11/LinearAlgebra/src/matmul.jl:894 [inlined]
  [5] macro expansion
    @ ./simdloop.jl:77 [inlined]
  [6] _generic_matmatmul!(C::ConcreteRArray{Float32, 2}, A::Reactant.TracedRArray{Float32, 2}, B::ConcreteRArray{Float32, 2}, _add::LinearAlgebra.MulAddMul{true, true, Bool, Bool})
    @ LinearAlgebra ~/.julia/juliaup/julia-1.11.1+0.x64.linux.gnu/share/julia/stdlib/v1.11/LinearAlgebra/src/matmul.jl:893
  [7] generic_matmatmul!
    @ ~/.julia/juliaup/julia-1.11.1+0.x64.linux.gnu/share/julia/stdlib/v1.11/LinearAlgebra/src/matmul.jl:868 [inlined]
  [8] _mul!
    @ ~/.julia/juliaup/julia-1.11.1+0.x64.linux.gnu/share/julia/stdlib/v1.11/LinearAlgebra/src/matmul.jl:287 [inlined]
  [9] mul!
    @ ~/.julia/juliaup/julia-1.11.1+0.x64.linux.gnu/share/julia/stdlib/v1.11/LinearAlgebra/src/matmul.jl:285 [inlined]
 [10] mul!
    @ ~/.julia/juliaup/julia-1.11.1+0.x64.linux.gnu/share/julia/stdlib/v1.11/LinearAlgebra/src/matmul.jl:253 [inlined]
 [11] *
    @ ~/.julia/juliaup/julia-1.11.1+0.x64.linux.gnu/share/julia/stdlib/v1.11/LinearAlgebra/src/matmul.jl:114 [inlined]
 [12] muladd(A::Reactant.TracedRArray{Float32, 2}, y::ConcreteRArray{Float32, 2}, z::Reactant.TracedRArray{Float32, 1})
    @ LinearAlgebra ~/.julia/juliaup/julia-1.11.1+0.x64.linux.gnu/share/julia/stdlib/v1.11/LinearAlgebra/src/matmul.jl:180
 [13] matmuladd
    @ ~/.julia/packages/LuxLib/I9RHW/src/impl/matmul.jl:12 [inlined]
 [14] matmuladd
    @ ~/.julia/packages/LuxLib/I9RHW/src/impl/matmul.jl:7 [inlined]
 [15] fused_dense
    @ ~/.julia/packages/LuxLib/I9RHW/src/impl/dense.jl:6 [inlined]
 [16] fused_dense_bias_activation
    @ ~/.julia/packages/LuxLib/I9RHW/src/api/dense.jl:35 [inlined]
 [17] RNNCell
    @ ~/.julia/packages/Lux/atwzZ/src/layers/recurrent.jl:291 [inlined]
 [18] (::RNNCell{Static.False, typeof(tanh), Int64, Int64, Nothing, Nothing, typeof(zeros32), Static.True})(x::SubArray{Float32, 2, Reactant.TracedRArray{Float32, 3}, Tuple{Base.Slice{Base.OneTo{Int64}}, Int64, Base.Slice{Base.OneTo{Int64}}}, false}, ps::@NamedTuple{weight_ih::Reactant.TracedRArray{Float32, 2}, weight_hh::Reactant.TracedRArray{Float32, 2}, bias_ih::Reactant.TracedRArray{Float32, 1}, bias_hh::Reactant.TracedRArray{Float32, 1}}, st::@NamedTuple{rng::Xoshiro})
    @ Lux ~/.julia/packages/Lux/atwzZ/src/layers/recurrent.jl:277
 [19] apply
    @ ~/.julia/packages/LuxCore/IBKvY/src/LuxCore.jl:155 [inlined]
 [20] (::Recurrence{Static.False, RNNCell{Static.False, typeof(tanh), Int64, Int64, Nothing, Nothing, typeof(zeros32), Static.True}, BatchLastIndex})(x::Vector{SubArray{Float32, 2, Reactant.TracedRArray{Float32, 3}, Tuple{Base.Slice{Base.OneTo{Int64}}, Int64, Base.Slice{Base.OneTo{Int64}}}, false}}, ps::@NamedTuple{weight_ih::Reactant.TracedRArray{Float32, 2}, weight_hh::Reactant.TracedRArray{Float32, 2}, bias_ih::Reactant.TracedRArray{Float32, 1}, bias_hh::Reactant.TracedRArray{Float32, 1}}, st::@NamedTuple{rng::Xoshiro})
    @ Lux ~/.julia/packages/Lux/atwzZ/src/layers/recurrent.jl:118
 [21] apply
    @ ~/.julia/packages/LuxCore/IBKvY/src/LuxCore.jl:155 [inlined]
 [22] Recurrence
    @ ~/.julia/packages/Lux/atwzZ/src/layers/recurrent.jl:114 [inlined]
 [23] #apply#19
    @ /mnt/software/lux/Reactant.jl/src/utils.jl:33 [inlined]
 [24] apply
    @ /mnt/software/lux/Reactant.jl/src/utils.jl:32 [inlined]
 [25] (::Tuple{})(none::Recurrence{Static.False, RNNCell{Static.False, typeof(tanh), Int64, Int64, Nothing, Nothing, typeof(zeros32), Static.True}, BatchLastIndex}, none::Tuple{Reactant.TracedRArray{Float32, 3}, @NamedTuple{weight_ih::Reactant.TracedRArray{Float32, 2}, weight_hh::Reactant.TracedRArray{Float32, 2}, bias_ih::Reactant.TracedRArray{Float32, 1}, bias_hh::Reactant.TracedRArray{Float32, 1}}, @NamedTuple{rng::Xoshiro}})
    @ Base.Experimental ./<missing>:0
 [26] (::Reactant.var"#26#35"{Bool, typeof(Reactant.apply), Tuple{Recurrence{Static.False, RNNCell{Static.False, typeof(tanh), Int64, Int64, Nothing, Nothing, typeof(zeros32), Static.True}, BatchLastIndex}, ConcreteRArray{Float32, 3}, @NamedTuple{weight_ih::ConcreteRArray{Float32, 2}, weight_hh::ConcreteRArray{Float32, 2}, bias_ih::ConcreteRArray{Float32, 1}, bias_hh::ConcreteRArray{Float32, 1}}, @NamedTuple{rng::Xoshiro}}, Vector{Union{ReactantCore.MissingTracedValue, Reactant.TracedRArray, Reactant.TracedRNumber}}, Tuple{Recurrence{Static.False, RNNCell{Static.False, typeof(tanh), Int64, Int64, Nothing, Nothing, typeof(zeros32), Static.True}, BatchLastIndex}, Reactant.TracedRArray{Float32, 3}, @NamedTuple{weight_ih::Reactant.TracedRArray{Float32, 2}, weight_hh::Reactant.TracedRArray{Float32, 2}, bias_ih::Reactant.TracedRArray{Float32, 1}, bias_hh::Reactant.TracedRArray{Float32, 1}}, @NamedTuple{rng::Xoshiro}}})()
    @ Reactant /mnt/software/lux/Reactant.jl/src/utils.jl:139
 [27] block!(f::Reactant.var"#26#35"{Bool, typeof(Reactant.apply), Tuple{Recurrence{Static.False, RNNCell{Static.False, typeof(tanh), Int64, Int64, Nothing, Nothing, typeof(zeros32), Static.True}, BatchLastIndex}, ConcreteRArray{Float32, 3}, @NamedTuple{weight_ih::ConcreteRArray{Float32, 2}, weight_hh::ConcreteRArray{Float32, 2}, bias_ih::ConcreteRArray{Float32, 1}, bias_hh::ConcreteRArray{Float32, 1}}, @NamedTuple{rng::Xoshiro}}, Vector{Union{ReactantCore.MissingTracedValue, Reactant.TracedRArray, Reactant.TracedRNumber}}, Tuple{Recurrence{Static.False, RNNCell{Static.False, typeof(tanh), Int64, Int64, Nothing, Nothing, typeof(zeros32), Static.True}, BatchLastIndex}, Reactant.TracedRArray{Float32, 3}, @NamedTuple{weight_ih::Reactant.TracedRArray{Float32, 2}, weight_hh::Reactant.TracedRArray{Float32, 2}, bias_ih::Reactant.TracedRArray{Float32, 1}, bias_hh::Reactant.TracedRArray{Float32, 1}}, @NamedTuple{rng::Xoshiro}}}, blk::Reactant.MLIR.IR.Block)
    @ Reactant.MLIR.IR /mnt/software/lux/Reactant.jl/src/mlir/IR/Block.jl:201
 [28] make_mlir_fn(f::Function, args::Tuple{Recurrence{Static.False, RNNCell{Static.False, typeof(tanh), Int64, Int64, Nothing, Nothing, typeof(zeros32), Static.True}, BatchLastIndex}, ConcreteRArray{Float32, 3}, @NamedTuple{weight_ih::ConcreteRArray{Float32, 2}, weight_hh::ConcreteRArray{Float32, 2}, bias_ih::ConcreteRArray{Float32, 1}, bias_hh::ConcreteRArray{Float32, 1}}, @NamedTuple{rng::Xoshiro}}, kwargs::Tuple{}, name::String, concretein::Bool; toscalar::Bool, return_dialect::Symbol, no_args_in_result::Bool, construct_function_without_args::Bool)
    @ Reactant /mnt/software/lux/Reactant.jl/src/utils.jl:112
 [29] make_mlir_fn(f::Recurrence{Static.False, RNNCell{Static.False, typeof(tanh), Int64, Int64, Nothing, Nothing, typeof(zeros32), Static.True}, BatchLastIndex}, args::Vector{Any}, kwargs::Tuple{}, name::String, concretein::Bool; toscalar::Bool, return_dialect::Symbol, no_args_in_result::Bool, construct_function_without_args::Bool)
    @ Reactant /mnt/software/lux/Reactant.jl/src/utils.jl:48
 [30] make_mlir_fn
    @ /mnt/software/lux/Reactant.jl/src/utils.jl:36 [inlined]
 [31] #6
    @ /mnt/software/lux/Reactant.jl/src/Compiler.jl:270 [inlined]
 [32] block!(f::Reactant.Compiler.var"#6#11"{Recurrence{Static.False, RNNCell{Static.False, typeof(tanh), Int64, Int64, Nothing, Nothing, typeof(zeros32), Static.True}, BatchLastIndex}, Vector{Any}}, blk::Reactant.MLIR.IR.Block)
    @ Reactant.MLIR.IR /mnt/software/lux/Reactant.jl/src/mlir/IR/Block.jl:201
 [33] #5
    @ /mnt/software/lux/Reactant.jl/src/Compiler.jl:269 [inlined]
 [34] mmodule!(f::Reactant.Compiler.var"#5#10"{Reactant.MLIR.IR.Module, Recurrence{Static.False, RNNCell{Static.False, typeof(tanh), Int64, Int64, Nothing, Nothing, typeof(zeros32), Static.True}, BatchLastIndex}, Vector{Any}}, blk::Reactant.MLIR.IR.Module)
    @ Reactant.MLIR.IR /mnt/software/lux/Reactant.jl/src/mlir/IR/Module.jl:93
 [35] compile_mlir!(mod::Reactant.MLIR.IR.Module, f::Recurrence{Static.False, RNNCell{Static.False, typeof(tanh), Int64, Int64, Nothing, Nothing, typeof(zeros32), Static.True}, BatchLastIndex}, args::Vector{Any}; optimize::Bool)
    @ Reactant.Compiler /mnt/software/lux/Reactant.jl/src/Compiler.jl:266
 [36] compile_mlir!
    @ /mnt/software/lux/Reactant.jl/src/Compiler.jl:265 [inlined]
 [37] #2
    @ /mnt/software/lux/Reactant.jl/src/Compiler.jl:260 [inlined]
 [38] context!(f::Reactant.Compiler.var"#2#3"{@Kwargs{optimize::Bool}, Recurrence{Static.False, RNNCell{Static.False, typeof(tanh), Int64, Int64, Nothing, Nothing, typeof(zeros32), Static.True}, BatchLastIndex}, Vector{Any}}, ctx::Reactant.MLIR.IR.Context)
    @ Reactant.MLIR.IR /mnt/software/lux/Reactant.jl/src/mlir/IR/Context.jl:76
 [39] #compile_mlir#1
    @ /mnt/software/lux/Reactant.jl/src/Compiler.jl:258 [inlined]
 [40] top-level scope
    @ /mnt/software/lux/Reactant.jl/src/Compiler.jl:419
 [41] eval
    @ ./boot.jl:430 [inlined]
 [42] include_string(mapexpr::typeof(REPL.softscope), mod::Module, code::String, filename::String)
    @ Base ./loading.jl:2643
 [43] invokelatest(::Any, ::Any, ::Vararg{Any}; kwargs::@Kwargs{})
    @ Base ./essentials.jl:1055
 [44] invokelatest(::Any, ::Any, ::Vararg{Any})
    @ Base ./essentials.jl:1052
 [45] inlineeval(m::Module, code::String, code_line::Int64, code_column::Int64, file::String; softscope::Bool)
    @ VSCodeServer ~/.vscode/extensions/julialang.language-julia-1.127.2/scripts/packages/VSCodeServer/src/eval.jl:271
 [46] (::VSCodeServer.var"#69#74"{Bool, Bool, Bool, Module, String, Int64, Int64, String, VSCodeServer.ReplRunCodeRequestParams})()
    @ VSCodeServer ~/.vscode/extensions/julialang.language-julia-1.127.2/scripts/packages/VSCodeServer/src/eval.jl:181
 [47] withpath(f::VSCodeServer.var"#69#74"{Bool, Bool, Bool, Module, String, Int64, Int64, String, VSCodeServer.ReplRunCodeRequestParams}, path::String)
    @ VSCodeServer ~/.vscode/extensions/julialang.language-julia-1.127.2/scripts/packages/VSCodeServer/src/repl.jl:276
 [48] (::VSCodeServer.var"#68#73"{Bool, Bool, Bool, Module, String, Int64, Int64, String, VSCodeServer.ReplRunCodeRequestParams})()
    @ VSCodeServer ~/.vscode/extensions/julialang.language-julia-1.127.2/scripts/packages/VSCodeServer/src/eval.jl:179
 [49] hideprompt(f::VSCodeServer.var"#68#73"{Bool, Bool, Bool, Module, String, Int64, Int64, String, VSCodeServer.ReplRunCodeRequestParams})
    @ VSCodeServer ~/.vscode/extensions/julialang.language-julia-1.127.2/scripts/packages/VSCodeServer/src/repl.jl:38
 [50] #67
    @ ~/.vscode/extensions/julialang.language-julia-1.127.2/scripts/packages/VSCodeServer/src/eval.jl:150 [inlined]
 [51] with_logstate(f::VSCodeServer.var"#67#72"{Bool, Bool, Bool, Module, String, Int64, Int64, String, VSCodeServer.ReplRunCodeRequestParams}, logstate::Base.CoreLogging.LogState)
    @ Base.CoreLogging ./logging/logging.jl:522
 [52] with_logger
    @ ./logging/logging.jl:632 [inlined]
 [53] (::VSCodeServer.var"#66#71"{VSCodeServer.ReplRunCodeRequestParams})()
    @ VSCodeServer ~/.vscode/extensions/julialang.language-julia-1.127.2/scripts/packages/VSCodeServer/src/eval.jl:263
 [54] #invokelatest#2
    @ ./essentials.jl:1055 [inlined]
 [55] invokelatest(::Any)
    @ Base ./essentials.jl:1052
in expression starting at /mnt/software/lux/Reactant.jl/envs/lux/rnn.jl:7

@avik-pal
Copy link
Member Author

avik-pal commented Nov 4, 2024

The other probably nicer way is to just write the init function with a copyto! into an array initialized using similar

@wsmoses
Copy link
Contributor

wsmoses commented Nov 5, 2024

@avik-pal from the error log above it looks like we don't override the in place mul method (and should).

Specifically:

  [9] mul!
    @ ~/.julia/juliaup/julia-1.11.1+0.x64.linux.gnu/share/julia/stdlib/v1.11/LinearAlgebra/src/matmul.jl:285 [inlined]
 [10] mul!
    @ ~/.julia/juliaup/julia-1.11.1+0.x64.linux.gnu/share/julia/stdlib/v1.11/LinearAlgebra/src/matmul.jl:253 [inlined]
 [11] *
    @ ~/.julia/juliaup/julia-1.11.1+0.x64.linux.gnu/share/julia/stdlib/v1.11/LinearAlgebra/src/matmul.jl:114 [inlined]
 [12] muladd(A::Reactant.TracedRArray{Float32, 2}, y::ConcreteRArray{Float32, 2}, z::Reactant.TracedRArray{Float32, 1})
    @ LinearAlgebra ~/.julia/juliaup/julia-1.11.1+0.x64.linux.gnu/share/julia/stdlib/v1.11/LinearAlgebra/src/matmul.jl:180

@avik-pal
Copy link
Member Author

avik-pal commented Nov 7, 2024

After the patches it is going to be:

Module:
module attributes {transform.with_named_sequence} {
  func.func @main(%arg0: tensor<12x2x4xf32>, %arg1: tensor<4x4xf32>, %arg2: tensor<4x4xf32>, %arg3: tensor<4xf32>, %arg4: tensor<4xf32>) -> tensor<12x4xf32> {
    %c = stablehlo.constant dense<11> : tensor<i64>
    %c_0 = stablehlo.constant dense<10> : tensor<i64>
    %c_1 = stablehlo.constant dense<9> : tensor<i64>
    %c_2 = stablehlo.constant dense<8> : tensor<i64>
    %c_3 = stablehlo.constant dense<7> : tensor<i64>
    %c_4 = stablehlo.constant dense<6> : tensor<i64>
    %c_5 = stablehlo.constant dense<5> : tensor<i64>
    %c_6 = stablehlo.constant dense<4> : tensor<i64>
    %c_7 = stablehlo.constant dense<3> : tensor<i64>
    %c_8 = stablehlo.constant dense<2> : tensor<i64>
    %c_9 = stablehlo.constant dense<1> : tensor<i64>
    %c_10 = stablehlo.constant dense<0> : tensor<i64>
    %cst = stablehlo.constant dense<0.000000e+00> : tensor<1x1xf32>
    %cst_11 = stablehlo.constant dense<0.000000e+00> : tensor<4x12xf32>
    %0 = stablehlo.transpose %arg0, dims = [2, 1, 0] : (tensor<12x2x4xf32>) -> tensor<4x2x12xf32>
    %1 = stablehlo.dynamic_update_slice %cst_11, %cst, %c_10, %c_10 : (tensor<4x12xf32>, tensor<1x1xf32>, tensor<i64>, tensor<i64>) -> tensor<4x12xf32>
    %2 = stablehlo.dynamic_update_slice %1, %cst, %c_9, %c_10 : (tensor<4x12xf32>, tensor<1x1xf32>, tensor<i64>, tensor<i64>) -> tensor<4x12xf32>
    %3 = stablehlo.dynamic_update_slice %2, %cst, %c_8, %c_10 : (tensor<4x12xf32>, tensor<1x1xf32>, tensor<i64>, tensor<i64>) -> tensor<4x12xf32>
    %4 = stablehlo.dynamic_update_slice %3, %cst, %c_7, %c_10 : (tensor<4x12xf32>, tensor<1x1xf32>, tensor<i64>, tensor<i64>) -> tensor<4x12xf32>
    %5 = stablehlo.dynamic_update_slice %4, %cst, %c_10, %c_9 : (tensor<4x12xf32>, tensor<1x1xf32>, tensor<i64>, tensor<i64>) -> tensor<4x12xf32>
    %6 = stablehlo.dynamic_update_slice %5, %cst, %c_9, %c_9 : (tensor<4x12xf32>, tensor<1x1xf32>, tensor<i64>, tensor<i64>) -> tensor<4x12xf32>
    %7 = stablehlo.dynamic_update_slice %6, %cst, %c_8, %c_9 : (tensor<4x12xf32>, tensor<1x1xf32>, tensor<i64>, tensor<i64>) -> tensor<4x12xf32>
    %8 = stablehlo.dynamic_update_slice %7, %cst, %c_7, %c_9 : (tensor<4x12xf32>, tensor<1x1xf32>, tensor<i64>, tensor<i64>) -> tensor<4x12xf32>
    %9 = stablehlo.dynamic_update_slice %8, %cst, %c_10, %c_8 : (tensor<4x12xf32>, tensor<1x1xf32>, tensor<i64>, tensor<i64>) -> tensor<4x12xf32>
    %10 = stablehlo.dynamic_update_slice %9, %cst, %c_9, %c_8 : (tensor<4x12xf32>, tensor<1x1xf32>, tensor<i64>, tensor<i64>) -> tensor<4x12xf32>
    %11 = stablehlo.dynamic_update_slice %10, %cst, %c_8, %c_8 : (tensor<4x12xf32>, tensor<1x1xf32>, tensor<i64>, tensor<i64>) -> tensor<4x12xf32>
    %12 = stablehlo.dynamic_update_slice %11, %cst, %c_7, %c_8 : (tensor<4x12xf32>, tensor<1x1xf32>, tensor<i64>, tensor<i64>) -> tensor<4x12xf32>
    %13 = stablehlo.dynamic_update_slice %12, %cst, %c_10, %c_7 : (tensor<4x12xf32>, tensor<1x1xf32>, tensor<i64>, tensor<i64>) -> tensor<4x12xf32>
    %14 = stablehlo.dynamic_update_slice %13, %cst, %c_9, %c_7 : (tensor<4x12xf32>, tensor<1x1xf32>, tensor<i64>, tensor<i64>) -> tensor<4x12xf32>
    %15 = stablehlo.dynamic_update_slice %14, %cst, %c_8, %c_7 : (tensor<4x12xf32>, tensor<1x1xf32>, tensor<i64>, tensor<i64>) -> tensor<4x12xf32>
    %16 = stablehlo.dynamic_update_slice %15, %cst, %c_7, %c_7 : (tensor<4x12xf32>, tensor<1x1xf32>, tensor<i64>, tensor<i64>) -> tensor<4x12xf32>
    %17 = stablehlo.dynamic_update_slice %16, %cst, %c_10, %c_6 : (tensor<4x12xf32>, tensor<1x1xf32>, tensor<i64>, tensor<i64>) -> tensor<4x12xf32>
    %18 = stablehlo.dynamic_update_slice %17, %cst, %c_9, %c_6 : (tensor<4x12xf32>, tensor<1x1xf32>, tensor<i64>, tensor<i64>) -> tensor<4x12xf32>
    %19 = stablehlo.dynamic_update_slice %18, %cst, %c_8, %c_6 : (tensor<4x12xf32>, tensor<1x1xf32>, tensor<i64>, tensor<i64>) -> tensor<4x12xf32>
    %20 = stablehlo.dynamic_update_slice %19, %cst, %c_7, %c_6 : (tensor<4x12xf32>, tensor<1x1xf32>, tensor<i64>, tensor<i64>) -> tensor<4x12xf32>
    %21 = stablehlo.dynamic_update_slice %20, %cst, %c_10, %c_5 : (tensor<4x12xf32>, tensor<1x1xf32>, tensor<i64>, tensor<i64>) -> tensor<4x12xf32>
    %22 = stablehlo.dynamic_update_slice %21, %cst, %c_9, %c_5 : (tensor<4x12xf32>, tensor<1x1xf32>, tensor<i64>, tensor<i64>) -> tensor<4x12xf32>
    %23 = stablehlo.dynamic_update_slice %22, %cst, %c_8, %c_5 : (tensor<4x12xf32>, tensor<1x1xf32>, tensor<i64>, tensor<i64>) -> tensor<4x12xf32>
    %24 = stablehlo.dynamic_update_slice %23, %cst, %c_7, %c_5 : (tensor<4x12xf32>, tensor<1x1xf32>, tensor<i64>, tensor<i64>) -> tensor<4x12xf32>
    %25 = stablehlo.dynamic_update_slice %24, %cst, %c_10, %c_4 : (tensor<4x12xf32>, tensor<1x1xf32>, tensor<i64>, tensor<i64>) -> tensor<4x12xf32>
    %26 = stablehlo.dynamic_update_slice %25, %cst, %c_9, %c_4 : (tensor<4x12xf32>, tensor<1x1xf32>, tensor<i64>, tensor<i64>) -> tensor<4x12xf32>
    %27 = stablehlo.dynamic_update_slice %26, %cst, %c_8, %c_4 : (tensor<4x12xf32>, tensor<1x1xf32>, tensor<i64>, tensor<i64>) -> tensor<4x12xf32>
    %28 = stablehlo.dynamic_update_slice %27, %cst, %c_7, %c_4 : (tensor<4x12xf32>, tensor<1x1xf32>, tensor<i64>, tensor<i64>) -> tensor<4x12xf32>
    %29 = stablehlo.dynamic_update_slice %28, %cst, %c_10, %c_3 : (tensor<4x12xf32>, tensor<1x1xf32>, tensor<i64>, tensor<i64>) -> tensor<4x12xf32>
    %30 = stablehlo.dynamic_update_slice %29, %cst, %c_9, %c_3 : (tensor<4x12xf32>, tensor<1x1xf32>, tensor<i64>, tensor<i64>) -> tensor<4x12xf32>
    %31 = stablehlo.dynamic_update_slice %30, %cst, %c_8, %c_3 : (tensor<4x12xf32>, tensor<1x1xf32>, tensor<i64>, tensor<i64>) -> tensor<4x12xf32>
    %32 = stablehlo.dynamic_update_slice %31, %cst, %c_7, %c_3 : (tensor<4x12xf32>, tensor<1x1xf32>, tensor<i64>, tensor<i64>) -> tensor<4x12xf32>
    %33 = stablehlo.dynamic_update_slice %32, %cst, %c_10, %c_2 : (tensor<4x12xf32>, tensor<1x1xf32>, tensor<i64>, tensor<i64>) -> tensor<4x12xf32>
    %34 = stablehlo.dynamic_update_slice %33, %cst, %c_9, %c_2 : (tensor<4x12xf32>, tensor<1x1xf32>, tensor<i64>, tensor<i64>) -> tensor<4x12xf32>
    %35 = stablehlo.dynamic_update_slice %34, %cst, %c_8, %c_2 : (tensor<4x12xf32>, tensor<1x1xf32>, tensor<i64>, tensor<i64>) -> tensor<4x12xf32>
    %36 = stablehlo.dynamic_update_slice %35, %cst, %c_7, %c_2 : (tensor<4x12xf32>, tensor<1x1xf32>, tensor<i64>, tensor<i64>) -> tensor<4x12xf32>
    %37 = stablehlo.dynamic_update_slice %36, %cst, %c_10, %c_1 : (tensor<4x12xf32>, tensor<1x1xf32>, tensor<i64>, tensor<i64>) -> tensor<4x12xf32>
    %38 = stablehlo.dynamic_update_slice %37, %cst, %c_9, %c_1 : (tensor<4x12xf32>, tensor<1x1xf32>, tensor<i64>, tensor<i64>) -> tensor<4x12xf32>
    %39 = stablehlo.dynamic_update_slice %38, %cst, %c_8, %c_1 : (tensor<4x12xf32>, tensor<1x1xf32>, tensor<i64>, tensor<i64>) -> tensor<4x12xf32>
    %40 = stablehlo.dynamic_update_slice %39, %cst, %c_7, %c_1 : (tensor<4x12xf32>, tensor<1x1xf32>, tensor<i64>, tensor<i64>) -> tensor<4x12xf32>
    %41 = stablehlo.dynamic_update_slice %40, %cst, %c_10, %c_0 : (tensor<4x12xf32>, tensor<1x1xf32>, tensor<i64>, tensor<i64>) -> tensor<4x12xf32>
    %42 = stablehlo.dynamic_update_slice %41, %cst, %c_9, %c_0 : (tensor<4x12xf32>, tensor<1x1xf32>, tensor<i64>, tensor<i64>) -> tensor<4x12xf32>
    %43 = stablehlo.dynamic_update_slice %42, %cst, %c_8, %c_0 : (tensor<4x12xf32>, tensor<1x1xf32>, tensor<i64>, tensor<i64>) -> tensor<4x12xf32>
    %44 = stablehlo.dynamic_update_slice %43, %cst, %c_7, %c_0 : (tensor<4x12xf32>, tensor<1x1xf32>, tensor<i64>, tensor<i64>) -> tensor<4x12xf32>
    %45 = stablehlo.dynamic_update_slice %44, %cst, %c_10, %c : (tensor<4x12xf32>, tensor<1x1xf32>, tensor<i64>, tensor<i64>) -> tensor<4x12xf32>
    %46 = stablehlo.dynamic_update_slice %45, %cst, %c_9, %c : (tensor<4x12xf32>, tensor<1x1xf32>, tensor<i64>, tensor<i64>) -> tensor<4x12xf32>
    %47 = stablehlo.dynamic_update_slice %46, %cst, %c_8, %c : (tensor<4x12xf32>, tensor<1x1xf32>, tensor<i64>, tensor<i64>) -> tensor<4x12xf32>
    %48 = stablehlo.dynamic_update_slice %47, %cst, %c_7, %c : (tensor<4x12xf32>, tensor<1x1xf32>, tensor<i64>, tensor<i64>) -> tensor<4x12xf32>
    %49 = stablehlo.dot_general %arg2, %48, contracting_dims = [0] x [0], precision = [DEFAULT, DEFAULT] : (tensor<4x4xf32>, tensor<4x12xf32>) -> tensor<4x12xf32>
    %50 = stablehlo.broadcast_in_dim %arg4, dims = [0] : (tensor<4xf32>) -> tensor<4x12xf32>
    %51 = stablehlo.add %49, %50 : tensor<4x12xf32>
    %52 = stablehlo.slice %0 [0:4, 0:1, 0:12] : (tensor<4x2x12xf32>) -> tensor<4x1x12xf32>
    %53 = stablehlo.transpose %52, dims = [2, 1, 0] : (tensor<4x1x12xf32>) -> tensor<12x1x4xf32>
    %54 = stablehlo.reshape %53 : (tensor<12x1x4xf32>) -> tensor<12x4xf32>
    %55 = stablehlo.dot_general %arg1, %54, contracting_dims = [0] x [1], precision = [DEFAULT, DEFAULT] : (tensor<4x4xf32>, tensor<12x4xf32>) -> tensor<4x12xf32>
    %56 = stablehlo.broadcast_in_dim %arg3, dims = [0] : (tensor<4xf32>) -> tensor<4x12xf32>
    %57 = stablehlo.add %55, %56 : tensor<4x12xf32>
    %58 = stablehlo.add %51, %57 : tensor<4x12xf32>
    %59 = stablehlo.tanh %58 : tensor<4x12xf32>
    %60 = stablehlo.dot_general %arg2, %59, contracting_dims = [0] x [0], precision = [DEFAULT, DEFAULT] : (tensor<4x4xf32>, tensor<4x12xf32>) -> tensor<4x12xf32>
    %61 = stablehlo.add %60, %50 : tensor<4x12xf32>
    %62 = stablehlo.slice %0 [0:4, 1:2, 0:12] : (tensor<4x2x12xf32>) -> tensor<4x1x12xf32>
    %63 = stablehlo.transpose %62, dims = [2, 1, 0] : (tensor<4x1x12xf32>) -> tensor<12x1x4xf32>
    %64 = stablehlo.reshape %63 : (tensor<12x1x4xf32>) -> tensor<12x4xf32>
    %65 = stablehlo.dot_general %arg1, %64, contracting_dims = [0] x [1], precision = [DEFAULT, DEFAULT] : (tensor<4x4xf32>, tensor<12x4xf32>) -> tensor<4x12xf32>
    %66 = stablehlo.add %65, %56 : tensor<4x12xf32>
    %67 = stablehlo.add %61, %66 : tensor<4x12xf32>
    %68 = stablehlo.tanh %67 : tensor<4x12xf32>
    %69 = stablehlo.transpose %68, dims = [1, 0] : (tensor<4x12xf32>) -> tensor<12x4xf32>
    return %69 : tensor<12x4xf32>
  }
}

with the mapslices PR this should be more compact

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants