Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions rust/src/item_cf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -228,12 +228,12 @@ impl PyItemCF {
update_by_sims(self.n_items, &cosine_sims, &mut self.sim_mapping)?;

// merge interactions for inference on new users/items
self.user_interactions = CsrMatrix::add(
self.user_interactions = CsrMatrix::merge(
&self.user_interactions,
&new_user_interactions,
Some(self.n_users),
);
self.item_interactions = CsrMatrix::add(
self.item_interactions = CsrMatrix::merge(
&self.item_interactions,
&new_item_interactions,
Some(self.n_items),
Expand Down
73 changes: 49 additions & 24 deletions rust/src/sparse.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ impl<T: Copy + Eq + Hash + Ord, U: Copy> CsrMatrix<T, U> {

#[inline]
pub fn n_rows(&self) -> usize {
self.indptr.len() - 1
self.indptr.len().saturating_sub(1)
}

fn to_dok(&self, n_rows: Option<usize>) -> DokMatrix<T, U> {
Expand All @@ -39,13 +39,13 @@ impl<T: Copy + Eq + Hash + Ord, U: Copy> CsrMatrix<T, U> {
DokMatrix { data }
}

pub fn add(
pub fn merge(
this: &CsrMatrix<T, U>,
other: &CsrMatrix<T, U>,
n_rows: Option<usize>,
) -> CsrMatrix<T, U> {
let mut dok_matrix = this.to_dok(n_rows);
dok_matrix.add(other).to_csr()
dok_matrix.merge(other).to_csr()
}

fn iter(&self) -> CsrMatrixIterator<T, U> {
Expand Down Expand Up @@ -97,16 +97,19 @@ where
None
};
}
let mut index = start;
let index_iter = std::iter::from_fn(move || {
if index < end {
let item = (matrix.indices[index], matrix.data[index]);
index += 1;
Some(item)
} else {
None
}
});

// let mut index = start;
// let index_iter = std::iter::from_fn(move || {
// if index < end {
// let item = (matrix.indices[index], matrix.data[index]);
// index += 1;
// Some(item)
// } else {
// None
// }
// });

let index_iter = (start..end).map(|i| (matrix.indices[i], matrix.data[i]));
Some(Box::new(index_iter))
}

Expand All @@ -121,7 +124,7 @@ where
T: Copy + Eq + Hash + Ord,
U: Copy,
{
fn add(&mut self, other: &CsrMatrix<T, U>) -> &Self {
fn merge(&mut self, other: &CsrMatrix<T, U>) -> &Self {
for (i, row) in other.iter().enumerate() {
if row.is_empty() {
continue;
Expand All @@ -139,14 +142,14 @@ where
let mut indptr: Vec<usize> = vec![0];
let mut data: Vec<U> = Vec::new();
for d in self.data.iter() {
if d.is_empty() {
continue;
if !d.is_empty() {
let mut mapping: Vec<(&T, &U)> = d.iter().collect();
mapping.sort_unstable_by_key(|(i, _)| *i);
let (idx, dat): (Vec<T>, Vec<U>) = mapping.into_iter().unzip();
indices.extend(idx);
data.extend(dat);
}
let mut mapping: Vec<(&T, &U)> = d.iter().collect();
mapping.sort_unstable_by_key(|(i, _)| *i);
let (idx, dat): (Vec<T>, Vec<U>) = mapping.into_iter().unzip();
indices.extend(idx);
data.extend(dat);
// ensure keeping empty oov row
indptr.push(indices.len());
}
CsrMatrix {
Expand Down Expand Up @@ -183,13 +186,13 @@ mod tests {
};

// [[1, 0, 0], [1, 0, 2], [3, 3, 0]]
matrix = CsrMatrix::add(&matrix, &matrix_large, Some(3));
matrix = CsrMatrix::merge(&matrix, &matrix_large, Some(3));
assert_eq!(matrix.indices, vec![0, 0, 2, 0, 1]);
assert_eq!(matrix.indptr, vec![0, 1, 3, 5]);
assert_eq!(matrix.data, vec![1, 1, 2, 3, 3]);

// [[2, 0, 4], [1, 0, 2], [3, 3, 0]]
matrix = CsrMatrix::add(&matrix, &matrix_small, Some(3));
matrix = CsrMatrix::merge(&matrix, &matrix_small, Some(3));
assert_eq!(matrix.indices, vec![0, 2, 0, 2, 0, 1]);
assert_eq!(matrix.indptr, vec![0, 2, 4, 6]);
assert_eq!(matrix.data, vec![2, 4, 1, 2, 3, 3]);
Expand All @@ -211,6 +214,28 @@ mod tests {
indptr: vec![0, 0, 2, 4],
data: vec![1, 2, 3, 3],
};
CsrMatrix::add(&matrix, &matrix_large, Some(new_size));
CsrMatrix::merge(&matrix, &matrix_large, Some(new_size));
}

#[test]
fn test_add_with_empty_rows() {
// [[0, 0, 0], [0, 0, 1]]
let mut matrix = CsrMatrix {
indices: vec![2],
indptr: vec![0, 0, 1],
data: vec![1],
};
// [[0, 0, 0], [1, 0, 2], [3, 3, 0]]
let matrix_large = CsrMatrix {
indices: vec![0, 2, 0, 1],
indptr: vec![0, 0, 2, 4],
data: vec![1, 2, 3, 3],
};

// [[0, 0, 0], [1, 0, 2], [3, 3, 0]]
matrix = CsrMatrix::merge(&matrix, &matrix_large, Some(3));
assert_eq!(matrix.indices, vec![0, 2, 0, 1]);
assert_eq!(matrix.indptr, vec![0, 0, 2, 4]);
assert_eq!(matrix.data, vec![1, 2, 3, 3]);
}
}
4 changes: 2 additions & 2 deletions rust/src/user_cf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -221,12 +221,12 @@ impl PyUserCF {
update_by_sims(self.n_users, &cosine_sims, &mut self.sim_mapping)?;

// merge interactions for inference on new users/items
self.user_interactions = CsrMatrix::add(
self.user_interactions = CsrMatrix::merge(
&self.user_interactions,
&new_user_interactions,
Some(self.n_users),
);
self.item_interactions = CsrMatrix::add(
self.item_interactions = CsrMatrix::merge(
&self.item_interactions,
&new_item_interactions,
Some(self.n_items),
Expand Down
Loading