Skip to content

Commit 48a7de5

Browse files
perf: use contiguous memory stride for edge/angle indices (#4804)
This brings ~10% speedup for the training of DPA3 model with 24 thin layers and dynamic sel (average training time: 0.6891 s/batch vs 0.7635 s/batch, batch size = auto:128). <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit - **Refactor** - Updated the shape and indexing conventions for edge and angle index arrays and tensors across multiple components, standardizing them to column-major format for improved consistency. - Adjusted initialization and handling of index arrays/tensors to match the new conventions in both NumPy and PyTorch implementations. - Updated relevant method calls and internal logic to align with the revised index shapes and access patterns. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 1dc5b04 commit 48a7de5

File tree

5 files changed

+40
-38
lines changed

5 files changed

+40
-38
lines changed

deepmd/dpmodel/descriptor/repflows.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -578,7 +578,8 @@ def call(
578578
# n_angle x 1
579579
a_sw = (a_sw[:, :, :, None] * a_sw[:, :, None, :])[a_nlist_mask]
580580
else:
581-
edge_index = angle_index = xp.zeros([1, 3], dtype=nlist.dtype)
581+
edge_index = xp.zeros([2, 1], dtype=nlist.dtype)
582+
angle_index = xp.zeros([3, 1], dtype=nlist.dtype)
582583

583584
# get edge and angle embedding
584585
# nb x nloc x nnei x e_dim [OR] n_edge x e_dim
@@ -622,7 +623,7 @@ def call(
622623
edge_ebd,
623624
h2,
624625
sw,
625-
owner=edge_index[:, 0],
626+
owner=edge_index[0, :],
626627
num_owner=nframes * nloc,
627628
nb=nframes,
628629
nloc=nloc,
@@ -1286,8 +1287,8 @@ def call(
12861287
a_nlist: np.ndarray, # nf x nloc x a_nnei
12871288
a_nlist_mask: np.ndarray, # nf x nloc x a_nnei
12881289
a_sw: np.ndarray, # switch func, nf x nloc x a_nnei
1289-
edge_index: np.ndarray, # n_edge x 2
1290-
angle_index: np.ndarray, # n_angle x 3
1290+
edge_index: np.ndarray, # 2 x n_edge
1291+
angle_index: np.ndarray, # 3 x n_angle
12911292
):
12921293
"""
12931294
Parameters
@@ -1312,12 +1313,12 @@ def call(
13121313
Masks of the neighbor list for angle. real nei 1 otherwise 0
13131314
a_sw : nf x nloc x a_nnei
13141315
Switch function for angle.
1315-
edge_index : Optional for dynamic sel, n_edge x 2
1316+
edge_index : Optional for dynamic sel, 2 x n_edge
13161317
n2e_index : n_edge
13171318
Broadcast indices from node(i) to edge(ij), or reduction indices from edge(ij) to node(i).
13181319
n_ext2e_index : n_edge
13191320
Broadcast indices from extended node(j) to edge(ij).
1320-
angle_index : Optional for dynamic sel, n_angle x 3
1321+
angle_index : Optional for dynamic sel, 3 x n_angle
13211322
n2a_index : n_angle
13221323
Broadcast indices from extended node(j) to angle(ijk).
13231324
eij2a_index : n_angle
@@ -1362,11 +1363,11 @@ def call(
13621363
assert (n_edge, 3) == h2.shape
13631364
del a_nlist # may be used in the future
13641365

1365-
n2e_index, n_ext2e_index = edge_index[:, 0], edge_index[:, 1]
1366+
n2e_index, n_ext2e_index = edge_index[0, :], edge_index[1, :]
13661367
n2a_index, eij2a_index, eik2a_index = (
1367-
angle_index[:, 0],
1368-
angle_index[:, 1],
1369-
angle_index[:, 2],
1368+
angle_index[0, :],
1369+
angle_index[1, :],
1370+
angle_index[2, :],
13701371
)
13711372

13721373
# nb x nloc x nnei x n_dim [OR] n_edge x n_dim

deepmd/dpmodel/utils/network.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1036,12 +1036,12 @@ def get_graph_index(
10361036
10371037
Returns
10381038
-------
1039-
edge_index : n_edge x 2
1039+
edge_index : 2 x n_edge
10401040
n2e_index : n_edge
10411041
Broadcast indices from node(i) to edge(ij), or reduction indices from edge(ij) to node(i).
10421042
n_ext2e_index : n_edge
10431043
Broadcast indices from extended node(j) to edge(ij).
1044-
angle_index : n_angle x 3
1044+
angle_index : 3 x n_angle
10451045
n2a_index : n_angle
10461046
Broadcast indices from extended node(j) to angle(ijk).
10471047
eij2a_index : n_angle
@@ -1111,7 +1111,7 @@ def get_graph_index(
11111111
# n_angle
11121112
eik2a_index = edge_index_ik[a_nlist_mask_3d]
11131113

1114-
edge_index_result = xp.stack([n2e_index, n_ext2e_index], axis=-1)
1115-
angle_index_result = xp.stack([n2a_index, eij2a_index, eik2a_index], axis=-1)
1114+
edge_index_result = xp.stack([n2e_index, n_ext2e_index], axis=0)
1115+
angle_index_result = xp.stack([n2a_index, eij2a_index, eik2a_index], axis=0)
11161116

11171117
return edge_index_result, angle_index_result

deepmd/pt/model/descriptor/repflow_layer.py

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -370,7 +370,7 @@ def _cal_hg_dynamic(
370370
# n_edge x e_dim
371371
flat_edge_ebd = flat_edge_ebd * flat_sw.unsqueeze(-1)
372372
# n_edge x 3 x e_dim
373-
flat_h2g2 = (flat_h2[..., None] * flat_edge_ebd[:, None, :]).reshape(
373+
flat_h2g2 = (flat_h2.unsqueeze(-1) * flat_edge_ebd.unsqueeze(-2)).reshape(
374374
-1, 3 * e_dim
375375
)
376376
# nf x nloc x 3 x e_dim
@@ -694,8 +694,8 @@ def forward(
694694
a_nlist: torch.Tensor, # nf x nloc x a_nnei
695695
a_nlist_mask: torch.Tensor, # nf x nloc x a_nnei
696696
a_sw: torch.Tensor, # switch func, nf x nloc x a_nnei
697-
edge_index: torch.Tensor, # n_edge x 2
698-
angle_index: torch.Tensor, # n_angle x 3
697+
edge_index: torch.Tensor, # 2 x n_edge
698+
angle_index: torch.Tensor, # 3 x n_angle
699699
):
700700
"""
701701
Parameters
@@ -720,12 +720,12 @@ def forward(
720720
Masks of the neighbor list for angle. real nei 1 otherwise 0
721721
a_sw : nf x nloc x a_nnei
722722
Switch function for angle.
723-
edge_index : Optional for dynamic sel, n_edge x 2
723+
edge_index : Optional for dynamic sel, 2 x n_edge
724724
n2e_index : n_edge
725725
Broadcast indices from node(i) to edge(ij), or reduction indices from edge(ij) to node(i).
726726
n_ext2e_index : n_edge
727727
Broadcast indices from extended node(j) to edge(ij).
728-
angle_index : Optional for dynamic sel, n_angle x 3
728+
angle_index : Optional for dynamic sel, 3 x n_angle
729729
n2a_index : n_angle
730730
Broadcast indices from extended node(j) to angle(ijk).
731731
eij2a_index : n_angle
@@ -745,19 +745,21 @@ def forward(
745745
nb, nloc, nnei = nlist.shape
746746
nall = node_ebd_ext.shape[1]
747747
node_ebd = node_ebd_ext[:, :nloc, :]
748-
n_edge = int(nlist_mask.sum().item())
749748
assert (nb, nloc) == node_ebd.shape[:2]
750749
if not self.use_dynamic_sel:
751750
assert (nb, nloc, nnei, 3) == h2.shape
751+
n_edge = None
752752
else:
753-
assert (n_edge, 3) == h2.shape
753+
# n_edge = int(nlist_mask.sum().item())
754+
# assert (n_edge, 3) == h2.shape
755+
n_edge = h2.shape[0]
754756
del a_nlist # may be used in the future
755757

756-
n2e_index, n_ext2e_index = edge_index[:, 0], edge_index[:, 1]
758+
n2e_index, n_ext2e_index = edge_index[0], edge_index[1]
757759
n2a_index, eij2a_index, eik2a_index = (
758-
angle_index[:, 0],
759-
angle_index[:, 1],
760-
angle_index[:, 2],
760+
angle_index[0],
761+
angle_index[1],
762+
angle_index[2],
761763
)
762764

763765
# nb x nloc x nnei x n_dim [OR] n_edge x n_dim
@@ -1026,7 +1028,9 @@ def forward(
10261028
if not self.use_dynamic_sel:
10271029
# nb x nloc x a_nnei x a_nnei x e_dim
10281030
weighted_edge_angle_update = (
1029-
a_sw[..., None, None] * a_sw[..., None, :, None] * edge_angle_update
1031+
a_sw.unsqueeze(-1).unsqueeze(-1)
1032+
* a_sw.unsqueeze(-2).unsqueeze(-1)
1033+
* edge_angle_update
10301034
)
10311035
# nb x nloc x a_nnei x e_dim
10321036
reduced_edge_angle_update = torch.sum(

deepmd/pt/model/descriptor/repflows.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -537,9 +537,8 @@ def forward(
537537
a_sw = (a_sw[:, :, :, None] * a_sw[:, :, None, :])[a_nlist_mask]
538538
else:
539539
# avoid jit assertion
540-
edge_index = angle_index = torch.zeros(
541-
[1, 3], device=nlist.device, dtype=nlist.dtype
542-
)
540+
edge_index = torch.zeros([2, 1], device=nlist.device, dtype=nlist.dtype)
541+
angle_index = torch.zeros([3, 1], device=nlist.device, dtype=nlist.dtype)
543542
# get edge and angle embedding
544543
# nb x nloc x nnei x e_dim [OR] n_edge x e_dim
545544
if not self.edge_init_use_dist:
@@ -646,7 +645,7 @@ def forward(
646645
edge_ebd,
647646
h2,
648647
sw,
649-
owner=edge_index[:, 0],
648+
owner=edge_index[0],
650649
num_owner=nframes * nloc,
651650
nb=nframes,
652651
nloc=nloc,

deepmd/pt/model/network/utils.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -74,12 +74,12 @@ def get_graph_index(
7474
7575
Returns
7676
-------
77-
edge_index : n_edge x 2
77+
edge_index : 2 x n_edge
7878
n2e_index : n_edge
7979
Broadcast indices from node(i) to edge(ij), or reduction indices from edge(ij) to node(i).
8080
n_ext2e_index : n_edge
8181
Broadcast indices from extended node(j) to edge(ij).
82-
angle_index : n_angle x 3
82+
angle_index : 3 x n_angle
8383
n2a_index : n_angle
8484
Broadcast indices from extended node(j) to angle(ijk).
8585
eij2a_index : n_angle
@@ -135,9 +135,7 @@ def get_graph_index(
135135
# n_angle
136136
eik2a_index = edge_index_ik[a_nlist_mask_3d]
137137

138-
return torch.cat(
139-
[n2e_index.unsqueeze(-1), n_ext2e_index.unsqueeze(-1)], dim=-1
140-
), torch.cat(
141-
[n2a_index.unsqueeze(-1), eij2a_index.unsqueeze(-1), eik2a_index.unsqueeze(-1)],
142-
dim=-1,
143-
)
138+
edge_index_result = torch.stack([n2e_index, n_ext2e_index], dim=0)
139+
angle_index_result = torch.stack([n2a_index, eij2a_index, eik2a_index], dim=0)
140+
141+
return edge_index_result, angle_index_result

0 commit comments

Comments
 (0)