Skip to content

Commit

Permalink
Fix Torch-TRT backend and userbenchmark (#1957)
Browse files Browse the repository at this point in the history
Summary:
- Improve error checking in benchmark
- Improve argument handling and processing in backend

Pull Request resolved: #1957

Reviewed By: aaronenyeshi

Differential Revision: D49912033

Pulled By: xuzhao9

fbshipit-source-id: 8709e738344885b7af6619ca13f437b67ef380dd
  • Loading branch information
gs-olive authored and facebook-github-bot committed Oct 4, 2023
1 parent 9b2f1ee commit 3f47fef
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 31 deletions.
6 changes: 5 additions & 1 deletion torchbenchmark/util/backends/trt.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def parse_torch_trt_args(backend_args: List[str]):
arg_parser.add_argument(
"--ir",
type=str,
help="Which internal representation to use: {'ts', 'dynamo_compile', 'fx_ts_compat', ...}",
help="Which internal representation to use: {'torch_compile', 'dynamo', 'ts', ...}",
)
args, unknown = arg_parser.parse_known_args(backend_args)

Expand Down Expand Up @@ -125,6 +125,10 @@ def _torch_trt():
enabled_precisions={torch_dtype_precision},
**torch_trt_kwargs,
)

# Trigger compilation
trt_module(*example_inputs)

model.set_module(trt_module)

return _torch_trt, backend_args
64 changes: 34 additions & 30 deletions userbenchmark/torch_trt/run.py
Original file line number Diff line number Diff line change
@@ -1,32 +1,30 @@
import argparse
import traceback
import torch

import numpy as np

import json
import os
import time
import traceback
from datetime import datetime
from typing import List, Union

from torchbenchmark.util.experiment.instantiator import (
TorchBenchModelConfig,
load_model_isolated,
list_models,
)
import numpy as np
import torch
from torchbenchmark import (
ModelNotFoundError,
ModelTask,
load_canary_model_by_name,
load_model_by_name,
ModelNotFoundError,
)
from torchbenchmark.util.experiment.instantiator import (
TorchBenchModelConfig,
list_models,
load_model_isolated,
)
from torchbenchmark.util.model import BenchmarkModel


def cli(args: List[str]):
"""Parse input arguments, extracting model specification and batch size"""
arg_parser = argparse.ArgumentParser(args)
arg_parser = argparse.ArgumentParser()
arg_parser.add_argument(
"--model",
help="Full or partial name of a model to run. If partial, picks the first match.",
Expand Down Expand Up @@ -153,8 +151,8 @@ def run_one_step(
)

# Get median times for GPU and CPU Walltime
gpu_time = np.median(list(map(lambda x: x[0], result_summary)))
cpu_walltime = np.median(list(map(lambda x: x[1], result_summary)))
gpu_time = np.median([x[0] for x in result_summary])
cpu_walltime = np.median([x[1] for x in result_summary])

# Differentiate model attribute access based on input type
if isinstance(model, ModelTask):
Expand Down Expand Up @@ -195,7 +193,7 @@ def run(args: List[str]):
ir_idx = unknown_args.index("--ir")
selected_ir = unknown_args[ir_idx + 1]
except (ValueError, IndexError):
# If no IR was specified, default to torch.compile
# If no IR was specified, default to torch.compile
selected_ir = "torch_compile"
unknown_args.append("--ir")
unknown_args.append(selected_ir)
Expand Down Expand Up @@ -232,7 +230,8 @@ def run(args: List[str]):
extra_args=[
"--backend",
]
+ unknown_args,
+ unknown_args
+ ["--truncate_long_and_double"],
)

all_metrics = run_single_model(
Expand All @@ -247,6 +246,7 @@ def run(args: List[str]):

# For all models, use ModelTask instances
for model_name in list_models():
# Add optional argument `truncate_long_and_double`, as required for many models
config = TorchBenchModelConfig(
name=model_name,
test="eval",
Expand All @@ -255,26 +255,30 @@ def run(args: List[str]):
extra_args=[
"--backend",
]
+ unknown_args,
+ unknown_args
+ ["--truncate_long_and_double"],
)

try:
Model = load_model_isolated(config=config)
except ValueError as e:
print(
f"Loading model {model_name} failed with:\n{e}\nSkipping the model."
print(f"\nLoading model {model_name} succeeded.\n")
metrics = run_single_model(
Model,
selected_ir,
parsed_args["num_warmup"],
parsed_args["num_iter"],
)
continue

metrics = run_single_model(
Model,
selected_ir,
parsed_args["num_warmup"],
parsed_args["num_iter"],
)
all_metrics = {**all_metrics, **metrics}
all_metrics = {**all_metrics, **metrics}

# Delete model instance and clean up workspace
del Model
# Delete model instance and clean up workspace
del Model

except Exception as e:
traceback.print_exc()
print(
f"\nLoading model {model_name} failed with:\n{e}\nSkipping the model.\n"
)
continue

save_metrics(all_metrics)

0 comments on commit 3f47fef

Please sign in to comment.