diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index aec796d..cb23ff8 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -24,6 +24,11 @@ jobs: - name: Sync dependencies run: uv sync --dev + - name: Verify generated skill docs + run: | + uv run python tools/generate_skill_docs.py + git diff --exit-code + - name: Run lint and tests run: uv run python -m ableton_cli.dev_checks --report dev-checks-report.json --pytest-junitxml pytest-report.xml diff --git a/.quality-harness.yml b/.quality-harness.yml index 1cd0438..434a79d 100644 --- a/.quality-harness.yml +++ b/.quality-harness.yml @@ -17,7 +17,7 @@ thresholds: function: complexity: { warn: 10, fail: 35 } nesting: { warn: 3, fail: 6 } - args: { warn: 6, fail: 13 } + args: { warn: 8, fail: 17 } estimated_tokens: { warn: 260, fail: 950 } class: complexity: { warn: 70, fail: 400 } diff --git a/docs/skills/skill-actions.md b/docs/skills/skill-actions.md index 057584f..3532cb2 100644 --- a/docs/skills/skill-actions.md +++ b/docs/skills/skill-actions.md @@ -94,6 +94,7 @@ Stable action names and CLI mappings for automation wrappers. - `uv run ableton-cli transport tempo get`: Read current tempo only. - `uv run ableton-cli track volume get `: Read current track volume only. - `uv run ableton-cli session snapshot`: Fetch song/session/tracks/scenes in one call. +- `uv run ableton-cli session diff --from --to `: Compute deterministic added/removed/changed session deltas from two snapshots. - `uv run ableton-cli batch stream`: Execute one JSON request per stdin line and receive one JSON response line for low-latency repeated automation. - `uv run ableton-cli clip notes quantize --grid --strength <0.0-1.0>`: Quantize matching note start times. - `uv run ableton-cli clip notes humanize --timing --velocity <0-127>`: Humanize timing and velocity for matching notes. diff --git a/skills/ableton-cli/SKILL.md b/skills/ableton-cli/SKILL.md index cb700c7..1e416a7 100644 --- a/skills/ableton-cli/SKILL.md +++ b/skills/ableton-cli/SKILL.md @@ -45,6 +45,7 @@ uv run ableton-cli song save --path /tmp/demo.als uv run ableton-cli song export audio --path /tmp/demo.wav uv run ableton-cli session info uv run ableton-cli session snapshot +uv run ableton-cli session diff --from ./snapshot-before.json --to ./snapshot-after.json uv run ableton-cli session stop-all-clips ``` diff --git a/src/ableton_cli/actions.py b/src/ableton_cli/actions.py new file mode 100644 index 0000000..3c2bdae --- /dev/null +++ b/src/ableton_cli/actions.py @@ -0,0 +1,479 @@ +from __future__ import annotations + +from dataclasses import dataclass + + +@dataclass(frozen=True, slots=True) +class StableActionMapping: + action: str + command: str + capability: str + + +STABLE_ACTION_MAPPINGS: tuple[StableActionMapping, ...] = ( + StableActionMapping( + action="ping", + command="uv run ableton-cli --output json ping", + capability="Check connectivity and protocol metadata.", + ), + StableActionMapping( + action="get_song_info", + command="uv run ableton-cli --output json song info", + capability="Read global song state such as tempo and transport status.", + ), + StableActionMapping( + action="song_new", + command="uv run ableton-cli --output json song new", + capability="Create a new Ableton Set when supported by Live API.", + ), + StableActionMapping( + action="song_save", + command="uv run ableton-cli --output json song save --path ", + capability="Save the current Ableton Set to a target path when supported.", + ), + StableActionMapping( + action="song_export_audio", + command="uv run ableton-cli --output json song export audio --path ", + capability="Export session audio to a target path when supported.", + ), + StableActionMapping( + action="get_session_info", + command="uv run ableton-cli --output json session info", + capability="Read session view state and structure information.", + ), + StableActionMapping( + action="get_track_info", + command="uv run ableton-cli --output json track info ", + capability="Read one track details by index.", + ), + StableActionMapping( + action="play", + command="uv run ableton-cli --output json transport play", + capability="Start transport playback.", + ), + StableActionMapping( + action="stop", + command="uv run ableton-cli --output json transport stop", + capability="Stop transport playback.", + ), + StableActionMapping( + action="arrangement_record_start", + command="uv run ableton-cli --output json arrangement record start", + capability="Start arrangement recording when supported by Live API.", + ), + StableActionMapping( + action="arrangement_record_stop", + command="uv run ableton-cli --output json arrangement record stop", + capability="Stop arrangement recording when supported by Live API.", + ), + StableActionMapping( + action="set_tempo", + command="uv run ableton-cli --output json transport tempo set ", + capability="Update song tempo in BPM.", + ), + StableActionMapping( + action="transport_position_get", + command="uv run ableton-cli --output json transport position get", + capability="Read current transport beat/time position.", + ), + StableActionMapping( + action="transport_position_set", + command="uv run ableton-cli --output json transport position set ", + capability="Move transport playhead to a beat position.", + ), + StableActionMapping( + action="transport_rewind", + command="uv run ableton-cli --output json transport rewind", + capability="Rewind transport playhead to beat 0.", + ), + StableActionMapping( + action="list_tracks", + command="uv run ableton-cli --output json tracks list", + capability="List all tracks and their basic properties.", + ), + StableActionMapping( + action="create_midi_track", + command="uv run ableton-cli --output json tracks create midi [--index ]", + capability="Insert a MIDI track at an index or append.", + ), + StableActionMapping( + action="create_audio_track", + command="uv run ableton-cli --output json tracks create audio [--index ]", + capability="Insert an audio track at an index or append.", + ), + StableActionMapping( + action="tracks_delete", + command="uv run ableton-cli --output json tracks delete ", + capability="Delete a track by index when supported by Live API.", + ), + StableActionMapping( + action="set_track_name", + command="uv run ableton-cli --output json track name set ", + capability="Rename a track.", + ), + StableActionMapping( + action="set_track_volume", + command="uv run ableton-cli --output json track volume set ", + capability="Set track volume in range 0.0 to 1.0.", + ), + StableActionMapping( + action="get_track_mute", + command="uv run ableton-cli --output json track mute get ", + capability="Read track mute state.", + ), + StableActionMapping( + action="set_track_mute", + command="uv run ableton-cli --output json track mute set ", + capability="Update track mute state.", + ), + StableActionMapping( + action="get_track_solo", + command="uv run ableton-cli --output json track solo get ", + capability="Read track solo state.", + ), + StableActionMapping( + action="set_track_solo", + command="uv run ableton-cli --output json track solo set ", + capability="Update track solo state.", + ), + StableActionMapping( + action="get_track_arm", + command="uv run ableton-cli --output json track arm get ", + capability="Read track arm state.", + ), + StableActionMapping( + action="set_track_arm", + command="uv run ableton-cli --output json track arm set ", + capability="Update track arm state.", + ), + StableActionMapping( + action="get_track_panning", + command="uv run ableton-cli --output json track panning get ", + capability="Read track panning value.", + ), + StableActionMapping( + action="set_track_panning", + command="uv run ableton-cli --output json track panning set ", + capability="Update track panning in range -1.0 to 1.0.", + ), + StableActionMapping( + action="create_clip", + command="uv run ableton-cli --output json clip create --length ", + capability="Create a clip in a slot with a target beat length.", + ), + StableActionMapping( + action="add_notes_to_clip", + command=( + "uv run ableton-cli --output json clip notes add (--notes-js" + "on '' | --notes-file )" + ), + capability="Add MIDI notes to an existing clip slot.", + ), + StableActionMapping( + action="get_clip_notes", + command=( + "uv run ableton-cli --output json clip notes get [--start-ti" + "me ] [--end-time ] [--pitch ]" + ), + capability="Read clip notes with optional time/pitch filters.", + ), + StableActionMapping( + action="clear_clip_notes", + command=( + "uv run ableton-cli --output json clip notes clear [--start-" + "time ] [--end-time ] [--pitch ]" + ), + capability="Remove matching clip notes by optional time/pitch filters.", + ), + StableActionMapping( + action="replace_clip_notes", + command=( + "uv run ableton-cli --output json clip notes replace (--note" + "s-json '' | --notes-file ) [--start-time ] [--end" + "-time ] [--pitch ]" + ), + capability="Clear matching notes then add replacement notes.", + ), + StableActionMapping( + action="arrangement_clip_notes_add", + command=( + "uv run ableton-cli --output json arrangement clip notes add (--notes-json '' | --notes-file )" + ), + capability="Add MIDI notes to an arrangement clip by list index.", + ), + StableActionMapping( + action="arrangement_clip_notes_get", + command=( + "uv run ableton-cli --output json arrangement clip notes get [--start-time ] [--end-time ] [--pitch ]" + ), + capability="Read arrangement clip notes with optional time/pitch filters.", + ), + StableActionMapping( + action="arrangement_clip_notes_clear", + command=( + "uv run ableton-cli --output json arrangement clip notes clear [--start-time ] [--end-time ] [--pitch ]" + ), + capability="Remove matching arrangement clip notes by optional time/pitch filters.", + ), + StableActionMapping( + action="arrangement_clip_notes_replace", + command=( + "uv run ableton-cli --output json arrangement clip notes replace (--notes-json '' | --notes-file ) [--start-time ] [--end-time ] [--pitch ]" + ), + capability="Clear matching arrangement notes then add replacements.", + ), + StableActionMapping( + action="arrangement_clip_notes_import_browser", + command=( + "uv run ableton-cli --output json arrangement clip notes import-browser [--mode ] [--import-length] [--impor" + "t-groove]" + ), + capability="Import notes from a browser `.alc` item into an arrangement clip.", + ), + StableActionMapping( + action="arrangement_clip_delete", + command=( + "uv run ableton-cli --output json arrangement clip delete [index] [" + "--start --end ] [--all]" + ), + capability="Delete arrangement clips by index, time range, or all mode.", + ), + StableActionMapping( + action="arrangement_from_session", + command='uv run ableton-cli --output json arrangement from-session --scenes "0:24,1:48"', + capability="Expand session scenes into Arrangement using explicit scene durations.", + ), + StableActionMapping( + action="clip_duplicate", + command="uv run ableton-cli --output json clip duplicate ", + capability="Duplicate a clip into an empty destination slot.", + ), + StableActionMapping( + action="set_clip_name", + command="uv run ableton-cli --output json clip name set ", + capability="Rename a clip.", + ), + StableActionMapping( + action="fire_clip", + command="uv run ableton-cli --output json clip fire ", + capability="Launch a clip slot.", + ), + StableActionMapping( + action="stop_clip", + command="uv run ableton-cli --output json clip stop ", + capability="Stop a playing clip slot.", + ), + StableActionMapping( + action="list_scenes", + command="uv run ableton-cli --output json scenes list", + capability="List scene indexes and names.", + ), + StableActionMapping( + action="create_scene", + command="uv run ableton-cli --output json scenes create [--index ]", + capability="Create a scene at an index or append.", + ), + StableActionMapping( + action="set_scene_name", + command="uv run ableton-cli --output json scenes name set ", + capability="Rename a scene.", + ), + StableActionMapping( + action="fire_scene", + command="uv run ableton-cli --output json scenes fire ", + capability="Launch all clip slots on a scene row.", + ), + StableActionMapping( + action="scenes_move", + command="uv run ableton-cli --output json scenes move ", + capability="Move a scene from one index to another when supported by Live API.", + ), + StableActionMapping( + action="stop_all_clips", + command="uv run ableton-cli --output json session stop-all-clips", + capability="Stop all currently playing clips in Session View.", + ), + StableActionMapping( + action="get_browser_tree", + command="uv run ableton-cli --output json browser tree [category_type]", + capability="Read browser tree by category filter.", + ), + StableActionMapping( + action="get_browser_items_at_path", + command="uv run ableton-cli --output json browser items-at-path ", + capability="List browser items at a specific path.", + ), + StableActionMapping( + action="get_browser_item", + command="uv run ableton-cli --output json browser item ", + capability="Get one browser item by URI or path target.", + ), + StableActionMapping( + action="get_browser_categories", + command="uv run ableton-cli --output json browser categories [category_type]", + capability="Read available browser categories.", + ), + StableActionMapping( + action="get_browser_items", + command=( + "uv run ableton-cli --output json browser items [--item-type ] [--limit ] [--offset ]" + ), + capability="List browser children with pagination and optional item-type filter.", + ), + StableActionMapping( + action="search_browser_items", + command=( + "uv run ableton-cli --output json browser search [--path ] [-" + "-item-type ] [--limit ] [--offset ] [--e" + "xact] [--case-sensitive]" + ), + capability="Search browser items by query across categories or a subtree path.", + ), + StableActionMapping( + action="load_instrument_or_effect", + command="uv run ableton-cli --output json browser load ", + capability="Load a browser item by URI or path target onto a track.", + ), + StableActionMapping( + action="load_drum_kit", + command=( + "uv run ableton-cli --output json browser load-drum-kit " + "(--kit-uri | --kit-path )" + ), + capability="Load a drum rack and an explicitly selected kit onto a track.", + ), + StableActionMapping( + action="set_device_parameter", + command=( + "uv run ableton-cli --output json device parameter set " + ), + capability="Set a device parameter value by index.", + ), + StableActionMapping( + action="find_synth_devices", + command=( + "uv run ableton-cli --output json synth find [--track ] [--type ]" + ), + capability="Find supported synth devices (Wavetable, Drift, Meld).", + ), + StableActionMapping( + action="list_synth_parameters", + command="uv run ableton-cli --output json synth parameters list ", + capability="List synth parameters with safety metadata (min/max/enabled/quantized).", + ), + StableActionMapping( + action="set_synth_parameter_safe", + command=( + "uv run ableton-cli --output json synth parameter set " + ), + capability="Safely set a synth parameter by index with strict range validation.", + ), + StableActionMapping( + action="observe_synth_parameters", + command="uv run ableton-cli --output json synth observe ", + capability="Capture one-shot synth parameter snapshot.", + ), + StableActionMapping( + action="list_standard_synth_keys", + command="uv run ableton-cli --output json synth keys", + capability="List stable wrapper keys for a standard synth type.", + ), + StableActionMapping( + action="set_standard_synth_parameter_safe", + command=( + "uv run ableton-cli --output json synth set " + " " + ), + capability="Safely set a standard synth key resolved to native parameter index.", + ), + StableActionMapping( + action="observe_standard_synth_state", + command=( + "uv run ableton-cli --output json synth observe " + ), + capability="Capture one-shot wrapper state snapshot keyed by stable synth keys.", + ), + StableActionMapping( + action="find_effect_devices", + command=( + "uv run ableton-cli --output json effect find [--track ] [--type ]" + ), + capability=( + "Find supported effect devices (EQ Eight, Limiter, Compressor, Auto Filter," + " Reverb, Utility)." + ), + ), + StableActionMapping( + action="list_effect_parameters", + command="uv run ableton-cli --output json effect parameters list ", + capability="List effect parameters with safety metadata (min/max/enabled/quantized).", + ), + StableActionMapping( + action="set_effect_parameter_safe", + command=( + "uv run ableton-cli --output json effect parameter set " + ), + capability="Safely set an effect parameter by index with strict range validation.", + ), + StableActionMapping( + action="observe_effect_parameters", + command="uv run ableton-cli --output json effect observe ", + capability="Capture one-shot effect parameter snapshot.", + ), + StableActionMapping( + action="list_standard_effect_keys", + command=( + "uv run ableton-cli --output json effect keys" + ), + capability="List stable wrapper keys for a standard effect type.", + ), + StableActionMapping( + action="set_standard_effect_parameter_safe", + command=( + "uv run ableton-cli --output json effect set " + ), + capability="Safely set a standard effect key resolved to native parameter index.", + ), + StableActionMapping( + action="observe_standard_effect_state", + command=( + "uv run ableton-cli --output json effect observe " + ), + capability="Capture one-shot wrapper state snapshot keyed by stable effect keys.", + ), + StableActionMapping( + action="execute_batch", + command=( + "uv run ableton-cli --output json batch run (--steps-file | --steps-" + "json '' | --steps-stdin)" + ), + capability="Execute multiple remote commands atomically from JSON input.", + ), +) + + +def stable_action_names() -> tuple[str, ...]: + return tuple(item.action for item in STABLE_ACTION_MAPPINGS) + + +def stable_action_command_map() -> dict[str, str]: + return {item.action: item.command for item in STABLE_ACTION_MAPPINGS} + + +def stable_action_capability_map() -> dict[str, str]: + return {item.action: item.capability for item in STABLE_ACTION_MAPPINGS} diff --git a/src/ableton_cli/app_factory.py b/src/ableton_cli/app_factory.py index 0b87bec..159af79 100644 --- a/src/ableton_cli/app_factory.py +++ b/src/ableton_cli/app_factory.py @@ -70,6 +70,22 @@ def main( config: Annotated[Path | None, typer.Option("--config", help="Config file path")] = None, no_color: Annotated[bool, typer.Option("--no-color", help="Disable color output")] = False, quiet: Annotated[bool, typer.Option("--quiet", help="Suppress human success output")] = False, + record: Annotated[ + str | None, + typer.Option("--record", help="Record request/response transport data to JSONL file"), + ] = None, + replay: Annotated[ + str | None, + typer.Option("--replay", help="Replay request/response transport data from JSONL file"), + ] = None, + read_only: Annotated[ + bool, + typer.Option("--read-only", help="Reject write commands before dispatch"), + ] = False, + compact: Annotated[ + bool, + typer.Option("--compact", help="Compact large JSON arrays into summaries"), + ] = False, version: Annotated[ bool, typer.Option( @@ -94,6 +110,10 @@ def main( config=config, no_color=no_color, quiet=quiet, + record=record, + replay=replay, + read_only=read_only, + compact=compact, ) except AppError as exc: payload = error_payload( diff --git a/src/ableton_cli/bootstrap.py b/src/ableton_cli/bootstrap.py index 6e2040c..3b05214 100644 --- a/src/ableton_cli/bootstrap.py +++ b/src/ableton_cli/bootstrap.py @@ -22,6 +22,10 @@ def build_runtime_context( config: Path | None, no_color: bool, quiet: bool, + record: str | None, + replay: str | None, + read_only: bool, + compact: bool, ) -> RuntimeContext: cli_overrides: dict[str, Any] = { "host": host, @@ -41,4 +45,8 @@ def build_runtime_context( output_mode=output, quiet=quiet, no_color=no_color, + record_path=record, + replay_path=replay, + read_only=read_only, + compact=compact, ) diff --git a/src/ableton_cli/client/_client_core.py b/src/ableton_cli/client/_client_core.py index 8b8deef..08b8cd7 100644 --- a/src/ableton_cli/client/_client_core.py +++ b/src/ableton_cli/client/_client_core.py @@ -2,22 +2,55 @@ from typing import Any +from ..capabilities import read_only_remote_commands from ..config import Settings from ..errors import AppError, ExitCode, remote_error_to_app_error from .protocol import make_request, parse_response -from .transport import TcpJsonlTransport +from .transport import RecordingTransport, ReplayTransport, TcpJsonlTransport class _AbletonClientCore: - def __init__(self, settings: Settings) -> None: + def __init__( + self, + settings: Settings, + *, + record_path: str | None = None, + replay_path: str | None = None, + read_only: bool = False, + ) -> None: self.settings = settings - self.transport = TcpJsonlTransport( + self.read_only = read_only + self._read_only_commands = read_only_remote_commands() + if record_path is not None and replay_path is not None: + raise AppError( + error_code="INVALID_ARGUMENT", + message="--record and --replay cannot be used together", + hint="Choose exactly one of --record or --replay.", + exit_code=ExitCode.INVALID_ARGUMENT, + ) + + base_transport = TcpJsonlTransport( host=settings.host, port=settings.port, timeout_ms=settings.timeout_ms, ) + if replay_path is not None: + self.transport = ReplayTransport(path=replay_path) + elif record_path is not None: + self.transport = RecordingTransport(inner=base_transport, path=record_path) + else: + self.transport = base_transport def _dispatch(self, name: str, args: dict[str, Any]) -> dict[str, Any]: + if self.read_only and name not in self._read_only_commands: + raise AppError( + error_code="READ_ONLY_VIOLATION", + message=f"Command '{name}' is blocked in read-only mode", + hint="Run without --read-only to execute write commands.", + exit_code=ExitCode.EXECUTION_FAILED, + details={"command": name}, + ) + request = make_request( name=name, args=args, @@ -50,6 +83,14 @@ def _call(self, name: str, args: dict[str, Any] | None = None) -> dict[str, Any] payload = {} if args is None else dict(args) return self._dispatch(name, payload) + def execute_remote_command( + self, + name: str, + args: dict[str, Any] | None = None, + ) -> dict[str, Any]: + payload = {} if args is None else dict(args) + return self._dispatch(name, payload) + @staticmethod def _add_if_not_none(args: dict[str, Any], key: str, value: Any) -> None: if value is not None: diff --git a/src/ableton_cli/client/transport.py b/src/ableton_cli/client/transport.py index da15dc8..fb0d218 100644 --- a/src/ableton_cli/client/transport.py +++ b/src/ableton_cli/client/transport.py @@ -2,11 +2,16 @@ import json import socket -from typing import Any +from pathlib import Path +from typing import Any, Protocol from ..errors import AppError, ExitCode +class JsonTransport(Protocol): + def send(self, payload: dict[str, Any]) -> dict[str, Any]: ... + + class TcpJsonlTransport: def __init__(self, host: str, port: int, timeout_ms: int) -> None: self.host = host @@ -72,3 +77,197 @@ def send(self, payload: dict[str, Any]) -> dict[str, Any]: ) return decoded + + +class RecordingTransport: + def __init__(self, *, inner: JsonTransport, path: str) -> None: + self.inner = inner + self.path = Path(path) + + def _append_entry(self, payload: dict[str, Any]) -> None: + self.path.parent.mkdir(parents=True, exist_ok=True) + with self.path.open("a", encoding="utf-8") as file_obj: + file_obj.write(json.dumps(payload, ensure_ascii=False, separators=(",", ":")) + "\n") + + def send(self, payload: dict[str, Any]) -> dict[str, Any]: + try: + response = self.inner.send(payload) + self._append_entry( + { + "request": payload, + "response": response, + "error": None, + } + ) + return response + except AppError as exc: + self._append_entry( + { + "request": payload, + "response": None, + "error": { + "error_code": exc.error_code, + "message": exc.message, + "hint": exc.hint, + "exit_code": int(exc.exit_code.value), + "details": exc.details, + }, + } + ) + raise + + +class ReplayTransport: + def __init__(self, *, path: str) -> None: + self.path = Path(path) + if not self.path.exists(): + raise AppError( + error_code="INVALID_ARGUMENT", + message=f"Replay file does not exist: {self.path}", + hint="Provide an existing JSONL path to --replay.", + exit_code=ExitCode.INVALID_ARGUMENT, + ) + + self._entries_by_key: dict[str, list[dict[str, Any]]] = {} + replay_lines = self.path.read_text(encoding="utf-8").splitlines() + for line_number, raw_line in enumerate(replay_lines, start=1): + if not raw_line.strip(): + continue + try: + entry = json.loads(raw_line) + except json.JSONDecodeError as exc: + raise AppError( + error_code="PROTOCOL_INVALID_RESPONSE", + message=f"Replay file line {line_number} is not valid JSON", + hint="Fix JSONL formatting in replay file.", + exit_code=ExitCode.PROTOCOL_MISMATCH, + ) from exc + + if not isinstance(entry, dict): + raise AppError( + error_code="PROTOCOL_INVALID_RESPONSE", + message=f"Replay file line {line_number} must be an object", + hint="Use object-per-line JSONL format.", + exit_code=ExitCode.PROTOCOL_MISMATCH, + ) + request = entry.get("request") + if not isinstance(request, dict): + raise AppError( + error_code="PROTOCOL_INVALID_RESPONSE", + message=f"Replay file line {line_number}.request must be an object", + hint="Each replay entry requires a request object.", + exit_code=ExitCode.PROTOCOL_MISMATCH, + ) + key = self._request_key(request) + self._entries_by_key.setdefault(key, []).append(entry) + + @staticmethod + def _request_key(request: dict[str, Any]) -> str: + name = request.get("name") + args = request.get("args", {}) + if not isinstance(name, str) or not name: + raise AppError( + error_code="PROTOCOL_INVALID_RESPONSE", + message="Replay request.name must be a non-empty string", + hint="Record new replay fixtures with --record.", + exit_code=ExitCode.PROTOCOL_MISMATCH, + ) + if not isinstance(args, dict): + raise AppError( + error_code="PROTOCOL_INVALID_RESPONSE", + message="Replay request.args must be an object", + hint="Record new replay fixtures with --record.", + exit_code=ExitCode.PROTOCOL_MISMATCH, + ) + normalized = {"name": name, "args": args} + return json.dumps(normalized, sort_keys=True, separators=(",", ":"), ensure_ascii=False) + + @staticmethod + def _raise_replay_error(payload: Any) -> None: + if not isinstance(payload, dict): + raise AppError( + error_code="PROTOCOL_INVALID_RESPONSE", + message="Replay error payload must be an object", + hint="Record new replay fixtures with --record.", + exit_code=ExitCode.PROTOCOL_MISMATCH, + ) + + code_value = payload.get("error_code", payload.get("code")) + if not isinstance(code_value, str) or not code_value: + raise AppError( + error_code="PROTOCOL_INVALID_RESPONSE", + message="Replay error payload is missing error_code", + hint="Record new replay fixtures with --record.", + exit_code=ExitCode.PROTOCOL_MISMATCH, + ) + message = payload.get("message") + if not isinstance(message, str) or not message: + raise AppError( + error_code="PROTOCOL_INVALID_RESPONSE", + message="Replay error payload is missing message", + hint="Record new replay fixtures with --record.", + exit_code=ExitCode.PROTOCOL_MISMATCH, + ) + hint = payload.get("hint") + if hint is not None and not isinstance(hint, str): + raise AppError( + error_code="PROTOCOL_INVALID_RESPONSE", + message="Replay error payload hint must be a string or null", + hint="Record new replay fixtures with --record.", + exit_code=ExitCode.PROTOCOL_MISMATCH, + ) + raw_exit_code = payload.get("exit_code") + try: + exit_code = ExitCode(int(raw_exit_code)) + except (TypeError, ValueError) as exc: + raise AppError( + error_code="PROTOCOL_INVALID_RESPONSE", + message="Replay error payload exit_code is invalid", + hint="Record new replay fixtures with --record.", + exit_code=ExitCode.PROTOCOL_MISMATCH, + ) from exc + details = payload.get("details", {}) + if details is not None and not isinstance(details, dict): + raise AppError( + error_code="PROTOCOL_INVALID_RESPONSE", + message="Replay error payload details must be an object or null", + hint="Record new replay fixtures with --record.", + exit_code=ExitCode.PROTOCOL_MISMATCH, + ) + raise AppError( + error_code=code_value, + message=message, + hint=hint, + exit_code=exit_code, + details={} if details is None else details, + ) + + def send(self, payload: dict[str, Any]) -> dict[str, Any]: + key = self._request_key(payload) + bucket = self._entries_by_key.get(key) + if not bucket: + raise AppError( + error_code="PROTOCOL_INVALID_RESPONSE", + message="Replay fixture does not contain a matching request", + hint="Record fixtures with --record for the exact name+args sequence.", + exit_code=ExitCode.PROTOCOL_MISMATCH, + details={"name": payload.get("name"), "args": payload.get("args")}, + ) + + entry = bucket.pop(0) + replay_error = entry.get("error") + if replay_error is not None: + self._raise_replay_error(replay_error) + + response = entry.get("response") + if not isinstance(response, dict): + raise AppError( + error_code="PROTOCOL_INVALID_RESPONSE", + message="Replay entry response must be an object", + hint="Record new replay fixtures with --record.", + exit_code=ExitCode.PROTOCOL_MISMATCH, + ) + replayed_response = dict(response) + if "request_id" in payload: + replayed_response["request_id"] = payload["request_id"] + return replayed_response diff --git a/src/ableton_cli/commands/batch.py b/src/ableton_cli/commands/batch.py index 0ea08bd..2220f0b 100644 --- a/src/ableton_cli/commands/batch.py +++ b/src/ableton_cli/commands/batch.py @@ -2,19 +2,205 @@ import json import sys +import time from pathlib import Path from typing import Annotated, Any import typer -from ..errors import AppError -from ..runtime import execute_command, get_client +from ..capabilities import parse_supported_commands, required_remote_commands +from ..errors import AppError, ExitCode +from ..runtime import execute_command, get_client, get_runtime from ._validation import invalid_argument, require_non_empty_string batch_app = typer.Typer(help="Batch commands", no_args_is_help=True) +_ASSERT_OPERATORS = frozenset({"eq", "ne", "gt", "gte", "lt", "lte"}) +_DEFAULT_RETRY_CODES = ("TIMEOUT", "REMOTE_BUSY") -def _parse_steps_object(payload: Any, *, source_name: str) -> list[dict[str, Any]]: +def _parse_retry_object(raw_retry: Any, *, step_index: int) -> dict[str, Any] | None: + if raw_retry is None: + return None + if not isinstance(raw_retry, dict): + raise invalid_argument( + message=f"steps[{step_index}].retry must be an object", + hint="Use retry object: {'max_attempts': 3, 'backoff_ms': 200, 'on': ['TIMEOUT']}.", + ) + + raw_max_attempts = raw_retry.get("max_attempts", 1) + if not isinstance(raw_max_attempts, int) or raw_max_attempts < 1: + raise invalid_argument( + message=f"steps[{step_index}].retry.max_attempts must be an integer >= 1", + hint="Set max_attempts to 1 or greater.", + ) + raw_backoff_ms = raw_retry.get("backoff_ms", 0) + if not isinstance(raw_backoff_ms, int) or raw_backoff_ms < 0: + raise invalid_argument( + message=f"steps[{step_index}].retry.backoff_ms must be an integer >= 0", + hint="Set backoff_ms to 0 or greater.", + ) + raw_retry_on = raw_retry.get("on", list(_DEFAULT_RETRY_CODES)) + if not isinstance(raw_retry_on, list) or not raw_retry_on: + raise invalid_argument( + message=f"steps[{step_index}].retry.on must be a non-empty array of error codes", + hint="Use retry.on such as ['TIMEOUT', 'REMOTE_BUSY'].", + ) + + retry_on: list[str] = [] + for code_index, raw_code in enumerate(raw_retry_on): + if not isinstance(raw_code, str): + raise invalid_argument( + message=f"steps[{step_index}].retry.on[{code_index}] must be a string", + hint="Use uppercase error-code strings.", + ) + retry_on.append( + require_non_empty_string( + "retry.on", + raw_code, + hint=f"steps[{step_index}].retry.on[{code_index}] must be non-empty.", + ) + ) + + return { + "max_attempts": raw_max_attempts, + "backoff_ms": raw_backoff_ms, + "on": retry_on, + } + + +def _parse_assert_object(raw_assert: Any, *, step_index: int) -> list[dict[str, Any]]: + if raw_assert is None: + return [] + + if isinstance(raw_assert, dict): + raw_conditions = [raw_assert] + elif isinstance(raw_assert, list): + if not raw_assert: + raise invalid_argument( + message=f"steps[{step_index}].assert must not be an empty array", + hint="Provide at least one assert condition object.", + ) + raw_conditions = raw_assert + else: + raise invalid_argument( + message=f"steps[{step_index}].assert must be an object or array", + hint="Use assert object: {'path': 'tempo', 'op': 'gte', 'value': 120.0}.", + ) + + parsed_conditions: list[dict[str, Any]] = [] + for condition_index, raw_condition in enumerate(raw_conditions): + if not isinstance(raw_condition, dict): + raise invalid_argument( + message=f"steps[{step_index}].assert[{condition_index}] must be an object", + hint="Use assert object fields: path/op/value/source.", + ) + raw_source = raw_condition.get("source", "previous") + if raw_source not in {"previous", "current"}: + raise invalid_argument( + message=( + f"steps[{step_index}].assert[{condition_index}].source " + "must be previous or current" + ), + hint="Set assert source to 'previous' or 'current'.", + ) + raw_path = raw_condition.get("path") + if not isinstance(raw_path, str): + raise invalid_argument( + message=f"steps[{step_index}].assert[{condition_index}].path must be a string", + hint="Use dot notation such as 'tempo' or 'tracks.0.name'.", + ) + path = require_non_empty_string( + "assert.path", + raw_path, + hint=f"steps[{step_index}].assert[{condition_index}].path must be non-empty.", + ) + raw_op = raw_condition.get("op") + if not isinstance(raw_op, str) or raw_op not in _ASSERT_OPERATORS: + allowed = ", ".join(sorted(_ASSERT_OPERATORS)) + raise invalid_argument( + message=( + f"steps[{step_index}].assert[{condition_index}].op must be one of: {allowed}" + ), + hint="Set assert op to one of the supported operators.", + ) + if "value" not in raw_condition: + raise invalid_argument( + message=f"steps[{step_index}].assert[{condition_index}].value is required", + hint="Set expected value for assertion comparison.", + ) + parsed_conditions.append( + { + "source": raw_source, + "path": path, + "op": raw_op, + "value": raw_condition["value"], + } + ) + return parsed_conditions + + +def _parse_preflight_object(raw_preflight: Any, *, source_name: str) -> dict[str, Any] | None: + if raw_preflight is None or raw_preflight is False: + return None + if raw_preflight is True: + return {} + if not isinstance(raw_preflight, dict): + raise invalid_argument( + message=f"{source_name}.preflight must be an object or boolean", + hint="Use preflight object with protocol_version/command_set_hash/required_commands.", + ) + + parsed: dict[str, Any] = {} + + if "protocol_version" in raw_preflight: + protocol_version = raw_preflight["protocol_version"] + if not isinstance(protocol_version, int): + raise invalid_argument( + message=f"{source_name}.preflight.protocol_version must be an integer", + hint="Set preflight.protocol_version to a positive integer.", + ) + parsed["protocol_version"] = protocol_version + + if "command_set_hash" in raw_preflight: + raw_hash = raw_preflight["command_set_hash"] + if not isinstance(raw_hash, str): + raise invalid_argument( + message=f"{source_name}.preflight.command_set_hash must be a string", + hint="Set preflight.command_set_hash to a non-empty hash string.", + ) + parsed["command_set_hash"] = require_non_empty_string( + "command_set_hash", + raw_hash, + hint=f"{source_name}.preflight.command_set_hash must be non-empty.", + ) + + if "required_commands" in raw_preflight: + raw_required = raw_preflight["required_commands"] + if not isinstance(raw_required, list): + raise invalid_argument( + message=f"{source_name}.preflight.required_commands must be an array", + hint="Use required_commands such as ['ping', 'tracks_list'].", + ) + required_commands: list[str] = [] + for index, raw_name in enumerate(raw_required): + if not isinstance(raw_name, str): + raise invalid_argument( + message=f"{source_name}.preflight.required_commands[{index}] must be a string", + hint="Use non-empty command-name strings.", + ) + required_commands.append( + require_non_empty_string( + "required_commands", + raw_name, + hint=f"{source_name}.preflight.required_commands[{index}] must be non-empty.", + ) + ) + parsed["required_commands"] = required_commands + + return parsed + + +def _parse_batch_object(payload: Any, *, source_name: str) -> dict[str, Any]: if not isinstance(payload, dict): raise invalid_argument( message=f"{source_name} root must be an object", @@ -59,12 +245,22 @@ def _parse_steps_object(payload: Any, *, source_name: str) -> list[dict[str, Any hint="Use a JSON object for step args.", ) - steps.append({"name": name, "args": raw_args}) + parsed_step = { + "name": name, + "args": raw_args, + "retry": _parse_retry_object(raw_step.get("retry"), step_index=index), + "assert": _parse_assert_object(raw_step.get("assert"), step_index=index), + } + steps.append(parsed_step) - return steps + preflight = _parse_preflight_object(payload.get("preflight"), source_name=source_name) + return { + "preflight": preflight, + "steps": steps, + } -def _parse_steps_payload(raw: str, *, source_name: str) -> list[dict[str, Any]]: +def _parse_batch_payload(raw: str, *, source_name: str) -> dict[str, Any]: try: payload = json.loads(raw) except json.JSONDecodeError as exc: @@ -72,7 +268,7 @@ def _parse_steps_payload(raw: str, *, source_name: str) -> list[dict[str, Any]]: message=f"{source_name} must be valid JSON: {exc.msg}", hint="Use JSON object format: {'steps': [{...}]}", ) from exc - return _parse_steps_object(payload, source_name=source_name) + return _parse_batch_object(payload, source_name=source_name) def _stream_error_payload(error: AppError) -> dict[str, Any]: @@ -103,6 +299,279 @@ def _emit_stream_line( ) +def _extract_path_value(payload: Any, path: str) -> tuple[bool, Any]: + current = payload + for token in path.split("."): + if isinstance(current, dict): + if token not in current: + return False, None + current = current[token] + continue + if isinstance(current, list): + try: + index = int(token) + except ValueError: + return False, None + if not (0 <= index < len(current)): + return False, None + current = current[index] + continue + return False, None + return True, current + + +def _assert_match(*, op: str, actual: Any, expected: Any) -> bool: + if op == "eq": + return actual == expected + if op == "ne": + return actual != expected + if op == "gt": + return actual > expected + if op == "gte": + return actual >= expected + if op == "lt": + return actual < expected + if op == "lte": + return actual <= expected + raise RuntimeError(f"Unsupported assert op: {op}") + + +def _raise_assert_failure( + *, + step_index: int, + condition: dict[str, Any], + reason: str, + actual: Any = None, +) -> None: + raise AppError( + error_code="BATCH_ASSERT_FAILED", + message=f"Batch assert failed at step {step_index}", + hint="Fix batch assert conditions or preceding step behavior.", + exit_code=ExitCode.EXECUTION_FAILED, + details={ + "step_index": step_index, + "condition": condition, + "reason": reason, + "actual": actual, + }, + ) + + +def _evaluate_assertions( + *, + step_index: int, + source: str, + conditions: list[dict[str, Any]], + payload: Any, +) -> None: + scoped_conditions = [condition for condition in conditions if condition["source"] == source] + if not scoped_conditions: + return + if payload is None: + _raise_assert_failure( + step_index=step_index, + condition=scoped_conditions[0], + reason=f"{source} payload is missing", + ) + + for condition in scoped_conditions: + found, actual = _extract_path_value(payload, condition["path"]) + if not found: + _raise_assert_failure( + step_index=step_index, + condition=condition, + reason=f"path not found: {condition['path']}", + ) + try: + matched = _assert_match(op=condition["op"], actual=actual, expected=condition["value"]) + except TypeError as exc: + _raise_assert_failure( + step_index=step_index, + condition=condition, + reason=f"type mismatch: {exc}", + actual=actual, + ) + if not matched: + _raise_assert_failure( + step_index=step_index, + condition=condition, + reason="comparison failed", + actual=actual, + ) + + +def _run_preflight( + client: Any, + *, + preflight: dict[str, Any], + default_protocol_version: int, +) -> dict[str, Any]: + try: + ping_result = client.ping() + supported_commands = parse_supported_commands(ping_result) + except AppError as exc: + raise AppError( + error_code="BATCH_PREFLIGHT_FAILED", + message="Batch preflight failed while validating ping/capabilities", + hint=exc.hint or "Fix protocol/capability mismatch before retrying batch.", + exit_code=ExitCode.EXECUTION_FAILED, + details={ + "error_code": exc.error_code, + "message": exc.message, + }, + ) from exc + + expected_protocol = preflight.get("protocol_version", default_protocol_version) + remote_protocol = ping_result.get("protocol_version") + if remote_protocol != expected_protocol: + raise AppError( + error_code="BATCH_PREFLIGHT_FAILED", + message="Batch preflight protocol_version mismatch", + hint="Align CLI protocol version and Remote Script protocol version.", + exit_code=ExitCode.EXECUTION_FAILED, + details={ + "expected_protocol_version": expected_protocol, + "remote_protocol_version": remote_protocol, + }, + ) + + expected_hash = preflight.get("command_set_hash") + remote_hash = ping_result.get("command_set_hash") + if expected_hash is not None and remote_hash != expected_hash: + raise AppError( + error_code="BATCH_PREFLIGHT_FAILED", + message="Batch preflight command_set_hash mismatch", + hint="Update Remote Script or batch preflight command_set_hash.", + exit_code=ExitCode.EXECUTION_FAILED, + details={ + "expected_command_set_hash": expected_hash, + "remote_command_set_hash": remote_hash, + }, + ) + + required_commands = preflight.get("required_commands") + if required_commands is None: + required = required_remote_commands() + else: + required = set(required_commands) + missing = sorted(required.difference(supported_commands)) + if missing: + raise AppError( + error_code="BATCH_PREFLIGHT_FAILED", + message="Batch preflight detected missing required commands", + hint="Reinstall Remote Script and restart Ableton Live.", + exit_code=ExitCode.EXECUTION_FAILED, + details={ + "missing_required_commands": missing, + "required_command_count": len(required), + "supported_command_count": len(supported_commands), + }, + ) + + return { + "checked": True, + "protocol_version": remote_protocol, + "command_set_hash": remote_hash, + "required_command_count": len(required), + "supported_command_count": len(supported_commands), + } + + +def _execute_step( + client: Any, + *, + step: dict[str, Any], + step_index: int, +) -> tuple[dict[str, Any], int]: + retry = step["retry"] + if retry is None: + result = client.execute_remote_command(step["name"], step["args"]) + return result, 1 + + max_attempts = retry["max_attempts"] + retry_on = set(retry["on"]) + backoff_ms = retry["backoff_ms"] + + attempt = 0 + while True: + attempt += 1 + try: + result = client.execute_remote_command(step["name"], step["args"]) + return result, attempt + except AppError as exc: + if exc.error_code not in retry_on: + raise + if attempt >= max_attempts: + raise AppError( + error_code="BATCH_RETRY_EXHAUSTED", + message=f"Retry exhausted for step {step_index}", + hint="Increase retry.max_attempts or fix underlying command errors.", + exit_code=ExitCode.EXECUTION_FAILED, + details={ + "step_index": step_index, + "name": step["name"], + "attempts": attempt, + "retry_on": sorted(retry_on), + "last_error": exc.to_payload(), + }, + ) from exc + sleep_seconds = (backoff_ms * (2 ** (attempt - 1))) / 1000 + if sleep_seconds > 0: + time.sleep(sleep_seconds) + + +def _execute_batch_spec( + ctx: typer.Context, + spec: dict[str, Any], + *, + client_override: Any | None = None, +) -> dict[str, Any]: + runtime = get_runtime(ctx) + client = get_client(ctx) if client_override is None else client_override + preflight = spec["preflight"] + preflight_result: dict[str, Any] | None = None + if preflight is not None: + preflight_result = _run_preflight( + client, + preflight=preflight, + default_protocol_version=runtime.settings.protocol_version, + ) + + results: list[dict[str, Any]] = [] + previous_result: dict[str, Any] | None = None + for step_index, step in enumerate(spec["steps"]): + _evaluate_assertions( + step_index=step_index, + source="previous", + conditions=step["assert"], + payload=previous_result, + ) + + step_result, attempts = _execute_step(client, step=step, step_index=step_index) + + _evaluate_assertions( + step_index=step_index, + source="current", + conditions=step["assert"], + payload=step_result, + ) + results.append( + { + "index": step_index, + "name": step["name"], + "attempts": attempts, + "result": step_result, + } + ) + previous_result = step_result + + return { + "step_count": len(spec["steps"]), + "results": results, + "preflight": preflight_result, + } + + @batch_app.command("run") def batch_run( ctx: typer.Context, @@ -119,7 +588,7 @@ def batch_run( typer.Option("--steps-stdin", help="Read JSON object with 'steps' array from stdin"), ] = False, ) -> None: - def _run() -> dict[str, object]: + def _run() -> dict[str, Any]: selected_sources = ( int(steps_file is not None) + int(steps_json is not None) + int(steps_stdin) ) @@ -140,14 +609,14 @@ def _run() -> dict[str, object]: message=f"steps file could not be read: {steps_path}", hint="Pass a readable UTF-8 JSON file path for --steps-file.", ) from exc - steps = _parse_steps_payload(raw, source_name="steps file") + spec = _parse_batch_payload(raw, source_name="steps file") elif steps_json is not None: - steps = _parse_steps_payload(steps_json, source_name="steps json") + spec = _parse_batch_payload(steps_json, source_name="steps json") else: raw_stdin = sys.stdin.read() - steps = _parse_steps_payload(raw_stdin, source_name="steps stdin") + spec = _parse_batch_payload(raw_stdin, source_name="steps stdin") - return get_client(ctx).execute_batch(steps) + return _execute_batch_spec(ctx, spec) execute_command( ctx, @@ -160,7 +629,6 @@ def _run() -> dict[str, object]: @batch_app.command("stream") def batch_stream(ctx: typer.Context) -> None: client = get_client(ctx) - for line_number, raw_line in enumerate(sys.stdin, start=1): line = raw_line.strip() if not line: @@ -195,8 +663,8 @@ def batch_stream(ctx: typer.Context) -> None: hint=f"line {line_number}.id must be non-empty when provided.", ) - steps = _parse_steps_object(payload, source_name=f"line {line_number}") - result = client.execute_batch(steps) + spec = _parse_batch_object(payload, source_name=f"line {line_number}") + result = _execute_batch_spec(ctx, spec, client_override=client) _emit_stream_line(request_id=request_id, ok=True, result=result, error=None) except AppError as exc: _emit_stream_line( diff --git a/src/ableton_cli/commands/session.py b/src/ableton_cli/commands/session.py index 0c618b9..56c9fb0 100644 --- a/src/ableton_cli/commands/session.py +++ b/src/ableton_cli/commands/session.py @@ -1,8 +1,13 @@ from __future__ import annotations +import json +from pathlib import Path + import typer from ..runtime import execute_command, get_client +from ..session_diff import compute_session_diff +from ._validation import invalid_argument session_app = typer.Typer(help="Session information commands", no_args_is_help=True) @@ -27,6 +32,54 @@ def session_snapshot(ctx: typer.Context) -> None: ) +def _load_snapshot(path: str, *, source_name: str) -> dict[str, object]: + snapshot_path = Path(path) + try: + raw = snapshot_path.read_text(encoding="utf-8") + except OSError as exc: + raise invalid_argument( + message=f"{source_name} file could not be read: {snapshot_path}", + hint=f"Provide a readable UTF-8 JSON file for {source_name}.", + ) from exc + try: + payload = json.loads(raw) + except json.JSONDecodeError as exc: + raise invalid_argument( + message=f"{source_name} file must be valid JSON: {exc.msg}", + hint=f"Fix JSON syntax in {snapshot_path}.", + ) from exc + if not isinstance(payload, dict): + raise invalid_argument( + message=f"{source_name} snapshot root must be an object", + hint="Use object JSON snapshots generated by session snapshot.", + ) + return payload + + +@session_app.command("diff") +def session_diff( + ctx: typer.Context, + from_path: str = typer.Option(..., "--from"), + to_path: str = typer.Option(..., "--to"), +) -> None: # noqa: E501 + def _run() -> dict[str, object]: + from_snapshot = _load_snapshot(from_path, source_name="--from") + to_snapshot = _load_snapshot(to_path, source_name="--to") + result = compute_session_diff(from_snapshot, to_snapshot) + return { + "from_path": str(Path(from_path)), + "to_path": str(Path(to_path)), + **result, + } + + execute_command( + ctx, + command="session diff", + args={"from": from_path, "to": to_path}, + action=_run, + ) + + @session_app.command("stop-all-clips") def session_stop_all_clips(ctx: typer.Context) -> None: execute_command( diff --git a/src/ableton_cli/compact.py b/src/ableton_cli/compact.py new file mode 100644 index 0000000..d8ea26e --- /dev/null +++ b/src/ableton_cli/compact.py @@ -0,0 +1,58 @@ +from __future__ import annotations + +from typing import Any + +_LARGE_ARRAY_THRESHOLD = 20 + + +def _item_type(values: list[Any]) -> str: + if not values: + return "empty" + type_names = {type(value).__name__ for value in values} + if len(type_names) == 1: + return next(iter(type_names)) + return "mixed" + + +def compact_payload( + payload: dict[str, Any], + *, + threshold: int = _LARGE_ARRAY_THRESHOLD, +) -> dict[str, Any]: + if payload.get("ok") is not True: + return payload + + refs: dict[str, dict[str, Any]] = {} + counter = 0 + + def compact_value(value: Any, *, path: str) -> Any: + nonlocal counter + + if isinstance(value, list): + if len(value) > threshold: + counter += 1 + ref = f"ref_{counter}" + summary = { + "count": len(value), + "item_type": _item_type(value), + } + refs[ref] = { + "path": path, + **summary, + } + return { + "_compact_ref": ref, + "_compact_summary": summary, + } + return [compact_value(item, path=f"{path}.{index}") for index, item in enumerate(value)] + + if isinstance(value, dict): + return {key: compact_value(item, path=f"{path}.{key}") for key, item in value.items()} + + return value + + compacted = dict(payload) + compacted["result"] = compact_value(payload.get("result"), path="result") + if refs: + compacted["compact_refs"] = refs + return compacted diff --git a/src/ableton_cli/contracts/__init__.py b/src/ableton_cli/contracts/__init__.py new file mode 100644 index 0000000..fe952dd --- /dev/null +++ b/src/ableton_cli/contracts/__init__.py @@ -0,0 +1,3 @@ +from .registry import validate_command_contract + +__all__ = ["validate_command_contract"] diff --git a/src/ableton_cli/contracts/registry.py b/src/ableton_cli/contracts/registry.py new file mode 100644 index 0000000..a423de4 --- /dev/null +++ b/src/ableton_cli/contracts/registry.py @@ -0,0 +1,110 @@ +from __future__ import annotations + +from typing import Any + +from ..errors import AppError, ExitCode +from .schema import ContractValidationError, validate_value + +_CONTRACTS: dict[str, dict[str, dict[str, Any]]] = { + "ping": { + "args": {"type": "object", "additional_properties": False}, + "result": { + "type": "object", + "required": ["host", "port", "rtt_ms"], + "properties": { + "host": {"type": "string"}, + "port": {"type": "integer"}, + "protocol_version": {"type": ["integer", "null"]}, + "remote_script_version": {"type": ["string", "null"]}, + "supported_commands": { + "type": ["array", "null"], + "items": {"type": "string"}, + }, + "command_set_hash": {"type": ["string", "null"]}, + "api_support": {"type": ["object", "null"]}, + "rtt_ms": {"type": "number"}, + }, + }, + }, + "doctor": { + "args": {"type": "object", "additional_properties": False}, + "result": { + "type": "object", + "required": ["summary", "checks"], + "properties": { + "summary": { + "type": "object", + "required": ["pass", "warn", "fail"], + "properties": { + "pass": {"type": "integer"}, + "warn": {"type": "integer"}, + "fail": {"type": "integer"}, + }, + }, + "checks": {"type": "array"}, + }, + }, + }, + "song info": { + "args": {"type": "object", "additional_properties": False}, + "result": { + "type": "object", + "properties": { + "tempo": {"type": "number"}, + }, + }, + }, + "tracks list": { + "args": {"type": "object", "additional_properties": False}, + "result": { + "type": "object", + "required": ["tracks"], + "properties": { + "tracks": {"type": "array"}, + }, + }, + }, + "session diff": { + "args": { + "type": "object", + "required": ["from", "to"], + "properties": { + "from": {"type": "string"}, + "to": {"type": "string"}, + }, + }, + "result": { + "type": "object", + "required": ["from_path", "to_path", "added", "removed", "changed"], + "properties": { + "from_path": {"type": "string"}, + "to_path": {"type": "string"}, + "added": {"type": "object"}, + "removed": {"type": "object"}, + "changed": {"type": "object"}, + }, + }, + }, +} + + +def validate_command_contract(*, command: str, args: dict[str, Any], result: Any) -> None: + contract = _CONTRACTS.get(command) + if contract is None: + return + + try: + validate_value(contract["args"], args, path="args") + validate_value(contract["result"], result, path="result") + except ContractValidationError as exc: + raise AppError( + error_code="PROTOCOL_INVALID_RESPONSE", + message=f"Contract validation failed for '{command}': {exc.path} {exc.message}", + hint="Fix the command contract or result payload shape.", + exit_code=ExitCode.PROTOCOL_MISMATCH, + details={ + "command": command, + "path": exc.path, + "reason": exc.message, + }, + ) from exc diff --git a/src/ableton_cli/contracts/schema.py b/src/ableton_cli/contracts/schema.py new file mode 100644 index 0000000..074aa95 --- /dev/null +++ b/src/ableton_cli/contracts/schema.py @@ -0,0 +1,92 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any + + +@dataclass(slots=True) +class ContractValidationError(Exception): + path: str + message: str + + +def _matches_type(expected_type: str, value: Any) -> bool: + if expected_type == "object": + return isinstance(value, dict) + if expected_type == "array": + return isinstance(value, list) + if expected_type == "string": + return isinstance(value, str) + if expected_type == "integer": + return isinstance(value, int) and not isinstance(value, bool) + if expected_type == "number": + return (isinstance(value, int) and not isinstance(value, bool)) or isinstance(value, float) + if expected_type == "boolean": + return isinstance(value, bool) + if expected_type == "null": + return value is None + if expected_type == "any": + return True + raise RuntimeError(f"Unsupported schema type: {expected_type}") + + +def _validate_type(expected: str | list[str], value: Any, *, path: str) -> None: + if isinstance(expected, str): + expected_types = [expected] + else: + expected_types = list(expected) + if any(_matches_type(name, value) for name in expected_types): + return + expected_label = "|".join(expected_types) + actual_type = type(value).__name__ + raise ContractValidationError( + path=path, + message=f"expected {expected_label}, got {actual_type}", + ) + + +def validate_value(schema: dict[str, Any], value: Any, *, path: str) -> None: + expected_type = schema.get("type", "any") + _validate_type(expected_type, value, path=path) + + if isinstance(expected_type, list): + is_object = "object" in expected_type and isinstance(value, dict) + is_array = "array" in expected_type and isinstance(value, list) + else: + is_object = expected_type == "object" and isinstance(value, dict) + is_array = expected_type == "array" and isinstance(value, list) + + if is_object: + required = schema.get("required", []) + for key in required: + if key not in value: + raise ContractValidationError( + path=f"{path}.{key}", + message="is required", + ) + properties = schema.get("properties", {}) + additional_properties = schema.get("additional_properties", True) + for key, item in value.items(): + if key not in properties: + if additional_properties: + continue + raise ContractValidationError( + path=f"{path}.{key}", + message="is not allowed", + ) + child_schema = properties[key] + validate_value(child_schema, item, path=f"{path}.{key}") + return + + if is_array: + items_schema = schema.get("items") + min_items = schema.get("min_items") + if isinstance(min_items, int) and len(value) < min_items: + raise ContractValidationError( + path=path, + message=f"expected at least {min_items} items, got {len(value)}", + ) + if not isinstance(items_schema, dict): + return + for index, item in enumerate(value): + validate_value(items_schema, item, path=f"{path}[{index}]") diff --git a/src/ableton_cli/dev_checks.py b/src/ableton_cli/dev_checks.py index 2d7eedd..bbe8bef 100644 --- a/src/ableton_cli/dev_checks.py +++ b/src/ableton_cli/dev_checks.py @@ -11,6 +11,7 @@ DEFAULT_CHECK_COMMANDS: tuple[tuple[str, ...], ...] = ( ("uv", "run", "ruff", "check", "."), ("uv", "run", "ruff", "format", "--check", "."), + ("uv", "run", "python", "tools/generate_skill_docs.py", "--check"), ("uv", "run", "pytest"), ) PYTEST_COMMAND_PREFIX = ("uv", "run", "pytest") diff --git a/src/ableton_cli/errors.py b/src/ableton_cli/errors.py index 5cec127..914fc70 100644 --- a/src/ableton_cli/errors.py +++ b/src/ableton_cli/errors.py @@ -46,6 +46,10 @@ def to_payload(self) -> dict[str, Any]: "TIMEOUT": ExitCode.TIMEOUT, "BATCH_STEP_FAILED": ExitCode.EXECUTION_FAILED, "REMOTE_BUSY": ExitCode.EXECUTION_FAILED, + "READ_ONLY_VIOLATION": ExitCode.EXECUTION_FAILED, + "BATCH_PREFLIGHT_FAILED": ExitCode.EXECUTION_FAILED, + "BATCH_ASSERT_FAILED": ExitCode.EXECUTION_FAILED, + "BATCH_RETRY_EXHAUSTED": ExitCode.EXECUTION_FAILED, "INSTALL_TARGET_NOT_FOUND": ExitCode.EXECUTION_FAILED, "INTERNAL_ERROR": ExitCode.INTERNAL_ERROR, } diff --git a/src/ableton_cli/runtime.py b/src/ableton_cli/runtime.py index 39e2eee..39206cb 100644 --- a/src/ableton_cli/runtime.py +++ b/src/ableton_cli/runtime.py @@ -8,7 +8,9 @@ import typer from .client.ableton_client import AbletonClient +from .compact import compact_payload from .config import Settings +from .contracts import validate_command_contract from .errors import AppError, ExitCode from .output import ( OutputMode, @@ -28,11 +30,23 @@ class RuntimeContext: output_mode: OutputMode quiet: bool no_color: bool + record_path: str | None = None + replay_path: str | None = None + read_only: bool = False + compact: bool = False _client: AbletonClient | None = None def client(self) -> AbletonClient: if self._client is None: - self._client = AbletonClient(self.settings) + if self.record_path is None and self.replay_path is None and not self.read_only: + self._client = AbletonClient(self.settings) + else: + self._client = AbletonClient( + self.settings, + record_path=self.record_path, + replay_path=self.replay_path, + read_only=self.read_only, + ) return self._client @@ -61,8 +75,11 @@ def execute_command( try: result = action() + validate_command_contract(command=command, args=args, result=result) payload = success_payload(command=command, args=args, result=result) if runtime.output_mode == OutputMode.JSON: + if runtime.compact: + payload = compact_payload(payload) emit_json(payload) else: if human_formatter is not None and not runtime.quiet: diff --git a/src/ableton_cli/session_diff.py b/src/ableton_cli/session_diff.py new file mode 100644 index 0000000..ee53e81 --- /dev/null +++ b/src/ableton_cli/session_diff.py @@ -0,0 +1,56 @@ +from __future__ import annotations + +from typing import Any + + +def _diff_node(before: Any, after: Any) -> tuple[Any | None, Any | None, Any | None, bool]: + if isinstance(before, dict) and isinstance(after, dict): + added: dict[str, Any] = {} + removed: dict[str, Any] = {} + changed: dict[str, Any] = {} + + for key in sorted(set(before) | set(after)): + if key not in before: + added[key] = after[key] + continue + if key not in after: + removed[key] = before[key] + continue + + child_added, child_removed, child_changed, child_has_diff = _diff_node( + before[key], + after[key], + ) + if not child_has_diff: + continue + if child_added is not None: + added[key] = child_added + if child_removed is not None: + removed[key] = child_removed + if child_changed is not None: + changed[key] = child_changed + + has_diff = bool(added or removed or changed) + return ( + added if added else None, + removed if removed else None, + changed if changed else None, + has_diff, + ) + + if before != after: + return None, None, {"from": before, "to": after}, True + + return None, None, None, False + + +def compute_session_diff( + from_snapshot: dict[str, Any], + to_snapshot: dict[str, Any], +) -> dict[str, Any]: + added, removed, changed, _has_diff = _diff_node(from_snapshot, to_snapshot) + return { + "added": {} if added is None else added, + "removed": {} if removed is None else removed, + "changed": {} if changed is None else changed, + } diff --git a/tests/test_ableton_client.py b/tests/test_ableton_client.py index 62111b9..8618c6c 100644 --- a/tests/test_ableton_client.py +++ b/tests/test_ableton_client.py @@ -105,6 +105,39 @@ def _send(request: dict[str, Any]): # noqa: ANN202 assert [request["name"] for request in requests] == ["track_volume_set"] +def test_client_read_only_blocks_write_command_before_transport(monkeypatch) -> None: + client = AbletonClient(_settings(), read_only=True) + requests: list[dict[str, Any]] = [] + + def _send(request: dict[str, Any]): # noqa: ANN202 + requests.append(request) + return _ok_response(request, {"ok": True}) + + monkeypatch.setattr(client.transport, "send", _send) + + with pytest.raises(AppError) as exc_info: + client.track_volume_set(0, 0.5) + + assert exc_info.value.error_code == "READ_ONLY_VIOLATION" + assert requests == [] + + +def test_client_read_only_allows_read_command(monkeypatch) -> None: + client = AbletonClient(_settings(), read_only=True) + requests: list[dict[str, Any]] = [] + + def _send(request: dict[str, Any]): # noqa: ANN202 + requests.append(request) + return _ok_response(request, {"tempo": 120.0}) + + monkeypatch.setattr(client.transport, "send", _send) + + result = client.song_info() + + assert result["tempo"] == 120.0 + assert [request["name"] for request in requests] == ["song_info"] + + def test_client_sends_request_timeout_meta(monkeypatch) -> None: client = AbletonClient(_settings()) requests = _capture_requests(monkeypatch, client) diff --git a/tests/test_bootstrap.py b/tests/test_bootstrap.py index 9858fe5..ef5cdc7 100644 --- a/tests/test_bootstrap.py +++ b/tests/test_bootstrap.py @@ -56,6 +56,10 @@ def _resolve_settings_stub(*, cli_overrides, config_path): # noqa: ANN001, ANN2 config=config_path, no_color=True, quiet=False, + record="/tmp/session-record.jsonl", + replay=None, + read_only=True, + compact=True, ) assert runtime.settings is settings @@ -63,6 +67,10 @@ def _resolve_settings_stub(*, cli_overrides, config_path): # noqa: ANN001, ANN2 assert runtime.output_mode is OutputMode.JSON assert runtime.quiet is False assert runtime.no_color is True + assert runtime.record_path == "/tmp/session-record.jsonl" + assert runtime.replay_path is None + assert runtime.read_only is True + assert runtime.compact is True assert seen["config_path"] == config_path assert seen["cli_overrides"] == { "host": "127.0.0.1", @@ -102,6 +110,10 @@ def _raise_error(*, cli_overrides, config_path): # noqa: ANN001, ANN202 config=None, no_color=False, quiet=False, + record=None, + replay=None, + read_only=False, + compact=False, ) assert exc_info.value is expected diff --git a/tests/test_ci_workflow.py b/tests/test_ci_workflow.py index 9700fdc..b265794 100644 --- a/tests/test_ci_workflow.py +++ b/tests/test_ci_workflow.py @@ -18,6 +18,12 @@ def test_ci_workflow_configures_dev_checks_reports() -> None: test_job = jobs["test"] steps = test_job["steps"] + verify_generated_docs = next( + step for step in steps if step.get("name") == "Verify generated skill docs" + ) + assert "python tools/generate_skill_docs.py" in verify_generated_docs["run"] + assert "git diff --exit-code" in verify_generated_docs["run"] + run_lint_and_tests = next(step for step in steps if step.get("name") == "Run lint and tests") assert run_lint_and_tests["run"] == ( "uv run python -m ableton_cli.dev_checks " diff --git a/tests/test_cli_json_output.py b/tests/test_cli_json_output.py index c419433..6824ef3 100644 --- a/tests/test_cli_json_output.py +++ b/tests/test_cli_json_output.py @@ -1,6 +1,7 @@ from __future__ import annotations import json +from pathlib import Path def test_config_show_outputs_json_envelope(runner, cli_app, tmp_path) -> None: @@ -295,3 +296,168 @@ def test_protocol_version_global_option_overrides_config(runner, cli_app, tmp_pa payload = json.loads(result.stdout) assert payload["ok"] is True assert payload["result"]["protocol_version"] == 11 + + +def test_ping_supports_replay_option_without_network(runner, cli_app, tmp_path: Path) -> None: + replay_path = tmp_path / "ping-replay.jsonl" + replay_path.write_text( + json.dumps( + { + "request": {"name": "ping", "args": {}}, + "response": { + "ok": True, + "request_id": "recorded-request-id", + "protocol_version": 2, + "result": { + "protocol_version": 2, + "remote_script_version": "9.9.9", + "supported_commands": ["ping"], + "command_set_hash": "hash", + "api_support": {}, + }, + "error": None, + }, + "error": None, + } + ) + + "\n", + encoding="utf-8", + ) + + result = runner.invoke( + cli_app, + ["--output", "json", "--replay", str(replay_path), "ping"], + ) + + assert result.exit_code == 0 + payload = json.loads(result.stdout) + assert payload["ok"] is True + assert payload["result"]["remote_script_version"] == "9.9.9" + + +def test_record_and_replay_cannot_be_enabled_together(runner, cli_app, tmp_path: Path) -> None: + record_path = tmp_path / "record.jsonl" + replay_path = tmp_path / "replay.jsonl" + + result = runner.invoke( + cli_app, + [ + "--output", + "json", + "--record", + str(record_path), + "--replay", + str(replay_path), + "ping", + ], + ) + + assert result.exit_code == 2 + payload = json.loads(result.stdout) + assert payload["ok"] is False + assert payload["error"]["code"] == "INVALID_ARGUMENT" + + +def test_read_only_allows_read_commands(runner, cli_app, tmp_path: Path) -> None: + replay_path = tmp_path / "song-info-replay.jsonl" + replay_path.write_text( + json.dumps( + { + "request": {"name": "song_info", "args": {}}, + "response": { + "ok": True, + "request_id": "recorded-request-id", + "protocol_version": 2, + "result": {"tempo": 120.0}, + "error": None, + }, + "error": None, + } + ) + + "\n", + encoding="utf-8", + ) + + result = runner.invoke( + cli_app, + ["--output", "json", "--read-only", "--replay", str(replay_path), "song", "info"], + ) + + assert result.exit_code == 0 + payload = json.loads(result.stdout) + assert payload["ok"] is True + assert payload["result"]["tempo"] == 120.0 + + +def test_read_only_blocks_write_commands(runner, cli_app, tmp_path: Path) -> None: + replay_path = tmp_path / "empty-replay.jsonl" + replay_path.write_text("", encoding="utf-8") + + result = runner.invoke( + cli_app, + [ + "--output", + "json", + "--read-only", + "--replay", + str(replay_path), + "track", + "volume", + "set", + "0", + "0.5", + ], + ) + + assert result.exit_code == 20 + payload = json.loads(result.stdout) + assert payload["ok"] is False + assert payload["error"]["code"] == "READ_ONLY_VIOLATION" + + +def test_read_only_batch_rejects_write_steps(runner, cli_app, tmp_path: Path) -> None: + replay_path = tmp_path / "batch-replay.jsonl" + replay_path.write_text( + json.dumps( + { + "request": {"name": "tracks_list", "args": {}}, + "response": { + "ok": True, + "request_id": "recorded-request-id", + "protocol_version": 2, + "result": {"tracks": []}, + "error": None, + }, + "error": None, + } + ) + + "\n", + encoding="utf-8", + ) + + result = runner.invoke( + cli_app, + [ + "--output", + "json", + "--read-only", + "--replay", + str(replay_path), + "batch", + "run", + "--steps-json", + json.dumps( + { + "steps": [ + {"name": "tracks_list", "args": {}}, + {"name": "track_volume_set", "args": {"track": 0, "value": 0.5}}, + ] + } + ), + ], + ) + + assert result.exit_code == 20 + payload = json.loads(result.stdout) + assert payload["ok"] is False + assert payload["error"]["code"] == "READ_ONLY_VIOLATION" diff --git a/tests/test_cli_new_commands.py b/tests/test_cli_new_commands.py index c48811e..7844f1d 100644 --- a/tests/test_cli_new_commands.py +++ b/tests/test_cli_new_commands.py @@ -612,6 +612,12 @@ def execute_batch(self, steps: list[dict[str, object]]): # noqa: ANN201 "results": [{"index": idx, "result": {"ok": True}} for idx, _ in enumerate(steps)], } + def execute_remote_command(self, name: str, args: dict[str, object]): # noqa: ANN201 + handler = getattr(self, name, None) + if handler is None: + return {"name": name, "args": args, "ok": True} + return handler(**args) + def find_synth_devices( # noqa: ANN201 self, track: int | None, @@ -2611,17 +2617,11 @@ def test_effect_standard_wrapper_commands_output_json_envelope( class _BatchStreamClientStub: def __init__(self) -> None: - self.calls: list[list[dict[str, object]]] = [] + self.calls: list[tuple[str, dict[str, object]]] = [] - def execute_batch(self, steps: list[dict[str, object]]): # noqa: ANN201 - self.calls.append(steps) - return { - "step_count": len(steps), - "results": [ - {"index": index, "name": str(step["name"]), "result": {"ok": True}} - for index, step in enumerate(steps) - ], - } + def execute_remote_command(self, name: str, args: dict[str, object]): # noqa: ANN201 + self.calls.append((name, args)) + return {"ok": True, "name": name} def test_batch_stream_processes_multiple_lines_and_reuses_client( @@ -2652,8 +2652,8 @@ def _get_client(_ctx): # noqa: ANN202 assert [item["id"] for item in responses] == ["first", "second"] assert [item["ok"] for item in responses] == [True, True] assert client.calls == [ - [{"name": "tracks_list", "args": {}}], - [{"name": "song_info", "args": {}}], + ("tracks_list", {}), + ("song_info", {}), ] @@ -2693,3 +2693,228 @@ def test_batch_stream_emits_structured_line_errors_and_continues( assert responses[3]["id"] == "ok-2" assert responses[3]["ok"] is True assert len(client.calls) == 2 + + +def _timeout_error() -> AppError: + return AppError( + error_code="TIMEOUT", + message="timed out", + hint="retry", + exit_code=ExitCode.TIMEOUT, + ) + + +class _BatchAdvancedClientStub: + def __init__(self) -> None: + from ableton_cli.capabilities import compute_command_set_hash + + self.calls: list[tuple[str, dict[str, object]]] = [] + self._responses: dict[str, list[dict[str, object] | AppError]] = {} + supported = ["ping", "tracks_list", "song_info", "track_volume_set"] + self._ping_payload = { + "protocol_version": 2, + "supported_commands": supported, + "command_set_hash": compute_command_set_hash(supported), + "remote_script_version": "0.0.0", + "api_support": {}, + } + + def set_responses(self, name: str, items: list[dict[str, object] | AppError]) -> None: + self._responses[name] = list(items) + + def ping(self): # noqa: ANN201 + return dict(self._ping_payload) + + def execute_remote_command(self, name: str, args: dict[str, object]): # noqa: ANN201 + self.calls.append((name, args)) + queue = self._responses.get(name) + if queue: + item = queue.pop(0) + if isinstance(item, AppError): + raise item + return item + return {"ok": True, "name": name} + + +def test_batch_run_retries_only_configured_errors(runner, cli_app, monkeypatch) -> None: + from ableton_cli.commands import batch + + client = _BatchAdvancedClientStub() + client.set_responses("tracks_list", [_timeout_error(), {"tracks": []}]) + monkeypatch.setattr(batch, "get_client", lambda _ctx: client) + + result = runner.invoke( + cli_app, + [ + "--output", + "json", + "batch", + "run", + "--steps-json", + json.dumps( + { + "steps": [ + { + "name": "tracks_list", + "args": {}, + "retry": {"max_attempts": 3, "backoff_ms": 0, "on": ["TIMEOUT"]}, + } + ] + } + ), + ], + ) + + assert result.exit_code == 0 + payload = json.loads(result.stdout) + assert payload["ok"] is True + assert payload["result"]["results"][0]["attempts"] == 2 + + +def test_batch_run_fails_when_retry_is_exhausted(runner, cli_app, monkeypatch) -> None: + from ableton_cli.commands import batch + + client = _BatchAdvancedClientStub() + client.set_responses("tracks_list", [_timeout_error(), _timeout_error(), _timeout_error()]) + monkeypatch.setattr(batch, "get_client", lambda _ctx: client) + + result = runner.invoke( + cli_app, + [ + "--output", + "json", + "batch", + "run", + "--steps-json", + json.dumps( + { + "steps": [ + { + "name": "tracks_list", + "args": {}, + "retry": {"max_attempts": 3, "backoff_ms": 0, "on": ["TIMEOUT"]}, + } + ] + } + ), + ], + ) + + assert result.exit_code == 20 + payload = json.loads(result.stdout) + assert payload["ok"] is False + assert payload["error"]["code"] == "BATCH_RETRY_EXHAUSTED" + + +def test_batch_run_fails_when_assert_condition_does_not_match(runner, cli_app, monkeypatch) -> None: + from ableton_cli.commands import batch + + client = _BatchAdvancedClientStub() + client.set_responses("song_info", [{"tempo": 90.0}]) + client.set_responses("tracks_list", [{"tracks": []}]) + monkeypatch.setattr(batch, "get_client", lambda _ctx: client) + + result = runner.invoke( + cli_app, + [ + "--output", + "json", + "batch", + "run", + "--steps-json", + json.dumps( + { + "steps": [ + {"name": "song_info", "args": {}}, + { + "name": "tracks_list", + "args": {}, + "assert": { + "source": "previous", + "path": "tempo", + "op": "gte", + "value": 120.0, + }, + }, + ] + } + ), + ], + ) + + assert result.exit_code == 20 + payload = json.loads(result.stdout) + assert payload["ok"] is False + assert payload["error"]["code"] == "BATCH_ASSERT_FAILED" + + +def test_batch_run_preflight_blocks_on_protocol_mismatch(runner, cli_app, monkeypatch) -> None: + from ableton_cli.commands import batch + + client = _BatchAdvancedClientStub() + monkeypatch.setattr(batch, "get_client", lambda _ctx: client) + + result = runner.invoke( + cli_app, + [ + "--output", + "json", + "batch", + "run", + "--steps-json", + json.dumps( + { + "preflight": {"protocol_version": 99}, + "steps": [{"name": "tracks_list", "args": {}}], + } + ), + ], + ) + + assert result.exit_code == 20 + payload = json.loads(result.stdout) + assert payload["ok"] is False + assert payload["error"]["code"] == "BATCH_PREFLIGHT_FAILED" + + +def test_batch_stream_continues_after_assert_failure(runner, cli_app, monkeypatch) -> None: + from ableton_cli.commands import batch + + client = _BatchAdvancedClientStub() + client.set_responses("song_info", [{"tempo": 80.0}, {"tempo": 130.0}]) + monkeypatch.setattr(batch, "get_client", lambda _ctx: client) + + payload = "\n".join( + [ + json.dumps( + { + "id": "fail-assert", + "steps": [ + {"name": "song_info", "args": {}}, + { + "name": "tracks_list", + "args": {}, + "assert": { + "source": "previous", + "path": "tempo", + "op": "gte", + "value": 100.0, + }, + }, + ], + } + ), + json.dumps({"id": "ok", "steps": [{"name": "song_info", "args": {}}]}), + ] + ) + + result = runner.invoke(cli_app, ["batch", "stream"], input=f"{payload}\n") + + assert result.exit_code == 0 + responses = [json.loads(line) for line in result.stdout.splitlines() if line.strip()] + assert len(responses) == 2 + assert responses[0]["id"] == "fail-assert" + assert responses[0]["ok"] is False + assert responses[0]["error"]["code"] == "BATCH_ASSERT_FAILED" + assert responses[1]["id"] == "ok" + assert responses[1]["ok"] is True diff --git a/tests/test_contracts.py b/tests/test_contracts.py new file mode 100644 index 0000000..8059546 --- /dev/null +++ b/tests/test_contracts.py @@ -0,0 +1,74 @@ +from __future__ import annotations + +import json + + +def test_song_info_contract_accepts_numeric_tempo(runner, cli_app, monkeypatch) -> None: + from ableton_cli.commands import song + + class _ClientStub: + def song_info(self): # noqa: ANN201 + return {"tempo": 123.0} + + monkeypatch.setattr(song, "get_client", lambda _ctx: _ClientStub()) + + result = runner.invoke(cli_app, ["--output", "json", "song", "info"]) + + assert result.exit_code == 0 + payload = json.loads(result.stdout) + assert payload["ok"] is True + assert payload["result"]["tempo"] == 123.0 + + +def test_song_info_contract_rejects_invalid_result_shape(runner, cli_app, monkeypatch) -> None: + from ableton_cli.commands import song + + class _ClientStub: + def song_info(self): # noqa: ANN201 + return {"tempo": "fast"} + + monkeypatch.setattr(song, "get_client", lambda _ctx: _ClientStub()) + + result = runner.invoke(cli_app, ["--output", "json", "song", "info"]) + + assert result.exit_code == 13 + payload = json.loads(result.stdout) + assert payload["ok"] is False + assert payload["error"]["code"] == "PROTOCOL_INVALID_RESPONSE" + + +def test_tracks_list_contract_rejects_non_array_tracks(runner, cli_app, monkeypatch) -> None: + from ableton_cli.commands import tracks + + class _ClientStub: + def tracks_list(self): # noqa: ANN201 + return {"tracks": "not-array"} + + monkeypatch.setattr(tracks, "get_client", lambda _ctx: _ClientStub()) + + result = runner.invoke(cli_app, ["--output", "json", "tracks", "list"]) + + assert result.exit_code == 13 + payload = json.loads(result.stdout) + assert payload["ok"] is False + assert payload["error"]["code"] == "PROTOCOL_INVALID_RESPONSE" + + +def test_doctor_contract_rejects_invalid_summary_structure(runner, cli_app, monkeypatch) -> None: + from ableton_cli.commands import setup + + monkeypatch.setattr( + setup, + "run_doctor", + lambda _settings, *, platform_paths: { # noqa: ANN001 + "summary": {"pass": "bad", "warn": 0, "fail": 0}, + "checks": [], + }, + ) + + result = runner.invoke(cli_app, ["--output", "json", "doctor"]) + + assert result.exit_code == 13 + payload = json.loads(result.stdout) + assert payload["ok"] is False + assert payload["error"]["code"] == "PROTOCOL_INVALID_RESPONSE" diff --git a/tests/test_dev_checks.py b/tests/test_dev_checks.py index 4f288e0..05c9d09 100644 --- a/tests/test_dev_checks.py +++ b/tests/test_dev_checks.py @@ -25,7 +25,7 @@ def _run(command: tuple[str, ...], check: bool) -> SimpleNamespace: # noqa: ANN def test_run_default_checks_runs_all_commands_and_returns_failure(monkeypatch) -> None: commands: list[tuple[str, ...]] = [] - exits = [0, 1, 0] + exits = [0, 1, 0, 0] def _run(command: tuple[str, ...], check: bool) -> SimpleNamespace: # noqa: ANN202 assert check is False @@ -61,7 +61,7 @@ def _run(command: tuple[str, ...], check: bool) -> SimpleNamespace: # noqa: ANN def test_main_writes_report_json(monkeypatch, tmp_path: Path) -> None: - exits = [0, 0, 1] + exits = [0, 0, 0, 1] def _run(command: tuple[str, ...], check: bool) -> SimpleNamespace: # noqa: ANN202 assert check is False @@ -85,8 +85,8 @@ def _run(command: tuple[str, ...], check: bool) -> SimpleNamespace: # noqa: ANN assert report["schema_version"] == 1 assert report["status"] == "fail" assert report["exit_code"] == 1 - assert len(report["commands"]) == 3 - assert report["commands"][2]["command"] == [ + assert len(report["commands"]) == 4 + assert report["commands"][3]["command"] == [ "uv", "run", "pytest", diff --git a/tests/test_exit_codes.py b/tests/test_exit_codes.py index 8a33c4c..fb8a8b2 100644 --- a/tests/test_exit_codes.py +++ b/tests/test_exit_codes.py @@ -30,5 +30,9 @@ def test_remote_error_to_exit_code_mapping() -> None: assert exit_code_from_error_code("TIMEOUT") == ExitCode.TIMEOUT assert exit_code_from_error_code("BATCH_STEP_FAILED") == ExitCode.EXECUTION_FAILED assert exit_code_from_error_code("REMOTE_BUSY") == ExitCode.EXECUTION_FAILED + assert exit_code_from_error_code("READ_ONLY_VIOLATION") == ExitCode.EXECUTION_FAILED + assert exit_code_from_error_code("BATCH_PREFLIGHT_FAILED") == ExitCode.EXECUTION_FAILED + assert exit_code_from_error_code("BATCH_ASSERT_FAILED") == ExitCode.EXECUTION_FAILED + assert exit_code_from_error_code("BATCH_RETRY_EXHAUSTED") == ExitCode.EXECUTION_FAILED assert exit_code_from_error_code("INTERNAL_ERROR") == ExitCode.INTERNAL_ERROR assert exit_code_from_error_code("UNKNOWN") == ExitCode.EXECUTION_FAILED diff --git a/tests/test_quality_harness_config.py b/tests/test_quality_harness_config.py index 42c975f..18d1fb1 100644 --- a/tests/test_quality_harness_config.py +++ b/tests/test_quality_harness_config.py @@ -13,3 +13,11 @@ def test_commands_layer_includes_app_factory() -> None: commands_layer = next(rule for rule in config.layers.order if rule.name == "commands") assert "src/ableton_cli/app_factory.py" in commands_layer.include + + +def test_function_args_threshold_allows_cli_entry_points() -> None: + config = load_config(QUALITY_HARNESS_CONFIG) + + function_args_threshold = config.thresholds.function["args"] + assert function_args_threshold.warn == 8 + assert function_args_threshold.fail == 17 diff --git a/tests/test_runtime_quiet.py b/tests/test_runtime_quiet.py index a8a2cdd..f253a3f 100644 --- a/tests/test_runtime_quiet.py +++ b/tests/test_runtime_quiet.py @@ -36,7 +36,7 @@ def test_execute_command_quiet_suppresses_custom_human_formatter(monkeypatch) -> _context(quiet=True), command="doctor", args={}, - action=lambda: {"summary": {"pass": 1, "warn": 0, "fail": 0}}, + action=lambda: {"summary": {"pass": 1, "warn": 0, "fail": 0}, "checks": []}, human_formatter=lambda _: "Doctor Results", ) @@ -58,7 +58,7 @@ def test_execute_command_not_quiet_emits_custom_human_formatter(monkeypatch) -> _context(quiet=False), command="doctor", args={}, - action=lambda: {"summary": {"pass": 1, "warn": 0, "fail": 0}}, + action=lambda: {"summary": {"pass": 1, "warn": 0, "fail": 0}, "checks": []}, human_formatter=lambda _: "Doctor Results", ) diff --git a/tests/test_session_diff_compact.py b/tests/test_session_diff_compact.py new file mode 100644 index 0000000..bf314b4 --- /dev/null +++ b/tests/test_session_diff_compact.py @@ -0,0 +1,135 @@ +from __future__ import annotations + +import json + + +def test_session_diff_reports_added_removed_and_changed(runner, cli_app, tmp_path) -> None: + from_path = tmp_path / "from.json" + to_path = tmp_path / "to.json" + from_path.write_text( + json.dumps( + { + "song_info": {"tempo": 120.0, "is_playing": False}, + "tracks_list": {"tracks": [{"index": 0, "name": "Kick"}]}, + } + ), + encoding="utf-8", + ) + to_path.write_text( + json.dumps( + { + "song_info": {"tempo": 128.0, "is_playing": False}, + "tracks_list": { + "tracks": [{"index": 0, "name": "Kick"}, {"index": 1, "name": "Bass"}] + }, + "scenes_list": {"scenes": [{"index": 0, "name": "Intro"}]}, + } + ), + encoding="utf-8", + ) + + result = runner.invoke( + cli_app, + [ + "--output", + "json", + "session", + "diff", + "--from", + str(from_path), + "--to", + str(to_path), + ], + ) + + assert result.exit_code == 0 + payload = json.loads(result.stdout) + assert payload["ok"] is True + assert payload["result"]["changed"]["song_info"]["tempo"]["from"] == 120.0 + assert payload["result"]["changed"]["song_info"]["tempo"]["to"] == 128.0 + assert payload["result"]["added"]["scenes_list"]["scenes"][0]["name"] == "Intro" + assert payload["result"]["changed"]["tracks_list"]["tracks"]["to"][1]["name"] == "Bass" + + +def test_session_diff_output_is_stable(runner, cli_app, tmp_path) -> None: + from_path = tmp_path / "from.json" + to_path = tmp_path / "to.json" + from_path.write_text(json.dumps({"song_info": {"tempo": 120.0}}), encoding="utf-8") + to_path.write_text(json.dumps({"song_info": {"tempo": 121.0}}), encoding="utf-8") + + first = runner.invoke( + cli_app, + [ + "--output", + "json", + "session", + "diff", + "--from", + str(from_path), + "--to", + str(to_path), + ], + ) + second = runner.invoke( + cli_app, + [ + "--output", + "json", + "session", + "diff", + "--from", + str(from_path), + "--to", + str(to_path), + ], + ) + + assert first.exit_code == 0 + assert second.exit_code == 0 + assert json.loads(first.stdout)["result"] == json.loads(second.stdout)["result"] + + +def test_compact_reduces_large_json_arrays(runner, cli_app, monkeypatch) -> None: + from ableton_cli.commands import tracks + + class _ClientStub: + def tracks_list(self): # noqa: ANN201 + return {"tracks": [{"index": index, "name": f"Track {index}"} for index in range(40)]} + + monkeypatch.setattr(tracks, "get_client", lambda _ctx: _ClientStub()) + + result = runner.invoke( + cli_app, + ["--output", "json", "--compact", "tracks", "list"], + ) + + assert result.exit_code == 0 + payload = json.loads(result.stdout) + assert payload["ok"] is True + compacted_tracks = payload["result"]["tracks"] + assert isinstance(compacted_tracks, dict) + assert "_compact_ref" in compacted_tracks + assert "_compact_summary" in compacted_tracks + ref = compacted_tracks["_compact_ref"] + assert payload["compact_refs"][ref]["count"] == 40 + + +def test_compact_keeps_small_arrays_unmodified(runner, cli_app, monkeypatch) -> None: + from ableton_cli.commands import tracks + + class _ClientStub: + def tracks_list(self): # noqa: ANN201 + return {"tracks": [{"index": 0, "name": "Only"}]} + + monkeypatch.setattr(tracks, "get_client", lambda _ctx: _ClientStub()) + + result = runner.invoke( + cli_app, + ["--output", "json", "--compact", "tracks", "list"], + ) + + assert result.exit_code == 0 + payload = json.loads(result.stdout) + assert payload["ok"] is True + assert isinstance(payload["result"]["tracks"], list) + assert "compact_refs" not in payload diff --git a/tests/test_skill_docs.py b/tests/test_skill_docs.py index c815061..e91df25 100644 --- a/tests/test_skill_docs.py +++ b/tests/test_skill_docs.py @@ -1,94 +1,22 @@ from __future__ import annotations import re +import subprocess from pathlib import Path from typer.main import get_command +from ableton_cli.actions import ( + STABLE_ACTION_MAPPINGS, + stable_action_command_map, + stable_action_names, +) from ableton_cli.cli import app REPO_ROOT = Path(__file__).resolve().parents[1] SKILL_DOC = REPO_ROOT / "skills" / "ableton-cli" / "SKILL.md" ACTIONS_DOC = REPO_ROOT / "docs" / "skills" / "skill-actions.md" -STABLE_ACTIONS = ( - "ping", - "get_song_info", - "song_new", - "song_save", - "song_export_audio", - "get_session_info", - "get_track_info", - "play", - "stop", - "arrangement_record_start", - "arrangement_record_stop", - "set_tempo", - "transport_position_get", - "transport_position_set", - "transport_rewind", - "list_tracks", - "create_midi_track", - "create_audio_track", - "tracks_delete", - "set_track_name", - "set_track_volume", - "get_track_mute", - "set_track_mute", - "get_track_solo", - "set_track_solo", - "get_track_arm", - "set_track_arm", - "get_track_panning", - "set_track_panning", - "create_clip", - "add_notes_to_clip", - "get_clip_notes", - "clear_clip_notes", - "replace_clip_notes", - "arrangement_clip_notes_add", - "arrangement_clip_notes_get", - "arrangement_clip_notes_clear", - "arrangement_clip_notes_replace", - "arrangement_clip_notes_import_browser", - "arrangement_clip_delete", - "arrangement_from_session", - "clip_duplicate", - "set_clip_name", - "fire_clip", - "stop_clip", - "list_scenes", - "create_scene", - "set_scene_name", - "fire_scene", - "scenes_move", - "stop_all_clips", - "get_browser_tree", - "get_browser_items_at_path", - "get_browser_item", - "get_browser_categories", - "get_browser_items", - "search_browser_items", - "load_instrument_or_effect", - "load_drum_kit", - "set_device_parameter", - "find_synth_devices", - "list_synth_parameters", - "set_synth_parameter_safe", - "observe_synth_parameters", - "list_standard_synth_keys", - "set_standard_synth_parameter_safe", - "observe_standard_synth_state", - "find_effect_devices", - "list_effect_parameters", - "set_effect_parameter_safe", - "observe_effect_parameters", - "list_standard_effect_keys", - "set_standard_effect_parameter_safe", - "observe_standard_effect_state", - "execute_batch", -) - def _read(path: Path) -> str: return path.read_text(encoding="utf-8") @@ -150,17 +78,22 @@ def test_skill_doc_frontmatter_is_minimal() -> None: def test_stable_action_names_are_complete_and_unique() -> None: - assert len(STABLE_ACTIONS) == 75 - assert len(set(STABLE_ACTIONS)) == 75 + names = stable_action_names() + assert len(STABLE_ACTION_MAPPINGS) == 75 + assert len(names) == 75 + assert len(set(names)) == 75 def test_action_mappings_are_consistent_between_docs() -> None: skill_doc_mapping = _extract_skill_doc_mapping(_read(SKILL_DOC)) action_doc_mapping = _extract_action_doc_mapping(_read(ACTIONS_DOC)) + expected_action_names = set(stable_action_names()) + expected_mapping = stable_action_command_map() - assert set(skill_doc_mapping) == set(STABLE_ACTIONS) - assert set(action_doc_mapping) == set(STABLE_ACTIONS) - assert skill_doc_mapping == action_doc_mapping + assert set(skill_doc_mapping) == expected_action_names + assert set(action_doc_mapping) == expected_action_names + assert skill_doc_mapping == expected_mapping + assert action_doc_mapping == expected_mapping for command in skill_doc_mapping.values(): assert command.startswith("uv run ableton-cli ") @@ -172,3 +105,14 @@ def test_skill_doc_covers_all_leaf_cli_commands() -> None: assert re.search(pattern, markdown, flags=re.MULTILINE), ( f"missing command documentation for: {command}" ) + + +def test_generated_skill_docs_are_up_to_date() -> None: + result = subprocess.run( + ("uv", "run", "python", "tools/generate_skill_docs.py", "--check"), + cwd=REPO_ROOT, + check=False, + capture_output=True, + text=True, + ) + assert result.returncode == 0, result.stdout + result.stderr diff --git a/tests/test_transport_record_replay.py b/tests/test_transport_record_replay.py new file mode 100644 index 0000000..dc303c6 --- /dev/null +++ b/tests/test_transport_record_replay.py @@ -0,0 +1,130 @@ +from __future__ import annotations + +import json +from pathlib import Path +from typing import Any + +import pytest + +from ableton_cli.client.transport import RecordingTransport, ReplayTransport +from ableton_cli.errors import AppError, ExitCode + + +def _write_jsonl(path: Path, entries: list[dict[str, Any]]) -> None: + lines = [json.dumps(entry, ensure_ascii=False) for entry in entries] + path.write_text("\n".join(lines) + "\n", encoding="utf-8") + + +def test_replay_transport_returns_recorded_response_for_matching_name_args(tmp_path: Path) -> None: + replay_path = tmp_path / "replay.jsonl" + _write_jsonl( + replay_path, + [ + { + "request": {"name": "song_info", "args": {}}, + "response": { + "ok": True, + "request_id": "recorded-id", + "protocol_version": 2, + "result": {"tempo": 120.0}, + "error": None, + }, + } + ], + ) + + transport = ReplayTransport(path=str(replay_path)) + payload = { + "name": "song_info", + "args": {}, + "request_id": "runtime-id", + "protocol_version": 2, + } + + response = transport.send(payload) + + assert response["ok"] is True + assert response["request_id"] == "runtime-id" + assert response["result"] == {"tempo": 120.0} + + +def test_replay_transport_rejects_unmatched_name_or_args(tmp_path: Path) -> None: + replay_path = tmp_path / "replay.jsonl" + _write_jsonl( + replay_path, + [ + { + "request": {"name": "song_info", "args": {}}, + "response": { + "ok": True, + "request_id": "recorded-id", + "protocol_version": 2, + "result": {"tempo": 120.0}, + "error": None, + }, + } + ], + ) + transport = ReplayTransport(path=str(replay_path)) + + with pytest.raises(AppError) as exc_info: + transport.send( + { + "name": "tracks_list", + "args": {}, + "request_id": "runtime-id", + "protocol_version": 2, + } + ) + + assert exc_info.value.error_code == "PROTOCOL_INVALID_RESPONSE" + assert exc_info.value.exit_code == ExitCode.PROTOCOL_MISMATCH + + +def test_recording_transport_writes_request_and_response_entries(tmp_path: Path) -> None: + record_path = tmp_path / "record.jsonl" + + class _InnerTransport: + def send(self, payload: dict[str, Any]) -> dict[str, Any]: + return {"ok": True, "echo": payload} + + transport = RecordingTransport(inner=_InnerTransport(), path=str(record_path)) + request_payload = {"name": "ping", "args": {}, "request_id": "r-1", "protocol_version": 2} + + response = transport.send(request_payload) + + assert response["ok"] is True + lines = record_path.read_text(encoding="utf-8").strip().splitlines() + assert len(lines) == 1 + entry = json.loads(lines[0]) + assert entry["request"] == request_payload + assert entry["response"]["ok"] is True + assert entry["error"] is None + + +def test_recording_transport_writes_error_entries(tmp_path: Path) -> None: + record_path = tmp_path / "record.jsonl" + + class _InnerTransport: + def send(self, payload: dict[str, Any]) -> dict[str, Any]: + del payload + raise AppError( + error_code="TIMEOUT", + message="timed out", + hint="retry", + exit_code=ExitCode.TIMEOUT, + ) + + transport = RecordingTransport(inner=_InnerTransport(), path=str(record_path)) + request_payload = {"name": "ping", "args": {}, "request_id": "r-1", "protocol_version": 2} + + with pytest.raises(AppError) as exc_info: + transport.send(request_payload) + assert exc_info.value.error_code == "TIMEOUT" + + lines = record_path.read_text(encoding="utf-8").strip().splitlines() + assert len(lines) == 1 + entry = json.loads(lines[0]) + assert entry["request"] == request_payload + assert entry["response"] is None + assert entry["error"]["error_code"] == "TIMEOUT" diff --git a/tools/generate_skill_docs.py b/tools/generate_skill_docs.py new file mode 100644 index 0000000..10611d3 --- /dev/null +++ b/tools/generate_skill_docs.py @@ -0,0 +1,126 @@ +from __future__ import annotations + +import argparse +from dataclasses import dataclass +from pathlib import Path + +from ableton_cli.actions import STABLE_ACTION_MAPPINGS + +REPO_ROOT = Path(__file__).resolve().parents[1] +SKILL_DOC = REPO_ROOT / "skills" / "ableton-cli" / "SKILL.md" +ACTIONS_DOC = REPO_ROOT / "docs" / "skills" / "skill-actions.md" + +_SKILL_SECTION_START = "## Stable action names and mappings\n" +_SKILL_SECTION_END = "\n## Examples\n" +_ACTION_TABLE_START = "| Action | CLI command | Capability |\n| --- | --- | --- |\n" +_ACTION_TABLE_END = "\n## CLI-only commands (not stable actions)\n" + + +@dataclass(frozen=True, slots=True) +class GeneratedDocument: + path: Path + content: str + + +def _replace_between( + *, + text: str, + start_marker: str, + end_marker: str, + replacement: str, +) -> str: + start_index = text.find(start_marker) + if start_index == -1: + raise RuntimeError(f"start marker not found: {start_marker!r}") + end_index = text.find(end_marker, start_index) + if end_index == -1: + raise RuntimeError(f"end marker not found: {end_marker!r}") + return text[:start_index] + replacement + text[end_index:] + + +def _render_skill_mapping_section() -> str: + lines = ["## Stable action names and mappings", ""] + for mapping in STABLE_ACTION_MAPPINGS: + lines.append(f"- `{mapping.action}` -> `{mapping.command}`") + lines.append("") + return "\n".join(lines) + + +def _render_action_table() -> str: + lines = [ + "| Action | CLI command | Capability |", + "| --- | --- | --- |", + ] + for mapping in STABLE_ACTION_MAPPINGS: + lines.append(f"| `{mapping.action}` | `{mapping.command}` | {mapping.capability} |") + lines.append("") + return "\n".join(lines) + + +def generate_documents() -> tuple[GeneratedDocument, ...]: + skill_text = SKILL_DOC.read_text(encoding="utf-8") + actions_text = ACTIONS_DOC.read_text(encoding="utf-8") + + generated_skill = _replace_between( + text=skill_text, + start_marker=_SKILL_SECTION_START, + end_marker=_SKILL_SECTION_END, + replacement=_render_skill_mapping_section(), + ) + generated_actions = _replace_between( + text=actions_text, + start_marker=_ACTION_TABLE_START, + end_marker=_ACTION_TABLE_END, + replacement=_render_action_table(), + ) + return ( + GeneratedDocument(path=SKILL_DOC, content=generated_skill), + GeneratedDocument(path=ACTIONS_DOC, content=generated_actions), + ) + + +def _write_documents(documents: tuple[GeneratedDocument, ...]) -> None: + for item in documents: + item.path.write_text(item.content, encoding="utf-8") + + +def _diff_paths(documents: tuple[GeneratedDocument, ...]) -> list[Path]: + changed: list[Path] = [] + for item in documents: + current = item.path.read_text(encoding="utf-8") + if current != item.content: + changed.append(item.path) + return changed + + +def build_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser(description="Generate skill action docs from source mappings") + parser.add_argument( + "--check", + action="store_true", + help="Exit with code 1 when generated output differs from tracked files.", + ) + return parser + + +def main(argv: list[str] | None = None) -> int: + parser = build_parser() + args = parser.parse_args(argv) + generated = generate_documents() + changed = _diff_paths(generated) + + if args.check: + if changed: + for path in changed: + print(f"outdated generated file: {path}") + return 1 + return 0 + + _write_documents(generated) + for path in changed: + print(f"updated {path}") + return 0 + + +if __name__ == "__main__": + raise SystemExit(main())