Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.

Commit df9bcc6

Browse files
committed
update test
1 parent a632ac6 commit df9bcc6

File tree

1 file changed

+8
-4
lines changed

1 file changed

+8
-4
lines changed

tests/python/unittest/test_operator.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6520,6 +6520,7 @@ def test_laop_6():
65206520
v = np.random.random(4)
65216521
a = np.eye(4) + np.outer(v, v)
65226522
a = np.tile(a, (3, 1, 1))
6523+
permute_mat = np.eye(4)[[1, 0, 2, 3]]
65236524

65246525
# test matrix inverse
65256526
r = np.eye(4)
@@ -6541,10 +6542,13 @@ def test_laop_6():
65416542
check_fw(test_logdet, [a], [r])
65426543
check_grad(test_logdet, [a])
65436544
# test slogdet
6544-
r = np.log(np.abs(np.linalg.det(a)))
6545-
_, test_slogdet = mx.sym.linalg.slogdet(data)
6546-
check_fw(test_slogdet, [a], [r])
6547-
check_grad(test_slogdet, [a])
6545+
r1 = np.array([1., 1., 1.])
6546+
r2 = np.log(np.abs(np.linalg.det(a)))
6547+
test_sign, test_logabsdet = mx.sym.linalg.slogdet(data)
6548+
check_fw(test_sign, [a], [r1])
6549+
check_fw(test_sign, [np.dot(a, permute_mat)], [-r1])
6550+
check_fw(test_logabsdet, [a], [r2])
6551+
check_grad(test_logabsdet, [a])
65486552

65496553
@with_seed()
65506554
def test_stack():

0 commit comments

Comments
 (0)