-
-
Notifications
You must be signed in to change notification settings - Fork 2.4k
/
Copy pathk_means.rs
90 lines (66 loc) · 2.35 KB
/
k_means.rs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
use rand::random;
fn get_distance(p1: &(f64, f64), p2: &(f64, f64)) -> f64 {
let dx: f64 = p1.0 - p2.0;
let dy: f64 = p1.1 - p2.1;
((dx * dx) + (dy * dy)).sqrt()
}
fn find_nearest(data_point: &(f64, f64), centroids: &[(f64, f64)]) -> u32 {
let mut cluster: u32 = 0;
for (i, c) in centroids.iter().enumerate() {
let d1 = get_distance(data_point, c);
let d2 = get_distance(data_point, ¢roids[cluster as usize]);
if d1 < d2 {
cluster = i as u32;
}
}
cluster
}
pub fn k_means(data_points: Vec<(f64, f64)>, n_clusters: usize, max_iter: i32) -> Option<Vec<u32>> {
if data_points.len() < n_clusters {
return None;
}
let mut centroids: Vec<(f64, f64)> = Vec::new();
let mut labels: Vec<u32> = vec![0; data_points.len()];
for _ in 0..n_clusters {
let x: f64 = random::<f64>();
let y: f64 = random::<f64>();
centroids.push((x, y));
}
let mut count_iter: i32 = 0;
while count_iter < max_iter {
let mut new_centroids_position: Vec<(f64, f64)> = vec![(0.0, 0.0); n_clusters];
let mut new_centroids_num: Vec<u32> = vec![0; n_clusters];
for (i, d) in data_points.iter().enumerate() {
let nearest_cluster = find_nearest(d, ¢roids);
labels[i] = nearest_cluster;
new_centroids_position[nearest_cluster as usize].0 += d.0;
new_centroids_position[nearest_cluster as usize].1 += d.1;
new_centroids_num[nearest_cluster as usize] += 1;
}
for i in 0..centroids.len() {
if new_centroids_num[i] == 0 {
continue;
}
let new_x: f64 = new_centroids_position[i].0 / new_centroids_num[i] as f64;
let new_y: f64 = new_centroids_position[i].1 / new_centroids_num[i] as f64;
centroids[i] = (new_x, new_y);
}
count_iter += 1;
}
Some(labels)
}
#[cfg(test)]
mod test {
use super::*;
#[test]
fn test_k_means() {
let mut data_points: Vec<(f64, f64)> = vec![];
let n_points: usize = 1000;
for _ in 0..n_points {
let x: f64 = random::<f64>() * 100.0;
let y: f64 = random::<f64>() * 100.0;
data_points.push((x, y));
}
println!("{:?}", k_means(data_points, 10, 100).unwrap_or_default());
}
}