Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
185 changes: 185 additions & 0 deletions cognite/client/_api/ai/python_schema.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,185 @@
from __future__ import annotations

import ast
from dataclasses import dataclass, field
from typing import TypeAlias

from jsonschema.exceptions import SchemaError
from jsonschema.validators import Draft7Validator

DefaultValue: TypeAlias = str | int | float | bool
FunctionDef: TypeAlias = ast.FunctionDef | ast.AsyncFunctionDef

COGNITE_QUERY_ID_FORMAT = "cognite-query-id"
DATAFRAME_PARAMETER_TYPES = {"pd.DataFrame", "DataFrame"}
QUERY_ID_DESCRIPTION_SUFFIX = (
"**THIS FIELD IS POPULATED BY PROVIDING A query_id FROM THE MEMORY TABLE**"
)


@dataclass
class SchemaResult:
type: str = "object"
properties: dict[str, dict] = field(default_factory=dict)
required: list[str] = field(default_factory=list)


@dataclass
class ErrorResult:
errors: list[str]


def _extract_argument_details(func_def: FunctionDef) -> tuple[list[str], dict[str, str | None]]:
argument_names = [arg.arg for arg in func_def.args.args]
argument_annotations: dict[str, str | None] = {
arg.arg: (ast.unparse(arg.annotation) if arg.annotation is not None else None)
for arg in func_def.args.args
}
return argument_names, argument_annotations


def _find_duplicate_argument_names(argument_names: list[str]) -> list[str]:
seen: set[str] = set()
duplicates: list[str] = []
for name in argument_names:
if name in seen and name not in duplicates:
duplicates.append(name)
seen.add(name)
return duplicates


def _find_missing_type_annotations(argument_annotations: dict[str, str | None]) -> list[str]:
return [name for name, annotation in argument_annotations.items() if annotation is None]


def _validate_function_arguments(func_def: FunctionDef) -> list[str]:
argument_names, argument_annotations = _extract_argument_details(func_def)
errors: list[str] = []
duplicates = _find_duplicate_argument_names(argument_names)
if duplicates:
errors.append(f"Duplicate argument name(s) found: {', '.join(duplicates)}")
missing = _find_missing_type_annotations(argument_annotations)
if missing:
errors.append(f"Missing type annotation for parameter(s): {', '.join(missing)}")
return errors


def _extract_argument_defaults(func_def: FunctionDef, argument_names: list[str]) -> dict[str, DefaultValue]:
defaults = func_def.args.defaults
if not defaults:
return {}
offset = len(argument_names) - len(defaults)
return {argument_names[offset + i]: ast.unparse(default) for i, default in enumerate(defaults)}


def _parse_list_type(annotation: str) -> tuple[bool, str | None]:
if annotation.startswith("list["):
return True, annotation[5:-1]
return False, None


def _scalar_to_schema(name: str, base_type: str, in_list: bool) -> tuple[dict | None, str | None]:
primitive_map: dict[str, dict] = {
"str": {"type": "string"},
"int": {"type": "integer"},
"float": {"type": "number"},
"bool": {"type": "boolean"},
}
if base_type in primitive_map:
return dict(primitive_map[base_type]), None
if in_list:
return None, f"Unsupported type for parameter '{name}': list[{base_type}]"
return None, f"Unsupported type for parameter '{name}': {base_type}"


def _parse_single_annotation(
name: str, annotation: str, default_value: DefaultValue | None
) -> tuple[dict | None, str | None]:
is_list, inner_type = _parse_list_type(annotation)

if is_list:
inner_is_list, _ = _parse_list_type(inner_type or "")
if inner_is_list:
return None, f"Unsupported type for parameter '{name}': nested list types are not supported"
items_schema, err = _scalar_to_schema(name, inner_type or "", in_list=True)
if err:
return None, err
return {"type": "array", "items": items_schema}, None

scalar_prop, err = _scalar_to_schema(name, annotation, in_list=False)
if err or scalar_prop is None:
return None, err

return scalar_prop, None


def _parse_type_annotations(func_def: FunctionDef) -> SchemaResult | ErrorResult:
argument_names, argument_annotations = _extract_argument_details(func_def)
argument_defaults = _extract_argument_defaults(func_def, argument_names)

errors: list[str] = []
properties: dict[str, dict] = {}
required_params: list[str] = []

for name in argument_names:
annotation = argument_annotations.get(name)
if annotation is None:
errors.append(f"Missing type annotation for parameter '{name}'")
continue
default_value = argument_defaults.get(name)
prop, err = _parse_single_annotation(name, annotation, default_value)
if err:
errors.append(err)
else:
properties[name] = prop # type: ignore[assignment]
if default_value is None:
required_params.append(name)

if errors:
return ErrorResult(errors)
return SchemaResult(type="object", properties=properties, required=sorted(required_params))


def _parse_docstring(func_def: FunctionDef, schema: SchemaResult) -> SchemaResult:
return schema


def extract_function_schema(code: str) -> SchemaResult | ErrorResult:
try:
try:
tree = ast.parse(code)
except SyntaxError:
return ErrorResult(["Failed to parse Python code. Check that the code is valid Python code."])

handle_func: FunctionDef | None = None
for node in ast.walk(tree):
if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)) and node.name == "handle":
handle_func = node
break

if handle_func is None:
return ErrorResult(["Function 'handle' not found."])

errors = _validate_function_arguments(handle_func)
if errors:
return ErrorResult(errors)

result = _parse_type_annotations(handle_func)
if isinstance(result, ErrorResult):
return result

try:
result = _parse_docstring(handle_func, result)
except ValueError as e:
return ErrorResult([str(e)])

schema_dict = {"type": result.type, "properties": result.properties, "required": result.required}
try:
Draft7Validator.check_schema(schema_dict)
except SchemaError as e:
return ErrorResult([str(e)])

return result

except Exception as e:
return ErrorResult([f"Unexpected error during schema extraction: {e}"])
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,8 @@ pandas = { version = ">=2.1", optional = true }
geopandas = { version = ">=0.14", optional = true }
shapely = { version = ">=1.7.0", optional = true }
PyYAML = { version = "^6.0", optional = true }
griffe = { version = ">=1.0,<2", optional = true }
jsonschema = { version = ">=4.0,<5", optional = true }

[tool.poetry.extras]
pandas = ["pandas"]
Expand All @@ -62,6 +64,7 @@ geo = ["geopandas", "shapely"]
sympy = ["sympy"]
functions = ["pip"]
yaml = ["PyYAML"]
schema = ["griffe", "jsonschema"]
pyodide = ["tzdata", "anyio"] # keep pyodide related dependencies outside of 'all'
all = ["numpy", "pandas", "geopandas", "shapely", "sympy", "pip", "PyYAML"]

Expand Down
3 changes: 3 additions & 0 deletions pytest.ini
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@
markers =
dsl
coredeps
us1: User Story 1 — Core Schema Extraction + Primitive Types (PR 1)
us2: User Story 2 — Extended Type Support (PR 2)
us3: User Story 3 — Docstring Enrichment + Default Values (PR 3)

anyio_mode = auto

Expand Down
122 changes: 122 additions & 0 deletions tests/tests_unit/test_api/test_python_schema.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
from __future__ import annotations

from unittest.mock import patch

import pytest

from cognite.client._api.ai.python_schema import ErrorResult, SchemaResult, extract_function_schema
from jsonschema.exceptions import SchemaError


# ---------------------------------------------------------------------------
# US1 — Core Schema Extraction + Primitive Types
# ---------------------------------------------------------------------------


@pytest.mark.us1
def test_no_handle_function():
result = extract_function_schema("def process(name: str): pass")
assert isinstance(result, ErrorResult)
assert result.errors[0] == "Function 'handle' not found."


@pytest.mark.us1
def test_missing_type_annotation():
result = extract_function_schema("def handle(name: str, age): pass")
assert isinstance(result, ErrorResult)
assert any("Missing type annotation for parameter(s): age" in e for e in result.errors)


@pytest.mark.us1
def test_duplicate_argument_names():
result = extract_function_schema("def handle(name: str, name: int): pass")
assert isinstance(result, ErrorResult)
assert any("Duplicate argument name(s) found: name" in e for e in result.errors)


@pytest.mark.us1
def test_invalid_python_syntax():
result = extract_function_schema("def handle(name: str\n pass")
assert isinstance(result, ErrorResult)
assert len(result.errors) > 0


@pytest.mark.us1
def test_primitive_types():
code = "def handle(name: str, age: int, height: float, is_active: bool) -> None: pass"
result = extract_function_schema(code)
assert isinstance(result, SchemaResult)
assert result.properties["name"] == {"type": "string"}
assert result.properties["age"] == {"type": "integer"}
assert result.properties["height"] == {"type": "number"}
assert result.properties["is_active"] == {"type": "boolean"}
assert result.required == ["age", "height", "is_active", "name"]


@pytest.mark.us1
def test_primitive_list_types():
code = "def handle(names: list[str], ages: list[int], heights: list[float], flags: list[bool]) -> None: pass"
result = extract_function_schema(code)
assert isinstance(result, SchemaResult)
assert result.properties["names"] == {"type": "array", "items": {"type": "string"}}
assert result.properties["ages"] == {"type": "array", "items": {"type": "integer"}}
assert result.properties["heights"] == {"type": "array", "items": {"type": "number"}}
assert result.properties["flags"] == {"type": "array", "items": {"type": "boolean"}}


@pytest.mark.us1
def test_nested_list_unsupported():
result = extract_function_schema("def handle(nested: list[list[str]]) -> None: pass")
assert isinstance(result, ErrorResult)
assert any("nested list" in e for e in result.errors)


@pytest.mark.us1
def test_async_handle():
code = "async def handle(name: str) -> None: pass"
result = extract_function_schema(code)
assert isinstance(result, SchemaResult)
assert result.properties["name"] == {"type": "string"}
assert result.required == ["name"]


@pytest.mark.us1
def test_empty_handle():
result = extract_function_schema("def handle() -> None: pass")
assert isinstance(result, SchemaResult)
assert result.properties == {}
assert result.required == []


@pytest.mark.us1
def test_no_docstring():
code = "def handle(name: str, age: int) -> None: pass"
result = extract_function_schema(code)
assert isinstance(result, SchemaResult)
assert "description" not in result.properties["name"]
assert "description" not in result.properties["age"]


@pytest.mark.us1
def test_multiple_functions_only_handle_extracted():
code = """
def helper(x: int):
pass

def handle(name: str):
pass
"""
result = extract_function_schema(code)
assert isinstance(result, SchemaResult)
assert list(result.properties.keys()) == ["name"]


@pytest.mark.us1
def test_draft7_validation_failure():
with patch(
"cognite.client._api.ai.python_schema.Draft7Validator.check_schema",
side_effect=SchemaError("bad schema"),
):
result = extract_function_schema("def handle(name: str): pass")
assert isinstance(result, ErrorResult)
assert len(result.errors) > 0
Loading