Skip to content

Commit

Permalink
Automatically collect best_config
Browse files Browse the repository at this point in the history
Summary:
Previously, users would need to specify their own implementation of the `best_config` metric, and they would need to manually map the benchmark name to the respective config. We now monkeypatch Autotuner.run to check if a benchmark target has run an autotuned Triton kernel. This allows us to automatically log the `best_config` as a BenchmarkOperator-level generic metric.

I've also removed `dump_autotuner_best_config` in favor of using Triton's `Config.all_kwargs()` method. This should ensure that any new config parameters get automatically serialized (assuming upstream maintains this).

One minor downside is that we now have a `best_config` column regardless of whether we are benchmarking a Triton kernel. (Previously, some operator implementations would specify `skip_baseline=True` to skip this column just for the baseline, but in general we don't know if the baseline is a Triton kernel.) Possible follow-up would be to omit columns from the CSV if all the values in it are `None`.

Reviewed By: xuzhao9

Differential Revision: D59176133

fbshipit-source-id: d55be479671424a32879afba9dbf3c2808ec52e6
  • Loading branch information
int3 authored and facebook-github-bot committed Jul 8, 2024
1 parent afdc319 commit 555d05a
Show file tree
Hide file tree
Showing 6 changed files with 29 additions and 90 deletions.
12 changes: 1 addition & 11 deletions torchbenchmark/operators/addmm/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
from torchbenchmark.util.triton_op import (
BenchmarkOperator,
BenchmarkOperatorMetrics,
dump_autotuner_best_config,
register_benchmark,
register_metric,
register_x_val,
Expand Down Expand Up @@ -68,7 +67,7 @@


class Operator(BenchmarkOperator):
DEFAULT_METRICS = ["tflops"]
DEFAULT_METRICS = ["tflops", "best_config"]
DEFAULT_PRECISION = "bf16"

def __init__(self, mode: str, device: str, extra_args: Optional[List[str]] = None):
Expand Down Expand Up @@ -133,15 +132,6 @@ def get_x_val(self, example_inputs) -> Tuple[int, int, int]:
k, n = mat2.size()
return (m, n, k)

@register_metric(skip_baseline=True)
def best_config(
self, fn_name: str, example_inputs: Any, metrics: BenchmarkOperatorMetrics
) -> str:
if "triton_addmm" in str(fn_name):
return dump_autotuner_best_config(_addmm_fwd)
else:
return ""

def get_input_iter(self) -> Generator:
for shape in self.shapes:
m, k, n = shape
Expand Down
17 changes: 1 addition & 16 deletions torchbenchmark/operators/gemm/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
register_benchmark,
register_metric,
register_x_val,
dump_autotuner_best_config,
)

from .data_io import parse_args, read_shapes_from_csv
Expand Down Expand Up @@ -76,7 +75,7 @@
]

class Operator(BenchmarkOperator):
DEFAULT_METRICS = ["latency", "speedup", "accuracy", "tflops"]
DEFAULT_METRICS = ["latency", "speedup", "accuracy", "tflops", "best_config"]
DEFAULT_PRECISION = "fp16"

def __init__(self, mode: str, device: str, extra_args: Optional[List[str]] = None):
Expand Down Expand Up @@ -180,20 +179,6 @@ def gbps(
numel = numel * a.element_size() / 1e9
return numel / metrics.latency * 1e3

@register_metric(skip_baseline=True)
def best_config(
self, fn_name: str, example_inputs: Any, metrics: BenchmarkOperatorMetrics
) -> str:
if "triton_tutorial_matmul" in str(fn_name):
return dump_autotuner_best_config(triton_tutorial_matmul_kernel)
elif "triton_ops_matmul" in str(fn_name):
return dump_autotuner_best_config(kernels._kernel)
elif "hstu_triton_matmul" in str(fn_name):
import hammer
return dump_autotuner_best_config(hammer.ops.triton.triton_matmul._epilogue_mm)
else:
return ""

@register_metric()
def tflops(
self, fn_name: str, example_inputs: Any, metrics: BenchmarkOperatorMetrics
Expand Down
8 changes: 1 addition & 7 deletions torchbenchmark/operators/int4_gemm/int4_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@


class Operator(BenchmarkOperator):
DEFAULT_METRICS = ["tflops", "gbps", "latency"]
DEFAULT_METRICS = ["tflops", "gbps", "latency", "best_config"]

def __init__(self, mode, device, extra_args):
super().__init__(mode=mode, device=device, extra_args=extra_args)
Expand Down Expand Up @@ -80,12 +80,6 @@ def triton(self, x, w, scales_and_zeros):
w_int4 = pack_2xint4(w).T.contiguous().T
return lambda: matmul(x, w_int4)

@register_metric()
def best_config(self, fn, inputs, metrics):
if "triton" in str(fn):
return str(matmul_kernel.best_config)
return ""

@register_metric()
def gbps(self, fn, example_inputs: Any, metrics: BenchmarkOperatorMetrics) -> float:
def nbytes(t):
Expand Down
25 changes: 1 addition & 24 deletions torchbenchmark/operators/jagged_sum/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
from torchbenchmark.util.triton_op import (
BenchmarkOperator,
BenchmarkOperatorMetrics,
dump_autotuner_best_config,
register_benchmark,
register_metric,
)
Expand Down Expand Up @@ -121,7 +120,7 @@ def execute_kernel_variable_length_loop(x, sum_then_buffer):

class Operator(BenchmarkOperator):

DEFAULT_METRICS = ["latency", "accuracy"]
DEFAULT_METRICS = ["latency", "accuracy", "best_config"]
use_cuda_graphs = (
False # enables GPU/CPU sync (for methods like NestedTensor unbind)
)
Expand Down Expand Up @@ -310,28 +309,6 @@ def input_shape(
f"sparsity: {example_inputs[4]}", # sparsity
) # return (B, '*', M, max seqlen, sparsity) for each example input

@register_metric(skip_baseline=True)
def best_config(
self, fn_name, example_inputs, metrics: BenchmarkOperatorMetrics
) -> str:
fn_name_str = str(fn_name).split(".")[1]

if self.sum_then_buffer:
if "simple_fused" in fn_name_str:
return dump_autotuner_best_config(
triton_jagged_sum_kernel_simple_fused_sum_then_buffer
)
return dump_autotuner_best_config(
triton_jagged_sum_kernel_variable_length_loop_sum_then_buffer
)
if "simple_fused" in fn_name_str:
return dump_autotuner_best_config(
triton_jagged_sum_kernel_simple_fused_buffer_then_sum
)
return dump_autotuner_best_config(
triton_jagged_sum_kernel_variable_length_loop_buffer_then_sum
)

def plot(self):
str_B, str_M, str_seqlen, str_sparsity = f"-B-{self.B}", f"-M-{self.M}", f"-seqlen-{self.seqlen}", f"-sparsity-{self.sparsity}"
if self.B is None:
Expand Down
22 changes: 1 addition & 21 deletions torchbenchmark/operators/sum/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
from torchbenchmark.util.triton_op import (
BenchmarkOperator,
BenchmarkOperatorMetrics,
dump_autotuner_best_config,
register_benchmark,
register_metric,
)
Expand Down Expand Up @@ -150,7 +149,7 @@ def execute_kernel_2D_result(x):

class Operator(BenchmarkOperator):

DEFAULT_METRICS = ["latency", "accuracy"]
DEFAULT_METRICS = ["latency", "accuracy", "best_config"]

def __init__(self, mode: str, device: str, extra_args: Optional[List[str]] = None):
super().__init__(mode=mode, device=device, extra_args=extra_args)
Expand Down Expand Up @@ -302,25 +301,6 @@ def gbps(self, fn_name, example_inputs, metrics: BenchmarkOperatorMetrics):
* GIGABYTES_PER_BYTE
)

@register_metric(skip_baseline=True)
def best_config(
self, fn_name, example_inputs, metrics: BenchmarkOperatorMetrics
) -> str:
if self.input_dim == 2:
if self.sum_then_buffer:
return dump_autotuner_best_config(
triton_sum_kernel_1D_result_sum_then_buffer
)
return dump_autotuner_best_config(
triton_sum_kernel_1D_result_buffer_then_sum
)
elif self.input_dim == 3:
return dump_autotuner_best_config(
triton_sum_kernel_2D_result_dim_1_sum_then_buffer
)
else:
return ""

@register_metric(x_only=True)
def input_shape(
self, fn_name: str, example_inputs, metrics: BenchmarkOperatorMetrics
Expand Down
35 changes: 24 additions & 11 deletions torchbenchmark/util/triton_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
"cpu_peak_mem",
"gpu_peak_mem",
"hw_roofline",
"best_config",
]
BASELINE_SKIP_METRICS = set(["speedup", "accuracy"])
X_ONLY_METRICS = set(["hw_roofline"])
Expand Down Expand Up @@ -169,17 +170,6 @@ def _find_op_name_from_module_path(module_path: str) -> str:
return suffix.split(".")[1]
return suffix.split(".")[0]

def dump_autotuner_best_config(kernel: triton.runtime.Autotuner) -> str:
if not hasattr(kernel, "best_config"):
return ""
# pyre-ignore: Undefined attribute [16]
bconfig = kernel.best_config
kwargs = copy.deepcopy(bconfig.kwargs)
kwargs["num_stages"] = bconfig.num_stages
kwargs["num_warps"] = bconfig.num_warps
dumped_str = json.dumps(kwargs)
return dumped_str


@dataclass
class BenchmarkOperatorMetrics:
Expand Down Expand Up @@ -209,6 +199,8 @@ class BenchmarkOperatorMetrics:
error_msg: Optional[str] = None
# hw roofline
hw_roofline: Optional[float] = None
# best config
best_config: Optional[Dict[str, Any]] = None
# extra metrics
extra_metrics: Optional[Dict[str, float]] = None

Expand Down Expand Up @@ -673,6 +665,25 @@ def plot(self):
"Each operator must implement its own plotting logic."
)

def best_config(self, fn):
from unittest import mock
from triton.runtime import Autotuner

original_run = Autotuner.run
autotuner = None

def run_and_capture(self, *args, **kwargs):
nonlocal autotuner
autotuner = self
original_run(self, *args, **kwargs)

with mock.patch.object(Autotuner, "run", run_and_capture):
fn()

if autotuner is not None:
return autotuner.best_config.all_kwargs()
return None

def enable_bf16(self):
tensor_cond = lambda x: x.dtype == torch.float32
tensor_action = lambda x: x.to(torch.bfloat16)
Expand Down Expand Up @@ -859,6 +870,8 @@ def _init_extra_metrics() -> Dict[str, Any]:
metrics.ncu_rep = self.ncu_trace(input_id, fn_name, replay=True)
if "kineto_trace" in self.required_metrics:
metrics.kineto_trace = self.kineto_trace(input_id, fn)
if "best_config" in self.required_metrics:
metrics.best_config = self.best_config(fn)
# run the hidden metric "_compile_time_in_task"
# to get the compile time in parent process
if "_compile_time_in_task" in self.required_metrics:
Expand Down

0 comments on commit 555d05a

Please sign in to comment.