Skip to content

Commit 5212c8e

Browse files
committed
Store I as a single number
1 parent ada1192 commit 5212c8e

File tree

1 file changed

+12
-14
lines changed

1 file changed

+12
-14
lines changed

torchmdnet/models/tensornet.py

Lines changed: 12 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -57,11 +57,10 @@ def vector_to_symtensor(vector):
5757
@nvtx_annotate("decompose_tensor")
5858
def decompose_tensor(tensor):
5959
"""Full tensor decomposition into irreducible components."""
60-
I = (tensor.diagonal(offset=0, dim1=-1, dim2=-2)).mean(-1)[
61-
..., None, None
62-
] * torch.eye(3, 3, device=tensor.device, dtype=tensor.dtype)
60+
I = (tensor.diagonal(offset=0, dim1=-1, dim2=-2)).mean(-1)
6361
A = 0.5 * (tensor - tensor.transpose(-2, -1))
64-
S = 0.5 * (tensor + tensor.transpose(-2, -1)) - I
62+
S = tensor - A
63+
S.diagonal(offset=0, dim1=-1, dim2=-2).sub_(I.unsqueeze(-1))
6564
return I, A, S
6665

6766

@@ -260,7 +259,7 @@ def _compute_neighbors(
260259
def output(self, X: Tensor) -> Tensor:
261260
I, A, S = decompose_tensor(X) # shape: (n_atoms, hidden_channels, 3, 3)
262261
x = torch.cat(
263-
(tensor_norm(I), tensor_norm(A), tensor_norm(S)), dim=-1
262+
(3 * I**2, tensor_norm(A), tensor_norm(S)), dim=-1
264263
) # shape: (n_atoms, 3*hidden_channels)
265264
x = self.out_norm(x) # shape: (n_atoms, 3*hidden_channels)
266265
x = self.act(self.linear((x))) # shape: (n_atoms, hidden_channels)
@@ -322,10 +321,7 @@ def forward(self, X: Tensor, factor: Optional[Tensor] = None) -> Tensor:
322321
.unsqueeze(-1)
323322
).expand(-1, -1, 3)
324323
I, A, S = decompose_tensor(X)
325-
I = (
326-
self.linearI(I.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
327-
* factor[..., 0, None, None]
328-
)
324+
I = self.linearI(I) * factor[..., 0]
329325
A = (
330326
self.linearA(A.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
331327
* factor[..., 1, None, None]
@@ -334,7 +330,8 @@ def forward(self, X: Tensor, factor: Optional[Tensor] = None) -> Tensor:
334330
self.linearS(S.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
335331
* factor[..., 2, None, None]
336332
)
337-
dX = I + A + S
333+
dX = A + S
334+
dX.diagonal(dim1=-2, dim2=-1).add_(I.unsqueeze(-1))
338335
return dX
339336

340337

@@ -490,10 +487,11 @@ def forward(
490487
@nvtx_annotate("compute_tensor_edge_features")
491488
def compute_tensor_edge_features(X, edge_index, factor):
492489
I, A, S = decompose_tensor(X)
493-
msg = (
494-
factor[..., 0, None, None] * I.index_select(0, edge_index[1])
495-
+ factor[..., 1, None, None] * A.index_select(0, edge_index[1])
496-
+ factor[..., 2, None, None] * S.index_select(0, edge_index[1])
490+
msg = factor[..., 1, None, None] * A.index_select(0, edge_index[1]) + factor[
491+
..., 2, None, None
492+
] * S.index_select(0, edge_index[1])
493+
msg.diagonal(dim1=-2, dim2=-1).add_(
494+
factor[..., 0, None] * I.index_select(0, edge_index[1]).unsqueeze(-1)
497495
)
498496
return msg
499497

0 commit comments

Comments
 (0)