Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/graphon/dsl/entities.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import abc
from collections.abc import Mapping
from enum import StrEnum, auto
from typing import Any, Protocol
Expand Down Expand Up @@ -125,4 +126,5 @@ def loadable(self) -> bool:


class TypedNodeFactory(Protocol):
@abc.abstractmethod
def create_node(self, node_config: NodeConfigDict) -> Any: ...
10 changes: 9 additions & 1 deletion src/graphon/dsl/slim/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import json
from collections.abc import Generator, Iterable, Mapping, Sequence
from dataclasses import dataclass, field
from typing import Any, Literal, overload, override
from typing import TYPE_CHECKING, Any, Literal, overload, override

from pydantic import StrictStr, TypeAdapter, ValidationError

Expand Down Expand Up @@ -682,3 +682,11 @@ def _parse_optional_llm_usage(payload: object) -> LLMUsage | None:
raise TypeError(msg)
normalized_payload[key] = value
return LLMUsage.from_metadata(normalized_payload)


if TYPE_CHECKING:
# static assertion to ensure SlimLLM implements LLMProtocol.
def _assert_slim_llm_protocol(
runtime: SlimLLM,
) -> LLMProtocol: # pyright: ignore[reportUnusedFunction]
return runtime
11 changes: 10 additions & 1 deletion src/graphon/dsl/tool_runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from collections.abc import Callable, Generator, Iterable, Mapping, Sequence
from dataclasses import dataclass, field
from enum import StrEnum
from typing import Any, cast
from typing import TYPE_CHECKING, Any, cast

from graphon.file.models import File
from graphon.model_runtime.entities.llm_entities import LLMUsage
Expand Down Expand Up @@ -896,3 +896,12 @@ def _decode_blob(value: object) -> bytes:
raise ToolNodeError(msg) from error
msg = "Slim blob payload must be bytes or base64 text."
raise ToolNodeError(msg)


if TYPE_CHECKING:
# static assertion to ensure SlimToolNodeRuntime implements
# ToolNodeRuntimeProtocol.
def _assert_slim_tool_node_runtime_protocol(
runtime: SlimToolNodeRuntime,
) -> ToolNodeRuntimeProtocol: # pyright: ignore[reportUnusedFunction]
return runtime
9 changes: 9 additions & 0 deletions src/graphon/file/protocols.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import abc
from collections.abc import Generator
from typing import TYPE_CHECKING, Literal, Protocol

Expand All @@ -18,26 +19,32 @@ class WorkflowFileRuntimeProtocol(Protocol):
"""

@property
@abc.abstractmethod
def multimodal_send_format(self) -> str: ...

@abc.abstractmethod
def http_get(
self,
url: str,
*,
follow_redirects: bool = True,
) -> HttpResponseProtocol: ...

@abc.abstractmethod
def storage_load(self, path: str, *, stream: bool = False) -> bytes | Generator: ...

@abc.abstractmethod
def load_file_bytes(self, *, file: File) -> bytes: ...

@abc.abstractmethod
def resolve_file_url(
self,
*,
file: File,
for_external: bool = True,
) -> str | None: ...

@abc.abstractmethod
def resolve_upload_file_url(
self,
*,
Expand All @@ -46,6 +53,7 @@ def resolve_upload_file_url(
for_external: bool = True,
) -> str: ...

@abc.abstractmethod
def resolve_tool_file_url(
self,
*,
Expand All @@ -54,6 +62,7 @@ def resolve_tool_file_url(
for_external: bool = True,
) -> str: ...

@abc.abstractmethod
def verify_preview_signature(
self,
*,
Expand Down
2 changes: 2 additions & 0 deletions src/graphon/graph/graph.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import abc
import logging
from collections import defaultdict
from collections.abc import Mapping, Sequence
Expand Down Expand Up @@ -27,6 +28,7 @@ class NodeFactory(Protocol):
allowing for different node creation strategies while maintaining type safety.
"""

@abc.abstractmethod
def create_node(self, node_config: NodeConfigDict) -> Node:
"""Create a Node instance from node configuration data.

Expand Down
2 changes: 2 additions & 0 deletions src/graphon/graph/validation.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import abc
from collections.abc import Sequence
from dataclasses import dataclass
from typing import TYPE_CHECKING, Protocol
Expand Down Expand Up @@ -34,6 +35,7 @@ def __init__(self, issues: Sequence[GraphValidationIssue]) -> None:
class GraphValidationRule(Protocol):
"""Protocol that individual validation rules must satisfy."""

@abc.abstractmethod
def validate(self, graph: Graph) -> Sequence[GraphValidationIssue]:
"""Validate the provided graph and return any discovered issues."""
...
Expand Down
12 changes: 11 additions & 1 deletion src/graphon/graph_engine/command_channels/in_memory_channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
"""

from queue import Empty, Queue
from typing import final
from typing import TYPE_CHECKING, final

from ..entities.commands import GraphEngineCommand

Expand Down Expand Up @@ -49,3 +49,13 @@ def send_command(self, command: GraphEngineCommand) -> None:

"""
self._queue.put(command)


if TYPE_CHECKING:
from .protocol import CommandChannel

# static assertion to ensure InMemoryChannel implements CommandChannel.
def _assert_command_channel(
channel: InMemoryChannel,
) -> CommandChannel: # pyright: ignore[reportUnusedFunction]
return channel
3 changes: 3 additions & 0 deletions src/graphon/graph_engine/command_channels/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
to/from a GraphEngine instance, supporting both local and distributed scenarios.
"""

import abc
from typing import Protocol

from ..entities.commands import GraphEngineCommand
Expand All @@ -16,6 +17,7 @@ class CommandChannel(Protocol):
this channel is dedicated to that single execution.
"""

@abc.abstractmethod
def fetch_commands(self) -> list[GraphEngineCommand]:
"""Fetch pending commands for this GraphEngine instance.

Expand All @@ -27,6 +29,7 @@ def fetch_commands(self) -> list[GraphEngineCommand]:
"""
...

@abc.abstractmethod
def send_command(self, command: GraphEngineCommand) -> None:
"""Send a command to be processed by this GraphEngine instance.

Expand Down
21 changes: 20 additions & 1 deletion src/graphon/graph_engine/command_channels/redis_channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,10 @@
Each instance uses a unique key for its command queue.
"""

import abc
import json
from contextlib import AbstractContextManager
from typing import Any, Protocol, final
from typing import TYPE_CHECKING, Any, Protocol, final

from ..entities.commands import (
AbortCommand,
Expand All @@ -27,18 +28,26 @@
class RedisPipelineProtocol(Protocol):
"""Minimal Redis pipeline contract used by the command channel."""

@abc.abstractmethod
def lrange(self, name: str, start: int, end: int) -> Any: ...
@abc.abstractmethod
def delete(self, *names: str) -> Any: ...
@abc.abstractmethod
def execute(self) -> list[Any]: ...
@abc.abstractmethod
def rpush(self, name: str, *values: str) -> Any: ...
@abc.abstractmethod
def expire(self, name: str, time: int) -> Any: ...
@abc.abstractmethod
def set(self, name: str, value: str, ex: int | None = None) -> Any: ...
@abc.abstractmethod
def get(self, name: str) -> Any: ...


class RedisClientProtocol(Protocol):
"""Redis client contract required by the command channel."""

@abc.abstractmethod
def pipeline(self) -> AbstractContextManager[RedisPipelineProtocol]: ...


Expand Down Expand Up @@ -160,3 +169,13 @@ def _has_pending_commands(self) -> bool:
pending_value, _ = pipe.execute()

return pending_value is not None


if TYPE_CHECKING:
from .protocol import CommandChannel

# static assertion to ensure RedisChannel implements CommandChannel.
def _assert_command_channel(
channel: RedisChannel,
) -> CommandChannel: # pyright: ignore[reportUnusedFunction]
return channel
20 changes: 19 additions & 1 deletion src/graphon/graph_engine/command_processing/command_handlers.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import logging
from typing import final, override
from typing import TYPE_CHECKING, final, override

from graphon.entities.pause_reason import SchedulingPause
from graphon.runtime.graph_runtime_state import GraphExecutionProtocol
Expand Down Expand Up @@ -69,3 +69,21 @@ def handle(
execution.workflow_id,
exc,
)


if TYPE_CHECKING:
# static assertions to ensure command handlers implement CommandHandler.
def _assert_abort_handler_protocol(
handler: AbortCommandHandler,
) -> CommandHandler[AbortCommand]: # pyright: ignore[reportUnusedFunction]
return handler

def _assert_pause_handler_protocol(
handler: PauseCommandHandler,
) -> CommandHandler[PauseCommand]: # pyright: ignore[reportUnusedFunction]
return handler

def _assert_update_variables_handler_protocol(
handler: UpdateVariablesCommandHandler,
) -> CommandHandler[UpdateVariablesCommand]: # pyright: ignore[reportUnusedFunction]
return handler
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Main command processor for handling external commands."""

import abc
import logging
from collections.abc import Callable
from typing import Protocol, final
Expand All @@ -15,6 +16,7 @@
class CommandHandler[CommandT: GraphEngineCommand](Protocol):
"""Protocol for command handlers."""

@abc.abstractmethod
def handle(
self,
command: CommandT,
Expand Down
11 changes: 11 additions & 0 deletions src/graphon/graph_engine/domain/node_execution.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""NodeExecution entity representing a node's execution state."""

from dataclasses import dataclass
from typing import TYPE_CHECKING

from graphon.enums import NodeState

Expand Down Expand Up @@ -40,3 +41,13 @@ def mark_skipped(self) -> None:
def increment_retry(self) -> None:
"""Increment the retry count for this node."""
self.retry_count += 1


if TYPE_CHECKING:
from graphon.runtime.graph_runtime_state import NodeExecutionProtocol

# static assertion to ensure NodeExecution implements NodeExecutionProtocol.
def _assert_node_execution_protocol(
execution: NodeExecution,
) -> NodeExecutionProtocol: # pyright: ignore[reportUnusedFunction]
return execution
8 changes: 7 additions & 1 deletion src/graphon/graph_engine/ready_queue/in_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
"""

import queue
from typing import final
from typing import TYPE_CHECKING, final

from .protocol import ReadyQueue, ReadyQueueState

Expand Down Expand Up @@ -137,3 +137,9 @@ def loads(self, data: str) -> None:
# Restore items
for item in state.items:
self._queue.put(item)


if TYPE_CHECKING:
# static assertion to ensure InMemoryReadyQueue implements ReadyQueue.
def _assert_ready_queue(queue_impl: InMemoryReadyQueue) -> ReadyQueue: # pyright: ignore[reportUnusedFunction]
return queue_impl
8 changes: 8 additions & 0 deletions src/graphon/graph_engine/ready_queue/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
for execution, supporting both in-memory and persistent storage scenarios.
"""

import abc
from collections.abc import Sequence
from typing import Protocol

Expand Down Expand Up @@ -35,6 +36,7 @@ class ReadyQueue(Protocol):
that can be serialized for state storage.
"""

@abc.abstractmethod
def put(self, item: str) -> None:
"""Add a node ID to the ready queue.

Expand All @@ -44,6 +46,7 @@ def put(self, item: str) -> None:
"""
...

@abc.abstractmethod
def get(self, timeout: float | None = None) -> str:
"""Retrieve and remove a node ID from the queue.

Expand All @@ -56,6 +59,7 @@ def get(self, timeout: float | None = None) -> str:
"""
...

@abc.abstractmethod
def task_done(self) -> None:
"""Indicate that a previously retrieved task is complete.

Expand All @@ -64,6 +68,7 @@ def task_done(self) -> None:
"""
...

@abc.abstractmethod
def empty(self) -> bool:
"""Check if the queue is empty.

Expand All @@ -73,6 +78,7 @@ def empty(self) -> bool:
"""
...

@abc.abstractmethod
def qsize(self) -> int:
"""Get the approximate size of the queue.

Expand All @@ -82,6 +88,7 @@ def qsize(self) -> int:
"""
...

@abc.abstractmethod
def dumps(self) -> str:
"""Serialize the queue state to a JSON string for storage.

Expand All @@ -92,6 +99,7 @@ def dumps(self) -> str:
"""
...

@abc.abstractmethod
def loads(self, data: str) -> None:
"""Restore the queue state from a JSON string.

Expand Down
Loading
Loading