From b37e9d735a1341c683b2a95ed870974584e61af7 Mon Sep 17 00:00:00 2001 From: Dopamine Date: Tue, 13 Jan 2026 21:47:37 +0800 Subject: [PATCH 1/3] Add KNN classification algorithm --- src/machine_learning/k_nearest_neighbors.rs | 114 ++++++++++++++++++++ src/machine_learning/mod.rs | 2 + 2 files changed, 116 insertions(+) create mode 100644 src/machine_learning/k_nearest_neighbors.rs diff --git a/src/machine_learning/k_nearest_neighbors.rs b/src/machine_learning/k_nearest_neighbors.rs new file mode 100644 index 00000000000..b7e4b452020 --- /dev/null +++ b/src/machine_learning/k_nearest_neighbors.rs @@ -0,0 +1,114 @@ +/// K-Nearest Neighbors (KNN) algorithm for classification. +/// KNN is a simple, instance-based learning algorithm that classifies +/// a data point based on the majority class of its k nearest neighbors. + +fn euclidean_distance(p1: &[f64], p2: &[f64]) -> f64 { + if p1.len() != p2.len() { + return f64::INFINITY; + } + + p1.iter() + .zip(p2.iter()) + .map(|(a, b)| (a - b).powi(2)) + .sum::() + .sqrt() +} + +pub fn k_nearest_neighbors( + training_data: Vec<(Vec, f64)>, + test_point: Vec, + k: usize, +) -> Option { + if training_data.is_empty() || k == 0 || k > training_data.len() { + return None; + } + + let mut distances: Vec<(f64, f64)> = training_data + .iter() + .map(|(features, label)| (euclidean_distance(&test_point, features), *label)) + .collect(); + + distances.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal)); + + let k_nearest = &distances[..k]; + + let mut label_counts: Vec<(f64, usize)> = Vec::new(); + for (_, label) in k_nearest { + let found = label_counts + .iter_mut() + .find(|(l, _)| (l - label).abs() < 1e-10); + if let Some((_, count)) = found { + *count += 1; + } else { + label_counts.push((*label, 1)); + } + } + + label_counts + .iter() + .max_by_key(|(_, count)| *count) + .map(|(label, _)| *label) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_standard_knn() { + let training_data = vec![ + (vec![0.0, 0.0], 0.0), + (vec![1.0, 0.0], 0.0), + (vec![0.0, 1.0], 0.0), + (vec![5.0, 5.0], 1.0), + (vec![6.0, 5.0], 1.0), + (vec![5.0, 6.0], 1.0), + ]; + + let test_point = vec![0.5, 0.5]; + let result = k_nearest_neighbors(training_data.clone(), test_point, 3); + assert_eq!(result, Some(0.0)); + + let test_point = vec![5.5, 5.5]; + let result = k_nearest_neighbors(training_data, test_point, 3); + assert_eq!(result, Some(1.0)); + } + + #[test] + fn test_one_dimensional_knn() { + let training_data = vec![ + (vec![1.0], 0.0), + (vec![2.0], 0.0), + (vec![3.0], 0.0), + (vec![8.0], 1.0), + (vec![9.0], 1.0), + (vec![10.0], 1.0), + ]; + + let test_point = vec![2.5]; + let result = k_nearest_neighbors(training_data, test_point, 3); + assert_eq!(result, Some(0.0)); + } + + #[test] + fn test_knn_empty_data() { + let training_data = vec![]; + let test_point = vec![1.0, 2.0]; + let result = k_nearest_neighbors(training_data, test_point, 3); + assert_eq!(result, None); + } + + #[test] + fn test_knn_invalid_k() { + let training_data = vec![(vec![1.0], 0.0), (vec![2.0], 1.0)]; + let test_point = vec![1.5]; + + // k = 0 should return None + let result = k_nearest_neighbors(training_data.clone(), test_point.clone(), 0); + assert_eq!(result, None); + + // k > training_data.len() should return None + let result = k_nearest_neighbors(training_data, test_point, 10); + assert_eq!(result, None); + } +} diff --git a/src/machine_learning/mod.rs b/src/machine_learning/mod.rs index 534326d2121..b4baa8025cf 100644 --- a/src/machine_learning/mod.rs +++ b/src/machine_learning/mod.rs @@ -1,5 +1,6 @@ mod cholesky; mod k_means; +mod k_nearest_neighbors; mod linear_regression; mod logistic_regression; mod loss_function; @@ -7,6 +8,7 @@ mod optimization; pub use self::cholesky::cholesky; pub use self::k_means::k_means; +pub use self::k_nearest_neighbors::k_nearest_neighbors; pub use self::linear_regression::linear_regression; pub use self::logistic_regression::logistic_regression; pub use self::loss_function::average_margin_ranking_loss; From 678a5e12f6370fc4859a8f849997ef9c37c05d8d Mon Sep 17 00:00:00 2001 From: Dopamine Date: Tue, 13 Jan 2026 21:53:32 +0800 Subject: [PATCH 2/3] update directory --- DIRECTORY.md | 1 + 1 file changed, 1 insertion(+) diff --git a/DIRECTORY.md b/DIRECTORY.md index 735726149e4..a1b962ae803 100644 --- a/DIRECTORY.md +++ b/DIRECTORY.md @@ -201,6 +201,7 @@ * Machine Learning * [Cholesky](https://github.com/TheAlgorithms/Rust/blob/master/src/machine_learning/cholesky.rs) * [K-Means](https://github.com/TheAlgorithms/Rust/blob/master/src/machine_learning/k_means.rs) + * [K-Nearest Neighbors](https://github.com/TheAlgorithms/Rust/blob/master/src/machine_learning/k_nearest_neighbors.rs) * [Linear Regression](https://github.com/TheAlgorithms/Rust/blob/master/src/machine_learning/linear_regression.rs) * [Logistic Regression](https://github.com/TheAlgorithms/Rust/blob/master/src/machine_learning/logistic_regression.rs) * Loss Function From d33dadb5a036cbfe5f390a24362bec8dbbbf36cb Mon Sep 17 00:00:00 2001 From: Dopamine Date: Tue, 13 Jan 2026 22:41:52 +0800 Subject: [PATCH 3/3] add coverage test on euclidean distance --- src/machine_learning/k_nearest_neighbors.rs | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/src/machine_learning/k_nearest_neighbors.rs b/src/machine_learning/k_nearest_neighbors.rs index b7e4b452020..38c9fe1f99b 100644 --- a/src/machine_learning/k_nearest_neighbors.rs +++ b/src/machine_learning/k_nearest_neighbors.rs @@ -111,4 +111,16 @@ mod tests { let result = k_nearest_neighbors(training_data, test_point, 10); assert_eq!(result, None); } + + #[test] + fn test_euclidean_distance_different_dimensions() { + let training_data = vec![ + (vec![1.0, 2.0], 0.0), + (vec![2.0, 3.0], 0.0), + (vec![5.0], 1.0), + ]; + let test_point = vec![1.5, 2.5]; + let result = k_nearest_neighbors(training_data, test_point, 2); + assert_eq!(result, Some(0.0)); + } }