Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion shard_core/data_model/backend/api_token_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from .permission_model import PermissionHolder, Permission


class ApiTokenResult(PermissionHolder, BaseModel):
class ApiTokenDb(BaseModel):
id: uuid.UUID
name: str
created: datetime
Expand All @@ -18,6 +18,10 @@ class ApiTokenResult(PermissionHolder, BaseModel):
owner_name: str | None


class ApiTokenResult(ApiTokenDb, PermissionHolder):
pass


class ApiTokenCreate(BaseModel):
name: str
permissions: Set[Permission]
4 changes: 4 additions & 0 deletions shard_core/data_model/backend/settings_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,10 @@
class SettingKey(StrEnum):
MIN_NR_OF_STANDBY_SHARDS = auto()
NEW_INSTANCE_SIZE = auto()
AUTO_PROVISIONING_ENABLED = auto()
NEW_SHARD_CORE_VERSION = auto()
TRIAL_MAX_VM_SIZE = auto()
SUBSCRIBED_MAX_VM_SIZE = auto()


class Setting(BaseModel):
Expand Down
81 changes: 67 additions & 14 deletions shard_core/data_model/backend/shard_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,12 @@

from datetime import datetime
from enum import StrEnum, auto
from functools import total_ordering
from typing import List

from pydantic import BaseModel

from .permission_model import PermissionHolder
from .subscription_model import SubscriptionStatus
from .telemetry_model import Telemetry


Expand All @@ -32,18 +32,35 @@ class ShardStatus(StrEnum):
_VM_SIZE_ORDER = ["xs", "s", "m", "l", "xl"]


@total_ordering
class VmSize(StrEnum):
XS = auto()
S = auto()
M = auto()
L = auto()
XL = auto()

def _idx(self) -> int:
return _VM_SIZE_ORDER.index(self.value)

def __lt__(self, other: object) -> bool:
if not isinstance(other, VmSize):
return NotImplemented
return _VM_SIZE_ORDER.index(self.value) < _VM_SIZE_ORDER.index(other.value)
return self._idx() < other._idx()

def __le__(self, other: object) -> bool:
if not isinstance(other, VmSize):
return NotImplemented
return self._idx() <= other._idx()

def __gt__(self, other: object) -> bool:
if not isinstance(other, VmSize):
return NotImplemented
return self._idx() > other._idx()

def __ge__(self, other: object) -> bool:
if not isinstance(other, VmSize):
return NotImplemented
return self._idx() >= other._idx()

def __eq__(self, other: object) -> bool:
return str.__eq__(self, other)
Expand All @@ -58,7 +75,7 @@ class Cloud(StrEnum):
OVHCLOUD = auto()


class ShardBase(PermissionHolder, BaseModel):
class ShardBase(BaseModel):
machine_id: str | None
hash_id: str | None = None
domain: str | None = None
Expand All @@ -81,6 +98,10 @@ class ShardBase(PermissionHolder, BaseModel):
auto_managed: bool = True
core_version: str | None = None
last_seen_backup: datetime | None = None
subscription_id: int | None = None
price_cents: int | None = None
pending_vm_size: VmSize | None = None
pending_price_cents: int | None = None

@property
def short_id(self) -> str:
Expand All @@ -91,15 +112,36 @@ class ShardDb(ShardBase):
id: int


class ShardResponse(ShardDb):
class ShardWithPermissions(ShardDb, PermissionHolder):
"""ShardDb enriched with permissions loaded from the DB."""

pass


class ShardSubscriptionSummary(BaseModel):
status: SubscriptionStatus
price_cents: int
currency: str
next_billing_date: datetime | None = None
payer_email: str | None = None
pending_vm_size: VmSize | None = None
pending_price_cents: int | None = None
paypal_manage_url: str


class ShardResponse(ShardWithPermissions):
telemetry: List[Telemetry]
telemetry_start: datetime
telemetry_end: datetime
subscription: ShardSubscriptionSummary | None = None


class ShardUpdate(BaseModel):
owner_name: str | None = None
max_vm_size: VmSize | None = None
delete_after: datetime | None = None
status: ShardStatus | None = None
core_version: str | None = None


class ShardCreateDb(BaseModel):
Expand Down Expand Up @@ -131,6 +173,10 @@ class ShardUpdateDb(BaseModel):
time_assigned: datetime | None = None
core_version: str | None = None
last_seen_backup: datetime | None = None
subscription_id: int | None = None
price_cents: int | None = None
pending_vm_size: VmSize | None = None
pending_price_cents: int | None = None


class AppUsageReport(BaseModel):
Expand All @@ -157,6 +203,10 @@ class AssignShardRequest(BaseModel):
owner_email: str | None = None


class AssignTrialRequest(BaseModel):
owner_email: str | None = None


class AssignShardResponse(BaseModel):
hash_id: str
domain: str
Expand Down Expand Up @@ -191,6 +241,15 @@ class AddPubkeyRequest(BaseModel):
pubkey: str


class BulkUpgradeCoreRequest(BaseModel):
shard_ids: list[int]


class BulkUpgradeCoreResponse(BaseModel):
upgraded: int
skipped: int


class InvalidShardStatus(Exception):
pass

Expand All @@ -206,12 +265,6 @@ class ShardLifecycleEventDb(BaseModel):
error_traceback: str | None = None


class ShardLifecycleEventResponse(BaseModel):
id: int
shard_id: int
timestamp: datetime
status_to: ShardStatus
actor: str
details: str | None = None
error_message: str | None = None
error_traceback: str | None = None
class ShardLifecycleEventResponse(ShardLifecycleEventDb):
actor_owner_name: str | None = None
actor_db_id: int | None = None
65 changes: 65 additions & 0 deletions shard_core/data_model/backend/subscription_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
# DO NOT MODIFY - copied from freeshard-controller

from datetime import datetime
from enum import StrEnum

from pydantic import BaseModel


class SubscriptionStatus(StrEnum):
ACTIVE = "active"
GRACE = "grace"
ENDED = "ended"
ERROR = "error"


class SubscriptionBase(BaseModel):
paypal_subscription_id: str
status: SubscriptionStatus
payer_email: str | None = None
payer_name: str | None = None
price_cents: int
currency: str = "EUR"
next_billing_date: datetime | None = None
last_payment_failed_at: datetime | None = None
created: datetime
activated: datetime | None = None
ended: datetime | None = None


class SubscriptionDb(SubscriptionBase):
id: int


class SubscriptionCreateDb(BaseModel):
paypal_subscription_id: str
status: SubscriptionStatus
price_cents: int
currency: str = "EUR"
payer_email: str | None = None
payer_name: str | None = None
next_billing_date: datetime | None = None
created: datetime
activated: datetime | None = None


class SubscriptionUpdateDb(BaseModel):
status: SubscriptionStatus | None = None
payer_email: str | None = None
payer_name: str | None = None
price_cents: int | None = None
next_billing_date: datetime | None = None
last_payment_failed_at: datetime | None = None
activated: datetime | None = None
ended: datetime | None = None


class SubscribeResponse(BaseModel):
approval_url: str
expected_price_cents: int


class ResizeResponse(BaseModel):
approval_url: str | None = None
expected_price_cents: int | None = None
current_price_cents: int | None = None
7 changes: 6 additions & 1 deletion shard_core/data_model/profile.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,10 @@

from pydantic import BaseModel

from shard_core.data_model.backend.shard_model import ShardBase
from shard_core.data_model.backend.shard_model import (
ShardBase,
ShardSubscriptionSummary,
)
from shard_core.database.database import set_value, get_value
from shard_core.data_model.app_meta import VMSize

Expand All @@ -17,6 +20,7 @@ class Profile(BaseModel):
delete_after: Optional[datetime] = None
vm_size: VMSize
max_vm_size: Optional[VMSize] = None
subscription: Optional[ShardSubscriptionSummary] = None

@classmethod
def from_shard(cls, shard: ShardBase):
Expand All @@ -31,6 +35,7 @@ def from_shard(cls, shard: ShardBase):
max_vm_size=(
VMSize(shard.max_vm_size.value.lower()) if shard.max_vm_size else None
),
subscription=getattr(shard, "subscription", None),
)


Expand Down
4 changes: 2 additions & 2 deletions shard_core/service/portal_controller.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import logging
from requests.exceptions import HTTPError

from shard_core.data_model.backend.shard_model import ShardBase, SasUrlResponse
from shard_core.data_model.backend.shard_model import ShardResponse, SasUrlResponse
from shard_core.data_model import profile
from shard_core.settings import settings
from shard_core.service.signed_call import signed_request
Expand Down Expand Up @@ -29,7 +29,7 @@ async def refresh_profile() -> profile.Profile | None:
return None
else:
raise
shard = ShardBase.model_validate(response.json())
shard = ShardResponse.model_validate(response.json())
profile_ = profile.Profile.from_shard(shard)
await profile.set_profile(profile_)
log.debug("refreshed profile")
Expand Down
3 changes: 2 additions & 1 deletion shard_core/web/management/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from fastapi import APIRouter

from . import apps, pairing_code
from . import apps, notify, pairing_code

router = APIRouter(
prefix="/management",
tags=["/management"],
)

router.include_router(apps.router)
router.include_router(notify.router)
router.include_router(pairing_code.router)
24 changes: 24 additions & 0 deletions shard_core/web/management/notify.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
import logging

from fastapi import APIRouter, status
from pydantic import BaseModel

from shard_core.service.websocket import ws_worker

log = logging.getLogger(__name__)

router = APIRouter(
prefix="/notify",
)


class NotifyRequest(BaseModel):
type: str


@router.post("", status_code=status.HTTP_204_NO_CONTENT)
async def notify(body: NotifyRequest):
try:
ws_worker.broadcast_message(body.type)
except Exception:
log.exception("failed to broadcast notify message of type %s", body.type)
24 changes: 22 additions & 2 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@
ShardStatus,
VmSize,
ShardDb,
ShardResponse,
ShardSubscriptionSummary,
Cloud,
)
from shard_core import app_factory
Expand Down Expand Up @@ -261,8 +263,26 @@ async def app_client(mocker) -> AsyncGenerator[AsyncClient]:
)


def _shard_self_response_body(
shard: ShardDb, subscription: ShardSubscriptionSummary | None = None
) -> str:
now = datetime.now()
return ShardResponse(
**shard.model_dump(),
telemetry=[],
telemetry_start=now,
telemetry_end=now,
subscription=subscription,
).model_dump_json()


@contextmanager
def requests_mock_context(*, shard: ShardDb = None, profile: Profile = None):
def requests_mock_context(
*,
shard: ShardDb = None,
profile: Profile = None,
subscription: ShardSubscriptionSummary | None = None,
):
from shard_core.settings import settings

management_api = "https://management-mock"
Expand Down Expand Up @@ -296,7 +316,7 @@ def requests_mock_context(*, shard: ShardDb = None, profile: Profile = None):
)
rsps.get(
f"{controller_base_url}/api/shards/self",
body=(shard or mock_shard).model_dump_json(),
body=_shard_self_response_body(shard or mock_shard, subscription),
)
rsps.post(f"{controller_base_url}/api/feedback")
rsps.get(f"{controller_base_url}/api/foo")
Expand Down
Loading
Loading