diff --git a/libreco/algorithms/swing.py b/libreco/algorithms/swing.py index 7c9430df..1beae679 100644 --- a/libreco/algorithms/swing.py +++ b/libreco/algorithms/swing.py @@ -79,19 +79,26 @@ def fit( self.show_start_time() user_interacts = build_sparse(train_data.sparse_interaction) item_interacts = build_sparse(train_data.sparse_interaction, transpose=True) - self.rs_model = recfarm.Swing( - self.top_k, - self.alpha, - self.max_cache_num, - self.n_users, - self.n_items, - user_interacts, - item_interacts, - self.user_consumed, - self.default_pred, - ) - with time_block("swing computing", verbose=1): - self.rs_model.compute_swing(self.num_threads, self.incremental) + if self.incremental: + assert self.rs_model is not None + with time_block("update swing", verbose=1): + self.rs_model.update_swing( + self.num_threads, user_interacts, item_interacts + ) + else: + self.rs_model = recfarm.Swing( + self.top_k, + self.alpha, + self.max_cache_num, + self.n_users, + self.n_items, + user_interacts, + item_interacts, + self.user_consumed, + self.default_pred, + ) + with time_block("swing computing", verbose=1): + self.rs_model.compute_swing(self.num_threads) num = self.rs_model.num_swing_elements() density_ratio = 100 * num / (self.n_items * self.n_items) @@ -137,17 +144,21 @@ def recommend_user( result_recs[u] = popular_recommendations( self.data_info, inner_id, n_rec ) + if user_ids: - computed_recs, no_rec_indices = self.rs_model.recommend( + computed_recs, additional_rec_counts = self.rs_model.recommend( user_ids, n_rec, filter_consumed, random_rec, ) - for i in no_rec_indices: - computed_recs[i] = popular_recommendations( - self.data_info, inner_id=True, n_rec=n_rec - ) + for rec, arc in zip(computed_recs, additional_rec_counts): + if arc > 0: + additional_recs = popular_recommendations( + self.data_info, inner_id=True, n_rec=arc + ) + rec.extend(additional_recs) + user_recs = construct_rec(self.data_info, user_ids, computed_recs, inner_id) result_recs.update(user_recs) return result_recs