From 8201881de44aa9c72556619f7c3eaf8bb95818c1 Mon Sep 17 00:00:00 2001 From: Anders Hafreager Date: Thu, 23 Apr 2026 21:31:36 -0700 Subject: [PATCH] feat: add extract_function_schema with primitive type support (PR 1/3) Ports the Python script schema generator into the SDK as a pure-Python, no-network utility. This PR establishes the module skeleton, dataclasses, AST helpers, validation, and primitive type support (str, int, float, bool). - New module: cognite/client/_api/ai/python_schema.py - New tests: tests/tests_unit/test_api/test_python_schema.py (us1 group) - pyproject.toml: adds griffe>=1.0,<2 and jsonschema>=4.0,<5 as optional deps under a new [schema] extra - pytest.ini: registers us1/us2/us3 markers Co-Authored-By: Claude Sonnet 4.6 --- cognite/client/_api/ai/python_schema.py | 185 ++++++++++++++++++ pyproject.toml | 3 + pytest.ini | 3 + .../tests_unit/test_api/test_python_schema.py | 122 ++++++++++++ 4 files changed, 313 insertions(+) create mode 100644 cognite/client/_api/ai/python_schema.py create mode 100644 tests/tests_unit/test_api/test_python_schema.py diff --git a/cognite/client/_api/ai/python_schema.py b/cognite/client/_api/ai/python_schema.py new file mode 100644 index 0000000000..5b0d80181a --- /dev/null +++ b/cognite/client/_api/ai/python_schema.py @@ -0,0 +1,185 @@ +from __future__ import annotations + +import ast +from dataclasses import dataclass, field +from typing import TypeAlias + +from jsonschema.exceptions import SchemaError +from jsonschema.validators import Draft7Validator + +DefaultValue: TypeAlias = str | int | float | bool +FunctionDef: TypeAlias = ast.FunctionDef | ast.AsyncFunctionDef + +COGNITE_QUERY_ID_FORMAT = "cognite-query-id" +DATAFRAME_PARAMETER_TYPES = {"pd.DataFrame", "DataFrame"} +QUERY_ID_DESCRIPTION_SUFFIX = ( + "**THIS FIELD IS POPULATED BY PROVIDING A query_id FROM THE MEMORY TABLE**" +) + + +@dataclass +class SchemaResult: + type: str = "object" + properties: dict[str, dict] = field(default_factory=dict) + required: list[str] = field(default_factory=list) + + +@dataclass +class ErrorResult: + errors: list[str] + + +def _extract_argument_details(func_def: FunctionDef) -> tuple[list[str], dict[str, str | None]]: + argument_names = [arg.arg for arg in func_def.args.args] + argument_annotations: dict[str, str | None] = { + arg.arg: (ast.unparse(arg.annotation) if arg.annotation is not None else None) + for arg in func_def.args.args + } + return argument_names, argument_annotations + + +def _find_duplicate_argument_names(argument_names: list[str]) -> list[str]: + seen: set[str] = set() + duplicates: list[str] = [] + for name in argument_names: + if name in seen and name not in duplicates: + duplicates.append(name) + seen.add(name) + return duplicates + + +def _find_missing_type_annotations(argument_annotations: dict[str, str | None]) -> list[str]: + return [name for name, annotation in argument_annotations.items() if annotation is None] + + +def _validate_function_arguments(func_def: FunctionDef) -> list[str]: + argument_names, argument_annotations = _extract_argument_details(func_def) + errors: list[str] = [] + duplicates = _find_duplicate_argument_names(argument_names) + if duplicates: + errors.append(f"Duplicate argument name(s) found: {', '.join(duplicates)}") + missing = _find_missing_type_annotations(argument_annotations) + if missing: + errors.append(f"Missing type annotation for parameter(s): {', '.join(missing)}") + return errors + + +def _extract_argument_defaults(func_def: FunctionDef, argument_names: list[str]) -> dict[str, DefaultValue]: + defaults = func_def.args.defaults + if not defaults: + return {} + offset = len(argument_names) - len(defaults) + return {argument_names[offset + i]: ast.unparse(default) for i, default in enumerate(defaults)} + + +def _parse_list_type(annotation: str) -> tuple[bool, str | None]: + if annotation.startswith("list["): + return True, annotation[5:-1] + return False, None + + +def _scalar_to_schema(name: str, base_type: str, in_list: bool) -> tuple[dict | None, str | None]: + primitive_map: dict[str, dict] = { + "str": {"type": "string"}, + "int": {"type": "integer"}, + "float": {"type": "number"}, + "bool": {"type": "boolean"}, + } + if base_type in primitive_map: + return dict(primitive_map[base_type]), None + if in_list: + return None, f"Unsupported type for parameter '{name}': list[{base_type}]" + return None, f"Unsupported type for parameter '{name}': {base_type}" + + +def _parse_single_annotation( + name: str, annotation: str, default_value: DefaultValue | None +) -> tuple[dict | None, str | None]: + is_list, inner_type = _parse_list_type(annotation) + + if is_list: + inner_is_list, _ = _parse_list_type(inner_type or "") + if inner_is_list: + return None, f"Unsupported type for parameter '{name}': nested list types are not supported" + items_schema, err = _scalar_to_schema(name, inner_type or "", in_list=True) + if err: + return None, err + return {"type": "array", "items": items_schema}, None + + scalar_prop, err = _scalar_to_schema(name, annotation, in_list=False) + if err or scalar_prop is None: + return None, err + + return scalar_prop, None + + +def _parse_type_annotations(func_def: FunctionDef) -> SchemaResult | ErrorResult: + argument_names, argument_annotations = _extract_argument_details(func_def) + argument_defaults = _extract_argument_defaults(func_def, argument_names) + + errors: list[str] = [] + properties: dict[str, dict] = {} + required_params: list[str] = [] + + for name in argument_names: + annotation = argument_annotations.get(name) + if annotation is None: + errors.append(f"Missing type annotation for parameter '{name}'") + continue + default_value = argument_defaults.get(name) + prop, err = _parse_single_annotation(name, annotation, default_value) + if err: + errors.append(err) + else: + properties[name] = prop # type: ignore[assignment] + if default_value is None: + required_params.append(name) + + if errors: + return ErrorResult(errors) + return SchemaResult(type="object", properties=properties, required=sorted(required_params)) + + +def _parse_docstring(func_def: FunctionDef, schema: SchemaResult) -> SchemaResult: + return schema + + +def extract_function_schema(code: str) -> SchemaResult | ErrorResult: + try: + try: + tree = ast.parse(code) + except SyntaxError: + return ErrorResult(["Failed to parse Python code. Check that the code is valid Python code."]) + + handle_func: FunctionDef | None = None + for node in ast.walk(tree): + if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)) and node.name == "handle": + handle_func = node + break + + if handle_func is None: + return ErrorResult(["Function 'handle' not found."]) + + errors = _validate_function_arguments(handle_func) + if errors: + return ErrorResult(errors) + + result = _parse_type_annotations(handle_func) + if isinstance(result, ErrorResult): + return result + + try: + result = _parse_docstring(handle_func, result) + except ValueError as e: + return ErrorResult([str(e)]) + + schema_dict = {"type": result.type, "properties": result.properties, "required": result.required} + try: + Draft7Validator.check_schema(schema_dict) + except SchemaError as e: + return ErrorResult([str(e)]) + + return result + + except Exception as e: + return ErrorResult([f"Unexpected error during schema extraction: {e}"]) diff --git a/pyproject.toml b/pyproject.toml index 05d4d18255..4efe59ee77 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -54,6 +54,8 @@ pandas = { version = ">=2.1", optional = true } geopandas = { version = ">=0.14", optional = true } shapely = { version = ">=1.7.0", optional = true } PyYAML = { version = "^6.0", optional = true } +griffe = { version = ">=1.0,<2", optional = true } +jsonschema = { version = ">=4.0,<5", optional = true } [tool.poetry.extras] pandas = ["pandas"] @@ -62,6 +64,7 @@ geo = ["geopandas", "shapely"] sympy = ["sympy"] functions = ["pip"] yaml = ["PyYAML"] +schema = ["griffe", "jsonschema"] pyodide = ["tzdata", "anyio"] # keep pyodide related dependencies outside of 'all' all = ["numpy", "pandas", "geopandas", "shapely", "sympy", "pip", "PyYAML"] diff --git a/pytest.ini b/pytest.ini index d03db47b74..46f5e33a98 100644 --- a/pytest.ini +++ b/pytest.ini @@ -2,6 +2,9 @@ markers = dsl coredeps + us1: User Story 1 — Core Schema Extraction + Primitive Types (PR 1) + us2: User Story 2 — Extended Type Support (PR 2) + us3: User Story 3 — Docstring Enrichment + Default Values (PR 3) anyio_mode = auto diff --git a/tests/tests_unit/test_api/test_python_schema.py b/tests/tests_unit/test_api/test_python_schema.py new file mode 100644 index 0000000000..184e33c9de --- /dev/null +++ b/tests/tests_unit/test_api/test_python_schema.py @@ -0,0 +1,122 @@ +from __future__ import annotations + +from unittest.mock import patch + +import pytest + +from cognite.client._api.ai.python_schema import ErrorResult, SchemaResult, extract_function_schema +from jsonschema.exceptions import SchemaError + + +# --------------------------------------------------------------------------- +# US1 — Core Schema Extraction + Primitive Types +# --------------------------------------------------------------------------- + + +@pytest.mark.us1 +def test_no_handle_function(): + result = extract_function_schema("def process(name: str): pass") + assert isinstance(result, ErrorResult) + assert result.errors[0] == "Function 'handle' not found." + + +@pytest.mark.us1 +def test_missing_type_annotation(): + result = extract_function_schema("def handle(name: str, age): pass") + assert isinstance(result, ErrorResult) + assert any("Missing type annotation for parameter(s): age" in e for e in result.errors) + + +@pytest.mark.us1 +def test_duplicate_argument_names(): + result = extract_function_schema("def handle(name: str, name: int): pass") + assert isinstance(result, ErrorResult) + assert any("Duplicate argument name(s) found: name" in e for e in result.errors) + + +@pytest.mark.us1 +def test_invalid_python_syntax(): + result = extract_function_schema("def handle(name: str\n pass") + assert isinstance(result, ErrorResult) + assert len(result.errors) > 0 + + +@pytest.mark.us1 +def test_primitive_types(): + code = "def handle(name: str, age: int, height: float, is_active: bool) -> None: pass" + result = extract_function_schema(code) + assert isinstance(result, SchemaResult) + assert result.properties["name"] == {"type": "string"} + assert result.properties["age"] == {"type": "integer"} + assert result.properties["height"] == {"type": "number"} + assert result.properties["is_active"] == {"type": "boolean"} + assert result.required == ["age", "height", "is_active", "name"] + + +@pytest.mark.us1 +def test_primitive_list_types(): + code = "def handle(names: list[str], ages: list[int], heights: list[float], flags: list[bool]) -> None: pass" + result = extract_function_schema(code) + assert isinstance(result, SchemaResult) + assert result.properties["names"] == {"type": "array", "items": {"type": "string"}} + assert result.properties["ages"] == {"type": "array", "items": {"type": "integer"}} + assert result.properties["heights"] == {"type": "array", "items": {"type": "number"}} + assert result.properties["flags"] == {"type": "array", "items": {"type": "boolean"}} + + +@pytest.mark.us1 +def test_nested_list_unsupported(): + result = extract_function_schema("def handle(nested: list[list[str]]) -> None: pass") + assert isinstance(result, ErrorResult) + assert any("nested list" in e for e in result.errors) + + +@pytest.mark.us1 +def test_async_handle(): + code = "async def handle(name: str) -> None: pass" + result = extract_function_schema(code) + assert isinstance(result, SchemaResult) + assert result.properties["name"] == {"type": "string"} + assert result.required == ["name"] + + +@pytest.mark.us1 +def test_empty_handle(): + result = extract_function_schema("def handle() -> None: pass") + assert isinstance(result, SchemaResult) + assert result.properties == {} + assert result.required == [] + + +@pytest.mark.us1 +def test_no_docstring(): + code = "def handle(name: str, age: int) -> None: pass" + result = extract_function_schema(code) + assert isinstance(result, SchemaResult) + assert "description" not in result.properties["name"] + assert "description" not in result.properties["age"] + + +@pytest.mark.us1 +def test_multiple_functions_only_handle_extracted(): + code = """ +def helper(x: int): + pass + +def handle(name: str): + pass +""" + result = extract_function_schema(code) + assert isinstance(result, SchemaResult) + assert list(result.properties.keys()) == ["name"] + + +@pytest.mark.us1 +def test_draft7_validation_failure(): + with patch( + "cognite.client._api.ai.python_schema.Draft7Validator.check_schema", + side_effect=SchemaError("bad schema"), + ): + result = extract_function_schema("def handle(name: str): pass") + assert isinstance(result, ErrorResult) + assert len(result.errors) > 0