diff --git a/TODO.md b/TODO.md index adb0382..fc05dc2 100644 --- a/TODO.md +++ b/TODO.md @@ -189,6 +189,12 @@ After completing a milestone, create a pull request with your changes for review - [x] Surface error messages in UI pages using `st.error` - [x] Add tests for new error handling in data and transform modules +## PR18: Upload Helpers Refactor + +- [x] Implement helper functions in `utils/data.py` for uploading/validating files +- [x] Provide wrapper storing uploaded data in `st.session_state` +- [x] Replace repetitive upload code across pages + ## Notes for Development - Create comprehensive commit messages that clearly describe changes diff --git a/pages/data_explorer.py b/pages/data_explorer.py index da68be0..0e80380 100644 --- a/pages/data_explorer.py +++ b/pages/data_explorer.py @@ -26,10 +26,11 @@ def main() -> None: with st.sidebar: st.header("Data Options") - uploaded_file = st.file_uploader( + data_utils.upload_data_to_session( "Upload CSV or Excel file", - type=["csv", "xlsx", "xls"], - key="file_uploader", + session_key="data", + datetime_key="datetime_cols", + uploader_key="file_uploader", help="Supported formats: CSV, XLSX", ) @@ -51,15 +52,7 @@ def main() -> None: with st.expander("Help"): st.markdown(ui.help_markdown()) - if uploaded_file is not None: - try: - df = data_utils.load_data(uploaded_file) - df = data_utils.convert_dtypes(df) - st.session_state["data"] = df - st.session_state["datetime_cols"] = eda.detect_datetime_columns(df) - st.success("File loaded successfully!") - except (ValueError, TypeError) as exc: - st.error(f"Failed to load file: {exc}") + data = st.session_state.get("data") if data is not None: diff --git a/pages/prediction.py b/pages/prediction.py index 82e87b2..004452b 100644 --- a/pages/prediction.py +++ b/pages/prediction.py @@ -22,8 +22,14 @@ def main() -> None: with st.sidebar: mode = st.radio("Mode", ["Single", "Batch"], key="pred_mode") - model_file = st.file_uploader("Upload Model (.joblib)", type=["joblib"], key="model_file") - data_file = st.file_uploader("Upload Data", type=["csv", "xlsx", "xls"], key="pred_data") + model_file = st.file_uploader( + "Upload Model (.joblib)", type=["joblib"], key="model_file" + ) + data_utils.upload_data_to_session( + "Upload Data", + session_key="pred_data", + uploader_key="pred_data", + ) model_obj = None if model_file is not None: @@ -32,14 +38,6 @@ def main() -> None: tmp.flush() model_obj = predict.load_model(Path(tmp.name)) - if data_file is not None: - try: - df = data_utils.load_data(data_file) - df = data_utils.convert_dtypes(df) - st.session_state["pred_data"] = df - except (ValueError, TypeError) as exc: - st.error(f"Failed to load data: {exc}") - data = st.session_state.get("pred_data") if model_obj is not None and data is not None: diff --git a/pages/report.py b/pages/report.py index 61ea129..5e9d0f3 100644 --- a/pages/report.py +++ b/pages/report.py @@ -21,17 +21,13 @@ def main() -> None: st.title("Report Generator") with st.sidebar: - data_file = st.file_uploader( - "Upload Data", type=["csv", "xlsx", "xls"], key="report_data" + data_utils.upload_data_to_session( + "Upload Data", + session_key="report_data", + uploader_key="report_data", ) - if data_file is not None: - try: - df = data_utils.load_data(data_file) - df = data_utils.convert_dtypes(df) - st.session_state["report_data"] = df - except (ValueError, TypeError) as exc: - st.error(f"Failed to load data: {exc}") + df = st.session_state.get("report_data") if df is not None: diff --git a/tests/test_data_utils.py b/tests/test_data_utils.py index 04e0de4..d3e5687 100644 --- a/tests/test_data_utils.py +++ b/tests/test_data_utils.py @@ -61,3 +61,40 @@ def test_data_summary_empty(): with pytest.raises(ValueError): data.data_summary(pd.DataFrame()) + +def test_validate_file_type(tmp_path): + file = tmp_path / "a.csv" + file.write_text("a,b\n1,2") + ext = data.validate_file_type(file, ["csv", "xlsx"]) + assert ext == ".csv" + + +def test_validate_file_type_invalid(tmp_path): + file = tmp_path / "a.txt" + file.write_text("x") + with pytest.raises(ValueError): + data.validate_file_type(file, ["csv"]) + + +def test_process_uploaded_file(tmp_path): + import streamlit as st + + df = pd.DataFrame({"a": [1]}) + file = tmp_path / "test.csv" + df.to_csv(file, index=False) + st.session_state.clear() + data.process_uploaded_file(file, session_key="up") + pd.testing.assert_frame_equal(st.session_state["up"], df) + + +def test_upload_data_to_session(monkeypatch, tmp_path): + import streamlit as st + + df = pd.DataFrame({"a": [1]}) + file = tmp_path / "test.csv" + df.to_csv(file, index=False) + st.session_state.clear() + monkeypatch.setattr(st, "file_uploader", lambda *a, **k: file) + data.upload_data_to_session("Upload", session_key="foo") + pd.testing.assert_frame_equal(st.session_state["foo"], df) + diff --git a/utils/data.py b/utils/data.py index c627b5c..393a589 100644 --- a/utils/data.py +++ b/utils/data.py @@ -3,7 +3,9 @@ from __future__ import annotations from pathlib import Path -from typing import Any +from typing import Any, Iterable + +import streamlit as st import pandas as pd @@ -50,3 +52,70 @@ def data_summary(df: pd.DataFrame) -> pd.DataFrame: if df.empty: raise ValueError("DataFrame is empty") return df.describe(include="all") + + +def validate_file_type(file: Any, allowed_types: Iterable[str]) -> str: + """Return the lowercase file extension if allowed.""" + if file is None: + raise ValueError("No file provided") + + if hasattr(file, "name"): + ext = Path(file.name).suffix.lower() + else: + path = Path(str(file)) + ext = path.suffix.lower() + if not path.exists(): + raise ValueError("Invalid file path") + + if ext not in {f".{t.lstrip('.').lower()}" for t in allowed_types}: + raise ValueError(f"Unsupported file type: {ext}") + return ext + + +def process_uploaded_file( + uploaded_file: Any, + *, + session_key: str, + detect_datetime: bool = False, + datetime_key: str = "datetime_cols", +) -> pd.DataFrame | None: + """Load an uploaded file and store the DataFrame in session state.""" + if uploaded_file is None: + return None + try: + _ = validate_file_type(uploaded_file, ["csv", "xls", "xlsx"]) + df = load_data(uploaded_file) + df = convert_dtypes(df) + except (ValueError, TypeError) as exc: # pragma: no cover - tested via wrapper + st.error(f"Failed to load file: {exc}") + return None + st.session_state[session_key] = df + if detect_datetime: + from . import eda + + st.session_state[datetime_key] = eda.detect_datetime_columns(df) + return df + + +def upload_data_to_session( + label: str, + *, + session_key: str, + datetime_key: str | None = None, + uploader_key: str | None = None, + help: str | None = None, + types: Iterable[str] = ("csv", "xlsx", "xls"), +) -> pd.DataFrame | None: + """Upload a file and store the loaded DataFrame in session state.""" + file = st.file_uploader( + label, + type=list(types), + key=uploader_key or f"{session_key}_uploader", + help=help, + ) + return process_uploaded_file( + file, + session_key=session_key, + detect_datetime=datetime_key is not None, + datetime_key=datetime_key or "datetime_cols", + )