Skip to content

Commit 05bb93e

Browse files
Yosua Michael Maranathafacebook-github-bot
authored andcommitted
[fbsync] add usage logging to prototype dispatchers / kernels (#7012)
Reviewed By: datumbox Differential Revision: D41836890 fbshipit-source-id: bc3cf1b53fba33beaac82fa5ba5f289b8839e84d
1 parent ce0ba2d commit 05bb93e

File tree

8 files changed

+148
-0
lines changed

8 files changed

+148
-0
lines changed

test/prototype_transforms_kernel_infos.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,9 @@ def __init__(
5757
# structure, but with adapted parameters. This is useful in case a parameter value is closely tied to the input
5858
# dtype.
5959
float32_vs_uint8=False,
60+
# Some kernels don't have dispatchers that would handle logging the usage. Thus, the kernel has to do it
61+
# manually. If set, triggers a test that makes sure this happens.
62+
logs_usage=False,
6063
# See InfoBase
6164
test_marks=None,
6265
# See InfoBase
@@ -71,6 +74,7 @@ def __init__(
7174
if float32_vs_uint8 and not callable(float32_vs_uint8):
7275
float32_vs_uint8 = lambda other_args, kwargs: (other_args, kwargs) # noqa: E731
7376
self.float32_vs_uint8 = float32_vs_uint8
77+
self.logs_usage = logs_usage
7478

7579

7680
def _pixel_difference_closeness_kwargs(uint8_atol, *, dtype=torch.uint8, mae=False):
@@ -675,6 +679,7 @@ def reference_inputs_convert_format_bounding_box():
675679
sample_inputs_fn=sample_inputs_convert_format_bounding_box,
676680
reference_fn=reference_convert_format_bounding_box,
677681
reference_inputs_fn=reference_inputs_convert_format_bounding_box,
682+
logs_usage=True,
678683
),
679684
)
680685

@@ -2100,6 +2105,7 @@ def sample_inputs_clamp_bounding_box():
21002105
KernelInfo(
21012106
F.clamp_bounding_box,
21022107
sample_inputs_fn=sample_inputs_clamp_bounding_box,
2108+
logs_usage=True,
21032109
)
21042110
)
21052111

test/test_prototype_transforms_functional.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,19 @@ class TestKernels:
108108
args_kwargs_fn=lambda info: info.reference_inputs_fn(),
109109
)
110110

111+
@make_info_args_kwargs_parametrization(
112+
[info for info in KERNEL_INFOS if info.logs_usage],
113+
args_kwargs_fn=lambda info: info.sample_inputs_fn(),
114+
)
115+
@pytest.mark.parametrize("device", cpu_and_gpu())
116+
def test_logging(self, spy_on, info, args_kwargs, device):
117+
spy = spy_on(torch._C._log_api_usage_once)
118+
119+
args, kwargs = args_kwargs.load(device)
120+
info.kernel(*args, **kwargs)
121+
122+
spy.assert_any_call(f"{info.kernel.__module__}.{info.id}")
123+
111124
@ignore_jit_warning_no_profile
112125
@sample_inputs
113126
@pytest.mark.parametrize("device", cpu_and_gpu())
@@ -291,6 +304,19 @@ class TestDispatchers:
291304
args_kwargs_fn=lambda info: info.sample_inputs(datapoints.Image),
292305
)
293306

307+
@make_info_args_kwargs_parametrization(
308+
DISPATCHER_INFOS,
309+
args_kwargs_fn=lambda info: info.sample_inputs(),
310+
)
311+
@pytest.mark.parametrize("device", cpu_and_gpu())
312+
def test_logging(self, spy_on, info, args_kwargs, device):
313+
spy = spy_on(torch._C._log_api_usage_once)
314+
315+
args, kwargs = args_kwargs.load(device)
316+
info.dispatcher(*args, **kwargs)
317+
318+
spy.assert_any_call(f"{info.dispatcher.__module__}.{info.id}")
319+
294320
@ignore_jit_warning_no_profile
295321
@image_sample_inputs
296322
@pytest.mark.parametrize("device", cpu_and_gpu())

torchvision/prototype/transforms/functional/_augment.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import torch
66
from torchvision.prototype import datapoints
77
from torchvision.transforms.functional import pil_to_tensor, to_pil_image
8+
from torchvision.utils import _log_api_usage_once
89

910

1011
def erase_image_tensor(
@@ -41,6 +42,9 @@ def erase(
4142
v: torch.Tensor,
4243
inplace: bool = False,
4344
) -> Union[datapoints.ImageTypeJIT, datapoints.VideoTypeJIT]:
45+
if not torch.jit.is_scripting():
46+
_log_api_usage_once(erase)
47+
4448
if isinstance(inpt, torch.Tensor) and (
4549
torch.jit.is_scripting() or not isinstance(inpt, (datapoints.Image, datapoints.Video))
4650
):

torchvision/prototype/transforms/functional/_color.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
from torchvision.transforms import functional_pil as _FP
66
from torchvision.transforms.functional_tensor import _max_value
77

8+
from torchvision.utils import _log_api_usage_once
9+
810
from ._meta import _num_value_bits, _rgb_to_gray, convert_dtype_image_tensor
911

1012

@@ -38,6 +40,9 @@ def adjust_brightness_video(video: torch.Tensor, brightness_factor: float) -> to
3840

3941

4042
def adjust_brightness(inpt: datapoints.InputTypeJIT, brightness_factor: float) -> datapoints.InputTypeJIT:
43+
if not torch.jit.is_scripting():
44+
_log_api_usage_once(adjust_brightness)
45+
4146
if isinstance(inpt, torch.Tensor) and (
4247
torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint)
4348
):
@@ -79,6 +84,9 @@ def adjust_saturation_video(video: torch.Tensor, saturation_factor: float) -> to
7984

8085

8186
def adjust_saturation(inpt: datapoints.InputTypeJIT, saturation_factor: float) -> datapoints.InputTypeJIT:
87+
if not torch.jit.is_scripting():
88+
_log_api_usage_once(adjust_saturation)
89+
8290
if isinstance(inpt, torch.Tensor) and (
8391
torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint)
8492
):
@@ -120,6 +128,9 @@ def adjust_contrast_video(video: torch.Tensor, contrast_factor: float) -> torch.
120128

121129

122130
def adjust_contrast(inpt: datapoints.InputTypeJIT, contrast_factor: float) -> datapoints.InputTypeJIT:
131+
if not torch.jit.is_scripting():
132+
_log_api_usage_once(adjust_contrast)
133+
123134
if isinstance(inpt, torch.Tensor) and (
124135
torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint)
125136
):
@@ -195,6 +206,9 @@ def adjust_sharpness_video(video: torch.Tensor, sharpness_factor: float) -> torc
195206

196207

197208
def adjust_sharpness(inpt: datapoints.InputTypeJIT, sharpness_factor: float) -> datapoints.InputTypeJIT:
209+
if not torch.jit.is_scripting():
210+
_log_api_usage_once(adjust_sharpness)
211+
198212
if isinstance(inpt, torch.Tensor) and (
199213
torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint)
200214
):
@@ -309,6 +323,9 @@ def adjust_hue_video(video: torch.Tensor, hue_factor: float) -> torch.Tensor:
309323

310324

311325
def adjust_hue(inpt: datapoints.InputTypeJIT, hue_factor: float) -> datapoints.InputTypeJIT:
326+
if not torch.jit.is_scripting():
327+
_log_api_usage_once(adjust_hue)
328+
312329
if isinstance(inpt, torch.Tensor) and (
313330
torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint)
314331
):
@@ -351,6 +368,9 @@ def adjust_gamma_video(video: torch.Tensor, gamma: float, gain: float = 1) -> to
351368

352369

353370
def adjust_gamma(inpt: datapoints.InputTypeJIT, gamma: float, gain: float = 1) -> datapoints.InputTypeJIT:
371+
if not torch.jit.is_scripting():
372+
_log_api_usage_once(adjust_gamma)
373+
354374
if isinstance(inpt, torch.Tensor) and (
355375
torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint)
356376
):
@@ -387,6 +407,9 @@ def posterize_video(video: torch.Tensor, bits: int) -> torch.Tensor:
387407

388408

389409
def posterize(inpt: datapoints.InputTypeJIT, bits: int) -> datapoints.InputTypeJIT:
410+
if not torch.jit.is_scripting():
411+
_log_api_usage_once(posterize)
412+
390413
if isinstance(inpt, torch.Tensor) and (
391414
torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint)
392415
):
@@ -417,6 +440,9 @@ def solarize_video(video: torch.Tensor, threshold: float) -> torch.Tensor:
417440

418441

419442
def solarize(inpt: datapoints.InputTypeJIT, threshold: float) -> datapoints.InputTypeJIT:
443+
if not torch.jit.is_scripting():
444+
_log_api_usage_once(solarize)
445+
420446
if isinstance(inpt, torch.Tensor) and (
421447
torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint)
422448
):
@@ -469,6 +495,9 @@ def autocontrast_video(video: torch.Tensor) -> torch.Tensor:
469495

470496

471497
def autocontrast(inpt: datapoints.InputTypeJIT) -> datapoints.InputTypeJIT:
498+
if not torch.jit.is_scripting():
499+
_log_api_usage_once(autocontrast)
500+
472501
if isinstance(inpt, torch.Tensor) and (
473502
torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint)
474503
):
@@ -561,6 +590,9 @@ def equalize_video(video: torch.Tensor) -> torch.Tensor:
561590

562591

563592
def equalize(inpt: datapoints.InputTypeJIT) -> datapoints.InputTypeJIT:
593+
if not torch.jit.is_scripting():
594+
_log_api_usage_once(equalize)
595+
564596
if isinstance(inpt, torch.Tensor) and (
565597
torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint)
566598
):
@@ -594,6 +626,9 @@ def invert_video(video: torch.Tensor) -> torch.Tensor:
594626

595627

596628
def invert(inpt: datapoints.InputTypeJIT) -> datapoints.InputTypeJIT:
629+
if not torch.jit.is_scripting():
630+
_log_api_usage_once(invert)
631+
597632
if isinstance(inpt, torch.Tensor) and (
598633
torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint)
599634
):

torchvision/prototype/transforms/functional/_geometry.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919
)
2020
from torchvision.transforms.functional_tensor import _pad_symmetric
2121

22+
from torchvision.utils import _log_api_usage_once
23+
2224
from ._meta import convert_format_bounding_box, get_spatial_size_image_pil
2325

2426

@@ -55,6 +57,9 @@ def horizontal_flip_video(video: torch.Tensor) -> torch.Tensor:
5557

5658

5759
def horizontal_flip(inpt: datapoints.InputTypeJIT) -> datapoints.InputTypeJIT:
60+
if not torch.jit.is_scripting():
61+
_log_api_usage_once(horizontal_flip)
62+
5863
if isinstance(inpt, torch.Tensor) and (
5964
torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint)
6065
):
@@ -103,6 +108,9 @@ def vertical_flip_video(video: torch.Tensor) -> torch.Tensor:
103108

104109

105110
def vertical_flip(inpt: datapoints.InputTypeJIT) -> datapoints.InputTypeJIT:
111+
if not torch.jit.is_scripting():
112+
_log_api_usage_once(vertical_flip)
113+
106114
if isinstance(inpt, torch.Tensor) and (
107115
torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint)
108116
):
@@ -231,6 +239,8 @@ def resize(
231239
max_size: Optional[int] = None,
232240
antialias: Optional[bool] = None,
233241
) -> datapoints.InputTypeJIT:
242+
if not torch.jit.is_scripting():
243+
_log_api_usage_once(resize)
234244
if isinstance(inpt, torch.Tensor) and (
235245
torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint)
236246
):
@@ -730,6 +740,9 @@ def affine(
730740
fill: datapoints.FillTypeJIT = None,
731741
center: Optional[List[float]] = None,
732742
) -> datapoints.InputTypeJIT:
743+
if not torch.jit.is_scripting():
744+
_log_api_usage_once(affine)
745+
733746
# TODO: consider deprecating integers from angle and shear on the future
734747
if isinstance(inpt, torch.Tensor) and (
735748
torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint)
@@ -913,6 +926,9 @@ def rotate(
913926
center: Optional[List[float]] = None,
914927
fill: datapoints.FillTypeJIT = None,
915928
) -> datapoints.InputTypeJIT:
929+
if not torch.jit.is_scripting():
930+
_log_api_usage_once(rotate)
931+
916932
if isinstance(inpt, torch.Tensor) and (
917933
torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint)
918934
):
@@ -1120,6 +1136,9 @@ def pad(
11201136
fill: datapoints.FillTypeJIT = None,
11211137
padding_mode: str = "constant",
11221138
) -> datapoints.InputTypeJIT:
1139+
if not torch.jit.is_scripting():
1140+
_log_api_usage_once(pad)
1141+
11231142
if isinstance(inpt, torch.Tensor) and (
11241143
torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint)
11251144
):
@@ -1197,6 +1216,9 @@ def crop_video(video: torch.Tensor, top: int, left: int, height: int, width: int
11971216

11981217

11991218
def crop(inpt: datapoints.InputTypeJIT, top: int, left: int, height: int, width: int) -> datapoints.InputTypeJIT:
1219+
if not torch.jit.is_scripting():
1220+
_log_api_usage_once(crop)
1221+
12001222
if isinstance(inpt, torch.Tensor) and (
12011223
torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint)
12021224
):
@@ -1452,6 +1474,8 @@ def perspective(
14521474
fill: datapoints.FillTypeJIT = None,
14531475
coefficients: Optional[List[float]] = None,
14541476
) -> datapoints.InputTypeJIT:
1477+
if not torch.jit.is_scripting():
1478+
_log_api_usage_once(perspective)
14551479
if isinstance(inpt, torch.Tensor) and (
14561480
torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint)
14571481
):
@@ -1612,6 +1636,9 @@ def elastic(
16121636
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
16131637
fill: datapoints.FillTypeJIT = None,
16141638
) -> datapoints.InputTypeJIT:
1639+
if not torch.jit.is_scripting():
1640+
_log_api_usage_once(elastic)
1641+
16151642
if isinstance(inpt, torch.Tensor) and (
16161643
torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint)
16171644
):
@@ -1724,6 +1751,9 @@ def center_crop_video(video: torch.Tensor, output_size: List[int]) -> torch.Tens
17241751

17251752

17261753
def center_crop(inpt: datapoints.InputTypeJIT, output_size: List[int]) -> datapoints.InputTypeJIT:
1754+
if not torch.jit.is_scripting():
1755+
_log_api_usage_once(center_crop)
1756+
17271757
if isinstance(inpt, torch.Tensor) and (
17281758
torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint)
17291759
):
@@ -1817,6 +1847,9 @@ def resized_crop(
18171847
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
18181848
antialias: Optional[bool] = None,
18191849
) -> datapoints.InputTypeJIT:
1850+
if not torch.jit.is_scripting():
1851+
_log_api_usage_once(resized_crop)
1852+
18201853
if isinstance(inpt, torch.Tensor) and (
18211854
torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint)
18221855
):
@@ -1897,6 +1930,9 @@ def five_crop_video(
18971930
def five_crop(
18981931
inpt: ImageOrVideoTypeJIT, size: List[int]
18991932
) -> Tuple[ImageOrVideoTypeJIT, ImageOrVideoTypeJIT, ImageOrVideoTypeJIT, ImageOrVideoTypeJIT, ImageOrVideoTypeJIT]:
1933+
if not torch.jit.is_scripting():
1934+
_log_api_usage_once(five_crop)
1935+
19001936
# TODO: consider breaking BC here to return List[datapoints.ImageTypeJIT/VideoTypeJIT] to align this op with
19011937
# `ten_crop`
19021938
if isinstance(inpt, torch.Tensor) and (
@@ -1952,6 +1988,9 @@ def ten_crop_video(video: torch.Tensor, size: List[int], vertical_flip: bool = F
19521988
def ten_crop(
19531989
inpt: Union[datapoints.ImageTypeJIT, datapoints.VideoTypeJIT], size: List[int], vertical_flip: bool = False
19541990
) -> Union[List[datapoints.ImageTypeJIT], List[datapoints.VideoTypeJIT]]:
1991+
if not torch.jit.is_scripting():
1992+
_log_api_usage_once(ten_crop)
1993+
19551994
if isinstance(inpt, torch.Tensor) and (
19561995
torch.jit.is_scripting() or not isinstance(inpt, (datapoints.Image, datapoints.Video))
19571996
):

0 commit comments

Comments
 (0)