Skip to content

Commit b8fe5be

Browse files
committed
Use InstanceTerminationReason instead of strings
Use the `InstanceTerminationReason` enum as the type of `InstanceModel.termination_reason`. To handle old instances, automatically convert their legacy termination reason strings to relevant `InstanceTerminationReason` enum members on reads.
1 parent 84ca903 commit b8fe5be

7 files changed

Lines changed: 150 additions & 42 deletions

File tree

src/dstack/_internal/core/models/instances.py

Lines changed: 59 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,9 @@
1515
from dstack._internal.core.models.health import HealthStatus
1616
from dstack._internal.core.models.volumes import Volume
1717
from dstack._internal.utils.common import pretty_resources
18+
from dstack._internal.utils.logging import get_logger
19+
20+
logger = get_logger(__name__)
1821

1922

2023
class Gpu(CoreModel):
@@ -256,14 +259,68 @@ def finished_statuses(cls) -> List["InstanceStatus"]:
256259

257260
class InstanceTerminationReason(str, Enum):
258261
IDLE_TIMEOUT = "idle_timeout"
259-
PROOVISIONING_TIMEOUT = "provisioning_timeout"
262+
PROVISIONING_TIMEOUT = "provisioning_timeout"
260263
ERROR = "error"
261264
JOB_FINISHED = "job_finished"
262265
TERMINATION_TIMEOUT = "termination_timeout"
263266
STARTING_TIMEOUT = "starting_timeout"
264267
NO_OFFERS = "no_offers"
265268
MASTER_FAILED = "master_failed"
266-
NO_BALANCE = "no_balance"
269+
MAX_INSTANCES_LIMIT = "max_instances_limit"
270+
NO_BALANCE = "no_balance" # used in dstack Sky
271+
272+
@classmethod
273+
def from_legacy_str(cls, v: str) -> "InstanceTerminationReason":
274+
"""
275+
Convert legacy termination reason string to relevant termination reason enum.
276+
277+
dstack versions prior to 0.20.1 represented instance termination reasons as raw
278+
strings. Such strings may still be stored in the database.
279+
"""
280+
281+
if v == "Idle timeout":
282+
return cls.IDLE_TIMEOUT
283+
if v in (
284+
"Provisioning timeout expired",
285+
"Proivisioning timeout expired", # typo is intentional
286+
"The proivisioning timeout expired", # typo is intentional
287+
):
288+
return cls.PROVISIONING_TIMEOUT
289+
if v in (
290+
"Unsupported private SSH key type",
291+
"Failed to locate internal IP address on the given network",
292+
"Specified internal IP not found among instance interfaces",
293+
"Cannot split into blocks",
294+
"Backend not available",
295+
"Error while waiting for instance to become running",
296+
"Empty profile, requirements or instance_configuration",
297+
"Unable to locate the internal ip-address for the given network",
298+
"Private SSH key is encrypted, password required",
299+
"Cannot parse private key, key type is not supported",
300+
) or v.startswith("Error to parse profile, requirements or instance_configuration:"):
301+
return cls.ERROR
302+
if v in (
303+
"All offers failed",
304+
"No offers found",
305+
"There were no offers found",
306+
"Retry duration expired",
307+
"The retry's duration expired",
308+
):
309+
return cls.NO_OFFERS
310+
if v == "Master instance failed to start":
311+
return cls.MASTER_FAILED
312+
if v == "Instance job finished":
313+
return cls.JOB_FINISHED
314+
if v == "Termination deadline":
315+
return cls.TERMINATION_TIMEOUT
316+
if v == "Instance has not become running in time":
317+
return cls.STARTING_TIMEOUT
318+
if v == "Fleet has too many instances":
319+
return cls.MAX_INSTANCES_LIMIT
320+
if v == "Low account balance":
321+
return cls.NO_BALANCE
322+
logger.warning("Unexpected instance termination reason string: %r", v)
323+
return cls.ERROR
267324

268325

269326
class Instance(CoreModel):

src/dstack/_internal/server/background/tasks/process_fleets.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from sqlalchemy.orm import joinedload, load_only, selectinload
99

1010
from dstack._internal.core.models.fleets import FleetSpec, FleetStatus
11-
from dstack._internal.core.models.instances import InstanceStatus
11+
from dstack._internal.core.models.instances import InstanceStatus, InstanceTerminationReason
1212
from dstack._internal.server.db import get_db, get_session_ctx
1313
from dstack._internal.server.models import (
1414
FleetModel,
@@ -213,7 +213,8 @@ def _maintain_fleet_nodes_in_min_max_range(
213213
break
214214
if instance.status in [InstanceStatus.IDLE]:
215215
instance.status = InstanceStatus.TERMINATING
216-
instance.termination_reason = "Fleet has too many instances"
216+
instance.termination_reason = InstanceTerminationReason.MAX_INSTANCES_LIMIT
217+
instance.termination_reason_message = "Fleet has too many instances"
217218
nodes_redundant -= 1
218219
logger.info(
219220
"Terminating instance %s: %s",

src/dstack/_internal/server/background/tasks/process_instances.py

Lines changed: 17 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -275,7 +275,7 @@ def _check_and_mark_terminating_if_idle_duration_expired(instance: InstanceModel
275275
delta = datetime.timedelta(seconds=idle_seconds)
276276
if idle_duration > delta:
277277
instance.status = InstanceStatus.TERMINATING
278-
instance.termination_reason = InstanceTerminationReason.IDLE_TIMEOUT.value
278+
instance.termination_reason = InstanceTerminationReason.IDLE_TIMEOUT
279279
logger.info(
280280
"Instance %s idle duration expired: idle time %ss. Terminating",
281281
instance.name,
@@ -311,7 +311,7 @@ async def _add_remote(instance: InstanceModel) -> None:
311311
retry_duration_deadline = instance.created_at + timedelta(seconds=PROVISIONING_TIMEOUT_SECONDS)
312312
if retry_duration_deadline < get_current_datetime():
313313
instance.status = InstanceStatus.TERMINATED
314-
instance.termination_reason = InstanceTerminationReason.PROOVISIONING_TIMEOUT.value
314+
instance.termination_reason = InstanceTerminationReason.PROVISIONING_TIMEOUT
315315
logger.warning(
316316
"Failed to start instance %s in %d seconds. Terminating...",
317317
instance.name,
@@ -334,7 +334,7 @@ async def _add_remote(instance: InstanceModel) -> None:
334334
ssh_proxy_pkeys = None
335335
except (ValueError, PasswordRequiredException):
336336
instance.status = InstanceStatus.TERMINATED
337-
instance.termination_reason = InstanceTerminationReason.ERROR.value
337+
instance.termination_reason = InstanceTerminationReason.ERROR
338338
instance.termination_reason_message = "Unsupported private SSH key type"
339339
logger.warning(
340340
"Failed to add instance %s: unsupported private SSH key type",
@@ -393,7 +393,7 @@ async def _add_remote(instance: InstanceModel) -> None:
393393
)
394394
if instance_network is not None and internal_ip is None:
395395
instance.status = InstanceStatus.TERMINATED
396-
instance.termination_reason = InstanceTerminationReason.ERROR.value
396+
instance.termination_reason = InstanceTerminationReason.ERROR
397397
instance.termination_reason_message = (
398398
"Failed to locate internal IP address on the given network"
399399
)
@@ -409,7 +409,7 @@ async def _add_remote(instance: InstanceModel) -> None:
409409
if internal_ip is not None:
410410
if not is_ip_among_addresses(ip_address=internal_ip, addresses=host_network_addresses):
411411
instance.status = InstanceStatus.TERMINATED
412-
instance.termination_reason = InstanceTerminationReason.ERROR.value
412+
instance.termination_reason = InstanceTerminationReason.ERROR
413413
instance.termination_reason_message = (
414414
"Specified internal IP not found among instance interfaces"
415415
)
@@ -432,7 +432,7 @@ async def _add_remote(instance: InstanceModel) -> None:
432432
instance.total_blocks = blocks
433433
else:
434434
instance.status = InstanceStatus.TERMINATED
435-
instance.termination_reason = InstanceTerminationReason.ERROR.value
435+
instance.termination_reason = InstanceTerminationReason.ERROR
436436
instance.termination_reason_message = "Cannot split into blocks"
437437
logger.warning(
438438
"Failed to add instance %s: cannot split into blocks",
@@ -552,7 +552,7 @@ async def _create_instance(session: AsyncSession, instance: InstanceModel) -> No
552552
requirements = get_instance_requirements(instance)
553553
except ValidationError as e:
554554
instance.status = InstanceStatus.TERMINATED
555-
instance.termination_reason = InstanceTerminationReason.ERROR.value
555+
instance.termination_reason = InstanceTerminationReason.ERROR
556556
instance.termination_reason_message = (
557557
f"Error to parse profile, requirements or instance_configuration: {e}"
558558
)
@@ -679,17 +679,19 @@ async def _create_instance(session: AsyncSession, instance: InstanceModel) -> No
679679
)
680680
return
681681

682-
_mark_terminated(instance, InstanceTerminationReason.NO_OFFERS.value)
682+
_mark_terminated(instance, InstanceTerminationReason.NO_OFFERS)
683683
if instance.fleet and is_fleet_master_instance(instance) and is_cloud_cluster(instance.fleet):
684684
# Do not attempt to deploy other instances, as they won't determine the correct cluster
685685
# backend, region, and placement group without a successfully deployed master instance
686686
for sibling_instance in instance.fleet.instances:
687687
if sibling_instance.id == instance.id:
688688
continue
689-
_mark_terminated(sibling_instance, InstanceTerminationReason.MASTER_FAILED.value)
689+
_mark_terminated(sibling_instance, InstanceTerminationReason.MASTER_FAILED)
690690

691691

692-
def _mark_terminated(instance: InstanceModel, termination_reason: str) -> None:
692+
def _mark_terminated(
693+
instance: InstanceModel, termination_reason: InstanceTerminationReason
694+
) -> None:
693695
instance.status = InstanceStatus.TERMINATED
694696
instance.termination_reason = termination_reason
695697
logger.info(
@@ -711,7 +713,7 @@ async def _check_instance(session: AsyncSession, instance: InstanceModel) -> Non
711713
):
712714
# A busy instance could have no active jobs due to this bug: https://github.com/dstackai/dstack/issues/2068
713715
instance.status = InstanceStatus.TERMINATING
714-
instance.termination_reason = InstanceTerminationReason.JOB_FINISHED.value
716+
instance.termination_reason = InstanceTerminationReason.JOB_FINISHED
715717
logger.info(
716718
"Detected busy instance %s with finished job. Marked as TERMINATING",
717719
instance.name,
@@ -840,7 +842,7 @@ async def _check_instance(session: AsyncSession, instance: InstanceModel) -> Non
840842
deadline = instance.termination_deadline
841843
if get_current_datetime() > deadline:
842844
instance.status = InstanceStatus.TERMINATING
843-
instance.termination_reason = InstanceTerminationReason.TERMINATION_TIMEOUT.value
845+
instance.termination_reason = InstanceTerminationReason.TERMINATION_TIMEOUT
844846
logger.warning(
845847
"Instance %s shim waiting timeout. Marked as TERMINATING",
846848
instance.name,
@@ -869,7 +871,7 @@ async def _wait_for_instance_provisioning_data(
869871
"Instance %s failed because instance has not become running in time", instance.name
870872
)
871873
instance.status = InstanceStatus.TERMINATING
872-
instance.termination_reason = InstanceTerminationReason.STARTING_TIMEOUT.value
874+
instance.termination_reason = InstanceTerminationReason.STARTING_TIMEOUT
873875
return
874876

875877
backend = await backends_services.get_project_backend_by_type(
@@ -882,7 +884,7 @@ async def _wait_for_instance_provisioning_data(
882884
instance.name,
883885
)
884886
instance.status = InstanceStatus.TERMINATING
885-
instance.termination_reason = InstanceTerminationReason.ERROR.value
887+
instance.termination_reason = InstanceTerminationReason.ERROR
886888
instance.termination_reason_message = "Backend not available"
887889
return
888890
try:
@@ -900,7 +902,7 @@ async def _wait_for_instance_provisioning_data(
900902
repr(e),
901903
)
902904
instance.status = InstanceStatus.TERMINATING
903-
instance.termination_reason = InstanceTerminationReason.ERROR.value
905+
instance.termination_reason = InstanceTerminationReason.ERROR
904906
instance.termination_reason_message = "Error while waiting for instance to become running"
905907
except Exception:
906908
logger.exception(

src/dstack/_internal/server/background/tasks/process_running_jobs.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -393,7 +393,7 @@ async def _process_running_job(session: AsyncSession, job_model: JobModel):
393393
if (
394394
job_model.instance is not None
395395
and job_model.instance.termination_reason
396-
== InstanceTerminationReason.NO_BALANCE.value
396+
== InstanceTerminationReason.NO_BALANCE
397397
):
398398
# if instance was terminated due to no balance, set job termination reason accodingly
399399
job_model.termination_reason = JobTerminationReason.NO_BALANCE

src/dstack/_internal/server/models.py

Lines changed: 35 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import enum
22
import uuid
33
from datetime import datetime, timezone
4-
from typing import Callable, List, Optional, Union
4+
from typing import Callable, Generic, List, Optional, TypeVar, Union
55

66
from sqlalchemy import (
77
BigInteger,
@@ -30,7 +30,7 @@
3030
from dstack._internal.core.models.fleets import FleetStatus
3131
from dstack._internal.core.models.gateways import GatewayStatus
3232
from dstack._internal.core.models.health import HealthStatus
33-
from dstack._internal.core.models.instances import InstanceStatus
33+
from dstack._internal.core.models.instances import InstanceStatus, InstanceTerminationReason
3434
from dstack._internal.core.models.profiles import (
3535
DEFAULT_FLEET_TERMINATION_IDLE_TIME,
3636
TerminationPolicy,
@@ -141,26 +141,45 @@ def process_result_value(self, value: Optional[str], dialect) -> Optional[Decryp
141141
return DecryptedString(plaintext=None, decrypted=False, exc=e)
142142

143143

144-
class EnumAsString(TypeDecorator):
144+
E = TypeVar("E", bound=enum.Enum)
145+
146+
147+
class EnumAsString(TypeDecorator, Generic[E]):
145148
"""
146149
A custom type decorator that stores enums as strings in the DB.
147150
"""
148151

149152
impl = String
150153
cache_ok = True
151154

152-
def __init__(self, enum_class: type[enum.Enum], *args, **kwargs):
155+
def __init__(
156+
self,
157+
enum_class: type[E],
158+
*args,
159+
fallback_deserializer: Optional[Callable[[str], E]] = None,
160+
**kwargs,
161+
):
162+
"""
163+
Args:
164+
enum_class: The enum class to be stored.
165+
fallback_deserializer: An optional function used when the string
166+
from the DB does not match any enum member name. If not
167+
provided, an exception will be raised in such cases.
168+
"""
153169
self.enum_class = enum_class
170+
self.fallback_deserializer = fallback_deserializer
154171
super().__init__(*args, **kwargs)
155172

156-
def process_bind_param(self, value: Optional[enum.Enum], dialect) -> Optional[str]:
173+
def process_bind_param(self, value: Optional[E], dialect) -> Optional[str]:
157174
if value is None:
158175
return None
159176
return value.name
160177

161-
def process_result_value(self, value: Optional[str], dialect) -> Optional[enum.Enum]:
178+
def process_result_value(self, value: Optional[str], dialect) -> Optional[E]:
162179
if value is None:
163180
return None
181+
if value not in self.enum_class.__members__ and self.fallback_deserializer is not None:
182+
return self.fallback_deserializer(value)
164183
return self.enum_class[value]
165184

166185

@@ -641,8 +660,16 @@ class InstanceModel(BaseModel):
641660

642661
# instance termination handling
643662
termination_deadline: Mapped[Optional[datetime]] = mapped_column(NaiveDateTime)
644-
# TODO: Migrate to EnumAsString(InstanceTerminationReason, 100) after enough releases to ensure backward compatibility
645-
termination_reason: Mapped[Optional[str]] = mapped_column(String(4000))
663+
# dstack versions prior to 0.20.1 represented instance termination reasons as raw strings.
664+
# Such strings may still be stored in the database, so we are using a wide column (4000 chars)
665+
# and a fallback deserializer to convert them to relevant enum members.
666+
termination_reason: Mapped[Optional[InstanceTerminationReason]] = mapped_column(
667+
EnumAsString(
668+
InstanceTerminationReason,
669+
4000,
670+
fallback_deserializer=InstanceTerminationReason.from_legacy_str,
671+
)
672+
)
646673
termination_reason_message: Mapped[Optional[str]] = mapped_column(String(4000))
647674
# Deprecated since 0.19.22, not used
648675
health_status: Mapped[Optional[str]] = mapped_column(String(4000), deferred=True)

0 commit comments

Comments
 (0)