Skip to content

Commit

Permalink
[MetaSchedule][M4a] User-API: Tune-TE/TIR/Relay (apache#10079)
Browse files Browse the repository at this point in the history
* Add tuning scripts for tir, te & relay.

Co-authored-by: Junru Shao <[email protected]>
Co-authored-by: Bohan Hou <[email protected]>
Co-authored-by: Ruihang Lai <[email protected]>
Co-authored-by: Hongyi Jin <[email protected]>
Co-authored-by: Wuwei Lin <[email protected]>
Co-authored-by: Siyuan Feng <[email protected]>

Minor fix.

Nits.

Add back tests.

* slightly improve tune.py

Co-authored-by: Junru Shao <[email protected]>
  • Loading branch information
zxybazh and junrushao authored Jan 30, 2022
1 parent 1f9c76b commit 779dc51
Show file tree
Hide file tree
Showing 14 changed files with 1,270 additions and 26 deletions.
13 changes: 11 additions & 2 deletions python/tvm/meta_schedule/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,19 @@
from . import database
from . import builder
from . import runner
from . import mutator
from . import postproc
from . import schedule_rule
from . import space_generator
from . import search_strategy
from . import schedule_rule
from . import integration
from . import feature_extractor
from . import cost_model
from .search_strategy import (
EvolutionarySearchConfig,
MeasureCandidate,
ReplayFuncConfig,
ReplayTraceConfig,
)
from .tune import tune_te, tune_tir, tune_relay
from .tune_context import TuneContext
from .search_strategy import MeasureCandidate
2 changes: 1 addition & 1 deletion python/tvm/meta_schedule/integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ def __init__(self, database) -> None:
self.__init_handle_by_constructor__(_ffi_api.ApplyHistoryBest, database) # type: ignore # pylint: disable=no-member


def extract_task(
def extract_task_from_relay(
mod: Union[IRModule, RelayFunc],
target: Target,
params: Optional[Dict[str, NDArray]] = None,
Expand Down
4 changes: 2 additions & 2 deletions python/tvm/meta_schedule/testing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,6 @@
# specific language governing permissions and limitations
# under the License.
"""Testing utilities in meta schedule"""
from .local_rpc import LocalRPC
from .relay_workload import get_network
from .byoc_trt import relay_build_with_tensorrt
from .local_rpc import LocalRPC
from .relay_workload import MODEL_TYPE, MODEL_TYPES, get_network, get_torch_model
80 changes: 80 additions & 0 deletions python/tvm/meta_schedule/testing/relay_workload.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,93 @@
# specific language governing permissions and limitations
# under the License.
"""Workloads in Relay IR"""
from enum import Enum
from typing import Dict, Tuple

import tvm.relay.testing # pylint: disable=unused-import
from tvm import relay
from tvm.ir import IRModule
from tvm.runtime import NDArray

# Model types supported in Torchvision
class MODEL_TYPE(Enum): # pylint: disable=invalid-name
IMAGE_CLASSIFICATION = (1,)
VIDEO_CLASSIFICATION = (2,)
SEGMENTATION = (3,)
OBJECT_DETECTION = (4,)
TEXT_CLASSIFICATION = (5,)


# Specify the type of each model
MODEL_TYPES = {
"resnet18": MODEL_TYPE.IMAGE_CLASSIFICATION,
"mobilenet_v2": MODEL_TYPE.IMAGE_CLASSIFICATION,
"bert_base": MODEL_TYPE.TEXT_CLASSIFICATION,
}


def get_torch_model(
model_name: str,
input_shape: Tuple[int, ...],
output_shape: Tuple[int, int], # pylint: disable=unused-argument
dtype: str = "float32",
) -> Tuple[IRModule, Dict[str, NDArray]]:
"""Load model from torch model zoo
Parameters
----------
model_name : str
The name of the model to load
input_shape: Tuple[int, ...]
Tuple for input shape
output_shape: Tuple[int, int]
Tuple for output shape
dtype: str
Tensor data type
"""

assert dtype == "float32"

import torch # type: ignore # pylint: disable=import-error,import-outside-toplevel
from torchvision import models # type: ignore # pylint: disable=import-error,import-outside-toplevel
import transformers # type: ignore # pylint: disable=import-error,import-outside-toplevel
import os # type: ignore # pylint: disable=import-error,import-outside-toplevel

def do_trace(model, inp):
model.eval()
model_trace = torch.jit.trace(model, inp)
model_trace.eval()
return model_trace

# Load model from torchvision
if MODEL_TYPES[model_name] == MODEL_TYPE.TEXT_CLASSIFICATION:
os.environ["TOKENIZERS_PARALLELISM"] = "false"
model = transformers.BertModel(
transformers.BertConfig(
num_hidden_layers=12,
hidden_size=768,
intermediate_size=3072,
num_attention_heads=12,
return_dict=False,
)
)
model.eval()
input_data = torch.randint(10000, input_shape)
shape_list = [("input_ids", input_shape)]
scripted_model = torch.jit.trace(model, [input_data], strict=False)
elif MODEL_TYPES[model_name] == MODEL_TYPE.IMAGE_CLASSIFICATION:
model = getattr(models, model_name)()
# Setup input
input_data = torch.randn(input_shape).type(torch.float32)
shape_list = [("input0", input_shape)]
# Get trace. Depending on the model type, wrapper may be necessary.
scripted_model = do_trace(model, input_data)
else:
raise ValueError("Unsupported model in Torch model zoo.")

# Convert torch model to relay module
mod, params = relay.frontend.from_pytorch(scripted_model, shape_list)
return mod, params


def get_network(
name: str,
Expand Down
Loading

0 comments on commit 779dc51

Please sign in to comment.