@@ -31,10 +31,16 @@ def sin(x):
31
31
def grad_grad_op (x ):
32
32
return - nd .sin (x )
33
33
34
+ def grad_grad_grad_op (x ):
35
+ return - nd .cos (x )
36
+
34
37
for dim in range (1 , 5 ):
35
38
shape = rand_shape_nd (dim )
36
39
array = random_arrays (shape )
37
40
check_second_order_unary (array , sin , grad_grad_op )
41
+ # TODO(kshitij12345): Remove
42
+ check_nth_order_unary (array , sin ,
43
+ [grad_grad_op , grad_grad_grad_op ], [2 , 3 ])
38
44
39
45
40
46
@with_seed ()
@@ -45,10 +51,16 @@ def cos(x):
45
51
def grad_grad_op (x ):
46
52
return - nd .cos (x )
47
53
54
+ def grad_grad_grad_op (x ):
55
+ return nd .sin (x )
56
+
48
57
for dim in range (1 , 5 ):
49
58
shape = rand_shape_nd (dim )
50
59
array = random_arrays (shape )
51
60
check_second_order_unary (array , cos , grad_grad_op )
61
+ # TODO(kshitij12345): Remove
62
+ check_nth_order_unary (array , cos ,
63
+ [grad_grad_op , grad_grad_grad_op ], [2 , 3 ])
52
64
53
65
54
66
@with_seed ()
@@ -178,13 +190,18 @@ def test_log():
178
190
def log (x ):
179
191
return nd .log (x )
180
192
193
+ def grad_op (x ):
194
+ return 1 / x
195
+
181
196
def grad_grad_op (x ):
182
197
return - 1 / (x ** 2 )
183
198
184
199
for dim in range (1 , 5 ):
185
200
shape = rand_shape_nd (dim )
186
201
array = random_arrays (shape )
187
202
check_second_order_unary (array , log , grad_grad_op )
203
+ # TODO(kshitij12345): Remove
204
+ check_nth_order_unary (array , log , [grad_op , grad_grad_op ], [1 , 2 ])
188
205
189
206
190
207
@with_seed ()
@@ -288,6 +305,9 @@ def grad_grad_op(x):
288
305
shape = rand_shape_nd (dim )
289
306
array = random_arrays (shape )
290
307
check_second_order_unary (array , sigmoid , grad_grad_op )
308
+ # TODO(kshitij12345): Remove
309
+ check_nth_order_unary (array , sigmoid , [grad_op , grad_grad_op ], [1 , 2 ])
310
+ check_nth_order_unary (array , sigmoid , grad_grad_op , 2 )
291
311
292
312
293
313
@with_seed ()
@@ -331,28 +351,77 @@ def grad_grad_op(x):
331
351
332
352
333
353
def check_second_order_unary (x , op , grad_grad_op , rtol = None , atol = None ):
354
+ check_nth_order_unary (x , op , grad_grad_op , 2 , rtol , atol )
355
+
356
+
357
+ def check_nth_order_unary (x , op , grad_ops , orders , rtol = None , atol = None ):
358
+ """Assert n-th order autograd gradient against expected gradient.
359
+
360
+ Multiple order of gradients can be checked by passing list of
361
+ function computing the particular order gradient and passing the
362
+ corresponding list of order.
363
+
364
+ Note
365
+ ----
366
+ 1. Orders should always be monotonically increasing.
367
+ 2. Elements of grads_ops should correspond to elements of orders
368
+ i.e. grads_op = [grad_op, grad_grad_grad_op] should be passed with
369
+ orders = [1, 3]
370
+
371
+ Parameters
372
+ ----------
373
+ x : mxnet.NDArray
374
+ Input Array.
375
+ op : Callable
376
+ Operation to perform on Input Array.
377
+ grad_ops : Callable or List of Callable
378
+ Function to compute and assert gradient of given order.
379
+ orders : int or List of int
380
+ Order/s to assert expected and computed gradients.
381
+
382
+ Returns
383
+ -------
384
+ None
385
+
386
+ """
387
+ if isinstance (orders , int ):
388
+ orders = [orders ]
389
+ grad_ops = [grad_ops ]
390
+
391
+ assert all (i < j for i , j in zip (orders [0 :- 1 ], orders [1 :])), \
392
+ "orders should be monotonically increasing"
393
+ assert len (set (orders )) == len (orders ), \
394
+ "orders should have unique elements"
395
+ highest_order = max (orders )
396
+
334
397
x = nd .array (x )
335
- grad_grad_x = grad_grad_op (x )
336
398
x .attach_grad ()
337
399
338
- # Manual head_grads.
339
- y_grad = nd . random . normal ( shape = x . shape )
340
- head_grad_grads = nd . random . normal ( shape = x . shape )
400
+ expected_grads = [ grad_op ( x ) for grad_op in grad_ops ]
401
+ computed_grads = []
402
+ head_grads = []
341
403
342
404
# Perform compute.
343
405
with autograd .record ():
344
406
y = op (x )
345
- x_grad = autograd .grad (heads = y , variables = x , head_grads = y_grad ,
346
- create_graph = True , retain_graph = True )[0 ]
347
- x_grad .backward (head_grad_grads )
348
-
349
- # Compute expected values.
350
- expected_grad_grad = grad_grad_x .asnumpy () * head_grad_grads .asnumpy () * \
351
- y_grad .asnumpy ()
352
-
353
- # Validate the gradients.
354
- assert_almost_equal (expected_grad_grad ,
355
- x .grad .asnumpy (), rtol = rtol , atol = atol )
407
+ for current_order in range (1 , highest_order + 1 ):
408
+ head_grad = nd .random .normal (shape = x .shape )
409
+ y = autograd .grad (heads = y , variables = x , head_grads = head_grad ,
410
+ create_graph = True , retain_graph = True )[0 ]
411
+ if current_order in orders :
412
+ computed_grads .append (y )
413
+ head_grads .append (head_grad )
414
+
415
+ # Validate all the gradients.
416
+ for order , grad , computed_grad in \
417
+ zip (orders , expected_grads , computed_grads ):
418
+ # Compute expected values.
419
+ expected_grad = grad .asnumpy ()
420
+ for head_grad in head_grads [:order ]:
421
+ expected_grad *= head_grad .asnumpy ()
422
+
423
+ assert_almost_equal (
424
+ expected_grad , computed_grad .asnumpy (), rtol = rtol , atol = atol )
356
425
357
426
358
427
if __name__ == '__main__' :
0 commit comments