Skip to content

Commit 47ea8bd

Browse files
feat: add K-Nearest Neighbors classification algorithm in machine learning (#993)
1 parent b705fd7 commit 47ea8bd

3 files changed

Lines changed: 129 additions & 0 deletions

File tree

DIRECTORY.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -201,6 +201,7 @@
201201
* Machine Learning
202202
* [Cholesky](https://github.com/TheAlgorithms/Rust/blob/master/src/machine_learning/cholesky.rs)
203203
* [K-Means](https://github.com/TheAlgorithms/Rust/blob/master/src/machine_learning/k_means.rs)
204+
* [K-Nearest Neighbors](https://github.com/TheAlgorithms/Rust/blob/master/src/machine_learning/k_nearest_neighbors.rs)
204205
* [Linear Regression](https://github.com/TheAlgorithms/Rust/blob/master/src/machine_learning/linear_regression.rs)
205206
* [Logistic Regression](https://github.com/TheAlgorithms/Rust/blob/master/src/machine_learning/logistic_regression.rs)
206207
* Loss Function
Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,126 @@
1+
/// K-Nearest Neighbors (KNN) algorithm for classification.
2+
/// KNN is a simple, instance-based learning algorithm that classifies
3+
/// a data point based on the majority class of its k nearest neighbors.
4+
5+
fn euclidean_distance(p1: &[f64], p2: &[f64]) -> f64 {
6+
if p1.len() != p2.len() {
7+
return f64::INFINITY;
8+
}
9+
10+
p1.iter()
11+
.zip(p2.iter())
12+
.map(|(a, b)| (a - b).powi(2))
13+
.sum::<f64>()
14+
.sqrt()
15+
}
16+
17+
pub fn k_nearest_neighbors(
18+
training_data: Vec<(Vec<f64>, f64)>,
19+
test_point: Vec<f64>,
20+
k: usize,
21+
) -> Option<f64> {
22+
if training_data.is_empty() || k == 0 || k > training_data.len() {
23+
return None;
24+
}
25+
26+
let mut distances: Vec<(f64, f64)> = training_data
27+
.iter()
28+
.map(|(features, label)| (euclidean_distance(&test_point, features), *label))
29+
.collect();
30+
31+
distances.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal));
32+
33+
let k_nearest = &distances[..k];
34+
35+
let mut label_counts: Vec<(f64, usize)> = Vec::new();
36+
for (_, label) in k_nearest {
37+
let found = label_counts
38+
.iter_mut()
39+
.find(|(l, _)| (l - label).abs() < 1e-10);
40+
if let Some((_, count)) = found {
41+
*count += 1;
42+
} else {
43+
label_counts.push((*label, 1));
44+
}
45+
}
46+
47+
label_counts
48+
.iter()
49+
.max_by_key(|(_, count)| *count)
50+
.map(|(label, _)| *label)
51+
}
52+
53+
#[cfg(test)]
54+
mod tests {
55+
use super::*;
56+
57+
#[test]
58+
fn test_standard_knn() {
59+
let training_data = vec![
60+
(vec![0.0, 0.0], 0.0),
61+
(vec![1.0, 0.0], 0.0),
62+
(vec![0.0, 1.0], 0.0),
63+
(vec![5.0, 5.0], 1.0),
64+
(vec![6.0, 5.0], 1.0),
65+
(vec![5.0, 6.0], 1.0),
66+
];
67+
68+
let test_point = vec![0.5, 0.5];
69+
let result = k_nearest_neighbors(training_data.clone(), test_point, 3);
70+
assert_eq!(result, Some(0.0));
71+
72+
let test_point = vec![5.5, 5.5];
73+
let result = k_nearest_neighbors(training_data, test_point, 3);
74+
assert_eq!(result, Some(1.0));
75+
}
76+
77+
#[test]
78+
fn test_one_dimensional_knn() {
79+
let training_data = vec![
80+
(vec![1.0], 0.0),
81+
(vec![2.0], 0.0),
82+
(vec![3.0], 0.0),
83+
(vec![8.0], 1.0),
84+
(vec![9.0], 1.0),
85+
(vec![10.0], 1.0),
86+
];
87+
88+
let test_point = vec![2.5];
89+
let result = k_nearest_neighbors(training_data, test_point, 3);
90+
assert_eq!(result, Some(0.0));
91+
}
92+
93+
#[test]
94+
fn test_knn_empty_data() {
95+
let training_data = vec![];
96+
let test_point = vec![1.0, 2.0];
97+
let result = k_nearest_neighbors(training_data, test_point, 3);
98+
assert_eq!(result, None);
99+
}
100+
101+
#[test]
102+
fn test_knn_invalid_k() {
103+
let training_data = vec![(vec![1.0], 0.0), (vec![2.0], 1.0)];
104+
let test_point = vec![1.5];
105+
106+
// k = 0 should return None
107+
let result = k_nearest_neighbors(training_data.clone(), test_point.clone(), 0);
108+
assert_eq!(result, None);
109+
110+
// k > training_data.len() should return None
111+
let result = k_nearest_neighbors(training_data, test_point, 10);
112+
assert_eq!(result, None);
113+
}
114+
115+
#[test]
116+
fn test_euclidean_distance_different_dimensions() {
117+
let training_data = vec![
118+
(vec![1.0, 2.0], 0.0),
119+
(vec![2.0, 3.0], 0.0),
120+
(vec![5.0], 1.0),
121+
];
122+
let test_point = vec![1.5, 2.5];
123+
let result = k_nearest_neighbors(training_data, test_point, 2);
124+
assert_eq!(result, Some(0.0));
125+
}
126+
}

src/machine_learning/mod.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
11
mod cholesky;
22
mod k_means;
3+
mod k_nearest_neighbors;
34
mod linear_regression;
45
mod logistic_regression;
56
mod loss_function;
67
mod optimization;
78

89
pub use self::cholesky::cholesky;
910
pub use self::k_means::k_means;
11+
pub use self::k_nearest_neighbors::k_nearest_neighbors;
1012
pub use self::linear_regression::linear_regression;
1113
pub use self::logistic_regression::logistic_regression;
1214
pub use self::loss_function::average_margin_ranking_loss;

0 commit comments

Comments
 (0)