Skip to content

Commit 68c4f56

Browse files
Yosua Michael Maranathadatumbox
authored andcommitted
[fbsync] Add more advanced smoke test for project Nova and validation workflows (#7014)
Summary: * Add more advanced smoke test * add torch import * remove dependency on torch * Add missing vars * More code and ufmt Reviewed By: datumbox Differential Revision: D41836892 fbshipit-source-id: 643a85707be252c4a84c59578b8b1929d415193a Co-authored-by: Vasilis Vryniotis <[email protected]>
1 parent 35f7afb commit 68c4f56

File tree

1 file changed

+56
-5
lines changed

1 file changed

+56
-5
lines changed

test/smoke_test.py

Lines changed: 56 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,63 @@
11
"""Run smoke tests"""
22

33
import os
4+
from pathlib import Path
45

6+
import torch
57
import torchvision
68
from torchvision.io import read_image
9+
from torchvision.models import resnet50, ResNet50_Weights
710

8-
image_path = os.path.join(
9-
os.path.dirname(os.path.abspath(__file__)), "assets", "encode_jpeg", "grace_hopper_517x606.jpg"
10-
)
11-
print("torchvision version is ", torchvision.__version__)
12-
img = read_image(image_path)
11+
SCRIPT_DIR = Path(__file__).parent
12+
13+
14+
def smoke_test_torchvision() -> None:
15+
print(
16+
"Is torchvision useable?",
17+
all(x is not None for x in [torch.ops.image.decode_png, torch.ops.torchvision.roi_align]),
18+
)
19+
20+
21+
def smoke_test_torchvision_read_decode() -> None:
22+
img_jpg = read_image(str(SCRIPT_DIR / "assets" / "encode_jpeg" / "grace_hopper_517x606.jpg"))
23+
if img_jpg.ndim != 3 or img_jpg.numel() < 100:
24+
raise RuntimeError(f"Unexpected shape of img_jpg: {img_jpg.shape}")
25+
img_png = read_image(str(SCRIPT_DIR / "assets" / "interlaced_png" / "wizard_low.png"))
26+
if img_png.ndim != 3 or img_png.numel() < 100:
27+
raise RuntimeError(f"Unexpected shape of img_png: {img_png.shape}")
28+
29+
30+
def smoke_test_torchvision_resnet50_classify() -> None:
31+
img = read_image(str(SCRIPT_DIR / ".." / "gallery" / "assets" / "dog2.jpg"))
32+
33+
# Step 1: Initialize model with the best available weights
34+
weights = ResNet50_Weights.DEFAULT
35+
model = resnet50(weights=weights)
36+
model.eval()
37+
38+
# Step 2: Initialize the inference transforms
39+
preprocess = weights.transforms()
40+
41+
# Step 3: Apply inference preprocessing transforms
42+
batch = preprocess(img).unsqueeze(0)
43+
44+
# Step 4: Use the model and print the predicted category
45+
prediction = model(batch).squeeze(0).softmax(0)
46+
class_id = prediction.argmax().item()
47+
score = prediction[class_id].item()
48+
category_name = weights.meta["categories"][class_id]
49+
expected_category = "German shepherd"
50+
print(f"{category_name}: {100 * score:.1f}%")
51+
if category_name != expected_category:
52+
raise RuntimeError(f"Failed ResNet50 classify {category_name} Expected: {expected_category}")
53+
54+
55+
def main() -> None:
56+
print(f"torchvision: {torchvision.__version__}")
57+
smoke_test_torchvision()
58+
smoke_test_torchvision_read_decode()
59+
smoke_test_torchvision_resnet50_classify()
60+
61+
62+
if __name__ == "__main__":
63+
main()

0 commit comments

Comments
 (0)