Skip to content

Commit 8e7f21f

Browse files
authored
Merge branch 'main' into malfet/be-delete-newlines
2 parents cca7a31 + d2db23a commit 8e7f21f

File tree

1 file changed

+7
-61
lines changed

1 file changed

+7
-61
lines changed

test/test_ops.py

Lines changed: 7 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -1201,67 +1201,13 @@ def test_forward_scriptability(self):
12011201
torch.jit.script(ops.DeformConv2d(in_channels=8, out_channels=8, kernel_size=3))
12021202

12031203

1204-
@pytest.mark.parametrize("dtype", (torch.float16, torch.float32, torch.float64))
1205-
@pytest.mark.parametrize("device", cpu_and_cuda())
1206-
@pytest.mark.parametrize("requires_grad", (True, False))
1207-
def test_deform_conv2d_opcheck(dtype, device, requires_grad):
1208-
batch_size, channels_in, height, width = 1, 6, 10, 10
1209-
kernel_size = (3, 3)
1210-
stride = (1, 1)
1211-
padding = (1, 1)
1212-
dilation = (1, 1)
1213-
groups = 2
1214-
out_channels = 4
1215-
out_h = (height + 2 * padding[0] - dilation[0] * (kernel_size[0] - 1) - 1) // stride[0] + 1
1216-
out_w = (width + 2 * padding[1] - dilation[1] * (kernel_size[1] - 1) - 1) // stride[1] + 1
1217-
x = torch.randn(batch_size, channels_in, height, width, dtype=dtype, device=device, requires_grad=requires_grad)
1218-
offset = torch.randn(
1219-
batch_size,
1220-
2 * kernel_size[0] * kernel_size[1],
1221-
out_h,
1222-
out_w,
1223-
dtype=dtype,
1224-
device=device,
1225-
requires_grad=requires_grad,
1226-
)
1227-
weight = torch.randn(
1228-
out_channels,
1229-
channels_in // groups,
1230-
kernel_size[0],
1231-
kernel_size[1],
1232-
dtype=dtype,
1233-
device=device,
1234-
requires_grad=requires_grad,
1235-
)
1236-
bias = torch.randn(out_channels, dtype=dtype, device=device, requires_grad=requires_grad)
1237-
use_mask = True
1238-
mask = torch.sigmoid(
1239-
torch.randn(
1240-
batch_size,
1241-
kernel_size[0] * kernel_size[1],
1242-
out_h,
1243-
out_w,
1244-
dtype=dtype,
1245-
device=device,
1246-
requires_grad=requires_grad,
1247-
)
1248-
)
1249-
kwargs = {
1250-
"offset": offset,
1251-
"weight": weight,
1252-
"bias": bias,
1253-
"stride_h": stride[0],
1254-
"stride_w": stride[1],
1255-
"pad_h": padding[0],
1256-
"pad_w": padding[1],
1257-
"dilation_h": dilation[0],
1258-
"dilation_w": dilation[1],
1259-
"groups": groups,
1260-
"offset_groups": 1,
1261-
"use_mask": use_mask,
1262-
"mask": mask, # no modulation in this test
1263-
}
1264-
optests.opcheck(torch.ops.torchvision.deform_conv2d, args=(x,), kwargs=kwargs)
1204+
optests.generate_opcheck_tests(
1205+
testcase=TestDeformConv,
1206+
namespaces=["torchvision"],
1207+
failures_dict_path=os.path.join(os.path.dirname(__file__), "optests_failures_dict.json"),
1208+
additional_decorators=[],
1209+
test_utils=OPTESTS,
1210+
)
12651211

12661212

12671213
class TestFrozenBNT:

0 commit comments

Comments
 (0)