diff --git a/.gitignore b/.gitignore index 11cd44c593..26f2141f33 100644 --- a/.gitignore +++ b/.gitignore @@ -187,3 +187,4 @@ dmypy.json # Visual Studio Code Files .vscode .devcontainer +.vs diff --git a/docs/DevicePlacement-NNPA.md b/docs/DevicePlacement-NNPA.md index b10f8cf7d0..420028da04 100644 --- a/docs/DevicePlacement-NNPA.md +++ b/docs/DevicePlacement-NNPA.md @@ -8,7 +8,7 @@ Device placement is how the compiler place one operation on CPU or NNPA. There are two ways to know which device an operation is placed on: - Using `onnx-mlir --EmitONNXIR --maccel=NNPA model.onnx`, or -- Using `onnx-mlir --save-device-placement-file=cfg.json model.onnx`. +- Using `onnx-mlir --nnpa-save-device-placement-file=cfg.json model.onnx`. 1. Using `--EmitONNXIR --maccel=NNPA` @@ -25,7 +25,7 @@ Below is an example of the output of `--EmitONNXIR --maccel=NNPA`: %3 = "onnx.Sigmoid"(%2) {device="nnpa", onnx_node_name = "Sigmoid_0"} : (tensor) -> tensor ``` -2. Using `--save-device-placement-file=cfg.json` +2. Using `--nnpa-save-device-placement-file=cfg.json` The option is to save the device placement configuration into a JSON file. This option is convenient when users don't want to interrupt the compilation. @@ -63,15 +63,15 @@ Below is one example of a JSON file: ## Set device placement manually. -We allow users to force one opeartion to run on a specific device. However, at this moment, only placing on CPU is guaranted to be successful done. It means that even when `device=NNPA` is specified, it is not guaranted that the operation will run on NNPA. +We allow users to force one operation to run on a specific device. However, at this moment, only placing on CPU is guaranted to be successful done. It means that even when `device=NNPA` is specified, it is not guaranted that the operation will run on NNPA. There are two ways to change device of an operation: - by editing the output of `--EmitONNXIR --maccel=NNPA` directly and compile again. -- by passing a JSON file for device placement to the compiler by using `--load-device-placement-file=json`. +- by passing a JSON file for device placement to the compiler by using `--nnpa-load-device-placement-file=json`. For the former option, it is straighforward, just changing the value of the `device` attribute of an operation, for example, changing `device=nnpa` to `device=cpu`. -For the later option, users can obtain a template file from `--save-device-placement-file`, and use it as the starting point of modification. +For the later option, users can obtain a template file from `--nnpa-save-device-placement-file`, and use it as the starting point of modification. We use C++ std::regex_match function to match operations based on `node_type` and `onnx_node_name`. Both `node_type` and `onnx_node_name` must be satisfied. The JSON file will contain a list of records for each operation matching. The order of the records does matter. If one operation matches a record and is set device, it will not be set device again even when it matches the later records in the list. If one operation does not match a record but matches a later record, the operation is still set device by the later record. In other words, the device of an operation is set by the first matched record. @@ -161,7 +161,7 @@ func.func @test_load_config_file_all_on_cpu(%arg0: tensor) -> tensor< "onnx_node_name": "Sigmoid_0" }, { - "device": "nnpa", + "device": "cpu", "node_type": "onnx.Relu", "onnx_node_name": "Relu_(1|2)" } diff --git a/docs/Dialects/krnl.md b/docs/Dialects/krnl.md index 00f7624ff9..9a28e3250a 100644 --- a/docs/Dialects/krnl.md +++ b/docs/Dialects/krnl.md @@ -313,37 +313,6 @@ intend to optimize. | :----: | ----------- | «unnamed» | variadic of any type -### `krnl.dim` (KrnlDimOp) - -_Krnl dimensions operation._ - -Emits the dimension of a MemRef independent of the MemRef alloc: - -``` -"krnl.dim"(%memref, %index) -``` - -The index identifies the dimension within the shape which is going to be emitted. -Initially the krnl.dim operation depends on the alloc of the MemRef. -Unlike the std.dim operation which maintains a dependency on the alloc of the MemRef, the dimension emitted by krnl.dim will not depend on the alloc operation of the MemRef once the krnl.dim operation is lowered. - -Any changes to the original MemRef size after the krnl.dim has been lowered will not be picked up by the emitted dimension. This allows the original MemRef to be safely modified via code transformations or affine map normalization without the risk of changing the value already emitted via krnl.dim. - -Traits: MemRefsNormalizable - -#### Operands: - -| Operand | Description | -| :-----: | ----------- | -| `alloc` | memref of any type values -| `index` | index - -#### Results: - -| Result | Description | -| :----: | ----------- | -| `dimension` | index - ### `krnl.entry_point` (KrnlEntryPointOp) _Indicate ONNX entry point_ @@ -429,34 +398,6 @@ current tile being iterated over. | :----: | ----------- | | `ind_var_vals` | variadic of any type -### `krnl.getref` (KrnlGetRefOp) - -_Krnl a MemRef from within another MemRef starting at a specific offset._ - - Retrieves a MemRef from within another MemRef: - -``` - "krnl.getref"(%memref, %offset) -``` - The offset is an integer which is used as an index into the input MemRef. It works - just like an array index. - -Traits: MemRefsNormalizable - -#### Operands: - -| Operand | Description | -| :-----: | ----------- | -| `mempool` | memref of any type values -| `offset` | integer -| `value` | variadic of index - -#### Results: - -| Result | Description | -| :----: | ----------- | -| `output` | memref of any type values - ### `krnl.global` (KrnlGlobalOp) _Krnl global operation_ @@ -917,6 +858,28 @@ are nested imperfectly between an "eager" and a "lazy" loop. Traits: SingleBlock, SingleBlockImplicitTerminator +### `krnl.noValue` (KrnlNoneOp) + +_An operation representing the absence of a value._ + +This operation can be used to represent the absence of a value. It is +typically used as an argument to operators that have optional parameters, +and converted into nullptr while krnl to llvm lowering. +Typically it is used for optional arguments used in KrnlCallop. + +#### Attributes: + + + + +
AttributeMLIR TypeDescription
value::mlir::UnitAttrunit attribute
+ +#### Results: + +| Result | Description | +| :----: | ----------- | +| `none_val` | none type + ### `krnl.parallel` (KrnlParallelOp) _Mark Krnl loops as parallel loops_ @@ -1212,30 +1175,6 @@ Traits: MemRefsNormalizable | `seq` | memref of any type values | `index` | index -### `krnl.shape` (KrnlShapeOp) - -_Krnl operation to retrieve the shape of a MemRef._ - -Extracts the shape of a MemRef: -``` - "krnl.shape"(%memref) -``` -The return result is of `shape.type`. - -Traits: MemRefsNormalizable - -#### Operands: - -| Operand | Description | -| :-----: | ----------- | -| `alloc` | memref of any type values - -#### Results: - -| Result | Description | -| :----: | ----------- | -| `shape` | memref of any type values - ### `krnl.specialized_kernel` (KrnlSpecializedKernel) _Krnl specialized kernel op_ diff --git a/docs/Dialects/onnx.md b/docs/Dialects/onnx.md index 17d6d982fb..1fb17c0965 100644 --- a/docs/Dialects/onnx.md +++ b/docs/Dialects/onnx.md @@ -3883,7 +3883,7 @@ Effects: MemoryEffects::Effect{} | Operand | Description | | :-----: | ----------- | -| `X` | tensor of 32-bit float values or tensor of 64-bit float values +| `X` | tensor of bfloat16 type values or tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values or tensor of f8E4M3FN type values or tensor of f8E4M3FNUZ type values or tensor of f8E5M2 type values or tensor of f8E5M2FNUZ type values #### Results: @@ -3907,7 +3907,7 @@ Effects: MemoryEffects::Effect{} | Operand | Description | | :-----: | ----------- | -| `X` | tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values or tensor of bfloat16 type values +| `X` | tensor of bfloat16 type values or tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values or tensor of f8E4M3FN type values or tensor of f8E4M3FNUZ type values or tensor of f8E5M2 type values or tensor of f8E5M2FNUZ type values #### Results: diff --git a/docs/SupportedONNXOps-NNPA.md b/docs/SupportedONNXOps-NNPA.md index 04697e27a8..c18b2c0b32 100644 --- a/docs/SupportedONNXOps-NNPA.md +++ b/docs/SupportedONNXOps-NNPA.md @@ -3,11 +3,11 @@ # Supported ONNX Operation for Target *NNPA*. -Onnx-mlir currently supports ONNX operations targeting up to opset 19. Limitations are listed when applicable. This documentation highlights the minimum and maximum opset versions that are fully supported by onnx-mlir and not the version changes. +Onnx-mlir currently supports ONNX operations targeting up to opset 20. Limitations are listed when applicable. This documentation highlights the minimum and maximum opset versions that are fully supported by onnx-mlir and not the version changes. * Operations are defined by the [ONNX Standard](https://github.com/onnx/onnx/blob/main/docs/Operators.md). -* **Supported Opsets** indicates the lowest and highest opset a model may have for onnx-mlir to support compiling a model with the operator. - * A * indicates onnx-mlir is compatible with the latest version of that operator available as of opset 19. +* **Supported Opsets** indicates the lowest and highest opset a model may have for onnx-mlir to support compiling a model with the operator. + * A * indicates onnx-mlir is compatible with the latest version of that operator available as of opset 20. NNPA has hardware limitations in dimension index size and tensor size, which are described in [NNPALimit.h](../src/Accelerators/NNPA/Support/NNPALimit.h). They are large enough for normal use cases, but if your model exceeds the limitations, CPU is used instead of NNPA. diff --git a/docs/SupportedONNXOps-cpu.md b/docs/SupportedONNXOps-cpu.md index 7dfea58161..f6f7fdaeee 100644 --- a/docs/SupportedONNXOps-cpu.md +++ b/docs/SupportedONNXOps-cpu.md @@ -6,7 +6,7 @@ Onnx-mlir currently supports ONNX operations targeting up to opset 20. Limitations are listed when applicable. This documentation highlights the minimum and maximum opset versions that are fully supported by onnx-mlir and not the version changes. * Operations are defined by the [ONNX Standard](https://github.com/onnx/onnx/blob/main/docs/Operators.md). -* **Supported Opsets** indicates the lowest and highest opset a model may have for onnx-mlir to support compiling a model with the operator. +* **Supported Opsets** indicates the lowest and highest opset a model may have for onnx-mlir to support compiling a model with the operator. * A * indicates onnx-mlir is compatible with the latest version of that operator available as of opset 20. @@ -36,7 +36,7 @@ Onnx-mlir currently supports ONNX operations targeting up to opset 20. Limitatio | **BitwiseOr** |18 - * | | | | **BitwiseXor** |18 - * | | | | **BlackmanWindow** |none | | | | -| **Cast** |6 - 18 |Cast only between float and double types. Only ppc64le and MacOS platforms support float16. | | +| **Cast** |6 - * |Cast only between float and double types. Only ppc64le and MacOS platforms support float16. | | | **CastLike** |19 - * |CastLike only between float and double types. Only ppc64le and MacOS platforms support float16. | | | **CastMap** |none | | | | | **CategoryMapper** |none | | | | @@ -48,15 +48,15 @@ Onnx-mlir currently supports ONNX operations targeting up to opset 20. Limitatio | **Compress** |9 - * | | | | **Concat** |6 - * | | | | **ConcatFromSequence** |none | | | | -| **Constant** |6 - 18 | | | -| **ConstantOfShape** |9 - * | | | +| **Constant** |6 - * | | | +| **ConstantOfShape** |9 - 19 | | | | **Conv** |6 - * | | | | **ConvInteger** |none | | | | | **ConvTranspose** |6 - * |Unknown dimension in spatial dimensions (such as H and W) not supported. | | | **Cos** |7 - * | | | | **Cosh** |9 - * | | | | **CumSum** |11 - * | | | -| **DFT** |none | | | | +| **DFT** |17 - 19 | | | | **DeformConv** |none | | | | | **DepthToSpace** |13 - * | | | | **DequantizeLinear** |10 - * |Only support for per-tensor or layer dequantization. No support for per-axis dequantization. | | @@ -67,7 +67,7 @@ Onnx-mlir currently supports ONNX operations targeting up to opset 20. Limitatio | **DynamicQuantizeLinear** |11 - * | | | | **Einsum** |12 - * |Limited to the types supported by ReduceSum and MatMul (which we decompose to in most cases) which exclude integers with width < 32. | | | **Elu** |6 - * | | | -| **Equal** |7 - 18 | | | +| **Equal** |7 - * | | | | **Erf** |9 - * | | | | **Exp** |6 - * | | | | **Expand** |8 - * | | | @@ -98,8 +98,8 @@ Onnx-mlir currently supports ONNX operations targeting up to opset 20. Limitatio | **If** |16 - * |Sequence and Optional outputs are not supported. | | | **Imputer** |none | | | | | **InstanceNormalization** |6 - * | | | -| **IsInf** |10 - * | | | -| **IsNaN** |9 - * | | | +| **IsInf** |20 - * |Currently no support for float16 infinity value. Only for float32 and float64. | | +| **IsNaN** |20 - * | | | | **LRN** |6 - * | | | | **LSTM** |7 - * | | | | **LabelEncoder** |none | | | | @@ -142,11 +142,11 @@ Onnx-mlir currently supports ONNX operations targeting up to opset 20. Limitatio | **OptionalHasElement** |none | | | | | **Or** |7 - * | | | | **PRelu** |6 - * | | | -| **Pad** |6 - 18 |axes input not supported. | | +| **Pad** |6 - * |axes input not supported. | | | **Pow** |7 - * |No support for power with integer types. | | | **QLinearConv** |none | | | | | **QLinearMatMul** |none | | | | -| **QuantizeLinear** |10 - 18 |Do not support per-axis and i8 quantization. | | +| **QuantizeLinear** |10 - * |Do not support per-axis and i8 quantization. | | | **RNN** |7 - * | | | | **RandomNormal** |none | | | | | **RandomNormalLike** |none | | | | @@ -158,15 +158,15 @@ Onnx-mlir currently supports ONNX operations targeting up to opset 20. Limitatio | **ReduceL2** |13 - * |do_not_keep_dim not supported. | | | **ReduceLogSum** |13 - * |do_not_keep_dim not supported. | | | **ReduceLogSumExp** |13 - * |do_not_keep_dim not supported. | | -| **ReduceMax** |6 - * |do_not_keep_dim not supported. | | +| **ReduceMax** |6 - 19 |do_not_keep_dim not supported. | | | **ReduceMean** |6 - * |do_not_keep_dim not supported. | | -| **ReduceMin** |6 - * |do_not_keep_dim not supported. | | +| **ReduceMin** |6 - 19 |do_not_keep_dim not supported. | | | **ReduceProd** |13 - * |do_not_keep_dim not supported. | | | **ReduceSum** |6 - * |Default axis and do_not_keep_dim not supported. |Default axis and do_not_keep_dim temporarily removed due to changes in onnx 1.8.1. | | **ReduceSumSquare** |13 - * |Default axis and do_not_keep_dim not supported. | | | **Relu** |6 - * | | | | **Reshape** |6 - * |allowzero not supported. | | -| **Resize** |10 - 18 |Missing support for linear, cubic, crop, pytorch_half_pixel, and floor. Attributes antialias, axes and keep_aspect_ratio_policy are not supported. | | +| **Resize** |10 - * |Missing support for linear, cubic, crop, pytorch_half_pixel, and floor. Attributes antialias, axes and keep_aspect_ratio_policy are not supported. | | | **ReverseSequence** |10 - * | | | | **RoiAlign** |none | | | | | **Round** |11 - * | | | @@ -193,13 +193,13 @@ Onnx-mlir currently supports ONNX operations targeting up to opset 20. Limitatio | **Sin** |7 - * | | | | **Sinh** |9 - * | | | | **Size** |13 - * | | | -| **Slice** |13 - 18 |Axis must be a constant argument. |Add tests to slices, currently have none. | +| **Slice** |13 - * |Axis must be a constant argument. |Add tests to slices, currently have none. | | **Softmax** |6 - * | | | | **SoftmaxCrossEntropyLoss** |none | | | | | **Softplus** |6 - * | | | | **Softsign** |6 - * | | | | **SpaceToDepth** |13 - * | |Example works, the other is imprecise. To investigate. | -| **Split** |6 - 18 |Does not support static and dynamic shape, zero size splits. |Temporally removed due to changes in onnx 1.8.1. | +| **Split** |6 - * |Does not support static and dynamic shape, zero size splits. |Temporally removed due to changes in onnx 1.8.1. | | **SplitToSequence** |none | | | | | **Sqrt** |6 - * | | | | **Squeeze** |6 - * |Does not support static and dynamic shape. |Temporally removed due to changes in onnx 1.8.1. | diff --git a/docs/mnist_example/requirements.txt b/docs/mnist_example/requirements.txt index a2b91adde9..fe023b5220 100644 --- a/docs/mnist_example/requirements.txt +++ b/docs/mnist_example/requirements.txt @@ -1,4 +1,4 @@ numpy~=1.22.2 -pillow~=10.0.1 +pillow~=10.2.0 torch~=2.0.0 torchvision~=0.15.1 diff --git a/src/Accelerators/NNPA/Compiler/NNPACompilerOptions.cpp b/src/Accelerators/NNPA/Compiler/NNPACompilerOptions.cpp index eed4ae0437..a9eba08122 100644 --- a/src/Accelerators/NNPA/Compiler/NNPACompilerOptions.cpp +++ b/src/Accelerators/NNPA/Compiler/NNPACompilerOptions.cpp @@ -54,7 +54,8 @@ llvm::cl::opt nnpaLoadDevicePlacementFile{ llvm::cl::desc( "Load device placement configuration from a JSON file. To " "have a template for the JSON file, use " - "-save-device-placement-file=cfg.json. Note that we can use regex for " + "--nnpa-save-device-placement-file=cfg.json. Note that we can use " + "regex for " "string values in the JSON file to match operations. The compiler uses " "C++ std::regex_match function for matching."), llvm::cl::init(""), llvm::cl::cat(OnnxMlirOptions)}; diff --git a/src/Accelerators/NNPA/Runtime/CMakeLists.txt b/src/Accelerators/NNPA/Runtime/CMakeLists.txt index 6afd581f64..cb183a9168 100644 --- a/src/Accelerators/NNPA/Runtime/CMakeLists.txt +++ b/src/Accelerators/NNPA/Runtime/CMakeLists.txt @@ -21,6 +21,6 @@ set_target_properties(RuntimeNNPA PROPERTIES LANGUAGE C POSITION_INDEPENDENT_CODE TRUE - COMPILE_OPTIONS "-O3;-fopenmp" + COMPILE_OPTIONS "-O3" ) diff --git a/src/Accelerators/NNPA/Runtime/zDNNExtension/Elementwise.c b/src/Accelerators/NNPA/Runtime/zDNNExtension/Elementwise.c index bd8db036d6..7d81770c87 100644 --- a/src/Accelerators/NNPA/Runtime/zDNNExtension/Elementwise.c +++ b/src/Accelerators/NNPA/Runtime/zDNNExtension/Elementwise.c @@ -67,19 +67,17 @@ static zdnn_status zdnn_unary_elementwise_common(const zdnn_ztensor *input, // We split e1 or e2 in (e4, e3, e2, e1). SplitAxis axis = selectSplitAxis(input); - SplitInfo splitInfoX = {.fullZTensor = input, - .axis = axis, - .numOfElemsPerTile = OMZTensorSplitSize}; - SplitInfo splitInfoY = {.fullZTensor = output, - .axis = axis, - .numOfElemsPerTile = OMZTensorSplitSize}; - initSplitInfo(&splitInfoX, true, "UnaryElementwise X"); - initSplitInfo(&splitInfoY, true, "UnaryElementwise Y"); + uint32_t splitSize = OMZTensorSplitSize; + SplitInfo siX, siY; + initSplitInfo(&siX, input, axis, splitSize, /*allocTileBuffers=*/true, + "UnaryElementwise X"); + initSplitInfo(&siY, output, axis, splitSize, /*allocTileBuffers=*/true, + "UnaryElementwise Y"); // Copy data from input to tiles. if (OMZTensorSplitDebug) start_time = clock(); - copyData(&splitInfoX, FULL_TO_TILES); + copyData(&siX, FULL_TO_TILES); if (OMZTensorSplitDebug) { end_time = clock(); splitTime = ((float)(end_time - start_time) / (float)CLOCKS_PER_SEC) * 1000; @@ -88,23 +86,24 @@ static zdnn_status zdnn_unary_elementwise_common(const zdnn_ztensor *input, // Call zdnn op on each tile. if (OMZTensorSplitDebug) start_time = clock(); - for (uint32_t i = 0; i < splitInfoX.numOfTiles; ++i) { - zdnn_ztensor *zxTensor = splitInfoX.tiles + i; - zdnn_ztensor *zyTensor = splitInfoY.tiles + i; + for (uint32_t i = 0; i < getNumOfTiles(&siX); ++i) { + zdnn_ztensor *zx = getTile(&siX, i); + zdnn_ztensor *zy = getTile(&siY, i); zdnn_status status; if (opType == ZDNN_EXP_EXT) - status = zdnn_exp(zxTensor, zyTensor); + status = zdnn_exp(zx, zy); else if (opType == ZDNN_LOG_EXT) - status = zdnn_log(zxTensor, zyTensor); + status = zdnn_log(zx, zy); else if (opType == ZDNN_RELU_EXT) - status = zdnn_relu(zxTensor, clippingValue, zyTensor); + status = zdnn_relu(zx, clippingValue, zy); else if (opType == ZDNN_SIGMOID_EXT) - status = zdnn_sigmoid(zxTensor, zyTensor); + status = zdnn_sigmoid(zx, zy); else if (opType == ZDNN_TANH_EXT) - status = zdnn_tanh(zxTensor, zyTensor); + status = zdnn_tanh(zx, zy); else status = ZDNN_UNAVAILABLE_FUNCTION; - assert(status == ZDNN_OK); + if (status != ZDNN_OK) + return status; } if (OMZTensorSplitDebug) { end_time = clock(); @@ -115,14 +114,14 @@ static zdnn_status zdnn_unary_elementwise_common(const zdnn_ztensor *input, // Copy data from tiles to the output. if (OMZTensorSplitDebug) start_time = clock(); - copyData(&splitInfoY, TILES_TO_FULL); + copyData(&siY, TILES_TO_FULL); if (OMZTensorSplitDebug) { end_time = clock(); mergeTime = ((float)(end_time - start_time) / (float)CLOCKS_PER_SEC) * 1000; } - FreeSplitInfoData(&splitInfoX); - FreeSplitInfoData(&splitInfoY); + freeSplitInfoData(&siX); + freeSplitInfoData(&siY); if (OMZTensorSplitDebug) printf( @@ -142,24 +141,20 @@ static zdnn_status zdnn_binary_elementwise_common(const zdnn_ztensor *inputA, // We split e1 or e2 in (e4, e3, e2, e1). SplitAxis axis = selectSplitAxis(inputA); - SplitInfo splitInfoA = {.fullZTensor = inputA, - .axis = axis, - .numOfElemsPerTile = OMZTensorSplitSize}; - SplitInfo splitInfoB = {.fullZTensor = inputB, - .axis = axis, - .numOfElemsPerTile = OMZTensorSplitSize}; - SplitInfo splitInfoY = {.fullZTensor = output, - .axis = axis, - .numOfElemsPerTile = OMZTensorSplitSize}; - initSplitInfo(&splitInfoA, true, "BinaryElementwise A"); - initSplitInfo(&splitInfoB, true, "BinaryElementwise B"); - initSplitInfo(&splitInfoY, true, "BinaryElementwise Y"); + uint32_t splitSize = OMZTensorSplitSize; + SplitInfo siA, siB, siY; + initSplitInfo(&siA, inputA, axis, splitSize, /*allocTileBuffers=*/true, + "BinaryElementwise A"); + initSplitInfo(&siB, inputB, axis, splitSize, /*allocTileBuffers=*/true, + "BinaryElementwise B"); + initSplitInfo(&siY, output, axis, splitSize, /*allocTileBuffers=*/true, + "BinaryElementwise Y"); // Copy data from inputs into tiles. if (OMZTensorSplitDebug) start_time = clock(); - copyData(&splitInfoA, FULL_TO_TILES); - copyData(&splitInfoB, FULL_TO_TILES); + copyData(&siA, FULL_TO_TILES); + copyData(&siB, FULL_TO_TILES); if (OMZTensorSplitDebug) { end_time = clock(); splitTime = ((float)(end_time - start_time) / (float)CLOCKS_PER_SEC) * 1000; @@ -168,26 +163,27 @@ static zdnn_status zdnn_binary_elementwise_common(const zdnn_ztensor *inputA, // Call zdnn op on each tile. if (OMZTensorSplitDebug) start_time = clock(); - for (uint32_t i = 0; i < splitInfoA.numOfTiles; ++i) { - zdnn_ztensor *zaTensor = splitInfoA.tiles + i; - zdnn_ztensor *zbTensor = splitInfoB.tiles + i; - zdnn_ztensor *zyTensor = splitInfoY.tiles + i; + for (uint32_t i = 0; i < getNumOfTiles(&siA); ++i) { + zdnn_ztensor *za = getTile(&siA, i); + zdnn_ztensor *zb = getTile(&siB, i); + zdnn_ztensor *zy = getTile(&siY, i); zdnn_status status; if (opType == ZDNN_ADD_EXT) - status = zdnn_add(zaTensor, zbTensor, zyTensor); + status = zdnn_add(za, zb, zy); else if (opType == ZDNN_SUB_EXT) - status = zdnn_sub(zaTensor, zbTensor, zyTensor); + status = zdnn_sub(za, zb, zy); else if (opType == ZDNN_MUL_EXT) - status = zdnn_mul(zaTensor, zbTensor, zyTensor); + status = zdnn_mul(za, zb, zy); else if (opType == ZDNN_DIV_EXT) - status = zdnn_div(zaTensor, zbTensor, zyTensor); + status = zdnn_div(za, zb, zy); else if (opType == ZDNN_MAX_EXT) - status = zdnn_max(zaTensor, zbTensor, zyTensor); + status = zdnn_max(za, zb, zy); else if (opType == ZDNN_MIN_EXT) - status = zdnn_min(zaTensor, zbTensor, zyTensor); + status = zdnn_min(za, zb, zy); else status = ZDNN_UNAVAILABLE_FUNCTION; - assert(status == ZDNN_OK); + if (status != ZDNN_OK) + return status; } if (OMZTensorSplitDebug) { end_time = clock(); @@ -198,15 +194,15 @@ static zdnn_status zdnn_binary_elementwise_common(const zdnn_ztensor *inputA, // Copy data from tiles to the output. if (OMZTensorSplitDebug) start_time = clock(); - copyData(&splitInfoY, TILES_TO_FULL); + copyData(&siY, TILES_TO_FULL); if (OMZTensorSplitDebug) { end_time = clock(); mergeTime = ((float)(end_time - start_time) / (float)CLOCKS_PER_SEC) * 1000; } - FreeSplitInfoData(&splitInfoA); - FreeSplitInfoData(&splitInfoB); - FreeSplitInfoData(&splitInfoY); + freeSplitInfoData(&siA); + freeSplitInfoData(&siB); + freeSplitInfoData(&siY); if (OMZTensorSplitDebug) printf("[BinaryElementwise] split, %f, compute, %f, merge, %f " diff --git a/src/Accelerators/NNPA/Runtime/zDNNExtension/MatMul.c b/src/Accelerators/NNPA/Runtime/zDNNExtension/MatMul.c index 7afd3ecc4e..22683a2cd2 100644 --- a/src/Accelerators/NNPA/Runtime/zDNNExtension/MatMul.c +++ b/src/Accelerators/NNPA/Runtime/zDNNExtension/MatMul.c @@ -61,32 +61,25 @@ static zdnn_status zdnn_matmul_op_common(const zdnn_ztensor *inputA, // For a MatMul of A(M,N)*B(N,P)+C(P), // We split M that is e2 in (e4, e3, e2, e1), and P that is e1. - SplitInfo splitInfoA = {.fullZTensor = inputA, - .axis = E2, - .numOfElemsPerTile = OMZTensorSplitSize}; - SplitInfo splitInfoB = {.fullZTensor = inputB, - .axis = E1, - .numOfElemsPerTile = OMZTensorSplitSize}; - SplitInfo splitInfoC = {.fullZTensor = inputC, - .axis = E1, - .numOfElemsPerTile = OMZTensorSplitSize}; - SplitInfo splitInfoY = {.fullZTensor = output, - .axis = E2, - .numOfElemsPerTile = OMZTensorSplitSize}; + uint32_t splitSize = OMZTensorSplitSize; + SplitInfo siA, siB, siC, siY; + initSplitInfo( + &siA, inputA, E2, splitSize, /*allocTileBuffers=*/true, "MatMul A"); + initSplitInfo( + &siB, inputB, E1, splitSize, /*allocTileBuffers=*/true, "MatMul B"); + initSplitInfo( + &siC, inputC, E1, splitSize, /*allocTileBuffers=*/true, "MatMul C"); + initSplitInfo( + &siY, output, E2, splitSize, /*allocTileBuffers=*/true, "MatMul Y"); if (OMZTensorSplitDebug) { gettimeofday(&start_t, NULL); } - initSplitInfo(&splitInfoA, true, "MatMul A"); - initSplitInfo(&splitInfoB, true, "MatMul B"); - initSplitInfo(&splitInfoC, true, "MatMul C"); - initSplitInfo(&splitInfoY, true, "MatMul Y"); - // Copy data from A, B, C into their tiles. - copyData(&splitInfoA, FULL_TO_TILES); - copyData(&splitInfoB, FULL_TO_TILES); - copyData(&splitInfoC, FULL_TO_TILES); + copyData(&siA, FULL_TO_TILES); + copyData(&siB, FULL_TO_TILES); + copyData(&siC, FULL_TO_TILES); if (OMZTensorSplitDebug) { gettimeofday(&start_t1, NULL); @@ -94,22 +87,20 @@ static zdnn_status zdnn_matmul_op_common(const zdnn_ztensor *inputA, // Call zdnn_matmul_op on each tile. // Iterate over the tiles along the first dim of A. - for (uint32_t i = 0; i < splitInfoA.numOfTiles; ++i) { - zdnn_ztensor *zaTensor = splitInfoA.tiles + i; - zdnn_ztensor *zyTensor = splitInfoY.tiles + i; - - SplitInfo splitInfoYB = {.fullZTensor = zyTensor, - .axis = E1, - .numOfElemsPerTile = OMZTensorSplitSize}; - initSplitInfo(&splitInfoYB, true, "MatMul YB"); + for (uint32_t i = 0; i < getNumOfTiles(&siA); ++i) { + zdnn_ztensor *za = getTile(&siA, i); + zdnn_ztensor *zy = getTile(&siY, i); + SplitInfo siYB; + initSplitInfo( + &siYB, zy, E1, splitSize, /*allocTileBuffers=*/true, "MatMul YB"); // Iterate over the tiles along the second dim of B. - for (uint32_t j = 0; j < splitInfoB.numOfTiles; ++j) { - zdnn_ztensor *zbTensor = splitInfoB.tiles + j; - zdnn_ztensor *zcTensor = splitInfoC.tiles + j; - zdnn_ztensor *zybTensor = splitInfoYB.tiles + j; - zdnn_status status = call_zdnn_matmul_op( - zaTensor, zbTensor, zcTensor, opType, zybTensor, isBcast); + for (uint32_t j = 0; j < getNumOfTiles(&siB); ++j) { + zdnn_ztensor *zb = getTile(&siB, j); + zdnn_ztensor *zc = getTile(&siC, j); + zdnn_ztensor *zyb = getTile(&siYB, j); + zdnn_status status = + call_zdnn_matmul_op(za, zb, zc, opType, zyb, isBcast); assert(status == ZDNN_OK); if (OMZTensorSplitDebug) { int cpuId = 0; @@ -122,8 +113,8 @@ static zdnn_status zdnn_matmul_op_common(const zdnn_ztensor *inputA, printf("thread [%u, %u] is on cpu %d\n", i, j, cpuId); } } - copyData(&splitInfoYB, TILES_TO_FULL); - FreeSplitInfoData(&splitInfoYB); + copyData(&siYB, TILES_TO_FULL); + freeSplitInfoData(&siYB); } if (OMZTensorSplitDebug) { @@ -133,13 +124,13 @@ static zdnn_status zdnn_matmul_op_common(const zdnn_ztensor *inputA, } // Copy data from the tiles back to the full ztensor. - copyData(&splitInfoY, TILES_TO_FULL); + copyData(&siY, TILES_TO_FULL); // Free temporary buffers. - FreeSplitInfoData(&splitInfoA); - FreeSplitInfoData(&splitInfoB); - FreeSplitInfoData(&splitInfoC); - FreeSplitInfoData(&splitInfoY); + freeSplitInfoData(&siA); + freeSplitInfoData(&siB); + freeSplitInfoData(&siC); + freeSplitInfoData(&siY); if (OMZTensorSplitDebug) { gettimeofday(&end_t, NULL); diff --git a/src/Accelerators/NNPA/Runtime/zDNNExtension/Softmax.c b/src/Accelerators/NNPA/Runtime/zDNNExtension/Softmax.c index 969e6bfabc..3386dccb1b 100644 --- a/src/Accelerators/NNPA/Runtime/zDNNExtension/Softmax.c +++ b/src/Accelerators/NNPA/Runtime/zDNNExtension/Softmax.c @@ -42,19 +42,17 @@ zdnn_status zdnn_softmax_ext(const zdnn_ztensor *input, void *save_area, clock_t start_time = 0, end_time = 0; // We split e4 in (e4, e3, e2, e1) to reuse the full buffer. - SplitInfo splitInfoX = {.fullZTensor = input, - .axis = E4, - .numOfElemsPerTile = OMZTensorSplitSize}; - SplitInfo splitInfoY = {.fullZTensor = output, - .axis = E4, - .numOfElemsPerTile = OMZTensorSplitSize}; - initSplitInfo(&splitInfoX, true, "Softmax X"); - initSplitInfo(&splitInfoY, true, "Softmax Y"); + uint32_t splitSize = OMZTensorSplitSize; + SplitInfo siX, siY; + initSplitInfo( + &siX, input, E4, splitSize, /*allocTileBuffers=*/true, "Softmax X"); + initSplitInfo( + &siY, output, E4, splitSize, /*allocTileBuffers=*/true, "Softmax Y"); // Copy data from input to tiles. if (OMZTensorSplitDebug) start_time = clock(); - copyData(&splitInfoX, FULL_TO_TILES); + copyData(&siX, FULL_TO_TILES); if (OMZTensorSplitDebug) { end_time = clock(); splitTime = ((float)(end_time - start_time) / (float)CLOCKS_PER_SEC) * 1000; @@ -64,11 +62,11 @@ zdnn_status zdnn_softmax_ext(const zdnn_ztensor *input, void *save_area, // TODO: could we reuse save_area in particular in the parallel scenario? if (OMZTensorSplitDebug) start_time = clock(); - for (uint32_t i = 0; i < splitInfoX.numOfTiles; ++i) { - zdnn_ztensor *zxTensor = splitInfoX.tiles + i; - zdnn_ztensor *zyTensor = splitInfoY.tiles + i; - zdnn_status status = zdnn_softmax(zxTensor, - (splitInfoX.reuseFullZTensor) ? save_area : NULL, act_func, zyTensor); + for (uint32_t i = 0; i < getNumOfTiles(&siX); ++i) { + zdnn_ztensor *zx = getTile(&siX, i); + zdnn_ztensor *zy = getTile(&siY, i); + zdnn_status status = zdnn_softmax( + zx, (siX.reuseFullZTensor) ? save_area : NULL, act_func, zy); assert(status == ZDNN_OK); } if (OMZTensorSplitDebug) { @@ -80,14 +78,14 @@ zdnn_status zdnn_softmax_ext(const zdnn_ztensor *input, void *save_area, // Copy data from tiles to the output. if (OMZTensorSplitDebug) start_time = clock(); - copyData(&splitInfoY, TILES_TO_FULL); + copyData(&siY, TILES_TO_FULL); if (OMZTensorSplitDebug) { end_time = clock(); mergeTime = ((float)(end_time - start_time) / (float)CLOCKS_PER_SEC) * 1000; } - FreeSplitInfoData(&splitInfoX); - FreeSplitInfoData(&splitInfoY); + freeSplitInfoData(&siX); + freeSplitInfoData(&siY); if (OMZTensorSplitDebug) printf("[Softmax] split, %f, compute, %f, merge, %f (milliseconds)\n", diff --git a/src/Accelerators/NNPA/Runtime/zDNNExtension/zDNNExtension.c b/src/Accelerators/NNPA/Runtime/zDNNExtension/zDNNExtension.c index 7ee4a19e58..eea1068bcb 100644 --- a/src/Accelerators/NNPA/Runtime/zDNNExtension/zDNNExtension.c +++ b/src/Accelerators/NNPA/Runtime/zDNNExtension/zDNNExtension.c @@ -13,6 +13,7 @@ //===----------------------------------------------------------------------===// #include +#include #include #include #include @@ -58,6 +59,33 @@ static bool ZTensorSplitDebugFromEnv() { return enabled; } +// malloc_aligned_4k is from zdnn. +static void *malloc_aligned_4k(size_t size) { + // Request one more page + size of a pointer from the OS. + unsigned short extra_allocation = + (AIU_PAGESIZE_IN_BYTES - 1) + sizeof(void *); + + // Make sure size is reasonable. + if (!size || size > SIZE_MAX) { + return NULL; + } + + void *ptr = malloc(size + extra_allocation); + if (!ptr) { + perror("Error during malloc"); + fprintf(stderr, "errno = %d\n", errno); + return ptr; + } + + // Find the 4k boundary after ptr. + void *aligned_ptr = (void *)(((uintptr_t)ptr + extra_allocation) & + ~(AIU_PAGESIZE_IN_BYTES - 1)); + // Put the original malloc'd address right before aligned_ptr. + ((void **)aligned_ptr)[-1] = ptr; + + return aligned_ptr; +} + void zDNNExtensionInit() { OMZTensorSplitEnabled = ZTensorSplitEnabledFromEnv(); OMZTensorSplitDebug = ZTensorSplitDebugFromEnv(); @@ -103,27 +131,53 @@ static void getMappedShape(const zdnn_ztensor *t, MappedShape *shape) { (uint64_t)shape->d2 * (uint64_t)shape->d1 * (uint64_t)AIU_2BYTE_CELL_SIZE; uint64_t sizeFromBuffer = t->buffer_size; - assert(sizeFromDim == sizeFromBuffer && "buffer size mismatched"); + if (sizeFromDim != sizeFromBuffer) + assert(false && "buffer size mismatched"); } static uint32_t getMappedNumOfElemsPerTile(const SplitInfo *splitInfo) { // Mapping: (e4, e3, e2, e1) -> (e4, e1/64, e3, e2/32, 32, 64) switch (splitInfo->axis) { - case E4: + case (E4): return splitInfo->numOfElemsPerTile; - case E3: + case (E3): return splitInfo->numOfElemsPerTile; - case E2: + case (E2): return CEIL(splitInfo->numOfElemsPerTile, AIU_STICKS_PER_PAGE); - case E1: + case (E1): return CEIL(splitInfo->numOfElemsPerTile, AIU_2BYTE_CELLS_PER_STICK); - default: - omUnreachable(); } + omUnreachable(); return 0; } -zdnn_status initTileWithAlloc(const SplitInfo *splitInfo, uint32_t tileID) { +uint32_t getMDIS() { return zdnn_get_nnpa_max_dim_idx_size(); } + +zdnn_ztensor *getTile(const SplitInfo *splitInfo, uint32_t tileID) { + return splitInfo->tiles + tileID; +} + +uint32_t getNumOfTiles(const SplitInfo *splitInfo) { + return splitInfo->numOfTiles; +} + +zdnn_status allocTileBuffer(zdnn_ztensor *tile) { + if (!(tile->buffer = malloc_aligned_4k(tile->buffer_size))) + return ZDNN_ALLOCATION_FAILURE; + return ZDNN_OK; +} + +void freeTileBuffer(zdnn_ztensor *tile) { + if (tile->buffer) + zdnn_free_ztensor_buffer(tile); +} + +void *getTileBuffer(zdnn_ztensor *tile) { return tile->buffer; } + +void setTileBuffer(zdnn_ztensor *tile, void *buffer) { tile->buffer = buffer; } + +zdnn_status initTile( + const SplitInfo *splitInfo, uint32_t tileID, bool allocBuffer) { const zdnn_ztensor *fullZTensor = splitInfo->fullZTensor; SplitAxis axis = splitInfo->axis; @@ -223,11 +277,16 @@ zdnn_status initTileWithAlloc(const SplitInfo *splitInfo, uint32_t tileID) { if (status != ZDNN_OK) return status; + // Initialize the tile. + zdnn_init_ztensor(preTransDesc, transDesc, tile); + + // The tile is already transformed. + tile->is_transformed = true; + + // Set a buffer size for the tile. + tile->buffer_size = zdnn_getsize_ztensor(transDesc); if (splitInfo->reuseFullBuffer) { // No need to alloc buffers if reuseFullZTensor. - zdnn_init_ztensor(preTransDesc, transDesc, tile); - // Set a buffer size for the tile. - tile->buffer_size = zdnn_getsize_ztensor(transDesc); // Set a buffer for the tile. // All tiles except the last one have the same buffer size. // The offset for the last tile is simple "totalSize - lastSize". @@ -240,27 +299,26 @@ zdnn_status initTileWithAlloc(const SplitInfo *splitInfo, uint32_t tileID) { assert( ((reuseBufferOffset + tile->buffer_size) <= fullZTensor->buffer_size) && "Tile buffer is outside the original buffer"); - status = ZDNN_OK; - } else { - // Init a zTensor with malloc. - status = zdnn_init_ztensor_with_malloc(preTransDesc, transDesc, tile); + return ZDNN_OK; } - tile->is_transformed = true; - return status; + if (allocBuffer) + return allocTileBuffer(tile); + + return ZDNN_OK; } void freeTileData(const SplitInfo *splitInfo, uint32_t tileID) { - zdnn_ztensor *t = splitInfo->tiles + tileID; + zdnn_ztensor *tile = splitInfo->tiles + tileID; // Free the tile buffer if it has its own buffer. if (!splitInfo->reuseFullBuffer) - zdnn_free_ztensor_buffer(t); + freeTileBuffer(tile); // Free the tile descriptors it has its own ztensor. if (!splitInfo->reuseFullZTensor) { // We allocated one buffer for both two descriptors, so just one free is // enought. - if (t->pre_transformed_desc) - free(t->pre_transformed_desc); + if (tile->pre_transformed_desc) + free(tile->pre_transformed_desc); } } @@ -461,20 +519,16 @@ static void copyDataForTileScalar( return; } -bool initSplitInfo(SplitInfo *splitInfo, bool initTiles, const char *tag) { - // Check required information. - assert((splitInfo->axis == E1 || splitInfo->axis == E2 || - splitInfo->axis == E3 || splitInfo->axis == E4) && - "Invalid split axis"); - assert(splitInfo->fullZTensor && "The full ztensor is null"); - assert(splitInfo->numOfElemsPerTile && "numOfElemsPerTile was not set"); - - // fullZTensor. - const zdnn_ztensor *fullZTensor = splitInfo->fullZTensor; - zdnn_data_layouts layout = fullZTensor->transformed_desc->layout; +bool initSplitInfo(SplitInfo *splitInfo, const zdnn_ztensor *fullZTensor, + SplitAxis axis, uint32_t numOfElemsPerTile, bool allocTileBuffers, + const char *tag) { + splitInfo->axis = axis; + splitInfo->fullZTensor = fullZTensor; + splitInfo->numOfElemsPerTile = numOfElemsPerTile; // Splitting has not yet been supported for the following cases, so redirect // to the original zdnn function by setting splitInfo->numOfTiles = 1. + zdnn_data_layouts layout = fullZTensor->transformed_desc->layout; bool isNotSupported = (layout == ZDNN_FICO) || (layout == ZDNN_BIDIR_ZRH) || (layout == ZDNN_BIDIR_FICO) || (layout == ZDNN_ZRH) || (layout == ZDNN_4DS); @@ -484,7 +538,7 @@ bool initSplitInfo(SplitInfo *splitInfo, bool initTiles, const char *tag) { splitInfo->numOfTiles = 1; else { uint32_t totalNumOfElems = getUnmappedDim(fullZTensor, splitInfo->axis); - splitInfo->numOfTiles = CEIL(totalNumOfElems, splitInfo->numOfElemsPerTile); + splitInfo->numOfTiles = CEIL(totalNumOfElems, numOfElemsPerTile); } // reuseFullZTensor. @@ -492,7 +546,7 @@ bool initSplitInfo(SplitInfo *splitInfo, bool initTiles, const char *tag) { // No split benefit. splitInfo->reuseFullZTensor = true; splitInfo->reuseFullBuffer = true; - splitInfo->tiles = fullZTensor; + splitInfo->tiles = (zdnn_ztensor *)fullZTensor; if (OMZTensorSplitDebug) printSplitInfo(splitInfo, tag); return false; @@ -502,7 +556,7 @@ bool initSplitInfo(SplitInfo *splitInfo, bool initTiles, const char *tag) { // reuseFullBuffer. // (e4, e3, e2, e1) -> (d6=e4, d5=e1/64, d4=e3, d3=e2/32, d2=32, d1=64) splitInfo->reuseFullBuffer = false; - if (splitInfo->axis == E4) { + if (axis == E4) { // Always reuse if splitting on e4 (batchsize). splitInfo->reuseFullBuffer = true; } else { @@ -510,15 +564,15 @@ bool initSplitInfo(SplitInfo *splitInfo, bool initTiles, const char *tag) { MappedShape shapeOfFull; getMappedShape(splitInfo->fullZTensor, &shapeOfFull); if (shapeOfFull.d6 == 1) { - if (splitInfo->axis == E1) { + if (axis == E1) { splitInfo->reuseFullBuffer = true; } else { if (shapeOfFull.d5 == 1) { - if (splitInfo->axis == E3) { + if (axis == E3) { splitInfo->reuseFullBuffer = true; } else { if (shapeOfFull.d4 == 1) { - if (splitInfo->axis == E2) + if (axis == E2) splitInfo->reuseFullBuffer = true; } } @@ -531,11 +585,10 @@ bool initSplitInfo(SplitInfo *splitInfo, bool initTiles, const char *tag) { splitInfo->tiles = malloc(splitInfo->numOfTiles * sizeof(zdnn_ztensor)); assert(splitInfo->tiles && "Failed to allocate tile ztensors"); - if (initTiles) { - for (uint32_t i = 0; i < splitInfo->numOfTiles; ++i) { - zdnn_status status = initTileWithAlloc(splitInfo, i); - assert(status == ZDNN_OK && "Failed to initialize a tile"); - } + for (uint32_t i = 0; i < splitInfo->numOfTiles; ++i) { + zdnn_status status = initTile(splitInfo, i, allocTileBuffers); + if (status != ZDNN_OK) + assert(false && "Failed to initialize a tile"); } if (OMZTensorSplitDebug) @@ -544,7 +597,7 @@ bool initSplitInfo(SplitInfo *splitInfo, bool initTiles, const char *tag) { return true; } -void FreeSplitInfoData(SplitInfo *splitInfo) { +void freeSplitInfoData(SplitInfo *splitInfo) { if (splitInfo->reuseFullZTensor) return; diff --git a/src/Accelerators/NNPA/Runtime/zDNNExtension/zDNNExtension.h b/src/Accelerators/NNPA/Runtime/zDNNExtension/zDNNExtension.h index 12a2c7655b..11304c3f6b 100644 --- a/src/Accelerators/NNPA/Runtime/zDNNExtension/zDNNExtension.h +++ b/src/Accelerators/NNPA/Runtime/zDNNExtension/zDNNExtension.h @@ -147,6 +147,13 @@ inline void omUnreachable() { */ void getUnmappedShape(const zdnn_ztensor *t, UnmappedShape *shape); +/** + * \brief Get the NNPA maximum dimension index size. + * + * @return the NNPA maximum dimension index size. + */ +uint32_t getMDIS(); + /** * \brief Initialize a SplitInfo struct. * @@ -156,25 +163,19 @@ void getUnmappedShape(const zdnn_ztensor *t, UnmappedShape *shape); * numOfElemsPerTile to be defined as they are effectively input parameters of * this function. * - * Make sure to call FreeSplitInfoData to free buffers. + * Make sure to call freeSplitInfoData to free buffers. * * @param splitInfo information for splitting - * @param initTiles whether initialize ztensors for tiles or not. + * @param splitInfo the full ztensor that will be splitted + * @param axis dimension to split fullZTensor + * @param numOfElemsPerTile value is used to split the axis equally + * @param allocTileBuffers whether alloc buffers for the ztensor tiles or not * @param tag a string to use when printing debug info - * @return true if the ztensor is splitable. Otherwise, false + * @return true if the full ztensor is splitable. Otherwise, false */ -bool initSplitInfo(SplitInfo *splitInfo, bool initTile, const char *tag); - -/** - * \brief Initialize a SplitInfo struct. - * - * This will initialize a ztensor for a specific tile. - * - * @param splitInfo information for splitting - * @param tileID the id of a tile in the range of [0, numOfTiles - 1] - * @return zdnn_status - */ -zdnn_status initTileWithAlloc(const SplitInfo *splitInfo, uint32_t tileID); +bool initSplitInfo(SplitInfo *splitInfo, const zdnn_ztensor *fullZTensor, + SplitAxis axis, uint32_t numOfElemsPerTile, bool allocTileBuffers, + const char *tag); /** * \brief Free ztensor tile data. @@ -193,7 +194,7 @@ void freeTileData(const SplitInfo *splitInfo, uint32_t tileID); * * @param splitInfo split information */ -void FreeSplitInfoData(SplitInfo *splitInfo); +void freeSplitInfoData(SplitInfo *splitInfo); /** * \brief Print SplitInfo. @@ -202,6 +203,53 @@ void FreeSplitInfoData(SplitInfo *splitInfo); */ void printSplitInfo(const SplitInfo *splitInfo, const char *tag); +/** + * \brief Allocate memory for the tile buffer. + * + * @param tile tile->buffer will point to the allocated buffer + * @return zdnn_status + */ +zdnn_status allocTileBuffer(zdnn_ztensor *tile); + +/** + * \brief Free memory for the tile buffer. + * + * @param tile ztensor tile + */ +void freeTileBuffer(zdnn_ztensor *tile); + +/** + * \brief Get a pointer pointing to the tile buffer. + * + * @param tile ztensor of the tile + * @return a pointer to the tile buffer + */ +void *getTileBuffer(zdnn_ztensor *tile); + +/** + * \brief Set the tile buffer pointing to the given buffer. + * + * @param tile ztensor of the tile. + */ +void setTileBuffer(zdnn_ztensor *tile, void *buffer); + +/** + * \brief Get a pointer pointing to a tile. + * + * @param splitInfo information for splitting + * @param tileID the id of a tile in the range of [0, numOfTiles - 1] + * @return a pointer to the tile. + */ +zdnn_ztensor *getTile(const SplitInfo *splitInfo, uint32_t tileID); + +/** + * \brief Get the number of tiles. + * + * @param splitInfo information for splitting + * @return the number of tiles. + */ +uint32_t getNumOfTiles(const SplitInfo *splitInfo); + /** * \brief Copy data between the full ztensor and its tiles. * diff --git a/src/Builder/OpBuildTable.inc b/src/Builder/OpBuildTable.inc index e1ebf9c53b..fa8e8f8926 100644 --- a/src/Builder/OpBuildTable.inc +++ b/src/Builder/OpBuildTable.inc @@ -90,8 +90,8 @@ op_dialect_version_map_["Identity"] = {19}; op_dialect_version_map_["If"] = {19}; op_dialect_version_map_["Imputer"] = {1}; op_dialect_version_map_["InstanceNormalization"] = {6}; -op_dialect_version_map_["IsInf"] = {10}; -op_dialect_version_map_["IsNaN"] = {13}; +op_dialect_version_map_["IsInf"] = {20}; +op_dialect_version_map_["IsNaN"] = {20}; op_dialect_version_map_["LayerNormalization"] = {17}; op_dialect_version_map_["LRN"] = {13}; op_dialect_version_map_["LSTM"] = {14}; diff --git a/src/Compiler/CompilerPasses.cpp b/src/Compiler/CompilerPasses.cpp index 9cd51e0dd4..e3322c49f2 100644 --- a/src/Compiler/CompilerPasses.cpp +++ b/src/Compiler/CompilerPasses.cpp @@ -193,9 +193,6 @@ void addONNXToKrnlPasses(mlir::PassManager &pm, int optLevel, bool enableCSE, // from ONNX dialect to Standard dialect exposes additional canonicalization // opportunities. pm.addPass(mlir::createCanonicalizerPass()); - pm.addNestedPass( - onnx_mlir::createDisconnectKrnlDimFromAllocPass()); - pm.addPass(mlir::createCanonicalizerPass()); } void addKrnlToAffinePasses(mlir::PassManager &pm) { @@ -315,4 +312,4 @@ void addPasses(mlir::OwningOpRef &module, mlir::PassManager &pm, addKrnlToLLVMPasses(pm, outputNameNoExt, /*enableCSE=*/true); } -} // namespace onnx_mlir \ No newline at end of file +} // namespace onnx_mlir diff --git a/src/Compiler/CompilerUtils.cpp b/src/Compiler/CompilerUtils.cpp index 3768644c2b..a4dc5ab7ba 100644 --- a/src/Compiler/CompilerUtils.cpp +++ b/src/Compiler/CompilerUtils.cpp @@ -653,8 +653,10 @@ static void outputModule(mlir::OwningOpRef &module, raw_ostream &os, mlir::OpPrintingFlags flags; if (preserveLocations) flags.enableDebugInfo(); - if (largeElementLimit >= 0) + if (largeElementLimit >= 0) { flags.elideLargeElementsAttrs(largeElementLimit); + flags.elideLargeResourceString(largeElementLimit); + } module->print(os, flags); } diff --git a/src/Conversion/KrnlToAffine/ConvertKrnlToAffine.cpp b/src/Conversion/KrnlToAffine/ConvertKrnlToAffine.cpp index b5e70b3c73..9bb6f27439 100644 --- a/src/Conversion/KrnlToAffine/ConvertKrnlToAffine.cpp +++ b/src/Conversion/KrnlToAffine/ConvertKrnlToAffine.cpp @@ -773,8 +773,6 @@ void ConvertKrnlToAffinePass::runOnOperation() { ConversionTarget target(*ctx); // Legal/illegal ops. target.addIllegalOp(); - // krnl.dim operations must be lowered prior to this pass. - target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); diff --git a/src/Conversion/KrnlToLLVM/CMakeLists.txt b/src/Conversion/KrnlToLLVM/CMakeLists.txt index f6f2703be2..52a583552f 100644 --- a/src/Conversion/KrnlToLLVM/CMakeLists.txt +++ b/src/Conversion/KrnlToLLVM/CMakeLists.txt @@ -5,7 +5,6 @@ add_onnx_mlir_library(OMKrnlToLLVM KrnlFindIndex.cpp KrnlCall.cpp KrnlEntryPoint.cpp - KrnlGetRef.cpp KrnlGlobal.cpp KrnlInstrument.cpp KrnlMemcpy.cpp diff --git a/src/Conversion/KrnlToLLVM/ConvertKrnlToLLVM.cpp b/src/Conversion/KrnlToLLVM/ConvertKrnlToLLVM.cpp index 9b71bce504..3aacc57ff0 100644 --- a/src/Conversion/KrnlToLLVM/ConvertKrnlToLLVM.cpp +++ b/src/Conversion/KrnlToLLVM/ConvertKrnlToLLVM.cpp @@ -956,7 +956,6 @@ void populateKrnlToLLVMConversion(LLVMTypeConverter &typeConverter, krnl::populateLoweringKrnlCallOpPattern(typeConverter, patterns, ctx); krnl::populateLoweringKrnlFindIndexOpPattern(typeConverter, patterns, ctx); krnl::populateLoweringKrnlGlobalOpPattern(typeConverter, patterns, ctx); - krnl::populateLoweringKrnlGetRefOpPattern(typeConverter, patterns, ctx); krnl::populateLoweringKrnlInstrumentOpPattern(typeConverter, patterns, ctx); krnl::populateLoweringKrnlMemcpyOpPattern(typeConverter, patterns, ctx); krnl::populateLoweringKrnlPrintOpPattern(typeConverter, patterns, ctx); diff --git a/src/Conversion/KrnlToLLVM/ConvertKrnlToLLVM.hpp b/src/Conversion/KrnlToLLVM/ConvertKrnlToLLVM.hpp index 6a82708bac..d19f971942 100644 --- a/src/Conversion/KrnlToLLVM/ConvertKrnlToLLVM.hpp +++ b/src/Conversion/KrnlToLLVM/ConvertKrnlToLLVM.hpp @@ -67,9 +67,6 @@ void populateLoweringKrnlFindIndexOpPattern( mlir::LLVMTypeConverter &typeConverter, mlir::RewritePatternSet &patterns, mlir::MLIRContext *ctx); -void populateLoweringKrnlGetRefOpPattern(mlir::LLVMTypeConverter &typeConverter, - mlir::RewritePatternSet &patterns, mlir::MLIRContext *ctx); - void populateLoweringKrnlGlobalOpPattern(mlir::LLVMTypeConverter &typeConverter, mlir::RewritePatternSet &patterns, mlir::MLIRContext *ctx); diff --git a/src/Conversion/KrnlToLLVM/KrnlGetRef.cpp b/src/Conversion/KrnlToLLVM/KrnlGetRef.cpp deleted file mode 100644 index 55ec38411c..0000000000 --- a/src/Conversion/KrnlToLLVM/KrnlGetRef.cpp +++ /dev/null @@ -1,172 +0,0 @@ -/* - * SPDX-License-Identifier: Apache-2.0 - */ - -//===------ KrnlGetRefOp.cpp - Lower KrnlGetRefOp -------------------------===// -// -// Copyright 2019-2022 The IBM Research Authors. -// -// ============================================================================= -// -// This file lowers the KrnlGetRefOp operator. -// -//===----------------------------------------------------------------------===// - -#include "mlir/Conversion/LLVMCommon/Pattern.h" -#include "mlir/Conversion/LLVMCommon/TypeConverter.h" -#include "mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/Dialect/LLVMIR/LLVMDialect.h" - -#include "src/Conversion/KrnlToLLVM/KrnlToLLVMHelper.hpp" -#include "src/Dialect/Krnl/KrnlHelper.hpp" -#include "src/Dialect/Krnl/KrnlOps.hpp" -#include "src/Dialect/Mlir/DialectBuilder.hpp" -#include "src/Support/KrnlSupport.hpp" - -#include "llvm/Support/Debug.h" - -#define DEBUG_TYPE "krnl_to_llvm" - -using namespace mlir; - -namespace onnx_mlir { -namespace krnl { - -class KrnlGetRefOpLowering : public ConvertToLLVMPattern { -public: - using ConvertToLLVMPattern::createIndexAttrConstant; - using ConvertToLLVMPattern::getIndexType; - - explicit KrnlGetRefOpLowering( - LLVMTypeConverter &typeConverter, MLIRContext *context) - : ConvertToLLVMPattern( - KrnlGetRefOp::getOperationName(), context, typeConverter) {} - - LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, - ConversionPatternRewriter &rewriter) const override { - Location loc = op->getLoc(); - MLIRContext *context = rewriter.getContext(); - MultiDialectBuilder create(rewriter, loc); - - KrnlGetRefOpAdaptor operandAdaptor(operands); - - // This is the type of the krnl.getref output. This type is used - // for the type of the internal MemRef. - auto type = op->getResult(0).getType(); - auto memRefTy = type.cast(); - - // auto llvmMemRefType = typeConverter->convertType(type).cast(); - auto outputElementType = - typeConverter->convertType(memRefTy.getElementType()); - - // This is the start of the memory pool containing the output MemRef. - Type memPoolType = operandAdaptor.getMempool() - .getType() - .cast() - .getBody()[1]; - Value alignedMemPoolBase = - create.llvm.extractValue(memPoolType, operandAdaptor.getMempool(), {1}); - - // Get pointer using the offset. - auto llvmOutputElementType = outputElementType.cast(); - auto offset = operandAdaptor.getOffset(); - auto llvmMemPoolType = typeConverter->convertType(memPoolType).cast(); - auto outputMemPoolTypePtrAlloc = - create.llvm.getElemPtr(llvmMemPoolType, llvmOutputElementType, - alignedMemPoolBase, ArrayRef{offset}); - - // Bitcast to output MemRef type i.e. from i8* to the element type - // of the output MemRef. - Value outputTypedPtrAlloc = - create.llvm.bitcast(getPointerType(context, llvmOutputElementType), - outputMemPoolTypePtrAlloc); - - // Handle the static case. - if (hasAllConstantDimensions(memRefTy)) { - // Create llvm MemRef from original MemRef and fill the data pointers. - auto llvmMemRef = MemRefDescriptor::fromStaticShape( - rewriter, loc, *getTypeConverter(), memRefTy, outputTypedPtrAlloc); - - rewriter.replaceOp(op, {llvmMemRef}); - return success(); - } - - // Handle the dynamic case. - - // Compute strides and offset based on MemRef type. - int64_t alignmentOffset; - SmallVector strides; - auto successStrides = - getStridesAndOffset(memRefTy, strides, alignmentOffset); - (void)successStrides; - assert(succeeded(successStrides) && "unexpected non-strided memref"); - - // Create the memRef descriptor. - auto structType = typeConverter->convertType(memRefTy); - auto memRefDescriptor = MemRefDescriptor::undef(rewriter, loc, structType); - - // Allocated pointer, used for malloc/free. - memRefDescriptor.setAllocatedPtr(rewriter, loc, outputTypedPtrAlloc); - - // Actual aligned pointer to payload. - // TODO: support aligned MemRefs. - memRefDescriptor.setAlignedPtr(rewriter, loc, outputTypedPtrAlloc); - - // Offset in aligned pointer. - // TODO: support non-zero here in the aligned case. - - Type indexType = getIndexType(); - memRefDescriptor.setOffset( - rewriter, loc, createIndexAttrConstant(rewriter, loc, indexType, 0)); - - if (memRefTy.getRank() != 0) { - // Prepare sizes. - SmallVector sizes; - sizes.reserve(memRefTy.getRank()); - unsigned i = 0; - for (int64_t s : memRefTy.getShape()) - sizes.push_back( - s == ShapedType::kDynamic - ? operands[2 + i++] - : createIndexAttrConstant(rewriter, loc, indexType, s)); - - // Store all sizes in the descriptor. Only dynamic sizes are passed in as - // operands to AllocOp. - Value runningStride = nullptr; - auto nStrides = strides.size(); - SmallVector strideValues(nStrides, nullptr); - for (unsigned i = 0; i < nStrides; ++i) { - int64_t index = nStrides - 1 - i; - if (strides[index] == ShapedType::kDynamic) - // Identity layout map is enforced in the match function, so we - // compute: - // `runningStride *= sizes[index + 1]` - runningStride = runningStride ? rewriter.create(loc, - runningStride, sizes[index + 1]) - : createIndexAttrConstant( - rewriter, loc, indexType, 1); - else - runningStride = - createIndexAttrConstant(rewriter, loc, indexType, strides[index]); - strideValues[index] = runningStride; - } - // Fill size and stride descriptors in memref. - for (auto indexedSize : llvm::enumerate(sizes)) { - int64_t index = indexedSize.index(); - memRefDescriptor.setSize(rewriter, loc, index, indexedSize.value()); - memRefDescriptor.setStride(rewriter, loc, index, strideValues[index]); - } - } - - rewriter.replaceOp(op, {memRefDescriptor}); - return success(); - } -}; - -void populateLoweringKrnlGetRefOpPattern(LLVMTypeConverter &typeConverter, - RewritePatternSet &patterns, MLIRContext *ctx) { - patterns.insert(typeConverter, ctx); -} - -} // namespace krnl -} // namespace onnx_mlir diff --git a/src/Conversion/KrnlToLLVM/KrnlGlobal.cpp b/src/Conversion/KrnlToLLVM/KrnlGlobal.cpp index 0feccc1498..d888e36f94 100644 --- a/src/Conversion/KrnlToLLVM/KrnlGlobal.cpp +++ b/src/Conversion/KrnlToLLVM/KrnlGlobal.cpp @@ -289,24 +289,42 @@ class KrnlGlobalOpLowering : public ConvertToLLVMPattern { Location loc = krnlGlobalOp.getLoc(); MultiDialectBuilder create(builder, loc); - ModuleOp module = krnlGlobalOp->getParentOfType(); DenseElementsAttr denseAttr = krnlGlobalOp.getValue().value().cast(); Type i8PtrType = getI8PointerType(builder.getContext()); - // Generate LLVM GlobalOps for each string in the KrnlGlobalOp dense - // attribute. - SmallVector globalOps; - for (StringRef str : denseAttr.getValues()) { - LLVM::GlobalOp globalOp = krnl::getOrCreateGlobalString( - str, loc, builder, module, getTypeConverter()); - globalOps.push_back(globalOp); + auto strs = denseAttr.getValues(); + // Collect total size of the strs. + size_t totalSize = 0; + for (StringRef str : strs) { + // Add 1 for the null terminator. + totalSize += str.size() + 1; + } + + // Concatenate all strings into one. + std::vector concatStr(totalSize); + size_t offset = 0; + std::vector offsets; + for (StringRef str : strs) { + offsets.emplace_back(offset); + std::copy(str.begin(), str.end(), concatStr.begin() + offset); + concatStr[offset + str.size()] = '\0'; + offset += str.size() + 1; } + // Create a global for the concatenated string. + StringRef data(concatStr.data(), concatStr.size()); + StringAttr llvmStringAttr = StringAttr::get(builder.getContext(), data); + auto i8Type = IntegerType::get(builder.getContext(), 8); + auto llvmArrayI8Ty = LLVM::LLVMArrayType::get(i8Type, totalSize); + LLVM::GlobalOp globalStr = create.llvm.globalOp(llvmArrayI8Ty, + /*isConstant=*/true, LLVM::Linkage::Internal, + "om.strArray." + krnlGlobalOp.getName().str(), llvmStringAttr); + // Generate an LLVM GlobalOps with an initializer region containing one // block. - auto arrayType = LLVM::LLVMArrayType::get(i8PtrType, globalOps.size()); + auto arrayType = LLVM::LLVMArrayType::get(i8PtrType, offsets.size()); auto global = create.llvm.globalOp(arrayType, /*isConstant=*/true, LLVM::Linkage::Internal, krnlGlobalOp.getName(), Attribute()); @@ -319,10 +337,15 @@ class KrnlGlobalOpLowering : public ConvertToLLVMPattern { int32_t index = 0; Value lastValue = array; - for (const LLVM::GlobalOp &globalOp : globalOps) { - Value strAddr = krnl::getPtrToGlobalString(globalOp, loc, builder); + Value baseAddr = create.llvm.addressOf(globalStr); + // Cast globalStr to i8Ptr. + baseAddr = create.llvm.bitcast(i8PtrType, baseAddr); + for (size_t offset : offsets) { + // Get each str with gep base, offset. + Value gepOp = + create.llvm.getElemPtr(i8PtrType, i8Type, baseAddr, {offset}); lastValue = - create.llvm.insertValue(arrayType, lastValue, strAddr, {index++}); + create.llvm.insertValue(arrayType, lastValue, gepOp, {index++}); } create.llvm._return(lastValue); diff --git a/src/Conversion/ONNXToKrnl/ML/CategoryMapper.cpp b/src/Conversion/ONNXToKrnl/ML/CategoryMapper.cpp index 5fc3b722bb..067412c130 100644 --- a/src/Conversion/ONNXToKrnl/ML/CategoryMapper.cpp +++ b/src/Conversion/ONNXToKrnl/ML/CategoryMapper.cpp @@ -257,20 +257,37 @@ struct ONNXCategoryMapperOpLowering .Case( [&](IntegerType) { inputElem = createKrnl.load(memref, loopInd); }) .Case([&](krnl::StringType stringType) { - MathBuilder createMath(createKrnl); - Value zero = createMath.constant( - createMath.getBuilder().getIntegerType(64), 0); ArrayRef shape = memref.getType().cast().getShape(); SmallVector newShape; - for (uint64_t i = 0; i < shape.size(); i++) - newShape.emplace_back( - (shape[i] == ShapedType::kDynamic) ? 1 : shape[i]); - auto memRefType = MemRefType::get( - newShape, krnl::StringType::get(elementType.getContext())); - // Sole use of krnl.getRef. - Value stringMemRef = createKrnl.getRef(memRefType, memref, zero); - inputElem = createKrnl.load(stringMemRef, loopInd); + bool hasDynamicDim = false; + for (uint64_t i = 0; i < shape.size(); i++) { + if (shape[i] == ShapedType::kDynamic) { + newShape.emplace_back(1); + hasDynamicDim = true; + } else { + newShape.emplace_back(shape[i]); + } + } + if (!hasDynamicDim) { + inputElem = createKrnl.load(memref, loopInd); + } else { + MemRefBuilder createMemRef(createKrnl); + MemRefType memRefType = MemRefType::get( + newShape, krnl::StringType::get(elementType.getContext())); + SmallVector offsets(shape.size(), 0); + SmallVector strides; + int64_t alignmentOffset; // not used, just to make the function call + // completed. + if (getStridesAndOffset(memRefType, strides, alignmentOffset) + .failed()) + llvm_unreachable("Failed to get strides"); + Value stringMemRef = + createMemRef + .subView(memRefType, memref, offsets, newShape, strides) + .getResult(); + inputElem = createKrnl.load(stringMemRef, loopInd); + } }) .Default([&](Type type) { llvm::errs() << "type: " << type << "\n"; diff --git a/src/Conversion/ONNXToKrnl/Math/Reduction.cpp b/src/Conversion/ONNXToKrnl/Math/Reduction.cpp index f8288f5708..7faf3cd63f 100644 --- a/src/Conversion/ONNXToKrnl/Math/Reduction.cpp +++ b/src/Conversion/ONNXToKrnl/Math/Reduction.cpp @@ -487,7 +487,7 @@ struct ONNXReductionOpLowering : public OpConversionPattern { bool parallelSimd = false; int64_t innermostLoopCollapse = 0; int64_t VL = 0; - int64_t estimatedSimdLoopTripCount; + int64_t estimatedSimdLoopTripCount = 0; // With dynamic axes, use this Value maskVal = nullptr; diff --git a/src/Conversion/ONNXToStablehlo/CMakeLists.txt b/src/Conversion/ONNXToStablehlo/CMakeLists.txt index 690c58ef44..4b1b0ec002 100644 --- a/src/Conversion/ONNXToStablehlo/CMakeLists.txt +++ b/src/Conversion/ONNXToStablehlo/CMakeLists.txt @@ -56,6 +56,7 @@ add_onnx_mlir_library(OMONNXToStablehlo Tensor/Concat.cpp Tensor/Constant.cpp Tensor/DepthToSpace.cpp + Tensor/Dim.cpp Tensor/Expand.cpp Tensor/Flatten.cpp Tensor/Gather.cpp diff --git a/src/Conversion/ONNXToStablehlo/ConvertONNXToStablehlo.cpp b/src/Conversion/ONNXToStablehlo/ConvertONNXToStablehlo.cpp index 74ea09a3dc..1550214d60 100644 --- a/src/Conversion/ONNXToStablehlo/ConvertONNXToStablehlo.cpp +++ b/src/Conversion/ONNXToStablehlo/ConvertONNXToStablehlo.cpp @@ -41,6 +41,7 @@ void populateONNXToStablehloConversionPattern( populateLoweringONNXConcatOpToStablehloPattern(patterns, ctx); populateLoweringONNXConstantOpToStablehloPattern(patterns, ctx); populateLoweringONNXDepthToSpaceOpToStablehloPattern(patterns, ctx); + populateLoweringONNXDimOpToStablehloPattern(patterns, ctx); populateLoweringONNXExpandOpToStablehloPattern(patterns, ctx); populateLoweringONNXFlattenOpToStablehloPattern(patterns, ctx); populateLoweringONNXGatherOpToStablehloPattern(patterns, ctx); @@ -87,6 +88,7 @@ struct FrontendToStablehloLoweringPass void getDependentDialects(DialectRegistry ®istry) const override { registry.insert(); + registry.insert(); registry.insert(); } diff --git a/src/Conversion/ONNXToStablehlo/Math/MatMul.cpp b/src/Conversion/ONNXToStablehlo/Math/MatMul.cpp index 667860b721..79b1f8576e 100644 --- a/src/Conversion/ONNXToStablehlo/Math/MatMul.cpp +++ b/src/Conversion/ONNXToStablehlo/Math/MatMul.cpp @@ -44,13 +44,9 @@ struct ONNXMatMulOpLoweringToStablehlo : public ConversionPattern { ShapedType outputShapedType = outputType.cast(); Type elementType = outputShapedType.getElementType(); - if (llvm::any_of(outputShapedType.getShape(), ShapedType::isDynamic)) - return rewriter.notifyMatchFailure( - op, "dynamic dimensions not supported"); - Value A(operandAdaptor.getA()), B(operandAdaptor.getB()); - auto aRank = A.getType().cast().getShape().size(); - auto bRank = B.getType().cast().getShape().size(); + auto aRank = A.getType().cast().getRank(); + auto bRank = B.getType().cast().getRank(); // Size all the arrays to padded length. int paddedRank = std::max(aRank, bRank); paddedRank = std::max(paddedRank, 2); @@ -82,28 +78,54 @@ struct ONNXMatMulOpLoweringToStablehlo : public ConversionPattern { if (!bPadDims[paddedRank - 1]) bShape.push_back(bShapeList[paddedRank - 1]); - Type outputAType = RankedTensorType::get(aShape, elementType); - Type outputBType = RankedTensorType::get(bShape, elementType); - int64_t oneDPadA = aPadDims[paddedRank - 2]; int64_t oneDPadB = bPadDims[paddedRank - 1]; - Value broadcastedA; - { + // TODO: Some of the above logic could probably be absorbed into this + // function but will require more refactoring + auto broadCastTo = [&](const Value &operandToBroadcast, + const Value &operandToMatch, + ArrayRef shapeInts, int64_t oneDPad) { + Value broadcasted; + auto rank = operandToBroadcast.getType().cast().getRank(); + RankedTensorType broadCastedType = + RankedTensorType::get(shapeInts, elementType); SmallVector broadcastDimensions = llvm::to_vector<4>(llvm::seq( - paddedRank - oneDPadA - aRank, paddedRank - oneDPadA)); - broadcastedA = rewriter.createOrFold( - loc, outputAType, A, rewriter.getI64VectorAttr(broadcastDimensions)); - } - Value broadcastedB; - { - SmallVector broadcastDimensions = - llvm::to_vector<4>(llvm::seq( - paddedRank - oneDPadB - bRank, paddedRank - oneDPadB)); - broadcastedB = rewriter.createOrFold( - loc, outputBType, B, rewriter.getI64VectorAttr(broadcastDimensions)); - } + paddedRank - oneDPad - rank, paddedRank - oneDPad)); + if (!broadCastedType.hasStaticShape()) { + SmallVector dimTensors(paddedRank - oneDPad - rank); + for (int64_t i = 0; i < paddedRank - oneDPad - rank; i++) { + Value dim = rewriter.create(loc, operandToMatch, i); + dim = rewriter.create( + loc, rewriter.getI64Type(), dim); + dimTensors[i] = + rewriter.create(loc, ValueRange{dim}); + } + Value broadcastedShape = + rewriter.create(loc, operandToBroadcast); + broadcastedShape = rewriter.create(loc, + RankedTensorType::get({rank}, rewriter.getI64Type()), + broadcastedShape); + dimTensors.push_back(broadcastedShape); + Value fullShape = rewriter.create(loc, + RankedTensorType::get( + {broadCastedType.getRank()}, rewriter.getI64Type()), + dimTensors, rewriter.getI64IntegerAttr(0)); + broadcasted = rewriter.createOrFold( + loc, broadCastedType, operandToBroadcast, fullShape, + rewriter.getI64VectorAttr(broadcastDimensions)); + } else { + broadcasted = rewriter.createOrFold(loc, + broadCastedType, operandToBroadcast, + rewriter.getI64VectorAttr(broadcastDimensions)); + } + return broadcasted; + }; + + Value broadcastedA = broadCastTo(A, B, aShape, oneDPadA); + Value broadcastedB = broadCastTo(B, A, bShape, oneDPadB); + Value dotProduct; if (paddedRank > 2) dotProduct = rewriter.create(loc, outputType, diff --git a/src/Conversion/ONNXToStablehlo/ONNXToStablehloCommon.hpp b/src/Conversion/ONNXToStablehlo/ONNXToStablehloCommon.hpp index 832e72b973..5618f9962c 100644 --- a/src/Conversion/ONNXToStablehlo/ONNXToStablehloCommon.hpp +++ b/src/Conversion/ONNXToStablehlo/ONNXToStablehloCommon.hpp @@ -179,6 +179,8 @@ void populateLoweringONNXConcatOpToStablehloPattern( RewritePatternSet &, MLIRContext *); void populateLoweringONNXConstantOpToStablehloPattern( RewritePatternSet &, MLIRContext *); +void populateLoweringONNXDimOpToStablehloPattern( + RewritePatternSet &, MLIRContext *); void populateLoweringONNXDepthToSpaceOpToStablehloPattern( RewritePatternSet &, MLIRContext *); void populateLoweringONNXExpandOpToStablehloPattern( diff --git a/src/Conversion/ONNXToStablehlo/Tensor/Dim.cpp b/src/Conversion/ONNXToStablehlo/Tensor/Dim.cpp new file mode 100644 index 0000000000..2e40c2ade6 --- /dev/null +++ b/src/Conversion/ONNXToStablehlo/Tensor/Dim.cpp @@ -0,0 +1,65 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + */ + +//===----------------- Dim.cpp - Lowering Dim Op ----------------===// +// +// Copyright 2022-2024 +// +// ============================================================================= +// +// This file lowers the ONNXDim operator to the Tensor dialect. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Shape/IR/Shape.h" +#include "src/Conversion/ONNXToStablehlo/ONNXToStablehloCommon.hpp" + +using namespace mlir; + +namespace onnx_mlir { + +namespace { + +struct ONNXDimOpLoweringToStablehlo : public ConversionPattern { + ONNXDimOpLoweringToStablehlo(MLIRContext *ctx) + : ConversionPattern(ONNXDimOp::getOperationName(), 1, ctx) {} + + LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const final { + Location loc = op->getLoc(); + ONNXDimOp dimOp = cast(op); + int64_t axisLit = dimOp.getAxis(); + + // Check that axisLit is a valid dimension index + Value tensorArg = operands[0]; + assert(tensorArg.getType().isa() && + "Expected ranked tensor type"); + + int64_t rank = tensorArg.getType().cast().getRank(); + + assert((axisLit >= 0 && axisLit < rank) && + "Axis must be in the range [0, input tensor rank - 1]"); + + Value inputShape = rewriter.create(loc, tensorArg); + Value dimValue = + rewriter.create(loc, inputShape, axisLit); + Type dimType = dimOp.getDim().getType(); + Type indexValueType = dimType.cast().getElementType(); + Value castedIndex = + rewriter.create(loc, indexValueType, dimValue); + Value indexTensor = rewriter.create( + loc, dimType, ArrayRef{castedIndex}); + rewriter.replaceOp(op, indexTensor); + return success(); + } +}; + +} // namespace + +void populateLoweringONNXDimOpToStablehloPattern( + RewritePatternSet &patterns, MLIRContext *ctx) { + patterns.insert(ctx); +} + +} // namespace onnx_mlir diff --git a/src/Conversion/ONNXToStablehlo/Tensor/Gather.cpp b/src/Conversion/ONNXToStablehlo/Tensor/Gather.cpp index 46937a9f7d..8ecc5d2962 100644 --- a/src/Conversion/ONNXToStablehlo/Tensor/Gather.cpp +++ b/src/Conversion/ONNXToStablehlo/Tensor/Gather.cpp @@ -40,7 +40,8 @@ struct ONNXGatherOpLoweringToStablehlo : public ConversionPattern { shapeHelper.computeShapeAndAssertOnFailure(); Type outputType = *op->result_type_begin(); - assert(isRankedShapedType(outputType) && "Expected Ranked ShapedType"); + if (!isRankedShapedType(outputType)) + return rewriter.notifyMatchFailure(op, "Expected Ranked ShapedType"); // Operands and attributes. Value data = operandAdaptor.getData(); @@ -56,7 +57,7 @@ struct ONNXGatherOpLoweringToStablehlo : public ConversionPattern { // start indices Value zero = getShapedZero(loc, rewriter, indices); Value axisDimSize; - if (inputType.hasStaticShape()) { + if (!inputType.isDynamicDim(axisLit)) { int64_t axisDimSizeLit = inputType.getShape()[axisLit]; axisDimSize = getShapedInt(loc, rewriter, axisDimSizeLit, indices); } else { @@ -66,6 +67,9 @@ struct ONNXGatherOpLoweringToStablehlo : public ConversionPattern { rewriter.create(loc, inputShape, axisLit); Value axisDimSizeValue = rewriter.create( loc, indicesType.getElementType(), axisDimSizeIndexValue); + axisDimSizeValue = rewriter.create(loc, + RankedTensorType::get({}, indicesType.getElementType()), + axisDimSizeValue); axisDimSize = rewriter.create(loc, indicesType, axisDimSizeValue, indicesShape, rewriter.getI64TensorAttr({})); diff --git a/src/Dialect/Krnl/DialectBuilder.cpp b/src/Dialect/Krnl/DialectBuilder.cpp index 05f4a23f4e..e3b933d73e 100644 --- a/src/Dialect/Krnl/DialectBuilder.cpp +++ b/src/Dialect/Krnl/DialectBuilder.cpp @@ -211,19 +211,10 @@ void KrnlBuilder::matmul(Value A, ValueRange aStart, Value B, ValueRange bStart, globalUBs[1], globalUBs[2], simdize, unroll, overCompute); } -Value KrnlBuilder::dim(Type type, Value alloc, Value index) const { - return b().create(loc(), type, alloc, index); -} - KrnlMovableOp KrnlBuilder::movable() const { return b().create(loc()); } -KrnlGetRefOp KrnlBuilder::getRef( - Type type, Value memref, Value offset, ValueRange indices) const { - return b().create(loc(), type, memref, offset, indices); -} - Value KrnlBuilder::constant(MemRefType type, StringRef name, std::optional value, std::optional offset, std::optional alignment) const { diff --git a/src/Dialect/Krnl/DialectBuilder.hpp b/src/Dialect/Krnl/DialectBuilder.hpp index eb0622b2db..5957a23855 100644 --- a/src/Dialect/Krnl/DialectBuilder.hpp +++ b/src/Dialect/Krnl/DialectBuilder.hpp @@ -128,13 +128,8 @@ struct KrnlBuilder : public DialectBuilder { mlir::ValueRange globalUBs, bool simdize, bool unroll, bool overCompute) const; - mlir::Value dim(mlir::Type type, mlir::Value alloc, mlir::Value index) const; - mlir::KrnlMovableOp movable() const; - mlir::KrnlGetRefOp getRef(mlir::Type type, mlir::Value memref, - mlir::Value offset, mlir::ValueRange indices = {}) const; - mlir::Value constant(mlir::MemRefType type, mlir::StringRef name, std::optional value, std::optional offset = std::nullopt, diff --git a/src/Dialect/Krnl/Krnl.td b/src/Dialect/Krnl/Krnl.td index 84f5106b27..31f5a9949d 100644 --- a/src/Dialect/Krnl/Krnl.td +++ b/src/Dialect/Krnl/Krnl.td @@ -368,37 +368,6 @@ def KrnlGlobalOp : Op { let results = (outs AnyTypeOf<[AnyMemRef]>:$output); } -def KrnlGetRefOp : Op { - let summary = "Krnl a MemRef from within another MemRef starting at a specific offset."; - let description = [{ - Retrieves a MemRef from within another MemRef: - -``` - "krnl.getref"(%memref, %offset) -``` - The offset is an integer which is used as an index into the input MemRef. It works - just like an array index. - }]; - - let arguments = (ins AnyTypeOf<[AnyMemRef]>:$mempool, AnyInteger:$offset, - Variadic:$value); - let results = (outs AnyTypeOf<[AnyMemRef]>:$output); - - let builders = [ - OpBuilder<(ins "Type":$resultType, "Value":$mempool, "Value":$offset), [{ - build($_builder, $_state, resultType, mempool, offset, {}); - }]>, - ]; - - let extraClassDeclaration = [{ - /// Returns the symbolic operands (the ones in square brackets), which bind - /// to the symbols of the memref's layout map. - operand_range getDynamicSizes() { - return {operand_begin() + 2, operand_end()}; - } - }]; -} - def KrnlBlockOp : Op { let summary = "Krnl block operation"; let description = [{ @@ -522,41 +491,6 @@ def KrnlParallelOp : Op { }]; } - -def KrnlDimOp : Op { - let summary = "Krnl dimensions operation."; - let description = [{ - Emits the dimension of a MemRef independent of the MemRef alloc: - - ``` - "krnl.dim"(%memref, %index) - ``` - - The index identifies the dimension within the shape which is going to be emitted. - Initially the krnl.dim operation depends on the alloc of the MemRef. - Unlike the std.dim operation which maintains a dependency on the alloc of the MemRef, the dimension emitted by krnl.dim will not depend on the alloc operation of the MemRef once the krnl.dim operation is lowered. - - Any changes to the original MemRef size after the krnl.dim has been lowered will not be picked up by the emitted dimension. This allows the original MemRef to be safely modified via code transformations or affine map normalization without the risk of changing the value already emitted via krnl.dim. - }]; - - let arguments = (ins AnyTypeOf<[AnyMemRef]>:$alloc, Index:$index); - let results = (outs Index:$dimension); -} - -def KrnlShapeOp : Op { - let summary = "Krnl operation to retrieve the shape of a MemRef."; - let description = [{ - Extracts the shape of a MemRef: - ``` - "krnl.shape"(%memref) - ``` - The return result is of `shape.type`. - }]; - - let arguments = (ins AnyTypeOf<[AnyMemRef]>:$alloc); - let results = (outs AnyTypeOf<[AnyMemRef]>:$shape); -} - def KrnlErfOp : Op { let summary = "Krnl erf scalar operation"; let description = [{ diff --git a/src/Dialect/Mlir/DialectBuilder.cpp b/src/Dialect/Mlir/DialectBuilder.cpp index 593ab7a4a3..5daa065b4c 100644 --- a/src/Dialect/Mlir/DialectBuilder.cpp +++ b/src/Dialect/Mlir/DialectBuilder.cpp @@ -1333,6 +1333,14 @@ memref::ViewOp MemRefBuilder::view(Value input, int64_t byteOffset, loc(), outputType, input, offset, outputDynSymbols); } +memref::SubViewOp MemRefBuilder::subView(MemRefType outputType, Value val, + llvm::SmallVectorImpl &offsets, + llvm::SmallVectorImpl &sizes, + llvm::SmallVectorImpl &strides) const { + return b().create( + loc(), outputType, val, offsets, sizes, strides); +} + memref::SubViewOp MemRefBuilder::subView(Value input, llvm::SmallVectorImpl &offsetsIE, llvm::SmallVectorImpl &sizesIE, diff --git a/src/Dialect/Mlir/DialectBuilder.hpp b/src/Dialect/Mlir/DialectBuilder.hpp index f209de0b06..39a8fffbbf 100644 --- a/src/Dialect/Mlir/DialectBuilder.hpp +++ b/src/Dialect/Mlir/DialectBuilder.hpp @@ -345,6 +345,13 @@ struct MemRefBuilder final : DialectBuilder { mlir::memref::ViewOp view(mlir::Value input, int64_t byteOffset, mlir::MemRefType outputType, mlir::ValueRange outputDynSymbols) const; + // Create a subview of val. + mlir::memref::SubViewOp subView(mlir::MemRefType outputType, mlir::Value val, + llvm::SmallVectorImpl &offsets, // Offset for each val dims. + llvm::SmallVectorImpl &sizes, // Sizes for each val dims. + llvm::SmallVectorImpl &strides) // Stride for each val dims. + const; + // Create a subview of val. Size of 1 => remove that dim. mlir::memref::SubViewOp subView(mlir::Value val, llvm::SmallVectorImpl &offsets, // Offset for each val dims. diff --git a/src/Dialect/ONNX/CMakeLists.txt b/src/Dialect/ONNX/CMakeLists.txt index 4be0256ad7..3917c94aa4 100644 --- a/src/Dialect/ONNX/CMakeLists.txt +++ b/src/Dialect/ONNX/CMakeLists.txt @@ -7,12 +7,12 @@ if (${CMAKE_SYSTEM_PROCESSOR} STREQUAL "s390x") endif() add_subdirectory(ElementsAttr) +add_subdirectory(ONNXOps) +add_subdirectory(Transforms) add_onnx_mlir_dialect(ONNX onnx) add_onnx_mlir_dialect_doc(onnx ONNX.td) -add_onnx_mlir_rewriter(Rewrite) - add_onnx_mlir_library(OMONNXOps # Top files for ONNX dialect DialectBuilder.cpp @@ -22,9 +22,9 @@ add_onnx_mlir_library(OMONNXOps ONNXDimAnalysis.cpp ONNXOps.cpp ONNXOps/OpHelper.cpp + ONNXOps/Canonicalize.cpp ONNXOps/ShapeHelper.cpp ONNXTypes.cpp - Rewrite.cpp # Support for shape inference and verifiers ONNXOps/Additional/ConcatShapeTranspose.cpp @@ -104,7 +104,7 @@ add_onnx_mlir_library(OMONNXOps DEPENDS OMHasOnnxSubgraphOpInterfaceIncGen OMONNXIncGen - OMONNXRewriteIncGen + OMONNXCanonicalizeIncGen OMResultTypeInferenceOpInterfaceIncGen OMShapeInferenceOpInterfaceIncGen diff --git a/src/Dialect/ONNX/ONNXDimAnalysis.cpp b/src/Dialect/ONNX/ONNXDimAnalysis.cpp index 3f950d8848..c1b06859d8 100644 --- a/src/Dialect/ONNX/ONNXDimAnalysis.cpp +++ b/src/Dialect/ONNX/ONNXDimAnalysis.cpp @@ -869,8 +869,9 @@ void DimAnalysis::visitDim( // CastOp if (auto castOp = dyn_cast(op)) { - DimAnalysis::DimT newSameDim(castOp.getInput(), dimIndex); - sameDims.insert(newSameDim); + if (auto d = insertDimWhenUseful(castOp.getInput(), dimIndex, sameDims)) + LLVM_DEBUG(llvm::dbgs() << " - Added a new dim(" << d.value().first + << ", " << d.value().second << ")\n"); return; } diff --git a/src/Dialect/ONNX/ONNXOps.td.inc b/src/Dialect/ONNX/ONNXOps.td.inc index ed8c25c749..943218962d 100644 --- a/src/Dialect/ONNX/ONNXOps.td.inc +++ b/src/Dialect/ONNX/ONNXOps.td.inc @@ -3387,7 +3387,7 @@ def ONNXIsInfOp:ONNX_Op<"IsInf", let description = [{ Map infinity to true and other values to false. }]; - let arguments = (ins AnyTypeOf<[TensorOf<[F32]>, TensorOf<[F64]>]>:$X, + let arguments = (ins AnyTypeOf<[TensorOf<[BF16]>, TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>, TensorOf<[F8E4M3FN]>, TensorOf<[F8E4M3FNUZ]>, TensorOf<[F8E5M2]>, TensorOf<[F8E5M2FNUZ]>]>:$X, DefaultValuedAttr:$detect_negative, DefaultValuedAttr:$detect_positive); let results = (outs TensorOf<[I1]>:$Y); @@ -3419,7 +3419,7 @@ def ONNXIsNaNOp:ONNX_Op<"IsNaN", let description = [{ Returns which elements of the input are NaN. }]; - let arguments = (ins AnyTypeOf<[TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>, TensorOf<[BF16]>]>:$X); + let arguments = (ins AnyTypeOf<[TensorOf<[BF16]>, TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>, TensorOf<[F8E4M3FN]>, TensorOf<[F8E4M3FNUZ]>, TensorOf<[F8E5M2]>, TensorOf<[F8E5M2FNUZ]>]>:$X); let results = (outs TensorOf<[I1]>:$Y); let extraClassDeclaration = [{ static int getNumberOfOperands() { diff --git a/src/Dialect/ONNX/ONNXOps/Additional/Dim.cpp b/src/Dialect/ONNX/ONNXOps/Additional/Dim.cpp index 2f0a513b35..7afdc6bbf4 100644 --- a/src/Dialect/ONNX/ONNXOps/Additional/Dim.cpp +++ b/src/Dialect/ONNX/ONNXOps/Additional/Dim.cpp @@ -37,12 +37,16 @@ LogicalResult ONNXDimOpShapeHelper::computeShape() { //===----------------------------------------------------------------------===// LogicalResult ONNXDimOp::verify() { - // Input data must be ranked. if (!hasShapeAndRank(this->getData())) - return failure(); - // Axis must be in [0, rank -1]. + return emitOpError("input must have shape and rank."); + int64_t axis = this->getAxis(); - return failure((axis < 0) || (axis >= getRank(this->getData().getType()))); + if ((axis < 0) || (axis >= getRank(this->getData().getType()))) + return emitOpError("attribute ") + << ONNXDimOp::getAxisAttrName() << " value is " << axis + << ", accepted range is [0, " + << getRank(this->getData().getType()) - 1 << "]."; + return success(); } //===----------------------------------------------------------------------===// diff --git a/src/Dialect/ONNX/ONNXOps/CMakeLists.txt b/src/Dialect/ONNX/ONNXOps/CMakeLists.txt new file mode 100644 index 0000000000..6ccb259555 --- /dev/null +++ b/src/Dialect/ONNX/ONNXOps/CMakeLists.txt @@ -0,0 +1,4 @@ +# SPDX-License-Identifier: Apache-2.0 + +add_onnx_mlir_rewriter(Canonicalize) + diff --git a/src/Dialect/ONNX/Rewrite.cpp b/src/Dialect/ONNX/ONNXOps/Canonicalize.cpp similarity index 98% rename from src/Dialect/ONNX/Rewrite.cpp rename to src/Dialect/ONNX/ONNXOps/Canonicalize.cpp index a8d7f097d6..ba8ebafb9d 100644 --- a/src/Dialect/ONNX/Rewrite.cpp +++ b/src/Dialect/ONNX/ONNXOps/Canonicalize.cpp @@ -103,6 +103,21 @@ SmallVector transposeVariadicInput(PatternRewriter &rewriter, return transposedInputs; } +// Cast a variadic input using the given `saturate` and `to`. +SmallVector castVariadicInput(PatternRewriter &rewriter, Location loc, + ValueRange inputs, IntegerAttr saturate, TypeAttr to) { + SmallVector castInputs; + for (Value inp : inputs) { + ShapedType inpType = inp.getType().cast(); + assert(inpType && "Type is not ShapedType"); + ONNXCastOp castOp = rewriter.create(loc, + UnrankedTensorType::get(inpType.getElementType()), inp, saturate, to); + (void)castOp.inferShapes([](Region ®ion) {}); + castInputs.emplace_back(castOp.getResult()); + } + return castInputs; +} + // Check if all values are produced by ONNXTransposeOp. bool areProducedByTransposeOp(ValueRange values) { return llvm::all_of(values, [](Value v) { @@ -260,7 +275,7 @@ bool matchShapeAddMatMul(Value v, Value &matA, Value &biasB, /// Include the patterns defined in the Declarative Rewrite framework. // ============================================================================= -#include "src/Dialect/ONNX/ONNXRewrite.inc" +#include "src/Dialect/ONNX/ONNXOps/ONNXCanonicalize.inc" // ============================================================================= // Rewrite pattern for elementwise binary ops (not handled in Rewrite.td). @@ -1516,6 +1531,8 @@ void ONNXAndOp::getCanonicalizationPatterns( void ONNXCastOp::getCanonicalizationPatterns( RewritePatternSet &result, MLIRContext *context) { result.insert(context); + result.insert(context); + result.insert(context); // TODO: Reintroduce pattern for sound type combinations, see issue #2210. // result.insert(context); } diff --git a/src/Dialect/ONNX/Rewrite.td b/src/Dialect/ONNX/ONNXOps/Canonicalize.td similarity index 98% rename from src/Dialect/ONNX/Rewrite.td rename to src/Dialect/ONNX/ONNXOps/Canonicalize.td index fc24f35cc6..ad3c498ecf 100644 --- a/src/Dialect/ONNX/Rewrite.td +++ b/src/Dialect/ONNX/ONNXOps/Canonicalize.td @@ -232,6 +232,9 @@ class RankXMinusRankYIs: Constraint< def TransposeVariadicInput: NativeCodeCall< "onnx_mlir::transposeVariadicInput($_builder, $_loc, $0, $1)">; +def CastVariadicInput: NativeCodeCall< + "onnx_mlir::castVariadicInput($_builder, $_loc, $0, $1, $2)">; + // Check whether two variables are equal. def Equal: Constraint, "are equal">; @@ -432,6 +435,21 @@ def CastEliminationPattern : Pat< // (ONNXCastOp (ONNXCastOp $arg, $_), $type), // (ONNXCastOp $arg, $type)>; +// Do cast on concat's inputs instead of output in order to propagate +// cast operations together, which brings concat close to reshape +// because concat is used for shape in reshape. +def SwapCastConcatPattern: Pat< + (ONNXCastOp (ONNXConcatOp $inputs, $axis), $saturate, $to), + (ONNXConcatOp (CastVariadicInput $inputs, $saturate, $to), $axis) +>; + +// Do cast on slice's inputs instead of output in order to propagate +// cast operations together, which brings slice close to reshape. +def SwapCastSlicePattern: Pat< + (ONNXCastOp (ONNXSliceOp $data, $starts, $ends, $axes, $steps), $saturate, $to), + (ONNXSliceOp (ONNXCastOp $data, $saturate, $to), $starts, $ends, $axes, $steps) +>; + //===----------------------------------------------------------------------===// // Canonicalization for ONNXLayoutTransformOp //===----------------------------------------------------------------------===// diff --git a/src/Dialect/ONNX/ONNXOps/Math/ElementwiseBroadcast.cpp b/src/Dialect/ONNX/ONNXOps/Math/ElementwiseBroadcast.cpp index 22392cebb8..0c3092d66d 100644 --- a/src/Dialect/ONNX/ONNXOps/Math/ElementwiseBroadcast.cpp +++ b/src/Dialect/ONNX/ONNXOps/Math/ElementwiseBroadcast.cpp @@ -4,7 +4,7 @@ //===------------------ ElementwiseBroadcast.cpp - ONNX Operations --------===// // -// Copyright 2019-2023 The IBM Research Authors. +// Copyright 2019-2024 The IBM Research Authors. // // ============================================================================= // @@ -299,6 +299,10 @@ LogicalResult ONNXModOp::verify() { // must be set to 1. if (elementType.isa() && (getFmod() != 1)) return emitOpError("fmod must be 1 when the input type is floating point"); + // Verify that when the input type is integer, then `fmod` attribute + // must be set to 0. + if (elementType.isa() && (getFmod() != 0)) + return emitOpError("fmod must be 0 when the input type is an integer"); return success(); } diff --git a/src/Transform/ONNX/CMakeLists.txt b/src/Dialect/ONNX/Transforms/CMakeLists.txt similarity index 98% rename from src/Transform/ONNX/CMakeLists.txt rename to src/Dialect/ONNX/Transforms/CMakeLists.txt index 2575a01517..36813f8576 100644 --- a/src/Transform/ONNX/CMakeLists.txt +++ b/src/Dialect/ONNX/Transforms/CMakeLists.txt @@ -4,25 +4,6 @@ add_onnx_mlir_rewriter(Decompose) add_onnx_mlir_rewriter(ConstProp) add_onnx_mlir_rewriter(ConvOpt) -add_onnx_mlir_library(OMONNXRewrite - ConstProp.cpp - ConvOpt.cpp - Decompose.cpp - DecomposeEinsum.cpp - ScrubDisposablePass.cpp - SetONNXNodeName.cpp - Recompose.cpp - - DEPENDS - OMONNXDecomposeIncGen - OMONNXConstPropIncGen - OMONNXConvOptIncGen - - LINK_LIBS PUBLIC - MLIRTransformUtils - OMONNXOps - ) - add_onnx_mlir_library(OMShapeInference ShapeInference.cpp @@ -44,7 +25,6 @@ add_onnx_mlir_library(OMShapeInferencePass ) add_onnx_mlir_library(OMInstrumentONNX - InstrumentPass.cpp InstrumentONNXSignaturePass.cpp INCLUDE_DIRS PUBLIC @@ -52,11 +32,29 @@ add_onnx_mlir_library(OMInstrumentONNX LINK_LIBS PUBLIC OMONNXOps - OMKrnlOps MLIRPass OMOptionUtils ) +add_onnx_mlir_library(OMONNXRewrite + ConstProp.cpp + ConvOpt.cpp + Decompose.cpp + DecomposeEinsum.cpp + ScrubDisposablePass.cpp + SetONNXNodeName.cpp + Recompose.cpp + + DEPENDS + OMONNXDecomposeIncGen + OMONNXConstPropIncGen + OMONNXConvOptIncGen + + LINK_LIBS PUBLIC + MLIRTransformUtils + OMONNXOps + ) + add_onnx_mlir_library(OMOpTransform ONNXOpTransformPass.cpp diff --git a/src/Transform/ONNX/ConstProp.cpp b/src/Dialect/ONNX/Transforms/ConstProp.cpp similarity index 98% rename from src/Transform/ONNX/ConstProp.cpp rename to src/Dialect/ONNX/Transforms/ConstProp.cpp index 116a41d653..b8b5585665 100644 --- a/src/Transform/ONNX/ConstProp.cpp +++ b/src/Dialect/ONNX/Transforms/ConstProp.cpp @@ -4,7 +4,7 @@ //===----------- ONNXConstProp.cpp - ONNX High Level Rewriting ------------===// // -// Copyright 2019-2023 The IBM Research Authors. +// Copyright 2019-2024 The IBM Research Authors. // // ============================================================================= // @@ -13,9 +13,6 @@ // //===----------------------------------------------------------------------===// -#include "src/Transform/ONNX/ConstProp.hpp" -#include "src/Pass/Passes.hpp" - #include "mlir/IR/Matchers.h" #include "mlir/IR/PatternMatch.h" #include "mlir/IR/TypeUtilities.h" @@ -32,6 +29,8 @@ #include "src/Dialect/ONNX/ONNXOps/OpHelper.hpp" #include "src/Dialect/ONNX/ONNXOps/ShapeHelper.hpp" #include "src/Dialect/ONNX/OnnxElementsAttrBuilder.hpp" +#include "src/Dialect/ONNX/Transforms/ConstProp.hpp" +#include "src/Pass/Passes.hpp" #include "src/Support/TypeUtilities.hpp" #include @@ -217,6 +216,28 @@ struct ElementWiseBinaryOpImpl { static T eval(T lhs, T rhs) { return std::max(lhs, rhs); } }; +template <> +struct ElementWiseBinaryOpImpl> { + static int64_t eval(int64_t lhs, int64_t rhs) { + // The original calculation for mod + int64_t mod = lhs % rhs; + // Handle the case when one of the int values are negative + // If both int values are positive or multiples of each other, we can + // calculate as normal + if ((mod != 0) && ((lhs < 0) ^ (rhs < 0))) + return (mod + rhs); + return mod; + } +}; + +template <> +struct ElementWiseBinaryOpImpl> { + static double eval(double lhs, double rhs) { + // Rounding to match the results of the backend tests + return (std::floor(fmod(lhs, rhs) * 1000000000) / 1000000000); + } +}; + template struct ElementWiseBinaryOpImpl { static bool eval(T lhs, T rhs) { return lhs == rhs; } @@ -945,7 +966,7 @@ Value ConstPropNonZero( // Pattern definition. //===----------------------------------------------------------------------===// -#include "src/Transform/ONNX/ONNXConstProp.inc" +#include "src/Dialect/ONNX/Transforms/ONNXConstProp.inc" //===----------------------------------------------------------------------===// // Code to perform constant propagation for split. diff --git a/src/Transform/ONNX/ConstProp.hpp b/src/Dialect/ONNX/Transforms/ConstProp.hpp similarity index 100% rename from src/Transform/ONNX/ConstProp.hpp rename to src/Dialect/ONNX/Transforms/ConstProp.hpp diff --git a/src/Transform/ONNX/ConstProp.td b/src/Dialect/ONNX/Transforms/ConstProp.td similarity index 97% rename from src/Transform/ONNX/ConstProp.td rename to src/Dialect/ONNX/Transforms/ConstProp.td index 329ed8905f..46a4fc1d3f 100644 --- a/src/Transform/ONNX/ConstProp.td +++ b/src/Dialect/ONNX/Transforms/ConstProp.td @@ -2,7 +2,7 @@ //===- ONNXConstProp.td - Rewriting for Constant Propagation in ONNX Ops -*- tablegen -===// // -// Copyright 2019-2020 The IBM Research Authors. +// Copyright 2019-2024 The IBM Research Authors. // // ============================================================================= // @@ -166,6 +166,9 @@ def CreateGreaterOrEqualOfTwoConst : def CreatePowOfTwoConst : NativeCodeCall<"ConstPropElementwiseBinary($_builder, $0, $1, $2)">; +def CreateModOfTwoConst : + NativeCodeCall<"ConstPropElementwiseBinary($_builder, $0, $1, $2)">; + def CreateWhereOfThreeConst : NativeCodeCall<"ConstPropWhere($_builder, $0, $1, $2, $3)">; @@ -522,7 +525,7 @@ def DivOnesOnRhs : NamedPat<"DivOnesOnRhs", ]>; //===----------------------------------------------------------------------===// -// Constant propagate ONNXEqualOp +// Constant propagation for ONNXEqualOp //===----------------------------------------------------------------------===// def EqualConstProp : NamedPat<"EqualConstProp", @@ -537,7 +540,7 @@ def EqualConstProp : NamedPat<"EqualConstProp", (IsIntOrFloatType:$lhs), (SatisfiesExpansionBound:$result)]>; //===----------------------------------------------------------------------===// -// Constant propagate ONNXLessOp +// Constant propagation for ONNXLessOp //===----------------------------------------------------------------------===// def LessConstPropPattern : NamedPat<"LessConstPropPattern", @@ -549,7 +552,7 @@ def LessConstPropPattern : NamedPat<"LessConstPropPattern", (SatisfiesExpansionBound:$result)]>; //===----------------------------------------------------------------------===// -// Constant propagate ONNXGreaterOp +// Constant propagation for ONNXGreaterOp //===----------------------------------------------------------------------===// def GreaterConstPropPattern : NamedPat<"GreaterConstPropPattern", @@ -561,7 +564,7 @@ def GreaterConstPropPattern : NamedPat<"GreaterConstPropPattern", (SatisfiesExpansionBound:$result)]>; //===----------------------------------------------------------------------===// -// Constant propagate ONNXLessOrEqualOp +// Constant propagation for ONNXLessOrEqualOp //===----------------------------------------------------------------------===// def LessOrEqualConstPropPattern : NamedPat<"LessOrEqualConstPropPattern", @@ -573,7 +576,7 @@ def LessOrEqualConstPropPattern : NamedPat<"LessOrEqualConstPropPattern", (SatisfiesExpansionBound:$result)]>; //===----------------------------------------------------------------------===// -// Constant propagate ONNXGreaterOrEqualOp +// Constant propagation for ONNXGreaterOrEqualOp //===----------------------------------------------------------------------===// def GreaterOrEqualConstPropPattern : NamedPat<"GreaterOrEqualConstPropPattern", @@ -584,6 +587,19 @@ def GreaterOrEqualConstPropPattern : NamedPat<"GreaterOrEqualConstPropPattern", [(IsFromDenseONNXConstantOp:$lhs), (IsFromDenseONNXConstantOp:$rhs), (SatisfiesExpansionBound:$result)]>; +//===----------------------------------------------------------------------===// +// Constant propagation for ONNXModOp +//===----------------------------------------------------------------------===// + +def ModConstPropPattern : NamedPat<"ModConstPropPattern", + (ONNXModOp:$modOp + (ONNXConstantOp:$A $_, $_, $_, $_, $_, $_, $_, $_), + (ONNXConstantOp:$B $_, $_, $_, $_, $_, $_, $_, $_), + $fmod), + (CreateModOfTwoConst $modOp, $A, $B), + [(IsFromDenseONNXConstantOp:$A), (IsFromDenseONNXConstantOp:$B), + (SatisfiesExpansionBound:$modOp)]>; + //===----------------------------------------------------------------------===// // Patterns for Where. //===----------------------------------------------------------------------===// diff --git a/src/Transform/ONNX/ConvOpt.cpp b/src/Dialect/ONNX/Transforms/ConvOpt.cpp similarity index 99% rename from src/Transform/ONNX/ConvOpt.cpp rename to src/Dialect/ONNX/Transforms/ConvOpt.cpp index 1c0f201dc9..7aafaa3fcd 100644 --- a/src/Transform/ONNX/ConvOpt.cpp +++ b/src/Dialect/ONNX/Transforms/ConvOpt.cpp @@ -13,9 +13,6 @@ // //===----------------------------------------------------------------------===// -#include "src/Transform/ONNX/ConvOpt.hpp" -#include "src/Pass/Passes.hpp" - #include "mlir/IR/Matchers.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Pass/Pass.h" @@ -25,6 +22,8 @@ #include "src/Dialect/ONNX/ONNXLayoutHelper.hpp" #include "src/Dialect/ONNX/ONNXOps.hpp" #include "src/Dialect/ONNX/ONNXOps/OpHelper.hpp" +#include "src/Dialect/ONNX/Transforms/ConvOpt.hpp" +#include "src/Pass/Passes.hpp" #include "src/Support/TypeUtilities.hpp" // Enables a minimum of printing. @@ -105,7 +104,7 @@ bool ExpressONNXConvOpAsMatmul(ONNXConvOp convOp, bool verbose = 0) { namespace { /// Include the patterns defined in the Declarative Rewrite framework. -#include "src/Transform/ONNX/ONNXConvOpt.inc" +#include "src/Dialect/ONNX/Transforms/ONNXConvOpt.inc" /* Pattern: when we have a convolution with filter of 1x1, stride 1, dilation of diff --git a/src/Transform/ONNX/ConvOpt.hpp b/src/Dialect/ONNX/Transforms/ConvOpt.hpp similarity index 100% rename from src/Transform/ONNX/ConvOpt.hpp rename to src/Dialect/ONNX/Transforms/ConvOpt.hpp diff --git a/src/Transform/ONNX/ConvOpt.td b/src/Dialect/ONNX/Transforms/ConvOpt.td similarity index 100% rename from src/Transform/ONNX/ConvOpt.td rename to src/Dialect/ONNX/Transforms/ConvOpt.td diff --git a/src/Transform/ONNX/Decompose.cpp b/src/Dialect/ONNX/Transforms/Decompose.cpp similarity index 99% rename from src/Transform/ONNX/Decompose.cpp rename to src/Dialect/ONNX/Transforms/Decompose.cpp index 94862b6054..0b34d4f80c 100644 --- a/src/Transform/ONNX/Decompose.cpp +++ b/src/Dialect/ONNX/Transforms/Decompose.cpp @@ -20,9 +20,6 @@ // //===----------------------------------------------------------------------===// -#include "src/Transform/ONNX/Decompose.hpp" -#include "src/Pass/Passes.hpp" - #include "mlir/IR/Matchers.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Pass/Pass.h" @@ -33,8 +30,10 @@ #include "src/Dialect/ONNX/ONNXOps.hpp" #include "src/Dialect/ONNX/ONNXOps/OpHelper.hpp" #include "src/Dialect/ONNX/ONNXOps/ShapeHelper.hpp" +#include "src/Dialect/ONNX/Transforms/Decompose.hpp" +#include "src/Dialect/ONNX/Transforms/DecomposeEinsum.hpp" +#include "src/Pass/Passes.hpp" #include "src/Support/TypeUtilities.hpp" -#include "src/Transform/ONNX/DecomposeEinsum.hpp" #define DEBUG_TYPE "decompose" @@ -486,7 +485,7 @@ Value normalizeConstantOp( namespace { /// Include the patterns defined in the Declarative Rewrite framework. -#include "src/Transform/ONNX/ONNXDecompose.inc" +#include "src/Dialect/ONNX/Transforms/ONNXDecompose.inc" #ifdef ONNX_MLIR_ENABLE_STABLEHLO @@ -905,7 +904,8 @@ struct GroupNormIntoLayerNormPattern Type inputShapeType = RankedTensorType::get({inputRank}, rewriter.getI64Type()); Value inputShape = create.onnx.shape(inputShapeType, input); - Value Y = create.onnx.reshape(inputType, layerNormY, inputShape); + Type outputType = groupNormOp.getY().getType(); + Value Y = create.onnx.reshape(outputType, layerNormY, inputShape); // Replace operation. rewriter.replaceOp(groupNormOp, Y); return success(); diff --git a/src/Transform/ONNX/Decompose.hpp b/src/Dialect/ONNX/Transforms/Decompose.hpp similarity index 100% rename from src/Transform/ONNX/Decompose.hpp rename to src/Dialect/ONNX/Transforms/Decompose.hpp diff --git a/src/Transform/ONNX/Decompose.td b/src/Dialect/ONNX/Transforms/Decompose.td similarity index 100% rename from src/Transform/ONNX/Decompose.td rename to src/Dialect/ONNX/Transforms/Decompose.td diff --git a/src/Transform/ONNX/DecomposeEinsum.cpp b/src/Dialect/ONNX/Transforms/DecomposeEinsum.cpp similarity index 99% rename from src/Transform/ONNX/DecomposeEinsum.cpp rename to src/Dialect/ONNX/Transforms/DecomposeEinsum.cpp index 4e8acd5ac5..4d86634c49 100644 --- a/src/Transform/ONNX/DecomposeEinsum.cpp +++ b/src/Dialect/ONNX/Transforms/DecomposeEinsum.cpp @@ -8,7 +8,7 @@ // //===----------------------------------------------------------------------===// -#include "src/Transform/ONNX/DecomposeEinsum.hpp" +#include "src/Dialect/ONNX/Transforms/DecomposeEinsum.hpp" #include "src/Dialect/Mlir/DialectBuilder.hpp" #include "src/Dialect/ONNX/DialectBuilder.hpp" #include "src/Dialect/ONNX/ONNXOps/Math/EinsumHelper.hpp" diff --git a/src/Transform/ONNX/DecomposeEinsum.hpp b/src/Dialect/ONNX/Transforms/DecomposeEinsum.hpp similarity index 100% rename from src/Transform/ONNX/DecomposeEinsum.hpp rename to src/Dialect/ONNX/Transforms/DecomposeEinsum.hpp diff --git a/src/Transform/ONNX/InstrumentONNXSignaturePass.cpp b/src/Dialect/ONNX/Transforms/InstrumentONNXSignaturePass.cpp similarity index 87% rename from src/Transform/ONNX/InstrumentONNXSignaturePass.cpp rename to src/Dialect/ONNX/Transforms/InstrumentONNXSignaturePass.cpp index 94b4bc8449..3f23b1242a 100644 --- a/src/Transform/ONNX/InstrumentONNXSignaturePass.cpp +++ b/src/Dialect/ONNX/Transforms/InstrumentONNXSignaturePass.cpp @@ -2,15 +2,14 @@ * SPDX-License-Identifier: Apache-2.0 */ -//===------- InstrumentONNXSignaturePass.cpp - Instrumentation -//---------------------===// +//===------- InstrumentONNXSignaturePass.cpp - Instrumentation ------------===// // // Copyright 2022 The IBM Research Authors. // // ============================================================================= // -// This file implements a Function level pass that inserts krnl print statements -// that print the operation name and its input type signature at runtime. +// This file implements a Function level pass that inserts statements that print +// the operation name and its input type signature at runtime. // //===----------------------------------------------------------------------===// @@ -25,8 +24,6 @@ #include "llvm/ADT/SmallPtrSet.h" #include "llvm/Support/raw_ostream.h" -#include "src/Dialect/Krnl/DialectBuilder.hpp" -#include "src/Dialect/Krnl/KrnlOps.hpp" #include "src/Dialect/ONNX/ONNXOps.hpp" #include "src/Dialect/ONNX/ONNXOps/OpHelper.hpp" #include "src/Interface/ShapeInferenceOpInterface.hpp" @@ -37,7 +34,7 @@ using namespace mlir; namespace { /*! - * This pass insert KrnlPrint and KrnlPrintTensor before each ONNX ops to print + * This pass insert ONNXPrintSignatureOp before each ONNX ops to print * an operation name and input operand type signatures at runtime. */ diff --git a/src/Transform/ONNX/ONNXHybridTransformPass.cpp b/src/Dialect/ONNX/Transforms/ONNXHybridTransformPass.cpp similarity index 95% rename from src/Transform/ONNX/ONNXHybridTransformPass.cpp rename to src/Dialect/ONNX/Transforms/ONNXHybridTransformPass.cpp index 23343ffa52..ccfd5fe154 100644 --- a/src/Transform/ONNX/ONNXHybridTransformPass.cpp +++ b/src/Dialect/ONNX/Transforms/ONNXHybridTransformPass.cpp @@ -20,13 +20,13 @@ #include "llvm/ADT/StringSet.h" #include "src/Dialect/ONNX/ONNXOps.hpp" +#include "src/Dialect/ONNX/Transforms/ConstProp.hpp" +#include "src/Dialect/ONNX/Transforms/ConvOpt.hpp" +#include "src/Dialect/ONNX/Transforms/Decompose.hpp" +#include "src/Dialect/ONNX/Transforms/Recompose.hpp" +#include "src/Dialect/ONNX/Transforms/ShapeInference.hpp" #include "src/Interface/ShapeInferenceOpInterface.hpp" #include "src/Pass/Passes.hpp" -#include "src/Transform/ONNX/ConstProp.hpp" -#include "src/Transform/ONNX/ConvOpt.hpp" -#include "src/Transform/ONNX/Decompose.hpp" -#include "src/Transform/ONNX/Recompose.hpp" -#include "src/Transform/ONNX/ShapeInference.hpp" #include diff --git a/src/Transform/ONNX/ONNXOpTransformPass.cpp b/src/Dialect/ONNX/Transforms/ONNXOpTransformPass.cpp similarity index 100% rename from src/Transform/ONNX/ONNXOpTransformPass.cpp rename to src/Dialect/ONNX/Transforms/ONNXOpTransformPass.cpp diff --git a/src/Transform/ONNX/ONNXPreKrnlVerifyPass.cpp b/src/Dialect/ONNX/Transforms/ONNXPreKrnlVerifyPass.cpp similarity index 98% rename from src/Transform/ONNX/ONNXPreKrnlVerifyPass.cpp rename to src/Dialect/ONNX/Transforms/ONNXPreKrnlVerifyPass.cpp index 9eee9b8828..20b840a847 100644 --- a/src/Transform/ONNX/ONNXPreKrnlVerifyPass.cpp +++ b/src/Dialect/ONNX/Transforms/ONNXPreKrnlVerifyPass.cpp @@ -22,7 +22,6 @@ #include "llvm/ADT/SmallPtrSet.h" #include "llvm/Support/raw_ostream.h" -#include "src/Dialect/Krnl/KrnlOps.hpp" #include "src/Dialect/ONNX/ONNXOps.hpp" #include "src/Interface/ShapeInferenceOpInterface.hpp" #include "src/Pass/Passes.hpp" diff --git a/src/Transform/ONNX/Recompose.cpp b/src/Dialect/ONNX/Transforms/Recompose.cpp similarity index 99% rename from src/Transform/ONNX/Recompose.cpp rename to src/Dialect/ONNX/Transforms/Recompose.cpp index 282f4aa9cc..e95bdb3a3d 100644 --- a/src/Transform/ONNX/Recompose.cpp +++ b/src/Dialect/ONNX/Transforms/Recompose.cpp @@ -20,9 +20,6 @@ // //===----------------------------------------------------------------------===// -#include "src/Transform/ONNX/Recompose.hpp" -#include "src/Pass/Passes.hpp" - #include "mlir/IR/Matchers.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Pass/Pass.h" @@ -33,6 +30,8 @@ #include "src/Dialect/ONNX/ONNXOps.hpp" #include "src/Dialect/ONNX/ONNXOps/OpHelper.hpp" #include "src/Dialect/ONNX/ONNXOps/ShapeHelper.hpp" +#include "src/Dialect/ONNX/Transforms/Recompose.hpp" +#include "src/Pass/Passes.hpp" #include "src/Support/TypeUtilities.hpp" #define DEBUG_TYPE "recompose" @@ -41,7 +40,7 @@ using namespace mlir; namespace { /// Include the patterns defined in the Declarative Rewrite framework. -// #include "src/Transform/ONNX/ONNXRecompose.inc" +// #include "src/Dialect/ONNX/Transforms/ONNXRecompose.inc" struct RecomposeLayerNormFromMulPattern : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; diff --git a/src/Transform/ONNX/Recompose.hpp b/src/Dialect/ONNX/Transforms/Recompose.hpp similarity index 100% rename from src/Transform/ONNX/Recompose.hpp rename to src/Dialect/ONNX/Transforms/Recompose.hpp diff --git a/src/Transform/ONNX/ScrubDisposablePass.cpp b/src/Dialect/ONNX/Transforms/ScrubDisposablePass.cpp similarity index 100% rename from src/Transform/ONNX/ScrubDisposablePass.cpp rename to src/Dialect/ONNX/Transforms/ScrubDisposablePass.cpp diff --git a/src/Transform/ONNX/SetONNXNodeName.cpp b/src/Dialect/ONNX/Transforms/SetONNXNodeName.cpp similarity index 100% rename from src/Transform/ONNX/SetONNXNodeName.cpp rename to src/Dialect/ONNX/Transforms/SetONNXNodeName.cpp diff --git a/src/Transform/ONNX/ShapeInference.cpp b/src/Dialect/ONNX/Transforms/ShapeInference.cpp similarity index 100% rename from src/Transform/ONNX/ShapeInference.cpp rename to src/Dialect/ONNX/Transforms/ShapeInference.cpp diff --git a/src/Transform/ONNX/ShapeInference.hpp b/src/Dialect/ONNX/Transforms/ShapeInference.hpp similarity index 100% rename from src/Transform/ONNX/ShapeInference.hpp rename to src/Dialect/ONNX/Transforms/ShapeInference.hpp diff --git a/src/Transform/ONNX/ShapeInferencePass.cpp b/src/Dialect/ONNX/Transforms/ShapeInferencePass.cpp similarity index 96% rename from src/Transform/ONNX/ShapeInferencePass.cpp rename to src/Dialect/ONNX/Transforms/ShapeInferencePass.cpp index c7b0f31dda..bc43bfbb2b 100644 --- a/src/Transform/ONNX/ShapeInferencePass.cpp +++ b/src/Dialect/ONNX/Transforms/ShapeInferencePass.cpp @@ -19,8 +19,8 @@ #include "mlir/Pass/Pass.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "src/Dialect/ONNX/Transforms/ShapeInference.hpp" #include "src/Pass/Passes.hpp" -#include "src/Transform/ONNX/ShapeInference.hpp" using namespace mlir; diff --git a/src/Transform/ONNX/SimplifyShapeRelatedOps.cpp b/src/Dialect/ONNX/Transforms/SimplifyShapeRelatedOps.cpp similarity index 100% rename from src/Transform/ONNX/SimplifyShapeRelatedOps.cpp rename to src/Dialect/ONNX/Transforms/SimplifyShapeRelatedOps.cpp diff --git a/src/Transform/ONNX/StandardFuncReturnPass.cpp b/src/Dialect/ONNX/Transforms/StandardFuncReturnPass.cpp similarity index 97% rename from src/Transform/ONNX/StandardFuncReturnPass.cpp rename to src/Dialect/ONNX/Transforms/StandardFuncReturnPass.cpp index a58bb7746c..d9196d04ab 100644 --- a/src/Transform/ONNX/StandardFuncReturnPass.cpp +++ b/src/Dialect/ONNX/Transforms/StandardFuncReturnPass.cpp @@ -16,8 +16,8 @@ #include "mlir/Transforms/Passes.h" #include "src/Dialect/ONNX/ONNXOps.hpp" +#include "src/Dialect/ONNX/Transforms/ShapeInference.hpp" #include "src/Pass/Passes.hpp" -#include "src/Transform/ONNX/ShapeInference.hpp" using namespace mlir; diff --git a/src/Pass/Passes.hpp b/src/Pass/Passes.hpp index 754f2ee179..819f2845c4 100644 --- a/src/Pass/Passes.hpp +++ b/src/Pass/Passes.hpp @@ -96,12 +96,6 @@ std::unique_ptr createLowerToStablehloPass(); std::unique_ptr createLowerToStablehloPass(bool enableUnroll); #endif -/// Pass for lowering krnl.dim operations to standard dialect. -std::unique_ptr createDisconnectKrnlDimFromAllocPass(); - -/// Pass for lowering krnl.shape operation. -std::unique_ptr createLowerKrnlShapePass(); - /// Pass for eliding the values of global Krnl operations. std::unique_ptr createElideConstGlobalValuePass(); diff --git a/src/Tools/onnx-mlir-opt/RegisterPasses.cpp b/src/Tools/onnx-mlir-opt/RegisterPasses.cpp index 34534d4564..2e4f8e2c98 100644 --- a/src/Tools/onnx-mlir-opt/RegisterPasses.cpp +++ b/src/Tools/onnx-mlir-opt/RegisterPasses.cpp @@ -96,10 +96,6 @@ void registerOMPasses(int optLevel) { return createProcessScfParallelPrivatePass(); }); - mlir::registerPass([]() -> std::unique_ptr { - return createElideConstGlobalValuePass(); - }); - mlir::registerPass([]() -> std::unique_ptr { return krnl::createConvertSeqToMemrefPass(); }); @@ -112,14 +108,6 @@ void registerOMPasses(int optLevel) { return krnl::createConvertKrnlToLLVMPass(); }); - mlir::registerPass([]() -> std::unique_ptr { - return createDisconnectKrnlDimFromAllocPass(); - }); - - mlir::registerPass([]() -> std::unique_ptr { - return createLowerKrnlShapePass(); - }); - mlir::registerPass([]() -> std::unique_ptr { return createSimplifyShapeRelatedOpsPass(); }); diff --git a/src/Transform/CMakeLists.txt b/src/Transform/CMakeLists.txt index b794e6c36a..240f74b4e5 100644 --- a/src/Transform/CMakeLists.txt +++ b/src/Transform/CMakeLists.txt @@ -1,43 +1,30 @@ # SPDX-License-Identifier: Apache-2.0 -add_subdirectory(ONNX) - -add_onnx_mlir_library(OMElideKrnlGlobalConstants - ElideKrnlGlobalConstants.cpp - - LINK_LIBS PUBLIC - OMKrnlOps - MLIRTransformUtils - ) - -add_onnx_mlir_library(OMDisconnectKrnlDimFromAlloc - DisconnectKrnlDimFromAlloc.cpp +add_onnx_mlir_library(OMLowerKrnlRegion + LowerKrnlRegion.cpp LINK_LIBS PUBLIC OMSupport MLIRTransformUtils ) -add_onnx_mlir_library(OMLowerKrnlShape - LowerKrnlShape.cpp + add_onnx_mlir_library(OMScfParallelPrivateRegion + ProcessScfParallelPrivate.cpp LINK_LIBS PUBLIC OMSupport MLIRTransformUtils ) -add_onnx_mlir_library(OMLowerKrnlRegion - LowerKrnlRegion.cpp +add_onnx_mlir_library(OMInstrument + InstrumentPass.cpp - LINK_LIBS PUBLIC - OMSupport - MLIRTransformUtils - ) - - add_onnx_mlir_library(OMScfParallelPrivateRegion - ProcessScfParallelPrivate.cpp + INCLUDE_DIRS PUBLIC + ${ONNX_MLIR_SRC_ROOT}/include LINK_LIBS PUBLIC - OMSupport - MLIRTransformUtils + OMONNXOps + OMKrnlOps + MLIRPass + OMOptionUtils ) diff --git a/src/Transform/DisconnectKrnlDimFromAlloc.cpp b/src/Transform/DisconnectKrnlDimFromAlloc.cpp deleted file mode 100644 index 92351d7edd..0000000000 --- a/src/Transform/DisconnectKrnlDimFromAlloc.cpp +++ /dev/null @@ -1,159 +0,0 @@ -/* - * SPDX-License-Identifier: Apache-2.0 - */ - -//===-------- DisconnectKrnlDimFromAlloc.cpp ------------------------------===// -// -// Copyright 2019-2020 The IBM Research Authors. -// -// ============================================================================= -// -// This pass enables the lowering of the krnl.dim operation to a series of -// instruction which do not depend on the alloc of the MemRef whose dim is -// being taken. The krnl.dim operation works in the presence of MemRefs -// which contain affine maps by ignoring the map if present. -// -//===----------------------------------------------------------------------===// - -#include "mlir/Dialect/Affine/IR/AffineOps.h" -#include "mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/Pass/Pass.h" -#include "mlir/Transforms/GreedyPatternRewriteDriver.h" - -#include "src/Dialect/Krnl/KrnlOps.hpp" -#include "src/Dialect/Mlir/DialectBuilder.hpp" -#include "src/Pass/Passes.hpp" -#include "src/Support/KrnlSupport.hpp" - -using namespace mlir; -using namespace onnx_mlir; - -namespace { - -/*! - * RewritePattern that replaces: - * %0 = alloc(%d) : memref, #map> - * %1 = krnl.dim(%0, 0) : (memref, #map>, index) -> index - * %2 = krnl.dim(%0, 1) : (memref, #map>, index) -> index - * %3 = add %1, %2 - * with: - * %0 = alloc(%d) : memref, #map> - * %2 = constant 10 : index - * %3 = add %d, %2 - * - * When the first argument of the krnl.dim is an input argument - * i.e. it is not the output of an alloc operation, we emit either - * the constant or the strandard dim operation depending on whether - * the dimension is static or dynamic. - * - * function(%arg0 : memref>) { - * %0 = krnl.dim(%arg0, 0) : (memref>, index) -> index - * %1 = krnl.dim(%arg0, 1) : memref> - * } - * - * - * becomes: - * - * function(%arg0 : memref>) { - * %0 = dim %arg0, 0 : (memref>, index) -> index - * %1 = constant 10 : index - * } - * - * The following case is not supported: - * - * function(%arg0 : memref, #map>) { - * %0 = krnl.dim(%arg0, 0) : (memref, #map>, index) -> index - * } - */ - -class DisconnectKrnlDimFromAlloc : public OpRewritePattern { -public: - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite( - KrnlDimOp krnlDimOp, PatternRewriter &rewriter) const override { - Location loc = krnlDimOp.getLoc(); - - // If index is not constant, return failure. - arith::ConstantOp indexOp = - dyn_cast(krnlDimOp.getIndex().getDefiningOp()); - if (!indexOp) - return failure(); - - // Get the integer value of the index. - int64_t index = indexOp->getAttrOfType("value").getInt(); - - // Get the shape of the MemRef argument. - auto memRefType = krnlDimOp.getAlloc().getType().dyn_cast(); - auto memRefShape = memRefType.getShape(); - int64_t rank = memRefShape.size(); - assert(index >= 0 && index < rank && "Index must be in bounds"); - - // Get the defining operation of the first argument of krnl.dim. - // If this operation is not an alloc, and the value comes from the - // list of input arguments, the support is limited to MemRefs without - // maps. - auto firstArgDefOp = krnlDimOp.getAlloc().getDefiningOp(); - - MultiDialectBuilder create(rewriter, loc); - - Value result; - if (!memRefType.isDynamicDim(index)) { - // If dimension is static, then we can just emit the constant value. - result = create.math.constantIndex(memRefShape[index]); - } else if (firstArgDefOp && isa(firstArgDefOp)) { - // Get defining operation for the MemRef argument. - memref::AllocOp allocOp = - dyn_cast(krnlDimOp.getAlloc().getDefiningOp()); - - // If dimension is dynamic we need to return the input alloc Value which - // corresponds to it. - int64_t dynDimIdx = getAllocArgIndex(allocOp, index); - assert(dynDimIdx >= 0 && - dynDimIdx < (int64_t)allocOp.getOperands().size() && - "Dynamic index outside range of alloc argument list."); - result = allocOp.getOperands()[dynDimIdx]; - } else if (memRefType.getLayout().isIdentity()) { - // Use a standard DimOp since no map is present. - result = create.mem.dim(krnlDimOp.getAlloc(), krnlDimOp.getIndex()); - } else - llvm_unreachable( - "dynamic sized MemRef with map must be defined by an AllocOp"); - - rewriter.replaceOp(krnlDimOp, result); - - return success(); - } -}; - -/*! - * Function pass that disconnects krnl.dim emission from its MemRef alloc. - */ -class DisconnectKrnlDimFromAllocPass - : public PassWrapper> { -public: - MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(DisconnectKrnlDimFromAllocPass) - - StringRef getArgument() const override { return "lower-krnl-shape-to-std"; } - - StringRef getDescription() const override { - return "Lowers krnl shape-related operations."; - } - - void runOnOperation() override { - auto function = getOperation(); - - ConversionTarget target(getContext()); - RewritePatternSet patterns(&getContext()); - patterns.insert(&getContext()); - - if (failed(applyPatternsAndFoldGreedily(function, std::move(patterns)))) - signalPassFailure(); - } -}; -} // namespace - -std::unique_ptr onnx_mlir::createDisconnectKrnlDimFromAllocPass() { - return std::make_unique(); -} diff --git a/src/Transform/ElideKrnlGlobalConstants.cpp b/src/Transform/ElideKrnlGlobalConstants.cpp deleted file mode 100644 index 5c5bca4d8c..0000000000 --- a/src/Transform/ElideKrnlGlobalConstants.cpp +++ /dev/null @@ -1,115 +0,0 @@ -/* - * SPDX-License-Identifier: Apache-2.0 - */ - -//===- ElideKrnlGlobalConstants.cpp - Krnl Constant lobal Value Elision ---===// -// -// Copyright 2019-2022 The IBM Research Authors. -// -// ============================================================================= -// -// In practice, the constant values of Global Krnl operations may be large -// enough to hinder the readability of the MLIR intermediate representation. -// -// This file creates a pass which elides the explicit values of constant -// global operations. This pass has purely cosmetic purposes and should only be -// run to obtain a compact representation of the program when emitting Krnl -// dialect code. This pass should never be invoked on code meant to be run. -// -//===----------------------------------------------------------------------===// - -#include "mlir/Dialect/Affine/IR/AffineOps.h" -#include "mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/Pass/Pass.h" -#include "mlir/Transforms/GreedyPatternRewriteDriver.h" - -#include "src/Dialect/Krnl/DialectBuilder.hpp" -#include "src/Dialect/Krnl/KrnlOps.hpp" -#include "src/Pass/Passes.hpp" -#include "src/Support/KrnlSupport.hpp" - -#include "ElideKrnlGlobalConstants.hpp" - -using namespace mlir; -using namespace onnx_mlir; - -constexpr uint64_t KrnlConstGlobalValueElision::kDefaultElisionThreshold; - -mlir::LogicalResult KrnlConstGlobalValueElision::matchAndRewrite( - mlir::KrnlGlobalOp op, mlir::PatternRewriter &rewriter) const { - Location loc = op.getLoc(); - - // Only elide if value is available. - if (!op.getValue().has_value()) - return success(); - - // Only elide dense and dense resource attributes. - if (!(op.getValue()->isa() || - op.getValue()->isa())) - return success(); - - MultiDialectBuilder create(rewriter, loc); - - bool elide = false; - - if (op.getValue()->isa()) { - const auto &valAttr = - op.getValueAttr().dyn_cast_or_null(); - if (valAttr.getNumElements() > elisionThreshold && !valAttr.isSplat()) { - elide = true; - } - } else { - const auto &valAttr = - op.getValueAttr().dyn_cast_or_null(); - if (valAttr.getNumElements() > elisionThreshold) { - elide = true; - } - } - - if (elide) { - IntegerAttr offsetAttr = op.getOffset() ? op.getOffsetAttr() : nullptr; - IntegerAttr alignmentAttr = - op.getAlignment() ? op.getAlignmentAttr() : nullptr; - auto newGlobalOp = - create.krnl.constant(op.getResult().getType().cast(), - op.getName(), std::nullopt, offsetAttr, alignmentAttr); - rewriter.replaceOp(op, newGlobalOp); - } - - return success(); -} - -namespace { -/*! - * Function pass that performs constant value elision of Krnl globals. - */ -class ElideConstGlobalValuePass : public PassWrapper> { -public: - MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(ElideConstGlobalValuePass) - - StringRef getArgument() const override { return "elide-krnl-constants"; } - - StringRef getDescription() const override { - return "Elide the constant values of the Global Krnl operations."; - } - - void runOnOperation() override { - auto function = getOperation(); - - ConversionTarget target(getContext()); - RewritePatternSet patterns(&getContext()); - patterns.insert( - &getContext(), KrnlConstGlobalValueElision::kDefaultElisionThreshold); - // No need to test, its ok to fail the apply. - LogicalResult res = - applyPatternsAndFoldGreedily(function, std::move(patterns)); - assert((succeeded(res) || failed(res)) && "remove unused var warning"); - } -}; - -} // namespace - -std::unique_ptr onnx_mlir::createElideConstGlobalValuePass() { - return std::make_unique(); -} diff --git a/src/Transform/ElideKrnlGlobalConstants.hpp b/src/Transform/ElideKrnlGlobalConstants.hpp deleted file mode 100644 index 9cdaf7c9b1..0000000000 --- a/src/Transform/ElideKrnlGlobalConstants.hpp +++ /dev/null @@ -1,39 +0,0 @@ -/* - * SPDX-License-Identifier: Apache-2.0 - */ - -#pragma once - -#include "mlir/Pass/Pass.h" -#include "mlir/Transforms/DialectConversion.h" - -#include "src/Dialect/Krnl/KrnlOps.hpp" - -/*! - * RewritePattern that replaces existing constant Krnl global values - * with a similar operation which preserves all attributes except the value - * attribute. - */ -class KrnlConstGlobalValueElision - : public mlir::OpRewritePattern { -public: - /* - * A threshold value specifying the maximum number of elements a constant - * operation can hold as an attribute. If the number exceeds this threshold, - * constants will be packed together and, in the case where `move-to-file` - * option is enabled, stored as a binary file on disk. This can help preserve - * readability of IR dump and improve compilation speed. - */ - static constexpr uint64_t kDefaultElisionThreshold = 32; - - int64_t elisionThreshold; - - using mlir::OpRewritePattern::OpRewritePattern; - - explicit KrnlConstGlobalValueElision( - mlir::MLIRContext *context, int64_t elisionThreshold) - : OpRewritePattern(context), elisionThreshold(elisionThreshold) {} - - mlir::LogicalResult matchAndRewrite( - mlir::KrnlGlobalOp op, mlir::PatternRewriter &rewriter) const override; -}; diff --git a/src/Transform/ONNX/InstrumentPass.cpp b/src/Transform/InstrumentPass.cpp similarity index 100% rename from src/Transform/ONNX/InstrumentPass.cpp rename to src/Transform/InstrumentPass.cpp diff --git a/src/Transform/LowerKrnlShape.cpp b/src/Transform/LowerKrnlShape.cpp deleted file mode 100644 index f43840e5b0..0000000000 --- a/src/Transform/LowerKrnlShape.cpp +++ /dev/null @@ -1,113 +0,0 @@ -/* - * SPDX-License-Identifier: Apache-2.0 - */ - -//===-------- LowerKrnlShape.cpp ------------------------------------------===// -// -// Copyright 2019-2022 The IBM Research Authors. -// -// ============================================================================= -// -// This pass enables the lowering of the krnl.shape operation to use Shape -// dialect operations. -// -//===----------------------------------------------------------------------===// - -#include "mlir/Dialect/Affine/IR/AffineOps.h" -#include "mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/Pass/Pass.h" -#include "mlir/Transforms/GreedyPatternRewriteDriver.h" - -#include "src/Dialect/Krnl/DialectBuilder.hpp" -#include "src/Dialect/Krnl/KrnlOps.hpp" -#include "src/Pass/Passes.hpp" -#include "src/Support/KrnlSupport.hpp" - -using namespace mlir; -using namespace onnx_mlir; - -namespace { - -/*! - * RewritePattern that replaces: - * %0 = alloc(%d) : memref, #map> - * %1 = krnl.shape(%0) : memref> -> !shape.shape - * with: - * %0 = alloc(%d) : memref, #map> - * %c0 = constant 0 : index - * %1 = krnl.dim(%0, %c0) : memref, #map>, index - * %c1 = constant 1 : index - * %2 = krnl.dim(%0, %c1) : memref, #map>, index - * %shape = shape.from_extents %1, %2 - */ - -class LowerKrnlShape : public OpRewritePattern { -public: - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite( - KrnlShapeOp krnlShapeOp, PatternRewriter &rewriter) const override { - Location loc = krnlShapeOp.getLoc(); - size_t rank = krnlShapeOp.getAlloc() - .getType() - .dyn_cast() - .getShape() - .size(); - - MultiDialectBuilder create( - rewriter, loc); - - // Create MemRef to hold shape information. - auto memRefType = - MemRefType::get({static_cast(rank)}, rewriter.getIndexType()); - memref::AllocOp newMemRefAlloc = create.mem.alloc(memRefType); - - for (size_t idx = 0; idx < rank; idx++) { - Value index = create.math.constantIndex(idx); - Value operand = create.krnl.dim( - rewriter.getIndexType(), krnlShapeOp.getAlloc(), index); - - // Store value in the new MemRef. - Value idxValue = create.math.constant(rewriter.getIndexType(), idx); - SmallVector indexArg = {idxValue}; - rewriter.create( - loc, operand, newMemRefAlloc, indexArg); - } - - rewriter.replaceOp(krnlShapeOp, newMemRefAlloc.getResult()); - - return success(); - } -}; - -/*! - * Function pass that emits the shape of a MemRef. - */ -class LowerKrnlShapePass - : public PassWrapper> { -public: - MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(LowerKrnlShapePass) - - StringRef getArgument() const override { return "lower-krnl-shape"; } - - StringRef getDescription() const override { - return "Lower krnl.shape operation to use Shape dialect operations."; - } - - void runOnOperation() override { - auto function = getOperation(); - - ConversionTarget target(getContext()); - RewritePatternSet patterns(&getContext()); - patterns.insert(&getContext()); - - if (failed(applyPatternsAndFoldGreedily(function, std::move(patterns)))) - signalPassFailure(); - } -}; -} // namespace - -// TODO: integrate with other passes if needed. -std::unique_ptr onnx_mlir::createLowerKrnlShapePass() { - return std::make_unique(); -} diff --git a/test/backend/common.py b/test/backend/common.py index a0ea488299..9f8aca6b15 100644 --- a/test/backend/common.py +++ b/test/backend/common.py @@ -63,7 +63,7 @@ def execute_commands(cmds, dynamic_inputs_dims): first_dim = False else: env_string += "," + str(dim_index) - my_env["TEST_IMPORTER_FORCE_DYNAMIC"] = env_string + my_env["IMPORTER_FORCE_DYNAMIC"] = env_string subprocess.run(cmds, env=my_env, check=True) diff --git a/test/backend/inference_backend.py b/test/backend/inference_backend.py index 1d9d9605c5..b73c571ccf 100644 --- a/test/backend/inference_backend.py +++ b/test/backend/inference_backend.py @@ -394,7 +394,6 @@ def get_test_models(): }, # ==OP== Cast # ==MIN== 6 - # ==UNSUPPORTED== 19 # ==LIM== Cast only between float and double types. Only ppc64le and MacOS platforms support float16. "test_cast_FLOAT_to_DOUBLE_cpu": { STATIC_SHAPE: {}, @@ -437,35 +436,41 @@ def get_test_models(): # ==LIM== CastLike only between float and double types. Only ppc64le and MacOS platforms support float16. "test_castlike_FLOAT_to_DOUBLE_cpu": { STATIC_SHAPE: {}, - DYNAMIC_SHAPE: {-1: {-1}}, + # Issue #2639: Dynamic test fails. Need to be fixed. + # DYNAMIC_SHAPE: {-1: {-1}}, CONSTANT_INPUT: {-1}, }, "test_castlike_DOUBLE_to_FLOAT_cpu": { STATIC_SHAPE: {}, - DYNAMIC_SHAPE: {-1: {-1}}, + # Issue #2639: Dynamic test fails. Need to be fixed. + # DYNAMIC_SHAPE: {-1: {-1}}, CONSTANT_INPUT: {-1}, }, "test_castlike_FLOAT_to_FLOAT16_cpu": { STATIC_SHAPE: {}, - DYNAMIC_SHAPE: {-1: {-1}}, + # Issue #2639: Dynamic test fails. Need to be fixed. + # DYNAMIC_SHAPE: {-1: {-1}}, CONSTANT_INPUT: {-1}, FLOAT16: {}, }, "test_castlike_FLOAT16_to_FLOAT_cpu": { STATIC_SHAPE: {}, - DYNAMIC_SHAPE: {-1: {-1}}, + # Issue #2639: Dynamic test fails. Need to be fixed. + # DYNAMIC_SHAPE: {-1: {-1}}, CONSTANT_INPUT: {-1}, FLOAT16: {}, }, "test_castlike_FLOAT16_to_DOUBLE_cpu": { STATIC_SHAPE: {}, - DYNAMIC_SHAPE: {-1: {-1}}, + # Issue #2639: Dynamic test fails. Need to be fixed. + # DYNAMIC_SHAPE: {-1: {-1}}, CONSTANT_INPUT: {-1}, FLOAT16: {}, }, "test_castlike_DOUBLE_to_FLOAT16_cpu": { STATIC_SHAPE: {}, - DYNAMIC_SHAPE: {-1: {-1}}, + # Issue #2639: Dynamic test fails. Need to be fixed. + # DYNAMIC_SHAPE: {-1: {-1}}, CONSTANT_INPUT: {-1}, FLOAT16: {}, }, @@ -615,11 +620,11 @@ def get_test_models(): }, # ==OP== Constant # ==MIN== 1 - # ==UNSUPPORTED== 19 # By def, no dynamic shapes. "test_constant_cpu": {STATIC_SHAPE: {}}, # ==OP== ConstantOfShape # ==MIN== 9 + # ==UNSUPPORTED== 20 # By def, no dynamic shapes. "test_constantofshape_float_ones_cpu": {STATIC_SHAPE: {}}, "test_constantofshape_int_zeros_cpu": {STATIC_SHAPE: {}}, @@ -664,47 +669,56 @@ def get_test_models(): # TODO: Support unknown dimensions in spatial dimensions "test_convtranspose_1d_cpu": { STATIC_SHAPE: {}, - DYNAMIC_SHAPE: {0: {0, 1}, 1: {0, 1}}, + # Issue #2639: Dynamic test fails. Need to be fixed. + # DYNAMIC_SHAPE: {0: {0, 1}, 1: {0, 1}}, CONSTANT_INPUT: {1}, }, "test_convtranspose_3d_cpu": { STATIC_SHAPE: {}, - DYNAMIC_SHAPE: {0: {0, 1}, 1: {0, 1}}, + # Issue #2639: Dynamic test fails. Need to be fixed. + # DYNAMIC_SHAPE: {0: {0, 1}, 1: {0, 1}}, CONSTANT_INPUT: {1}, }, "test_convtranspose_autopad_same_cpu": { STATIC_SHAPE: {}, - DYNAMIC_SHAPE: {0: {0, 1}, 1: {0, 1}}, + # Issue #2639: Dynamic test fails. Need to be fixed. + # DYNAMIC_SHAPE: {0: {0, 1}, 1: {0, 1}}, CONSTANT_INPUT: {1}, }, "test_convtranspose_cpu": { STATIC_SHAPE: {}, - DYNAMIC_SHAPE: {0: {0, 1}, 1: {0, 1}}, + # Issue #2639: Dynamic test fails. Need to be fixed. + # DYNAMIC_SHAPE: {0: {0, 1}, 1: {0, 1}}, CONSTANT_INPUT: {1}, }, "test_convtranspose_dilations_cpu": { STATIC_SHAPE: {}, - DYNAMIC_SHAPE: {0: {0, 1}, 1: {0, 1}}, + # Issue #2639: Dynamic test fails. Need to be fixed. + # DYNAMIC_SHAPE: {0: {0, 1}, 1: {0, 1}}, CONSTANT_INPUT: {1}, }, "test_convtranspose_kernel_shape_cpu": { STATIC_SHAPE: {}, - DYNAMIC_SHAPE: {0: {0, 1}, 1: {0, 1}}, + # Issue #2639: Dynamic test fails. Need to be fixed. + # DYNAMIC_SHAPE: {0: {0, 1}, 1: {0, 1}}, CONSTANT_INPUT: {1}, }, "test_convtranspose_output_shape_cpu": { STATIC_SHAPE: {}, - DYNAMIC_SHAPE: {0: {0, 1}, 1: {0, 1}}, + # Issue #2639: Dynamic test fails. Need to be fixed. + # DYNAMIC_SHAPE: {0: {0, 1}, 1: {0, 1}}, CONSTANT_INPUT: {1}, }, "test_convtranspose_pad_cpu": { STATIC_SHAPE: {}, - DYNAMIC_SHAPE: {0: {0, 1}, 1: {0, 1}}, + # Issue #2639: Dynamic test fails. Need to be fixed. + # DYNAMIC_SHAPE: {0: {0, 1}, 1: {0, 1}}, CONSTANT_INPUT: {1}, }, "test_convtranspose_pads_cpu": { STATIC_SHAPE: {}, - DYNAMIC_SHAPE: {0: {0, 1}, 1: {0, 1}}, + # Issue #2639: Dynamic test fails. Need to be fixed. + # DYNAMIC_SHAPE: {0: {0, 1}, 1: {0, 1}}, CONSTANT_INPUT: {1}, }, # ==OP== Cos @@ -769,6 +783,8 @@ def get_test_models(): CONSTANT_INPUT: {-1}, }, # ==OP== DFT + # ==MIN== 17 + # ==UNSUPPORTED== 20 # "test_dft_axis_cpu": {STATIC_SHAPE:{}, DYNAMIC_SHAPE:{-1:{-1}}, CONSTANT_INPUT:{-1}}, # "test_dft_cpu": {STATIC_SHAPE:{}, DYNAMIC_SHAPE:{-1:{-1}}, CONSTANT_INPUT:{-1}}, # "test_dft_inverse_cpu": {STATIC_SHAPE:{}, DYNAMIC_SHAPE:{-1:{-1}}, CONSTANT_INPUT:{-1}}, @@ -776,12 +792,14 @@ def get_test_models(): # ==MIN== 13 "test_depthtospace_example_cpu": { STATIC_SHAPE: {}, - DYNAMIC_SHAPE: {-1: {-1}}, + # Issue #2639: Dynamic test fails. Need to be fixed. + # DYNAMIC_SHAPE: {-1: {-1}}, CONSTANT_INPUT: {-1}, }, "test_depthtospace_crd_mode_example_cpu": { STATIC_SHAPE: {}, - DYNAMIC_SHAPE: {-1: {-1}}, + # Issue #2639: Dynamic test fails. Need to be fixed. + # DYNAMIC_SHAPE: {-1: {-1}}, CONSTANT_INPUT: {-1}, }, # ==OP== DequantizeLinear @@ -875,27 +893,32 @@ def get_test_models(): # ==LIM== Limited to the types supported by ReduceSum and MatMul (which we decompose to in most cases) which exclude integers with width < 32 "test_einsum_batch_diagonal_cpu": { STATIC_SHAPE: {}, - DYNAMIC_SHAPE: {-1: {-1}}, + # Issue #2639: Dynamic test fails. Need to be fixed. + # DYNAMIC_SHAPE: {-1: {-1}}, CONSTANT_INPUT: {-1}, }, "test_einsum_batch_matmul_cpu": { STATIC_SHAPE: {}, - DYNAMIC_SHAPE: {-1: {-1}}, + # Issue #2639: Dynamic test fails. Need to be fixed. + # DYNAMIC_SHAPE: {-1: {-1}}, CONSTANT_INPUT: {-1}, }, "test_einsum_inner_prod_cpu": { STATIC_SHAPE: {}, - DYNAMIC_SHAPE: {-1: {-1}}, + # Issue #2639: Dynamic test fails. Need to be fixed. + # DYNAMIC_SHAPE: {-1: {-1}}, CONSTANT_INPUT: {-1}, }, "test_einsum_sum_cpu": { STATIC_SHAPE: {}, - DYNAMIC_SHAPE: {-1: {-1}}, + # Issue #2639: Dynamic test fails. Need to be fixed. + # DYNAMIC_SHAPE: {-1: {-1}}, CONSTANT_INPUT: {-1}, }, "test_einsum_transpose_cpu": { STATIC_SHAPE: {}, - DYNAMIC_SHAPE: {-1: {-1}}, + # Issue #2639: Dynamic test fails. Need to be fixed. + # DYNAMIC_SHAPE: {-1: {-1}}, CONSTANT_INPUT: {-1}, }, # ==OP== Elu @@ -917,7 +940,6 @@ def get_test_models(): }, # ==OP== Equal # ==MIN== 7 - # ==UNSUPPORTED== 19 "test_equal_cpu": { STATIC_SHAPE: {}, DYNAMIC_SHAPE: {-1: {-1}}, @@ -1072,17 +1094,20 @@ def get_test_models(): # ==MIN== 11 "test_gathernd_example_int32_cpu": { STATIC_SHAPE: {}, - DYNAMIC_SHAPE: {-1: {-1}}, + # Issue #2639: Dynamic test fails. Need to be fixed. + # DYNAMIC_SHAPE: {-1: {-1}}, CONSTANT_INPUT: {-1}, }, "test_gathernd_example_float32_cpu": { STATIC_SHAPE: {}, - DYNAMIC_SHAPE: {-1: {-1}}, + # Issue #2639: Dynamic test fails. Need to be fixed. + # DYNAMIC_SHAPE: {-1: {-1}}, CONSTANT_INPUT: {-1}, }, "test_gathernd_example_int32_batch_dim1_cpu": { STATIC_SHAPE: {}, - DYNAMIC_SHAPE: {-1: {-1}}, + # Issue #2639: Dynamic test fails. Need to be fixed. + # DYNAMIC_SHAPE: {-1: {-1}}, CONSTANT_INPUT: {-1}, }, # ==OP== Gelu @@ -1366,16 +1391,19 @@ def get_test_models(): # ==MIN== 6 "test_instancenorm_example_cpu": { STATIC_SHAPE: {}, - DYNAMIC_SHAPE: {-1: {-1}}, + # Issue #2639: Dynamic test fails. Need to be fixed. + # DYNAMIC_SHAPE: {-1: {-1}}, CONSTANT_INPUT: {-1}, }, "test_instancenorm_epsilon_cpu": { STATIC_SHAPE: {}, - DYNAMIC_SHAPE: {-1: {-1}}, + # Issue #2639: Dynamic test fails. Need to be fixed. + # DYNAMIC_SHAPE: {-1: {-1}}, CONSTANT_INPUT: {-1}, }, # ==OP== IsInf - # ==MIN== 10 + # ==MIN== 20 + # ==LIM== Currently no support for float16 infinity value. Only for float32 and float64. "test_isinf_cpu": { STATIC_SHAPE: {}, DYNAMIC_SHAPE: {-1: {-1}}, @@ -1391,13 +1419,15 @@ def get_test_models(): DYNAMIC_SHAPE: {-1: {-1}}, CONSTANT_INPUT: {-1}, }, + # "test_isinf_float16_cpu": {STATIC_SHAPE:{}, DYNAMIC_SHAPE:{-1:{-1}}, CONSTANT_INPUT:{-1}}, # ==OP== IsNaN - # ==MIN== 9 + # ==MIN== 20 "test_isnan_cpu": { STATIC_SHAPE: {}, DYNAMIC_SHAPE: {-1: {-1}}, CONSTANT_INPUT: {-1}, }, + # "test_isnan_float16_cpu": {STATIC_SHAPE:{}, DYNAMIC_SHAPE:{-1:{-1}}, CONSTANT_INPUT:{-1}}, # ==OP== LayerNormalization # ==MIN== 17 "test_layer_normalization_2d_axis0_cpu": { @@ -2244,21 +2274,23 @@ def get_test_models(): }, # ==OP== Pad # ==MIN== 2 - # ==UNSUPPORTED== 19 # ==LIM== axes input not supported "test_constant_pad_cpu": { STATIC_SHAPE: {}, - DYNAMIC_SHAPE: {-1: {-1}}, + # Issue #2639: Dynamic test fails. Need to be fixed. + # DYNAMIC_SHAPE: {-1: {-1}}, CONSTANT_INPUT: {-1}, }, "test_edge_pad_cpu": { STATIC_SHAPE: {}, - DYNAMIC_SHAPE: {-1: {-1}}, + # Issue #2639: Dynamic test fails. Need to be fixed. + # DYNAMIC_SHAPE: {-1: {-1}}, CONSTANT_INPUT: {-1}, }, "test_reflect_pad_cpu": { STATIC_SHAPE: {}, - DYNAMIC_SHAPE: {-1: {-1}}, + # Issue #2639: Dynamic test fails. Need to be fixed. + # DYNAMIC_SHAPE: {-1: {-1}}, CONSTANT_INPUT: {-1}, }, # ==OP== Pow @@ -2298,7 +2330,6 @@ def get_test_models(): }, # ==OP== QuantizeLinear # ==MIN== 10 - # ==UNSUPPORTED== 19 # ==LIM== Do not support per-axis and i8 quantization. # "test_quantizelinear_axis_cpu": {STATIC_SHAPE:{}, DYNAMIC_SHAPE:{-1:{-1}}, CONSTANT_INPUT:{-1}}, "test_quantizelinear_cpu": { @@ -2430,6 +2461,7 @@ def get_test_models(): }, # ==OP== ReduceMax # ==MIN== 1 + # ==UNSUPPORTED== 20 # ==LIM== do_not_keep_dim not supported. # "test_reduce_max_default_axes_keepdim_example_cpu": {STATIC_SHAPE:{}, DYNAMIC_SHAPE:{-1:{-1}}, CONSTANT_INPUT:{-1}}, # "test_reduce_max_default_axes_keepdims_random_cpu": {STATIC_SHAPE:{}, DYNAMIC_SHAPE:{-1:{-1}}, CONSTANT_INPUT:{-1}}, @@ -2484,6 +2516,7 @@ def get_test_models(): }, # ==OP== ReduceMin # ==MIN== 1 + # ==UNSUPPORTED== 20 # ==LIM== do_not_keep_dim not supported. # "test_reduce_min_default_axes_keepdims_example_cpu": {STATIC_SHAPE:{}, DYNAMIC_SHAPE:{-1:{-1}}, CONSTANT_INPUT:{-1}}, # "test_reduce_min_default_axes_keepdims_random_cpu": {STATIC_SHAPE:{}, DYNAMIC_SHAPE:{-1:{-1}}, CONSTANT_INPUT:{-1}}, @@ -2653,7 +2686,6 @@ def get_test_models(): }, # ==OP== Resize # ==MIN== 10 - # ==UNSUPPORTED== 19 # ==LIM== Missing support for linear, cubic, crop, pytorch_half_pixel, and floor. Attributes antialias, axes and keep_aspect_ratio_policy are not supported. # Resize # All test cases in onnx v1.11.0. yes for currently supported @@ -2908,7 +2940,6 @@ def get_test_models(): }, # ==OP== Slice # ==MIN== 13 - # ==UNSUPPORTED== 19 # ==LIM== Axis must be a constant argument. # ==TODO== Add tests to slices, currently have none. # (makes Axis a runtime argument, which is not supported). @@ -2987,12 +3018,12 @@ def get_test_models(): # "test_spacetodepth_cpu": {STATIC_SHAPE:{}, DYNAMIC_SHAPE:{-1:{-1}}, CONSTANT_INPUT:{-1}}, "test_spacetodepth_example_cpu": { STATIC_SHAPE: {}, - DYNAMIC_SHAPE: {-1: {-1}}, + # Issue #2639: Dynamic test fails. Need to be fixed. + # DYNAMIC_SHAPE: {-1: {-1}}, CONSTANT_INPUT: {-1}, }, # ==OP== Split # ==MIN== 2 - # ==UNSUPPORTED== 19 # ==LIM== Does not support static and dynamic shape, zero size splits. # ==TODO== Temporally removed due to changes in onnx 1.8.1 # "test_split_equal_parts_1d_cpu": {STATIC_SHAPE:{}, DYNAMIC_SHAPE:{-1:{-1}}, CONSTANT_INPUT:{-1}}, @@ -3107,17 +3138,17 @@ def get_test_models(): # ==MIN== 10 "test_top_k_cpu": { STATIC_SHAPE: {}, - DYNAMIC_SHAPE: {-1: {-1}}, + DYNAMIC_SHAPE: {0: {-1}}, CONSTANT_INPUT: {-1}, }, "test_top_k_smallest_cpu": { STATIC_SHAPE: {}, - DYNAMIC_SHAPE: {-1: {-1}}, + DYNAMIC_SHAPE: {0: {-1}}, CONSTANT_INPUT: {-1}, }, "test_top_k_negative_axis_cpu": { STATIC_SHAPE: {}, - DYNAMIC_SHAPE: {-1: {-1}}, + DYNAMIC_SHAPE: {0: {-1}}, CONSTANT_INPUT: {-1}, }, # ==OP== Transpose @@ -3355,12 +3386,12 @@ def get_test_models(): # float16 LLVM instructions that are unsupported on some platforms. "test_onnxmlir_top_k_float16_cpu": { STATIC_SHAPE: {}, - DYNAMIC_SHAPE: {-1: {-1}}, + DYNAMIC_SHAPE: {0: {-1}}, CONSTANT_INPUT: {-1}, }, "test_onnxmlir_top_k_smallest_float16_cpu": { STATIC_SHAPE: {}, - DYNAMIC_SHAPE: {-1: {-1}}, + DYNAMIC_SHAPE: {0: {-1}}, CONSTANT_INPUT: {-1}, }, } diff --git a/test/backend/variables.py b/test/backend/variables.py index 7923aa01eb..5cd9b0a416 100644 --- a/test/backend/variables.py +++ b/test/backend/variables.py @@ -44,7 +44,6 @@ def get_args_from_env(): TEST_VERBOSE = os.getenv("TEST_VERBOSE") TEST_CASE_CHECK = os.getenv("TEST_CASE_CHECK") TEST_INVOKECONVERTER = os.getenv("TEST_INVOKECONVERTER") - TEST_IMPORTER_FORCE_DYNAMIC = os.getenv("TEST_IMPORTER_FORCE_DYNAMIC") # Force input tensors to constants. Set this to a list of input indices. # E.g. # - "0, 2" for the first and third input tensors. diff --git a/test/mlir/conversion/krnl_to_llvm/krnl_category_mapper.mlir b/test/mlir/conversion/krnl_to_llvm/krnl_category_mapper.mlir index ebb2954990..9711b01c79 100644 --- a/test/mlir/conversion/krnl_to_llvm/krnl_category_mapper.mlir +++ b/test/mlir/conversion/krnl_to_llvm/krnl_category_mapper.mlir @@ -78,22 +78,19 @@ func.func private @test_category_mapper_string_to_int64(%arg0: memref<2x2x!krnl. // CHECK-DAG: llvm.func @strncmp(!llvm.ptr, !llvm.ptr, i64) -> i32 // CHECK-DAG: llvm.func @strlen(!llvm.ptr) -> i64 // CHECK-DAG: llvm.func @find_index_str(!llvm.ptr, !llvm.ptr, !llvm.ptr, i32) -> i64 - // CHECK-DAG: llvm.mlir.global internal constant @om_cat("cat") - // CHECK-DAG: llvm.mlir.global internal constant @om_dog("dog") - // CHECK-DAG: llvm.mlir.global internal constant @om_cow("cow") - // CHECK: llvm.mlir.global internal constant @cats_strings{{.*}}() {addr_space = 0 : i32, alignment = 16 : i64} : !llvm.array<3 x ptr> { - // CHECK: [[ARRAY:%.+]] = llvm.mlir.undef : !llvm.array<3 x ptr> - // CHECK: [[CAT_ADDR:%.+]] = llvm.mlir.addressof @om_cat : !llvm.ptr - // CHECK: [[CAT_GEP:%.+]] = llvm.bitcast [[CAT_ADDR]] : !llvm.ptr to !llvm.ptr - // CHECK: [[CAT_INS_VAL:%.+]] = llvm.insertvalue [[CAT_GEP]], [[ARRAY]][0] : !llvm.array<3 x ptr> - // CHECK: [[DOG_ADDR:%.+]] = llvm.mlir.addressof @om_dog : !llvm.ptr - // CHECK: [[DOG_GEP:%.+]] = llvm.bitcast [[DOG_ADDR]] : !llvm.ptr to !llvm.ptr - // CHECK: [[DOG_INS_VAL:%.+]] = llvm.insertvalue [[DOG_GEP]], [[CAT_INS_VAL]][1] : !llvm.array<3 x ptr> - // CHECK: [[COW_ADDR:%.+]] = llvm.mlir.addressof @om_cow : !llvm.ptr - // CHECK: [[COW_GEP:%.+]] = llvm.bitcast [[COW_ADDR]] : !llvm.ptr to !llvm.ptr - // CHECK: [[COW_INS_VAL:%.+]] = llvm.insertvalue [[COW_GEP]], [[DOG_INS_VAL]][2] : !llvm.array<3 x ptr> - // CHECK: llvm.return [[COW_INS_VAL]] : !llvm.array<3 x ptr> - // CHECK: } + // CHECK-DAG: llvm.mlir.global internal constant @om.strArray.cats_strings("cat\00dog\00cow\00") {addr_space = 0 : i32} + // CHECK: llvm.mlir.global internal constant @cats_strings() {addr_space = 0 : i32, alignment = 16 : i64} : !llvm.array<3 x ptr> { + // CHECK: [[ARRAY:%.+]] = llvm.mlir.undef : !llvm.array<3 x ptr> + // CHECK: [[BASE_ADDR:%.+]] = llvm.mlir.addressof @om.strArray.cats_strings : !llvm.ptr + // CHECK: [[I8_BASE_ADDR:%.+]] = llvm.bitcast %1 : !llvm.ptr to !llvm.ptr + // CHECK: [[CAT_GEP:%.+]] = llvm.getelementptr [[I8_BASE_ADDR]][0] : (!llvm.ptr) -> !llvm.ptr, i8 + // CHECK: [[CAT_INS_VAL:%.+]] = llvm.insertvalue [[CAT_GEP]], [[ARRAY]][0] : !llvm.array<3 x ptr> + // CHECK: [[DOG_GEP:%.+]] = llvm.getelementptr [[I8_BASE_ADDR]][4] : (!llvm.ptr) -> !llvm.ptr, i8 + // CHECK: [[DOG_INS_VAL:%.+]] = llvm.insertvalue [[DOG_GEP]], [[CAT_INS_VAL]][1] : !llvm.array<3 x ptr> + // CHECK: [[COW_GEP:%.+]] = llvm.getelementptr [[I8_BASE_ADDR]][8] : (!llvm.ptr) -> !llvm.ptr, i8 + // CHECK: [[COW_INS_VAL:%.+]] = llvm.insertvalue [[COW_GEP]], [[DOG_INS_VAL]][2] : !llvm.array<3 x ptr> + // CHECK: llvm.return [[COW_INS_VAL]] : !llvm.array<3 x ptr> + // CHECK: } // CHECK-DAG: llvm.mlir.global internal constant @cats_int64s{{.*}}(dense<[1, 2, 3]> : tensor<3xi64>) {addr_space = 0 : i32, alignment = 16 : i64} : !llvm.array<3 x i64> // CHECK-DAG: llvm.mlir.global internal constant @V{{.*}}(dense<[1, 2, 0]> : tensor<3xi32>) {addr_space = 0 : i32, alignment = 16 : i64} : !llvm.array<3 x i32> // CHECK-DAG: llvm.mlir.global internal constant @G{{.*}}(dense<[1, 0, -3]> : tensor<3xi32>) {addr_space = 0 : i32, alignment = 16 : i64} : !llvm.array<3 x i32> @@ -161,23 +158,20 @@ func.func private @test_category_mapper_int64_to_string(%arg0: memref<2x2xi64>) return %0 : memref<2x2x!krnl.string> // CHECK-DAG: llvm.func @find_index_i64(i64, !llvm.ptr, !llvm.ptr, i32) -> i64 - // CHECK-DAG: llvm.mlir.global internal constant @om_none("none") - // CHECK-DAG: llvm.mlir.global internal constant @om_cat("cat") - // CHECK-DAG: llvm.mlir.global internal constant @om_dog("dog") - // CHECK-DAG: llvm.mlir.global internal constant @om_cow("cow") - // CHECK: llvm.mlir.global internal constant @cats_strings{{.*}}() {addr_space = 0 : i32, alignment = 16 : i64} : !llvm.array<3 x ptr> { - // CHECK: [[ARRAY:%.+]] = llvm.mlir.undef : !llvm.array<3 x ptr> - // CHECK: [[CAT_ADDR:%.+]] = llvm.mlir.addressof @om_cat : !llvm.ptr - // CHECK: [[CAT_GEP:%.+]] = llvm.bitcast [[CAT_ADDR]] : !llvm.ptr to !llvm.ptr - // CHECK: [[CAT_INS_VAL:%.+]] = llvm.insertvalue [[CAT_GEP]], [[ARRAY]][0] : !llvm.array<3 x ptr> - // CHECK: [[DOG_ADDR:%.+]] = llvm.mlir.addressof @om_dog : !llvm.ptr - // CHECK: [[DOG_GEP:%.+]] = llvm.bitcast [[DOG_ADDR]] : !llvm.ptr to !llvm.ptr - // CHECK: [[DOG_INS_VAL:%.+]] = llvm.insertvalue [[DOG_GEP]], [[CAT_INS_VAL]][1] : !llvm.array<3 x ptr> - // CHECK: [[COW_ADDR:%.+]] = llvm.mlir.addressof @om_cow : !llvm.ptr - // CHECK: [[COW_GEP:%.+]] = llvm.bitcast [[COW_ADDR]] : !llvm.ptr to !llvm.ptr - // CHECK: [[COW_INS_VAL:%.+]] = llvm.insertvalue [[COW_GEP]], [[DOG_INS_VAL]][2] : !llvm.array<3 x ptr> - // CHECK: llvm.return [[COW_INS_VAL]] : !llvm.array<3 x ptr> - // CHECK: } + // CHECK-DAG: llvm.mlir.global internal constant @om.strArray.default_string("none\00") {addr_space = 0 : i32} + // CHECK-DAG: llvm.mlir.global internal constant @om.strArray.cats_strings("cat\00dog\00cow\00") {addr_space = 0 : i32} + // CHECK: llvm.mlir.global internal constant @cats_strings() {addr_space = 0 : i32, alignment = 16 : i64} : !llvm.array<3 x ptr> { + // CHECK: [[ARRAY:%.+]] = llvm.mlir.undef : !llvm.array<3 x ptr> + // CHECK: [[BASE_ADDR:%.+]] = llvm.mlir.addressof @om.strArray.cats_strings : !llvm.ptr + // CHECK: [[I8_BASE_ADDR:%.+]] = llvm.bitcast %1 : !llvm.ptr to !llvm.ptr + // CHECK: [[CAT_GEP:%.+]] = llvm.getelementptr [[I8_BASE_ADDR]][0] : (!llvm.ptr) -> !llvm.ptr, i8 + // CHECK: [[CAT_INS_VAL:%.+]] = llvm.insertvalue [[CAT_GEP]], [[ARRAY]][0] : !llvm.array<3 x ptr> + // CHECK: [[DOG_GEP:%.+]] = llvm.getelementptr [[I8_BASE_ADDR]][4] : (!llvm.ptr) -> !llvm.ptr, i8 + // CHECK: [[DOG_INS_VAL:%.+]] = llvm.insertvalue [[DOG_GEP]], [[CAT_INS_VAL]][1] : !llvm.array<3 x ptr> + // CHECK: [[COW_GEP:%.+]] = llvm.getelementptr [[I8_BASE_ADDR]][8] : (!llvm.ptr) -> !llvm.ptr, i8 + // CHECK: [[COW_INS_VAL:%.+]] = llvm.insertvalue [[COW_GEP]], [[DOG_INS_VAL]][2] : !llvm.array<3 x ptr> + // CHECK: llvm.return [[COW_INS_VAL]] : !llvm.array<3 x ptr> + // CHECK: } // CHECK-DAG: llvm.mlir.global internal constant @cats_int64s{{.*}}(dense<[1, 2, 3]> : tensor<3xi64>) {addr_space = 0 : i32, alignment = 16 : i64} : !llvm.array<3 x i64> // CHECK-DAG: llvm.mlir.global internal constant @V{{.*}}(dense<[2, 1, 0]> : tensor<3xi32>) {addr_space = 0 : i32, alignment = 16 : i64} : !llvm.array<3 x i32> // CHECK-DAG: llvm.mlir.global internal constant @G{{.*}}(dense<[-1, 1, 0]> : tensor<3xi32>) {addr_space = 0 : i32, alignment = 16 : i64} : !llvm.array<3 x i32> diff --git a/test/mlir/conversion/krnl_to_llvm/krnl_getref_lowering.mlir b/test/mlir/conversion/krnl_to_llvm/krnl_getref_lowering.mlir deleted file mode 100644 index 4562618406..0000000000 --- a/test/mlir/conversion/krnl_to_llvm/krnl_getref_lowering.mlir +++ /dev/null @@ -1,88 +0,0 @@ -// RUN: onnx-mlir-opt -O3 --convert-krnl-to-affine --convert-krnl-to-llvm %s -split-input-file | FileCheck %s - -func.func @test_getref_lowering(%arg0: memref<2x2xf32>) -> memref<2x2xf32> { - %c13_i64 = arith.constant 13 : i64 - %1 = memref.alloc() : memref<10x10xf32> - %2 = "krnl.getref"(%1, %c13_i64) : (memref<10x10xf32>, i64) -> memref<2x2xf32> - return %2 : memref<2x2xf32> - - // CHECK-LABEL: test_getref_lowering - // CHECK: %[[OFFSET:.+]] = llvm.mlir.constant(13 : i64) : i64 - // CHECK: [[CONST_10_0:%.+]] = llvm.mlir.constant(10 : index) : i64 - // CHECK: [[CONST_10_1:%.+]] = llvm.mlir.constant(10 : index) : i64 - // CHECK: [[CONST_1:%.+]] = llvm.mlir.constant(1 : index) : i64 - // CHECK: %[[CONST_100:.+]] = llvm.mlir.constant(100 : index) : i64 - // CHECK: [[FLOAT_STAR:%.+]] = llvm.mlir.zero : !llvm.ptr - // CHECK: [[ELEM1:%.+]] = llvm.getelementptr [[FLOAT_STAR]][%[[CONST_100]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32 - // CHECK: [[ELEM_SIZE:%.+]] = llvm.ptrtoint [[ELEM1]] : !llvm.ptr to i64 - // CHECK: [[MEMPOOL:%.+]] = llvm.call @malloc([[ELEM_SIZE]]) : (i64) -> !llvm.ptr - // CHECK: [[MEMPOOL_MEMREF:%.+]] = llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> - // CHECK: [[MEMREF1:%.+]] = llvm.insertvalue [[MEMPOOL]], [[MEMPOOL_MEMREF]][0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> - // CHECK: [[MEMREF2:%.+]] = llvm.insertvalue [[MEMPOOL]], [[MEMREF1]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> - // CHECK: llvm.mlir.constant - // CHECK: llvm.insertvalue - // CHECK: llvm.insertvalue - // CHECK: llvm.insertvalue - // CHECK: llvm.insertvalue - // CHECK: llvm.insertvalue - // CHECK: [[MEMPOOL1:%.+]] = llvm.extractvalue {{.*}}[1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> - // CHECK: [[MEMPOOL_ALLOC:%.+]] = llvm.getelementptr [[MEMPOOL1]][%[[OFFSET]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32 - // CHECK: [[TYPED_MEMPOOL_ALLOC:%.+]] = llvm.bitcast [[MEMPOOL_ALLOC]] : !llvm.ptr to !llvm.ptr - // CHECK: [[MEMPOOLED_MEMREF:%.+]] = llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> - // CHECK: [[MEMREF3:%.+]] = llvm.insertvalue [[TYPED_MEMPOOL_ALLOC]], [[MEMPOOLED_MEMREF]][0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> - // CHECK: [[MEMREF4:%.+]] = llvm.insertvalue [[TYPED_MEMPOOL_ALLOC]], [[MEMREF3]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> -} - -// ----- - -func.func @test_getref_lowering_dynamic(%arg0: memref<2x2xf32>) -> memref<2x?xf32> { - %c13_i64 = arith.constant 13 : i64 - %c5_index = arith.constant 5 : index - %1 = memref.alloc(%c5_index) : memref<10x?xf32> - %2 = "krnl.getref"(%1, %c13_i64, %c5_index) : (memref<10x?xf32>, i64, index) -> memref<2x?xf32> - return %2 : memref<2x?xf32> - - // CHECK-LABEL: test_getref_lowering_dynamic - // CHECK: %[[C13_I64:.+]] = llvm.mlir.constant(13 : i64) : i64 - // CHECK: %[[C5_INDEX:.+]] = llvm.mlir.constant(5 : index) : i64 - // CHECK: %[[C10_INDEX:.+]] = llvm.mlir.constant(10 : index) : i64 - // CHECK: %[[C1_INDEX:.+]] = llvm.mlir.constant(1 : index) : i64 - // CHECK: %[[MUL1:.+]] = llvm.mul %[[C5_INDEX]], %[[C10_INDEX]] : i64 - // CHECK: [[FLOAT_STAR:%.+]] = llvm.mlir.zero : !llvm.ptr - // CHECK: [[ELEM1:%.+]] = llvm.getelementptr [[FLOAT_STAR]][%[[MUL1]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32 - // CHECK: [[ELEM_SIZE:%.+]] = llvm.ptrtoint [[ELEM1]] : !llvm.ptr to i64 - - /// Allocate the memory pool alloc. - // CHECK: [[ALLOC:%.+]] = llvm.call @malloc([[ELEM_SIZE]]) : (i64) -> !llvm.ptr - - /// Definition of the Alloc output memref<10x?xf32> - // CHECK: [[ALLOC_MEMREF_1:%.+]] = llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> - // CHECK: [[ALLOC_MEMREF_2:%.+]] = llvm.insertvalue [[ALLOC]], [[ALLOC_MEMREF_1]][0] - // CHECK: [[ALLOC_MEMREF_3:%.+]] = llvm.insertvalue [[ALLOC]], [[ALLOC_MEMREF_2]][1] - // CHECK: [[CONST_0:%.+]] = llvm.mlir.constant(0 : index) : i64 - // CHECK: [[ALLOC_MEMREF_4:%.+]] = llvm.insertvalue [[CONST_0]], [[ALLOC_MEMREF_3]][2] - // CHECK: [[ALLOC_MEMREF_5:%.+]] = llvm.insertvalue %[[C10_INDEX]], [[ALLOC_MEMREF_4]][3, 0] - // CHECK: [[ALLOC_MEMREF_6:%.+]] = llvm.insertvalue %[[C5_INDEX]], [[ALLOC_MEMREF_5]][3, 1] - // CHECK: [[ALLOC_MEMREF_7:%.+]] = llvm.insertvalue %[[C5_INDEX]], [[ALLOC_MEMREF_6]][4, 0] - // CHECK: [[ALLOC_MEMREF_8:%.+]] = llvm.insertvalue %[[C1_INDEX]], [[ALLOC_MEMREF_7]][4, 1] - - /// Fetch the allocated memory from the memory pool alloc. - // CHECK: [[MEMPOOL:%.+]] = llvm.extractvalue [[ALLOC_MEMREF_8]][1] - // CHECK: [[GETREF_START:%.+]] = llvm.getelementptr [[MEMPOOL]][%[[C13_I64]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32 - // CHECK: [[TYPED_GETREF_START:%.+]] = llvm.bitcast [[GETREF_START]] : !llvm.ptr to !llvm.ptr - - /// Definition of the krnl.getref output memref<2x?xf32> - // CHECK: [[GETREF_MEMREF_1:%.+]] = llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> - // CHECK: [[GETREF_MEMREF_2:%.+]] = llvm.insertvalue [[TYPED_GETREF_START]], [[GETREF_MEMREF_1]][0] - // CHECK: [[GETREF_MEMREF_3:%.+]] = llvm.insertvalue [[TYPED_GETREF_START]], [[GETREF_MEMREF_2]][1] - // CHECK: [[CONST_0:%.+]] = llvm.mlir.constant(0 : index) : i64 - // CHECK: [[GETREF_MEMREF_4:%.+]] = llvm.insertvalue [[CONST_0]], [[GETREF_MEMREF_3]][2] - // CHECK: [[CONST_2:%.+]] = llvm.mlir.constant(2 : index) : i64 - // CHECK: [[CONST_1:%.+]] = llvm.mlir.constant(1 : index) : i64 - // CHECK: [[MUL3:%.+]] = llvm.mul [[CONST_1]], %[[C5_INDEX]] : i64 - // CHECK: [[GETREF_MEMREF_5:%.+]] = llvm.insertvalue [[CONST_2]], [[GETREF_MEMREF_4]][3, 0] - // CHECK: [[GETREF_MEMREF_6:%.+]] = llvm.insertvalue [[MUL3]], [[GETREF_MEMREF_5]][4, 0] - // CHECK: [[GETREF_MEMREF_7:%.+]] = llvm.insertvalue %[[C5_INDEX]], [[GETREF_MEMREF_6]][3, 1] - // CHECK: [[GETREF_MEMREF_8:%.+]] = llvm.insertvalue [[CONST_1]], [[GETREF_MEMREF_7]][4, 1] - // CHECK: llvm.return [[GETREF_MEMREF_8]] -} diff --git a/test/mlir/conversion/onnx_to_krnl/ML/onnx_lowering_category_mapper.mlir b/test/mlir/conversion/onnx_to_krnl/ML/onnx_lowering_category_mapper.mlir index bfdca45202..7668130e94 100644 --- a/test/mlir/conversion/onnx_to_krnl/ML/onnx_lowering_category_mapper.mlir +++ b/test/mlir/conversion/onnx_to_krnl/ML/onnx_lowering_category_mapper.mlir @@ -8,7 +8,6 @@ func.func private @test_category_mapper_string_to_int64(%arg0 : tensor<2x2x!onnx "func.return"(%0) : (tensor<2x2xi64>) -> () // CHECK-LABEL: test_category_mapper_string_to_int64 - // CHECK-DAG: [[ZERO_i64:%.+]] = arith.constant 0 : i64 // CHECK-DAG: [[LEN:%.+]] = arith.constant 3 : i32 // CHECK-DAG: [[ALLOCA:%.+]] = memref.alloc() {alignment = 16 : i64} : memref<2x2xi64> // CHECK-DAG: [[G:%.+]] = "krnl.global"() {name = {{.*}}, shape = [3], value = dense<[1, 0, -3]> : tensor<3xi32>} : () -> memref<3xi32> @@ -20,8 +19,7 @@ func.func private @test_category_mapper_string_to_int64(%arg0 : tensor<2x2x!onnx // CHECK-DAG: [[LOOP_0:%.+]]:2 = krnl.define_loops 2 // CHECK: krnl.iterate([[LOOP_0]]#0, [[LOOP_0]]#1) with ([[LOOP_0]]#0 -> [[I_0:%.+]] = 0 to 2, [[LOOP_0]]#1 -> [[I_1:%.+]] = 0 to 2){ // CHECK: [[IVS:%.+]]:2 = krnl.get_induction_var_value([[LOOP_0]]#0, [[LOOP_0]]#1) : (!krnl.loop, !krnl.loop) -> (index, index) - // CHECK: [[REF:%.+]] = "krnl.getref"(%arg0, [[ZERO_i64]]) : (memref<2x2x!krnl.string>, i64) -> memref<2x2x!krnl.string> - // CHECK: [[LOAD1:%.+]] = krnl.load [[REF]]{{.}}[[IVS]]#0, [[IVS]]#1{{.}} : memref<2x2x!krnl.string> + // CHECK: [[LOAD1:%.+]] = krnl.load %arg0{{.}}[[IVS]]#0, [[IVS]]#1{{.}} : memref<2x2x!krnl.string> // CHECK: [[INDEX:%.+]] = "krnl.find_index"([[LOAD1]], [[G]], [[V]], [[LEN]]) : (!krnl.string, memref<3xi32>, memref<3xi32>, i32) -> index // CHECK: [[LOAD2:%.+]] = krnl.load [[CAT_STRINGS]]{{.}}[[INDEX]]{{.}} : memref<3x!krnl.string> // CHECK: [[STRLEN:%.+]] = "krnl.strlen"([[LOAD2]]) : (!krnl.string) -> i64 @@ -76,7 +74,6 @@ func.func private @test_rank3_category_mapper_string_to_int64(%arg0 : tensor<2x2 "func.return"(%0) : (tensor<2x2x2xi64>) -> () // CHECK-LABEL: test_rank3_category_mapper_string_to_int64 - // CHECK-DAG: [[ZERO_i64:%.+]] = arith.constant 0 : i64 // CHECK-DAG: [[LEN:%.+]] = arith.constant 3 : i32 // CHECK-DAG: [[ALLOCA:%.+]] = memref.alloc() {alignment = 16 : i64} : memref<2x2x2xi64> // CHECK-DAG: [[G:%.+]] = "krnl.global"() {name = {{.*}}, shape = [3], value = dense<[1, 0, -3]> : tensor<3xi32>} : () -> memref<3xi32> @@ -88,8 +85,7 @@ func.func private @test_rank3_category_mapper_string_to_int64(%arg0 : tensor<2x2 // CHECK-DAG: [[LOOP_0:%.+]]:3 = krnl.define_loops 3 // CHECK: krnl.iterate([[LOOP_0]]#0, [[LOOP_0]]#1, [[LOOP_0]]#2) with ([[LOOP_0]]#0 -> %arg1 = 0 to 2, [[LOOP_0]]#1 -> %arg2 = 0 to 2, [[LOOP_0]]#2 -> %arg3 = 0 to 2){ // CHECK: [[IVS:%.+]]:3 = krnl.get_induction_var_value([[LOOP_0]]#0, [[LOOP_0]]#1, [[LOOP_0]]#2) : (!krnl.loop, !krnl.loop, !krnl.loop) -> (index, index, index) - // CHECK: [[REF:%.+]] = "krnl.getref"(%arg0, [[ZERO_i64]]) : (memref<2x2x2x!krnl.string>, i64) -> memref<2x2x2x!krnl.string> - // CHECK: [[LOAD1:%.+]] = krnl.load [[REF]]{{.}}[[IVS]]#0, [[IVS]]#1, [[IVS]]#2{{.}} : memref<2x2x2x!krnl.string> + // CHECK: [[LOAD1:%.+]] = krnl.load %arg0{{.}}[[IVS]]#0, [[IVS]]#1, [[IVS]]#2{{.}} : memref<2x2x2x!krnl.string> // CHECK: [[INDEX:%.+]] = "krnl.find_index"([[LOAD1]], [[G]], [[V]], [[LEN]]) : (!krnl.string, memref<3xi32>, memref<3xi32>, i32) -> index // CHECK: [[LOAD2:%.+]] = krnl.load [[CAT_STRINGS]]{{.}}[[INDEX]]{{.}} : memref<3x!krnl.string> // CHECK: [[STRLEN:%.+]] = "krnl.strlen"([[LOAD2]]) : (!krnl.string) -> i64 diff --git a/test/mlir/conversion/onnx_to_stablehlo/Math/MatMul.mlir b/test/mlir/conversion/onnx_to_stablehlo/Math/MatMul.mlir index 65aa094b6e..5ffc96c091 100644 --- a/test/mlir/conversion/onnx_to_stablehlo/Math/MatMul.mlir +++ b/test/mlir/conversion/onnx_to_stablehlo/Math/MatMul.mlir @@ -15,6 +15,65 @@ func.func @test_onnx_to_matmul2d(%arg0 : tensor<4x8xf32>, %arg1 : tensor<8x16xf3 // ----- +func.func @test_onnx_to_matmul2d_dynM(%arg0 : tensor, %arg1 : tensor<8x16xf32>) -> tensor<*xf32> { + %0 = "onnx.MatMul"(%arg0, %arg1) : (tensor, tensor<8x16xf32>) -> tensor<*xf32> + "func.return"(%0) : (tensor<*xf32>) -> () +} + +// CHECK-LABEL: func.func @test_onnx_to_matmul2d_dynM +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor, [[PARAM_1_:%.+]]: tensor<8x16xf32>) -> tensor { +// CHECK-DAG: [[SHAPE_A_INDEX_:%.+]] = shape.shape_of [[PARAM_0_]] : tensor -> tensor<2xindex> +// CHECK-DAG: [[SHAPE_A_:%.+]] = arith.index_cast [[SHAPE_A_INDEX_]] : tensor<2xindex> to tensor<2xi64> +// CHECK-DAG: [[SHAPE_A_BCAST_:%.+]] = stablehlo.concatenate [[SHAPE_A_]], dim = 0 : (tensor<2xi64>) -> tensor<2xi64> +// CHECK-DAG: [[VAR_0_:%.+]] = stablehlo.dynamic_broadcast_in_dim [[PARAM_0_]], [[SHAPE_A_BCAST_]], dims = [0, 1] : (tensor, tensor<2xi64>) -> tensor +// CHECK-DAG: [[VAR_1_:%.+]] = stablehlo.broadcast_in_dim [[PARAM_1_]], dims = [0, 1] : (tensor<8x16xf32>) -> tensor<8x16xf32> +// CHECK: [[VAR_2_:%.+]] = stablehlo.dot [[VAR_0_]], [[VAR_1_]] : (tensor, tensor<8x16xf32>) -> tensor +// CHECK: return [[VAR_2_]] : tensor +// CHECK: } + +// ----- + +func.func @test_onnx_to_matmul2d_dynN(%arg0 : tensor<4x8xf32>, %arg1 : tensor<8x?xf32>) -> tensor<*xf32> { + %0 = "onnx.MatMul"(%arg0, %arg1) : (tensor<4x8xf32>, tensor<8x?xf32>) -> tensor<*xf32> + "func.return"(%0) : (tensor<*xf32>) -> () +} + + +// CHECK-LABEL: func.func @test_onnx_to_matmul2d_dynN +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<4x8xf32>, [[PARAM_1_:%.+]]: tensor<8x?xf32>) -> tensor<4x?xf32> { +// CHECK-DAG: [[SHAPE_B_INDEX_:%.+]] = shape.shape_of [[PARAM_1_]] : tensor<8x?xf32> -> tensor<2xindex> +// CHECK-DAG: [[SHAPE_B_:%.+]] = arith.index_cast [[SHAPE_B_INDEX_]] : tensor<2xindex> to tensor<2xi64> +// CHECK-DAG: [[SHAPE_B_BCAST_:%.+]] = stablehlo.concatenate [[SHAPE_B_]], dim = 0 : (tensor<2xi64>) -> tensor<2xi64> +// CHECK-DAG: [[VAR_0_:%.+]] = stablehlo.broadcast_in_dim [[PARAM_0_]], dims = [0, 1] : (tensor<4x8xf32>) -> tensor<4x8xf32> +// CHECK-DAG: [[VAR_1_:%.+]] = stablehlo.dynamic_broadcast_in_dim [[PARAM_1_]], [[SHAPE_B_BCAST_]], dims = [0, 1] : (tensor<8x?xf32>, tensor<2xi64>) -> tensor<8x?xf32> +// CHECK: [[VAR_2_:%.+]] = stablehlo.dot [[VAR_0_]], [[VAR_1_]] : (tensor<4x8xf32>, tensor<8x?xf32>) -> tensor<4x?xf32> +// CHECK: return [[VAR_2_]] : tensor<4x?xf32> +// CHECK: } + +// ----- + +func.func @test_onnx_to_matmul2d_dynK(%arg0 : tensor<4x?xf32>, %arg1 : tensor) -> tensor<*xf32> { + %0 = "onnx.MatMul"(%arg0, %arg1) : (tensor<4x?xf32>, tensor) -> tensor<*xf32> + "func.return"(%0) : (tensor<*xf32>) -> () +} + + +// CHECK-LABEL: func.func @test_onnx_to_matmul2d_dynK +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<4x?xf32>, [[PARAM_1_:%.+]]: tensor) -> tensor<4x16xf32> { +// CHECK-DAG: [[SHAPE_A_INDEX_:%.+]] = shape.shape_of [[PARAM_0_]] : tensor<4x?xf32> -> tensor<2xindex> +// CHECK-DAG: [[SHAPE_A_:%.+]] = arith.index_cast [[SHAPE_A_INDEX_]] : tensor<2xindex> to tensor<2xi64> +// CHECK-DAG: [[SHAPE_A_BCAST_:%.+]] = stablehlo.concatenate [[SHAPE_A_]], dim = 0 : (tensor<2xi64>) -> tensor<2xi64> +// CHECK-DAG: [[SHAPE_B_INDEX_:%.+]] = shape.shape_of [[PARAM_1_]] : tensor -> tensor<2xindex> +// CHECK-DAG: [[SHAPE_B_:%.+]] = arith.index_cast [[SHAPE_B_INDEX_]] : tensor<2xindex> to tensor<2xi64> +// CHECK-DAG: [[SHAPE_B_BCAST_:%.+]] = stablehlo.concatenate [[SHAPE_B_]], dim = 0 : (tensor<2xi64>) -> tensor<2xi64> +// CHECK-DAG: [[VAR_0_:%.+]] = stablehlo.dynamic_broadcast_in_dim [[PARAM_0_]], [[SHAPE_A_BCAST_]], dims = [0, 1] : (tensor<4x?xf32>, tensor<2xi64>) -> tensor<4x?xf32> +// CHECK-DAG: [[VAR_1_:%.+]] = stablehlo.dynamic_broadcast_in_dim [[PARAM_1_]], [[SHAPE_B_BCAST_]], dims = [0, 1] : (tensor, tensor<2xi64>) -> tensor +// CHECK: [[VAR_2_:%.+]] = stablehlo.dot [[VAR_0_]], [[VAR_1_]] : (tensor<4x?xf32>, tensor) -> tensor<4x16xf32> +// CHECK: return [[VAR_2_]] : tensor<4x16xf32> +// CHECK: } + +// ----- + func.func @test_onnx_to_matmul3d(%arg0 : tensor<100x4x8xf32>, %arg1 : tensor<100x8x16xf32>) -> tensor<*xf32> { %0 = "onnx.MatMul"(%arg0, %arg1) : (tensor<100x4x8xf32>, tensor<100x8x16xf32>) -> tensor<*xf32> "func.return"(%0) : (tensor<*xf32>) -> () @@ -45,6 +104,52 @@ func.func @test_onnx_to_matmul3dbcast(%arg0 : tensor<100x4x8xf32>, %arg1 : tenso // ----- +func.func @test_onnx_to_matmul3dbcast_dynMN(%arg0 : tensor<100x?x8xf32>, %arg1 : tensor<8x?xf32>) -> tensor<*xf32> { + %0 = "onnx.MatMul"(%arg0, %arg1) : (tensor<100x?x8xf32>, tensor<8x?xf32>) -> tensor<*xf32> + "func.return"(%0) : (tensor<*xf32>) -> () +} + +// CHECK-LABEL: func.func @test_onnx_to_matmul3dbcast_dynMN +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<100x?x8xf32>, [[PARAM_1_:%.+]]: tensor<8x?xf32>) -> tensor<100x?x?xf32> { +// CHECK-DAG: [[BDIM:%.+]] = arith.constant dense<100> : tensor<1xi64> +// CHECK-DAG: [[SHAPE_A_INDEX_:%.+]] = shape.shape_of [[PARAM_0_]] : tensor<100x?x8xf32> -> tensor<3xindex> +// CHECK-DAG: [[SHAPE_A_:%.+]] = arith.index_cast [[SHAPE_A_INDEX_]] : tensor<3xindex> to tensor<3xi64> +// CHECK-DAG: [[SHAPE_A_BCAST_:%.+]] = stablehlo.concatenate [[SHAPE_A_]], dim = 0 : (tensor<3xi64>) -> tensor<3xi64> +// CHECK-DAG: [[SHAPE_B_INDEX_:%.+]] = shape.shape_of [[PARAM_1_]] : tensor<8x?xf32> -> tensor<2xindex> +// CHECK-DAG: [[SHAPE_B_:%.+]] = arith.index_cast [[SHAPE_B_INDEX_]] : tensor<2xindex> to tensor<2xi64> +// CHECK-DAG: [[SHAPE_B_BCAST_:%.+]] = stablehlo.concatenate [[BDIM]], [[SHAPE_B_]], dim = 0 : (tensor<1xi64>, tensor<2xi64>) -> tensor<3xi64> +// CHECK-DAG: [[VAR_0_:%.+]] = stablehlo.dynamic_broadcast_in_dim [[PARAM_0_]], [[SHAPE_A_BCAST_]], dims = [0, 1, 2] : (tensor<100x?x8xf32>, tensor<3xi64>) -> tensor<100x?x8xf32> +// CHECK-DAG: [[VAR_1_:%.+]] = stablehlo.dynamic_broadcast_in_dim [[PARAM_1_]], [[SHAPE_B_BCAST_]], dims = [1, 2] : (tensor<8x?xf32>, tensor<3xi64>) -> tensor<100x8x?xf32> +// CHECK: [[VAR_2_:%.+]] = stablehlo.dot_general [[VAR_0_]], [[VAR_1_]], batching_dims = [0] x [0], contracting_dims = [2] x [1] : (tensor<100x?x8xf32>, tensor<100x8x?xf32>) -> tensor<100x?x?xf32> +// CHECK: return [[VAR_2_]] : tensor<100x?x?xf32> +// CHECK: } + +// ----- + +func.func @test_onnx_to_matmul3dbcast_dynBatch(%arg0 : tensor<4x8xf32>, %arg1 : tensor) -> tensor<*xf32> { + %0 = "onnx.MatMul"(%arg0, %arg1) : (tensor<4x8xf32>, tensor) -> tensor<*xf32> + "func.return"(%0) : (tensor<*xf32>) -> () +} + +// CHECK-LABEL: func.func @test_onnx_to_matmul3dbcast_dynBatch +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<4x8xf32>, [[PARAM_1_:%.+]]: tensor) -> tensor { +// CHECK-DAG: [[C0:%.+]] = arith.constant 0 : index +// CHECK-DAG: [[BATCH_DIM_INDEX_:%.+]] = tensor.dim [[PARAM_1_]], [[C0]] : tensor +// CHECK-DAG: [[BATCH_DIM_:%.+]] = arith.index_cast [[BATCH_DIM_INDEX_]] : index to i64 +// CHECK-DAG: [[BATCH_DIM_TENSOR_:%.+]] = tensor.from_elements [[BATCH_DIM_]] : tensor<1xi64> +// CHECK-DAG: [[SHAPE_A_:%.+]] = arith.constant dense<[4, 8]> : tensor<2xi64> +// CHECK-DAG: [[SHAPE_A_BCAST_:%.+]] = stablehlo.concatenate [[BATCH_DIM_TENSOR_]], [[SHAPE_A_]], dim = 0 : (tensor<1xi64>, tensor<2xi64>) -> tensor<3xi64> +// CHECK-DAG: [[SHAPE_B_INDEX_:%.+]] = shape.shape_of [[PARAM_1_]] : tensor -> tensor<3xindex> +// CHECK-DAG: [[SHAPE_B_:%.+]] = arith.index_cast [[SHAPE_B_INDEX_]] : tensor<3xindex> to tensor<3xi64> +// CHECK-DAG: [[SHAPE_B_BCAST_:%.+]] = stablehlo.concatenate [[SHAPE_B_]], dim = 0 : (tensor<3xi64>) -> tensor<3xi64> +// CHECK-DAG: [[VAR_0_:%.+]] = stablehlo.dynamic_broadcast_in_dim [[PARAM_0_]], [[SHAPE_A_BCAST_]], dims = [1, 2] : (tensor<4x8xf32>, tensor<3xi64>) -> tensor +// CHECK-DAG: [[VAR_1_:%.+]] = stablehlo.dynamic_broadcast_in_dim [[PARAM_1_]], [[SHAPE_B_BCAST_]], dims = [0, 1, 2] : (tensor, tensor<3xi64>) -> tensor +// CHECK: [[VAR_2_:%.+]] = stablehlo.dot_general [[VAR_0_]], [[VAR_1_]], batching_dims = [0] x [0], contracting_dims = [2] x [1] : (tensor, tensor) -> tensor +// CHECK: return [[VAR_2_]] : tensor +// CHECK: } + +// ----- + func.func @test_onnx_1d(%arg0 : tensor<6xf32>, %arg1 : tensor<6xf32>) -> tensor<*xf32> { %0 = "onnx.MatMul"(%arg0, %arg1) : (tensor<6xf32>, tensor<6xf32>) -> tensor<*xf32> "func.return"(%0) : (tensor<*xf32>) -> () diff --git a/test/mlir/conversion/onnx_to_stablehlo/Tensor/Dim.mlir b/test/mlir/conversion/onnx_to_stablehlo/Tensor/Dim.mlir new file mode 100644 index 0000000000..e3013a57e5 --- /dev/null +++ b/test/mlir/conversion/onnx_to_stablehlo/Tensor/Dim.mlir @@ -0,0 +1,46 @@ +// RUN: onnx-mlir-opt --convert-onnx-to-stablehlo --canonicalize %s -split-input-file -verify-diagnostics | FileCheck %s + +// ----- + +func.func @test_dim_1(%arg0 : tensor<5x?x1x32xf32>) -> tensor<1xi64> { + %1 = "onnx.Dim"(%arg0) { axis = 1 : si64} : (tensor<5x?x1x32xf32>) -> tensor<1xi64> + return %1 : tensor<1xi64> +} +// CHECK-LABEL: func.func @test_dim_1 +// CHECK-SAME: ([[PARAM:%.+]]: tensor<5x?x1x32xf32>) -> tensor<1xi64> { +// CHECK-NEXT: [[CONST_1:%.+]] = arith.constant 1 : index +// CHECK-NEXT: [[SHAPE:%.+]] = shape.shape_of [[PARAM]] : tensor<5x?x1x32xf32> -> tensor<4xindex> +// CHECK-NEXT: [[DIM:%.+]] = shape.get_extent [[SHAPE]], [[CONST_1]] : tensor<4xindex>, index -> index +// CHECK-NEXT: [[INDEX_CAST:%.+]] = arith.index_cast [[DIM]] : index to i64 +// CHECK-NEXT: [[FROM_ELEMENTS:%.+]] = tensor.from_elements [[INDEX_CAST]] : tensor<1xi64> +// CHECK-NEXT: return [[FROM_ELEMENTS]] : tensor<1xi64> +// CHECK: } + +// ----- + +func.func @test_dim_2(%arg0 : tensor<5x7xf32>) -> tensor<1xi64> { + %1 = "onnx.Dim"(%arg0) { axis = 0 : si64} : (tensor<5x7xf32>) -> tensor<1xi64> + return %1 : tensor<1xi64> +} + +// CHECK-LABEL: func.func @test_dim_2 +// CHECK-SAME: ([[PARAM:%.+]]: tensor<5x7xf32>) -> tensor<1xi64> { +// CHECK-NEXT: [[CONST:%.+]] = arith.constant dense<5> : tensor<1xi64> +// CHECK-NEXT: return [[CONST]] : tensor<1xi64> +// CHECK: } + +// ----- + +func.func @test_dim_invalid_1(%arg0 : tensor<5x7xf32>) -> tensor<1xi64> { + // expected-error @+1 {{attribute "axis" value is 3, accepted range is [0, 1].}} + %1 = "onnx.Dim"(%arg0) { axis = 3 : si64} : (tensor<5x7xf32>) -> tensor<1xi64> + return %1 : tensor<1xi64> +} + +// ----- + +func.func @test_dim_invalid_2(%arg0 : tensor<*xf32>) -> tensor<1xi64> { + // expected-error @+1 {{input must have shape and rank.}} + %1 = "onnx.Dim"(%arg0) { axis = 0 : si64} : (tensor<*xf32>) -> tensor<1xi64> + return %1 : tensor<1xi64> +} diff --git a/test/mlir/conversion/onnx_to_stablehlo/Tensor/Gather.mlir b/test/mlir/conversion/onnx_to_stablehlo/Tensor/Gather.mlir index 7878c6d55f..86dad1d232 100644 --- a/test/mlir/conversion/onnx_to_stablehlo/Tensor/Gather.mlir +++ b/test/mlir/conversion/onnx_to_stablehlo/Tensor/Gather.mlir @@ -21,6 +21,33 @@ func.func @test_gather_axis0(%arg0 : tensor<3x2xf32>) -> tensor<2x2x2xf32> { // ----- +func.func @test_gather_dynamic_axis0(%arg0 : tensor) -> tensor<2x2x?xf32> { + %indices = "onnx.Constant"() {value = dense<[[0, 1], [1, 2]]> : tensor<2x2xi64>} : () -> tensor<2x2xi64> + %0 = "onnx.Gather"(%arg0, %indices) {axis = 0 : si64} : (tensor, tensor<2x2xi64>) -> tensor<2x2x?xf32> + "func.return"(%0) : (tensor<2x2x?xf32>) -> () +} + +// CHECK-LABEL: func.func @test_gather_dynamic_axis0 +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor) -> tensor<2x2x?xf32> { +// CHECK-DAG: [[C0:%.+]] = arith.constant 0 : index +// CHECK-DAG: [[VAR_0_:%.+]] = stablehlo.constant dense<{{.}}[0, 1], [1, 2]{{.}}> : tensor<2x2xi64> +// CHECK-DAG: [[VAR_1_:%.+]] = stablehlo.constant dense<0> : tensor<2x2xi64> +// CHECK-DAG: [[INDICES_SHAPE_:%.+]] = shape.const_shape [2, 2] : tensor<2xindex> +// CHECK-DAG: [[SHAPE_:%.+]] = shape.shape_of [[PARAM_0_]] : tensor -> tensor<2xindex> +// CHECK-DAG: [[DIM_:%.+]] = shape.get_extent [[SHAPE_]], [[C0]] : tensor<2xindex>, index -> index +// CHECK-DAG: [[DIM_CAST_:%.+]] = arith.index_cast [[DIM_]] : index to i64 +// CHECK-DAG: [[DIM_TENSOR_:%.+]] = tensor.from_elements [[DIM_CAST_]] : tensor +// CHECK-DAG: [[VAR_2_:%.+]] = stablehlo.dynamic_broadcast_in_dim [[DIM_TENSOR_]], [[INDICES_SHAPE_]], dims = [] : (tensor, tensor<2xindex>) -> tensor<2x2xi64> +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_3_:%.+]] = stablehlo.compare LT, [[VAR_0_]], [[VAR_1_]], NOTYPE : (tensor<2x2xi64>, tensor<2x2xi64>) -> tensor<2x2xi1> +// CHECK-DAG: [[VAR_4_:%.+]] = stablehlo.add [[VAR_0_]], [[VAR_2_]] : tensor<2x2xi64> +// CHECK: [[VAR_5_:%.+]] = stablehlo.select [[VAR_3_]], [[VAR_4_]], [[VAR_0_]] : tensor<2x2xi1>, tensor<2x2xi64> +// CHECK: [[VAR_6_:%.+]] = "stablehlo.torch_index_select"([[PARAM_0_]], [[VAR_5_]]) {batch_dims = 0 : i64, dim = 0 : i64} : (tensor, tensor<2x2xi64>) -> tensor<2x2x?xf32> +// CHECK: return [[VAR_6_]] : tensor<2x2x?xf32> +// CHECK: } + +// ----- + func.func @test_gather_axis0neg(%arg0 : tensor<3x2xf32>) -> tensor<2x2x2xf32> { %indices = "onnx.Constant"() {value = dense<[[0, -1], [1, 2]]> : tensor<2x2xi64>} : () -> tensor<2x2xi64> %0 = "onnx.Gather"(%arg0, %indices) {axis = 0 : si64} : (tensor<3x2xf32>, tensor<2x2xi64>) -> tensor<2x2x2xf32> diff --git a/test/mlir/driver/buffer_loop_hoisting.mlir b/test/mlir/driver/buffer_loop_hoisting.mlir index 6a7fbe7b5e..a785814f1f 100644 --- a/test/mlir/driver/buffer_loop_hoisting.mlir +++ b/test/mlir/driver/buffer_loop_hoisting.mlir @@ -3,6 +3,11 @@ // Test that llvm.alloca is hoisted out of the loop nest. // CHECK-LABEL: test_buffer_loop_hoisting +// CHECK-NOT: llvm.br +// CHECK-NOT: llvm.cond_br +// CHECK: llvm.alloca +// CHECK-NEXT: llvm.br +// CHECK: llvm.cond_br func.func @test_buffer_loop_hoisting() { %c0_i64 = arith.constant 0 : i64 @@ -17,7 +22,8 @@ func.func @test_buffer_loop_hoisting() { %c5 = arith.constant 5 : index scf.for %arg2 = %c0 to %c20 step %c5 { %0 = memref.alloca() : memref<10x10xf32> - %1 = "krnl.getref"(%0, %c0_i64, %c0) : (memref<10x10xf32>, i64, index) -> memref<2x10xf32> + %1 = memref.dim %0, %c0 : memref<10x10xf32> + memref.dealloc %0 : memref<10x10xf32> } } } diff --git a/test/mlir/krnl/krnl_disconnect_dim_from_alloc.mlir b/test/mlir/krnl/krnl_disconnect_dim_from_alloc.mlir deleted file mode 100644 index 81c637c669..0000000000 --- a/test/mlir/krnl/krnl_disconnect_dim_from_alloc.mlir +++ /dev/null @@ -1,71 +0,0 @@ -// RUN: onnx-mlir-opt -O3 --lower-krnl-shape-to-std %s -split-input-file | FileCheck %s - -/// Lower krnl.dim when input MemRef does not have an affine map. -func.func @test_krnl_dim_lowering(%arg0: memref) -> index { - %c1 = arith.constant 1 : index - %c0 = arith.constant 0 : index - %0 = memref.dim %arg0, %c0 : memref - %1 = memref.alloc(%0) : memref - %d0 = "krnl.dim"(%1, %c0) : (memref, index) -> index - %d1 = "krnl.dim"(%1, %c1) : (memref, index) -> index - %e = arith.addi %d0, %d1 : index - return %e : index - - // CHECK-LABEL: test_krnl_dim_lowering - // CHECK-DAG: [[CONST0:%.+]] = arith.constant 0 : index - // CHECK-DAG: [[CONST10:%.+]] = arith.constant 10 : index - // CHECK: [[DIM:%.+]] = memref.dim %arg0, [[CONST0]] : memref - // CHECK: [[SUM:%.+]] = arith.addi [[DIM]], [[CONST10]] : index - // CHECK: return [[SUM]] : index -} - -// ----- - -/// Lower krnl.dim when input MemRef has an affine map. -#map = affine_map<(d0, d1) -> (d1, d0)> -func.func @test_krnl_dim_lowering_with_map(%arg0: memref) -> index { - %c1 = arith.constant 1 : index - %c0 = arith.constant 0 : index - %0 = memref.dim %arg0, %c0 : memref - %1 = memref.alloc(%0) : memref - %d0 = "krnl.dim"(%1, %c0) : (memref, index) -> index - %d1 = "krnl.dim"(%1, %c1) : (memref, index) -> index - %e = arith.addi %d0, %d1 : index - return %e : index - - // CHECK-LABEL: test_krnl_dim_lowering_with_map - // CHECK-DAG: [[CONST0:%.+]] = arith.constant 0 : index - // CHECK-DAG: [[CONST10:%.+]] = arith.constant 10 : index - // CHECK: [[DIM:%.+]] = memref.dim %arg0, [[CONST0]] : memref - // CHECK: [[SUM:%.+]] = arith.addi [[DIM]], [[CONST10]] : index - // CHECK: return [[SUM]] : index -} - -// ----- - -/// Lower krnl.dim to arith.constant when first argument of krnl.dim is an input arg -/// and the dimensions is static. -func.func @test_krnl_dim_lowering_with_const_arg(%arg0: memref<10x20xf32>) -> index { - %c0 = arith.constant 0 : index - %0 = "krnl.dim"(%arg0, %c0) : (memref<10x20xf32>, index) -> index - return %0 : index - - // CHECK-LABEL: test_krnl_dim_lowering_with_const_arg - // CHECK: [[CONST10:%.+]] = arith.constant 10 : index - // CHECK: return [[CONST10]] : index -} - -// ----- - -/// Lower krnl.dim to a standard dim operation when first argument of krnl.dim -/// is an input arg and the dimensions is dynamic. -func.func @test_krnl_dim_lowering_with_dynamic_arg(%arg0: memref<10x?xf32>) -> index { - %c0 = arith.constant 1 : index - %0 = "krnl.dim"(%arg0, %c0) : (memref<10x?xf32>, index) -> index - return %0 : index - - // CHECK-LABEL: test_krnl_dim_lowering_with_dynamic_arg - // CHECK: [[CONST1:%.+]] = arith.constant 1 : index - // CHECK: [[DIM:%.+]] = memref.dim %arg0, [[CONST1]] : memref<10x?xf32> - // CHECK: return [[DIM]] : index -} diff --git a/test/mlir/krnl/krnl_global_elision.mlir b/test/mlir/krnl/krnl_global_elision.mlir deleted file mode 100644 index dc13ef5d60..0000000000 --- a/test/mlir/krnl/krnl_global_elision.mlir +++ /dev/null @@ -1,28 +0,0 @@ -// RUN: onnx-mlir-opt --elide-krnl-constants %s -split-input-file | FileCheck %s - -// CHECK-LABEL: func @test_elide_krnl_global_constant(%arg0: memref<1xf32>) -> memref<1x70xf32> -func.func @test_elide_krnl_global_constant(%arg0: memref<1xf32>) -> memref<1x70xf32> { - %0 = "krnl.global"() {name = "constant_0", shape = [1, 70], value = dense<[[0., 1.0, 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]]> : tensor<1x70xf32>} : () -> memref<1x70xf32> - return %0 : memref<1x70xf32> - - // CHECK: {{.*}} = "krnl.global"() {name = "constant_00", shape = [1, 70]} : () -> memref<1x70xf32> - // CHECK: return {{.*}} : memref<1x70xf32> -} - -// ----- - -func.func @test_elide_krnl_global_constant() -> memref<1x80xf32> { - %0 = "krnl.global"() {name = "constant_0", shape = [1, 80], value = dense_resource : tensor<1x80xf32>} : () -> memref<1x80xf32> - return %0 : memref<1x80xf32> - -// CHECK: {{.*}} = "krnl.global"() {name = "constant_01", shape = [1, 80]} : () -> memref<1x80xf32> -// CHECK: return {{.*}} : memref<1x80xf32> -} - -{-# - dialect_resources: { - builtin: { - hex_constant: "0x010000003F0000003F8000003FC00000400000003F0000003F8000003FC00000400000003F0000003F8000003FC00000400000003F0000003F8000003FC00000400000003F0000003F8000003FC00000400000003F0000003F8000003FC00000400000003F0000003F8000003FC00000400000003F0000003F8000003FC00000400000003F0000003F8000003FC00000400000003F0000003F8000003FC00000400000003F0000003F8000003FC00000400000003F0000003F8000003FC00000400000003F0000003F8000003FC00000400000003F0000003F8000003FC00000400000003F0000003F8000003FC00000400000003F0000003F8000003FC00000400000003F0000003F8000003FC00000400000003F0000003F8000003FC00000400000003F0000003F8000003FC00000400000003F0000003F8000003FC0000040000000" - } - } -#-} diff --git a/test/mlir/krnl/krnl_normalization.mlir b/test/mlir/krnl/krnl_normalization.mlir index 5e636e77b4..9c7157f782 100644 --- a/test/mlir/krnl/krnl_normalization.mlir +++ b/test/mlir/krnl/krnl_normalization.mlir @@ -15,17 +15,3 @@ func.func @test_krnl_memcpy_norm(%arg0: memref<1x16384xf32>) -> memref<1x16x4x4x return %0 : memref<1x16x4x4xf32, #map_tile> // CHECK: return [[ALLOC]] : memref<1x16x1x1x32x32xf32> } - -// CHECK-LABEL: test_getref_norm -func.func @test_getref_norm() -> () { - %c0_i64 = arith.constant 0 : i64 - %0 = memref.alloc() : memref<1x81920xf32> - %1 = memref.alloc() : memref<1x16x4x4xf32, #map_tile> - // CHECK: [[ALLOC:%.+]] = memref.alloc() : memref<1x16x1x1x32x32xf32> - %2 = "krnl.getref"(%0, %c0_i64) : (memref<1x81920xf32>, i64) -> memref<1x16x4x4xf32> - // Do something using %1 and %2 - memref.dealloc %1: memref<1x16x4x4xf32, #map_tile> - // CHECK: memref.dealloc [[ALLOC:%.+]] : memref<1x16x1x1x32x32xf32> - memref.dealloc %0: memref<1x81920xf32> - return -} diff --git a/test/mlir/krnl/krnl_shape_lowering.mlir b/test/mlir/krnl/krnl_shape_lowering.mlir deleted file mode 100644 index bdee93d0d8..0000000000 --- a/test/mlir/krnl/krnl_shape_lowering.mlir +++ /dev/null @@ -1,54 +0,0 @@ -// RUN: onnx-mlir-opt -O3 --lower-krnl-shape %s -split-input-file | FileCheck %s - -func.func @test_krnl_shape_lowering(%arg0: memref) -> index { - %c1 = arith.constant 1 : index - %c0 = arith.constant 0 : index - %0 = memref.dim %arg0, %c0 : memref - %1 = memref.alloc(%0) : memref - %shape = "krnl.shape"(%1) : (memref) -> memref<2xindex> - %e = memref.load %shape[%c0] : memref<2xindex> - return %e : index - - // CHECK-LABEL: test_krnl_shape_lowering - // CHECK: %[[CONST0:.+]] = arith.constant 0 : index - // CHECK: %[[CONST1:.+]] = arith.constant 1 : index - // CHECK: [[DIM:%.+]] = memref.dim %arg0, %[[CONST0]] : memref - // CHECK: [[ALLOC:%.+]] = memref.alloc([[DIM]]) : memref - // CHECK: [[SHAPE:%.+]] = memref.alloc() : memref<2xindex> - // CHECK: [[DIM0:%.+]] = "krnl.dim"([[ALLOC]], %[[CONST0]]) : (memref, index) -> index - // CHECK: store [[DIM0]], [[SHAPE]][%[[CONST0]]] : memref<2xindex> - // CHECK: [[DIM1:%.+]] = "krnl.dim"([[ALLOC]], %[[CONST1]]) : (memref, index) -> index - // CHECK: store [[DIM1]], [[SHAPE]][%[[CONST1]]] : memref<2xindex> - // CHECK: [[RES:%.+]] = memref.load [[SHAPE]][%[[CONST0]]] : memref<2xindex> - // CHECK: return [[RES]] : index -} - -// ----- - -// COM: check krnl.shape lowering when its input is a MemRef with affine_map. - -#map0 = affine_map<(d0, d1) -> (d0 floordiv 2, d1 floordiv 4, d0 mod 2, d1 mod 4)> - -func.func @test_krnl_shape_lowering_with_affine_map(%arg0: memref) -> index { - %c1 = arith.constant 1 : index - %c0 = arith.constant 0 : index - %0 = memref.dim %arg0, %c0 : memref - %1 = memref.alloc(%0) : memref - %shape = "krnl.shape"(%1) : (memref) -> memref<2xindex> - %e = memref.load %shape[%c0] : memref<2xindex> - return %e : index - - // CHECK: #[[MAP0:.*]] = affine_map<(d0, d1) -> (d0 floordiv 2, d1 floordiv 4, d0 mod 2, d1 mod 4)> - // CHECK-LABEL: test_krnl_shape_lowering_with_affine_map - // CHECK: %[[CONST0:.+]] = arith.constant 0 : index - // CHECK: %[[CONST1:.+]] = arith.constant 1 : index - // CHECK: [[DIM:%.+]] = memref.dim %arg0, %[[CONST0]] : memref - // CHECK: [[ALLOC:%.+]] = memref.alloc([[DIM]]) : memref - // CHECK: [[SHAPE:%.+]] = memref.alloc() : memref<2xindex> - // CHECK: [[DIM0:%.+]] = "krnl.dim"([[ALLOC]], %[[CONST0]]) : (memref, index) -> index - // CHECK: store [[DIM0]], [[SHAPE]][%[[CONST0]]] : memref<2xindex> - // CHECK: [[DIM1:%.+]] = "krnl.dim"([[ALLOC]], %[[CONST1]]) : (memref, index) -> index - // CHECK: store [[DIM1]], [[SHAPE]][%[[CONST1]]] : memref<2xindex> - // CHECK: [[RES:%.+]] = memref.load [[SHAPE]][%[[CONST0]]] : memref<2xindex> - // CHECK: return [[RES]] : index -} diff --git a/test/mlir/onnx/invalid.mlir b/test/mlir/onnx/invalid.mlir index edb2eee264..f91d261eaa 100644 --- a/test/mlir/onnx/invalid.mlir +++ b/test/mlir/onnx/invalid.mlir @@ -83,6 +83,22 @@ func.func @test_concat_from_sequence_verifier_2(%arg0 : !onnx.Seq) -> tensor { + // expected-error @+1 {{input must have shape and rank}} + %1 = "onnx.Dim"(%arg0) {axis = 0 : si64} : (tensor<*xf32>) -> tensor + "onnx.Return"(%1) : (tensor) -> () +} + +// ----- + +func.func @test_dim_verifier_2(%arg0 : tensor<5x5xf32>) -> tensor { + // expected-error @+1 {{'onnx.Dim' op attribute "axis" value is -1, accepted range is [0, 1].}} + %1 = "onnx.Dim"(%arg0) {axis = -1 : si64} : (tensor<5x5xf32>) -> tensor + "onnx.Return"(%1) : (tensor) -> () +} + +// ----- + func.func @test_dequantize_linear_verifier_1(%arg0 : tensor<5x5x1xi32>, %arg1 : tensor<3xf32>, %arg2 : tensor<3xi32>) -> tensor<*xf32> { // expected-error @+1 {{onnx.DequantizeLinear: 'axis' value is 3, accepted range is [-3, 2]}} %1 = "onnx.DequantizeLinear"(%arg0, %arg1, %arg2) {axis = 3 : si64} : (tensor<5x5x1xi32>, tensor<3xf32>, tensor<3xi32>) -> tensor<*xf32> diff --git a/test/mlir/onnx/onnx_canonicalization.mlir b/test/mlir/onnx/onnx_canonicalization.mlir index 0b57530173..894e74bfa3 100644 --- a/test/mlir/onnx/onnx_canonicalization.mlir +++ b/test/mlir/onnx/onnx_canonicalization.mlir @@ -85,7 +85,7 @@ func.func @test_gemm_add_fusion_beta_zero(%arg0: tensor<128x128xf32>, %arg1: ten // ----- -//CHECK-LABEL: @test_gemm_add_fusion_rank3(%{{.*}}: tensor<128x128x256xf32>, %{{.*}}: tensor<128x128x256xf32>, %{{.*}}: tensor<256xf32>) -> tensor<*xf32> { +// CHECK-LABEL: @test_gemm_add_fusion_rank3(%{{.*}}: tensor<128x128x256xf32>, %{{.*}}: tensor<128x128x256xf32>, %{{.*}}: tensor<256xf32>) -> tensor<*xf32> { func.func @test_gemm_add_fusion_rank3(%arg0: tensor<128x128x256xf32>, %arg1: tensor<128x128x256xf32>, %arg2: tensor<256xf32>) -> tensor<*xf32> { %cst = "onnx.NoValue"() {value} : () -> none %0 = "onnx.Gemm"(%arg0, %arg1, %cst) : (tensor<128x128x256xf32>, tensor<128x128x256xf32>, none) -> tensor<*xf32> @@ -98,7 +98,7 @@ func.func @test_gemm_add_fusion_rank3(%arg0: tensor<128x128x256xf32>, %arg1: ten // ----- -//CHECK-LABEL: @cast_elimination(%{{.*}}: tensor<2xf32>) -> tensor<2xf32> { +// CHECK-LABEL: @cast_elimination(%{{.*}}: tensor<2xf32>) -> tensor<2xf32> { func.func @cast_elimination(%arg0: tensor<2xf32>) -> tensor<2xf32> { %0 = "onnx.Cast"(%arg0) {to = f32} : (tensor<2xf32>) -> tensor<2xf32> onnx.Return %0 : tensor<2xf32> @@ -108,6 +108,37 @@ func.func @cast_elimination(%arg0: tensor<2xf32>) -> tensor<2xf32> { // ----- +func.func @cast_concat_swap(%arg0: tensor<1xi32>, %arg1: tensor<1xi32>) -> tensor<2xi64> { + %0 = "onnx.Concat"(%arg0, %arg1) {axis = 0 : si64} : (tensor<1xi32>, tensor<1xi32>) -> tensor<2xi32> + %1 = "onnx.Cast"(%0) {to = i64} : (tensor<2xi32>) -> tensor<2xi64> + onnx.Return %1 : tensor<2xi64> + +// CHECK-LABEL: func.func @cast_concat_swap +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1xi32>, [[PARAM_1_:%.+]]: tensor<1xi32>) -> tensor<2xi64> { +// CHECK-DAG: [[VAR_0_:%.+]] = "onnx.Cast"([[PARAM_0_]]) {saturate = 1 : si64, to = i64} : (tensor<1xi32>) -> tensor<1xi64> +// CHECK-DAG: [[VAR_1_:%.+]] = "onnx.Cast"([[PARAM_1_]]) {saturate = 1 : si64, to = i64} : (tensor<1xi32>) -> tensor<1xi64> +// CHECK: [[VAR_2_:%.+]] = "onnx.Concat"([[VAR_0_]], [[VAR_1_]]) {axis = 0 : si64} : (tensor<1xi64>, tensor<1xi64>) -> tensor<2xi64> +// CHECK: onnx.Return [[VAR_2_]] : tensor<2xi64> +// CHECK: } +} + +// ----- + +func.func @cast_slice_swap(%arg0: tensor<3xi32>, %arg1: tensor<1xi64>, %arg2: tensor<1xi64>, %arg3: tensor<1xi64>, %arg4: tensor<1xi64>) -> tensor<1xi64> { + %0 = "onnx.Slice"(%arg0, %arg1, %arg2, %arg3, %arg4) {axis = 0 : si64} : (tensor<3xi32>, tensor<1xi64>, tensor<1xi64>, tensor<1xi64>, tensor<1xi64>) -> tensor<1xi32> + %1 = "onnx.Cast"(%0) {to = i64} : (tensor<1xi32>) -> tensor<1xi64> + onnx.Return %1 : tensor<1xi64> + +// CHECK-LABEL: func.func @cast_slice_swap +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<3xi32>, [[PARAM_1_:%.+]]: tensor<1xi64>, [[PARAM_2_:%.+]]: tensor<1xi64>, [[PARAM_3_:%.+]]: tensor<1xi64>, [[PARAM_4_:%.+]]: tensor<1xi64>) -> tensor<1xi64> { +// CHECK: [[VAR_0_:%.+]] = "onnx.Cast"([[PARAM_0_]]) {saturate = 1 : si64, to = i64} : (tensor<3xi32>) -> tensor<*xi64> +// CHECK: [[VAR_1_:%.+]] = "onnx.Slice"([[VAR_0_]], [[PARAM_1_]], [[PARAM_2_]], [[PARAM_3_]], [[PARAM_4_]]) : (tensor<*xi64>, tensor<1xi64>, tensor<1xi64>, tensor<1xi64>, tensor<1xi64>) -> tensor<1xi64> +// CHECK: onnx.Return [[VAR_1_]] : tensor<1xi64> +// CHECK: } +} + +// ----- + func.func @test_conv_batchnormtestmode_fusion_nobias(%arg0: tensor<1x3x224x224xf32>, %0: tensor<64x3x7x7xf32>, %2: tensor<64xf32>, %3: tensor<64xf32>, %4: tensor<64xf32>, %5: tensor<64xf32>) -> tensor<1x64x112x112xf32> { %cst = "onnx.NoValue"() {value} : () -> none %1 = "onnx.Conv"(%arg0, %0, %cst) {auto_pad = "NOTSET", dilations = [1, 1], group = 1 : si64, kernel_shape = [7, 7], pads = [3, 3, 3, 3], strides = [2, 2]} : (tensor<1x3x224x224xf32>, tensor<64x3x7x7xf32>, none) -> tensor<1x64x112x112xf32> diff --git a/test/mlir/onnx/onnx_constprop.mlir b/test/mlir/onnx/onnx_constprop.mlir index 724fe54824..3905b447de 100644 --- a/test/mlir/onnx/onnx_constprop.mlir +++ b/test/mlir/onnx/onnx_constprop.mlir @@ -380,7 +380,7 @@ func.func @test_div_ones(%arg0 : tensor<1x2xui8>) -> tensor<1x2xui8> { } //===----------------------------------------------------------------------===// -/// Equal tests +/// Equal test // ----- @@ -395,7 +395,7 @@ func.func @test_equal() -> tensor<3xi1> { } //===----------------------------------------------------------------------===// -/// Less tests +/// Less test // ----- @@ -410,7 +410,7 @@ func.func @test_less() -> tensor<3xi1> { } //===----------------------------------------------------------------------===// -/// Greater tests +/// Greater test // ----- @@ -425,7 +425,7 @@ func.func @test_greater() -> tensor<3xi1> { } //===----------------------------------------------------------------------===// -/// LessOrEqual tests +/// LessOrEqual test // ----- @@ -440,7 +440,7 @@ func.func @test_lessorequal() -> tensor<3xi1> { } //===----------------------------------------------------------------------===// -/// GreaterOrEqual tests +/// GreaterOrEqual test // ----- @@ -455,7 +455,67 @@ func.func @test_greaterorequal() -> tensor<3xi1> { } //===----------------------------------------------------------------------===// -/// Sqrt tests +/// Modulo tests + +// ----- + +// CHECK-LABEL: @test_modulo_int_both_neg() -> tensor +func.func @test_modulo_int_both_neg() -> tensor { + %0 = onnx.Constant dense<-7> : tensor + %1 = onnx.Constant dense<-5> : tensor + %2 = "onnx.Mod"(%0, %1) : (tensor , tensor) -> tensor + "onnx.Return"(%2) : (tensor) -> () + // CHECK: [[CONST:%.+]] = onnx.Constant dense<-2> : tensor +} + +// ----- + +// CHECK-LABEL: @test_modulo_int_neg() -> tensor +func.func @test_modulo_int_neg() -> tensor { + %0 = onnx.Constant dense<-4> : tensor + %1 = onnx.Constant dense<2> : tensor + %2 = "onnx.Mod"(%0, %1) : (tensor , tensor) -> tensor + "onnx.Return"(%2) : (tensor) -> () + // CHECK: [[CONST:%.+]] = onnx.Constant dense<0> : tensor +} + +// ----- + +// CHECK-LABEL: @test_modulo_int_pos() -> tensor +func.func @test_modulo_int_pos() -> tensor { + %0 = onnx.Constant dense<5> : tensor + %1 = onnx.Constant dense<8> : tensor + %2 = "onnx.Mod"(%0, %1) : (tensor , tensor) -> tensor + "onnx.Return"(%2) : (tensor) -> () + // CHECK: [[CONST:%.+]] = onnx.Constant dense<5> : tensor +} + +// ----- + +// CHECK-LABEL: @test_modulo_float() -> tensor<1xf32> +func.func @test_modulo_float() -> tensor<1xf32> { + %0 = onnx.Constant dense<[2.0]> : tensor<1xf32> + %1 = onnx.Constant dense<[7.0]> : tensor<1xf32> + %2 = "onnx.Mod"(%0, %1) {fmod = 1 : si64} : (tensor<1xf32> , tensor<1xf32>) -> tensor<1xf32> + "onnx.Return"(%2) : (tensor<1xf32>) -> () + // CHECK: [[CONST:%.+]] = onnx.Constant dense<2.000000e+00> : tensor<1xf32> + // CHECK-NOT: {{.*}} = "onnx.Mod"{{.*}} +} + +// ----- + +// CHECK-LABEL: @test_modulo_float_mixed() -> tensor<1xf32> +func.func @test_modulo_float_mixed() -> tensor<1xf32> { + %0 = onnx.Constant dense<[-4.3]> : tensor<1xf32> + %1 = onnx.Constant dense<[2.1]> : tensor<1xf32> + %2 = "onnx.Mod"(%0, %1) {fmod = 1 : si64} : (tensor<1xf32> , tensor<1xf32>) -> tensor<1xf32> + "onnx.Return"(%2) : (tensor<1xf32>) -> () + // CHECK: [[CONST:%.+]] = onnx.Constant dense<-0.100000381> : tensor<1xf32> + // CHECK-NOT: {{.*}} = "onnx.Mod"{{.*}} +} + +//===----------------------------------------------------------------------===// +/// Sqrt test // ----- @@ -468,7 +528,8 @@ func.func @test_sqrt() -> tensor<1x2xf32> { // CHECK-NOT: {{.*}} = "onnx.Sqrt"{{.*}} } -/// Relu tests +//===----------------------------------------------------------------------===// +/// Relu test // ----- diff --git a/third_party/rapidcheck b/third_party/rapidcheck index 1c91f40e64..ff6af6fc68 160000 --- a/third_party/rapidcheck +++ b/third_party/rapidcheck @@ -1 +1 @@ -Subproject commit 1c91f40e64d87869250cfb610376c629307bf77d +Subproject commit ff6af6fc683159deb51c543b065eba14dfcf329b diff --git a/utils/gen_onnx_mlir.py b/utils/gen_onnx_mlir.py index e1343f101f..7bb5c655e7 100755 --- a/utils/gen_onnx_mlir.py +++ b/utils/gen_onnx_mlir.py @@ -171,8 +171,8 @@ "If": [19], "Imputer": [1], "InstanceNormalization": [6], - "IsInf": [10], - "IsNaN": [13], + "IsInf": [20], + "IsNaN": [20], "LayerNormalization": [17], "LRN": [13], "LSTM": [14],