-
Notifications
You must be signed in to change notification settings - Fork 62
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Compiling Recurrent Models with Reactant #1025
Comments
Error for full context: 1-element ExceptionStack:
LoadError: MethodError: no method matching Float32(::Reactant.TracedRNumber{Float32})
The type `Float32` exists, but no method is defined for this combination of argument types when trying to construct it.
Closest candidates are:
(::Type{T})(::T) where T<:Number
@ Core boot.jl:900
Float32(::IrrationalConstants.Log2π)
@ IrrationalConstants ~/.julia/packages/IrrationalConstants/vp5v4/src/macro.jl:113
Float32(::IrrationalConstants.Halfπ)
@ IrrationalConstants ~/.julia/packages/IrrationalConstants/vp5v4/src/macro.jl:113
...
Stacktrace:
[1] convert(::Type{Float32}, x::Reactant.TracedRNumber{Float32})
@ Base ./number.jl:7
[2] unsafe_store!(p::Ptr{Float32}, x::Reactant.TracedRNumber{Float32}, i::Int64)
@ Base ./pointer.jl:180
[3] setindex!(::ConcreteRArray{Float32, 2}, ::Reactant.TracedRNumber{Float32}, ::Int64, ::Int64)
@ Reactant /mnt/software/lux/Reactant.jl/src/ConcreteRArray.jl:233
[4] macro expansion
@ ~/.julia/juliaup/julia-1.11.1+0.x64.linux.gnu/share/julia/stdlib/v1.11/LinearAlgebra/src/matmul.jl:894 [inlined]
[5] macro expansion
@ ./simdloop.jl:77 [inlined]
[6] _generic_matmatmul!(C::ConcreteRArray{Float32, 2}, A::Reactant.TracedRArray{Float32, 2}, B::ConcreteRArray{Float32, 2}, _add::LinearAlgebra.MulAddMul{true, true, Bool, Bool})
@ LinearAlgebra ~/.julia/juliaup/julia-1.11.1+0.x64.linux.gnu/share/julia/stdlib/v1.11/LinearAlgebra/src/matmul.jl:893
[7] generic_matmatmul!
@ ~/.julia/juliaup/julia-1.11.1+0.x64.linux.gnu/share/julia/stdlib/v1.11/LinearAlgebra/src/matmul.jl:868 [inlined]
[8] _mul!
@ ~/.julia/juliaup/julia-1.11.1+0.x64.linux.gnu/share/julia/stdlib/v1.11/LinearAlgebra/src/matmul.jl:287 [inlined]
[9] mul!
@ ~/.julia/juliaup/julia-1.11.1+0.x64.linux.gnu/share/julia/stdlib/v1.11/LinearAlgebra/src/matmul.jl:285 [inlined]
[10] mul!
@ ~/.julia/juliaup/julia-1.11.1+0.x64.linux.gnu/share/julia/stdlib/v1.11/LinearAlgebra/src/matmul.jl:253 [inlined]
[11] *
@ ~/.julia/juliaup/julia-1.11.1+0.x64.linux.gnu/share/julia/stdlib/v1.11/LinearAlgebra/src/matmul.jl:114 [inlined]
[12] muladd(A::Reactant.TracedRArray{Float32, 2}, y::ConcreteRArray{Float32, 2}, z::Reactant.TracedRArray{Float32, 1})
@ LinearAlgebra ~/.julia/juliaup/julia-1.11.1+0.x64.linux.gnu/share/julia/stdlib/v1.11/LinearAlgebra/src/matmul.jl:180
[13] matmuladd
@ ~/.julia/packages/LuxLib/I9RHW/src/impl/matmul.jl:12 [inlined]
[14] matmuladd
@ ~/.julia/packages/LuxLib/I9RHW/src/impl/matmul.jl:7 [inlined]
[15] fused_dense
@ ~/.julia/packages/LuxLib/I9RHW/src/impl/dense.jl:6 [inlined]
[16] fused_dense_bias_activation
@ ~/.julia/packages/LuxLib/I9RHW/src/api/dense.jl:35 [inlined]
[17] RNNCell
@ ~/.julia/packages/Lux/atwzZ/src/layers/recurrent.jl:291 [inlined]
[18] (::RNNCell{Static.False, typeof(tanh), Int64, Int64, Nothing, Nothing, typeof(zeros32), Static.True})(x::SubArray{Float32, 2, Reactant.TracedRArray{Float32, 3}, Tuple{Base.Slice{Base.OneTo{Int64}}, Int64, Base.Slice{Base.OneTo{Int64}}}, false}, ps::@NamedTuple{weight_ih::Reactant.TracedRArray{Float32, 2}, weight_hh::Reactant.TracedRArray{Float32, 2}, bias_ih::Reactant.TracedRArray{Float32, 1}, bias_hh::Reactant.TracedRArray{Float32, 1}}, st::@NamedTuple{rng::Xoshiro})
@ Lux ~/.julia/packages/Lux/atwzZ/src/layers/recurrent.jl:277
[19] apply
@ ~/.julia/packages/LuxCore/IBKvY/src/LuxCore.jl:155 [inlined]
[20] (::Recurrence{Static.False, RNNCell{Static.False, typeof(tanh), Int64, Int64, Nothing, Nothing, typeof(zeros32), Static.True}, BatchLastIndex})(x::Vector{SubArray{Float32, 2, Reactant.TracedRArray{Float32, 3}, Tuple{Base.Slice{Base.OneTo{Int64}}, Int64, Base.Slice{Base.OneTo{Int64}}}, false}}, ps::@NamedTuple{weight_ih::Reactant.TracedRArray{Float32, 2}, weight_hh::Reactant.TracedRArray{Float32, 2}, bias_ih::Reactant.TracedRArray{Float32, 1}, bias_hh::Reactant.TracedRArray{Float32, 1}}, st::@NamedTuple{rng::Xoshiro})
@ Lux ~/.julia/packages/Lux/atwzZ/src/layers/recurrent.jl:118
[21] apply
@ ~/.julia/packages/LuxCore/IBKvY/src/LuxCore.jl:155 [inlined]
[22] Recurrence
@ ~/.julia/packages/Lux/atwzZ/src/layers/recurrent.jl:114 [inlined]
[23] #apply#19
@ /mnt/software/lux/Reactant.jl/src/utils.jl:33 [inlined]
[24] apply
@ /mnt/software/lux/Reactant.jl/src/utils.jl:32 [inlined]
[25] (::Tuple{})(none::Recurrence{Static.False, RNNCell{Static.False, typeof(tanh), Int64, Int64, Nothing, Nothing, typeof(zeros32), Static.True}, BatchLastIndex}, none::Tuple{Reactant.TracedRArray{Float32, 3}, @NamedTuple{weight_ih::Reactant.TracedRArray{Float32, 2}, weight_hh::Reactant.TracedRArray{Float32, 2}, bias_ih::Reactant.TracedRArray{Float32, 1}, bias_hh::Reactant.TracedRArray{Float32, 1}}, @NamedTuple{rng::Xoshiro}})
@ Base.Experimental ./<missing>:0
[26] (::Reactant.var"#26#35"{Bool, typeof(Reactant.apply), Tuple{Recurrence{Static.False, RNNCell{Static.False, typeof(tanh), Int64, Int64, Nothing, Nothing, typeof(zeros32), Static.True}, BatchLastIndex}, ConcreteRArray{Float32, 3}, @NamedTuple{weight_ih::ConcreteRArray{Float32, 2}, weight_hh::ConcreteRArray{Float32, 2}, bias_ih::ConcreteRArray{Float32, 1}, bias_hh::ConcreteRArray{Float32, 1}}, @NamedTuple{rng::Xoshiro}}, Vector{Union{ReactantCore.MissingTracedValue, Reactant.TracedRArray, Reactant.TracedRNumber}}, Tuple{Recurrence{Static.False, RNNCell{Static.False, typeof(tanh), Int64, Int64, Nothing, Nothing, typeof(zeros32), Static.True}, BatchLastIndex}, Reactant.TracedRArray{Float32, 3}, @NamedTuple{weight_ih::Reactant.TracedRArray{Float32, 2}, weight_hh::Reactant.TracedRArray{Float32, 2}, bias_ih::Reactant.TracedRArray{Float32, 1}, bias_hh::Reactant.TracedRArray{Float32, 1}}, @NamedTuple{rng::Xoshiro}}})()
@ Reactant /mnt/software/lux/Reactant.jl/src/utils.jl:139
[27] block!(f::Reactant.var"#26#35"{Bool, typeof(Reactant.apply), Tuple{Recurrence{Static.False, RNNCell{Static.False, typeof(tanh), Int64, Int64, Nothing, Nothing, typeof(zeros32), Static.True}, BatchLastIndex}, ConcreteRArray{Float32, 3}, @NamedTuple{weight_ih::ConcreteRArray{Float32, 2}, weight_hh::ConcreteRArray{Float32, 2}, bias_ih::ConcreteRArray{Float32, 1}, bias_hh::ConcreteRArray{Float32, 1}}, @NamedTuple{rng::Xoshiro}}, Vector{Union{ReactantCore.MissingTracedValue, Reactant.TracedRArray, Reactant.TracedRNumber}}, Tuple{Recurrence{Static.False, RNNCell{Static.False, typeof(tanh), Int64, Int64, Nothing, Nothing, typeof(zeros32), Static.True}, BatchLastIndex}, Reactant.TracedRArray{Float32, 3}, @NamedTuple{weight_ih::Reactant.TracedRArray{Float32, 2}, weight_hh::Reactant.TracedRArray{Float32, 2}, bias_ih::Reactant.TracedRArray{Float32, 1}, bias_hh::Reactant.TracedRArray{Float32, 1}}, @NamedTuple{rng::Xoshiro}}}, blk::Reactant.MLIR.IR.Block)
@ Reactant.MLIR.IR /mnt/software/lux/Reactant.jl/src/mlir/IR/Block.jl:201
[28] make_mlir_fn(f::Function, args::Tuple{Recurrence{Static.False, RNNCell{Static.False, typeof(tanh), Int64, Int64, Nothing, Nothing, typeof(zeros32), Static.True}, BatchLastIndex}, ConcreteRArray{Float32, 3}, @NamedTuple{weight_ih::ConcreteRArray{Float32, 2}, weight_hh::ConcreteRArray{Float32, 2}, bias_ih::ConcreteRArray{Float32, 1}, bias_hh::ConcreteRArray{Float32, 1}}, @NamedTuple{rng::Xoshiro}}, kwargs::Tuple{}, name::String, concretein::Bool; toscalar::Bool, return_dialect::Symbol, no_args_in_result::Bool, construct_function_without_args::Bool)
@ Reactant /mnt/software/lux/Reactant.jl/src/utils.jl:112
[29] make_mlir_fn(f::Recurrence{Static.False, RNNCell{Static.False, typeof(tanh), Int64, Int64, Nothing, Nothing, typeof(zeros32), Static.True}, BatchLastIndex}, args::Vector{Any}, kwargs::Tuple{}, name::String, concretein::Bool; toscalar::Bool, return_dialect::Symbol, no_args_in_result::Bool, construct_function_without_args::Bool)
@ Reactant /mnt/software/lux/Reactant.jl/src/utils.jl:48
[30] make_mlir_fn
@ /mnt/software/lux/Reactant.jl/src/utils.jl:36 [inlined]
[31] #6
@ /mnt/software/lux/Reactant.jl/src/Compiler.jl:270 [inlined]
[32] block!(f::Reactant.Compiler.var"#6#11"{Recurrence{Static.False, RNNCell{Static.False, typeof(tanh), Int64, Int64, Nothing, Nothing, typeof(zeros32), Static.True}, BatchLastIndex}, Vector{Any}}, blk::Reactant.MLIR.IR.Block)
@ Reactant.MLIR.IR /mnt/software/lux/Reactant.jl/src/mlir/IR/Block.jl:201
[33] #5
@ /mnt/software/lux/Reactant.jl/src/Compiler.jl:269 [inlined]
[34] mmodule!(f::Reactant.Compiler.var"#5#10"{Reactant.MLIR.IR.Module, Recurrence{Static.False, RNNCell{Static.False, typeof(tanh), Int64, Int64, Nothing, Nothing, typeof(zeros32), Static.True}, BatchLastIndex}, Vector{Any}}, blk::Reactant.MLIR.IR.Module)
@ Reactant.MLIR.IR /mnt/software/lux/Reactant.jl/src/mlir/IR/Module.jl:93
[35] compile_mlir!(mod::Reactant.MLIR.IR.Module, f::Recurrence{Static.False, RNNCell{Static.False, typeof(tanh), Int64, Int64, Nothing, Nothing, typeof(zeros32), Static.True}, BatchLastIndex}, args::Vector{Any}; optimize::Bool)
@ Reactant.Compiler /mnt/software/lux/Reactant.jl/src/Compiler.jl:266
[36] compile_mlir!
@ /mnt/software/lux/Reactant.jl/src/Compiler.jl:265 [inlined]
[37] #2
@ /mnt/software/lux/Reactant.jl/src/Compiler.jl:260 [inlined]
[38] context!(f::Reactant.Compiler.var"#2#3"{@Kwargs{optimize::Bool}, Recurrence{Static.False, RNNCell{Static.False, typeof(tanh), Int64, Int64, Nothing, Nothing, typeof(zeros32), Static.True}, BatchLastIndex}, Vector{Any}}, ctx::Reactant.MLIR.IR.Context)
@ Reactant.MLIR.IR /mnt/software/lux/Reactant.jl/src/mlir/IR/Context.jl:76
[39] #compile_mlir#1
@ /mnt/software/lux/Reactant.jl/src/Compiler.jl:258 [inlined]
[40] top-level scope
@ /mnt/software/lux/Reactant.jl/src/Compiler.jl:419
[41] eval
@ ./boot.jl:430 [inlined]
[42] include_string(mapexpr::typeof(REPL.softscope), mod::Module, code::String, filename::String)
@ Base ./loading.jl:2643
[43] invokelatest(::Any, ::Any, ::Vararg{Any}; kwargs::@Kwargs{})
@ Base ./essentials.jl:1055
[44] invokelatest(::Any, ::Any, ::Vararg{Any})
@ Base ./essentials.jl:1052
[45] inlineeval(m::Module, code::String, code_line::Int64, code_column::Int64, file::String; softscope::Bool)
@ VSCodeServer ~/.vscode/extensions/julialang.language-julia-1.127.2/scripts/packages/VSCodeServer/src/eval.jl:271
[46] (::VSCodeServer.var"#69#74"{Bool, Bool, Bool, Module, String, Int64, Int64, String, VSCodeServer.ReplRunCodeRequestParams})()
@ VSCodeServer ~/.vscode/extensions/julialang.language-julia-1.127.2/scripts/packages/VSCodeServer/src/eval.jl:181
[47] withpath(f::VSCodeServer.var"#69#74"{Bool, Bool, Bool, Module, String, Int64, Int64, String, VSCodeServer.ReplRunCodeRequestParams}, path::String)
@ VSCodeServer ~/.vscode/extensions/julialang.language-julia-1.127.2/scripts/packages/VSCodeServer/src/repl.jl:276
[48] (::VSCodeServer.var"#68#73"{Bool, Bool, Bool, Module, String, Int64, Int64, String, VSCodeServer.ReplRunCodeRequestParams})()
@ VSCodeServer ~/.vscode/extensions/julialang.language-julia-1.127.2/scripts/packages/VSCodeServer/src/eval.jl:179
[49] hideprompt(f::VSCodeServer.var"#68#73"{Bool, Bool, Bool, Module, String, Int64, Int64, String, VSCodeServer.ReplRunCodeRequestParams})
@ VSCodeServer ~/.vscode/extensions/julialang.language-julia-1.127.2/scripts/packages/VSCodeServer/src/repl.jl:38
[50] #67
@ ~/.vscode/extensions/julialang.language-julia-1.127.2/scripts/packages/VSCodeServer/src/eval.jl:150 [inlined]
[51] with_logstate(f::VSCodeServer.var"#67#72"{Bool, Bool, Bool, Module, String, Int64, Int64, String, VSCodeServer.ReplRunCodeRequestParams}, logstate::Base.CoreLogging.LogState)
@ Base.CoreLogging ./logging/logging.jl:522
[52] with_logger
@ ./logging/logging.jl:632 [inlined]
[53] (::VSCodeServer.var"#66#71"{VSCodeServer.ReplRunCodeRequestParams})()
@ VSCodeServer ~/.vscode/extensions/julialang.language-julia-1.127.2/scripts/packages/VSCodeServer/src/eval.jl:263
[54] #invokelatest#2
@ ./essentials.jl:1055 [inlined]
[55] invokelatest(::Any)
@ Base ./essentials.jl:1052
in expression starting at /mnt/software/lux/Reactant.jl/envs/lux/rnn.jl:7 |
The other probably nicer way is to just write the |
@avik-pal from the error log above it looks like we don't override the in place mul method (and should). Specifically:
|
After the patches it is going to be: Module:
module attributes {transform.with_named_sequence} {
func.func @main(%arg0: tensor<12x2x4xf32>, %arg1: tensor<4x4xf32>, %arg2: tensor<4x4xf32>, %arg3: tensor<4xf32>, %arg4: tensor<4xf32>) -> tensor<12x4xf32> {
%c = stablehlo.constant dense<11> : tensor<i64>
%c_0 = stablehlo.constant dense<10> : tensor<i64>
%c_1 = stablehlo.constant dense<9> : tensor<i64>
%c_2 = stablehlo.constant dense<8> : tensor<i64>
%c_3 = stablehlo.constant dense<7> : tensor<i64>
%c_4 = stablehlo.constant dense<6> : tensor<i64>
%c_5 = stablehlo.constant dense<5> : tensor<i64>
%c_6 = stablehlo.constant dense<4> : tensor<i64>
%c_7 = stablehlo.constant dense<3> : tensor<i64>
%c_8 = stablehlo.constant dense<2> : tensor<i64>
%c_9 = stablehlo.constant dense<1> : tensor<i64>
%c_10 = stablehlo.constant dense<0> : tensor<i64>
%cst = stablehlo.constant dense<0.000000e+00> : tensor<1x1xf32>
%cst_11 = stablehlo.constant dense<0.000000e+00> : tensor<4x12xf32>
%0 = stablehlo.transpose %arg0, dims = [2, 1, 0] : (tensor<12x2x4xf32>) -> tensor<4x2x12xf32>
%1 = stablehlo.dynamic_update_slice %cst_11, %cst, %c_10, %c_10 : (tensor<4x12xf32>, tensor<1x1xf32>, tensor<i64>, tensor<i64>) -> tensor<4x12xf32>
%2 = stablehlo.dynamic_update_slice %1, %cst, %c_9, %c_10 : (tensor<4x12xf32>, tensor<1x1xf32>, tensor<i64>, tensor<i64>) -> tensor<4x12xf32>
%3 = stablehlo.dynamic_update_slice %2, %cst, %c_8, %c_10 : (tensor<4x12xf32>, tensor<1x1xf32>, tensor<i64>, tensor<i64>) -> tensor<4x12xf32>
%4 = stablehlo.dynamic_update_slice %3, %cst, %c_7, %c_10 : (tensor<4x12xf32>, tensor<1x1xf32>, tensor<i64>, tensor<i64>) -> tensor<4x12xf32>
%5 = stablehlo.dynamic_update_slice %4, %cst, %c_10, %c_9 : (tensor<4x12xf32>, tensor<1x1xf32>, tensor<i64>, tensor<i64>) -> tensor<4x12xf32>
%6 = stablehlo.dynamic_update_slice %5, %cst, %c_9, %c_9 : (tensor<4x12xf32>, tensor<1x1xf32>, tensor<i64>, tensor<i64>) -> tensor<4x12xf32>
%7 = stablehlo.dynamic_update_slice %6, %cst, %c_8, %c_9 : (tensor<4x12xf32>, tensor<1x1xf32>, tensor<i64>, tensor<i64>) -> tensor<4x12xf32>
%8 = stablehlo.dynamic_update_slice %7, %cst, %c_7, %c_9 : (tensor<4x12xf32>, tensor<1x1xf32>, tensor<i64>, tensor<i64>) -> tensor<4x12xf32>
%9 = stablehlo.dynamic_update_slice %8, %cst, %c_10, %c_8 : (tensor<4x12xf32>, tensor<1x1xf32>, tensor<i64>, tensor<i64>) -> tensor<4x12xf32>
%10 = stablehlo.dynamic_update_slice %9, %cst, %c_9, %c_8 : (tensor<4x12xf32>, tensor<1x1xf32>, tensor<i64>, tensor<i64>) -> tensor<4x12xf32>
%11 = stablehlo.dynamic_update_slice %10, %cst, %c_8, %c_8 : (tensor<4x12xf32>, tensor<1x1xf32>, tensor<i64>, tensor<i64>) -> tensor<4x12xf32>
%12 = stablehlo.dynamic_update_slice %11, %cst, %c_7, %c_8 : (tensor<4x12xf32>, tensor<1x1xf32>, tensor<i64>, tensor<i64>) -> tensor<4x12xf32>
%13 = stablehlo.dynamic_update_slice %12, %cst, %c_10, %c_7 : (tensor<4x12xf32>, tensor<1x1xf32>, tensor<i64>, tensor<i64>) -> tensor<4x12xf32>
%14 = stablehlo.dynamic_update_slice %13, %cst, %c_9, %c_7 : (tensor<4x12xf32>, tensor<1x1xf32>, tensor<i64>, tensor<i64>) -> tensor<4x12xf32>
%15 = stablehlo.dynamic_update_slice %14, %cst, %c_8, %c_7 : (tensor<4x12xf32>, tensor<1x1xf32>, tensor<i64>, tensor<i64>) -> tensor<4x12xf32>
%16 = stablehlo.dynamic_update_slice %15, %cst, %c_7, %c_7 : (tensor<4x12xf32>, tensor<1x1xf32>, tensor<i64>, tensor<i64>) -> tensor<4x12xf32>
%17 = stablehlo.dynamic_update_slice %16, %cst, %c_10, %c_6 : (tensor<4x12xf32>, tensor<1x1xf32>, tensor<i64>, tensor<i64>) -> tensor<4x12xf32>
%18 = stablehlo.dynamic_update_slice %17, %cst, %c_9, %c_6 : (tensor<4x12xf32>, tensor<1x1xf32>, tensor<i64>, tensor<i64>) -> tensor<4x12xf32>
%19 = stablehlo.dynamic_update_slice %18, %cst, %c_8, %c_6 : (tensor<4x12xf32>, tensor<1x1xf32>, tensor<i64>, tensor<i64>) -> tensor<4x12xf32>
%20 = stablehlo.dynamic_update_slice %19, %cst, %c_7, %c_6 : (tensor<4x12xf32>, tensor<1x1xf32>, tensor<i64>, tensor<i64>) -> tensor<4x12xf32>
%21 = stablehlo.dynamic_update_slice %20, %cst, %c_10, %c_5 : (tensor<4x12xf32>, tensor<1x1xf32>, tensor<i64>, tensor<i64>) -> tensor<4x12xf32>
%22 = stablehlo.dynamic_update_slice %21, %cst, %c_9, %c_5 : (tensor<4x12xf32>, tensor<1x1xf32>, tensor<i64>, tensor<i64>) -> tensor<4x12xf32>
%23 = stablehlo.dynamic_update_slice %22, %cst, %c_8, %c_5 : (tensor<4x12xf32>, tensor<1x1xf32>, tensor<i64>, tensor<i64>) -> tensor<4x12xf32>
%24 = stablehlo.dynamic_update_slice %23, %cst, %c_7, %c_5 : (tensor<4x12xf32>, tensor<1x1xf32>, tensor<i64>, tensor<i64>) -> tensor<4x12xf32>
%25 = stablehlo.dynamic_update_slice %24, %cst, %c_10, %c_4 : (tensor<4x12xf32>, tensor<1x1xf32>, tensor<i64>, tensor<i64>) -> tensor<4x12xf32>
%26 = stablehlo.dynamic_update_slice %25, %cst, %c_9, %c_4 : (tensor<4x12xf32>, tensor<1x1xf32>, tensor<i64>, tensor<i64>) -> tensor<4x12xf32>
%27 = stablehlo.dynamic_update_slice %26, %cst, %c_8, %c_4 : (tensor<4x12xf32>, tensor<1x1xf32>, tensor<i64>, tensor<i64>) -> tensor<4x12xf32>
%28 = stablehlo.dynamic_update_slice %27, %cst, %c_7, %c_4 : (tensor<4x12xf32>, tensor<1x1xf32>, tensor<i64>, tensor<i64>) -> tensor<4x12xf32>
%29 = stablehlo.dynamic_update_slice %28, %cst, %c_10, %c_3 : (tensor<4x12xf32>, tensor<1x1xf32>, tensor<i64>, tensor<i64>) -> tensor<4x12xf32>
%30 = stablehlo.dynamic_update_slice %29, %cst, %c_9, %c_3 : (tensor<4x12xf32>, tensor<1x1xf32>, tensor<i64>, tensor<i64>) -> tensor<4x12xf32>
%31 = stablehlo.dynamic_update_slice %30, %cst, %c_8, %c_3 : (tensor<4x12xf32>, tensor<1x1xf32>, tensor<i64>, tensor<i64>) -> tensor<4x12xf32>
%32 = stablehlo.dynamic_update_slice %31, %cst, %c_7, %c_3 : (tensor<4x12xf32>, tensor<1x1xf32>, tensor<i64>, tensor<i64>) -> tensor<4x12xf32>
%33 = stablehlo.dynamic_update_slice %32, %cst, %c_10, %c_2 : (tensor<4x12xf32>, tensor<1x1xf32>, tensor<i64>, tensor<i64>) -> tensor<4x12xf32>
%34 = stablehlo.dynamic_update_slice %33, %cst, %c_9, %c_2 : (tensor<4x12xf32>, tensor<1x1xf32>, tensor<i64>, tensor<i64>) -> tensor<4x12xf32>
%35 = stablehlo.dynamic_update_slice %34, %cst, %c_8, %c_2 : (tensor<4x12xf32>, tensor<1x1xf32>, tensor<i64>, tensor<i64>) -> tensor<4x12xf32>
%36 = stablehlo.dynamic_update_slice %35, %cst, %c_7, %c_2 : (tensor<4x12xf32>, tensor<1x1xf32>, tensor<i64>, tensor<i64>) -> tensor<4x12xf32>
%37 = stablehlo.dynamic_update_slice %36, %cst, %c_10, %c_1 : (tensor<4x12xf32>, tensor<1x1xf32>, tensor<i64>, tensor<i64>) -> tensor<4x12xf32>
%38 = stablehlo.dynamic_update_slice %37, %cst, %c_9, %c_1 : (tensor<4x12xf32>, tensor<1x1xf32>, tensor<i64>, tensor<i64>) -> tensor<4x12xf32>
%39 = stablehlo.dynamic_update_slice %38, %cst, %c_8, %c_1 : (tensor<4x12xf32>, tensor<1x1xf32>, tensor<i64>, tensor<i64>) -> tensor<4x12xf32>
%40 = stablehlo.dynamic_update_slice %39, %cst, %c_7, %c_1 : (tensor<4x12xf32>, tensor<1x1xf32>, tensor<i64>, tensor<i64>) -> tensor<4x12xf32>
%41 = stablehlo.dynamic_update_slice %40, %cst, %c_10, %c_0 : (tensor<4x12xf32>, tensor<1x1xf32>, tensor<i64>, tensor<i64>) -> tensor<4x12xf32>
%42 = stablehlo.dynamic_update_slice %41, %cst, %c_9, %c_0 : (tensor<4x12xf32>, tensor<1x1xf32>, tensor<i64>, tensor<i64>) -> tensor<4x12xf32>
%43 = stablehlo.dynamic_update_slice %42, %cst, %c_8, %c_0 : (tensor<4x12xf32>, tensor<1x1xf32>, tensor<i64>, tensor<i64>) -> tensor<4x12xf32>
%44 = stablehlo.dynamic_update_slice %43, %cst, %c_7, %c_0 : (tensor<4x12xf32>, tensor<1x1xf32>, tensor<i64>, tensor<i64>) -> tensor<4x12xf32>
%45 = stablehlo.dynamic_update_slice %44, %cst, %c_10, %c : (tensor<4x12xf32>, tensor<1x1xf32>, tensor<i64>, tensor<i64>) -> tensor<4x12xf32>
%46 = stablehlo.dynamic_update_slice %45, %cst, %c_9, %c : (tensor<4x12xf32>, tensor<1x1xf32>, tensor<i64>, tensor<i64>) -> tensor<4x12xf32>
%47 = stablehlo.dynamic_update_slice %46, %cst, %c_8, %c : (tensor<4x12xf32>, tensor<1x1xf32>, tensor<i64>, tensor<i64>) -> tensor<4x12xf32>
%48 = stablehlo.dynamic_update_slice %47, %cst, %c_7, %c : (tensor<4x12xf32>, tensor<1x1xf32>, tensor<i64>, tensor<i64>) -> tensor<4x12xf32>
%49 = stablehlo.dot_general %arg2, %48, contracting_dims = [0] x [0], precision = [DEFAULT, DEFAULT] : (tensor<4x4xf32>, tensor<4x12xf32>) -> tensor<4x12xf32>
%50 = stablehlo.broadcast_in_dim %arg4, dims = [0] : (tensor<4xf32>) -> tensor<4x12xf32>
%51 = stablehlo.add %49, %50 : tensor<4x12xf32>
%52 = stablehlo.slice %0 [0:4, 0:1, 0:12] : (tensor<4x2x12xf32>) -> tensor<4x1x12xf32>
%53 = stablehlo.transpose %52, dims = [2, 1, 0] : (tensor<4x1x12xf32>) -> tensor<12x1x4xf32>
%54 = stablehlo.reshape %53 : (tensor<12x1x4xf32>) -> tensor<12x4xf32>
%55 = stablehlo.dot_general %arg1, %54, contracting_dims = [0] x [1], precision = [DEFAULT, DEFAULT] : (tensor<4x4xf32>, tensor<12x4xf32>) -> tensor<4x12xf32>
%56 = stablehlo.broadcast_in_dim %arg3, dims = [0] : (tensor<4xf32>) -> tensor<4x12xf32>
%57 = stablehlo.add %55, %56 : tensor<4x12xf32>
%58 = stablehlo.add %51, %57 : tensor<4x12xf32>
%59 = stablehlo.tanh %58 : tensor<4x12xf32>
%60 = stablehlo.dot_general %arg2, %59, contracting_dims = [0] x [0], precision = [DEFAULT, DEFAULT] : (tensor<4x4xf32>, tensor<4x12xf32>) -> tensor<4x12xf32>
%61 = stablehlo.add %60, %50 : tensor<4x12xf32>
%62 = stablehlo.slice %0 [0:4, 1:2, 0:12] : (tensor<4x2x12xf32>) -> tensor<4x1x12xf32>
%63 = stablehlo.transpose %62, dims = [2, 1, 0] : (tensor<4x1x12xf32>) -> tensor<12x1x4xf32>
%64 = stablehlo.reshape %63 : (tensor<12x1x4xf32>) -> tensor<12x4xf32>
%65 = stablehlo.dot_general %arg1, %64, contracting_dims = [0] x [1], precision = [DEFAULT, DEFAULT] : (tensor<4x4xf32>, tensor<12x4xf32>) -> tensor<4x12xf32>
%66 = stablehlo.add %65, %56 : tensor<4x12xf32>
%67 = stablehlo.add %61, %66 : tensor<4x12xf32>
%68 = stablehlo.tanh %67 : tensor<4x12xf32>
%69 = stablehlo.transpose %68, dims = [1, 0] : (tensor<4x12xf32>) -> tensor<12x4xf32>
return %69 : tensor<12x4xf32>
}
} with the mapslices PR this should be more compact |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
The issue originates from the following function moving the data to a ConcreteRArray instead of a TracedRArray when run inside a compilation context. I could use
@reactant_override
to define custom dispatches when used insideReactant.compile
, but I am not sure that should be the recommended usageThe text was updated successfully, but these errors were encountered: