Skip to content

Commit dc0cbc5

Browse files
authored
Merge pull request #1 from WaveSpeedAI/feat/waveless_env
fix: separate URL resolution for runpod and waverless environments
2 parents 83121bf + e6ff395 commit dc0cbc5

File tree

3 files changed

+131
-47
lines changed

3 files changed

+131
-47
lines changed

src/wavespeed/config.py

Lines changed: 86 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import os
44
import sys
5+
import uuid
56
from typing import Optional
67

78
from ._config_module import install_config_module
@@ -83,22 +84,41 @@ def _detect_serverless_env() -> Optional[str]:
8384
The serverless environment type ("runpod", "waverless") or None
8485
if not running in a known serverless environment.
8586
"""
86-
# Check for RunPod environment
87-
if os.environ.get("RUNPOD_POD_ID"):
88-
return "runpod"
8987

9088
# Check for native Waverless environment
91-
if os.environ.get("WAVERLESS_POD_ID"):
89+
if os.environ.get("WAVERLESS_ENDPOINT_ID"):
9290
return "waverless"
9391

92+
# Check for RunPod environment
93+
if os.environ.get("RUNPOD_ENDPOINT_ID"):
94+
return "runpod"
95+
9496
return None
9597

9698

97-
def _resolve_url(url_template: Optional[str], pod_id: str) -> Optional[str]:
98-
"""Replace pod ID placeholder in URL template.
99+
def _generate_pod_id(endpoint_id: Optional[str], raw_pod_id: Optional[str]) -> str:
100+
"""Generate or resolve pod_id.
99101
100-
Note: Only $RUNPOD_POD_ID is replaced here. The $ID placeholder is
101-
replaced later at runtime with the actual job ID in http._handle_result.
102+
Priority: raw_pod_id > DEVICE_ID > auto-generate
103+
104+
Args:
105+
endpoint_id: The endpoint identifier.
106+
raw_pod_id: The raw pod_id from environment variable.
107+
108+
Returns:
109+
The resolved pod_id.
110+
"""
111+
if raw_pod_id:
112+
return raw_pod_id
113+
device_id = os.environ.get("DEVICE_ID")
114+
if device_id:
115+
return device_id
116+
prefix = endpoint_id or "worker"
117+
return f"{prefix}-{uuid.uuid4().hex}"
118+
119+
120+
def _resolve_runpod_url(url_template: Optional[str], pod_id: str) -> Optional[str]:
121+
"""Replace pod ID placeholder in RunPod URL template.
102122
103123
Args:
104124
url_template: URL template with $RUNPOD_POD_ID placeholder.
@@ -112,26 +132,55 @@ def _resolve_url(url_template: Optional[str], pod_id: str) -> Optional[str]:
112132
return url_template.replace("$RUNPOD_POD_ID", pod_id)
113133

114134

135+
def _resolve_waverless_url(url_template: Optional[str], pod_id: str) -> Optional[str]:
136+
"""Replace pod ID placeholder in Waverless URL template.
137+
138+
Args:
139+
url_template: URL template with $WAVERLESS_POD_ID placeholder.
140+
pod_id: The worker/pod ID to substitute.
141+
142+
Returns:
143+
URL with $WAVERLESS_POD_ID placeholder replaced, or None if template is None.
144+
"""
145+
if not url_template:
146+
return None
147+
return url_template.replace("$WAVERLESS_POD_ID", pod_id)
148+
149+
115150
def _load_runpod_serverless_config() -> None:
116151
"""Load RunPod environment variables into serverless config."""
152+
# Endpoint identification (load first for pod_id generation)
153+
serverless.endpoint_id = os.environ.get("RUNPOD_ENDPOINT_ID")
154+
serverless.project_id = os.environ.get("RUNPOD_PROJECT_ID")
155+
117156
# Worker identification
118-
serverless.pod_id = os.environ.get("RUNPOD_POD_ID") or ""
157+
raw_pod_id = os.environ.get("RUNPOD_POD_ID")
158+
serverless.pod_id = _generate_pod_id(serverless.endpoint_id, raw_pod_id)
159+
serverless.pod_hostname = os.environ.get("RUNPOD_POD_HOSTNAME", serverless.pod_id)
119160

120161
# API endpoint templates
121162
serverless.webhook_get_job = os.environ.get("RUNPOD_WEBHOOK_GET_JOB")
122163
serverless.webhook_post_output = os.environ.get("RUNPOD_WEBHOOK_POST_OUTPUT")
123164
serverless.webhook_post_stream = os.environ.get("RUNPOD_WEBHOOK_POST_STREAM")
124165
serverless.webhook_ping = os.environ.get("RUNPOD_WEBHOOK_PING")
125166

126-
# Resolved API endpoints (with pod_id substituted)
127-
serverless.job_get_url = _resolve_url(serverless.webhook_get_job, serverless.pod_id)
128-
serverless.job_done_url = _resolve_url(
167+
# Resolved API endpoints (with $RUNPOD_POD_ID substituted)
168+
job_get_url = _resolve_runpod_url(serverless.webhook_get_job, serverless.pod_id)
169+
# job_get_url also needs $ID replaced with worker ID (like runpod-python)
170+
if job_get_url:
171+
job_get_url = job_get_url.replace("$ID", serverless.pod_id)
172+
serverless.job_get_url = job_get_url
173+
174+
# job_done_url keeps $ID for runtime replacement with job_id
175+
serverless.job_done_url = _resolve_runpod_url(
129176
serverless.webhook_post_output, serverless.pod_id
130177
)
131-
serverless.job_stream_url = _resolve_url(
178+
serverless.job_stream_url = _resolve_runpod_url(
132179
serverless.webhook_post_stream, serverless.pod_id
133180
)
134-
serverless.ping_url = _resolve_url(serverless.webhook_ping, serverless.pod_id)
181+
serverless.ping_url = _resolve_runpod_url(
182+
serverless.webhook_ping, serverless.pod_id
183+
)
135184

136185
# Authentication
137186
serverless.api_key = os.environ.get("RUNPOD_AI_API_KEY")
@@ -142,11 +191,6 @@ def _load_runpod_serverless_config() -> None:
142191
log_level = os.environ.get("RUNPOD_DEBUG_LEVEL")
143192
serverless.log_level = log_level or "INFO"
144193

145-
# Endpoint identification
146-
serverless.endpoint_id = os.environ.get("RUNPOD_ENDPOINT_ID")
147-
serverless.project_id = os.environ.get("RUNPOD_PROJECT_ID")
148-
serverless.pod_hostname = os.environ.get("RUNPOD_POD_HOSTNAME")
149-
150194
# Timing and concurrency
151195
ping_interval = os.environ.get("RUNPOD_PING_INTERVAL")
152196
if ping_interval:
@@ -163,36 +207,48 @@ def _load_runpod_serverless_config() -> None:
163207

164208
def _load_waverless_serverless_config() -> None:
165209
"""Load Waverless environment variables into serverless config."""
210+
# Endpoint identification (load first for pod_id generation)
211+
serverless.endpoint_id = os.environ.get("WAVERLESS_ENDPOINT_ID")
212+
# Endpoint identification (endpoint_id already set above)
213+
serverless.project_id = os.environ.get("WAVERLESS_PROJECT_ID")
214+
166215
# Worker identification
167-
serverless.pod_id = os.environ.get("WAVERLESS_POD_ID") or ""
216+
raw_pod_id = os.environ.get("WAVERLESS_POD_ID")
217+
serverless.pod_id = _generate_pod_id(serverless.endpoint_id, raw_pod_id)
218+
serverless.pod_hostname = os.environ.get(
219+
"WAVERLESS_POD_HOSTNAME", serverless.pod_id
220+
)
168221

169222
# API endpoint templates
170223
serverless.webhook_get_job = os.environ.get("WAVERLESS_WEBHOOK_GET_JOB")
171224
serverless.webhook_post_output = os.environ.get("WAVERLESS_WEBHOOK_POST_OUTPUT")
172225
serverless.webhook_post_stream = os.environ.get("WAVERLESS_WEBHOOK_POST_STREAM")
173226
serverless.webhook_ping = os.environ.get("WAVERLESS_WEBHOOK_PING")
174227

175-
# Resolved API endpoints (with pod_id substituted)
176-
serverless.job_get_url = _resolve_url(serverless.webhook_get_job, serverless.pod_id)
177-
serverless.job_done_url = _resolve_url(
228+
# Resolved API endpoints (with $WAVERLESS_POD_ID substituted)
229+
job_get_url = _resolve_waverless_url(serverless.webhook_get_job, serverless.pod_id)
230+
# job_get_url also needs $ID replaced with worker ID (like runpod)
231+
if job_get_url:
232+
job_get_url = job_get_url.replace("$ID", serverless.pod_id)
233+
serverless.job_get_url = job_get_url
234+
235+
# job_done_url keeps $ID for runtime replacement with job_id
236+
serverless.job_done_url = _resolve_waverless_url(
178237
serverless.webhook_post_output, serverless.pod_id
179238
)
180-
serverless.job_stream_url = _resolve_url(
239+
serverless.job_stream_url = _resolve_waverless_url(
181240
serverless.webhook_post_stream, serverless.pod_id
182241
)
183-
serverless.ping_url = _resolve_url(serverless.webhook_ping, serverless.pod_id)
242+
serverless.ping_url = _resolve_waverless_url(
243+
serverless.webhook_ping, serverless.pod_id
244+
)
184245

185246
# Authentication
186247
serverless.api_key = os.environ.get("WAVERLESS_API_KEY")
187248

188249
# Logging
189250
serverless.log_level = os.environ.get("WAVERLESS_LOG_LEVEL", "INFO")
190251

191-
# Endpoint identification
192-
serverless.endpoint_id = os.environ.get("WAVERLESS_ENDPOINT_ID")
193-
serverless.project_id = os.environ.get("WAVERLESS_PROJECT_ID")
194-
serverless.pod_hostname = os.environ.get("WAVERLESS_POD_HOSTNAME")
195-
196252
# Timing and concurrency
197253
ping_interval = os.environ.get("WAVERLESS_PING_INTERVAL")
198254
if ping_interval:

src/wavespeed/serverless/modules/fastapi.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -220,6 +220,15 @@ def __init__(self, config: Dict[str, Any]):
220220
tags=["Status"],
221221
)
222222

223+
# Health check endpoint
224+
router.add_api_route(
225+
"/health",
226+
lambda: {"status": "ok"},
227+
methods=["GET"],
228+
summary="Health check",
229+
tags=["Status"],
230+
)
231+
223232
self.app.include_router(router)
224233

225234
def start(

tests/test_config.py

Lines changed: 36 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -2,51 +2,70 @@
22

33
import unittest
44

5-
from wavespeed.config import _resolve_url, serverless
5+
from wavespeed.config import _resolve_runpod_url, _resolve_waverless_url, serverless
66

77

8-
class TestResolveUrl(unittest.TestCase):
9-
"""Tests for the _resolve_url function."""
8+
class TestResolveRunpodUrl(unittest.TestCase):
9+
"""Tests for the _resolve_runpod_url function."""
1010

1111
def test_replaces_runpod_pod_id(self):
1212
"""Test that $RUNPOD_POD_ID is replaced with pod_id."""
1313
template = "https://api.runpod.ai/v2/endpoint/job-done/$RUNPOD_POD_ID"
14-
result = _resolve_url(template, "my-pod-123")
14+
result = _resolve_runpod_url(template, "my-pod-123")
1515
self.assertEqual(
1616
result, "https://api.runpod.ai/v2/endpoint/job-done/my-pod-123"
1717
)
1818

1919
def test_preserves_id_placeholder(self):
2020
"""Test that $ID is NOT replaced - it's for job ID at runtime."""
2121
template = "https://api.runpod.ai/v2/endpoint/job-done/$RUNPOD_POD_ID/$ID"
22-
result = _resolve_url(template, "my-pod-123")
23-
# $ID should remain as placeholder for job ID replacement later
22+
result = _resolve_runpod_url(template, "my-pod-123")
2423
self.assertEqual(
2524
result, "https://api.runpod.ai/v2/endpoint/job-done/my-pod-123/$ID"
2625
)
2726

2827
def test_handles_none_template(self):
2928
"""Test that None template returns None."""
30-
result = _resolve_url(None, "my-pod-123")
29+
result = _resolve_runpod_url(None, "my-pod-123")
3130
self.assertIsNone(result)
3231

33-
def test_handles_empty_template(self):
34-
"""Test that empty template returns None (falsy check)."""
35-
result = _resolve_url("", "my-pod-123")
32+
def test_no_placeholders(self):
33+
"""Test URL without any placeholders."""
34+
template = "https://api.example.com/endpoint"
35+
result = _resolve_runpod_url(template, "my-pod-123")
36+
self.assertEqual(result, "https://api.example.com/endpoint")
37+
38+
39+
class TestResolveWaverlessUrl(unittest.TestCase):
40+
"""Tests for the _resolve_waverless_url function."""
41+
42+
def test_replaces_waverless_pod_id_placeholder(self):
43+
"""Test that $WAVERLESS_POD_ID is replaced with pod_id."""
44+
template = "https://api.wavespeed.ai/v2/test/job-take/$WAVERLESS_POD_ID"
45+
result = _resolve_waverless_url(template, "my-pod-123")
46+
self.assertEqual(
47+
result, "https://api.wavespeed.ai/v2/test/job-take/my-pod-123"
48+
)
49+
50+
def test_preserves_id_placeholder(self):
51+
"""Test that $ID is NOT replaced - it's for job/worker ID at runtime."""
52+
template = "https://api.wavespeed.ai/v2/test/job-done/$WAVERLESS_POD_ID/$ID"
53+
result = _resolve_waverless_url(template, "my-pod-123")
54+
self.assertEqual(
55+
result, "https://api.wavespeed.ai/v2/test/job-done/my-pod-123/$ID"
56+
)
57+
58+
def test_handles_none_template(self):
59+
"""Test that None template returns None."""
60+
result = _resolve_waverless_url(None, "my-pod-123")
3661
self.assertIsNone(result)
3762

3863
def test_no_placeholders(self):
3964
"""Test URL without any placeholders."""
4065
template = "https://api.example.com/endpoint"
41-
result = _resolve_url(template, "my-pod-123")
66+
result = _resolve_waverless_url(template, "my-pod-123")
4267
self.assertEqual(result, "https://api.example.com/endpoint")
4368

44-
def test_multiple_pod_id_placeholders(self):
45-
"""Test multiple $RUNPOD_POD_ID placeholders are all replaced."""
46-
template = "https://api.runpod.ai/$RUNPOD_POD_ID/test/$RUNPOD_POD_ID"
47-
result = _resolve_url(template, "pod-456")
48-
self.assertEqual(result, "https://api.runpod.ai/pod-456/test/pod-456")
49-
5069

5170
class TestServerlessConfig(unittest.TestCase):
5271
"""Tests for serverless config loading."""

0 commit comments

Comments
 (0)