@@ -444,6 +444,72 @@ def test_topk():
444
444
assert l .sum () == np .sum (np .arange (0 , SMALL_Y ))
445
445
446
446
447
+ def test_exponent_logarithm_operators ():
448
+ a = 2 * nd .ones (shape = (LARGE_X , SMALL_Y ))
449
+ # exponent
450
+ result = nd .exp (a )
451
+ assert result [0 ][- 1 ] == 7.389056
452
+ assert result .shape == a .shape
453
+
454
+ # exponent minus 1
455
+ result = nd .expm1 (a )
456
+ assert result [0 ][- 1 ] == 6.389056
457
+ assert result .shape == a .shape
458
+
459
+ # log2
460
+ result = nd .log2 (a )
461
+ assert result [0 ][- 1 ] == 1
462
+ assert result .shape == a .shape
463
+
464
+ # log10
465
+ result = nd .log10 (a )
466
+ assert result [0 ][- 1 ] == 0.30103
467
+ assert result .shape == a .shape
468
+
469
+ # log1p
470
+ result = nd .log1p (a )
471
+ assert result [0 ][- 1 ] == 1.0986123
472
+ assert result .shape == a .shape
473
+
474
+ # log
475
+ result = nd .log (a )
476
+ assert result [0 ][- 1 ] == 0.6931472
477
+ assert result .shape == a .shape
478
+
479
+
480
+ def test_power_operators ():
481
+ a = 2 * nd .ones (shape = (LARGE_X , SMALL_Y ))
482
+ # sqrt
483
+ result = nd .sqrt (a )
484
+ assert result [0 ][- 1 ] == 1.4142135
485
+ assert result .shape == a .shape
486
+
487
+ # rsqrt
488
+ result = nd .rsqrt (a )
489
+ assert result [0 ][- 1 ] == 0.70710677
490
+ assert result .shape == a .shape
491
+
492
+ # cbrt
493
+ result = nd .cbrt (a )
494
+ assert result [0 ][- 1 ] == 1.2599211
495
+ assert result .shape == a .shape
496
+
497
+ # rcbrt
498
+ result = nd .rcbrt (a )
499
+ assert result [0 ][- 1 ] == 0.7937005
500
+ assert result .shape == a .shape
501
+
502
+ # square
503
+ result = nd .square (a )
504
+ assert result [0 ][- 1 ] == 4
505
+ assert result .shape == a .shape
506
+
507
+ # reciprocal
508
+ result = nd .reciprocal (a )
509
+ assert result [0 ][- 1 ] == 0.5
510
+ assert result .shape == a .shape
511
+
512
+
447
513
def test_sequence_mask ():
448
514
# Sequence Mask input [max_sequence_length, batch_size, other_feature_dims]
449
515
# test with input batch_size = 2
0 commit comments