|
1 | 1 | """Run smoke tests"""
|
2 | 2 |
|
3 | 3 | import os
|
| 4 | +from pathlib import Path |
4 | 5 |
|
| 6 | +import torch |
5 | 7 | import torchvision
|
6 | 8 | from torchvision.io import read_image
|
| 9 | +from torchvision.models import resnet50, ResNet50_Weights |
7 | 10 |
|
8 |
| -image_path = os.path.join( |
9 |
| - os.path.dirname(os.path.abspath(__file__)), "assets", "encode_jpeg", "grace_hopper_517x606.jpg" |
10 |
| -) |
11 |
| -print("torchvision version is ", torchvision.__version__) |
12 |
| -img = read_image(image_path) |
| 11 | +SCRIPT_DIR = Path(__file__).parent |
| 12 | + |
| 13 | + |
| 14 | +def smoke_test_torchvision() -> None: |
| 15 | + print( |
| 16 | + "Is torchvision useable?", |
| 17 | + all(x is not None for x in [torch.ops.image.decode_png, torch.ops.torchvision.roi_align]), |
| 18 | + ) |
| 19 | + |
| 20 | + |
| 21 | +def smoke_test_torchvision_read_decode() -> None: |
| 22 | + img_jpg = read_image(str(SCRIPT_DIR / "assets" / "encode_jpeg" / "grace_hopper_517x606.jpg")) |
| 23 | + if img_jpg.ndim != 3 or img_jpg.numel() < 100: |
| 24 | + raise RuntimeError(f"Unexpected shape of img_jpg: {img_jpg.shape}") |
| 25 | + img_png = read_image(str(SCRIPT_DIR / "assets" / "interlaced_png" / "wizard_low.png")) |
| 26 | + if img_png.ndim != 3 or img_png.numel() < 100: |
| 27 | + raise RuntimeError(f"Unexpected shape of img_png: {img_png.shape}") |
| 28 | + |
| 29 | + |
| 30 | +def smoke_test_torchvision_resnet50_classify() -> None: |
| 31 | + img = read_image(str(SCRIPT_DIR / ".." / "gallery" / "assets" / "dog2.jpg")) |
| 32 | + |
| 33 | + # Step 1: Initialize model with the best available weights |
| 34 | + weights = ResNet50_Weights.DEFAULT |
| 35 | + model = resnet50(weights=weights) |
| 36 | + model.eval() |
| 37 | + |
| 38 | + # Step 2: Initialize the inference transforms |
| 39 | + preprocess = weights.transforms() |
| 40 | + |
| 41 | + # Step 3: Apply inference preprocessing transforms |
| 42 | + batch = preprocess(img).unsqueeze(0) |
| 43 | + |
| 44 | + # Step 4: Use the model and print the predicted category |
| 45 | + prediction = model(batch).squeeze(0).softmax(0) |
| 46 | + class_id = prediction.argmax().item() |
| 47 | + score = prediction[class_id].item() |
| 48 | + category_name = weights.meta["categories"][class_id] |
| 49 | + expected_category = "German shepherd" |
| 50 | + print(f"{category_name}: {100 * score:.1f}%") |
| 51 | + if category_name != expected_category: |
| 52 | + raise RuntimeError(f"Failed ResNet50 classify {category_name} Expected: {expected_category}") |
| 53 | + |
| 54 | + |
| 55 | +def main() -> None: |
| 56 | + print(f"torchvision: {torchvision.__version__}") |
| 57 | + smoke_test_torchvision() |
| 58 | + smoke_test_torchvision_read_decode() |
| 59 | + smoke_test_torchvision_resnet50_classify() |
| 60 | + |
| 61 | + |
| 62 | +if __name__ == "__main__": |
| 63 | + main() |
0 commit comments