11
11
rbf_class_mapping ,
12
12
act_class_mapping ,
13
13
MLP ,
14
+ nvtx_annotate ,
15
+ nvtx_range ,
14
16
)
15
17
16
18
__all__ = ["TensorNet" ]
17
- torch .set_float32_matmul_precision ("high " )
19
+ torch .set_float32_matmul_precision ("medium " )
18
20
torch .backends .cuda .matmul .allow_tf32 = True
19
21
20
22
23
+ @nvtx_annotate ("vector_to_skewtensor" )
21
24
def vector_to_skewtensor (vector ):
22
25
"""Creates a skew-symmetric tensor from a vector."""
23
- batch_size = vector .size ( 0 )
26
+ batch_size = vector .shape [: - 1 ]
24
27
zero = torch .zeros (batch_size , device = vector .device , dtype = vector .dtype )
25
28
tensor = torch .stack (
26
29
(
27
30
zero ,
28
- - vector [: , 2 ],
29
- vector [: , 1 ],
30
- vector [: , 2 ],
31
+ - vector [... , 2 ],
32
+ vector [... , 1 ],
33
+ vector [... , 2 ],
31
34
zero ,
32
- - vector [: , 0 ],
33
- - vector [: , 1 ],
34
- vector [: , 0 ],
35
+ - vector [... , 0 ],
36
+ - vector [... , 1 ],
37
+ vector [... , 0 ],
35
38
zero ,
36
39
),
37
- dim = 1 ,
40
+ dim = - 1 ,
38
41
)
39
- tensor = tensor .view (- 1 , 3 , 3 )
42
+ tensor = tensor .view (* batch_size , 3 , 3 )
40
43
return tensor .squeeze (0 )
41
44
42
45
46
+ @nvtx_annotate ("vector_to_symtensor" )
43
47
def vector_to_symtensor (vector ):
44
48
"""Creates a symmetric traceless tensor from the outer product of a vector with itself."""
45
49
tensor = torch .matmul (vector .unsqueeze (- 1 ), vector .unsqueeze (- 2 ))
@@ -50,6 +54,7 @@ def vector_to_symtensor(vector):
50
54
return S
51
55
52
56
57
+ @nvtx_annotate ("decompose_tensor" )
53
58
def decompose_tensor (tensor ):
54
59
"""Full tensor decomposition into irreducible components."""
55
60
I = (tensor .diagonal (offset = 0 , dim1 = - 1 , dim2 = - 2 )).mean (- 1 )[
@@ -60,6 +65,7 @@ def decompose_tensor(tensor):
60
65
return I , A , S
61
66
62
67
68
+ @nvtx_annotate ("tensor_norm" )
63
69
def tensor_norm (tensor ):
64
70
"""Computes Frobenius norm."""
65
71
return (tensor ** 2 ).sum ((- 2 , - 1 ))
@@ -220,6 +226,7 @@ def reset_parameters(self):
220
226
self .linear .reset_parameters ()
221
227
self .out_norm .reset_parameters ()
222
228
229
+ @nvtx_annotate ("make_static" )
223
230
def _make_static (
224
231
self , num_nodes : int , edge_index : Tensor , edge_weight : Tensor , edge_vec : Tensor
225
232
) -> Tuple [Tensor , Tensor , Tensor ]:
@@ -235,6 +242,7 @@ def _make_static(
235
242
)
236
243
return edge_index , edge_weight , edge_vec
237
244
245
+ @nvtx_annotate ("compute_neighbors" )
238
246
def _compute_neighbors (
239
247
self , pos : Tensor , batch : Tensor , box : Optional [Tensor ]
240
248
) -> Tuple [Tensor , Tensor , Tensor ]:
@@ -248,6 +256,7 @@ def _compute_neighbors(
248
256
)
249
257
return edge_index , edge_weight , edge_vec
250
258
259
+ @nvtx_annotate ("output" )
251
260
def output (self , X : Tensor ) -> Tensor :
252
261
I , A , S = decompose_tensor (X ) # shape: (n_atoms, hidden_channels, 3, 3)
253
262
x = torch .cat (
@@ -257,6 +266,7 @@ def output(self, X: Tensor) -> Tensor:
257
266
x = self .act (self .linear ((x ))) # shape: (n_atoms, hidden_channels)
258
267
return x
259
268
269
+ @nvtx_annotate ("TensorNet" )
260
270
def forward (
261
271
self ,
262
272
z : Tensor ,
@@ -303,6 +313,7 @@ def reset_parameters(self):
303
313
self .linearA .reset_parameters ()
304
314
self .linearS .reset_parameters ()
305
315
316
+ @nvtx_annotate ("TensorLinear" )
306
317
def forward (self , X : Tensor , factor : Optional [Tensor ] = None ) -> Tensor :
307
318
if factor is None :
308
319
factor = (
@@ -363,6 +374,8 @@ def __init__(
363
374
nn .Linear (2 * hidden_channels , 3 * hidden_channels , bias = True , dtype = dtype )
364
375
)
365
376
self .init_norm = nn .LayerNorm (hidden_channels , dtype = dtype )
377
+ self .num_rbf = num_rbf
378
+ self .hidden_channels = hidden_channels
366
379
self .reset_parameters ()
367
380
368
381
def reset_parameters (self ):
@@ -376,6 +389,7 @@ def reset_parameters(self):
376
389
linear .reset_parameters ()
377
390
self .init_norm .reset_parameters ()
378
391
392
+ @nvtx_annotate ("normalize_edges" )
379
393
def _normalize_edges (
380
394
self , edge_index : Tensor , edge_weight : Tensor , edge_vec : Tensor
381
395
) -> Tensor :
@@ -385,16 +399,18 @@ def _normalize_edges(
385
399
edge_vec = edge_vec / edge_weight .masked_fill (mask , 1 ).unsqueeze (1 )
386
400
return edge_vec
387
401
402
+ @nvtx_annotate ("compute_edge_atomic_features" )
388
403
def _compute_edge_atomic_features (self , z : Tensor , edge_index : Tensor ) -> Tensor :
389
404
Z = self .emb (z )
390
405
Zij = self .emb2 (
391
406
Z .index_select (0 , edge_index .t ().reshape (- 1 )).view (
392
407
- 1 , self .hidden_channels * 2
393
408
)
394
- )[..., None , None ]
409
+ )
395
410
return Zij
396
411
397
- def _compute_edge_tensor_features (
412
+ @nvtx_annotate ("compute_edge_tensor_features" )
413
+ def _compute_node_tensor_features (
398
414
self ,
399
415
z : Tensor ,
400
416
edge_index ,
@@ -405,44 +421,53 @@ def _compute_edge_tensor_features(
405
421
edge_vec_norm = self ._normalize_edges (
406
422
edge_index , edge_weight , edge_vec
407
423
) # shape: (n_edges, 3)
408
- Zij = self ._compute_edge_atomic_features (
424
+ Zij = self .cutoff ( edge_weight )[:, None ] * self . _compute_edge_atomic_features (
409
425
z , edge_index
410
426
) # shape: (n_edges, hidden_channels)
411
- C = self .cutoff (edge_weight ).reshape (- 1 , 1 , 1 , 1 ) * Zij
412
- Iij = self .distance_proj1 (edge_attr )
413
- Aij = (
414
- self .distance_proj2 (edge_attr )[..., None , None ]
415
- * vector_to_skewtensor (edge_vec_norm )[..., None , :, :]
416
- )
417
- Sij = (
427
+
428
+ A = (
429
+ self .distance_proj2 (edge_attr )[
430
+ ..., None
431
+ ] # shape: (n_edges, hidden_channels, 1)
432
+ * Zij [..., None ] # shape: (n_edges, hidden_channels, 1)
433
+ * edge_vec_norm [:, None , :] # shape: (n_edges, 1, 3)
434
+ ) # shape: (n_edges, hidden_channels, 3)
435
+ A = self ._aggregate_edge_features (
436
+ z .shape [0 ], A , edge_index [0 ]
437
+ ) # shape: (n_atoms, hidden_channels, 3)
438
+ A = vector_to_skewtensor (A ) # shape: (n_atoms, hidden_channels, 3, 3)
439
+
440
+ S = (
418
441
self .distance_proj3 (edge_attr )[..., None , None ]
442
+ * Zij [..., None , None ]
419
443
* vector_to_symtensor (edge_vec_norm )[..., None , :, :]
420
- )
421
- features = Aij + Sij
422
- features .diagonal (dim1 = - 2 , dim2 = - 1 ).add_ (Iij .unsqueeze (- 1 ))
423
- return features * C
444
+ ) # shape: (n_edges, hidden_channels, 3, 3)
445
+ S = self ._aggregate_edge_features (
446
+ z .shape [0 ], S , edge_index [0 ]
447
+ ) # shape: (n_atoms, hidden_channels, 3, 3)
448
+ I = self .distance_proj1 (edge_attr ) * Zij
449
+ I = self ._aggregate_edge_features (z .shape [0 ], I , edge_index [0 ])
450
+ features = A + S
451
+ features .diagonal (dim1 = - 2 , dim2 = - 1 ).add_ (I .unsqueeze (- 1 ))
452
+ return features
424
453
454
+ @nvtx_annotate ("aggregate_edge_features" )
425
455
def _aggregate_edge_features (
426
- self , num_atoms : int , X : Tensor , edge_index : Tensor
456
+ self , num_atoms : int , T : Tensor , source_indices : Tensor
427
457
) -> Tensor :
428
- Xij = torch .zeros (
429
- num_atoms ,
430
- self .hidden_channels ,
431
- 3 ,
432
- 3 ,
433
- device = X .device ,
434
- dtype = X .dtype ,
435
- )
436
- Xij = Xij .index_add (0 , edge_index [0 ], source = X )
437
- return Xij
458
+ targetI = torch .zeros (num_atoms , * T .shape [1 :], device = T .device , dtype = T .dtype )
459
+ I = targetI .index_add (dim = 0 , index = source_indices , source = T )
460
+ return I
438
461
462
+ @nvtx_annotate ("norm_mlp" )
439
463
def _norm_mlp (self , norm ):
440
464
norm = self .init_norm (norm )
441
465
for linear_scalar in self .linears_scalar :
442
466
norm = self .act (linear_scalar (norm ))
443
467
norm = norm .reshape (- 1 , self .hidden_channels , 3 )
444
468
return norm
445
469
470
+ @nvtx_annotate ("TensorEmbedding" )
446
471
def forward (
447
472
self ,
448
473
z : Tensor ,
@@ -451,17 +476,18 @@ def forward(
451
476
edge_vec : Tensor ,
452
477
edge_attr : Tensor ,
453
478
) -> Tensor :
454
- Xij = self ._compute_edge_tensor_features (
479
+ X = self ._compute_node_tensor_features (
455
480
z , edge_index , edge_weight , edge_vec , edge_attr
456
- ) # shape: (n_edges, hidden_channels, 3, 3)
457
- X = self ._aggregate_edge_features (
458
- z .shape [0 ], Xij , edge_index
459
481
) # shape: (n_atoms, hidden_channels, 3, 3)
482
+ # X = self._aggregate_edge_features(
483
+ # z.shape[0], Xij, edge_index
484
+ # ) # shape: (n_atoms, hidden_channels, 3, 3)
460
485
norm = self ._norm_mlp (tensor_norm (X )) # shape: (n_atoms, hidden_channels)
461
486
X = self .linear_tensor (X , norm ) # shape: (n_atoms, hidden_channels, 3, 3)
462
487
return X
463
488
464
489
490
+ @nvtx_annotate ("compute_tensor_edge_features" )
465
491
def compute_tensor_edge_features (X , edge_index , factor ):
466
492
I , A , S = decompose_tensor (X )
467
493
msg = (
@@ -472,6 +498,7 @@ def compute_tensor_edge_features(X, edge_index, factor):
472
498
return msg
473
499
474
500
501
+ @nvtx_annotate ("tensor_message_passing" )
475
502
def tensor_message_passing (n_atoms : int , edge_index : Tensor , tensor : Tensor ) -> Tensor :
476
503
msg = tensor .index_select (
477
504
0 , edge_index [1 ]
@@ -528,6 +555,7 @@ def reset_parameters(self):
528
555
self .tensor_linear_in .reset_parameters ()
529
556
self .tensor_linear_out .reset_parameters ()
530
557
558
+ @nvtx_annotate ("update_tensor_node_features" )
531
559
def _update_tensor_node_features (self , X , X_aggregated ):
532
560
X = self .tensor_linear_in (X )
533
561
B = torch .matmul (X , X_aggregated )
@@ -540,6 +568,7 @@ def _update_tensor_node_features(self, X, X_aggregated):
540
568
Xnew = A + B
541
569
return Xnew
542
570
571
+ @nvtx_annotate ("compute_vector_node_features" )
543
572
def _compute_vector_node_features (self , edge_attr , edge_weight ):
544
573
C = self .cutoff (edge_weight )
545
574
for linear_scalar in self .linears_scalar :
@@ -549,6 +578,7 @@ def _compute_vector_node_features(self, edge_attr, edge_weight):
549
578
)
550
579
return edge_attr
551
580
581
+ @nvtx_annotate ("Interaction" )
552
582
def forward (
553
583
self ,
554
584
X : Tensor ,
@@ -562,7 +592,7 @@ def forward(
562
592
) # shape (n_atoms, hidden_channels, 3, 3)
563
593
node_features = self ._compute_vector_node_features (
564
594
edge_attr , edge_weight
565
- ) # shape (n_atoms , hidden_channels, 3)
595
+ ) # shape (n_edges , hidden_channels, 3)
566
596
Y_edges = compute_tensor_edge_features (
567
597
X , edge_index , node_features
568
598
) # shape (n_edges, hidden_channels, 3, 3)
0 commit comments