@@ -295,8 +295,8 @@ def norm(x, p='fro', axis=None, keepdim=False, name=None):
295
295
Args:
296
296
x (Tensor): The input tensor could be N-D tensor, and the input data
297
297
type could be float32 or float64.
298
- p (float|string, optional): Order of the norm. Supported values are `fro`, `0`, `1`, `2`,
299
- `inf`, `-inf` and any positive real number yielding the corresponding p-norm. Not supported: ord < 0 and nuclear norm .
298
+ p (float|string, optional): Order of the norm. Supported values are `fro`, `nuc`, ` 0`, `1`, `2`,
299
+ `inf`, `-inf` and any positive real number yielding the corresponding p-norm. Not supported: ord < 0.
300
300
Default value is `fro`.
301
301
axis (int|list|tuple, optional): The axis on which to apply norm operation. If axis is int
302
302
or list(int)/tuple(int) with only one element, the vector norm is computed over the axis.
@@ -374,6 +374,21 @@ def norm(x, p='fro', axis=None, keepdim=False, name=None):
374
374
[4., 3., 2., 1.]])
375
375
"""
376
376
377
+ def _backshift_permutation (dim0 , dim1 , dimn ):
378
+ """
379
+ Auxiliary function for matrix_norm
380
+ Computes the permutation that moves the two given dimensions to the back
381
+ """
382
+ ret = [i for i in range (dimn ) if i != dim0 and i != dim1 ]
383
+ ret .extend ((dim0 , dim1 ))
384
+ return ret
385
+
386
+ def _inverse_permutation (perm ):
387
+ """
388
+ Given a permutation, returns its inverse. It's equivalent to argsort on an array
389
+ """
390
+ return [i for i , j in sorted (enumerate (perm ), key = lambda ij : ij [1 ])]
391
+
377
392
def frobenius_norm (input , dim = None , keepdim = False , name = None ):
378
393
"""
379
394
The frobenius norm OP is to calculate the frobenius norm of certain two dimensions of Tensor `input`.
@@ -414,6 +429,98 @@ def frobenius_norm(input, dim=None, keepdim=False, name=None):
414
429
)
415
430
return out
416
431
432
+ def nuclear_norm (input , axis = axis , keepdim = False , name = None ):
433
+ """
434
+ The nuclear norm OP is to calculate the nuclear norm of certain two dimensions of Tensor `input`.
435
+ Args:
436
+ input (Variable): Tensor, data type float32, float64.
437
+ dim (list): Two dimensions.
438
+ keepdim (bool, optional): Whether keep the dimensions as the `input`, Default False.
439
+ name (str, optional): The default value is None. Normally there is no need for
440
+ user to set this property. For more information, please refer to :ref:`api_guide_Name`.
441
+ """
442
+
443
+ perm = _backshift_permutation (axis [0 ], axis [1 ], len (input .shape ))
444
+ inv_perm = _inverse_permutation (perm )
445
+
446
+ if in_dynamic_mode ():
447
+ transposed = _C_ops .transpose (input , perm )
448
+ u , s , vh = _C_ops .svd (transposed , False )
449
+ result = _C_ops .sum (s , - 1 , None , keepdim )
450
+ if keepdim :
451
+ result = _C_ops .transpose (
452
+ _C_ops .unsqueeze (result , - 1 ), inv_perm
453
+ )
454
+ return result
455
+
456
+ attrs = {'axis' : axis , 'keepdim' : keepdim }
457
+
458
+ check_variable_and_dtype (
459
+ input , 'input' , ['float32' , 'float64' ], 'nuclear_norm'
460
+ )
461
+
462
+ block = LayerHelper ('nuclear_nrom' , ** locals ())
463
+ out = block .create_variable_for_type_inference (
464
+ dtype = block .input_dtype ()
465
+ )
466
+
467
+ transpose_out = block .create_variable_for_type_inference (
468
+ dtype = block .input_dtype ()
469
+ )
470
+ input_shape = block .create_variable_for_type_inference (
471
+ dtype = block .input_dtype ()
472
+ )
473
+
474
+ block .append_op (
475
+ type = 'transpose2' ,
476
+ inputs = {'X' : [input ]},
477
+ outputs = {'Out' : [transpose_out ], 'XShape' : [input_shape ]},
478
+ attrs = {'axis' : perm },
479
+ )
480
+
481
+ u = block .create_variable_for_type_inference (dtype = block .input_dtype ())
482
+ s = block .create_variable_for_type_inference (dtype = block .input_dtype ())
483
+ vt = block .create_variable_for_type_inference (dtype = block .input_dtype ())
484
+ block .append_op (
485
+ type = 'svd' ,
486
+ inputs = {'X' : [transpose_out ]},
487
+ outputs = {'U' : u , 'VH' : vt , 'S' : s },
488
+ attrs = {'full_matrices' : False },
489
+ )
490
+
491
+ reduce_all , sum_axis = _get_reduce_axis (- 1 , s )
492
+ block .append_op (
493
+ type = 'reduce_sum' ,
494
+ inputs = {'X' : s },
495
+ outputs = {'Out' : out },
496
+ attrs = {
497
+ 'dim' : sum_axis ,
498
+ 'keep_dim' : keepdim ,
499
+ 'reduce_all' : reduce_all ,
500
+ },
501
+ )
502
+
503
+ if keepdim :
504
+ unsqueeze_out = block .create_variable_for_type_inference (
505
+ dtype = block .input_dtype ()
506
+ )
507
+
508
+ block .append_op (
509
+ type = 'unsqueeze2' ,
510
+ inputs = {'X' : [out ]},
511
+ outputs = {'Out' : [unsqueeze_out ], 'XShape' : [input_shape ]},
512
+ attrs = {'axes' : [- 1 ]},
513
+ )
514
+
515
+ block .append_op (
516
+ type = 'transpose2' ,
517
+ inputs = {'X' : [unsqueeze_out ]},
518
+ outputs = {'Out' : [out ], 'XShape' : [input_shape ]},
519
+ attrs = {'axis' : inv_perm },
520
+ )
521
+
522
+ return out
523
+
417
524
def vector_norm (
418
525
input , porder = None , axis = None , keepdim = False , asvector = False , name = None
419
526
):
@@ -616,6 +723,8 @@ def p_matrix_norm(input, porder=1.0, axis=axis, keepdim=False, name=None):
616
723
elif isinstance (axis , list ) and len (axis ) == 2 :
617
724
if p == "fro" :
618
725
return frobenius_norm (x , dim = axis , keepdim = keepdim , name = name )
726
+ elif p == "nuc" :
727
+ return nuclear_norm (x , axis = axis , keepdim = keepdim , name = name )
619
728
elif p == np .inf or p == - np .inf :
620
729
return inf_norm (x , porder = p , axis = axis , keepdim = keepdim , name = name )
621
730
elif p == 0 :
0 commit comments