Skip to content

Commit abba7da

Browse files
Bihan  RanaBihan  Rana
authored andcommitted
Resolve pyright type check
1 parent 5abbcad commit abba7da

6 files changed

Lines changed: 57 additions & 76 deletions

File tree

src/dstack/_internal/core/models/configurations.py

Lines changed: 22 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -838,25 +838,19 @@ class ServiceConfigurationParams(CoreModel):
838838
SERVICE_HTTPS_DEFAULT
839839
)
840840
auth: Annotated[bool, Field(description="Enable the authorization")] = True
841-
# replicas: Annotated[
842-
# Range[int],
843-
# Field(
844-
# description="The number of replicas. Can be a number (e.g. `2`) or a range (`0..4` or `1..8`). "
845-
# "If it's a range, the `scaling` property is required"
846-
# ),
847-
# ] = Range[int](min=1, max=1)
848-
# scaling: Annotated[
849-
# Optional[ScalingSpec],
850-
# Field(description="The auto-scaling rules. Required if `replicas` is set to a range"),
851-
# ] = None
841+
842+
scaling: Annotated[
843+
Optional[ScalingSpec],
844+
Field(description="The auto-scaling rules. Required if `replicas` is set to a range"),
845+
] = None
852846
rate_limits: Annotated[list[RateLimit], Field(description="Rate limiting rules")] = []
853847
probes: Annotated[
854848
list[ProbeConfig],
855849
Field(description="List of probes used to determine job health"),
856850
] = []
857851

858852
replicas: Annotated[
859-
Optional[Union[Range[int], List[ReplicaGroup], int, str]],
853+
Optional[Union[Range[int], List[ReplicaGroup]]],
860854
Field(
861855
description=(
862856
"List of replica groups. Each group defines replicas with shared configuration "
@@ -882,16 +876,6 @@ def convert_model(cls, v: Optional[Union[AnyModel, str]]) -> Optional[AnyModel]:
882876
return OpenAIChatModel(type="chat", name=v, format="openai")
883877
return v
884878

885-
# @validator("replicas")
886-
# def convert_replicas(cls, v: Range[int]) -> Range[int]:
887-
# if v.max is None:
888-
# raise ValueError("The maximum number of replicas is required")
889-
# if v.min is None:
890-
# v.min = 0
891-
# if v.min < 0:
892-
# raise ValueError("The minimum number of replicas must be greater than or equal to 0")
893-
# return v
894-
895879
@validator("gateway")
896880
def validate_gateway(
897881
cls, v: Optional[Union[bool, str]]
@@ -902,22 +886,6 @@ def validate_gateway(
902886
)
903887
return v
904888

905-
# @root_validator()
906-
# def validate_scaling(cls, values):
907-
# replica_groups = values.get("replica_groups")
908-
# # If replica_groups are set, we don't need to validate scaling.
909-
# # Each replica group has its own scaling.
910-
# if replica_groups:
911-
# return values
912-
913-
# scaling = values.get("scaling")
914-
# replicas = values.get("replicas")
915-
# if replicas and replicas.min != replicas.max and not scaling:
916-
# raise ValueError("When you set `replicas` to a range, ensure to specify `scaling`.")
917-
# if replicas and replicas.min == replicas.max and scaling:
918-
# raise ValueError("To use `scaling`, `replicas` must be set to a range.")
919-
# return values
920-
921889
@root_validator()
922890
def normalize_replicas(cls, values):
923891
replicas = values.get("replicas")
@@ -966,10 +934,12 @@ def validate_probes(cls, v: list[ProbeConfig]) -> list[ProbeConfig]:
966934
return v
967935

968936
@validator("replicas")
969-
def validate_replicas(cls, v: Optional[List[ReplicaGroup]]) -> Optional[List[ReplicaGroup]]:
937+
def validate_replicas(
938+
cls, v: Optional[Union[Range[int], List[ReplicaGroup]]]
939+
) -> Optional[Union[Range[int], List[ReplicaGroup]]]:
970940
if v is None:
971941
return v
972-
if isinstance(v, (Range, int, str)):
942+
if isinstance(v, Range):
973943
return v
974944

975945
if isinstance(v, list):
@@ -1007,6 +977,18 @@ class ServiceConfiguration(
1007977
):
1008978
type: Literal["service"] = "service"
1009979

980+
@property
981+
def replica_groups(self) -> Optional[List[ReplicaGroup]]:
982+
"""
983+
Get normalized replica groups. After validation, replicas is always List[ReplicaGroup] or None.
984+
Use this property for type-safe access in code.
985+
"""
986+
if self.replicas is None:
987+
return None
988+
if isinstance(self.replicas, list):
989+
return self.replicas
990+
return None
991+
1010992

1011993
AnyRunConfiguration = Union[DevEnvironmentConfiguration, TaskConfiguration, ServiceConfiguration]
1012994

src/dstack/_internal/server/background/tasks/process_runs.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -196,10 +196,9 @@ async def _process_pending_run(session: AsyncSession, run_model: RunModel):
196196
logger.debug("%s: retrying run is not yet ready for resubmission", fmt(run_model))
197197
return
198198

199-
# run_model.desired_replica_count = 1
200199
if run.run_spec.configuration.type == "service":
201200
run_model.desired_replica_count = sum(
202-
group.replicas.min or 0 for group in run.run_spec.configuration.replicas
201+
group.replicas.min or 0 for group in (run.run_spec.configuration.replica_groups or [])
203202
)
204203
await update_service_desired_replica_count(
205204
session,
@@ -214,7 +213,7 @@ async def _process_pending_run(session: AsyncSession, run_model: RunModel):
214213
return
215214

216215
# Per group scaling because single replica is also normalized to replica groups.
217-
replicas = run.run_spec.configuration.replicas or []
216+
replicas: List[ReplicaGroup] = run.run_spec.configuration.replica_groups or []
218217
counts = (
219218
json.loads(run_model.desired_replica_counts)
220219
if run_model.desired_replica_counts
@@ -461,7 +460,7 @@ async def _handle_run_replicas(
461460
# FIXME: should only include scaling events, not retries and deployments
462461
last_scaled_at=max((r.timestamp for r in replicas_info), default=None),
463462
)
464-
replicas = run_spec.configuration.replicas or []
463+
replicas: List[ReplicaGroup] = run_spec.configuration.replica_groups or []
465464
if replicas:
466465
counts = (
467466
json.loads(run_model.desired_replica_counts)

src/dstack/_internal/server/services/runs/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -520,7 +520,7 @@ async def submit_run(
520520

521521
global_replica_num = 0 # Global counter across all groups for unique replica_num
522522

523-
for replica_group in service_config.replicas:
523+
for replica_group in service_config.replica_groups or []:
524524
if run_spec.merged_profile.schedule is not None:
525525
group_initial_replicas = 0
526526
else:

src/dstack/_internal/server/services/runs/spec.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ def validate_run_spec_and_set_defaults(
9090
if isinstance(run_spec.configuration, ServiceConfiguration):
9191
# Check if any group has min=0
9292
if run_spec.merged_profile.schedule and any(
93-
group.replicas.min == 0 for group in run_spec.configuration.replicas
93+
group.replicas.min == 0 for group in (run_spec.configuration.replica_groups or [])
9494
):
9595
raise ServerClientError(
9696
"Scheduled services with autoscaling to zero are not supported"
@@ -154,7 +154,7 @@ def get_nodes_required_num(run_spec: RunSpec) -> int:
154154
nodes_required_num = run_spec.configuration.nodes
155155
elif run_spec.configuration.type == "service":
156156
nodes_required_num = sum(
157-
group.replicas.min or 0 for group in run_spec.configuration.replicas
157+
group.replicas.min or 0 for group in (run_spec.configuration.replica_groups or [])
158158
)
159159
return nodes_required_num
160160

src/dstack/_internal/server/services/services/__init__.py

Lines changed: 14 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,8 @@ def _register_service_in_server(run_model: RunModel, run_spec: RunSpec) -> Servi
144144
)
145145
# Check if any group has autoscaling (min != max)
146146
has_autoscaling = any(
147-
group.replicas.min != group.replicas.max for group in run_spec.configuration.replicas
147+
group.replicas.min != group.replicas.max
148+
for group in (run_spec.configuration.replica_groups or [])
148149
)
149150
if has_autoscaling:
150151
raise ServerClientError(
@@ -308,21 +309,17 @@ async def update_service_desired_replica_count(
308309
if run_model.gateway_id is not None:
309310
conn = await get_or_add_gateway_connection(session, run_model.gateway_id)
310311
stats = await conn.get_stats(run_model.project.name, run_model.run_name)
311-
if configuration.replicas:
312+
replica_groups = configuration.replica_groups or []
313+
if replica_groups:
312314
desired_replica_counts = {}
313315
total = 0
314316
prev_counts = (
315317
json.loads(run_model.desired_replica_counts)
316318
if run_model.desired_replica_counts
317319
else {}
318320
)
319-
for group in configuration.replicas:
320-
# temp group_wise config to get the group_wise desired replica count.
321-
group_config = configuration.copy(
322-
exclude={"replicas"},
323-
update={"replicas": group.replicas, "scaling": group.scaling},
324-
)
325-
scaler = get_service_scaler(group_config)
321+
for group in replica_groups:
322+
scaler = get_service_scaler(group.replicas, group.scaling)
326323
group_desired = scaler.get_desired_count(
327324
current_desired_count=prev_counts.get(group.name, group.replicas.min or 0),
328325
stats=stats,
@@ -334,9 +331,11 @@ async def update_service_desired_replica_count(
334331
run_model.desired_replica_count = total
335332
else:
336333
# Todo Not required as single replica is normalized to replicas.
337-
scaler = get_service_scaler(configuration)
338-
run_model.desired_replica_count = scaler.get_desired_count(
339-
current_desired_count=run_model.desired_replica_count,
340-
stats=stats,
341-
last_scaled_at=last_scaled_at,
342-
)
334+
if configuration.replica_groups:
335+
first_group = configuration.replica_groups[0]
336+
scaler = get_service_scaler(count=first_group.replicas, scaling=first_group.scaling)
337+
run_model.desired_replica_count = scaler.get_desired_count(
338+
current_desired_count=run_model.desired_replica_count,
339+
stats=stats,
340+
last_scaled_at=last_scaled_at,
341+
)

src/dstack/_internal/server/services/services/autoscalers.py

Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,8 @@
66
from pydantic import BaseModel
77

88
import dstack._internal.utils.common as common_utils
9-
from dstack._internal.core.models.configurations import ServiceConfiguration
9+
from dstack._internal.core.models.configurations import ScalingSpec
10+
from dstack._internal.core.models.resources import Range
1011
from dstack._internal.proxy.gateway.schemas.stats import PerWindowStats
1112

1213

@@ -119,21 +120,21 @@ def get_desired_count(
119120
return new_desired_count
120121

121122

122-
def get_service_scaler(conf: ServiceConfiguration) -> BaseServiceScaler:
123-
assert conf.replicas.min is not None
124-
assert conf.replicas.max is not None
125-
if conf.scaling is None:
123+
def get_service_scaler(count: Range[int], scaling: Optional[ScalingSpec]) -> BaseServiceScaler:
124+
assert count.min is not None
125+
assert count.max is not None
126+
if scaling is None:
126127
return ManualScaler(
127-
min_replicas=conf.replicas.min,
128-
max_replicas=conf.replicas.max,
128+
min_replicas=count.min,
129+
max_replicas=count.max,
129130
)
130-
if conf.scaling.metric == "rps":
131+
if scaling.metric == "rps":
131132
return RPSAutoscaler(
132133
# replicas count validated by configuration model
133-
min_replicas=conf.replicas.min,
134-
max_replicas=conf.replicas.max,
135-
target=conf.scaling.target,
136-
scale_up_delay=conf.scaling.scale_up_delay,
137-
scale_down_delay=conf.scaling.scale_down_delay,
134+
min_replicas=count.min,
135+
max_replicas=count.max,
136+
target=scaling.target,
137+
scale_up_delay=scaling.scale_up_delay,
138+
scale_down_delay=scaling.scale_down_delay,
138139
)
139-
raise ValueError(f"No scaler found for scaling parameters {conf.scaling}")
140+
raise ValueError(f"No scaler found for scaling parameters {scaling}")

0 commit comments

Comments
 (0)