Skip to content

Commit f20dbc6

Browse files
authored
v.cluster: Add tests (#5538)
1 parent 1ff6bf1 commit f20dbc6

File tree

1 file changed

+242
-0
lines changed

1 file changed

+242
-0
lines changed
+242
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,242 @@
1+
import os
2+
import tempfile
3+
from grass.gunittest.case import TestCase
4+
from grass.gunittest.main import test
5+
from grass.script import core as grass
6+
7+
8+
class TestVCluster(TestCase):
9+
@classmethod
10+
def setUpClass(cls):
11+
cls.runModule(
12+
"v.random",
13+
output="test_points",
14+
npoints=100,
15+
seed=42,
16+
overwrite=True,
17+
)
18+
19+
@classmethod
20+
def tearDownClass(cls):
21+
"""Clean up"""
22+
cls.runModule(
23+
"g.remove",
24+
type="vector",
25+
name="test_points,clustered",
26+
flags="f",
27+
)
28+
29+
def setUp(self):
30+
self.temp_files = []
31+
self.temp_points = self.create_temp_file(
32+
"1 1\n1 2\n2 1\n2 2\n3 3\n11 11\n11 12\n12 11\n12 12\n50 50\n50 51\n51 50\n51 51\n100 100"
33+
)
34+
35+
self.temp_3d = self.create_temp_file(
36+
"1 1 5\n1 2 7\n2 1 8\n2 2 0\n11 11 10\n11 12 10\n12 11 10\n12 12 10\n50 50 20\n50 51 20\n51 50 20\n51 51 20\n100 100 30"
37+
)
38+
39+
self.runModule(
40+
"v.in.ascii",
41+
input=self.temp_points,
42+
format="point",
43+
separator="space",
44+
output="test_points",
45+
overwrite=True,
46+
)
47+
48+
self.runModule(
49+
"v.in.ascii",
50+
input=self.temp_3d,
51+
format="point",
52+
z=3,
53+
flags="z",
54+
separator="space",
55+
output="test_points_3d",
56+
overwrite=True,
57+
)
58+
59+
def tearDown(self):
60+
"""Removes all temporary files created during the tests."""
61+
for temp_file in self.temp_files:
62+
os.remove(temp_file)
63+
64+
def create_temp_file(self, content):
65+
"""Creates a temporary file with the given content and returns its path."""
66+
with tempfile.NamedTemporaryFile(mode="w", delete=False) as temp_file:
67+
temp_file.write(content)
68+
temp_file_name = temp_file.name
69+
self.temp_files.append(temp_file_name)
70+
return temp_file_name
71+
72+
def get_cluster_info(self, map_name):
73+
# Export the clustered points to ASCII format
74+
ascii_output = grass.read_command(
75+
"v.out.ascii", layer=2, input=map_name, format="point", separator="comma"
76+
)
77+
78+
# print(ascii_output)
79+
# Parse the ASCII output to extract cluster IDs
80+
clusters = {}
81+
82+
for line in ascii_output.splitlines():
83+
if line.strip(): # Skip empty lines
84+
parts = line.split(",")
85+
if len(parts) >= 3:
86+
x, y, cluster_id = float(parts[0]), float(parts[1]), int(parts[2])
87+
if cluster_id != 0: # Skip noise points
88+
if cluster_id not in clusters:
89+
clusters[cluster_id] = []
90+
clusters[cluster_id].append((x, y))
91+
return clusters
92+
93+
def get_noise_points(self, map_name):
94+
ascii_output = grass.read_command(
95+
"v.out.ascii", layer=2, input=map_name, format="point", separator="comma"
96+
)
97+
98+
noise_points = []
99+
for line in ascii_output.splitlines():
100+
if line.strip():
101+
parts = line.split(",")
102+
if len(parts) >= 3:
103+
cluster_id = int(parts[2])
104+
if cluster_id == 0:
105+
noise_points.append(cluster_id)
106+
107+
return noise_points
108+
109+
def test_cluster_formation(self):
110+
"""Test DBSCAN clustering with proper attribute handling"""
111+
# Run clustering with clean table creation
112+
self.assertModule(
113+
"v.cluster",
114+
input="test_points",
115+
output="clustered",
116+
method="dbscan",
117+
distance=1.5,
118+
min=4,
119+
flags="b",
120+
overwrite=True,
121+
)
122+
123+
clusters = self.get_cluster_info("clustered")
124+
# print(clusters)
125+
self.assertGreater(len(clusters), 1)
126+
cluster_sizes = sorted([len(points) for _, points in clusters.items()])
127+
self.assertEqual(cluster_sizes, [4, 4, 5])
128+
129+
noise_points = self.get_noise_points("clustered")
130+
self.assertEqual(len(noise_points), 1)
131+
132+
def test_min_points(self):
133+
"""Testing the effect of the min points parameter on clustering"""
134+
self.assertModule(
135+
"v.cluster",
136+
input="test_points",
137+
output="clustered",
138+
method="dbscan",
139+
distance=1.5,
140+
min=5,
141+
flags="b",
142+
overwrite=True,
143+
)
144+
145+
clusters = self.get_cluster_info("clustered")
146+
self.assertEqual(len(clusters), 1)
147+
148+
def test_distance_threshold_effect(self):
149+
"""Test that distance threshold correctly affects cluster formation"""
150+
151+
self.assertModule(
152+
"v.cluster",
153+
input="test_points",
154+
output="clustered",
155+
method="dbscan",
156+
distance=1.5,
157+
min=4,
158+
flags="b",
159+
overwrite=True,
160+
)
161+
162+
clusters = self.get_cluster_info("clustered")
163+
nodes = len(clusters[1])
164+
# print(nodes)
165+
166+
self.assertModule(
167+
"v.cluster",
168+
input="test_points",
169+
output="clustered_20",
170+
method="dbscan",
171+
distance=20,
172+
min=4,
173+
flags="b",
174+
overwrite=True,
175+
)
176+
177+
clusters_20 = self.get_cluster_info("clustered_20")
178+
nodes_20 = len(clusters_20[1])
179+
180+
self.assertGreaterEqual(nodes_20, nodes)
181+
182+
def test_2d_flag(self):
183+
"""Test the effect of 2d flag on clustering for 3D points"""
184+
self.assertModule(
185+
"v.cluster",
186+
input="test_points_3d",
187+
output="clustered_3d",
188+
method="dbscan",
189+
distance=1.5,
190+
overwrite=True,
191+
)
192+
193+
self.assertVectorExists("clustered_3d")
194+
ascii_output = grass.read_command(
195+
"v.out.ascii",
196+
input="clustered_3d",
197+
format="point",
198+
layer=2,
199+
separator="comma",
200+
)
201+
202+
clusterIds_3d = set()
203+
for line in ascii_output.splitlines():
204+
if line.strip(): # Skip empty lines
205+
parts = line.split(",")
206+
if len(parts) >= 4:
207+
clusterIds_3d.add(parts[3])
208+
209+
# print(ascii_output)
210+
211+
self.assertModule(
212+
"v.cluster",
213+
input="test_points_3d",
214+
output="clustered_2d",
215+
method="dbscan",
216+
distance=1.5,
217+
min=4,
218+
flags="2b",
219+
overwrite=True,
220+
)
221+
222+
self.assertVectorExists("clustered_2d")
223+
ascii_2d = grass.read_command(
224+
"v.out.ascii",
225+
input="clustered_2d",
226+
format="point",
227+
layer=2,
228+
separator="comma",
229+
)
230+
231+
clusterIds_2d = set()
232+
for line in ascii_2d.splitlines():
233+
if line.strip(): # Skip empty lines
234+
parts = line.split(",")
235+
if len(parts) >= 4:
236+
clusterIds_2d.add(parts[3])
237+
238+
self.assertNotEqual(clusterIds_2d, clusterIds_3d)
239+
240+
241+
if __name__ == "__main__":
242+
test()

0 commit comments

Comments
 (0)