Skip to content

Commit 21e1580

Browse files
committed
Merge branch 'Snopt push origin mainoff-line'
2 parents d8ef506 + 2e82372 commit 21e1580

File tree

6 files changed

+548
-8
lines changed

6 files changed

+548
-8
lines changed
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
transform_type: "lifting"
2+
transform_name: "SimplicialLineLifting"
3+
feature_lifting: ProjectionSum

modules/data/utils/utils.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -50,16 +50,16 @@ def get_complex_connectivity(complex, max_rank, signed=False):
5050
)
5151
except ValueError: # noqa: PERF203
5252
if connectivity_info == "incidence":
53-
connectivity[f"{connectivity_info}_{rank_idx}"] = (
54-
generate_zero_sparse_connectivity(
55-
m=practical_shape[rank_idx - 1], n=practical_shape[rank_idx]
56-
)
53+
connectivity[
54+
f"{connectivity_info}_{rank_idx}"
55+
] = generate_zero_sparse_connectivity(
56+
m=practical_shape[rank_idx - 1], n=practical_shape[rank_idx]
5757
)
5858
else:
59-
connectivity[f"{connectivity_info}_{rank_idx}"] = (
60-
generate_zero_sparse_connectivity(
61-
m=practical_shape[rank_idx], n=practical_shape[rank_idx]
62-
)
59+
connectivity[
60+
f"{connectivity_info}_{rank_idx}"
61+
] = generate_zero_sparse_connectivity(
62+
m=practical_shape[rank_idx], n=practical_shape[rank_idx]
6363
)
6464
connectivity["shape"] = practical_shape
6565
return connectivity

modules/transforms/data_transform.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,9 @@
2020
from modules.transforms.liftings.graph2simplicial.graph_induced_lifting import (
2121
SimplicialGraphInducedLifting,
2222
)
23+
from modules.transforms.liftings.graph2simplicial.line_lifting import (
24+
SimplicialLineLifting,
25+
)
2326
from modules.transforms.liftings.graph2simplicial.vietoris_rips_lifting import (
2427
SimplicialVietorisRipsLifting,
2528
)
@@ -29,6 +32,7 @@
2932
"HypergraphKNNLifting": HypergraphKNNLifting,
3033
# Graph -> Simplicial Complex
3134
"SimplicialCliqueLifting": SimplicialCliqueLifting,
35+
"SimplicialLineLifting": SimplicialLineLifting,
3236
"SimplicialVietorisRipsLifting": SimplicialVietorisRipsLifting,
3337
"SimplicialGraphInducedLifting": SimplicialGraphInducedLifting,
3438
# Graph -> Cell Complex
Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
import networkx as nx
2+
import torch_geometric
3+
from toponetx.classes import SimplicialComplex
4+
5+
from modules.transforms.liftings.graph2simplicial.base import (
6+
Graph2SimplicialLifting,
7+
)
8+
9+
10+
class SimplicialLineLifting(Graph2SimplicialLifting):
11+
r"""Lifts graphs to a simplicial complex domain by considering line simplicial complex.
12+
13+
Line simplicial complex is a clique complex of the line graph. Line graph is a graph, in which
14+
the vertices are the edges in the initial graph, and two vertices are adjacent if the corresponding
15+
edges are adjacent in the initial graph.
16+
17+
Parameters
18+
----------
19+
**kwargs : optional
20+
Additional arguments for the class.
21+
"""
22+
23+
def __init__(self, **kwargs):
24+
super().__init__(**kwargs)
25+
26+
def lift_topology(self, data: torch_geometric.data.Data) -> dict:
27+
r"""Lifts the topology of a graph to simplicial domain via line simplicial complex construction.
28+
29+
Parameters
30+
----------
31+
data : torch_geometric.data.Data
32+
The input data to be lifted.
33+
34+
Returns
35+
----------
36+
dict
37+
The lifted topology.
38+
"""
39+
40+
graph = self._generate_graph_from_data(data)
41+
line_graph = nx.line_graph(graph)
42+
43+
node_features = {
44+
node: ((data.x[node[0], :] + data.x[node[1], :]) / 2)
45+
for node in list(line_graph.nodes)
46+
}
47+
48+
cliques = nx.find_cliques(line_graph)
49+
simplices = list(cliques) # list(map(lambda x: set(x), cliques))
50+
51+
# we need to rename simplices here since now vertices are named as pairs
52+
self.rename_vertices_dict = {
53+
node: i for i, node in enumerate(line_graph.nodes)
54+
}
55+
self.rename_vertices_dict_inverse = {
56+
i: node for node, i in self.rename_vertices_dict.items()
57+
}
58+
renamed_line_graph = nx.relabel_nodes(
59+
line_graph, self.rename_vertices_dict
60+
)
61+
62+
renamed_simplices = [
63+
{self.rename_vertices_dict[vertex] for vertex in simplex}
64+
for simplex in simplices
65+
]
66+
67+
renamed_node_features = {
68+
self.rename_vertices_dict[node]: value
69+
for node, value in node_features.items()
70+
}
71+
72+
simplicial_complex = SimplicialComplex(simplices=renamed_simplices)
73+
self.complex_dim = simplicial_complex.dim
74+
75+
simplicial_complex.set_simplex_attributes(
76+
renamed_node_features, name="features"
77+
)
78+
79+
return self._get_lifted_topology(
80+
simplicial_complex, renamed_line_graph
81+
)
Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
"""Test the message passing module."""
2+
3+
import torch
4+
import torch_geometric
5+
6+
from modules.transforms.liftings.graph2simplicial.line_lifting import (
7+
SimplicialLineLifting,
8+
)
9+
10+
11+
def create_test_graph():
12+
num_nodes = 5
13+
x = [1] * num_nodes
14+
edge_index = torch.tensor(
15+
[[0, 1, 2, 3, 4, 1, 2], [1, 2, 3, 4, 0, 4, 3]], dtype=torch.long
16+
) # [[0, 0, 1, 1, 2, 2, 3], [1, 4, 2, 3, 3, 4, 4]]
17+
y = [0, 0, 1, 1, 0]
18+
19+
return torch_geometric.data.Data(
20+
x=torch.tensor(x).float().reshape(-1, 1),
21+
edge_index=edge_index, # torch.Tensor(edge_index, dtype=torch.long),
22+
num_nodes=num_nodes,
23+
y=torch.tensor(y),
24+
)
25+
26+
27+
class TestSimplicialLineLifting:
28+
"""Test the SimplicialLineLifting class."""
29+
30+
def setup_method(self):
31+
# Load the graph
32+
self.data = create_test_graph() # load_manual_graph()
33+
34+
# Initialise the SimplicialCliqueLifting class
35+
self.lifting_signed = SimplicialLineLifting(signed=True)
36+
self.lifting_unsigned = SimplicialLineLifting(signed=False)
37+
38+
def test_lift_topology(self):
39+
"""Test the lift_topology method."""
40+
41+
# Test the lift_topology method
42+
lifted_data_signed = self.lifting_signed.forward(self.data.clone())
43+
lifted_data_unsigned = self.lifting_unsigned.forward(self.data.clone())
44+
45+
expected_incidence_1 = torch.tensor(
46+
[
47+
[-1.0, -1.0, -1.0, -1.0, 0.0, 0.0, 0.0, 0.0, 0.0],
48+
[1.0, 0.0, 0.0, 0.0, -1.0, -1.0, 0.0, 0.0, 0.0],
49+
[0.0, 1.0, 0.0, 0.0, 1.0, 0.0, -1.0, 0.0, 0.0],
50+
[0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, -1.0, -1.0],
51+
[0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 1.0, 0.0],
52+
[0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0, 1.0],
53+
]
54+
)
55+
56+
print(lifted_data_signed.incidence_1.to_dense())
57+
58+
assert (
59+
abs(expected_incidence_1) == lifted_data_unsigned.incidence_1.to_dense()
60+
).all(), "Something is wrong with unsigned incidence_1 (nodes to edges)."
61+
assert (
62+
expected_incidence_1 == lifted_data_signed.incidence_1.to_dense()
63+
).all(), "Something is wrong with signed incidence_1 (nodes to edges)."
64+
65+
expected_incidence_2 = torch.tensor(
66+
[
67+
[1.0, 0.0, 0.0],
68+
[-1.0, 1.0, 0.0],
69+
[0.0, 0.0, 1.0],
70+
[0.0, -1.0, -1.0],
71+
[1.0, 0.0, 0.0],
72+
[0.0, 0.0, 0.0],
73+
[0.0, 1.0, 0.0],
74+
[0.0, 0.0, 0.0],
75+
[0.0, 0.0, 1.0],
76+
]
77+
)
78+
79+
assert (
80+
abs(expected_incidence_2) == lifted_data_unsigned.incidence_2.to_dense()
81+
).all(), "Something is wrong with unsigned incidence_2 (edges to triangles)."
82+
assert (
83+
expected_incidence_2 == lifted_data_signed.incidence_2.to_dense()
84+
).all(), "Something is wrong with signed incidence_2 (edges to triangles)."

tutorials/graph2simplicial/line_lifting.ipynb

Lines changed: 368 additions & 0 deletions
Large diffs are not rendered by default.

0 commit comments

Comments
 (0)