@@ -1201,67 +1201,13 @@ def test_forward_scriptability(self):
1201
1201
torch .jit .script (ops .DeformConv2d (in_channels = 8 , out_channels = 8 , kernel_size = 3 ))
1202
1202
1203
1203
1204
- @pytest .mark .parametrize ("dtype" , (torch .float16 , torch .float32 , torch .float64 ))
1205
- @pytest .mark .parametrize ("device" , cpu_and_cuda ())
1206
- @pytest .mark .parametrize ("requires_grad" , (True , False ))
1207
- def test_deform_conv2d_opcheck (dtype , device , requires_grad ):
1208
- batch_size , channels_in , height , width = 1 , 6 , 10 , 10
1209
- kernel_size = (3 , 3 )
1210
- stride = (1 , 1 )
1211
- padding = (1 , 1 )
1212
- dilation = (1 , 1 )
1213
- groups = 2
1214
- out_channels = 4
1215
- out_h = (height + 2 * padding [0 ] - dilation [0 ] * (kernel_size [0 ] - 1 ) - 1 ) // stride [0 ] + 1
1216
- out_w = (width + 2 * padding [1 ] - dilation [1 ] * (kernel_size [1 ] - 1 ) - 1 ) // stride [1 ] + 1
1217
- x = torch .randn (batch_size , channels_in , height , width , dtype = dtype , device = device , requires_grad = requires_grad )
1218
- offset = torch .randn (
1219
- batch_size ,
1220
- 2 * kernel_size [0 ] * kernel_size [1 ],
1221
- out_h ,
1222
- out_w ,
1223
- dtype = dtype ,
1224
- device = device ,
1225
- requires_grad = requires_grad ,
1226
- )
1227
- weight = torch .randn (
1228
- out_channels ,
1229
- channels_in // groups ,
1230
- kernel_size [0 ],
1231
- kernel_size [1 ],
1232
- dtype = dtype ,
1233
- device = device ,
1234
- requires_grad = requires_grad ,
1235
- )
1236
- bias = torch .randn (out_channels , dtype = dtype , device = device , requires_grad = requires_grad )
1237
- use_mask = True
1238
- mask = torch .sigmoid (
1239
- torch .randn (
1240
- batch_size ,
1241
- kernel_size [0 ] * kernel_size [1 ],
1242
- out_h ,
1243
- out_w ,
1244
- dtype = dtype ,
1245
- device = device ,
1246
- requires_grad = requires_grad ,
1247
- )
1248
- )
1249
- kwargs = {
1250
- "offset" : offset ,
1251
- "weight" : weight ,
1252
- "bias" : bias ,
1253
- "stride_h" : stride [0 ],
1254
- "stride_w" : stride [1 ],
1255
- "pad_h" : padding [0 ],
1256
- "pad_w" : padding [1 ],
1257
- "dilation_h" : dilation [0 ],
1258
- "dilation_w" : dilation [1 ],
1259
- "groups" : groups ,
1260
- "offset_groups" : 1 ,
1261
- "use_mask" : use_mask ,
1262
- "mask" : mask , # no modulation in this test
1263
- }
1264
- optests .opcheck (torch .ops .torchvision .deform_conv2d , args = (x ,), kwargs = kwargs )
1204
+ optests .generate_opcheck_tests (
1205
+ testcase = TestDeformConv ,
1206
+ namespaces = ["torchvision" ],
1207
+ failures_dict_path = os .path .join (os .path .dirname (__file__ ), "optests_failures_dict.json" ),
1208
+ additional_decorators = [],
1209
+ test_utils = OPTESTS ,
1210
+ )
1265
1211
1266
1212
1267
1213
class TestFrozenBNT :
0 commit comments