Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
118 changes: 84 additions & 34 deletions mcp_server/data_manipulator_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@ def plan_manipulation(self, user_query: str, available_files: list[str], raw_dat
"You are a Senior Data Engineering Agent.\n"
"Your Job: Listen to the user's request and decide how to manipulate their datasets.\n"
"Currently, you have access to the following tools:\n"
"1. 'merge_datasets' (Parameters: file_paths (list of strings), output_filename (string), join_column (string))\n"
" - Use 'join_column' if the user specifies a column to merge on (e.g., 'ID', 'Subject'). Provide an empty string if unknown.\n"
"1. 'merge_datasets' (Parameters: file_paths (list of strings), output_filename (string), join_columns (list of strings))\n"
" - Use 'join_columns' if the user specifies one or more columns to merge on (e.g., ['Case', 'age']). Provide an empty list if unknown.\n"
"You must return ONLY a JSON object."
)

Expand All @@ -48,7 +48,7 @@ def plan_manipulation(self, user_query: str, available_files: list[str], raw_dat
"parameters": {{
"file_paths": ["file_A.csv", "file_B.csv"],
"output_filename": "merged_output.csv",
"join_column": "ID"
"join_columns": ["ID"]
}},
"explanation": "Merging file A and B based on the user request."
}}
Expand All @@ -62,7 +62,10 @@ def plan_manipulation(self, user_query: str, available_files: list[str], raw_dat
# Execute the merge immediately on the backend if merge_datasets was selected
if parsed_result.tool_to_call == "merge_datasets":
file_paths = parsed_result.parameters.get("file_paths", [])
join_column = parsed_result.parameters.get("join_column", "")
join_columns = parsed_result.parameters.get("join_columns", [])

if isinstance(join_columns, str):
join_columns = [join_columns] if join_columns else []

# Robust matching: If the LLM returned nothing, but we only have 2 files in the context, just use them.
if not file_paths and len(raw_datasets) >= 2:
Expand All @@ -72,7 +75,15 @@ def plan_manipulation(self, user_query: str, available_files: list[str], raw_dat

# Fetch data from the provided raw_datasets mapped from the frontend
dfs = []
seen_keys = set()

# Pre-filter file_paths to prevent LLM hallucinating the same file twice
unique_fps = []
for fp in file_paths:
if fp not in unique_fps:
unique_fps.append(fp)

for fp in unique_fps:
matched_key = None
# Exact match
if fp in raw_datasets:
Expand All @@ -84,49 +95,88 @@ def plan_manipulation(self, user_query: str, available_files: list[str], raw_dat
matched_key = raw_k
break

if matched_key:
if matched_key and matched_key not in seen_keys:
# Convert frontend dict records back to pandas dataframe
df = pd.DataFrame(raw_datasets[matched_key])
dfs.append(df)
else:
seen_keys.add(matched_key)
elif not matched_key:
parsed_result.explanation += f" (Warning: File '{fp}' not found in uploaded dataset context)"

# Fallback: if we still didn't find at least 2 datasets, but the context has exactly 2, just use them.
if len(dfs) < 2 and len(raw_datasets) == 2:
dfs = [pd.DataFrame(data) for data in raw_datasets.values()]
parsed_result.explanation += " (Fallback: Merged all available files because specific matches failed.)"
# Fallback: if we still didn't find at least 2 datasets, but the context has 2+ files, just grab the first two.
if len(dfs) < 2 and len(raw_datasets) >= 2:
dfs = []
for k, data in list(raw_datasets.items())[:2]:
dfs.append(pd.DataFrame(data))
parsed_result.explanation += " (Fallback: Merged available files because specific matches failed.)"

if len(dfs) >= 2:
try:
# Try to infer a join column
potential_ids = ['ID', 'id', 'Subject', 'subject', 'RID', 'rid', 'Participant_ID', 'participant_id', 'Case', 'case']
# Helper to find a fuzzy column match in a dataframe
def find_col(df, query_col):
if query_col in df.columns:
return query_col
# Fuzzy check: case insensitive or substring
for c in df.columns:
if query_col.lower() == c.lower() or query_col.lower() in c.lower() or c.lower() in query_col.lower():
return c
return None

# Match join_columns for each dataframe independently
df_join_keys = [[] for _ in dfs]
valid_merge = False

if join_columns:
valid_merge = True
for req_col in join_columns:
for i, df in enumerate(dfs):
matched_col = find_col(df, req_col)
if matched_col:
df_join_keys[i].append(matched_col)
else:
valid_merge = False # Missing column in at least one DF

valid_join_col = None
if join_column and all(join_column in df.columns for df in dfs):
valid_join_col = join_column
else:
for cand in potential_ids:
if all(cand in df.columns for df in dfs):
valid_join_col = cand
break
# The Universal Fix: Dynamic Intersection
# If no explicit join_columns were provided or they failed, find the exact intersection of all columns.
if not valid_merge:
# Find common columns across all dataframes (case-insensitive intersection)
common_cols_lower = set(col.lower() for col in dfs[0].columns)
for df in dfs[1:]:
common_cols_lower = common_cols_lower.intersection(set(col.lower() for col in df.columns))

if common_cols_lower:
valid_merge = True
df_join_keys = [[] for _ in dfs]
# Map the lowercased common columns back to their original case for each dataframe
for c_lower in common_cols_lower:
for i, df in enumerate(dfs):
for orig_col in df.columns:
if orig_col.lower() == c_lower:
df_join_keys[i].append(orig_col)
break

if valid_join_col:
# Merge using inner/outer join based on the column
if valid_merge:
merged_df = dfs[0]
for df in dfs[1:]:
merged_df = pd.merge(merged_df, df, on=valid_join_col, how='outer', suffixes=('', '_dup'))
# remove duplicate columns
cols_to_drop = [c for c in merged_df.columns if c.endswith('_dup')]
merged_df.drop(columns=cols_to_drop, inplace=True)
left_keys = df_join_keys[0]

for i, df in enumerate(dfs[1:], start=1):
right_keys = df_join_keys[i]

# Perform merge mapping left keys to right keys
merged_df = pd.merge(merged_df, df, left_on=left_keys, right_on=right_keys, how='outer', suffixes=('', f'_file{i+1}'))

# If the right keys had a different name than the left keys, drop the redundant right key column
for l_key, r_key in zip(left_keys, right_keys):
if l_key != r_key and r_key in merged_df.columns:
merged_df.drop(columns=[r_key], inplace=True)

# Move ID column to front
cols = [valid_join_col] + [c for c in merged_df.columns if c != valid_join_col]
# Move join columns to front
cols = left_keys + [c for c in merged_df.columns if c not in left_keys]
merged_df = merged_df[cols]
else:
# Fallback: concatenate
merged_df = pd.concat(dfs, axis=1)
# Remove duplicate columns if they arose from concat
merged_df = merged_df.loc[:, ~merged_df.columns.duplicated()]
# If absolutely no columns match, attempting to concat horizontally is dangerous for medical data
# because it assumes perfect row alignment. We will reject the merge to prevent silent data corruption.
raise ValueError("Cannot merge datasets: No common columns found to join on, and horizontal concatenation is unsafe.")

# Convert back to CSV string to send to frontend (na_rep outputs empty string for NaNs)
parsed_result.merged_csv_data = merged_df.to_csv(index=False, na_rep="")
Expand Down Expand Up @@ -203,4 +253,4 @@ async def manipulate_endpoint(request: ManipulationRequest):
# The team's ports: Planner=8011, Executor=8012, Researcher=9013, Validator=8014.
# Use 8015 for the Manipulator
print("Starting Data Manipulator A2A Server on Port 8015")
uvicorn.run(app, host="0.0.0.0", port=8015)
uvicorn.run(app, host="0.0.0.0", port=8015)