Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 57 additions & 0 deletions tests/actor/code_act/test_execute_code_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,63 @@ def get_result(out: Any) -> Any:
return get_output_field(out, "result", None)


# ---------------------------------------------------------------------------
# Test: Runtime OAuth token helper exposed to real actor execute_code
# ---------------------------------------------------------------------------


class _FakeOAuthSecretManager:
def __init__(self) -> None:
self.calls: list[Any] = []
self.secrets = {
"MICROSOFT_ACCESS_TOKEN": "microsoft:fresh-token",
"MICROSOFT_TOKEN_EXPIRES_AT": "2999-01-01T00:00:00+00:00",
"GOOGLE_ACCESS_TOKEN": "google:fresh-token",
"GOOGLE_TOKEN_EXPIRES_AT": "2999-01-01T00:00:00+00:00",
}

def sync_assistant_secrets_if_stale(self, **kwargs: Any) -> bool:
self.calls.append(("sync", kwargs))
return True

def _get_secret_value(self, name: str) -> str | None:
self.calls.append(("secret", name))
return self.secrets.get(name)


@pytest.mark.asyncio
async def test_execute_code_oauth_helper_uses_parent_secret_manager(
execute_code_tool: tuple[Any, Primitives],
monkeypatch: pytest.MonkeyPatch,
) -> None:
execute_code, _ = execute_code_tool
fake_secret_manager = _FakeOAuthSecretManager()
monkeypatch.setattr(
ManagerRegistry,
"get_secret_manager",
lambda: fake_secret_manager,
)

out = await execute_code(
"mock scenario: call rotating OAuth token helper for multiple providers",
"""
microsoft_token = get_oauth_access_token("microsoft", min_ttl_seconds=123)
google_token = get_oauth_access_token("google", min_ttl_seconds=456)
assert microsoft_token == "microsoft:fresh-token"
assert google_token == "google:fresh-token"
print("TOKEN_OK")
""",
language="python",
state_mode="stateless",
)

assert get_error(out) is None
assert "TOKEN_OK" in get_stdout_text(out)
assert ("secret", "MICROSOFT_ACCESS_TOKEN") in fake_secret_manager.calls
assert ("secret", "GOOGLE_ACCESS_TOKEN") in fake_secret_manager.calls
assert any(call[0] == "sync" for call in fake_secret_manager.calls)


# ---------------------------------------------------------------------------
# Test: Basic stdout capture from primitives.*.ask().result()
# ---------------------------------------------------------------------------
Expand Down
20 changes: 20 additions & 0 deletions tests/actor/code_act/test_prompt_builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,26 @@ def test_code_act_prompt_includes_diverse_examples_sessions_computer_primitives_
assert "execute_function vs execute_code decision" in prompt


@pytest.mark.timeout(30)
def test_code_act_prompt_teaches_refresh_token_oauth_helper():
actor = CodeActActor()
prompt = build_code_act_prompt(
environments={},
tools=dict(actor.get_tools("act")),
)

assert "def reason(" in prompt
assert (
"def get_oauth_access_token(provider: str, *, "
"min_ttl_seconds: int = 300) -> str"
) in prompt
assert 'get_oauth_access_token("microsoft")' in prompt
assert 'get_oauth_access_token("google")' in prompt
assert "refresh-token backed OAuth" in prompt
assert "Do not print, log, return, or store the token value." in prompt
assert "provider sdk/default credential behavior" in prompt.lower()


@pytest.mark.timeout(30)
def test_code_act_prompt_includes_comms_namespace_and_docstrings():
from unity.actor.environments.state_managers import StateManagerEnvironment
Expand Down
138 changes: 138 additions & 0 deletions tests/common/test_runtime_oauth.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
from __future__ import annotations

from datetime import datetime, timedelta, timezone
from typing import Any

import pytest

from unity.common import runtime_oauth


def _future_expiry() -> str:
return (datetime.now(timezone.utc) + timedelta(hours=1)).isoformat()


def _past_expiry() -> str:
return (datetime.now(timezone.utc) - timedelta(minutes=5)).isoformat()


class _FakeSecretManager:
def __init__(self, secrets: dict[str, str]) -> None:
self.secrets = secrets
self.sync_calls: list[dict[str, Any]] = []
self.on_sync = None

def _get_secret_value(self, name: str) -> str | None:
return self.secrets.get(name)

def sync_assistant_secrets_if_stale(self, **kwargs: Any) -> bool:
self.sync_calls.append(kwargs)
if self.on_sync is not None:
self.on_sync()
return True


def _install_secret_manager(monkeypatch, sm: _FakeSecretManager) -> None:
monkeypatch.setattr(runtime_oauth, "_get_secret_manager", lambda: sm)


def test_get_oauth_access_token_supports_provider_alias(monkeypatch):
sm = _FakeSecretManager(
{
"MICROSOFT_ACCESS_TOKEN": "fresh-ms-token",
"MICROSOFT_TOKEN_EXPIRES_AT": _future_expiry(),
},
)
_install_secret_manager(monkeypatch, sm)

assert runtime_oauth.get_oauth_access_token("graph") == "fresh-ms-token"
assert sm.sync_calls[-1]["force"] is False


def test_get_oauth_access_token_unknown_provider_is_non_secret_error(monkeypatch):
sm = _FakeSecretManager({})
_install_secret_manager(monkeypatch, sm)

with pytest.raises(ValueError, match="Unknown refresh-token OAuth provider"):
runtime_oauth.get_oauth_access_token("not-a-real-provider")


def test_get_oauth_access_token_missing_token_raises_after_sync(monkeypatch):
sm = _FakeSecretManager({"MICROSOFT_TOKEN_EXPIRES_AT": _future_expiry()})
_install_secret_manager(monkeypatch, sm)

with pytest.raises(ValueError, match="No access token is available"):
runtime_oauth.get_oauth_access_token("microsoft")

assert sm.sync_calls[-1]["force"] is True


def test_get_oauth_access_token_forces_sync_when_token_is_expired(monkeypatch):
sm = _FakeSecretManager(
{
"MICROSOFT_ACCESS_TOKEN": "old-ms-token",
"MICROSOFT_TOKEN_EXPIRES_AT": _past_expiry(),
},
)
_install_secret_manager(monkeypatch, sm)

def refresh_token() -> None:
sm.secrets["MICROSOFT_ACCESS_TOKEN"] = "fresh-ms-token"
sm.secrets["MICROSOFT_TOKEN_EXPIRES_AT"] = _future_expiry()

sm.on_sync = refresh_token

assert runtime_oauth.get_oauth_access_token("microsoft") == "fresh-ms-token"
assert sm.sync_calls[-1]["force"] is True


def test_get_oauth_access_token_forces_sync_when_expiry_is_invalid(monkeypatch):
sm = _FakeSecretManager(
{
"GOOGLE_ACCESS_TOKEN": "old-google-token",
"GOOGLE_TOKEN_EXPIRES_AT": "not-a-date",
},
)
_install_secret_manager(monkeypatch, sm)

def refresh_token() -> None:
sm.secrets["GOOGLE_ACCESS_TOKEN"] = "fresh-google-token"
sm.secrets["GOOGLE_TOKEN_EXPIRES_AT"] = _future_expiry()

sm.on_sync = refresh_token

assert runtime_oauth.get_oauth_access_token("google") == "fresh-google-token"
assert sm.sync_calls[-1]["force"] is True


def test_get_oauth_access_token_supports_multiple_providers(monkeypatch):
sm = _FakeSecretManager(
{
"MICROSOFT_ACCESS_TOKEN": "fresh-ms-token",
"MICROSOFT_TOKEN_EXPIRES_AT": _future_expiry(),
"GOOGLE_ACCESS_TOKEN": "fresh-google-token",
"GOOGLE_TOKEN_EXPIRES_AT": _future_expiry(),
},
)
_install_secret_manager(monkeypatch, sm)

assert runtime_oauth.get_oauth_access_token("microsoft") == "fresh-ms-token"
assert runtime_oauth.get_oauth_access_token("google") == "fresh-google-token"


def test_get_refresh_token_oauth_env_overlay_returns_all_current_values(monkeypatch):
sm = _FakeSecretManager(
{
"MICROSOFT_ACCESS_TOKEN": "fresh-ms-token",
"MICROSOFT_TOKEN_EXPIRES_AT": _future_expiry(),
"GOOGLE_ACCESS_TOKEN": "fresh-google-token",
"GOOGLE_TOKEN_EXPIRES_AT": _future_expiry(),
},
)
_install_secret_manager(monkeypatch, sm)

overlay = runtime_oauth.get_refresh_token_oauth_env_overlay()

assert overlay["MICROSOFT_ACCESS_TOKEN"] == "fresh-ms-token"
assert overlay["GOOGLE_ACCESS_TOKEN"] == "fresh-google-token"
assert sm.sync_calls[-1]["reason"] == "oauth_env_overlay"
89 changes: 89 additions & 0 deletions tests/function_manager/test_runtime_oauth_bridge.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
from __future__ import annotations

from dataclasses import dataclass, field

import pytest

from unity.function_manager.function_manager import FunctionManager
from unity.function_manager import function_manager as function_manager_module


@dataclass
class _FakeSecretManager:
sync_reasons: list[str] = field(default_factory=list)

def _get_secret_value(self, name: str) -> str | None:
values = {
"GOOGLE_ACCESS_TOKEN": "google-token",
"GOOGLE_TOKEN_EXPIRES_AT": "2999-01-01T00:00:00+00:00",
}
return values.get(name)

def sync_assistant_secrets_if_stale(self, **kwargs) -> bool:
self.sync_reasons.append(kwargs["reason"])
return True


@pytest.mark.asyncio
async def test_shell_runtime_oauth_token_helper_uses_parent_rpc(monkeypatch):
fake_secret_manager = _FakeSecretManager()
monkeypatch.setattr(
function_manager_module.ManagerRegistry,
"get_secret_manager",
lambda: fake_secret_manager,
)

fm = object.__new__(FunctionManager)
result = await fm.execute_shell_script(
implementation=(
"#!/bin/sh\n"
"token=$(unity-primitive runtime get_oauth_access_token "
"--provider google --min_ttl_seconds 42)\n"
'if [ "$token" = \'"google-token"\' ]; then echo "TOKEN_OK"; fi\n'
),
language="sh",
)

assert result["error"] is None
assert result["result"] == 0
assert "TOKEN_OK" in result["stdout"]
assert fake_secret_manager.sync_reasons == ["oauth_access_token:google"]


def test_runtime_oauth_env_overlay_routes_through_runtime_helper(monkeypatch):
from unity.common import runtime_oauth

monkeypatch.setattr(
runtime_oauth,
"get_refresh_token_oauth_env_overlay",
lambda: {"GOOGLE_ACCESS_TOKEN": "fresh-google-token"},
)

fm = object.__new__(FunctionManager)

assert fm._get_runtime_oauth_env_overlay() == {
"GOOGLE_ACCESS_TOKEN": "fresh-google-token",
}


def test_venv_runtime_oauth_helper_uses_parent_rpc(monkeypatch):
from unity.function_manager import venv_runner

calls = []

def fake_rpc_call_sync(path, kwargs):
calls.append((path, kwargs))
return "fresh-ms-token"

monkeypatch.setattr(venv_runner, "rpc_call_sync", fake_rpc_call_sync)

assert (
venv_runner.get_oauth_access_token("microsoft", min_ttl_seconds=12)
== "fresh-ms-token"
)
assert calls == [
(
"runtime.get_oauth_access_token",
{"provider": "microsoft", "min_ttl_seconds": 12},
),
]
Loading