diff --git a/src/labthings_fastapi/actions.py b/src/labthings_fastapi/actions.py index 18f6acf..95b9910 100644 --- a/src/labthings_fastapi/actions.py +++ b/src/labthings_fastapi/actions.py @@ -41,7 +41,7 @@ from .base_descriptor import BaseDescriptor from .logs import add_thing_log_destination -from .utilities import model_to_dict +from .utilities import model_to_dict, wrap_plain_types_in_rootmodel from .invocations import InvocationModel, InvocationStatus, LogRecordModel from .dependencies.invocation import NonWarningInvocationID from .exceptions import ( @@ -477,7 +477,6 @@ def list_all_invocations( @app.get( ACTION_INVOCATIONS_PATH + "/{id}", - response_model=InvocationModel, responses={404: {"description": "Invocation ID not found"}}, ) def action_invocation( @@ -683,7 +682,7 @@ def __init__( remove_first_positional_arg=True, ignore=[p.name for p in self.dependency_params], ) - self.output_model = return_type(func) + self.output_model = wrap_plain_types_in_rootmodel(return_type(func)) self.invocation_model = create_model( f"{name}_invocation", __base__=InvocationModel, diff --git a/tests/test_actions.py b/tests/test_actions.py index 6994701..d08c6e6 100644 --- a/tests/test_actions.py +++ b/tests/test_actions.py @@ -161,7 +161,7 @@ def action_wrapper(*args, **kwargs): return action_wrapper -def assert_input_models_equivalent(model_a, model_b): +def assert_models_equivalent(model_a, model_b): """Check two basemodels are equivalent.""" keys = list(model_a.model_fields.keys()) assert list(model_b.model_fields.keys()) == keys @@ -198,11 +198,10 @@ def decorated( """An example decorated action with type annotations.""" return 0.5 - assert_input_models_equivalent( - Example.action.input_model, Example.decorated.input_model + assert_models_equivalent(Example.action.input_model, Example.decorated.input_model) + assert_models_equivalent( + Example.action.output_model, Example.decorated.output_model ) - assert Example.action.output_model == Example.decorated.output_model - # Check we can make the thing and it has a valid TD example = create_thing_without_server(Example) example.validate_thing_description() diff --git a/tests/test_numpy_type.py b/tests/test_numpy_type.py index 1f53f6c..cf1d5c9 100644 --- a/tests/test_numpy_type.py +++ b/tests/test_numpy_type.py @@ -2,6 +2,7 @@ from pydantic import BaseModel, RootModel import numpy as np +from fastapi.testclient import TestClient from labthings_fastapi.testing import create_thing_without_server from labthings_fastapi.types.numpy import NDArray, DenumpifyingDict @@ -70,6 +71,14 @@ class MyNumpyThing(lt.Thing): def action_with_arrays(self, a: NDArray) -> NDArray: return a * 2 + @lt.action + def read_array(self) -> NDArray: + return np.array([1, 2]) + + @lt.property + def array_property(self) -> NDArray: + return np.array([3, 4, 5]) + def test_thing_description(): """Make sure the TD validates when numpy types are used.""" @@ -102,3 +111,16 @@ def test_rootmodel(): m = ArrayModel(root=input) assert isinstance(m.root, np.ndarray) assert (m.model_dump() == [0, 1, 2]).all() + + +def test_numpy_over_http(): + """Read numpy array over http.""" + server = lt.ThingServer({"np_thing": MyNumpyThing}) + with TestClient(server.app) as client: + np_thing_client = lt.ThingClient.from_url("/np_thing/", client=client) + + arrayprop = np_thing_client.array_property + assert np.array_equal(np.asarray(arrayprop), np.array([3, 4, 5])) + + array = np_thing_client.read_array() + assert np.array_equal(np.asarray(array), np.array([1, 2]))