From 84b945c27b069419717c568a9a196417f606cedb Mon Sep 17 00:00:00 2001 From: Chad Newbry Date: Fri, 27 Mar 2026 17:15:00 -0400 Subject: [PATCH 1/8] Add mta-next-train community ability --- community/mta-next-train/README.md | 50 +++ community/mta-next-train/__init__.py | 1 + community/mta-next-train/main.py | 195 +++++++++++ .../mta-next-train/mta_next_train_core.py | 318 ++++++++++++++++++ community/mta-next-train/test_harness.py | 72 ++++ community/mta-next-train/tests/fixtures.py | 53 +++ .../tests/test_mta_next_train.py | 59 ++++ 7 files changed, 748 insertions(+) create mode 100644 community/mta-next-train/README.md create mode 100644 community/mta-next-train/__init__.py create mode 100644 community/mta-next-train/main.py create mode 100644 community/mta-next-train/mta_next_train_core.py create mode 100644 community/mta-next-train/test_harness.py create mode 100644 community/mta-next-train/tests/fixtures.py create mode 100644 community/mta-next-train/tests/test_mta_next_train.py diff --git a/community/mta-next-train/README.md b/community/mta-next-train/README.md new file mode 100644 index 00000000..a55ce0db --- /dev/null +++ b/community/mta-next-train/README.md @@ -0,0 +1,50 @@ +# MTA Next Train + +![Community](https://img.shields.io/badge/OpenHome-Community-orange?style=flat-square) +![Author](https://img.shields.io/badge/Author-@chadnewbry-lightgrey?style=flat-square) + +Live New York City subway arrivals for OpenHome. Ask for your next train using a saved default station, or ask for a specific line and station on demand. + +## What It Does + +- Reads live subway arrivals from `SubwayInfo.nyc` with no API key required +- Saves one default station so the shortest prompt can be "when's my next train?" +- Handles explicit station requests like "next Q train at Union Square" +- Supports simple direction filters like "northbound" or "downtown" +- Lets the user change their default station by voice + +## Trigger Words + +- `next train` +- `next subway` +- `when's my next train` +- `subway arrivals` +- `mta` +- `mta next train` + +## Setup + +This ability uses `SubwayInfo.nyc` and does not require an API key. + +Install the ability in OpenHome, then set a default station once: + +- "set my default station to Astor Place" +- "use 14 street union square as my default station" + +Then the main demo becomes: + +- "when's my next train?" + +## Example Voice Commands + +- "When's my next train?" +- "Next Q train at Union Square" +- "When is the next northbound 6 at Astor Place?" +- "Set my default station to Fulton Street" +- "Change my default station to Jay Street MetroTech" + +## Technical Notes + +- Station search and live arrivals come from `SubwayInfo.nyc` +- No API key or account setup is required +- The local test harness includes fixture-based end-to-end checks plus an optional live mode diff --git a/community/mta-next-train/__init__.py b/community/mta-next-train/__init__.py new file mode 100644 index 00000000..7197977e --- /dev/null +++ b/community/mta-next-train/__init__.py @@ -0,0 +1 @@ +# Community ability package marker. diff --git a/community/mta-next-train/main.py b/community/mta-next-train/main.py new file mode 100644 index 00000000..43d79759 --- /dev/null +++ b/community/mta-next-train/main.py @@ -0,0 +1,195 @@ +import os +import sys +from typing import Dict, Optional + +CURRENT_DIR = os.path.dirname(os.path.abspath(__file__)) +if CURRENT_DIR not in sys.path: + sys.path.append(CURRENT_DIR) + +from src.agent.capability import MatchingCapability +from src.agent.capability_worker import CapabilityWorker +from src.main import AgentWorker + +from mta_next_train_core import ( + ACTION_HELP, + ACTION_SET_DEFAULT, + QueryIntent, + Station, + fetch_arrivals, + find_station_matches, + format_arrivals_for_voice, + parse_query_intent, + search_stations, + station_from_prefs, +) + + +PREFS_KEY = "mta_next_train_prefs" +EXIT_WORDS = { + "stop", "exit", "quit", "done", "cancel", "bye", "goodbye", "nothing else", +} + + +class MTANextTrainCapability(MatchingCapability): + worker: AgentWorker = None + capability_worker: CapabilityWorker = None + prefs: Dict = None + + #{{register_capability}} + + def call(self, worker: AgentWorker): + self.worker = worker + self.capability_worker = CapabilityWorker(self.worker) + self.worker.session_tasks.create(self.run()) + + async def run(self): + try: + self.prefs = self.load_prefs() + trigger_text = self.get_trigger_text() + if not trigger_text: + trigger_text = await self.capability_worker.run_io_loop( + "What would you like to check? You can say, when is my next train, or set my default station to Astor Place." + ) + if not trigger_text: + return + + while True: + if self._is_exit(trigger_text): + await self.capability_worker.speak("Okay. Closing MTA Next Train.") + return + + intent = parse_query_intent(trigger_text) + if intent.action == ACTION_HELP: + await self.capability_worker.speak( + "You can ask for your next train, ask for a specific line at a station, or set a default station. For example: when is my next train, next Q train at Union Square, or set my default station to Astor Place." + ) + elif intent.action == ACTION_SET_DEFAULT: + await self.handle_set_default(intent) + else: + await self.handle_arrivals(intent) + + trigger_text = await self.capability_worker.run_io_loop( + "Anything else for the subway? Say another station or line, or say done." + ) + if not trigger_text: + return + except Exception as exc: + self.worker.editor_logging_handler.error(f"[MTANextTrain] {exc}") + await self.capability_worker.speak( + "Something went wrong while checking live arrivals." + ) + finally: + self.capability_worker.resume_normal_flow() + + def load_prefs(self) -> Dict: + prefs = self.capability_worker.get_single_key(PREFS_KEY) + if isinstance(prefs, dict): + return prefs + return { + "default_station_id": "", + "default_station_name": "", + "default_station_borough": "", + "default_station_lines": [], + } + + def save_prefs(self): + existing = self.capability_worker.get_single_key(PREFS_KEY) + if existing: + self.capability_worker.update_key(PREFS_KEY, self.prefs) + else: + self.capability_worker.create_key(PREFS_KEY, self.prefs) + + def get_trigger_text(self) -> str: + history = self.capability_worker.get_full_message_history() or [] + for item in reversed(history): + if item.get("role") == "user": + content = (item.get("content") or "").strip() + if content: + return content + return "" + + def _is_exit(self, text: str) -> bool: + lowered = (text or "").strip().lower() + return any(phrase in lowered for phrase in EXIT_WORDS) + + async def handle_set_default(self, intent: QueryIntent): + station_text = intent.station_text + if not station_text: + station_text = await self.capability_worker.run_io_loop( + "Which station should I save as your default?" + ) + if not station_text: + return + + station = await self.resolve_station(station_text) + if not station: + return + + self.prefs["default_station_id"] = station.station_id + self.prefs["default_station_name"] = station.name + self.prefs["default_station_borough"] = station.borough + self.prefs["default_station_lines"] = station.lines + self.save_prefs() + await self.capability_worker.speak( + f"Saved {station.name} as your default station." + ) + + async def handle_arrivals(self, intent: QueryIntent): + station = None + if intent.station_text: + station = await self.resolve_station(intent.station_text) + elif self.prefs.get("default_station_id"): + station = station_from_prefs( + self.prefs["default_station_id"], + self.prefs.get("default_station_name", ""), + self.prefs.get("default_station_borough", ""), + self.prefs.get("default_station_lines", []), + ) + + if not station: + spoken_station = await self.capability_worker.run_io_loop( + "Which station do you want to check? You can also say set my default station to save one." + ) + if not spoken_station: + return + if self._is_exit(spoken_station): + await self.capability_worker.speak("Okay. Closing MTA Next Train.") + return + station = await self.resolve_station(spoken_station) + if not station: + return + + await self.capability_worker.speak(f"Checking live arrivals for {station.name}.") + arrivals = fetch_arrivals(station.station_id, intent.routes, intent.direction) + summary = format_arrivals_for_voice( + station, + arrivals, + intent.routes, + intent.direction, + ) + await self.capability_worker.speak(summary) + + async def resolve_station(self, station_text: str) -> Optional[Station]: + stations = search_stations(station_text, limit=5) + matches = find_station_matches(stations, station_text) + if not matches: + await self.capability_worker.speak( + f"I could not find a subway station matching {station_text}." + ) + return None + + top_match = matches[0] + if len(matches) > 1 and (top_match.score - matches[1].score) < 0.08: + options = ", ".join(match.station.name for match in matches[:3]) + response = await self.capability_worker.run_io_loop( + f"I found a few close matches: {options}. Which one did you mean?" + ) + if not response: + return None + narrowed = find_station_matches(stations, response, limit=1) + if not narrowed: + await self.capability_worker.speak("I still could not pin down the station.") + return None + return narrowed[0].station + + return top_match.station diff --git a/community/mta-next-train/mta_next_train_core.py b/community/mta-next-train/mta_next_train_core.py new file mode 100644 index 00000000..c0bc222d --- /dev/null +++ b/community/mta-next-train/mta_next_train_core.py @@ -0,0 +1,318 @@ +from __future__ import annotations + +import difflib +import json +import re +import urllib.parse +import urllib.request +from dataclasses import dataclass, field +from typing import Dict, List, Optional, Sequence + + +API_BASE_URL = "https://subwayinfo.nyc/api" +ACTION_SET_DEFAULT = "set_default" +ACTION_ARRIVALS = "arrivals" +ACTION_HELP = "help" +ROUTE_ALIASES = { + "one": "1", + "two": "2", + "three": "3", + "four": "4", + "five": "5", + "six": "6", + "seven": "7", + "a": "A", + "b": "B", + "c": "C", + "d": "D", + "e": "E", + "f": "F", + "g": "G", + "j": "J", + "l": "L", + "m": "M", + "n": "N", + "q": "Q", + "r": "R", + "s": "S", + "w": "W", + "z": "Z", + "shuttle": "S", +} +NORTHBOUND_WORDS = {"northbound", "uptown"} +SOUTHBOUND_WORDS = {"southbound", "downtown"} +STREET_WORDS = { + "st": "street", + "street": "street", + "ave": "avenue", + "av": "avenue", + "avenue": "avenue", + "sq": "square", + "square": "square", + "plz": "plaza", + "plaza": "plaza", +} + + +@dataclass +class Station: + station_id: str + name: str + normalized_name: str + borough: str = "" + lines: List[str] = field(default_factory=list) + + +@dataclass +class StationMatch: + station: Station + score: float + + +@dataclass +class QueryIntent: + action: str + station_text: Optional[str] = None + routes: List[str] = field(default_factory=list) + direction: Optional[str] = None + + +@dataclass +class Arrival: + route_id: str + direction: str + direction_label: str + minutes_away: int + headsign: str + + +def normalize_text(text: str) -> str: + lowered = (text or "").strip().lower() + lowered = lowered.replace("&", " and ") + lowered = re.sub(r"[^a-z0-9\s]", " ", lowered) + lowered = re.sub(r"\s+", " ", lowered).strip() + words = [] + for part in lowered.split(): + words.append(STREET_WORDS.get(part, part)) + return " ".join(words) + + +def station_from_api_item(item: Dict) -> Station: + return Station( + station_id=str(item.get("id", "")), + name=str(item.get("name", "")).strip(), + normalized_name=normalize_text(str(item.get("name", "")).strip()), + borough=str(item.get("borough", "")).strip(), + lines=[str(line) for line in item.get("lines", []) if str(line).strip()], + ) + + +def station_from_prefs(station_id: str, name: str, borough: str = "", lines: Optional[List[str]] = None) -> Station: + return Station( + station_id=station_id, + name=name, + normalized_name=normalize_text(name), + borough=borough, + lines=lines or [], + ) + + +def find_station_matches(stations: Sequence[Station], query: str, limit: int = 3) -> List[StationMatch]: + normalized_query = normalize_text(query) + if not normalized_query: + return [] + + matches: List[StationMatch] = [] + query_tokens = set(normalized_query.split()) + for station in stations: + ratio = difflib.SequenceMatcher(None, normalized_query, station.normalized_name).ratio() + station_tokens = set(station.normalized_name.split()) + overlap = 0.0 + if station_tokens: + overlap = len(query_tokens & station_tokens) / len(query_tokens | station_tokens) + contains_bonus = 0.18 if normalized_query in station.normalized_name else 0.0 + prefix_bonus = 0.12 if station.normalized_name.startswith(normalized_query) else 0.0 + score = max(ratio, overlap) + contains_bonus + prefix_bonus + if score >= 0.48: + matches.append(StationMatch(station=station, score=score)) + + matches.sort(key=lambda item: (-item.score, item.station.name)) + deduped: List[StationMatch] = [] + seen = set() + for match in matches: + key = (match.station.station_id, match.station.name) + if key in seen: + continue + seen.add(key) + deduped.append(match) + if len(deduped) == limit: + break + return deduped + + +def parse_query_intent(text: str) -> QueryIntent: + normalized = normalize_text(text) + routes = extract_routes(normalized) + direction = extract_direction(normalized) + + for pattern in ( + r"(?:set|change|update|use)\s+(?:my\s+)?(?:default|home)\s+station(?:\s+to)?\s+(?P.+)", + r"my\s+(?:default\s+|home\s+)?station\s+is\s+(?P.+)", + r"use\s+(?P.+)\s+as\s+(?:my\s+)?(?:default|home)\s+station", + ): + match = re.search(pattern, normalized) + if match: + return QueryIntent( + action=ACTION_SET_DEFAULT, + station_text=clean_station_phrase(match.group("station")), + routes=routes, + direction=direction, + ) + + if any(phrase in normalized for phrase in ("help", "what can you do", "how does this work")): + return QueryIntent(action=ACTION_HELP, routes=routes, direction=direction) + + return QueryIntent( + action=ACTION_ARRIVALS, + station_text=extract_station_phrase(normalized), + routes=routes, + direction=direction, + ) + + +def clean_station_phrase(text: Optional[str]) -> Optional[str]: + if not text: + return None + cleaned = normalize_text(text) + cleaned = re.sub(r"\b(?:station|stop|train|subway)\b", " ", cleaned) + cleaned = re.sub(r"\s+", " ", cleaned).strip() + return cleaned or None + + +def extract_station_phrase(normalized_text: str) -> Optional[str]: + patterns = ( + r"(?:at|from)\s+(?P[a-z0-9\s]+?)(?:\s+for\s+[a-z0-9]+\s+train|\s+for\s+[a-z0-9]+|\s+train|$)", + r"next\s+(?:train|subway)\s+(?:at|from)\s+(?P[a-z0-9\s]+)$", + r"when(?:'s|\s+is)?\s+the\s+next\s+(?:train|subway)\s+(?:at|from)\s+(?P[a-z0-9\s]+)$", + ) + for pattern in patterns: + match = re.search(pattern, normalized_text) + if match: + return clean_station_phrase(match.group("station")) + return None + + +def extract_routes(normalized_text: str) -> List[str]: + routes: List[str] = [] + for alias, route in sorted(ROUTE_ALIASES.items(), key=lambda item: -len(item[0])): + if re.search(rf"\b{re.escape(alias)}\b", normalized_text): + routes.append(route) + + tokens = re.findall(r"\b[a-z0-9]{1,2}\b", normalized_text) + for token in tokens: + upper = token.upper() + if upper in {"1", "2", "3", "4", "5", "6", "7", "A", "B", "C", "D", "E", "F", "G", "J", "L", "M", "N", "Q", "R", "S", "W", "Z"}: + routes.append(upper) + + deduped: List[str] = [] + for route in routes: + if route not in deduped: + deduped.append(route) + return deduped + + +def extract_direction(normalized_text: str) -> Optional[str]: + for phrase in NORTHBOUND_WORDS: + if phrase in normalized_text: + return "N" + for phrase in SOUTHBOUND_WORDS: + if phrase in normalized_text: + return "S" + return None + + +def fetch_json(url: str, timeout: int = 12) -> object: + request = urllib.request.Request( + url, + headers={ + "User-Agent": "OpenHome-MTA-Next-Train/1.0", + "Accept": "application/json", + }, + method="GET", + ) + with urllib.request.urlopen(request, timeout=timeout) as response: + return json.loads(response.read().decode("utf-8")) + + +def search_stations(query: str, limit: int = 5) -> List[Station]: + params = urllib.parse.urlencode({"query": query, "limit": limit}) + payload = fetch_json(f"{API_BASE_URL}/stations?{params}") + if not isinstance(payload, list): + return [] + return [station_from_api_item(item) for item in payload if isinstance(item, dict)] + + +def fetch_arrivals(station_id: str, routes: Sequence[str], direction: Optional[str], limit: int = 8) -> List[Arrival]: + params = {"station_id": station_id, "limit": limit} + if routes: + params["line"] = routes[0] + if direction: + params["direction"] = direction + payload = fetch_json(f"{API_BASE_URL}/arrivals?{urllib.parse.urlencode(params)}") + if not isinstance(payload, dict): + return [] + raw_arrivals = payload.get("arrivals", []) + if not isinstance(raw_arrivals, list): + return [] + arrivals: List[Arrival] = [] + for item in raw_arrivals: + if not isinstance(item, dict): + continue + route_id = str(item.get("line", "")).strip() + item_direction = str(item.get("direction", "")).strip() + if routes and route_id not in routes: + continue + if direction and item_direction != direction: + continue + arrivals.append( + Arrival( + route_id=route_id, + direction=item_direction, + direction_label=str(item.get("directionLabel", "")).strip(), + minutes_away=int(item.get("minutesAway", 0)), + headsign=str(item.get("headsign", "")).strip(), + ) + ) + return arrivals + + +def format_arrivals_for_voice(station: Station, arrivals: Sequence[Arrival], routes: Sequence[str], direction: Optional[str]) -> str: + if not arrivals: + return f"I am not seeing any matching trains for {station.name} right now." + + grouped: Dict[tuple[str, str], List[Arrival]] = {} + for arrival in arrivals: + grouped.setdefault((arrival.route_id, arrival.direction), []).append(arrival) + + lead_bits: List[str] = [] + for (_, _), group in sorted(grouped.items(), key=lambda item: item[1][0].minutes_away): + first = group[0] + direction_label = "northbound" if first.direction == "N" else "southbound" + bit = f"{direction_label} {first.route_id} in {render_minutes(first.minutes_away)}" + if len(group) > 1: + bit += f", then {render_minutes(group[1].minutes_away)}" + lead_bits.append(bit) + if len(lead_bits) == 3: + break + + if len(lead_bits) == 1: + return f"At {station.name}, the next train is {lead_bits[0]}." + return f"At {station.name}, next trains: " + "; ".join(lead_bits) + "." + + +def render_minutes(minutes: int) -> str: + if minutes <= 0: + return "now" + if minutes == 1: + return "1 minute" + return f"{minutes} minutes" diff --git a/community/mta-next-train/test_harness.py b/community/mta-next-train/test_harness.py new file mode 100644 index 00000000..90daf2ca --- /dev/null +++ b/community/mta-next-train/test_harness.py @@ -0,0 +1,72 @@ +#!/usr/bin/env python3 +from __future__ import annotations + +import argparse +import sys +from pathlib import Path + +CURRENT_DIR = Path(__file__).resolve().parent +if str(CURRENT_DIR) not in sys.path: + sys.path.append(str(CURRENT_DIR)) + +from mta_next_train_core import ( + fetch_arrivals, + find_station_matches, + format_arrivals_for_voice, + parse_query_intent, + search_stations, + station_from_api_item, +) +from tests.fixtures import SAMPLE_ARRIVALS_RESPONSE, SAMPLE_STATIONS + + +def run_fixture_mode(phrase: str) -> int: + stations = [station_from_api_item(item) for item in SAMPLE_STATIONS] + intent = parse_query_intent(phrase) + station_query = intent.station_text or "Astor Place" + matches = find_station_matches(stations, station_query, limit=1) + if not matches: + print("No station match") + return 1 + station = matches[0].station + import mta_next_train_core + original_fetch_json = mta_next_train_core.fetch_json + try: + mta_next_train_core.fetch_json = lambda url, timeout=12: SAMPLE_ARRIVALS_RESPONSE + arrivals = fetch_arrivals(station.station_id, intent.routes, intent.direction) + finally: + mta_next_train_core.fetch_json = original_fetch_json + print(format_arrivals_for_voice(station, arrivals, intent.routes, intent.direction)) + return 0 + + +def run_live_mode(phrase: str) -> int: + intent = parse_query_intent(phrase) + station_query = intent.station_text + if not station_query: + print("Live mode needs a station in the phrase, e.g. 'next Q train at Union Square'") + return 1 + stations = search_stations(station_query, limit=5) + matches = find_station_matches(stations, station_query, limit=1) + if not matches: + print("No station match") + return 1 + station = matches[0].station + arrivals = fetch_arrivals(station.station_id, intent.routes, intent.direction) + print(format_arrivals_for_voice(station, arrivals, intent.routes, intent.direction)) + return 0 + + +def main() -> int: + parser = argparse.ArgumentParser() + parser.add_argument("--phrase", required=True) + parser.add_argument("--live", action="store_true") + args = parser.parse_args() + + if args.live: + return run_live_mode(args.phrase) + return run_fixture_mode(args.phrase) + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/community/mta-next-train/tests/fixtures.py b/community/mta-next-train/tests/fixtures.py new file mode 100644 index 00000000..2753471e --- /dev/null +++ b/community/mta-next-train/tests/fixtures.py @@ -0,0 +1,53 @@ +SAMPLE_STATIONS = [ + { + "id": "636", + "name": "Astor Pl", + "lat": 40.730054, + "lon": -73.99107, + "borough": "Manhattan", + "lines": ["4", "6"], + }, + { + "id": "R20", + "name": "14 St-Union Sq", + "lat": 40.734673, + "lon": -73.989951, + "borough": "Manhattan", + "lines": ["4", "5", "6", "L", "N", "Q", "R", "W"], + }, +] + +SAMPLE_ARRIVALS_RESPONSE = { + "stationId": "636", + "stationName": "Astor Pl", + "arrivals": [ + { + "line": "6", + "direction": "N", + "directionLabel": "Bronx-bound to Pelham Bay Park", + "arrivalTime": "2026-03-27T21:10:30.000Z", + "minutesAway": 0, + "isAssigned": False, + "headsign": "Pelham Bay Park", + }, + { + "line": "6", + "direction": "S", + "directionLabel": "to Brooklyn Bridge", + "arrivalTime": "2026-03-27T21:12:42.000Z", + "minutesAway": 3, + "isAssigned": False, + "headsign": "Brooklyn Bridge-City Hall", + }, + { + "line": "6", + "direction": "N", + "directionLabel": "Bronx-bound to Pelham Bay Park", + "arrivalTime": "2026-03-27T21:13:19.000Z", + "minutesAway": 3, + "isAssigned": False, + "headsign": "Pelham Bay Park", + }, + ], + "lastUpdated": "2026-03-27T21:09:08.865Z", +} diff --git a/community/mta-next-train/tests/test_mta_next_train.py b/community/mta-next-train/tests/test_mta_next_train.py new file mode 100644 index 00000000..75504985 --- /dev/null +++ b/community/mta-next-train/tests/test_mta_next_train.py @@ -0,0 +1,59 @@ +from __future__ import annotations + +import unittest + +from mta_next_train_core import ( + ACTION_ARRIVALS, + ACTION_SET_DEFAULT, + fetch_arrivals, + find_station_matches, + format_arrivals_for_voice, + parse_query_intent, + station_from_api_item, +) +from tests.fixtures import SAMPLE_ARRIVALS_RESPONSE, SAMPLE_STATIONS + + +class MTANextTrainCoreTests(unittest.TestCase): + def setUp(self): + self.stations = [station_from_api_item(item) for item in SAMPLE_STATIONS] + + def test_parse_default_station_command(self): + intent = parse_query_intent("set my default station to astor place") + self.assertEqual(intent.action, ACTION_SET_DEFAULT) + self.assertEqual(intent.station_text, "astor place") + + def test_parse_arrivals_command(self): + intent = parse_query_intent("when is the next northbound 6 train at astor place") + self.assertEqual(intent.action, ACTION_ARRIVALS) + self.assertEqual(intent.direction, "N") + self.assertEqual(intent.routes, ["6"]) + self.assertEqual(intent.station_text, "astor place") + + def test_station_match(self): + matches = find_station_matches(self.stations, "14 street union square") + self.assertEqual(matches[0].station.name, "14 St-Union Sq") + + def test_fixture_end_to_end(self): + station = find_station_matches(self.stations, "astor place", limit=1)[0].station + import mta_next_train_core + original_fetch_json = mta_next_train_core.fetch_json + try: + mta_next_train_core.fetch_json = lambda url, timeout=12: SAMPLE_ARRIVALS_RESPONSE + arrivals = fetch_arrivals(station.station_id, routes=["6"], direction=None) + finally: + mta_next_train_core.fetch_json = original_fetch_json + self.assertEqual([arrival.route_id for arrival in arrivals[:3]], ["6", "6", "6"]) + spoken = format_arrivals_for_voice( + station, + arrivals, + routes=["6"], + direction=None, + ) + self.assertIn("Astor Pl", spoken) + self.assertIn("northbound 6 in now", spoken) + self.assertIn("southbound 6 in 3 minutes", spoken) + + +if __name__ == "__main__": + unittest.main() From 0eb8ee0e2e0121a7bd9baacee520a6f9742d896c Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Fri, 27 Mar 2026 21:15:41 +0000 Subject: [PATCH 2/8] style: auto-format Python files with autoflake + autopep8 --- community/mta-next-train/main.py | 24 +++++++++++------------- community/mta-next-train/test_harness.py | 19 +++++++++---------- 2 files changed, 20 insertions(+), 23 deletions(-) diff --git a/community/mta-next-train/main.py b/community/mta-next-train/main.py index 43d79759..2d5ee0fc 100644 --- a/community/mta-next-train/main.py +++ b/community/mta-next-train/main.py @@ -1,15 +1,3 @@ -import os -import sys -from typing import Dict, Optional - -CURRENT_DIR = os.path.dirname(os.path.abspath(__file__)) -if CURRENT_DIR not in sys.path: - sys.path.append(CURRENT_DIR) - -from src.agent.capability import MatchingCapability -from src.agent.capability_worker import CapabilityWorker -from src.main import AgentWorker - from mta_next_train_core import ( ACTION_HELP, ACTION_SET_DEFAULT, @@ -22,6 +10,16 @@ search_stations, station_from_prefs, ) +from src.main import AgentWorker +from src.agent.capability_worker import CapabilityWorker +from src.agent.capability import MatchingCapability +import os +import sys +from typing import Dict, Optional + +CURRENT_DIR = os.path.dirname(os.path.abspath(__file__)) +if CURRENT_DIR not in sys.path: + sys.path.append(CURRENT_DIR) PREFS_KEY = "mta_next_train_prefs" @@ -35,7 +33,7 @@ class MTANextTrainCapability(MatchingCapability): capability_worker: CapabilityWorker = None prefs: Dict = None - #{{register_capability}} + # {{register_capability}} def call(self, worker: AgentWorker): self.worker = worker diff --git a/community/mta-next-train/test_harness.py b/community/mta-next-train/test_harness.py index 90daf2ca..27eacad6 100644 --- a/community/mta-next-train/test_harness.py +++ b/community/mta-next-train/test_harness.py @@ -1,14 +1,6 @@ #!/usr/bin/env python3 from __future__ import annotations - -import argparse -import sys -from pathlib import Path - -CURRENT_DIR = Path(__file__).resolve().parent -if str(CURRENT_DIR) not in sys.path: - sys.path.append(str(CURRENT_DIR)) - +from tests.fixtures import SAMPLE_ARRIVALS_RESPONSE, SAMPLE_STATIONS from mta_next_train_core import ( fetch_arrivals, find_station_matches, @@ -17,7 +9,14 @@ search_stations, station_from_api_item, ) -from tests.fixtures import SAMPLE_ARRIVALS_RESPONSE, SAMPLE_STATIONS + +import argparse +import sys +from pathlib import Path + +CURRENT_DIR = Path(__file__).resolve().parent +if str(CURRENT_DIR) not in sys.path: + sys.path.append(str(CURRENT_DIR)) def run_fixture_mode(phrase: str) -> int: From 94ca79503cb1ccaa5ca45b6b053213e74925ca36 Mon Sep 17 00:00:00 2001 From: Chad Newbry Date: Fri, 27 Mar 2026 17:17:44 -0400 Subject: [PATCH 3/8] Fix empty __init__ for mta-next-train --- community/mta-next-train/__init__.py | 1 - 1 file changed, 1 deletion(-) diff --git a/community/mta-next-train/__init__.py b/community/mta-next-train/__init__.py index 7197977e..e69de29b 100644 --- a/community/mta-next-train/__init__.py +++ b/community/mta-next-train/__init__.py @@ -1 +0,0 @@ -# Community ability package marker. From a8ebdbf2b465b8b7607a285ab464c7f5fd678ee7 Mon Sep 17 00:00:00 2001 From: Chad Newbry Date: Fri, 27 Mar 2026 17:51:30 -0400 Subject: [PATCH 4/8] Update mta-next-train runtime compatibility --- community/mta-next-train/main.py | 6 ---- .../mta-next-train/mta_next_train_core.py | 36 ++++++++++++++----- 2 files changed, 27 insertions(+), 15 deletions(-) diff --git a/community/mta-next-train/main.py b/community/mta-next-train/main.py index 2d5ee0fc..7025c810 100644 --- a/community/mta-next-train/main.py +++ b/community/mta-next-train/main.py @@ -13,14 +13,8 @@ from src.main import AgentWorker from src.agent.capability_worker import CapabilityWorker from src.agent.capability import MatchingCapability -import os -import sys from typing import Dict, Optional -CURRENT_DIR = os.path.dirname(os.path.abspath(__file__)) -if CURRENT_DIR not in sys.path: - sys.path.append(CURRENT_DIR) - PREFS_KEY = "mta_next_train_prefs" EXIT_WORDS = { diff --git a/community/mta-next-train/mta_next_train_core.py b/community/mta-next-train/mta_next_train_core.py index c0bc222d..1328d009 100644 --- a/community/mta-next-train/mta_next_train_core.py +++ b/community/mta-next-train/mta_next_train_core.py @@ -3,12 +3,10 @@ import difflib import json import re -import urllib.parse -import urllib.request from dataclasses import dataclass, field from typing import Dict, List, Optional, Sequence - +import requests API_BASE_URL = "https://subwayinfo.nyc/api" ACTION_SET_DEFAULT = "set_default" ACTION_ARRIVALS = "arrivals" @@ -232,20 +230,40 @@ def extract_direction(normalized_text: str) -> Optional[str]: def fetch_json(url: str, timeout: int = 12) -> object: - request = urllib.request.Request( + response = requests.get( url, headers={ "User-Agent": "OpenHome-MTA-Next-Train/1.0", "Accept": "application/json", }, - method="GET", + timeout=timeout, ) - with urllib.request.urlopen(request, timeout=timeout) as response: - return json.loads(response.read().decode("utf-8")) + response.raise_for_status() + return response.json() + + +def urlencode(params: Dict[str, object]) -> str: + parts = [] + for key, value in params.items(): + parts.append(f"{escape_query_value(str(key))}={escape_query_value(str(value))}") + return "&".join(parts) + + +def escape_query_value(value: str) -> str: + safe = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_.~" + out = [] + for ch in value: + if ch in safe: + out.append(ch) + elif ch == " ": + out.append("+") + else: + out.append(f"%{ord(ch):02X}") + return "".join(out) def search_stations(query: str, limit: int = 5) -> List[Station]: - params = urllib.parse.urlencode({"query": query, "limit": limit}) + params = urlencode({"query": query, "limit": limit}) payload = fetch_json(f"{API_BASE_URL}/stations?{params}") if not isinstance(payload, list): return [] @@ -258,7 +276,7 @@ def fetch_arrivals(station_id: str, routes: Sequence[str], direction: Optional[s params["line"] = routes[0] if direction: params["direction"] = direction - payload = fetch_json(f"{API_BASE_URL}/arrivals?{urllib.parse.urlencode(params)}") + payload = fetch_json(f"{API_BASE_URL}/arrivals?{urlencode(params)}") if not isinstance(payload, dict): return [] raw_arrivals = payload.get("arrivals", []) From 7d4954a9a8b3e44513473cbdc1301a56a1892e00 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Fri, 27 Mar 2026 21:51:43 +0000 Subject: [PATCH 5/8] style: auto-format Python files with autoflake + autopep8 --- community/mta-next-train/mta_next_train_core.py | 1 - 1 file changed, 1 deletion(-) diff --git a/community/mta-next-train/mta_next_train_core.py b/community/mta-next-train/mta_next_train_core.py index 1328d009..bfc9933d 100644 --- a/community/mta-next-train/mta_next_train_core.py +++ b/community/mta-next-train/mta_next_train_core.py @@ -1,7 +1,6 @@ from __future__ import annotations import difflib -import json import re from dataclasses import dataclass, field from typing import Dict, List, Optional, Sequence From 5c36b1fcdb7cc23d760fa746ff28c28f780811b9 Mon Sep 17 00:00:00 2001 From: Chad Newbry Date: Fri, 27 Mar 2026 19:13:48 -0400 Subject: [PATCH 6/8] Flatten mta-next-train into single-file runtime --- community/mta-next-train/README.md | 29 +- community/mta-next-train/main.py | 446 +++++++++++++++++- .../mta-next-train/mta_next_train_core.py | 335 ------------- community/mta-next-train/test_harness.py | 71 --- community/mta-next-train/tests/fixtures.py | 53 --- .../tests/test_mta_next_train.py | 59 --- 6 files changed, 429 insertions(+), 564 deletions(-) delete mode 100644 community/mta-next-train/mta_next_train_core.py delete mode 100644 community/mta-next-train/test_harness.py delete mode 100644 community/mta-next-train/tests/fixtures.py delete mode 100644 community/mta-next-train/tests/test_mta_next_train.py diff --git a/community/mta-next-train/README.md b/community/mta-next-train/README.md index a55ce0db..83f35a5c 100644 --- a/community/mta-next-train/README.md +++ b/community/mta-next-train/README.md @@ -15,36 +15,11 @@ Live New York City subway arrivals for OpenHome. Ask for your next train using a ## Trigger Words -- `next train` +- `mta next train` - `next subway` -- `when's my next train` - `subway arrivals` -- `mta` -- `mta next train` +- `nyc mta` ## Setup This ability uses `SubwayInfo.nyc` and does not require an API key. - -Install the ability in OpenHome, then set a default station once: - -- "set my default station to Astor Place" -- "use 14 street union square as my default station" - -Then the main demo becomes: - -- "when's my next train?" - -## Example Voice Commands - -- "When's my next train?" -- "Next Q train at Union Square" -- "When is the next northbound 6 at Astor Place?" -- "Set my default station to Fulton Street" -- "Change my default station to Jay Street MetroTech" - -## Technical Notes - -- Station search and live arrivals come from `SubwayInfo.nyc` -- No API key or account setup is required -- The local test harness includes fixture-based end-to-end checks plus an optional live mode diff --git a/community/mta-next-train/main.py b/community/mta-next-train/main.py index 7025c810..e39dd83c 100644 --- a/community/mta-next-train/main.py +++ b/community/mta-next-train/main.py @@ -1,41 +1,414 @@ -from mta_next_train_core import ( - ACTION_HELP, - ACTION_SET_DEFAULT, - QueryIntent, - Station, - fetch_arrivals, - find_station_matches, - format_arrivals_for_voice, - parse_query_intent, - search_stations, - station_from_prefs, -) -from src.main import AgentWorker -from src.agent.capability_worker import CapabilityWorker from src.agent.capability import MatchingCapability -from typing import Dict, Optional +from src.agent.capability_worker import CapabilityWorker +from src.main import AgentWorker + +import difflib +import re +from dataclasses import dataclass, field +from typing import Dict, List, Optional, Sequence + +import requests +API_BASE_URL = "https://subwayinfo.nyc/api" +ACTION_SET_DEFAULT = "set_default" +ACTION_ARRIVALS = "arrivals" +ACTION_HELP = "help" +ACTION_GET_DEFAULT = "get_default" +ROUTE_ALIASES = { + "one": "1", + "two": "2", + "three": "3", + "four": "4", + "five": "5", + "six": "6", + "seven": "7", + "a": "A", + "b": "B", + "c": "C", + "d": "D", + "e": "E", + "f": "F", + "g": "G", + "j": "J", + "l": "L", + "m": "M", + "n": "N", + "q": "Q", + "r": "R", + "s": "S", + "w": "W", + "z": "Z", + "shuttle": "S", +} +NORTHBOUND_WORDS = {"northbound", "uptown"} +SOUTHBOUND_WORDS = {"southbound", "downtown"} +STREET_WORDS = { + "st": "street", + "street": "street", + "ave": "avenue", + "av": "avenue", + "avenue": "avenue", + "sq": "square", + "square": "square", + "plz": "plaza", + "plaza": "plaza", +} PREFS_KEY = "mta_next_train_prefs" EXIT_WORDS = { "stop", "exit", "quit", "done", "cancel", "bye", "goodbye", "nothing else", } +@dataclass +class Station: + station_id: str + name: str + normalized_name: str + borough: str = "" + lines: List[str] = field(default_factory=list) + + +@dataclass +class StationMatch: + station: Station + score: float + + +@dataclass +class QueryIntent: + action: str + station_text: Optional[str] = None + routes: List[str] = field(default_factory=list) + direction: Optional[str] = None + + +@dataclass +class Arrival: + route_id: str + direction: str + direction_label: str + minutes_away: int + headsign: str + + +def normalize_text(text: str) -> str: + lowered = (text or "").strip().lower() + lowered = lowered.replace("&", " and ") + lowered = re.sub(r"[^a-z0-9\s]", " ", lowered) + lowered = re.sub(r"\s+", " ", lowered).strip() + words = [] + for part in lowered.split(): + words.append(STREET_WORDS.get(part, part)) + return " ".join(words) + + +def station_from_api_item(item: Dict) -> Station: + return Station( + station_id=str(item.get("id", "")), + name=str(item.get("name", "")).strip(), + normalized_name=normalize_text(str(item.get("name", "")).strip()), + borough=str(item.get("borough", "")).strip(), + lines=[str(line) for line in item.get("lines", []) if str(line).strip()], + ) + + +def station_from_prefs( + station_id: str, + name: str, + borough: str = "", + lines: Optional[List[str]] = None, +) -> Station: + return Station( + station_id=station_id, + name=name, + normalized_name=normalize_text(name), + borough=borough, + lines=lines or [], + ) + + +def find_station_matches( + stations: Sequence[Station], query: str, limit: int = 3 +) -> List[StationMatch]: + normalized_query = normalize_text(query) + if not normalized_query: + return [] + + matches: List[StationMatch] = [] + query_tokens = set(normalized_query.split()) + for station in stations: + ratio = difflib.SequenceMatcher( + None, normalized_query, station.normalized_name + ).ratio() + station_tokens = set(station.normalized_name.split()) + overlap = 0.0 + if station_tokens: + overlap = len(query_tokens & station_tokens) / len( + query_tokens | station_tokens + ) + contains_bonus = 0.18 if normalized_query in station.normalized_name else 0.0 + prefix_bonus = 0.12 if station.normalized_name.startswith(normalized_query) else 0.0 + score = max(ratio, overlap) + contains_bonus + prefix_bonus + if score >= 0.48: + matches.append(StationMatch(station=station, score=score)) + + matches.sort(key=lambda item: (-item.score, item.station.name)) + deduped: List[StationMatch] = [] + seen = set() + for match in matches: + key = (match.station.station_id, match.station.name) + if key in seen: + continue + seen.add(key) + deduped.append(match) + if len(deduped) == limit: + break + return deduped + + +def parse_query_intent(text: str) -> QueryIntent: + normalized = normalize_text(text) + routes = extract_routes(text, normalized) + direction = extract_direction(normalized) + + for pattern in ( + r"(?:set|change|update|use)\s+(?:my\s+)?(?:default|home)\s+station(?:\s+to)?\s+(?P.+)", + r"my\s+(?:default\s+|home\s+)?station\s+is\s+(?P.+)", + r"use\s+(?P.+)\s+as\s+(?:my\s+)?(?:default|home)\s+station", + ): + match = re.search(pattern, normalized) + if match: + return QueryIntent( + action=ACTION_SET_DEFAULT, + station_text=clean_station_phrase(match.group("station")), + routes=routes, + direction=direction, + ) + + if any( + phrase in normalized + for phrase in ( + "what is my default station", + "what s my default station", + "what is the default station", + "what s the default station", + "what is my home station", + "what s my home station", + ) + ): + return QueryIntent(action=ACTION_GET_DEFAULT, routes=routes, direction=direction) + + if any( + phrase in normalized for phrase in ("help", "what can you do", "how does this work") + ): + return QueryIntent(action=ACTION_HELP, routes=routes, direction=direction) + + return QueryIntent( + action=ACTION_ARRIVALS, + station_text=extract_station_phrase(normalized), + routes=routes, + direction=direction, + ) + + +def clean_station_phrase(text: Optional[str]) -> Optional[str]: + if not text: + return None + cleaned = normalize_text(text) + cleaned = re.sub(r"\b(?:station|stop|train|subway)\b", " ", cleaned) + cleaned = re.sub(r"\s+", " ", cleaned).strip() + return cleaned or None + + +def extract_station_phrase(normalized_text: str) -> Optional[str]: + patterns = ( + r"\b(?:at|from)\b\s+(?P[a-z0-9\s]+?)(?:\s+for\s+[a-z0-9]+\s+train|\s+for\s+[a-z0-9]+|\s+train|$)", + r"\bnext\s+(?:train|subway)\s+(?:at|from)\b\s+(?P[a-z0-9\s]+)$", + r"\bwhen(?:'s|\s+is)?\s+the\s+next\s+(?:train|subway)\s+(?:at|from)\b\s+(?P[a-z0-9\s]+)$", + ) + for pattern in patterns: + match = re.search(pattern, normalized_text) + if match: + return clean_station_phrase(match.group("station")) + return None + + +def extract_routes(raw_text: str, normalized_text: str) -> List[str]: + routes: List[str] = [] + + # Spoken-out number routes are safe to match directly. + spoken_number_aliases = { + alias: route + for alias, route in ROUTE_ALIASES.items() + if len(alias) > 1 and alias not in {"shuttle"} + } + for alias, route in sorted( + spoken_number_aliases.items(), key=lambda item: -len(item[0]) + ): + if re.search(rf"\b{re.escape(alias)}\b", normalized_text): + routes.append(route) + + route_context_patterns = ( + r"\b(?P[1234567abcdefgjlmnqrswz])\s+(?:train|line)\b", + r"\b(?:train|line)\s+(?P[1234567abcdefgjlmnqrswz])\b", + r"\bnext\s+(?P[1234567abcdefgjlmnqrswz])\b", + r"\b(?:uptown|downtown|northbound|southbound)\s+(?P[1234567abcdefgjlmnqrswz])\b", + ) + lowered_raw = (raw_text or "").lower() + for pattern in route_context_patterns: + for match in re.finditer(pattern, lowered_raw): + routes.append(match.group("route").upper()) + + if re.search(r"\bshuttle\b", normalized_text): + routes.append("S") + + deduped: List[str] = [] + for route in routes: + if route not in deduped: + deduped.append(route) + return deduped + + +def extract_direction(normalized_text: str) -> Optional[str]: + for phrase in NORTHBOUND_WORDS: + if phrase in normalized_text: + return "N" + for phrase in SOUTHBOUND_WORDS: + if phrase in normalized_text: + return "S" + return None + + +def fetch_json(url: str, timeout: int = 12) -> object: + response = requests.get( + url, + headers={ + "User-Agent": "OpenHome-MTA-Next-Train/1.0", + "Accept": "application/json", + }, + timeout=timeout, + ) + response.raise_for_status() + return response.json() + + +def urlencode(params: Dict[str, object]) -> str: + parts = [] + for key, value in params.items(): + parts.append(f"{escape_query_value(str(key))}={escape_query_value(str(value))}") + return "&".join(parts) + + +def escape_query_value(value: str) -> str: + safe = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_.~" + out = [] + for ch in value: + if ch in safe: + out.append(ch) + elif ch == " ": + out.append("+") + else: + out.append(f"%{ord(ch):02X}") + return "".join(out) + + +def search_stations(query: str, limit: int = 5) -> List[Station]: + params = urlencode({"query": query, "limit": limit}) + payload = fetch_json(f"{API_BASE_URL}/stations?{params}") + if not isinstance(payload, list): + return [] + return [station_from_api_item(item) for item in payload if isinstance(item, dict)] + + +def fetch_arrivals( + station_id: str, routes: Sequence[str], direction: Optional[str], limit: int = 8 +) -> List[Arrival]: + params = {"station_id": station_id, "limit": limit} + if routes: + params["line"] = routes[0] + if direction: + params["direction"] = direction + payload = fetch_json(f"{API_BASE_URL}/arrivals?{urlencode(params)}") + if not isinstance(payload, dict): + return [] + raw_arrivals = payload.get("arrivals", []) + if not isinstance(raw_arrivals, list): + return [] + arrivals: List[Arrival] = [] + for item in raw_arrivals: + if not isinstance(item, dict): + continue + route_id = str(item.get("line", "")).strip() + item_direction = str(item.get("direction", "")).strip() + if routes and route_id not in routes: + continue + if direction and item_direction != direction: + continue + arrivals.append( + Arrival( + route_id=route_id, + direction=item_direction, + direction_label=str(item.get("directionLabel", "")).strip(), + minutes_away=int(item.get("minutesAway", 0)), + headsign=str(item.get("headsign", "")).strip(), + ) + ) + return arrivals + + +def format_arrivals_for_voice( + station: Station, + arrivals: Sequence[Arrival], + routes: Sequence[str], + direction: Optional[str], +) -> str: + if not arrivals: + return f"I am not seeing any matching trains for {station.name} right now." + + grouped: Dict[tuple[str, str], List[Arrival]] = {} + for arrival in arrivals: + grouped.setdefault((arrival.route_id, arrival.direction), []).append(arrival) + + lead_bits: List[str] = [] + for (_, _), group in sorted(grouped.items(), key=lambda item: item[1][0].minutes_away): + first = group[0] + direction_label = "northbound" if first.direction == "N" else "southbound" + bit = f"{direction_label} {first.route_id} in {render_minutes(first.minutes_away)}" + if len(group) > 1: + bit += f", then {render_minutes(group[1].minutes_away)}" + lead_bits.append(bit) + if len(lead_bits) == 3: + break + + if len(lead_bits) == 1: + return f"At {station.name}, the next train is {lead_bits[0]}." + return f"At {station.name}, next trains: " + "; ".join(lead_bits) + "." + + +def render_minutes(minutes: int) -> str: + if minutes <= 0: + return "now" + if minutes == 1: + return "1 minute" + return f"{minutes} minutes" + + class MTANextTrainCapability(MatchingCapability): worker: AgentWorker = None capability_worker: CapabilityWorker = None prefs: Dict = None - # {{register_capability}} + #{{register capability}} def call(self, worker: AgentWorker): self.worker = worker - self.capability_worker = CapabilityWorker(self.worker) + self.capability_worker = CapabilityWorker(self) self.worker.session_tasks.create(self.run()) async def run(self): try: + self.worker.editor_logging_handler.info("MTA Next Train triggered") self.prefs = self.load_prefs() trigger_text = self.get_trigger_text() if not trigger_text: @@ -55,6 +428,8 @@ async def run(self): await self.capability_worker.speak( "You can ask for your next train, ask for a specific line at a station, or set a default station. For example: when is my next train, next Q train at Union Square, or set my default station to Astor Place." ) + elif intent.action == ACTION_GET_DEFAULT: + await self.handle_get_default() elif intent.action == ACTION_SET_DEFAULT: await self.handle_set_default(intent) else: @@ -112,6 +487,12 @@ async def handle_set_default(self, intent: QueryIntent): ) if not station_text: return + follow_up_intent = parse_query_intent(station_text) + if follow_up_intent.action == ACTION_SET_DEFAULT: + station_text = follow_up_intent.station_text + elif follow_up_intent.action == ACTION_GET_DEFAULT: + await self.handle_get_default() + return station = await self.resolve_station(station_text) if not station: @@ -126,6 +507,16 @@ async def handle_set_default(self, intent: QueryIntent): f"Saved {station.name} as your default station." ) + async def handle_get_default(self): + if not self.prefs.get("default_station_id"): + await self.capability_worker.speak( + "You do not have a default station saved yet." + ) + return + await self.capability_worker.speak( + f"Your default station is {self.prefs.get('default_station_name', 'unknown')}." + ) + async def handle_arrivals(self, intent: QueryIntent): station = None if intent.station_text: @@ -147,12 +538,27 @@ async def handle_arrivals(self, intent: QueryIntent): if self._is_exit(spoken_station): await self.capability_worker.speak("Okay. Closing MTA Next Train.") return + follow_up_intent = parse_query_intent(spoken_station) + if follow_up_intent.action == ACTION_SET_DEFAULT: + await self.handle_set_default(follow_up_intent) + return + if follow_up_intent.action == ACTION_GET_DEFAULT: + await self.handle_get_default() + return station = await self.resolve_station(spoken_station) if not station: return - await self.capability_worker.speak(f"Checking live arrivals for {station.name}.") + await self.capability_worker.speak( + f"Checking live arrivals for {station.name}." + ) + self.worker.editor_logging_handler.info( + f"MTA arrivals request station_id={station.station_id} station={station.name} routes={intent.routes} direction={intent.direction}" + ) arrivals = fetch_arrivals(station.station_id, intent.routes, intent.direction) + self.worker.editor_logging_handler.info( + f"MTA arrivals result count={len(arrivals)} station_id={station.station_id}" + ) summary = format_arrivals_for_voice( station, arrivals, @@ -180,7 +586,9 @@ async def resolve_station(self, station_text: str) -> Optional[Station]: return None narrowed = find_station_matches(stations, response, limit=1) if not narrowed: - await self.capability_worker.speak("I still could not pin down the station.") + await self.capability_worker.speak( + "I still could not pin down the station." + ) return None return narrowed[0].station diff --git a/community/mta-next-train/mta_next_train_core.py b/community/mta-next-train/mta_next_train_core.py deleted file mode 100644 index bfc9933d..00000000 --- a/community/mta-next-train/mta_next_train_core.py +++ /dev/null @@ -1,335 +0,0 @@ -from __future__ import annotations - -import difflib -import re -from dataclasses import dataclass, field -from typing import Dict, List, Optional, Sequence - -import requests -API_BASE_URL = "https://subwayinfo.nyc/api" -ACTION_SET_DEFAULT = "set_default" -ACTION_ARRIVALS = "arrivals" -ACTION_HELP = "help" -ROUTE_ALIASES = { - "one": "1", - "two": "2", - "three": "3", - "four": "4", - "five": "5", - "six": "6", - "seven": "7", - "a": "A", - "b": "B", - "c": "C", - "d": "D", - "e": "E", - "f": "F", - "g": "G", - "j": "J", - "l": "L", - "m": "M", - "n": "N", - "q": "Q", - "r": "R", - "s": "S", - "w": "W", - "z": "Z", - "shuttle": "S", -} -NORTHBOUND_WORDS = {"northbound", "uptown"} -SOUTHBOUND_WORDS = {"southbound", "downtown"} -STREET_WORDS = { - "st": "street", - "street": "street", - "ave": "avenue", - "av": "avenue", - "avenue": "avenue", - "sq": "square", - "square": "square", - "plz": "plaza", - "plaza": "plaza", -} - - -@dataclass -class Station: - station_id: str - name: str - normalized_name: str - borough: str = "" - lines: List[str] = field(default_factory=list) - - -@dataclass -class StationMatch: - station: Station - score: float - - -@dataclass -class QueryIntent: - action: str - station_text: Optional[str] = None - routes: List[str] = field(default_factory=list) - direction: Optional[str] = None - - -@dataclass -class Arrival: - route_id: str - direction: str - direction_label: str - minutes_away: int - headsign: str - - -def normalize_text(text: str) -> str: - lowered = (text or "").strip().lower() - lowered = lowered.replace("&", " and ") - lowered = re.sub(r"[^a-z0-9\s]", " ", lowered) - lowered = re.sub(r"\s+", " ", lowered).strip() - words = [] - for part in lowered.split(): - words.append(STREET_WORDS.get(part, part)) - return " ".join(words) - - -def station_from_api_item(item: Dict) -> Station: - return Station( - station_id=str(item.get("id", "")), - name=str(item.get("name", "")).strip(), - normalized_name=normalize_text(str(item.get("name", "")).strip()), - borough=str(item.get("borough", "")).strip(), - lines=[str(line) for line in item.get("lines", []) if str(line).strip()], - ) - - -def station_from_prefs(station_id: str, name: str, borough: str = "", lines: Optional[List[str]] = None) -> Station: - return Station( - station_id=station_id, - name=name, - normalized_name=normalize_text(name), - borough=borough, - lines=lines or [], - ) - - -def find_station_matches(stations: Sequence[Station], query: str, limit: int = 3) -> List[StationMatch]: - normalized_query = normalize_text(query) - if not normalized_query: - return [] - - matches: List[StationMatch] = [] - query_tokens = set(normalized_query.split()) - for station in stations: - ratio = difflib.SequenceMatcher(None, normalized_query, station.normalized_name).ratio() - station_tokens = set(station.normalized_name.split()) - overlap = 0.0 - if station_tokens: - overlap = len(query_tokens & station_tokens) / len(query_tokens | station_tokens) - contains_bonus = 0.18 if normalized_query in station.normalized_name else 0.0 - prefix_bonus = 0.12 if station.normalized_name.startswith(normalized_query) else 0.0 - score = max(ratio, overlap) + contains_bonus + prefix_bonus - if score >= 0.48: - matches.append(StationMatch(station=station, score=score)) - - matches.sort(key=lambda item: (-item.score, item.station.name)) - deduped: List[StationMatch] = [] - seen = set() - for match in matches: - key = (match.station.station_id, match.station.name) - if key in seen: - continue - seen.add(key) - deduped.append(match) - if len(deduped) == limit: - break - return deduped - - -def parse_query_intent(text: str) -> QueryIntent: - normalized = normalize_text(text) - routes = extract_routes(normalized) - direction = extract_direction(normalized) - - for pattern in ( - r"(?:set|change|update|use)\s+(?:my\s+)?(?:default|home)\s+station(?:\s+to)?\s+(?P.+)", - r"my\s+(?:default\s+|home\s+)?station\s+is\s+(?P.+)", - r"use\s+(?P.+)\s+as\s+(?:my\s+)?(?:default|home)\s+station", - ): - match = re.search(pattern, normalized) - if match: - return QueryIntent( - action=ACTION_SET_DEFAULT, - station_text=clean_station_phrase(match.group("station")), - routes=routes, - direction=direction, - ) - - if any(phrase in normalized for phrase in ("help", "what can you do", "how does this work")): - return QueryIntent(action=ACTION_HELP, routes=routes, direction=direction) - - return QueryIntent( - action=ACTION_ARRIVALS, - station_text=extract_station_phrase(normalized), - routes=routes, - direction=direction, - ) - - -def clean_station_phrase(text: Optional[str]) -> Optional[str]: - if not text: - return None - cleaned = normalize_text(text) - cleaned = re.sub(r"\b(?:station|stop|train|subway)\b", " ", cleaned) - cleaned = re.sub(r"\s+", " ", cleaned).strip() - return cleaned or None - - -def extract_station_phrase(normalized_text: str) -> Optional[str]: - patterns = ( - r"(?:at|from)\s+(?P[a-z0-9\s]+?)(?:\s+for\s+[a-z0-9]+\s+train|\s+for\s+[a-z0-9]+|\s+train|$)", - r"next\s+(?:train|subway)\s+(?:at|from)\s+(?P[a-z0-9\s]+)$", - r"when(?:'s|\s+is)?\s+the\s+next\s+(?:train|subway)\s+(?:at|from)\s+(?P[a-z0-9\s]+)$", - ) - for pattern in patterns: - match = re.search(pattern, normalized_text) - if match: - return clean_station_phrase(match.group("station")) - return None - - -def extract_routes(normalized_text: str) -> List[str]: - routes: List[str] = [] - for alias, route in sorted(ROUTE_ALIASES.items(), key=lambda item: -len(item[0])): - if re.search(rf"\b{re.escape(alias)}\b", normalized_text): - routes.append(route) - - tokens = re.findall(r"\b[a-z0-9]{1,2}\b", normalized_text) - for token in tokens: - upper = token.upper() - if upper in {"1", "2", "3", "4", "5", "6", "7", "A", "B", "C", "D", "E", "F", "G", "J", "L", "M", "N", "Q", "R", "S", "W", "Z"}: - routes.append(upper) - - deduped: List[str] = [] - for route in routes: - if route not in deduped: - deduped.append(route) - return deduped - - -def extract_direction(normalized_text: str) -> Optional[str]: - for phrase in NORTHBOUND_WORDS: - if phrase in normalized_text: - return "N" - for phrase in SOUTHBOUND_WORDS: - if phrase in normalized_text: - return "S" - return None - - -def fetch_json(url: str, timeout: int = 12) -> object: - response = requests.get( - url, - headers={ - "User-Agent": "OpenHome-MTA-Next-Train/1.0", - "Accept": "application/json", - }, - timeout=timeout, - ) - response.raise_for_status() - return response.json() - - -def urlencode(params: Dict[str, object]) -> str: - parts = [] - for key, value in params.items(): - parts.append(f"{escape_query_value(str(key))}={escape_query_value(str(value))}") - return "&".join(parts) - - -def escape_query_value(value: str) -> str: - safe = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_.~" - out = [] - for ch in value: - if ch in safe: - out.append(ch) - elif ch == " ": - out.append("+") - else: - out.append(f"%{ord(ch):02X}") - return "".join(out) - - -def search_stations(query: str, limit: int = 5) -> List[Station]: - params = urlencode({"query": query, "limit": limit}) - payload = fetch_json(f"{API_BASE_URL}/stations?{params}") - if not isinstance(payload, list): - return [] - return [station_from_api_item(item) for item in payload if isinstance(item, dict)] - - -def fetch_arrivals(station_id: str, routes: Sequence[str], direction: Optional[str], limit: int = 8) -> List[Arrival]: - params = {"station_id": station_id, "limit": limit} - if routes: - params["line"] = routes[0] - if direction: - params["direction"] = direction - payload = fetch_json(f"{API_BASE_URL}/arrivals?{urlencode(params)}") - if not isinstance(payload, dict): - return [] - raw_arrivals = payload.get("arrivals", []) - if not isinstance(raw_arrivals, list): - return [] - arrivals: List[Arrival] = [] - for item in raw_arrivals: - if not isinstance(item, dict): - continue - route_id = str(item.get("line", "")).strip() - item_direction = str(item.get("direction", "")).strip() - if routes and route_id not in routes: - continue - if direction and item_direction != direction: - continue - arrivals.append( - Arrival( - route_id=route_id, - direction=item_direction, - direction_label=str(item.get("directionLabel", "")).strip(), - minutes_away=int(item.get("minutesAway", 0)), - headsign=str(item.get("headsign", "")).strip(), - ) - ) - return arrivals - - -def format_arrivals_for_voice(station: Station, arrivals: Sequence[Arrival], routes: Sequence[str], direction: Optional[str]) -> str: - if not arrivals: - return f"I am not seeing any matching trains for {station.name} right now." - - grouped: Dict[tuple[str, str], List[Arrival]] = {} - for arrival in arrivals: - grouped.setdefault((arrival.route_id, arrival.direction), []).append(arrival) - - lead_bits: List[str] = [] - for (_, _), group in sorted(grouped.items(), key=lambda item: item[1][0].minutes_away): - first = group[0] - direction_label = "northbound" if first.direction == "N" else "southbound" - bit = f"{direction_label} {first.route_id} in {render_minutes(first.minutes_away)}" - if len(group) > 1: - bit += f", then {render_minutes(group[1].minutes_away)}" - lead_bits.append(bit) - if len(lead_bits) == 3: - break - - if len(lead_bits) == 1: - return f"At {station.name}, the next train is {lead_bits[0]}." - return f"At {station.name}, next trains: " + "; ".join(lead_bits) + "." - - -def render_minutes(minutes: int) -> str: - if minutes <= 0: - return "now" - if minutes == 1: - return "1 minute" - return f"{minutes} minutes" diff --git a/community/mta-next-train/test_harness.py b/community/mta-next-train/test_harness.py deleted file mode 100644 index 27eacad6..00000000 --- a/community/mta-next-train/test_harness.py +++ /dev/null @@ -1,71 +0,0 @@ -#!/usr/bin/env python3 -from __future__ import annotations -from tests.fixtures import SAMPLE_ARRIVALS_RESPONSE, SAMPLE_STATIONS -from mta_next_train_core import ( - fetch_arrivals, - find_station_matches, - format_arrivals_for_voice, - parse_query_intent, - search_stations, - station_from_api_item, -) - -import argparse -import sys -from pathlib import Path - -CURRENT_DIR = Path(__file__).resolve().parent -if str(CURRENT_DIR) not in sys.path: - sys.path.append(str(CURRENT_DIR)) - - -def run_fixture_mode(phrase: str) -> int: - stations = [station_from_api_item(item) for item in SAMPLE_STATIONS] - intent = parse_query_intent(phrase) - station_query = intent.station_text or "Astor Place" - matches = find_station_matches(stations, station_query, limit=1) - if not matches: - print("No station match") - return 1 - station = matches[0].station - import mta_next_train_core - original_fetch_json = mta_next_train_core.fetch_json - try: - mta_next_train_core.fetch_json = lambda url, timeout=12: SAMPLE_ARRIVALS_RESPONSE - arrivals = fetch_arrivals(station.station_id, intent.routes, intent.direction) - finally: - mta_next_train_core.fetch_json = original_fetch_json - print(format_arrivals_for_voice(station, arrivals, intent.routes, intent.direction)) - return 0 - - -def run_live_mode(phrase: str) -> int: - intent = parse_query_intent(phrase) - station_query = intent.station_text - if not station_query: - print("Live mode needs a station in the phrase, e.g. 'next Q train at Union Square'") - return 1 - stations = search_stations(station_query, limit=5) - matches = find_station_matches(stations, station_query, limit=1) - if not matches: - print("No station match") - return 1 - station = matches[0].station - arrivals = fetch_arrivals(station.station_id, intent.routes, intent.direction) - print(format_arrivals_for_voice(station, arrivals, intent.routes, intent.direction)) - return 0 - - -def main() -> int: - parser = argparse.ArgumentParser() - parser.add_argument("--phrase", required=True) - parser.add_argument("--live", action="store_true") - args = parser.parse_args() - - if args.live: - return run_live_mode(args.phrase) - return run_fixture_mode(args.phrase) - - -if __name__ == "__main__": - raise SystemExit(main()) diff --git a/community/mta-next-train/tests/fixtures.py b/community/mta-next-train/tests/fixtures.py deleted file mode 100644 index 2753471e..00000000 --- a/community/mta-next-train/tests/fixtures.py +++ /dev/null @@ -1,53 +0,0 @@ -SAMPLE_STATIONS = [ - { - "id": "636", - "name": "Astor Pl", - "lat": 40.730054, - "lon": -73.99107, - "borough": "Manhattan", - "lines": ["4", "6"], - }, - { - "id": "R20", - "name": "14 St-Union Sq", - "lat": 40.734673, - "lon": -73.989951, - "borough": "Manhattan", - "lines": ["4", "5", "6", "L", "N", "Q", "R", "W"], - }, -] - -SAMPLE_ARRIVALS_RESPONSE = { - "stationId": "636", - "stationName": "Astor Pl", - "arrivals": [ - { - "line": "6", - "direction": "N", - "directionLabel": "Bronx-bound to Pelham Bay Park", - "arrivalTime": "2026-03-27T21:10:30.000Z", - "minutesAway": 0, - "isAssigned": False, - "headsign": "Pelham Bay Park", - }, - { - "line": "6", - "direction": "S", - "directionLabel": "to Brooklyn Bridge", - "arrivalTime": "2026-03-27T21:12:42.000Z", - "minutesAway": 3, - "isAssigned": False, - "headsign": "Brooklyn Bridge-City Hall", - }, - { - "line": "6", - "direction": "N", - "directionLabel": "Bronx-bound to Pelham Bay Park", - "arrivalTime": "2026-03-27T21:13:19.000Z", - "minutesAway": 3, - "isAssigned": False, - "headsign": "Pelham Bay Park", - }, - ], - "lastUpdated": "2026-03-27T21:09:08.865Z", -} diff --git a/community/mta-next-train/tests/test_mta_next_train.py b/community/mta-next-train/tests/test_mta_next_train.py deleted file mode 100644 index 75504985..00000000 --- a/community/mta-next-train/tests/test_mta_next_train.py +++ /dev/null @@ -1,59 +0,0 @@ -from __future__ import annotations - -import unittest - -from mta_next_train_core import ( - ACTION_ARRIVALS, - ACTION_SET_DEFAULT, - fetch_arrivals, - find_station_matches, - format_arrivals_for_voice, - parse_query_intent, - station_from_api_item, -) -from tests.fixtures import SAMPLE_ARRIVALS_RESPONSE, SAMPLE_STATIONS - - -class MTANextTrainCoreTests(unittest.TestCase): - def setUp(self): - self.stations = [station_from_api_item(item) for item in SAMPLE_STATIONS] - - def test_parse_default_station_command(self): - intent = parse_query_intent("set my default station to astor place") - self.assertEqual(intent.action, ACTION_SET_DEFAULT) - self.assertEqual(intent.station_text, "astor place") - - def test_parse_arrivals_command(self): - intent = parse_query_intent("when is the next northbound 6 train at astor place") - self.assertEqual(intent.action, ACTION_ARRIVALS) - self.assertEqual(intent.direction, "N") - self.assertEqual(intent.routes, ["6"]) - self.assertEqual(intent.station_text, "astor place") - - def test_station_match(self): - matches = find_station_matches(self.stations, "14 street union square") - self.assertEqual(matches[0].station.name, "14 St-Union Sq") - - def test_fixture_end_to_end(self): - station = find_station_matches(self.stations, "astor place", limit=1)[0].station - import mta_next_train_core - original_fetch_json = mta_next_train_core.fetch_json - try: - mta_next_train_core.fetch_json = lambda url, timeout=12: SAMPLE_ARRIVALS_RESPONSE - arrivals = fetch_arrivals(station.station_id, routes=["6"], direction=None) - finally: - mta_next_train_core.fetch_json = original_fetch_json - self.assertEqual([arrival.route_id for arrival in arrivals[:3]], ["6", "6", "6"]) - spoken = format_arrivals_for_voice( - station, - arrivals, - routes=["6"], - direction=None, - ) - self.assertIn("Astor Pl", spoken) - self.assertIn("northbound 6 in now", spoken) - self.assertIn("southbound 6 in 3 minutes", spoken) - - -if __name__ == "__main__": - unittest.main() From 4c931d23fd157db20d0446f68baf7ec028e23dbd Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Fri, 27 Mar 2026 23:14:14 +0000 Subject: [PATCH 7/8] style: auto-format Python files with autoflake + autopep8 --- community/mta-next-train/main.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/community/mta-next-train/main.py b/community/mta-next-train/main.py index e39dd83c..1c5961ef 100644 --- a/community/mta-next-train/main.py +++ b/community/mta-next-train/main.py @@ -399,7 +399,7 @@ class MTANextTrainCapability(MatchingCapability): capability_worker: CapabilityWorker = None prefs: Dict = None - #{{register capability}} + # {{register capability}} def call(self, worker: AgentWorker): self.worker = worker From 8fd34e608462b3892ba4ea7a680e8a84f8f7e489 Mon Sep 17 00:00:00 2001 From: Uzair Ullah Date: Mon, 30 Mar 2026 15:00:58 +0500 Subject: [PATCH 8/8] Enhance direction keywords and user prompts for better UX Added new keywords for northbound and southbound directions, enhanced exit phrases, and improved user prompts for station queries. Signed-off-by: Uzair Ullah --- community/mta-next-train/main.py | 154 +++++++++++++++++++++++-------- 1 file changed, 117 insertions(+), 37 deletions(-) diff --git a/community/mta-next-train/main.py b/community/mta-next-train/main.py index 1c5961ef..19b822a8 100644 --- a/community/mta-next-train/main.py +++ b/community/mta-next-train/main.py @@ -2,6 +2,7 @@ from src.agent.capability_worker import CapabilityWorker from src.main import AgentWorker +import asyncio import difflib import re from dataclasses import dataclass, field @@ -41,8 +42,8 @@ "z": "Z", "shuttle": "S", } -NORTHBOUND_WORDS = {"northbound", "uptown"} -SOUTHBOUND_WORDS = {"southbound", "downtown"} +NORTHBOUND_WORDS = {"northbound", "uptown", "north", "up"} +SOUTHBOUND_WORDS = {"southbound", "downtown", "south", "down"} STREET_WORDS = { "st": "street", "street": "street", @@ -56,9 +57,14 @@ } PREFS_KEY = "mta_next_train_prefs" EXIT_WORDS = { - "stop", "exit", "quit", "done", "cancel", "bye", "goodbye", "nothing else", + "stop", "exit", "quit", "done", "cancel", "bye", "goodbye", + "nothing else", "i'm done", "that's it", "i'm good", "all set", + "we're done", "no thanks", "never mind", "that's all", + "all done", "i'm finished", } +_ORDINAL_TO_IDX = {"first": 0, "second": 1, "third": 2, "fourth": 3, "fifth": 4} + @dataclass class Station: @@ -186,6 +192,15 @@ def parse_query_intent(text: str) -> QueryIntent: direction=direction, ) + # "change my default station" / "set my default" without a station name + if re.search(r"(?:set|change|update)\s+(?:my\s+)?(?:default|home)\s+station", normalized): + return QueryIntent( + action=ACTION_SET_DEFAULT, + station_text=None, + routes=routes, + direction=direction, + ) + if any( phrase in normalized for phrase in ( @@ -195,12 +210,18 @@ def parse_query_intent(text: str) -> QueryIntent: "what s the default station", "what is my home station", "what s my home station", + "what s my station", + "which station am i using", + "what station is saved", ) ): return QueryIntent(action=ACTION_GET_DEFAULT, routes=routes, direction=direction) if any( - phrase in normalized for phrase in ("help", "what can you do", "how does this work") + phrase in normalized for phrase in ( + "help", "what can you do", "how does this work", + "what can i say", "what are my options", + ) ): return QueryIntent(action=ACTION_HELP, routes=routes, direction=direction) @@ -223,9 +244,10 @@ def clean_station_phrase(text: Optional[str]) -> Optional[str]: def extract_station_phrase(normalized_text: str) -> Optional[str]: patterns = ( - r"\b(?:at|from)\b\s+(?P[a-z0-9\s]+?)(?:\s+for\s+[a-z0-9]+\s+train|\s+for\s+[a-z0-9]+|\s+train|$)", - r"\bnext\s+(?:train|subway)\s+(?:at|from)\b\s+(?P[a-z0-9\s]+)$", - r"\bwhen(?:'s|\s+is)?\s+the\s+next\s+(?:train|subway)\s+(?:at|from)\b\s+(?P[a-z0-9\s]+)$", + r"\b(?:at|from|in|for)\b\s+(?P[a-z0-9\s]+?)(?:\s+for\s+[a-z0-9]+\s+train|\s+for\s+[a-z0-9]+|\s+train|$)", + r"\bnext\s+(?:train|subway)\s+(?:at|from|in)\b\s+(?P[a-z0-9\s]+)$", + r"\bwhen(?:'s|\s+is)?\s+the\s+next\s+(?:train|subway)\s+(?:at|from|in|per)\b\s+(?P[a-z0-9\s]+)$", + r"\b(?:check|get)\s+.*?(?:at|from|in|for|per)\b\s+(?P[a-z0-9\s]+)$", ) for pattern in patterns: match = re.search(pattern, normalized_text) @@ -364,7 +386,7 @@ def format_arrivals_for_voice( direction: Optional[str], ) -> str: if not arrivals: - return f"I am not seeing any matching trains for {station.name} right now." + return f"I'm not seeing any matching trains at {station.name} right now." grouped: Dict[tuple[str, str], List[Arrival]] = {} for arrival in arrivals: @@ -383,7 +405,8 @@ def format_arrivals_for_voice( if len(lead_bits) == 1: return f"At {station.name}, the next train is {lead_bits[0]}." - return f"At {station.name}, next trains: " + "; ".join(lead_bits) + "." + joined = ", ".join(lead_bits[:-1]) + f", and {lead_bits[-1]}" + return f"At {station.name}, next trains: {joined}." def render_minutes(minutes: int) -> str: @@ -406,27 +429,55 @@ def call(self, worker: AgentWorker): self.capability_worker = CapabilityWorker(self) self.worker.session_tasks.create(self.run()) + def _get_trigger_text(self) -> str: + history = self.capability_worker.get_full_message_history() or [] + for item in reversed(history): + if item.get("role") == "user": + content = (item.get("content") or "").strip() + if content: + return content + return "" + async def run(self): try: self.worker.editor_logging_handler.info("MTA Next Train triggered") self.prefs = self.load_prefs() - trigger_text = self.get_trigger_text() - if not trigger_text: + trigger_text = self._get_trigger_text() + + # Parse trigger intent first — handle get_default/set_default/help + # directly without prompting for a station + if trigger_text: + intent = parse_query_intent(trigger_text) + if intent.action in (ACTION_GET_DEFAULT, ACTION_SET_DEFAULT, ACTION_HELP): + # Has a clear non-arrivals intent, use it directly + pass + elif intent.action == ACTION_ARRIVALS and not intent.station_text and not self.prefs.get("default_station_id"): + # Arrivals but no station and no default — need to ask + trigger_text = await self.capability_worker.run_io_loop( + "What station or line do you want to check?" + ) + # else: arrivals with station or default — proceed + else: trigger_text = await self.capability_worker.run_io_loop( - "What would you like to check? You can say, when is my next train, or set my default station to Astor Place." + "What station or line do you want to check?" ) if not trigger_text: return while True: if self._is_exit(trigger_text): - await self.capability_worker.speak("Okay. Closing MTA Next Train.") + await self.capability_worker.speak("All good. See you next ride.") return intent = parse_query_intent(trigger_text) if intent.action == ACTION_HELP: await self.capability_worker.speak( - "You can ask for your next train, ask for a specific line at a station, or set a default station. For example: when is my next train, next Q train at Union Square, or set my default station to Astor Place." + "You can ask for your next train at any station, " + "or a specific line like the Q at Union Square." + ) + await self.capability_worker.speak( + "You can also set a default station so you " + "just have to say, when's my next train." ) elif intent.action == ACTION_GET_DEFAULT: await self.handle_get_default() @@ -436,14 +487,14 @@ async def run(self): await self.handle_arrivals(intent) trigger_text = await self.capability_worker.run_io_loop( - "Anything else for the subway? Say another station or line, or say done." + "Anything else?" ) if not trigger_text: return except Exception as exc: self.worker.editor_logging_handler.error(f"[MTANextTrain] {exc}") await self.capability_worker.speak( - "Something went wrong while checking live arrivals." + "Something went wrong checking live arrivals." ) finally: self.capability_worker.resume_normal_flow() @@ -466,18 +517,12 @@ def save_prefs(self): else: self.capability_worker.create_key(PREFS_KEY, self.prefs) - def get_trigger_text(self) -> str: - history = self.capability_worker.get_full_message_history() or [] - for item in reversed(history): - if item.get("role") == "user": - content = (item.get("content") or "").strip() - if content: - return content - return "" - def _is_exit(self, text: str) -> bool: lowered = (text or "").strip().lower() - return any(phrase in lowered for phrase in EXIT_WORDS) + for phrase in EXIT_WORDS: + if lowered == phrase or lowered.startswith(phrase + " ") or lowered.endswith(" " + phrase): + return True + return False async def handle_set_default(self, intent: QueryIntent): station_text = intent.station_text @@ -498,19 +543,26 @@ async def handle_set_default(self, intent: QueryIntent): if not station: return + confirmed = await self.capability_worker.run_confirmation_loop( + f"Save {station.name} as your default station?" + ) + if not confirmed: + await self.capability_worker.speak("No problem.") + return + self.prefs["default_station_id"] = station.station_id self.prefs["default_station_name"] = station.name self.prefs["default_station_borough"] = station.borough self.prefs["default_station_lines"] = station.lines self.save_prefs() await self.capability_worker.speak( - f"Saved {station.name} as your default station." + f"Done. {station.name} is your default now." ) async def handle_get_default(self): if not self.prefs.get("default_station_id"): await self.capability_worker.speak( - "You do not have a default station saved yet." + "You don't have a default station saved yet." ) return await self.capability_worker.speak( @@ -531,12 +583,12 @@ async def handle_arrivals(self, intent: QueryIntent): if not station: spoken_station = await self.capability_worker.run_io_loop( - "Which station do you want to check? You can also say set my default station to save one." + "Which station do you want to check?" ) if not spoken_station: return if self._is_exit(spoken_station): - await self.capability_worker.speak("Okay. Closing MTA Next Train.") + await self.capability_worker.speak("All good. See you next ride.") return follow_up_intent = parse_query_intent(spoken_station) if follow_up_intent.action == ACTION_SET_DEFAULT: @@ -550,12 +602,14 @@ async def handle_arrivals(self, intent: QueryIntent): return await self.capability_worker.speak( - f"Checking live arrivals for {station.name}." + f"Checking {station.name}." ) self.worker.editor_logging_handler.info( f"MTA arrivals request station_id={station.station_id} station={station.name} routes={intent.routes} direction={intent.direction}" ) - arrivals = fetch_arrivals(station.station_id, intent.routes, intent.direction) + arrivals = await asyncio.to_thread( + fetch_arrivals, station.station_id, intent.routes, intent.direction + ) self.worker.editor_logging_handler.info( f"MTA arrivals result count={len(arrivals)} station_id={station.station_id}" ) @@ -567,27 +621,53 @@ async def handle_arrivals(self, intent: QueryIntent): ) await self.capability_worker.speak(summary) + # Offer to save as default if user doesn't have one yet + if not self.prefs.get("default_station_id"): + save = await self.capability_worker.run_confirmation_loop( + f"Want me to save {station.name} as your default?" + ) + if save: + self.prefs["default_station_id"] = station.station_id + self.prefs["default_station_name"] = station.name + self.prefs["default_station_borough"] = station.borough + self.prefs["default_station_lines"] = station.lines + self.save_prefs() + await self.capability_worker.speak("Saved.") + async def resolve_station(self, station_text: str) -> Optional[Station]: - stations = search_stations(station_text, limit=5) + await self.capability_worker.speak("One sec.") + stations = await asyncio.to_thread(search_stations, station_text, 5) matches = find_station_matches(stations, station_text) if not matches: await self.capability_worker.speak( - f"I could not find a subway station matching {station_text}." + f"Couldn't find a station matching {station_text}." ) return None top_match = matches[0] if len(matches) > 1 and (top_match.score - matches[1].score) < 0.08: - options = ", ".join(match.station.name for match in matches[:3]) + # Deduplicate by name — if all close matches have the same name, just pick the first + unique_names = list(dict.fromkeys(m.station.name for m in matches[:3])) + if len(unique_names) == 1: + return top_match.station + + options = ", ".join(unique_names) response = await self.capability_worker.run_io_loop( - f"I found a few close matches: {options}. Which one did you mean?" + f"I found a few close matches: {options}. Which one?" ) if not response: return None + + # Handle ordinal selection ("first one", "second one", "the first", etc.) + lowered_resp = response.lower().strip() + for word, idx in _ORDINAL_TO_IDX.items(): + if word in lowered_resp and idx < len(matches): + return matches[idx].station + narrowed = find_station_matches(stations, response, limit=1) if not narrowed: await self.capability_worker.speak( - "I still could not pin down the station." + "Still couldn't pin down the station." ) return None return narrowed[0].station