Skip to content

Commit 63ba850

Browse files
committed
Merge branch 'PatRyg99-pr/voronoi-lifting'
2 parents 346bb93 + 0df2c38 commit 63ba850

File tree

10 files changed

+462
-9
lines changed

10 files changed

+462
-9
lines changed
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
data_domain: pointcloud
2+
data_type: toy_dataset
3+
data_name: random_pointcloud
4+
data_dir: datasets/${data_domain}/${data_type}/${data_name}
5+
6+
# Dataset parameters
7+
pos_to_x: True
8+
num_features: 3
9+
num_points: 20
10+
dim: 3
11+
num_classes: 2
12+
task: classification
13+
loss_type: cross_entropy
14+
monitor_metric: accuracy
15+
task_level: node

configs/datasets/stanford_bunny.yaml

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
data_domain: pointcloud
2+
data_type: toy_dataset
3+
data_name: stanford_bunny
4+
data_dir: datasets/${data_domain}/${data_type}/${data_name}
5+
6+
# Dataset parameters
7+
pos_to_x: True
8+
num_features: 3
9+
num_classes: 1
10+
task: regression
11+
loss_type: mse
12+
monitor_metric: mae
13+
task_level: graph
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
transform_type: 'lifting'
2+
transform_name: "VoronoiLifting"
3+
support_ratio: 0.005
4+
feature_lifting: ProjectionSum

modules/data/load/loaders.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
load_manual_mol,
2525
load_manual_points,
2626
load_point_cloud,
27+
load_pointcloud_dataset,
2728
load_random_points,
2829
load_simplicial_dataset,
2930
)
@@ -278,7 +279,6 @@ def load(self) -> torch_geometric.data.Dataset:
278279
----------
279280
None
280281
281-
282282
Returns
283283
-------
284284
torch_geometric.data.Dataset
@@ -316,3 +316,8 @@ def load(self) -> torch_geometric.data.Dataset:
316316
)
317317

318318
return CustomDataset([data], self.data_dir)
319+
self.data_dir = os.path.join(root_folder, self.parameters["data_dir"])
320+
321+
return CustomDataset(
322+
[load_pointcloud_dataset(self.parameters)], self.data_dir
323+
)

modules/data/utils/utils.py

Lines changed: 78 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -405,7 +405,6 @@ def load_manual_graph():
405405
)
406406

407407

408-
409408
def load_k4_graph() -> torch_geometric.data.Data:
410409
"""K_4 is a complete graph with 4 vertices."""
411410
vertices = [i for i in range(4)]
@@ -475,9 +474,13 @@ def load_8_vertex_cubic_graphs() -> list[torch_geometric.data.Data]:
475474
n = 8 if i < 5 else 10
476475
vertices = [i for i in range(n)]
477476
x = (
478-
torch.tensor([1, 5, 10, 50, 100, 500, 1000, 5000]).unsqueeze(1).float()
477+
torch.tensor([1, 5, 10, 50, 100, 500, 1000, 5000])
478+
.unsqueeze(1)
479+
.float()
479480
if i < 5
480-
else torch.tensor([1, 5, 10, 50, 100, 500, 1000, 5000, 10000, 50000])
481+
else torch.tensor(
482+
[1, 5, 10, 50, 100, 500, 1000, 5000, 10000, 50000]
483+
)
481484
.unsqueeze(1)
482485
.float()
483486
)
@@ -494,11 +497,14 @@ def load_8_vertex_cubic_graphs() -> list[torch_geometric.data.Data]:
494497
G.to_undirected()
495498
edge_list = torch.Tensor(list(G.edges())).T.long()
496499

497-
data = torch_geometric.data.Data(x=x, edge_index=edge_list, num_nodes=n, y=y)
500+
data = torch_geometric.data.Data(
501+
x=x, edge_index=edge_list, num_nodes=n, y=y
502+
)
498503

499504
list_data.append(data)
500505
return list_data
501506

507+
502508
def load_manual_mol():
503509
"""Create a manual graph for testing the ring implementation.
504510
Actually is the 471 molecule of QM9 dataset."""
@@ -620,7 +626,6 @@ def load_manual_mol():
620626
)
621627

622628

623-
624629
def get_Planetoid_pyg(cfg):
625630
r"""Loads Planetoid graph datasets from torch_geometric.
626631
@@ -822,3 +827,71 @@ def load_manual_points():
822827
x = torch.ones_like(pos, dtype=torch.float)
823828
y = torch.randint(0, 2, (pos.shape[0],), dtype=torch.float)
824829
return torch_geometric.data.Data(x=x, y=y, pos=pos, complex_dim=0)
830+
831+
832+
def load_pointcloud_dataset(cfg):
833+
r"""Loads point cloud datasets.
834+
835+
Parameters
836+
----------
837+
cfg : DictConfig
838+
Configuration parameters.
839+
840+
Returns
841+
-------
842+
torch_geometric.data.Data
843+
Point cloud dataset.
844+
"""
845+
# Define the path to the data directory
846+
root_folder = rootutils.find_root()
847+
data_dir = osp.join(root_folder, cfg["data_dir"])
848+
849+
if cfg["data_name"] == "random_pointcloud":
850+
num_points, dim = cfg["num_points"], cfg["dim"]
851+
pos = torch.rand((num_points, dim))
852+
elif cfg["data_name"] == "stanford_bunny":
853+
pos = fetch_bunny(
854+
file_path=osp.join(data_dir, "stanford_bunny.npy"),
855+
accept_license=False,
856+
)
857+
num_points = len(pos)
858+
pos = torch.tensor(pos)
859+
860+
if cfg.pos_to_x:
861+
return torch_geometric.data.Data(
862+
x=pos, pos=pos, num_nodes=num_points, num_features=pos.size(1)
863+
)
864+
865+
return torch_geometric.data.Data(
866+
pos=pos, num_nodes=num_points, num_features=0
867+
)
868+
869+
870+
def load_manual_pointcloud(pos_to_x: bool = False):
871+
"""Create a manual pointcloud for testing purposes."""
872+
# Define the positions
873+
pos = torch.tensor(
874+
[
875+
[0, 0, 0],
876+
[0, 0, 1],
877+
[0, 1, 0],
878+
[10, 0, 0],
879+
[10, 0, 1],
880+
[10, 1, 0],
881+
[10, 1, 1],
882+
[20, 0, 0],
883+
[20, 0, 1],
884+
[20, 1, 0],
885+
[20, 1, 1],
886+
[30, 0, 0],
887+
]
888+
).float()
889+
890+
if pos_to_x:
891+
return torch_geometric.data.Data(
892+
x=pos, pos=pos, num_nodes=pos.size(0), num_features=pos.size(1)
893+
)
894+
895+
return torch_geometric.data.Data(
896+
pos=pos, num_nodes=pos.size(0), num_features=0
897+
)

modules/transforms/data_transform.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,9 @@
4444
from modules.transforms.liftings.graph2simplicial.vietoris_rips_lifting import (
4545
SimplicialVietorisRipsLifting,
4646
)
47+
from modules.transforms.liftings.pointcloud2hypergraph.voronoi_lifting import (
48+
VoronoiLifting,
49+
)
4750
from modules.transforms.liftings.pointcloud2simplicial.alpha_complex_lifting import (
4851
AlphaComplexLifting,
4952
)
@@ -72,6 +75,8 @@
7275
"AlphaComplexLifting": AlphaComplexLifting,
7376
# Point-cloud -> Simplicial Complex
7477
"DelaunayLifting": DelaunayLifting,
78+
# Pointcloud -> Hypergraph
79+
"VoronoiLifting": VoronoiLifting,
7580
# Feature Liftings
7681
"ProjectionSum": ProjectionSum,
7782
# Data Manipulations
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
import torch
2+
import torch_geometric
3+
from torch_cluster import fps, knn
4+
5+
from modules.transforms.liftings.pointcloud2hypergraph.base import (
6+
PointCloud2HypergraphLifting,
7+
)
8+
9+
10+
class VoronoiLifting(PointCloud2HypergraphLifting):
11+
r"""Lifts pointcloud to Farthest-point Voronoi graph.
12+
13+
Parameters
14+
----------
15+
support_ratio : float
16+
Ratio of points to sample with FPS to form voronoi support set.
17+
**kwargs : optional
18+
Additional arguments for the class.
19+
"""
20+
21+
def __init__(self, support_ratio: float, **kwargs):
22+
super().__init__(**kwargs)
23+
self.support_ratio = support_ratio
24+
25+
def lift_topology(self, data: torch_geometric.data.Data) -> dict:
26+
r"""Lifts pointcloud to voronoi graph induced by Farthest Point Sampling (FPS) support set.
27+
28+
Parameters
29+
----------
30+
data : torch_geometric.data.Data
31+
The input data to be lifted.
32+
33+
Returns
34+
-------
35+
dict
36+
The lifted topology.
37+
"""
38+
39+
# Sample FPS induced Voronoi graph
40+
support_idcs = fps(data.pos, ratio=self.support_ratio)
41+
target_idcs, source_idcs = knn(data.pos[support_idcs], data.pos, k=1)
42+
43+
# Construct incidence matrix
44+
incidence_matrix = torch.sparse_coo_tensor(
45+
torch.stack((target_idcs, source_idcs)),
46+
torch.ones(source_idcs.numel()),
47+
size=(data.num_nodes, support_idcs.numel()),
48+
)
49+
50+
return {
51+
"incidence_hyperedges": incidence_matrix,
52+
"num_hyperedges": incidence_matrix.size(1),
53+
"x_0": data.x,
54+
}

modules/utils/utils.py

Lines changed: 27 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import rootutils
1111
import torch
1212
import torch_geometric
13+
from matplotlib import cm
1314
from matplotlib.patches import Polygon
1415

1516
plt.rcParams["text.usetex"] = bool(shutil.which("latex"))
@@ -148,19 +149,22 @@ def describe_data(dataset: torch_geometric.data.Dataset, idx_sample: int = 0):
148149
if hasattr(data, "num_nodes"):
149150
complex_dim.append(data.num_nodes)
150151
features_dim.append(data.num_node_features)
151-
elif hasattr(data, "x"):
152+
elif hasattr(data, "x") and data.x is not None:
152153
complex_dim.append(data.x.shape[0])
153154
features_dim.append(data.x.shape[1])
154155
else:
155156
raise ValueError(
156157
"Data object does not contain any vertices/points."
157158
)
158-
if hasattr(data, "num_edges") and hasattr(data, "num_edge_features"):
159+
160+
if hasattr(data, "num_edges") and data.num_edges > 0:
159161
complex_dim.append(data.num_edges)
160162
features_dim.append(data.num_edge_features)
161-
elif hasattr(data, "edge_index") and (data.edge_index is not None):
163+
164+
elif hasattr(data, "edge_index") and data.edge_index is not None:
162165
complex_dim.append(data.edge_index.shape[1])
163166
features_dim.append(data.edge_attr.shape[1])
167+
164168
# Check if the data object contains hyperedges
165169
hyperedges = False
166170
if hasattr(data, "x_hyperedges"):
@@ -201,6 +205,7 @@ def describe_data(dataset: torch_geometric.data.Dataset, idx_sample: int = 0):
201205
hasattr(data, "edge_index")
202206
and hasattr(data, "x")
203207
and (data.edge_index is not None)
208+
and data.x is not None
204209
):
205210
connected_nodes = torch.unique(data.edge_index)
206211
isolated_nodes = []
@@ -587,3 +592,22 @@ def describe_hypergraph(data: torch_geometric.data.Data):
587592
if he_idx >= 10:
588593
print("...")
589594
break
595+
596+
597+
def plot_pointcloud_voronoi(
598+
dataset, idx_sample: int = 0, azim: float = 180, roll: float = -90
599+
):
600+
fig = plt.figure(figsize=(8, 8))
601+
ax = fig.add_subplot(111, projection="3d")
602+
603+
data = dataset.get(idx_sample % len(dataset))
604+
points = data.pos
605+
606+
if hasattr(data, "incidence_hyperedges"):
607+
color = np.array(data.incidence_hyperedges.coalesce().indices()[1])
608+
else:
609+
color = np.ones(len(points))
610+
611+
ax.scatter(*points.T, s=1.0, c=color, cmap=cm.flag)
612+
ax.view_init(elev=10.0, azim=azim, roll=roll)
613+
plt.show()
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
"""Test the message passing module."""
2+
3+
import torch
4+
5+
from modules.data.utils.utils import load_manual_pointcloud
6+
from modules.transforms.liftings.pointcloud2hypergraph.voronoi_lifting import (
7+
VoronoiLifting,
8+
)
9+
10+
11+
class TestVoronoiLifting:
12+
"""Test the SimplicialCliqueLifting class."""
13+
14+
def setup_method(self):
15+
# Load the graph
16+
self.data = load_manual_pointcloud(pos_to_x=True)
17+
18+
# Initialise the VoronoiLifting class
19+
self.lifting = VoronoiLifting(support_ratio=0.26)
20+
21+
def test_lift_topology(self):
22+
"""Test the lift_topology method."""
23+
24+
# Test the lift_topology method
25+
lifted_data = self.lifting.forward(self.data.clone())
26+
27+
expected_cluster_sizes = torch.tensor([1, 3, 4, 4])
28+
cluster_sizes = torch.sort(
29+
torch.unique(
30+
lifted_data.incidence_hyperedges.coalesce().indices()[1],
31+
return_counts=True,
32+
)[1]
33+
)[0]
34+
35+
assert (
36+
expected_cluster_sizes == cluster_sizes
37+
).all(), "Something is wrong with edge_index (mismatched cluster sizes)."

0 commit comments

Comments
 (0)