diff --git a/src/graphon/dsl/entities.py b/src/graphon/dsl/entities.py index 3871b77..d9a17cc 100644 --- a/src/graphon/dsl/entities.py +++ b/src/graphon/dsl/entities.py @@ -1,5 +1,6 @@ from __future__ import annotations +import abc from collections.abc import Mapping from enum import StrEnum, auto from typing import Any, Protocol @@ -125,4 +126,5 @@ def loadable(self) -> bool: class TypedNodeFactory(Protocol): + @abc.abstractmethod def create_node(self, node_config: NodeConfigDict) -> Any: ... diff --git a/src/graphon/dsl/slim/llm.py b/src/graphon/dsl/slim/llm.py index 0bc8131..e31defb 100644 --- a/src/graphon/dsl/slim/llm.py +++ b/src/graphon/dsl/slim/llm.py @@ -3,7 +3,7 @@ import json from collections.abc import Generator, Iterable, Mapping, Sequence from dataclasses import dataclass, field -from typing import Any, Literal, overload, override +from typing import TYPE_CHECKING, Any, Literal, overload, override from pydantic import StrictStr, TypeAdapter, ValidationError @@ -682,3 +682,11 @@ def _parse_optional_llm_usage(payload: object) -> LLMUsage | None: raise TypeError(msg) normalized_payload[key] = value return LLMUsage.from_metadata(normalized_payload) + + +if TYPE_CHECKING: + # static assertion to ensure SlimLLM implements LLMProtocol. + def _assert_slim_llm_protocol( + runtime: SlimLLM, + ) -> LLMProtocol: # pyright: ignore[reportUnusedFunction] + return runtime diff --git a/src/graphon/dsl/tool_runtime.py b/src/graphon/dsl/tool_runtime.py index f389fbf..6623d5a 100644 --- a/src/graphon/dsl/tool_runtime.py +++ b/src/graphon/dsl/tool_runtime.py @@ -4,7 +4,7 @@ from collections.abc import Callable, Generator, Iterable, Mapping, Sequence from dataclasses import dataclass, field from enum import StrEnum -from typing import Any, cast +from typing import TYPE_CHECKING, Any, cast from graphon.file.models import File from graphon.model_runtime.entities.llm_entities import LLMUsage @@ -896,3 +896,12 @@ def _decode_blob(value: object) -> bytes: raise ToolNodeError(msg) from error msg = "Slim blob payload must be bytes or base64 text." raise ToolNodeError(msg) + + +if TYPE_CHECKING: + # static assertion to ensure SlimToolNodeRuntime implements + # ToolNodeRuntimeProtocol. + def _assert_slim_tool_node_runtime_protocol( + runtime: SlimToolNodeRuntime, + ) -> ToolNodeRuntimeProtocol: # pyright: ignore[reportUnusedFunction] + return runtime diff --git a/src/graphon/file/protocols.py b/src/graphon/file/protocols.py index c376c02..1caaecd 100644 --- a/src/graphon/file/protocols.py +++ b/src/graphon/file/protocols.py @@ -1,5 +1,6 @@ from __future__ import annotations +import abc from collections.abc import Generator from typing import TYPE_CHECKING, Literal, Protocol @@ -18,8 +19,10 @@ class WorkflowFileRuntimeProtocol(Protocol): """ @property + @abc.abstractmethod def multimodal_send_format(self) -> str: ... + @abc.abstractmethod def http_get( self, url: str, @@ -27,10 +30,13 @@ def http_get( follow_redirects: bool = True, ) -> HttpResponseProtocol: ... + @abc.abstractmethod def storage_load(self, path: str, *, stream: bool = False) -> bytes | Generator: ... + @abc.abstractmethod def load_file_bytes(self, *, file: File) -> bytes: ... + @abc.abstractmethod def resolve_file_url( self, *, @@ -38,6 +44,7 @@ def resolve_file_url( for_external: bool = True, ) -> str | None: ... + @abc.abstractmethod def resolve_upload_file_url( self, *, @@ -46,6 +53,7 @@ def resolve_upload_file_url( for_external: bool = True, ) -> str: ... + @abc.abstractmethod def resolve_tool_file_url( self, *, @@ -54,6 +62,7 @@ def resolve_tool_file_url( for_external: bool = True, ) -> str: ... + @abc.abstractmethod def verify_preview_signature( self, *, diff --git a/src/graphon/graph/graph.py b/src/graphon/graph/graph.py index ac79390..2eb9481 100644 --- a/src/graphon/graph/graph.py +++ b/src/graphon/graph/graph.py @@ -1,5 +1,6 @@ from __future__ import annotations +import abc import logging from collections import defaultdict from collections.abc import Mapping, Sequence @@ -27,6 +28,7 @@ class NodeFactory(Protocol): allowing for different node creation strategies while maintaining type safety. """ + @abc.abstractmethod def create_node(self, node_config: NodeConfigDict) -> Node: """Create a Node instance from node configuration data. diff --git a/src/graphon/graph/validation.py b/src/graphon/graph/validation.py index b66def8..af3ac0a 100644 --- a/src/graphon/graph/validation.py +++ b/src/graphon/graph/validation.py @@ -1,5 +1,6 @@ from __future__ import annotations +import abc from collections.abc import Sequence from dataclasses import dataclass from typing import TYPE_CHECKING, Protocol @@ -34,6 +35,7 @@ def __init__(self, issues: Sequence[GraphValidationIssue]) -> None: class GraphValidationRule(Protocol): """Protocol that individual validation rules must satisfy.""" + @abc.abstractmethod def validate(self, graph: Graph) -> Sequence[GraphValidationIssue]: """Validate the provided graph and return any discovered issues.""" ... diff --git a/src/graphon/graph_engine/command_channels/in_memory_channel.py b/src/graphon/graph_engine/command_channels/in_memory_channel.py index 557100e..ad91d00 100644 --- a/src/graphon/graph_engine/command_channels/in_memory_channel.py +++ b/src/graphon/graph_engine/command_channels/in_memory_channel.py @@ -5,7 +5,7 @@ """ from queue import Empty, Queue -from typing import final +from typing import TYPE_CHECKING, final from ..entities.commands import GraphEngineCommand @@ -49,3 +49,13 @@ def send_command(self, command: GraphEngineCommand) -> None: """ self._queue.put(command) + + +if TYPE_CHECKING: + from .protocol import CommandChannel + + # static assertion to ensure InMemoryChannel implements CommandChannel. + def _assert_command_channel( + channel: InMemoryChannel, + ) -> CommandChannel: # pyright: ignore[reportUnusedFunction] + return channel diff --git a/src/graphon/graph_engine/command_channels/protocol.py b/src/graphon/graph_engine/command_channels/protocol.py index 10e3e59..956fd30 100644 --- a/src/graphon/graph_engine/command_channels/protocol.py +++ b/src/graphon/graph_engine/command_channels/protocol.py @@ -4,6 +4,7 @@ to/from a GraphEngine instance, supporting both local and distributed scenarios. """ +import abc from typing import Protocol from ..entities.commands import GraphEngineCommand @@ -16,6 +17,7 @@ class CommandChannel(Protocol): this channel is dedicated to that single execution. """ + @abc.abstractmethod def fetch_commands(self) -> list[GraphEngineCommand]: """Fetch pending commands for this GraphEngine instance. @@ -27,6 +29,7 @@ def fetch_commands(self) -> list[GraphEngineCommand]: """ ... + @abc.abstractmethod def send_command(self, command: GraphEngineCommand) -> None: """Send a command to be processed by this GraphEngine instance. diff --git a/src/graphon/graph_engine/command_channels/redis_channel.py b/src/graphon/graph_engine/command_channels/redis_channel.py index 79312dc..8f45d27 100644 --- a/src/graphon/graph_engine/command_channels/redis_channel.py +++ b/src/graphon/graph_engine/command_channels/redis_channel.py @@ -5,9 +5,10 @@ Each instance uses a unique key for its command queue. """ +import abc import json from contextlib import AbstractContextManager -from typing import Any, Protocol, final +from typing import TYPE_CHECKING, Any, Protocol, final from ..entities.commands import ( AbortCommand, @@ -27,18 +28,26 @@ class RedisPipelineProtocol(Protocol): """Minimal Redis pipeline contract used by the command channel.""" + @abc.abstractmethod def lrange(self, name: str, start: int, end: int) -> Any: ... + @abc.abstractmethod def delete(self, *names: str) -> Any: ... + @abc.abstractmethod def execute(self) -> list[Any]: ... + @abc.abstractmethod def rpush(self, name: str, *values: str) -> Any: ... + @abc.abstractmethod def expire(self, name: str, time: int) -> Any: ... + @abc.abstractmethod def set(self, name: str, value: str, ex: int | None = None) -> Any: ... + @abc.abstractmethod def get(self, name: str) -> Any: ... class RedisClientProtocol(Protocol): """Redis client contract required by the command channel.""" + @abc.abstractmethod def pipeline(self) -> AbstractContextManager[RedisPipelineProtocol]: ... @@ -160,3 +169,13 @@ def _has_pending_commands(self) -> bool: pending_value, _ = pipe.execute() return pending_value is not None + + +if TYPE_CHECKING: + from .protocol import CommandChannel + + # static assertion to ensure RedisChannel implements CommandChannel. + def _assert_command_channel( + channel: RedisChannel, + ) -> CommandChannel: # pyright: ignore[reportUnusedFunction] + return channel diff --git a/src/graphon/graph_engine/command_processing/command_handlers.py b/src/graphon/graph_engine/command_processing/command_handlers.py index 745f8e4..978e25b 100644 --- a/src/graphon/graph_engine/command_processing/command_handlers.py +++ b/src/graphon/graph_engine/command_processing/command_handlers.py @@ -1,5 +1,5 @@ import logging -from typing import final, override +from typing import TYPE_CHECKING, final, override from graphon.entities.pause_reason import SchedulingPause from graphon.runtime.graph_runtime_state import GraphExecutionProtocol @@ -69,3 +69,21 @@ def handle( execution.workflow_id, exc, ) + + +if TYPE_CHECKING: + # static assertions to ensure command handlers implement CommandHandler. + def _assert_abort_handler_protocol( + handler: AbortCommandHandler, + ) -> CommandHandler[AbortCommand]: # pyright: ignore[reportUnusedFunction] + return handler + + def _assert_pause_handler_protocol( + handler: PauseCommandHandler, + ) -> CommandHandler[PauseCommand]: # pyright: ignore[reportUnusedFunction] + return handler + + def _assert_update_variables_handler_protocol( + handler: UpdateVariablesCommandHandler, + ) -> CommandHandler[UpdateVariablesCommand]: # pyright: ignore[reportUnusedFunction] + return handler diff --git a/src/graphon/graph_engine/command_processing/command_processor.py b/src/graphon/graph_engine/command_processing/command_processor.py index 99903d9..3be4317 100644 --- a/src/graphon/graph_engine/command_processing/command_processor.py +++ b/src/graphon/graph_engine/command_processing/command_processor.py @@ -1,5 +1,6 @@ """Main command processor for handling external commands.""" +import abc import logging from collections.abc import Callable from typing import Protocol, final @@ -15,6 +16,7 @@ class CommandHandler[CommandT: GraphEngineCommand](Protocol): """Protocol for command handlers.""" + @abc.abstractmethod def handle( self, command: CommandT, diff --git a/src/graphon/graph_engine/domain/node_execution.py b/src/graphon/graph_engine/domain/node_execution.py index 17270ee..aeb25c4 100644 --- a/src/graphon/graph_engine/domain/node_execution.py +++ b/src/graphon/graph_engine/domain/node_execution.py @@ -1,6 +1,7 @@ """NodeExecution entity representing a node's execution state.""" from dataclasses import dataclass +from typing import TYPE_CHECKING from graphon.enums import NodeState @@ -40,3 +41,13 @@ def mark_skipped(self) -> None: def increment_retry(self) -> None: """Increment the retry count for this node.""" self.retry_count += 1 + + +if TYPE_CHECKING: + from graphon.runtime.graph_runtime_state import NodeExecutionProtocol + + # static assertion to ensure NodeExecution implements NodeExecutionProtocol. + def _assert_node_execution_protocol( + execution: NodeExecution, + ) -> NodeExecutionProtocol: # pyright: ignore[reportUnusedFunction] + return execution diff --git a/src/graphon/graph_engine/ready_queue/in_memory.py b/src/graphon/graph_engine/ready_queue/in_memory.py index 19dee5a..5677880 100644 --- a/src/graphon/graph_engine/ready_queue/in_memory.py +++ b/src/graphon/graph_engine/ready_queue/in_memory.py @@ -5,7 +5,7 @@ """ import queue -from typing import final +from typing import TYPE_CHECKING, final from .protocol import ReadyQueue, ReadyQueueState @@ -137,3 +137,9 @@ def loads(self, data: str) -> None: # Restore items for item in state.items: self._queue.put(item) + + +if TYPE_CHECKING: + # static assertion to ensure InMemoryReadyQueue implements ReadyQueue. + def _assert_ready_queue(queue_impl: InMemoryReadyQueue) -> ReadyQueue: # pyright: ignore[reportUnusedFunction] + return queue_impl diff --git a/src/graphon/graph_engine/ready_queue/protocol.py b/src/graphon/graph_engine/ready_queue/protocol.py index 6c53677..cd1cb55 100644 --- a/src/graphon/graph_engine/ready_queue/protocol.py +++ b/src/graphon/graph_engine/ready_queue/protocol.py @@ -4,6 +4,7 @@ for execution, supporting both in-memory and persistent storage scenarios. """ +import abc from collections.abc import Sequence from typing import Protocol @@ -35,6 +36,7 @@ class ReadyQueue(Protocol): that can be serialized for state storage. """ + @abc.abstractmethod def put(self, item: str) -> None: """Add a node ID to the ready queue. @@ -44,6 +46,7 @@ def put(self, item: str) -> None: """ ... + @abc.abstractmethod def get(self, timeout: float | None = None) -> str: """Retrieve and remove a node ID from the queue. @@ -56,6 +59,7 @@ def get(self, timeout: float | None = None) -> str: """ ... + @abc.abstractmethod def task_done(self) -> None: """Indicate that a previously retrieved task is complete. @@ -64,6 +68,7 @@ def task_done(self) -> None: """ ... + @abc.abstractmethod def empty(self) -> bool: """Check if the queue is empty. @@ -73,6 +78,7 @@ def empty(self) -> bool: """ ... + @abc.abstractmethod def qsize(self) -> int: """Get the approximate size of the queue. @@ -82,6 +88,7 @@ def qsize(self) -> int: """ ... + @abc.abstractmethod def dumps(self) -> str: """Serialize the queue state to a JSON string for storage. @@ -92,6 +99,7 @@ def dumps(self) -> str: """ ... + @abc.abstractmethod def loads(self, data: str) -> None: """Restore the queue state from a JSON string. diff --git a/src/graphon/graph_engine/response_coordinator/coordinator.py b/src/graphon/graph_engine/response_coordinator/coordinator.py index 0216b40..89a4e78 100644 --- a/src/graphon/graph_engine/response_coordinator/coordinator.py +++ b/src/graphon/graph_engine/response_coordinator/coordinator.py @@ -791,3 +791,14 @@ def loads(self, data: str) -> None: if state.active_session else None ) + + +if TYPE_CHECKING: + from graphon.runtime.graph_runtime_state import ResponseStreamCoordinatorProtocol + + # static assertion to ensure ResponseStreamCoordinator implements + # ResponseStreamCoordinatorProtocol. + def _assert_response_stream_coordinator_protocol( + coordinator: ResponseStreamCoordinator, + ) -> ResponseStreamCoordinatorProtocol: # pyright: ignore[reportUnusedFunction] + return coordinator diff --git a/src/graphon/http/client.py b/src/graphon/http/client.py index a258deb..514a397 100644 --- a/src/graphon/http/client.py +++ b/src/graphon/http/client.py @@ -1,7 +1,7 @@ from __future__ import annotations from collections.abc import Sequence -from typing import Any, override +from typing import TYPE_CHECKING, Any, override import httpx @@ -102,3 +102,11 @@ def _normalize_request_kwargs(kwargs: dict[str, Any]) -> dict[str, Any]: ) return request_kwargs + + +if TYPE_CHECKING: + # static assertion to ensure HttpxHttpClient implements HttpClientProtocol. + def _assert_http_client_protocol( + client: HttpxHttpClient, + ) -> HttpClientProtocol: # pyright: ignore[reportUnusedFunction] + return client diff --git a/src/graphon/http/protocols.py b/src/graphon/http/protocols.py index 505dfa0..ce31c93 100644 --- a/src/graphon/http/protocols.py +++ b/src/graphon/http/protocols.py @@ -1,3 +1,4 @@ +import abc from collections.abc import Mapping from typing import Any, Protocol @@ -6,38 +7,51 @@ class HttpResponseProtocol(Protocol): @property + @abc.abstractmethod def headers(self) -> Mapping[str, str]: ... @property + @abc.abstractmethod def content(self) -> bytes: ... @property + @abc.abstractmethod def status_code(self) -> int: ... @property + @abc.abstractmethod def text(self) -> str: ... @property + @abc.abstractmethod def is_success(self) -> bool: ... + @abc.abstractmethod def raise_for_status(self) -> None: ... class HttpClientProtocol(Protocol): @property + @abc.abstractmethod def max_retries_exceeded_error(self) -> type[Exception]: ... @property + @abc.abstractmethod def request_error(self) -> type[Exception]: ... + @abc.abstractmethod def get(self, url: str, max_retries: int = ..., **kwargs: Any) -> HttpResponse: ... + @abc.abstractmethod def head(self, url: str, max_retries: int = ..., **kwargs: Any) -> HttpResponse: ... + @abc.abstractmethod def post(self, url: str, max_retries: int = ..., **kwargs: Any) -> HttpResponse: ... + @abc.abstractmethod def put(self, url: str, max_retries: int = ..., **kwargs: Any) -> HttpResponse: ... + @abc.abstractmethod def delete( self, url: str, @@ -45,6 +59,7 @@ def delete( **kwargs: Any, ) -> HttpResponse: ... + @abc.abstractmethod def patch( self, url: str, diff --git a/src/graphon/http/response.py b/src/graphon/http/response.py index e033fd8..781ed5f 100644 --- a/src/graphon/http/response.py +++ b/src/graphon/http/response.py @@ -122,3 +122,13 @@ def _extract_charset_from_content_type(self) -> str | None: def raise_for_status(self) -> None: if not self.is_success: raise HttpStatusError(self) + + +if TYPE_CHECKING: + from .protocols import HttpResponseProtocol + + # static assertion to ensure HttpResponse implements HttpResponseProtocol. + def _assert_http_response_protocol( + response: HttpResponse, + ) -> HttpResponseProtocol: # pyright: ignore[reportUnusedFunction] + return response diff --git a/src/graphon/model_runtime/memory/prompt_message_memory.py b/src/graphon/model_runtime/memory/prompt_message_memory.py index ea2c245..ccb6b36 100644 --- a/src/graphon/model_runtime/memory/prompt_message_memory.py +++ b/src/graphon/model_runtime/memory/prompt_message_memory.py @@ -1,5 +1,6 @@ from __future__ import annotations +import abc from collections.abc import Sequence from typing import Protocol @@ -11,6 +12,7 @@ class PromptMessageMemory(Protocol): """Port for loading memory as prompt messages.""" + @abc.abstractmethod def get_history_prompt_messages( self, max_token_limit: int = DEFAULT_MEMORY_MAX_TOKEN_LIMIT, diff --git a/src/graphon/model_runtime/model_providers/base/tokenizers/gpt2_tokenizer.py b/src/graphon/model_runtime/model_providers/base/tokenizers/gpt2_tokenizer.py index 2e33253..7bb3493 100644 --- a/src/graphon/model_runtime/model_providers/base/tokenizers/gpt2_tokenizer.py +++ b/src/graphon/model_runtime/model_providers/base/tokenizers/gpt2_tokenizer.py @@ -1,3 +1,4 @@ +import abc import logging from collections.abc import Sequence from pathlib import Path @@ -8,6 +9,7 @@ class _TokenizerProtocol(Protocol): + @abc.abstractmethod def encode(self, text: str) -> Sequence[int]: ... diff --git a/src/graphon/model_runtime/protocols/llm_runtime.py b/src/graphon/model_runtime/protocols/llm_runtime.py index 08b52d1..0e4c9ff 100644 --- a/src/graphon/model_runtime/protocols/llm_runtime.py +++ b/src/graphon/model_runtime/protocols/llm_runtime.py @@ -1,5 +1,6 @@ from __future__ import annotations +import abc from collections.abc import Generator, Sequence from typing import Any, Literal, Protocol, overload, runtime_checkable @@ -22,6 +23,7 @@ class LLMModelRuntime(ModelProviderRuntime, Protocol): """Runtime surface required by LLM-backed model wrappers.""" @overload + @abc.abstractmethod def invoke_llm( self, *, @@ -36,6 +38,7 @@ def invoke_llm( ) -> LLMResult: ... @overload + @abc.abstractmethod def invoke_llm( self, *, @@ -49,6 +52,7 @@ def invoke_llm( stream: Literal[True], ) -> Generator[LLMResultChunk, None, None]: ... + @abc.abstractmethod def invoke_llm( self, *, @@ -63,6 +67,7 @@ def invoke_llm( ) -> LLMResult | Generator[LLMResultChunk, None, None]: ... @overload + @abc.abstractmethod def invoke_llm_with_structured_output( self, *, @@ -77,6 +82,7 @@ def invoke_llm_with_structured_output( ) -> LLMResultWithStructuredOutput: ... @overload + @abc.abstractmethod def invoke_llm_with_structured_output( self, *, @@ -90,6 +96,7 @@ def invoke_llm_with_structured_output( stream: Literal[True], ) -> Generator[LLMResultChunkWithStructuredOutput, None, None]: ... + @abc.abstractmethod def invoke_llm_with_structured_output( self, *, @@ -106,6 +113,7 @@ def invoke_llm_with_structured_output( | Generator[LLMResultChunkWithStructuredOutput, None, None] ): ... + @abc.abstractmethod def get_llm_num_tokens( self, *, diff --git a/src/graphon/model_runtime/protocols/moderation_runtime.py b/src/graphon/model_runtime/protocols/moderation_runtime.py index 8cf6ae9..52c4087 100644 --- a/src/graphon/model_runtime/protocols/moderation_runtime.py +++ b/src/graphon/model_runtime/protocols/moderation_runtime.py @@ -1,5 +1,6 @@ from __future__ import annotations +import abc from typing import Any, Protocol, runtime_checkable from graphon.model_runtime.protocols.provider_runtime import ModelProviderRuntime @@ -9,6 +10,7 @@ class ModerationModelRuntime(ModelProviderRuntime, Protocol): """Runtime surface required by moderation model wrappers.""" + @abc.abstractmethod def invoke_moderation( self, *, diff --git a/src/graphon/model_runtime/protocols/provider_runtime.py b/src/graphon/model_runtime/protocols/provider_runtime.py index 83c35fe..5cc98d7 100644 --- a/src/graphon/model_runtime/protocols/provider_runtime.py +++ b/src/graphon/model_runtime/protocols/provider_runtime.py @@ -1,5 +1,6 @@ from __future__ import annotations +import abc from collections.abc import Sequence from typing import Any, Protocol, runtime_checkable @@ -11,8 +12,10 @@ class ModelProviderRuntime(Protocol): """Shared provider discovery, credential validation, and schema lookup.""" + @abc.abstractmethod def fetch_model_providers(self) -> Sequence[ProviderEntity]: ... + @abc.abstractmethod def get_provider_icon( self, *, @@ -21,6 +24,7 @@ def get_provider_icon( lang: str, ) -> tuple[bytes, str]: ... + @abc.abstractmethod def validate_provider_credentials( self, *, @@ -28,6 +32,7 @@ def validate_provider_credentials( credentials: dict[str, Any], ) -> None: ... + @abc.abstractmethod def validate_model_credentials( self, *, @@ -37,6 +42,7 @@ def validate_model_credentials( credentials: dict[str, Any], ) -> None: ... + @abc.abstractmethod def get_model_schema( self, *, diff --git a/src/graphon/model_runtime/protocols/rerank_runtime.py b/src/graphon/model_runtime/protocols/rerank_runtime.py index aa3814a..2f3a953 100644 --- a/src/graphon/model_runtime/protocols/rerank_runtime.py +++ b/src/graphon/model_runtime/protocols/rerank_runtime.py @@ -1,5 +1,6 @@ from __future__ import annotations +import abc from typing import Any, Protocol, runtime_checkable from graphon.model_runtime.entities.rerank_entities import ( @@ -13,6 +14,7 @@ class RerankModelRuntime(ModelProviderRuntime, Protocol): """Runtime surface required by rerank model wrappers.""" + @abc.abstractmethod def invoke_rerank( self, *, @@ -25,6 +27,7 @@ def invoke_rerank( top_n: int | None, ) -> RerankResult: ... + @abc.abstractmethod def invoke_multimodal_rerank( self, *, diff --git a/src/graphon/model_runtime/protocols/speech_to_text_runtime.py b/src/graphon/model_runtime/protocols/speech_to_text_runtime.py index 8f59a62..07096a6 100644 --- a/src/graphon/model_runtime/protocols/speech_to_text_runtime.py +++ b/src/graphon/model_runtime/protocols/speech_to_text_runtime.py @@ -1,5 +1,6 @@ from __future__ import annotations +import abc from typing import IO, Any, Protocol, runtime_checkable from graphon.model_runtime.protocols.provider_runtime import ModelProviderRuntime @@ -9,6 +10,7 @@ class SpeechToTextModelRuntime(ModelProviderRuntime, Protocol): """Runtime surface required by speech-to-text model wrappers.""" + @abc.abstractmethod def invoke_speech_to_text( self, *, diff --git a/src/graphon/model_runtime/protocols/text_embedding_runtime.py b/src/graphon/model_runtime/protocols/text_embedding_runtime.py index 4938ccd..b7808b8 100644 --- a/src/graphon/model_runtime/protocols/text_embedding_runtime.py +++ b/src/graphon/model_runtime/protocols/text_embedding_runtime.py @@ -1,5 +1,6 @@ from __future__ import annotations +import abc from typing import Any, Protocol, runtime_checkable from graphon.model_runtime.entities.text_embedding_entities import ( @@ -13,6 +14,7 @@ class TextEmbeddingModelRuntime(ModelProviderRuntime, Protocol): """Runtime surface required by text and multimodal embedding wrappers.""" + @abc.abstractmethod def invoke_text_embedding( self, *, @@ -23,6 +25,7 @@ def invoke_text_embedding( input_type: EmbeddingInputType, ) -> EmbeddingResult: ... + @abc.abstractmethod def invoke_multimodal_embedding( self, *, @@ -33,6 +36,7 @@ def invoke_multimodal_embedding( input_type: EmbeddingInputType, ) -> EmbeddingResult: ... + @abc.abstractmethod def get_text_embedding_num_tokens( self, *, diff --git a/src/graphon/model_runtime/protocols/tts_runtime.py b/src/graphon/model_runtime/protocols/tts_runtime.py index 2b9129a..37a924a 100644 --- a/src/graphon/model_runtime/protocols/tts_runtime.py +++ b/src/graphon/model_runtime/protocols/tts_runtime.py @@ -1,5 +1,6 @@ from __future__ import annotations +import abc from collections.abc import Iterable from typing import Any, Protocol, runtime_checkable @@ -10,6 +11,7 @@ class TTSModelRuntime(ModelProviderRuntime, Protocol): """Runtime surface required by text-to-speech model wrappers.""" + @abc.abstractmethod def invoke_tts( self, *, @@ -20,6 +22,7 @@ def invoke_tts( voice: str, ) -> Iterable[bytes]: ... + @abc.abstractmethod def get_tts_model_voices( self, *, diff --git a/src/graphon/nodes/code/code_node.py b/src/graphon/nodes/code/code_node.py index a37039d..37c4c20 100644 --- a/src/graphon/nodes/code/code_node.py +++ b/src/graphon/nodes/code/code_node.py @@ -1,5 +1,6 @@ from __future__ import annotations +import abc from collections.abc import Mapping, Sequence from decimal import Decimal from textwrap import dedent @@ -23,6 +24,7 @@ class CodeExecutorProtocol(Protocol): + @abc.abstractmethod def execute( self, *, @@ -31,6 +33,7 @@ def execute( inputs: Mapping[str, Any], ) -> Mapping[str, Any]: ... + @abc.abstractmethod def is_execution_error(self, error: Exception) -> bool: ... diff --git a/src/graphon/nodes/llm/file_saver.py b/src/graphon/nodes/llm/file_saver.py index c6708ae..9372b1d 100644 --- a/src/graphon/nodes/llm/file_saver.py +++ b/src/graphon/nodes/llm/file_saver.py @@ -1,5 +1,6 @@ from __future__ import annotations +import abc import mimetypes import typing as tp @@ -21,6 +22,7 @@ class LLMFileSaver(tp.Protocol): LLM. """ + @abc.abstractmethod def save_binary_string( self, data: bytes, @@ -58,6 +60,7 @@ def save_binary_string( """ raise NotImplementedError + @abc.abstractmethod def save_remote_url(self, url: str, file_type: FileType) -> File: """save_remote_url saves the file from a remote url returned by LLM. @@ -198,3 +201,9 @@ def _validate_extension_override(extension_override: str | None) -> str | None: extension_override, ) return extension_override + + +if tp.TYPE_CHECKING: + # static assertion to ensure FileSaverImpl implements LLMFileSaver. + def _assert_llm_file_saver(saver: FileSaverImpl) -> LLMFileSaver: # pyright: ignore[reportUnusedFunction] + return saver diff --git a/src/graphon/nodes/llm/protocols.py b/src/graphon/nodes/llm/protocols.py index 740d039..1c307f2 100644 --- a/src/graphon/nodes/llm/protocols.py +++ b/src/graphon/nodes/llm/protocols.py @@ -1,5 +1,6 @@ from __future__ import annotations +import abc from typing import Any, Protocol from graphon.nodes.llm.runtime_protocols import LLMProtocol @@ -8,6 +9,7 @@ class CredentialsProvider(Protocol): """Port for loading runtime credentials for a provider/model pair.""" + @abc.abstractmethod def fetch(self, provider_name: str, model_name: str) -> dict[str, Any]: """Return credentials for the target provider/model or raise a domain error.""" ... @@ -16,6 +18,7 @@ def fetch(self, provider_name: str, model_name: str) -> dict[str, Any]: class ModelFactory(Protocol): """Port for creating prepared graph-facing LLM runtimes for execution.""" + @abc.abstractmethod def init_model_instance( self, provider_name: str, diff --git a/src/graphon/nodes/llm/runtime_protocols.py b/src/graphon/nodes/llm/runtime_protocols.py index efafc4c..3b31cb3 100644 --- a/src/graphon/nodes/llm/runtime_protocols.py +++ b/src/graphon/nodes/llm/runtime_protocols.py @@ -1,5 +1,6 @@ from __future__ import annotations +import abc from collections.abc import Generator, Mapping, Sequence from typing import Any, Literal, Protocol, overload @@ -22,25 +23,33 @@ class LLMProtocol(Protocol): """A graph-facing LLM runtime adapter for node execution.""" @property + @abc.abstractmethod def provider(self) -> str: ... @property + @abc.abstractmethod def model_name(self) -> str: ... @property + @abc.abstractmethod def parameters(self) -> Mapping[str, Any]: ... @parameters.setter + @abc.abstractmethod def parameters(self, value: Mapping[str, Any]) -> None: ... @property + @abc.abstractmethod def stop(self) -> Sequence[str] | None: ... + @abc.abstractmethod def get_model_schema(self) -> AIModelEntity: ... + @abc.abstractmethod def get_llm_num_tokens(self, prompt_messages: Sequence[PromptMessage]) -> int: ... @overload + @abc.abstractmethod def invoke_llm( self, *, @@ -52,6 +61,7 @@ def invoke_llm( ) -> LLMResult: ... @overload + @abc.abstractmethod def invoke_llm( self, *, @@ -62,6 +72,7 @@ def invoke_llm( stream: Literal[True], ) -> Generator[LLMResultChunk, None, None]: ... + @abc.abstractmethod def invoke_llm( self, *, @@ -73,6 +84,7 @@ def invoke_llm( ) -> LLMResult | Generator[LLMResultChunk, None, None]: ... @overload + @abc.abstractmethod def invoke_llm_with_structured_output( self, *, @@ -84,6 +96,7 @@ def invoke_llm_with_structured_output( ) -> LLMResultWithStructuredOutput: ... @overload + @abc.abstractmethod def invoke_llm_with_structured_output( self, *, @@ -94,6 +107,7 @@ def invoke_llm_with_structured_output( stream: Literal[True], ) -> Generator[LLMResultChunkWithStructuredOutput, None, None]: ... + @abc.abstractmethod def invoke_llm_with_structured_output( self, *, @@ -107,12 +121,14 @@ def invoke_llm_with_structured_output( | Generator[LLMResultChunkWithStructuredOutput, None, None] ): ... + @abc.abstractmethod def is_structured_output_parse_error(self, error: Exception) -> bool: ... class PromptMessageSerializerProtocol(Protocol): """Port for converting compiled prompt messages into persisted process data.""" + @abc.abstractmethod def serialize( self, *, @@ -124,4 +140,5 @@ def serialize( class RetrieverAttachmentLoaderProtocol(Protocol): """Port for resolving retriever segment attachments into graph file references.""" + @abc.abstractmethod def load(self, *, segment_id: str) -> Sequence[File]: ... diff --git a/src/graphon/nodes/protocols.py b/src/graphon/nodes/protocols.py index 3557792..df903d6 100644 --- a/src/graphon/nodes/protocols.py +++ b/src/graphon/nodes/protocols.py @@ -1,3 +1,4 @@ +import abc from collections.abc import Generator, Mapping from typing import Any, Protocol @@ -6,10 +7,12 @@ class FileManagerProtocol(Protocol): + @abc.abstractmethod def download(self, f: File, /) -> bytes: ... class ToolFileManagerProtocol(Protocol): + @abc.abstractmethod def create_file_by_raw( self, *, @@ -18,6 +21,7 @@ def create_file_by_raw( filename: str | None = None, ) -> Any: ... + @abc.abstractmethod def get_file_generator_by_tool_file_id( self, tool_file_id: str, @@ -29,6 +33,7 @@ class FileReferenceFactoryProtocol(Protocol): format. It enforces approriate permission filtering for the file. """ + @abc.abstractmethod def build_from_mapping(self, *, mapping: Mapping[str, Any]) -> File: ... diff --git a/src/graphon/nodes/runtime.py b/src/graphon/nodes/runtime.py index 2843e4a..fb35f66 100644 --- a/src/graphon/nodes/runtime.py +++ b/src/graphon/nodes/runtime.py @@ -26,6 +26,7 @@ class ToolNodeRuntimeProtocol(Protocol): translate between graph-owned abstractions and `core.tools` internals. """ + @abc.abstractmethod def get_runtime( self, *, @@ -35,12 +36,14 @@ def get_runtime( node_execution_id: str | None = None, ) -> ToolRuntimeHandle: ... + @abc.abstractmethod def get_runtime_parameters( self, *, tool_runtime: ToolRuntimeHandle, ) -> Sequence[ToolRuntimeParameter]: ... + @abc.abstractmethod def invoke( self, *, @@ -50,12 +53,14 @@ def invoke( provider_name: str, ) -> Generator[ToolRuntimeMessage, None, None]: ... + @abc.abstractmethod def get_usage( self, *, tool_runtime: ToolRuntimeHandle, ) -> LLMUsage: ... + @abc.abstractmethod def build_file_reference(self, *, mapping: Mapping[str, Any]) -> Any: ... @@ -85,6 +90,7 @@ def create_form( class HumanInputFormRepositoryBindableRuntimeProtocol(Protocol): """Optional capability for runtimes that require explicit repository binding.""" + @abc.abstractmethod def with_form_repository( self, form_repository: object, @@ -130,22 +136,29 @@ def _normalize_human_input_runtime( class HumanInputFormStateProtocol(Protocol): @property + @abc.abstractmethod def id(self) -> str: ... @property + @abc.abstractmethod def rendered_content(self) -> str: ... @property + @abc.abstractmethod def selected_action_id(self) -> str | None: ... @property + @abc.abstractmethod def submitted_data(self) -> Mapping[str, Any] | None: ... @property + @abc.abstractmethod def submitted(self) -> bool: ... @property + @abc.abstractmethod def status(self) -> HumanInputFormStatus: ... @property + @abc.abstractmethod def expiration_time(self) -> datetime: ... diff --git a/src/graphon/runtime/graph_runtime_state.py b/src/graphon/runtime/graph_runtime_state.py index c4a9221..b29f57b 100644 --- a/src/graphon/runtime/graph_runtime_state.py +++ b/src/graphon/runtime/graph_runtime_state.py @@ -1,5 +1,6 @@ from __future__ import annotations +import abc import importlib import json from collections.abc import Mapping, Sequence @@ -25,32 +26,39 @@ class ReadyQueueProtocol(Protocol): """Structural interface required from ready queue implementations.""" + @abc.abstractmethod def put(self, item: str) -> None: """Enqueue the identifier of a node that is ready to run.""" ... + @abc.abstractmethod def get(self, timeout: float | None = None) -> str: """Return the next node identifier, blocking until available or timeout expires. """ ... + @abc.abstractmethod def task_done(self) -> None: """Signal that the most recently dequeued node has completed processing.""" ... + @abc.abstractmethod def empty(self) -> bool: """Return True when the queue contains no pending nodes.""" ... + @abc.abstractmethod def qsize(self) -> int: """Approximate the number of pending nodes awaiting execution.""" ... + @abc.abstractmethod def dumps(self) -> str: """Serialize the queue contents for persistence.""" ... + @abc.abstractmethod def loads(self, data: str) -> None: """Restore the queue contents from a serialized payload.""" ... @@ -63,18 +71,22 @@ class NodeExecutionProtocol(Protocol): retry_count: int execution_id: str | None + @abc.abstractmethod def mark_started(self, execution_id: str) -> None: """Mark the node execution as started.""" ... + @abc.abstractmethod def mark_taken(self) -> None: """Mark the node execution as successfully completed.""" ... + @abc.abstractmethod def mark_failed(self, error: str) -> None: """Mark the node execution as failed with an error.""" ... + @abc.abstractmethod def increment_retry(self) -> None: """Increment the retry counter for the node execution.""" ... @@ -98,52 +110,64 @@ class GraphExecutionProtocol(Protocol): pause_reasons: list[PauseReason] @property + @abc.abstractmethod def node_executions(self) -> Mapping[str, NodeExecutionProtocol]: """Return the persisted node execution state keyed by node id.""" ... + @abc.abstractmethod def start(self) -> None: """Transition execution into the running state.""" ... + @abc.abstractmethod def complete(self) -> None: """Mark execution as successfully completed.""" ... + @abc.abstractmethod def abort(self, reason: str) -> None: """Abort execution in response to an external stop request.""" ... + @abc.abstractmethod def pause(self, reason: PauseReason) -> None: """Pause execution with a recorded reason.""" ... + @abc.abstractmethod def fail(self, error: Exception) -> None: """Record an unrecoverable error and end execution.""" ... + @abc.abstractmethod def record_node_failure(self) -> None: """Increment the count of node failures observed during execution.""" ... + @abc.abstractmethod def get_or_create_node_execution(self, node_id: str) -> NodeExecutionProtocol: """Return the execution entity for a node, creating it when needed.""" ... @property + @abc.abstractmethod def is_paused(self) -> bool: """Return whether the execution is currently paused.""" ... @property + @abc.abstractmethod def has_error(self) -> bool: """Return whether the execution has recorded an error.""" ... + @abc.abstractmethod def dumps(self) -> str: """Serialize execution state into a JSON payload.""" ... + @abc.abstractmethod def loads(self, data: str) -> None: """Restore execution state from a previously serialized payload.""" ... @@ -152,18 +176,22 @@ def loads(self, data: str) -> None: class ResponseStreamCoordinatorProtocol(Protocol): """Structural interface for response stream coordinator.""" + @abc.abstractmethod def register(self, response_node_id: str) -> None: """Register a response node so its outputs can be streamed.""" ... + @abc.abstractmethod def track_node_execution(self, node_id: str, execution_id: str) -> None: """Track the current execution id for a node.""" ... + @abc.abstractmethod def on_edge_taken(self, edge_id: str) -> Sequence[NodeRunStreamChunkEvent]: """Update pending response sessions after an edge is taken.""" ... + @abc.abstractmethod def intercept_event( self, event: NodeRunStreamChunkEvent | NodeRunSucceededEvent, @@ -171,10 +199,12 @@ def intercept_event( """Translate node events into streamed response events.""" ... + @abc.abstractmethod def loads(self, data: str) -> None: """Restore coordinator state from a serialized payload.""" ... + @abc.abstractmethod def dumps(self) -> str: """Serialize coordinator state for persistence.""" ... @@ -188,6 +218,7 @@ class NodeProtocol(Protocol): execution_type: NodeExecutionType node_type: ClassVar[NodeType] + @abc.abstractmethod def blocks_variable_output( self, variable_selectors: set[tuple[str, ...]], @@ -211,10 +242,12 @@ class GraphProtocol(Protocol): edges: Mapping[str, EdgeProtocol] root_node: NodeProtocol + @abc.abstractmethod def get_outgoing_edges(self, node_id: str) -> Sequence[EdgeProtocol]: ... class ChildGraphEngineBuilderProtocol(Protocol): + @abc.abstractmethod def build_child_engine( self, *, diff --git a/src/graphon/runtime/graph_runtime_state_protocol.py b/src/graphon/runtime/graph_runtime_state_protocol.py index 7b5c9ce..e732e78 100644 --- a/src/graphon/runtime/graph_runtime_state_protocol.py +++ b/src/graphon/runtime/graph_runtime_state_protocol.py @@ -1,3 +1,4 @@ +import abc from collections.abc import Mapping, Sequence from typing import Protocol @@ -8,10 +9,12 @@ class ReadOnlyVariablePool(Protocol): """Read-only interface for VariablePool.""" + @abc.abstractmethod def get(self, selector: Sequence[str], /) -> Segment | None: """Get a variable value (read-only).""" ... + @abc.abstractmethod def get_by_prefix(self, prefix: str, /) -> Mapping[str, object]: """Get all variables stored under a given node prefix (read-only).""" ... @@ -26,49 +29,59 @@ class ReadOnlyGraphRuntimeState(Protocol): """ @property + @abc.abstractmethod def variable_pool(self) -> ReadOnlyVariablePool: """Get read-only access to the variable pool.""" ... @property + @abc.abstractmethod def start_at(self) -> float: """Get the start time (read-only).""" ... @property + @abc.abstractmethod def total_tokens(self) -> int: """Get the total tokens count (read-only).""" ... @property + @abc.abstractmethod def llm_usage(self) -> LLMUsage: """Get a copy of LLM usage info (read-only).""" ... @property + @abc.abstractmethod def outputs(self) -> dict[str, object]: """Get a defensive copy of outputs (read-only).""" ... @property + @abc.abstractmethod def node_run_steps(self) -> int: """Get the node run steps count (read-only).""" ... @property + @abc.abstractmethod def ready_queue_size(self) -> int: """Get the number of nodes currently in the ready queue.""" ... @property + @abc.abstractmethod def exceptions_count(self) -> int: """Get the number of node execution exceptions recorded.""" ... + @abc.abstractmethod def get_output(self, key: str, default: object = None) -> object: """Get a single output value (returns a copy).""" ... + @abc.abstractmethod def dumps(self) -> str: """Serialize the runtime state into a JSON snapshot (read-only).""" ... diff --git a/src/graphon/runtime/read_only_wrappers.py b/src/graphon/runtime/read_only_wrappers.py index 822c3ce..8c836fa 100644 --- a/src/graphon/runtime/read_only_wrappers.py +++ b/src/graphon/runtime/read_only_wrappers.py @@ -2,6 +2,7 @@ from collections.abc import Mapping, Sequence from copy import deepcopy +from typing import TYPE_CHECKING from graphon.model_runtime.entities.llm_entities import LLMUsage from graphon.variables.segments import Segment @@ -72,3 +73,18 @@ def get_output(self, key: str, default: object = None) -> object: def dumps(self) -> str: """Serialize the underlying runtime state for external persistence.""" return self._state.dumps() + + +if TYPE_CHECKING: + from .graph_runtime_state_protocol import ReadOnlyGraphRuntimeState + + # static assertions to ensure read-only wrappers implement their protocols. + def _assert_readonly_variable_pool_wrapper( + pool: ReadOnlyVariablePoolWrapper, + ) -> ReadOnlyVariablePool: # pyright: ignore[reportUnusedFunction] + return pool + + def _assert_readonly_graph_runtime_state_wrapper( + state: ReadOnlyGraphRuntimeStateWrapper, + ) -> ReadOnlyGraphRuntimeState: # pyright: ignore[reportUnusedFunction] + return state diff --git a/src/graphon/variable_loader.py b/src/graphon/variable_loader.py index 3cf32ca..3c97633 100644 --- a/src/graphon/variable_loader.py +++ b/src/graphon/variable_loader.py @@ -1,6 +1,6 @@ import abc from collections.abc import Mapping, Sequence -from typing import Any, Protocol +from typing import TYPE_CHECKING, Any, Protocol from graphon.runtime.variable_pool import VariablePool from graphon.variables.consts import SELECTORS_LENGTH @@ -45,6 +45,14 @@ def load_variables(self, selectors: list[list[str]]) -> list[VariableBase]: return [] +if TYPE_CHECKING: + # static assertion to ensure _DummyVariableLoader implements VariableLoader. + def _assert_variable_loader( + loader: _DummyVariableLoader, + ) -> VariableLoader: # pyright: ignore[reportUnusedFunction] + return loader + + DUMMY_VARIABLE_LOADER = _DummyVariableLoader() diff --git a/tests/test_protocol_abstract_contracts.py b/tests/test_protocol_abstract_contracts.py new file mode 100644 index 0000000..0468b33 --- /dev/null +++ b/tests/test_protocol_abstract_contracts.py @@ -0,0 +1,184 @@ +from __future__ import annotations + +import ast +import importlib +import inspect +from pathlib import Path +from typing import Any + +import pytest + + +def _is_direct_protocol_class( + class_def: ast.ClassDef, + *, + protocol_aliases: set[str], +) -> bool: + for base in class_def.bases: + if isinstance(base, ast.Name) and base.id in protocol_aliases: + return True + if ( + isinstance(base, ast.Attribute) + and isinstance(base.value, ast.Name) + and f"{base.value.id}.{base.attr}" in protocol_aliases + ): + return True + return False + + +def _discover_protocol_aliases(parsed: ast.Module) -> set[str]: + protocol_aliases = {"Protocol"} + typing_aliases: set[str] = set() + for node in parsed.body: + if isinstance(node, ast.ImportFrom) and node.module == "typing": + for alias in node.names: + if alias.name == "Protocol": + protocol_aliases.add(alias.asname or alias.name) + elif isinstance(node, ast.Import): + for alias in node.names: + if alias.name == "typing": + typing_aliases.add(alias.asname or "typing") + + protocol_aliases.update(f"{alias}.Protocol" for alias in typing_aliases) + return protocol_aliases + + +def _has_protocol_members(class_def: ast.ClassDef) -> bool: + return any( + isinstance(member, ast.FunctionDef | ast.AsyncFunctionDef) + for member in class_def.body + ) + + +def _discover_protocol_targets() -> list[type[object]]: + src_root = Path(__file__).resolve().parents[1] / "src" / "graphon" + protocol_classes: list[type[object]] = [] + + for file_path in sorted(src_root.rglob("*.py")): + parsed = ast.parse(file_path.read_text()) + protocol_aliases = _discover_protocol_aliases(parsed) + module_name = "graphon." + ".".join( + file_path.relative_to(src_root).with_suffix("").parts + ) + module = importlib.import_module(module_name) + + for class_def in [ + node for node in parsed.body if isinstance(node, ast.ClassDef) + ]: + if not _is_direct_protocol_class( + class_def, + protocol_aliases=protocol_aliases, + ): + continue + if not _has_protocol_members(class_def): + continue + protocol_classes.append(getattr(module, class_def.name)) + + protocol_classes.sort(key=lambda cls: (cls.__module__, cls.__name__)) + return protocol_classes + + +def _protocol_member_names(protocol_cls: type[object]) -> list[str]: + member_names: list[str] = [] + for name, value in protocol_cls.__dict__.items(): + if name.startswith("__") and name.endswith("__"): + continue + if isinstance(value, property | classmethod | staticmethod): + member_names.append(name) + continue + if inspect.isfunction(value): + member_names.append(name) + return member_names + + +def _build_member_override(protocol_cls: type[object], member_name: str) -> Any: + member = protocol_cls.__dict__[member_name] + if isinstance(member, property): + return property(lambda _: None) + if isinstance(member, classmethod): + + def _class_stub(_cls: type[object], *args: object, **kwargs: object) -> None: + _ = args, kwargs + + return classmethod(_class_stub) + if isinstance(member, staticmethod): + + def _static_stub(*args: object, **kwargs: object) -> None: + _ = args, kwargs + + return staticmethod(_static_stub) + + def _stub(self: object, *args: object, **kwargs: object) -> None: + _ = self, args, kwargs + + return _stub + + +PROTOCOL_TARGETS = _discover_protocol_targets() + + +@pytest.mark.parametrize( + "protocol_cls", + PROTOCOL_TARGETS, + ids=lambda cls: f"{cls.__module__}.{cls.__name__}", +) +def test_protocol_members_are_abstract(protocol_cls: type[object]) -> None: + member_names = _protocol_member_names(protocol_cls) + assert member_names, ( + f"{protocol_cls.__module__}.{protocol_cls.__name__} has no protocol members." + ) + + non_abstract_members = [ + name + for name in member_names + if not getattr(protocol_cls.__dict__[name], "__isabstractmethod__", False) + ] + assert not non_abstract_members, ( + f"{protocol_cls.__module__}.{protocol_cls.__name__} contains non-abstract " + f"members: {non_abstract_members!r}" + ) + + +@pytest.mark.parametrize( + "protocol_cls", + PROTOCOL_TARGETS, + ids=lambda cls: f"{cls.__module__}.{cls.__name__}", +) +def test_protocol_direct_subclass_requires_overrides( + protocol_cls: type[object], +) -> None: + direct_impl = type( + f"Direct{protocol_cls.__name__}", + (protocol_cls,), + {}, + ) + with pytest.raises(TypeError): + direct_impl() + + +@pytest.mark.parametrize( + "protocol_cls", + PROTOCOL_TARGETS, + ids=lambda cls: f"{cls.__module__}.{cls.__name__}", +) +def test_protocol_indirect_partial_override_stays_abstract( + protocol_cls: type[object], +) -> None: + member_names = _protocol_member_names(protocol_cls) + if len(member_names) < 2: + pytest.skip("Protocol has fewer than two members.") + + first_member = member_names[0] + intermediate = type( + f"Intermediate{protocol_cls.__name__}", + (protocol_cls,), + {}, + ) + partial_impl = type( + f"Partial{protocol_cls.__name__}", + (intermediate,), + {first_member: _build_member_override(protocol_cls, first_member)}, + ) + + with pytest.raises(TypeError): + partial_impl()