Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[StableHLO] failed to legalize operation ttir.convolution #1342

Closed
mmanzoorTT opened this issue Nov 20, 2024 · 2 comments
Closed

[StableHLO] failed to legalize operation ttir.convolution #1342

mmanzoorTT opened this issue Nov 20, 2024 · 2 comments
Assignees
Labels
bug Something isn't working stablehlo conversion bug Bugs in StableHLO conversion

Comments

@mmanzoorTT
Copy link
Contributor

stablehlo.convolution op is lowered to ttir.convolution but lowering to TTNN fails because ttir.convolution op is marked illegal.

module {
  func.func @main(%arg0: tensor<1x256x512xbf16>, %arg1: tensor<1024x256x1xbf16>, %arg2: tensor<1024xbf16>) -> tensor<1x1024x512xbf16> {
    %0 = stablehlo.convolution(%arg0, %arg1) dim_numbers = [b, f, 0]x[o, i, 0]->[b, f, 0], window = {stride = [1], pad = [[0, 0]], rhs_dilate = [1]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64} : (tensor<1x256x512xbf16>, tensor<1024x256x1xbf16>) -> tensor<1x1024x512xbf16>
    %1 = stablehlo.reshape %arg2 : (tensor<1024xbf16>) -> tensor<1024x1xbf16>
    %2 = stablehlo.broadcast_in_dim %0, dims = [0, 1, 2] : (tensor<1x1024x512xbf16>) -> tensor<1x1024x512xbf16>
    %3 = stablehlo.broadcast_in_dim %1, dims = [1, 2] : (tensor<1024x1xbf16>) -> tensor<1x1024x512xbf16>
    %4 = stablehlo.add %2, %3 : tensor<1x1024x512xbf16>
    return %4 : tensor<1x1024x512xbf16>
  }
}

TTIR grap

"#any_device_tile = #tt.operand_constraint<dram|l1|tile|none|interleaved|single_bank|height_sharded|width_sharded|block_sharded|any_layout|any_device_tile>
module {
  func.func @main(%arg0: tensor<1x256x512xbf16>, %arg1: tensor<1024x256x1xbf16>, %arg2: tensor<1024xbf16>) -> tensor<1x1024x512xbf16> {
    %0 = tensor.empty() : tensor<1x1024x512xbf16>
    %1 = ""ttir.convolution""(%arg0, %arg1, %0) <{batch_group_count = 1 : i64, convolution_layout = #ttir<convolution_layout input_batch = 0, input_feature = 1, input_spatial_dimensions = 2, kernel_output_feature = 0, kernel_input_feature = 1, kernel_spatial_dimensions = 2, output_batch = 0, output_feature = 1, output_spatial_dimensions = 2>, feature_group_count = 1 : i64, input_dilation = array<i64: 1>, operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile], padding = array<i64: 0, 0>, weight_dilation = array<i64: 1>, window_reversal = array<i1: false>, window_strides = array<i64: 1>}> : (tensor<1x256x512xbf16>, tensor<1024x256x1xbf16>, tensor<1x1024x512xbf16>) -> tensor<1x1024x512xbf16>
    %2 = tensor.empty() : tensor<1024x1xbf16>
    %3 = ""ttir.reshape""(%arg2, %2) <{operand_constraints = [#any_device_tile, #any_device_tile], shape = [1024 : i32, 1 : i32]}> : (tensor<1024xbf16>, tensor<1024x1xbf16>) -> tensor<1024x1xbf16>
    %4 = tensor.empty() : tensor<1x1024x512xbf16>
    %5 = ""ttir.broadcast""(%3, %4) <{dimension = [1, 2], operand_constraints = [#any_device_tile, #any_device_tile]}> : (tensor<1024x1xbf16>, tensor<1x1024x512xbf16>) -> tensor<1x1024x512xbf16>
    %6 = tensor.empty() : tensor<1x1024x512xbf16>
    %7 = ""ttir.add""(%1, %5, %6) <{operandSegmentSizes = array<i32: 2, 1>, operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<1x1024x512xbf16>, tensor<1x1024x512xbf16>, tensor<1x1024x512xbf16>) -> tensor<1x1024x512xbf16>
    return %7 : tensor<1x1024x512xbf16>
  }
}

Error

results/mlir_tests/ttir/aten::convolution_0.mlir:5:10: error: failed to legalize operation 'ttir.convolution' that was explicitly marked illegal	"results/mlir_tests/ttir/aten::convolution_0.mlir:5:10: error: failed to legalize operation 'ttir.convolution' that was explicitly marked illegal
    %1 = ""ttir.convolution""(%arg0, %arg1, %0) <{batch_group_count = 1 : i64, convolution_layout = #ttir<convolution_layout input_batch = 0, input_feature = 1, input_spatial_dimensions = 2, kernel_output_feature = 0, kernel_input_feature = 1, kernel_spatial_dimensions = 2, output_batch = 0, output_feature = 1, output_spatial_dimensions = 2>, feature_group_count = 1 : i64, input_dilation = array<i64: 1>, operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile], padding = array<i64: 0, 0>, weight_dilation = array<i64: 1>, window_reversal = array<i1: false>, window_strides = array<i64: 1>}> : (tensor<1x256x512xbf16>, tensor<1024x256x1xbf16>, tensor<1x1024x512xbf16>) -> tensor<1x1024x512xbf16>
         ^
results/mlir_tests/ttir/aten::convolution_0.mlir:5:10: note: see current operation: %1 = ""ttir.convolution""(%arg0, %arg1, %0) <{batch_group_count = 1 : i64, convolution_layout = #ttir<convolution_layout input_batch = 0, input_feature = 1, input_spatial_dimensions = 2, kernel_output_feature = 0, kernel_input_feature = 1, kernel_spatial_dimensions = 2, output_batch = 0, output_feature = 1, output_spatial_dimensions = 2>, feature_group_count = 1 : i64, input_dilation = array<i64: 1>, operand_constraints = [#tt.operand_constraint<dram|l1|tile|none|interleaved|single_bank|height_sharded|width_sharded|block_sharded|any_layout|any_device_tile>, #tt.operand_constraint<dram|l1|tile|none|interleaved|single_bank|height_sharded|width_sharded|block_sharded|any_layout|any_device_tile>, #tt.operand_constraint<dram|l1|tile|none|interleaved|single_bank|height_sharded|width_sharded|block_sharded|any_layout|any_device_tile>], padding = array<i64: 0, 0>, weight_dilation = array<i64: 1>, window_reversal = array<i1: false>, window_strides = array<i64: 1>}> : (tensor<1x256x512xbf16>, tensor<1024x256x1xbf16>, tensor<1x1024x512xbf16>) -> tensor<1x1024x512xbf16>
@mmanzoorTT mmanzoorTT added bug Something isn't working stablehlo conversion bug Bugs in StableHLO conversion labels Nov 20, 2024
@mmanzoorTT mmanzoorTT added this to the [Third Party] HLO + XLA milestone Nov 20, 2024
@ajakovljevicTT ajakovljevicTT self-assigned this Nov 21, 2024
@ajakovljevicTT
Copy link
Contributor

ajakovljevicTT commented Nov 22, 2024

In talks with @LPanosTT, we agreed that it makes sense to reshape the 1d tensors to 2d in order to do the convolution, and then reshape the output back. This is the same as tt-metal currently does for their ttnn.Conv1d on the python side.

In addition, I will open an issue to the tt-metal folks to see if they can provide a c++ api to conv1d, which would make things easier on our side by shifting the reshapes into tt-metal. Issue: tenstorrent/tt-metal#15452

@ajakovljevicTT
Copy link
Contributor

Closed with #1438

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working stablehlo conversion bug Bugs in StableHLO conversion
Projects
None yet
Development

No branches or pull requests

2 participants