diff --git a/docker-compose.yml b/docker-compose.yml index 2f580fa..9b829e7 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -185,13 +185,16 @@ services: - database - sftp_receiver environment: - - DATABASE_URL=postgresql://myuser:mypassword@database:5432/mydatabase + - DATABASE_URL=postgresql://webserver_role:webserverpassword@database:5432/mydatabase + - SECRET_KEY=${SECRET_KEY:-changeme-set-a-real-secret-in-production} + - USERS_CONFIG_PATH=/app/configs/users.json - STORAGE_BACKEND=local # Options: local, s3, minio - STORAGE_BASE_PATH=/app/data volumes: - sftp_openmrg_uploads:/app/data/incoming:ro # Read-only access to OpenMRG SFTP uploads - webserver_data_staged:/app/data/staged - webserver_data_archived:/app/data/archived + - ./webserver/configs:/app/configs:ro sftp_receiver: image: atmoz/sftp:latest diff --git a/docs/multi-user-architecture.md b/docs/multi-user-architecture.md index 8d1af72..851219d 100644 --- a/docs/multi-user-architecture.md +++ b/docs/multi-user-architecture.md @@ -2,18 +2,19 @@ ## Status -PRs 1–2 are merged. The database schema and isolation model are in place. -PRs 3–7 remain and are described below. +PRs 1–5 are merged. The database schema, isolation model, and webserver login are in place. +PRs 6–8 remain and are described below. | PR | Branch | Status | Scope | |----|--------|--------|-------| | 1 | `feat/db-add-user-id` | merged | `user_id` columns, updated aggregate + compression | | 2 | `feat/db-roles-rls` | merged | Roles, RLS, security-barrier views | -| 3 | `feat/parser-user-id` | **next** | Parser injects `user_id`; removes compat defaults | -| 4 | `feat/sftp-multi-user` | not started | Per-user SFTP dirs, volumes, parser instances | -| 5 | `feat/webserver-auth` | not started | Login, session, DB role switching — go-live milestone | +| 3 | `feat/parser-user-id` | merged | Parser injects `user_id`; removes compat defaults | +| 4 | `feat/sftp-multi-user` | merged | Per-user SFTP dirs, volumes, parser instances | +| 5 | `feat/webserver-auth` | merged | Login, session, DB role switching — go-live milestone | | 6 | `feat/web-api-upload` | not started | HTTP API upload + drag-and-drop | | 7 | `feat/user-onboarding` | not started | `add_user.sh`, docs | +| 8 | `feat/grafana-auth-proxy` | not started | Per-user Grafana datasources + auth proxy header | --- @@ -557,13 +558,115 @@ No second user should be onboarded until this is resolved. --- +## PR8 — `feat/grafana-auth-proxy` + +**Goal:** Grafana dashboards are scoped to the logged-in user without a separate Grafana login. +Users retain full interactive dashboard access and can build their own panels. + +### Known gap (until this PR) + +Grafana currently connects as `myuser` (superuser) and sees all tenants' data regardless of +which user is logged in to the webserver. The `/grafana/` proxy is gated by `@login_required` +(PR5), so unauthenticated access is blocked, but data isolation within Grafana is not enforced. + +### Approach — Grafana auth proxy + per-user datasources + +Grafana's [Auth Proxy](https://grafana.com/docs/grafana/latest/setup-grafana/configure-security/configure-authentication/auth-proxy/) +mode trusts an upstream header (`X-WEBAUTH-USER`) set by a reverse proxy or, in our case, the +Flask `/grafana/` proxy. Grafana auto-provisions a Grafana user on first login and maps them to +an org/team. Combined with per-user PostgreSQL datasources (each connecting as the matching PG +role), queries are automatically scoped to that user's data via RLS and security-barrier views. + +**Data isolation chain:** +``` +Flask session → X-WEBAUTH-USER header → Grafana user → per-user datasource → PG role → RLS +``` + +### Changes + +**`grafana/provisioning/datasources/postgres.yml`** — replace single `myuser` datasource with +one datasource per user: + +```yaml +apiVersion: 1 +datasources: + - name: demo_openmrg + uid: ds_demo_openmrg + type: grafana-postgresql-datasource + access: proxy + url: database:5432 + database: mydatabase + user: demo_openmrg + secureJsonData: + password: + jsonData: + sslmode: disable + + - name: demo_orange_cameroun + uid: ds_demo_orange_cameroun + type: grafana-postgresql-datasource + access: proxy + url: database:5432 + database: mydatabase + user: demo_orange_cameroun + secureJsonData: + password: + jsonData: + sslmode: disable +``` + +**`grafana/provisioning/datasources/postgres.yml`** — also keep an admin datasource connecting +as `webserver_role` for cross-tenant dashboards used by operators. + +**`docker-compose.yml` — Grafana environment:** + +```yaml + grafana: + environment: + - GF_AUTH_PROXY_ENABLED=true + - GF_AUTH_PROXY_HEADER_NAME=X-WEBAUTH-USER + - GF_AUTH_PROXY_HEADER_PROPERTY=username + - GF_AUTH_PROXY_AUTO_SIGN_UP=true + - GF_AUTH_PROXY_WHITELIST=webserver # only accept header from the webserver container + - GF_AUTH_DISABLE_LOGIN_FORM=true + - GF_AUTH_ANONYMOUS_ENABLED=false +``` + +**`webserver/main.py` — inject header in Grafana proxy:** + +```python +@app.route("/grafana/", defaults={"path": ""}, methods=[...]) +@app.route("/grafana/", methods=[...]) +@login_required +def grafana_proxy(path): + headers = {k: v for k, v in request.headers if k.lower() != "host"} + headers["X-WEBAUTH-USER"] = current_user.id # inject identity + ... +``` + +### Onboarding impact + +`scripts/add_user.sh` (PR7) must also provision the Grafana datasource and add a `GF_` +environment variable or Grafana API call to create the user's org/team mapping. + +### Security notes + +- `GF_AUTH_PROXY_WHITELIST` must restrict the trusted header to the webserver container IP/name + so external clients cannot forge `X-WEBAUTH-USER`. +- Grafana datasource passwords are dev defaults; rotate before production. +- Per-user datasources connecting as PG login roles provide the same DB-level isolation as the + webserver (RLS on `cml_metadata`/`cml_stats`, security-barrier views for `cml_data`). + +--- + ## Success criteria - Each user's `cml_metadata` and `cml_stats` rows are invisible to other user roles (RLS). - Each user role cannot read `cml_data` directly; only `cml_data_secure` is accessible. - `webserver_role` without `SET ROLE` can read all tenants' metadata and stats (admin path). - After `SET ROLE user1`, all queries on `cml_data_secure` and `cml_data_1h_secure` return only `user_id = 'user1'` rows. -- The webserver requires login on all routes (PR5). +- The webserver requires login on all routes (PR5). ✓ - A second user can be fully onboarded without touching the running DB schema (PR7). +- Grafana dashboards are scoped to the logged-in user; no cross-tenant data visible (PR8). - Database RAM stays ≤ 3 GB for 10 users (compression + aggregate already in place). diff --git a/webserver/README.md b/webserver/README.md new file mode 100644 index 0000000..1d155b8 --- /dev/null +++ b/webserver/README.md @@ -0,0 +1,44 @@ +# Webserver + +Flask application serving the GMDI data portal. + +## User Management + +Users are stored in `configs/users.json`. Each entry maps a **user ID** (which must match the corresponding PostgreSQL role name) to a display name and a hashed password. + +```json +{ + "alice": { + "display_name": "Alice", + "password_hash": "" + } +} +``` + +### Generating a password hash + +Use werkzeug (already installed in the webserver image) to produce a hash: + +```bash +python -c "from werkzeug.security import generate_password_hash; print(generate_password_hash('yourpassword'))" +``` + +Copy the output into `password_hash`. The hash format is `scrypt:32768:8:1$$` — werkzeug selects the algorithm and parameters automatically. + +### Adding a user + +1. Create the PostgreSQL role in the database (see `database/migrations/` for examples). +2. Add an entry to `configs/users.json` with the generated hash. +3. Restart the webserver container (it reads the file at startup). + +> **Important:** the user ID in `users.json` must exactly match the PostgreSQL role name, because the webserver issues `SET LOCAL ROLE ` to scope every DB query to that tenant. + +## Running Tests + +```bash +docker compose run --rm --no-deps \ + -e DATABASE_URL=postgresql://x:x@localhost/x \ + -e USERS_CONFIG_PATH=/app/configs/users.json \ + -v "$(pwd)/configs:/app/configs:ro" \ + webserver sh -c "pip install pytest pytest-cov && python -m pytest tests/ -v" +``` diff --git a/webserver/configs/users.json b/webserver/configs/users.json new file mode 100644 index 0000000..65de9c3 --- /dev/null +++ b/webserver/configs/users.json @@ -0,0 +1,10 @@ +{ + "demo_openmrg": { + "password_hash": "scrypt:32768:8:1$HLOwGuhFRtd4Dah3$f7dca30ff20c0da01f53569bf7396bdaec4bbd428ef1875f07791a857d57d8434c8fbfa67269453a0980769e1db6787ab46c01e8e33b57aa160d615db385a944", + "display_name": "OpenMRG Demo" + }, + "demo_orange_cameroun": { + "password_hash": "scrypt:32768:8:1$EpwEQPpmJkYCDx4I$662c778d419645ac0f8be645b5af56543e1f47cfc5a4d33c261935db3750941d0ce9e107a5586415f112c5b31e57d85ada449d8c678a5de7169fce7854647a54", + "display_name": "Orange Cameroun Demo" + } +} \ No newline at end of file diff --git a/webserver/main.py b/webserver/main.py index 92d2707..64e3a48 100644 --- a/webserver/main.py +++ b/webserver/main.py @@ -3,20 +3,66 @@ import time import math import psycopg2 +from psycopg2 import sql as pgsql import folium import requests from markupsafe import escape -from flask import Flask, render_template, request, jsonify, Response, redirect +from flask import ( + Flask, + render_template, + request, + jsonify, + Response, + redirect, + url_for, + flash, +) +from flask_login import ( + LoginManager, + UserMixin, + login_user, + logout_user, + login_required, + current_user, +) +from werkzeug.security import check_password_hash from werkzeug.utils import secure_filename from datetime import datetime, timedelta from pathlib import Path +from contextlib import contextmanager import uuid app = Flask(__name__) +app.secret_key = os.getenv("SECRET_KEY", os.urandom(32)) +app.config["MAX_CONTENT_LENGTH"] = 500 * 1024 * 1024 # WSGI-level enforcement + +# ── User store (loaded from file at startup) ────────────────────────────────── +_users_config_path = os.getenv("USERS_CONFIG_PATH", "/app/configs/users.json") +try: + with open(_users_config_path) as _f: + USERS = json.load(_f) +except FileNotFoundError: + USERS = {} + +# ── Flask-Login setup ───────────────────────────────────────────────────────── +login_manager = LoginManager(app) +login_manager.login_view = "login" +login_manager.login_message = "Please log in to access this page." + + +class User(UserMixin): + def __init__(self, user_id: str): + self.id = user_id + self.display_name = USERS[user_id].get("display_name", user_id) + + +@login_manager.user_loader +def load_user(user_id: str): + return User(user_id) if user_id in USERS else None + ALLOWED_EXTENSIONS = {"nc", "csv", "h5", "hdf5"} MAX_FILE_SIZE = 500 * 1024 * 1024 # 500 MB -app.config["MAX_CONTENT_LENGTH"] = MAX_FILE_SIZE # WSGI-level enforcement # Data directories DATA_INCOMING_DIR = "/app/data_incoming" @@ -47,9 +93,11 @@ def safe_float(value): return parsed -# Database connection helper +# ── Database helpers ───────────────────────────────────────────────────────── + + def get_db_connection(): - """Create and return a database connection""" + """Admin connection as webserver_role (cross-tenant queries).""" try: conn = psycopg2.connect(os.getenv("DATABASE_URL")) return conn @@ -58,10 +106,71 @@ def get_db_connection(): return None +@contextmanager +def user_db_scope(user_id: str): + """Context manager: connection scoped to user_id for one request. + + Connects as webserver_role then issues SET LOCAL ROLE . + SET LOCAL is automatically reverted at transaction end, so role + bleed is impossible even on connection reuse. + + The role name is composed with pgsql.Identifier (never %s) so it + cannot be used as a SQL injection vector. user_id is also + allowlisted against USERS before reaching SQL composition. + """ + if user_id not in USERS: + raise ValueError(f"Unknown user_id: {user_id!r}") + + conn = psycopg2.connect(os.getenv("DATABASE_URL")) + try: + with conn.cursor() as cur: + cur.execute( + pgsql.SQL("SET LOCAL ROLE {}").format(pgsql.Identifier(user_id)) + ) + yield conn + conn.commit() + except Exception: + conn.rollback() + raise + finally: + conn.close() + + +# ==================== AUTH ROUTES ==================== + + +@app.route("/login", methods=["GET", "POST"]) +def login(): + if current_user.is_authenticated: + return redirect(url_for("overview")) + if request.method == "POST": + username = request.form.get("username", "") + password = request.form.get("password", "") + if username in USERS and check_password_hash( + USERS[username]["password_hash"], password + ): + login_user(User(username)) + next_page = request.args.get("next") + # Guard against open-redirect: only allow relative paths. + if next_page and not next_page.startswith("/"): + next_page = None + return redirect(next_page or url_for("overview")) + flash("Invalid username or password.") + return render_template("login.html") + + +@app.route("/logout") +@login_required +def logout(): + logout_user() + return redirect(url_for("login")) + + # ==================== LANDING PAGE ROUTES ==================== @app.route("/") +@login_required def overview(): """Landing page with overview and processing status""" stats = { @@ -73,27 +182,25 @@ def overview(): } try: - conn = get_db_connection() - if conn: + with user_db_scope(current_user.id) as conn: cur = conn.cursor() - # Get count of CMLs + # Get count of CMLs visible to this user (RLS enforced) cur.execute("SELECT COUNT(DISTINCT cml_id) FROM cml_metadata") stats["total_cmls"] = cur.fetchone()[0] - # Get approximate count of data records (fast on large tables) - cur.execute("SELECT approximate_row_count('cml_data')") + # Approximate count via secure view + cur.execute("SELECT COUNT(*) FROM cml_data_secure") stats["total_records"] = cur.fetchone()[0] - # Get data date range (from 1h aggregate — fast, indexed) - cur.execute("SELECT MIN(bucket), MAX(bucket) FROM cml_data_1h") + # Get data date range (from 1h secure view) + cur.execute("SELECT MIN(bucket), MAX(bucket) FROM cml_data_1h_secure") result = cur.fetchone() if result: stats["data_start_date"] = result[0] stats["data_end_date"] = result[1] cur.close() - conn.close() except Exception as e: print(f"Error fetching landing stats: {e}") @@ -103,20 +210,16 @@ def overview(): # ==================== REAL-TIME DATA ROUTES ==================== -def generate_cml_map(): +def generate_cml_map(user_id: str): """Generate a Leaflet map showing all CMLs with clickable lines""" try: - conn = get_db_connection() - if not conn: - return None - - cur = conn.cursor() - cur.execute( - "SELECT DISTINCT ON (cml_id) cml_id, site_0_lon, site_0_lat, site_1_lon, site_1_lat FROM cml_metadata ORDER BY cml_id" - ) - data = cur.fetchall() - cur.close() - conn.close() + with user_db_scope(user_id) as conn: + cur = conn.cursor() + cur.execute( + "SELECT DISTINCT ON (cml_id) cml_id, site_0_lon, site_0_lat, site_1_lon, site_1_lat FROM cml_metadata ORDER BY cml_id" + ) + data = cur.fetchall() + cur.close() if not data: return None @@ -269,18 +372,14 @@ def generate_cml_map(): return None -def get_available_cmls(): - """Get list of available CMLs""" +def get_available_cmls(user_id: str): + """Get list of CMLs visible to the given user.""" try: - conn = get_db_connection() - if not conn: - return [] - - cur = conn.cursor() - cur.execute("SELECT DISTINCT cml_id FROM cml_metadata ORDER BY cml_id") - cmls = [row[0] for row in cur.fetchall()] - cur.close() - conn.close() + with user_db_scope(user_id) as conn: + cur = conn.cursor() + cur.execute("SELECT DISTINCT cml_id FROM cml_metadata ORDER BY cml_id") + cmls = [row[0] for row in cur.fetchall()] + cur.close() return cmls except Exception as e: print(f"Error fetching CMLs: {e}") @@ -288,10 +387,11 @@ def get_available_cmls(): @app.route("/realtime") +@login_required def realtime(): """Real-time data page""" - map_html = generate_cml_map() - cmls = get_available_cmls() + map_html = generate_cml_map(current_user.id) + cmls = get_available_cmls(current_user.id) default_cml = cmls[0] if cmls else None return render_template( @@ -303,6 +403,7 @@ def realtime(): @app.route("/grafana") +@login_required def grafana_root_redirect(): """Redirect /grafana to /grafana/ for proper subpath routing.""" return redirect("/grafana/", code=302) @@ -316,6 +417,7 @@ def grafana_root_redirect(): @app.route( "/grafana/", methods=["GET", "POST", "PUT", "DELETE", "PATCH", "OPTIONS"] ) +@login_required def grafana_proxy(path): """Proxy all requests to Grafana container.""" grafana_url = f"http://grafana:3000/grafana/{path}" @@ -349,20 +451,17 @@ def grafana_proxy(path): @app.route("/api/cml-metadata") +@login_required def api_cml_metadata(): """API endpoint for fetching CML metadata""" try: - conn = get_db_connection() - if not conn: - return jsonify({"cmls": []}) - - cur = conn.cursor() - cur.execute( - "SELECT DISTINCT ON (cml_id) cml_id, site_0_lon, site_0_lat, site_1_lon, site_1_lat FROM cml_metadata ORDER BY cml_id" - ) - data = cur.fetchall() - cur.close() - conn.close() + with user_db_scope(current_user.id) as conn: + cur = conn.cursor() + cur.execute( + "SELECT DISTINCT ON (cml_id) cml_id, site_0_lon, site_0_lat, site_1_lon, site_1_lat FROM cml_metadata ORDER BY cml_id" + ) + data = cur.fetchall() + cur.close() cmls = [ { @@ -381,20 +480,17 @@ def api_cml_metadata(): @app.route("/api/cml-map") +@login_required def api_cml_map(): """API endpoint for fetching CML data optimized for map rendering""" try: - conn = get_db_connection() - if not conn: - return jsonify([]) - - cur = conn.cursor() - cur.execute( - "SELECT DISTINCT ON (cml_id) cml_id::text, site_0_lon, site_0_lat, site_1_lon, site_1_lat FROM cml_metadata ORDER BY cml_id" - ) - data = cur.fetchall() - cur.close() - conn.close() + with user_db_scope(current_user.id) as conn: + cur = conn.cursor() + cur.execute( + "SELECT DISTINCT ON (cml_id) cml_id::text, site_0_lon, site_0_lat, site_1_lon, site_1_lat FROM cml_metadata ORDER BY cml_id" + ) + data = cur.fetchall() + cur.close() cmls = [ { @@ -411,43 +507,40 @@ def api_cml_map(): @app.route("/api/cml-stats") +@login_required def api_cml_stats(): """API endpoint for fetching per-CML statistics for data quality visualization""" try: - conn = get_db_connection() - if not conn: - return jsonify([]) - - cur = conn.cursor() - cur.execute( + with user_db_scope(current_user.id) as conn: + cur = conn.cursor() + cur.execute( + """ + SELECT + cs.cml_id::text, + cs.total_records, + cs.valid_records, + cs.null_records, + cs.completeness_percent, + cs.min_rsl, + cs.max_rsl, + cs.mean_rsl, + cs.stddev_rsl, + cs.last_rsl, + ROUND(STDDEV(cd.rsl)::numeric, 2) as stddev_last_60min + FROM cml_stats cs + LEFT JOIN ( + SELECT cml_id, rsl + FROM cml_data_secure + WHERE time >= (SELECT MAX(bucket) FROM cml_data_1h_secure) - INTERVAL '60 minutes' + ) cd ON cs.cml_id = cd.cml_id + GROUP BY cs.cml_id, cs.total_records, cs.valid_records, cs.null_records, + cs.completeness_percent, cs.min_rsl, cs.max_rsl, cs.mean_rsl, + cs.stddev_rsl, cs.last_rsl + ORDER BY cs.cml_id """ - SELECT - cs.cml_id::text, - cs.total_records, - cs.valid_records, - cs.null_records, - cs.completeness_percent, - cs.min_rsl, - cs.max_rsl, - cs.mean_rsl, - cs.stddev_rsl, - cs.last_rsl, - ROUND(STDDEV(cd.rsl)::numeric, 2) as stddev_last_60min - FROM cml_stats cs - LEFT JOIN ( - SELECT cml_id, rsl - FROM cml_data - WHERE time >= (SELECT MAX(bucket) FROM cml_data_1h) - INTERVAL '60 minutes' - ) cd ON cs.cml_id = cd.cml_id - GROUP BY cs.cml_id, cs.total_records, cs.valid_records, cs.null_records, - cs.completeness_percent, cs.min_rsl, cs.max_rsl, cs.mean_rsl, - cs.stddev_rsl, cs.last_rsl - ORDER BY cs.cml_id - """ - ) - data = cur.fetchall() - cur.close() - conn.close() + ) + data = cur.fetchall() + cur.close() stats = [ { @@ -472,18 +565,15 @@ def api_cml_stats(): @app.route("/api/data-time-range") +@login_required def api_data_time_range(): """API endpoint for fetching the actual time range of available data""" try: - conn = get_db_connection() - if not conn: - return jsonify({"earliest": None, "latest": None}) - - cur = conn.cursor() - cur.execute("SELECT MIN(bucket), MAX(bucket) FROM cml_data_1h") - result = cur.fetchone() - cur.close() - conn.close() + with user_db_scope(current_user.id) as conn: + cur = conn.cursor() + cur.execute("SELECT MIN(bucket), MAX(bucket) FROM cml_data_1h_secure") + result = cur.fetchone() + cur.close() if result and result[0] and result[1]: # Format as ISO 8601 strings @@ -499,8 +589,8 @@ def api_data_time_range(): # ==================== ARCHIVE STATISTICS ROUTES ==================== -def get_archive_statistics(): - """Fetch aggregated statistics from the long-term archive""" +def get_archive_statistics(user_id: str): + """Fetch aggregated statistics from the long-term archive for the given user.""" stats = { "total_records": 0, "cml_count": 0, @@ -508,29 +598,25 @@ def get_archive_statistics(): } try: - conn = get_db_connection() - if not conn: - return stats - - cur = conn.cursor() + with user_db_scope(user_id) as conn: + cur = conn.cursor() - # Total records (approximate, fast on large tables) - cur.execute("SELECT approximate_row_count('cml_data')") - stats["total_records"] = cur.fetchone()[0] + # Row count via secure view + cur.execute("SELECT COUNT(*) FROM cml_data_secure") + stats["total_records"] = cur.fetchone()[0] - # CML count - cur.execute("SELECT COUNT(DISTINCT cml_id) FROM cml_metadata") - stats["cml_count"] = cur.fetchone()[0] + # CML count (RLS enforced) + cur.execute("SELECT COUNT(DISTINCT cml_id) FROM cml_metadata") + stats["cml_count"] = cur.fetchone()[0] - # Date range (from 1h aggregate — fast, indexed) - cur.execute("SELECT MIN(bucket), MAX(bucket) FROM cml_data_1h") - result = cur.fetchone() - if result: - stats["date_range"]["start"] = result[0] - stats["date_range"]["end"] = result[1] + # Date range (from 1h secure view) + cur.execute("SELECT MIN(bucket), MAX(bucket) FROM cml_data_1h_secure") + result = cur.fetchone() + if result: + stats["date_range"]["start"] = result[0] + stats["date_range"]["end"] = result[1] - cur.close() - conn.close() + cur.close() except Exception as e: print(f"Error fetching archive statistics: {e}") @@ -538,9 +624,10 @@ def get_archive_statistics(): @app.route("/archive") +@login_required def archive(): """Archive statistics page""" - stats = get_archive_statistics() + stats = get_archive_statistics(current_user.id) return render_template("archive.html", stats=stats) @@ -548,6 +635,7 @@ def archive(): @app.route("/data-uploads") +@login_required def data_uploads(): """Data uploads page""" return render_template("data_uploads.html") @@ -567,6 +655,7 @@ def get_file_size_mb(filepath): @app.route("/api/upload", methods=["POST"]) +@login_required def upload_file(): """Handle file upload via drag and drop""" try: @@ -581,7 +670,10 @@ def upload_file(): safe_name = secure_filename(file.filename) if not safe_name or not allowed_file(safe_name): - return jsonify({"error": "File type not allowed. Allowed: nc, csv, h5, hdf5"}), 400 + return ( + jsonify({"error": "File type not allowed. Allowed: nc, csv, h5, hdf5"}), + 400, + ) # Generate unique filename to avoid collisions timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") @@ -627,6 +719,7 @@ def upload_file(): @app.route("/api/files", methods=["GET"]) +@login_required def get_files(): """Get list of files in data_incoming and data_staged_for_parsing directories""" try: diff --git a/webserver/requirements.txt b/webserver/requirements.txt index 0111892..e9a5a70 100644 --- a/webserver/requirements.txt +++ b/webserver/requirements.txt @@ -1,4 +1,5 @@ Flask==2.3.3 +Flask-Login==0.6.3 psycopg2-binary==2.9.7 folium==0.14.0 gunicorn==22.0.0 diff --git a/webserver/templates/base.html b/webserver/templates/base.html index 38a16d4..93d21bf 100644 --- a/webserver/templates/base.html +++ b/webserver/templates/base.html @@ -23,6 +23,7 @@ diff --git a/webserver/templates/login.html b/webserver/templates/login.html new file mode 100644 index 0000000..7202192 --- /dev/null +++ b/webserver/templates/login.html @@ -0,0 +1,49 @@ + + + + + + + Login — GMDI Platform + + + + + +
+
+
+
+ GMDI Logo +

GMDI Platform

+

Sign in to continue

+
+ + {% with messages = get_flashed_messages() %} + {% if messages %} + + {% endif %} + {% endwith %} + +
+
+ + +
+
+ + +
+ +
+
+
+
+ + + + \ No newline at end of file diff --git a/webserver/tests/test_api_cml_stats.py b/webserver/tests/test_api_cml_stats.py index 36b85e7..f90834e 100644 --- a/webserver/tests/test_api_cml_stats.py +++ b/webserver/tests/test_api_cml_stats.py @@ -1,4 +1,5 @@ import sys +from contextlib import contextmanager from unittest.mock import Mock import pytest @@ -45,7 +46,18 @@ def test_api_cml_stats_returns_cached_stats(monkeypatch): mock_cursor.close = Mock() mock_conn.close = Mock() - monkeypatch.setattr(wm, "get_db_connection", lambda: mock_conn) + # The route now uses user_db_scope(current_user.id) instead of get_db_connection(). + # Mock user_db_scope to yield the mock connection, and disable login enforcement. + @contextmanager + def mock_user_db_scope(user_id): + yield mock_conn + + mock_user = Mock() + mock_user.id = "demo_openmrg" + + monkeypatch.setattr(wm, "user_db_scope", mock_user_db_scope) + monkeypatch.setattr(wm, "current_user", mock_user) + monkeypatch.setitem(wm.app.config, "LOGIN_DISABLED", True) client = wm.app.test_client() resp = client.get("/api/cml-stats") diff --git a/webserver/tests/test_api_routes.py b/webserver/tests/test_api_routes.py new file mode 100644 index 0000000..ffc4783 --- /dev/null +++ b/webserver/tests/test_api_routes.py @@ -0,0 +1,87 @@ +import os +import sys +from contextlib import contextmanager +from datetime import datetime +from unittest.mock import Mock + +import pytest + +sys.modules.setdefault("folium", Mock()) +sys.modules.setdefault("requests", Mock()) + +sys.path.insert(0, os.path.dirname(os.path.dirname(__file__))) +import main as wm # noqa: E402 + + +@pytest.fixture +def auth_client(monkeypatch): + """Test client with login bypassed and user_db_scope mocked.""" + mock_conn = Mock() + mock_cursor = Mock() + mock_conn.cursor.return_value = mock_cursor + + mock_user = Mock() + mock_user.id = "demo_openmrg" + + @contextmanager + def mock_scope(user_id): + yield mock_conn + + monkeypatch.setattr(wm, "user_db_scope", mock_scope) + monkeypatch.setattr(wm, "current_user", mock_user) + monkeypatch.setitem(wm.app.config, "LOGIN_DISABLED", True) + + wm.app.config["TESTING"] = True + return wm.app.test_client(), mock_cursor + + +def test_api_cml_metadata_returns_cmls(auth_client): + client, cursor = auth_client + cursor.fetchall.return_value = [("CML01", 10.0, 5.0, 10.5, 5.5)] + resp = client.get("/api/cml-metadata") + assert resp.status_code == 200 + data = resp.get_json() + assert len(data["cmls"]) == 1 + assert data["cmls"][0]["id"] == "CML01" + assert data["cmls"][0]["site_0_lon"] == 10.0 + + +def test_api_cml_metadata_empty(auth_client): + client, cursor = auth_client + cursor.fetchall.return_value = [] + resp = client.get("/api/cml-metadata") + assert resp.status_code == 200 + assert resp.get_json() == {"cmls": []} + + +def test_api_cml_map_returns_list(auth_client): + client, cursor = auth_client + cursor.fetchall.return_value = [("CML01", 10.0, 5.0, 10.5, 5.5)] + resp = client.get("/api/cml-map") + assert resp.status_code == 200 + data = resp.get_json() + assert data[0]["cml_id"] == "CML01" + assert data[0]["site_0"] == {"lon": 10.0, "lat": 5.0} + assert data[0]["site_1"] == {"lon": 10.5, "lat": 5.5} + + +def test_api_data_time_range_with_data(auth_client): + client, cursor = auth_client + dt1 = datetime(2025, 1, 1, 0, 0) + dt2 = datetime(2025, 12, 31, 0, 0) + cursor.fetchone.return_value = (dt1, dt2) + resp = client.get("/api/data-time-range") + assert resp.status_code == 200 + data = resp.get_json() + assert data["earliest"] == dt1.isoformat() + assert data["latest"] == dt2.isoformat() + + +def test_api_data_time_range_no_data(auth_client): + client, cursor = auth_client + cursor.fetchone.return_value = (None, None) + resp = client.get("/api/data-time-range") + assert resp.status_code == 200 + data = resp.get_json() + assert data["earliest"] is None + assert data["latest"] is None diff --git a/webserver/tests/test_auth.py b/webserver/tests/test_auth.py new file mode 100644 index 0000000..756dcf4 --- /dev/null +++ b/webserver/tests/test_auth.py @@ -0,0 +1,86 @@ +import os +import sys +from unittest.mock import Mock + +import pytest +from werkzeug.security import generate_password_hash + +# Stub optional heavy imports before main.py is loaded +sys.modules.setdefault("folium", Mock()) +sys.modules.setdefault("requests", Mock()) + +sys.path.insert(0, os.path.dirname(os.path.dirname(__file__))) +import main as wm # noqa: E402 + +_PROTECTED_ROUTES = ["/", "/realtime", "/api/cml-stats", "/api/cml-metadata"] + + +@pytest.fixture +def client(): + wm.app.config["TESTING"] = True + return wm.app.test_client() + + +@pytest.fixture +def test_user(monkeypatch): + monkeypatch.setitem( + wm.USERS, + "testuser", + {"password_hash": generate_password_hash("testpass"), "display_name": "Test"}, + ) + return "testuser", "testpass" + + +def test_login_page_accessible(client): + assert client.get("/login").status_code == 200 + + +@pytest.mark.parametrize("path", _PROTECTED_ROUTES) +def test_protected_routes_redirect_unauthenticated(client, path): + resp = client.get(path) + assert resp.status_code == 302 + assert "login" in resp.headers["Location"] + + +def test_login_valid_credentials_redirects(client, test_user): + username, password = test_user + resp = client.post("/login", data={"username": username, "password": password}) + assert resp.status_code == 302 + assert resp.headers["Location"].endswith("/") + + +def test_login_wrong_password_stays_on_login(client, test_user): + username, _ = test_user + resp = client.post("/login", data={"username": username, "password": "wrong"}) + assert resp.status_code == 200 + + +def test_login_unknown_user_stays_on_login(client): + resp = client.post("/login", data={"username": "nobody", "password": "x"}) + assert resp.status_code == 200 + + +def test_login_open_redirect_blocked(client, test_user): + username, password = test_user + resp = client.post( + "/login?next=https://evil.com", + data={"username": username, "password": password}, + ) + assert resp.status_code == 302 + assert "evil.com" not in resp.headers["Location"] + + +def test_logout_redirects_to_login(client): + resp = client.get("/logout") + assert resp.status_code == 302 + assert "login" in resp.headers["Location"] + + +def test_logout_when_authenticated_clears_session(client, test_user): + username, password = test_user + client.post("/login", data={"username": username, "password": password}) + resp = client.get("/logout") + assert resp.status_code == 302 + assert "login" in resp.headers["Location"] + # Session is cleared: next protected request redirects again + assert client.get("/").status_code == 302 diff --git a/webserver/tests/test_helpers.py b/webserver/tests/test_helpers.py new file mode 100644 index 0000000..f633aa5 --- /dev/null +++ b/webserver/tests/test_helpers.py @@ -0,0 +1,49 @@ +import os +import sys +from unittest.mock import Mock + +sys.modules.setdefault("folium", Mock()) +sys.modules.setdefault("requests", Mock()) + +sys.path.insert(0, os.path.dirname(os.path.dirname(__file__))) +import main as wm # noqa: E402 + + +# ── safe_float ──────────────────────────────────────────────────────────────── + + +def test_safe_float_none_returns_none(): + assert wm.safe_float(None) is None + + +def test_safe_float_valid_number(): + assert wm.safe_float(3.14) == 3.14 + assert wm.safe_float("2.5") == 2.5 + + +def test_safe_float_non_numeric_returns_none(): + assert wm.safe_float("abc") is None + + +def test_safe_float_nan_returns_none(): + assert wm.safe_float(float("nan")) is None + + +def test_safe_float_inf_returns_none(): + assert wm.safe_float(float("inf")) is None + + +# ── load_user ───────────────────────────────────────────────────────────────── + + +def test_load_user_known(monkeypatch): + monkeypatch.setitem(wm.USERS, "alice", {"display_name": "Alice"}) + user = wm.load_user("alice") + assert user is not None + assert user.id == "alice" + assert user.display_name == "Alice" + + +def test_load_user_unknown(monkeypatch): + monkeypatch.setattr(wm, "USERS", {"alice": {}}) + assert wm.load_user("nobody") is None diff --git a/webserver/tests/test_user_db_scope.py b/webserver/tests/test_user_db_scope.py new file mode 100644 index 0000000..3a22806 --- /dev/null +++ b/webserver/tests/test_user_db_scope.py @@ -0,0 +1,52 @@ +import os +import sys +from unittest.mock import Mock, MagicMock + +import pytest +from psycopg2 import sql as pgsql + +# Stub optional heavy imports before main.py is loaded +sys.modules.setdefault("folium", Mock()) +sys.modules.setdefault("requests", Mock()) + +sys.path.insert(0, os.path.dirname(os.path.dirname(__file__))) +import main as wm # noqa: E402 + + +def test_unknown_user_raises_before_connecting(monkeypatch): + mock_connect = Mock() + monkeypatch.setattr(wm.psycopg2, "connect", mock_connect) + monkeypatch.setattr(wm, "USERS", {"known": {}}) + + with pytest.raises(ValueError, match="Unknown user_id"): + with wm.user_db_scope("injected_role"): + pass # pragma: no cover + + mock_connect.assert_not_called() + + +def test_exception_inside_scope_triggers_rollback(monkeypatch): + mock_conn = MagicMock() + monkeypatch.setattr(wm.psycopg2, "connect", Mock(return_value=mock_conn)) + monkeypatch.setattr(wm, "USERS", {"myuser": {}}) + + with pytest.raises(RuntimeError): + with wm.user_db_scope("myuser"): + raise RuntimeError("boom") + + mock_conn.rollback.assert_called_once() + mock_conn.commit.assert_not_called() + + +def test_set_local_role_uses_sql_identifier(monkeypatch): + """SET LOCAL ROLE must use pgsql.Identifier, not string interpolation.""" + mock_conn = MagicMock() + monkeypatch.setattr(wm.psycopg2, "connect", Mock(return_value=mock_conn)) + monkeypatch.setattr(wm, "USERS", {"myuser": {}}) + + with wm.user_db_scope("myuser"): + pass + + cur = mock_conn.cursor.return_value.__enter__.return_value + call_arg = cur.execute.call_args[0][0] + assert isinstance(call_arg, pgsql.Composable)