diff --git a/.github/workflows/docs.yml b/.github/workflows/docs.yml index 9a82f18..e6cbbf9 100644 --- a/.github/workflows/docs.yml +++ b/.github/workflows/docs.yml @@ -19,7 +19,7 @@ jobs: uses: actions/checkout@v4 - name: Set up Python - uses: actions/setup-python@v4 + uses: actions/setup-python@v5 with: python-version: '3.9' @@ -36,15 +36,15 @@ jobs: - name: Setup Pages if: github.ref == 'refs/heads/main' - uses: actions/configure-pages@v3 + uses: actions/configure-pages@v5 - name: Upload artifact if: github.ref == 'refs/heads/main' - uses: actions/upload-pages-artifact@v2 + uses: actions/upload-pages-artifact@v3 with: path: 'docs/_build/html' - name: Deploy to GitHub Pages if: github.ref == 'refs/heads/main' id: deployment - uses: actions/deploy-pages@v2 + uses: actions/deploy-pages@v4 diff --git a/segmind/__init__.py b/segmind/__init__.py index 5382169..c6c9bd6 100644 --- a/segmind/__init__.py +++ b/segmind/__init__.py @@ -6,11 +6,19 @@ Usage: import segmind - # Run a model + # Run a model (sync v1) response = segmind.run("seedream-v3-text-to-image", prompt="A sunset") with open("image.jpg", "wb") as f: f.write(response.content) + # Run a model (v2 async — submit + poll until done) + result = segmind.run_async("seedance-1-pro", prompt="A sunset", timeout=300) + + # Or split the submit / wait for finer control + job = segmind.submit_async("seedance-1-pro", prompt="A sunset") + print(job.request_id) + result = job.wait(timeout=300) + # Upload files result = segmind.files.upload("image.png") print(result["file_urls"]) @@ -22,6 +30,13 @@ from typing import Optional from segmind.client import SegmindClient +from segmind.v2 import ( + DEFAULT_POLL_INTERVAL_S, + DEFAULT_POLL_TIMEOUT_S, + AsyncJob, + InferenceFailed, + InferenceTimeout, +) __version__ = "1.0.0" @@ -38,7 +53,7 @@ def _get_client() -> SegmindClient: def run(slug: str, **params): - """Run a model inference request. + """Run a sync (v1) model inference request. Args: slug: Model slug/identifier @@ -56,6 +71,55 @@ def run(slug: str, **params): return _get_client().run(slug, **params) +def submit_async(slug: str, **params) -> AsyncJob: + """Submit a v2 async inference request and return a job handle. + + The handle exposes `.wait()`, `.status()`, and `.result()`. Use `.wait()` + to block until COMPLETED or FAILED. Useful when you want to track the + request_id, run other work in parallel, or batch many submissions. + + Args: + slug: Model slug/identifier. + **params: Parameters to pass to the model. + + Example: + import segmind + job = segmind.submit_async("seedance-1-pro", prompt="A sunset") + print(job.request_id) + result = job.wait(timeout=300) + """ + return _get_client().submit_async(slug, **params) + + +def run_async( + slug: str, + *, + timeout: float = DEFAULT_POLL_TIMEOUT_S, + interval: float = DEFAULT_POLL_INTERVAL_S, + **params, +) -> dict: + """Run a v2 async inference request to completion (submit + poll). + + Args: + slug: Model slug/identifier. + timeout: Hard deadline in seconds (default 600s). + interval: Status-poll cadence (default 1.0s). + **params: Parameters to pass to the model. + + Returns: + The final response body once the task reaches COMPLETED. + + Raises: + segmind.InferenceFailed: server returned FAILED. + segmind.InferenceTimeout: timeout elapsed before terminal state. + + Example: + import segmind + result = segmind.run_async("seedance-1-pro", prompt="A sunset", timeout=300) + """ + return _get_client().run_async(slug, timeout=timeout, interval=interval, **params) + + # Namespace proxies class _Files: def upload(self, file_paths): @@ -123,11 +187,18 @@ def recent(self, model_name): generations = _Generations() __all__ = [ + "DEFAULT_POLL_INTERVAL_S", + "DEFAULT_POLL_TIMEOUT_S", + "AsyncJob", + "InferenceFailed", + "InferenceTimeout", "SegmindClient", "files", "generations", "models", "pixelflows", "run", + "run_async", + "submit_async", "webhooks", ] diff --git a/segmind/client.py b/segmind/client.py index 2e84499..77a4ced 100644 --- a/segmind/client.py +++ b/segmind/client.py @@ -3,6 +3,7 @@ import httpx +from segmind import v2 as _v2 from segmind.accounts import Accounts from segmind.exceptions import raise_for_status from segmind.files import Files @@ -72,6 +73,45 @@ def run(self, slug: str, **params) -> httpx.Response: raise_for_status(response) return response + def submit_async(self, slug: str, **params) -> "_v2.AsyncJob": + """Submit a v2 async inference request and return a job handle. + + The handle exposes `wait()`, `status()`, and `result()`. Use `wait()` + to block until COMPLETED or FAILED. Default poll interval 1.0s, + timeout 600s — override per-call for very slow models. + + Args: + slug: Model slug/identifier. + **params: Parameters to pass to the model. + + Returns: + AsyncJob handle wrapping the submit response. + """ + return _v2.submit(self, slug, **params) + + def run_async( + self, + slug: str, + *, + timeout: float = _v2.DEFAULT_POLL_TIMEOUT_S, + interval: float = _v2.DEFAULT_POLL_INTERVAL_S, + **params, + ) -> dict: + """One-shot v2 async inference: submit + wait. Returns the final + response body (the same dict you'd get from polling to COMPLETED). + + Args: + slug: Model slug/identifier. + timeout: Hard deadline in seconds (default 600s). + interval: Status-poll cadence (default 1.0s). + **params: Parameters to pass to the model. + + Raises: + v2.InferenceFailed: server returned FAILED. + v2.InferenceTimeout: timeout elapsed before terminal state. + """ + return _v2.run(self, slug, timeout=timeout, interval=interval, **params) + def stream(self, slug: str, **params) -> httpx.Response: """Stream a model inference request (not implemented). diff --git a/segmind/v2.py b/segmind/v2.py new file mode 100644 index 0000000..ea849e3 --- /dev/null +++ b/segmind/v2.py @@ -0,0 +1,238 @@ +# ruff: noqa: N818 exception names are deliberately non-Error-suffixed for natural reading in user code +"""v2 async inference for the Segmind Python SDK. + +The v2 path is a two-step submit-then-poll: `POST /v2/{slug}` returns a +`request_id` immediately; the actual result lands in Redis once a worker +finishes. Clients poll `/v2/requests/{id}/status` until the task hits +`COMPLETED` or `FAILED`, then GET `/v2/requests/{id}` for the body. + +This module provides: + + client.submit_async(slug, **params) -> AsyncJob + client.run_async(slug, **params) -> dict # submit + wait + AsyncJob.wait(timeout, interval) -> dict # block to completion + +Defaults are 1.0s poll interval, 600s overall timeout. For slugs known to +be slow (long video, long-running LLM), pass a larger `timeout` and +`interval` per call. For fire-and-forget patterns, use webhooks. +""" + +from __future__ import annotations + +import time +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, Any + +from segmind.exceptions import SegmindError, raise_for_status + +if TYPE_CHECKING: + from segmind.client import SegmindClient + + +DEFAULT_POLL_INTERVAL_S = 1.0 +DEFAULT_POLL_TIMEOUT_S = 600.0 + +# heimdall's /v2/requests/{id}/status returns HTTP 422 on FAILED while still +# carrying a status=FAILED body. We treat a body that announces a known +# terminal state as a valid payload regardless of the HTTP code. +_TERMINAL_STATES = ("COMPLETED", "FAILED") + + +class InferenceFailed(SegmindError): + """Raised when a v2 async request reaches the FAILED state. + + Server-provided error string is in `detail`. The status-endpoint body + is on `.status_body` for callers that want the raw payload; if you + need server-side metrics or a fuller failure record, call + `AsyncJob.result()` separately after catching. + """ + + status_body: dict[str, Any] + + def __init__(self, detail: str | None, status_body: dict[str, Any]) -> None: + super().__init__(status=None, detail=detail) + self.status_body = status_body + + +class InferenceTimeout(SegmindError): + """Raised when an `AsyncJob.wait()` exceeds its `timeout` before the + task reaches a terminal state. The job may still be running on the + server — re-fetch the status URL to confirm and recover the result. + """ + + request_id: str + elapsed_s: float + + def __init__(self, request_id: str, elapsed_s: float) -> None: + super().__init__( + status=None, + detail=f"v2 request {request_id!r} did not complete within {elapsed_s:.1f}s", + ) + self.request_id = request_id + self.elapsed_s = elapsed_s + + +@dataclass +class AsyncJob: + """Handle for a v2 async request that has been submitted but not yet completed. + + Returned by `SegmindClient.submit_async()`. Use `wait()` to block until + a terminal state, or poll `status()` manually if you need finer control. + """ + + request_id: str + status_url: str + response_url: str + submit_response: dict[str, Any] + _client: SegmindClient = field(repr=False) + + def status(self) -> dict[str, Any]: + """Fetch the current status payload without blocking. + + Returns the server's body for `GET /v2/requests/{id}/status`. The + `status` field is one of `QUEUED`, `PROCESSING`, `COMPLETED`, + `FAILED`. On `FAILED`, the body also includes `error`. + """ + return self._fetch_terminal_tolerant(self.status_url) + + def result(self) -> dict[str, Any]: + """Fetch the final response body. Only meaningful once status is + COMPLETED — for a FAILED task the body is also returned (heimdall + serves it under HTTP 422).""" + return self._fetch_terminal_tolerant(self.response_url) + + def _fetch_terminal_tolerant(self, url: str) -> dict[str, Any]: + """GET a v2 status / response URL, tolerating heimdall's 4xx-on-FAILED. + + heimdall returns the FAILED body under HTTP 422 on both `/status` and + `/requests/{id}`. The body itself still carries the terminal state + (`status="FAILED"`, plus `error`). Treat any body that announces a + recognised terminal state as a valid payload, regardless of HTTP code; + otherwise fall through to the existing `raise_for_status` so genuine + transport errors (401/404/5xx, missing body) still surface as + `SegmindError`. + """ + # Use the underlying httpx client directly so we can inspect the body + # before deciding whether the non-2xx is a transport error or a FAILED + # task body served with a 4xx code. + resp = self._client._client.request("GET", url) + try: + body = resp.json() + except ValueError: + body = {} + if isinstance(body, dict) and body.get("status") in _TERMINAL_STATES: + return body + if resp.is_success: + return body if isinstance(body, dict) else {} + raise_for_status(resp) + return body # unreachable; raise_for_status raised + + def wait( + self, + timeout: float = DEFAULT_POLL_TIMEOUT_S, + interval: float = DEFAULT_POLL_INTERVAL_S, + ) -> dict[str, Any]: + """Block until the task reaches a terminal state and return the result. + + Args: + timeout: Hard deadline in seconds. Raises `InferenceTimeout` if + exceeded. Default 600s. + interval: Sleep between status polls. Default 1.0s. + + Returns: + The server body from `GET /v2/requests/{id}` on COMPLETED. Shape + varies by model — every model carries `status`, `metrics`, and an + `output` key, but the rest of the response is model-specific + (e.g. image models include the image URL; mock-inference includes + `partial` / `reasoning`). Don't assume a fixed key set. + + Raises: + InferenceFailed: status reached FAILED. The server error string + is in `e.detail`; the raw status body in `e.status_body`. + InferenceTimeout: `timeout` elapsed before a terminal state. + """ + start = time.monotonic() + deadline = start + timeout + while True: + status_body = self.status() + state = status_body.get("status") + + if state == "COMPLETED": + return self.result() + + if state == "FAILED": + # /status already carries the error string for FAILED (heimdall + # SEG-97). Build the exception from the status body directly so + # we don't pay a second HTTP round-trip on every failure path. + raise InferenceFailed( + detail=status_body.get("error"), + status_body=status_body, + ) + + if time.monotonic() >= deadline: + raise InferenceTimeout( + request_id=self.request_id, + elapsed_s=time.monotonic() - start, + ) + + time.sleep(interval) + + +def submit(client: SegmindClient, slug: str, **params) -> AsyncJob: + """`POST /v2/{slug}`; return an AsyncJob handle for polling.""" + url = _v2_base(client) + "/" + slug.lstrip("/") + resp = client._request("POST", url, json=params) + body = resp.json() + + request_id = body.get("request_id") + status_url = body.get("status_url") + response_url = body.get("response_url") + if not (request_id and status_url and response_url): + # Server's contract is to always return these three on a successful + # submit. If we got a 2xx without them, something is genuinely off + # — fail loudly rather than poll forever on an unknown URL. + raise SegmindError( + status=resp.status_code, + detail=( + "v2 submit returned 2xx but is missing request_id / status_url / " + f"response_url; got keys={sorted(body.keys())}" + ), + ) + + return AsyncJob( + request_id=request_id, + status_url=status_url, + response_url=response_url, + submit_response=body, + _client=client, + ) + + +def run( + client: SegmindClient, + slug: str, + *, + timeout: float = DEFAULT_POLL_TIMEOUT_S, + interval: float = DEFAULT_POLL_INTERVAL_S, + **params, +) -> dict[str, Any]: + """One-shot convenience: submit and wait. Equivalent to + `client.submit_async(slug, **params).wait(timeout, interval)`. + """ + job = submit(client, slug, **params) + return job.wait(timeout=timeout, interval=interval) + + +def _v2_base(client: SegmindClient) -> str: + """Derive the v2 prefix from the client's `base_url`. + + The default client base_url is `https://api.segmind.com/v1`; the v2 + endpoint sits at `https://api.segmind.com/v2`. We strip the trailing + `/vN` segment and append `/v2` so callers who override base_url for + staging (`api-latest.segmind.com/v1`) keep working without extra + config. + """ + base = client.base_url.rstrip("/") + if "/" in base and base.rsplit("/", 1)[1].startswith("v"): + base = base.rsplit("/", 1)[0] + return base + "/v2" diff --git a/tests/test_v2.py b/tests/test_v2.py new file mode 100644 index 0000000..727c140 --- /dev/null +++ b/tests/test_v2.py @@ -0,0 +1,285 @@ +"""Unit tests for the v2 async helpers. + +All HTTP traffic is mocked with respx; no network required. +""" + +from __future__ import annotations + +import os +from unittest.mock import patch + +import httpx +import pytest +import respx + +import segmind +from segmind import AsyncJob, InferenceFailed, InferenceTimeout, SegmindClient + +API_HOST = "https://api.segmind.com" +SLUG = "mock-inference" +REQ_ID = "11111111-2222-3333-4444-555555555555" + +SUBMIT_BODY = { + "request_id": REQ_ID, + "status": "QUEUED", + "poll_url": f"{API_HOST}/v1/requests/{REQ_ID}", + "response_url": f"{API_HOST}/v2/requests/{REQ_ID}", + "status_url": f"{API_HOST}/v2/requests/{REQ_ID}/status", +} +RESULT_BODY = { + "status": "COMPLETED", + "error": None, + "metrics": {"inference_time": 0.5}, + "output": "ok", +} + + +@pytest.fixture +def client(): + """A SegmindClient with a fake API key so the test doesn't depend on env.""" + return SegmindClient(api_key="sk-test", base_url=f"{API_HOST}/v1") + + +# ---- submit ---------------------------------------------------------------- + + +@respx.mock +def test_submit_async_returns_job_with_urls(client): + respx.post(f"{API_HOST}/v2/{SLUG}").mock(return_value=httpx.Response(200, json=SUBMIT_BODY)) + + job = client.submit_async(SLUG, sleep=1, credits=1e-6) + + assert isinstance(job, AsyncJob) + assert job.request_id == REQ_ID + assert job.status_url.endswith(f"/v2/requests/{REQ_ID}/status") + assert job.response_url.endswith(f"/v2/requests/{REQ_ID}") + assert job.submit_response["status"] == "QUEUED" + + +@respx.mock +def test_submit_async_propagates_4xx_as_segmind_error(client): + respx.post(f"{API_HOST}/v2/{SLUG}").mock( + return_value=httpx.Response(401, json={"error": "Invalid API key"}) + ) + + with pytest.raises(segmind.SegmindError if hasattr(segmind, "SegmindError") else Exception): + client.submit_async(SLUG) + + +@respx.mock +def test_submit_raises_when_response_is_missing_request_id(client): + """If the server's 2xx body lacks request_id, we must fail loudly rather + than swallow it and poll forever on a missing URL.""" + respx.post(f"{API_HOST}/v2/{SLUG}").mock( + return_value=httpx.Response(200, json={"status": "QUEUED"}), + ) + + from segmind.exceptions import SegmindError + + with pytest.raises(SegmindError) as exc: + client.submit_async(SLUG) + + assert "missing request_id" in str(exc.value).lower() + + +# ---- wait ------------------------------------------------------------------ + + +@respx.mock +def test_wait_returns_result_on_completed(client): + respx.post(f"{API_HOST}/v2/{SLUG}").mock(return_value=httpx.Response(200, json=SUBMIT_BODY)) + respx.get(f"{API_HOST}/v2/requests/{REQ_ID}/status").mock( + return_value=httpx.Response(200, json={"status": "COMPLETED"}) + ) + respx.get(f"{API_HOST}/v2/requests/{REQ_ID}").mock( + return_value=httpx.Response(200, json=RESULT_BODY) + ) + + job = client.submit_async(SLUG) + out = job.wait(timeout=5, interval=0.01) + + assert out == RESULT_BODY + + +@respx.mock +def test_wait_polls_until_completed(client): + respx.post(f"{API_HOST}/v2/{SLUG}").mock(return_value=httpx.Response(200, json=SUBMIT_BODY)) + # First two polls report QUEUED then PROCESSING, third reports COMPLETED. + respx.get(f"{API_HOST}/v2/requests/{REQ_ID}/status").mock( + side_effect=[ + httpx.Response(200, json={"status": "QUEUED"}), + httpx.Response(200, json={"status": "PROCESSING"}), + httpx.Response(200, json={"status": "COMPLETED"}), + ] + ) + respx.get(f"{API_HOST}/v2/requests/{REQ_ID}").mock( + return_value=httpx.Response(200, json=RESULT_BODY) + ) + + out = client.submit_async(SLUG).wait(timeout=5, interval=0.01) + + assert out == RESULT_BODY + + +@respx.mock +def test_wait_raises_inference_failed_on_failed(client): + respx.post(f"{API_HOST}/v2/{SLUG}").mock(return_value=httpx.Response(200, json=SUBMIT_BODY)) + respx.get(f"{API_HOST}/v2/requests/{REQ_ID}/status").mock( + return_value=httpx.Response(200, json={"status": "FAILED", "error": "boom"}) + ) + # The response-URL mock must NOT be hit on FAILED — the exception is + # built from the status body alone. If wait() ever regresses to fetching + # the result on failure, this mock will fire and `result_route.called` + # below will flip True. + result_route = respx.get(f"{API_HOST}/v2/requests/{REQ_ID}").mock( + return_value=httpx.Response(200, json={"_unused": True}) + ) + + job = client.submit_async(SLUG) + with pytest.raises(InferenceFailed) as exc: + job.wait(timeout=5, interval=0.01) + + assert exc.value.detail == "boom" + assert exc.value.status_body["status"] == "FAILED" + assert result_route.called is False, "FAILED path should not GET the response URL" + + +@respx.mock +def test_wait_raises_inference_failed_on_4xx_terminal_body(client): + """Regression test for the heimdall behaviour reported on SEG-52: + /v2/requests/{id}/status returns HTTP 422 on FAILED while still carrying + a `status=FAILED` body. We must surface this as `InferenceFailed`, not + as a `SegmindError(422)` swallowing the failure detail. + """ + respx.post(f"{API_HOST}/v2/{SLUG}").mock(return_value=httpx.Response(200, json=SUBMIT_BODY)) + respx.get(f"{API_HOST}/v2/requests/{REQ_ID}/status").mock( + return_value=httpx.Response( + 422, + json={ + "status": "FAILED", + "error": "Validation error in MockInferenceProcessor: sleep must be …", + "metrics": {"inference_time": 0.008}, + "request_id": REQ_ID, + }, + ), + ) + + job = client.submit_async(SLUG) + with pytest.raises(InferenceFailed) as exc: + job.wait(timeout=5, interval=0.01) + + assert "Validation error" in (exc.value.detail or "") + assert exc.value.status_body["status"] == "FAILED" + + +@respx.mock +def test_wait_propagates_genuine_transport_error(client): + """A non-terminal 4xx (auth failure, missing resource, …) without a + `status=FAILED` body must still raise `SegmindError` — we don't want the + 422-tolerance above to swallow real transport failures.""" + respx.post(f"{API_HOST}/v2/{SLUG}").mock(return_value=httpx.Response(200, json=SUBMIT_BODY)) + respx.get(f"{API_HOST}/v2/requests/{REQ_ID}/status").mock( + return_value=httpx.Response(401, json={"error": "Invalid API key"}) + ) + + from segmind.exceptions import SegmindError + + job = client.submit_async(SLUG) + with pytest.raises(SegmindError) as exc: + job.wait(timeout=5, interval=0.01) + + assert exc.value.status == 401 + # And specifically NOT InferenceFailed — that's reserved for terminal-FAILED + # bodies. + assert not isinstance(exc.value, InferenceFailed) + + +@respx.mock +def test_wait_raises_inference_timeout(client): + respx.post(f"{API_HOST}/v2/{SLUG}").mock(return_value=httpx.Response(200, json=SUBMIT_BODY)) + respx.get(f"{API_HOST}/v2/requests/{REQ_ID}/status").mock( + return_value=httpx.Response(200, json={"status": "PROCESSING"}), + ) + + job = client.submit_async(SLUG) + with pytest.raises(InferenceTimeout) as exc: + job.wait(timeout=0.05, interval=0.01) + + assert exc.value.request_id == REQ_ID + assert exc.value.elapsed_s == pytest.approx(0.05, rel=0.5) + + +# ---- run_async one-shot ---------------------------------------------------- + + +@respx.mock +def test_run_async_one_shot(client): + respx.post(f"{API_HOST}/v2/{SLUG}").mock(return_value=httpx.Response(200, json=SUBMIT_BODY)) + respx.get(f"{API_HOST}/v2/requests/{REQ_ID}/status").mock( + return_value=httpx.Response(200, json={"status": "COMPLETED"}) + ) + respx.get(f"{API_HOST}/v2/requests/{REQ_ID}").mock( + return_value=httpx.Response(200, json=RESULT_BODY) + ) + + out = client.run_async(SLUG, sleep=1, credits=1e-6, timeout=5, interval=0.01) + + assert out == RESULT_BODY + + +# ---- staging base_url derivation ------------------------------------------- + + +@respx.mock +def test_v2_url_derives_from_staging_base(): + """If the caller overrides base_url for staging, v2 derives correctly.""" + staging_host = "https://api-latest.segmind.com" + client = SegmindClient(api_key="sk-test", base_url=f"{staging_host}/v1") + + respx.post(f"{staging_host}/v2/{SLUG}").mock( + return_value=httpx.Response( + 200, + json={ + **SUBMIT_BODY, + "status_url": f"{staging_host}/v2/requests/{REQ_ID}/status", + "response_url": f"{staging_host}/v2/requests/{REQ_ID}", + }, + ), + ) + + job = client.submit_async(SLUG) + assert staging_host in job.status_url + + +# ---- module-level helpers -------------------------------------------------- + + +@respx.mock +def test_module_level_run_async_uses_default_client(): + respx.post(f"{API_HOST}/v2/{SLUG}").mock(return_value=httpx.Response(200, json=SUBMIT_BODY)) + respx.get(f"{API_HOST}/v2/requests/{REQ_ID}/status").mock( + return_value=httpx.Response(200, json={"status": "COMPLETED"}) + ) + respx.get(f"{API_HOST}/v2/requests/{REQ_ID}").mock( + return_value=httpx.Response(200, json=RESULT_BODY) + ) + + with patch.dict(os.environ, {"SEGMIND_API_KEY": "sk-test"}): + # Reset the cached default client so it picks up the env var. + segmind._default_client = None + out = segmind.run_async(SLUG, sleep=1, timeout=5, interval=0.01) + + assert out == RESULT_BODY + + +# ---- module exports -------------------------------------------------------- + + +def test_module_exports_v2_symbols(): + assert hasattr(segmind, "submit_async") + assert hasattr(segmind, "run_async") + assert hasattr(segmind, "AsyncJob") + assert hasattr(segmind, "InferenceFailed") + assert hasattr(segmind, "InferenceTimeout") + assert "submit_async" in segmind.__all__ + assert "run_async" in segmind.__all__