Skip to content

Independent sets lifting (graph to simplicial) #2

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
transform_type: 'lifting'
transform_name: "SimplicialIndependentSetLifting"
complex_dim: 3
preserve_edge_attr: False
signed: False
feature_lifting: ProjectionSum
4 changes: 4 additions & 0 deletions modules/transforms/data_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,16 @@
from modules.transforms.liftings.graph2simplicial.clique_lifting import (
SimplicialCliqueLifting,
)
from modules.transforms.liftings.graph2simplicial.independent_set_lifting import (
SimplicialIndependentSetsLifting,
)

TRANSFORMS = {
# Graph -> Hypergraph
"HypergraphKNNLifting": HypergraphKNNLifting,
# Graph -> Simplicial Complex
"SimplicialCliqueLifting": SimplicialCliqueLifting,
"SimplicialIndependentSetLifting": SimplicialIndependentSetsLifting,
# Graph -> Cell Complex
"CellCycleLifting": CellCycleLifting,
# Feature Liftings
Expand Down
31 changes: 25 additions & 6 deletions modules/transforms/liftings/graph2simplicial/clique_lifting.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,30 @@ class SimplicialCliqueLifting(Graph2SimplicialLifting):
def __init__(self, **kwargs):
super().__init__(**kwargs)

@staticmethod
def generate_simplices(dim: int, graph: nx.Graph) -> list:
r"""Generates list of simplices from cliques of a given graph

Parameters
----------
dim: int
Maximum dimension of the complex
graph: nx.Graph
Input graph

Returns
-------
list[tuple]
List of simplices
"""
cliques = nx.find_cliques(graph)
simplices = [set() for _ in range(2, dim + 1)]
for clique in cliques:
for i in range(2, dim + 1):
for c in combinations(clique, i + 1):
simplices[i - 2].add(tuple(c))
return simplices

def lift_topology(self, data: torch_geometric.data.Data) -> dict:
r"""Lifts the topology of a graph to a simplicial complex by identifying the cliques as k-simplices.

Expand All @@ -34,13 +58,8 @@ def lift_topology(self, data: torch_geometric.data.Data) -> dict:
"""
graph = self._generate_graph_from_data(data)
simplicial_complex = SimplicialComplex(graph)
cliques = nx.find_cliques(graph)
simplices = [set() for _ in range(2, self.complex_dim + 1)]
for clique in cliques:
for i in range(2, self.complex_dim + 1):
for c in combinations(clique, i + 1):
simplices[i - 2].add(tuple(c))

simplices = self.generate_simplices(self.complex_dim, graph)
for set_k_simplices in simplices:
simplicial_complex.add_simplices_from(list(set_k_simplices))

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
import networkx as nx
import torch_geometric
from toponetx.classes import SimplicialComplex

from modules.transforms.liftings.graph2simplicial.base import Graph2SimplicialLifting
from modules.transforms.liftings.graph2simplicial.clique_lifting import (
SimplicialCliqueLifting,
)


class SimplicialIndependentSetsLifting(Graph2SimplicialLifting):
r"""Lifts graphs to simplicial complex domain by identifying the independent sets as k-simplices

Parameters
----------
**kwargs : optional
Additional arguments for the class.
"""

def __init__(self, **kwargs):
super().__init__(**kwargs)

def lift_topology(self, data: torch_geometric.data.Data) -> dict:
r"""Lifts the topology of a graph to a simplicial complex by identifying the independent sets as k-simplices

We use the fact that the independent sets of a graph G are the cliques of its complement graph Gc. The nodes and the edges
of the complement graph represent the 0-simplices and 1-simplices respectively.

Parameters
----------
data : torch_geometric.data.Data
The input data to be lifted.

Returns
-------
dict
The lifted topology.
"""
graph = self._generate_graph_from_data(data)
complement_graph = nx.complement(graph)

# Since we lose the original edges, not sure we can keep the edge features ? Should we warn ?
self.contains_edge_attr = False

# Propagate node features to complement
nodes_attributes = {
n: dict(features=data.x[n], dim=0) for n in range(data.x.shape[0])
}
nx.set_node_attributes(complement_graph, nodes_attributes)

simplicial_complex = SimplicialComplex(complement_graph)
simplices = SimplicialCliqueLifting.generate_simplices(
self.complex_dim, complement_graph
)

for set_k_simplices in simplices:
simplicial_complex.add_simplices_from(list(set_k_simplices))

return self._get_lifted_topology(simplicial_complex, graph)
Original file line number Diff line number Diff line change
@@ -0,0 +1,202 @@
"""Test the message passing module."""
import torch

from modules.data.utils.utils import load_manual_graph
from modules.transforms.liftings.graph2simplicial.independent_set_lifting import (
SimplicialIndependentSetsLifting,
)


class TestSimplicialIndependentSetsLiftingClass:
"""Test the SimplicialIndependentSetsLifting class."""

def setup_method(self):
# Load the graph
self.data = load_manual_graph()

# Initialise the class
self.lifting_unsigned = SimplicialIndependentSetsLifting(
complex_dim=3, signed=False
)

def test_lift_topology(self):
"""Test the lift_topology method."""

# Test the lift_topology method
lifted_data_unsigned = self.lifting_unsigned.forward(self.data.clone())

expected_incidence_1 = torch.tensor(
[
[
1.0,
1.0,
1.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
],
[
0.0,
0.0,
0.0,
1.0,
1.0,
1.0,
1.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
],
[
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
1.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
],
[
1.0,
0.0,
0.0,
1.0,
0.0,
0.0,
0.0,
0.0,
1.0,
1.0,
1.0,
0.0,
0.0,
0.0,
0.0,
],
[
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
1.0,
0.0,
0.0,
1.0,
1.0,
1.0,
0.0,
],
[
0.0,
1.0,
0.0,
0.0,
1.0,
0.0,
0.0,
0.0,
0.0,
1.0,
0.0,
1.0,
0.0,
0.0,
0.0,
],
[
0.0,
0.0,
1.0,
0.0,
0.0,
1.0,
0.0,
1.0,
0.0,
0.0,
0.0,
0.0,
1.0,
0.0,
1.0,
],
[
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
1.0,
0.0,
0.0,
0.0,
1.0,
0.0,
0.0,
1.0,
1.0,
],
]
)

assert (
abs(expected_incidence_1) == lifted_data_unsigned.incidence_1.to_dense()
).all(), "Something is wrong with unsigned incidence_1 (nodes to edges)."

expected_incidence_2 = torch.tensor(
[
[1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
[1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
[0.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0],
[0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0],
[0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0],
[0.0, 0.0, 1.0, 1.0, 0.0, 0.0, 0.0],
[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
[0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 0.0],
[1.0, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0],
[0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0],
[0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0],
[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0],
[0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0],
[0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0],
]
)

assert (
abs(expected_incidence_2) == lifted_data_unsigned.incidence_2.to_dense()
).all(), "Something is wrong with unsigned incidence_2 (edges to triangles)."

expected_incidence_3 = torch.empty(7, 0)
assert (
abs(expected_incidence_3) == lifted_data_unsigned.incidence_3.to_dense()
).all(), (
"Something is wrong with unsigned incidence_3 (triangles to tetrahedrons)."
)
Loading
Loading