Skip to content

Commit 1fd3447

Browse files
committed
t push origin mainMerge branch 'Jonas-Verhellen-graph-induced-lifting'
2 parents 2206228 + cd860bc commit 1fd3447

File tree

5 files changed

+628
-5
lines changed

5 files changed

+628
-5
lines changed
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
transform_type: 'lifting'
2-
transform_name: "SimplicialVietorisRipsLifting"
2+
transform_name: "SimplicialGraphInducedLifting"
33
complex_dim: 3
44
preserve_edge_attr: False
55
signed: True
6-
distance_threshold: 2.0
76
feature_lifting: ProjectionSum

modules/transforms/data_transform.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,16 +17,16 @@
1717
from modules.transforms.liftings.graph2simplicial.clique_lifting import (
1818
SimplicialCliqueLifting,
1919
)
20-
from modules.transforms.liftings.graph2simplicial.vietoris_rips_lifting import (
21-
SimplicialVietorisRipsLifting,
20+
from modules.transforms.liftings.graph2simplicial.graph_induced_lifting import (
21+
SimplicialGraphInducedLifting,
2222
)
2323

2424
TRANSFORMS = {
2525
# Graph -> Hypergraph
2626
"HypergraphKNNLifting": HypergraphKNNLifting,
2727
# Graph -> Simplicial Complex
2828
"SimplicialCliqueLifting": SimplicialCliqueLifting,
29-
"SimplicialVietorisRipsLifting": SimplicialVietorisRipsLifting,
29+
"SimplicialGraphInducedLifting": SimplicialGraphInducedLifting,
3030
# Graph -> Cell Complex
3131
"CellCycleLifting": CellCycleLifting,
3232
# Feature Liftings
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
from itertools import combinations
2+
3+
import networkx as nx
4+
import torch_geometric
5+
from toponetx.classes import SimplicialComplex
6+
7+
from modules.transforms.liftings.graph2simplicial.base import Graph2SimplicialLifting
8+
9+
10+
class SimplicialGraphInducedLifting(Graph2SimplicialLifting):
11+
r"""Lifts graphs to simplicial complex domain by identifying connected subgraphs as simplices.
12+
13+
Parameters
14+
----------
15+
**kwargs : optional
16+
Additional arguments for the class.
17+
"""
18+
19+
def __init__(self, **kwargs):
20+
super().__init__(**kwargs)
21+
22+
def lift_topology(self, data: torch_geometric.data.Data) -> dict:
23+
r"""Lifts the topology of a graph to a simplicial complex by identifying connected subgraphs as simplices.
24+
25+
Parameters
26+
----------
27+
data : torch_geometric.data.Data
28+
The input data to be lifted.
29+
30+
Returns
31+
-------
32+
dict
33+
The lifted topology.
34+
"""
35+
graph = self._generate_graph_from_data(data)
36+
simplicial_complex = SimplicialComplex(graph)
37+
all_nodes = list(graph.nodes)
38+
simplices = [set() for _ in range(2, self.complex_dim + 1)]
39+
40+
for k in range(2, self.complex_dim + 1):
41+
for combination in combinations(all_nodes, k + 1):
42+
subgraph = graph.subgraph(combination)
43+
if nx.is_connected(subgraph):
44+
simplices[k - 2].add(tuple(sorted(combination)))
45+
46+
for set_k_simplices in simplices:
47+
simplicial_complex.add_simplices_from(list(set_k_simplices))
48+
49+
return self._get_lifted_topology(simplicial_complex, graph)
Lines changed: 219 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,219 @@
1+
"""Test the message passing module."""
2+
3+
import torch
4+
5+
from modules.data.utils.utils import load_manual_graph
6+
from modules.transforms.liftings.graph2simplicial.graph_induced_lifting import (
7+
SimplicialGraphInducedLifting,
8+
)
9+
10+
11+
class TestSimplicialCliqueLifting:
12+
"""Test the SimplicialCliqueLifting class."""
13+
14+
def setup_method(self):
15+
# Load the graph
16+
self.data = load_manual_graph()
17+
18+
# Initialise the SimplicialCliqueLifting class
19+
self.lifting_signed = SimplicialGraphInducedLifting(complex_dim=3, signed=True)
20+
self.lifting_unsigned = SimplicialGraphInducedLifting(
21+
complex_dim=3, signed=False
22+
)
23+
24+
def test_lift_topology(self):
25+
"""Test the lift_topology method."""
26+
27+
# Test the lift_topology method
28+
lifted_data_signed = self.lifting_signed.forward(self.data.clone())
29+
lifted_data_unsigned = self.lifting_unsigned.forward(self.data.clone())
30+
31+
expected_incidence_1_singular_values_unsigned = torch.tensor(
32+
[3.7417, 2.4495, 2.4495, 2.4495, 2.4495, 2.4495, 2.4495, 2.4495]
33+
)
34+
35+
expected_incidence_1_singular_values_signed = torch.tensor(
36+
[
37+
2.8284e00,
38+
2.8284e00,
39+
2.8284e00,
40+
2.8284e00,
41+
2.8284e00,
42+
2.8284e00,
43+
2.8284e00,
44+
6.8993e-08,
45+
]
46+
)
47+
48+
U, S_unsigned, V = torch.svd(lifted_data_unsigned.incidence_1.to_dense())
49+
U, S_signed, V = torch.svd(lifted_data_signed.incidence_1.to_dense())
50+
51+
assert torch.allclose(
52+
expected_incidence_1_singular_values_unsigned, S_unsigned, atol=1.0e-04
53+
), "Something is wrong with unsigned incidence_1 (nodes to edges)."
54+
assert torch.allclose(
55+
expected_incidence_1_singular_values_signed, S_signed, atol=1.0e-04
56+
), "Something is wrong with signed incidence_1 (nodes to edges)."
57+
58+
expected_incidence_2_singular_values_unsigned = torch.tensor(
59+
[
60+
4.1190,
61+
3.1623,
62+
3.1623,
63+
3.1623,
64+
3.0961,
65+
3.0000,
66+
3.0000,
67+
2.7564,
68+
2.0000,
69+
2.0000,
70+
2.0000,
71+
2.0000,
72+
2.0000,
73+
2.0000,
74+
2.0000,
75+
2.0000,
76+
2.0000,
77+
2.0000,
78+
2.0000,
79+
2.0000,
80+
2.0000,
81+
2.0000,
82+
2.0000,
83+
1.7321,
84+
1.6350,
85+
1.4142,
86+
1.4142,
87+
1.0849,
88+
]
89+
)
90+
91+
expected_incidence_2_singular_values_signed = torch.tensor(
92+
[
93+
2.8284e00,
94+
2.8284e00,
95+
2.8284e00,
96+
2.8284e00,
97+
2.8284e00,
98+
2.8284e00,
99+
2.8284e00,
100+
2.8284e00,
101+
2.8284e00,
102+
2.8284e00,
103+
2.8284e00,
104+
2.8284e00,
105+
2.8284e00,
106+
2.8284e00,
107+
2.8284e00,
108+
2.8284e00,
109+
2.6458e00,
110+
2.6458e00,
111+
2.2361e00,
112+
1.7321e00,
113+
1.7321e00,
114+
9.3758e-07,
115+
4.7145e-07,
116+
4.3417e-07,
117+
4.0241e-07,
118+
3.1333e-07,
119+
2.2512e-07,
120+
1.9160e-07,
121+
]
122+
)
123+
124+
U, S_unsigned, V = torch.svd(lifted_data_unsigned.incidence_2.to_dense())
125+
U, S_signed, V = torch.svd(lifted_data_signed.incidence_2.to_dense())
126+
assert torch.allclose(
127+
expected_incidence_2_singular_values_unsigned, S_unsigned, atol=1.0e-04
128+
), "Something is wrong with unsigned incidence_2 (edges to triangles)."
129+
assert torch.allclose(
130+
expected_incidence_2_singular_values_signed, S_signed, atol=1.0e-04
131+
), "Something is wrong with signed incidence_2 (edges to triangles)."
132+
133+
expected_incidence_3_singular_values_unsigned = torch.tensor(
134+
[
135+
3.8466,
136+
3.1379,
137+
3.0614,
138+
2.8749,
139+
2.8392,
140+
2.8125,
141+
2.5726,
142+
2.3709,
143+
2.2858,
144+
2.2369,
145+
2.1823,
146+
2.0724,
147+
2.0000,
148+
2.0000,
149+
2.0000,
150+
1.8937,
151+
1.7814,
152+
1.7321,
153+
1.7256,
154+
1.5469,
155+
1.5340,
156+
1.4834,
157+
1.4519,
158+
1.4359,
159+
1.4142,
160+
1.0525,
161+
1.0000,
162+
1.0000,
163+
1.0000,
164+
1.0000,
165+
0.9837,
166+
0.9462,
167+
0.8853,
168+
0.7850,
169+
]
170+
)
171+
172+
expected_incidence_3_singular_values_signed = torch.tensor(
173+
[
174+
2.8284e00,
175+
2.8284e00,
176+
2.8284e00,
177+
2.8284e00,
178+
2.8284e00,
179+
2.8284e00,
180+
2.8284e00,
181+
2.8284e00,
182+
2.8284e00,
183+
2.6933e00,
184+
2.6458e00,
185+
2.6458e00,
186+
2.6280e00,
187+
2.4495e00,
188+
2.3040e00,
189+
1.9475e00,
190+
1.7321e00,
191+
1.7321e00,
192+
1.7321e00,
193+
1.4823e00,
194+
1.0000e00,
195+
1.0000e00,
196+
1.0000e00,
197+
1.0000e00,
198+
1.0000e00,
199+
1.0000e00,
200+
1.0000e00,
201+
1.0000e00,
202+
1.0000e00,
203+
7.3584e-01,
204+
2.7959e-07,
205+
2.1776e-07,
206+
1.4498e-07,
207+
5.5373e-08,
208+
]
209+
)
210+
211+
U, S_unsigned, V = torch.svd(lifted_data_unsigned.incidence_3.to_dense())
212+
U, S_signed, V = torch.svd(lifted_data_signed.incidence_3.to_dense())
213+
214+
assert torch.allclose(
215+
expected_incidence_3_singular_values_unsigned, S_unsigned, atol=1.0e-04
216+
), "Something is wrong with unsigned incidence_3 (triangles to tetrahedrons)."
217+
assert torch.allclose(
218+
expected_incidence_3_singular_values_signed, S_signed, atol=1.0e-04
219+
), "Something is wrong with signed incidence_3 (triangles to tetrahedrons)."

0 commit comments

Comments
 (0)