diff --git a/training_testing.py b/training_testing.py index 69dcbac..0173096 100644 --- a/training_testing.py +++ b/training_testing.py @@ -899,8 +899,15 @@ def test_simple_model(self, embedding_name, emb_model, tokenizer, simple_model, destination_folder3, "model_result.tsv" ) # Model saved path + self.dish_dicts = dict() + self.gold_alignments = dict() + for dish in dish_list: + dish_dict, dish_group_alignments = fetch_dish(dish, folder, alignment_file, recipe_folder_name, emb_model, tokenizer, device, embedding_name) + self.dish_dicts[dish] = dish_dict + self.gold_alignments[dish] = dish_group_alignments + correct_predictions, num_actions, results_df = self.run_model( self.dish_dicts[dish], self.gold_alignments[dish],