diff --git a/rust/src/graph.rs b/rust/src/graph.rs index 6eae7e31..a9853f35 100644 --- a/rust/src/graph.rs +++ b/rust/src/graph.rs @@ -159,7 +159,11 @@ fn compute_single_swing( let target_i32 = target_item as i32; let users = get_row_vec(item_interactions, target_item); if users.is_empty() { - return (target_i32, Vec::new()); + let scores = prev_scores + .get(&target_i32) + .cloned() + .unwrap_or_default(); + return (target_i32, scores); } let mut item_scores = init_item_scores(target_i32, n_items, prev_scores); diff --git a/rust/src/sparse.rs b/rust/src/sparse.rs index 26e9c674..ad0bd1b8 100644 --- a/rust/src/sparse.rs +++ b/rust/src/sparse.rs @@ -48,7 +48,7 @@ impl CsrMatrix { dok_matrix.merge(other).to_csr() } - fn iter(&self) -> CsrMatrixIterator { + fn iter(&self) -> CsrMatrixIterator<'_, T, U> { CsrMatrixIterator { matrix: self, row_idx: 0, diff --git a/rust/src/swing.rs b/rust/src/swing.rs index 7db144c4..766e64fd 100644 --- a/rust/src/swing.rs +++ b/rust/src/swing.rs @@ -11,13 +11,11 @@ use crate::sparse::{get_row, CsrMatrix}; #[pyclass(module = "recfarm", name = "Swing")] #[derive(Serialize, Deserialize)] pub struct PySwing { - task: String, top_k: usize, alpha: f32, max_cache_num: usize, n_users: usize, n_items: usize, - cum_swings: FxHashMap, swing_score_mapping: FxHashMap>, user_interactions: CsrMatrix, item_interactions: CsrMatrix, @@ -45,7 +43,6 @@ impl PySwing { #[new] fn new( - task: &str, top_k: usize, alpha: f32, max_cache_num: usize, @@ -60,13 +57,11 @@ impl PySwing { let user_interactions: CsrMatrix = user_interactions.extract()?; let item_interactions: CsrMatrix = item_interactions.extract()?; Ok(Self { - task: task.to_owned(), top_k, alpha, max_cache_num, n_users, n_items, - cum_swings: FxHashMap::default(), swing_score_mapping: FxHashMap::default(), user_interactions, item_interactions, @@ -75,11 +70,9 @@ impl PySwing { }) } - fn compute_swing(&mut self, num_threads: usize, update_scores: bool) -> PyResult<()> { + fn compute_swing(&mut self, num_threads: usize) -> PyResult<()> { std::env::set_var("RAYON_NUM_THREADS", format!("{num_threads}")); - if !update_scores { - self.swing_score_mapping.clear(); - } + self.swing_score_mapping.clear(); self.swing_score_mapping = compute_swing_scores( &self.user_interactions, &self.item_interactions, @@ -92,6 +85,57 @@ impl PySwing { Ok(()) } + /// update on new sparse interactions + fn update_swing( + &mut self, + num_threads: usize, + user_interactions: &Bound<'_, PyAny>, + item_interactions: &Bound<'_, PyAny>, + ) -> PyResult<()> { + std::env::set_var("RAYON_NUM_THREADS", format!("{num_threads}")); + let new_user_interactions: CsrMatrix = user_interactions.extract()?; + let new_item_interactions: CsrMatrix = item_interactions.extract()?; + self.swing_score_mapping = compute_swing_scores( + &new_user_interactions, + &new_item_interactions, + &self.swing_score_mapping, + self.n_users, + self.n_items, + self.alpha, + self.max_cache_num, + )?; + + // merge interactions for inference on new users/items + self.user_interactions = CsrMatrix::merge( + &self.user_interactions, + &new_user_interactions, + Some(self.n_users), + ); + self.item_interactions = CsrMatrix::merge( + &self.item_interactions, + &new_item_interactions, + Some(self.n_items), + ); + Ok(()) + } + + // fn get_item_interactions(&self, user: usize) -> PyResult> { + // let start = self.user_interactions.indptr[user]; + // let end = self.user_interactions.indptr[user + 1]; + // let item_interactions = (start..end) + // .map(|i| self.user_interactions.indices[i]) + // .collect(); + // Ok(item_interactions) + // } + // + // fn get_swing_scores(&self, item: i32) -> PyResult> { + // let scores = match self.swing_score_mapping.get(&item).cloned() { + // Some(ss) => ss, + // None => Vec::new(), + // }; + // Ok(scores) + // } + fn num_swing_elements(&self) -> PyResult { if self.swing_score_mapping.is_empty() { return Err(pyo3::exceptions::PyRuntimeError::new_err( @@ -130,7 +174,7 @@ impl PySwing { if k_nb_swings.is_empty() { self.default_pred } else { - compute_pred(&self.task, &k_nb_swings, &k_nb_labels)? + compute_pred("ranking", &k_nb_swings, &k_nb_labels)? } } _ => self.default_pred, @@ -149,8 +193,8 @@ impl PySwing { random_rec: bool, ) -> PyResult<(Vec>, Bound<'py, PyList>)> { let mut recs = Vec::new(); - let mut no_rec_indices = Vec::new(); - for (k, u) in users.iter().enumerate() { + let mut additional_rec_counts = Vec::new(); + for u in users { let u: i32 = u.extract()?; let consumed = self .user_consumed @@ -174,23 +218,25 @@ impl PySwing { } } } + if item_scores.is_empty() { + additional_rec_counts.push(n_rec); recs.push(PyList::empty(py)); - no_rec_indices.push(k); } else { let items = get_rec_items(item_scores, n_rec, random_rec); + additional_rec_counts.push(n_rec - items.len()); recs.push(PyList::new(py, items)?); } } None => { + additional_rec_counts.push(n_rec); recs.push(PyList::empty(py)); - no_rec_indices.push(k); } } } - let no_rec_indices = PyList::new(py, no_rec_indices)?; - Ok((recs, no_rec_indices)) + let additional_rec_counts = PyList::new(py, additional_rec_counts)?; + Ok((recs, additional_rec_counts)) } } @@ -223,12 +269,11 @@ mod tests { } fn get_swing_model() -> Result> { - let task = "ranking"; let top_k = 10; let alpha = 1.0; let cache_common_num = 100; let n_users = 3; - let n_items = 4; + let n_items = 5; let default_pred = 0.0; let swing = Python::with_gil(|py| -> PyResult { // item_interactions: @@ -263,7 +308,6 @@ mod tests { .into_py_dict(py)?; let mut swing = PySwing::new( - task, top_k, alpha, cache_common_num, @@ -274,7 +318,7 @@ mod tests { &user_consumed, default_pred, )?; - swing.compute_swing(2, false)?; + swing.compute_swing(2)?; Ok(swing) })?; Ok(swing) @@ -311,7 +355,9 @@ mod tests { fn test_save_model() -> Result<(), Box> { pyo3::prepare_freethreaded_python(); let model = get_swing_model()?; - let cur_dir = std::env::current_dir()?.to_string_lossy().to_string(); + let cur_dir = std::env::current_dir()? + .to_string_lossy() + .to_string(); let model_name = "swing_model"; save(&model, &cur_dir, model_name)?;