Skip to content

Path based lifting (Graph to Combinatorial) #43

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 15 commits into
base: main
Choose a base branch
from
5 changes: 5 additions & 0 deletions configs/models/combinatorial/hmc.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
in_channels: null # This will be set by the dataset
hidden_channels: 32
out_channels: null # This will be set by the dataset
n_layers: 2
negative_slope: 0.2
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
transform_type: 'lifting'
transform_name: "CombinatorialPathLifting"
preserve_edge_attr: False
feature_lifting: ProjectionSum
24 changes: 16 additions & 8 deletions modules/data/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,23 +43,31 @@ def get_complex_connectivity(complex, max_rank, signed=False):
"hodge_laplacian",
]:
try:
# if isinstance(complex, CombinatorialComplex):
# matrix_method = f"get_{connectivity_info}_matrix"
# else:
# matrix_method = f"{connectivity_info}_matrix"
# connectivity[f"{connectivity_info}_{rank_idx}"] = from_sparse(
# getattr(complex, matrix_method)(rank=rank_idx, signed=signed)
# )

connectivity[f"{connectivity_info}_{rank_idx}"] = from_sparse(
getattr(complex, f"{connectivity_info}_matrix")(
rank=rank_idx, signed=signed
)
)
except ValueError: # noqa: PERF203
if connectivity_info == "incidence":
connectivity[f"{connectivity_info}_{rank_idx}"] = (
generate_zero_sparse_connectivity(
m=practical_shape[rank_idx - 1], n=practical_shape[rank_idx]
)
connectivity[
f"{connectivity_info}_{rank_idx}"
] = generate_zero_sparse_connectivity(
m=practical_shape[rank_idx - 1], n=practical_shape[rank_idx]
)
else:
connectivity[f"{connectivity_info}_{rank_idx}"] = (
generate_zero_sparse_connectivity(
m=practical_shape[rank_idx], n=practical_shape[rank_idx]
)
connectivity[
f"{connectivity_info}_{rank_idx}"
] = generate_zero_sparse_connectivity(
m=practical_shape[rank_idx], n=practical_shape[rank_idx]
)
connectivity["shape"] = practical_shape
return connectivity
Expand Down
79 changes: 79 additions & 0 deletions modules/models/combinatorial/hmc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
import torch
from topomodelx.nn.combinatorial.hmc import HMC


class HMCModel(torch.nn.Module):
r"""A simple CWN model that runs over combinatorial complex data.
Note that some parameters are defined by the considered dataset.

Parameters
----------
model_config : Dict | DictConfig
Model configuration.
dataset_config : Dict | DictConfig
Dataset configuration.
"""

def __init__(self, model_config, dataset_config):
in_channels_0 = (
dataset_config["num_features"]
if isinstance(dataset_config["num_features"], int)
else dataset_config["num_features"][0]
)
n_layers = model_config["n_layers"]
hidden_channels = model_config["hidden_channels"]
out_channels = dataset_config["num_classes"]

channels_per_layer = [
[
[in_channels_0, in_channels_0, in_channels_0],
[hidden_channels, hidden_channels, hidden_channels],
[hidden_channels, hidden_channels, hidden_channels],
]
]
rest = [
[
[hidden_channels for _ in range(3)],
[hidden_channels for _ in range(3)],
[hidden_channels for _ in range(3)],
]
for __ in range(1, n_layers)
]
channels_per_layer.extend(rest)

negative_slope = model_config["negative_slope"]
super().__init__()
self.base_model = HMC(
channels_per_layer=channels_per_layer, negative_slope=negative_slope
)
self.linear_0 = torch.nn.Linear(hidden_channels, out_channels)
self.linear_1 = torch.nn.Linear(hidden_channels, out_channels)
self.linear_2 = torch.nn.Linear(hidden_channels, out_channels)

def forward(self, data):
r"""Forward pass of the model.

Parameters
----------
data : torch_geometric.data.Data
Input data.

Returns
-------
tuple of torch.Tensor
Output tensor.
"""
x_0, x_1, x_2 = self.base_model(
data.x_0,
data.x_1,
data.x_2,
data.adjacency_0,
data.adjacency_1,
data.adjacency_2,
data.incidence_1,
data.incidence_2,
)
x_0 = self.linear_0(x_0)
x_1 = self.linear_1(x_1)
x_2 = self.linear_2(x_2)
return x_0, x_1, x_2
7 changes: 6 additions & 1 deletion modules/transforms/data_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@
)
from modules.transforms.feature_liftings.feature_liftings import ProjectionSum
from modules.transforms.liftings.graph2cell.cycle_lifting import CellCycleLifting
from modules.transforms.liftings.graph2combinatorial.path_lifting import (
CombinatorialPathLifting,
)
from modules.transforms.liftings.graph2hypergraph.knn_lifting import (
HypergraphKNNLifting,
)
Expand All @@ -23,7 +26,9 @@
"SimplicialCliqueLifting": SimplicialCliqueLifting,
# Graph -> Cell Complex
"CellCycleLifting": CellCycleLifting,
# Feature Liftings
# Graph -> Combinatorial Complex
"CombinatorialPathLifting": CombinatorialPathLifting,
# Feature Lifting
"ProjectionSum": ProjectionSum,
# Data Manipulations
"Identity": IdentityTransform,
Expand Down
2 changes: 1 addition & 1 deletion modules/transforms/feature_liftings/feature_liftings.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def lift_features(
-------
torch_geometric.data.Data | dict
The lifted data."""
keys = sorted([key.split("_")[1] for key in data.keys() if "incidence" in key]) # noqa : SIM118
keys = sorted([key.split("_")[1] for key in data if "incidence" in key])
for elem in keys:
if f"x_{elem}" not in data:
idx_to_project = 0 if elem == "hyperedges" else int(elem) - 1
Expand Down
3 changes: 2 additions & 1 deletion modules/transforms/liftings/graph2cell/cycle_lifting.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,10 @@ def lift_topology(self, data: torch_geometric.data.Data) -> dict:

# Eliminate self-loop cycles
cycles = [cycle for cycle in cycles if len(cycle) != 1]
# Eliminate cycles that are greater than the max_cell_lenght
# Eliminate cycles that are greater than the max_cell_length
if self.max_cell_length is not None:
cycles = [cycle for cycle in cycles if len(cycle) <= self.max_cell_length]
if len(cycles) != 0:
cell_complex.add_cells_from(cycles, rank=self.complex_dim)

return self._get_lifted_topology(cell_complex, G)
Loading
Loading