Skip to content

Commit 89c7a18

Browse files
[Services] Add default probes if model is set #3522
1 parent 9c81898 commit 89c7a18

2 files changed

Lines changed: 81 additions & 3 deletions

File tree

src/dstack/_internal/core/models/configurations.py

Lines changed: 36 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,8 @@
5656
DEFAULT_PROBE_METHOD = "get"
5757
MAX_PROBE_URL_LEN = 2048
5858
DEFAULT_REPLICA_GROUP_NAME = "0"
59+
DEFAULT_MODEL_PROBE_TIMEOUT = 30
60+
DEFAULT_MODEL_PROBE_URL = "/v1/chat/completions"
5961

6062

6163
class RunConfigurationType(str, Enum):
@@ -851,9 +853,9 @@ class ServiceConfigurationParams(CoreModel):
851853
] = None
852854
rate_limits: Annotated[list[RateLimit], Field(description="Rate limiting rules")] = []
853855
probes: Annotated[
854-
list[ProbeConfig],
856+
Optional[list[ProbeConfig]],
855857
Field(description="List of probes used to determine job health"),
856-
] = []
858+
] = None # None = omitted (may get default when model is set); [] = explicit empty
857859

858860
replicas: Annotated[
859861
Optional[Union[List[ReplicaGroup], Range[int]]],
@@ -895,7 +897,9 @@ def validate_rate_limits(cls, v: list[RateLimit]) -> list[RateLimit]:
895897
return v
896898

897899
@validator("probes")
898-
def validate_probes(cls, v: list[ProbeConfig]) -> list[ProbeConfig]:
900+
def validate_probes(cls, v: Optional[list[ProbeConfig]]) -> Optional[list[ProbeConfig]]:
901+
if v is None:
902+
return v
899903
if has_duplicates(v):
900904
# Using a custom validator instead of Field(unique_items=True) to avoid Pydantic bug:
901905
# https://github.com/pydantic/pydantic/issues/3765
@@ -932,6 +936,35 @@ def validate_replicas(
932936
)
933937
return v
934938

939+
@root_validator()
940+
def set_default_probes_for_model(cls, values):
941+
model = values.get("model")
942+
probes = values.get("probes")
943+
if model is not None and probes is None:
944+
body = orjson.dumps(
945+
{
946+
"model": model.name,
947+
"messages": [{"role": "user", "content": "hi"}],
948+
"max_tokens": 1,
949+
}
950+
).decode("utf-8")
951+
values["probes"] = [
952+
ProbeConfig(
953+
type="http",
954+
method="post",
955+
url=DEFAULT_MODEL_PROBE_URL,
956+
headers=[
957+
HTTPHeaderSpec(name="Content-Type", value="application/json"),
958+
],
959+
body=body,
960+
timeout=DEFAULT_MODEL_PROBE_TIMEOUT,
961+
)
962+
]
963+
elif probes is None:
964+
# Probes omitted and model not set: normalize to empty list for downstream.
965+
values["probes"] = []
966+
return values
967+
935968
@root_validator()
936969
def validate_scaling(cls, values):
937970
scaling = values.get("scaling")

src/tests/_internal/core/models/test_configurations.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
from dstack._internal.core.errors import ConfigurationError
66
from dstack._internal.core.models.common import RegistryAuth
77
from dstack._internal.core.models.configurations import (
8+
DEFAULT_MODEL_PROBE_TIMEOUT,
9+
DEFAULT_MODEL_PROBE_URL,
810
DevEnvironmentConfigurationParams,
911
RepoSpec,
1012
parse_run_configuration,
@@ -13,6 +15,49 @@
1315

1416

1517
class TestParseConfiguration:
18+
def test_service_model_sets_default_probes_when_probes_omitted(self):
19+
conf = {
20+
"type": "service",
21+
"commands": ["python3 -m http.server"],
22+
"port": 8000,
23+
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
24+
}
25+
parsed = parse_run_configuration(conf)
26+
assert len(parsed.probes) == 1
27+
probe = parsed.probes[0]
28+
assert probe.type == "http"
29+
assert probe.method == "post"
30+
assert probe.url == DEFAULT_MODEL_PROBE_URL
31+
assert probe.timeout == DEFAULT_MODEL_PROBE_TIMEOUT
32+
assert len(probe.headers) == 1
33+
assert probe.headers[0].name == "Content-Type"
34+
assert probe.headers[0].value == "application/json"
35+
assert "meta-llama/Meta-Llama-3.1-8B-Instruct" in (probe.body or "")
36+
assert "max_tokens" in (probe.body or "")
37+
38+
def test_service_model_does_not_override_explicit_probes(self):
39+
conf = {
40+
"type": "service",
41+
"commands": ["python3 -m http.server"],
42+
"port": 8000,
43+
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
44+
"probes": [{"type": "http", "url": "/health"}],
45+
}
46+
parsed = parse_run_configuration(conf)
47+
assert len(parsed.probes) == 1
48+
assert parsed.probes[0].url == "/health"
49+
50+
def test_service_model_explicit_empty_probes_no_default(self):
51+
conf = {
52+
"type": "service",
53+
"commands": ["python3 -m http.server"],
54+
"port": 8000,
55+
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
56+
"probes": [],
57+
}
58+
parsed = parse_run_configuration(conf)
59+
assert len(parsed.probes) == 0
60+
1661
def test_services_replicas_and_scaling(self):
1762
def test_conf(replicas: Any, scaling: Optional[Any] = None):
1863
conf = {

0 commit comments

Comments
 (0)