diff --git a/packages/backend/app/models.py b/packages/backend/app/models.py index 64d44810..52c2f0b1 100644 --- a/packages/backend/app/models.py +++ b/packages/backend/app/models.py @@ -133,3 +133,32 @@ class AuditLog(db.Model): user_id = db.Column(db.Integer, db.ForeignKey("users.id"), nullable=True) action = db.Column(db.String(100), nullable=False) created_at = db.Column(db.DateTime, default=datetime.utcnow, nullable=False) + + +class CategorizationRule(db.Model): + __tablename__ = "categorization_rules" + id = db.Column(db.Integer, primary_key=True) + user_id = db.Column(db.Integer, db.ForeignKey("users.id"), nullable=False) + keyword = db.Column(db.String(200), nullable=False) + category_name = db.Column(db.String(100), nullable=False) + confidence = db.Column(db.Float, default=0.80, nullable=False) + source = db.Column(db.String(20), default="learned", nullable=False) + created_at = db.Column(db.DateTime, default=datetime.utcnow, nullable=False) + updated_at = db.Column( + db.DateTime, default=datetime.utcnow, onupdate=datetime.utcnow, nullable=False + ) +from datetime import datetime +from .extensions import db + + +class TrustedDevice(db.Model): + __tablename__ = "trusted_devices" + id = db.Column(db.Integer, primary_key=True) + user_id = db.Column(db.Integer, db.ForeignKey("users.id"), nullable=False) + device_name = db.Column(db.String(200), nullable=False) + device_fingerprint = db.Column(db.String(64), nullable=False, unique=False) + user_agent = db.Column(db.String(500), nullable=True) + ip_address = db.Column(db.String(45), nullable=True) + trusted = db.Column(db.Boolean, default=True, nullable=False) + last_seen_at = db.Column(db.DateTime, default=datetime.utcnow, nullable=False) + created_at = db.Column(db.DateTime, default=datetime.utcnow, nullable=False) diff --git a/packages/backend/app/routes/__init__.py b/packages/backend/app/routes/__init__.py index f13b0f89..f62447b9 100644 --- a/packages/backend/app/routes/__init__.py +++ b/packages/backend/app/routes/__init__.py @@ -7,6 +7,8 @@ from .categories import bp as categories_bp from .docs import bp as docs_bp from .dashboard import bp as dashboard_bp +from .categorize import bp as categorize_bp +from .devices import bp as devices_bp def register_routes(app: Flask): @@ -18,3 +20,5 @@ def register_routes(app: Flask): app.register_blueprint(categories_bp, url_prefix="/categories") app.register_blueprint(docs_bp, url_prefix="/docs") app.register_blueprint(dashboard_bp, url_prefix="/dashboard") + app.register_blueprint(categorize_bp, url_prefix="/categorize") + app.register_blueprint(devices_bp, url_prefix="/devices") diff --git a/packages/backend/app/routes/categorize.py b/packages/backend/app/routes/categorize.py new file mode 100644 index 00000000..7bef72b8 --- /dev/null +++ b/packages/backend/app/routes/categorize.py @@ -0,0 +1,112 @@ +import logging + +from flask import Blueprint, jsonify, request +from flask_jwt_extended import jwt_required, get_jwt_identity + +from ..services.categorization import ( + categorize_transaction, + learn_from_correction, + batch_categorize, +) + +bp = Blueprint("categorize", __name__) +logger = logging.getLogger("finmind.categorize") + + +@bp.post("") +@jwt_required() +def categorize(): + """Categorize a single transaction by description.""" + uid = int(get_jwt_identity()) + data = request.get_json() or {} + description = (data.get("description") or "").strip() + if not description: + return jsonify(error="description required"), 400 + category_id = data.get("category_id") + result = categorize_transaction( + description=description, + existing_category_id=category_id, + user_id=uid, + ) + logger.info("Categorized user=%s desc=%s result=%s", uid, description[:50], result.get("category")) + return jsonify(result) + + +@bp.post("/batch") +@jwt_required() +def categorize_batch(): + """Categorize multiple transactions at once.""" + uid = int(get_jwt_identity()) + data = request.get_json() or {} + transactions = data.get("transactions") + if not isinstance(transactions, list) or not transactions: + return jsonify(error="transactions list required"), 400 + if len(transactions) > 100: + return jsonify(error="maximum 100 transactions per batch"), 400 + results = batch_categorize(transactions, user_id=uid) + logger.info("Batch categorized user=%s count=%s", uid, len(results)) + return jsonify(results=results, count=len(results)) + + +@bp.post("/learn") +@jwt_required() +def learn(): + """Learn from a user's manual categorization correction.""" + uid = int(get_jwt_identity()) + data = request.get_json() or {} + description = (data.get("description") or "").strip() + category = (data.get("category") or "").strip() + if not description: + return jsonify(error="description required"), 400 + if not category: + return jsonify(error="category required"), 400 + result = learn_from_correction( + description=description, + correct_category=category, + user_id=uid, + ) + logger.info("Learned user=%s cat=%s keywords=%s", uid, category, result.get("learned_count", 0)) + return jsonify(result) + + +@bp.get("/rules") +@jwt_required() +def list_rules(): + """List learned categorization rules for the current user.""" + uid = int(get_jwt_identity()) + from ..models import CategorizationRule as RuleModel + from ..extensions import db + + rules = ( + db.session.query(RuleModel) + .filter_by(user_id=uid) + .order_by(RuleModel.confidence.desc()) + .all() + ) + return jsonify([ + { + "id": r.id, + "keyword": r.keyword, + "category": r.category_name, + "confidence": round(r.confidence, 2), + "source": r.source, + } + for r in rules + ]) + + +@bp.delete("/rules/") +@jwt_required() +def delete_rule(rule_id: int): + """Delete a learned categorization rule.""" + uid = int(get_jwt_identity()) + from ..models import CategorizationRule as RuleModel + from ..extensions import db + + rule = db.session.get(RuleModel, rule_id) + if not rule or rule.user_id != uid: + return jsonify(error="not found"), 404 + db.session.delete(rule) + db.session.commit() + logger.info("Deleted rule id=%s user=%s", rule_id, uid) + return jsonify(message="deleted") diff --git a/packages/backend/app/routes/devices.py b/packages/backend/app/routes/devices.py new file mode 100644 index 00000000..9bf8f1a2 --- /dev/null +++ b/packages/backend/app/routes/devices.py @@ -0,0 +1,120 @@ +from flask import Blueprint, request, jsonify +from flask_jwt_extended import jwt_required, get_jwt_identity +from ..extensions import db +from ..models import TrustedDevice +import hashlib +import logging + +bp = Blueprint("devices", __name__) +logger = logging.getLogger("finmind.devices") + + +def _make_fingerprint(user_agent: str, ip: str) -> str: + raw = f"{user_agent}|{ip}" + return hashlib.sha256(raw.encode()).hexdigest() + + +@bp.get("/") +@jwt_required() +def list_devices(): + uid = int(get_jwt_identity()) + devices = ( + db.session.query(TrustedDevice) + .filter_by(user_id=uid) + .order_by(TrustedDevice.last_seen_at.desc()) + .all() + ) + return jsonify( + [ + { + "id": d.id, + "device_name": d.device_name, + "trusted": d.trusted, + "user_agent": d.user_agent, + "ip_address": d.ip_address, + "last_seen_at": d.last_seen_at.isoformat() + "Z", + "created_at": d.created_at.isoformat() + "Z", + } + for d in devices + ] + ) + + +@bp.post("/") +@jwt_required() +def trust_device(): + uid = int(get_jwt_identity()) + data = request.get_json() or {} + device_name = (data.get("device_name") or "").strip() + if not device_name: + return jsonify(error="device_name is required"), 400 + + user_agent = request.headers.get("User-Agent", "") + ip = request.remote_addr or "" + fingerprint = _make_fingerprint(user_agent, ip) + + existing = ( + db.session.query(TrustedDevice) + .filter_by(user_id=uid, device_fingerprint=fingerprint) + .first() + ) + if existing: + existing.trusted = True + existing.device_name = device_name + existing.last_seen_at = db.func.now() + db.session.commit() + logger.info("Re-trusted device id=%s for user_id=%s", existing.id, uid) + return jsonify( + id=existing.id, + device_name=existing.device_name, + trusted=existing.trusted, + message="device re-trusted", + ), 200 + + device = TrustedDevice( + user_id=uid, + device_name=device_name, + device_fingerprint=fingerprint, + user_agent=user_agent, + ip_address=ip, + trusted=True, + ) + db.session.add(device) + db.session.commit() + logger.info("Trusted new device id=%s for user_id=%s", device.id, uid) + return jsonify( + id=device.id, + device_name=device.device_name, + trusted=device.trusted, + message="device trusted", + ), 201 + + +@bp.delete("/") +@jwt_required() +def revoke_device(device_id: int): + uid = int(get_jwt_identity()) + device = db.session.get(TrustedDevice, device_id) + if not device or device.user_id != uid: + return jsonify(error="device not found"), 404 + device.trusted = False + db.session.commit() + logger.info("Revoked device id=%s for user_id=%s", device_id, uid) + return jsonify(message="device revoked", id=device.id, trusted=False), 200 + + +@bp.patch("/") +@jwt_required() +def rename_device(device_id: int): + uid = int(get_jwt_identity()) + device = db.session.get(TrustedDevice, device_id) + if not device or device.user_id != uid: + return jsonify(error="device not found"), 404 + data = request.get_json() or {} + new_name = (data.get("device_name") or "").strip() + if not new_name: + return jsonify(error="device_name is required"), 400 + device.device_name = new_name + db.session.commit() + logger.info("Renamed device id=%s to '%s' for user_id=%s", device_id, new_name, uid) + return jsonify(id=device.id, device_name=device.device_name, trusted=device.trusted), 200 diff --git a/packages/backend/app/services/categorization.py b/packages/backend/app/services/categorization.py new file mode 100644 index 00000000..0c1be01d --- /dev/null +++ b/packages/backend/app/services/categorization.py @@ -0,0 +1,381 @@ +""" +Intelligent Transaction Categorization Service + +Provides rule-based auto-categorization with: +- Keyword matching with configurable rules +- Confidence scoring (0.0 - 1.0) +- Learning from user corrections +- Fallback to default "Uncategorized" when confidence is low +""" + +import re +from typing import Any + +from ..extensions import db +from ..models import Category + +# Default keyword rules: keywords → (category_name, base_confidence) +DEFAULT_RULES: dict[str, tuple[str, float]] = { + # Food & Dining + "restaurant": ("Food & Dining", 0.95), + "cafe": ("Food & Dining", 0.90), + "coffee": ("Food & Dining", 0.85), + "pizza": ("Food & Dining", 0.90), + "burger": ("Food & Dining", 0.85), + "sushi": ("Food & Dining", 0.90), + "mcdonald": ("Food & Dining", 0.95), + "starbucks": ("Food & Dining", 0.95), + "uber eats": ("Food & Dining", 0.95), + "swiggy": ("Food & Dining", 0.95), + "zomato": ("Food & Dining", 0.95), + "doordash": ("Food & Dining", 0.95), + "grubhub": ("Food & Dining", 0.95), + "deliveroo": ("Food & Dining", 0.95), + "dining": ("Food & Dining", 0.90), + "grocery": ("Food & Dining", 0.85), + "supermarket": ("Food & Dining", 0.85), + "market": ("Food & Dining", 0.70), + # Transportation + "uber": ("Transportation", 0.90), + "lyft": ("Transportation", 0.95), + "taxi": ("Transportation", 0.90), + "fuel": ("Transportation", 0.85), + "petrol": ("Transportation", 0.90), + "gas station": ("Transportation", 0.90), + "parking": ("Transportation", 0.85), + "toll": ("Transportation", 0.90), + "metro": ("Transportation", 0.90), + "bus": ("Transportation", 0.80), + "train": ("Transportation", 0.80), + "flight": ("Transportation", 0.85), + "airline": ("Transportation", 0.85), + "ola": ("Transportation", 0.95), + "rapido": ("Transportation", 0.95), + # Shopping + "amazon": ("Shopping", 0.95), + "flipkart": ("Shopping", 0.95), + "ebay": ("Shopping", 0.90), + "walmart": ("Shopping", 0.90), + "target": ("Shopping", 0.85), + "mall": ("Shopping", 0.80), + "store": ("Shopping", 0.70), + "shop": ("Shopping", 0.75), + "clothing": ("Shopping", 0.85), + "fashion": ("Shopping", 0.80), + "ikea": ("Shopping", 0.95), + "nike": ("Shopping", 0.95), + "adidas": ("Shopping", 0.95), + # Bills & Utilities + "electric": ("Bills & Utilities", 0.90), + "electricity": ("Bills & Utilities", 0.95), + "water bill": ("Bills & Utilities", 0.95), + "gas bill": ("Bills & Utilities", 0.90), + "internet": ("Bills & Utilities", 0.90), + "broadband": ("Bills & Utilities", 0.90), + "phone bill": ("Bills & Utilities", 0.90), + "mobile": ("Bills & Utilities", 0.70), + "utility": ("Bills & Utilities", 0.85), + "wifi": ("Bills & Utilities", 0.90), + "bseb": ("Bills & Utilities", 0.95), + "msedcl": ("Bills & Utilities", 0.95), + "bsnl": ("Bills & Utilities", 0.95), + "jio": ("Bills & Utilities", 0.85), + "airtel": ("Bills & Utilities", 0.85), + # Entertainment + "netflix": ("Entertainment", 0.95), + "spotify": ("Entertainment", 0.95), + "disney": ("Entertainment", 0.90), + "hbo": ("Entertainment", 0.90), + "movie": ("Entertainment", 0.85), + "cinema": ("Entertainment", 0.90), + "theatre": ("Entertainment", 0.85), + "gaming": ("Entertainment", 0.85), + "steam": ("Entertainment", 0.90), + "playstation": ("Entertainment", 0.90), + "xbox": ("Entertainment", 0.90), + "concert": ("Entertainment", 0.90), + "ticket": ("Entertainment", 0.70), + "youtube": ("Entertainment", 0.80), + "hotstar": ("Entertainment", 0.95), + "prime video": ("Entertainment", 0.95), + # Health + "pharmacy": ("Health", 0.90), + "hospital": ("Health", 0.90), + "doctor": ("Health", 0.90), + "medical": ("Health", 0.85), + "dental": ("Health", 0.90), + "clinic": ("Health", 0.85), + "insurance": ("Health", 0.75), + "medicine": ("Health", 0.85), + "apollo": ("Health", 0.90), + "1mg": ("Health", 0.90), + "pharmeasy": ("Health", 0.90), + # Housing & Rent + "rent": ("Housing & Rent", 0.90), + "mortgage": ("Housing & Rent", 0.95), + "property": ("Housing & Rent", 0.75), + "maintenance": ("Housing & Rent", 0.70), + "hoa": ("Housing & Rent", 0.95), + # Subscriptions & Software + "subscription": ("Subscriptions", 0.85), + "saas": ("Subscriptions", 0.90), + "adobe": ("Subscriptions", 0.90), + "microsoft 365": ("Subscriptions", 0.95), + "gcp": ("Subscriptions", 0.85), + "aws": ("Subscriptions", 0.85), + "azure": ("Subscriptions", 0.85), + "openai": ("Subscriptions", 0.90), + "chatgpt": ("Subscriptions", 0.90), + # Income & Salary + "salary": ("Income", 0.95), + "payroll": ("Income", 0.95), + "freelance": ("Income", 0.85), + "consulting": ("Income", 0.85), + "dividend": ("Income", 0.90), + "interest": ("Income", 0.80), + "refund": ("Income", 0.85), + "cashback": ("Income", 0.90), +} + +# Confidence threshold — below this, return "Uncategorized" +CONFIDENCE_THRESHOLD = 0.5 + + +class CategorizationRule: + """Represents a single categorization rule with keyword and category.""" + + def __init__(self, keyword: str, category_name: str, confidence: float, source: str = "default"): + self.keyword = keyword.lower().strip() + self.category_name = category_name + self.confidence = min(1.0, max(0.0, confidence)) + self.source = source # "default", "learned", "user" + + def matches(self, description: str) -> float | None: + """Return confidence if this rule matches the description, else None.""" + desc_lower = description.lower().strip() + if self.keyword in desc_lower: + return self.confidence + return None + + +class CategorizationResult: + """Result of a categorization attempt.""" + + def __init__( + self, + category: str, + confidence: float, + matched_rule: str | None = None, + alternatives: list[dict[str, Any]] | None = None, + ): + self.category = category + self.confidence = confidence + self.matched_rule = matched_rule + self.alternatives = alternatives or [] + + def to_dict(self) -> dict[str, Any]: + result = { + "category": self.category, + "confidence": round(self.confidence, 2), + } + if self.matched_rule: + result["matched_rule"] = self.matched_rule + if self.alternatives: + result["alternatives"] = self.alternatives + return result + + +def _load_default_rules() -> list[CategorizationRule]: + """Load default keyword rules.""" + return [ + CategorizationRule(keyword=k, category_name=v[0], confidence=v[1], source="default") + for k, v in DEFAULT_RULES.items() + ] + + +def categorize_transaction( + description: str, + existing_category_id: int | None = None, + user_id: int | None = None, +) -> dict[str, Any]: + """ + Categorize a transaction based on its description. + + Args: + description: Transaction description text + existing_category_id: If user already selected a category, use it as a hint + user_id: User ID for loading user-specific learned rules + + Returns: + Dict with category, confidence, matched_rule, and alternatives + """ + if not description or not description.strip(): + return CategorizationResult(category="Uncategorized", confidence=0.0).to_dict() + + desc_lower = description.lower().strip() + rules = _load_default_rules() + + # Load user-learned rules if user_id provided + if user_id: + rules.extend(_load_learned_rules(user_id)) + + # Find all matching rules, sorted by confidence (highest first) + matches: list[tuple[float, str, str]] = [] # (confidence, keyword, category) + for rule in rules: + conf = rule.matches(desc_lower) + if conf is not None: + matches.append((conf, rule.keyword, rule.category_name)) + + # Sort by confidence descending + matches.sort(key=lambda x: x[0], reverse=True) + + if not matches: + return CategorizationResult( + category="Uncategorized", + confidence=0.0, + ).to_dict() + + best_conf, best_keyword, best_category = matches[0] + + # Build alternatives list + alternatives = [] + seen_categories = {best_category} + for conf, kw, cat in matches[1:5]: # Top 4 alternatives + if cat not in seen_categories and conf >= 0.5: + alternatives.append({"category": cat, "confidence": round(conf, 2)}) + seen_categories.add(cat) + + # Below threshold → Uncategorized + if best_conf < CONFIDENCE_THRESHOLD: + return CategorizationResult( + category="Uncategorized", + confidence=best_conf, + alternatives=alternatives, + ).to_dict() + + return CategorizationResult( + category=best_category, + confidence=best_conf, + matched_rule=best_keyword, + alternatives=alternatives, + ).to_dict() + + +def learn_from_correction( + description: str, + correct_category: str, + user_id: int | None = None, +) -> dict[str, Any]: + """ + Learn from a user's manual categorization correction. + + Extracts keywords from the description and stores them as + learned rules for future categorization. + + Args: + description: Original transaction description + correct_category: The correct category the user assigned + user_id: User ID to associate learned rules with + + Returns: + Dict with status and number of rules learned + """ + if not description or not correct_category: + return {"status": "error", "message": "description and category required"} + + desc_lower = description.lower().strip() + + # Extract meaningful keywords (3+ chars, not common stop words) + stop_words = { + "the", "and", "for", "that", "this", "with", "from", "have", "has", + "was", "were", "been", "will", "would", "could", "should", "shall", + "about", "into", "your", "you", "are", "not", "but", "can", "did", + } + words = re.findall(r'\b[a-z]{3,}\b', desc_lower) + keywords = [w for w in words if w not in stop_words] + + learned_count = 0 + learned_keywords = [] + + if user_id: + from ..models import CategorizationRule as RuleModel + for keyword in keywords[:5]: # Max 5 rules per correction + # Check if rule already exists + existing = ( + db.session.query(RuleModel) + .filter_by(user_id=user_id, keyword=keyword) + .first() + ) + if existing: + # Boost confidence + existing.confidence = min(1.0, existing.confidence + 0.05) + existing.category_name = correct_category + else: + rule = RuleModel( + user_id=user_id, + keyword=keyword, + category_name=correct_category, + confidence=0.80, + source="learned", + ) + db.session.add(rule) + learned_count += 1 + learned_keywords.append(keyword) + db.session.commit() + + return { + "status": "ok", + "learned_count": learned_count, + "keywords": learned_keywords, + "category": correct_category, + } + + +def _load_learned_rules(user_id: int) -> list[CategorizationRule]: + """Load learned rules for a specific user from the database.""" + try: + from ..models import CategorizationRule as RuleModel + db_rules = ( + db.session.query(RuleModel) + .filter_by(user_id=user_id) + .all() + ) + return [ + CategorizationRule( + keyword=r.keyword, + category_name=r.category_name, + confidence=r.confidence, + source=r.source, + ) + for r in db_rules + ] + except Exception: + return [] + + +def batch_categorize( + transactions: list[dict[str, Any]], + user_id: int | None = None, +) -> list[dict[str, Any]]: + """ + Categorize multiple transactions at once. + + Args: + transactions: List of dicts with 'description' and optional 'category_id' + user_id: User ID for personalized rules + + Returns: + List of categorization results + """ + results = [] + for txn in transactions: + desc = txn.get("description", "") + cat_id = txn.get("category_id") + result = categorize_transaction( + description=desc, + existing_category_id=cat_id, + user_id=user_id, + ) + result["original_description"] = desc + results.append(result) + return results diff --git a/packages/backend/tests/test_categorize.py b/packages/backend/tests/test_categorize.py new file mode 100644 index 00000000..480ebb2e --- /dev/null +++ b/packages/backend/tests/test_categorize.py @@ -0,0 +1,296 @@ +"""Tests for the Intelligent Transaction Categorization Service.""" + + +class TestCategorizeEndpoint: + """Tests for POST /categorize""" + + def test_categorize_known_merchant(self, client, auth_header): + r = client.post( + "/categorize", + json={"description": "Starbucks Coffee Store #1234"}, + headers=auth_header, + ) + assert r.status_code == 200 + data = r.get_json() + assert data["category"] == "Food & Dining" + assert data["confidence"] >= 0.9 + + def test_categorize_transport(self, client, auth_header): + r = client.post( + "/categorize", + json={"description": "Uber Trip - Downtown Airport"}, + headers=auth_header, + ) + assert r.status_code == 200 + data = r.get_json() + assert data["category"] == "Transportation" + assert data["confidence"] >= 0.5 + + def test_categorize_shopping(self, client, auth_header): + r = client.post( + "/categorize", + json={"description": "Amazon.com Purchase - Electronics"}, + headers=auth_header, + ) + assert r.status_code == 200 + data = r.get_json() + assert data["category"] == "Shopping" + assert data["confidence"] >= 0.9 + + def test_categorize_unknown_returns_uncategorized(self, client, auth_header): + r = client.post( + "/categorize", + json={"description": "xyzzy29487 random gibberish"}, + headers=auth_header, + ) + assert r.status_code == 200 + data = r.get_json() + assert data["category"] == "Uncategorized" + assert data["confidence"] == 0.0 + + def test_categorize_empty_description_400(self, client, auth_header): + r = client.post("/categorize", json={"description": ""}, headers=auth_header) + assert r.status_code == 400 + + def test_categorize_missing_description_400(self, client, auth_header): + r = client.post("/categorize", json={}, headers=auth_header) + assert r.status_code == 400 + + def test_categorize_entertainment(self, client, auth_header): + r = client.post( + "/categorize", + json={"description": "Netflix Monthly Subscription"}, + headers=auth_header, + ) + assert r.status_code == 200 + data = r.get_json() + assert data["category"] == "Entertainment" + assert data["confidence"] >= 0.9 + + def test_categorize_bills(self, client, auth_header): + r = client.post( + "/categorize", + json={"description": "Electricity Bill - MSEDCL May 2025"}, + headers=auth_header, + ) + assert r.status_code == 200 + data = r.get_json() + assert data["category"] == "Bills & Utilities" + + def test_categorize_health(self, client, auth_header): + r = client.post( + "/categorize", + json={"description": "Apollo Hospital Lab Tests"}, + headers=auth_header, + ) + assert r.status_code == 200 + data = r.get_json() + assert data["category"] == "Health" + + def test_categorize_returns_alternatives(self, client, auth_header): + r = client.post( + "/categorize", + json={"description": "Amazon AWS Cloud Services"}, + headers=auth_header, + ) + assert r.status_code == 200 + data = r.get_json() + # Should match either Shopping or Subscriptions, with alternatives + assert data["category"] in ("Shopping", "Subscriptions") + assert "alternatives" in data + + +class TestBatchCategorize: + """Tests for POST /categorize/batch""" + + def test_batch_categorize(self, client, auth_header): + r = client.post( + "/categorize/batch", + json={ + "transactions": [ + {"description": "Starbucks Coffee"}, + {"description": "Uber Ride Home"}, + {"description": "Amazon Order"}, + {"description": "Netflix Subscription"}, + ] + }, + headers=auth_header, + ) + assert r.status_code == 200 + data = r.get_json() + assert data["count"] == 4 + assert len(data["results"]) == 4 + categories = [r["category"] for r in data["results"]] + assert "Food & Dining" in categories + assert "Transportation" in categories + assert "Shopping" in categories + assert "Entertainment" in categories + + def test_batch_empty_list_400(self, client, auth_header): + r = client.post( + "/categorize/batch", + json={"transactions": []}, + headers=auth_header, + ) + assert r.status_code == 400 + + def test_batch_too_many_400(self, client, auth_header): + r = client.post( + "/categorize/batch", + json={"transactions": [{"description": f"test {i}"} for i in range(101)]}, + headers=auth_header, + ) + assert r.status_code == 400 + + def test_batch_missing_transactions_400(self, client, auth_header): + r = client.post("/categorize/batch", json={}, headers=auth_header) + assert r.status_code == 400 + + +class TestLearnEndpoint: + """Tests for POST /categorize/learn""" + + def test_learn_from_correction(self, client, auth_header): + r = client.post( + "/categorize/learn", + json={ + "description": "Local Pizza Palace Delivery", + "category": "Food & Dining", + }, + headers=auth_header, + ) + assert r.status_code == 200 + data = r.get_json() + assert data["status"] == "ok" + assert data["learned_count"] >= 1 + assert "pizza" in data["keywords"] + + def test_learn_improves_categorization(self, client, auth_header): + # First categorize — should be unknown or low confidence + r1 = client.post( + "/categorize", + json={"description": "Wompalicious Gym Membership"}, + headers=auth_header, + ) + assert r1.status_code == 200 + + # Learn the correction + r2 = client.post( + "/categorize/learn", + json={ + "description": "Wompalicious Gym Membership", + "category": "Health", + }, + headers=auth_header, + ) + assert r2.status_code == 200 + assert r2.get_json()["status"] == "ok" + + # Categorize again — should now match "health" or the learned keyword + r3 = client.post( + "/categorize", + json={"description": "Wompalicious Gym Membership"}, + headers=auth_header, + ) + assert r3.status_code == 200 + # The keyword "wompalicious" or "gym" or "membership" should trigger + assert r3.get_json()["confidence"] >= 0.5 + + def test_learn_missing_category_400(self, client, auth_header): + r = client.post( + "/categorize/learn", + json={"description": "test transaction"}, + headers=auth_header, + ) + assert r.status_code == 400 + + def test_learn_missing_description_400(self, client, auth_header): + r = client.post( + "/categorize/learn", + json={"category": "Food"}, + headers=auth_header, + ) + assert r.status_code == 400 + + def test_learn_empty_description_400(self, client, auth_header): + r = client.post( + "/categorize/learn", + json={"description": "", "category": "Food"}, + headers=auth_header, + ) + assert r.status_code == 400 + + +class TestRulesEndpoint: + """Tests for GET /categorize/rules and DELETE /categorize/rules/:id""" + + def test_list_rules_empty(self, client, auth_header): + r = client.get("/categorize/rules", headers=auth_header) + assert r.status_code == 200 + assert r.get_json() == [] + + def test_list_rules_after_learning(self, client, auth_header): + # Learn a rule + client.post( + "/categorize/learn", + json={"description": "Some Unique Transaction XYZ", "category": "Shopping"}, + headers=auth_header, + ) + r = client.get("/categorize/rules", headers=auth_header) + assert r.status_code == 200 + rules = r.get_json() + assert len(rules) >= 1 + assert any(r["category"] == "Shopping" for r in rules) + + def test_delete_rule(self, client, auth_header): + # Learn a rule + client.post( + "/categorize/learn", + json={"description": "Delete Me Transaction ABC", "category": "Entertainment"}, + headers=auth_header, + ) + # Get rules + rules = client.get("/categorize/rules", headers=auth_header).get_json() + rule_id = rules[0]["id"] + # Delete + r = client.delete(f"/categorize/rules/{rule_id}", headers=auth_header) + assert r.status_code == 200 + # Verify deleted + rules2 = client.get("/categorize/rules", headers=auth_header).get_json() + assert not any(r["id"] == rule_id for r in rules2) + + def test_delete_nonexistent_rule_404(self, client, auth_header): + r = client.delete("/categorize/rules/99999", headers=auth_header) + assert r.status_code == 404 + + +class TestConfidenceScoring: + """Tests for confidence scoring accuracy.""" + + def test_high_confidence_exact_match(self, client, auth_header): + r = client.post( + "/categorize", + json={"description": "Netflix"}, + headers=auth_header, + ) + data = r.get_json() + assert data["confidence"] >= 0.95 + + def test_medium_confidence_partial_match(self, client, auth_header): + r = client.post( + "/categorize", + json={"description": "Trip to the market"}, + headers=auth_header, + ) + data = r.get_json() + # "market" matches Food & Dining at 0.70 + assert data["confidence"] <= 0.80 + + def test_low_confidence_no_match(self, client, auth_header): + r = client.post( + "/categorize", + json={"description": "zzzznonexistent123"}, + headers=auth_header, + ) + data = r.get_json() + assert data["category"] == "Uncategorized" diff --git a/packages/backend/tests/test_devices.py b/packages/backend/tests/test_devices.py new file mode 100644 index 00000000..5ace7a2d --- /dev/null +++ b/packages/backend/tests/test_devices.py @@ -0,0 +1,98 @@ +def test_device_trust_crud(client): + # Register and login + email = "devicetrust@test.com" + password = "secret123" + r = client.post("/auth/register", json={"email": email, "password": password}) + assert r.status_code in (201, 409) + + r = client.post("/auth/login", json={"email": email, "password": password}) + assert r.status_code == 200 + access = r.get_json()["access_token"] + auth = {"Authorization": f"Bearer {access}"} + + # List devices (empty) + r = client.get("/devices/", headers=auth) + assert r.status_code == 200 + assert r.get_json() == [] + + # Trust a device + r = client.post( + "/devices/", + json={"device_name": "My Laptop"}, + headers=auth, + ) + assert r.status_code == 201 + data = r.get_json() + assert data["device_name"] == "My Laptop" + assert data["trusted"] is True + device_id = data["id"] + + # List devices (1) + r = client.get("/devices/", headers=auth) + assert r.status_code == 200 + devices = r.get_json() + assert len(devices) == 1 + assert devices[0]["device_name"] == "My Laptop" + + # Rename device + r = client.patch(f"/devices/{device_id}", json={"device_name": "Work Laptop"}, headers=auth) + assert r.status_code == 200 + assert r.get_json()["device_name"] == "Work Laptop" + + # Revoke device + r = client.delete(f"/devices/{device_id}", headers=auth) + assert r.status_code == 200 + assert r.get_json()["trusted"] is False + + # Verify it shows as untrusted + r = client.get("/devices/", headers=auth) + devices = r.get_json() + assert len(devices) == 1 + assert devices[0]["trusted"] is False + + +def test_device_trust_requires_name(client): + email = "noname@test.com" + password = "secret123" + client.post("/auth/register", json={"email": email, "password": password}) + r = client.post("/auth/login", json={"email": email, "password": password}) + access = r.get_json()["access_token"] + auth = {"Authorization": f"Bearer {access}"} + + r = client.post("/devices/", json={"device_name": ""}, headers=auth) + assert r.status_code == 400 + + +def test_device_re_trust(client): + email = "retrust@test.com" + password = "secret123" + client.post("/auth/register", json={"email": email, "password": password}) + r = client.post("/auth/login", json={"email": email, "password": password}) + access = r.get_json()["access_token"] + auth = {"Authorization": f"Bearer {access}"} + + # Trust + r = client.post("/devices/", json={"device_name": "Phone"}, headers=auth) + assert r.status_code == 201 + device_id = r.get_json()["id"] + + # Revoke + client.delete(f"/devices/{device_id}", headers=auth) + + # Re-trust same device (same UA/IP = same fingerprint) -> updates existing + r = client.post("/devices/", json={"device_name": "Phone"}, headers=auth) + assert r.status_code == 200 + assert r.get_json()["trusted"] is True + assert r.get_json()["id"] == device_id + + +def test_device_revoke_not_found(client): + email = "nofound@test.com" + password = "secret123" + client.post("/auth/register", json={"email": email, "password": password}) + r = client.post("/auth/login", json={"email": email, "password": password}) + access = r.get_json()["access_token"] + auth = {"Authorization": f"Bearer {access}"} + + r = client.delete("/devices/9999", headers=auth) + assert r.status_code == 404