Skip to content

Commit

Permalink
Refactor smoke tests to configure module included in the release
Browse files Browse the repository at this point in the history
  • Loading branch information
atalman committed Dec 13, 2022
1 parent 9b31a47 commit 96a2712
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 107 deletions.
1 change: 0 additions & 1 deletion .github/workflows/validate-nightly-binaries.yml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ on:
- .github/workflows/validate-macos-binaries.yml
- .github/workflows/validate-macos-arm64-binaries.yml
- test/smoke_test/*

jobs:
nightly:
uses: ./.github/workflows/validate-binaries.yml
Expand Down
159 changes: 53 additions & 106 deletions test/smoke_test/smoke_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
import argparse
import torch
import platform
import importlib
import subprocess

gpu_arch_ver = os.getenv("GPU_ARCH_VER")
gpu_arch_type = os.getenv("GPU_ARCH_TYPE")
Expand All @@ -14,6 +16,21 @@
SCRIPT_DIR = Path(__file__).parent
NIGHTLY_ALLOWED_DELTA = 3

MODULES = [
{
"name": "torchvision",
"repo": "https://github.com/pytorch/vision.git",
"smoke_test": "python ./vision/test/smoke_test.py",
"extension": "extension",
},
{
"name": "torchaudio",
"repo": "https://github.com/pytorch/audio.git",
"smoke_test": "python ./audio/test/smoke_test/smoke_test.py --no-ffmpeg",
"extension": "_extension",
},
]

def check_nightly_binaries_date(package: str) -> None:
from datetime import datetime, timedelta
format_dt = '%Y%m%d'
Expand All @@ -27,33 +44,16 @@ def check_nightly_binaries_date(package: str) -> None:
)

if(package == "all"):
import torchaudio
import torchvision
ta_str = torchaudio.__version__
tv_str = torchvision.__version__
date_ta_str = re.findall("dev\d+", torchaudio.__version__)
date_tv_str = re.findall("dev\d+", torchvision.__version__)
date_ta_delta = datetime.now() - datetime.strptime(date_ta_str[0][3:], format_dt)
date_tv_delta = datetime.now() - datetime.strptime(date_tv_str[0][3:], format_dt)

# check that the above three lists are equal and none of them is empty
if date_ta_delta.days > NIGHTLY_ALLOWED_DELTA or date_tv_delta.days > NIGHTLY_ALLOWED_DELTA:
raise RuntimeError(
f"Expected torchaudio, torchvision to be less then {NIGHTLY_ALLOWED_DELTA} days. But they are from {date_ta_str}, {date_tv_str} respectively"
)

def check_cuda_version(version: str, dlibary: str):
version = torch.ops.torchaudio.cuda_version()
if version is not None and torch.version.cuda is not None:
version_str = str(version)
ta_version = f"{version_str[:-3]}.{version_str[-2]}"
t_version = torch.version.cuda.split(".")
t_version = f"{t_version[0]}.{t_version[1]}"
if ta_version != t_version:
raise RuntimeError(
"Detected that PyTorch and {dlibary} were compiled with different CUDA versions. "
f"PyTorch has CUDA version {t_version} whereas {dlibary} has CUDA version {ta_version}. "
)
for module in MODULES:
imported_module = importlib.import_module(module["name"])
module_version = imported_module.__version__
date_m_str = re.findall("dev\d+", module_version)
date_m_delta = datetime.now() - datetime.strptime(date_m_str[0][3:], format_dt)
print(f"Nightly date check for {module['name']} version {module_version}")
if date_m_delta.days > NIGHTLY_ALLOWED_DELTA:
raise RuntimeError(
f"Expected {module['name']} to be less then {NIGHTLY_ALLOWED_DELTA} days. But its {date_m_delta}"
)

def smoke_test_cuda(package: str) -> None:
if not torch.cuda.is_available() and is_cuda_system:
Expand All @@ -69,12 +69,15 @@ def smoke_test_cuda(package: str) -> None:
print(f"cuDNN enabled? {torch.backends.cudnn.enabled}")

if(package == 'all' and is_cuda_system):
import torchaudio
import torchvision
print(f"torchvision cuda: {torch.ops.torchvision._cuda_version()}")
print(f"torchaudio cuda: {torch.ops.torchaudio.cuda_version()}")
check_cuda_version(torch.ops.torchvision._cuda_version(), "TorchVision")
check_cuda_version(torch.ops.torchaudio.cuda_version(), "TorchAudio")
for module in MODULES:
imported_module = importlib.import_module(module["name"])
# TBD for vision move extension module to private so it will
# be _extention. For audio add version return from the check
if module["extension"] == "extension":
version = imported_module.extension._check_cuda_version()
print(f"{module['name']} CUDA: {version}")
else:
imported_module._extension._check_cuda_version()


def smoke_test_conv2d() -> None:
Expand All @@ -97,67 +100,20 @@ def smoke_test_conv2d() -> None:
out = conv(x)


def smoke_test_torchvision() -> None:
print(
"Is torchvision useable?",
all(
x is not None
for x in [torch.ops.image.decode_png, torch.ops.torchvision.roi_align]
),
)


def smoke_test_torchvision_read_decode() -> None:
from torchvision.io import read_image

img_jpg = read_image(str(SCRIPT_DIR / "assets" / "rgb_pytorch.jpg"))
if img_jpg.ndim != 3 or img_jpg.numel() < 100:
raise RuntimeError(f"Unexpected shape of img_jpg: {img_jpg.shape}")
img_png = read_image(str(SCRIPT_DIR / "assets" / "rgb_pytorch.png"))
if img_png.ndim != 3 or img_png.numel() < 100:
raise RuntimeError(f"Unexpected shape of img_png: {img_png.shape}")


def smoke_test_torchvision_resnet50_classify(device: str = "cpu") -> None:
from torchvision.io import read_image
from torchvision.models import resnet50, ResNet50_Weights

img = read_image(str(SCRIPT_DIR / "assets" / "dog2.jpg")).to(device)

# Step 1: Initialize model with the best available weights
weights = ResNet50_Weights.DEFAULT
model = resnet50(weights=weights).to(device)
model.eval()

# Step 2: Initialize the inference transforms
preprocess = weights.transforms()

# Step 3: Apply inference preprocessing transforms
batch = preprocess(img).unsqueeze(0)

# Step 4: Use the model and print the predicted category
prediction = model(batch).squeeze(0).softmax(0)
class_id = prediction.argmax().item()
score = prediction[class_id].item()
category_name = weights.meta["categories"][class_id]
expected_category = "German shepherd"
print(f"{category_name}: {100 * score:.1f}%")
if category_name != expected_category:
raise RuntimeError(
f"Failed ResNet50 classify {category_name} Expected: {expected_category}"
)


def smoke_test_torchaudio() -> None:
import torchaudio
import torchaudio.compliance.kaldi # noqa: F401
import torchaudio.datasets # noqa: F401
import torchaudio.functional # noqa: F401
import torchaudio.models # noqa: F401
import torchaudio.pipelines # noqa: F401
import torchaudio.sox_effects # noqa: F401
import torchaudio.transforms # noqa: F401
import torchaudio.utils # noqa: F401
def smoke_test_modules():
for module in MODULES:
if module["repo"]:
subprocess.check_output(f"git clone --depth 1 {module['repo']}", stderr=subprocess.STDOUT, shell=True)
try:
output = subprocess.check_output(
module["smoke_test"], stderr=subprocess.STDOUT, shell=True,
universal_newlines=True)
except subprocess.CalledProcessError as exc:
raise RuntimeError(
f"Module {module['name']} FAIL: {exc.returncode} Output: {exc.output}"
)
else:
print("Output: \n{}\n".format(output))


def main() -> None:
Expand All @@ -171,25 +127,16 @@ def main() -> None:
)
options = parser.parse_args()
print(f"torch: {torch.__version__}")

smoke_test_cuda(options.package)
smoke_test_conv2d()

if options.package == "all":
smoke_test_modules()

# only makes sense to check nightly package where dates are known
if installation_str.find("nightly") != -1:
check_nightly_binaries_date(options.package)

if options.package == "all":
import torchaudio
import torchvision
print(f"torchvision: {torchvision.__version__}")
print(f"torchaudio: {torchaudio.__version__}")
smoke_test_torchaudio()
smoke_test_torchvision()
smoke_test_torchvision_read_decode()
smoke_test_torchvision_resnet50_classify()
if torch.cuda.is_available():
smoke_test_torchvision_resnet50_classify("cuda")

if __name__ == "__main__":
main()

0 comments on commit 96a2712

Please sign in to comment.