@@ -57,11 +57,10 @@ def vector_to_symtensor(vector):
57
57
@nvtx_annotate ("decompose_tensor" )
58
58
def decompose_tensor (tensor ):
59
59
"""Full tensor decomposition into irreducible components."""
60
- I = (tensor .diagonal (offset = 0 , dim1 = - 1 , dim2 = - 2 )).mean (- 1 )[
61
- ..., None , None
62
- ] * torch .eye (3 , 3 , device = tensor .device , dtype = tensor .dtype )
60
+ I = (tensor .diagonal (offset = 0 , dim1 = - 1 , dim2 = - 2 )).mean (- 1 )
63
61
A = 0.5 * (tensor - tensor .transpose (- 2 , - 1 ))
64
- S = 0.5 * (tensor + tensor .transpose (- 2 , - 1 )) - I
62
+ S = tensor - A
63
+ S .diagonal (offset = 0 , dim1 = - 1 , dim2 = - 2 ).sub_ (I .unsqueeze (- 1 ))
65
64
return I , A , S
66
65
67
66
@@ -260,7 +259,7 @@ def _compute_neighbors(
260
259
def output (self , X : Tensor ) -> Tensor :
261
260
I , A , S = decompose_tensor (X ) # shape: (n_atoms, hidden_channels, 3, 3)
262
261
x = torch .cat (
263
- (tensor_norm ( I ) , tensor_norm (A ), tensor_norm (S )), dim = - 1
262
+ (3 * I ** 2 , tensor_norm (A ), tensor_norm (S )), dim = - 1
264
263
) # shape: (n_atoms, 3*hidden_channels)
265
264
x = self .out_norm (x ) # shape: (n_atoms, 3*hidden_channels)
266
265
x = self .act (self .linear ((x ))) # shape: (n_atoms, hidden_channels)
@@ -322,10 +321,7 @@ def forward(self, X: Tensor, factor: Optional[Tensor] = None) -> Tensor:
322
321
.unsqueeze (- 1 )
323
322
).expand (- 1 , - 1 , 3 )
324
323
I , A , S = decompose_tensor (X )
325
- I = (
326
- self .linearI (I .permute (0 , 2 , 3 , 1 )).permute (0 , 3 , 1 , 2 )
327
- * factor [..., 0 , None , None ]
328
- )
324
+ I = self .linearI (I ) * factor [..., 0 ]
329
325
A = (
330
326
self .linearA (A .permute (0 , 2 , 3 , 1 )).permute (0 , 3 , 1 , 2 )
331
327
* factor [..., 1 , None , None ]
@@ -334,7 +330,8 @@ def forward(self, X: Tensor, factor: Optional[Tensor] = None) -> Tensor:
334
330
self .linearS (S .permute (0 , 2 , 3 , 1 )).permute (0 , 3 , 1 , 2 )
335
331
* factor [..., 2 , None , None ]
336
332
)
337
- dX = I + A + S
333
+ dX = A + S
334
+ dX .diagonal (dim1 = - 2 , dim2 = - 1 ).add_ (I .unsqueeze (- 1 ))
338
335
return dX
339
336
340
337
@@ -490,10 +487,11 @@ def forward(
490
487
@nvtx_annotate ("compute_tensor_edge_features" )
491
488
def compute_tensor_edge_features (X , edge_index , factor ):
492
489
I , A , S = decompose_tensor (X )
493
- msg = (
494
- factor [..., 0 , None , None ] * I .index_select (0 , edge_index [1 ])
495
- + factor [..., 1 , None , None ] * A .index_select (0 , edge_index [1 ])
496
- + factor [..., 2 , None , None ] * S .index_select (0 , edge_index [1 ])
490
+ msg = factor [..., 1 , None , None ] * A .index_select (0 , edge_index [1 ]) + factor [
491
+ ..., 2 , None , None
492
+ ] * S .index_select (0 , edge_index [1 ])
493
+ msg .diagonal (dim1 = - 2 , dim2 = - 1 ).add_ (
494
+ factor [..., 0 , None ] * I .index_select (0 , edge_index [1 ]).unsqueeze (- 1 )
497
495
)
498
496
return msg
499
497
0 commit comments