From 6a4cd14b2a26a6964db99a5dbdae5c5d3d9bdc77 Mon Sep 17 00:00:00 2001 From: massquantity Date: Sun, 6 Jul 2025 18:34:38 +0800 Subject: [PATCH 1/2] [Rust] Rename add matrix to merge --- rust/src/item_cf.rs | 4 ++-- rust/src/sparse.rs | 35 +++++++++++++++++++---------------- rust/src/user_cf.rs | 4 ++-- 3 files changed, 23 insertions(+), 20 deletions(-) diff --git a/rust/src/item_cf.rs b/rust/src/item_cf.rs index 7ce8bbaf..991111bf 100644 --- a/rust/src/item_cf.rs +++ b/rust/src/item_cf.rs @@ -228,12 +228,12 @@ impl PyItemCF { update_by_sims(self.n_items, &cosine_sims, &mut self.sim_mapping)?; // merge interactions for inference on new users/items - self.user_interactions = CsrMatrix::add( + self.user_interactions = CsrMatrix::merge( &self.user_interactions, &new_user_interactions, Some(self.n_users), ); - self.item_interactions = CsrMatrix::add( + self.item_interactions = CsrMatrix::merge( &self.item_interactions, &new_item_interactions, Some(self.n_items), diff --git a/rust/src/sparse.rs b/rust/src/sparse.rs index 6ab45f3c..52ede22f 100644 --- a/rust/src/sparse.rs +++ b/rust/src/sparse.rs @@ -23,7 +23,7 @@ impl CsrMatrix { #[inline] pub fn n_rows(&self) -> usize { - self.indptr.len() - 1 + self.indptr.len().saturating_sub(1) } fn to_dok(&self, n_rows: Option) -> DokMatrix { @@ -39,13 +39,13 @@ impl CsrMatrix { DokMatrix { data } } - pub fn add( + pub fn merge( this: &CsrMatrix, other: &CsrMatrix, n_rows: Option, ) -> CsrMatrix { let mut dok_matrix = this.to_dok(n_rows); - dok_matrix.add(other).to_csr() + dok_matrix.merge(other).to_csr() } fn iter(&self) -> CsrMatrixIterator { @@ -97,16 +97,19 @@ where None }; } - let mut index = start; - let index_iter = std::iter::from_fn(move || { - if index < end { - let item = (matrix.indices[index], matrix.data[index]); - index += 1; - Some(item) - } else { - None - } - }); + + // let mut index = start; + // let index_iter = std::iter::from_fn(move || { + // if index < end { + // let item = (matrix.indices[index], matrix.data[index]); + // index += 1; + // Some(item) + // } else { + // None + // } + // }); + + let index_iter = (start..end).map(|i| (matrix.indices[i], matrix.data[i])); Some(Box::new(index_iter)) } @@ -121,7 +124,7 @@ where T: Copy + Eq + Hash + Ord, U: Copy, { - fn add(&mut self, other: &CsrMatrix) -> &Self { + fn merge(&mut self, other: &CsrMatrix) -> &Self { for (i, row) in other.iter().enumerate() { if row.is_empty() { continue; @@ -183,13 +186,13 @@ mod tests { }; // [[1, 0, 0], [1, 0, 2], [3, 3, 0]] - matrix = CsrMatrix::add(&matrix, &matrix_large, Some(3)); + matrix = CsrMatrix::merge(&matrix, &matrix_large, Some(3)); assert_eq!(matrix.indices, vec![0, 0, 2, 0, 1]); assert_eq!(matrix.indptr, vec![0, 1, 3, 5]); assert_eq!(matrix.data, vec![1, 1, 2, 3, 3]); // [[2, 0, 4], [1, 0, 2], [3, 3, 0]] - matrix = CsrMatrix::add(&matrix, &matrix_small, Some(3)); + matrix = CsrMatrix::merge(&matrix, &matrix_small, Some(3)); assert_eq!(matrix.indices, vec![0, 2, 0, 2, 0, 1]); assert_eq!(matrix.indptr, vec![0, 2, 4, 6]); assert_eq!(matrix.data, vec![2, 4, 1, 2, 3, 3]); diff --git a/rust/src/user_cf.rs b/rust/src/user_cf.rs index 4191bd7e..946262b6 100644 --- a/rust/src/user_cf.rs +++ b/rust/src/user_cf.rs @@ -221,12 +221,12 @@ impl PyUserCF { update_by_sims(self.n_users, &cosine_sims, &mut self.sim_mapping)?; // merge interactions for inference on new users/items - self.user_interactions = CsrMatrix::add( + self.user_interactions = CsrMatrix::merge( &self.user_interactions, &new_user_interactions, Some(self.n_users), ); - self.item_interactions = CsrMatrix::add( + self.item_interactions = CsrMatrix::merge( &self.item_interactions, &new_item_interactions, Some(self.n_items), From b0db859bf7568910e75cbd843d27b5800fd81a63 Mon Sep 17 00:00:00 2001 From: massquantity Date: Sun, 6 Jul 2025 18:35:18 +0800 Subject: [PATCH 2/2] [Rust] Fix empty row in indptr --- rust/src/sparse.rs | 38 ++++++++++++++++++++++++++++++-------- 1 file changed, 30 insertions(+), 8 deletions(-) diff --git a/rust/src/sparse.rs b/rust/src/sparse.rs index 52ede22f..26e9c674 100644 --- a/rust/src/sparse.rs +++ b/rust/src/sparse.rs @@ -142,14 +142,14 @@ where let mut indptr: Vec = vec![0]; let mut data: Vec = Vec::new(); for d in self.data.iter() { - if d.is_empty() { - continue; + if !d.is_empty() { + let mut mapping: Vec<(&T, &U)> = d.iter().collect(); + mapping.sort_unstable_by_key(|(i, _)| *i); + let (idx, dat): (Vec, Vec) = mapping.into_iter().unzip(); + indices.extend(idx); + data.extend(dat); } - let mut mapping: Vec<(&T, &U)> = d.iter().collect(); - mapping.sort_unstable_by_key(|(i, _)| *i); - let (idx, dat): (Vec, Vec) = mapping.into_iter().unzip(); - indices.extend(idx); - data.extend(dat); + // ensure keeping empty oov row indptr.push(indices.len()); } CsrMatrix { @@ -214,6 +214,28 @@ mod tests { indptr: vec![0, 0, 2, 4], data: vec![1, 2, 3, 3], }; - CsrMatrix::add(&matrix, &matrix_large, Some(new_size)); + CsrMatrix::merge(&matrix, &matrix_large, Some(new_size)); + } + + #[test] + fn test_add_with_empty_rows() { + // [[0, 0, 0], [0, 0, 1]] + let mut matrix = CsrMatrix { + indices: vec![2], + indptr: vec![0, 0, 1], + data: vec![1], + }; + // [[0, 0, 0], [1, 0, 2], [3, 3, 0]] + let matrix_large = CsrMatrix { + indices: vec![0, 2, 0, 1], + indptr: vec![0, 0, 2, 4], + data: vec![1, 2, 3, 3], + }; + + // [[0, 0, 0], [1, 0, 2], [3, 3, 0]] + matrix = CsrMatrix::merge(&matrix, &matrix_large, Some(3)); + assert_eq!(matrix.indices, vec![0, 2, 0, 1]); + assert_eq!(matrix.indptr, vec![0, 0, 2, 4]); + assert_eq!(matrix.data, vec![1, 2, 3, 3]); } }