Skip to content

Commit

Permalink
Test bert using gluon model.
Browse files Browse the repository at this point in the history
Change gluon to torch.

Revert evil work around.

Skip test.
  • Loading branch information
zxybazh committed Jan 13, 2022
1 parent c6e5dac commit ccd146b
Show file tree
Hide file tree
Showing 4 changed files with 65 additions and 9 deletions.
52 changes: 52 additions & 0 deletions python/tvm/meta_schedule/testing/relay_workload.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ class MODEL_TYPE(Enum): # pylint: disable=invalid-name
VIDEO_CLASSIFICATION = (2,)
SEGMENTATION = (3,)
OBJECT_DETECTION = (4,)
TEXT_CLASSIFICATION = (5,)


# Specify the type of each model
Expand Down Expand Up @@ -95,6 +96,11 @@ class MODEL_TYPE(Enum): # pylint: disable=invalid-name
"r3d_18": MODEL_TYPE.VIDEO_CLASSIFICATION,
"mc3_18": MODEL_TYPE.VIDEO_CLASSIFICATION,
"r2plus1d_18": MODEL_TYPE.VIDEO_CLASSIFICATION,
# Text classification
"bert_tiny": MODEL_TYPE.TEXT_CLASSIFICATION,
"bert_base": MODEL_TYPE.TEXT_CLASSIFICATION,
"bert_medium": MODEL_TYPE.TEXT_CLASSIFICATION,
"bert_large": MODEL_TYPE.TEXT_CLASSIFICATION,
}


Expand All @@ -121,6 +127,8 @@ def get_torch_model(

import torch # type: ignore # pylint: disable=import-error,import-outside-toplevel
from torchvision import models # type: ignore # pylint: disable=import-error,import-outside-toplevel
import transformers # type: ignore # pylint: disable=import-error,import-outside-toplevel
import os # type: ignore # pylint: disable=import-error,import-outside-toplevel

def do_trace(model, inp):
model_trace = torch.jit.trace(model, inp)
Expand All @@ -136,6 +144,50 @@ def do_trace(model, inp):
model = getattr(models.detection, model_name)()
elif MODEL_TYPES[model_name] == MODEL_TYPE.VIDEO_CLASSIFICATION:
model = getattr(models.video, model_name)()
elif MODEL_TYPES[model_name] == MODEL_TYPE.TEXT_CLASSIFICATION:
os.environ["TOKENIZERS_PARALLELISM"] = "false"
config_dict = {
"bert_tiny": transformers.BertConfig(
num_hidden_layers=6,
hidden_size=512,
intermediate_size=2048,
num_attention_heads=8,
return_dict=False,
),
"bert_base": transformers.BertConfig(
num_hidden_layers=12,
hidden_size=768,
intermediate_size=3072,
num_attention_heads=12,
return_dict=False,
),
"bert_medium": transformers.BertConfig(
num_hidden_layers=12,
hidden_size=1024,
intermediate_size=4096,
num_attention_heads=16,
return_dict=False,
),
"bert_large": transformers.BertConfig(
num_hidden_layers=24,
hidden_size=1024,
intermediate_size=4096,
num_attention_heads=16,
return_dict=False,
),
}
configuration = config_dict[model_name]
model = transformers.BertModel(configuration)
input_name = "input_ids"
A = torch.randint(10000, input_shape)

model.eval()
scripted_model = torch.jit.trace(model, [A], strict=False)

input_name = "input_ids"
shape_list = [(input_name, input_shape)]
mod, params = relay.frontend.from_pytorch(scripted_model, shape_list)
return mod, params
else:
raise ValueError("Unsupported model in Torch model zoo.")

Expand Down
8 changes: 4 additions & 4 deletions src/meta_schedule/space_generator/post_order_apply.cc
Original file line number Diff line number Diff line change
Expand Up @@ -95,10 +95,10 @@ class PostOrderApplyNode : public SpaceGeneratorNode {

Array<tir::Schedule> GenerateDesignSpace(const IRModule& mod_) final {
using ScheduleAndUnvisitedBlocks = std::pair<tir::Schedule, Array<tir::BlockRV>>;
tir::Schedule sch = tir::Schedule::Traced( //
/*mod=*/mod_, //
/*rand_state=*/ForkSeed(&this->rand_state_), //
/*debug_mode=*/tir::kVerifySRefTree | tir::kVerifyCachedFlags, //
tir::Schedule sch = tir::Schedule::Traced( //
/*mod=*/mod_, //
/*rand_state=*/ForkSeed(&this->rand_state_), //
/*debug_mode=*/tir::kVerifySRefTree | tir::kVerifyCachedFlags,
/*error_render_level=*/tir::ScheduleErrorRenderLevel::kDetail);

std::vector<ScheduleAndUnvisitedBlocks> stack;
Expand Down
2 changes: 1 addition & 1 deletion src/meta_schedule/task_scheduler/task_scheduler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ void TaskSchedulerNode::Tune() {

int running_tasks = tasks.size();
for (int task_id; (task_id = NextTaskId()) != -1;) {
LOG(INFO) << "Scheduler picks Task #" << task_id << ": " << tasks[task_id]->task_name;
LOG(INFO) << "Scheduler picks Task #" << task_id + 1 << ": " << tasks[task_id]->task_name;
TuneContext task = tasks[task_id];
ICHECK(!task->is_stopped);
ICHECK(!task->runner_futures.defined());
Expand Down
12 changes: 8 additions & 4 deletions tests/python/unittest/test_meta_schedule_tune_relay.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@


@pytest.mark.skip("Integration test")
@pytest.mark.parametrize("model_name", ["resnet18"])
@pytest.mark.parametrize("model_name", ["resnet18", "mobilenet_v2", "bert_base"])
@pytest.mark.parametrize("batch_size", [1])
@pytest.mark.parametrize("target", ["llvm --num-cores=16", "nvidia/geforce-rtx-3070"])
def test_meta_schedule_tune_relay(model_name: str, batch_size: int, target: str):
Expand All @@ -47,6 +47,9 @@ def test_meta_schedule_tune_relay(model_name: str, batch_size: int, target: str)
input_shape = (1, 3, 300, 300)
elif MODEL_TYPES[model_name] == MODEL_TYPE.VIDEO_CLASSIFICATION:
input_shape = (batch_size, 3, 3, 299, 299)
elif MODEL_TYPES[model_name] == MODEL_TYPE.TEXT_CLASSIFICATION:
seq_length = 128
input_shape = (batch_size, seq_length)
else:
raise ValueError("Unsupported model: " + model_name)
output_shape: Tuple[int, int] = (batch_size, 1000)
Expand All @@ -71,7 +74,7 @@ def test_meta_schedule_tune_relay(model_name: str, batch_size: int, target: str)
work_dir=work_dir,
)
for i, sch in enumerate(schs):
print("-" * 10 + f" Part {i}/{len(schs)} " + "-" * 10)
print("-" * 10 + f" Part {i+1}/{len(schs)} " + "-" * 10)
if sch is None:
print("No valid schedule found!")
else:
Expand All @@ -80,5 +83,6 @@ def test_meta_schedule_tune_relay(model_name: str, batch_size: int, target: str)


if __name__ == """__main__""":
test_meta_schedule_tune_relay("resnet18", 1, "llvm --num-cores=16")
test_meta_schedule_tune_relay("resnet18", 1, "nvidia/geforce-rtx-3070")
# test_meta_schedule_tune_relay("resnet18", 1, "llvm --num-cores=16")
# test_meta_schedule_tune_relay("mobilenet_v2", 1, "nvidia/geforce-rtx-3070")
test_meta_schedule_tune_relay("bert_base", 1, "llvm --num-cores=16")

0 comments on commit ccd146b

Please sign in to comment.