From 553e074ba67423949a9f5bb0af66db4a3f0c1e6c Mon Sep 17 00:00:00 2001 From: Haris Mahmood Date: Thu, 7 May 2026 16:43:31 +0500 Subject: [PATCH 01/17] fix(prompts): cache assistant timezone lookup Centralize assistant timezone resolution behind the shared prompt helper so repeated now() calls and renderer timezone blocks reuse one TTL-backed lookup while still computing the current timestamp fresh each time. --- tests/common/test_prompt_helpers.py | 101 +++++++++- .../core/test_renderer.py | 44 ++++ unity/common/prompt_helpers.py | 190 ++++++++++++++---- .../conversation_manager/domains/renderer.py | 97 +-------- 4 files changed, 299 insertions(+), 133 deletions(-) diff --git a/tests/common/test_prompt_helpers.py b/tests/common/test_prompt_helpers.py index 2ec791230..1566e1624 100644 --- a/tests/common/test_prompt_helpers.py +++ b/tests/common/test_prompt_helpers.py @@ -1,9 +1,22 @@ import functools -from datetime import datetime +import importlib +from datetime import datetime, timedelta, timezone +from types import SimpleNamespace + +import pytest import unity.common.prompt_helpers as prompt_helpers from unity.common.tool_spec import ToolSpec +pytestmark = pytest.mark.no_unify_context + + +def _real_prompt_helpers(monkeypatch): + module = importlib.reload(prompt_helpers) + monkeypatch.setattr(module, "log_startup_timing", lambda *args, **kwargs: None) + module._assistant_timezone_cache = None + return module + def test_now_full_format(): # Human-readable format with day, month, date, time, and timezone @@ -23,6 +36,92 @@ def test_now_as_datetime(): assert result.day == 13 +def test_assistant_timezone_lookup_caches_within_ttl(monkeypatch): + module = _real_prompt_helpers(monkeypatch) + calls = [] + + def fake_get_logs(**kwargs): + calls.append(kwargs) + return [SimpleNamespace(entries={"timezone": "Asia/Karachi"})] + + monkeypatch.setattr(module, "_contacts_context", lambda: "User/Assistant/Contacts") + monkeypatch.setattr(module.time, "monotonic", lambda: 1000.0) + monkeypatch.setattr("unify.get_logs", fake_get_logs) + + assert module.get_assistant_timezone() == "Asia/Karachi" + assert module.get_assistant_timezone() == "Asia/Karachi" + assert len(calls) == 1 + assert calls[0]["filter"] == "contact_id == 0" + assert calls[0]["from_fields"] == ["timezone"] + + +def test_assistant_timezone_lookup_refreshes_after_ttl(monkeypatch): + module = _real_prompt_helpers(monkeypatch) + calls = [] + monotonic_now = {"value": 1000.0} + + def fake_get_logs(**kwargs): + calls.append(kwargs) + return [SimpleNamespace(entries={"timezone": "Asia/Karachi"})] + + monkeypatch.setattr(module, "_contacts_context", lambda: "User/Assistant/Contacts") + monkeypatch.setattr(module.time, "monotonic", lambda: monotonic_now["value"]) + monkeypatch.setattr("unify.get_logs", fake_get_logs) + + assert module.get_assistant_timezone() == "Asia/Karachi" + monotonic_now["value"] += 299 + assert module.get_assistant_timezone() == "Asia/Karachi" + monotonic_now["value"] += 2 + assert module.get_assistant_timezone() == "Asia/Karachi" + assert len(calls) == 2 + + +def test_now_recomputes_current_time_while_reusing_cached_timezone(monkeypatch): + module = _real_prompt_helpers(monkeypatch) + calls = [] + current_times = iter( + [ + datetime(2026, 5, 7, 8, 0, 0, tzinfo=timezone.utc), + datetime(2026, 5, 7, 8, 0, 1, tzinfo=timezone.utc), + ], + ) + + def fake_get_logs(**kwargs): + calls.append(kwargs) + return [SimpleNamespace(entries={"timezone": "UTC"})] + + monkeypatch.setattr(module, "_contacts_context", lambda: "User/Assistant/Contacts") + monkeypatch.setattr(module.time, "monotonic", lambda: 1000.0) + monkeypatch.setattr(module, "_utc_now", lambda: next(current_times)) + monkeypatch.setattr("unify.get_logs", fake_get_logs) + + first = module.now(as_string=False) + second = module.now(as_string=False) + + assert first == datetime(2026, 5, 7, 8, 0, 0, tzinfo=timezone.utc) + assert second == first + timedelta(seconds=1) + assert len(calls) == 1 + + +def test_now_falls_back_to_utc_when_timezone_lookup_fails(monkeypatch): + module = _real_prompt_helpers(monkeypatch) + + def fake_get_logs(**kwargs): + raise RuntimeError("backend unavailable") + + monkeypatch.setattr(module, "_contacts_context", lambda: "User/Assistant/Contacts") + monkeypatch.setattr(module.time, "monotonic", lambda: 1000.0) + monkeypatch.setattr( + module, + "_utc_now", + lambda: datetime(2026, 5, 7, 8, 0, 0, tzinfo=timezone.utc), + ) + monkeypatch.setattr("unify.get_logs", fake_get_logs) + + assert module.get_assistant_timezone() is None + assert module.now() == "Thursday, May 07, 2026 at 08:00 AM UTC" + + async def _sample_execute_code( thought: str, code: str | None = None, diff --git a/tests/conversation_manager/core/test_renderer.py b/tests/conversation_manager/core/test_renderer.py index 5e681480c..6bdb654ba 100644 --- a/tests/conversation_manager/core/test_renderer.py +++ b/tests/conversation_manager/core/test_renderer.py @@ -32,6 +32,8 @@ _get_assistant_email_role, ) +pytestmark = pytest.mark.no_unify_context + # ============================================================================= # Test Fixtures # ============================================================================= @@ -788,6 +790,48 @@ def test_tracks_messages_in_conversation( assert msg.contact_id == 1 assert msg.timestamp == ts1 + def test_render_state_uses_shared_assistant_timezone_helper( + self, + renderer, + contact_index, + notification_bar, + monkeypatch, + ): + """Active conversation rendering gets assistant timezone from common helper.""" + from unity.conversation_manager.cm_types import Medium + + calls = [] + + def fake_get_assistant_timezone(): + calls.append(True) + return "America/New_York" + + monkeypatch.setattr( + "unity.conversation_manager.domains.renderer.get_assistant_timezone", + fake_get_assistant_timezone, + ) + contact_index._fallback_contacts[1]["timezone"] = "America/Los_Angeles" + ts1 = datetime(2025, 6, 13, 12, 0, 0, tzinfo=timezone.utc) + contact_index.push_message( + contact_id=1, + sender_name="Alice", + thread_name=Medium.SMS_MESSAGE, + message_content="Hello there!", + timestamp=ts1, + role="user", + ) + + result = renderer.render_state( + contact_index, + notification_bar, + in_flight_actions={}, + last_snapshot=datetime(2025, 6, 13, 11, 0, 0, tzinfo=timezone.utc), + ) + + assert calls == [True] + assert "America/New_York" in result.full_render + assert "America/Los_Angeles" in result.full_render + def test_tracks_notifications(self, renderer, contact_index, notification_bar): """Notifications are tracked with identity.""" ts1 = datetime(2025, 6, 13, 12, 0, 0, tzinfo=timezone.utc) diff --git a/unity/common/prompt_helpers.py b/unity/common/prompt_helpers.py index 000822cb1..380ccdddf 100644 --- a/unity/common/prompt_helpers.py +++ b/unity/common/prompt_helpers.py @@ -2,6 +2,7 @@ from dataclasses import dataclass, field from datetime import datetime import json +import time from time import perf_counter from pydantic import BaseModel @@ -14,6 +15,7 @@ "sig_dict", "unwrap_tool_callable", "now", + "get_assistant_timezone", "tool_name", "require_tools", "parallelism_guidance", @@ -99,6 +101,144 @@ def _stable(sig_str: str) -> str: } +@dataclass(frozen=True) +class _AssistantTimezoneLookup: + timezone: str | None + cache_hit: bool + cache_age_seconds: float + context_ms: float + get_logs_ms: float + extract_ms: float + cache_store_ms: float + rows_count: int + error_type: str + + +_ASSISTANT_TIMEZONE_TTL = 300 # 5 minutes — timezone changes are very rare +_assistant_timezone_cache: tuple[float, str, str | None] | None = None + + +def _contacts_context() -> str: + from unity.session_details import SESSION_DETAILS + + return ( + f"{SESSION_DETAILS.user_context}/{SESSION_DETAILS.assistant_context}/Contacts" + ) + + +def _lookup_assistant_timezone() -> _AssistantTimezoneLookup: + """Return assistant timezone lookup details with a process-local TTL cache.""" + global _assistant_timezone_cache + + _timing_t0 = perf_counter() + _monotonic_t0 = _timing_t0 + monotonic_now = time.monotonic() + _monotonic_ms = (perf_counter() - _monotonic_t0) * 1000 + + _context_t0 = perf_counter() + contacts_ctx = _contacts_context() + _context_ms = (perf_counter() - _context_t0) * 1000 + + if _assistant_timezone_cache is not None: + cached_at, cached_context, cached_val = _assistant_timezone_cache + cache_age = monotonic_now - cached_at + if cached_context == contacts_ctx and cache_age < _ASSISTANT_TIMEZONE_TTL: + log_startup_timing( + LOGGER, + ( + "⏱️ [StartupTiming] timezone.assistant_lookup.detail " + "total=%.0fms monotonic=%.0fms cache_hit=True cache_age=%.0fs " + "context=%.0fms get_logs=0ms extract=0ms rows=0 tz=%s error=" + ), + (perf_counter() - _timing_t0) * 1000, + _monotonic_ms, + cache_age, + _context_ms, + cached_val, + ) + return _AssistantTimezoneLookup( + timezone=cached_val, + cache_hit=True, + cache_age_seconds=cache_age, + context_ms=_context_ms, + get_logs_ms=0.0, + extract_ms=0.0, + cache_store_ms=0.0, + rows_count=0, + error_type="", + ) + + import unify as _unify + + result: str | None = None + rows_count = 0 + error_type = "" + _get_logs_t0 = perf_counter() + try: + rows = _unify.get_logs( + context=contacts_ctx, + filter="contact_id == 0", + limit=1, + from_fields=["timezone"], + ) + _get_logs_ms = (perf_counter() - _get_logs_t0) * 1000 + rows_count = len(rows or []) + _extract_t0 = perf_counter() + if rows: + val = rows[0].entries.get("timezone") + if isinstance(val, str) and val.strip(): + result = val.strip() + except Exception: + _get_logs_ms = (perf_counter() - _get_logs_t0) * 1000 + _extract_t0 = perf_counter() + error_type = "get_logs" + _extract_ms = (perf_counter() - _extract_t0) * 1000 + + _cache_store_t0 = perf_counter() + _assistant_timezone_cache = (monotonic_now, contacts_ctx, result) + _cache_store_ms = (perf_counter() - _cache_store_t0) * 1000 + log_startup_timing( + LOGGER, + ( + "⏱️ [StartupTiming] timezone.assistant_lookup.detail " + "total=%.0fms monotonic=%.0fms cache_hit=False cache_age=0s " + "context=%.0fms get_logs=%.0fms extract=%.0fms cache_store=%.0fms " + "rows=%d tz=%s error=%s" + ), + (perf_counter() - _timing_t0) * 1000, + _monotonic_ms, + _context_ms, + _get_logs_ms, + _extract_ms, + _cache_store_ms, + rows_count, + result, + error_type, + ) + return _AssistantTimezoneLookup( + timezone=result, + cache_hit=False, + cache_age_seconds=0.0, + context_ms=_context_ms, + get_logs_ms=_get_logs_ms, + extract_ms=_extract_ms, + cache_store_ms=_cache_store_ms, + rows_count=rows_count, + error_type=error_type, + ) + + +def get_assistant_timezone() -> str | None: + """Return the assistant's configured IANA timezone, cached by assistant context.""" + return _lookup_assistant_timezone().timezone + + +def _utc_now() -> datetime: + from datetime import timezone as dt_timezone + + return datetime.now(dt_timezone.utc) + + def now(time_only: bool = False, as_string: bool = True) -> "str | datetime": """Return the current timestamp in the assistant's timezone. @@ -117,10 +257,7 @@ def now(time_only: bool = False, as_string: bool = True) -> "str | datetime": In tests, this function is monkeypatched by tests/conftest.py to return fixed or incrementing datetimes for cache consistency. """ - from datetime import datetime, timezone as dt_timezone from zoneinfo import ZoneInfo - import unify as _unify - from unity.session_details import SESSION_DETAILS _timing_t0 = perf_counter() _step_t0 = _timing_t0 @@ -132,37 +269,13 @@ def _mark_step() -> float: _step_t0 = step_now return elapsed_ms - _contacts_ctx = ( - f"{SESSION_DETAILS.user_context}/{SESSION_DETAILS.assistant_context}/Contacts" - ) - _context_ms = _mark_step() - # Default to UTC if assistant row/field is unavailable - tz_name: str = "UTC" - rows_count = 0 - error_type = "" - try: - rows = _unify.get_logs( - context=_contacts_ctx, - filter="contact_id == 0", - limit=1, - from_fields=["timezone"], - ) - _get_logs_ms = _mark_step() - rows_count = len(rows or []) - if rows: - val = rows[0].entries.get("timezone") - if isinstance(val, str) and val.strip(): - tz_name = val.strip() - except Exception: - _get_logs_ms = _mark_step() - error_type = "get_logs" - # Best-effort only; fall back to UTC - tz_name = "UTC" - _extract_ms = _mark_step() + lookup = _lookup_assistant_timezone() + _mark_step() + tz_name = lookup.timezone or "UTC" # Convert UTC now to the target timezone - utc_now = datetime.now(dt_timezone.utc) + utc_now = _utc_now() _utc_now_ms = _mark_step() zone_error = "" try: @@ -190,21 +303,24 @@ def _mark_step() -> float: "⏱️ [StartupTiming] timezone.prompt_now.detail " "total=%.0fms context=%.0fms get_logs=%.0fms extract=%.0fms " "utc_now=%.0fms zone_convert=%.0fms format=%.0fms " - "rows=%d tz=%s label=%s as_string=%s time_only=%s error=%s" + "rows=%d tz=%s label=%s as_string=%s time_only=%s " + "cache_hit=%s cache_age=%.0fs error=%s" ), (perf_counter() - _timing_t0) * 1000, - _context_ms, - _get_logs_ms, - _extract_ms, + lookup.context_ms, + lookup.get_logs_ms, + lookup.extract_ms, _utc_now_ms, _zone_convert_ms, _format_ms, - rows_count, + lookup.rows_count, tz_name, label, as_string, time_only, - error_type or zone_error or "", + lookup.cache_hit, + lookup.cache_age_seconds, + lookup.error_type or zone_error or "", ) return result diff --git a/unity/conversation_manager/domains/renderer.py b/unity/conversation_manager/domains/renderer.py index e9346ce63..5608294be 100644 --- a/unity/conversation_manager/domains/renderer.py +++ b/unity/conversation_manager/domains/renderer.py @@ -16,6 +16,7 @@ from typing import TYPE_CHECKING from unity.common._async_tool.utils import get_handle_paused_state +from unity.common.prompt_helpers import get_assistant_timezone from unity.common.startup_timing import log_startup_timing from unity.conversation_manager.domains.contact_index import ( ApiMessage, @@ -100,100 +101,6 @@ def _get_current_time_in_timezone(tz_name: str) -> str: return result -_assistant_tz_cache: tuple[float, str | None] | None = None -_ASSISTANT_TZ_TTL = 300 # 5 minutes — timezone changes are very rare - - -def _get_assistant_timezone() -> str | None: - """Get the assistant's timezone from contact_id=0. - - Uses a module-level TTL cache to avoid synchronous HTTP round-trips to - Orchestra on every render_state() call (which runs in the hot path of the - event loop). - - Returns: - IANA timezone identifier or None if not available. - """ - global _assistant_tz_cache - import time - - _timing_t0 = perf_counter() - now = time.monotonic() - _monotonic_ms = (perf_counter() - _timing_t0) * 1000 - if _assistant_tz_cache is not None: - cached_at, cached_val = _assistant_tz_cache - if now - cached_at < _ASSISTANT_TZ_TTL: - log_startup_timing( - LOGGER, - ( - "⏱️ [StartupTiming] timezone.assistant_lookup.detail " - "total=%.0fms monotonic=%.0fms cache_hit=True cache_age=%.0fs " - "context=0ms get_logs=0ms extract=0ms rows=0 tz=%s error=" - ), - (perf_counter() - _timing_t0) * 1000, - _monotonic_ms, - now - cached_at, - cached_val, - ) - return cached_val - - import unify as _unify - from unity.session_details import SESSION_DETAILS - - result: str | None = None - _context_t0 = perf_counter() - _contacts_ctx = ( - f"{SESSION_DETAILS.user_context}/{SESSION_DETAILS.assistant_context}/Contacts" - ) - _context_ms = (perf_counter() - _context_t0) * 1000 - - rows_count = 0 - error_type = "" - try: - _get_logs_t0 = perf_counter() - rows = _unify.get_logs( - context=_contacts_ctx, - filter="contact_id == 0", - limit=1, - from_fields=["timezone"], - ) - _get_logs_ms = (perf_counter() - _get_logs_t0) * 1000 - rows_count = len(rows or []) - _extract_t0 = perf_counter() - if rows: - val = rows[0].entries.get("timezone") - if isinstance(val, str) and val.strip(): - result = val.strip() - except Exception: - _get_logs_ms = (perf_counter() - _get_logs_t0) * 1000 - _extract_t0 = perf_counter() - error_type = "get_logs" - _extract_ms = (perf_counter() - _extract_t0) * 1000 - - _cache_store_t0 = perf_counter() - _assistant_tz_cache = (now, result) - _cache_store_ms = (perf_counter() - _cache_store_t0) * 1000 - log_startup_timing( - LOGGER, - ( - "⏱️ [StartupTiming] timezone.assistant_lookup.detail " - "total=%.0fms monotonic=%.0fms cache_hit=False cache_age=0s " - "context=%.0fms get_logs=%.0fms extract=%.0fms cache_store=%.0fms " - "rows=%d tz=%s error=%s" - ), - (perf_counter() - _timing_t0) * 1000, - _monotonic_ms, - _context_ms, - _get_logs_ms, - _extract_ms, - _cache_store_ms, - rows_count, - result, - error_type, - ) - return result - - def _format_timezone_block( assistant_tz: str | None, participants: list[tuple[str, str | None]], @@ -1173,7 +1080,7 @@ def render_active_conversations( """ _render_t0 = perf_counter() # Fetch assistant's timezone once for all contacts - assistant_timezone = _get_assistant_timezone() + assistant_timezone = get_assistant_timezone() _timezone_ms = (perf_counter() - _render_t0) * 1000 # Group global thread entries by contact_id From 510d6f400e9fda7ffcc0269102d6f9d14610b19c Mon Sep 17 00:00:00 2001 From: Haris Mahmood Date: Thu, 7 May 2026 17:17:41 +0500 Subject: [PATCH 02/17] fix(prompt): avoid caching failed timezone lookups Only cache successful assistant timezone values so transient get_logs errors or missing Contacts rows do not pin assistants to UTC for the full TTL. Add regression coverage for failed and empty timezone lookups. --- tests/common/test_prompt_helpers.py | 41 +++++++++++++++++++++++++++++ unity/common/prompt_helpers.py | 5 +++- 2 files changed, 45 insertions(+), 1 deletion(-) diff --git a/tests/common/test_prompt_helpers.py b/tests/common/test_prompt_helpers.py index 1566e1624..a289b3077 100644 --- a/tests/common/test_prompt_helpers.py +++ b/tests/common/test_prompt_helpers.py @@ -105,8 +105,10 @@ def fake_get_logs(**kwargs): def test_now_falls_back_to_utc_when_timezone_lookup_fails(monkeypatch): module = _real_prompt_helpers(monkeypatch) + calls = [] def fake_get_logs(**kwargs): + calls.append(kwargs) raise RuntimeError("backend unavailable") monkeypatch.setattr(module, "_contacts_context", lambda: "User/Assistant/Contacts") @@ -120,6 +122,45 @@ def fake_get_logs(**kwargs): assert module.get_assistant_timezone() is None assert module.now() == "Thursday, May 07, 2026 at 08:00 AM UTC" + assert len(calls) == 2 + + +def test_failed_assistant_timezone_lookup_does_not_poison_cache(monkeypatch): + module = _real_prompt_helpers(monkeypatch) + calls = [] + + def fake_get_logs(**kwargs): + calls.append(kwargs) + if len(calls) == 1: + raise RuntimeError("backend unavailable") + return [SimpleNamespace(entries={"timezone": "Asia/Karachi"})] + + monkeypatch.setattr(module, "_contacts_context", lambda: "User/Assistant/Contacts") + monkeypatch.setattr(module.time, "monotonic", lambda: 1000.0) + monkeypatch.setattr("unify.get_logs", fake_get_logs) + + assert module.get_assistant_timezone() is None + assert module.get_assistant_timezone() == "Asia/Karachi" + assert len(calls) == 2 + + +def test_missing_assistant_timezone_row_does_not_poison_cache(monkeypatch): + module = _real_prompt_helpers(monkeypatch) + calls = [] + + def fake_get_logs(**kwargs): + calls.append(kwargs) + if len(calls) == 1: + return [] + return [SimpleNamespace(entries={"timezone": "Asia/Karachi"})] + + monkeypatch.setattr(module, "_contacts_context", lambda: "User/Assistant/Contacts") + monkeypatch.setattr(module.time, "monotonic", lambda: 1000.0) + monkeypatch.setattr("unify.get_logs", fake_get_logs) + + assert module.get_assistant_timezone() is None + assert module.get_assistant_timezone() == "Asia/Karachi" + assert len(calls) == 2 async def _sample_execute_code( diff --git a/unity/common/prompt_helpers.py b/unity/common/prompt_helpers.py index 380ccdddf..67c824b80 100644 --- a/unity/common/prompt_helpers.py +++ b/unity/common/prompt_helpers.py @@ -195,7 +195,10 @@ def _lookup_assistant_timezone() -> _AssistantTimezoneLookup: _extract_ms = (perf_counter() - _extract_t0) * 1000 _cache_store_t0 = perf_counter() - _assistant_timezone_cache = (monotonic_now, contacts_ctx, result) + # Only cache a real timezone. Startup can briefly query before Contacts are + # readable; caching that miss would pin the assistant to UTC for the full TTL. + if result is not None: + _assistant_timezone_cache = (monotonic_now, contacts_ctx, result) _cache_store_ms = (perf_counter() - _cache_store_t0) * 1000 log_startup_timing( LOGGER, From c08be252dd368e12ecba4fca0a2da3d472f43186 Mon Sep 17 00:00:00 2001 From: juliagsy <67888047+juliagsy@users.noreply.github.com> Date: Thu, 7 May 2026 21:06:44 +0800 Subject: [PATCH 03/17] updated integration detection to also pick up explicitly registered assistants during config --- unity/integration_status/__init__.py | 51 +++++++++++++++++++++++----- 1 file changed, 42 insertions(+), 9 deletions(-) diff --git a/unity/integration_status/__init__.py b/unity/integration_status/__init__.py index ed9371019..7f995c60a 100644 --- a/unity/integration_status/__init__.py +++ b/unity/integration_status/__init__.py @@ -114,22 +114,55 @@ def reset_session_cache() -> None: def _load_registry() -> list[dict[str, Any]]: """Pull rows from ``Integrations/Manifests`` once per session, cache. - The persisted registry only contains rows for integrations that were - declared in the deployment spec (``integrations=[...]``). An assistant - may have credentials for *available* packages that weren't declared — - the typical case is a manual token paste in Console. To bridge that - gap we synthesize in-memory rows from disk discovery when the persisted - registry is empty. See :mod:`unity.integration_status.discovery`. + Always returns the **union** of two sources, deduped by slug: + + * ``_read_persisted_registry()`` — rows the deployment seed wrote into + ``Integrations/Manifests``. Authoritative for slugs they cover + because their ``function_names_json`` was the basis of the + FunctionManager rows the deploy pass created; ``enabled_function_ids`` + relies on that mapping holding. + * ``_synthesize_rows_from_discovery()`` — rows projected from every + manifest currently on disk under unity-deploy's package roots. + Covers packages installed but not declared in any deployment, so + auto-detection works for both registered and non-registered + assistants — matching the design intent that the integrations sync + exists only to auto-detect unity-deploy packages. + + Persisted rows win on slug conflict; disk-only slugs are added on top. + A package added to disk after the last deploy seed is detected via the + discovery half of the union; once a token paste enables it, hot-load + writes the real persisted row so subsequent reads come from the + persisted source. """ cache = _session_cache() if cache["registry_loaded"]: return cache["registry"] - rows = _read_persisted_registry() - if not rows: - rows = _synthesize_rows_from_discovery() + persisted = _read_persisted_registry() + discovered = _synthesize_rows_from_discovery() + + by_slug: dict[str, dict[str, Any]] = {} + for row in persisted: + slug = row.get("slug") + if slug: + by_slug[slug] = row + for row in discovered: + slug = row.get("slug") + if slug and slug not in by_slug: + by_slug[slug] = row + + rows = list(by_slug.values()) cache["registry_loaded"] = True cache["registry"] = rows + + logger.info( + "[integrations] _load_registry: persisted=%d discovered=%d " + "merged=%d slugs=%s", + len(persisted), + len(discovered), + len(rows), + sorted(by_slug.keys()), + ) return rows From c601be739ad81ffbced1e72a523b2eeada1dabce Mon Sep 17 00:00:00 2001 From: juliagsy <67888047+juliagsy@users.noreply.github.com> Date: Fri, 8 May 2026 11:09:26 +0800 Subject: [PATCH 04/17] fixed company-registered assistant integrations discovery bug --- .../test_enablement.py | 581 +++++------ .../test_integration_status/test_hot_load.py | 427 -------- unity/__init__.py | 24 + unity/integration_status/__init__.py | 970 +++++++----------- unity/integration_status/discovery.py | 32 +- unity/secret_manager/secret_manager.py | 109 +- 6 files changed, 707 insertions(+), 1436 deletions(-) delete mode 100644 tests/test_integration_status/test_hot_load.py diff --git a/tests/test_integration_status/test_enablement.py b/tests/test_integration_status/test_enablement.py index ef656550b..4b14301eb 100644 --- a/tests/test_integration_status/test_enablement.py +++ b/tests/test_integration_status/test_enablement.py @@ -1,21 +1,39 @@ -"""Unit tests for ``unity.integration_status``. - -These tests exercise the enablement-detection logic with synthetic registry -rows so we don't need a real DataManager context provisioned. Live-context -tests live separately under ``tests/secret_manager/`` once we wire the -hook end-to-end against Orchestra. +"""Unit tests for ``unity.integration_status`` enablement read. + +Exercises the pure read API (``get_enabled_integrations``, +``get_setup_completeness``, ``enabled_summary_for_prompt``, +``build_guidance_filter_scope``) against synthetic disk-package metadata +and synthetic secret keysets. + +Synthetic state is injected by stubbing two helpers: + +* ``discover_available_packages`` from + ``unity.integration_status.discovery`` — returns the list of + package-metadata dicts the assistant has on disk. +* ``_read_local_secret_keyset`` on the ``unity.integration_status`` + module — returns the set of currently-present secret names in the + assistant's local Secrets context. + +Manager-coupled helpers (``enabled_function_ids``, +``enabled_guidance_ids``) are tested separately under +``tests/test_integration_status/test_manager_resolution.py`` (TODO) once +we have lightweight FunctionManager / GuidanceManager fakes. """ from __future__ import annotations -import json +from pathlib import Path import pytest from unity import integration_status as IS +# --------------------------------------------------------------------------- +# Test fixtures +# --------------------------------------------------------------------------- + -def _row( +def _pkg( *, slug: str, label: str, @@ -24,403 +42,314 @@ def _row( function_names: list[str] | None = None, guidance_titles: list[str] | None = None, ) -> dict: + """Build a synthetic package-metadata dict matching the shape returned + by ``discover_available_packages``.""" return { "slug": slug, "label": label, "category": "test", "version": "0.1.0", "tier": "api", - "quality": "bronze", - "required_secrets_json": json.dumps(required), - "optional_secrets_json": json.dumps(optional or []), - "function_names_json": json.dumps(function_names or []), - "guidance_titles_json": json.dumps(guidance_titles or []), - "capability_ids_json": "[]", - "tags_json": "[]", - "homepage": "", - "description": f"{label} description", + "root_dir": Path("/nonexistent"), + "required_secrets": required, + "optional_secrets": optional or [], + "function_names": function_names or [], + "guidance_titles": guidance_titles or [], + "function_dir": None, + "guidance_dir": None, } -@pytest.fixture(autouse=True) -def _reset_cache(monkeypatch): - """Reset the per-session cache and reload the registry from the supplied - rows on each test. Bypasses DataManager entirely. +def _stub_packages_and_keyset( + monkeypatch: pytest.MonkeyPatch, + *, + packages: list[dict], + keyset: set[str] | None = None, +) -> None: + """Stub disk discovery + local keyset reads for an enablement test.""" + from unity.integration_status import discovery as D - Stubs ``schedule_hot_load`` so ``recompute_enablement`` doesn't spawn - daemon threads that would race with the test's cache-reset teardown. + monkeypatch.setattr(D, "discover_available_packages", lambda: packages) + monkeypatch.setattr(IS, "_read_local_secret_keyset", lambda: set(keyset or set())) - Stubs ``_read_local_secret_keyset`` to return an empty set by default - so tests that pass ``secrets={...}`` to ``recompute_enablement`` see - that argument as the sole keyset source (production reads from the - local Secrets context too, but in unit tests that context isn't - provisioned). Individual tests can override via monkeypatch. - Tests that specifically care about hot-load scheduling or local-context - keys can re-stub via monkeypatch + assert on the calls.""" +@pytest.fixture(autouse=True) +def _reset_cache(): IS.reset_session_cache() - monkeypatch.setattr(IS, "schedule_hot_load", lambda slug: None) - monkeypatch.setattr(IS, "_read_local_secret_keyset", lambda: set()) yield IS.reset_session_cache() -def _seed_registry(monkeypatch, rows: list[dict]) -> None: - """Patch ``_load_registry`` to return *rows* and prime the cache.""" - - def _fake_load() -> list[dict]: - cache = IS._session_cache() - cache["registry_loaded"] = True - cache["registry"] = rows - return rows - - monkeypatch.setattr(IS, "_load_registry", _fake_load) - - -# --------------------------------------------------------------------------- -# Empty registry → no detection -# --------------------------------------------------------------------------- - - -def test_recompute_with_no_registry_returns_empty(monkeypatch): - _seed_registry(monkeypatch, []) - - IS.recompute_enablement( - assistant_id=1, - secrets={"HUBSPOT_PRIVATE_APP_TOKEN": "tok"}, - ) - - assert IS.get_enabled_integrations() == [] - assert IS.get_setup_completeness() == {} - - -def test_summary_empty_when_no_registry(monkeypatch): - _seed_registry(monkeypatch, []) - IS.recompute_enablement(assistant_id=1, secrets={}) - assert IS.enabled_summary_for_prompt() == "" - - -def test_build_guidance_filter_scope_returns_none_when_registry_empty(monkeypatch): - _seed_registry(monkeypatch, []) - IS.recompute_enablement(assistant_id=1, secrets={}) - assert IS.build_guidance_filter_scope() is None - - # --------------------------------------------------------------------------- -# Single-integration enablement: HubSpot +# get_enabled_integrations — basic enablement logic # --------------------------------------------------------------------------- -def _hubspot_row() -> dict: - return _row( - slug="hubspot", - label="HubSpot", - required=["HUBSPOT_PRIVATE_APP_TOKEN"], - optional=["HUBSPOT_PORTAL_ID"], - function_names=["get_hubspot_contact", "search_hubspot_contacts"], - guidance_titles=["Hubspot Overview", "Hubspot Crm Contacts"], +def test_no_packages_returns_empty(monkeypatch): + _stub_packages_and_keyset( + monkeypatch, + packages=[], + keyset={"HUBSPOT_PRIVATE_APP_TOKEN"}, ) - - -def test_hubspot_enabled_when_required_token_present(monkeypatch): - _seed_registry(monkeypatch, [_hubspot_row()]) - - IS.recompute_enablement( - assistant_id=1, - secrets={"HUBSPOT_PRIVATE_APP_TOKEN": "tok"}, + assert IS.get_enabled_integrations() == {} + + +def test_required_secret_present_enables_package(monkeypatch): + _stub_packages_and_keyset( + monkeypatch, + packages=[ + _pkg( + slug="hubspot", + label="HubSpot", + required=["HUBSPOT_PRIVATE_APP_TOKEN"], + ), + ], + keyset={"HUBSPOT_PRIVATE_APP_TOKEN"}, ) - - assert IS.get_enabled_integrations() == ["hubspot"] - completeness = IS.get_setup_completeness() - assert completeness["hubspot"]["status"] == "configured" - assert completeness["hubspot"]["missing_optional_secrets"] == ["HUBSPOT_PORTAL_ID"] - - -def test_hubspot_disabled_without_required_token(monkeypatch): - _seed_registry(monkeypatch, [_hubspot_row()]) - - IS.recompute_enablement(assistant_id=1, secrets={"HUBSPOT_PORTAL_ID": "12345"}) - - assert IS.get_enabled_integrations() == [] - - -def test_hubspot_fully_connected_when_optional_present(monkeypatch): - _seed_registry(monkeypatch, [_hubspot_row()]) - - IS.recompute_enablement( - assistant_id=1, - secrets={ - "HUBSPOT_PRIVATE_APP_TOKEN": "tok", - "HUBSPOT_PORTAL_ID": "12345", - }, + enabled = IS.get_enabled_integrations() + assert set(enabled.keys()) == {"hubspot"} + assert enabled["hubspot"]["label"] == "HubSpot" + + +def test_required_secret_missing_disables_package(monkeypatch): + _stub_packages_and_keyset( + monkeypatch, + packages=[ + _pkg( + slug="hubspot", + label="HubSpot", + required=["HUBSPOT_PRIVATE_APP_TOKEN"], + ), + ], + keyset=set(), ) + assert IS.get_enabled_integrations() == {} - completeness = IS.get_setup_completeness() - assert completeness["hubspot"]["status"] == "fully_connected" - assert completeness["hubspot"]["missing_optional_secrets"] == [] - - -def test_empty_string_secret_treated_as_missing(monkeypatch): - """Orchestra returns ``""`` for unset secrets in some paths. Treat - empty/whitespace as not-present so we don't false-positive enablement.""" - _seed_registry(monkeypatch, [_hubspot_row()]) - IS.recompute_enablement( - assistant_id=1, - secrets={"HUBSPOT_PRIVATE_APP_TOKEN": ""}, +def test_no_required_secrets_means_always_enabled(monkeypatch): + """Packages with no declared required secrets are read-only / always-on + and should be enabled unconditionally.""" + _stub_packages_and_keyset( + monkeypatch, + packages=[_pkg(slug="public_data", label="Public Data", required=[])], + keyset=set(), ) - - assert IS.get_enabled_integrations() == [] - - -# --------------------------------------------------------------------------- -# Multi-secret AND: Employment Hero (CLIENT_ID + CLIENT_SECRET both required) -# --------------------------------------------------------------------------- + assert "public_data" in IS.get_enabled_integrations() -def _eh_row() -> dict: - return _row( +def test_multi_secret_AND_requires_all_present(monkeypatch): + eh = _pkg( slug="employment_hero", label="Employment Hero", - required=[ - "EMPLOYMENTHERO_OAUTH_CLIENT_ID", - "EMPLOYMENTHERO_OAUTH_CLIENT_SECRET", - ], - optional=[ - "EMPLOYMENTHERO_REFRESH_TOKEN", - "EMPLOYMENTHERO_ORGANISATION_ID", - ], - function_names=["get_employmenthero_employee"], - guidance_titles=["Employmenthero Overview"], + required=["EH_CLIENT_ID", "EH_CLIENT_SECRET"], ) + # Only one of two required secrets present. + _stub_packages_and_keyset(monkeypatch, packages=[eh], keyset={"EH_CLIENT_ID"}) + assert IS.get_enabled_integrations() == {} + + # Both present → enabled. + _stub_packages_and_keyset( + monkeypatch, + packages=[eh], + keyset={"EH_CLIENT_ID", "EH_CLIENT_SECRET"}, + ) + assert "employment_hero" in IS.get_enabled_integrations() -def test_eh_disabled_when_only_client_id_present(monkeypatch): - _seed_registry(monkeypatch, [_eh_row()]) - - IS.recompute_enablement( - assistant_id=1, - secrets={"EMPLOYMENTHERO_OAUTH_CLIENT_ID": "id"}, +def test_multiple_packages_independently_enabled(monkeypatch): + _stub_packages_and_keyset( + monkeypatch, + packages=[ + _pkg(slug="hubspot", label="HubSpot", required=["HUBSPOT_TOKEN"]), + _pkg(slug="webex", label="Webex", required=["WEBEX_TOKEN"]), + _pkg(slug="salesforce", label="Salesforce", required=["SF_TOKEN"]), + ], + keyset={"HUBSPOT_TOKEN", "WEBEX_TOKEN"}, # SF missing ) + enabled = IS.get_enabled_integrations() + assert set(enabled.keys()) == {"hubspot", "webex"} - assert IS.get_enabled_integrations() == [] +# --------------------------------------------------------------------------- +# get_setup_completeness — fully_connected vs configured +# --------------------------------------------------------------------------- -def test_eh_configured_with_both_required_but_oauth_not_complete(monkeypatch): - """Once CLIENT_ID + CLIENT_SECRET are pasted, EH is *configured* but the - OAuth Connect flow that populates REFRESH_TOKEN hasn't run. We surface - that gap via ``missing_optional_secrets`` so the prompt can guide the - user to complete setup.""" - _seed_registry(monkeypatch, [_eh_row()]) - IS.recompute_enablement( - assistant_id=1, - secrets={ - "EMPLOYMENTHERO_OAUTH_CLIENT_ID": "id", - "EMPLOYMENTHERO_OAUTH_CLIENT_SECRET": "secret", - }, +def test_completeness_configured_when_optional_missing(monkeypatch): + _stub_packages_and_keyset( + monkeypatch, + packages=[ + _pkg( + slug="hubspot", + label="HubSpot", + required=["HUBSPOT_TOKEN"], + optional=["HUBSPOT_PORTAL_ID"], + ), + ], + keyset={"HUBSPOT_TOKEN"}, ) - - enabled = IS.get_enabled_integrations() - assert enabled == ["employment_hero"] - completeness = IS.get_setup_completeness() - assert completeness["employment_hero"]["status"] == "configured" - assert ( - "EMPLOYMENTHERO_REFRESH_TOKEN" - in completeness["employment_hero"]["missing_optional_secrets"] + comp = IS.get_setup_completeness() + assert comp["hubspot"]["status"] == "configured" + assert comp["hubspot"]["missing_optional_secrets"] == ["HUBSPOT_PORTAL_ID"] + assert comp["hubspot"]["missing_required_secrets"] == [] + + +def test_completeness_fully_connected_when_optional_present(monkeypatch): + _stub_packages_and_keyset( + monkeypatch, + packages=[ + _pkg( + slug="hubspot", + label="HubSpot", + required=["HUBSPOT_TOKEN"], + optional=["HUBSPOT_PORTAL_ID"], + ), + ], + keyset={"HUBSPOT_TOKEN", "HUBSPOT_PORTAL_ID"}, ) + comp = IS.get_setup_completeness() + assert comp["hubspot"]["status"] == "fully_connected" + assert comp["hubspot"]["missing_optional_secrets"] == [] -def test_eh_fully_connected_with_full_oauth_set(monkeypatch): - _seed_registry(monkeypatch, [_eh_row()]) - - IS.recompute_enablement( - assistant_id=1, - secrets={ - "EMPLOYMENTHERO_OAUTH_CLIENT_ID": "id", - "EMPLOYMENTHERO_OAUTH_CLIENT_SECRET": "secret", - "EMPLOYMENTHERO_REFRESH_TOKEN": "rt", - "EMPLOYMENTHERO_ORGANISATION_ID": "org", - }, +def test_completeness_only_includes_enabled(monkeypatch): + _stub_packages_and_keyset( + monkeypatch, + packages=[ + _pkg(slug="hubspot", label="HubSpot", required=["HUBSPOT_TOKEN"]), + _pkg(slug="webex", label="Webex", required=["WEBEX_TOKEN"]), + ], + keyset={"HUBSPOT_TOKEN"}, ) - - completeness = IS.get_setup_completeness() - assert completeness["employment_hero"]["status"] == "fully_connected" + comp = IS.get_setup_completeness() + assert set(comp.keys()) == {"hubspot"} # --------------------------------------------------------------------------- -# Mid-session recompute (the core promise of the _sync_assistant_secrets hook) +# enabled_summary_for_prompt — prompt block rendering # --------------------------------------------------------------------------- -def test_mid_session_recompute_picks_up_newly_added_secret(monkeypatch): - """Simulates the user pasting a HubSpot token mid-session. The first - sync call runs with no token; the second runs after the user adds it. - The enabled set must update without any session restart.""" - _seed_registry(monkeypatch, [_hubspot_row()]) +def test_summary_empty_when_no_packages(monkeypatch): + _stub_packages_and_keyset(monkeypatch, packages=[], keyset=set()) + assert IS.enabled_summary_for_prompt() == "" - IS.recompute_enablement(assistant_id=1, secrets={}) - assert IS.get_enabled_integrations() == [] - # User pastes the token. Next _sync_assistant_secrets fires. - IS.recompute_enablement( - assistant_id=1, - secrets={"HUBSPOT_PRIVATE_APP_TOKEN": "tok"}, +def test_summary_renders_active_and_inactive(monkeypatch): + _stub_packages_and_keyset( + monkeypatch, + packages=[ + _pkg(slug="hubspot", label="HubSpot", required=["HUBSPOT_TOKEN"]), + _pkg(slug="webex", label="Webex", required=["WEBEX_TOKEN"]), + ], + keyset={"HUBSPOT_TOKEN"}, ) - assert IS.get_enabled_integrations() == ["hubspot"] - - -def test_mid_session_recompute_drops_disabled_integration(monkeypatch): - """Symmetric: user revokes a token mid-session → integration falls out - of the enabled set on next sync.""" - _seed_registry(monkeypatch, [_hubspot_row()]) - - IS.recompute_enablement( - assistant_id=1, - secrets={"HUBSPOT_PRIVATE_APP_TOKEN": "tok"}, + summary = IS.enabled_summary_for_prompt() + assert "### Integrations" in summary + assert "HubSpot" in summary + assert "Webex" in summary + assert "WEBEX_TOKEN" in summary + # HubSpot should be in the "Active" section, Webex in "Inactive". + active_section, inactive_section = summary.split( + "Inactive (credentials not configured):", ) - assert IS.get_enabled_integrations() == ["hubspot"] - - IS.recompute_enablement(assistant_id=1, secrets={}) - assert IS.get_enabled_integrations() == [] - - -# --------------------------------------------------------------------------- -# Multi-integration: HubSpot + Employment Hero -# --------------------------------------------------------------------------- + assert "HubSpot" in active_section + assert "Webex" in inactive_section -def test_multi_integration_independent_enablement(monkeypatch): - """HubSpot and EH should enable independently of each other.""" - _seed_registry(monkeypatch, [_hubspot_row(), _eh_row()]) - - IS.recompute_enablement( - assistant_id=1, - secrets={"HUBSPOT_PRIVATE_APP_TOKEN": "tok"}, +def test_summary_no_inactive_section_when_all_active(monkeypatch): + _stub_packages_and_keyset( + monkeypatch, + packages=[_pkg(slug="hubspot", label="HubSpot", required=["HUBSPOT_TOKEN"])], + keyset={"HUBSPOT_TOKEN"}, ) - - enabled = sorted(IS.get_enabled_integrations()) - assert enabled == ["hubspot"] + summary = IS.enabled_summary_for_prompt() + assert "Active integrations:" in summary + assert "Inactive" not in summary # --------------------------------------------------------------------------- -# Allowlist union (drives SecretManager._resolve_secret_allowlist) +# build_guidance_filter_scope — guidance gating # --------------------------------------------------------------------------- -def test_all_known_secret_names_unions_all_required_and_optional(monkeypatch): - _seed_registry(monkeypatch, [_hubspot_row(), _eh_row()]) - - names = IS.all_known_secret_names() - - # Required secrets from both packages. - assert "HUBSPOT_PRIVATE_APP_TOKEN" in names - assert "EMPLOYMENTHERO_OAUTH_CLIENT_ID" in names - assert "EMPLOYMENTHERO_OAUTH_CLIENT_SECRET" in names - # Optional secrets from both packages. - assert "HUBSPOT_PORTAL_ID" in names - assert "EMPLOYMENTHERO_REFRESH_TOKEN" in names - - -def test_all_known_secret_names_empty_when_registry_empty(monkeypatch): - """``all_known_secret_names`` unions the persisted registry with disk - discovery (added in the hot-load PR). To assert "empty result", both - sources must be patched to empty — otherwise the function legitimately - returns disk-discovered packages installed alongside this venv.""" - _seed_registry(monkeypatch, []) +def test_filter_scope_none_when_no_packages(monkeypatch): + """No integration packages on disk → don't filter (preserve existing + behaviour for non-integration callers).""" + _stub_packages_and_keyset(monkeypatch, packages=[], keyset={"FOO"}) + assert IS.build_guidance_filter_scope() is None - from unity.integration_status import discovery as IS_DISCOVERY - monkeypatch.setattr( - IS_DISCOVERY, - "discover_available_packages", - lambda *, force_reload=False: [], +def test_filter_scope_never_match_when_packages_exist_but_none_enabled(monkeypatch): + """Packages on disk but none enabled → hide all integration guidance.""" + _stub_packages_and_keyset( + monkeypatch, + packages=[_pkg(slug="hubspot", label="HubSpot", required=["HUBSPOT_TOKEN"])], + keyset=set(), ) - - assert IS.all_known_secret_names() == set() + assert IS.build_guidance_filter_scope() == "guidance_id in ()" # --------------------------------------------------------------------------- -# Prompt-block rendering +# register_available_integrations — startup pass # --------------------------------------------------------------------------- -def test_summary_lists_active_and_inactive_with_setup_hint(monkeypatch): - _seed_registry(monkeypatch, [_hubspot_row(), _eh_row()]) +def test_register_is_idempotent(monkeypatch): + """Re-running register_available_integrations doesn't double-process + packages already in registered_slugs.""" + calls = {"functions": 0, "guidance": 0} - IS.recompute_enablement( - assistant_id=1, - secrets={ - "HUBSPOT_PRIVATE_APP_TOKEN": "tok", - "HUBSPOT_PORTAL_ID": "12345", - }, - ) + def fake_register_functions(pkg): + calls["functions"] += 1 + return 0 - summary = IS.enabled_summary_for_prompt() - # HubSpot ends up in Active. - assert "Active integrations:" in summary - assert "HubSpot (fully_connected)" in summary - # Employment Hero ends up in Inactive with the missing required keys - # spelled out, so the LLM can guide the user. - assert "Inactive (credentials not configured):" in summary - assert "Employment Hero" in summary - assert "EMPLOYMENTHERO_OAUTH_CLIENT_ID" in summary - assert "EMPLOYMENTHERO_OAUTH_CLIENT_SECRET" in summary - - -def test_summary_active_with_configured_lists_missing_optional(monkeypatch): - """When an integration is configured (required met) but missing optional - secrets like REFRESH_TOKEN, the prompt should call this out so the LLM - can guide the user through the OAuth Connect step.""" - _seed_registry(monkeypatch, [_eh_row()]) - - IS.recompute_enablement( - assistant_id=1, - secrets={ - "EMPLOYMENTHERO_OAUTH_CLIENT_ID": "id", - "EMPLOYMENTHERO_OAUTH_CLIENT_SECRET": "secret", - }, - ) + def fake_register_guidance(pkg): + calls["guidance"] += 1 + return 0 - summary = IS.enabled_summary_for_prompt() - assert "Active integrations:" in summary - assert "configured" in summary - assert "EMPLOYMENTHERO_REFRESH_TOKEN" in summary - - -# --------------------------------------------------------------------------- -# Filter scope construction (used to set GuidanceManager.filter_scope) -# --------------------------------------------------------------------------- - - -def test_build_guidance_filter_scope_returns_empty_set_when_registry_seeded_but_no_enablement( - monkeypatch, -): - """When integrations exist on the deployment but none have credentials, - we want guidance retrieval to actively *exclude* integration guidance, - not fall through silently. ``guidance_id in ()`` matches nothing — - this is intentional.""" - _seed_registry(monkeypatch, [_hubspot_row()]) - IS.recompute_enablement(assistant_id=1, secrets={}) + _stub_packages_and_keyset( + monkeypatch, + packages=[_pkg(slug="hubspot", label="HubSpot", required=["X"])], + keyset=set(), + ) + monkeypatch.setattr(IS, "_register_functions", fake_register_functions) + monkeypatch.setattr(IS, "_register_guidance", fake_register_guidance) - scope = IS.build_guidance_filter_scope() - assert scope == "guidance_id in ()" + IS.register_available_integrations() + assert calls == {"functions": 1, "guidance": 1} + # Second call should be a no-op for the already-registered slug. + IS.register_available_integrations() + assert calls == {"functions": 1, "guidance": 1} -# --------------------------------------------------------------------------- -# Cache reset -# --------------------------------------------------------------------------- +def test_register_per_package_failure_does_not_halt_others(monkeypatch): + """If one package's functions/guidance step raises, the remaining + packages still get processed.""" -def test_reset_session_cache_clears_state(monkeypatch): - _seed_registry(monkeypatch, [_hubspot_row()]) + def failing_functions(pkg): + if pkg["slug"] == "broken": + raise RuntimeError("simulated failure") + return 0 - IS.recompute_enablement( - assistant_id=1, - secrets={"HUBSPOT_PRIVATE_APP_TOKEN": "tok"}, + _stub_packages_and_keyset( + monkeypatch, + packages=[ + _pkg(slug="broken", label="Broken", required=[]), + _pkg(slug="hubspot", label="HubSpot", required=[]), + ], + keyset=set(), ) - assert IS.get_enabled_integrations() == ["hubspot"] - - IS.reset_session_cache() - assert IS.get_enabled_integrations() == [] - assert IS.get_setup_completeness() == {} + monkeypatch.setattr(IS, "_register_functions", failing_functions) + monkeypatch.setattr(IS, "_register_guidance", lambda pkg: 0) + + # Should not raise. + IS.register_available_integrations() + + # Both slugs should be marked registered (per-step try/except in the + # implementation; the broken slug's failure is logged but not + # propagated, and registration of guidance still ran). + cache = IS._session_cache() + assert "broken" in cache["registered_slugs"] + assert "hubspot" in cache["registered_slugs"] diff --git a/tests/test_integration_status/test_hot_load.py b/tests/test_integration_status/test_hot_load.py deleted file mode 100644 index 5bb7e7f64..000000000 --- a/tests/test_integration_status/test_hot_load.py +++ /dev/null @@ -1,427 +0,0 @@ -"""Unit tests for the hot-load path in ``unity.integration_status``. - -Hot-load registers an available package's functions, guidance, and registry -row into the running managers in a daemon thread, so a token paste in -Console takes effect on the next message without a deployment cycle. - -These tests stub out the heavyweight pieces (FunctionManager, -GuidanceManager, unify backend) so we can verify the orchestration logic in -isolation. Live-context tests live separately. -""" - -from __future__ import annotations - -import logging -import time -from pathlib import Path - -import pytest - -from unity import integration_status as IS -from unity.integration_status import discovery as IS_DISCOVERY - -# --------------------------------------------------------------------------- -# Fixtures: synthetic discovery rows -# --------------------------------------------------------------------------- - - -def _hubspot_discovery_row(*, root: Path | None = None) -> dict: - return { - "slug": "hubspot", - "label": "HubSpot", - "category": "crm", - "version": "0.1.0", - "tier": "api", - "root_dir": root or Path("/tmp/fake/hubspot"), - "required_secrets": ["HUBSPOT_PRIVATE_APP_TOKEN"], - "optional_secrets": ["HUBSPOT_PORTAL_ID"], - "function_names": ["get_hubspot_contact", "search_hubspot_contacts"], - "guidance_titles": ["Hubspot Overview", "Hubspot Crm Contacts"], - "function_dir": ( - (root or Path("/tmp/fake/hubspot")) / "functions" if root else None - ), - "guidance_dir": ( - (root or Path("/tmp/fake/hubspot")) / "guidance" if root else None - ), - } - - -def _eh_discovery_row() -> dict: - return { - "slug": "employment_hero", - "label": "Employment Hero", - "category": "human-resources", - "version": "0.1.0", - "tier": "api", - "root_dir": Path("/tmp/fake/eh"), - "required_secrets": [ - "EMPLOYMENTHERO_OAUTH_CLIENT_ID", - "EMPLOYMENTHERO_OAUTH_CLIENT_SECRET", - ], - "optional_secrets": [ - "EMPLOYMENTHERO_REFRESH_TOKEN", - "EMPLOYMENTHERO_ORGANISATION_ID", - ], - "function_names": ["get_employmenthero_employee"], - "guidance_titles": ["Employmenthero Overview"], - "function_dir": None, - "guidance_dir": None, - } - - -@pytest.fixture(autouse=True) -def _isolate_state(monkeypatch): - """Clear cache + discovery between tests so they don't leak state. - - Also stubs ``_read_local_secret_keyset`` to return an empty set by - default so tests that drive ``recompute_enablement`` via ``secrets={}`` - don't accidentally pull from a real SecretManager (which doesn't exist - in unit-test scope). Tests that care about the local-context path - re-stub explicitly.""" - IS.reset_session_cache() - IS_DISCOVERY.reset_discovery_cache() - monkeypatch.setattr(IS, "_read_local_secret_keyset", lambda: set()) - yield - IS.reset_session_cache() - IS_DISCOVERY.reset_discovery_cache() - - -def _patch_discovery(monkeypatch, rows: list[dict]) -> None: - monkeypatch.setattr( - IS_DISCOVERY, - "discover_available_packages", - lambda *, force_reload=False: rows, - ) - - -# --------------------------------------------------------------------------- -# all_known_secret_names: now unions registry + discovery -# --------------------------------------------------------------------------- - - -def test_all_known_secret_names_includes_discovered_when_registry_empty(monkeypatch): - """When the persisted registry is empty (deployment didn't declare any - integrations) but disk has packages, allowlist still includes their - secrets — that's what bridges the token-paste-in-Console flow.""" - _patch_discovery(monkeypatch, [_hubspot_discovery_row(), _eh_discovery_row()]) - - # Force _load_registry to return empty (mimics no persisted registry). - monkeypatch.setattr(IS, "_read_persisted_registry", lambda: []) - - names = IS.all_known_secret_names() - assert "HUBSPOT_PRIVATE_APP_TOKEN" in names - assert "HUBSPOT_PORTAL_ID" in names - assert "EMPLOYMENTHERO_OAUTH_CLIENT_ID" in names - assert "EMPLOYMENTHERO_OAUTH_CLIENT_SECRET" in names - assert "EMPLOYMENTHERO_REFRESH_TOKEN" in names - - -def test_all_known_secret_names_empty_when_no_registry_no_discovery(monkeypatch): - monkeypatch.setattr(IS, "_read_persisted_registry", lambda: []) - _patch_discovery(monkeypatch, []) - assert IS.all_known_secret_names() == set() - - -# --------------------------------------------------------------------------- -# Synthesized registry rows: recompute_enablement works against discovery -# even when the persisted registry is empty -# --------------------------------------------------------------------------- - - -def test_recompute_works_against_discovery_when_persisted_registry_empty(monkeypatch): - """The whole point: an assistant whose deployment didn't declare HubSpot - can still detect HubSpot as enabled once a token is added to Secrets.""" - _patch_discovery(monkeypatch, [_hubspot_discovery_row()]) - monkeypatch.setattr(IS, "_read_persisted_registry", lambda: []) - # Stub out schedule_hot_load so the test doesn't spawn a thread. - spawned: list[str] = [] - monkeypatch.setattr(IS, "schedule_hot_load", lambda slug: spawned.append(slug)) - - IS.recompute_enablement( - assistant_id=42, - secrets={"HUBSPOT_PRIVATE_APP_TOKEN": "tok"}, - ) - - assert IS.get_enabled_integrations() == ["hubspot"] - assert spawned == ["hubspot"] - - -def test_recompute_does_not_schedule_when_no_required_present(monkeypatch): - _patch_discovery(monkeypatch, [_hubspot_discovery_row()]) - monkeypatch.setattr(IS, "_read_persisted_registry", lambda: []) - spawned: list[str] = [] - monkeypatch.setattr(IS, "schedule_hot_load", lambda slug: spawned.append(slug)) - - IS.recompute_enablement(assistant_id=42, secrets={}) - - assert IS.get_enabled_integrations() == [] - assert spawned == [] - - -def test_recompute_enables_integration_from_local_secrets_with_empty_orchestra( - monkeypatch, -): - """Regression test for the staging bug diagnosed 2026-05-06. - - Integration tokens (HubSpot etc.) are written by Console directly to - the assistant's local Secrets context — they never appear in the - Orchestra ``assistant.secrets`` payload. Before the fix, - ``recompute_enablement`` only saw the Orchestra payload, so HubSpot - could never be enabled even when the token was actually present in - the local context. This test pins the corrected behaviour: the local - keyset alone is sufficient to drive enablement + hot-load scheduling, - even when the Orchestra supplement is empty. - """ - _patch_discovery(monkeypatch, [_hubspot_discovery_row()]) - monkeypatch.setattr(IS, "_read_persisted_registry", lambda: []) - - # Mimic Console having written HUBSPOT_PRIVATE_APP_TOKEN to the local - # Secrets context. Orchestra returns nothing for this assistant. - monkeypatch.setattr( - IS, - "_read_local_secret_keyset", - lambda: {"HUBSPOT_PRIVATE_APP_TOKEN"}, - ) - spawned: list[str] = [] - monkeypatch.setattr(IS, "schedule_hot_load", lambda slug: spawned.append(slug)) - - IS.recompute_enablement(assistant_id=629, secrets={}) - - assert IS.get_enabled_integrations() == ["hubspot"] - assert spawned == ["hubspot"] - - -def test_recompute_unions_local_keyset_with_orchestra_supplement(monkeypatch): - """Both sources contribute. The local context holds Console-pasted - tokens (HubSpot here); the Orchestra supplement carries a key not yet - mirrored — both should drive enablement of their respective packages - in a single recompute call.""" - _patch_discovery(monkeypatch, [_hubspot_discovery_row(), _eh_discovery_row()]) - monkeypatch.setattr(IS, "_read_persisted_registry", lambda: []) - - # Local context: only HubSpot's token. - monkeypatch.setattr( - IS, - "_read_local_secret_keyset", - lambda: {"HUBSPOT_PRIVATE_APP_TOKEN"}, - ) - spawned: list[str] = [] - monkeypatch.setattr(IS, "schedule_hot_load", lambda slug: spawned.append(slug)) - - # Orchestra supplement: EH's two required keys (race window before mirror). - IS.recompute_enablement( - assistant_id=629, - secrets={ - "EMPLOYMENTHERO_OAUTH_CLIENT_ID": "id", - "EMPLOYMENTHERO_OAUTH_CLIENT_SECRET": "secret", - }, - ) - - assert sorted(IS.get_enabled_integrations()) == ["employment_hero", "hubspot"] - assert sorted(spawned) == ["employment_hero", "hubspot"] - - -def test_recompute_does_not_re_schedule_when_already_loaded(monkeypatch): - """Once a slug is in ``loaded_slugs`` (a previous hot-load completed), - recompute_enablement must not queue another load for the same slug.""" - _patch_discovery(monkeypatch, [_hubspot_discovery_row()]) - monkeypatch.setattr(IS, "_read_persisted_registry", lambda: []) - spawned: list[str] = [] - monkeypatch.setattr(IS, "schedule_hot_load", lambda slug: spawned.append(slug)) - - cache = IS._session_cache() - cache["loaded_slugs"] = {"hubspot"} - - IS.recompute_enablement( - assistant_id=42, - secrets={"HUBSPOT_PRIVATE_APP_TOKEN": "tok"}, - ) - - assert spawned == [] - - -# --------------------------------------------------------------------------- -# schedule_hot_load: non-blocking + idempotent -# --------------------------------------------------------------------------- - - -def test_schedule_hot_load_returns_immediately(monkeypatch): - """schedule_hot_load spawns a daemon thread; the calling thread must not - wait for the load to finish. We use a deliberately slow worker to assert - schedule_hot_load returns well before the load does.""" - barrier_seconds = 5.0 - - def _slow_hot_load(slug: str): - time.sleep(barrier_seconds) - - async def _slow_async(slug: str): - time.sleep(barrier_seconds) - - monkeypatch.setattr(IS, "hot_load_integration", _slow_async) - - started = time.perf_counter() - IS.schedule_hot_load("hubspot") - elapsed = time.perf_counter() - started - - # Schedule should return in well under 1s; the worker is sleeping for 5s. - assert elapsed < 1.0, f"schedule_hot_load took {elapsed:.2f}s — should be ~0s" - - -def test_schedule_hot_load_is_idempotent_for_loaded_slug(monkeypatch): - """If the slug is already in ``loaded_slugs``, schedule must skip the - thread spawn entirely.""" - threads_spawned: list[str] = [] - - real_thread_cls = IS.__dict__.get("threading", None) - - import threading as _threading - - monkeypatch.setattr( - _threading, - "Thread", - lambda *a, **k: pytest.fail("Should not spawn a thread"), - ) - - cache = IS._session_cache() - cache["loaded_slugs"] = {"hubspot"} - - # Should be a no-op. - IS.schedule_hot_load("hubspot") - - -def test_schedule_hot_load_skips_when_already_loading(monkeypatch): - """Same as above but for the in-flight guard.""" - cache = IS._session_cache() - cache["loading_slugs"] = {"hubspot"} - - import threading as _threading - - monkeypatch.setattr( - _threading, - "Thread", - lambda *a, **k: pytest.fail("Should not spawn a thread"), - ) - - IS.schedule_hot_load("hubspot") - - -def test_schedule_hot_load_failure_does_not_propagate(monkeypatch): - """Any error inside the worker (asyncio.run, etc.) must be logged and - swallowed; the calling thread must not see it.""" - - async def _raises(slug): - raise RuntimeError("synthetic failure") - - monkeypatch.setattr(IS, "hot_load_integration", _raises) - - # Should not raise. Wait a short time for the worker to fail. - IS.schedule_hot_load("hubspot") - time.sleep(0.2) - - # ``loading_slugs`` cleanup runs in the worker's finally block. - cache = IS._session_cache() - assert "hubspot" not in cache.get("loading_slugs", set()) - # ``loaded_slugs`` is NOT updated on failure → next recompute will retry. - assert "hubspot" not in cache.get("loaded_slugs", set()) - - -# --------------------------------------------------------------------------- -# hot_load_integration: orchestration with stubbed sub-steps -# --------------------------------------------------------------------------- - - -def test_hot_load_invokes_all_three_substeps_for_known_slug(monkeypatch): - pkg = _hubspot_discovery_row() - monkeypatch.setattr(IS_DISCOVERY, "get_package_for_slug", lambda slug: pkg) - - calls: list[str] = [] - monkeypatch.setattr(IS, "_hot_load_guidance", lambda p: calls.append("guidance")) - monkeypatch.setattr( - IS, - "_hot_load_functions", - lambda p: calls.append("functions") or 2, - ) - monkeypatch.setattr( - IS, - "_hot_load_registry_row", - lambda p: calls.append("registry"), - ) - - import asyncio - - asyncio.run(IS.hot_load_integration("hubspot")) - - assert calls == ["guidance", "functions", "registry"] - cache = IS._session_cache() - assert "hubspot" in cache.get("loaded_slugs", set()) - - -def test_hot_load_skips_when_already_loaded(monkeypatch): - monkeypatch.setattr( - IS_DISCOVERY, - "get_package_for_slug", - lambda slug: pytest.fail("Should not consult discovery"), - ) - - cache = IS._session_cache() - cache["loaded_slugs"] = {"hubspot"} - - import asyncio - - asyncio.run(IS.hot_load_integration("hubspot")) - - -def test_hot_load_no_op_for_unknown_slug(monkeypatch, caplog): - monkeypatch.setattr(IS_DISCOVERY, "get_package_for_slug", lambda slug: None) - - import asyncio - - with caplog.at_level(logging.WARNING): - asyncio.run(IS.hot_load_integration("nonexistent")) - - cache = IS._session_cache() - assert "nonexistent" not in cache.get("loaded_slugs", set()) - - -def test_hot_load_partial_failure_still_marks_loaded(monkeypatch): - """If one sub-step fails (e.g. guidance) but others succeed, the slug - is still marked loaded — partial success is better than re-attempting - the same failure on every recompute.""" - pkg = _hubspot_discovery_row() - monkeypatch.setattr(IS_DISCOVERY, "get_package_for_slug", lambda slug: pkg) - - def _explode(p): - raise RuntimeError("guidance step blew up") - - monkeypatch.setattr(IS, "_hot_load_guidance", _explode) - monkeypatch.setattr(IS, "_hot_load_functions", lambda p: 0) - monkeypatch.setattr(IS, "_hot_load_registry_row", lambda p: None) - - import asyncio - - asyncio.run(IS.hot_load_integration("hubspot")) - - cache = IS._session_cache() - assert "hubspot" in cache.get("loaded_slugs", set()) - - -def test_hot_load_invalidates_registry_cache(monkeypatch): - """After a successful hot-load, the persisted registry has a fresh row - that subsequent allowlist computations should see. Verify the cache - flag is reset so ``_load_registry`` re-reads.""" - pkg = _hubspot_discovery_row() - monkeypatch.setattr(IS_DISCOVERY, "get_package_for_slug", lambda slug: pkg) - monkeypatch.setattr(IS, "_hot_load_guidance", lambda p: 0) - monkeypatch.setattr(IS, "_hot_load_functions", lambda p: 0) - monkeypatch.setattr(IS, "_hot_load_registry_row", lambda p: None) - - cache = IS._session_cache() - cache["registry_loaded"] = True - cache["registry"] = [{"slug": "stale"}] - - import asyncio - - asyncio.run(IS.hot_load_integration("hubspot")) - - assert cache["registry_loaded"] is False - assert cache["registry"] == [] diff --git a/unity/__init__.py b/unity/__init__.py index 182d34ade..b7ad3f465 100644 --- a/unity/__init__.py +++ b/unity/__init__.py @@ -111,6 +111,30 @@ def init( with startup_timing(LOGGER, "unity.init.context_registry_setup"): ContextRegistry.setup() + # Schedule disk-package registration as a background task so the + # fast-brain conversation/communication loop can come online without + # waiting for function + guidance inserts to finish. Single + # daemon thread, runs once, captures + re-applies the Unify active + # context. Must be scheduled AFTER ContextRegistry.setup so the + # worker can resolve manager contexts. Integration functions become + # callable on the next conversation turn after the worker completes + # (low hundreds of ms typically). See + # :func:`unity.integration_status.schedule_register_available_integrations` + # for the contract — this replaces the May-2026 per-slug daemon-thread + # hot-load that ran from inside ``SecretManager.__init__``. + with startup_timing(LOGGER, "unity.init.schedule_register_integrations"): + try: + from .integration_status import ( + schedule_register_available_integrations, + ) + + schedule_register_available_integrations() + except Exception: + LOGGER.exception( + "[integrations] failed to schedule registration at startup; " + "integrations may not be available this session", + ) + from .events import event_bus as _event_bus_mod with startup_timing(LOGGER, "unity.init.event_bus_init"): diff --git a/unity/integration_status/__init__.py b/unity/integration_status/__init__.py index 7f995c60a..c6f9f3125 100644 --- a/unity/integration_status/__init__.py +++ b/unity/integration_status/__init__.py @@ -1,71 +1,57 @@ -"""Per-session detection of which integrations are enabled for an assistant. - -The deploy side (``unity_deploy.assistant_deployments.integrations``) seeds a -flat row per integration into the ``Integrations/Manifests`` DataManager -context. Each row carries the integration's ``required_secrets``, -``optional_secrets``, ``function_names``, and ``guidance_titles`` (all -JSON-stringified). - -At runtime the assistant has its own secret keyset (synced from Orchestra by -``SecretManager._sync_assistant_secrets``). An integration is *enabled* iff -every secret listed in ``required_secrets_json`` is present in that keyset. -Once we know which integrations are enabled we can: - -* Inject an integration-status block into the system prompt so the LLM knows - what's available + what setup steps remain. -* Set ``GuidanceManager.filter_scope`` to a guidance-id predicate that hides - guidance for integrations whose credentials aren't configured. -* Drive the per-session cache that ``recompute_enablement`` updates from - every ``_sync_assistant_secrets`` call. - -This module is purely a reader of the registry + secret keyset; it never -writes. All work is best-effort: if the registry hasn't been seeded (e.g. -first deploy after PR A1 lands but before A2 ships), every helper here -returns an empty/no-op result and the caller falls through to existing -behaviour. +"""Integration package registration and enablement detection. + +Two orthogonal mechanisms, deliberately decoupled from secret transport +(:mod:`unity.secret_manager`). + +1. **Registration (startup, mechanism 1).** + :func:`register_available_integrations` walks every package manifest under + unity-deploy's package roots and registers each one's functions + guidance + with the runtime managers. Synchronous, main-thread, idempotent. Called + once from :mod:`unity.__init__` after :meth:`ContextRegistry.setup`. + +2. **Enablement (read-only, mechanism 2).** + :func:`get_enabled_integrations` is a pure query that returns which + packages have their required secrets satisfied right now. No caching + that can go stale, no thread spawning, no manager re-construction. + Prompt builders / the actor invoke on demand. + +**Secret transport is a third, separate concern** owned by +:meth:`SecretManager._sync_assistant_secrets` (Google / Microsoft OAuth +tokens) and :meth:`SecretManager._sync_dotenv` (everything else, including +Console-pasted integration credentials). This module never touches secret +values; it only reads the local Secrets keyset to decide which integrations +are enabled. + +History: the early-May design entangled all three concerns into one sync +flow with daemon-thread hot-load, registry-derived secret allowlists, and +recompute_enablement called from inside SecretManager.__init__. That +produced cascading bugs (silent integration-secret wipes, recursive +manager construction, daemon-thread context-resolution failures). This +module is the post-cleanup shape — see ``unity/integration_status`` +commit history for the trail. """ from __future__ import annotations -import json import logging from typing import Any logger = logging.getLogger(__name__) -_REGISTRY_CONTEXT_LEAF = "Integrations/Manifests" - # --------------------------------------------------------------------------- -# Per-body session cache +# Per-body session cache (tiny — only registration idempotency) # --------------------------------------------------------------------------- -# -# The cache lives on ``SESSION_DETAILS`` so that mid-session changes (a user -# pasting a new token via Settings → Secrets, then ``_sync_assistant_secrets`` -# firing again) propagate to subsequent prompt builds + guidance retrievals -# without a session restart. -# -# We attach a single dict via ``setattr`` rather than a typed field so we can -# evolve the shape without coordinating a unity-side schema change. _SESSION_CACHE_ATTR = "_integration_status_cache" def _session_cache() -> dict[str, Any]: - """Lazily attach + return the per-body integration-status cache. - - The cache shape: - - { - "registry": [, ...] # full Integrations/Manifests rows - "registry_loaded": bool # have we tried at least once? - "enabled": {: } # only enabled integrations - "completeness": {: {...}} # configured vs fully_connected per slug - "secret_names": set[str] # last-observed assistant keyset - } + """Return the per-body cache dict, attaching to ``SESSION_DETAILS``. - Falls back to a process-local module dict when ``SESSION_DETAILS`` isn't - available (test paths that exercise this helper in isolation). + Falls back to a process-local module dict when ``SESSION_DETAILS`` + isn't constructed yet (test paths that exercise this helper in + isolation). """ try: from unity.session_details import SESSION_DETAILS @@ -76,170 +62,43 @@ def _session_cache() -> dict[str, Any]: setattr(SESSION_DETAILS, _SESSION_CACHE_ATTR, cache) return cache except Exception: - # Fallback for environments where SESSION_DETAILS isn't constructed yet. if not hasattr(_session_cache, "_fallback"): - _session_cache._fallback = _empty_cache() + _session_cache._fallback = _empty_cache() # type: ignore[attr-defined] return _session_cache._fallback # type: ignore[attr-defined] def _empty_cache() -> dict[str, Any]: return { - "registry": [], - "registry_loaded": False, - "enabled": {}, - "completeness": {}, - "secret_names": set(), - # Hot-load tracking. ``loaded_slugs`` is set after a successful - # hot-load completes for this body; ``loading_slugs`` is the set of - # slugs currently being loaded by a daemon thread. ``schedule_hot_load`` - # is a no-op when the slug is in either set. - "loaded_slugs": set(), - "loading_slugs": set(), + # Slugs whose functions + guidance have already been registered + # with the runtime managers this session. Idempotency guard for + # ``register_available_integrations`` re-runs. + "registered_slugs": set(), } def reset_session_cache() -> None: - """Drop the cached enablement state for the current session. - - Used by tests to isolate scenarios; production code shouldn't need this.""" + """Drop the cache. Used by tests for isolation; production code + shouldn't need this.""" cache = _session_cache() cache.update(_empty_cache()) # --------------------------------------------------------------------------- -# Registry lookup +# Local secret keyset (read once per enablement query) # --------------------------------------------------------------------------- -def _load_registry() -> list[dict[str, Any]]: - """Pull rows from ``Integrations/Manifests`` once per session, cache. - - Always returns the **union** of two sources, deduped by slug: - - * ``_read_persisted_registry()`` — rows the deployment seed wrote into - ``Integrations/Manifests``. Authoritative for slugs they cover - because their ``function_names_json`` was the basis of the - FunctionManager rows the deploy pass created; ``enabled_function_ids`` - relies on that mapping holding. - * ``_synthesize_rows_from_discovery()`` — rows projected from every - manifest currently on disk under unity-deploy's package roots. - Covers packages installed but not declared in any deployment, so - auto-detection works for both registered and non-registered - assistants — matching the design intent that the integrations sync - exists only to auto-detect unity-deploy packages. - - Persisted rows win on slug conflict; disk-only slugs are added on top. - A package added to disk after the last deploy seed is detected via the - discovery half of the union; once a token paste enables it, hot-load - writes the real persisted row so subsequent reads come from the - persisted source. - """ - cache = _session_cache() - if cache["registry_loaded"]: - return cache["registry"] - - persisted = _read_persisted_registry() - discovered = _synthesize_rows_from_discovery() - - by_slug: dict[str, dict[str, Any]] = {} - for row in persisted: - slug = row.get("slug") - if slug: - by_slug[slug] = row - for row in discovered: - slug = row.get("slug") - if slug and slug not in by_slug: - by_slug[slug] = row - - rows = list(by_slug.values()) - cache["registry_loaded"] = True - cache["registry"] = rows - - logger.info( - "[integrations] _load_registry: persisted=%d discovered=%d " - "merged=%d slugs=%s", - len(persisted), - len(discovered), - len(rows), - sorted(by_slug.keys()), - ) - return rows - - -def _read_persisted_registry() -> list[dict[str, Any]]: - try: - import unify - - active = unify.get_active_context()["read"] - ctx = f"{active}/{_REGISTRY_CONTEXT_LEAF}" - logs = unify.get_logs(context=ctx, limit=1000) - except Exception: - return [] - - rows: list[dict[str, Any]] = [] - for log in logs or []: - entries = dict(log.entries or {}) - if entries.get("slug"): - rows.append(entries) - return rows - - -def _synthesize_rows_from_discovery() -> list[dict[str, Any]]: - """Project disk-discovered packages into registry-row shape. - - These rows live only in the per-body cache; they're never persisted by - this module. After a successful hot-load, the persisted registry gets - its real row written by ``hot_load_integration``.""" - try: - from unity.integration_status.discovery import discover_available_packages - except Exception: - return [] - - rows: list[dict[str, Any]] = [] - for pkg in discover_available_packages(): - rows.append( - { - "slug": pkg["slug"], - "label": pkg["label"], - "category": pkg.get("category", ""), - "version": pkg.get("version", ""), - "tier": pkg.get("tier", ""), - "required_secrets_json": json.dumps(pkg.get("required_secrets", [])), - "optional_secrets_json": json.dumps(pkg.get("optional_secrets", [])), - "function_names_json": json.dumps(pkg.get("function_names", [])), - "guidance_titles_json": json.dumps(pkg.get("guidance_titles", [])), - "_synthesized": True, - }, - ) - return rows - - -def _row_required(row: dict) -> set[str]: - return set(json.loads(row.get("required_secrets_json") or "[]")) - - -def _row_optional(row: dict) -> set[str]: - return set(json.loads(row.get("optional_secrets_json") or "[]")) - - -def _row_function_names(row: dict) -> list[str]: - return list(json.loads(row.get("function_names_json") or "[]")) - - -def _row_guidance_titles(row: dict) -> list[str]: - return list(json.loads(row.get("guidance_titles_json") or "[]")) - - def _read_local_secret_keyset() -> set[str]: - """Names of non-empty secrets in the assistant's local Secrets context. + """Names of non-empty secrets in the assistant's local ``Secrets`` context. The local Secrets context is the source of truth for both - Orchestra-mirrored OAuth tokens (Google / MS365, written by - ``SecretManager._sync_assistant_secrets``) AND directly-pasted - integration tokens (HubSpot, EmploymentHero, etc., written by Console). + Orchestra-mirrored OAuth tokens (Google / MS, written by + :meth:`SecretManager._sync_assistant_secrets`) AND directly-pasted + integration tokens (HubSpot / EmploymentHero / Matterport / Webex / + Salesforce, written by Console). - Best-effort: returns an empty set if the SecretManager isn't available - or the context can't be read. Never raises. + Best-effort: returns an empty set if the SecretManager isn't + available or the context can't be read. Never raises. """ try: import unify @@ -264,126 +123,312 @@ def _read_local_secret_keyset() -> set[str]: # --------------------------------------------------------------------------- -# Public API +# Mechanism 1 — Startup registration # --------------------------------------------------------------------------- -def recompute_enablement( - *, - assistant_id: int, - secrets: dict[str, Any] | None = None, -) -> None: - """Recompute the enabled set + setup-completeness for this assistant. - - Reads the keyset from the assistant's local Secrets context (the - SecretManager-managed unify context), which is the single source of - truth for both Orchestra-mirrored OAuth tokens AND Console-pasted - integration tokens. - - The optional ``secrets`` argument is unioned in for the edge case - where a freshly-fetched Orchestra payload hasn't yet been mirrored to - the local context — used when called inline from - ``_sync_assistant_secrets`` between the fetch and the mirror. It is - purely additive: integration keys that live only in the local context - are never excluded just because they're absent from ``secrets``. - - Treats values as "present" iff ``isinstance(v, str) and v != ""``. - Idempotent. Cheap (~10ms on first call when registry needs fetching, - ~1ms thereafter — pure dict ops over the cached registry). +def register_available_integrations() -> None: + """Walk disk packages and register each one's functions + guidance with + the runtime managers. + + **Synchronous and idempotent.** In production, callers schedule this + as a background task via :func:`schedule_register_available_integrations` + so the fast-brain conversation loop can come online without waiting + for the (potentially many-hundreds-of-ms) function/guidance inserts + to finish. Direct synchronous use is fine for tests and CLI tools. + + Replaces the May-2026 per-slug daemon-thread hot-load mechanism. + Adding a new package to disk now requires a session restart — which + matches how every other deployment artifact (manifests, scenarios, + guidance) behaves. No cross-thread context-state races (single + thread when backgrounded), no recursive manager construction (no + re-entry into ``SecretManager``), no token-paste-triggered side + effects. """ cache = _session_cache() - local_keys = _read_local_secret_keyset() - orch_supplement = { - k for k, v in (secrets or {}).items() if isinstance(v, str) and v - } - keyset = local_keys | orch_supplement - cache["secret_names"] = keyset + try: + from unity.integration_status.discovery import discover_available_packages + except Exception: + logger.warning( + "[integrations] register: discovery module unimportable; skipping", + exc_info=True, + ) + return - registry = _load_registry() - if not registry: - cache["enabled"] = {} - cache["completeness"] = {} + packages = discover_available_packages() + if not packages: + logger.info("[integrations] register: no packages discovered on disk") return - enabled: dict[str, dict] = {} - completeness: dict[str, dict] = {} - for row in registry: - slug = row.get("slug") + total_funcs = 0 + total_guidance = 0 + registered_now: list[str] = [] + + for pkg in packages: + slug = pkg.get("slug") or "" if not slug: continue - required = _row_required(row) - optional = _row_optional(row) - - if not required.issubset(keyset): - # Required secrets missing → not enabled. We still surface this - # row in the prompt's "inactive" section via _render_status_block, - # which reads the registry directly so it can describe what's - # missing in plain English. + already_registered = cache.setdefault("registered_slugs", set()) + if slug in already_registered: continue - enabled[slug] = row - missing_optional = sorted(optional - keyset) - completeness[slug] = { - "status": "fully_connected" if not missing_optional else "configured", - "missing_optional_secrets": missing_optional, - "missing_required_secrets": [], # by definition empty when enabled - } + try: + total_funcs += _register_functions(pkg) + except Exception: + logger.exception( + "[integrations] register: functions step failed for %s", + slug, + ) - # Detect newly-enabled integrations (in current set but not yet hot-loaded) - # and queue a non-blocking background load for each. schedule_hot_load - # spawns a daemon thread so the calling sync path returns immediately. - just_enabled = set(enabled.keys()) - cache.get("loaded_slugs", set()) + try: + total_guidance += _register_guidance(pkg) + except Exception: + logger.exception( + "[integrations] register: guidance step failed for %s", + slug, + ) - cache["enabled"] = enabled - cache["completeness"] = completeness + already_registered.add(slug) + registered_now.append(slug) logger.info( - "[integrations] recompute: assistant_id=%s keyset=%d enabled=%s just_enabled=%s", - assistant_id, - len(keyset), - sorted(enabled.keys()), - sorted(just_enabled), + "[integrations] register: packages=%d new=%s functions=%d guidance=%d", + len(packages), + sorted(registered_now), + total_funcs, + total_guidance, ) - for slug in just_enabled: + +def schedule_register_available_integrations() -> None: + """Spawn a single daemon thread that runs + :func:`register_available_integrations` in the background. + + Returns immediately so the calling startup path (typically + :mod:`unity.__init__`) stays non-blocking — the assistant's + conversation / communication "fast brain" comes online without + waiting for function and guidance inserts to finish. Integration + functions become callable as soon as the worker thread completes + (low hundreds of ms typically); turns that arrive earlier will see + them missing from the FunctionManager and fall through to existing + no-tool behaviour. + + Safety vs the May-2026 daemon-thread hot-load we removed: + + * Single thread, single registration pass — not per-slug spawned + from inside ``recompute_enablement`` running inside + ``SecretManager.__init__``. + * Caller is :mod:`unity.__init__` after :meth:`ContextRegistry.setup` + has populated ``_base_context``, so manager constructions inside + the worker can resolve their contexts. + * Captures the calling thread's Unify active context and re-applies + it inside the worker so ``unify.get_logs(...)`` / context-resolving + reads in the registration path see the same project + context the + main thread does. + + Best-effort: spawning failures are logged and never propagate. + """ + import threading + + try: + import unify + + captured_ctx = unify.get_active_context() + except Exception: + captured_ctx = None + + def _worker() -> None: + if captured_ctx is not None: + try: + import unify + + unify.set_context( + captured_ctx["read"], + skip_create=True, + ) + except Exception: + # Worst case the worker reads from a different context + # than the main thread; log and continue so a partial + # context-resolution failure doesn't silently block + # registration entirely. + logger.warning( + "[integrations] register worker: failed to inherit " + "Unify active context; proceeding with worker default", + exc_info=True, + ) + try: + register_available_integrations() + except Exception: + logger.exception( + "[integrations] register worker: registration failed; " + "integrations may not be available this session", + ) + + thread = threading.Thread( + target=_worker, + daemon=True, + name="integration-register-startup", + ) + thread.start() + + +def _register_functions(pkg: dict) -> int: + """Add a package's ``@custom_function`` callables to FunctionManager. + + Per-name insert/update (no orphan-delete pass) so we don't disturb + functions registered by the deployment's own ``function_dirs``. + Returns the number of functions actually added or updated. + """ + function_dir = pkg.get("function_dir") + if function_dir is None: + return 0 + + from unity.function_manager.custom_functions import collect_custom_functions + from unity.manager_registry import ManagerRegistry + + source_fns = collect_custom_functions(directory=function_dir) + if not source_fns: + return 0 + + fm = ManagerRegistry.get_function_manager() + db_fns = fm._get_custom_functions_from_db() + + changed = 0 + for name, source_data in source_fns.items(): + try: + # ``custom_functions`` retains ``venv_name`` for the + # deploy-time sync path that resolves it via the venv + # catalog. At register-time we don't manage venvs (none of + # the in-tree integration packages declare one); strip the + # key so the FM insert API doesn't choke on it. + source_data = {k: v for k, v in source_data.items() if k != "venv_name"} + if name in db_fns: + if db_fns[name].get("custom_hash") != source_data.get("custom_hash"): + fm._update_custom_function( + function_id=db_fns[name]["function_id"], + data=source_data, + ) + changed += 1 + else: + fm._insert_custom_function(source_data) + changed += 1 + except Exception: + logger.exception("Failed to register function %s", name) + return changed + + +def _register_guidance(pkg: dict) -> int: + """Add a package's guidance markdown entries to GuidanceManager. + + Per-title check + insert. Doesn't update an existing entry's content + even if the markdown on disk has changed — that's a deploy-time + concern (``_sync_guidance`` handles drift via SeedMetaStore). + Returns the number of new entries added. + """ + guidance_dir = pkg.get("guidance_dir") + if guidance_dir is None: + return 0 + + try: + from unity_deploy.assistant_deployments.integrations.loader import ( + _load_guidance, + ) + except Exception: + return 0 + + from unity.manager_registry import ManagerRegistry + + entries = _load_guidance(guidance_dir) + if not entries: + return 0 + + gm = ManagerRegistry.get_guidance_manager() + added = 0 + for entry in entries: try: - schedule_hot_load(slug) + existing = gm.filter(filter=f"title == {entry.title!r}", limit=1) + if existing: + continue + gm.add_guidance(title=entry.title, content=entry.content) + added += 1 except Exception: - logger.exception("Failed to schedule hot-load for %s", slug) + logger.exception("Failed to register guidance %r", entry.title) + return added -def get_enabled_integrations(_assistant_id: int | None = None) -> list[str]: - """Return slugs of integrations whose required secrets are all configured. +# --------------------------------------------------------------------------- +# Mechanism 2 — Enablement read (pure, on-demand) +# --------------------------------------------------------------------------- + + +def get_enabled_integrations() -> dict[str, dict]: + """Return ``{slug: package_metadata}`` for every disk-discovered package + whose required secrets are currently in the assistant's local + ``/Secrets`` context. - Reads the session cache populated by ``recompute_enablement``. Returns - ``[]`` when no recompute has happened yet OR when the registry hasn't - been seeded. + Pure function — reads disk discovery (process-cached in + :mod:`unity.integration_status.discovery`) plus the local Secrets + keyset. Costs ~1ms. No caching that can go stale; callers get fresh + state every call. + + A package with no required secrets is considered enabled + unconditionally (rare; useful for read-only or always-on packages). """ - return list(_session_cache().get("enabled", {}).keys()) + try: + from unity.integration_status.discovery import discover_available_packages + except Exception: + return {} + keyset = _read_local_secret_keyset() + enabled: dict[str, dict] = {} + for pkg in discover_available_packages(): + slug = pkg.get("slug") or "" + if not slug: + continue + required = set(pkg.get("required_secrets", [])) + if not required or required.issubset(keyset): + enabled[slug] = pkg + return enabled -def get_setup_completeness(_assistant_id: int | None = None) -> dict[str, dict]: - """Per enabled integration, return ``{status, missing_optional_secrets, - missing_required_secrets}``. Only enabled integrations appear here; - inactive ones are described separately in the prompt block.""" - return dict(_session_cache().get("completeness", {})) + +def get_setup_completeness() -> dict[str, dict]: + """Per enabled integration, return setup-completeness metadata. + + Returns ``{slug: {status, missing_optional_secrets, + missing_required_secrets}}``. ``status`` is ``"fully_connected"`` if + every required + optional secret is present, ``"configured"`` if some + optional ones are still missing. ``missing_required_secrets`` is + always empty for enabled integrations (by definition). + """ + keyset = _read_local_secret_keyset() + out: dict[str, dict] = {} + for slug, pkg in get_enabled_integrations().items(): + optional = set(pkg.get("optional_secrets", [])) + missing_opt = sorted(optional - keyset) + out[slug] = { + "status": "fully_connected" if not missing_opt else "configured", + "missing_optional_secrets": missing_opt, + "missing_required_secrets": [], + } + return out -def enabled_function_ids(_assistant_id: int | None = None) -> set[int]: - """Resolve the enabled integrations' function names → FunctionManager ids. +def enabled_function_ids() -> set[int]: + """Resolve enabled integrations' function names → FunctionManager ids. - Returns an empty set when nothing is enabled OR when the FunctionManager - isn't available; callers should treat empty as "don't filter" so the - no-registry fallback preserves existing behaviour. + Returns an empty set when nothing is enabled OR when the + FunctionManager isn't available; callers should treat empty as + "don't filter" so the no-integrations fallback preserves existing + behaviour. """ - enabled = _session_cache().get("enabled", {}) + enabled = get_enabled_integrations() if not enabled: return set() function_names: set[str] = set() - for row in enabled.values(): - function_names.update(_row_function_names(row)) + for pkg in enabled.values(): + function_names.update(pkg.get("function_names", [])) if not function_names: return set() @@ -413,19 +458,21 @@ def enabled_function_ids(_assistant_id: int | None = None) -> set[int]: return ids -def enabled_guidance_ids(_assistant_id: int | None = None) -> set[int]: +def enabled_guidance_ids() -> set[int]: """Resolve enabled integrations' guidance titles → GuidanceManager ids. - Used to set ``GuidanceManager.filter_scope`` to a guidance-id predicate - that hides entries belonging to disabled integrations. Returns an empty - set when the registry is empty or the GuidanceManager isn't available.""" - enabled = _session_cache().get("enabled", {}) + Used to set ``GuidanceManager.filter_scope`` to a guidance-id + predicate that hides entries belonging to disabled integrations. + Returns an empty set when nothing is enabled or the GuidanceManager + isn't available. + """ + enabled = get_enabled_integrations() if not enabled: return set() titles: set[str] = set() - for row in enabled.values(): - titles.update(_row_guidance_titles(row)) + for pkg in enabled.values(): + titles.update(pkg.get("guidance_titles", [])) if not titles: return set() @@ -455,60 +502,73 @@ def enabled_guidance_ids(_assistant_id: int | None = None) -> set[int]: return ids -def build_guidance_filter_scope(_assistant_id: int | None = None) -> str | None: +def build_guidance_filter_scope() -> str | None: """Return a ``guidance_id in (...)`` filter scope, or ``None`` to disable. - Convention: an empty registry (no detection possible) or zero enabled - integrations means we don't gate retrieval — return ``None`` so existing - behaviour is preserved. Actively-enabled integrations with resolved - guidance ids produce a positive filter.""" - cache = _session_cache() - if not cache.get("registry"): + Convention: when there are no integration packages on disk at all, + return ``None`` so non-integration callers see existing behaviour. + When packages exist but none are enabled, return a never-matching + filter so disabled-integration guidance is hidden. When some are + enabled, return a positive filter naming their guidance ids. + """ + try: + from unity.integration_status.discovery import discover_available_packages + + packages = discover_available_packages() + except Exception: return None - if not cache.get("enabled"): - # Registry seeded but no integration enabled → hide all integration - # guidance. Use a filter that matches nothing. + + if not packages: + return None + + enabled = get_enabled_integrations() + if not enabled: return "guidance_id in ()" + ids = enabled_guidance_ids() if not ids: return "guidance_id in ()" return "guidance_id in (" + ", ".join(str(i) for i in sorted(ids)) + ")" -def enabled_summary_for_prompt(_assistant_id: int | None = None) -> str: +def enabled_summary_for_prompt() -> str: """Render the system-prompt status block. - Re-rendered per turn (cheap; reads only the session cache). Format:: + Re-rendered per turn (cheap; reads disk discovery + local keyset). + Returns an empty string when there are no integration packages on + disk. Format:: ### Integrations Active integrations: - HubSpot (fully_connected) - - Employment Hero (configured — OAuth Connect not complete; missing - EMPLOYMENTHERO_REFRESH_TOKEN, EMPLOYMENTHERO_ORGANISATION_ID. - Tell the user to click Connect in Settings → Integrations.) + - Employment Hero (configured — setup incomplete; missing optional + secrets: EMPLOYMENTHERO_REFRESH_TOKEN. Suggest the user complete + the Connect step in Settings → Integrations if applicable.) - Inactive (token not configured): + Inactive (credentials not configured): - Salesforce — needs SALESFORCE_CLIENT_ID, SALESFORCE_CLIENT_SECRET. - - Returns an empty string when there's nothing to say (no registry, or no - integrations defined for this deployment). """ - cache = _session_cache() - registry = cache.get("registry") or [] - if not registry: + try: + from unity.integration_status.discovery import discover_available_packages + + packages = discover_available_packages() + except Exception: return "" - enabled = cache.get("enabled", {}) - completeness = cache.get("completeness", {}) - keyset = cache.get("secret_names", set()) + if not packages: + return "" + + enabled = get_enabled_integrations() + completeness = get_setup_completeness() + keyset = _read_local_secret_keyset() active_lines: list[str] = [] inactive_lines: list[str] = [] - for row in sorted(registry, key=lambda r: r.get("slug", "")): - slug = row.get("slug") - label = row.get("label") or slug + for pkg in sorted(packages, key=lambda p: p.get("slug", "")): + slug = pkg.get("slug") + label = pkg.get("label") or slug if slug in enabled: comp = completeness.get(slug, {}) status = comp.get("status", "configured") @@ -525,7 +585,7 @@ def enabled_summary_for_prompt(_assistant_id: int | None = None) -> str: f"Connect step in Settings → Integrations if applicable.)", ) else: - required = _row_required(row) + required = set(pkg.get("required_secrets", [])) missing_required = sorted(required - keyset) missing_str = ( ", ".join(missing_required) if missing_required else "(see manifest)" @@ -554,294 +614,14 @@ def enabled_summary_for_prompt(_assistant_id: int | None = None) -> str: return "\n\n".join(parts) -def all_known_secret_names() -> set[str]: - """Return the union of every secret name (required + optional) declared - by any integration in the registry OR available on disk. - - Used by ``SecretManager._sync_assistant_secrets`` as a registry-derived - replacement for the hardcoded ``OAUTH_SECRET_ALLOWLIST``. Includes - discovery so that secrets for *available* (not-yet-loaded) packages are - allowlisted too — this is what makes a token paste in Console actually - sync from Orchestra without requiring a deployment-time declaration.""" - out: set[str] = set() - - # Persisted-registry rows (covers integrations the deployment declared). - registry = _load_registry() - for row in registry or []: - out |= _row_required(row) - out |= _row_optional(row) - - # Disk-discovered packages (covers available-but-not-declared ones). - try: - from unity.integration_status.discovery import discover_available_packages - - for pkg in discover_available_packages(): - out |= set(pkg.get("required_secrets", [])) - out |= set(pkg.get("optional_secrets", [])) - except Exception: - # Best-effort; unity_deploy not importable in the running env. - pass - - return out - - -# --------------------------------------------------------------------------- -# Hot-load: register an available package's functions + guidance + registry -# row into the running managers, in the background. -# --------------------------------------------------------------------------- - - -def schedule_hot_load(slug: str) -> None: - """Sync-safe entry point: register the package for ``slug`` in the - background. Returns immediately. - - Spawns a daemon thread that runs ``asyncio.run(hot_load_integration(slug))`` - so it works from any caller — sync or async, with or without a running - event loop (constructor paths in ``SecretManager`` run before the - assistant's main loop exists). - - Idempotent: a no-op when the slug is already loaded for this session - or already being loaded by an in-flight thread. Errors during the - background load are logged but never propagate to the caller.""" - import threading - - cache = _session_cache() - if slug in cache.setdefault("loaded_slugs", set()): - return - if slug in cache.setdefault("loading_slugs", set()): - return - - cache["loading_slugs"].add(slug) - - def _worker() -> None: - try: - import asyncio - - asyncio.run(hot_load_integration(slug)) - except Exception: - logger.exception("Hot-load worker failed for slug=%s", slug) - finally: - try: - cache["loading_slugs"].discard(slug) - except Exception: - pass - - thread = threading.Thread( - target=_worker, - daemon=True, - name=f"integration-hot-load-{slug}", - ) - thread.start() - - -async def hot_load_integration(slug: str) -> None: - """Load a package's functions + guidance + registry row. - - Per-step idempotent: re-running for the same slug is a no-op for any - sub-step that's already complete. Each step runs in its own try/except - so a partial failure (e.g. one bad function file) doesn't abort the rest. - """ - cache = _session_cache() - - if slug in cache.get("loaded_slugs", set()): - return - - try: - from unity.integration_status.discovery import get_package_for_slug - except Exception: - return - - pkg = get_package_for_slug(slug) - if pkg is None: - logger.warning("hot_load_integration: package %r not on disk", slug) - return - - try: - _hot_load_guidance(pkg) - except Exception: - logger.exception("hot_load_integration: guidance step failed for %s", slug) - - try: - functions_added = _hot_load_functions(pkg) - except Exception: - logger.exception("hot_load_integration: functions step failed for %s", slug) - functions_added = 0 - - try: - _hot_load_registry_row(pkg) - except Exception: - logger.exception("hot_load_integration: registry-row step failed for %s", slug) - - cache.setdefault("loaded_slugs", set()).add(slug) - - # Invalidate the registry cache so subsequent ``_load_registry`` calls - # see the persisted row we just wrote (and stop returning the synthesized - # variant). - cache["registry_loaded"] = False - cache["registry"] = [] - - logger.info( - "Hot-loaded integration slug=%s functions_added=%d guidance=%d", - slug, - functions_added, - len(pkg.get("guidance_titles", [])), - ) - - -def _hot_load_functions(pkg: dict) -> int: - """Add a package's @custom_function callables to FunctionManager. - - Uses per-name insert/update (no orphan-delete pass) so we don't disturb - functions registered by the deployment's own ``function_dirs``. Returns - the number of functions actually added or updated. - """ - function_dir = pkg.get("function_dir") - if function_dir is None: - return 0 - - from unity.function_manager.custom_functions import collect_custom_functions - from unity.manager_registry import ManagerRegistry - - source_fns = collect_custom_functions(directory=function_dir) - if not source_fns: - return 0 - - fm = ManagerRegistry.get_function_manager() - db_fns = fm._get_custom_functions_from_db() - - changed = 0 - for name, source_data in source_fns.items(): - try: - # ``custom_functions`` retains ``venv_name`` for the deploy-time - # sync path that resolves it via the venv catalog. At hot-load - # time we don't manage venvs (none of the in-tree integration - # packages declare one); strip the key so the FM insert API - # doesn't choke on it. - source_data = {k: v for k, v in source_data.items() if k != "venv_name"} - if name in db_fns: - if db_fns[name].get("custom_hash") != source_data.get("custom_hash"): - fm._update_custom_function( - function_id=db_fns[name]["function_id"], - data=source_data, - ) - changed += 1 - else: - fm._insert_custom_function(source_data) - changed += 1 - except Exception: - logger.exception("Failed to register function %s", name) - return changed - - -def _hot_load_guidance(pkg: dict) -> int: - """Add a package's guidance markdown entries to GuidanceManager. - - Per-title check + insert. Doesn't update an existing entry's content - even if the markdown on disk has changed — that's a deploy-time concern - (``_sync_guidance`` handles drift via SeedMetaStore). Returns the - number of new entries added. - """ - guidance_dir = pkg.get("guidance_dir") - if guidance_dir is None: - return 0 - - try: - from unity_deploy.assistant_deployments.integrations.loader import ( - _load_guidance, - ) - except Exception: - return 0 - - from unity.manager_registry import ManagerRegistry - - entries = _load_guidance(guidance_dir) - if not entries: - return 0 - - gm = ManagerRegistry.get_guidance_manager() - added = 0 - for entry in entries: - try: - existing = gm.filter(filter=f"title == {entry.title!r}", limit=1) - if existing: - continue - gm.add_guidance(title=entry.title, content=entry.content) - added += 1 - except Exception: - logger.exception("Failed to register guidance %r", entry.title) - return added - - -def _hot_load_registry_row(pkg: dict) -> None: - """Persist the integration's registry row into ``Integrations/Manifests``. - - Idempotent on slug: if a row with this slug already exists, update its - fields; otherwise insert. Mirrors what unity-deploy's - ``_sync_integration_registry`` would write at deploy time. - """ - try: - import unify - except Exception: - return - - try: - active = unify.get_active_context()["read"] - ctx = f"{active}/{_REGISTRY_CONTEXT_LEAF}" - try: - unify.create_context(ctx) - except Exception: - pass - except Exception: - return - - row = { - "slug": pkg["slug"], - "label": pkg["label"], - "category": pkg.get("category", ""), - "version": pkg.get("version", ""), - "tier": pkg.get("tier", ""), - "required_secrets_json": json.dumps(pkg.get("required_secrets", [])), - "optional_secrets_json": json.dumps(pkg.get("optional_secrets", [])), - "function_names_json": json.dumps(pkg.get("function_names", [])), - "guidance_titles_json": json.dumps(pkg.get("guidance_titles", [])), - } - - try: - existing = unify.get_logs( - context=ctx, - filter=f"slug == {pkg['slug']!r}", - limit=1, - ) - except Exception: - existing = [] - - if existing: - try: - unify.update_logs( - logs=[existing[0].id], - context=ctx, - entries=[row], - overwrite=True, - ) - except Exception: - logger.exception("Failed to update registry row for %s", pkg["slug"]) - else: - try: - unify.log(context=ctx, **row) - except Exception: - logger.exception("Failed to insert registry row for %s", pkg["slug"]) - - __all__ = [ - "all_known_secret_names", "build_guidance_filter_scope", "enabled_function_ids", "enabled_guidance_ids", "enabled_summary_for_prompt", "get_enabled_integrations", "get_setup_completeness", - "hot_load_integration", - "recompute_enablement", + "register_available_integrations", "reset_session_cache", - "schedule_hot_load", + "schedule_register_available_integrations", ] diff --git a/unity/integration_status/discovery.py b/unity/integration_status/discovery.py index 0dd757753..fc77565b2 100644 --- a/unity/integration_status/discovery.py +++ b/unity/integration_status/discovery.py @@ -1,24 +1,22 @@ -"""Runtime-side discovery of integration packages installed on disk. +"""Runtime discovery of integration packages installed on disk. -The deploy side (``unity_deploy.assistant_deployments.integrations``) seeds -*registered* integrations (those declared in a deployment's -``integrations=[...]`` list) into the ``Integrations/Manifests`` DataManager -context. But assistants can have credentials for *available* packages that -weren't declared on their deployment — that's the gap this module bridges. +This module enumerates every manifest under unity_deploy's package roots +and projects each one into a lightweight metadata record. It is the +**single source of truth** for "what integration packages exist" at +runtime — :mod:`unity.integration_status` reads from here, never from the +persisted ``Integrations/Manifests`` DataManager context. -At runtime we lazily enumerate every manifest under unity_deploy's package -roots so we can: - -* Allowlist secrets for *any* installed package (not just declared ones), so - ``SecretManager._sync_assistant_secrets`` pulls them from Orchestra. -* Detect when a token for an available-but-not-declared package gets pasted - into the assistant's secrets, and drive the hot-load that registers the - package's functions + guidance into the running managers. +(Historical note: the deploy side seeds rows into ``Integrations/Manifests`` +for telemetry, and the early-May runtime read from there too. That dual +source produced cascading bugs — registered-vs-non-registered asymmetry, +stale registry rows masking new disk packages — so the read side was +collapsed to disk-only. Deploy still writes the rows; runtime ignores +them.) Everything is best-effort: if unity_deploy isn't importable on the runtime -container (the only environment where this matters in practice is a unity-only -test env), every helper here returns an empty result and callers fall back to -pre-C behaviour. +container (the only environment where this matters in practice is a +unity-only test env), every helper here returns an empty result and +callers fall through to existing behaviour. """ from __future__ import annotations diff --git a/unity/secret_manager/secret_manager.py b/unity/secret_manager/secret_manager.py index 3702f86db..3269fd875 100644 --- a/unity/secret_manager/secret_manager.py +++ b/unity/secret_manager/secret_manager.py @@ -184,12 +184,16 @@ def _default_update_tool_policy( # --------------------- Internal helpers (assistant secret sync) ---------- # - # Built-in OAuth allowlist for Google + Microsoft. These are always synced - # because Communication writes them back to Orchestra independently of the - # integration registry (the OAuth callback runs before any deploy-time - # registry seeding has had a chance to land). At runtime, we extend this - # set with everything declared by the integration registry — see - # ``_resolve_secret_allowlist``. + # Allowlist for ``_sync_assistant_secrets``. Limited to OAuth tokens that + # Communication writes to Orchestra's ``AssistantSecret`` table after each + # Google / Microsoft OAuth callback — those are the only secrets THIS sync + # is responsible for transporting. Console-pasted integration secrets + # (HubSpot, Employment Hero, Matterport, Webex, Salesforce …) live in the + # ``/Secrets`` context directly and reach ``os.environ`` via + # ``_sync_dotenv``. They neither need nor go through this sync; mixing + # them in here causes the cleanup loop to wipe them, which is the bug + # ``61141bba2`` patched. Concern separation enforced explicitly: keep + # this allowlist OAuth-only so the bug class can't reappear. _BUILTIN_OAUTH_SECRET_ALLOWLIST = frozenset( { "GOOGLE_ACCESS_TOKEN", @@ -203,44 +207,25 @@ def _default_update_tool_policy( }, ) - # Backwards-compatible alias. Existing call sites + tests that read - # ``OAUTH_SECRET_ALLOWLIST`` keep working; the value now reflects the - # built-in set only and is augmented dynamically via - # ``_resolve_secret_allowlist`` at sync time. + # Backwards-compatible alias for the (small number of) call sites and + # tests that read ``OAUTH_SECRET_ALLOWLIST`` directly. Identical to + # the built-in set above. OAUTH_SECRET_ALLOWLIST = _BUILTIN_OAUTH_SECRET_ALLOWLIST - @classmethod - def _resolve_secret_allowlist(cls) -> frozenset[str]: - """Union the built-in OAuth allowlist with secrets declared by every - installed integration — both the seeded registry and packages - discovered on disk. - - Including disk discovery is what lets a token paste in Console for - an integration the deployment didn't declare actually sync from - Orchestra. Without it the token would be filtered out of - ``_sync_assistant_secrets`` and never reach the assistant's Secrets + def _sync_assistant_secrets(self) -> None: + """Pull Google / Microsoft OAuth tokens from Orchestra's + ``AssistantSecret`` table into the assistant's local ``Secrets`` context. - Falls back to the built-in set on any failure (unity_deploy not - importable, registry context unreadable, etc.). Adding a new - integration package therefore requires no edit to this file. - """ - try: - from unity.integration_status import all_known_secret_names - - return frozenset( - cls._BUILTIN_OAUTH_SECRET_ALLOWLIST | all_known_secret_names(), - ) - except Exception: - return cls._BUILTIN_OAUTH_SECRET_ALLOWLIST + Communication writes those tokens via REST (``/assistant/{id}/secret``) + from the OAuth callback. This sync mirrors them locally so the + Actor can use them in code-first plans, and writes them to + ``os.environ`` via ``_env_set`` so subprocesses see them too. - def _sync_assistant_secrets(self) -> None: - """Pull OAuth tokens from Orchestra's assistant secrets into the Secrets context. - - Communication stores OAuth tokens (Google/Microsoft) as assistant-level - secrets in Orchestra after each OAuth callback. This method reads those - secrets via the admin API and upserts them into the ``Secrets`` context - so the Actor can discover and use them in code-first plans. + **Scope is intentionally narrow.** Console-pasted integration + secrets live in the ``/Secrets`` context directly; they reach env + via :meth:`_sync_dotenv`. This method does not know or care about + them — see ``_BUILTIN_OAUTH_SECRET_ALLOWLIST``. Best-effort: failures are logged and silently swallowed. """ @@ -274,11 +259,10 @@ def _sync_assistant_secrets(self) -> None: except Exception: return - # Resolve the active allowlist by unioning the built-in OAuth keys - # (Google + Microsoft) with everything declared by the integration - # registry. Adding a new integration package adds its secret names - # here automatically, no edit to this module. - active_allowlist = self._resolve_secret_allowlist() + # Allowlist is intentionally OAuth-only — see the + # ``_BUILTIN_OAUTH_SECRET_ALLOWLIST`` docstring above for why + # integration secrets do NOT flow through this sync. + active_allowlist = self._BUILTIN_OAUTH_SECRET_ALLOWLIST written = 0 for name, value in secrets_dict.items(): @@ -323,15 +307,16 @@ def _sync_assistant_secrets(self) -> None: written, ) - # Stale-cleanup: only the built-in Google/Microsoft OAuth keys are - # owned by THIS sync, so only those may be deleted when missing - # from Orchestra's response. Integration-managed keys (e.g. - # EMPLOYMENTHERO_REFRESH_TOKEN) and customer-pasted keys (CLIENT_ID, - # CLIENT_SECRET, API tokens) are written into the same Secrets - # context by Console's OAuth callback / paste flow and must NOT - # be cleaned up here just because Orchestra's secrets payload - # omits them — that would silently wipe valid user state every - # time the admin endpoint returned a partial or stripped response. + # Stale-cleanup: only the built-in Google / Microsoft OAuth keys + # are owned by THIS sync, so only those may be deleted when + # missing from Orchestra's response. Console-pasted secrets + # (HubSpot, Matterport, etc.) and OAuth-managed integration + # tokens (EMPLOYMENTHERO_REFRESH_TOKEN, etc.) live in the same + # local Secrets context but are NOT this sync's responsibility, + # so they must not be cleaned up here just because Orchestra's + # secrets payload omits them — that would silently wipe valid + # user state every time the admin endpoint returned a partial + # or stripped response. for stale_name in self._BUILTIN_OAUTH_SECRET_ALLOWLIST - secrets_dict.keys(): try: ids = unify.get_logs( @@ -346,24 +331,6 @@ def _sync_assistant_secrets(self) -> None: except Exception: continue - # Recompute integration enablement. ``recompute_enablement`` reads - # the keyset from the local Secrets context (the source of truth for - # both Orchestra-mirrored OAuth tokens — just written above — and - # Console-pasted integration tokens that never touch Orchestra). - # We still pass ``secrets_dict`` as a supplemental union so that any - # Google/MS keys present in the Orchestra payload but not yet - # mirrored (race window between fetch and write) are not missed. - # Best-effort; never blocks the secret-sync return. - try: - from unity.integration_status import recompute_enablement - - recompute_enablement( - assistant_id=int(agent_id), - secrets=secrets_dict, - ) - except Exception: - pass - # --------------------- Internal helpers (.env sync) --------------------- # def _dotenv_path(self) -> str: """Return the path to the .env file used for local sync. From 243b136d654fa41dd9824d772220051ffc6289e0 Mon Sep 17 00:00:00 2001 From: juliagsy <67888047+juliagsy@users.noreply.github.com> Date: Fri, 8 May 2026 11:49:11 +0800 Subject: [PATCH 05/17] added a fix to avoid non-integrated and non-registered assistants to pick up package's custom guidance/etc --- unity/integration_status/__init__.py | 65 ++++++++++++++++++++-------- 1 file changed, 48 insertions(+), 17 deletions(-) diff --git a/unity/integration_status/__init__.py b/unity/integration_status/__init__.py index c6f9f3125..3ca3f70f1 100644 --- a/unity/integration_status/__init__.py +++ b/unity/integration_status/__init__.py @@ -128,22 +128,36 @@ def _read_local_secret_keyset() -> set[str]: def register_available_integrations() -> None: - """Walk disk packages and register each one's functions + guidance with - the runtime managers. - - **Synchronous and idempotent.** In production, callers schedule this - as a background task via :func:`schedule_register_available_integrations` - so the fast-brain conversation loop can come online without waiting - for the (potentially many-hundreds-of-ms) function/guidance inserts - to finish. Direct synchronous use is fine for tests and CLI tools. - - Replaces the May-2026 per-slug daemon-thread hot-load mechanism. - Adding a new package to disk now requires a session restart — which - matches how every other deployment artifact (manifests, scenarios, - guidance) behaves. No cross-thread context-state races (single - thread when backgrounded), no recursive manager construction (no - re-entry into ``SecretManager``), no token-paste-triggered side - effects. + """Walk disk packages and register each ENABLED one's functions + + guidance with the runtime managers. + + **Gated by enablement.** Only packages whose required secrets are + present in the local ``/Secrets`` keyset are registered — this + prevents every package on disk from polluting FunctionManager / + GuidanceManager for assistants that never opted into them (e.g. an + assistant whose deployment declares only HubSpot shouldn't pick up + Matterport, Webex, etc. tools just because their packages happen to + be on disk). Packages with no declared required secrets are + always-on and are registered unconditionally. + + Deployment-declared packages whose secrets aren't pasted yet are + still loaded by the deploy seed via ``_sync_functions`` / + ``_sync_guidance``; this register pass is idempotent over those. + + **Synchronous and idempotent.** In production, callers schedule + this as a background task via + :func:`schedule_register_available_integrations` so the fast-brain + conversation loop can come online without waiting for the + (potentially many-hundreds-of-ms) inserts to finish. Direct + synchronous use is fine for tests and CLI tools. + + Mid-session token paste does **not** auto-register the integration + today — adding a secret after startup means the package's functions + won't appear until the next session. The pre-May ``schedule_hot_load`` + mechanism handled this lazily but at the cost of a daemon-thread + bug class we removed; lazy mid-session registration can come back + as a separate, single-thread, debounced helper if the UX gap shows + up in practice. """ cache = _session_cache() @@ -161,9 +175,15 @@ def register_available_integrations() -> None: logger.info("[integrations] register: no packages discovered on disk") return + # Gate registration on the local secret keyset. Computed once per + # call rather than per-package so we don't re-hit the SecretManager + # context for each disk package. + keyset = _read_local_secret_keyset() + total_funcs = 0 total_guidance = 0 registered_now: list[str] = [] + skipped_no_secrets: list[str] = [] for pkg in packages: slug = pkg.get("slug") or "" @@ -173,6 +193,15 @@ def register_available_integrations() -> None: if slug in already_registered: continue + required = set(pkg.get("required_secrets", [])) + # A package with required_secrets is registered only when the + # user has configured every one of them. Packages with NO + # required secrets (always-on / read-only) are registered + # unconditionally. + if required and not required.issubset(keyset): + skipped_no_secrets.append(slug) + continue + try: total_funcs += _register_functions(pkg) except Exception: @@ -193,9 +222,11 @@ def register_available_integrations() -> None: registered_now.append(slug) logger.info( - "[integrations] register: packages=%d new=%s functions=%d guidance=%d", + "[integrations] register: discovered=%d registered=%s " + "skipped_no_secrets=%s functions=%d guidance=%d", len(packages), sorted(registered_now), + sorted(skipped_no_secrets), total_funcs, total_guidance, ) From 0f87d9dd03967c48ea15aeb2ac7b5cd899bcc8e9 Mon Sep 17 00:00:00 2001 From: nassimberrada <112006029+nassimberrada@users.noreply.github.com> Date: Fri, 8 May 2026 16:31:19 +0000 Subject: [PATCH 06/17] feat(billing): skip credit-balance gate for METERED accounts MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The spending-limit guard's credits ≤ 0 → block rule is correct only for CREDITS-mode accounts. METERED accounts pay by monthly invoice (orchestra's monthly_metered_invoicer) and intentionally hold a zero wallet — records ledger-only on METERED — so the legacy gate would block every call once the first usage event lands. now parses the field from the orchestra spend endpoint, propagates it through , and skips the credit-balance check when . CREDITS accounts and the 'field not present' legacy path keep the existing behaviour, so the guard never loosens during a partial orchestra rollout. Test coverage in — asserts allow on METERED with zero balance, block on CREDITS with zero balance, and the legacy fallback when orchestra hasn't surfaced yet. --- tests/event_bus/test_credit_balance_guard.py | 160 +++++++++++++++++++ unity/spending_limits.py | 29 +++- 2 files changed, 187 insertions(+), 2 deletions(-) diff --git a/tests/event_bus/test_credit_balance_guard.py b/tests/event_bus/test_credit_balance_guard.py index 0ba878396..4f0c4a049 100644 --- a/tests/event_bus/test_credit_balance_guard.py +++ b/tests/event_bus/test_credit_balance_guard.py @@ -878,3 +878,163 @@ async def test_negative_from_overdraft_blocks(self): assert result.allowed is False assert "insufficient credits" in result.reason.lower() + + +# --------------------------------------------------------------------------- +# METERED-mode bypass (managed-billing) +# +# METERED accounts pay by monthly invoice via Orchestra's +# ``monthly_metered_invoicer`` rather than via a pre-paid wallet. +# Their wallet balance intentionally stays at $0 (``deduct_credits`` +# does not mutate it on METERED), so the legacy credit-balance gate +# would block every +# LLM call. The callback must skip the gate when the spend response +# carries ``billing_mode == "METERED"``. +# --------------------------------------------------------------------------- + + +def _metered_response( + *, + cumulative_spend: float = 0.0, + limit: float | None = None, + credit_balance: float | None = 0.0, +) -> dict: + """Build a spend response that mirrors a METERED account.""" + data: dict = { + "cumulative_spend": cumulative_spend, + "limit": limit, + "billing_mode": "METERED", + } + if credit_balance is not None: + data["credit_balance"] = credit_balance + return data + + +class TestMeteredBypassesCreditGate: + """METERED accounts have credit_balance=0 by design and must not be gated.""" + + @pytest.mark.asyncio + async def test_metered_zero_balance_allows(self): + """METERED + balance=$0 must allow (the canonical case).""" + from unity.spending_limits import check_spending_limits_callback + + resp = _metered_response(credit_balance=0.0) + with _patch_context(): + with _patch_spend_client(resp): + result = await check_spending_limits_callback( + LimitCheckRequest(model="gpt-4", endpoint="test"), + ) + + assert result.allowed is True + + @pytest.mark.asyncio + async def test_metered_negative_balance_allows(self): + """METERED + small negative balance (rounding/in-flight) must allow.""" + from unity.spending_limits import check_spending_limits_callback + + resp = _metered_response(credit_balance=-0.42) + with _patch_context(): + with _patch_spend_client(resp): + result = await check_spending_limits_callback( + LimitCheckRequest(model="gpt-4", endpoint="test"), + ) + + assert result.allowed is True + + @pytest.mark.asyncio + async def test_metered_org_zero_balance_allows(self): + """Same in org context: METERED org account isn't gated on balance.""" + from unity.spending_limits import check_spending_limits_callback + + resp = _metered_response(credit_balance=0.0) + with _patch_context(org_id=789): + with _patch_spend_client(resp): + result = await check_spending_limits_callback( + LimitCheckRequest(model="gpt-4", endpoint="test"), + ) + + assert result.allowed is True + + @pytest.mark.asyncio + async def test_metered_still_enforces_spending_cap(self): + """METERED accounts are still subject to spending caps.""" + from unity.spending_limits import check_spending_limits_callback + + resp = _metered_response( + cumulative_spend=200.0, + limit=100.0, + credit_balance=0.0, + ) + with _patch_context(): + with _patch_spend_client(resp): + result = await check_spending_limits_callback( + LimitCheckRequest(model="gpt-4", endpoint="test"), + ) + + assert result.allowed is False + assert "spending limit exceeded" in result.reason.lower() + + @pytest.mark.asyncio + async def test_credits_mode_explicit_still_gates(self): + """billing_mode=CREDITS keeps the legacy gate active.""" + from unity.spending_limits import check_spending_limits_callback + + data = { + "cumulative_spend": 0.0, + "limit": None, + "credit_balance": 0.0, + "billing_mode": "CREDITS", + } + with _patch_context(): + with _patch_spend_client(data): + result = await check_spending_limits_callback( + LimitCheckRequest(model="gpt-4", endpoint="test"), + ) + + assert result.allowed is False + assert "insufficient credits" in result.reason.lower() + + @pytest.mark.asyncio + async def test_missing_billing_mode_defaults_to_legacy_gate(self): + """Old Orchestra builds without billing_mode keep CREDITS-mode behaviour. + + Backward compat: a partial Orchestra rollout shouldn't loosen + the gate. ``billing_mode`` only bypasses the check when the + endpoint explicitly says ``"METERED"``. + """ + from unity.spending_limits import check_spending_limits_callback + + # No "billing_mode" key at all → should still gate at $0 + resp = _make_spend_response(credit_balance=0.0) + assert "billing_mode" not in resp + with _patch_context(): + with _patch_spend_client(resp): + result = await check_spending_limits_callback( + LimitCheckRequest(model="gpt-4", endpoint="test"), + ) + + assert result.allowed is False + + @pytest.mark.asyncio + async def test_metered_resolution_when_only_one_endpoint_returns_it(self): + """billing_mode discovered on any endpoint suffices to bypass. + + Mixed responses (assistant endpoint pre-rollout, user endpoint + post-rollout) shouldn't gate the call: the first non-None + billing_mode wins, matching how credit_balance is resolved. + """ + from unity.spending_limits import check_spending_limits_callback + + legacy = {"cumulative_spend": 0.0, "limit": None, "credit_balance": 0.0} + modern = _metered_response(credit_balance=0.0) + methods = { + "get_assistant_spend": legacy, + "get_user_spend": modern, + } + with _patch_context(): + with _patch_spend_client(methods): + result = await check_spending_limits_callback( + LimitCheckRequest(model="gpt-4", endpoint="test"), + ) + + assert result.allowed is True diff --git a/unity/spending_limits.py b/unity/spending_limits.py index 72c8ed3e1..2edbc35bf 100644 --- a/unity/spending_limits.py +++ b/unity/spending_limits.py @@ -71,6 +71,12 @@ class _LimitCheckResult: limit_set_at: Optional[str] = None # ISO format timestamp organization_id: Optional[int] = None # For member limits credit_balance: Optional[float] = None # Billing account credit balance + # Billing mode of the underlying account: "CREDITS" (pre-paid wallet, + # subject to the credit_balance gate) or "METERED" (invoiced + # monthly, gate must be skipped). Defaults to None when the spend + # endpoint didn't surface the field — older Orchestra builds — in + # which case we fall back to the legacy CREDITS-mode behaviour. + billing_mode: Optional[str] = None def _get_current_month(timezone: str = "UTC") -> str: @@ -95,9 +101,14 @@ def _parse_spend_result( spend = data.get("cumulative_spend", 0) limit_set_at = data.get("limit_set_at") credit_balance = data.get("credit_balance") + billing_mode = data.get("billing_mode") if limit is None: - return _LimitCheckResult(exceeded=False, credit_balance=credit_balance) + return _LimitCheckResult( + exceeded=False, + credit_balance=credit_balance, + billing_mode=billing_mode, + ) return _LimitCheckResult( exceeded=spend >= limit, @@ -109,6 +120,7 @@ def _parse_spend_result( limit_set_at=limit_set_at, organization_id=organization_id, credit_balance=credit_balance, + billing_mode=billing_mode, ) @@ -325,6 +337,7 @@ def _to_limit_type(type_str: Optional[str]) -> Optional[LimitType]: return None credit_balance: Optional[float] = None + billing_mode: Optional[str] = None for result in results: if isinstance(result, Exception): @@ -333,6 +346,8 @@ def _to_limit_type(type_str: Optional[str]) -> Optional[LimitType]: if credit_balance is None and result.credit_balance is not None: credit_balance = result.credit_balance + if billing_mode is None and result.billing_mode is not None: + billing_mode = result.billing_mode if result.exceeded: current = ( @@ -355,7 +370,17 @@ def _to_limit_type(type_str: Optional[str]) -> Optional[LimitType]: entity_name=result.entity_name, ) - if credit_balance is not None and credit_balance <= 0: + # Credit-balance gate. METERED accounts pay by monthly invoice via + # ``monthly_metered_invoicer`` and intentionally have a zero wallet + # balance (``deduct_credits`` doesn't mutate it on METERED), so the + # legacy gate would block every call. Skip it for METERED, keep it + # for CREDITS (and for the no-billing-mode-yet legacy case so we + # don't loosen the gate during a partial Orchestra rollout). + if ( + billing_mode != "METERED" + and credit_balance is not None + and credit_balance <= 0 + ): return LimitCheckResponse( allowed=False, reason=( From f61a466cd123f3f7afa5c0a565a6a1c0b3a2369d Mon Sep 17 00:00:00 2001 From: Haris Mahmood Date: Fri, 8 May 2026 16:00:48 +0500 Subject: [PATCH 07/17] feat(oauth): add runtime access token helper Introduce a runtime-owned OAuth helper for refresh-token backed providers instead of putting provider-specific token semantics on SecretManager. The helper owns provider metadata, aliases, expiry checks, env overlay construction, and the actor-facing get_oauth_access_token(...) documentation surface. This gives generated Python a clear way to request an explicit provider-scoped access token when an SDK or HTTP client requires one, while preserving the normal environment-based credential path for SDKs that can read credentials directly. --- tests/common/test_runtime_oauth.py | 138 ++++++++++++++++ unity/common/runtime_oauth.py | 250 +++++++++++++++++++++++++++++ 2 files changed, 388 insertions(+) create mode 100644 tests/common/test_runtime_oauth.py create mode 100644 unity/common/runtime_oauth.py diff --git a/tests/common/test_runtime_oauth.py b/tests/common/test_runtime_oauth.py new file mode 100644 index 000000000..998452275 --- /dev/null +++ b/tests/common/test_runtime_oauth.py @@ -0,0 +1,138 @@ +from __future__ import annotations + +from datetime import datetime, timedelta, timezone +from typing import Any + +import pytest + +from unity.common import runtime_oauth + + +def _future_expiry() -> str: + return (datetime.now(timezone.utc) + timedelta(hours=1)).isoformat() + + +def _past_expiry() -> str: + return (datetime.now(timezone.utc) - timedelta(minutes=5)).isoformat() + + +class _FakeSecretManager: + def __init__(self, secrets: dict[str, str]) -> None: + self.secrets = secrets + self.sync_calls: list[dict[str, Any]] = [] + self.on_sync = None + + def _get_secret_value(self, name: str) -> str | None: + return self.secrets.get(name) + + def sync_assistant_secrets_if_stale(self, **kwargs: Any) -> bool: + self.sync_calls.append(kwargs) + if self.on_sync is not None: + self.on_sync() + return True + + +def _install_secret_manager(monkeypatch, sm: _FakeSecretManager) -> None: + monkeypatch.setattr(runtime_oauth, "_get_secret_manager", lambda: sm) + + +def test_get_oauth_access_token_supports_provider_alias(monkeypatch): + sm = _FakeSecretManager( + { + "MICROSOFT_ACCESS_TOKEN": "fresh-ms-token", + "MICROSOFT_TOKEN_EXPIRES_AT": _future_expiry(), + }, + ) + _install_secret_manager(monkeypatch, sm) + + assert runtime_oauth.get_oauth_access_token("graph") == "fresh-ms-token" + assert sm.sync_calls[-1]["force"] is False + + +def test_get_oauth_access_token_unknown_provider_is_non_secret_error(monkeypatch): + sm = _FakeSecretManager({}) + _install_secret_manager(monkeypatch, sm) + + with pytest.raises(ValueError, match="Unknown refresh-token OAuth provider"): + runtime_oauth.get_oauth_access_token("not-a-real-provider") + + +def test_get_oauth_access_token_missing_token_raises_after_sync(monkeypatch): + sm = _FakeSecretManager({"MICROSOFT_TOKEN_EXPIRES_AT": _future_expiry()}) + _install_secret_manager(monkeypatch, sm) + + with pytest.raises(ValueError, match="No access token is available"): + runtime_oauth.get_oauth_access_token("microsoft") + + assert sm.sync_calls[-1]["force"] is True + + +def test_get_oauth_access_token_forces_sync_when_token_is_expired(monkeypatch): + sm = _FakeSecretManager( + { + "MICROSOFT_ACCESS_TOKEN": "old-ms-token", + "MICROSOFT_TOKEN_EXPIRES_AT": _past_expiry(), + }, + ) + _install_secret_manager(monkeypatch, sm) + + def refresh_token() -> None: + sm.secrets["MICROSOFT_ACCESS_TOKEN"] = "fresh-ms-token" + sm.secrets["MICROSOFT_TOKEN_EXPIRES_AT"] = _future_expiry() + + sm.on_sync = refresh_token + + assert runtime_oauth.get_oauth_access_token("microsoft") == "fresh-ms-token" + assert sm.sync_calls[-1]["force"] is True + + +def test_get_oauth_access_token_forces_sync_when_expiry_is_invalid(monkeypatch): + sm = _FakeSecretManager( + { + "GOOGLE_ACCESS_TOKEN": "old-google-token", + "GOOGLE_TOKEN_EXPIRES_AT": "not-a-date", + }, + ) + _install_secret_manager(monkeypatch, sm) + + def refresh_token() -> None: + sm.secrets["GOOGLE_ACCESS_TOKEN"] = "fresh-google-token" + sm.secrets["GOOGLE_TOKEN_EXPIRES_AT"] = _future_expiry() + + sm.on_sync = refresh_token + + assert runtime_oauth.get_oauth_access_token("google") == "fresh-google-token" + assert sm.sync_calls[-1]["force"] is True + + +def test_get_oauth_access_token_supports_multiple_providers(monkeypatch): + sm = _FakeSecretManager( + { + "MICROSOFT_ACCESS_TOKEN": "fresh-ms-token", + "MICROSOFT_TOKEN_EXPIRES_AT": _future_expiry(), + "GOOGLE_ACCESS_TOKEN": "fresh-google-token", + "GOOGLE_TOKEN_EXPIRES_AT": _future_expiry(), + }, + ) + _install_secret_manager(monkeypatch, sm) + + assert runtime_oauth.get_oauth_access_token("microsoft") == "fresh-ms-token" + assert runtime_oauth.get_oauth_access_token("google") == "fresh-google-token" + + +def test_get_refresh_token_oauth_env_overlay_returns_all_current_values(monkeypatch): + sm = _FakeSecretManager( + { + "MICROSOFT_ACCESS_TOKEN": "fresh-ms-token", + "MICROSOFT_TOKEN_EXPIRES_AT": _future_expiry(), + "GOOGLE_ACCESS_TOKEN": "fresh-google-token", + "GOOGLE_TOKEN_EXPIRES_AT": _future_expiry(), + }, + ) + _install_secret_manager(monkeypatch, sm) + + overlay = runtime_oauth.get_refresh_token_oauth_env_overlay() + + assert overlay["MICROSOFT_ACCESS_TOKEN"] == "fresh-ms-token" + assert overlay["GOOGLE_ACCESS_TOKEN"] == "fresh-google-token" + assert sm.sync_calls[-1]["reason"] == "oauth_env_overlay" diff --git a/unity/common/runtime_oauth.py b/unity/common/runtime_oauth.py new file mode 100644 index 000000000..24b70ae57 --- /dev/null +++ b/unity/common/runtime_oauth.py @@ -0,0 +1,250 @@ +"""Runtime helpers for refresh-token backed OAuth credentials. + +SecretManager owns storage and synchronization: it mirrors allowlisted assistant +secrets from Orchestra into the local ``Secrets`` context, ``.env``, and +``os.environ``. This module owns the runtime interpretation of those mirrored +values: provider aliases, access-token/expiry secret names, freshness checks, +and the sandbox helper exposed to actor-written Python. + +The split is deliberate. ``get_oauth_access_token(...)`` is not a +``primitives.secrets`` tool and does not expose arbitrary secrets; it behaves +like ``reason(...)``/``notify(...)`` as a Python runtime helper for code paths +that must pass an explicit OAuth access token to an SDK/client/request. Code +that can rely on provider SDK/default environment credential behavior should +continue to do so; env overlays below keep rotating OAuth env vars fresh for +venv and shell backends. +""" + +import inspect +import os +from dataclasses import dataclass +from datetime import datetime, timezone + + +@dataclass(frozen=True) +class OAuthProviderMetadata: + """Runtime metadata for a refresh-token backed OAuth provider.""" + + canonical_name: str + aliases: tuple[str, ...] + access_token_secret: str + refresh_token_secret: str | None = None + expiry_secret: str | None = None + granted_scopes_secret: str | None = None + docs_label: str = "" + + @property + def secret_names(self) -> frozenset[str]: + return frozenset( + name + for name in ( + self.access_token_secret, + self.refresh_token_secret, + self.expiry_secret, + self.granted_scopes_secret, + ) + if name + ) + + +_OAUTH_PROVIDER_METADATA: dict[str, OAuthProviderMetadata] = { + "google": OAuthProviderMetadata( + canonical_name="google", + aliases=("google", "gmail", "google_workspace", "drive"), + access_token_secret="GOOGLE_ACCESS_TOKEN", + refresh_token_secret="GOOGLE_REFRESH_TOKEN", + expiry_secret="GOOGLE_TOKEN_EXPIRES_AT", + granted_scopes_secret="GOOGLE_GRANTED_SCOPES", + docs_label="Google APIs", + ), + "microsoft": OAuthProviderMetadata( + canonical_name="microsoft", + aliases=("microsoft", "msft", "ms365", "microsoft_365", "graph"), + access_token_secret="MICROSOFT_ACCESS_TOKEN", + refresh_token_secret="MICROSOFT_REFRESH_TOKEN", + expiry_secret="MICROSOFT_TOKEN_EXPIRES_AT", + granted_scopes_secret="MICROSOFT_GRANTED_SCOPES", + docs_label="Microsoft Graph", + ), +} +_OAUTH_PROVIDER_ALIASES: dict[str, str] = { + alias.strip().lower().replace("-", "_"): metadata.canonical_name + for metadata in _OAUTH_PROVIDER_METADATA.values() + for alias in metadata.aliases +} + + +def _resolve_oauth_provider(provider: str) -> OAuthProviderMetadata: + if not isinstance(provider, str) or not provider.strip(): + supported = ", ".join(sorted(_OAUTH_PROVIDER_METADATA)) + raise ValueError( + "A refresh-token OAuth provider name is required. " + f"Supported providers: {supported}", + ) + normalized = provider.strip().lower().replace("-", "_") + canonical = _OAUTH_PROVIDER_ALIASES.get(normalized, normalized) + metadata = _OAUTH_PROVIDER_METADATA.get(canonical) + if metadata is None: + supported = ", ".join(sorted(_OAUTH_PROVIDER_METADATA)) + raise ValueError( + f"Unknown refresh-token OAuth provider {provider!r}. " + f"Supported providers: {supported}", + ) + return metadata + + +def refresh_token_oauth_secret_names() -> frozenset[str]: + names: set[str] = set() + for metadata in _OAUTH_PROVIDER_METADATA.values(): + names.update(metadata.secret_names) + return frozenset(names) + + +def _get_secret_manager(): + from unity.manager_registry import ManagerRegistry + + return ManagerRegistry.get_secret_manager() + + +def _get_secret_value(secret_manager, name: str) -> str | None: + getter = getattr(secret_manager, "_get_secret_value", None) + if callable(getter): + value = getter(name) + if isinstance(value, str) and value: + return value + value = os.environ.get(name) + return value if value else None + + +def _parse_expiry(value: str) -> datetime: + if value.isdigit(): + return datetime.fromtimestamp(int(value), tz=timezone.utc) + normalized = value.replace("Z", "+00:00") + parsed = datetime.fromisoformat(normalized) + if parsed.tzinfo is None: + return parsed.replace(tzinfo=timezone.utc) + return parsed.astimezone(timezone.utc) + + +def _token_expires_within( + secret_manager, + metadata: OAuthProviderMetadata, + min_ttl_seconds: int, +) -> bool: + if metadata.expiry_secret is None: + return False + expiry_value = _get_secret_value(secret_manager, metadata.expiry_secret) + if not expiry_value: + return True + try: + expiry = _parse_expiry(expiry_value) + except ValueError: + return True + remaining = (expiry - datetime.now(timezone.utc)).total_seconds() + return remaining <= min_ttl_seconds + + +def get_oauth_access_token(provider: str, *, min_ttl_seconds: int = 300) -> str: + """ + Return a current OAuth access token for a refresh-token backed provider. + + Use this runtime helper inside generated Python code when a provider SDK, + client, or direct HTTP request requires an explicit access token. Prefer + provider SDK/default credential behavior when it can read credentials from + the runtime environment directly; Unity keeps rotating OAuth env vars + synced separately for that path. + + Parameters + ---------- + provider: + Provider name or alias. Built-in aliases include ``"microsoft"``, + ``"graph"``, ``"google"``, ``"gmail"``, and ``"drive"``. + min_ttl_seconds: + Minimum acceptable token lifetime. If the current token is missing or + expires within this many seconds, the parent runtime forces an + assistant-secret sync from Orchestra before returning a token. + + Examples + -------- + Multiple providers can be used in one sandbox; request each explicitly:: + + microsoft_token = get_oauth_access_token("microsoft") + google_token = get_oauth_access_token("google") + + For direct OAuth2 HTTP APIs such as Microsoft Graph, provider docs commonly + show the access token in an ``Authorization: Bearer ...`` header. Other SDKs + may require a credential object or may read environment variables directly, + so follow the provider's SDK/API docs for how to apply the token. + + Anti-patterns + ------------- + - Do not print, log, return, or store the token value. + - Do not save concrete token values in FunctionManager functions or + GuidanceManager guidance. + - Do not read rotating OAuth access-token env vars directly when this + helper is available and an explicit access token is required. + """ + metadata = _resolve_oauth_provider(provider) + secret_manager = _get_secret_manager() + token = _get_secret_value(secret_manager, metadata.access_token_secret) + needs_force_sync = token is None or _token_expires_within( + secret_manager, + metadata, + min_ttl_seconds, + ) + secret_manager.sync_assistant_secrets_if_stale( + ttl_seconds=60.0, + force=needs_force_sync, + reason=f"oauth_access_token:{metadata.canonical_name}", + ) + token = _get_secret_value(secret_manager, metadata.access_token_secret) + if not token: + raise ValueError( + f"No access token is available for refresh-token OAuth provider " + f"{metadata.canonical_name!r}.", + ) + if _token_expires_within(secret_manager, metadata, min_ttl_seconds): + raise ValueError( + f"The access token for refresh-token OAuth provider " + f"{metadata.canonical_name!r} is expired or near expiry after sync.", + ) + return token + + +def get_refresh_token_oauth_env_overlay() -> dict[str, str]: + """Return fresh rotating OAuth env vars for subprocess execution backends. + + Venv and persistent shell sessions can outlive the parent process's last + environment update, so they cannot rely solely on the environment copied at + process start. This helper performs the debounced assistant-secret sync, + then returns only the built-in refresh-token OAuth variables that should be + overlaid into those subprocesses before execution. + """ + secret_manager = _get_secret_manager() + secret_manager.sync_assistant_secrets_if_stale( + ttl_seconds=60.0, + reason="oauth_env_overlay", + ) + overlay: dict[str, str] = {} + for name in refresh_token_oauth_secret_names(): + value = _get_secret_value(secret_manager, name) + if value: + overlay[name] = value + return overlay + + +def get_oauth_prompt_context() -> str: + """Return actor-facing documentation for OAuth runtime helpers.""" + doc = inspect.getdoc(get_oauth_access_token) or "" + signature = ( + f"def {get_oauth_access_token.__name__}" + f"{inspect.signature(get_oauth_access_token)}" + ) + return ( + "### OAuth Access Token Helper: `get_oauth_access_token(...)`\n\n" + "`get_oauth_access_token(...)` is available inside `execute_code` " + "Python sessions and stored Python functions. It is a normal sandbox " + "helper, not a JSON tool call.\n\n" + f"```python\n{signature}\n```\n\n" + f"{doc}" + ) From 78f9337db4777af18ddb1ba2cb3edbb20fb24976 Mon Sep 17 00:00:00 2001 From: Haris Mahmood Date: Fri, 8 May 2026 16:01:02 +0500 Subject: [PATCH 08/17] refactor(secrets): debounce runtime OAuth secret sync Keep SecretManager focused on mirroring allowlisted runtime OAuth secrets from Orchestra into local Secrets, .env, and os.environ, while keeping OAuth provider semantics in the runtime helper. The sync path now has a single debounced gate so frequent runtime callers can ask for freshness without forcing a network round trip on every operation. Assistant update events and secret inspection still force sync because those paths represent explicit freshness boundaries. Normal runtime execution can use the same gate with a TTL, which keeps credentials reasonably current without making every actor step pay the full Orchestra sync cost. --- tests/secret_manager/test_oauth_tokens.py | 92 ++++++++++ .../domains/managers_utils.py | 2 +- unity/secret_manager/secret_manager.py | 165 +++++++++++++----- 3 files changed, 217 insertions(+), 42 deletions(-) create mode 100644 tests/secret_manager/test_oauth_tokens.py diff --git a/tests/secret_manager/test_oauth_tokens.py b/tests/secret_manager/test_oauth_tokens.py new file mode 100644 index 000000000..0370f9cf6 --- /dev/null +++ b/tests/secret_manager/test_oauth_tokens.py @@ -0,0 +1,92 @@ +from __future__ import annotations + +import os +from threading import Lock + +from unity.secret_manager.secret_manager import SecretManager + + +def _unit_secret_manager() -> SecretManager: + sm = object.__new__(SecretManager) + sm._assistant_secret_sync_lock = Lock() + sm._last_assistant_secret_sync_success_at = None + sm._last_assistant_secret_sync_failure_at = None + return sm + + +def test_sync_assistant_secrets_if_stale_debounces(monkeypatch): + sm = _unit_secret_manager() + calls: list[str] = [] + + monkeypatch.setattr( + sm, + "_sync_assistant_secrets", + lambda: calls.append("assistant"), + ) + monkeypatch.setattr(sm, "_sync_dotenv", lambda: calls.append("dotenv")) + + assert sm.sync_assistant_secrets_if_stale(reason="test") is True + assert sm.sync_assistant_secrets_if_stale(reason="test") is False + + assert calls == ["assistant", "dotenv"] + + +def test_sync_assistant_secrets_if_stale_force_bypasses_debounce(monkeypatch): + sm = _unit_secret_manager() + calls: list[str] = [] + + monkeypatch.setattr( + sm, + "_sync_assistant_secrets", + lambda: calls.append("assistant"), + ) + monkeypatch.setattr(sm, "_sync_dotenv", lambda: calls.append("dotenv")) + + assert sm.sync_assistant_secrets_if_stale(reason="test") is True + assert sm.sync_assistant_secrets_if_stale(reason="test", force=True) is True + + assert calls == ["assistant", "dotenv", "assistant", "dotenv"] + + +def test_sync_assistant_secrets_if_stale_observes_failure_cooldown(monkeypatch): + sm = _unit_secret_manager() + calls = {"assistant": 0, "dotenv": 0} + + def fail_sync(): + calls["assistant"] += 1 + raise RuntimeError("sync failed") + + def sync_dotenv(): + calls["dotenv"] += 1 + + monkeypatch.setattr(sm, "_sync_assistant_secrets", fail_sync) + monkeypatch.setattr(sm, "_sync_dotenv", sync_dotenv) + + assert sm.sync_assistant_secrets_if_stale(reason="test") is False + assert sm.sync_assistant_secrets_if_stale(reason="test") is False + assert calls == {"assistant": 1, "dotenv": 0} + + +def test_resolve_secret_allowlist_includes_runtime_oauth_secret_names(): + allowlist = SecretManager._resolve_secret_allowlist() + + assert "MICROSOFT_ACCESS_TOKEN" in allowlist + assert "MICROSOFT_TOKEN_EXPIRES_AT" in allowlist + assert "GOOGLE_ACCESS_TOKEN" in allowlist + assert "GOOGLE_TOKEN_EXPIRES_AT" in allowlist + + +def test_env_merge_and_write_updates_dotenv_and_process_env(monkeypatch, tmp_path): + sm = _unit_secret_manager() + dotenv_path = tmp_path / ".env" + + monkeypatch.setattr(sm, "_dotenv_path", lambda: str(dotenv_path)) + monkeypatch.delenv("MICROSOFT_ACCESS_TOKEN", raising=False) + + sm._env_merge_and_write( + add_or_update={"MICROSOFT_ACCESS_TOKEN": "fresh-token"}, + remove_keys=None, + ) + + assert dotenv_path.read_text() == "MICROSOFT_ACCESS_TOKEN=fresh-token\n" + assert os.environ["MICROSOFT_ACCESS_TOKEN"] == "fresh-token" diff --git a/unity/conversation_manager/domains/managers_utils.py b/unity/conversation_manager/domains/managers_utils.py index 944f27161..3d6e7b971 100644 --- a/unity/conversation_manager/domains/managers_utils.py +++ b/unity/conversation_manager/domains/managers_utils.py @@ -1211,7 +1211,7 @@ async def sync_assistant_secrets() -> None: from unity.manager_registry import ManagerRegistry sm = ManagerRegistry.get_secret_manager() - sm._sync_assistant_secrets() + sm.sync_assistant_secrets_if_stale(force=True, reason="assistant_update") # Contact updates diff --git a/unity/secret_manager/secret_manager.py b/unity/secret_manager/secret_manager.py index 3269fd875..771df41a1 100644 --- a/unity/secret_manager/secret_manager.py +++ b/unity/secret_manager/secret_manager.py @@ -4,6 +4,8 @@ import functools import logging import os +from threading import Lock +from time import monotonic from typing import Any, Callable, Dict, List, Optional, Type from pydantic import BaseModel @@ -57,6 +59,9 @@ def __init__(self) -> None: super().__init__() self.include_in_multi_assistant_table = True self._ctx = ContextRegistry.get_context(self, "Secrets") + self._assistant_secret_sync_lock = Lock() + self._last_assistant_secret_sync_success_at: float | None = None + self._last_assistant_secret_sync_failure_at: float | None = None # Ensure storage/schema exists deterministically (idempotent) self._provision_storage() @@ -212,22 +217,35 @@ def _default_update_tool_policy( # the built-in set above. OAUTH_SECRET_ALLOWLIST = _BUILTIN_OAUTH_SECRET_ALLOWLIST - def _sync_assistant_secrets(self) -> None: - """Pull Google / Microsoft OAuth tokens from Orchestra's - ``AssistantSecret`` table into the assistant's local ``Secrets`` - context. + @classmethod + def _resolve_secret_allowlist(cls) -> frozenset[str]: + """Return assistant-secret names owned by runtime OAuth sync. - Communication writes those tokens via REST (``/assistant/{id}/secret``) - from the OAuth callback. This sync mirrors them locally so the - Actor can use them in code-first plans, and writes them to - ``os.environ`` via ``_env_set`` so subprocesses see them too. + The set is intentionally limited to refresh-token OAuth metadata. + Console-pasted integration credentials already live in the local + ``Secrets`` context and reach ``os.environ`` through ``_sync_dotenv``. + """ + try: + from unity.common.runtime_oauth import refresh_token_oauth_secret_names - **Scope is intentionally narrow.** Console-pasted integration - secrets live in the ``/Secrets`` context directly; they reach env - via :meth:`_sync_dotenv`. This method does not know or care about - them — see ``_BUILTIN_OAUTH_SECRET_ALLOWLIST``. + return cls._BUILTIN_OAUTH_SECRET_ALLOWLIST | refresh_token_oauth_secret_names() + except Exception: + return cls._BUILTIN_OAUTH_SECRET_ALLOWLIST - Best-effort: failures are logged and silently swallowed. + def _sync_assistant_secrets(self) -> None: + """Mirror runtime OAuth assistant secrets from Orchestra into local state. + + Orchestra is the platform source of truth for assistant-level OAuth + secrets written outside this Unity process. Communication refresh jobs + persist updated access tokens there; this method pulls those values into + Unity's local ``Secrets`` context, then updates ``.env``/``os.environ`` + so generated code and provider SDKs can use normal environment-based + credential discovery. + + The sync is intentionally allowlisted. We mirror refresh-token OAuth + keys, but we do not copy arbitrary assistant secrets into the runtime. + Failures are best-effort: callers use ``sync_assistant_secrets_if_stale`` + as the observable gate. """ from ..session_details import SESSION_DETAILS @@ -259,10 +277,10 @@ def _sync_assistant_secrets(self) -> None: except Exception: return - # Allowlist is intentionally OAuth-only — see the - # ``_BUILTIN_OAUTH_SECRET_ALLOWLIST`` docstring above for why - # integration secrets do NOT flow through this sync. - active_allowlist = self._BUILTIN_OAUTH_SECRET_ALLOWLIST + # Allowlist is intentionally OAuth-only; integration secrets do not flow + # through this sync because they already live in the local Secrets + # context and are exported by _sync_dotenv. + active_allowlist = self._resolve_secret_allowlist() written = 0 for name, value in secrets_dict.items(): @@ -307,17 +325,10 @@ def _sync_assistant_secrets(self) -> None: written, ) - # Stale-cleanup: only the built-in Google / Microsoft OAuth keys - # are owned by THIS sync, so only those may be deleted when - # missing from Orchestra's response. Console-pasted secrets - # (HubSpot, Matterport, etc.) and OAuth-managed integration - # tokens (EMPLOYMENTHERO_REFRESH_TOKEN, etc.) live in the same - # local Secrets context but are NOT this sync's responsibility, - # so they must not be cleaned up here just because Orchestra's - # secrets payload omits them — that would silently wipe valid - # user state every time the admin endpoint returned a partial - # or stripped response. - for stale_name in self._BUILTIN_OAUTH_SECRET_ALLOWLIST - secrets_dict.keys(): + # Stale-cleanup is limited to the OAuth secrets owned by this sync. + # Console-pasted integration credentials live in the same local Secrets + # context but are not removed based on the admin assistant payload. + for stale_name in active_allowlist - secrets_dict.keys(): try: ids = unify.get_logs( context=self._ctx, @@ -331,6 +342,89 @@ def _sync_assistant_secrets(self) -> None: except Exception: continue + def sync_assistant_secrets_if_stale( + self, + ttl_seconds: float = 60.0, + *, + force: bool = False, + reason: str = "runtime", + failure_cooldown_seconds: float = 10.0, + ) -> bool: + """Pull assistant secrets through one debounced runtime sync gate. + + This is the single runtime entry point for keeping Unity's local secret + state close to Orchestra without adding a network round trip to every + actor operation. Normal callers, including ``execute_code``, call with + ``force=False`` and therefore only perform the expensive Orchestra pull + once per ``ttl_seconds`` window. Forced callers use this when freshness + matters more than debounce, such as SecretManager construction, + ``primitives.secrets.ask(...)``, assistant-update events, or an OAuth + helper detecting a missing/near-expiry access token. + + Returns ``True`` only when this invocation actually ran the sync work. + Returns ``False`` when the success debounce or failure cooldown skipped + work, or when the wrapped sync raised an exception. + """ + now = monotonic() + if not force: + last_success = self._last_assistant_secret_sync_success_at + if last_success is not None and now - last_success < ttl_seconds: + return False + last_failure = self._last_assistant_secret_sync_failure_at + if ( + last_failure is not None + and now - last_failure < failure_cooldown_seconds + ): + return False + + with self._assistant_secret_sync_lock: + now = monotonic() + if not force: + last_success = self._last_assistant_secret_sync_success_at + if last_success is not None and now - last_success < ttl_seconds: + return False + last_failure = self._last_assistant_secret_sync_failure_at + if ( + last_failure is not None + and now - last_failure < failure_cooldown_seconds + ): + return False + try: + self._sync_assistant_secrets() + self._sync_dotenv() + except Exception: + self._last_assistant_secret_sync_failure_at = monotonic() + logger.warning( + "[integrations] assistant secret sync failed reason=%s", + reason, + exc_info=True, + ) + return False + self._last_assistant_secret_sync_success_at = monotonic() + self._last_assistant_secret_sync_failure_at = None + logger.info( + "[integrations] assistant secret sync complete reason=%s", + reason, + ) + return True + + def _get_secret_value(self, name: str) -> str | None: + try: + rows = unify.get_logs( + context=self._ctx, + filter=f"name == {name!r}", + limit=1, + from_fields=["name", "value"], + ) + if rows: + value = (rows[0].entries or {}).get("value") + if isinstance(value, str) and value: + return value + except Exception: + pass + value = os.environ.get(name) + return value if value else None + # --------------------- Internal helpers (.env sync) --------------------- # def _dotenv_path(self) -> str: """Return the path to the .env file used for local sync. @@ -382,8 +476,7 @@ def _ensure_dotenv_synced_on_init(self) -> None: with open(path, "w", encoding="utf-8") as fh: fh.write("") - self._sync_assistant_secrets() - self._sync_dotenv() + self.sync_assistant_secrets_if_stale(force=True, reason="secret_manager_init") @staticmethod def _parse_env_lines(lines: List[str]) -> Dict[str, int]: @@ -536,17 +629,7 @@ async def ask( _clarification_down_q: Optional[asyncio.Queue[str]] = None, _call_id: Optional[str] = None, ) -> SteerableToolHandle: - # Pull OAuth tokens from Orchestra → Secrets context, then sync - # all secrets (including freshly-synced OAuth tokens) into .env - # so they're available through os.environ before the Actor reads them. - try: - self._sync_assistant_secrets() - except Exception: - pass - try: - self._sync_dotenv() - except Exception: - pass + self.sync_assistant_secrets_if_stale(force=True, reason="secret_ask") # First, replace any known raw secret values with placeholders try: From 16b5d778e41cf8c60846537d5ff508b45bcfacb5 Mon Sep 17 00:00:00 2001 From: Haris Mahmood Date: Fri, 8 May 2026 16:01:14 +0500 Subject: [PATCH 09/17] feat(oauth): wire fresh tokens into actor runtimes Route in-process Python, venv-backed Python, persistent shell sessions, and runtime RPC through the OAuth runtime helper. The execute_code boundary now asks the debounced secret sync gate for freshness, and long-lived subprocesses receive OAuth env overlays so SDK/default-env credential paths do not keep stale inherited values. Explicit get_oauth_access_token(...) calls in venv and shell route back to the parent runtime, which keeps token freshness checks centralized instead of trusting child process environment snapshots. The actor integration test covers Microsoft and Google in the same sandbox to prevent accidental global-token behavior. --- .../code_act/test_execute_code_output.py | 57 ++++++++++++ .../test_runtime_oauth_bridge.py | 89 +++++++++++++++++++ unity/actor/code_act_actor.py | 26 ++++++ unity/actor/execution/session.py | 58 +++++++++++- unity/function_manager/execution_env.py | 2 + unity/function_manager/function_manager.py | 44 +++++++++ unity/function_manager/venv_runner.py | 39 ++++++++ 7 files changed, 314 insertions(+), 1 deletion(-) create mode 100644 tests/function_manager/test_runtime_oauth_bridge.py diff --git a/tests/actor/code_act/test_execute_code_output.py b/tests/actor/code_act/test_execute_code_output.py index 5090555a0..81397655b 100644 --- a/tests/actor/code_act/test_execute_code_output.py +++ b/tests/actor/code_act/test_execute_code_output.py @@ -134,6 +134,63 @@ def get_result(out: Any) -> Any: return get_output_field(out, "result", None) +# --------------------------------------------------------------------------- +# Test: Runtime OAuth token helper exposed to real actor execute_code +# --------------------------------------------------------------------------- + + +class _FakeOAuthSecretManager: + def __init__(self) -> None: + self.calls: list[Any] = [] + self.secrets = { + "MICROSOFT_ACCESS_TOKEN": "microsoft:fresh-token", + "MICROSOFT_TOKEN_EXPIRES_AT": "2999-01-01T00:00:00+00:00", + "GOOGLE_ACCESS_TOKEN": "google:fresh-token", + "GOOGLE_TOKEN_EXPIRES_AT": "2999-01-01T00:00:00+00:00", + } + + def sync_assistant_secrets_if_stale(self, **kwargs: Any) -> bool: + self.calls.append(("sync", kwargs)) + return True + + def _get_secret_value(self, name: str) -> str | None: + self.calls.append(("secret", name)) + return self.secrets.get(name) + + +@pytest.mark.asyncio +async def test_execute_code_oauth_helper_uses_parent_secret_manager( + execute_code_tool: tuple[Any, Primitives], + monkeypatch: pytest.MonkeyPatch, +) -> None: + execute_code, _ = execute_code_tool + fake_secret_manager = _FakeOAuthSecretManager() + monkeypatch.setattr( + ManagerRegistry, + "get_secret_manager", + lambda: fake_secret_manager, + ) + + out = await execute_code( + "mock scenario: call rotating OAuth token helper for multiple providers", + """ +microsoft_token = get_oauth_access_token("microsoft", min_ttl_seconds=123) +google_token = get_oauth_access_token("google", min_ttl_seconds=456) +assert microsoft_token == "microsoft:fresh-token" +assert google_token == "google:fresh-token" +print("TOKEN_OK") +""", + language="python", + state_mode="stateless", + ) + + assert get_error(out) is None + assert "TOKEN_OK" in get_stdout_text(out) + assert ("secret", "MICROSOFT_ACCESS_TOKEN") in fake_secret_manager.calls + assert ("secret", "GOOGLE_ACCESS_TOKEN") in fake_secret_manager.calls + assert any(call[0] == "sync" for call in fake_secret_manager.calls) + + # --------------------------------------------------------------------------- # Test: Basic stdout capture from primitives.*.ask().result() # --------------------------------------------------------------------------- diff --git a/tests/function_manager/test_runtime_oauth_bridge.py b/tests/function_manager/test_runtime_oauth_bridge.py new file mode 100644 index 000000000..7570f31cd --- /dev/null +++ b/tests/function_manager/test_runtime_oauth_bridge.py @@ -0,0 +1,89 @@ +from __future__ import annotations + +from dataclasses import dataclass, field + +import pytest + +from unity.function_manager.function_manager import FunctionManager +from unity.function_manager import function_manager as function_manager_module + + +@dataclass +class _FakeSecretManager: + sync_reasons: list[str] = field(default_factory=list) + + def _get_secret_value(self, name: str) -> str | None: + values = { + "GOOGLE_ACCESS_TOKEN": "google-token", + "GOOGLE_TOKEN_EXPIRES_AT": "2999-01-01T00:00:00+00:00", + } + return values.get(name) + + def sync_assistant_secrets_if_stale(self, **kwargs) -> bool: + self.sync_reasons.append(kwargs["reason"]) + return True + + +@pytest.mark.asyncio +async def test_shell_runtime_oauth_token_helper_uses_parent_rpc(monkeypatch): + fake_secret_manager = _FakeSecretManager() + monkeypatch.setattr( + function_manager_module.ManagerRegistry, + "get_secret_manager", + lambda: fake_secret_manager, + ) + + fm = object.__new__(FunctionManager) + result = await fm.execute_shell_script( + implementation=( + "#!/bin/sh\n" + "token=$(unity-primitive runtime get_oauth_access_token " + "--provider google --min_ttl_seconds 42)\n" + 'if [ "$token" = \'"google-token"\' ]; then echo "TOKEN_OK"; fi\n' + ), + language="sh", + ) + + assert result["error"] is None + assert result["result"] == 0 + assert "TOKEN_OK" in result["stdout"] + assert fake_secret_manager.sync_reasons == ["oauth_access_token:google"] + + +def test_runtime_oauth_env_overlay_routes_through_runtime_helper(monkeypatch): + from unity.common import runtime_oauth + + monkeypatch.setattr( + runtime_oauth, + "get_refresh_token_oauth_env_overlay", + lambda: {"GOOGLE_ACCESS_TOKEN": "fresh-google-token"}, + ) + + fm = object.__new__(FunctionManager) + + assert fm._get_runtime_oauth_env_overlay() == { + "GOOGLE_ACCESS_TOKEN": "fresh-google-token", + } + + +def test_venv_runtime_oauth_helper_uses_parent_rpc(monkeypatch): + from unity.function_manager import venv_runner + + calls = [] + + def fake_rpc_call_sync(path, kwargs): + calls.append((path, kwargs)) + return "fresh-ms-token" + + monkeypatch.setattr(venv_runner, "rpc_call_sync", fake_rpc_call_sync) + + assert ( + venv_runner.get_oauth_access_token("microsoft", min_ttl_seconds=12) + == "fresh-ms-token" + ) + assert calls == [ + ( + "runtime.get_oauth_access_token", + {"provider": "microsoft", "min_ttl_seconds": 12}, + ), + ] diff --git a/unity/actor/code_act_actor.py b/unity/actor/code_act_actor.py index 5938dbf79..b4a32b778 100644 --- a/unity/actor/code_act_actor.py +++ b/unity/actor/code_act_actor.py @@ -2025,6 +2025,15 @@ async def execute_code( - **session_created**: True if a new session was created by this call. - **duration_ms**: Execution duration in milliseconds. + Runtime credential helpers + -------------------------- + Python execution globals include + ``get_oauth_access_token(provider)`` for refresh-token backed OAuth + providers when a provider SDK, client, or direct HTTP request needs + an explicit access token. Static API keys and provider SDKs that + read credentials from the environment may still use ``os.environ`` + after checking available secret names. + For in-process Python execution with rich output, the result is wrapped in an ExecutionResult object (a Pydantic model implementing FormattedToolResult). """ @@ -2088,6 +2097,23 @@ async def _pub_safe(**payload: Any) -> None: notification_q = _notification_up_q sandbox_id = None try: + try: + from unity.manager_registry import ManagerRegistry + + # Keep generated code's normal environment-based credential + # path fresh at the execution boundary. The SecretManager + # gate is debounced, so repeated execute_code calls only pay + # a cheap timestamp check within the TTL window. + ManagerRegistry.get_secret_manager().sync_assistant_secrets_if_stale( + ttl_seconds=60.0, + reason="execute_code", + ) + except Exception: + logger.warning( + "execute_code assistant secret sync failed", + exc_info=True, + ) + _rs = self._resolve_session( state_mode=state_mode, language=str(language), diff --git a/unity/actor/execution/session.py b/unity/actor/execution/session.py index 1210ea137..ca9cecb19 100644 --- a/unity/actor/execution/session.py +++ b/unity/actor/execution/session.py @@ -11,6 +11,8 @@ import ast import contextvars import logging +import json +import shlex import sys import traceback import types @@ -40,6 +42,25 @@ logger = logging.getLogger(__name__) +def _with_shell_env_overlay( + command: str, + env_overlay: dict[str, str], + *, + language: str, +) -> str: + if not env_overlay: + return command + if language == "powershell": + assignments = "\n".join( + f"$env:{key} = {json.dumps(value)}" for key, value in env_overlay.items() + ) + else: + assignments = "\n".join( + f"export {key}={shlex.quote(value)}" for key, value in env_overlay.items() + ) + return f"{assignments}\n{command}" + + # --------------------------------------------------------------------------- # Type aliases # --------------------------------------------------------------------------- @@ -780,6 +801,22 @@ def _se_ms(): if computer_primitives is None: computer_primitives = self._computer_primitives + def _runtime_oauth_env_overlay() -> dict[str, str]: + # The parent execute_code boundary already performs the generic + # debounced secret sync. This overlay is the subprocess-specific + # bridge: venv and shell sessions may be long-lived, so they need + # current rotating OAuth env vars injected for each execution. + if self._function_manager is None: + return {} + getter = getattr( + self._function_manager, + "_get_runtime_oauth_env_overlay", + None, + ) + if getter is None: + return {} + return getter() + async def _execute_in_python_session( sb: PythonExecutionSession, ) -> Dict[str, Any]: @@ -881,6 +918,10 @@ async def _execute_in_python_session( # Wrap arbitrary code in a function definition so venv_runner can execute it. implementation = _wrap_code_as_async_function(code) if state_mode == "stateful": + # Persistent venv workers keep their process environment + # across calls. Pass the OAuth overlay so SDK/default-env + # credential paths see fresh access tokens without the actor + # manually exporting anything. out = await self._venv_pool.execute_in_venv( venv_id=int(venv_id), implementation=implementation, @@ -891,6 +932,7 @@ async def _execute_in_python_session( computer_primitives=computer_primitives, function_manager=self._function_manager, timeout=self._timeout, + env_overlay=_runtime_oauth_env_overlay(), ) return { **out, @@ -916,6 +958,9 @@ async def _execute_in_python_session( session_id=int(session_id), timeout=10.0, ) + # Read-only venv execution runs in a one-shot subprocess + # seeded from persistent state, but still receives the same + # runtime OAuth overlay before code executes. out = await self._function_manager.execute_in_venv( venv_id=int(venv_id), implementation=implementation, @@ -924,6 +969,7 @@ async def _execute_in_python_session( initial_state=initial_state, primitives=primitives, computer_primitives=computer_primitives, + env_overlay=_runtime_oauth_env_overlay(), ) return { **out, @@ -1047,11 +1093,21 @@ async def _execute_in_python_session( language=language, # type: ignore[arg-type] session_id=int(session_id), ) + # Shells are especially prone to stale env because exports persist + # inside the session. We both pass an env overlay to the pool and + # prepend explicit assignments to the command so the current command + # and future commands in the same shell agree on the refreshed token + # values. res = await self._shell_pool.execute( language=language, # type: ignore[arg-type] - command=code, + command=_with_shell_env_overlay( + code, + _runtime_oauth_env_overlay(), + language=str(language), + ), session_id=int(session_id), timeout=self._timeout, + env=_runtime_oauth_env_overlay(), ) return { "stdout": res.stdout, diff --git a/unity/function_manager/execution_env.py b/unity/function_manager/execution_env.py index 799639282..adde67f17 100644 --- a/unity/function_manager/execution_env.py +++ b/unity/function_manager/execution_env.py @@ -259,11 +259,13 @@ async def my_workflow(goal: str) -> SteerableToolHandle: ) from unity.common.llm_client import new_llm_client from unity.common.reasoning import reason + from unity.common.runtime_oauth import get_oauth_access_token globals_dict["SteerableToolHandle"] = SteerableToolHandle globals_dict["start_async_tool_loop"] = start_async_tool_loop globals_dict["new_llm_client"] = new_llm_client globals_dict["reason"] = reason + globals_dict["get_oauth_access_token"] = get_oauth_access_token globals_dict["unillm"] = unillm return globals_dict diff --git a/unity/function_manager/function_manager.py b/unity/function_manager/function_manager.py index 64868adbf..9c25ce08d 100644 --- a/unity/function_manager/function_manager.py +++ b/unity/function_manager/function_manager.py @@ -414,6 +414,7 @@ async def execute( primitives: Optional[Any] = None, computer_primitives: Optional[Any] = None, timeout: Optional[float] = None, + env_overlay: Optional[Dict[str, str]] = None, ) -> dict: """ Execute a function in the persistent venv subprocess. @@ -446,6 +447,7 @@ async def execute( "implementation": implementation, "call_kwargs": call_kwargs, "is_async": is_async, + "env_overlay": env_overlay or {}, }, ) @@ -507,6 +509,16 @@ async def _handle_rpc_call( result, ), } + if namespace == "runtime" and method == "get_oauth_access_token": + from unity.common.runtime_oauth import get_oauth_access_token + + provider = kwargs.get("provider") + min_ttl_seconds = int(kwargs.get("min_ttl_seconds", 300)) + result = get_oauth_access_token( + provider, + min_ttl_seconds=min_ttl_seconds, + ) + return {"type": "rpc_result", "id": request_id, "result": result} if namespace == "computer" and computer_primitives is not None: fn = getattr(computer_primitives, method, None) elif primitives is not None: @@ -709,6 +721,7 @@ async def execute_in_venv( computer_primitives: Optional[Any] = None, function_manager: "FunctionManager", timeout: Optional[float] = None, + env_overlay: Optional[Dict[str, str]] = None, ) -> dict: """ Execute a function in a persistent venv subprocess. @@ -751,6 +764,7 @@ async def execute_in_venv( primitives=primitives, computer_primitives=computer_primitives, timeout=timeout, + env_overlay=env_overlay, ) # Update last_used best-effort md = self._metadata.get(key) @@ -779,6 +793,7 @@ async def execute_in_venv( primitives=primitives, computer_primitives=computer_primitives, timeout=timeout, + env_overlay=env_overlay, ) raise @@ -1646,6 +1661,23 @@ def __init__( # Dict[session_id, Dict[str, Any]] - persistent globals per session self._in_process_sessions: Dict[int, Dict[str, Any]] = {} + def _get_runtime_oauth_env_overlay(self) -> Dict[str, str]: + """Build the rotating OAuth env overlay for venv/shell execution. + + This is intentionally routed through ``unity.common.runtime_oauth`` + rather than SecretManager so provider metadata, expiry semantics, and + runtime helper behavior stay in one place. Failures should not block + unrelated function execution; explicit token calls can still surface a + provider-specific error when the actor really needs a token. + """ + try: + from unity.common.runtime_oauth import get_refresh_token_oauth_env_overlay + + return get_refresh_token_oauth_env_overlay() + except Exception: + logger.warning("Failed to build OAuth env overlay", exc_info=True) + return {} + @property def primitive_scope(self) -> PrimitiveScope: """The scope controlling which managers' primitives are accessible.""" @@ -4607,6 +4639,16 @@ async def _handle_rpc_call( return self._make_json_serializable(await reason(**kwargs)) + if manager_name == "runtime" and method_name == "get_oauth_access_token": + from unity.common.runtime_oauth import get_oauth_access_token + + provider = kwargs.get("provider") + min_ttl_seconds = int(kwargs.get("min_ttl_seconds", 300)) + return get_oauth_access_token( + provider, + min_ttl_seconds=min_ttl_seconds, + ) + # Handle computer primitives if manager_name == "computer": if computer_primitives is None: @@ -4649,6 +4691,7 @@ async def execute_in_venv( initial_state: Optional[Dict[str, Any]] = None, primitives: Optional[Any] = None, computer_primitives: Optional[Any] = None, + env_overlay: Optional[Dict[str, str]] = None, ) -> Dict[str, Any]: """ Execute a function implementation in a custom virtual environment. @@ -4688,6 +4731,7 @@ async def execute_in_venv( "implementation": implementation, "call_kwargs": call_kwargs, "is_async": is_async, + "env_overlay": env_overlay or self._get_runtime_oauth_env_overlay(), } if initial_state is not None: execute_payload["initial_state"] = initial_state diff --git a/unity/function_manager/venv_runner.py b/unity/function_manager/venv_runner.py index 4c4b62481..160e2d5b0 100644 --- a/unity/function_manager/venv_runner.py +++ b/unity/function_manager/venv_runner.py @@ -30,6 +30,7 @@ import asyncio import io import json +import os import signal import sys import threading @@ -295,6 +296,26 @@ async def reason( return result +def get_oauth_access_token(provider: str, *, min_ttl_seconds: int = 300) -> str: + """ + Return a current OAuth access token for a refresh-token backed provider. + + Custom virtual environments run in a child process whose environment can + be older than the parent Unity worker. This helper calls the parent process + over JSON-RPC so rotating OAuth access tokens are read from the current + assistant secret state instead of the child process's inherited env. + + Examples + -------- + ``token = get_oauth_access_token("microsoft")`` + ``token = get_oauth_access_token("google")`` + """ + return rpc_call_sync( + "runtime.get_oauth_access_token", + {"provider": provider, "min_ttl_seconds": min_ttl_seconds}, + ) + + # ──────────────────────────────────────────────────────────────────────────── # Execution Environment # ──────────────────────────────────────────────────────────────────────────── @@ -413,6 +434,7 @@ def create_safe_globals(is_async: bool = True): # Primitives proxy (computer and actor accessible via primitives.computer.* etc.) "primitives": PrimitivesProxy(is_async=is_async), "reason": reason, + "get_oauth_access_token": get_oauth_access_token, } # Try to add pydantic if available in this venv @@ -736,6 +758,21 @@ def inject_state_into_globals(state: dict, globals_dict: dict) -> None: pass +def apply_env_overlay(env_overlay: dict | None) -> None: + """Apply parent-supplied runtime env updates inside the child process. + + The venv runner can be a long-lived subprocess, so inherited environment + variables may be older than Unity's parent runtime. The parent sends only + the runtime overlay needed for execution, currently rotating OAuth token + variables, before each function call. + """ + if not env_overlay: + return + for key, value in env_overlay.items(): + if isinstance(key, str) and isinstance(value, str): + os.environ[key] = value + + # ──────────────────────────────────────────────────────────────────────────── # Main Entry Point # ──────────────────────────────────────────────────────────────────────────── @@ -864,6 +901,7 @@ def main(): call_kwargs = input_data.get("call_kwargs", {}) is_async = input_data.get("is_async", False) initial_state = input_data.get("initial_state") + apply_env_overlay(input_data.get("env_overlay")) # Execute with RPC support, optionally with initial state result = run_with_rpc_loop( @@ -1030,6 +1068,7 @@ def main_server(): implementation = input_data.get("implementation", "") call_kwargs = input_data.get("call_kwargs", {}) is_async = input_data.get("is_async", True) + apply_env_overlay(input_data.get("env_overlay")) # Execute with RPC support using persistent globals result = run_server_with_rpc_loop( From f3ed8e3d1c1571ac1ab7a51bb5f30170e1f27993 Mon Sep 17 00:00:00 2001 From: Haris Mahmood Date: Fri, 8 May 2026 16:01:21 +0500 Subject: [PATCH 10/17] docs(actor): teach runtime OAuth token usage Expose the runtime OAuth helper in the CodeAct prompt using the same signature-and-docstring pattern as reason(...). The guidance distinguishes SDK/default environment behavior from cases that require an explicit access token, and warns against printing, logging, storing, or baking concrete token values into reusable functions or guidance. The prompt test locks in the exact helper signature, multi-provider examples, and anti-pattern guidance so future prompt edits do not accidentally regress the actor's understanding of refreshed OAuth credentials. --- tests/actor/code_act/test_prompt_builders.py | 20 +++++++++++ unity/actor/prompt_builders.py | 36 ++++++++++++++------ 2 files changed, 46 insertions(+), 10 deletions(-) diff --git a/tests/actor/code_act/test_prompt_builders.py b/tests/actor/code_act/test_prompt_builders.py index da67a54a5..13401f9bb 100644 --- a/tests/actor/code_act/test_prompt_builders.py +++ b/tests/actor/code_act/test_prompt_builders.py @@ -111,6 +111,26 @@ def test_code_act_prompt_includes_diverse_examples_sessions_computer_primitives_ assert "execute_function vs execute_code decision" in prompt +@pytest.mark.timeout(30) +def test_code_act_prompt_teaches_refresh_token_oauth_helper(): + actor = CodeActActor() + prompt = build_code_act_prompt( + environments={}, + tools=dict(actor.get_tools("act")), + ) + + assert "def reason(" in prompt + assert ( + "def get_oauth_access_token(provider: str, *, " + "min_ttl_seconds: int = 300) -> str" + ) in prompt + assert 'get_oauth_access_token("microsoft")' in prompt + assert 'get_oauth_access_token("google")' in prompt + assert "refresh-token backed OAuth" in prompt + assert "Do not print, log, return, or store the token value." in prompt + assert "provider sdk/default credential behavior" in prompt.lower() + + @pytest.mark.timeout(30) def test_code_act_prompt_includes_comms_namespace_and_docstrings(): from unity.actor.environments.state_managers import StateManagerEnvironment diff --git a/unity/actor/prompt_builders.py b/unity/actor/prompt_builders.py index 19e5ab627..a6d59f121 100644 --- a/unity/actor/prompt_builders.py +++ b/unity/actor/prompt_builders.py @@ -482,15 +482,29 @@ Cloud, `slack-sdk` for Slack, `boto3` for AWS, `stripe` for Stripe). 3. **Integrate**: Write Python code that uses the SDK with the stored - credentials to interact with the service. Credentials are synced to - environment variables via the `.env` file managed by SecretManager — - use `os.environ` to access them after confirming their names via - `primitives.secrets.ask(...)`. + credentials to interact with the service. Static credentials and + non-rotating API keys are synced to environment variables via the `.env` + file managed by SecretManager; use `os.environ` for those after + confirming their names via `primitives.secrets.ask(...)`. For provider + SDKs that can read OAuth credentials from environment variables, prefer + the SDK's normal/default credential behavior. When a provider SDK, + client, or direct HTTP request requires an explicit refresh-token backed + OAuth access token, use the sandbox helper + `get_oauth_access_token(provider)` instead of reading access-token env + vars directly. + + ```python + microsoft_token = get_oauth_access_token("microsoft") + google_token = get_oauth_access_token("google") + ``` 4. **Store for reuse**: After a successful integration, store reusable functions via `store_skills` and document the setup via `GuidanceManager_add_guidance` so future interactions can reuse the - integration without rediscovery. + integration without rediscovery. Reusable OAuth integrations should + call `get_oauth_access_token(provider)` at runtime only when an explicit + token is required; never store or capture a concrete access-token value + inside a function implementation. **Prefer Python SDKs over CLI tools.** Python packages benefit from full environment management (isolated venvs, dependency resolution via @@ -500,11 +514,11 @@ #### Checking OAuth Scope Before API Calls - Before making Google or Microsoft API calls that rely on - platform-managed OAuth tokens, check whether the scope you need - has been granted. `GOOGLE_GRANTED_SCOPES` and - `MICROSOFT_GRANTED_SCOPES` hold space-separated raw OAuth scope - strings — not feature names. Examples of what you will see: + Before making API calls that rely on platform-managed OAuth tokens, + check whether the scope you need has been granted when the provider has + a granted-scopes secret. For the built-in providers, `GOOGLE_GRANTED_SCOPES` + and `MICROSOFT_GRANTED_SCOPES` hold space-separated raw OAuth scope + strings — not feature names. Examples of what you will see: - Google: full URLs such as `https://www.googleapis.com/auth/drive` and @@ -833,8 +847,10 @@ def build_code_act_prompt( parts.append(_EXECUTION_RULES) parts.append(_SEMANTIC_REASONING_SELECTION) from unity.common.reasoning import get_reasoning_prompt_context + from unity.common.runtime_oauth import get_oauth_prompt_context parts.append(get_reasoning_prompt_context()) + parts.append(get_oauth_prompt_context()) parts.append(_INCREMENTAL_EXECUTION) parts.append(_EXTERNAL_APP_INTEGRATION) From b42b9a2e2997997ec204589ac9c70817d328e732 Mon Sep 17 00:00:00 2001 From: Haris Mahmood Date: Mon, 11 May 2026 14:00:00 +0500 Subject: [PATCH 11/17] chore: Run black formatting --- unity/secret_manager/secret_manager.py | 4 +++- unity/spending_limits.py | 6 +----- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/unity/secret_manager/secret_manager.py b/unity/secret_manager/secret_manager.py index 771df41a1..487c4f037 100644 --- a/unity/secret_manager/secret_manager.py +++ b/unity/secret_manager/secret_manager.py @@ -228,7 +228,9 @@ def _resolve_secret_allowlist(cls) -> frozenset[str]: try: from unity.common.runtime_oauth import refresh_token_oauth_secret_names - return cls._BUILTIN_OAUTH_SECRET_ALLOWLIST | refresh_token_oauth_secret_names() + return ( + cls._BUILTIN_OAUTH_SECRET_ALLOWLIST | refresh_token_oauth_secret_names() + ) except Exception: return cls._BUILTIN_OAUTH_SECRET_ALLOWLIST diff --git a/unity/spending_limits.py b/unity/spending_limits.py index 2edbc35bf..277d59b8a 100644 --- a/unity/spending_limits.py +++ b/unity/spending_limits.py @@ -376,11 +376,7 @@ def _to_limit_type(type_str: Optional[str]) -> Optional[LimitType]: # legacy gate would block every call. Skip it for METERED, keep it # for CREDITS (and for the no-billing-mode-yet legacy case so we # don't loosen the gate during a partial Orchestra rollout). - if ( - billing_mode != "METERED" - and credit_balance is not None - and credit_balance <= 0 - ): + if billing_mode != "METERED" and credit_balance is not None and credit_balance <= 0: return LimitCheckResponse( allowed=False, reason=( From ce2ad82f1ef4652b8aae16b4b8927d26f55f8371 Mon Sep 17 00:00:00 2001 From: Haris Mahmood Date: Tue, 12 May 2026 15:34:35 +0500 Subject: [PATCH 12/17] feat(tasks): run description tasks through contained actors Route task execution through the active actor context instead of silently falling back to a simulated actor, and add workflow-specific post-run review plumbing for recurring and triggerable description-driven tasks. --- unity/actor/code_act_actor.py | 120 ++++++++++++++++++++-- unity/common/task_execution_context.py | 23 ++++- unity/task_scheduler/active_task.py | 99 +++++++++++------- unity/task_scheduler/prompt_builders.py | 85 +++++++++++++++- unity/task_scheduler/task_scheduler.py | 129 +++++++++++++++++++++--- 5 files changed, 389 insertions(+), 67 deletions(-) diff --git a/unity/actor/code_act_actor.py b/unity/actor/code_act_actor.py index b4a32b778..f1d18de25 100644 --- a/unity/actor/code_act_actor.py +++ b/unity/actor/code_act_actor.py @@ -41,6 +41,8 @@ start_async_tool_loop, ) from unity.common.task_execution_context import ( + PostRunReviewContext, + current_post_run_review_context, TaskExecutionDelegate, current_task_execution_delegate, ) @@ -365,14 +367,16 @@ async def start_task_run( """Start one task run using this actor's CodeAct execution machinery.""" _ = images + task_guidelines = kwargs.pop("guidelines", None) return await self._actor.act( task_description, + guidelines=task_guidelines, _parent_chat_context=parent_chat_context, _clarification_up_q=clarification_up_q, _clarification_down_q=clarification_down_q, entrypoint=entrypoint, persist=False, - _reuse_actor_slot=True, + _reuse_actor_slot=entrypoint is not None, **kwargs, ) @@ -381,6 +385,11 @@ async def start_task_run( # Shared storage-review prompt sections # --------------------------------------------------------------------------- +_DEFAULT_STORAGE_REVIEW_LABEL = "Storing reusable skills" +_DEFAULT_STORAGE_REVIEW_INSTRUCTIONS = ( + "Review the trajectory and store any reusable functions and compositional guidance." +) + _STORAGE_WHAT_CAN_BE_STORED = ( "## What Can Be Stored\n\n" "Any code that executed successfully in `execute_code` during " @@ -623,6 +632,7 @@ def _build_storage_tools( actor: "CodeActActor", ask_tools: dict, completed_tool_metadata: dict | None = None, + task_entrypoint_review: dict[str, Any] | None = None, ) -> tuple[Dict[str, Callable], list[str]]: """Build the tool dict shared by both post-processing and proactive storage loops. @@ -801,6 +811,45 @@ async def resume_inner_storage(tool_name: str) -> str: tools["pause_inner_storage"] = pause_inner_storage tools["resume_inner_storage"] = resume_inner_storage + if task_entrypoint_review: + attach_entrypoint = task_entrypoint_review.get("attach_entrypoint") + metadata = dict(task_entrypoint_review.get("metadata") or {}) + task_id = metadata.get("task_id") + instance_id = metadata.get("instance_id") + task_name = metadata.get("task_name") or metadata.get("name") or "the task" + + async def attach_entrypoint_to_recurring_task( + function_id: int, + rationale: str, + ) -> str: + """Attach a stored FunctionManager entrypoint to future runs of this task. + + Use this only after you have reviewed the completed trajectory and + decided that the stored function captures a stable reusable workflow + for future scheduled or triggered instances. Leaving the task + description-driven is valid when future runs still need broad + planning or tool discovery. + """ + + if not callable(attach_entrypoint): + return "No task entrypoint attachment hook is available." + return str( + attach_entrypoint( + function_id=int(function_id), + rationale=str(rationale), + ), + ) + + attach_entrypoint_to_recurring_task.__doc__ += ( + f"\n\nCurrent task: {task_name} " + f"(task_id={task_id}, completed instance_id={instance_id}). " + "The tool only patches future non-terminal instances; it never " + "rewrites the completed run." + ) + tools["attach_entrypoint_to_recurring_task"] = ( + attach_entrypoint_to_recurring_task + ) + return tools, storage_active_lines @@ -819,6 +868,7 @@ def _start_storage_check_loop( parent_lineage: list[str] | None = None, stop_reason: str | None = None, proactive_summaries: list[str] | None = None, + post_run_review_context: PostRunReviewContext | None = None, ) -> "AsyncToolLoopHandle | None": """Start a loop that reviews a completed trajectory for reusable knowledge. @@ -837,11 +887,17 @@ def _start_storage_check_loop( gm = actor.guidance_manager if fm is None or gm is None: return None + task_entrypoint_review = ( + post_run_review_context.extensions.get("task_entrypoint_review") + if post_run_review_context is not None + else None + ) tools, storage_active_lines = _build_storage_tools( actor=actor, ask_tools=ask_tools, completed_tool_metadata=completed_tool_metadata, + task_entrypoint_review=task_entrypoint_review, ) # ── Build prompt ────────────────────────────────────────────────── @@ -926,6 +982,31 @@ def _start_storage_check_loop( "be worth storing.\n\n" ) + task_entrypoint_section = "" + if task_entrypoint_review: + metadata = dict(task_entrypoint_review.get("metadata") or {}) + metadata_json = json.dumps(metadata, indent=2, default=str) + task_entrypoint_section = ( + "## Recurring Task Entrypoint Review\n\n" + "This trajectory completed a scheduled or triggered task that had " + "no stored entrypoint when it ran. You must explicitly consider " + "whether the successful run revealed a stable reusable workflow " + "worth attaching to future task instances.\n\n" + "No-op is valid: keep the task description-driven if future runs " + "need broad planning, changing tool discovery, or open-ended " + "judgment. If the workflow can be stabilized as code, it may still " + "use focused `reason(...)` calls for bounded semantic substeps " + "such as summarization, classification, ranking, drafting, or " + "source selection.\n\n" + "If you store a FunctionManager function and decide it is stable " + "enough for future runs, call " + "`attach_entrypoint_to_recurring_task(function_id=..., " + "rationale=...)`. Do not call that tool unless the function has " + "already been persisted and you have the numeric function_id.\n\n" + "Task metadata:\n" + f"```json\n{metadata_json}\n```\n\n" + ) + system_prompt = ( "You are a skill librarian. A CodeActActor has just completed a task. " "Your job is to review the execution trajectory and decide whether " @@ -936,6 +1017,7 @@ def _start_storage_check_loop( "## Final Result\n\n" f"{original_result}\n\n" f"{stop_context_section}" + f"{task_entrypoint_section}" f"{inner_storage_section}" f"{proactive_storage_section}" f"{_STORAGE_WHAT_CAN_BE_STORED}" @@ -1095,9 +1177,11 @@ def __init__( *, inner: "AsyncToolLoopHandle", actor: "CodeActActor", + post_run_review_context: PostRunReviewContext | None = None, ) -> None: self._inner = inner self._actor = actor + self._post_run_review_context = post_run_review_context self._notification_q: asyncio.Queue[dict] = asyncio.Queue() self._task_done_event = asyncio.Event() self._completion_event = asyncio.Event() @@ -1167,9 +1251,11 @@ async def _run_lifecycle(self) -> None: try: self._original_result = await self._inner.result() + task_succeeded = not self._stopped except asyncio.CancelledError: raise except Exception as exc: + task_succeeded = False self._original_result = ( f"Error: inner task failed: {type(exc).__name__}: {exc}" ) @@ -1230,14 +1316,27 @@ async def _run_lifecycle(self) -> None: _sc_suffix_token = _PENDING_LOOP_SUFFIX.set(_sc_suffix) try: + active_review_context = ( + self._post_run_review_context if task_succeeded else None + ) + review_display_label = ( + active_review_context.display_label + if active_review_context is not None + else _DEFAULT_STORAGE_REVIEW_LABEL + ) + review_instructions = ( + active_review_context.instructions + if active_review_context is not None + else _DEFAULT_STORAGE_REVIEW_INSTRUCTIONS + ) await publish_manager_method_event( _sc_call_id, "CodeActActor", "StorageCheck", phase="incoming", - display_label="Storing reusable skills", + display_label=review_display_label, hierarchy=_sc_hierarchy, - instructions="Review the trajectory and store any reusable functions and compositional guidance.", + instructions=review_instructions, ) proactive_summaries: list[str] = [] @@ -1259,6 +1358,7 @@ async def _run_lifecycle(self) -> None: parent_lineage=_sc_parent_lineage, stop_reason=self._stop_reason, proactive_summaries=proactive_summaries or None, + post_run_review_context=active_review_context, ) if storage_handle is None: @@ -1267,7 +1367,7 @@ async def _run_lifecycle(self) -> None: "CodeActActor", "StorageCheck", phase="outgoing", - display_label="Storing reusable skills", + display_label=review_display_label, hierarchy=_sc_hierarchy, ) else: @@ -1284,7 +1384,7 @@ async def _run_lifecycle(self) -> None: "CodeActActor", "StorageCheck", phase="outgoing", - display_label="Storing reusable skills", + display_label=review_display_label, hierarchy=_sc_hierarchy, ) finally: @@ -4063,9 +4163,15 @@ async def _resume_with_propagation(**kwargs: Any) -> None: # Update agent context with handle reference new_ctx.handle = handle + post_run_review_context = current_post_run_review_context.get() + # Wrap in StorageCheckHandle for post-completion function review. - if effective_can_store: - handle = _StorageCheckHandle(inner=handle, actor=self) + if effective_can_store or post_run_review_context is not None: + handle = _StorageCheckHandle( + inner=handle, + actor=self, + post_run_review_context=post_run_review_context, + ) return handle diff --git a/unity/common/task_execution_context.py b/unity/common/task_execution_context.py index 78ae1bfd9..4c75f51ab 100644 --- a/unity/common/task_execution_context.py +++ b/unity/common/task_execution_context.py @@ -23,7 +23,8 @@ Here, it enables run-scoped delegation: - **Run-scoped**: a delegate is set at the top of an async execution context and - reset in a `finally` block. + reset in a `finally` block. The delegate owns how the task run is contained + inside that environment, such as starting a child actor run for one task. - **Async-safe**: `ContextVar` propagation ensures each async task tree sees the correct delegate under concurrency. - **No leakage**: callers must reset to prevent delegates persisting across runs. @@ -50,6 +51,7 @@ from __future__ import annotations import asyncio +from dataclasses import dataclass, field from contextvars import ContextVar from typing import Any, Optional, Protocol, TYPE_CHECKING, runtime_checkable @@ -74,8 +76,8 @@ class TaskExecutionDelegate(Protocol): Usage ----- This protocol is used by task execution routing to run tasks through the - same execution environment that initiated the task, rather than spawning a - new one. + execution environment that initiated the task while preserving one task run + per returned handle. """ async def start_task_run( @@ -125,3 +127,18 @@ async def start_task_run( "current_task_execution_delegate", default=None, ) + + +@dataclass(frozen=True) +class PostRunReviewContext: + """Run-scoped metadata for an optional post-completion storage review.""" + + display_label: str + instructions: str + extensions: dict[str, Any] = field(default_factory=dict) + + +current_post_run_review_context: ContextVar[PostRunReviewContext | None] = ContextVar( + "current_post_run_review_context", + default=None, +) diff --git a/unity/task_scheduler/active_task.py b/unity/task_scheduler/active_task.py index 1349b673b..16e4276c4 100644 --- a/unity/task_scheduler/active_task.py +++ b/unity/task_scheduler/active_task.py @@ -18,7 +18,11 @@ from .base import BaseActiveTask from ..actor.base import BaseActor from unity.common.async_tool_loop import SteerableToolHandle -from unity.common.task_execution_context import current_task_execution_delegate +from unity.common.task_execution_context import ( + PostRunReviewContext, + current_post_run_review_context, + current_task_execution_delegate, +) from unity.common._async_tool.messages import forward_handle_call from .machine_state import ( TaskRunProvenance, @@ -167,6 +171,8 @@ async def create( entrypoint: Optional[int] = None, task_run_reference: Optional[TaskRunReference] = None, task_run_provenance: Optional[TaskRunProvenance] = None, + task_entrypoint_review: Optional[dict[str, Any]] = None, + task_guidelines: Optional[str] = None, ) -> "ActiveTask": """ Create an ActiveTask by starting work through a delegate or fallback actor. @@ -177,44 +183,63 @@ async def create( because execution is routed through the delegate instead. """ delegate = current_task_execution_delegate.get() + review_token = None + if task_entrypoint_review is not None: + review_token = current_post_run_review_context.set( + PostRunReviewContext( + display_label="Storing reusable workflow", + instructions=( + "Review the successful task trajectory and decide whether " + "a stable reusable workflow should be stored and attached " + "to future scheduled or triggered task instances." + ), + extensions={"task_entrypoint_review": task_entrypoint_review}, + ), + ) try: - if delegate is not None: - actor_steerable_handle = await delegate.start_task_run( - task_description=task_description, - entrypoint=entrypoint, - parent_chat_context=_parent_chat_context, - clarification_up_q=_clarification_up_q, - clarification_down_q=_clarification_down_q, - ) - else: - if fallback_actor is None: - raise RuntimeError( - "Task execution requires an actor when no run-scoped delegate is active.", + try: + if delegate is not None: + actor_steerable_handle = await delegate.start_task_run( + task_description=task_description, + entrypoint=entrypoint, + parent_chat_context=_parent_chat_context, + clarification_up_q=_clarification_up_q, + clarification_down_q=_clarification_down_q, + guidelines=task_guidelines, ) - actor_steerable_handle = await fallback_actor.act( - task_description, - _parent_chat_context=_parent_chat_context, - _clarification_up_q=_clarification_up_q, - _clarification_down_q=_clarification_down_q, - # Always pass entrypoint to the actor so it can immediately run the function - entrypoint=entrypoint, - persist=False, # Scheduler-run plans should complete instead of pausing for interjection - ) - except Exception as exc: - if task_run_reference is not None: - await asyncio.to_thread( - update_task_run_record, - task_run_reference, - { - "state": "failed", - "completed_at": _now_iso(), - "error": _truncate_task_run_text(str(exc)), - "result_summary": _truncate_task_run_text( - f"Task failed before execution fully started: {type(exc).__name__}({exc})", - ), - }, - ) - raise + else: + if fallback_actor is None: + raise RuntimeError( + "Task execution requires an actor when no run-scoped delegate is active.", + ) + actor_steerable_handle = await fallback_actor.act( + task_description, + guidelines=task_guidelines, + _parent_chat_context=_parent_chat_context, + _clarification_up_q=_clarification_up_q, + _clarification_down_q=_clarification_down_q, + # Always pass entrypoint to the actor so it can immediately run the function + entrypoint=entrypoint, + persist=False, # Scheduler-run plans should complete instead of pausing for interjection + ) + except Exception as exc: + if task_run_reference is not None: + await asyncio.to_thread( + update_task_run_record, + task_run_reference, + { + "state": "failed", + "completed_at": _now_iso(), + "error": _truncate_task_run_text(str(exc)), + "result_summary": _truncate_task_run_text( + f"Task failed before execution fully started: {type(exc).__name__}({exc})", + ), + }, + ) + raise + finally: + if review_token is not None: + current_post_run_review_context.reset(review_token) materialized_task_run_reference = task_run_reference if materialized_task_run_reference is None and task_run_provenance is not None: try: diff --git a/unity/task_scheduler/prompt_builders.py b/unity/task_scheduler/prompt_builders.py index 780d70341..6c8fd6430 100644 --- a/unity/task_scheduler/prompt_builders.py +++ b/unity/task_scheduler/prompt_builders.py @@ -8,17 +8,15 @@ from __future__ import annotations +import json from typing import Dict, Callable, Union, List from .types.task import Task +from .types.activated_by import ActivatedBy from ..common.prompt_helpers import ( - clarification_guidance, - sig_dict, - now, tool_name, require_tools, get_custom_columns, - # New standardized composer utilities PromptSpec, PromptParts, compose_system_prompt, @@ -33,6 +31,69 @@ # ───────────────────────────────────────────────────────────────────────────── +def build_task_execution_request(task: Task) -> str: + """Build the actor-facing request for one task instance.""" + + lines = [ + "Execute this TaskScheduler task as a contained task run.", + "", + f"Task id: {task.task_id}", + f"Instance id: {task.instance_id}", + f"Task name: {task.name}", + "", + "Task description:", + task.description or task.name, + ] + if task.response_policy: + lines.extend(["", "Task response policy:", task.response_policy]) + if task.schedule is not None: + lines.extend( + [ + "", + "Schedule metadata:", + json.dumps(task.schedule.model_dump(mode="json"), default=str), + ], + ) + if task.trigger is not None: + lines.extend( + [ + "", + "Trigger metadata:", + json.dumps(task.trigger.model_dump(mode="json"), default=str), + ], + ) + if task.repeat is not None: + lines.extend( + [ + "", + "Repeat metadata:", + json.dumps( + [r.model_dump(mode="json") for r in task.repeat], + default=str, + ), + ], + ) + return "\n".join(lines) + + +def build_task_run_guidelines(task: Task, reason: ActivatedBy) -> str: + """Build execution guidelines for a contained actor task run.""" + + return ( + "You are executing exactly one TaskScheduler task. Treat the task " + "name, description, schedule, trigger, repeat, and response policy " + "as the authoritative instruction for this run. Complete the task " + "itself; do not create another task unless the task description " + "explicitly asks you to create or modify tasks. If this task has no " + "stored entrypoint, interpret the natural-language description " + "directly using the available primitives and functions. Keep any " + "progress notifications focused on this task run.\n\n" + f"Activation reason: {reason.value}\n" + f"Task id: {task.task_id}\n" + f"Instance id: {task.instance_id}" + ) + + def build_ask_prompt( tools: Dict[str, Callable], num_tasks: int, @@ -360,6 +421,21 @@ def build_update_prompt( usage_examples_lines.extend( [ + "", + "Recurring and triggered workflows", + "---------------------------------", + '• For requests like "do this every Monday" or "send this report daily", create a live scheduled task with `schedule.start_at` for the first run and `repeat` for the cadence.', + "• For requests like \"whenever Alice emails about invoices\", create a live triggerable task with `trigger` and status 'triggerable'. Use contact lookup first when the trigger references a person.", + "• A scheduled/triggered live task may have `entrypoint=None`. This is the normal default for newly described natural-language workflows: execution will wake a contained actor run that interprets the description.", + "• Do not create an entrypoint function merely because a recurring task is being created. Entrypoint creation should follow an explicit user request or a successful run that has been reviewed as stable enough to store.", + "• If the user asks to repeat a workflow that just succeeded interactively and also wants hidden/offline execution, the workflow must first be stored as a function-backed skill; offline tasks require a numeric `entrypoint`.", + "• A stored entrypoint can still call `reason(...)` for bounded semantic judgment such as summarization, classification, ranking, drafting, or source selection. Keep broad planning and changing tool discovery actor-driven.", + "", + "Repeat field examples", + "---------------------", + "• Daily at a fixed time: set `schedule.start_at` to the first due datetime and `repeat=[{'frequency':'daily','interval':1}]`.", + "• Weekly on Monday at 12:00 UTC: set first `schedule.start_at` to the next Monday 12:00 UTC and `repeat=[{'frequency':'weekly','interval':1,'weekdays':['MO'],'time_of_day':'12:00'}]`.", + "• End after N runs: include `count`. End after a date: include `until`.", "", "Schedule/Queue invariants (must-follow)", "---------------------------------------", @@ -385,6 +461,7 @@ def build_update_prompt( "Triggers vs Schedules", "----------------------", f"• A task with a `trigger` must be in state 'triggerable'. Use `{update_task_fname}(task_id=, trigger=...)` to add/remove triggers. Do not set `start_at` on trigger‑based tasks.", + "• `schedule` and `trigger` are mutually exclusive. Use `repeat` with `schedule` for cadence-based tasks; use `trigger` for inbound-event tasks.", ], ) diff --git a/unity/task_scheduler/task_scheduler.py b/unity/task_scheduler/task_scheduler.py index caf816d03..0c9c7b9ec 100644 --- a/unity/task_scheduler/task_scheduler.py +++ b/unity/task_scheduler/task_scheduler.py @@ -19,7 +19,6 @@ from typing import Literal, overload from pydantic import BaseModel from dataclasses import dataclass -from functools import cached_property from ..settings import SETTINGS from ..common.embed_utils import ensure_vector_column @@ -64,11 +63,12 @@ from ..common.model_to_fields import model_to_fields from .prompt_builders import ( build_ask_prompt, + build_task_execution_request, + build_task_run_guidelines, build_update_prompt, ) from .base import BaseTaskScheduler from ..actor.base import BaseActor -from ..actor.simulated import SimulatedActor from .active_task import ActiveTask from .active_queue import ActiveQueue @@ -155,8 +155,9 @@ def __init__( Parameters ---------- actor : BaseActor | None, default ``None`` - Actor used to execute the steps of an active task. When ``None``, a - ``SimulatedActor(duration=20)`` is used. + Explicit fallback actor used for direct scheduler execution when no + run-scoped execution delegate is active. When ``None``, direct + execution fails loudly instead of creating an implicit actor. rolling_summary_in_prompts : bool, default ``True`` Whether to inject the rolling activity summary into system prompts sent to the LLM. @@ -296,18 +297,105 @@ def __init__( # this cache remains coherent without extra backend reads between tool calls. self._num_tasks_cached: Optional[int] = None - @cached_property - def _actor(self) -> BaseActor: - if self.__actor is None: - self.__actor = SimulatedActor(duration=SETTINGS.task.SIM_ACTOR_DURATION) - return self.__actor - def _actor_for_task_run(self) -> BaseActor | None: """Return the fallback actor only when task execution is not delegated.""" if current_task_execution_delegate.get() is not None: return None - return self._actor + return self.__actor + + def _build_task_entrypoint_review( + self, + *, + task: Task, + reason: ActivatedBy, + ) -> dict[str, Any] | None: + """Return post-run entrypoint review context for description-driven tasks.""" + + if task.entrypoint is not None: + return None + if task.repeat is None and task.trigger is None: + return None + + metadata: dict[str, Any] = { + "task_id": task.task_id, + "instance_id": task.instance_id, + "task_name": task.name, + "task_description": task.description, + "activation_reason": reason.value, + "response_policy": task.response_policy, + "schedule": ( + task.schedule.model_dump(mode="json") + if task.schedule is not None + else None + ), + "trigger": ( + task.trigger.model_dump(mode="json") + if task.trigger is not None + else None + ), + "repeat": ( + [pattern.model_dump(mode="json") for pattern in task.repeat] + if task.repeat is not None + else None + ), + } + + def _attach_entrypoint(*, function_id: int, rationale: str) -> dict[str, Any]: + return self._attach_entrypoint_to_future_instances( + task_id=task.task_id, + completed_instance_id=task.instance_id, + function_id=function_id, + rationale=rationale, + ) + + return { + "metadata": metadata, + "attach_entrypoint": _attach_entrypoint, + } + + def _attach_entrypoint_to_future_instances( + self, + *, + task_id: int, + completed_instance_id: int, + function_id: int, + rationale: str, + ) -> dict[str, Any]: + """Attach an entrypoint to future non-terminal instances of a logical task.""" + + if function_id < 0: + raise ValueError("function_id must be a non-negative integer.") + future_logs = self._view.get_rows( + filter=( + f"task_id == {task_id} and instance_id > {completed_instance_id} " + "and entrypoint is None and status not in ('completed','cancelled','failed','active')" + ), + return_ids_only=False, + ) + if not future_logs: + return { + "outcome": "no_future_instances", + "task_id": task_id, + "completed_instance_id": completed_instance_id, + "function_id": function_id, + "rationale": rationale, + } + + log_ids = [log.id for log in future_logs] + self._write_log_entries( + logs=log_ids, + entries={"entrypoint": int(function_id)}, + ) + return { + "outcome": "attached", + "task_id": task_id, + "patched_instance_ids": [ + log.entries.get("instance_id") for log in future_logs + ], + "function_id": int(function_id), + "rationale": rationale, + } # ------------------------------ Provisioning ----------------------------- # def warm_embeddings(self) -> None: @@ -746,6 +834,13 @@ async def _execute_internal( else: reason = ActivatedBy.explicit + fallback_actor = self._actor_for_task_run() + if fallback_actor is None and current_task_execution_delegate.get() is None: + raise RuntimeError( + "TaskScheduler.execute requires a run-scoped actor delegate or an explicit actor. " + "Description-driven tasks should be executed from Actor.act via primitives.tasks.execute(...).", + ) + task_run_source_type = ( "triggered" if trigger_attempt_token @@ -770,12 +865,9 @@ async def _execute_internal( unlink_from_prev=unlink_from_prev, ) - # Start task execution (delegated to the current execution environment when available) - # and wrap the resulting handle for Tasks-table synchronization. - handle = await ActiveTask.create( - self._actor_for_task_run(), - task_description=task.description or task.name, + fallback_actor, + task_description=build_task_execution_request(task), _parent_chat_context=parent_chat_context, _clarification_up_q=clarification_up_q, _clarification_down_q=clarification_down_q, @@ -784,6 +876,11 @@ async def _execute_internal( scheduler=self, entrypoint=task.entrypoint, task_run_provenance=task_run_provenance, + task_entrypoint_review=self._build_task_entrypoint_review( + task=task, + reason=reason, + ), + task_guidelines=build_task_run_guidelines(task, reason), ) self._active_task = TaskScheduler.ActivePointer( From ee5fbb65bbbf6b9eba3d71d216109e979a8d0670 Mon Sep 17 00:00:00 2001 From: Haris Mahmood Date: Tue, 12 May 2026 15:34:39 +0500 Subject: [PATCH 13/17] fix(tasks): preserve schedule as generic dict storage Mark schedule payloads with explicit dict typing so queue linkage and datetime schedules can coexist without backend type inference conflicts. --- unity/task_scheduler/storage.py | 27 ++++++++++++++++++++++++--- unity/task_scheduler/types/task.py | 12 ++++++++++-- 2 files changed, 34 insertions(+), 5 deletions(-) diff --git a/unity/task_scheduler/storage.py b/unity/task_scheduler/storage.py index c708b5ce1..06a7b8b5b 100644 --- a/unity/task_scheduler/storage.py +++ b/unity/task_scheduler/storage.py @@ -357,6 +357,22 @@ def _norm(v: Any) -> Any: return [TasksStore._norm(x) for x in v] return v + @staticmethod + def _with_explicit_task_types(entries: Any) -> Any: + if isinstance(entries, list): + return [TasksStore._with_explicit_task_types(entry) for entry in entries] + if not isinstance(entries, dict): + return entries + if entries.get("schedule") is None: + return entries + out = dict(entries) + explicit_types = dict(out.get("explicit_types") or {}) + schedule_types = dict(explicit_types.get("schedule") or {}) + schedule_types["type"] = "dict" + explicit_types["schedule"] = schedule_types + out["explicit_types"] = explicit_types + return out + # ------------------------------- Writes -------------------------------- def update( self, @@ -392,7 +408,9 @@ def _strip_nones(value: Any, *, top_level: bool) -> Any: ] return value - norm_entries = _strip_nones(TasksStore._norm(entries), top_level=True) + norm_entries = TasksStore._with_explicit_task_types( + _strip_nones(TasksStore._norm(entries), top_level=True), + ) return unify.update_logs( logs=logs, context=self._ctx, @@ -401,7 +419,7 @@ def _strip_nones(value: Any, *, top_level: bool) -> Any: ) def log(self, *, entries: Dict[str, Any], new: bool = True) -> unify.Log: - norm_entries = TasksStore._norm(entries) + norm_entries = TasksStore._with_explicit_task_types(TasksStore._norm(entries)) # Create with expanded fields so auto-counting applies when ids are omitted return unity_log( project=self._project, @@ -420,7 +438,10 @@ def create_many(self, *, entries_list: List[Dict[str, Any]]) -> Dict[str, Any]: with auto-incremented row identifiers. """ - normalised = [{**TasksStore._norm(e)} for e in entries_list] + normalised = [ + TasksStore._with_explicit_task_types({**TasksStore._norm(e)}) + for e in entries_list + ] try: return unity_create_logs( project=self._project, diff --git a/unity/task_scheduler/types/task.py b/unity/task_scheduler/types/task.py index 2f3ae42ba..a8708d633 100644 --- a/unity/task_scheduler/types/task.py +++ b/unity/task_scheduler/types/task.py @@ -30,6 +30,7 @@ class TaskBase(BaseModel): schedule: Optional[Schedule] = Field( default=None, description="Information about task scheduling, including adjacent tasks in the queue and ideal start time", + json_schema_extra={"unify_type": "dict"}, ) trigger: Optional[Trigger] = Field( default=None, @@ -41,7 +42,11 @@ class TaskBase(BaseModel): ) repeat: Optional[List[RepeatPattern]] = Field( default=None, - description="Pattern defining how the task recurs over time", + description=( + "Pattern defining how the task recurs over time. Recurring live tasks " + "may begin with entrypoint=null and execute from the natural-language " + "description until a post-run review stores a stable function." + ), ) priority: Priority = Field( description="Importance level of the task (low, normal, high, urgent)", @@ -58,7 +63,10 @@ class TaskBase(BaseModel): default=None, description=( "Optional function_id from the Functions table that should be invoked to perform this task. " - "When null, the task is executed by an Actor interpreting the free-form description on the fly." + "When null, a live task is executed by a contained Actor run interpreting the free-form " + "description on the fly. Do not set this for a newly described workflow unless the user " + "explicitly asks for a stored function-backed workflow or a successful execution has been " + "reviewed and distilled into a stable function." ), ) offline: bool = Field( From 7ce27ad91e7bd1eb5136ec2f8a9e39d6aeb87f74 Mon Sep 17 00:00:00 2001 From: Haris Mahmood Date: Tue, 12 May 2026 15:34:44 +0500 Subject: [PATCH 14/17] docs(tasks): clarify live recurring workflow semantics Teach actor and scheduler prompts that new scheduled or triggered workflows should usually remain live and description-driven unless a stored function is explicitly requested or later distilled. --- unity/actor/prompt_builders.py | 31 +++++++++++++++++++++++++++++++ unity/actor/prompt_examples.py | 33 +++++++++++++++++++++++++++++++++ unity/task_scheduler/README.md | 24 ++++++++++++++++++++++++ unity/task_scheduler/base.py | 29 ++++++++++++++++++++++++----- 4 files changed, 112 insertions(+), 5 deletions(-) diff --git a/unity/actor/prompt_builders.py b/unity/actor/prompt_builders.py index a6d59f121..226ac6b3c 100644 --- a/unity/actor/prompt_builders.py +++ b/unity/actor/prompt_builders.py @@ -464,6 +464,34 @@ straight to `compress_context`. """).strip() +_TASK_SCHEDULING_WORKFLOWS = textwrap.dedent(""" + ### Durable Scheduled And Triggered Workflows + + When the user asks for work to happen later, repeatedly, or in response to + future inbound events, represent that durable intent with the task + primitives rather than only doing the work once. + + Use `primitives.tasks.update(...)` for requests like: + - "Repeat this every Monday at 12:00 UTC" + - "Send me this report every day" + - "Whenever Alice emails about invoices, summarize it and draft a reply" + - "Turn what we just did into a recurring workflow" + + Natural-language recurring tasks should normally start as live + description-driven tasks with `entrypoint=None`. The future due wake will + call `primitives.tasks.execute(task_id=...)`; execution then runs a + contained child actor dedicated to that task. Do not write and attach an + untested entrypoint function at task creation unless the user explicitly + requested a stored function-backed workflow. + + If a workflow has just been completed interactively and the user wants it + repeated, include the relevant context in the task description. Use + `store_skills` or direct FunctionManager writes only when the user asks to + store the workflow, or when the completed trajectory clearly reveals a + reusable function worth saving. Offline tasks require a stored entrypoint; + description-only recurring work should remain live. +""").strip() + _EXTERNAL_APP_INTEGRATION = textwrap.dedent(""" ### External App Integration @@ -745,12 +773,15 @@ def _build_code_act_rules_and_examples( _has_computer = any( k.startswith("primitives.computer.") for k in env.get_tools() ) + _has_tasks = any(k.startswith("primitives.tasks.") for k in env.get_tools()) _has_state = any( k.startswith("primitives.") and not k.startswith("primitives.computer.") and not k.startswith("primitives.actor.") for k in env.get_tools() ) + if _has_tasks: + parts.append(_TASK_SCHEDULING_WORKFLOWS) if _has_computer and _has_state: from unity.actor.prompt_examples import get_mixed_examples diff --git a/unity/actor/prompt_examples.py b/unity/actor/prompt_examples.py index b8782531f..c8475baf6 100644 --- a/unity/actor/prompt_examples.py +++ b/unity/actor/prompt_examples.py @@ -901,6 +901,38 @@ class TaskIdResult(BaseModel): ''' +def get_primitives_task_recurring_creation_example() -> str: + """Example: creating durable scheduled and triggered tasks.""" + + return """ +# Example: durable recurring and triggered workflow creation +async def create_description_driven_recurring_tasks() -> str: + # User: "Every Monday at 12:00 UTC, research AI/agentic AI work from + # the last week and email me a summary document." + scheduled = await primitives.tasks.update( + "Create a live scheduled recurring task. Name: Weekly AI research report. " + "Description: Every Monday at 12:00 UTC, research important AI and agentic AI " + "work from the previous week, summarize the most important developments, " + "create a concise document, and email it to me. Set the first start_at to " + "the next Monday 12:00 UTC and repeat weekly on Monday at 12:00 UTC. " + "Leave entrypoint as null unless there is already a proven stored function. " + "Do not mark it offline." + ) + scheduled_result = await scheduled.result() + + # User: "Whenever Alice emails about invoices, summarize it and draft a reply." + triggered = await primitives.tasks.update( + "Create a live triggerable task. Name: Alice invoice email follow-up. " + "Description: When Alice emails about invoices, summarize the inbound email, " + "identify what action is needed, and draft a reply for review. Resolve Alice " + "to the right contact id before setting trigger filters. Leave entrypoint as " + "null; this should wake a live actor to interpret the description." + ) + triggered_result = await triggered.result() + return f"{scheduled_result}\\n{triggered_result}" +""" + + def get_primitives_dynamic_methods_example() -> str: """Example: using dynamic handle methods.""" @@ -1929,6 +1961,7 @@ def get_example_function_map() -> dict[str, callable]: "get_primitives_contact_update_example": get_primitives_contact_update_example, # Tasks "get_primitives_task_execute_example": get_primitives_task_execute_example, + "get_primitives_task_recurring_creation_example": get_primitives_task_recurring_creation_example, "get_primitives_dynamic_methods_example": get_primitives_dynamic_methods_example, # Knowledge "get_primitives_knowledge_ask_example": get_primitives_cross_manager_example, diff --git a/unity/task_scheduler/README.md b/unity/task_scheduler/README.md index 4a93d3a24..5cb72489e 100644 --- a/unity/task_scheduler/README.md +++ b/unity/task_scheduler/README.md @@ -72,6 +72,23 @@ This package manages the creation, scheduling, execution, and re‑ordering of t 3) Execute (run now) - Guards single‑active. If given a numeric id, can run in isolation (detach, followers keep schedule) or as a chain (preserve links). This path does not use an async LLM tool loop or an execute system prompt; it returns an `ActiveQueue` handle (direct delegation for isolated/single‑task). +4) Scheduled activation + - User-authored scheduled task rows are projected by Orchestra into machine-facing activation rows. + - Communication materializes scheduled live activations as Cloud Tasks targeting the adapters `/scheduled/tasks/due` endpoint. + - The live wake reason is delivered to ConversationManager, which asks the slow brain to start with `primitives.tasks.execute(task_id=...)`. + - Cloud Scheduler is used for platform maintenance jobs; per-task cadence is delivered by dynamic Cloud Tasks. + +5) Trigger activation + - Trigger definitions are projected into activation rows and mechanically matched by medium/contact filters when inbound communication events arrive. + - Live trigger candidates are surfaced to the slow brain, which performs semantic acceptance and calls `primitives.tasks.execute(task_id=..., trigger_attempt_token=...)` so the run adopts the exact inbound provenance. + - Recurring triggerable tasks clone a future triggerable instance before the current instance is marked active. + +6) Offline activation + - Offline means the hidden headless lane: the live ConversationManager and main actor are not woken. + - Offline scheduled activations use Cloud Tasks targeting Communication's offline-dispatch endpoint, which creates a short-lived Unity Kubernetes job. + - The job runs `offline_runner.py`, which executes exactly one stored FunctionManager entrypoint through `SingleFunctionActor(headless=True)`. + - Offline tasks require an entrypoint. Description-only tasks should remain live unless a later successful run is distilled into a stored function. + ### Queue/schedule invariants (enforced centrally) @@ -100,6 +117,13 @@ This package manages the creation, scheduling, execution, and re‑ordering of t - `ActiveTask`: internal steerable handle for a single running task; mirrors status and clears the scheduler’s active pointer when done. - `ActiveQueue`: public execution handle that sequences tasks using persisted `next_task` links, supports interjection routing across the queue, and provides a completion summary. Uses direct delegation when the queue is a singleton/isolated. +### Entrypoints and description-driven execution + +- `entrypoint` is optional for live tasks. When it is null, execution is actor-driven: a contained child actor run interprets the task name, description, schedule/trigger metadata, repeat pattern, and response policy. +- `entrypoint` is required for offline tasks because the headless lane executes one stored function without booting the live assistant runtime. +- Direct `TaskScheduler.execute(...)` needs either a run-scoped actor delegate or an explicitly configured actor. A production live wake normally reaches execution through `Actor.act` and `primitives.tasks.execute(...)`; tests can still inject a simulated actor explicitly. +- After a successful recurring or triggerable description-driven run, the actor always runs a storage review that considers whether the observed trajectory is stable enough to store as a function. The write is conditional: if future runs still need broad planning or tool discovery, the task remains description-driven. Stored functions may still use focused `reason(...)` calls for bounded judgment. + ### Clarification and contacts diff --git a/unity/task_scheduler/base.py b/unity/task_scheduler/base.py index 9ef4a76ba..cbb6de648 100644 --- a/unity/task_scheduler/base.py +++ b/unity/task_scheduler/base.py @@ -246,6 +246,24 @@ async def update( If the task is to be started *immediately*, then just put the current datetime as the `start_at`, and omit the deadline if one is not specified. + Entrypoints, live tasks, and offline tasks + ----------------------------------------- + A live scheduled or triggered task may start with ``entrypoint=None``. + In that case, execution wakes a contained actor run that interprets the + task's natural-language name/description and metadata. This is the + normal default for newly described recurring workflows. + + Offline tasks run in the hidden headless lane and must have a numeric + entrypoint before ``offline=True`` is set. Do not create a description-only + offline task. + + Do not create an entrypoint function merely because a new recurring task + was described. Entrypoint persistence should follow an explicit user + request or a successful execution reviewed as stable enough to store. + Stored functions may still use focused ``reason(...)`` calls for bounded + semantic judgment, but open-ended planning/tool discovery should remain + actor-driven. + All parameters mirror :pymeth:`ask`; refer there for detailed semantics. """ @@ -290,11 +308,12 @@ async def execute( Execution delegation -------------------- - When a run-scoped execution environment is available, task execution may be - delegated to that environment to maintain context continuity. Otherwise, - execution proceeds through the scheduler's configured execution strategy. - In both cases, a live steerable handle is returned that supports the full - steering interface. + When a run-scoped execution environment is available, task execution is + delegated through that environment while keeping one returned handle per + task run. A task without an entrypoint is executed by a contained actor + run dedicated to that task. Without a run-scoped delegate, direct + execution requires an explicitly configured actor; otherwise execution + fails loudly instead of silently using a simulated fallback. Returns ------- From 407a05f7edd6ac88ba0279e99c2ff02700953d4b Mon Sep 17 00:00:00 2001 From: Haris Mahmood Date: Tue, 12 May 2026 15:34:48 +0500 Subject: [PATCH 15/17] test(actor): cover task workflow execution guidance Verify child actor slot selection, reusable workflow review labeling, and real actor creation of live recurring and triggerable tasks with null entrypoints. --- tests/actor/code_act/test_prompt_builders.py | 38 ++++++++++ tests/actor/code_act/test_storage_on_stop.py | 71 +++++++++++++++++ .../code_act/test_task_execution_delegate.py | 57 +++++++++++++- .../tasks/test_recurring_creation_code_act.py | 76 +++++++++++++++++++ 4 files changed, 241 insertions(+), 1 deletion(-) create mode 100644 tests/actor/state_managers/real/tasks/test_recurring_creation_code_act.py diff --git a/tests/actor/code_act/test_prompt_builders.py b/tests/actor/code_act/test_prompt_builders.py index 13401f9bb..1552b785f 100644 --- a/tests/actor/code_act/test_prompt_builders.py +++ b/tests/actor/code_act/test_prompt_builders.py @@ -32,6 +32,17 @@ def get_tools(self) -> dict: return {} +class _DummyToolEnv(_DummyEnv): + """Minimal environment stub with configurable tool names.""" + + def __init__(self, prompt_context: str, tools: dict[str, Any]): + super().__init__(prompt_context) + self._tools = tools + + def get_tools(self) -> dict: + return self._tools + + def _real_envs_mixed() -> Mapping[str, Any]: """Real environments that produce self-contained prompt context.""" from unity.function_manager.primitives import ComputerPrimitives @@ -111,6 +122,33 @@ def test_code_act_prompt_includes_diverse_examples_sessions_computer_primitives_ assert "execute_function vs execute_code decision" in prompt +@pytest.mark.timeout(30) +def test_code_act_prompt_includes_task_workflow_guidance_only_with_task_primitives(): + prompt_with_tasks = build_code_act_prompt( + environments={ + "primitives": _DummyToolEnv( + "Task primitives are available.", + {"primitives.tasks.update": object()}, + ), + }, + tools={}, + ) + prompt_without_tasks = build_code_act_prompt( + environments={ + "primitives": _DummyToolEnv( + "Only contact primitives are available.", + {"primitives.contacts.ask": object()}, + ), + }, + tools={}, + ) + + assert "Durable Scheduled And Triggered Workflows" in prompt_with_tasks + assert "`entrypoint=None`" in prompt_with_tasks + assert "primitives.tasks.execute(task_id=...)" in prompt_with_tasks + assert "Durable Scheduled And Triggered Workflows" not in prompt_without_tasks + + @pytest.mark.timeout(30) def test_code_act_prompt_teaches_refresh_token_oauth_helper(): actor = CodeActActor() diff --git a/tests/actor/code_act/test_storage_on_stop.py b/tests/actor/code_act/test_storage_on_stop.py index 725bc06fd..7121d05d5 100644 --- a/tests/actor/code_act/test_storage_on_stop.py +++ b/tests/actor/code_act/test_storage_on_stop.py @@ -19,6 +19,7 @@ import pytest from unity.actor.code_act_actor import CodeActActor, _StorageCheckHandle +from unity.common.task_execution_context import PostRunReviewContext # --------------------------------------------------------------------------- # Symbolic: _StorageCheckHandle runs Phase 2 after stop @@ -230,6 +231,76 @@ async def _stop(**kwargs): ), f"StorageCheck incoming event must have a non-null instructions kwarg, got {instructions!r}" +@pytest.mark.asyncio +@pytest.mark.timeout(30) +async def test_task_entrypoint_review_uses_reusable_workflow_event_label(): + result_future: asyncio.Future[str] = asyncio.get_event_loop().create_future() + + inner = MagicMock() + + async def _await_result(): + return await result_future + + inner.result = _await_result + inner.next_notification = AsyncMock(side_effect=lambda: asyncio.Event().wait()) + + async def _stop(**kwargs): + if not result_future.done(): + result_future.set_result("done") + + inner.stop = AsyncMock(side_effect=_stop) + inner._client = MagicMock(messages=[{"role": "user", "content": "do task"}]) + + mock_task = MagicMock() + mock_task.get_ask_tools = MagicMock(return_value={}) + mock_task.get_completed_tool_metadata = MagicMock(return_value={}) + inner._task = mock_task + + actor = MagicMock() + actor.function_manager = None + actor.guidance_manager = None + + review_context = PostRunReviewContext( + display_label="Storing reusable workflow", + instructions="Review the completed recurring workflow.", + extensions={"task_entrypoint_review": {"metadata": {"task_id": 1}}}, + ) + + with ( + patch("unity.actor.code_act_actor._start_storage_check_loop") as mock_loop, + patch( + "unity.actor.code_act_actor.publish_manager_method_event", + new_callable=AsyncMock, + ) as mock_publish, + ): + mock_loop.return_value = None + handle = _StorageCheckHandle( + inner=inner, + actor=actor, + post_run_review_context=review_context, + ) + + result_future.set_result("done") + + deadline = asyncio.get_event_loop().time() + 10 + while not handle.done(): + if asyncio.get_event_loop().time() > deadline: + raise TimeoutError("Handle did not complete") + await asyncio.sleep(0.1) + + incoming_calls = [ + call + for call in mock_publish.call_args_list + if call.kwargs.get("phase") == "incoming" and call.args[2] == "StorageCheck" + ] + assert incoming_calls[0].kwargs["display_label"] == "Storing reusable workflow" + assert ( + incoming_calls[0].kwargs["instructions"] + == "Review the completed recurring workflow." + ) + assert mock_loop.call_args.kwargs["post_run_review_context"] is review_context + + # --------------------------------------------------------------------------- # Eval: persist=True + stop with memoize intent stores a function # --------------------------------------------------------------------------- diff --git a/tests/actor/code_act/test_task_execution_delegate.py b/tests/actor/code_act/test_task_execution_delegate.py index 6fc925a00..17b205904 100644 --- a/tests/actor/code_act/test_task_execution_delegate.py +++ b/tests/actor/code_act/test_task_execution_delegate.py @@ -5,8 +5,9 @@ import pytest from tests.helpers import _handle_project -from unity.actor.code_act_actor import CodeActActor +from unity.actor.code_act_actor import CodeActActor, _CodeActTaskExecutionDelegate from unity.actor.environments.state_managers import StateManagerEnvironment +from unity.actor.simulated import SimulatedActor from unity.common.task_execution_context import current_task_execution_delegate from unity.function_manager.function_manager import FunctionManager from unity.function_manager.primitives import PrimitiveScope, Primitives @@ -15,6 +16,60 @@ from unity.task_scheduler.types.status import Status +@pytest.mark.asyncio +async def test_codeact_task_delegate_runs_description_tasks_in_child_actor_slot(): + calls = [] + actor = SimulatedActor(steps=0) + + original_act = actor.act + + async def _spy_act(*args, **kwargs): + calls.append(kwargs) + return await original_act(*args, **kwargs) + + actor.act = _spy_act # type: ignore[method-assign] + delegate = _CodeActTaskExecutionDelegate(actor) # type: ignore[arg-type] + + handle = await delegate.start_task_run( + task_description="Run the description-driven task.", + entrypoint=None, + parent_chat_context=None, + clarification_up_q=None, + clarification_down_q=None, + ) + await handle.result() + + assert calls[0]["_reuse_actor_slot"] is False + assert calls[0]["persist"] is False + + +@pytest.mark.asyncio +async def test_codeact_task_delegate_reuses_actor_slot_for_entrypoint_tasks(): + calls = [] + actor = SimulatedActor(steps=0) + + original_act = actor.act + + async def _spy_act(*args, **kwargs): + calls.append(kwargs) + return await original_act(*args, **kwargs) + + actor.act = _spy_act # type: ignore[method-assign] + delegate = _CodeActTaskExecutionDelegate(actor) # type: ignore[arg-type] + + handle = await delegate.start_task_run( + task_description="Run the function-backed task.", + entrypoint=123, + parent_chat_context=None, + clarification_up_q=None, + clarification_down_q=None, + ) + await handle.result() + + assert calls[0]["_reuse_actor_slot"] is True + assert calls[0]["entrypoint"] == 123 + + @pytest.mark.asyncio @pytest.mark.llm_call @pytest.mark.timeout(120) diff --git a/tests/actor/state_managers/real/tasks/test_recurring_creation_code_act.py b/tests/actor/state_managers/real/tasks/test_recurring_creation_code_act.py new file mode 100644 index 000000000..943a4408b --- /dev/null +++ b/tests/actor/state_managers/real/tasks/test_recurring_creation_code_act.py @@ -0,0 +1,76 @@ +from __future__ import annotations + +import pytest + +from tests.actor.state_managers.utils import make_code_act_actor +from unity.task_scheduler.types.status import Status + +pytestmark = [pytest.mark.eval, pytest.mark.llm_call] + + +@pytest.mark.asyncio +@pytest.mark.timeout(300) +async def test_code_act_creates_live_recurring_task_with_null_entrypoint(): + async with make_code_act_actor(impl="real", exposed_managers={"tasks"}) as ( + actor, + primitives, + calls, + ): + handle = await actor.act( + ( + "Create exactly one live scheduled recurring task using " + "primitives.tasks.update. Name it exactly 'Controlled weekly AI report'. " + "Description: Every Monday at 12:00 UTC, research important AI and " + "agentic AI work from the previous week, summarize the key developments, " + "and email me a concise report. Set the first run for the next Monday " + "at 12:00 UTC and repeat weekly. Do not create or attach any entrypoint " + "function, do not mark it offline, and do not execute it now." + ), + clarification_enabled=False, + ) + result = await handle.result() + + assert result is not None + assert "primitives.tasks.update" in set(calls) + + rows = primitives.tasks._filter_tasks(filter="task_id >= 0") + task = [row for row in rows if row.name == "Controlled weekly AI report"][0] + assert task.offline is False + assert task.entrypoint is None + assert task.schedule is not None + assert task.repeat is not None + assert task.status == Status.scheduled + + +@pytest.mark.asyncio +@pytest.mark.timeout(300) +async def test_code_act_creates_live_triggerable_task_with_null_entrypoint(): + async with make_code_act_actor(impl="real", exposed_managers={"tasks"}) as ( + actor, + primitives, + calls, + ): + handle = await actor.act( + ( + "Create exactly one live triggerable task using primitives.tasks.update. " + "Name it exactly 'Controlled invoice email follow-up'. Description: " + "Whenever an inbound email about invoices arrives, summarize the email, " + "identify the needed action, and draft a reply for review. Use an email " + "trigger, leave entrypoint null, do not mark it offline, and do not " + "execute it now." + ), + clarification_enabled=False, + ) + result = await handle.result() + + assert result is not None + assert "primitives.tasks.update" in set(calls) + + rows = primitives.tasks._filter_tasks(filter="task_id >= 0") + task = [ + row for row in rows if row.name == "Controlled invoice email follow-up" + ][0] + assert task.offline is False + assert task.entrypoint is None + assert task.trigger is not None + assert task.status == Status.triggerable From 91a8e550c3623969de10dc92a869a62e8b2228d5 Mon Sep 17 00:00:00 2001 From: Haris Mahmood Date: Tue, 12 May 2026 15:35:05 +0500 Subject: [PATCH 16/17] test(tasks): cover description-driven recurring execution Add coverage for explicit actor requirements, entrypoint review context propagation, recurring clone timing, future instance patching, and task execution prompt builders. --- tests/task_scheduler/test_execute.py | 62 +++++++++++++- tests/task_scheduler/test_prompt_builders.py | 57 +++++++++++++ tests/task_scheduler/test_repetition.py | 86 ++++++++++++++++++++ tests/task_scheduler/test_trigger.py | 17 +++- 4 files changed, 220 insertions(+), 2 deletions(-) create mode 100644 tests/task_scheduler/test_prompt_builders.py diff --git a/tests/task_scheduler/test_execute.py b/tests/task_scheduler/test_execute.py index 26cf5232d..49482c15d 100644 --- a/tests/task_scheduler/test_execute.py +++ b/tests/task_scheduler/test_execute.py @@ -23,7 +23,9 @@ from unity.actor.simulated import SimulatedActorHandle from unity.task_scheduler.types.schedule import Schedule from unity.task_scheduler.types.activated_by import ActivatedBy +from unity.task_scheduler.types.repetition import Frequency, RepeatPattern from unity.task_scheduler.types.status import Status +from unity.common.task_execution_context import current_post_run_review_context # The helper used in the existing test‑suite – applies project‑level monkey‐ # patches (e.g. env vars, tracers) so we keep behaviour consistent. @@ -231,7 +233,7 @@ async def test_execute_interject(monkeypatch): @functools.wraps(original_interject) async def spy_interject(self, instruction: str, *, images=None) -> None: # type: ignore[override] calls["interject"] += 1 - await original_interject(self, instruction, images=images) + await original_interject(self, instruction) monkeypatch.setattr(SimulatedActorHandle, "interject", spy_interject, raising=True) @@ -344,6 +346,7 @@ async def test_execute_result_and_done(): # Perform an interjection for activity, then stop explicitly await task.interject("Provide initial outline first.") await task.stop(cancel=False) + await task.result() assert task.done(), "`done()` must return True after explicit stop" @@ -546,6 +549,63 @@ async def test_execute_sets_activated_by_explicit(): assert any(r.activated_by == ActivatedBy.explicit for r in rows) +@pytest.mark.asyncio +@_handle_project +async def test_execute_without_delegate_or_actor_fails_before_mutation(): + ts = TaskScheduler() + task_id = ts._create_task(name="Needs actor", description="Needs actor")["details"][ + "task_id" + ] + initial_status = ts._get_task_or_raise(task_id).status + + with pytest.raises(RuntimeError, match="run-scoped actor delegate"): + await ts.execute(task_id=task_id) + + row = ts._get_task_or_raise(task_id) + assert row.status == initial_status + assert ts._active_task is None + + +@pytest.mark.asyncio +@_handle_project +async def test_direct_description_driven_recurring_execution_passes_entrypoint_review(): + calls = [] + actor = SimulatedActor(steps=0) + original_act = actor.act + + async def _spy_act(*args, **kwargs): + calls.append( + { + "kwargs": kwargs, + "post_run_review_context": current_post_run_review_context.get(), + }, + ) + return await original_act(*args, **kwargs) + + actor.act = _spy_act # type: ignore[method-assign] + ts = TaskScheduler(actor=actor) + task_id = ts._create_task( + name="Recurring no-entrypoint task", + description="Run from the natural-language description every day.", + status=Status.scheduled, + schedule=Schedule(start_at=datetime.now(timezone.utc)), + repeat=[RepeatPattern(frequency=Frequency.DAILY)], + )["details"]["task_id"] + + handle = await ts.execute(task_id=task_id) + await handle.result() + + assert "task_entrypoint_review" not in calls[0]["kwargs"] + post_run_review_context = calls[0]["post_run_review_context"] + assert post_run_review_context is not None + assert post_run_review_context.display_label == "Storing reusable workflow" + review = post_run_review_context.extensions.get("task_entrypoint_review") + assert review is not None + assert review["metadata"]["task_id"] == task_id + assert review["metadata"]["task_name"] == "Recurring no-entrypoint task" + assert callable(review["attach_entrypoint"]) + + @pytest.mark.asyncio @_handle_project async def test_update_status_cannot_force_active(): diff --git a/tests/task_scheduler/test_prompt_builders.py b/tests/task_scheduler/test_prompt_builders.py new file mode 100644 index 000000000..27669f1e3 --- /dev/null +++ b/tests/task_scheduler/test_prompt_builders.py @@ -0,0 +1,57 @@ +from datetime import datetime, timezone + +from unity.task_scheduler.prompt_builders import ( + build_task_execution_request, + build_task_run_guidelines, +) +from unity.task_scheduler.types.activated_by import ActivatedBy +from unity.task_scheduler.types.priority import Priority +from unity.task_scheduler.types.repetition import Frequency, RepeatPattern +from unity.task_scheduler.types.schedule import Schedule +from unity.task_scheduler.types.status import Status +from unity.task_scheduler.types.task import Task + + +def test_build_task_execution_request_includes_run_metadata(): + task = Task( + task_id=7, + instance_id=2, + name="Weekly AI report", + description="Summarize the previous week's AI research.", + status=Status.scheduled, + priority=Priority.normal, + response_policy="Email the user a concise document.", + schedule=Schedule(start_at=datetime(2026, 5, 18, 12, 0, tzinfo=timezone.utc)), + repeat=[RepeatPattern(frequency=Frequency.WEEKLY)], + ) + + request = build_task_execution_request(task) + + assert "Execute this TaskScheduler task as a contained task run." in request + assert "Task id: 7" in request + assert "Instance id: 2" in request + assert "Weekly AI report" in request + assert "Summarize the previous week's AI research." in request + assert "Task response policy:" in request + assert "Schedule metadata:" in request + assert "Repeat metadata:" in request + + +def test_build_task_run_guidelines_keep_child_actor_focused_on_one_task(): + task = Task( + task_id=3, + instance_id=1, + name="Invoice follow-up", + description="Draft an invoice reply.", + status=Status.triggerable, + priority=Priority.normal, + ) + + guidelines = build_task_run_guidelines(task, ActivatedBy.trigger) + + assert "executing exactly one TaskScheduler task" in guidelines + assert "do not create another task" in guidelines + assert "interpret the natural-language description" in guidelines + assert "Activation reason: trigger" in guidelines + assert "Task id: 3" in guidelines + assert "Instance id: 1" in guidelines diff --git a/tests/task_scheduler/test_repetition.py b/tests/task_scheduler/test_repetition.py index ac289978e..233258187 100644 --- a/tests/task_scheduler/test_repetition.py +++ b/tests/task_scheduler/test_repetition.py @@ -1,6 +1,9 @@ from datetime import datetime, timedelta, timezone +import pytest + from tests.helpers import _handle_project +from unity.actor.simulated import SimulatedActor from unity.task_scheduler.task_scheduler import TaskScheduler from unity.task_scheduler.types.repetition import ( Frequency, @@ -76,6 +79,89 @@ def test_clone_task_instance_rearms_recurring_scheduled_task(): assert latest.instance_id == 1 assert latest.status == Status.scheduled assert latest.schedule_start_at == initial_start + timedelta(days=1) + assert latest.entrypoint is None + + +@_handle_project +def test_entrypoint_review_patches_future_description_driven_instances(): + scheduler = TaskScheduler() + initial_start = datetime.now(timezone.utc).replace(microsecond=0) - timedelta( + hours=1, + ) + scheduler._create_task( + name="Daily description-driven summary", + description="Summarize updates every day.", + status=Status.scheduled, + schedule=Schedule(start_at=initial_start.isoformat()), + repeat=[RepeatPattern(frequency=Frequency.DAILY)], + ) + + current = scheduler._get_task_or_raise(0) + scheduler._clone_task_instance(current) + result = scheduler._attach_entrypoint_to_future_instances( + task_id=0, + completed_instance_id=0, + function_id=321, + rationale="The successful run revealed a stable workflow.", + ) + + rows = scheduler._filter_tasks(filter="task_id == 0") + current_row = min(rows, key=lambda task: task.instance_id) + future_row = max(rows, key=lambda task: task.instance_id) + + assert result["outcome"] == "attached" + assert current_row.entrypoint is None + assert future_row.entrypoint == 321 + + +@pytest.mark.asyncio +@_handle_project +async def test_recurring_execution_clones_before_entrypoint_review_patch(): + scheduler = TaskScheduler(actor=SimulatedActor(steps=0)) + initial_start = datetime.now(timezone.utc).replace(microsecond=0) - timedelta( + hours=1, + ) + scheduler._create_task( + name="Daily report", + description="Run the daily report from the task description.", + status=Status.scheduled, + schedule=Schedule(start_at=initial_start.isoformat()), + repeat=[RepeatPattern(frequency=Frequency.DAILY)], + ) + + handle = await scheduler.execute(task_id=0) + await handle.result() + + rows_after_run = sorted( + scheduler._filter_tasks(filter="task_id == 0"), + key=lambda task: task.instance_id, + ) + assert [row.instance_id for row in rows_after_run] == [0, 1] + assert rows_after_run[0].entrypoint is None + assert rows_after_run[1].entrypoint is None + + result = scheduler._attach_entrypoint_to_future_instances( + task_id=0, + completed_instance_id=0, + function_id=321, + rationale="The completed run was stable enough to reuse.", + ) + assert result["outcome"] == "attached" + + patched_next = [ + row + for row in scheduler._filter_tasks(filter="task_id == 0") + if row.instance_id == 1 + ][0] + assert patched_next.entrypoint == 321 + + scheduler._clone_task_instance(patched_next) + cloned_from_patched = [ + row + for row in scheduler._filter_tasks(filter="task_id == 0") + if row.instance_id == 2 + ][0] + assert cloned_from_patched.entrypoint == 321 @_handle_project diff --git a/tests/task_scheduler/test_trigger.py b/tests/task_scheduler/test_trigger.py index 8cb910c96..0126f9d99 100644 --- a/tests/task_scheduler/test_trigger.py +++ b/tests/task_scheduler/test_trigger.py @@ -8,6 +8,7 @@ import pytest from tests.helpers import _handle_project +from unity.actor.simulated import SimulatedActor from unity.task_scheduler.task_scheduler import TaskScheduler from unity.task_scheduler.types.status import Status from unity.task_scheduler.types.schedule import Schedule @@ -178,7 +179,7 @@ async def test_triggerable_start_clones_instance(): • create a **new** row with the same `task_id` but `instance_id` 1 that remains in the *triggerable* state """ - ts = TaskScheduler() + ts = TaskScheduler(actor=SimulatedActor(steps=None, duration=None)) trig = Trigger(medium=Medium.EMAIL, recurring=False) tid = ts._create_task( @@ -202,6 +203,20 @@ async def test_triggerable_start_clones_instance(): assert status_by_inst[0] == Status.active assert status_by_inst[1] == Status.triggerable + result = ts._attach_entrypoint_to_future_instances( + task_id=tid, + completed_instance_id=0, + function_id=654, + rationale="The triggered run revealed a stable reusable workflow.", + ) + assert result["outcome"] == "attached" + future_row = [ + row + for row in ts._filter_tasks(filter=f"task_id == {tid}") + if row.instance_id == 1 + ][0] + assert future_row.entrypoint == 654 + # Clean-up (avoid background thread leaks) await handle.stop(cancel=True) await handle.result() From 6fede364f542ea0653e3eed5f0a52f13d8b0f369 Mon Sep 17 00:00:00 2001 From: Haris Mahmood Date: Tue, 12 May 2026 15:35:11 +0500 Subject: [PATCH 17/17] test(tasks): inject simulated actors in execution tests Update scheduler tests to provide explicit simulated actors now that direct execution no longer creates an implicit fallback actor. --- tests/task_scheduler/test_active_queue.py | 23 +++++++++++++++++++--- tests/task_scheduler/test_active_task.py | 2 +- tests/task_scheduler/test_event_logging.py | 5 +++-- tests/task_scheduler/test_reintegration.py | 19 ++++++++++++++++-- 4 files changed, 41 insertions(+), 8 deletions(-) diff --git a/tests/task_scheduler/test_active_queue.py b/tests/task_scheduler/test_active_queue.py index c2d878795..1652627b2 100644 --- a/tests/task_scheduler/test_active_queue.py +++ b/tests/task_scheduler/test_active_queue.py @@ -8,12 +8,26 @@ import pytest from tests.helpers import _handle_project -from unity.task_scheduler.task_scheduler import TaskScheduler +from unity.task_scheduler import task_scheduler as task_scheduler_module +from unity.task_scheduler.task_scheduler import TaskScheduler as _TaskScheduler from unity.task_scheduler.types.task import Task -from unity.actor.simulated import SimulatedActor, SimulatedActorHandle +from unity.actor.simulated import ( + SimulatedActor, + SimulatedActorHandle, + _StaticAnswerHandle, +) from unity.task_scheduler.types.status import Status import inspect +task_scheduler_module.SimulatedActor = SimulatedActor + + +def TaskScheduler(*args, **kwargs): + if "actor" not in kwargs: + actor_cls = getattr(task_scheduler_module, "SimulatedActor", SimulatedActor) + kwargs["actor"] = actor_cls() + return _TaskScheduler(*args, **kwargs) + async def _make_ordered_queue( ts: TaskScheduler, @@ -625,6 +639,9 @@ class _FakeQueueClient: def __init__(self, *a, **kw): pass + def set_on_log_file_pending(self, callback): + return None + def set_system_message(self, sys_msg): try: prompt_capture["system"] = str(sys_msg) @@ -1203,7 +1220,7 @@ async def spy_actor_ask(self, question: str): # type: ignore[override] self.simulate_step() except Exception: pass - return "OK" + return _StaticAnswerHandle("OK") async def spy_actor_interject(self, instruction: str, *, images=None): # type: ignore[override] interject_calls["count"] += 1 diff --git a/tests/task_scheduler/test_active_task.py b/tests/task_scheduler/test_active_task.py index 4b6e03f40..d7643f7b6 100644 --- a/tests/task_scheduler/test_active_task.py +++ b/tests/task_scheduler/test_active_task.py @@ -118,7 +118,7 @@ async def test_interject(monkeypatch): @functools.wraps(original_interject) async def spy_interject(self, instruction: str, *, images=None) -> None: # type: ignore[override] calls["interject"] += 1 - return await original_interject(self, instruction, images=images) + return await original_interject(self, instruction) monkeypatch.setattr(SimulatedActorHandle, "interject", spy_interject, raising=True) diff --git a/tests/task_scheduler/test_event_logging.py b/tests/task_scheduler/test_event_logging.py index facbd7bbb..63d3150d4 100644 --- a/tests/task_scheduler/test_event_logging.py +++ b/tests/task_scheduler/test_event_logging.py @@ -2,6 +2,7 @@ import pytest +from unity.actor.simulated import SimulatedActor from unity.task_scheduler.task_scheduler import TaskScheduler from tests.helpers import _handle_project, capture_events @@ -66,7 +67,7 @@ async def test_managermethod_events_for_update(): @pytest.mark.asyncio @_handle_project async def test_managermethod_events_for_execute(): - ts = TaskScheduler() + ts = TaskScheduler(actor=SimulatedActor(steps=0)) # create a simple task first outcome = ts._create_task(name="Demo", description="Run a demo task") @@ -81,7 +82,7 @@ async def test_managermethod_events_for_execute(): for e in events if e.payload.get("manager") == "TaskScheduler" and e.payload.get("method") == "execute" - and e.payload.get("request") == task_id + and e.payload.get("phase") == "incoming" ] assert incoming call_id = incoming[0].calling_id diff --git a/tests/task_scheduler/test_reintegration.py b/tests/task_scheduler/test_reintegration.py index 9be0b3b30..bc748fa1c 100644 --- a/tests/task_scheduler/test_reintegration.py +++ b/tests/task_scheduler/test_reintegration.py @@ -5,13 +5,25 @@ import pytest from tests.helpers import _handle_project -from unity.task_scheduler.task_scheduler import TaskScheduler +from unity.actor.simulated import SimulatedActor +from unity.task_scheduler import task_scheduler as task_scheduler_module +from unity.task_scheduler.task_scheduler import TaskScheduler as _TaskScheduler from unity.task_scheduler.types.schedule import Schedule from unity.task_scheduler.types.status import Status from unity.task_scheduler.types.trigger import Trigger, Medium pytestmark = pytest.mark.llm_call +task_scheduler_module.SimulatedActor = SimulatedActor + + +def TaskScheduler(*args, **kwargs): + if "actor" not in kwargs: + actor_cls = getattr(task_scheduler_module, "SimulatedActor", SimulatedActor) + kwargs["actor"] = actor_cls() + return _TaskScheduler(*args, **kwargs) + + # Speed up only this module's SimulatedActor by monkeypatching the class symbols # used by TaskScheduler to a shorter-duration variant. This does not affect # other test modules. @@ -424,7 +436,9 @@ async def test_reintegration_plan_clears_on_completion(): async def test_chain_then_defer_restores_next_head_start_at(monkeypatch): from datetime import datetime, timezone, timedelta - ts = TaskScheduler() + ts = _TaskScheduler( + actor=SimulatedActor(steps=None, duration=None, hold_completion=True), + ) # Chain execution is the default; no environment variable required. @@ -455,6 +469,7 @@ async def test_chain_then_defer_restores_next_head_start_at(monkeypatch): # Start the head in chain mode but only allow the head to complete handle = await ts.execute(task_id=head_id) + handle._current_handle._actor_handle.trigger_completion() # Wait for just the head to finish await handle._active_task_done()