Skip to content

Commit

Permalink
Add help for run_benchmark (#2361)
Browse files Browse the repository at this point in the history
Summary:
Show more help messages for Tritonbench

```
$ python run_benchmark.py triton --help
usage: run_benchmark.py [-h] [--op OP] [--mode {fwd,bwd,fwd_bwd}] [--bwd] [--fwd_bwd] [--device DEVICE] [--warmup WARMUP] [--iter ITER] [--csv] [--dump-csv] [--skip-print] [--plot] [--ci] [--metrics METRICS] [--only ONLY] [--baseline BASELINE]
                        [--num-inputs NUM_INPUTS] [--keep-going] [--input-id INPUT_ID] [--test-only] [--dump-ir]

options:
  -h, --help            show this help message and exit
  --op OP               Operator to benchmark.
  --mode {fwd,bwd,fwd_bwd}
                        Test mode (fwd, bwd, or fwd_bwd).
  --bwd                 Run backward pass.
  --fwd_bwd             Run both forward and backward pass.
  --device DEVICE       Device to benchmark.
  --warmup WARMUP       Num of warmup runs for reach benchmark run.
  --iter ITER           Num of reps for each benchmark run.
  --csv                 Print result as csv.
  --dump-csv            Dump result as csv.
  --skip-print          Skip printing result.
  --plot                Plot the result.
  --ci                  Run in the CI mode.
  --metrics METRICS     Metrics to collect, split with comma. E.g., --metrics latency,tflops,speedup.
  --only ONLY           Specify one or multiple operator implementations to run.
  --baseline BASELINE   Override default baseline.
  --num-inputs NUM_INPUTS
                        Number of example inputs.
  --keep-going
  --input-id INPUT_ID   Specify the start input id to run. For example, --input-id 0 runs only the first available input sample.When used together like --input-id <X> --num-inputs <Y>, start from the input id <X> and run <Y> different inputs.
  --test-only           Run this under test mode, potentially skipping expensive steps like autotuning.
  --dump-ir             Dump Triton IR
```

```
$ python run_benchmark.py triton --op gemm --num-inputs 1 --only triton_tutorial_matmul
      (M, N, K)    triton_tutorial_matmul-latency
---------------  --------------------------------
(256, 256, 256)                         0.0033702
```

Pull Request resolved: #2361

Reviewed By: jananisriram

Differential Revision: D59374656

Pulled By: xuzhao9

fbshipit-source-id: 139f865895d7550a3475a1a8b4bed037a9ecc769
  • Loading branch information
xuzhao9 authored and facebook-github-bot committed Jul 9, 2024
1 parent 555d05a commit bb52940
Show file tree
Hide file tree
Showing 17 changed files with 153 additions and 130 deletions.
2 changes: 1 addition & 1 deletion run_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def list_benchmarks() -> Dict[str, str]:

def run():
available_benchmarks = list_benchmarks()
parser = argparse.ArgumentParser(description="Run a TorchBench user benchmark")
parser = argparse.ArgumentParser(description="Run a TorchBench user benchmark", add_help=False)
parser.add_argument(
"bm_name",
choices=available_benchmarks.keys(),
Expand Down
7 changes: 3 additions & 4 deletions torchbenchmark/operators/addmm/operator.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import csv
import os
import statistics
import argparse
from typing import Any, Callable, Generator, List, Optional, Tuple

import numpy
Expand Down Expand Up @@ -70,8 +69,8 @@ class Operator(BenchmarkOperator):
DEFAULT_METRICS = ["tflops", "best_config"]
DEFAULT_PRECISION = "bf16"

def __init__(self, mode: str, device: str, extra_args: Optional[List[str]] = None):
super().__init__(mode=mode, device=device, extra_args=extra_args)
def __init__(self, tb_args: argparse.Namespace, extra_args: Optional[List[str]] = None):
super().__init__(tb_args, extra_args)
addmm_args = parse_args(self.extra_args)
if addmm_args.m and addmm_args.n and addmm_args.k:
self.shapes = [(addmm_args.m, addmm_args.k, addmm_args.n)]
Expand Down
5 changes: 2 additions & 3 deletions torchbenchmark/operators/flash_attention/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,9 +107,8 @@ def parse_op_args(args: List[str]):
class Operator(BenchmarkOperator):
DEFAULT_PRECISION = "bf16"

def __init__(self, mode: str, device: str, extra_args: Optional[List[str]]=None):
# pass the framework level args (e.g., device, is_training, dtype) to the parent class
super().__init__(mode=mode, device=device, extra_args=extra_args)
def __init__(self, tb_args: argparse.Namespace, extra_args: Optional[List[str]] = None):
super().__init__(tb_args, extra_args)
args = parse_op_args(self.extra_args)
self.BATCH = args.batch
self.H = args.n_heads
Expand Down
6 changes: 3 additions & 3 deletions torchbenchmark/operators/fp8_gemm/fp8_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from triton.runtime.jit import reinterpret

from typing import Any
from typing import Any, Optional, List

from torchbenchmark.util.triton_op import (
BenchmarkOperator,
Expand All @@ -27,8 +27,8 @@ def parse_args(args):
class Operator(BenchmarkOperator):
DEFAULT_METRICS = ["tflops", "gbps", "latency"]

def __init__(self, mode, device, extra_args):
super().__init__(mode=mode, device=device, extra_args=extra_args)
def __init__(self, tb_args: argparse.Namespace, extra_args: Optional[List[str]] = None):
super().__init__(tb_args, extra_args)
self.extra_args = parse_args(extra_args)

def get_input_iter(self):
Expand Down
4 changes: 2 additions & 2 deletions torchbenchmark/operators/fp8_gemm_blockwise/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,8 +108,8 @@ class Operator(BenchmarkOperator):
DEFAULT_METRICS = ["tflops", "speedup", "accuracy"]
DEFAULT_PRECISION = "fp32"

def __init__(self, mode: str, device: str, extra_args: Optional[List[str]] = None):
super().__init__(mode=mode, device=device, extra_args=extra_args)
def __init__(self, tb_args: argparse.Namespace, extra_args: Optional[List[str]] = None):
super().__init__(tb_args, extra_args)
addmm_args = parse_args(self.extra_args)
if addmm_args.m and addmm_args.n and addmm_args.k:
self.shapes = [(addmm_args.m, addmm_args.n, addmm_args.k)]
Expand Down
4 changes: 2 additions & 2 deletions torchbenchmark/operators/fp8_gemm_rowwise/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,8 +85,8 @@ class Operator(BenchmarkOperator):
DEFAULT_METRICS = ["tflops", "speedup", "accuracy"]
DEFAULT_PRECISION = "fp32"

def __init__(self, mode: str, device: str, extra_args: Optional[List[str]] = None):
super().__init__(mode=mode, device=device, extra_args=extra_args)
def __init__(self, tb_args: argparse.Namespace, extra_args: Optional[List[str]] = None):
super().__init__(tb_args, extra_args)
addmm_args = parse_args(self.extra_args)
if addmm_args.m and addmm_args.n and addmm_args.k:
self.shapes = [(addmm_args.m, addmm_args.n, addmm_args.k)]
Expand Down
5 changes: 3 additions & 2 deletions torchbenchmark/operators/gather_gemv/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
gather + gemv is the primary kernel driving mixtral perf.
"""

import argparse
import csv
import os
import statistics
Expand Down Expand Up @@ -38,8 +39,8 @@ def gbps(self, fn_name, example_inputs, metrics: BenchmarkOperatorMetrics):
* 1e-6
)

def __init__(self, mode: str, device: str, extra_args: List[str] = []):
super().__init__(mode=mode, device=device, extra_args=extra_args)
def __init__(self, tb_args: argparse.Namespace, extra_args: Optional[List[str]] = None):
super().__init__(tb_args, extra_args)

@register_benchmark(baseline=True)
def test_0(self, p1, p2, p3) -> Callable:
Expand Down
5 changes: 3 additions & 2 deletions torchbenchmark/operators/gemm/operator.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import argparse
import csv
import os
import statistics
Expand Down Expand Up @@ -78,8 +79,8 @@ class Operator(BenchmarkOperator):
DEFAULT_METRICS = ["latency", "speedup", "accuracy", "tflops", "best_config"]
DEFAULT_PRECISION = "fp16"

def __init__(self, mode: str, device: str, extra_args: Optional[List[str]] = None):
super().__init__(mode=mode, device=device, extra_args=extra_args)
def __init__(self, tb_args: argparse.Namespace, extra_args: Optional[List[str]] = None):
super().__init__(tb_args, extra_args)
gemm_args = parse_args(self.extra_args)
if gemm_args.input:
self.shapes = read_shapes_from_csv(gemm_args.input)
Expand Down
6 changes: 3 additions & 3 deletions torchbenchmark/operators/int4_gemm/int4_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import triton.ops
import triton.language as tl

from typing import Any
from typing import Any, Optional, List

from torchbenchmark.util.triton_op import (
BenchmarkOperator,
Expand All @@ -27,8 +27,8 @@
class Operator(BenchmarkOperator):
DEFAULT_METRICS = ["tflops", "gbps", "latency", "best_config"]

def __init__(self, mode, device, extra_args):
super().__init__(mode=mode, device=device, extra_args=extra_args)
def __init__(self, tb_args: argparse.Namespace, extra_args: Optional[List[str]] = None):
super().__init__(tb_args, extra_args)
# `Group size` and `inner K tiles` are defaults from gpt-fast.
self.group_size = 32
self.inner_k_tiles = 8
Expand Down
4 changes: 2 additions & 2 deletions torchbenchmark/operators/jagged_mean/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,8 +101,8 @@ class Operator(BenchmarkOperator):
False # enables GPU/CPU sync (for methods like NestedTensor unbind)
)

def __init__(self, mode: str, device: str, extra_args: Optional[List[str]] = None):
super().__init__(mode=mode, device=device, extra_args=extra_args)
def __init__(self, tb_args: argparse.Namespace, extra_args: Optional[List[str]] = None):
super().__init__(tb_args, extra_args)
self.sizes = list(range(2, 12, 4)) + list(
range(12, 23, 3)
) # bias towards larger sizes, which are more representative of real-world shapes
Expand Down
4 changes: 2 additions & 2 deletions torchbenchmark/operators/jagged_sum/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,8 +125,8 @@ class Operator(BenchmarkOperator):
False # enables GPU/CPU sync (for methods like NestedTensor unbind)
)

def __init__(self, mode: str, device: str, extra_args: Optional[List[str]] = None):
super().__init__(mode=mode, device=device, extra_args=extra_args)
def __init__(self, tb_args: argparse.Namespace, extra_args: Optional[List[str]] = None):
super().__init__(tb_args, extra_args)
self.sizes = list(range(2, 12, 4)) + list(
range(12, 23, 3)
) # bias towards larger sizes, which are more representative of real-world shapes
Expand Down
4 changes: 2 additions & 2 deletions torchbenchmark/operators/sum/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,8 +151,8 @@ class Operator(BenchmarkOperator):

DEFAULT_METRICS = ["latency", "accuracy", "best_config"]

def __init__(self, mode: str, device: str, extra_args: Optional[List[str]] = None):
super().__init__(mode=mode, device=device, extra_args=extra_args)
def __init__(self, tb_args: argparse.Namespace, extra_args: Optional[List[str]] = None):
super().__init__(tb_args, extra_args)
args = parse_op_args(self.extra_args)
self.input_dim = args.input_dim
self.reduce_dim = args.reduce_dim
Expand Down
6 changes: 3 additions & 3 deletions torchbenchmark/operators/template_attention/operator.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@

import argparse
import csv
import os
import statistics
Expand Down Expand Up @@ -29,8 +29,8 @@
class Operator(BenchmarkOperator):
DEFAULT_METRICS = ["latency", "speedup", "accuracy"]

def __init__(self, mode: str, device: str, extra_args: List[str] = []):
super().__init__(mode=mode, device=device, extra_args=extra_args)
def __init__(self, tb_args: argparse.Namespace, extra_args: Optional[List[str]] = None):
super().__init__(tb_args, extra_args)
self.shapes = BUILDIN_SHAPES

@register_benchmark(baseline=True)
Expand Down
5 changes: 3 additions & 2 deletions torchbenchmark/operators/test_op/operator.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import argparse
from typing import Generator, List, Optional

import torch
Expand All @@ -14,8 +15,8 @@ class Operator(BenchmarkOperator):

DEFAULT_METRICS = ["test_metric"]

def __init__(self, mode: str, device: str, extra_args: Optional[List[str]] = None):
super().__init__(mode=mode, device=device, extra_args=extra_args)
def __init__(self, tb_args: argparse.Namespace, extra_args: Optional[List[str]] = None):
super().__init__(tb_args, extra_args)

@register_benchmark(label="new_op_label")
def test_op(self, x: torch.Tensor):
Expand Down
6 changes: 3 additions & 3 deletions torchbenchmark/operators/welford/operator.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@

import argparse
import csv
import os
import statistics
Expand Down Expand Up @@ -38,8 +38,8 @@
class Operator(BenchmarkOperator):
DEFAULT_METRICS = ["latency", "speedup", "accuracy"]

def __init__(self, mode: str, device: str, extra_args: List[str] = []):
super().__init__(mode=mode, device=device, extra_args=extra_args)
def __init__(self, tb_args: argparse.Namespace, extra_args: Optional[List[str]] = None):
super().__init__(tb_args, extra_args)
self.shapes = BUILDIN_SHAPES

@register_benchmark()
Expand Down
86 changes: 13 additions & 73 deletions torchbenchmark/util/triton_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@
"fp16": torch.float16,
"bf16": torch.bfloat16,
}

_RANGE_NAME = "tritonbench_range"

class Mode(Enum):
FWD = "fwd"
Expand Down Expand Up @@ -380,7 +380,6 @@ def _inner(self, *args, **kwargs):

return decorator


def register_metric(
# Metrics that only apply to non-baseline impls
# E.g., accuracy, speedup
Expand Down Expand Up @@ -408,100 +407,47 @@ def _inner(self, *args, **kwargs):

return decorator


def parse_args(
default_metrics: List[str],
args: List[str],
) -> Tuple[argparse.Namespace, List[str]]:
parser = argparse.ArgumentParser(allow_abbrev=False)
parser.add_argument(
"--metrics",
default=",".join(default_metrics),
help="Metrics to collect, split with comma. E.g., --metrics latency,tflops,speedup.",
)
parser.add_argument(
"--only",
default=None,
help="Specify one or multiple operator implementations to run."
)
parser.add_argument(
"--baseline",
type=str,
default=None,
help="Override default baseline."
)
parser.add_argument(
"--num-inputs",
type=int,
help="Number of example inputs.",
)
parser.add_argument(
"--keep-going",
action="store_true",
)
parser.add_argument(
"--input-id",
type=int,
default=0,
help="Specify the start input id to run. " \
"For example, --input-id 0 runs only the first available input sample." \
"When used together like --input-id <X> --num-inputs <Y>, start from the input id <X> " \
"and run <Y> different inputs."
)
parser.add_argument(
"--test-only",
action="store_true",
help="Run this under test mode, potentially skipping expensive steps like autotuning."
)
parser.add_argument(
"--dump-ir",
action="store_true",
help="Dump Triton IR",
)
return parser.parse_known_args(args)

class PostInitProcessor(type):
def __call__(cls, *args, **kwargs):
obj = type.__call__(cls, *args, **kwargs)
obj.__post__init__()
return obj


_RANGE_NAME = "tritonbench_range"


class BenchmarkOperator(metaclass=PostInitProcessor):
mode: Mode = Mode.FWD
test: str = "eval"
device: str = "cuda"
# By default, only collect latency metrics
# Each operator can override to define their own default metrics
DEFAULT_METRICS = ["latency"]
required_metrics: List[str]
_input_iter: Optional[Generator] = None
extra_args: List[str] = []
example_inputs: Any = None
use_cuda_graphs: bool = True

# By default, only collect latency metrics
# Each operator can override to define their own default metrics
DEFAULT_METRICS = ["latency"]

"""
A base class for adding operators to torch benchmark.
"""

def __init__(self, mode: str, device: str, extra_args: Optional[List[str]]=None):
def __init__(self, tb_args: argparse.Namespace, extra_args: Optional[List[str]]=None):
set_random_seed()
self.name = _find_op_name_from_module_path(self.__class__.__module__)
self._raw_extra_args = copy.deepcopy(extra_args)
self.tb_args = tb_args
# we accept both "fwd" and "eval"
if mode == "fwd":
if self.tb_args.mode == "fwd":
self.mode = Mode.FWD
elif mode == "fwd_bwd":
elif self.tb_args.mode == "fwd_bwd":
self.mode = Mode.FWD_BWD
else:
assert (
mode == "bwd"
self.tb_args.mode == "bwd"
), f"We only accept 3 test modes: fwd(eval), fwd_bwd(train), or bwd."
self.mode = Mode.BWD
self.dargs, unprocessed_args = parse_decoration_args(self, extra_args)
self.device = tb_args.device
self.required_metrics = list(set(tb_args.metrics.split(","))) if tb_args.metrics else self.DEFAULT_METRICS
self.dargs, self.extra_args = parse_decoration_args(self, extra_args)
if self.name not in REGISTERED_X_VALS:
REGISTERED_X_VALS[self.name] = "x_val"
# This will be changed by the time we apply the decoration args
Expand All @@ -510,17 +456,11 @@ def __init__(self, mode: str, device: str, extra_args: Optional[List[str]]=None)
[x for x in REGISTERED_METRICS.get(self.name, []) if x not in BUILTIN_METRICS]
)
self.DEFAULT_METRICS = list(set(self.DEFAULT_METRICS))
self.tb_args, self.extra_args = parse_args(
self.DEFAULT_METRICS,
unprocessed_args
)
if self.tb_args.baseline:
BASELINE_BENCHMARKS[self.name] = self.tb_args.baseline
self.required_metrics = list(set(self.tb_args.metrics.split(",")))
self._only = _split_params_by_comma(self.tb_args.only)
self._input_id = self.tb_args.input_id
self._num_inputs = self.tb_args.num_inputs
self.device = device

# Run the post initialization
def __post__init__(self):
Expand Down
Loading

0 comments on commit bb52940

Please sign in to comment.