diff --git a/strix/llm/dedupe.py b/strix/llm/dedupe.py index 9edd6b7b..b3eb4676 100644 --- a/strix/llm/dedupe.py +++ b/strix/llm/dedupe.py @@ -186,6 +186,16 @@ def check_duplicate( if api_base: completion_kwargs["api_base"] = api_base + try: + from strix.telemetry.tracer import get_global_tracer + + tracer = get_global_tracer() + if tracer: + run_id = tracer.run_id + completion_kwargs["metadata"] = {"$ai_trace_id": run_id} + except Exception as e: + logger.error(f"Could not set trace metadata: {e}") + response = litellm.completion(**completion_kwargs) content = response.choices[0].message.content diff --git a/strix/llm/llm.py b/strix/llm/llm.py index 311de35e..1b2430d1 100644 --- a/strix/llm/llm.py +++ b/strix/llm/llm.py @@ -1,4 +1,5 @@ import asyncio +import logging from collections.abc import AsyncIterator from dataclasses import dataclass from typing import Any @@ -17,10 +18,14 @@ parse_tool_invocations, ) from strix.skills import load_skills +from strix.telemetry import posthog from strix.tools import get_tools_prompt from strix.utils.resource_paths import get_strix_resource_path +logger = logging.getLogger(__name__) + + litellm.drop_params = True litellm.modify_params = True @@ -74,6 +79,11 @@ def __init__(self, config: LLMConfig, agent_name: str | None = None): else: self._reasoning_effort = "high" + try: + posthog.configure_litellm_posthog() + except Exception as e: + logger.error(f"Could not config posthog traces: {e}") + def _load_system_prompt(self, agent_name: str | None) -> str: if not agent_name: return "" @@ -128,32 +138,31 @@ async def generate( async def _stream(self, messages: list[dict[str, Any]]) -> AsyncIterator[LLMResponse]: accumulated = "" chunks: list[Any] = [] - done_streaming = 0 + found_function_end = False self._total_stats.requests += 1 - response = await acompletion(**self._build_completion_args(messages), stream=True) + completion_args = self._build_completion_args(messages) + response = await acompletion(**completion_args, stream=True) async for chunk in response: - chunks.append(chunk) - if done_streaming: - done_streaming += 1 - if getattr(chunk, "usage", None) or done_streaming > 5: - break - continue + chunks.append(chunk) delta = self._get_chunk_content(chunk) if delta: accumulated += delta - if "" in accumulated: + if not found_function_end and "" in accumulated: accumulated = accumulated[ : accumulated.find("") + len("") ] yield LLMResponse(content=accumulated) - done_streaming = 1 + found_function_end = True continue - yield LLMResponse(content=accumulated) + + if not found_function_end: + yield LLMResponse(content=accumulated) if chunks: - self._update_usage_stats(stream_chunk_builder(chunks)) + final_response = stream_chunk_builder(chunks) + self._update_usage_stats(final_response) accumulated = fix_incomplete_tool_call(_truncate_to_first_function(accumulated)) yield LLMResponse( @@ -200,6 +209,20 @@ def _build_completion_args(self, messages: list[dict[str, Any]]) -> dict[str, An "stream_options": {"include_usage": True}, } + metadata: dict[str, Any] = {} + + try: + from strix.telemetry.tracer import get_global_tracer + + tracer = get_global_tracer() + if tracer: + run_id = tracer.run_id + metadata["$ai_trace_id"] = run_id + except Exception as e: + logger.error(f"Could not set trace metadata: {e}") + if metadata: + args["metadata"] = metadata + if api_key := Config.get("llm_api_key"): args["api_key"] = api_key if api_base := ( diff --git a/strix/llm/memory_compressor.py b/strix/llm/memory_compressor.py index a9532f8f..661929af 100644 --- a/strix/llm/memory_compressor.py +++ b/strix/llm/memory_compressor.py @@ -111,6 +111,18 @@ def _summarize_messages( "timeout": timeout, } + try: + from strix.telemetry.tracer import get_global_tracer + + tracer = get_global_tracer() + if tracer: + run_id = tracer.run_id + completion_args["metadata"] = { + "$ai_trace_id": run_id, + } + except Exception as e: + logger.error(f"Could not set trace metadata: {e}") + response = litellm.completion(**completion_args) summary = response.choices[0].message.content or "" if not summary.strip(): diff --git a/strix/telemetry/posthog.py b/strix/telemetry/posthog.py index fd66bcc0..00369ee8 100644 --- a/strix/telemetry/posthog.py +++ b/strix/telemetry/posthog.py @@ -1,4 +1,9 @@ +from litellm import CALLBACK_TYPES + + import json +import logging +import os import platform import sys import urllib.request @@ -6,20 +11,55 @@ from typing import TYPE_CHECKING, Any from uuid import uuid4 +import litellm + from strix.config import Config +logger = logging.getLogger(__name__) + if TYPE_CHECKING: from strix.telemetry.tracer import Tracer -_POSTHOG_PUBLIC_API_KEY = "phc_7rO3XRuNT5sgSKAl6HDIrWdSGh1COzxw0vxVIAR6vVZ" -_POSTHOG_HOST = "https://us.i.posthog.com" +_POSTHOG_PRIMARY_API_KEY = "phc_7rO3XRuNT5sgSKAl6HDIrWdSGh1COzxw0vxVIAR6vVZ" +_POSTHOG_PRIMARY_HOST = "https://us.i.posthog.com" + +_POSTHOG_LLM_API_KEY = os.environ.get("POSTHOG_LLM_API_KEY") +_POSTHOG_LLM_HOST = os.environ.get("POSTHOG_LLM_HOST") _SESSION_ID = uuid4().hex[:16] def _is_enabled() -> bool: - return (Config.get("strix_telemetry") or "1").lower() not in ("0", "false", "no", "off") + telemetry_value = Config.get("strix_telemetry") or "1" + return telemetry_value.lower() not in ("0", "false", "no", "off") + + +def configure_litellm_posthog() -> None: + """Configure LiteLLM to send LLM traces to env postHog account.""" + + should_send_trace_to_posthog = _POSTHOG_LLM_API_KEY is not None and _POSTHOG_LLM_HOST is not None + + if not _is_enabled(): + logger.info("PostHog telemetry (traces) is disabled") + return + + if not should_send_trace_to_posthog: + logger.info("PostHog telemetry (traces) is disabled") + return + + os.environ["POSTHOG_API_KEY"] = _POSTHOG_LLM_API_KEY + os.environ["POSTHOG_API_URL"] = _POSTHOG_LLM_HOST + + if "posthog" not in (litellm.success_callback or []): + callbacks = list[CALLBACK_TYPES](litellm.success_callback or []) + callbacks.append("posthog") + litellm.success_callback = callbacks + + if "posthog" not in (litellm.failure_callback or []): + callbacks = list[CALLBACK_TYPES](litellm.failure_callback or []) + callbacks.append("posthog") + litellm.failure_callback = callbacks def _is_first_run() -> bool: @@ -44,22 +84,24 @@ def _get_version() -> str: def _send(event: str, properties: dict[str, Any]) -> None: + """Send custom events to Instance A (Primary) for manual tracking.""" if not _is_enabled(): return try: payload = { - "api_key": _POSTHOG_PUBLIC_API_KEY, + "api_key": _POSTHOG_PRIMARY_API_KEY, "event": event, "distinct_id": _SESSION_ID, "properties": properties, } req = urllib.request.Request( # noqa: S310 - f"{_POSTHOG_HOST}/capture/", + f"{_POSTHOG_PRIMARY_HOST}/capture/", data=json.dumps(payload).encode(), headers={"Content-Type": "application/json"}, ) with urllib.request.urlopen(req, timeout=10): # noqa: S310 # nosec B310 pass + logger.error(f"Sent custom event '{event}' to hardcoded posthog account") except Exception: # noqa: BLE001, S110 pass # nosec B110