Skip to content

Add KNN graph lifting (point cloud to graph) #13

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 30 additions & 0 deletions modules/data/load/loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
load_cell_complex_dataset,
load_hypergraph_pickle_dataset,
load_manual_graph,
load_random_points,
load_simplicial_dataset,
)

Expand Down Expand Up @@ -204,3 +205,32 @@ def load(
torch_geometric.data.Dataset object containing the loaded data.
"""
return load_hypergraph_pickle_dataset(self.parameters)


class PointCloudLoader(AbstractLoader):
r"""Loader for point-cloud dataset.

Parameters
----------
parameters: DictConfig
Configuration parameters
"""

def __init__(self, parameters: DictConfig):
super().__init__(parameters)
self.parameters = parameters

def load(self) -> torch_geometric.data.Dataset:
r"""Load point-cloud dataset.

Parameters
----------
None

Returns
-------
torch_geometric.data.Dataset
torch_geometric.data.Dataset object containing the loaded data.
"""
data = load_random_points(num_classes=self.cfg["num_classes"])
return CustomDataset([data], self.cfg["data_dir"])
25 changes: 17 additions & 8 deletions modules/data/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,16 +50,16 @@ def get_complex_connectivity(complex, max_rank, signed=False):
)
except ValueError: # noqa: PERF203
if connectivity_info == "incidence":
connectivity[f"{connectivity_info}_{rank_idx}"] = (
generate_zero_sparse_connectivity(
m=practical_shape[rank_idx - 1], n=practical_shape[rank_idx]
)
connectivity[
f"{connectivity_info}_{rank_idx}"
] = generate_zero_sparse_connectivity(
m=practical_shape[rank_idx - 1], n=practical_shape[rank_idx]
)
else:
connectivity[f"{connectivity_info}_{rank_idx}"] = (
generate_zero_sparse_connectivity(
m=practical_shape[rank_idx], n=practical_shape[rank_idx]
)
connectivity[
f"{connectivity_info}_{rank_idx}"
] = generate_zero_sparse_connectivity(
m=practical_shape[rank_idx], n=practical_shape[rank_idx]
)
connectivity["shape"] = practical_shape
return connectivity
Expand Down Expand Up @@ -283,6 +283,15 @@ def load_hypergraph_pickle_dataset(cfg):
return data


def load_random_points(num_classes: int = 2, num_points: int = 8, seed: int = 2024):
"""Create a toy point cloud dataset"""
rng = np.random.default_rng(seed)
points = torch.tensor(rng.rand(num_points, 2))
classes = torch.tensor(rng.randint(num_classes, size=num_points))
features = torch.tensor(rng.randint(3, size=(num_points, 1)) * 1.0).float()
return torch_geometric.data.Data(x=features, y=classes, pos=points)


def load_manual_graph():
"""Create a manual graph for testing purposes."""
# Define the vertices (just 8 vertices)
Expand Down
3 changes: 3 additions & 0 deletions modules/transforms/data_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from modules.transforms.liftings.graph2simplicial.clique_lifting import (
SimplicialCliqueLifting,
)
from modules.transforms.liftings.pointcloud2graph.knn_lifting import GraphKNNLifting

TRANSFORMS = {
# Graph -> Hypergraph
Expand All @@ -23,6 +24,8 @@
"SimplicialCliqueLifting": SimplicialCliqueLifting,
# Graph -> Cell Complex
"CellCycleLifting": CellCycleLifting,
# Point-cloud -> Graph
"GraphKNNLifting": GraphKNNLifting,
# Feature Liftings
"ProjectionSum": ProjectionSum,
# Data Manipulations
Expand Down
38 changes: 38 additions & 0 deletions modules/transforms/liftings/pointcloud2graph/knn_lifting.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import torch_geometric

from modules.transforms.liftings.pointcloud2graph.base import PointCloud2GraphLifting


class GraphKNNLifting(PointCloud2GraphLifting):
r"""Lifts point cloud data to graph by creating its k-NN graph

Parameters
----------
**kwargs : optional
Additional arguments for the class
"""

def __init__(self, k: int = 1, **kwargs):
super().__init__(**kwargs)
self.k = k
self.transform = torch_geometric.transforms.KNNGraph(k=k)

def lift_topology(self, data: torch_geometric.data.Data) -> dict:
r"""Lifts a point cloud dataset to a graph by constructing its k-NN graph.

Parameters
----------
data : torch_geometric.data.Data
The input data to be lifted

Returns
-------
dict
The lifted topology
"""
graph_data = self.transform(data)
return {
"shape": [graph_data.x.shape[0], graph_data.edge_index.shape[1]],
"edge_index": graph_data.edge_index,
"num_nodes": graph_data.x.shape[0],
}
4 changes: 2 additions & 2 deletions modules/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ def describe_data(dataset: torch_geometric.data.Dataset, idx_sample: int = 0):
if hasattr(data, "num_edges"):
complex_dim.append(data.num_edges)
features_dim.append(data.num_edge_features)
elif hasattr(data, "edge_index"):
elif hasattr(data, "edge_index") and data.edge_index is not None:
complex_dim.append(data.edge_index.shape[1])
features_dim.append(data.edge_attr.shape[1])
# Check if the data object contains hyperedges
Expand Down Expand Up @@ -249,7 +249,7 @@ def sort_vertices_ccw(vertices):
edges.append(torch.where(edge != 0)[0].numpy())
edge_mapper[edge_idx] = sorted(node_idxs)
edges = np.array(edges)
elif hasattr(data, "edge_index"):
elif hasattr(data, "edge_index") and data.edge_index is not None:
edges = data.edge_index.T.tolist()
edge_mapper = {}
for e, edge in enumerate(edges):
Expand Down
Loading