Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 14 additions & 5 deletions cognite/client/_api/ai/python_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,7 @@

COGNITE_QUERY_ID_FORMAT = "cognite-query-id"
DATAFRAME_PARAMETER_TYPES = {"pd.DataFrame", "DataFrame"}
QUERY_ID_DESCRIPTION_SUFFIX = (
"**THIS FIELD IS POPULATED BY PROVIDING A query_id FROM THE MEMORY TABLE**"
)
QUERY_ID_DESCRIPTION_SUFFIX = "**THIS FIELD IS POPULATED BY PROVIDING A query_id FROM THE MEMORY TABLE**"


@dataclass
Expand All @@ -32,8 +30,7 @@ class ErrorResult:
def _extract_argument_details(func_def: FunctionDef) -> tuple[list[str], dict[str, str | None]]:
argument_names = [arg.arg for arg in func_def.args.args]
argument_annotations: dict[str, str | None] = {
arg.arg: (ast.unparse(arg.annotation) if arg.annotation is not None else None)
for arg in func_def.args.args
arg.arg: (ast.unparse(arg.annotation) if arg.annotation is not None else None) for arg in func_def.args.args
}
return argument_names, argument_annotations

Expand Down Expand Up @@ -87,6 +84,18 @@ def _scalar_to_schema(name: str, base_type: str, in_list: bool) -> tuple[dict |
}
if base_type in primitive_map:
return dict(primitive_map[base_type]), None
if base_type == "NodeId":
return {
"type": "object",
"properties": {"space": {"type": "string"}, "externalId": {"type": "string"}},
"required": ["externalId", "space"],
}, None
if base_type in ("datetime", "datetime.datetime"):
return {"type": "string", "format": "date-time"}, None
if base_type in DATAFRAME_PARAMETER_TYPES:
if in_list:
return None, f"Unsupported type for parameter '{name}': list[{base_type}]"
return {"type": "string", "format": COGNITE_QUERY_ID_FORMAT}, None
if in_list:
return None, f"Unsupported type for parameter '{name}': list[{base_type}]"
return None, f"Unsupported type for parameter '{name}': {base_type}"
Expand Down
119 changes: 117 additions & 2 deletions tests/tests_unit/test_api/test_python_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,9 @@
from unittest.mock import patch

import pytest

from cognite.client._api.ai.python_schema import ErrorResult, SchemaResult, extract_function_schema
from jsonschema.exceptions import SchemaError

from cognite.client._api.ai.python_schema import ErrorResult, SchemaResult, extract_function_schema

# ---------------------------------------------------------------------------
# US1 — Core Schema Extraction + Primitive Types
Expand Down Expand Up @@ -120,3 +119,119 @@ def test_draft7_validation_failure():
result = extract_function_schema("def handle(name: str): pass")
assert isinstance(result, ErrorResult)
assert len(result.errors) > 0


# ---------------------------------------------------------------------------
# US2 — Extended Type Support
# ---------------------------------------------------------------------------


@pytest.mark.us2
def test_nodeid_scalar_and_list():
code = "def handle(node: NodeId, node_list: list[NodeId]) -> None: pass"
result = extract_function_schema(code)
assert isinstance(result, SchemaResult)

expected_scalar = {
"type": "object",
"properties": {"space": {"type": "string"}, "externalId": {"type": "string"}},
"required": ["externalId", "space"],
}
assert result.properties["node"] == expected_scalar

expected_list = {
"type": "array",
"items": {
"type": "object",
"properties": {"space": {"type": "string"}, "externalId": {"type": "string"}},
"required": ["externalId", "space"],
},
}
assert result.properties["node_list"] == expected_list
assert "properties" not in result.properties["node_list"]
assert "required" not in result.properties["node_list"]


@pytest.mark.us2
def test_datetime_scalar_and_list():
code = "def handle(date: datetime, date_list: list[datetime]) -> None: pass"
result = extract_function_schema(code)
assert isinstance(result, SchemaResult)
assert result.properties["date"] == {"type": "string", "format": "date-time"}
assert result.properties["date_list"] == {"type": "array", "items": {"type": "string", "format": "date-time"}}


@pytest.mark.us2
def test_datetime_fully_qualified():
code = "def handle(date: datetime.datetime) -> None: pass"
result = extract_function_schema(code)
assert isinstance(result, SchemaResult)
assert result.properties["date"] == {"type": "string", "format": "date-time"}


@pytest.mark.us2
def test_pd_dataframe():
code = "def handle(x: int, assets: pd.DataFrame) -> None: pass"
result = extract_function_schema(code)
assert isinstance(result, SchemaResult)
assert result.properties["assets"]["type"] == "string"
assert result.properties["assets"]["format"] == "cognite-query-id"


@pytest.mark.us2
def test_dataframe_bare():
code = "def handle(assets: DataFrame) -> None: pass"
result = extract_function_schema(code)
assert isinstance(result, SchemaResult)
assert result.properties["assets"]["type"] == "string"
assert result.properties["assets"]["format"] == "cognite-query-id"


@pytest.mark.us2
def test_list_dataframe_unsupported():
result = extract_function_schema("def handle(dfs: list[pd.DataFrame]) -> None: pass")
assert isinstance(result, ErrorResult)
assert any("list[pd.DataFrame]" in e for e in result.errors)


@pytest.mark.us2
@pytest.mark.parametrize(
"annotation",
["dict", "any", "MyCustomNodeId"],
)
def test_unsupported_scalar_type(annotation: str):
code = f"def handle(param: {annotation}) -> None: pass"
result = extract_function_schema(code)
assert isinstance(result, ErrorResult)
assert any(annotation.lower() in e.lower() for e in result.errors)


@pytest.mark.us2
@pytest.mark.parametrize(
"annotation",
["list[list[str]]", "list[set]", "list[tuple]", "list[complex]"],
)
def test_unsupported_list_type(annotation: str):
code = f"def handle(param: {annotation}) -> None: pass"
result = extract_function_schema(code)
assert isinstance(result, ErrorResult)
assert len(result.errors) > 0


@pytest.mark.us2
@pytest.mark.parametrize(
"annotation",
["dict", "any", "MyCustomNodeId", "list[list[str]]", "list[set]", "list[tuple]", "list[complex]"],
)
def test_unsupported_type_with_docstring(annotation: str):
code = f'''
def handle(param: {annotation}) -> None:
"""Summary.

Args:
param: some description
"""
pass
'''
result = extract_function_schema(code)
assert isinstance(result, ErrorResult)
Loading