Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

torch.compile and eager benchmarks for softmax #1670

Merged
merged 6 commits into from
Jan 31, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 37 additions & 0 deletions python_benchmarks/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,19 @@ def pytest_addoption(parser):
default=False,
help="Disable benchmarking.",
)
parser.addoption(
"--benchmark-eager",
action="store_true",
default=False,
Priya2698 marked this conversation as resolved.
Show resolved Hide resolved
help="Benchmarks torch eager mode.",
)

parser.addoption(
"--benchmark-torchcompile",
action="store_true",
default=False,
help="Benchmarks torch.compile mode.",
)


@pytest.fixture
Expand All @@ -33,3 +46,27 @@ def pytest_make_parametrize_id(val):

def pytest_benchmark_update_machine_info(config, machine_info):
machine_info.update(DEVICE_PROPERTIES)


def pytest_collection_modifyitems(session, config, items):
run_eager = config.getoption("--benchmark-eager")
run_torchcompile = config.getoption("--benchmark-torchcompile")

if not run_eager:
skip_eager = pytest.mark.skip(reason="need --benchmark-eager option to run")
for item in items:
# If the benchmark has compile=False parameter (eager mode), skip it.
if (
"compile" in item.callspec.params
and not item.callspec.params["compile"]
):
item.add_marker(skip_eager)

if not run_torchcompile:
skip_torchcompile = pytest.mark.skip(
reason="need --benchmark-torchcompile option to run"
)
for item in items:
# If the benchmark has compile=True parameter (torch.compile mode), skip it.
if "compile" in item.callspec.params and item.callspec.params["compile"]:
item.add_marker(skip_torchcompile)
Priya2698 marked this conversation as resolved.
Show resolved Hide resolved
29 changes: 28 additions & 1 deletion python_benchmarks/test_softmax_bwd.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,10 +54,15 @@ def softmax_bwd_fusion(
fd.add_output(T19)


def unary_bwd_torch(inputs: list): # [in_tensor, output, grads]
inputs[1].backward(inputs[2], retain_graph=True)
return inputs[0].grad


@pytest.mark.parametrize("size", generate_input_sizes(dims=2))
@pytest.mark.parametrize("dtype", FLOAT_DTYPES)
@pytest.mark.parametrize("reduction_axis", [0, 1])
def test_softmax_bwd_benchmark(
def test_softmax_bwd_nvf_benchmark(
benchmark,
size: tuple,
dtype: torch.dtype,
Expand All @@ -82,3 +87,25 @@ def test_softmax_bwd_benchmark(

if not disable_benchmarking:
run_benchmark(benchmark, fd.execute, inputs)


@pytest.mark.parametrize("compile", [False, True], ids=["eager", "compile"])
@pytest.mark.parametrize("size", generate_input_sizes(dims=2))
@pytest.mark.parametrize("dtype", FLOAT_DTYPES)
@pytest.mark.parametrize("reduction_axis", [0, 1])
def test_softmax_bwd_baseline_benchmark(
benchmark,
size: tuple,
dtype: torch.dtype,
reduction_axis: int,
compile: bool,
):
clear_cuda_cache()
input = torch.randn(*size, device="cuda", dtype=dtype, requires_grad=True)
grads = torch.randn(*size, device="cuda", dtype=dtype)
Priya2698 marked this conversation as resolved.
Show resolved Hide resolved
output = torch.nn.functional.softmax(input, dim=reduction_axis)
run_benchmark(
benchmark,
torch.compile(unary_bwd_torch) if compile else unary_bwd_torch,
[input, output, grads],
)
28 changes: 26 additions & 2 deletions python_benchmarks/test_softmax_fwd.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,14 @@ def softmax_fwd_fusion(
fd.add_output(T27)


def softmax_fwd_fn(inputs: list): # [in_tensor, reduction_axis]
return torch.nn.functional.softmax(inputs[0], inputs[1])


@pytest.mark.parametrize("size", generate_input_sizes(dims=2))
@pytest.mark.parametrize("dtype", FLOAT_DTYPES)
@pytest.mark.parametrize("reduction_axis", [0, 1])
def test_softmax_fwd_benchmark(
def test_softmax_fwd_nvf_benchmark(
benchmark,
size: tuple,
dtype: torch.dtype,
Expand All @@ -63,8 +67,28 @@ def test_softmax_fwd_benchmark(
softmax_fwd_fusion(fd, torch_dtype_to_nvfuser_dtype(dtype), reduction_axis)

if not disable_validation:
eager_output = torch.nn.functional.softmax(inputs[0], dim=reduction_axis)
eager_output = softmax_fwd_fn([inputs[0], reduction_axis])
fd.validate(inputs, [eager_output])

if not disable_benchmarking:
run_benchmark(benchmark, fd.execute, inputs)


@pytest.mark.parametrize("compile", [False, True], ids=["eager", "compile"])
@pytest.mark.parametrize("size", generate_input_sizes(dims=2))
@pytest.mark.parametrize("dtype", FLOAT_DTYPES)
@pytest.mark.parametrize("reduction_axis", [0, 1])
def test_softmax_fwd_baseline_benchmark(
benchmark,
size: tuple,
dtype: torch.dtype,
reduction_axis: int,
compile: bool,
):
clear_cuda_cache()
input = torch.randn(*size, device="cuda", dtype=dtype)
Priya2698 marked this conversation as resolved.
Show resolved Hide resolved
run_benchmark(
benchmark,
torch.compile(softmax_fwd_fn) if compile else softmax_fwd_fn,
[input, reduction_axis],
)
Loading