Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,23 @@ jobs:
- name: Run ruff format check
run: ruff format --check app/

test:
name: Unit tests (pytest)
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4

- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: "3.12"

- name: Install dev deps
run: pip install -r requirements-dev.txt

- name: Run pytest
run: pytest -q

docker-build:
name: Docker build
runs-on: ubuntu-latest
Expand Down
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,14 @@ __pycache__/
*.py[cod]
*.pyo
.venv/
.venv-introspect/
venv/
.eggs/
*.egg-info/
dist/
build/
.pytest_cache/
.ruff_cache/

# Docker
*.tar
Expand Down
151 changes: 151 additions & 0 deletions app/schwab_data_proxy/enum_mapping.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
"""
Enum mapping helpers — translate REST query-string params into the typed enums
that schwab-py (pinned 1.5.1) requires.

schwab-py enforces enum types on its client methods (``enforce_enums=True`` by
default). Passing raw strings/ints raises ``ValueError``/``AttributeError`` at
call time. These helpers centralize the param -> enum translation so the REST
handlers stay thin and the mapping is unit-testable against a mocked client.

Design notes (validated against schwab-py 1.5.1):
- ``client.Quote.Fields`` replaced the old ``client.Quote.FIELD_*`` members.
- ``PriceHistory.Period`` / ``PriceHistory.Frequency`` are int-valued enums; the
proxy receives raw ints and must construct the enum *by value*.
- ``PriceHistory.PeriodType`` / ``FrequencyType`` are string-valued enums; the
proxy receives the lowercase token (e.g. ``"daily"``) and constructs *by value*.
- ``get_option_chain`` takes ``strike_range`` (NOT ``range``) and the Options
enums are string-valued (``ITM``, ``CALL`` ...). Callers may send either the
member name (``IN_THE_MONEY``) or the wire value (``ITM``); both are accepted.
- ``MarketHours.Market`` is a string-valued enum (``equity`` ...).

Every mapper accepts the live (or mocked) ``client`` so it can read the enum
classes off it, and raises ``UnknownEnumValue`` on an unmappable input rather
than silently passing a raw string through (which would surface as an opaque
schwab-py error downstream).
"""

from __future__ import annotations

from enum import Enum
from typing import Any


class UnknownEnumValue(ValueError):
"""Raised when an inbound param cannot be mapped to a schwab-py enum member."""


def _coerce(enum_cls: type[Enum], raw: Any, *, label: str) -> Enum:
"""Map ``raw`` onto ``enum_cls`` by member name or by wire value.

Accepts (in order):
1. an already-constructed member of ``enum_cls`` (pass-through)
2. the member NAME, case-insensitive (e.g. ``"in_the_money"``)
3. the member VALUE (e.g. ``"ITM"``, ``1``, ``"daily"``)
"""
if isinstance(raw, enum_cls):
return raw

# by name (case-insensitive) — only meaningful for string-ish inputs
if isinstance(raw, str):
try:
return enum_cls[raw.strip().upper()]
except KeyError:
pass

# by value (handles int-valued Period/Frequency and string-valued enums)
try:
return enum_cls(raw)
except ValueError:
pass

# last attempt: string value match case-insensitively
if isinstance(raw, str):
target = raw.strip().upper()
for member in enum_cls:
if str(member.value).upper() == target:
return member

valid = [m.name for m in enum_cls]
raise UnknownEnumValue(
f"{label}: cannot map {raw!r} to {enum_cls.__name__}; valid: {valid}"
)


# ---------------------------------------------------------------------------
# Quotes
# ---------------------------------------------------------------------------


def map_quote_fields(client: Any, fields: str) -> list[Enum]:
"""``"quote,reference"`` -> ``[Quote.Fields.QUOTE, Quote.Fields.REFERENCE]``."""
enum_cls = client.Quote.Fields
out: list[Enum] = []
for f in fields.split(","):
f = f.strip()
if not f:
continue
out.append(_coerce(enum_cls, f, label="quotes.fields"))
return out


# ---------------------------------------------------------------------------
# Price history
# ---------------------------------------------------------------------------


def map_period_type(client: Any, value: str) -> Enum:
return _coerce(
client.PriceHistory.PeriodType, value, label="pricehistory.period_type"
)


def map_period(client: Any, value: int) -> Enum:
return _coerce(client.PriceHistory.Period, value, label="pricehistory.period")


def map_frequency_type(client: Any, value: str) -> Enum:
return _coerce(
client.PriceHistory.FrequencyType, value, label="pricehistory.frequency_type"
)


def map_frequency(client: Any, value: int) -> Enum:
return _coerce(client.PriceHistory.Frequency, value, label="pricehistory.frequency")


# ---------------------------------------------------------------------------
# Option chains
# ---------------------------------------------------------------------------


def map_contract_type(client: Any, value: str) -> Enum:
return _coerce(client.Options.ContractType, value, label="chains.contract_type")


def map_strategy(client: Any, value: str) -> Enum:
return _coerce(client.Options.Strategy, value, label="chains.strategy")


def map_strike_range(client: Any, value: str) -> Enum:
return _coerce(client.Options.StrikeRange, value, label="chains.range")


def map_expiration_month(client: Any, value: str) -> Enum:
return _coerce(client.Options.ExpirationMonth, value, label="chains.exp_month")


def map_option_type(client: Any, value: str) -> Enum:
return _coerce(client.Options.Type, value, label="chains.option_type")


def map_entitlement(client: Any, value: str) -> Enum:
return _coerce(client.Options.Entitlement, value, label="chains.entitlement")


# ---------------------------------------------------------------------------
# Market hours
# ---------------------------------------------------------------------------


def map_market(client: Any, value: str) -> Enum:
return _coerce(client.MarketHours.Market, value, label="markets.markets")
81 changes: 26 additions & 55 deletions app/schwab_data_proxy/rest_proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
from fastapi import APIRouter, Query, Request
from fastapi.responses import JSONResponse

from . import enum_mapping
from .enum_mapping import UnknownEnumValue
from .schwab_session import session
from .settings import settings

Expand Down Expand Up @@ -57,6 +59,11 @@ async def _cached_call(endpoint: str, params: dict, coroutine_factory) -> JSONRe

try:
resp = await coroutine_factory()
except UnknownEnumValue as exc:
# An inbound param could not be mapped to a schwab-py enum — this is a
# client error (bad query string), not an upstream failure.
logger.warning("Invalid enum param for %s: %s", endpoint, exc)
return _error_response("BAD_REQUEST", str(exc), http_status=400)
except Exception as exc: # noqa: BLE001
logger.error("Upstream call failed for %s: %s", endpoint, exc)
return _error_response("UPSTREAM_ERROR", str(exc))
Expand Down Expand Up @@ -126,18 +133,7 @@ async def get_quotes(
async def call():
kwargs: dict[str, Any] = {}
if fields:
field_values = []
field_map = {
"quote": client.Quote.FIELD_QUOTE,
"reference": client.Quote.FIELD_REFERENCE,
"extended": client.Quote.FIELD_EXTENDED,
"fundamental": client.Quote.FIELD_FUNDAMENTAL,
"regular": client.Quote.FIELD_REGULAR,
}
for f in fields.split(","):
f = f.strip()
if f in field_map:
field_values.append(field_map[f])
field_values = enum_mapping.map_quote_fields(client, fields)
if field_values:
kwargs["fields"] = field_values
return await client.get_quotes(symbol_list, **kwargs)
Expand Down Expand Up @@ -199,30 +195,22 @@ async def get_chains(
async def call():
kwargs: dict[str, Any] = {"symbol": symbol}
if contract_type is not None:
try:
kwargs["contract_type"] = client.Options.ContractType[
contract_type.upper()
]
except (KeyError, AttributeError):
kwargs["contract_type"] = contract_type
kwargs["contract_type"] = enum_mapping.map_contract_type(
client, contract_type
)
if strike_count is not None:
kwargs["strike_count"] = strike_count
if include_underlying_quote is not None:
kwargs["include_underlying_quote"] = include_underlying_quote
if strategy is not None:
try:
kwargs["strategy"] = client.Options.Strategy[strategy.upper()]
except (KeyError, AttributeError):
kwargs["strategy"] = strategy
kwargs["strategy"] = enum_mapping.map_strategy(client, strategy)
if interval is not None:
kwargs["interval"] = interval
if strike is not None:
kwargs["strike"] = strike
if range is not None:
try:
kwargs["range"] = client.Options.StrikeRange[range.upper()]
except (KeyError, AttributeError):
kwargs["range"] = range
# schwab-py 1.5.1 names this kwarg ``strike_range`` (not ``range``).
kwargs["strike_range"] = enum_mapping.map_strike_range(client, range)
if from_date is not None:
kwargs["from_date"] = from_date
if to_date is not None:
Expand All @@ -236,17 +224,11 @@ async def call():
if days_to_expiration is not None:
kwargs["days_to_expiration"] = days_to_expiration
if exp_month is not None:
try:
kwargs["exp_month"] = client.Options.ExpirationMonth[exp_month.upper()]
except (KeyError, AttributeError):
kwargs["exp_month"] = exp_month
kwargs["exp_month"] = enum_mapping.map_expiration_month(client, exp_month)
if option_type is not None:
try:
kwargs["option_type"] = client.Options.Type[option_type.upper()]
except (KeyError, AttributeError):
kwargs["option_type"] = option_type
kwargs["option_type"] = enum_mapping.map_option_type(client, option_type)
if entitlement is not None:
kwargs["entitlement"] = entitlement
kwargs["entitlement"] = enum_mapping.map_entitlement(client, entitlement)
return await client.get_option_chain(**kwargs)

return await _cached_call("chains", params, call)
Expand Down Expand Up @@ -290,23 +272,17 @@ async def get_pricehistory(
async def call():
kwargs: dict[str, Any] = {"symbol": symbol}
if period_type is not None:
try:
kwargs["period_type"] = client.PriceHistory.PeriodType[
period_type.upper()
]
except (KeyError, AttributeError):
kwargs["period_type"] = period_type
kwargs["period_type"] = enum_mapping.map_period_type(client, period_type)
if period is not None:
kwargs["period"] = period
# schwab-py 1.5.1 enforces PriceHistory.Period (int-valued enum).
kwargs["period"] = enum_mapping.map_period(client, period)
if frequency_type is not None:
try:
kwargs["frequency_type"] = client.PriceHistory.FrequencyType[
frequency_type.upper()
]
except (KeyError, AttributeError):
kwargs["frequency_type"] = frequency_type
kwargs["frequency_type"] = enum_mapping.map_frequency_type(
client, frequency_type
)
if frequency is not None:
kwargs["frequency"] = frequency
# schwab-py 1.5.1 enforces PriceHistory.Frequency (int-valued enum).
kwargs["frequency"] = enum_mapping.map_frequency(client, frequency)
if start_datetime is not None:
kwargs["start_datetime"] = start_datetime
if end_datetime is not None:
Expand Down Expand Up @@ -344,12 +320,7 @@ async def get_markets(
)

async def call():
market_enums = []
for m in market_list:
try:
market_enums.append(client.MarketHours.Market[m.upper()])
except (KeyError, AttributeError):
market_enums.append(m)
market_enums = [enum_mapping.map_market(client, m) for m in market_list]
kwargs: dict[str, Any] = {"markets": market_enums}
if date:
kwargs["date"] = date
Expand Down
12 changes: 12 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
[tool.pytest.ini_options]
# Run from repo root: `app` resolves as a namespace package and `schwab_data_proxy`
# matches the in-container import path. asyncio_mode=auto lets async test
# functions run without a per-test decorator.
pythonpath = ["."]
testpaths = ["tests"]
asyncio_mode = "auto"

[tool.ruff]
# rest_proxy intentionally shadows the builtin `range` to match the Schwab REST
# query-param name; that file is reviewed and the shadow is deliberate.
target-version = "py312"
3 changes: 3 additions & 0 deletions requirements-dev.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
-r requirements.txt
pytest>=8
pytest-asyncio>=0.24
6 changes: 5 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
schwab-py
# Pinned: the rest_proxy enum mapping (Quote.Fields, PriceHistory.Period/Frequency,
# Options.*, MarketHours.Market) is validated against this exact version. Do not
# unpin — schwab-py renames/retypes enum members across releases and silently
# broke the /v1 endpoints when previously unpinned.
schwab-py==1.5.1
fastapi
uvicorn[standard]
pydantic-settings
Expand Down
Loading
Loading