Skip to content

Commit

Permalink
Option to run resnet classifier on specific device
Browse files Browse the repository at this point in the history
  • Loading branch information
malfet committed Oct 13, 2022
1 parent 6bc0bc2 commit 3b93537
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions test/smoke_test/smoke_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,15 +138,15 @@ def smoke_test_torchvision_read_decode() -> None:
raise RuntimeError(f"Unexpected shape of img_png: {img_png.shape}")


def smoke_test_torchvision_resnet50_classify() -> None:
def smoke_test_torchvision_resnet50_classify(device: str = "cpu") -> None:
from torchvision.io import read_image
from torchvision.models import resnet50, ResNet50_Weights

img = read_image(str(SCRIPT_DIR / "assets" / "dog2.jpg"))
img = read_image(str(SCRIPT_DIR / "assets" / "dog2.jpg")).to(device)

# Step 1: Initialize model with the best available weights
weights = ResNet50_Weights.DEFAULT
model = resnet50(weights=weights)
model = resnet50(weights=weights).to(device)
model.eval()

# Step 2: Initialize the inference transforms
Expand Down

0 comments on commit 3b93537

Please sign in to comment.