diff --git a/shard_core/data_model/backend/api_token_model.py b/shard_core/data_model/backend/api_token_model.py index 981b745..86bbd7c 100644 --- a/shard_core/data_model/backend/api_token_model.py +++ b/shard_core/data_model/backend/api_token_model.py @@ -9,7 +9,7 @@ from .permission_model import PermissionHolder, Permission -class ApiTokenResult(PermissionHolder, BaseModel): +class ApiTokenDb(BaseModel): id: uuid.UUID name: str created: datetime @@ -18,6 +18,10 @@ class ApiTokenResult(PermissionHolder, BaseModel): owner_name: str | None +class ApiTokenResult(ApiTokenDb, PermissionHolder): + pass + + class ApiTokenCreate(BaseModel): name: str permissions: Set[Permission] diff --git a/shard_core/data_model/backend/settings_model.py b/shard_core/data_model/backend/settings_model.py index fa72d54..ed04032 100644 --- a/shard_core/data_model/backend/settings_model.py +++ b/shard_core/data_model/backend/settings_model.py @@ -12,6 +12,10 @@ class SettingKey(StrEnum): MIN_NR_OF_STANDBY_SHARDS = auto() NEW_INSTANCE_SIZE = auto() + AUTO_PROVISIONING_ENABLED = auto() + NEW_SHARD_CORE_VERSION = auto() + TRIAL_MAX_VM_SIZE = auto() + SUBSCRIBED_MAX_VM_SIZE = auto() class Setting(BaseModel): diff --git a/shard_core/data_model/backend/shard_model.py b/shard_core/data_model/backend/shard_model.py index 54bfca3..e01ac15 100644 --- a/shard_core/data_model/backend/shard_model.py +++ b/shard_core/data_model/backend/shard_model.py @@ -2,12 +2,12 @@ from datetime import datetime from enum import StrEnum, auto -from functools import total_ordering from typing import List from pydantic import BaseModel from .permission_model import PermissionHolder +from .subscription_model import SubscriptionStatus from .telemetry_model import Telemetry @@ -32,7 +32,6 @@ class ShardStatus(StrEnum): _VM_SIZE_ORDER = ["xs", "s", "m", "l", "xl"] -@total_ordering class VmSize(StrEnum): XS = auto() S = auto() @@ -40,10 +39,28 @@ class VmSize(StrEnum): L = auto() XL = auto() + def _idx(self) -> int: + return _VM_SIZE_ORDER.index(self.value) + def __lt__(self, other: object) -> bool: if not isinstance(other, VmSize): return NotImplemented - return _VM_SIZE_ORDER.index(self.value) < _VM_SIZE_ORDER.index(other.value) + return self._idx() < other._idx() + + def __le__(self, other: object) -> bool: + if not isinstance(other, VmSize): + return NotImplemented + return self._idx() <= other._idx() + + def __gt__(self, other: object) -> bool: + if not isinstance(other, VmSize): + return NotImplemented + return self._idx() > other._idx() + + def __ge__(self, other: object) -> bool: + if not isinstance(other, VmSize): + return NotImplemented + return self._idx() >= other._idx() def __eq__(self, other: object) -> bool: return str.__eq__(self, other) @@ -58,7 +75,7 @@ class Cloud(StrEnum): OVHCLOUD = auto() -class ShardBase(PermissionHolder, BaseModel): +class ShardBase(BaseModel): machine_id: str | None hash_id: str | None = None domain: str | None = None @@ -81,6 +98,10 @@ class ShardBase(PermissionHolder, BaseModel): auto_managed: bool = True core_version: str | None = None last_seen_backup: datetime | None = None + subscription_id: int | None = None + price_cents: int | None = None + pending_vm_size: VmSize | None = None + pending_price_cents: int | None = None @property def short_id(self) -> str: @@ -91,8 +112,28 @@ class ShardDb(ShardBase): id: int -class ShardResponse(ShardDb): +class ShardWithPermissions(ShardDb, PermissionHolder): + """ShardDb enriched with permissions loaded from the DB.""" + + pass + + +class ShardSubscriptionSummary(BaseModel): + status: SubscriptionStatus + price_cents: int + currency: str + next_billing_date: datetime | None = None + payer_email: str | None = None + pending_vm_size: VmSize | None = None + pending_price_cents: int | None = None + paypal_manage_url: str + + +class ShardResponse(ShardWithPermissions): telemetry: List[Telemetry] + telemetry_start: datetime + telemetry_end: datetime + subscription: ShardSubscriptionSummary | None = None class ShardUpdate(BaseModel): @@ -100,6 +141,7 @@ class ShardUpdate(BaseModel): max_vm_size: VmSize | None = None delete_after: datetime | None = None status: ShardStatus | None = None + core_version: str | None = None class ShardCreateDb(BaseModel): @@ -131,6 +173,10 @@ class ShardUpdateDb(BaseModel): time_assigned: datetime | None = None core_version: str | None = None last_seen_backup: datetime | None = None + subscription_id: int | None = None + price_cents: int | None = None + pending_vm_size: VmSize | None = None + pending_price_cents: int | None = None class AppUsageReport(BaseModel): @@ -157,6 +203,10 @@ class AssignShardRequest(BaseModel): owner_email: str | None = None +class AssignTrialRequest(BaseModel): + owner_email: str | None = None + + class AssignShardResponse(BaseModel): hash_id: str domain: str @@ -191,6 +241,15 @@ class AddPubkeyRequest(BaseModel): pubkey: str +class BulkUpgradeCoreRequest(BaseModel): + shard_ids: list[int] + + +class BulkUpgradeCoreResponse(BaseModel): + upgraded: int + skipped: int + + class InvalidShardStatus(Exception): pass @@ -206,12 +265,6 @@ class ShardLifecycleEventDb(BaseModel): error_traceback: str | None = None -class ShardLifecycleEventResponse(BaseModel): - id: int - shard_id: int - timestamp: datetime - status_to: ShardStatus - actor: str - details: str | None = None - error_message: str | None = None - error_traceback: str | None = None +class ShardLifecycleEventResponse(ShardLifecycleEventDb): + actor_owner_name: str | None = None + actor_db_id: int | None = None diff --git a/shard_core/data_model/backend/subscription_model.py b/shard_core/data_model/backend/subscription_model.py new file mode 100644 index 0000000..bcdf10e --- /dev/null +++ b/shard_core/data_model/backend/subscription_model.py @@ -0,0 +1,65 @@ +# DO NOT MODIFY - copied from freeshard-controller + +from datetime import datetime +from enum import StrEnum + +from pydantic import BaseModel + + +class SubscriptionStatus(StrEnum): + ACTIVE = "active" + GRACE = "grace" + ENDED = "ended" + ERROR = "error" + + +class SubscriptionBase(BaseModel): + paypal_subscription_id: str + status: SubscriptionStatus + payer_email: str | None = None + payer_name: str | None = None + price_cents: int + currency: str = "EUR" + next_billing_date: datetime | None = None + last_payment_failed_at: datetime | None = None + created: datetime + activated: datetime | None = None + ended: datetime | None = None + + +class SubscriptionDb(SubscriptionBase): + id: int + + +class SubscriptionCreateDb(BaseModel): + paypal_subscription_id: str + status: SubscriptionStatus + price_cents: int + currency: str = "EUR" + payer_email: str | None = None + payer_name: str | None = None + next_billing_date: datetime | None = None + created: datetime + activated: datetime | None = None + + +class SubscriptionUpdateDb(BaseModel): + status: SubscriptionStatus | None = None + payer_email: str | None = None + payer_name: str | None = None + price_cents: int | None = None + next_billing_date: datetime | None = None + last_payment_failed_at: datetime | None = None + activated: datetime | None = None + ended: datetime | None = None + + +class SubscribeResponse(BaseModel): + approval_url: str + expected_price_cents: int + + +class ResizeResponse(BaseModel): + approval_url: str | None = None + expected_price_cents: int | None = None + current_price_cents: int | None = None diff --git a/shard_core/data_model/profile.py b/shard_core/data_model/profile.py index a6dce1d..c7064e1 100644 --- a/shard_core/data_model/profile.py +++ b/shard_core/data_model/profile.py @@ -3,7 +3,10 @@ from pydantic import BaseModel -from shard_core.data_model.backend.shard_model import ShardBase +from shard_core.data_model.backend.shard_model import ( + ShardBase, + ShardSubscriptionSummary, +) from shard_core.database.database import set_value, get_value from shard_core.data_model.app_meta import VMSize @@ -17,6 +20,7 @@ class Profile(BaseModel): delete_after: Optional[datetime] = None vm_size: VMSize max_vm_size: Optional[VMSize] = None + subscription: Optional[ShardSubscriptionSummary] = None @classmethod def from_shard(cls, shard: ShardBase): @@ -31,6 +35,7 @@ def from_shard(cls, shard: ShardBase): max_vm_size=( VMSize(shard.max_vm_size.value.lower()) if shard.max_vm_size else None ), + subscription=getattr(shard, "subscription", None), ) diff --git a/shard_core/service/portal_controller.py b/shard_core/service/portal_controller.py index 52768d3..26ab5f2 100644 --- a/shard_core/service/portal_controller.py +++ b/shard_core/service/portal_controller.py @@ -1,7 +1,7 @@ import logging from requests.exceptions import HTTPError -from shard_core.data_model.backend.shard_model import ShardBase, SasUrlResponse +from shard_core.data_model.backend.shard_model import ShardResponse, SasUrlResponse from shard_core.data_model import profile from shard_core.settings import settings from shard_core.service.signed_call import signed_request @@ -29,7 +29,7 @@ async def refresh_profile() -> profile.Profile | None: return None else: raise - shard = ShardBase.model_validate(response.json()) + shard = ShardResponse.model_validate(response.json()) profile_ = profile.Profile.from_shard(shard) await profile.set_profile(profile_) log.debug("refreshed profile") diff --git a/shard_core/web/management/__init__.py b/shard_core/web/management/__init__.py index 7055c80..3d76a35 100644 --- a/shard_core/web/management/__init__.py +++ b/shard_core/web/management/__init__.py @@ -1,6 +1,6 @@ from fastapi import APIRouter -from . import apps, pairing_code +from . import apps, notify, pairing_code router = APIRouter( prefix="/management", @@ -8,4 +8,5 @@ ) router.include_router(apps.router) +router.include_router(notify.router) router.include_router(pairing_code.router) diff --git a/shard_core/web/management/notify.py b/shard_core/web/management/notify.py new file mode 100644 index 0000000..68b2ad0 --- /dev/null +++ b/shard_core/web/management/notify.py @@ -0,0 +1,24 @@ +import logging + +from fastapi import APIRouter, status +from pydantic import BaseModel + +from shard_core.service.websocket import ws_worker + +log = logging.getLogger(__name__) + +router = APIRouter( + prefix="/notify", +) + + +class NotifyRequest(BaseModel): + type: str + + +@router.post("", status_code=status.HTTP_204_NO_CONTENT) +async def notify(body: NotifyRequest): + try: + ws_worker.broadcast_message(body.type) + except Exception: + log.exception("failed to broadcast notify message of type %s", body.type) diff --git a/tests/conftest.py b/tests/conftest.py index e6e6707..66ed62d 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -30,6 +30,8 @@ ShardStatus, VmSize, ShardDb, + ShardResponse, + ShardSubscriptionSummary, Cloud, ) from shard_core import app_factory @@ -261,8 +263,26 @@ async def app_client(mocker) -> AsyncGenerator[AsyncClient]: ) +def _shard_self_response_body( + shard: ShardDb, subscription: ShardSubscriptionSummary | None = None +) -> str: + now = datetime.now() + return ShardResponse( + **shard.model_dump(), + telemetry=[], + telemetry_start=now, + telemetry_end=now, + subscription=subscription, + ).model_dump_json() + + @contextmanager -def requests_mock_context(*, shard: ShardDb = None, profile: Profile = None): +def requests_mock_context( + *, + shard: ShardDb = None, + profile: Profile = None, + subscription: ShardSubscriptionSummary | None = None, +): from shard_core.settings import settings management_api = "https://management-mock" @@ -296,7 +316,7 @@ def requests_mock_context(*, shard: ShardDb = None, profile: Profile = None): ) rsps.get( f"{controller_base_url}/api/shards/self", - body=(shard or mock_shard).model_dump_json(), + body=_shard_self_response_body(shard or mock_shard, subscription), ) rsps.post(f"{controller_base_url}/api/feedback") rsps.get(f"{controller_base_url}/api/foo") diff --git a/tests/test_notify.py b/tests/test_notify.py new file mode 100644 index 0000000..545a700 --- /dev/null +++ b/tests/test_notify.py @@ -0,0 +1,36 @@ +from httpx import AsyncClient + + +async def test_notify_broadcasts_message_type(app_client: AsyncClient, mocker): + mock_broadcast = mocker.patch( + "shard_core.web.management.notify.ws_worker.broadcast_message" + ) + response = await app_client.post( + "management/notify", + json={"type": "subscription_updated"}, + headers={"authorization": "constantSharedSecret"}, + ) + assert response.status_code == 204 + mock_broadcast.assert_called_once_with("subscription_updated") + + +async def test_notify_swallows_broadcast_errors(app_client: AsyncClient, mocker): + mocker.patch( + "shard_core.web.management.notify.ws_worker.broadcast_message", + side_effect=RuntimeError("queue exploded"), + ) + response = await app_client.post( + "management/notify", + json={"type": "subscription_updated"}, + headers={"authorization": "constantSharedSecret"}, + ) + assert response.status_code == 204 + + +async def test_notify_rejects_missing_type(app_client: AsyncClient): + response = await app_client.post( + "management/notify", + json={}, + headers={"authorization": "constantSharedSecret"}, + ) + assert response.status_code == 422 diff --git a/tests/test_profile.py b/tests/test_profile.py index d01b076..3034baa 100644 --- a/tests/test_profile.py +++ b/tests/test_profile.py @@ -1,5 +1,9 @@ +from datetime import datetime, timezone + from httpx import AsyncClient +from shard_core.data_model.backend.shard_model import ShardSubscriptionSummary +from shard_core.data_model.backend.subscription_model import SubscriptionStatus from shard_core.data_model.profile import Profile from tests import conftest @@ -10,3 +14,28 @@ async def test_profile(requests_mock, app_client: AsyncClient): assert Profile.model_validate(response.json()) == Profile.from_shard( conftest.mock_shard ) + + +async def test_profile_includes_subscription(app_client: AsyncClient): + subscription = ShardSubscriptionSummary( + status=SubscriptionStatus.ACTIVE, + price_cents=499, + currency="EUR", + next_billing_date=datetime(2026, 7, 1, tzinfo=timezone.utc), + payer_email="payer@example.com", + paypal_manage_url="https://paypal.example/manage/abc", + ) + with conftest.requests_mock_context(subscription=subscription): + response = await app_client.get("protected/management/profile") + response.raise_for_status() + profile = Profile.model_validate(response.json()) + assert profile.subscription == subscription + + +async def test_profile_without_subscription_is_none( + requests_mock, app_client: AsyncClient +): + response = await app_client.get("protected/management/profile") + response.raise_for_status() + profile = Profile.model_validate(response.json()) + assert profile.subscription is None