diff --git a/src/algorithm/neighbour/fastpair.rs b/src/algorithm/neighbour/fastpair.rs index 9f663f67..4e99261b 100644 --- a/src/algorithm/neighbour/fastpair.rs +++ b/src/algorithm/neighbour/fastpair.rs @@ -212,7 +212,9 @@ mod tests_fastpair { use crate::linalg::basic::{arrays::Array, matrix::DenseMatrix}; /// Brute force algorithm, used only for comparison and testing - pub fn closest_pair_brute(fastpair: &FastPair>) -> PairwiseDistance { + pub fn closest_pair_brute( + fastpair: &FastPair<'_, f64, DenseMatrix>, + ) -> PairwiseDistance { use itertools::Itertools; let m = fastpair.samples.shape().0; diff --git a/src/linalg/basic/matrix.rs b/src/linalg/basic/matrix.rs index 47c5e9d2..88a0849c 100644 --- a/src/linalg/basic/matrix.rs +++ b/src/linalg/basic/matrix.rs @@ -91,7 +91,7 @@ impl<'a, T: Debug + Display + Copy + Sized> DenseMatrixView<'a, T> { } } -impl<'a, T: Debug + Display + Copy + Sized> fmt::Display for DenseMatrixView<'a, T> { +impl fmt::Display for DenseMatrixView<'_, T> { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { writeln!( f, @@ -142,7 +142,7 @@ impl<'a, T: Debug + Display + Copy + Sized> DenseMatrixMutView<'a, T> { } } - fn iter_mut<'b>(&'b mut self, axis: u8) -> Box + 'b> { + fn iter_mut<'b>(&'b mut self, axis: u8) -> Box + 'b> { let column_major = self.column_major; let stride = self.stride; let ptr = self.values.as_mut_ptr(); @@ -169,7 +169,7 @@ impl<'a, T: Debug + Display + Copy + Sized> DenseMatrixMutView<'a, T> { } } -impl<'a, T: Debug + Display + Copy + Sized> fmt::Display for DenseMatrixMutView<'a, T> { +impl fmt::Display for DenseMatrixMutView<'_, T> { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { writeln!( f, @@ -493,7 +493,7 @@ impl EVDDecomposable for DenseMatrix {} impl LUDecomposable for DenseMatrix {} impl SVDDecomposable for DenseMatrix {} -impl<'a, T: Debug + Display + Copy + Sized> Array for DenseMatrixView<'a, T> { +impl Array for DenseMatrixView<'_, T> { fn get(&self, pos: (usize, usize)) -> &T { if self.column_major { &self.values[pos.0 + pos.1 * self.stride] @@ -515,7 +515,7 @@ impl<'a, T: Debug + Display + Copy + Sized> Array for DenseMa } } -impl<'a, T: Debug + Display + Copy + Sized> Array for DenseMatrixView<'a, T> { +impl Array for DenseMatrixView<'_, T> { fn get(&self, i: usize) -> &T { if self.nrows == 1 { if self.column_major { @@ -553,11 +553,11 @@ impl<'a, T: Debug + Display + Copy + Sized> Array for DenseMatrixView< } } -impl<'a, T: Debug + Display + Copy + Sized> ArrayView2 for DenseMatrixView<'a, T> {} +impl ArrayView2 for DenseMatrixView<'_, T> {} -impl<'a, T: Debug + Display + Copy + Sized> ArrayView1 for DenseMatrixView<'a, T> {} +impl ArrayView1 for DenseMatrixView<'_, T> {} -impl<'a, T: Debug + Display + Copy + Sized> Array for DenseMatrixMutView<'a, T> { +impl Array for DenseMatrixMutView<'_, T> { fn get(&self, pos: (usize, usize)) -> &T { if self.column_major { &self.values[pos.0 + pos.1 * self.stride] @@ -579,9 +579,7 @@ impl<'a, T: Debug + Display + Copy + Sized> Array for DenseMa } } -impl<'a, T: Debug + Display + Copy + Sized> MutArray - for DenseMatrixMutView<'a, T> -{ +impl MutArray for DenseMatrixMutView<'_, T> { fn set(&mut self, pos: (usize, usize), x: T) { if self.column_major { self.values[pos.0 + pos.1 * self.stride] = x; @@ -595,15 +593,16 @@ impl<'a, T: Debug + Display + Copy + Sized> MutArray } } -impl<'a, T: Debug + Display + Copy + Sized> MutArrayView2 for DenseMatrixMutView<'a, T> {} +impl MutArrayView2 for DenseMatrixMutView<'_, T> {} -impl<'a, T: Debug + Display + Copy + Sized> ArrayView2 for DenseMatrixMutView<'a, T> {} +impl ArrayView2 for DenseMatrixMutView<'_, T> {} impl MatrixStats for DenseMatrix {} impl MatrixPreprocessing for DenseMatrix {} #[cfg(test)] +#[warn(clippy::reversed_empty_ranges)] mod tests { use super::*; use approx::relative_eq; diff --git a/src/linalg/basic/vector.rs b/src/linalg/basic/vector.rs index 05c03756..d2e0bae6 100644 --- a/src/linalg/basic/vector.rs +++ b/src/linalg/basic/vector.rs @@ -119,7 +119,7 @@ impl Array1 for Vec { } } -impl<'a, T: Debug + Display + Copy + Sized> Array for VecMutView<'a, T> { +impl Array for VecMutView<'_, T> { fn get(&self, i: usize) -> &T { &self.ptr[i] } @@ -138,7 +138,7 @@ impl<'a, T: Debug + Display + Copy + Sized> Array for VecMutView<'a, T } } -impl<'a, T: Debug + Display + Copy + Sized> MutArray for VecMutView<'a, T> { +impl MutArray for VecMutView<'_, T> { fn set(&mut self, i: usize, x: T) { self.ptr[i] = x; } @@ -149,10 +149,10 @@ impl<'a, T: Debug + Display + Copy + Sized> MutArray for VecMutView<'a } } -impl<'a, T: Debug + Display + Copy + Sized> ArrayView1 for VecMutView<'a, T> {} -impl<'a, T: Debug + Display + Copy + Sized> MutArrayView1 for VecMutView<'a, T> {} +impl ArrayView1 for VecMutView<'_, T> {} +impl MutArrayView1 for VecMutView<'_, T> {} -impl<'a, T: Debug + Display + Copy + Sized> Array for VecView<'a, T> { +impl Array for VecView<'_, T> { fn get(&self, i: usize) -> &T { &self.ptr[i] } @@ -171,7 +171,7 @@ impl<'a, T: Debug + Display + Copy + Sized> Array for VecView<'a, T> { } } -impl<'a, T: Debug + Display + Copy + Sized> ArrayView1 for VecView<'a, T> {} +impl ArrayView1 for VecView<'_, T> {} #[cfg(test)] mod tests { diff --git a/src/linalg/ndarray/matrix.rs b/src/linalg/ndarray/matrix.rs index adc8d7e8..5040497a 100644 --- a/src/linalg/ndarray/matrix.rs +++ b/src/linalg/ndarray/matrix.rs @@ -68,7 +68,7 @@ impl ArrayView2 for ArrayBase impl MutArrayView2 for ArrayBase, Ix2> {} -impl<'a, T: Debug + Display + Copy + Sized> BaseArray for ArrayView<'a, T, Ix2> { +impl BaseArray for ArrayView<'_, T, Ix2> { fn get(&self, pos: (usize, usize)) -> &T { &self[[pos.0, pos.1]] } @@ -144,11 +144,9 @@ impl EVDDecomposable for ArrayBase, Ix2> impl LUDecomposable for ArrayBase, Ix2> {} impl SVDDecomposable for ArrayBase, Ix2> {} -impl<'a, T: Debug + Display + Copy + Sized> ArrayView2 for ArrayView<'a, T, Ix2> {} +impl ArrayView2 for ArrayView<'_, T, Ix2> {} -impl<'a, T: Debug + Display + Copy + Sized> BaseArray - for ArrayViewMut<'a, T, Ix2> -{ +impl BaseArray for ArrayViewMut<'_, T, Ix2> { fn get(&self, pos: (usize, usize)) -> &T { &self[[pos.0, pos.1]] } @@ -175,9 +173,7 @@ impl<'a, T: Debug + Display + Copy + Sized> BaseArray } } -impl<'a, T: Debug + Display + Copy + Sized> MutArray - for ArrayViewMut<'a, T, Ix2> -{ +impl MutArray for ArrayViewMut<'_, T, Ix2> { fn set(&mut self, pos: (usize, usize), x: T) { self[[pos.0, pos.1]] = x } @@ -195,9 +191,9 @@ impl<'a, T: Debug + Display + Copy + Sized> MutArray } } -impl<'a, T: Debug + Display + Copy + Sized> MutArrayView2 for ArrayViewMut<'a, T, Ix2> {} +impl MutArrayView2 for ArrayViewMut<'_, T, Ix2> {} -impl<'a, T: Debug + Display + Copy + Sized> ArrayView2 for ArrayViewMut<'a, T, Ix2> {} +impl ArrayView2 for ArrayViewMut<'_, T, Ix2> {} #[cfg(test)] mod tests { diff --git a/src/linalg/ndarray/vector.rs b/src/linalg/ndarray/vector.rs index 7105da89..de3f7d93 100644 --- a/src/linalg/ndarray/vector.rs +++ b/src/linalg/ndarray/vector.rs @@ -41,7 +41,7 @@ impl ArrayView1 for ArrayBase impl MutArrayView1 for ArrayBase, Ix1> {} -impl<'a, T: Debug + Display + Copy + Sized> BaseArray for ArrayView<'a, T, Ix1> { +impl BaseArray for ArrayView<'_, T, Ix1> { fn get(&self, i: usize) -> &T { &self[i] } @@ -60,9 +60,9 @@ impl<'a, T: Debug + Display + Copy + Sized> BaseArray for ArrayView<'a } } -impl<'a, T: Debug + Display + Copy + Sized> ArrayView1 for ArrayView<'a, T, Ix1> {} +impl ArrayView1 for ArrayView<'_, T, Ix1> {} -impl<'a, T: Debug + Display + Copy + Sized> BaseArray for ArrayViewMut<'a, T, Ix1> { +impl BaseArray for ArrayViewMut<'_, T, Ix1> { fn get(&self, i: usize) -> &T { &self[i] } @@ -81,7 +81,7 @@ impl<'a, T: Debug + Display + Copy + Sized> BaseArray for ArrayViewMut } } -impl<'a, T: Debug + Display + Copy + Sized> MutArray for ArrayViewMut<'a, T, Ix1> { +impl MutArray for ArrayViewMut<'_, T, Ix1> { fn set(&mut self, i: usize, x: T) { self[i] = x; } @@ -92,8 +92,8 @@ impl<'a, T: Debug + Display + Copy + Sized> MutArray for ArrayViewMut< } } -impl<'a, T: Debug + Display + Copy + Sized> ArrayView1 for ArrayViewMut<'a, T, Ix1> {} -impl<'a, T: Debug + Display + Copy + Sized> MutArrayView1 for ArrayViewMut<'a, T, Ix1> {} +impl ArrayView1 for ArrayViewMut<'_, T, Ix1> {} +impl MutArrayView1 for ArrayViewMut<'_, T, Ix1> {} impl Array1 for ArrayBase, Ix1> { fn slice<'a>(&'a self, range: Range) -> Box + 'a> { diff --git a/src/linalg/traits/stats.rs b/src/linalg/traits/stats.rs index 8702a81a..6c3db820 100644 --- a/src/linalg/traits/stats.rs +++ b/src/linalg/traits/stats.rs @@ -142,7 +142,6 @@ pub trait MatrixPreprocessing: MutArrayView2 + Clone { /// /// assert_eq!(a, expected); /// ``` - fn binarize_mut(&mut self, threshold: T) { let (nrows, ncols) = self.shape(); for row in 0..nrows { diff --git a/src/linear/logistic_regression.rs b/src/linear/logistic_regression.rs index 7e934288..c28dc347 100644 --- a/src/linear/logistic_regression.rs +++ b/src/linear/logistic_regression.rs @@ -258,8 +258,8 @@ impl, Y: } } -impl<'a, T: Number + FloatNumber, X: Array2> ObjectiveFunction - for BinaryObjectiveFunction<'a, T, X> +impl> ObjectiveFunction + for BinaryObjectiveFunction<'_, T, X> { fn f(&self, w_bias: &[T]) -> T { let mut f = T::zero(); @@ -313,8 +313,8 @@ struct MultiClassObjectiveFunction<'a, T: Number + FloatNumber, X: Array2> { _phantom_t: PhantomData, } -impl<'a, T: Number + FloatNumber + RealNumber, X: Array2> ObjectiveFunction - for MultiClassObjectiveFunction<'a, T, X> +impl> ObjectiveFunction + for MultiClassObjectiveFunction<'_, T, X> { fn f(&self, w_bias: &[T]) -> T { let mut f = T::zero(); diff --git a/src/naive_bayes/mod.rs b/src/naive_bayes/mod.rs index 31cdd46d..26d91545 100644 --- a/src/naive_bayes/mod.rs +++ b/src/naive_bayes/mod.rs @@ -147,7 +147,7 @@ mod tests { #[derive(Debug, PartialEq, Clone)] struct TestDistribution<'d>(&'d Vec); - impl<'d> NBDistribution for TestDistribution<'d> { + impl NBDistribution for TestDistribution<'_> { fn prior(&self, _class_index: usize) -> f64 { 1. } diff --git a/src/preprocessing/numerical.rs b/src/preprocessing/numerical.rs index ddb74a45..674f6814 100644 --- a/src/preprocessing/numerical.rs +++ b/src/preprocessing/numerical.rs @@ -172,18 +172,14 @@ where T: Number + RealNumber, M: Array2, { - if let Some(output_matrix) = columns.first().cloned() { - return Some( - columns - .iter() - .skip(1) - .fold(output_matrix, |current_matrix, new_colum| { - current_matrix.h_stack(new_colum) - }), - ); - } else { - None - } + columns.first().cloned().map(|output_matrix| { + columns + .iter() + .skip(1) + .fold(output_matrix, |current_matrix, new_colum| { + current_matrix.h_stack(new_colum) + }) + }) } #[cfg(test)] diff --git a/src/readers/csv.rs b/src/readers/csv.rs index f8a03ebd..e9a88436 100644 --- a/src/readers/csv.rs +++ b/src/readers/csv.rs @@ -30,7 +30,7 @@ pub struct CSVDefinition<'a> { /// What seperates the fields in your csv-file? field_seperator: &'a str, } -impl<'a> Default for CSVDefinition<'a> { +impl Default for CSVDefinition<'_> { fn default() -> Self { Self { n_rows_header: 1, diff --git a/src/svm/svc.rs b/src/svm/svc.rs index 6477778b..cc5a0beb 100644 --- a/src/svm/svc.rs +++ b/src/svm/svc.rs @@ -360,8 +360,8 @@ impl<'a, TX: Number + RealNumber, TY: Number + Ord, X: Array2 + 'a, Y: Array } } -impl<'a, TX: Number + RealNumber, TY: Number + Ord, X: Array2, Y: Array1> PartialEq - for SVC<'a, TX, TY, X, Y> +impl, Y: Array1> PartialEq + for SVC<'_, TX, TY, X, Y> { fn eq(&self, other: &Self) -> bool { if (self.b.unwrap().sub(other.b.unwrap())).abs() > TX::epsilon() * TX::two() @@ -1110,7 +1110,7 @@ mod tests { let svc = SVC::fit(&x, &y, ¶ms).unwrap(); // serialization - let deserialized_svc: SVC = + let deserialized_svc: SVC<'_, f64, i32, _, _> = serde_json::from_str(&serde_json::to_string(&svc).unwrap()).unwrap(); assert_eq!(svc, deserialized_svc); diff --git a/src/svm/svr.rs b/src/svm/svr.rs index e68ebf85..4ce0aa28 100644 --- a/src/svm/svr.rs +++ b/src/svm/svr.rs @@ -281,8 +281,8 @@ impl<'a, T: Number + FloatNumber + PartialOrd, X: Array2, Y: Array1> SVR<' } } -impl<'a, T: Number + FloatNumber + PartialOrd, X: Array2, Y: Array1> PartialEq - for SVR<'a, T, X, Y> +impl, Y: Array1> PartialEq + for SVR<'_, T, X, Y> { fn eq(&self, other: &Self) -> bool { if (self.b - other.b).abs() > T::epsilon() * T::two() @@ -702,7 +702,7 @@ mod tests { let svr = SVR::fit(&x, &y, ¶ms).unwrap(); - let deserialized_svr: SVR, _> = + let deserialized_svr: SVR<'_, f64, DenseMatrix, _> = serde_json::from_str(&serde_json::to_string(&svr).unwrap()).unwrap(); assert_eq!(svr, deserialized_svr); diff --git a/src/tree/decision_tree_classifier.rs b/src/tree/decision_tree_classifier.rs index c6596517..5679516a 100644 --- a/src/tree/decision_tree_classifier.rs +++ b/src/tree/decision_tree_classifier.rs @@ -77,7 +77,9 @@ use serde::{Deserialize, Serialize}; use crate::api::{Predictor, SupervisedEstimator}; use crate::error::Failed; +use crate::linalg::basic::arrays::MutArray; use crate::linalg::basic::arrays::{Array1, Array2, MutArrayView1}; +use crate::linalg::basic::matrix::DenseMatrix; use crate::numbers::basenum::Number; use crate::rand_custom::get_rng_impl; @@ -887,11 +889,77 @@ impl, Y: Array1> } importances } + + /// Predict class probabilities for the input samples. + /// + /// # Arguments + /// + /// * `x` - The input samples as a matrix where each row is a sample and each column is a feature. + /// + /// # Returns + /// + /// A `Result` containing a `DenseMatrix` where each row corresponds to a sample and each column + /// corresponds to a class. The values represent the probability of the sample belonging to each class. + /// + /// # Errors + /// + /// Returns an error if at least one row prediction process fails. + pub fn predict_proba(&self, x: &X) -> Result, Failed> { + let (n_samples, _) = x.shape(); + let n_classes = self.classes().len(); + let mut result = DenseMatrix::::zeros(n_samples, n_classes); + + for i in 0..n_samples { + let probs = self.predict_proba_for_row(x, i)?; + for (j, &prob) in probs.iter().enumerate() { + result.set((i, j), prob); + } + } + + Ok(result) + } + + /// Predict class probabilities for a single input sample. + /// + /// # Arguments + /// + /// * `x` - The input matrix containing all samples. + /// * `row` - The index of the row in `x` for which to predict probabilities. + /// + /// # Returns + /// + /// A vector of probabilities, one for each class, representing the probability + /// of the input sample belonging to each class. + fn predict_proba_for_row(&self, x: &X, row: usize) -> Result, Failed> { + let mut node = 0; + + while let Some(current_node) = self.nodes().get(node) { + if current_node.true_child.is_none() && current_node.false_child.is_none() { + // Leaf node reached + let mut probs = vec![0.0; self.classes().len()]; + probs[current_node.output] = 1.0; + return Ok(probs); + } + + let split_feature = current_node.split_feature; + let split_value = current_node.split_value.unwrap_or(f64::NAN); + + if x.get((row, split_feature)).to_f64().unwrap() <= split_value { + node = current_node.true_child.unwrap(); + } else { + node = current_node.false_child.unwrap(); + } + } + + // This should never happen if the tree is properly constructed + Err(Failed::predict("Nodes iteration did not reach leaf")) + } } #[cfg(test)] mod tests { use super::*; + use crate::linalg::basic::arrays::Array; use crate::linalg::basic::matrix::DenseMatrix; #[test] @@ -934,6 +1002,51 @@ mod tests { ); } + #[cfg_attr( + all(target_arch = "wasm32", not(target_os = "wasi")), + wasm_bindgen_test::wasm_bindgen_test + )] + #[test] + fn test_predict_proba() { + let x: DenseMatrix = DenseMatrix::from_2d_array(&[ + &[5.1, 3.5, 1.4, 0.2], + &[4.9, 3.0, 1.4, 0.2], + &[4.7, 3.2, 1.3, 0.2], + &[4.6, 3.1, 1.5, 0.2], + &[5.0, 3.6, 1.4, 0.2], + &[7.0, 3.2, 4.7, 1.4], + &[6.4, 3.2, 4.5, 1.5], + &[6.9, 3.1, 4.9, 1.5], + &[5.5, 2.3, 4.0, 1.3], + &[6.5, 2.8, 4.6, 1.5], + ]) + .unwrap(); + let y: Vec = vec![0, 0, 0, 0, 0, 1, 1, 1, 1, 1]; + + let tree = DecisionTreeClassifier::fit(&x, &y, Default::default()).unwrap(); + let probabilities = tree.predict_proba(&x).unwrap(); + + assert_eq!(probabilities.shape(), (10, 2)); + + for row in 0..10 { + let row_sum: f64 = probabilities.get_row(row).sum(); + assert!( + (row_sum - 1.0).abs() < 1e-6, + "Row probabilities should sum to 1" + ); + } + + // Check if the first 5 samples have higher probability for class 0 + for i in 0..5 { + assert!(probabilities.get((i, 0)) > probabilities.get((i, 1))); + } + + // Check if the last 5 samples have higher probability for class 1 + for i in 5..10 { + assert!(probabilities.get((i, 1)) > probabilities.get((i, 0))); + } + } + #[cfg_attr( all(target_arch = "wasm32", not(target_os = "wasi")), wasm_bindgen_test::wasm_bindgen_test