Skip to content

Commit c1bd2e9

Browse files
committed
Small changes to I
1 parent 5212c8e commit c1bd2e9

File tree

1 file changed

+4
-5
lines changed

1 file changed

+4
-5
lines changed

torchmdnet/models/tensornet.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -47,19 +47,18 @@ def vector_to_skewtensor(vector):
4747
def vector_to_symtensor(vector):
4848
"""Creates a symmetric traceless tensor from the outer product of a vector with itself."""
4949
tensor = torch.matmul(vector.unsqueeze(-1), vector.unsqueeze(-2))
50-
I = (tensor.diagonal(offset=0, dim1=-1, dim2=-2)).mean(-1)[
51-
..., None, None
52-
] * torch.eye(3, 3, device=tensor.device, dtype=tensor.dtype)
53-
S = 0.5 * (tensor + tensor.transpose(-2, -1)) - I
50+
S = 0.5 * (tensor + tensor.transpose(-2, -1))
51+
I = (tensor.diagonal(offset=0, dim1=-1, dim2=-2)).mean(-1)
52+
S.diagonal(offset=0, dim1=-1, dim2=-2).sub_(I.unsqueeze(-1))
5453
return S
5554

5655

5756
@nvtx_annotate("decompose_tensor")
5857
def decompose_tensor(tensor):
5958
"""Full tensor decomposition into irreducible components."""
60-
I = (tensor.diagonal(offset=0, dim1=-1, dim2=-2)).mean(-1)
6159
A = 0.5 * (tensor - tensor.transpose(-2, -1))
6260
S = tensor - A
61+
I = (tensor.diagonal(offset=0, dim1=-1, dim2=-2)).mean(-1)
6362
S.diagonal(offset=0, dim1=-1, dim2=-2).sub_(I.unsqueeze(-1))
6463
return I, A, S
6564

0 commit comments

Comments
 (0)