diff --git a/mcp_server/data_manipulator_agent.py b/mcp_server/data_manipulator_agent.py index 57ef98b..9afa774 100644 --- a/mcp_server/data_manipulator_agent.py +++ b/mcp_server/data_manipulator_agent.py @@ -30,8 +30,8 @@ def plan_manipulation(self, user_query: str, available_files: list[str], raw_dat "You are a Senior Data Engineering Agent.\n" "Your Job: Listen to the user's request and decide how to manipulate their datasets.\n" "Currently, you have access to the following tools:\n" - "1. 'merge_datasets' (Parameters: file_paths (list of strings), output_filename (string), join_column (string))\n" - " - Use 'join_column' if the user specifies a column to merge on (e.g., 'ID', 'Subject'). Provide an empty string if unknown.\n" + "1. 'merge_datasets' (Parameters: file_paths (list of strings), output_filename (string), join_columns (list of strings))\n" + " - Use 'join_columns' if the user specifies one or more columns to merge on (e.g., ['Case', 'age']). Provide an empty list if unknown.\n" "You must return ONLY a JSON object." ) @@ -48,7 +48,7 @@ def plan_manipulation(self, user_query: str, available_files: list[str], raw_dat "parameters": {{ "file_paths": ["file_A.csv", "file_B.csv"], "output_filename": "merged_output.csv", - "join_column": "ID" + "join_columns": ["ID"] }}, "explanation": "Merging file A and B based on the user request." }} @@ -62,7 +62,10 @@ def plan_manipulation(self, user_query: str, available_files: list[str], raw_dat # Execute the merge immediately on the backend if merge_datasets was selected if parsed_result.tool_to_call == "merge_datasets": file_paths = parsed_result.parameters.get("file_paths", []) - join_column = parsed_result.parameters.get("join_column", "") + join_columns = parsed_result.parameters.get("join_columns", []) + + if isinstance(join_columns, str): + join_columns = [join_columns] if join_columns else [] # Robust matching: If the LLM returned nothing, but we only have 2 files in the context, just use them. if not file_paths and len(raw_datasets) >= 2: @@ -72,7 +75,15 @@ def plan_manipulation(self, user_query: str, available_files: list[str], raw_dat # Fetch data from the provided raw_datasets mapped from the frontend dfs = [] + seen_keys = set() + + # Pre-filter file_paths to prevent LLM hallucinating the same file twice + unique_fps = [] for fp in file_paths: + if fp not in unique_fps: + unique_fps.append(fp) + + for fp in unique_fps: matched_key = None # Exact match if fp in raw_datasets: @@ -84,49 +95,88 @@ def plan_manipulation(self, user_query: str, available_files: list[str], raw_dat matched_key = raw_k break - if matched_key: + if matched_key and matched_key not in seen_keys: # Convert frontend dict records back to pandas dataframe df = pd.DataFrame(raw_datasets[matched_key]) dfs.append(df) - else: + seen_keys.add(matched_key) + elif not matched_key: parsed_result.explanation += f" (Warning: File '{fp}' not found in uploaded dataset context)" - # Fallback: if we still didn't find at least 2 datasets, but the context has exactly 2, just use them. - if len(dfs) < 2 and len(raw_datasets) == 2: - dfs = [pd.DataFrame(data) for data in raw_datasets.values()] - parsed_result.explanation += " (Fallback: Merged all available files because specific matches failed.)" + # Fallback: if we still didn't find at least 2 datasets, but the context has 2+ files, just grab the first two. + if len(dfs) < 2 and len(raw_datasets) >= 2: + dfs = [] + for k, data in list(raw_datasets.items())[:2]: + dfs.append(pd.DataFrame(data)) + parsed_result.explanation += " (Fallback: Merged available files because specific matches failed.)" if len(dfs) >= 2: try: - # Try to infer a join column - potential_ids = ['ID', 'id', 'Subject', 'subject', 'RID', 'rid', 'Participant_ID', 'participant_id', 'Case', 'case'] + # Helper to find a fuzzy column match in a dataframe + def find_col(df, query_col): + if query_col in df.columns: + return query_col + # Fuzzy check: case insensitive or substring + for c in df.columns: + if query_col.lower() == c.lower() or query_col.lower() in c.lower() or c.lower() in query_col.lower(): + return c + return None + + # Match join_columns for each dataframe independently + df_join_keys = [[] for _ in dfs] + valid_merge = False + + if join_columns: + valid_merge = True + for req_col in join_columns: + for i, df in enumerate(dfs): + matched_col = find_col(df, req_col) + if matched_col: + df_join_keys[i].append(matched_col) + else: + valid_merge = False # Missing column in at least one DF - valid_join_col = None - if join_column and all(join_column in df.columns for df in dfs): - valid_join_col = join_column - else: - for cand in potential_ids: - if all(cand in df.columns for df in dfs): - valid_join_col = cand - break + # The Universal Fix: Dynamic Intersection + # If no explicit join_columns were provided or they failed, find the exact intersection of all columns. + if not valid_merge: + # Find common columns across all dataframes (case-insensitive intersection) + common_cols_lower = set(col.lower() for col in dfs[0].columns) + for df in dfs[1:]: + common_cols_lower = common_cols_lower.intersection(set(col.lower() for col in df.columns)) + + if common_cols_lower: + valid_merge = True + df_join_keys = [[] for _ in dfs] + # Map the lowercased common columns back to their original case for each dataframe + for c_lower in common_cols_lower: + for i, df in enumerate(dfs): + for orig_col in df.columns: + if orig_col.lower() == c_lower: + df_join_keys[i].append(orig_col) + break - if valid_join_col: - # Merge using inner/outer join based on the column + if valid_merge: merged_df = dfs[0] - for df in dfs[1:]: - merged_df = pd.merge(merged_df, df, on=valid_join_col, how='outer', suffixes=('', '_dup')) - # remove duplicate columns - cols_to_drop = [c for c in merged_df.columns if c.endswith('_dup')] - merged_df.drop(columns=cols_to_drop, inplace=True) + left_keys = df_join_keys[0] + + for i, df in enumerate(dfs[1:], start=1): + right_keys = df_join_keys[i] + + # Perform merge mapping left keys to right keys + merged_df = pd.merge(merged_df, df, left_on=left_keys, right_on=right_keys, how='outer', suffixes=('', f'_file{i+1}')) + + # If the right keys had a different name than the left keys, drop the redundant right key column + for l_key, r_key in zip(left_keys, right_keys): + if l_key != r_key and r_key in merged_df.columns: + merged_df.drop(columns=[r_key], inplace=True) - # Move ID column to front - cols = [valid_join_col] + [c for c in merged_df.columns if c != valid_join_col] + # Move join columns to front + cols = left_keys + [c for c in merged_df.columns if c not in left_keys] merged_df = merged_df[cols] else: - # Fallback: concatenate - merged_df = pd.concat(dfs, axis=1) - # Remove duplicate columns if they arose from concat - merged_df = merged_df.loc[:, ~merged_df.columns.duplicated()] + # If absolutely no columns match, attempting to concat horizontally is dangerous for medical data + # because it assumes perfect row alignment. We will reject the merge to prevent silent data corruption. + raise ValueError("Cannot merge datasets: No common columns found to join on, and horizontal concatenation is unsafe.") # Convert back to CSV string to send to frontend (na_rep outputs empty string for NaNs) parsed_result.merged_csv_data = merged_df.to_csv(index=False, na_rep="") @@ -203,4 +253,4 @@ async def manipulate_endpoint(request: ManipulationRequest): # The team's ports: Planner=8011, Executor=8012, Researcher=9013, Validator=8014. # Use 8015 for the Manipulator print("Starting Data Manipulator A2A Server on Port 8015") - uvicorn.run(app, host="0.0.0.0", port=8015) \ No newline at end of file + uvicorn.run(app, host="0.0.0.0", port=8015)