diff --git a/natsapi/routing.py b/natsapi/routing.py index da24b68..395f91d 100644 --- a/natsapi/routing.py +++ b/natsapi/routing.py @@ -34,7 +34,7 @@ def __init__( reply_name = "Reply_" + self.operation_id request_name = "Request_" + self.operation_id self.reply_field = create_field(name=reply_name, type_=self.params) - self.request_field = create_field(name=request_name, type_=self.result) + self.request_field = create_field(name=request_name, type_=self.result, mode="serialization") self.tags = tags or [] self.description = description or "" @@ -66,7 +66,7 @@ def __init__( self.operation_id = generate_operation_id_for_subject(summary=self.summary, subject=self.subject) self.params = get_request_model(self.endpoint, subject, self.skip_validation) reply_name = "Reply_" + self.operation_id - self.reply_field = create_field(name=reply_name, type_=self.params) + self.reply_field = create_field(name=reply_name, type_=self.params, mode="serialization") self.tags = tags or [] self.description = description or "" @@ -114,7 +114,7 @@ def __init__( self.tags = tags or [] self.externalDocs = externalDocs self.params = params - self.params_field = create_field(name="Publish_" + subject, type_=self.params) + self.params_field = create_field(name="Publish_" + subject, type_=self.params, mode="serialization") class SubjectRouter: diff --git a/natsapi/utils.py b/natsapi/utils.py index b95ae0d..6790eca 100644 --- a/natsapi/utils.py +++ b/natsapi/utils.py @@ -3,7 +3,7 @@ import re from collections.abc import Callable from enum import Enum -from typing import Any, Union +from typing import Any, Literal, Union from pydantic import BaseConfig, BaseModel, create_model from pydantic.fields import FieldInfo @@ -33,10 +33,15 @@ def create_field( class_validators: dict[str, Any] | None = None, model_config: type[BaseConfig] = BaseConfig, field_info: FieldInfo | None = None, + mode: Literal["validation", "serialization"] = "validation", ) -> ModelField: """ Yanked from fastapi.utils Create a new reply field. Raises if type_ is invalid. + + `mode` controls which pydantic JSON schema view is generated for this field: + "validation" for incoming payloads, "serialization" for outgoing payloads + (required for `@computed_field` to appear - Pydantic V2 only). """ class_validators = class_validators or {} @@ -44,7 +49,9 @@ def create_field( kwargs = {"name": name, "field_info": field_info} - if not PYDANTIC_V2: + if PYDANTIC_V2: + kwargs["mode"] = mode + else: kwargs.update( { "type_": type_, diff --git a/tests/asyncapi/test_generation.py b/tests/asyncapi/test_generation.py index 4d9c4be..662b312 100644 --- a/tests/asyncapi/test_generation.py +++ b/tests/asyncapi/test_generation.py @@ -5,10 +5,14 @@ from pydantic import BaseModel, Field from natsapi import NatsAPI, Pub, SubjectRouter +from natsapi._compat import PYDANTIC_V2 from natsapi.asyncapi import Errors from natsapi.asyncapi.models import ExternalDocumentation, Server from natsapi.exceptions import JsonRPCException +if PYDANTIC_V2: + from pydantic import computed_field + pytestmark = pytest.mark.asyncio production_server = Server( @@ -411,3 +415,49 @@ def _(app): schema = (await app.nc.request("natsapi.development.schema.RETRIEVE", {})).result assert schema["channels"]["natsapi.development.req"]["request"]["summary"] == "Req" assert schema["channels"]["natsapi.development.pub"]["publish"]["summary"] == "Pub" + + +@pytest.mark.skipif(not PYDANTIC_V2, reason="@computed_field and serialization-mode schemas are Pydantic v2 only") +def test_computed_field_in_result_should_appear_in_response_schema(): + + class User(BaseModel): + first_name: str + last_name: str + + @computed_field + @property + def full_name(self) -> str: + return f"{self.first_name} {self.last_name}" + + class CreateUserParams(BaseModel): + first_name: str + last_name: str + + @computed_field + @property + def full_name(self) -> str: + return f"{self.first_name} {self.last_name}" + + app = NatsAPI("natsapi.development") + router = SubjectRouter(prefix="v1") + + @router.request("users.GET", result=User) + def get_user(app, params: CreateUserParams): + return {} + + @router.publish("users.CREATED") + def user_created(app, user: User): + return {} + + app.include_router(router) + app.generate_asyncapi() + schema = app.asyncapi_schema + + user_schema = schema["components"]["schemas"]["User"] + assert "full_name" in user_schema["properties"] + assert user_schema["properties"]["full_name"]["readOnly"] is True + assert "full_name" in user_schema["required"] + + # the request model does not include the calculated property + params_schema = schema["components"]["schemas"]["CreateUserParams"] + assert set(params_schema["properties"].keys()) == {"first_name", "last_name"}