Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
189 changes: 189 additions & 0 deletions checkpoint/orbax/checkpoint/experimental/tiering_service/db_schema.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,189 @@
# Copyright 2026 The Orbax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""SQL database schema definition for the Checkpoint Tiering Service (CTS).

Provides SQLAlchemy models for tracking assets, tier paths, and job queues.
"""

import enum
import json
from typing import Any, Dict, List, Optional
import uuid

import sqlalchemy
import sqlalchemy.orm

Base = sqlalchemy.orm.declarative_base()


class AssetState(enum.IntEnum):
"""Represents the lifecycle state of an asset tracked by CTS."""

ASSET_STATE_UNSPECIFIED = 0
ASSET_STATE_ACTIVE_WRITE = 1
ASSET_STATE_STORED = 2
ASSET_STATE_DELETED = 3
ASSET_STATE_INCOMPLETE = 4


class BackendType(enum.IntEnum):
"""Identifies the storage backend type for a tier path."""

BACKEND_TYPE_UNSPECIFIED = 0
BACKEND_TYPE_LUSTRE = 1
BACKEND_TYPE_GCS = 2


class JobStatus(enum.Enum):
"""Represents the execution status of an asset job."""

QUEUED = "QUEUED"
PROCESSING = "PROCESSING"
COMPLETED = "COMPLETED"
FAILED = "FAILED"


class Asset(Base):
"""Database model representing a distinct CTS asset.

Acts as the primary entity holding metadata and coordinating storage locations
and execution queues.
"""

__tablename__ = "assets"

uuid = sqlalchemy.Column(
sqlalchemy.String,
primary_key=True,
default=lambda: str(uuid.uuid4()),
index=True,
)
unique_path = sqlalchemy.Column(sqlalchemy.String, index=True, nullable=False)
user = sqlalchemy.Column(sqlalchemy.String, nullable=False)
tags = sqlalchemy.Column(sqlalchemy.JSON, nullable=True)
state = sqlalchemy.Column(
sqlalchemy.Enum(AssetState), default=AssetState.ASSET_STATE_UNSPECIFIED
)
created_at = sqlalchemy.Column(sqlalchemy.DateTime, nullable=True)
finalized_at = sqlalchemy.Column(sqlalchemy.DateTime, nullable=True)
deleted_at = sqlalchemy.Column(sqlalchemy.DateTime, nullable=True)
updated_at = sqlalchemy.Column(sqlalchemy.DateTime, nullable=True)

tier_paths = sqlalchemy.orm.relationship(
"TierPath", back_populates="asset", cascade="all, delete-orphan"
)
jobs = sqlalchemy.orm.relationship(
"AssetJob", back_populates="asset", cascade="all, delete-orphan"
)

__table_args__ = (
# Enforce unique_path only for live assets (ACTIVE_WRITE, STORED).
# Duplicates are allowed for DELETED or INCOMPLETE states.
sqlalchemy.Index(
"idx_assets_unique_path_active_stored",
"unique_path",
unique=True,
sqlite_where=sqlalchemy.text(
"state IN ('ASSET_STATE_ACTIVE_WRITE', 'ASSET_STATE_STORED')"
),
),
)


class TierPath(Base):
"""Database model representing a storage location for an asset.

Links a specific storage level and path back to the parent asset.
"""

__tablename__ = "tier_paths"

id = sqlalchemy.Column(
sqlalchemy.Integer, primary_key=True, autoincrement=True
)
asset_uuid = sqlalchemy.Column(
sqlalchemy.String,
sqlalchemy.ForeignKey("assets.uuid", ondelete="CASCADE"),
nullable=False,
)
level = sqlalchemy.Column(sqlalchemy.Integer, nullable=False)
zone = sqlalchemy.Column(sqlalchemy.String, nullable=True)
region = sqlalchemy.Column(sqlalchemy.String, nullable=True)
multi_region_json = sqlalchemy.Column(sqlalchemy.Text, nullable=True)
backend_type = sqlalchemy.Column(
sqlalchemy.Enum(BackendType), default=BackendType.BACKEND_TYPE_UNSPECIFIED
)
path = sqlalchemy.Column(sqlalchemy.String, nullable=False)
ready_at = sqlalchemy.Column(sqlalchemy.DateTime, nullable=True)
expires_at = sqlalchemy.Column(sqlalchemy.DateTime, nullable=True)

asset = sqlalchemy.orm.relationship("Asset", back_populates="tier_paths")

@property
def multi_region(self) -> List[str]:
"""Gets the deserialized list of multi-regional locations."""
if not self.multi_region_json:
return []
return json.loads(self.multi_region_json)

@multi_region.setter
def multi_region(self, value: Optional[List[str]]):
"""Sets the list of multi-regional locations, serialized as JSON."""
if value is None:
self.multi_region_json = None
else:
self.multi_region_json = json.dumps(value)


class AssetJob(Base):
"""Database model representing an ACID transactional job within a queue.

Validates global serial execution for each asset to eliminate race conditions.
"""

__tablename__ = "asset_jobs"

id = sqlalchemy.Column(
sqlalchemy.Integer, primary_key=True, autoincrement=True
)
asset_uuid = sqlalchemy.Column(
sqlalchemy.String,
sqlalchemy.ForeignKey("assets.uuid", ondelete="CASCADE"),
nullable=False,
)
request_type = sqlalchemy.Column(sqlalchemy.String, nullable=False)
status = sqlalchemy.Column(
sqlalchemy.Enum(JobStatus), default=JobStatus.QUEUED, index=True
)
payload_json = sqlalchemy.Column(sqlalchemy.Text, nullable=True)
created_at = sqlalchemy.Column(sqlalchemy.DateTime, nullable=True)
completed_at = sqlalchemy.Column(sqlalchemy.DateTime, nullable=True)

asset = sqlalchemy.orm.relationship("Asset", back_populates="jobs")

@property
def payload(self) -> Dict[str, Any]:
"""Gets the deserialized job payload metadata."""
if not self.payload_json:
return {}
return json.loads(self.payload_json)

@payload.setter
def payload(self, value: Optional[Dict[str, Any]]):
"""Sets the job payload metadata, serialized as JSON."""
if value is None:
self.payload_json = None
else:
self.payload_json = json.dumps(value)
Loading
Loading