@@ -47,19 +47,18 @@ def vector_to_skewtensor(vector):
47
47
def vector_to_symtensor (vector ):
48
48
"""Creates a symmetric traceless tensor from the outer product of a vector with itself."""
49
49
tensor = torch .matmul (vector .unsqueeze (- 1 ), vector .unsqueeze (- 2 ))
50
- I = (tensor .diagonal (offset = 0 , dim1 = - 1 , dim2 = - 2 )).mean (- 1 )[
51
- ..., None , None
52
- ] * torch .eye (3 , 3 , device = tensor .device , dtype = tensor .dtype )
53
- S = 0.5 * (tensor + tensor .transpose (- 2 , - 1 )) - I
50
+ S = 0.5 * (tensor + tensor .transpose (- 2 , - 1 ))
51
+ I = (tensor .diagonal (offset = 0 , dim1 = - 1 , dim2 = - 2 )).mean (- 1 )
52
+ S .diagonal (offset = 0 , dim1 = - 1 , dim2 = - 2 ).sub_ (I .unsqueeze (- 1 ))
54
53
return S
55
54
56
55
57
56
@nvtx_annotate ("decompose_tensor" )
58
57
def decompose_tensor (tensor ):
59
58
"""Full tensor decomposition into irreducible components."""
60
- I = (tensor .diagonal (offset = 0 , dim1 = - 1 , dim2 = - 2 )).mean (- 1 )
61
59
A = 0.5 * (tensor - tensor .transpose (- 2 , - 1 ))
62
60
S = tensor - A
61
+ I = (tensor .diagonal (offset = 0 , dim1 = - 1 , dim2 = - 2 )).mean (- 1 )
63
62
S .diagonal (offset = 0 , dim1 = - 1 , dim2 = - 2 ).sub_ (I .unsqueeze (- 1 ))
64
63
return I , A , S
65
64
0 commit comments