Skip to content

Commit 6dec5a7

Browse files
committed
Fix bug in loaders
1 parent 5f650c1 commit 6dec5a7

File tree

1 file changed

+7
-4
lines changed

1 file changed

+7
-4
lines changed

modules/data/load/loaders.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
load_manual_graph,
2323
load_manual_mol,
2424
load_manual_points,
25+
load_point_cloud,
2526
load_random_points,
2627
load_simplicial_dataset,
2728
)
@@ -291,15 +292,17 @@ def load(self) -> torch_geometric.data.Dataset:
291292
feature_generator=self.feature_generator,
292293
target_generator=self.target_generator,
293294
)
294-
elif (
295-
self.parameters.data_name == "random_points"
296-
or self.parameters.data_name == "toy_point_cloud"
297-
):
295+
elif self.parameters.data_name == "random_points":
298296
data = load_random_points(
299297
dim=self.parameters["dim"],
300298
num_classes=self.parameters["num_classes"],
301299
num_samples=self.parameters["num_samples"],
302300
)
301+
elif self.parameters.data_name == "toy_point_cloud":
302+
data = load_point_cloud(
303+
num_classes=self.parameters["num_classes"],
304+
num_points=self.parameters["num_samples"],
305+
)
303306
elif self.parameters.data_name == "manual_points":
304307
data = load_manual_points()
305308
else:

0 commit comments

Comments
 (0)