From 0356fe2edeb261d4549894273da212434e98c62f Mon Sep 17 00:00:00 2001 From: Anders Hafreager Date: Thu, 23 Apr 2026 21:32:38 -0700 Subject: [PATCH] feat: add NodeId, datetime, DataFrame type support (PR 2/3) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Extends _scalar_to_schema with: - NodeId → JSON Schema object with space/externalId properties - datetime / datetime.datetime → string + format: date-time - pd.DataFrame / DataFrame → string + format: cognite-query-id - list[T] variants for all new types (list[pd.DataFrame] is an error) - Clear error messages for unrecognized types Co-Authored-By: Claude Sonnet 4.6 --- cognite/client/_api/ai/python_schema.py | 19 ++- .../tests_unit/test_api/test_python_schema.py | 119 +++++++++++++++++- 2 files changed, 131 insertions(+), 7 deletions(-) diff --git a/cognite/client/_api/ai/python_schema.py b/cognite/client/_api/ai/python_schema.py index 5b0d80181a..d4a2642737 100644 --- a/cognite/client/_api/ai/python_schema.py +++ b/cognite/client/_api/ai/python_schema.py @@ -12,9 +12,7 @@ COGNITE_QUERY_ID_FORMAT = "cognite-query-id" DATAFRAME_PARAMETER_TYPES = {"pd.DataFrame", "DataFrame"} -QUERY_ID_DESCRIPTION_SUFFIX = ( - "**THIS FIELD IS POPULATED BY PROVIDING A query_id FROM THE MEMORY TABLE**" -) +QUERY_ID_DESCRIPTION_SUFFIX = "**THIS FIELD IS POPULATED BY PROVIDING A query_id FROM THE MEMORY TABLE**" @dataclass @@ -32,8 +30,7 @@ class ErrorResult: def _extract_argument_details(func_def: FunctionDef) -> tuple[list[str], dict[str, str | None]]: argument_names = [arg.arg for arg in func_def.args.args] argument_annotations: dict[str, str | None] = { - arg.arg: (ast.unparse(arg.annotation) if arg.annotation is not None else None) - for arg in func_def.args.args + arg.arg: (ast.unparse(arg.annotation) if arg.annotation is not None else None) for arg in func_def.args.args } return argument_names, argument_annotations @@ -87,6 +84,18 @@ def _scalar_to_schema(name: str, base_type: str, in_list: bool) -> tuple[dict | } if base_type in primitive_map: return dict(primitive_map[base_type]), None + if base_type == "NodeId": + return { + "type": "object", + "properties": {"space": {"type": "string"}, "externalId": {"type": "string"}}, + "required": ["externalId", "space"], + }, None + if base_type in ("datetime", "datetime.datetime"): + return {"type": "string", "format": "date-time"}, None + if base_type in DATAFRAME_PARAMETER_TYPES: + if in_list: + return None, f"Unsupported type for parameter '{name}': list[{base_type}]" + return {"type": "string", "format": COGNITE_QUERY_ID_FORMAT}, None if in_list: return None, f"Unsupported type for parameter '{name}': list[{base_type}]" return None, f"Unsupported type for parameter '{name}': {base_type}" diff --git a/tests/tests_unit/test_api/test_python_schema.py b/tests/tests_unit/test_api/test_python_schema.py index 184e33c9de..e66b392652 100644 --- a/tests/tests_unit/test_api/test_python_schema.py +++ b/tests/tests_unit/test_api/test_python_schema.py @@ -3,10 +3,9 @@ from unittest.mock import patch import pytest - -from cognite.client._api.ai.python_schema import ErrorResult, SchemaResult, extract_function_schema from jsonschema.exceptions import SchemaError +from cognite.client._api.ai.python_schema import ErrorResult, SchemaResult, extract_function_schema # --------------------------------------------------------------------------- # US1 — Core Schema Extraction + Primitive Types @@ -120,3 +119,119 @@ def test_draft7_validation_failure(): result = extract_function_schema("def handle(name: str): pass") assert isinstance(result, ErrorResult) assert len(result.errors) > 0 + + +# --------------------------------------------------------------------------- +# US2 — Extended Type Support +# --------------------------------------------------------------------------- + + +@pytest.mark.us2 +def test_nodeid_scalar_and_list(): + code = "def handle(node: NodeId, node_list: list[NodeId]) -> None: pass" + result = extract_function_schema(code) + assert isinstance(result, SchemaResult) + + expected_scalar = { + "type": "object", + "properties": {"space": {"type": "string"}, "externalId": {"type": "string"}}, + "required": ["externalId", "space"], + } + assert result.properties["node"] == expected_scalar + + expected_list = { + "type": "array", + "items": { + "type": "object", + "properties": {"space": {"type": "string"}, "externalId": {"type": "string"}}, + "required": ["externalId", "space"], + }, + } + assert result.properties["node_list"] == expected_list + assert "properties" not in result.properties["node_list"] + assert "required" not in result.properties["node_list"] + + +@pytest.mark.us2 +def test_datetime_scalar_and_list(): + code = "def handle(date: datetime, date_list: list[datetime]) -> None: pass" + result = extract_function_schema(code) + assert isinstance(result, SchemaResult) + assert result.properties["date"] == {"type": "string", "format": "date-time"} + assert result.properties["date_list"] == {"type": "array", "items": {"type": "string", "format": "date-time"}} + + +@pytest.mark.us2 +def test_datetime_fully_qualified(): + code = "def handle(date: datetime.datetime) -> None: pass" + result = extract_function_schema(code) + assert isinstance(result, SchemaResult) + assert result.properties["date"] == {"type": "string", "format": "date-time"} + + +@pytest.mark.us2 +def test_pd_dataframe(): + code = "def handle(x: int, assets: pd.DataFrame) -> None: pass" + result = extract_function_schema(code) + assert isinstance(result, SchemaResult) + assert result.properties["assets"]["type"] == "string" + assert result.properties["assets"]["format"] == "cognite-query-id" + + +@pytest.mark.us2 +def test_dataframe_bare(): + code = "def handle(assets: DataFrame) -> None: pass" + result = extract_function_schema(code) + assert isinstance(result, SchemaResult) + assert result.properties["assets"]["type"] == "string" + assert result.properties["assets"]["format"] == "cognite-query-id" + + +@pytest.mark.us2 +def test_list_dataframe_unsupported(): + result = extract_function_schema("def handle(dfs: list[pd.DataFrame]) -> None: pass") + assert isinstance(result, ErrorResult) + assert any("list[pd.DataFrame]" in e for e in result.errors) + + +@pytest.mark.us2 +@pytest.mark.parametrize( + "annotation", + ["dict", "any", "MyCustomNodeId"], +) +def test_unsupported_scalar_type(annotation: str): + code = f"def handle(param: {annotation}) -> None: pass" + result = extract_function_schema(code) + assert isinstance(result, ErrorResult) + assert any(annotation.lower() in e.lower() for e in result.errors) + + +@pytest.mark.us2 +@pytest.mark.parametrize( + "annotation", + ["list[list[str]]", "list[set]", "list[tuple]", "list[complex]"], +) +def test_unsupported_list_type(annotation: str): + code = f"def handle(param: {annotation}) -> None: pass" + result = extract_function_schema(code) + assert isinstance(result, ErrorResult) + assert len(result.errors) > 0 + + +@pytest.mark.us2 +@pytest.mark.parametrize( + "annotation", + ["dict", "any", "MyCustomNodeId", "list[list[str]]", "list[set]", "list[tuple]", "list[complex]"], +) +def test_unsupported_type_with_docstring(annotation: str): + code = f''' +def handle(param: {annotation}) -> None: + """Summary. + + Args: + param: some description + """ + pass +''' + result = extract_function_schema(code) + assert isinstance(result, ErrorResult)