Skip to content

Commit ada1192

Browse files
committed
Updates to TensorEmbeeding MP
1 parent 6ea18d1 commit ada1192

File tree

1 file changed

+70
-40
lines changed

1 file changed

+70
-40
lines changed

torchmdnet/models/tensornet.py

Lines changed: 70 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -11,35 +11,39 @@
1111
rbf_class_mapping,
1212
act_class_mapping,
1313
MLP,
14+
nvtx_annotate,
15+
nvtx_range,
1416
)
1517

1618
__all__ = ["TensorNet"]
17-
torch.set_float32_matmul_precision("high")
19+
torch.set_float32_matmul_precision("medium")
1820
torch.backends.cuda.matmul.allow_tf32 = True
1921

2022

23+
@nvtx_annotate("vector_to_skewtensor")
2124
def vector_to_skewtensor(vector):
2225
"""Creates a skew-symmetric tensor from a vector."""
23-
batch_size = vector.size(0)
26+
batch_size = vector.shape[:-1]
2427
zero = torch.zeros(batch_size, device=vector.device, dtype=vector.dtype)
2528
tensor = torch.stack(
2629
(
2730
zero,
28-
-vector[:, 2],
29-
vector[:, 1],
30-
vector[:, 2],
31+
-vector[..., 2],
32+
vector[..., 1],
33+
vector[..., 2],
3134
zero,
32-
-vector[:, 0],
33-
-vector[:, 1],
34-
vector[:, 0],
35+
-vector[..., 0],
36+
-vector[..., 1],
37+
vector[..., 0],
3538
zero,
3639
),
37-
dim=1,
40+
dim=-1,
3841
)
39-
tensor = tensor.view(-1, 3, 3)
42+
tensor = tensor.view(*batch_size, 3, 3)
4043
return tensor.squeeze(0)
4144

4245

46+
@nvtx_annotate("vector_to_symtensor")
4347
def vector_to_symtensor(vector):
4448
"""Creates a symmetric traceless tensor from the outer product of a vector with itself."""
4549
tensor = torch.matmul(vector.unsqueeze(-1), vector.unsqueeze(-2))
@@ -50,6 +54,7 @@ def vector_to_symtensor(vector):
5054
return S
5155

5256

57+
@nvtx_annotate("decompose_tensor")
5358
def decompose_tensor(tensor):
5459
"""Full tensor decomposition into irreducible components."""
5560
I = (tensor.diagonal(offset=0, dim1=-1, dim2=-2)).mean(-1)[
@@ -60,6 +65,7 @@ def decompose_tensor(tensor):
6065
return I, A, S
6166

6267

68+
@nvtx_annotate("tensor_norm")
6369
def tensor_norm(tensor):
6470
"""Computes Frobenius norm."""
6571
return (tensor**2).sum((-2, -1))
@@ -220,6 +226,7 @@ def reset_parameters(self):
220226
self.linear.reset_parameters()
221227
self.out_norm.reset_parameters()
222228

229+
@nvtx_annotate("make_static")
223230
def _make_static(
224231
self, num_nodes: int, edge_index: Tensor, edge_weight: Tensor, edge_vec: Tensor
225232
) -> Tuple[Tensor, Tensor, Tensor]:
@@ -235,6 +242,7 @@ def _make_static(
235242
)
236243
return edge_index, edge_weight, edge_vec
237244

245+
@nvtx_annotate("compute_neighbors")
238246
def _compute_neighbors(
239247
self, pos: Tensor, batch: Tensor, box: Optional[Tensor]
240248
) -> Tuple[Tensor, Tensor, Tensor]:
@@ -248,6 +256,7 @@ def _compute_neighbors(
248256
)
249257
return edge_index, edge_weight, edge_vec
250258

259+
@nvtx_annotate("output")
251260
def output(self, X: Tensor) -> Tensor:
252261
I, A, S = decompose_tensor(X) # shape: (n_atoms, hidden_channels, 3, 3)
253262
x = torch.cat(
@@ -257,6 +266,7 @@ def output(self, X: Tensor) -> Tensor:
257266
x = self.act(self.linear((x))) # shape: (n_atoms, hidden_channels)
258267
return x
259268

269+
@nvtx_annotate("TensorNet")
260270
def forward(
261271
self,
262272
z: Tensor,
@@ -303,6 +313,7 @@ def reset_parameters(self):
303313
self.linearA.reset_parameters()
304314
self.linearS.reset_parameters()
305315

316+
@nvtx_annotate("TensorLinear")
306317
def forward(self, X: Tensor, factor: Optional[Tensor] = None) -> Tensor:
307318
if factor is None:
308319
factor = (
@@ -363,6 +374,8 @@ def __init__(
363374
nn.Linear(2 * hidden_channels, 3 * hidden_channels, bias=True, dtype=dtype)
364375
)
365376
self.init_norm = nn.LayerNorm(hidden_channels, dtype=dtype)
377+
self.num_rbf = num_rbf
378+
self.hidden_channels = hidden_channels
366379
self.reset_parameters()
367380

368381
def reset_parameters(self):
@@ -376,6 +389,7 @@ def reset_parameters(self):
376389
linear.reset_parameters()
377390
self.init_norm.reset_parameters()
378391

392+
@nvtx_annotate("normalize_edges")
379393
def _normalize_edges(
380394
self, edge_index: Tensor, edge_weight: Tensor, edge_vec: Tensor
381395
) -> Tensor:
@@ -385,16 +399,18 @@ def _normalize_edges(
385399
edge_vec = edge_vec / edge_weight.masked_fill(mask, 1).unsqueeze(1)
386400
return edge_vec
387401

402+
@nvtx_annotate("compute_edge_atomic_features")
388403
def _compute_edge_atomic_features(self, z: Tensor, edge_index: Tensor) -> Tensor:
389404
Z = self.emb(z)
390405
Zij = self.emb2(
391406
Z.index_select(0, edge_index.t().reshape(-1)).view(
392407
-1, self.hidden_channels * 2
393408
)
394-
)[..., None, None]
409+
)
395410
return Zij
396411

397-
def _compute_edge_tensor_features(
412+
@nvtx_annotate("compute_edge_tensor_features")
413+
def _compute_node_tensor_features(
398414
self,
399415
z: Tensor,
400416
edge_index,
@@ -405,44 +421,53 @@ def _compute_edge_tensor_features(
405421
edge_vec_norm = self._normalize_edges(
406422
edge_index, edge_weight, edge_vec
407423
) # shape: (n_edges, 3)
408-
Zij = self._compute_edge_atomic_features(
424+
Zij = self.cutoff(edge_weight)[:, None] * self._compute_edge_atomic_features(
409425
z, edge_index
410426
) # shape: (n_edges, hidden_channels)
411-
C = self.cutoff(edge_weight).reshape(-1, 1, 1, 1) * Zij
412-
Iij = self.distance_proj1(edge_attr)
413-
Aij = (
414-
self.distance_proj2(edge_attr)[..., None, None]
415-
* vector_to_skewtensor(edge_vec_norm)[..., None, :, :]
416-
)
417-
Sij = (
427+
428+
A = (
429+
self.distance_proj2(edge_attr)[
430+
..., None
431+
] # shape: (n_edges, hidden_channels, 1)
432+
* Zij[..., None] # shape: (n_edges, hidden_channels, 1)
433+
* edge_vec_norm[:, None, :] # shape: (n_edges, 1, 3)
434+
) # shape: (n_edges, hidden_channels, 3)
435+
A = self._aggregate_edge_features(
436+
z.shape[0], A, edge_index[0]
437+
) # shape: (n_atoms, hidden_channels, 3)
438+
A = vector_to_skewtensor(A) # shape: (n_atoms, hidden_channels, 3, 3)
439+
440+
S = (
418441
self.distance_proj3(edge_attr)[..., None, None]
442+
* Zij[..., None, None]
419443
* vector_to_symtensor(edge_vec_norm)[..., None, :, :]
420-
)
421-
features = Aij + Sij
422-
features.diagonal(dim1=-2, dim2=-1).add_(Iij.unsqueeze(-1))
423-
return features * C
444+
) # shape: (n_edges, hidden_channels, 3, 3)
445+
S = self._aggregate_edge_features(
446+
z.shape[0], S, edge_index[0]
447+
) # shape: (n_atoms, hidden_channels, 3, 3)
448+
I = self.distance_proj1(edge_attr) * Zij
449+
I = self._aggregate_edge_features(z.shape[0], I, edge_index[0])
450+
features = A + S
451+
features.diagonal(dim1=-2, dim2=-1).add_(I.unsqueeze(-1))
452+
return features
424453

454+
@nvtx_annotate("aggregate_edge_features")
425455
def _aggregate_edge_features(
426-
self, num_atoms: int, X: Tensor, edge_index: Tensor
456+
self, num_atoms: int, T: Tensor, source_indices: Tensor
427457
) -> Tensor:
428-
Xij = torch.zeros(
429-
num_atoms,
430-
self.hidden_channels,
431-
3,
432-
3,
433-
device=X.device,
434-
dtype=X.dtype,
435-
)
436-
Xij = Xij.index_add(0, edge_index[0], source=X)
437-
return Xij
458+
targetI = torch.zeros(num_atoms, *T.shape[1:], device=T.device, dtype=T.dtype)
459+
I = targetI.index_add(dim=0, index=source_indices, source=T)
460+
return I
438461

462+
@nvtx_annotate("norm_mlp")
439463
def _norm_mlp(self, norm):
440464
norm = self.init_norm(norm)
441465
for linear_scalar in self.linears_scalar:
442466
norm = self.act(linear_scalar(norm))
443467
norm = norm.reshape(-1, self.hidden_channels, 3)
444468
return norm
445469

470+
@nvtx_annotate("TensorEmbedding")
446471
def forward(
447472
self,
448473
z: Tensor,
@@ -451,17 +476,18 @@ def forward(
451476
edge_vec: Tensor,
452477
edge_attr: Tensor,
453478
) -> Tensor:
454-
Xij = self._compute_edge_tensor_features(
479+
X = self._compute_node_tensor_features(
455480
z, edge_index, edge_weight, edge_vec, edge_attr
456-
) # shape: (n_edges, hidden_channels, 3, 3)
457-
X = self._aggregate_edge_features(
458-
z.shape[0], Xij, edge_index
459481
) # shape: (n_atoms, hidden_channels, 3, 3)
482+
# X = self._aggregate_edge_features(
483+
# z.shape[0], Xij, edge_index
484+
# ) # shape: (n_atoms, hidden_channels, 3, 3)
460485
norm = self._norm_mlp(tensor_norm(X)) # shape: (n_atoms, hidden_channels)
461486
X = self.linear_tensor(X, norm) # shape: (n_atoms, hidden_channels, 3, 3)
462487
return X
463488

464489

490+
@nvtx_annotate("compute_tensor_edge_features")
465491
def compute_tensor_edge_features(X, edge_index, factor):
466492
I, A, S = decompose_tensor(X)
467493
msg = (
@@ -472,6 +498,7 @@ def compute_tensor_edge_features(X, edge_index, factor):
472498
return msg
473499

474500

501+
@nvtx_annotate("tensor_message_passing")
475502
def tensor_message_passing(n_atoms: int, edge_index: Tensor, tensor: Tensor) -> Tensor:
476503
msg = tensor.index_select(
477504
0, edge_index[1]
@@ -528,6 +555,7 @@ def reset_parameters(self):
528555
self.tensor_linear_in.reset_parameters()
529556
self.tensor_linear_out.reset_parameters()
530557

558+
@nvtx_annotate("update_tensor_node_features")
531559
def _update_tensor_node_features(self, X, X_aggregated):
532560
X = self.tensor_linear_in(X)
533561
B = torch.matmul(X, X_aggregated)
@@ -540,6 +568,7 @@ def _update_tensor_node_features(self, X, X_aggregated):
540568
Xnew = A + B
541569
return Xnew
542570

571+
@nvtx_annotate("compute_vector_node_features")
543572
def _compute_vector_node_features(self, edge_attr, edge_weight):
544573
C = self.cutoff(edge_weight)
545574
for linear_scalar in self.linears_scalar:
@@ -549,6 +578,7 @@ def _compute_vector_node_features(self, edge_attr, edge_weight):
549578
)
550579
return edge_attr
551580

581+
@nvtx_annotate("Interaction")
552582
def forward(
553583
self,
554584
X: Tensor,
@@ -562,7 +592,7 @@ def forward(
562592
) # shape (n_atoms, hidden_channels, 3, 3)
563593
node_features = self._compute_vector_node_features(
564594
edge_attr, edge_weight
565-
) # shape (n_atoms, hidden_channels, 3)
595+
) # shape (n_edges, hidden_channels, 3)
566596
Y_edges = compute_tensor_edge_features(
567597
X, edge_index, node_features
568598
) # shape (n_edges, hidden_channels, 3, 3)

0 commit comments

Comments
 (0)