Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions cognite/client/_api/ai/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from typing import TYPE_CHECKING

from cognite.client._api.ai.python_schema import ErrorResult, SchemaResult, extract_function_schema
from cognite.client._api.ai.tools import AIToolsAPI
from cognite.client._api_client import APIClient

Expand Down
60 changes: 59 additions & 1 deletion cognite/client/_api/ai/python_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,27 @@ def _scalar_to_schema(name: str, base_type: str, in_list: bool) -> tuple[dict |
return None, f"Unsupported type for parameter '{name}': {base_type}"


def _apply_default_value(prop: dict, name: str, annotation: str, default_value: DefaultValue) -> None:
if "list" in annotation:
raise ValueError(
f"Default value for parameter '{name}' was provided, but default values are not supported for lists."
)
str_default = str(default_value)
if annotation == "float":
prop["default"] = float(str_default)
elif annotation == "int":
prop["default"] = int(str_default)
elif annotation == "bool":
prop["default"] = str_default == "True"
elif annotation == "str":
prop["default"] = ast.literal_eval(str_default)
else:
raise ValueError(
f"Default value for parameter '{name}' was provided, but only primitive types"
" (str, int, float, bool) support default values."
)


def _parse_single_annotation(
name: str, annotation: str, default_value: DefaultValue | None
) -> tuple[dict | None, str | None]:
Expand All @@ -113,12 +134,24 @@ def _parse_single_annotation(
items_schema, err = _scalar_to_schema(name, inner_type or "", in_list=True)
if err:
return None, err
return {"type": "array", "items": items_schema}, None
prop: dict = {"type": "array", "items": items_schema}
if default_value is not None:
try:
_apply_default_value(prop, name, annotation, default_value)
except ValueError as e:
return None, str(e)
return prop, None

scalar_prop, err = _scalar_to_schema(name, annotation, in_list=False)
if err or scalar_prop is None:
return None, err

if default_value is not None:
try:
_apply_default_value(scalar_prop, name, annotation, default_value)
except ValueError as e:
return None, str(e)

return scalar_prop, None


Expand Down Expand Up @@ -150,6 +183,31 @@ def _parse_type_annotations(func_def: FunctionDef) -> SchemaResult | ErrorResult


def _parse_docstring(func_def: FunctionDef, schema: SchemaResult) -> SchemaResult:
from griffe import Docstring
from griffe import parse as griffe_parse

docstring_text = ast.get_docstring(func_def)
if docstring_text is not None:
doc_obj = Docstring(docstring_text, lineno=1)
sections = griffe_parse(doc_obj, parser="google")
for section in sections:
if section.kind.value == "parameters":
for param in section.value:
if param.name in schema.properties:
schema.properties[param.name]["description"] = param.description
else:
raise ValueError(
f"Docstring parameter '{param.name}' does not match any function arguments."
)

for _, prop in schema.properties.items():
if prop.get("format") == COGNITE_QUERY_ID_FORMAT:
existing_desc = prop.get("description", "").rstrip(".")
if existing_desc:
prop["description"] = f"{existing_desc}. {QUERY_ID_DESCRIPTION_SUFFIX}"
else:
prop["description"] = QUERY_ID_DESCRIPTION_SUFFIX

return schema


Expand Down
157 changes: 157 additions & 0 deletions tests/tests_unit/test_api/test_python_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,3 +235,160 @@ def handle(param: {annotation}) -> None:
'''
result = extract_function_schema(code)
assert isinstance(result, ErrorResult)


# ---------------------------------------------------------------------------
# US3 — Docstring Enrichment + Default Values
# ---------------------------------------------------------------------------


@pytest.mark.us3
def test_full_docstring():
code = '''
def handle(name: str, age: int, height: float) -> None:
"""Summary.

Args:
name: The person's name.
age: The person's age.
height: The person's height.
"""
pass
'''
result = extract_function_schema(code)
assert isinstance(result, SchemaResult)
assert result.properties["name"]["description"] == "The person's name."
assert result.properties["age"]["description"] == "The person's age."
assert result.properties["height"]["description"] == "The person's height."


@pytest.mark.us3
def test_partial_docstring():
code = '''
def handle(name: str, age: int, height: float) -> None:
"""Summary.

Args:
name: The person's name.
age: The person's age.
"""
pass
'''
result = extract_function_schema(code)
assert isinstance(result, SchemaResult)
assert result.properties["name"]["description"] == "The person's name."
assert result.properties["age"]["description"] == "The person's age."
assert "description" not in result.properties["height"]


@pytest.mark.us3
def test_no_docstring_no_description():
code = "def handle(name: str, age: int) -> None: pass"
result = extract_function_schema(code)
assert isinstance(result, SchemaResult)
assert "description" not in result.properties["name"]
assert "description" not in result.properties["age"]


@pytest.mark.us3
def test_docstring_unknown_param():
code = '''
def handle(name: str) -> None:
"""Summary.

Args:
name: The person's name.
extra_param: This param doesn't exist.
"""
pass
'''
result = extract_function_schema(code)
assert isinstance(result, ErrorResult)


@pytest.mark.us3
def test_dataframe_with_docstring_gets_suffix():
code = '''
def handle(x: int, assets: pd.DataFrame) -> None:
"""Summary.

Args:
assets: test assets
"""
pass
'''
result = extract_function_schema(code)
assert isinstance(result, SchemaResult)
expected = "test assets. **THIS FIELD IS POPULATED BY PROVIDING A query_id FROM THE MEMORY TABLE**"
assert result.properties["assets"]["description"] == expected


@pytest.mark.us3
def test_dataframe_without_docstring_gets_suffix():
code = "def handle(assets: pd.DataFrame) -> None: pass"
result = extract_function_schema(code)
assert isinstance(result, SchemaResult)
expected = "**THIS FIELD IS POPULATED BY PROVIDING A query_id FROM THE MEMORY TABLE**"
assert result.properties["assets"]["description"] == expected


@pytest.mark.us3
@pytest.mark.parametrize(
("py_type", "default_val", "json_type", "expected_default"),
[
("str", '"Joe"', "string", "Joe"),
("int", "1337", "integer", 1337),
("float", "4.2", "number", 4.2),
("bool", "True", "boolean", True),
],
)
def test_primitive_defaults(py_type: str, default_val: str, json_type: str, expected_default: object):
code = f"def handle(param: {py_type} = {default_val}) -> None: pass"
result = extract_function_schema(code)
assert isinstance(result, SchemaResult)
assert result.properties["param"]["type"] == json_type
assert result.properties["param"]["default"] == expected_default
assert "param" not in result.required


@pytest.mark.us3
def test_defaults_with_docstring():
code = '''
def handle(name: str, age: int = 30, is_active: bool = True) -> None:
"""Summary.

Args:
name: The person's name.
age: The person's age.
is_active: Whether the person is active.
"""
pass
'''
result = extract_function_schema(code)
assert isinstance(result, SchemaResult)
assert result.properties["age"]["description"] == "The person's age."
assert result.properties["age"]["default"] == 30
assert "name" in result.required
assert "age" not in result.required
assert "is_active" not in result.required


@pytest.mark.us3
def test_list_default_unsupported():
result = extract_function_schema('def handle(names: list[str] = ["a", "b"]) -> None: pass')
assert isinstance(result, ErrorResult)
assert any("default values are not supported for lists" in e for e in result.errors)


@pytest.mark.us3
def test_nodeid_default_unsupported():
result = extract_function_schema('def handle(node: NodeId = NodeId("space", "id")) -> None: pass')
assert isinstance(result, ErrorResult)
assert any("only primitive types (str, int, float, bool) support default values" in e for e in result.errors)


@pytest.mark.us3
def test_datetime_default_unsupported():
result = extract_function_schema("def handle(date: datetime = datetime.now()) -> None: pass")
assert isinstance(result, ErrorResult)
assert any("only primitive types (str, int, float, bool) support default values" in e for e in result.errors)
Loading