Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions strix/llm/dedupe.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,16 @@ def check_duplicate(
if api_base:
completion_kwargs["api_base"] = api_base

try:
from strix.telemetry.tracer import get_global_tracer

tracer = get_global_tracer()
if tracer:
run_id = tracer.run_id
completion_kwargs["metadata"] = {"$ai_trace_id": run_id}
except Exception as e:
logger.error(f"Could not set trace metadata: {e}")

response = litellm.completion(**completion_kwargs)

content = response.choices[0].message.content
Expand Down
47 changes: 35 additions & 12 deletions strix/llm/llm.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import asyncio
import logging
from collections.abc import AsyncIterator
from dataclasses import dataclass
from typing import Any
Expand All @@ -17,10 +18,14 @@
parse_tool_invocations,
)
from strix.skills import load_skills
from strix.telemetry import posthog
from strix.tools import get_tools_prompt
from strix.utils.resource_paths import get_strix_resource_path


logger = logging.getLogger(__name__)


litellm.drop_params = True
litellm.modify_params = True

Expand Down Expand Up @@ -74,6 +79,11 @@ def __init__(self, config: LLMConfig, agent_name: str | None = None):
else:
self._reasoning_effort = "high"

try:
posthog.configure_litellm_posthog()
except Exception as e:
logger.error(f"Could not config posthog traces: {e}")

def _load_system_prompt(self, agent_name: str | None) -> str:
if not agent_name:
return ""
Expand Down Expand Up @@ -128,32 +138,31 @@ async def generate(
async def _stream(self, messages: list[dict[str, Any]]) -> AsyncIterator[LLMResponse]:
accumulated = ""
chunks: list[Any] = []
done_streaming = 0
found_function_end = False

self._total_stats.requests += 1
response = await acompletion(**self._build_completion_args(messages), stream=True)
completion_args = self._build_completion_args(messages)
response = await acompletion(**completion_args, stream=True)

async for chunk in response:
chunks.append(chunk)
if done_streaming:
done_streaming += 1
if getattr(chunk, "usage", None) or done_streaming > 5:
break
continue
chunks.append(chunk)
delta = self._get_chunk_content(chunk)
if delta:
accumulated += delta
if "</function>" in accumulated:
if not found_function_end and "</function>" in accumulated:
accumulated = accumulated[
: accumulated.find("</function>") + len("</function>")
]
yield LLMResponse(content=accumulated)
done_streaming = 1
found_function_end = True
continue
yield LLMResponse(content=accumulated)

if not found_function_end:
yield LLMResponse(content=accumulated)

if chunks:
self._update_usage_stats(stream_chunk_builder(chunks))
final_response = stream_chunk_builder(chunks)
self._update_usage_stats(final_response)

accumulated = fix_incomplete_tool_call(_truncate_to_first_function(accumulated))
yield LLMResponse(
Expand Down Expand Up @@ -200,6 +209,20 @@ def _build_completion_args(self, messages: list[dict[str, Any]]) -> dict[str, An
"stream_options": {"include_usage": True},
}

metadata: dict[str, Any] = {}

try:
from strix.telemetry.tracer import get_global_tracer

tracer = get_global_tracer()
if tracer:
run_id = tracer.run_id
metadata["$ai_trace_id"] = run_id
except Exception as e:
logger.error(f"Could not set trace metadata: {e}")
if metadata:
args["metadata"] = metadata

if api_key := Config.get("llm_api_key"):
args["api_key"] = api_key
if api_base := (
Expand Down
12 changes: 12 additions & 0 deletions strix/llm/memory_compressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,18 @@ def _summarize_messages(
"timeout": timeout,
}

try:
from strix.telemetry.tracer import get_global_tracer

tracer = get_global_tracer()
if tracer:
run_id = tracer.run_id
completion_args["metadata"] = {
"$ai_trace_id": run_id,
}
except Exception as e:
logger.error(f"Could not set trace metadata: {e}")

response = litellm.completion(**completion_args)
summary = response.choices[0].message.content or ""
if not summary.strip():
Expand Down
52 changes: 47 additions & 5 deletions strix/telemetry/posthog.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,65 @@
from litellm import CALLBACK_TYPES


import json
import logging
import os
import platform
import sys
import urllib.request
from pathlib import Path
from typing import TYPE_CHECKING, Any
from uuid import uuid4

import litellm

from strix.config import Config


logger = logging.getLogger(__name__)

if TYPE_CHECKING:
from strix.telemetry.tracer import Tracer

_POSTHOG_PUBLIC_API_KEY = "phc_7rO3XRuNT5sgSKAl6HDIrWdSGh1COzxw0vxVIAR6vVZ"
_POSTHOG_HOST = "https://us.i.posthog.com"
_POSTHOG_PRIMARY_API_KEY = "phc_7rO3XRuNT5sgSKAl6HDIrWdSGh1COzxw0vxVIAR6vVZ"
_POSTHOG_PRIMARY_HOST = "https://us.i.posthog.com"

_POSTHOG_LLM_API_KEY = os.environ.get("POSTHOG_LLM_API_KEY")
_POSTHOG_LLM_HOST = os.environ.get("POSTHOG_LLM_HOST")

_SESSION_ID = uuid4().hex[:16]


def _is_enabled() -> bool:
return (Config.get("strix_telemetry") or "1").lower() not in ("0", "false", "no", "off")
telemetry_value = Config.get("strix_telemetry") or "1"
return telemetry_value.lower() not in ("0", "false", "no", "off")


def configure_litellm_posthog() -> None:
"""Configure LiteLLM to send LLM traces to env postHog account."""

should_send_trace_to_posthog = _POSTHOG_LLM_API_KEY is not None and _POSTHOG_LLM_HOST is not None

if not _is_enabled():
logger.info("PostHog telemetry (traces) is disabled")
return

if not should_send_trace_to_posthog:
logger.info("PostHog telemetry (traces) is disabled")
return

os.environ["POSTHOG_API_KEY"] = _POSTHOG_LLM_API_KEY
os.environ["POSTHOG_API_URL"] = _POSTHOG_LLM_HOST

if "posthog" not in (litellm.success_callback or []):
callbacks = list[CALLBACK_TYPES](litellm.success_callback or [])
callbacks.append("posthog")
litellm.success_callback = callbacks

if "posthog" not in (litellm.failure_callback or []):
callbacks = list[CALLBACK_TYPES](litellm.failure_callback or [])
callbacks.append("posthog")
litellm.failure_callback = callbacks


def _is_first_run() -> bool:
Expand All @@ -44,22 +84,24 @@ def _get_version() -> str:


def _send(event: str, properties: dict[str, Any]) -> None:
"""Send custom events to Instance A (Primary) for manual tracking."""
if not _is_enabled():
return
try:
payload = {
"api_key": _POSTHOG_PUBLIC_API_KEY,
"api_key": _POSTHOG_PRIMARY_API_KEY,
"event": event,
"distinct_id": _SESSION_ID,
"properties": properties,
}
req = urllib.request.Request( # noqa: S310
f"{_POSTHOG_HOST}/capture/",
f"{_POSTHOG_PRIMARY_HOST}/capture/",
data=json.dumps(payload).encode(),
headers={"Content-Type": "application/json"},
)
with urllib.request.urlopen(req, timeout=10): # noqa: S310 # nosec B310
pass
logger.error(f"Sent custom event '{event}' to hardcoded posthog account")
except Exception: # noqa: BLE001, S110
pass # nosec B110

Expand Down