2
2
import torch_geometric
3
3
from toponetx .classes import SimplicialComplex
4
4
5
- from modules .transforms .liftings .graph2simplicial .base import Graph2SimplicialLifting
5
+ from modules .transforms .liftings .graph2simplicial .base import (
6
+ Graph2SimplicialLifting ,
7
+ )
6
8
7
9
8
10
class SimplicialLineLifting (Graph2SimplicialLifting ):
@@ -47,11 +49,15 @@ def lift_topology(self, data: torch_geometric.data.Data) -> dict:
47
49
simplices = list (cliques ) # list(map(lambda x: set(x), cliques))
48
50
49
51
# we need to rename simplices here since now vertices are named as pairs
50
- self .rename_vertices_dict = {node : i for i , node in enumerate (line_graph .nodes )}
52
+ self .rename_vertices_dict = {
53
+ node : i for i , node in enumerate (line_graph .nodes )
54
+ }
51
55
self .rename_vertices_dict_inverse = {
52
56
i : node for node , i in self .rename_vertices_dict .items ()
53
57
}
54
- renamed_line_graph = nx .relabel_nodes (line_graph , self .rename_vertices_dict )
58
+ renamed_line_graph = nx .relabel_nodes (
59
+ line_graph , self .rename_vertices_dict
60
+ )
55
61
56
62
renamed_simplices = [
57
63
{self .rename_vertices_dict [vertex ] for vertex in simplex }
@@ -70,4 +76,6 @@ def lift_topology(self, data: torch_geometric.data.Data) -> dict:
70
76
renamed_node_features , name = "features"
71
77
)
72
78
73
- return self ._get_lifted_topology (simplicial_complex , renamed_line_graph )
79
+ return self ._get_lifted_topology (
80
+ simplicial_complex , renamed_line_graph
81
+ )
0 commit comments