Skip to content

Commit

Permalink
#13373: PyTorch Tracing Sweeps for hardswish
Browse files Browse the repository at this point in the history
  • Loading branch information
VirdhatchaniKN committed Oct 3, 2024
1 parent a4dda5e commit 58a26e2
Show file tree
Hide file tree
Showing 8 changed files with 367 additions and 160 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,13 @@ jobs:
cmd: tests/scripts/single_card/nightly/run_ttnn.sh,
timeout: 70
},
{
name: "WH N300 pgm dispatch nightly",
arch: wormhole_b0,
runs-on: ["cloud-virtual-machine", "N300", "in-service"],
cmd: ./tests/tt_metal/tt_metal/perf_microbenchmark/dispatch/compare_pgm_dispatch_perf_ci.sh,
timeout: 10
},
{
name: "GS-only models",
arch: grayskull,
Expand Down
1 change: 1 addition & 0 deletions .github/workflows/ttnn-run-sweeps.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ on:
- eltwise.unary.lgamma.lgamma
- eltwise.unary.sigmoid.sigmoid
- eltwise.unary.sigmoid_accurate.sigmoid_accurate
- eltwise.unary.hardswish.hardswish_pytorch2
- eltwise.binary.subtract.subtract
- eltwise.binary.multiply.multiply
- eltwise.binary.div.div
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
# SPDX-FileCopyrightText: © 2024 Tenstorrent Inc.

# SPDX-License-Identifier: Apache-2.0

from typing import Optional, Tuple
from functools import partial

import torch
import random
import ttnn
from tests.sweep_framework.utils import gen_shapes, gen_low_high_scalars
from tests.tt_eager.python_api_testing.sweep_tests.generation_funcs import gen_func_with_cast_tt

from tests.ttnn.utils_for_testing import check_with_pcc, start_measuring_time, stop_measuring_time
from models.utility_functions import torch_random

# Override the default timeout in seconds for hang detection.
TIMEOUT = 30

random.seed(0)

parameters = {
"hardswish_1": {
"input_shape": [
[1, 1024],
[1, 120, 14, 14],
[1, 1280],
[1, 144, 14, 14],
[1, 16, 112, 112],
[1, 16, 160, 160],
[1, 184, 14, 14],
[1, 184, 20, 20],
[1, 200, 14, 14],
[1, 200, 20, 20],
[1, 240, 14, 14],
[1, 240, 20, 20],
[1, 240, 28, 28],
[1, 240, 40, 40],
[1, 288, 14, 14],
[1, 288, 7, 7],
[1, 480, 10, 10],
[1, 480, 14, 14],
[1, 480, 20, 20],
[1, 576, 7, 7],
[1, 672, 10, 10],
[1, 672, 14, 14],
[1, 672, 20, 20],
[1, 672, 7, 7],
[1, 96, 14, 14],
[1, 96, 28, 28],
[1, 960, 7, 7],
],
"input_dtype": [ttnn.bfloat16, ttnn.bfloat8_b],
"input_layout": [ttnn.TILE_LAYOUT, ttnn.ROW_MAJOR_LAYOUT],
"input_memory_config": [ttnn.DRAM_MEMORY_CONFIG, ttnn.L1_MEMORY_CONFIG],
"output_memory_config": [ttnn.DRAM_MEMORY_CONFIG, ttnn.L1_MEMORY_CONFIG],
},
}


def run(
input_shape,
input_dtype,
input_layout,
input_memory_config,
output_memory_config,
*,
device,
) -> list:
data_seed = random.randint(0, 20000000)
torch.manual_seed(data_seed)

torch_input_tensor_a = gen_func_with_cast_tt(
partial(torch_random, low=-100, high=100, dtype=torch.float32), input_dtype
)(input_shape)

torch_output_tensor = torch.nn.functional.hardswish(torch_input_tensor_a)

input_tensor_a = ttnn.from_torch(
torch_input_tensor_a,
dtype=input_dtype,
layout=input_layout,
device=device,
memory_config=input_memory_config,
)

start_time = start_measuring_time()
result = ttnn.hardswish(input_tensor_a, memory_config=output_memory_config)
output_tensor = ttnn.to_torch(result)
e2e_perf = stop_measuring_time(start_time)

return [check_with_pcc(torch_output_tensor, output_tensor, 0.999), e2e_perf]
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
#!/bin/bash

LOG_FILE1="$TT_METAL_HOME/tests/tt_metal/tt_metal/perf_microbenchmark/dispatch/pgm_dispatch_golden.log"
LOG_FILE2="results.log"

# Run the pgm dispatch sweep with trace mode
cd $TT_METAL_HOME
./tests/tt_metal/tt_metal/perf_microbenchmark/dispatch/sweep_pgm_dispatch.sh --trace | tee log
./tests/tt_metal/tt_metal/perf_microbenchmark/dispatch/filt_pgm_dispatch.pl log > $LOG_FILE2

THRESHOLD=4 # Percentage difference threshold

# Check if log files exist
if [[ ! -f "$LOG_FILE1" || ! -f "$LOG_FILE2" ]]; then
echo "Error: One or both log files do not exist."
exit 1
fi

# Read and compare values from the log files
line_count=0
exit_code=0
while IFS= read -r line1 && IFS= read -r line2 <&3; do
# Convert commas to newlines to handle both formats
values1=($(echo "$line1" | tr ',' '\n'))
values2=($(echo "$line2" | tr ',' '\n'))

# Iterate through values
for i in "${!values1[@]}"; do
value1="${values1[$i]}"
value2="${values2[$i]}"

# Check if both values are numeric
if [[ -z "$value1" || -z "$value2" || ! "$value1" =~ ^[0-9]+(\.[0-9]+)?$ || ! "$value2" =~ ^[0-9]+(\.[0-9]+)?$ ]]; then
echo "Got invalid numeric value in output, check if all pgm dispatch tests ran properly."
cat $LOG_FILE2
exit 1
fi
if (( $(echo "$value2 < $value1" | bc -l) )); then
echo "Line $line_count test $i got $value2 which is lower than expected $value1, consider updating $LOG_FILE1"
fi
# Calculate the percentage difference
if (( $(echo "$value1 != 0" | bc -l) )); then
percentage_diff=$(echo "scale=2; 100 * (($value2 - $value1) / $value1)" | bc)
else
continue
fi

# Check if the percentage difference exceeds the threshold
if (( $(echo "$percentage_diff > $THRESHOLD" | bc -l) )); then
echo "Error: Line $line_count test $i expected value $value1 but got $value2"
exit_code=1
fi
done
line_count=$((line_count+1))
done < "$LOG_FILE1" 3< "$LOG_FILE2"

exit $exit_code
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,9 @@
my @parts = split(' ', $line);
my $us = $parts[8];
my $index = index($parts[8], ".");
$us = substr($us, 0, $index + 3);
my $digits = index($parts[8], "us") - $index;
$digits = $digits >= 3 ? 3 : $digits;
$us = substr($us, 0, $index + $digits);
$data->[$j][$i] = $us;
$j++;
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
3.03, 3.03, 3.26, 3.37, 3.52, 3.95, 3.99, 4.85, 3.87, 6.14, 3.71, 4.66, 6.07, 12.29, 4.42, 4.41, 14.04, 10.73,
3.04, 3.04, 3.27, 3.38, 3.58, 3.99, 4.03, 4.88, 3.87, 6.15, 3.72, 4.69, 6.10, 12.33, 4.42, 4.42, 14.07, 10.94,
3.06, 3.06, 3.34, 3.54, 3.75, 4.10, 4.15, 4.97, 4.04, 6.19, 3.74, 4.75, 5.98, 12.51, 4.45, 4.44, 14.21, 11.37,
3.10, 3.10, 3.55, 3.86, 3.98, 4.16, 4.40, 5.15, 4.26, 6.22, 3.77, 4.86, 7.00, 12.76, 4.51, 4.51, 14.44, 12.17,
3.31, 3.31, 3.99, 4.31, 4.65, 4.80, 5.12, 5.64, 5.01, 6.46, 3.88, 5.24, 7.17, 13.43, 4.71, 4.70, 15.17, 15.28,
6.49,
Loading

0 comments on commit 58a26e2

Please sign in to comment.