Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import json
import logging

# import sys
# import os

Expand Down
216 changes: 207 additions & 9 deletions altk/core/llm/output_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,25 @@


def json_schema_to_pydantic_model(
schema: Dict[str, Any], model_name: str = "AutoModel"
schema: Dict[str, Any],
model_name: str = "AutoModel",
free_form_object_as_str: bool = False,
) -> Type[BaseModel]:
"""Build a Pydantic model from a JSON Schema dict.

Args:
schema: JSON Schema dict.
model_name: name of the generated Pydantic model.
free_form_object_as_str: when ``True``, any free-form ``type: object``
property (one without its own ``properties`` sub-schema) is
modeled as a JSON-formatted ``str`` instead of a ``dict``. This
is the workaround for OpenAI's structured-output API, which
requires ``additionalProperties: false`` on every object schema —
a constraint that free-form dicts cannot meet. The caller is
expected to use :func:`relax_freeform_object_schema` when
validating the raw output so the JSON-string form is accepted.
Default ``False`` preserves backward-compatible behavior.
"""
fields = {}
required_fields = set(schema.get("required", []))

Expand All @@ -41,9 +58,29 @@ def json_schema_to_pydantic_model(
"null": type(None),
}

def parse_type(type_def: Union[str, List[str]]) -> Type[T]:
def _map_object_for_prop(prop_schema: Dict[str, Any]) -> Type:
"""Return dict/str for a property whose declared type is ``object``.

A property is "free-form" if it has no ``properties`` sub-schema; the
OpenAI workaround only applies to those.
"""
if free_form_object_as_str and "properties" not in prop_schema:
return str
return dict

def parse_type(
type_def: Union[str, List[str], None],
prop_schema: Dict[str, Any],
) -> Type[T]:
def _lookup(t: str) -> Type:
return (
_map_object_for_prop(prop_schema)
if t == "object"
else type_mapping.get(t, Any)
)

if isinstance(type_def, list):
python_types = [type_mapping.get(t, Any) for t in type_def]
python_types = [_lookup(t) for t in type_def]
if type(None) in python_types:
python_types.remove(type(None))
if len(python_types) == 1:
Expand All @@ -52,11 +89,12 @@ def parse_type(type_def: Union[str, List[str]]) -> Type[T]:
return Optional[Union[tuple(python_types)]] # type: ignore
else:
return Union[tuple(python_types)] # type: ignore
else:
return type_mapping.get(type_def, Any)
if isinstance(type_def, str):
return _lookup(type_def)
return Any # type: ignore[return-value]

for prop_name, prop_schema in schema.get("properties", {}).items():
field_type: Any = parse_type(prop_schema.get("type"))
field_type: Any = parse_type(prop_schema.get("type"), prop_schema)
default = ... if prop_name in required_fields else None
description = prop_schema.get("description", None)
field_args = {"description": description} if description else {}
Expand All @@ -65,6 +103,28 @@ def parse_type(type_def: Union[str, List[str]]) -> Type[T]:
return create_model(model_name, **fields) # type: ignore


def relax_freeform_object_schema(schema: Dict[str, Any]) -> Dict[str, Any]:
"""Return a deep copy of *schema* with free-form ``"type": "object"``
properties widened to accept ``"string"`` as well.

This is the validation-time counterpart to
``json_schema_to_pydantic_model(..., free_form_object_as_str=True)``: when
the Pydantic model emits a JSON string for a free-form object field,
``jsonschema.validate`` against the original schema would reject it. This
helper widens those fields so the same schema accepts both object-literal
and stringified forms. Schemas where the object has sub-``properties`` are
left alone.
"""
import copy

relaxed = copy.deepcopy(schema)
for _prop, prop_schema in relaxed.get("properties", {}).items():
t = prop_schema.get("type")
if t == "object" and "properties" not in prop_schema:
prop_schema["type"] = ["object", "string"]
return relaxed


class OutputValidationError(Exception):
"""Raised when LLM output cannot be validated against the provided schema."""

Expand All @@ -82,8 +142,110 @@ class ValidatingLLMClient(BaseLLMClient, ABC):
- Validates and parses the response.
- Retries only invalid items (single or batch) up to `retries` times.
- Falls back to single-item loops if no batch method is configured.

Production knobs (instance-level, with class-level defaults):
- ``free_form_object_as_str``: when ``True``, free-form ``type: object``
schema fields are modeled in Pydantic as ``str`` (and the validation
schema is widened at runtime to accept both object and string). Use
this for providers that require ``additionalProperties: false`` on
every object schema (notably OpenAI's structured-output API).
- ``prompt_based_validation``: when ``True``, the schema is always
injected into the system prompt and no native ``response_format``
kwarg is forwarded. Use for providers that don't support OpenAI-style
structured output (e.g. watsonx).
- ``default_generation_kwargs``: dict of kwargs merged into every
``generate``/``generate_async`` call (e.g. ``{"max_tokens": 8096,
"temperature": 0}``). Caller-provided kwargs override the defaults.
"""

# Class-level defaults — override on subclasses or per instance in
# ``configure_validation`` / constructor kwargs.
free_form_object_as_str: bool = False
prompt_based_validation: bool = False

def __init__(
self,
*,
free_form_object_as_str: Optional[bool] = None,
prompt_based_validation: Optional[bool] = None,
default_generation_kwargs: Optional[Dict[str, Any]] = None,
**base_kwargs: Any,
) -> None:
if free_form_object_as_str is not None:
self.free_form_object_as_str = free_form_object_as_str
if prompt_based_validation is not None:
self.prompt_based_validation = prompt_based_validation
self.default_generation_kwargs: Dict[str, Any] = dict(
default_generation_kwargs or {}
)
super().__init__(**base_kwargs)
# Wrap the subclass's _parse_llm_response so empty / malformed LLM
# outputs retry gracefully (the retry loop treats "" as invalid)
# rather than raising an unrecoverable ValueError.
# This particularly covers reasoning models that exhaust max_tokens
# on "thinking" tokens and return finish_reason="length" with no
# content but non-empty reasoning_content.
orig_parse = self._parse_llm_response
self._parse_llm_response = self._build_safe_parse(orig_parse) # type: ignore[assignment]

def configure_validation(
self,
*,
free_form_object_as_str: Optional[bool] = None,
prompt_based_validation: Optional[bool] = None,
default_generation_kwargs: Optional[Dict[str, Any]] = None,
) -> "ValidatingLLMClient":
"""Update the validation knobs after construction (chainable)."""
if free_form_object_as_str is not None:
self.free_form_object_as_str = free_form_object_as_str
if prompt_based_validation is not None:
self.prompt_based_validation = prompt_based_validation
if default_generation_kwargs is not None:
self.default_generation_kwargs = dict(default_generation_kwargs)
return self

@staticmethod
def _build_safe_parse(orig): # noqa: ANN001, ANN205
"""Wrap ``_parse_llm_response`` so parse failures become retry-worthy
empty strings instead of raising. Also surfaces a targeted warning
when a reasoning-only response exhausted the token budget."""
import logging as _logging

_logger = _logging.getLogger("altk.core.llm.output_parser")

def _safe_parse(raw): # noqa: ANN001, ANN202
try:
return orig(raw)
except (ValueError, KeyError):
# Detect: choice with reasoning_content but finish_reason='length'
_choices = getattr(raw, "choices", None) or (
raw.get("choices", []) if isinstance(raw, dict) else []
)
if _choices:
c0 = _choices[0]
_msg = getattr(c0, "message", None) or (
c0.get("message", {}) if isinstance(c0, dict) else {}
)
_reasoning = getattr(_msg, "reasoning_content", None) or (
_msg.get("reasoning_content")
if isinstance(_msg, dict)
else None
)
_finish = getattr(c0, "finish_reason", None) or (
c0.get("finish_reason") if isinstance(c0, dict) else None
)
if _reasoning and _finish == "length":
_logger.warning(
"LLM reasoning consumed the entire token budget "
"(finish_reason='length'). Consider increasing "
"max_tokens. Will retry."
)
return ""
_logger.debug("LLM returned empty/unparseable response; will retry.")
return ""

return _safe_parse

@classmethod
@abstractmethod
def provider_class(cls) -> Type[Any]:
Expand Down Expand Up @@ -168,8 +330,16 @@ def _validate(
raise ImportError(
"jsonschema is required for JSON Schema validation. Install with: pip install jsonschema"
)
# Widen free-form object props to also accept strings when we're
# configured to round-trip them as JSON strings (see
# ``free_form_object_as_str`` in the class docstring).
effective_schema = (
relax_freeform_object_schema(schema)
if self.free_form_object_as_str
else schema
)
try:
jsonschema.validate(instance=data, schema=schema)
jsonschema.validate(instance=data, schema=effective_schema)
except jsonschema.ValidationError as e:
raise OutputValidationError(
f"JSON Schema validation error: {e.message}"
Expand Down Expand Up @@ -225,6 +395,17 @@ def generate(
"""
Synchronous single-item generation with validation + retries.
"""
# Instance defaults — caller kwargs win.
if self.default_generation_kwargs:
merged = {**self.default_generation_kwargs}
merged.update(kwargs)
kwargs = merged
# Providers that don't support native structured output switch to
# prompt-based schema injection and drop any OpenAI-style
# ``response_format`` field.
if self.prompt_based_validation:
include_schema_in_system_prompt = True
schema_field = None
current = prompt
instr = None
if include_schema_in_system_prompt:
Expand All @@ -233,7 +414,10 @@ def generate(
if schema_field:
kwargs[schema_field] = schema
if isinstance(schema, dict):
new_schema = json_schema_to_pydantic_model(schema)
new_schema = json_schema_to_pydantic_model(
schema,
free_form_object_as_str=self.free_form_object_as_str,
)
kwargs[schema_field] = new_schema

last_error: Optional[str] = None
Expand Down Expand Up @@ -289,6 +473,17 @@ async def generate_async(
"""
Asynchronous single-item generation with validation + retries.
"""
# Instance defaults — caller kwargs win.
if self.default_generation_kwargs:
merged = {**self.default_generation_kwargs}
merged.update(kwargs)
kwargs = merged
# Providers that don't support native structured output switch to
# prompt-based schema injection and drop any OpenAI-style
# ``response_format`` field.
if self.prompt_based_validation:
include_schema_in_system_prompt = True
schema_field = None
current = prompt
instr = None
if include_schema_in_system_prompt:
Expand All @@ -297,7 +492,10 @@ async def generate_async(
if schema_field:
kwargs[schema_field] = schema
if isinstance(schema, dict):
new_schema = json_schema_to_pydantic_model(schema)
new_schema = json_schema_to_pydantic_model(
schema,
free_form_object_as_str=self.free_form_object_as_str,
)
kwargs[schema_field] = new_schema

last_error: Optional[str] = None
Expand Down
1 change: 0 additions & 1 deletion altk/core/llm_examples/azure_openai_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from altk.core.llm import get_llm, GenerationMode
from altk.core.llm.types import GenerationArgs


# ──────────────────────────────────────────────────────────────────────────────
# 1. Define schemas for structured output
# ──────────────────────────────────────────────────────────────────────────────
Expand Down
1 change: 0 additions & 1 deletion altk/core/llm_examples/ibm_watsonx_ai_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from altk.core.llm import get_llm, GenerationMode
from altk.core.llm.types import GenerationArgs


# ──────────────────────────────────────────────────────────────────────────────
# 1. Define schemas for structured output
# ──────────────────────────────────────────────────────────────────────────────
Expand Down
1 change: 0 additions & 1 deletion altk/core/llm_examples/litellm_ollama_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from altk.core.llm import get_llm
from altk.core.llm.types import GenerationArgs


# ──────────────────────────────────────────────────────────────────────────────
# 1. Define schemas for structured output
# ──────────────────────────────────────────────────────────────────────────────
Expand Down
1 change: 0 additions & 1 deletion altk/core/llm_examples/litellm_watsonx_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from altk.core.llm import get_llm
from altk.core.llm.types import GenerationArgs


# ──────────────────────────────────────────────────────────────────────────────
# 1. Define schemas for structured output
# ──────────────────────────────────────────────────────────────────────────────
Expand Down
1 change: 0 additions & 1 deletion altk/pre_llm/routing/follow_up_detection/follow_up.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
)
from langchain_core.messages import BaseMessage, HumanMessage, AIMessage


logger = logging.getLogger(__name__)

FOLLOW_UP_PROMPT = [
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import logging


logger = logging.getLogger(__name__)
# import built-in TopicRetriever implementations to force their registering
try:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
run_topic_extractions,
)


logger = logging.getLogger(__name__)


Expand Down
1 change: 0 additions & 1 deletion altk/pre_response/policy_guard/core/toolkit.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

from altk.core.toolkit import ComponentInput, ComponentOutput


######### Policy Guard Middleware Interfaces ##############


Expand Down
6 changes: 4 additions & 2 deletions altk/pre_response/policy_guard/detect/task_judge.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@ def create_adherence_check_report(results: list[dict]) -> dict:
class TaskJudge:
def __init__(self, config: ComponentConfig):
self.config = config
self.task_completion_prompt = Template("""<|begin_of_text|><|start_header_id|>system<|end_header_id|>
self.task_completion_prompt = Template(
"""<|begin_of_text|><|start_header_id|>system<|end_header_id|>
You are a grader whose job is to determine if a response is a valid response to a query. Your score should be "Yes" or "No". If a response is ambiguous your score should be "Yes". The score should only be "No" if the response is definitely not a valid response.<|eot_id|>
<|start_header_id|>user<|end_header_id|>

Expand All @@ -47,7 +48,8 @@ def __init__(self, config: ComponentConfig):
}

Return ONLY one JSON object, nothing else. Do not include any additional text or explanations outside the JSON object.<|eot_id|>
<|start_header_id|>assistant<|end_header_id|>""")
<|start_header_id|>assistant<|end_header_id|>"""
)

def check_task_completion(self, task: str, response: str):
prompt = self.task_completion_prompt.safe_substitute(
Expand Down
Loading
Loading