diff --git a/rust/src/swing.rs b/rust/src/swing.rs index 7db144c4..c8a112b2 100644 --- a/rust/src/swing.rs +++ b/rust/src/swing.rs @@ -11,13 +11,11 @@ use crate::sparse::{get_row, CsrMatrix}; #[pyclass(module = "recfarm", name = "Swing")] #[derive(Serialize, Deserialize)] pub struct PySwing { - task: String, top_k: usize, alpha: f32, max_cache_num: usize, n_users: usize, n_items: usize, - cum_swings: FxHashMap, swing_score_mapping: FxHashMap>, user_interactions: CsrMatrix, item_interactions: CsrMatrix, @@ -45,7 +43,6 @@ impl PySwing { #[new] fn new( - task: &str, top_k: usize, alpha: f32, max_cache_num: usize, @@ -60,13 +57,11 @@ impl PySwing { let user_interactions: CsrMatrix = user_interactions.extract()?; let item_interactions: CsrMatrix = item_interactions.extract()?; Ok(Self { - task: task.to_owned(), top_k, alpha, max_cache_num, n_users, n_items, - cum_swings: FxHashMap::default(), swing_score_mapping: FxHashMap::default(), user_interactions, item_interactions, @@ -130,7 +125,7 @@ impl PySwing { if k_nb_swings.is_empty() { self.default_pred } else { - compute_pred(&self.task, &k_nb_swings, &k_nb_labels)? + compute_pred("ranking", &k_nb_swings, &k_nb_labels)? } } _ => self.default_pred, @@ -223,12 +218,11 @@ mod tests { } fn get_swing_model() -> Result> { - let task = "ranking"; let top_k = 10; let alpha = 1.0; let cache_common_num = 100; let n_users = 3; - let n_items = 4; + let n_items = 5; let default_pred = 0.0; let swing = Python::with_gil(|py| -> PyResult { // item_interactions: @@ -263,7 +257,6 @@ mod tests { .into_py_dict(py)?; let mut swing = PySwing::new( - task, top_k, alpha, cache_common_num,