Skip to content

Commit 978ea84

Browse files
committed
Fix bug with toy point cloud
1 parent 6dec5a7 commit 978ea84

File tree

2 files changed

+5
-5
lines changed

2 files changed

+5
-5
lines changed

configs/datasets/toy_point_cloud.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ data_name: toy_point_cloud
44
data_dir: datasets/${data_domain}/${data_type}
55

66
# Dataset parameters
7-
num_points: 8
7+
num_samples: 8
88
num_classes: 2
99

1010
num_features: 1

modules/data/utils/utils.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -298,17 +298,17 @@ def load_hypergraph_pickle_dataset(cfg):
298298

299299

300300
def load_point_cloud(
301-
num_classes: int = 2, num_points: int = 18, seed: int = 42
301+
num_classes: int = 2, num_samples: int = 18, seed: int = 42
302302
):
303303
"""Create a toy point cloud dataset"""
304304
rng = np.random.default_rng(seed)
305305

306-
points = torch.tensor(rng.random((num_points, 2)), dtype=torch.float)
306+
points = torch.tensor(rng.random((num_samples, 2)), dtype=torch.float)
307307
classes = torch.tensor(
308-
rng.integers(num_classes, size=num_points), dtype=torch.long
308+
rng.integers(num_classes, size=num_samples), dtype=torch.long
309309
)
310310
features = torch.tensor(
311-
rng.integers(3, size=(num_points, 1)), dtype=torch.float
311+
rng.integers(3, size=(num_samples, 1)), dtype=torch.float
312312
)
313313

314314
return torch_geometric.data.Data(x=features, y=classes, pos=points)

0 commit comments

Comments
 (0)