diff --git a/dreadnode_cli/agent/cli.py b/dreadnode_cli/agent/cli.py index b21bdeb..910162a 100644 --- a/dreadnode_cli/agent/cli.py +++ b/dreadnode_cli/agent/cli.py @@ -328,11 +328,18 @@ def deploy( print(formatted) return - with Live(formatted, refresh_per_second=2) as live: - while run.is_running(): - time.sleep(1) - run = client.get_strike_run(run.id) - live.update(format_run(run)) + try: + with Live(formatted, refresh_per_second=2) as live: + while run.is_running(): + time.sleep(1) + run = client.get_strike_run(run.id) + live.update(format_run(run)) + except KeyboardInterrupt: + print("\n:warning: Terminating run...") + client.terminate_strike_run(run.id) + run = client.get_strike_run(run.id) + print(format_run(run)) + return @cli.command(help="List available models for the current (or specified) strike") diff --git a/dreadnode_cli/api.py b/dreadnode_cli/api.py index 2fcefd5..fa8e8a9 100644 --- a/dreadnode_cli/api.py +++ b/dreadnode_cli/api.py @@ -448,6 +448,11 @@ def get_strike_run(self, run: UUID | str) -> StrikeRunResponse: response = self.request("GET", f"/api/strikes/runs/{run}") return self.StrikeRunResponse(**response.json()) + def terminate_strike_run(self, run: UUID | str) -> StrikeRunResponse: + """Terminate a running strike.""" + response = self.request("POST", f"/api/strikes/runs/{run}/terminate") + return self.StrikeRunResponse(**response.json()) + def list_strike_runs(self, *, strike_id: UUID | str | None = None) -> list[StrikeRunSummaryResponse]: response = self.request( "GET", "/api/strikes/runs", query_params={"strike_id": str(strike_id)} if strike_id else None