From 566eff567b582277a7370ed3d8463dc249aca87f Mon Sep 17 00:00:00 2001 From: Anders Hafreager Date: Thu, 23 Apr 2026 21:33:52 -0700 Subject: [PATCH] feat: add docstring enrichment, default values, re-exports (PR 3/3) Completes the extract_function_schema feature: - _parse_docstring: parses Google-style docstrings via griffe to set description fields; appends cognite-query-id guidance suffix to DataFrame params; raises ValueError for unknown docstring params - _apply_default_value: coerces ast.unparse'd defaults to the correct Python type (str via literal_eval, int/float cast, bool via == "True") - _parse_single_annotation: wired to apply defaults for all param types - __init__.py: re-exports extract_function_schema, SchemaResult, ErrorResult Co-Authored-By: Claude Sonnet 4.6 --- cognite/client/_api/ai/__init__.py | 1 + cognite/client/_api/ai/python_schema.py | 60 ++++++- .../tests_unit/test_api/test_python_schema.py | 157 ++++++++++++++++++ 3 files changed, 217 insertions(+), 1 deletion(-) diff --git a/cognite/client/_api/ai/__init__.py b/cognite/client/_api/ai/__init__.py index 662093f4eb..a663818322 100644 --- a/cognite/client/_api/ai/__init__.py +++ b/cognite/client/_api/ai/__init__.py @@ -2,6 +2,7 @@ from typing import TYPE_CHECKING +from cognite.client._api.ai.python_schema import ErrorResult, SchemaResult, extract_function_schema from cognite.client._api.ai.tools import AIToolsAPI from cognite.client._api_client import APIClient diff --git a/cognite/client/_api/ai/python_schema.py b/cognite/client/_api/ai/python_schema.py index d4a2642737..b13c7abfcc 100644 --- a/cognite/client/_api/ai/python_schema.py +++ b/cognite/client/_api/ai/python_schema.py @@ -101,6 +101,27 @@ def _scalar_to_schema(name: str, base_type: str, in_list: bool) -> tuple[dict | return None, f"Unsupported type for parameter '{name}': {base_type}" +def _apply_default_value(prop: dict, name: str, annotation: str, default_value: DefaultValue) -> None: + if "list" in annotation: + raise ValueError( + f"Default value for parameter '{name}' was provided, but default values are not supported for lists." + ) + str_default = str(default_value) + if annotation == "float": + prop["default"] = float(str_default) + elif annotation == "int": + prop["default"] = int(str_default) + elif annotation == "bool": + prop["default"] = str_default == "True" + elif annotation == "str": + prop["default"] = ast.literal_eval(str_default) + else: + raise ValueError( + f"Default value for parameter '{name}' was provided, but only primitive types" + " (str, int, float, bool) support default values." + ) + + def _parse_single_annotation( name: str, annotation: str, default_value: DefaultValue | None ) -> tuple[dict | None, str | None]: @@ -113,12 +134,24 @@ def _parse_single_annotation( items_schema, err = _scalar_to_schema(name, inner_type or "", in_list=True) if err: return None, err - return {"type": "array", "items": items_schema}, None + prop: dict = {"type": "array", "items": items_schema} + if default_value is not None: + try: + _apply_default_value(prop, name, annotation, default_value) + except ValueError as e: + return None, str(e) + return prop, None scalar_prop, err = _scalar_to_schema(name, annotation, in_list=False) if err or scalar_prop is None: return None, err + if default_value is not None: + try: + _apply_default_value(scalar_prop, name, annotation, default_value) + except ValueError as e: + return None, str(e) + return scalar_prop, None @@ -150,6 +183,31 @@ def _parse_type_annotations(func_def: FunctionDef) -> SchemaResult | ErrorResult def _parse_docstring(func_def: FunctionDef, schema: SchemaResult) -> SchemaResult: + from griffe import Docstring + from griffe import parse as griffe_parse + + docstring_text = ast.get_docstring(func_def) + if docstring_text is not None: + doc_obj = Docstring(docstring_text, lineno=1) + sections = griffe_parse(doc_obj, parser="google") + for section in sections: + if section.kind.value == "parameters": + for param in section.value: + if param.name in schema.properties: + schema.properties[param.name]["description"] = param.description + else: + raise ValueError( + f"Docstring parameter '{param.name}' does not match any function arguments." + ) + + for _, prop in schema.properties.items(): + if prop.get("format") == COGNITE_QUERY_ID_FORMAT: + existing_desc = prop.get("description", "").rstrip(".") + if existing_desc: + prop["description"] = f"{existing_desc}. {QUERY_ID_DESCRIPTION_SUFFIX}" + else: + prop["description"] = QUERY_ID_DESCRIPTION_SUFFIX + return schema diff --git a/tests/tests_unit/test_api/test_python_schema.py b/tests/tests_unit/test_api/test_python_schema.py index e66b392652..2118021934 100644 --- a/tests/tests_unit/test_api/test_python_schema.py +++ b/tests/tests_unit/test_api/test_python_schema.py @@ -235,3 +235,160 @@ def handle(param: {annotation}) -> None: ''' result = extract_function_schema(code) assert isinstance(result, ErrorResult) + + +# --------------------------------------------------------------------------- +# US3 — Docstring Enrichment + Default Values +# --------------------------------------------------------------------------- + + +@pytest.mark.us3 +def test_full_docstring(): + code = ''' +def handle(name: str, age: int, height: float) -> None: + """Summary. + + Args: + name: The person's name. + age: The person's age. + height: The person's height. + """ + pass +''' + result = extract_function_schema(code) + assert isinstance(result, SchemaResult) + assert result.properties["name"]["description"] == "The person's name." + assert result.properties["age"]["description"] == "The person's age." + assert result.properties["height"]["description"] == "The person's height." + + +@pytest.mark.us3 +def test_partial_docstring(): + code = ''' +def handle(name: str, age: int, height: float) -> None: + """Summary. + + Args: + name: The person's name. + age: The person's age. + """ + pass +''' + result = extract_function_schema(code) + assert isinstance(result, SchemaResult) + assert result.properties["name"]["description"] == "The person's name." + assert result.properties["age"]["description"] == "The person's age." + assert "description" not in result.properties["height"] + + +@pytest.mark.us3 +def test_no_docstring_no_description(): + code = "def handle(name: str, age: int) -> None: pass" + result = extract_function_schema(code) + assert isinstance(result, SchemaResult) + assert "description" not in result.properties["name"] + assert "description" not in result.properties["age"] + + +@pytest.mark.us3 +def test_docstring_unknown_param(): + code = ''' +def handle(name: str) -> None: + """Summary. + + Args: + name: The person's name. + extra_param: This param doesn't exist. + """ + pass +''' + result = extract_function_schema(code) + assert isinstance(result, ErrorResult) + + +@pytest.mark.us3 +def test_dataframe_with_docstring_gets_suffix(): + code = ''' +def handle(x: int, assets: pd.DataFrame) -> None: + """Summary. + + Args: + assets: test assets + """ + pass +''' + result = extract_function_schema(code) + assert isinstance(result, SchemaResult) + expected = "test assets. **THIS FIELD IS POPULATED BY PROVIDING A query_id FROM THE MEMORY TABLE**" + assert result.properties["assets"]["description"] == expected + + +@pytest.mark.us3 +def test_dataframe_without_docstring_gets_suffix(): + code = "def handle(assets: pd.DataFrame) -> None: pass" + result = extract_function_schema(code) + assert isinstance(result, SchemaResult) + expected = "**THIS FIELD IS POPULATED BY PROVIDING A query_id FROM THE MEMORY TABLE**" + assert result.properties["assets"]["description"] == expected + + +@pytest.mark.us3 +@pytest.mark.parametrize( + ("py_type", "default_val", "json_type", "expected_default"), + [ + ("str", '"Joe"', "string", "Joe"), + ("int", "1337", "integer", 1337), + ("float", "4.2", "number", 4.2), + ("bool", "True", "boolean", True), + ], +) +def test_primitive_defaults(py_type: str, default_val: str, json_type: str, expected_default: object): + code = f"def handle(param: {py_type} = {default_val}) -> None: pass" + result = extract_function_schema(code) + assert isinstance(result, SchemaResult) + assert result.properties["param"]["type"] == json_type + assert result.properties["param"]["default"] == expected_default + assert "param" not in result.required + + +@pytest.mark.us3 +def test_defaults_with_docstring(): + code = ''' +def handle(name: str, age: int = 30, is_active: bool = True) -> None: + """Summary. + + Args: + name: The person's name. + age: The person's age. + is_active: Whether the person is active. + """ + pass +''' + result = extract_function_schema(code) + assert isinstance(result, SchemaResult) + assert result.properties["age"]["description"] == "The person's age." + assert result.properties["age"]["default"] == 30 + assert "name" in result.required + assert "age" not in result.required + assert "is_active" not in result.required + + +@pytest.mark.us3 +def test_list_default_unsupported(): + result = extract_function_schema('def handle(names: list[str] = ["a", "b"]) -> None: pass') + assert isinstance(result, ErrorResult) + assert any("default values are not supported for lists" in e for e in result.errors) + + +@pytest.mark.us3 +def test_nodeid_default_unsupported(): + result = extract_function_schema('def handle(node: NodeId = NodeId("space", "id")) -> None: pass') + assert isinstance(result, ErrorResult) + assert any("only primitive types (str, int, float, bool) support default values" in e for e in result.errors) + + +@pytest.mark.us3 +def test_datetime_default_unsupported(): + result = extract_function_schema("def handle(date: datetime = datetime.now()) -> None: pass") + assert isinstance(result, ErrorResult) + assert any("only primitive types (str, int, float, bool) support default values" in e for e in result.errors)