Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions TODO.md
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,12 @@ After completing a milestone, create a pull request with your changes for review
- [x] Surface error messages in UI pages using `st.error`
- [x] Add tests for new error handling in data and transform modules

## PR18: Upload Helpers Refactor

- [x] Implement helper functions in `utils/data.py` for uploading/validating files
- [x] Provide wrapper storing uploaded data in `st.session_state`
- [x] Replace repetitive upload code across pages

## Notes for Development

- Create comprehensive commit messages that clearly describe changes
Expand Down
17 changes: 5 additions & 12 deletions pages/data_explorer.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,11 @@ def main() -> None:

with st.sidebar:
st.header("Data Options")
uploaded_file = st.file_uploader(
data_utils.upload_data_to_session(
"Upload CSV or Excel file",
type=["csv", "xlsx", "xls"],
key="file_uploader",
session_key="data",
datetime_key="datetime_cols",
uploader_key="file_uploader",
help="Supported formats: CSV, XLSX",
)

Expand All @@ -51,15 +52,7 @@ def main() -> None:
with st.expander("Help"):
st.markdown(ui.help_markdown())

if uploaded_file is not None:
try:
df = data_utils.load_data(uploaded_file)
df = data_utils.convert_dtypes(df)
st.session_state["data"] = df
st.session_state["datetime_cols"] = eda.detect_datetime_columns(df)
st.success("File loaded successfully!")
except (ValueError, TypeError) as exc:
st.error(f"Failed to load file: {exc}")


data = st.session_state.get("data")
if data is not None:
Expand Down
18 changes: 8 additions & 10 deletions pages/prediction.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,14 @@ def main() -> None:

with st.sidebar:
mode = st.radio("Mode", ["Single", "Batch"], key="pred_mode")
model_file = st.file_uploader("Upload Model (.joblib)", type=["joblib"], key="model_file")
data_file = st.file_uploader("Upload Data", type=["csv", "xlsx", "xls"], key="pred_data")
model_file = st.file_uploader(
"Upload Model (.joblib)", type=["joblib"], key="model_file"
)
data_utils.upload_data_to_session(
"Upload Data",
session_key="pred_data",
uploader_key="pred_data",
)

model_obj = None
if model_file is not None:
Expand All @@ -32,14 +38,6 @@ def main() -> None:
tmp.flush()
model_obj = predict.load_model(Path(tmp.name))

if data_file is not None:
try:
df = data_utils.load_data(data_file)
df = data_utils.convert_dtypes(df)
st.session_state["pred_data"] = df
except (ValueError, TypeError) as exc:
st.error(f"Failed to load data: {exc}")

data = st.session_state.get("pred_data")

if model_obj is not None and data is not None:
Expand Down
14 changes: 5 additions & 9 deletions pages/report.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,17 +21,13 @@ def main() -> None:
st.title("Report Generator")

with st.sidebar:
data_file = st.file_uploader(
"Upload Data", type=["csv", "xlsx", "xls"], key="report_data"
data_utils.upload_data_to_session(
"Upload Data",
session_key="report_data",
uploader_key="report_data",
)

if data_file is not None:
try:
df = data_utils.load_data(data_file)
df = data_utils.convert_dtypes(df)
st.session_state["report_data"] = df
except (ValueError, TypeError) as exc:
st.error(f"Failed to load data: {exc}")


df = st.session_state.get("report_data")
if df is not None:
Expand Down
37 changes: 37 additions & 0 deletions tests/test_data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,3 +61,40 @@ def test_data_summary_empty():
with pytest.raises(ValueError):
data.data_summary(pd.DataFrame())


def test_validate_file_type(tmp_path):
file = tmp_path / "a.csv"
file.write_text("a,b\n1,2")
ext = data.validate_file_type(file, ["csv", "xlsx"])
assert ext == ".csv"


def test_validate_file_type_invalid(tmp_path):
file = tmp_path / "a.txt"
file.write_text("x")
with pytest.raises(ValueError):
data.validate_file_type(file, ["csv"])


def test_process_uploaded_file(tmp_path):
import streamlit as st

df = pd.DataFrame({"a": [1]})
file = tmp_path / "test.csv"
df.to_csv(file, index=False)
st.session_state.clear()
data.process_uploaded_file(file, session_key="up")
pd.testing.assert_frame_equal(st.session_state["up"], df)


def test_upload_data_to_session(monkeypatch, tmp_path):
import streamlit as st

df = pd.DataFrame({"a": [1]})
file = tmp_path / "test.csv"
df.to_csv(file, index=False)
st.session_state.clear()
monkeypatch.setattr(st, "file_uploader", lambda *a, **k: file)
data.upload_data_to_session("Upload", session_key="foo")
pd.testing.assert_frame_equal(st.session_state["foo"], df)

71 changes: 70 additions & 1 deletion utils/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@
from __future__ import annotations

from pathlib import Path
from typing import Any
from typing import Any, Iterable

import streamlit as st

import pandas as pd

Expand Down Expand Up @@ -50,3 +52,70 @@ def data_summary(df: pd.DataFrame) -> pd.DataFrame:
if df.empty:
raise ValueError("DataFrame is empty")
return df.describe(include="all")


def validate_file_type(file: Any, allowed_types: Iterable[str]) -> str:
"""Return the lowercase file extension if allowed."""
if file is None:
raise ValueError("No file provided")

if hasattr(file, "name"):
ext = Path(file.name).suffix.lower()
else:
path = Path(str(file))
ext = path.suffix.lower()
if not path.exists():
raise ValueError("Invalid file path")

if ext not in {f".{t.lstrip('.').lower()}" for t in allowed_types}:
raise ValueError(f"Unsupported file type: {ext}")
return ext


def process_uploaded_file(
uploaded_file: Any,
*,
session_key: str,
detect_datetime: bool = False,
datetime_key: str = "datetime_cols",
) -> pd.DataFrame | None:
"""Load an uploaded file and store the DataFrame in session state."""
if uploaded_file is None:
return None
try:
_ = validate_file_type(uploaded_file, ["csv", "xls", "xlsx"])
df = load_data(uploaded_file)
df = convert_dtypes(df)
except (ValueError, TypeError) as exc: # pragma: no cover - tested via wrapper
st.error(f"Failed to load file: {exc}")
return None
st.session_state[session_key] = df
if detect_datetime:
from . import eda

st.session_state[datetime_key] = eda.detect_datetime_columns(df)
return df


def upload_data_to_session(
label: str,
*,
session_key: str,
datetime_key: str | None = None,
uploader_key: str | None = None,
help: str | None = None,
types: Iterable[str] = ("csv", "xlsx", "xls"),
) -> pd.DataFrame | None:
"""Upload a file and store the loaded DataFrame in session state."""
file = st.file_uploader(
label,
type=list(types),
key=uploader_key or f"{session_key}_uploader",
help=help,
)
return process_uploaded_file(
file,
session_key=session_key,
detect_datetime=datetime_key is not None,
datetime_key=datetime_key or "datetime_cols",
)