Skip to content

Implement predict_proba for DecisionTreeClassifier #287

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Jan 20, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion src/algorithm/neighbour/fastpair.rs
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,9 @@ mod tests_fastpair {
use crate::linalg::basic::{arrays::Array, matrix::DenseMatrix};

/// Brute force algorithm, used only for comparison and testing
pub fn closest_pair_brute(fastpair: &FastPair<f64, DenseMatrix<f64>>) -> PairwiseDistance<f64> {
pub fn closest_pair_brute(
fastpair: &FastPair<'_, f64, DenseMatrix<f64>>,
) -> PairwiseDistance<f64> {
use itertools::Itertools;
let m = fastpair.samples.shape().0;

Expand Down
25 changes: 12 additions & 13 deletions src/linalg/basic/matrix.rs
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ impl<'a, T: Debug + Display + Copy + Sized> DenseMatrixView<'a, T> {
}
}

impl<'a, T: Debug + Display + Copy + Sized> fmt::Display for DenseMatrixView<'a, T> {
impl<T: Debug + Display + Copy + Sized> fmt::Display for DenseMatrixView<'_, T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
writeln!(
f,
Expand Down Expand Up @@ -142,7 +142,7 @@ impl<'a, T: Debug + Display + Copy + Sized> DenseMatrixMutView<'a, T> {
}
}

fn iter_mut<'b>(&'b mut self, axis: u8) -> Box<dyn Iterator<Item = &mut T> + 'b> {
fn iter_mut<'b>(&'b mut self, axis: u8) -> Box<dyn Iterator<Item = &'b mut T> + 'b> {
let column_major = self.column_major;
let stride = self.stride;
let ptr = self.values.as_mut_ptr();
Expand All @@ -169,7 +169,7 @@ impl<'a, T: Debug + Display + Copy + Sized> DenseMatrixMutView<'a, T> {
}
}

impl<'a, T: Debug + Display + Copy + Sized> fmt::Display for DenseMatrixMutView<'a, T> {
impl<T: Debug + Display + Copy + Sized> fmt::Display for DenseMatrixMutView<'_, T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
writeln!(
f,
Expand Down Expand Up @@ -493,7 +493,7 @@ impl<T: Number + RealNumber> EVDDecomposable<T> for DenseMatrix<T> {}
impl<T: Number + RealNumber> LUDecomposable<T> for DenseMatrix<T> {}
impl<T: Number + RealNumber> SVDDecomposable<T> for DenseMatrix<T> {}

impl<'a, T: Debug + Display + Copy + Sized> Array<T, (usize, usize)> for DenseMatrixView<'a, T> {
impl<T: Debug + Display + Copy + Sized> Array<T, (usize, usize)> for DenseMatrixView<'_, T> {
fn get(&self, pos: (usize, usize)) -> &T {
if self.column_major {
&self.values[pos.0 + pos.1 * self.stride]
Expand All @@ -515,7 +515,7 @@ impl<'a, T: Debug + Display + Copy + Sized> Array<T, (usize, usize)> for DenseMa
}
}

impl<'a, T: Debug + Display + Copy + Sized> Array<T, usize> for DenseMatrixView<'a, T> {
impl<T: Debug + Display + Copy + Sized> Array<T, usize> for DenseMatrixView<'_, T> {
fn get(&self, i: usize) -> &T {
if self.nrows == 1 {
if self.column_major {
Expand Down Expand Up @@ -553,11 +553,11 @@ impl<'a, T: Debug + Display + Copy + Sized> Array<T, usize> for DenseMatrixView<
}
}

impl<'a, T: Debug + Display + Copy + Sized> ArrayView2<T> for DenseMatrixView<'a, T> {}
impl<T: Debug + Display + Copy + Sized> ArrayView2<T> for DenseMatrixView<'_, T> {}

impl<'a, T: Debug + Display + Copy + Sized> ArrayView1<T> for DenseMatrixView<'a, T> {}
impl<T: Debug + Display + Copy + Sized> ArrayView1<T> for DenseMatrixView<'_, T> {}

impl<'a, T: Debug + Display + Copy + Sized> Array<T, (usize, usize)> for DenseMatrixMutView<'a, T> {
impl<T: Debug + Display + Copy + Sized> Array<T, (usize, usize)> for DenseMatrixMutView<'_, T> {
fn get(&self, pos: (usize, usize)) -> &T {
if self.column_major {
&self.values[pos.0 + pos.1 * self.stride]
Expand All @@ -579,9 +579,7 @@ impl<'a, T: Debug + Display + Copy + Sized> Array<T, (usize, usize)> for DenseMa
}
}

impl<'a, T: Debug + Display + Copy + Sized> MutArray<T, (usize, usize)>
for DenseMatrixMutView<'a, T>
{
impl<T: Debug + Display + Copy + Sized> MutArray<T, (usize, usize)> for DenseMatrixMutView<'_, T> {
fn set(&mut self, pos: (usize, usize), x: T) {
if self.column_major {
self.values[pos.0 + pos.1 * self.stride] = x;
Expand All @@ -595,15 +593,16 @@ impl<'a, T: Debug + Display + Copy + Sized> MutArray<T, (usize, usize)>
}
}

impl<'a, T: Debug + Display + Copy + Sized> MutArrayView2<T> for DenseMatrixMutView<'a, T> {}
impl<T: Debug + Display + Copy + Sized> MutArrayView2<T> for DenseMatrixMutView<'_, T> {}

impl<'a, T: Debug + Display + Copy + Sized> ArrayView2<T> for DenseMatrixMutView<'a, T> {}
impl<T: Debug + Display + Copy + Sized> ArrayView2<T> for DenseMatrixMutView<'_, T> {}

impl<T: RealNumber> MatrixStats<T> for DenseMatrix<T> {}

impl<T: RealNumber> MatrixPreprocessing<T> for DenseMatrix<T> {}

#[cfg(test)]
#[warn(clippy::reversed_empty_ranges)]
mod tests {
use super::*;
use approx::relative_eq;
Expand Down
12 changes: 6 additions & 6 deletions src/linalg/basic/vector.rs
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ impl<T: Debug + Display + Copy + Sized> Array1<T> for Vec<T> {
}
}

impl<'a, T: Debug + Display + Copy + Sized> Array<T, usize> for VecMutView<'a, T> {
impl<T: Debug + Display + Copy + Sized> Array<T, usize> for VecMutView<'_, T> {
fn get(&self, i: usize) -> &T {
&self.ptr[i]
}
Expand All @@ -138,7 +138,7 @@ impl<'a, T: Debug + Display + Copy + Sized> Array<T, usize> for VecMutView<'a, T
}
}

impl<'a, T: Debug + Display + Copy + Sized> MutArray<T, usize> for VecMutView<'a, T> {
impl<T: Debug + Display + Copy + Sized> MutArray<T, usize> for VecMutView<'_, T> {
fn set(&mut self, i: usize, x: T) {
self.ptr[i] = x;
}
Expand All @@ -149,10 +149,10 @@ impl<'a, T: Debug + Display + Copy + Sized> MutArray<T, usize> for VecMutView<'a
}
}

impl<'a, T: Debug + Display + Copy + Sized> ArrayView1<T> for VecMutView<'a, T> {}
impl<'a, T: Debug + Display + Copy + Sized> MutArrayView1<T> for VecMutView<'a, T> {}
impl<T: Debug + Display + Copy + Sized> ArrayView1<T> for VecMutView<'_, T> {}
impl<T: Debug + Display + Copy + Sized> MutArrayView1<T> for VecMutView<'_, T> {}

impl<'a, T: Debug + Display + Copy + Sized> Array<T, usize> for VecView<'a, T> {
impl<T: Debug + Display + Copy + Sized> Array<T, usize> for VecView<'_, T> {
fn get(&self, i: usize) -> &T {
&self.ptr[i]
}
Expand All @@ -171,7 +171,7 @@ impl<'a, T: Debug + Display + Copy + Sized> Array<T, usize> for VecView<'a, T> {
}
}

impl<'a, T: Debug + Display + Copy + Sized> ArrayView1<T> for VecView<'a, T> {}
impl<T: Debug + Display + Copy + Sized> ArrayView1<T> for VecView<'_, T> {}

#[cfg(test)]
mod tests {
Expand Down
16 changes: 6 additions & 10 deletions src/linalg/ndarray/matrix.rs
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ impl<T: Debug + Display + Copy + Sized> ArrayView2<T> for ArrayBase<OwnedRepr<T>

impl<T: Debug + Display + Copy + Sized> MutArrayView2<T> for ArrayBase<OwnedRepr<T>, Ix2> {}

impl<'a, T: Debug + Display + Copy + Sized> BaseArray<T, (usize, usize)> for ArrayView<'a, T, Ix2> {
impl<T: Debug + Display + Copy + Sized> BaseArray<T, (usize, usize)> for ArrayView<'_, T, Ix2> {
fn get(&self, pos: (usize, usize)) -> &T {
&self[[pos.0, pos.1]]
}
Expand Down Expand Up @@ -144,11 +144,9 @@ impl<T: Number + RealNumber> EVDDecomposable<T> for ArrayBase<OwnedRepr<T>, Ix2>
impl<T: Number + RealNumber> LUDecomposable<T> for ArrayBase<OwnedRepr<T>, Ix2> {}
impl<T: Number + RealNumber> SVDDecomposable<T> for ArrayBase<OwnedRepr<T>, Ix2> {}

impl<'a, T: Debug + Display + Copy + Sized> ArrayView2<T> for ArrayView<'a, T, Ix2> {}
impl<T: Debug + Display + Copy + Sized> ArrayView2<T> for ArrayView<'_, T, Ix2> {}

impl<'a, T: Debug + Display + Copy + Sized> BaseArray<T, (usize, usize)>
for ArrayViewMut<'a, T, Ix2>
{
impl<T: Debug + Display + Copy + Sized> BaseArray<T, (usize, usize)> for ArrayViewMut<'_, T, Ix2> {
fn get(&self, pos: (usize, usize)) -> &T {
&self[[pos.0, pos.1]]
}
Expand All @@ -175,9 +173,7 @@ impl<'a, T: Debug + Display + Copy + Sized> BaseArray<T, (usize, usize)>
}
}

impl<'a, T: Debug + Display + Copy + Sized> MutArray<T, (usize, usize)>
for ArrayViewMut<'a, T, Ix2>
{
impl<T: Debug + Display + Copy + Sized> MutArray<T, (usize, usize)> for ArrayViewMut<'_, T, Ix2> {
fn set(&mut self, pos: (usize, usize), x: T) {
self[[pos.0, pos.1]] = x
}
Expand All @@ -195,9 +191,9 @@ impl<'a, T: Debug + Display + Copy + Sized> MutArray<T, (usize, usize)>
}
}

impl<'a, T: Debug + Display + Copy + Sized> MutArrayView2<T> for ArrayViewMut<'a, T, Ix2> {}
impl<T: Debug + Display + Copy + Sized> MutArrayView2<T> for ArrayViewMut<'_, T, Ix2> {}

impl<'a, T: Debug + Display + Copy + Sized> ArrayView2<T> for ArrayViewMut<'a, T, Ix2> {}
impl<T: Debug + Display + Copy + Sized> ArrayView2<T> for ArrayViewMut<'_, T, Ix2> {}

#[cfg(test)]
mod tests {
Expand Down
12 changes: 6 additions & 6 deletions src/linalg/ndarray/vector.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ impl<T: Debug + Display + Copy + Sized> ArrayView1<T> for ArrayBase<OwnedRepr<T>

impl<T: Debug + Display + Copy + Sized> MutArrayView1<T> for ArrayBase<OwnedRepr<T>, Ix1> {}

impl<'a, T: Debug + Display + Copy + Sized> BaseArray<T, usize> for ArrayView<'a, T, Ix1> {
impl<T: Debug + Display + Copy + Sized> BaseArray<T, usize> for ArrayView<'_, T, Ix1> {
fn get(&self, i: usize) -> &T {
&self[i]
}
Expand All @@ -60,9 +60,9 @@ impl<'a, T: Debug + Display + Copy + Sized> BaseArray<T, usize> for ArrayView<'a
}
}

impl<'a, T: Debug + Display + Copy + Sized> ArrayView1<T> for ArrayView<'a, T, Ix1> {}
impl<T: Debug + Display + Copy + Sized> ArrayView1<T> for ArrayView<'_, T, Ix1> {}

impl<'a, T: Debug + Display + Copy + Sized> BaseArray<T, usize> for ArrayViewMut<'a, T, Ix1> {
impl<T: Debug + Display + Copy + Sized> BaseArray<T, usize> for ArrayViewMut<'_, T, Ix1> {
fn get(&self, i: usize) -> &T {
&self[i]
}
Expand All @@ -81,7 +81,7 @@ impl<'a, T: Debug + Display + Copy + Sized> BaseArray<T, usize> for ArrayViewMut
}
}

impl<'a, T: Debug + Display + Copy + Sized> MutArray<T, usize> for ArrayViewMut<'a, T, Ix1> {
impl<T: Debug + Display + Copy + Sized> MutArray<T, usize> for ArrayViewMut<'_, T, Ix1> {
fn set(&mut self, i: usize, x: T) {
self[i] = x;
}
Expand All @@ -92,8 +92,8 @@ impl<'a, T: Debug + Display + Copy + Sized> MutArray<T, usize> for ArrayViewMut<
}
}

impl<'a, T: Debug + Display + Copy + Sized> ArrayView1<T> for ArrayViewMut<'a, T, Ix1> {}
impl<'a, T: Debug + Display + Copy + Sized> MutArrayView1<T> for ArrayViewMut<'a, T, Ix1> {}
impl<T: Debug + Display + Copy + Sized> ArrayView1<T> for ArrayViewMut<'_, T, Ix1> {}
impl<T: Debug + Display + Copy + Sized> MutArrayView1<T> for ArrayViewMut<'_, T, Ix1> {}

impl<T: Debug + Display + Copy + Sized> Array1<T> for ArrayBase<OwnedRepr<T>, Ix1> {
fn slice<'a>(&'a self, range: Range<usize>) -> Box<dyn ArrayView1<T> + 'a> {
Expand Down
1 change: 0 additions & 1 deletion src/linalg/traits/stats.rs
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,6 @@ pub trait MatrixPreprocessing<T: RealNumber>: MutArrayView2<T> + Clone {
///
/// assert_eq!(a, expected);
/// ```

fn binarize_mut(&mut self, threshold: T) {
let (nrows, ncols) = self.shape();
for row in 0..nrows {
Expand Down
8 changes: 4 additions & 4 deletions src/linear/logistic_regression.rs
Original file line number Diff line number Diff line change
Expand Up @@ -258,8 +258,8 @@ impl<TX: Number + FloatNumber + RealNumber, TY: Number + Ord, X: Array2<TX>, Y:
}
}

impl<'a, T: Number + FloatNumber, X: Array2<T>> ObjectiveFunction<T, X>
for BinaryObjectiveFunction<'a, T, X>
impl<T: Number + FloatNumber, X: Array2<T>> ObjectiveFunction<T, X>
for BinaryObjectiveFunction<'_, T, X>
{
fn f(&self, w_bias: &[T]) -> T {
let mut f = T::zero();
Expand Down Expand Up @@ -313,8 +313,8 @@ struct MultiClassObjectiveFunction<'a, T: Number + FloatNumber, X: Array2<T>> {
_phantom_t: PhantomData<T>,
}

impl<'a, T: Number + FloatNumber + RealNumber, X: Array2<T>> ObjectiveFunction<T, X>
for MultiClassObjectiveFunction<'a, T, X>
impl<T: Number + FloatNumber + RealNumber, X: Array2<T>> ObjectiveFunction<T, X>
for MultiClassObjectiveFunction<'_, T, X>
{
fn f(&self, w_bias: &[T]) -> T {
let mut f = T::zero();
Expand Down
2 changes: 1 addition & 1 deletion src/naive_bayes/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ mod tests {
#[derive(Debug, PartialEq, Clone)]
struct TestDistribution<'d>(&'d Vec<i32>);

impl<'d> NBDistribution<i32, i32> for TestDistribution<'d> {
impl NBDistribution<i32, i32> for TestDistribution<'_> {
fn prior(&self, _class_index: usize) -> f64 {
1.
}
Expand Down
20 changes: 8 additions & 12 deletions src/preprocessing/numerical.rs
Original file line number Diff line number Diff line change
Expand Up @@ -172,18 +172,14 @@ where
T: Number + RealNumber,
M: Array2<T>,
{
if let Some(output_matrix) = columns.first().cloned() {
return Some(
columns
.iter()
.skip(1)
.fold(output_matrix, |current_matrix, new_colum| {
current_matrix.h_stack(new_colum)
}),
);
} else {
None
}
columns.first().cloned().map(|output_matrix| {
columns
.iter()
.skip(1)
.fold(output_matrix, |current_matrix, new_colum| {
current_matrix.h_stack(new_colum)
})
})
}

#[cfg(test)]
Expand Down
2 changes: 1 addition & 1 deletion src/readers/csv.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ pub struct CSVDefinition<'a> {
/// What seperates the fields in your csv-file?
field_seperator: &'a str,
}
impl<'a> Default for CSVDefinition<'a> {
impl Default for CSVDefinition<'_> {
fn default() -> Self {
Self {
n_rows_header: 1,
Expand Down
6 changes: 3 additions & 3 deletions src/svm/svc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -360,8 +360,8 @@ impl<'a, TX: Number + RealNumber, TY: Number + Ord, X: Array2<TX> + 'a, Y: Array
}
}

impl<'a, TX: Number + RealNumber, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>> PartialEq
for SVC<'a, TX, TY, X, Y>
impl<TX: Number + RealNumber, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>> PartialEq
for SVC<'_, TX, TY, X, Y>
{
fn eq(&self, other: &Self) -> bool {
if (self.b.unwrap().sub(other.b.unwrap())).abs() > TX::epsilon() * TX::two()
Expand Down Expand Up @@ -1110,7 +1110,7 @@ mod tests {
let svc = SVC::fit(&x, &y, &params).unwrap();

// serialization
let deserialized_svc: SVC<f64, i32, _, _> =
let deserialized_svc: SVC<'_, f64, i32, _, _> =
serde_json::from_str(&serde_json::to_string(&svc).unwrap()).unwrap();

assert_eq!(svc, deserialized_svc);
Expand Down
6 changes: 3 additions & 3 deletions src/svm/svr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -281,8 +281,8 @@ impl<'a, T: Number + FloatNumber + PartialOrd, X: Array2<T>, Y: Array1<T>> SVR<'
}
}

impl<'a, T: Number + FloatNumber + PartialOrd, X: Array2<T>, Y: Array1<T>> PartialEq
for SVR<'a, T, X, Y>
impl<T: Number + FloatNumber + PartialOrd, X: Array2<T>, Y: Array1<T>> PartialEq
for SVR<'_, T, X, Y>
{
fn eq(&self, other: &Self) -> bool {
if (self.b - other.b).abs() > T::epsilon() * T::two()
Expand Down Expand Up @@ -702,7 +702,7 @@ mod tests {

let svr = SVR::fit(&x, &y, &params).unwrap();

let deserialized_svr: SVR<f64, DenseMatrix<f64>, _> =
let deserialized_svr: SVR<'_, f64, DenseMatrix<f64>, _> =
serde_json::from_str(&serde_json::to_string(&svr).unwrap()).unwrap();

assert_eq!(svr, deserialized_svr);
Expand Down
Loading
Loading