Skip to content

Commit

Permalink
Add a way for backends to control which ops are legal for them.
Browse files Browse the repository at this point in the history
We were already hitting many cases where backends different in terms of
the legal ops that they wanted. This caused unnecessary coupling between
the backends. Examples:
- llvm#1161
- llvm#862

This PR centralizes all compilation to go through `torch_mlir.compile`
so that we can keep the logic centralized there. We should move these
lists closer to each backend. Especially cases like
llvm#862 where blocking a
decomposition is necessary to avoid a crash emphasize that the set of
decompositions is tightly coupled to the backend, and should be
"controlled by the backend" and not something arbitrarily tweakable.

Also:
- Fix a small bug in the way we passed through the backendLegalOps
  option.
- Add better error messages in `torch_mlir.compile` for import errors.
  • Loading branch information
silvasean committed Aug 19, 2022
1 parent f601435 commit 6f61aa3
Show file tree
Hide file tree
Showing 10 changed files with 98 additions and 77 deletions.
3 changes: 2 additions & 1 deletion include/torch-mlir/Dialect/Torch/Transforms/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,8 @@ std::unique_ptr<OperationPass<func::FuncOp>> createMaximizeValueSemanticsPass();

std::unique_ptr<OperationPass<ModuleOp>> createRefinePublicReturnPass();

std::unique_ptr<OperationPass<func::FuncOp>> createDecomposeComplexOpsPass();
std::unique_ptr<OperationPass<func::FuncOp>>
createDecomposeComplexOpsPass(ArrayRef<std::string> legalOps);

std::unique_ptr<OperationPass<ModuleOp>> createPreprocessShapeLibraryPass();

Expand Down
4 changes: 3 additions & 1 deletion include/torch-mlir/Dialect/Torch/Transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,9 @@ def RefinePublicReturn : Pass<"torch-refine-public-return", "ModuleOp"> {

def DecomposeComplexOps : Pass<"torch-decompose-complex-ops", "func::FuncOp"> {
let summary = "Decompose complicated torch operations";
let constructor = "mlir::torch::Torch::createDecomposeComplexOpsPass()";
let constructor = [{
mlir::torch::Torch::createDecomposeComplexOpsPass(/*legalOps=*/{})
}];
let options = [
ListOption<"legalOps", "legal-ops", "std::string",
"List of operation names that should be considered legal",
Expand Down
6 changes: 4 additions & 2 deletions lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2646,7 +2646,9 @@ class DecomposeComplexOpsPass
}
};
} // namespace

std::unique_ptr<OperationPass<func::FuncOp>>
mlir::torch::Torch::createDecomposeComplexOpsPass() {
return std::make_unique<DecomposeComplexOpsPass>();
mlir::torch::Torch::createDecomposeComplexOpsPass(
ArrayRef<std::string> legalOps) {
return std::make_unique<DecomposeComplexOpsPass>(legalOps);
}
3 changes: 2 additions & 1 deletion lib/Dialect/Torch/Transforms/Passes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,8 @@ void mlir::torch::Torch::createTorchSimplificationPipeline(
// basic blocks.
pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
if (options.decompose) {
pm.addNestedPass<func::FuncOp>(Torch::createDecomposeComplexOpsPass());
pm.addNestedPass<func::FuncOp>(
Torch::createDecomposeComplexOpsPass(options.backendLegalOps));
pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
}
}
Expand Down
42 changes: 38 additions & 4 deletions python/torch_mlir/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@
from typing import Sequence, Union, List
from enum import Enum

import sys
from io import StringIO

import torch

from torch_mlir.passmanager import PassManager
Expand Down Expand Up @@ -116,6 +119,19 @@ def like(tensor: torch.Tensor, dynamic_axes: List[int] = None):
return TensorPlaceholder(shape, tensor.dtype)


# The set of ops that are considered legal for each backend.
# These are currently quite load-bearing, since different backends might be
# missing patterns for decomposed forms of certain ops.
# TODO: Tighten up the definition of these "conditionally legal for backends"
# ops in the backend contract, and move these lists somewhere deeper in the
# compiler where each backend can "own" its set of legal ops.
BACKEND_LEGAL_OPS = {
OutputType.TOSA: [],
OutputType.LINALG_ON_TENSORS: [],
OutputType.MHLO: [],
}


_example_arg = Union[TensorPlaceholder, torch.Tensor]


Expand Down Expand Up @@ -209,14 +225,32 @@ def compile(model: torch.nn.Module,
mb = ModuleBuilder()
import_options = ImportOptions()
import_options.ignoreExistingTensorShapesAndDtypes = ignore_traced_shapes
mb.import_module(scripted._c, class_annotator, import_options)
try:
original_stderr = sys.stderr
sys.stderr = StringIO()
# Import the TorchScript module to MLIR
mb.import_module(scripted._c, class_annotator, import_options)
except Exception as e:
raise Exception(f"""
PyTorch TorchScript module -> torch-mlir Object Graph IR import failed with:
### Importer C++ Exception:
{e}
### Importer Diagnostics:
{sys.stderr.getvalue()}
""") from None
finally:
sys.stderr = original_stderr

if output_type == OutputType.RAW:
return mb.module

run_pipeline_with_repro_report(mb.module,
"torchscript-module-to-torch-backend-pipeline",
"Lowering TorchScript IR -> Torch Backend IR")
backend_legal_ops = BACKEND_LEGAL_OPS.get(output_type, [])
option_string = "{backend-legal-ops=" + ",".join(backend_legal_ops) + "}"
run_pipeline_with_repro_report(
mb.module,
f"torchscript-module-to-torch-backend-pipeline{option_string}",
"Lowering TorchScript IR -> Torch Backend IR",
)

if verbose:
print("\n====================")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,23 +3,18 @@
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
# Also available under a BSD-style license. See LICENSE.

import sys
from typing import Any
from io import StringIO
import os
import tempfile

import numpy as np
import torch
import torch_mlir

from torch_mlir_e2e_test.linalg_on_tensors_backends.abc import LinalgOnTensorsBackend
from torch_mlir_e2e_test.torchscript.framework import TestConfig, Trace, TraceItem
from torch_mlir.compiler_utils import run_pipeline_with_repro_report
from torch_mlir_e2e_test.utils import convert_annotations_to_placeholders

from .utils import (
recursively_convert_to_numpy,
recursively_convert_from_numpy,
convert_torchscript_module_to_torch_backend_contract_mlir,
)


Expand All @@ -34,14 +29,9 @@ def __init__(self, backend: LinalgOnTensorsBackend):
self.backend = backend

def compile(self, program: torch.nn.Module) -> Any:

module = convert_torchscript_module_to_torch_backend_contract_mlir(
program)

run_pipeline_with_repro_report(
module,
"torch-backend-to-linalg-on-tensors-backend-pipeline",
"Lower Torch Backend IR -> Linalg-on-Tensors Backend IR")
example_args = convert_annotations_to_placeholders(program.forward)
module = torch_mlir.compile(
program, example_args, output_type="linalg-on-tensors")

return self.backend.compile(module)

Expand Down
20 changes: 5 additions & 15 deletions python/torch_mlir_e2e_test/torchscript/configs/tosa_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,22 +3,17 @@
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
# Also available under a BSD-style license. See LICENSE.

import sys
from typing import Any
from io import StringIO
import os
import tempfile

import numpy as np
import torch
import torch_mlir

from torch_mlir_e2e_test.tosa_backends.abc import TosaBackend
from torch_mlir_e2e_test.torchscript.framework import TestConfig, Trace, TraceItem
from torch_mlir.compiler_utils import run_pipeline_with_repro_report
from torch_mlir_e2e_test.utils import convert_annotations_to_placeholders
from .utils import (
recursively_convert_to_numpy,
recursively_convert_from_numpy,
convert_torchscript_module_to_torch_backend_contract_mlir,
)


Expand All @@ -33,14 +28,9 @@ def __init__(self, backend: TosaBackend):
self.backend = backend

def compile(self, program: torch.nn.Module) -> Any:

module = convert_torchscript_module_to_torch_backend_contract_mlir(
program)

run_pipeline_with_repro_report(
module,
"torch-backend-to-tosa-backend-pipeline",
"Lower Torch Backend IR -> TOSA Backend IR")
example_args = convert_annotations_to_placeholders(program.forward)
module = torch_mlir.compile(
program, example_args, output_type="tosa")

return self.backend.compile(module)

Expand Down
38 changes: 0 additions & 38 deletions python/torch_mlir_e2e_test/torchscript/configs/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,41 +50,3 @@ def recursively_convert_from_numpy(o: Any):
if isinstance(o, int):
return o
raise Exception(f"Unexpected Python function output: {o}")


def convert_torchscript_module_to_torch_backend_contract_mlir(program: torch.nn.Module):
"""Perform common lowering from TorchScript to Torch MLIR
Returns an MLIR module that satisfies the Torch backend contract.
"""
mb = ModuleBuilder()
scripted = torch.jit.script(program)
class_annotator = ClassAnnotator()

extract_annotations(program, scripted, class_annotator)


# TODO: Find a way to make each of these calls own its own
# "debuggable error report" situation.
try:
original_stderr = sys.stderr
sys.stderr = StringIO()
# Import the TorchScript module to MLIR
mb.import_module(scripted._c, class_annotator)
except Exception as e:
raise Exception(f"""
PyTorch TorchScript module -> torch-mlir Object Graph IR import failed with:
Exception:
{e}
Diagnostics:
{sys.stderr.getvalue()}
""") from None
finally:
sys.stderr = original_stderr

run_pipeline_with_repro_report(
mb.module,
"torchscript-module-to-torch-backend-pipeline",
"Lowering TorchScript Object Graph IR -> Torch Backend IR")

return mb.module
22 changes: 22 additions & 0 deletions python/torch_mlir_e2e_test/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
# Also available under a BSD-style license. See LICENSE.

from torch_mlir import TensorPlaceholder
from torch_mlir_e2e_test.torchscript.annotations import TORCH_MLIR_ARG_ANNOTATIONS_ATTR_NAME

def convert_annotations_to_placeholders(forward_method):
"""Converts the annotations on a forward method into tensor placeholders.
These placeholders are suitable for being passed to `torch_mlir.compile`.
"""
annotations = getattr(forward_method, TORCH_MLIR_ARG_ANNOTATIONS_ATTR_NAME)
placeholders = []
# Skip the "self" annotation.
for annotation in annotations[1:]:
if not annotation[2]:
raise ValueError(
"Can only compile inputs annotated as having value semantics.")
placeholders.append(TensorPlaceholder(annotation[0], annotation[1]))
return placeholders
17 changes: 17 additions & 0 deletions test/Dialect/Torch/torch-function-to-torch-backend-pipeline.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
// RUN: torch-mlir-opt -pass-pipeline='torch-function-to-torch-backend-pipeline{backend-legal-ops=torch.aten.square,torch.aten.argmax}' -split-input-file %s | FileCheck %s

// CHECK-LABEL: func.func @torch.aten.square
func.func @torch.aten.square(%arg0: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[?,?,?],f32> {
// CHECK: torch.aten.square
%0 = torch.aten.square %arg0 : !torch.vtensor<[?,?,?],f32> -> !torch.vtensor<[?,?,?],f32>
return %0 : !torch.vtensor<[?,?,?],f32>
}

// CHECK-LABEL: func.func @torch.aten.argmax
func.func @torch.aten.argmax(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[1,?],si64> {
%int0 = torch.constant.int 0
%true = torch.constant.bool true
// CHECK: torch.aten.argmax
%0 = torch.aten.argmax %arg0, %int0, %true : !torch.vtensor<[?,?],f32>, !torch.int, !torch.bool -> !torch.vtensor<[1,?],si64>
return %0 : !torch.vtensor<[1,?],si64>
}

0 comments on commit 6f61aa3

Please sign in to comment.