Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions natsapi/routing.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def __init__(
reply_name = "Reply_" + self.operation_id
request_name = "Request_" + self.operation_id
self.reply_field = create_field(name=reply_name, type_=self.params)
self.request_field = create_field(name=request_name, type_=self.result)
self.request_field = create_field(name=request_name, type_=self.result, mode="serialization")

self.tags = tags or []
self.description = description or ""
Expand Down Expand Up @@ -66,7 +66,7 @@ def __init__(
self.operation_id = generate_operation_id_for_subject(summary=self.summary, subject=self.subject)
self.params = get_request_model(self.endpoint, subject, self.skip_validation)
reply_name = "Reply_" + self.operation_id
self.reply_field = create_field(name=reply_name, type_=self.params)
self.reply_field = create_field(name=reply_name, type_=self.params, mode="serialization")

self.tags = tags or []
self.description = description or ""
Expand Down Expand Up @@ -114,7 +114,7 @@ def __init__(
self.tags = tags or []
self.externalDocs = externalDocs
self.params = params
self.params_field = create_field(name="Publish_" + subject, type_=self.params)
self.params_field = create_field(name="Publish_" + subject, type_=self.params, mode="serialization")


class SubjectRouter:
Expand Down
11 changes: 9 additions & 2 deletions natsapi/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import re
from collections.abc import Callable
from enum import Enum
from typing import Any, Union
from typing import Any, Literal, Union

from pydantic import BaseConfig, BaseModel, create_model
from pydantic.fields import FieldInfo
Expand Down Expand Up @@ -33,18 +33,25 @@ def create_field(
class_validators: dict[str, Any] | None = None,
model_config: type[BaseConfig] = BaseConfig,
field_info: FieldInfo | None = None,
mode: Literal["validation", "serialization"] = "validation",
) -> ModelField:
"""
Yanked from fastapi.utils
Create a new reply field. Raises if type_ is invalid.

`mode` controls which pydantic JSON schema view is generated for this field:
"validation" for incoming payloads, "serialization" for outgoing payloads
(required for `@computed_field` to appear - Pydantic V2 only).
"""
class_validators = class_validators or {}

field_info = (field_info or FieldInfo(annotation=type_)) if PYDANTIC_V2 else (field_info or FieldInfo())

kwargs = {"name": name, "field_info": field_info}

if not PYDANTIC_V2:
if PYDANTIC_V2:
kwargs["mode"] = mode
else:
kwargs.update(
{
"type_": type_,
Expand Down
50 changes: 50 additions & 0 deletions tests/asyncapi/test_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,14 @@
from pydantic import BaseModel, Field

from natsapi import NatsAPI, Pub, SubjectRouter
from natsapi._compat import PYDANTIC_V2
from natsapi.asyncapi import Errors
from natsapi.asyncapi.models import ExternalDocumentation, Server
from natsapi.exceptions import JsonRPCException

if PYDANTIC_V2:
from pydantic import computed_field

pytestmark = pytest.mark.asyncio

production_server = Server(
Expand Down Expand Up @@ -411,3 +415,49 @@ def _(app):
schema = (await app.nc.request("natsapi.development.schema.RETRIEVE", {})).result
assert schema["channels"]["natsapi.development.req"]["request"]["summary"] == "Req"
assert schema["channels"]["natsapi.development.pub"]["publish"]["summary"] == "Pub"


@pytest.mark.skipif(not PYDANTIC_V2, reason="@computed_field and serialization-mode schemas are Pydantic v2 only")
def test_computed_field_in_result_should_appear_in_response_schema():
Comment thread
LanderMoerkerke marked this conversation as resolved.

class User(BaseModel):
first_name: str
last_name: str

@computed_field
@property
def full_name(self) -> str:
return f"{self.first_name} {self.last_name}"

class CreateUserParams(BaseModel):
first_name: str
last_name: str

@computed_field
@property
def full_name(self) -> str:
return f"{self.first_name} {self.last_name}"

app = NatsAPI("natsapi.development")
router = SubjectRouter(prefix="v1")

@router.request("users.GET", result=User)
def get_user(app, params: CreateUserParams):
return {}

@router.publish("users.CREATED")
Comment thread
LanderMoerkerke marked this conversation as resolved.
def user_created(app, user: User):
return {}

app.include_router(router)
app.generate_asyncapi()
schema = app.asyncapi_schema

user_schema = schema["components"]["schemas"]["User"]
assert "full_name" in user_schema["properties"]
assert user_schema["properties"]["full_name"]["readOnly"] is True
assert "full_name" in user_schema["required"]

# the request model does not include the calculated property
params_schema = schema["components"]["schemas"]["CreateUserParams"]
assert set(params_schema["properties"].keys()) == {"first_name", "last_name"}
Loading