@@ -319,6 +319,39 @@ def test_flip():
319
319
assert np .sum (t [- 1 , :].asnumpy () == 0 ) == b .shape [1 ]
320
320
321
321
322
+ def check_pad_with_shape (shape , pad_width , mode , dtype = "float64" ):
323
+ # bind with label
324
+ X = mx .symbol .Variable ('X' , dtype = dtype )
325
+ Y = mx .symbol .Pad (data = X , mode = mode , pad_width = pad_width )
326
+ x = mx .random .uniform (- 1 , 1 , shape , dtype = dtype )
327
+ # numpy result
328
+ pad_grouped = list (zip (* [iter (list (pad_width ))] * 2 ))
329
+ np_out = np .pad (x .asnumpy (), pad_grouped , mode )
330
+ # mxnet result
331
+ grad = mx .nd .empty (shape , dtype = dtype )
332
+ exec1 = Y .bind (args = [x ], ctx = mx .cpu (), args_grad = {'X' : grad })
333
+ exec1 .forward (is_train = True )
334
+ out = exec1 .outputs [0 ].asnumpy ()
335
+ # compare numpy + mxnet
336
+ assert_almost_equal (out , np_out )
337
+
338
+ def test_pad ():
339
+ shape1 = (LARGE_X , 2 , 2 , 2 )
340
+ pad1 = (0 , 0 , 0 , 0 , 1 , 1 , 1 , 1 )
341
+ shape2 = (LARGE_X , 2 , 2 , 1 , 1 )
342
+ pad2 = (0 , 0 , 0 , 0 , 1 , 1 , 0 , 0 , 0 , 0 )
343
+ # note: this op doesn't support ints yet. Add tests when supported
344
+ dtypes = ["float16" , "float32" , "float64" ]
345
+ dtypes = ["float16" ]
346
+ for dtype in dtypes :
347
+ check_pad_with_shape (shape1 , pad1 , 'constant' , dtype )
348
+ check_pad_with_shape (shape1 , pad1 , 'edge' , dtype )
349
+ check_pad_with_shape (shape2 , pad2 , 'constant' , dtype )
350
+ check_pad_with_shape (shape2 , pad2 , 'edge' , dtype )
351
+ check_pad_with_shape (shape1 , pad1 , 'reflect' , dtype )
352
+ check_pad_with_shape (shape2 , pad2 , 'reflect' , dtype )
353
+
354
+
322
355
if __name__ == '__main__' :
323
356
import nose
324
357
nose .runmodule ()
0 commit comments