Skip to content

Commit 9d3936a

Browse files
kshitij12345larroy
authored andcommitted
[MXNET-978] n-th order gradient test support. (apache#15611)
* n-th order grad test support * use check_nth_order_unary for second order check * add docstring to check_nth_order_unary * retrigger CI * add assertions for requirements of orders * fix assert condition * retrigger CI
1 parent eace8c7 commit 9d3936a

File tree

1 file changed

+84
-15
lines changed

1 file changed

+84
-15
lines changed

tests/python/unittest/test_higher_order_grad.py

Lines changed: 84 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -31,10 +31,16 @@ def sin(x):
3131
def grad_grad_op(x):
3232
return -nd.sin(x)
3333

34+
def grad_grad_grad_op(x):
35+
return -nd.cos(x)
36+
3437
for dim in range(1, 5):
3538
shape = rand_shape_nd(dim)
3639
array = random_arrays(shape)
3740
check_second_order_unary(array, sin, grad_grad_op)
41+
# TODO(kshitij12345): Remove
42+
check_nth_order_unary(array, sin,
43+
[grad_grad_op, grad_grad_grad_op], [2, 3])
3844

3945

4046
@with_seed()
@@ -45,10 +51,16 @@ def cos(x):
4551
def grad_grad_op(x):
4652
return -nd.cos(x)
4753

54+
def grad_grad_grad_op(x):
55+
return nd.sin(x)
56+
4857
for dim in range(1, 5):
4958
shape = rand_shape_nd(dim)
5059
array = random_arrays(shape)
5160
check_second_order_unary(array, cos, grad_grad_op)
61+
# TODO(kshitij12345): Remove
62+
check_nth_order_unary(array, cos,
63+
[grad_grad_op, grad_grad_grad_op], [2, 3])
5264

5365

5466
@with_seed()
@@ -178,13 +190,18 @@ def test_log():
178190
def log(x):
179191
return nd.log(x)
180192

193+
def grad_op(x):
194+
return 1/x
195+
181196
def grad_grad_op(x):
182197
return -1/(x**2)
183198

184199
for dim in range(1, 5):
185200
shape = rand_shape_nd(dim)
186201
array = random_arrays(shape)
187202
check_second_order_unary(array, log, grad_grad_op)
203+
# TODO(kshitij12345): Remove
204+
check_nth_order_unary(array, log, [grad_op, grad_grad_op], [1, 2])
188205

189206

190207
@with_seed()
@@ -288,6 +305,9 @@ def grad_grad_op(x):
288305
shape = rand_shape_nd(dim)
289306
array = random_arrays(shape)
290307
check_second_order_unary(array, sigmoid, grad_grad_op)
308+
# TODO(kshitij12345): Remove
309+
check_nth_order_unary(array, sigmoid, [grad_op, grad_grad_op], [1, 2])
310+
check_nth_order_unary(array, sigmoid, grad_grad_op, 2)
291311

292312

293313
@with_seed()
@@ -331,28 +351,77 @@ def grad_grad_op(x):
331351

332352

333353
def check_second_order_unary(x, op, grad_grad_op, rtol=None, atol=None):
354+
check_nth_order_unary(x, op, grad_grad_op, 2, rtol, atol)
355+
356+
357+
def check_nth_order_unary(x, op, grad_ops, orders, rtol=None, atol=None):
358+
"""Assert n-th order autograd gradient against expected gradient.
359+
360+
Multiple order of gradients can be checked by passing list of
361+
function computing the particular order gradient and passing the
362+
corresponding list of order.
363+
364+
Note
365+
----
366+
1. Orders should always be monotonically increasing.
367+
2. Elements of grads_ops should correspond to elements of orders
368+
i.e. grads_op = [grad_op, grad_grad_grad_op] should be passed with
369+
orders = [1, 3]
370+
371+
Parameters
372+
----------
373+
x : mxnet.NDArray
374+
Input Array.
375+
op : Callable
376+
Operation to perform on Input Array.
377+
grad_ops : Callable or List of Callable
378+
Function to compute and assert gradient of given order.
379+
orders : int or List of int
380+
Order/s to assert expected and computed gradients.
381+
382+
Returns
383+
-------
384+
None
385+
386+
"""
387+
if isinstance(orders, int):
388+
orders = [orders]
389+
grad_ops = [grad_ops]
390+
391+
assert all(i < j for i, j in zip(orders[0:-1], orders[1:])), \
392+
"orders should be monotonically increasing"
393+
assert len(set(orders)) == len(orders), \
394+
"orders should have unique elements"
395+
highest_order = max(orders)
396+
334397
x = nd.array(x)
335-
grad_grad_x = grad_grad_op(x)
336398
x.attach_grad()
337399

338-
# Manual head_grads.
339-
y_grad = nd.random.normal(shape=x.shape)
340-
head_grad_grads = nd.random.normal(shape=x.shape)
400+
expected_grads = [grad_op(x) for grad_op in grad_ops]
401+
computed_grads = []
402+
head_grads = []
341403

342404
# Perform compute.
343405
with autograd.record():
344406
y = op(x)
345-
x_grad = autograd.grad(heads=y, variables=x, head_grads=y_grad,
346-
create_graph=True, retain_graph=True)[0]
347-
x_grad.backward(head_grad_grads)
348-
349-
# Compute expected values.
350-
expected_grad_grad = grad_grad_x.asnumpy() * head_grad_grads.asnumpy() * \
351-
y_grad.asnumpy()
352-
353-
# Validate the gradients.
354-
assert_almost_equal(expected_grad_grad,
355-
x.grad.asnumpy(), rtol=rtol, atol=atol)
407+
for current_order in range(1, highest_order+1):
408+
head_grad = nd.random.normal(shape=x.shape)
409+
y = autograd.grad(heads=y, variables=x, head_grads=head_grad,
410+
create_graph=True, retain_graph=True)[0]
411+
if current_order in orders:
412+
computed_grads.append(y)
413+
head_grads.append(head_grad)
414+
415+
# Validate all the gradients.
416+
for order, grad, computed_grad in \
417+
zip(orders, expected_grads, computed_grads):
418+
# Compute expected values.
419+
expected_grad = grad.asnumpy()
420+
for head_grad in head_grads[:order]:
421+
expected_grad *= head_grad.asnumpy()
422+
423+
assert_almost_equal(
424+
expected_grad, computed_grad.asnumpy(), rtol=rtol, atol=atol)
356425

357426

358427
if __name__ == '__main__':

0 commit comments

Comments
 (0)