forked from llvm/torch-mlir
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add a way for backends to control which ops are legal for them.
We were already hitting many cases where backends different in terms of the legal ops that they wanted. This caused unnecessary coupling between the backends. Examples: - llvm#1161 - llvm#862 This PR centralizes all compilation to go through `torch_mlir.compile` so that we can keep the logic centralized there. We should move these lists closer to each backend. Especially cases like llvm#862 where blocking a decomposition is necessary to avoid a crash emphasize that the set of decompositions is tightly coupled to the backend, and should be "controlled by the backend" and not something arbitrarily tweakable. Also: - Fix a small bug in the way we passed through the backendLegalOps option. - Add better error messages in `torch_mlir.compile` for import errors.
- Loading branch information
Showing
10 changed files
with
98 additions
and
77 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. | ||
# See https://llvm.org/LICENSE.txt for license information. | ||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
# Also available under a BSD-style license. See LICENSE. | ||
|
||
from torch_mlir import TensorPlaceholder | ||
from torch_mlir_e2e_test.torchscript.annotations import TORCH_MLIR_ARG_ANNOTATIONS_ATTR_NAME | ||
|
||
def convert_annotations_to_placeholders(forward_method): | ||
"""Converts the annotations on a forward method into tensor placeholders. | ||
These placeholders are suitable for being passed to `torch_mlir.compile`. | ||
""" | ||
annotations = getattr(forward_method, TORCH_MLIR_ARG_ANNOTATIONS_ATTR_NAME) | ||
placeholders = [] | ||
# Skip the "self" annotation. | ||
for annotation in annotations[1:]: | ||
if not annotation[2]: | ||
raise ValueError( | ||
"Can only compile inputs annotated as having value semantics.") | ||
placeholders.append(TensorPlaceholder(annotation[0], annotation[1])) | ||
return placeholders |
17 changes: 17 additions & 0 deletions
17
test/Dialect/Torch/torch-function-to-torch-backend-pipeline.mlir
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
// RUN: torch-mlir-opt -pass-pipeline='torch-function-to-torch-backend-pipeline{backend-legal-ops=torch.aten.square,torch.aten.argmax}' -split-input-file %s | FileCheck %s | ||
|
||
// CHECK-LABEL: func.func @torch.aten.square | ||
func.func @torch.aten.square(%arg0: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[?,?,?],f32> { | ||
// CHECK: torch.aten.square | ||
%0 = torch.aten.square %arg0 : !torch.vtensor<[?,?,?],f32> -> !torch.vtensor<[?,?,?],f32> | ||
return %0 : !torch.vtensor<[?,?,?],f32> | ||
} | ||
|
||
// CHECK-LABEL: func.func @torch.aten.argmax | ||
func.func @torch.aten.argmax(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[1,?],si64> { | ||
%int0 = torch.constant.int 0 | ||
%true = torch.constant.bool true | ||
// CHECK: torch.aten.argmax | ||
%0 = torch.aten.argmax %arg0, %int0, %true : !torch.vtensor<[?,?],f32>, !torch.int, !torch.bool -> !torch.vtensor<[1,?],si64> | ||
return %0 : !torch.vtensor<[1,?],si64> | ||
} |