diff --git a/examples/knowledge_with_vectorstore/run_agent.py b/examples/knowledge_with_vectorstore/run_agent.py
index 329160c..e616067 100644
--- a/examples/knowledge_with_vectorstore/run_agent.py
+++ b/examples/knowledge_with_vectorstore/run_agent.py
@@ -7,10 +7,10 @@
import uuid
from dotenv import load_dotenv
-from trpc_agent.runners import Runner
-from trpc_agent.sessions import InMemorySessionService
-from trpc_agent.types import Content
-from trpc_agent.types import Part
+from trpc_agent_sdk.runners import Runner
+from trpc_agent_sdk.sessions import InMemorySessionService
+from trpc_agent_sdk.types import Content
+from trpc_agent_sdk.types import Part
# Load environment variables from the .env file
load_dotenv()
diff --git a/format.py b/format.py
index 290c8aa..cd52ba3 100644
--- a/format.py
+++ b/format.py
@@ -12,35 +12,9 @@
then lexicographical order.
- Keep ``from __future__ import annotations`` at the very top of imports.
-2) audit
- - Report files missing copyright marker ``Copyright @``.
- - Report files missing module docstring.
-
-3) build-package-api
- - Auto-generate per-package ``__init__.py`` export blocks.
- - Skip optional-feature modules inferred from ``pyproject.toml`` extras.
- - De-duplicate against symbols already manually imported in ``__init__.py``.
-
-4) sync-imports-package-api
- - Combined maintenance workflow:
- ensure missing ``__init__.py`` > rename private modules > rewrite imports >
- build package API exports > sort imports.
- - Add copyright marker to Python files when missing.
- - Keep exactly two blank lines after import blocks.
-
-5) rename-private-modules
- - Rename package modules from ``foo.py`` to ``_foo.py``.
- - Rewrite import statements across project files accordingly.
- - Optional: refresh package API blocks after rename.
-
Examples:
python3 format.py sort-imports --root . --dry-run
python3 format.py sort-imports --root .
- python3 format.py audit --root .
- python3 format.py build-package-api --root . --dry-run
- python3 format.py sync-imports-package-api --root . --dry-run
- python3 format.py rename-private-modules --root . --dry-run
- python3 format.py rename-private-modules --root . --build-package-api
python3 format.py check-chinese --root . --dry-run
"""
@@ -50,6 +24,7 @@
import ast
import os
import re
+import subprocess
import sys
import tomllib
from dataclasses import dataclass
@@ -74,14 +49,10 @@
".pytest_cache",
}
-COPYRIGHT_RE = re.compile(r"Copyright\s*@")
STDLIB_COMPANION_MODULES = {"typing_extensions"}
AUTO_EXPORT_BEGIN = "# "
AUTO_EXPORT_END = "# "
CHINESE_RE = re.compile(r"[\u4e00-\u9fff]")
-ENCODING_RE = re.compile(r"^#.*coding[:=]\s*([-\w.]+)")
-DEFAULT_COPYRIGHT_YEAR = "2026"
-DEFAULT_COPYRIGHT_OWNER = "Tencent.com"
@dataclass(frozen=True)
@@ -89,6 +60,7 @@ class ImportLine:
text: str
group: int
kind: int # 0=import, 1=from-import
+ rel_level: int = 0 # relative import level; larger means farther (.. > .)
@dataclass
@@ -157,42 +129,6 @@ def ensure_init_files(root: Path, apply: bool) -> list[Path]:
return created
-def _default_copyright_block() -> list[str]:
- return [
- "# -*- coding: utf-8 -*-",
- "#",
- f"# Copyright @ {DEFAULT_COPYRIGHT_YEAR} {DEFAULT_COPYRIGHT_OWNER}",
- ]
-
-
-def add_copyright_if_missing(source: str) -> tuple[str, bool]:
- if COPYRIGHT_RE.search(source):
- return source, False
-
- lines = source.splitlines()
- insert_at = 0
- out: list[str] = []
-
- if lines and lines[0].startswith("#!"):
- out.append(lines[0])
- insert_at = 1
-
- has_encoding = insert_at < len(lines) and bool(ENCODING_RE.match(lines[insert_at]))
- if has_encoding:
- out.append(lines[insert_at])
- insert_at += 1
- out.extend(["#", f"# Copyright @ {DEFAULT_COPYRIGHT_YEAR} {DEFAULT_COPYRIGHT_OWNER}"])
- else:
- out.extend(_default_copyright_block())
-
- if insert_at < len(lines) and lines[insert_at] != "":
- out.append("")
- out.extend(lines[insert_at:])
-
- new_source = "\n".join(out) + ("\n" if source.endswith("\n") or not source else "")
- return new_source, True
-
-
def read_optional_group_tokens(pyproject_path: Path) -> set[str]:
if not pyproject_path.exists():
return set()
@@ -446,21 +382,14 @@ def _extract_all_names_from_value(node: ast.AST) -> list[str]:
return out
-def merge_init_all_exports(init_path: Path, apply: bool) -> bool:
- """Merge multiple top-level __all__ blocks in __init__.py into one.
-
- Order policy:
- 1) follow top-level import symbol order first (from-import aliases order),
- 2) append remaining legacy __all__ names (deduplicated) in original order.
- """
- if not init_path.exists():
- return False
- original_source = init_path.read_text(encoding="utf-8")
- source = strip_auto_export_markers(original_source)
+def _merge_init_all_exports_source(source: str) -> str | None:
+ """Return merged __all__ source for __init__.py, or None when unchanged."""
+ original_source = source
+ source = strip_auto_export_markers(source)
try:
tree = ast.parse(source)
except SyntaxError:
- return False
+ return None
all_nodes: list[ast.Assign] = []
legacy_all_names: list[str] = []
@@ -473,6 +402,12 @@ def merge_init_all_exports(init_path: Path, apply: bool) -> bool:
if has_all_target:
all_nodes.append(node)
legacy_all_names.extend(_extract_all_names_from_value(node.value))
+ elif isinstance(node, ast.Import):
+ for alias in node.names:
+ sym = alias.asname or alias.name.split(".", 1)[0]
+ if sym not in seen_imports:
+ seen_imports.add(sym)
+ import_symbol_order.append(sym)
elif isinstance(node, ast.ImportFrom):
if node.module == "__future__":
continue
@@ -485,7 +420,7 @@ def merge_init_all_exports(init_path: Path, apply: bool) -> bool:
import_symbol_order.append(sym)
if len(all_nodes) == 0:
- return False
+ return None
final_names: list[str] = []
seen_final: set[str] = set()
@@ -515,7 +450,7 @@ def merge_init_all_exports(init_path: Path, apply: bool) -> bool:
try:
kept_tree = ast.parse(kept_source)
except SyntaxError:
- return False
+ return None
last_import_end = 0
for node in kept_tree.body:
if isinstance(node, (ast.Import, ast.ImportFrom)):
@@ -529,6 +464,17 @@ def merge_init_all_exports(init_path: Path, apply: bool) -> bool:
new_source = "\n".join(merged_lines) + ("\n" if source.endswith("\n") else "")
if new_source == original_source:
+ return None
+ return new_source
+
+
+def merge_init_all_exports(init_path: Path, apply: bool) -> bool:
+ """Merge multiple top-level __all__ blocks in __init__.py into one."""
+ if not init_path.exists():
+ return False
+ original_source = init_path.read_text(encoding="utf-8")
+ new_source = _merge_init_all_exports_source(original_source)
+ if new_source is None:
return False
if apply:
init_path.write_text(new_source, encoding="utf-8")
@@ -767,7 +713,7 @@ def normalize_import_block(
text = f"from {module} import {alias.name}"
if alias.asname:
text += f" as {alias.asname}"
- normalized.append(ImportLine(text=text, group=group, kind=1))
+ normalized.append(ImportLine(text=text, group=group, kind=1, rel_level=node.level))
groups: dict[int, list[ImportLine]] = {1: [], 2: [], 3: [], 4: []}
for item in normalized:
@@ -777,9 +723,13 @@ def normalize_import_block(
# 1) plain imports first
# 2) then from-imports
# 3) lexicographical order inside each kind
+ # 4) for relative imports, farther levels first: .. > .
for k in groups:
uniq = {(item.kind, item.text): item for item in groups[k]}
- groups[k] = sorted(uniq.values(), key=lambda x: (x.kind, x.text))
+ if k == 4:
+ groups[k] = sorted(uniq.values(), key=lambda x: (x.kind, -x.rel_level, x.text))
+ else:
+ groups[k] = sorted(uniq.values(), key=lambda x: (x.kind, x.text))
out: list[str] = []
if future_annotations:
@@ -793,39 +743,79 @@ def normalize_import_block(
return "\n".join(out) + "\n"
+def find_missing_relative_import_targets(init_path: Path, source: str) -> list[tuple[int, str]]:
+ """Check relative import module anchors in __init__.py.
+
+ For statements like ``from .a import b``, validate that ``a`` exists as
+ either ``a.py`` or directory ``a/`` at the resolved relative location.
+ """
+ try:
+ tree = ast.parse(source)
+ except SyntaxError:
+ return []
+
+ missing: list[tuple[int, str]] = []
+ for node in tree.body:
+ if not isinstance(node, ast.ImportFrom):
+ continue
+ if node.level < 1 or not node.module:
+ continue
+ first_seg = node.module.split(".", 1)[0].strip()
+ if not first_seg:
+ continue
+
+ anchor = init_path.parent
+ for _ in range(max(node.level - 1, 0)):
+ anchor = anchor.parent
+ file_candidate = anchor / f"{first_seg}.py"
+ dir_candidate = anchor / first_seg
+ if not (file_candidate.exists() or dir_candidate.exists()):
+ target = "." * node.level + node.module
+ missing.append((node.lineno, target))
+ return missing
+
+
+def run_yapf_on_python_files(root: Path) -> tuple[int, list[str]]:
+ """Run yapf -i for all Python files under root."""
+ formatted = 0
+ errors: list[str] = []
+ for py_file in sorted(iter_python_files(root)):
+ try:
+ result = subprocess.run(
+ ["yapf", "-i", str(py_file)],
+ capture_output=True,
+ text=True,
+ check=False,
+ )
+ except FileNotFoundError:
+ errors.append("yapf command not found")
+ break
+ if result.returncode != 0:
+ detail = result.stderr.strip() or result.stdout.strip() or f"exit={result.returncode}"
+ errors.append(f"{py_file}: {detail}")
+ continue
+ formatted += 1
+ return formatted, errors
+
+
def process_file(
path: Path,
stdlib_names: set[str],
project_packages: set[str],
apply: bool,
-) -> tuple[bool, bool, bool]:
- """Return (modified, has_copyright, has_module_docstring)."""
+) -> bool:
+ """Return whether file content was modified."""
source = path.read_text(encoding="utf-8")
- has_copyright = bool(COPYRIGHT_RE.search(source))
try:
tree = ast.parse(source)
except SyntaxError:
# Skip files that cannot be parsed.
- return False, has_copyright, False
-
- has_module_docstring = ast.get_docstring(tree, clean=False) is not None
- source_after_copyright, copyright_inserted = add_copyright_if_missing(source)
- if copyright_inserted:
- source = source_after_copyright
- has_copyright = True
- try:
- tree = ast.parse(source)
- except SyntaxError:
- if apply:
- path.write_text(source, encoding="utf-8")
- return True, has_copyright, has_module_docstring
+ return False
import_nodes = extract_leading_import_nodes(tree)
if not import_nodes:
- if copyright_inserted and apply:
- path.write_text(source, encoding="utf-8")
- return copyright_inserted, has_copyright, has_module_docstring
+ return False
start = import_nodes[0].lineno - 1
end = import_nodes[-1].end_lineno or import_nodes[-1].lineno
@@ -850,12 +840,17 @@ def process_file(
new_lines.extend(tail_lines[leading_blank_count:])
new_source = "\n".join(new_lines) + ("\n" if source.endswith("\n") else "")
+ if path.name == "__init__.py":
+ merged_source = _merge_init_all_exports_source(new_source)
+ if merged_source is not None:
+ new_source = merged_source
+
if new_source == source:
- return False, has_copyright, has_module_docstring
+ return False
if apply:
path.write_text(new_source, encoding="utf-8")
- return True, has_copyright, has_module_docstring
+ return True
def _names_from_target(target: ast.AST) -> list[str]:
@@ -1097,9 +1092,14 @@ def run_sort_imports(root: Path, dry_run: bool) -> int:
project_packages = discover_project_packages(root)
py_files = sorted(iter_python_files(root))
modified_files: list[Path] = []
+ missing_relative_targets: list[tuple[Path, int, str]] = []
for path in py_files:
- modified, _has_copyright, _has_module_docstring = process_file(
+ source_for_check = path.read_text(encoding="utf-8")
+ if path.name == "__init__.py":
+ for lineno, target in find_missing_relative_import_targets(path, source_for_check):
+ missing_relative_targets.append((path, lineno, target))
+ modified = process_file(
path=path,
stdlib_names=stdlib_names,
project_packages=project_packages,
@@ -1112,113 +1112,20 @@ def run_sort_imports(root: Path, dry_run: bool) -> int:
print(f"[{mode}] sort-imports scanned: {len(py_files)}, modified: {len(modified_files)}")
for p in modified_files:
print(str(p))
- return 0
-
-
-def run_audit(root: Path, dry_run: bool) -> int:
- stdlib_names = set(getattr(sys, "stdlib_module_names", set()))
- project_packages = discover_project_packages(root)
- py_files = sorted(iter_python_files(root))
- no_copyright_files: list[Path] = []
- no_module_docstring_files: list[Path] = []
-
- for path in py_files:
- _modified, has_copyright, has_module_docstring = process_file(
- path=path,
- stdlib_names=stdlib_names,
- project_packages=project_packages,
- apply=False,
- )
- if not has_copyright:
- no_copyright_files.append(path)
- if not has_module_docstring:
- no_module_docstring_files.append(path)
-
- mode = "DRY_RUN" if dry_run else "APPLY"
- print(f"[{mode}] audit scanned: {len(py_files)}")
- print("Files missing copyright marker (Copyright @):")
- for p in no_copyright_files:
- print(str(p))
- print("Files missing module docstring:")
- for p in no_module_docstring_files:
- print(str(p))
- return 0
-
-
-def run_build_package_api(root: Path, dry_run: bool) -> int:
- changed_init_files = build_package_api_exports(root, apply=not dry_run)
- mode = "DRY_RUN" if dry_run else "APPLY"
- print(f"[{mode}] build-package-api updated: {len(changed_init_files)}")
- for p in changed_init_files:
- print(str(p))
- return 0
+ if missing_relative_targets:
+ print("Missing relative import module targets in __init__.py:")
+ for path, lineno, target in missing_relative_targets:
+ print(f"{path}:{lineno} [missing_relative_import_target] {target}")
+ if not dry_run:
+ formatted_count, yapf_errors = run_yapf_on_python_files(root)
+ print(f"[APPLY] yapf formatted: {formatted_count}")
+ for err in yapf_errors:
+ print(f"[yapf-error] {err}")
+ if yapf_errors:
+ return 1
-def run_sync_imports_and_package_api(root: Path, dry_run: bool) -> int:
- """Combined command: rename/fix imports + build package API + sort imports."""
- apply = not dry_run
- project_packages = discover_project_packages(root)
- py_files = sorted(iter_python_files(root))
- stdlib_names = set(getattr(sys, "stdlib_module_names", set()))
-
- created_init_files = ensure_init_files(root, apply=apply)
- # Refresh package discovery after creating missing __init__.py files.
- project_packages = discover_project_packages(root)
- py_files = sorted(iter_python_files(root))
-
- rename_map = build_private_module_rename_map(root, project_packages)
- rewritten_import_files = rewrite_imports_for_renamed_modules(
- root=root,
- py_files=py_files,
- project_packages=project_packages,
- rename_map=rename_map,
- apply=apply,
- )
- renamed_modules = apply_private_module_renames(rename_map, apply=apply)
-
- # Re-scan file list after module renames.
- py_files = sorted(iter_python_files(root))
- changed_init_files = build_package_api_exports(root, apply=apply)
-
- modified_files: list[Path] = []
- for path in py_files:
- modified, _has_copyright, _has_module_docstring = process_file(
- path=path,
- stdlib_names=stdlib_names,
- project_packages=project_packages,
- apply=apply,
- )
- if modified:
- modified_files.append(path)
-
- # Ensure __all__ follows final import order in __init__.py files.
- merged_init_files: list[Path] = []
- for package_dir in iter_package_dirs(root):
- init_path = package_dir / "__init__.py"
- if merge_init_all_exports(init_path, apply=apply):
- merged_init_files.append(init_path)
-
- mode = "DRY_RUN" if dry_run else "APPLY"
- print(f"[{mode}] sync-imports-package-api")
- print(f"Created __init__.py files: {len(created_init_files)}")
- for p in created_init_files:
- print(str(p))
- print(f"Import files rewritten: {len(rewritten_import_files)}")
- for p in rewritten_import_files:
- print(str(p))
- print(f"Modules renamed: {len(renamed_modules)}")
- for old_path, new_path in renamed_modules:
- print(f"{old_path} -> {new_path}")
- print(f"Package __init__.py refreshed: {len(changed_init_files)}")
- for p in changed_init_files:
- print(str(p))
- print(f"Python files normalized: {len(modified_files)}")
- for p in modified_files:
- print(str(p))
- print(f"__init__.py __all__ merged: {len(merged_init_files)}")
- for p in merged_init_files:
- print(str(p))
- return 0
+ return 1 if missing_relative_targets else 0
def run_report_private_candidates(root: Path, dry_run: bool) -> int:
@@ -1231,37 +1138,6 @@ def run_report_private_candidates(root: Path, dry_run: bool) -> int:
return 0
-def run_rename_private_modules(root: Path, dry_run: bool, build_package_api: bool) -> int:
- project_packages = discover_project_packages(root)
- py_files = sorted(iter_python_files(root))
- rename_map = build_private_module_rename_map(root, project_packages)
- rewritten_import_files = rewrite_imports_for_renamed_modules(
- root=root,
- py_files=py_files,
- project_packages=project_packages,
- rename_map=rename_map,
- apply=not dry_run,
- )
- renamed_modules = apply_private_module_renames(rename_map, apply=not dry_run)
- changed_init_files: list[Path] = []
- if build_package_api:
- changed_init_files = build_package_api_exports(root, apply=not dry_run)
-
- mode = "DRY_RUN" if dry_run else "APPLY"
- print(f"[{mode}] rename-private-modules")
- print(f"Import files rewritten: {len(rewritten_import_files)}")
- for p in rewritten_import_files:
- print(str(p))
- print(f"Modules renamed: {len(renamed_modules)}")
- for old_path, new_path in renamed_modules:
- print(f"{old_path} -> {new_path}")
- if build_package_api:
- print(f"Package __init__.py refreshed: {len(changed_init_files)}")
- for p in changed_init_files:
- print(str(p))
- return 0
-
-
def run_detect_issues(root: Path, dry_run: bool) -> int:
py_files = sorted(iter_python_files(root))
all_issues: list[CheckIssue] = []
@@ -1304,34 +1180,10 @@ def main() -> int:
p_sort.add_argument("--root", default=".", help="Project root directory.")
p_sort.add_argument("--dry-run", action="store_true", help="Preview changes without writing files.")
- p_audit = subparsers.add_parser("audit", help="Report missing copyright/docstring files.")
- p_audit.add_argument("--root", default=".", help="Project root directory.")
- p_audit.add_argument("--dry-run", action="store_true", help="Read-only mode (same output).")
-
- p_api = subparsers.add_parser("build-package-api", help="Refresh package __init__.py export blocks.")
- p_api.add_argument("--root", default=".", help="Project root directory.")
- p_api.add_argument("--dry-run", action="store_true", help="Preview changes without writing files.")
-
- p_sync = subparsers.add_parser(
- "sync-imports-package-api",
- help="Run rename/import-fix + build-package-api + sort-imports in one command.",
- )
- p_sync.add_argument("--root", default=".", help="Project root directory.")
- p_sync.add_argument("--dry-run", action="store_true", help="Preview changes without writing files.")
-
p_report = subparsers.add_parser("report-private-candidates", help="List foo.py -> _foo.py candidates.")
p_report.add_argument("--root", default=".", help="Project root directory.")
p_report.add_argument("--dry-run", action="store_true", help="Read-only mode (same output).")
- p_rename = subparsers.add_parser("rename-private-modules", help="Rename modules and rewrite imports.")
- p_rename.add_argument("--root", default=".", help="Project root directory.")
- p_rename.add_argument("--dry-run", action="store_true", help="Preview changes without writing files.")
- p_rename.add_argument(
- "--build-package-api",
- action="store_true",
- help="Also refresh package __init__.py export blocks after rename.",
- )
-
p_detect = subparsers.add_parser(
"detect-issues",
help=(
@@ -1358,16 +1210,8 @@ def main() -> int:
if args.command == "sort-imports":
return run_sort_imports(root, args.dry_run)
- if args.command == "audit":
- return run_audit(root, args.dry_run)
- if args.command == "build-package-api":
- return run_build_package_api(root, args.dry_run)
- if args.command == "sync-imports-package-api":
- return run_sync_imports_and_package_api(root, args.dry_run)
if args.command == "report-private-candidates":
return run_report_private_candidates(root, args.dry_run)
- if args.command == "rename-private-modules":
- return run_rename_private_modules(root, args.dry_run, args.build_package_api)
if args.command == "detect-issues":
return run_detect_issues(root, args.dry_run)
if args.command == "check-chinese":
diff --git a/pyproject.toml b/pyproject.toml
index 09f2532..2cc41b9 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -52,6 +52,7 @@ dependencies = [
"rapidfuzz>=3.0.0",
"charset-normalizer>=3.0.0",
"typer>=0.12.0", # MIT License
+ "watchdog",
]
[project.optional-dependencies]
diff --git a/tests/agents/core/test_agent_transfer_processor.py b/tests/agents/core/test_agent_transfer_processor.py
index 167315b..19d97bd 100644
--- a/tests/agents/core/test_agent_transfer_processor.py
+++ b/tests/agents/core/test_agent_transfer_processor.py
@@ -8,19 +8,26 @@
from __future__ import annotations
import asyncio
-from typing import AsyncGenerator, List
-from unittest.mock import Mock, patch
+from typing import List
+from unittest.mock import Mock
import pytest
+import trpc_agent_sdk.skills as _skills_pkg
+
+if not hasattr(_skills_pkg, "get_skill_processor_parameters"):
+
+ def _compat_get_skill_processor_parameters(agent_context):
+ from trpc_agent_sdk.agents.core._skill_processor import get_skill_processor_parameters as _impl
+ return _impl(agent_context)
+
+ _skills_pkg.get_skill_processor_parameters = _compat_get_skill_processor_parameters
from trpc_agent_sdk.agents._llm_agent import LlmAgent
-from trpc_agent_sdk.agents._base_agent import BaseAgent
from trpc_agent_sdk.agents.core._agent_transfer_processor import (
AgentTransferProcessor,
default_agent_transfer_processor,
)
from trpc_agent_sdk.context import InvocationContext, create_agent_context
-from trpc_agent_sdk.events import Event
from trpc_agent_sdk.models import LLMModel, LlmRequest, LlmResponse, ModelRegistry
from trpc_agent_sdk.sessions import InMemorySessionService
from trpc_agent_sdk.types import GenerateContentConfig
diff --git a/tests/agents/core/test_code_execution_processor.py b/tests/agents/core/test_code_execution_processor.py
index 0a93a98..46fd21f 100644
--- a/tests/agents/core/test_code_execution_processor.py
+++ b/tests/agents/core/test_code_execution_processor.py
@@ -7,8 +7,6 @@
from __future__ import annotations
-import pytest
-
from trpc_agent_sdk.agents.core._code_execution_processor import (
DataFileUtil,
_DATA_FILE_UTIL_MAP,
diff --git a/tests/agents/core/test_history_processor.py b/tests/agents/core/test_history_processor.py
index 03cb64c..22182b3 100644
--- a/tests/agents/core/test_history_processor.py
+++ b/tests/agents/core/test_history_processor.py
@@ -8,8 +8,6 @@
from __future__ import annotations
import asyncio
-from typing import List
-from unittest.mock import Mock
import pytest
diff --git a/tests/agents/core/test_llm_processor.py b/tests/agents/core/test_llm_processor.py
index d327f28..5779bea 100644
--- a/tests/agents/core/test_llm_processor.py
+++ b/tests/agents/core/test_llm_processor.py
@@ -8,8 +8,8 @@
from __future__ import annotations
import asyncio
-from typing import AsyncGenerator, List
-from unittest.mock import AsyncMock, Mock, patch
+from typing import List
+from unittest.mock import Mock
import pytest
diff --git a/tests/agents/core/test_request_processor.py b/tests/agents/core/test_request_processor.py
index 9e2e36b..e7f4d28 100644
--- a/tests/agents/core/test_request_processor.py
+++ b/tests/agents/core/test_request_processor.py
@@ -8,11 +8,18 @@
from __future__ import annotations
import asyncio
-import copy
from typing import List
-from unittest.mock import AsyncMock, Mock, patch
import pytest
+import trpc_agent_sdk.skills as _skills_pkg
+
+if not hasattr(_skills_pkg, "get_skill_processor_parameters"):
+
+ def _compat_get_skill_processor_parameters(agent_context):
+ from trpc_agent_sdk.agents.core._skill_processor import get_skill_processor_parameters as _impl
+ return _impl(agent_context)
+
+ _skills_pkg.get_skill_processor_parameters = _compat_get_skill_processor_parameters
from trpc_agent_sdk.agents._llm_agent import LlmAgent
from trpc_agent_sdk.agents.core._request_processor import (
diff --git a/tests/agents/core/test_request_processor_ext.py b/tests/agents/core/test_request_processor_ext.py
index ba25f18..db15b75 100644
--- a/tests/agents/core/test_request_processor_ext.py
+++ b/tests/agents/core/test_request_processor_ext.py
@@ -9,11 +9,19 @@
from __future__ import annotations
import asyncio
-import copy
from typing import List
-from unittest.mock import AsyncMock, MagicMock, Mock, patch
+from unittest.mock import AsyncMock, MagicMock, patch
import pytest
+import trpc_agent_sdk.skills as _skills_pkg
+
+if not hasattr(_skills_pkg, "get_skill_processor_parameters"):
+
+ def _compat_get_skill_processor_parameters(agent_context):
+ from trpc_agent_sdk.agents.core._skill_processor import get_skill_processor_parameters as _impl
+ return _impl(agent_context)
+
+ _skills_pkg.get_skill_processor_parameters = _compat_get_skill_processor_parameters
from trpc_agent_sdk.agents._llm_agent import LlmAgent
from trpc_agent_sdk.agents.core._request_processor import (
@@ -21,7 +29,7 @@
default_request_processor,
)
from trpc_agent_sdk.context import InvocationContext, create_agent_context
-from trpc_agent_sdk.events import Event, EventActions
+from trpc_agent_sdk.events import Event
from trpc_agent_sdk.models import LLMModel, LlmRequest, LlmResponse, ModelRegistry
from trpc_agent_sdk.sessions import InMemorySessionService
from trpc_agent_sdk.types import Content, FunctionCall, FunctionResponse, GenerateContentConfig, Part
@@ -256,7 +264,7 @@ async def test_no_skill_repository_returns_none(self, processor, ctx):
"""Returns None (no error) when agent has no skill_repository."""
ctx.agent.skill_repository = None
request = LlmRequest(model="test-rp-ext-model")
- result = await processor._add_skills_to_request(ctx.agent, ctx, request)
+ result = await processor._add_skills_to_request(ctx.agent, ctx, request, parameters={})
assert result is None
@pytest.mark.asyncio
@@ -268,7 +276,7 @@ async def test_skill_processing_error_returns_event(self, processor, ctx):
instance = MockSRP.return_value
instance.process_llm_request = AsyncMock(side_effect=RuntimeError("skill boom"))
request = LlmRequest(model="test-rp-ext-model")
- result = await processor._add_skills_to_request(ctx.agent, ctx, request)
+ result = await processor._add_skills_to_request(ctx.agent, ctx, request, parameters={})
assert result is not None
assert result.error_code == "skill_processing_error"
diff --git a/tests/agents/core/test_skill_processor.py b/tests/agents/core/test_skill_processor.py
index f850278..7968c40 100644
--- a/tests/agents/core/test_skill_processor.py
+++ b/tests/agents/core/test_skill_processor.py
@@ -3,658 +3,222 @@
# Copyright (C) 2026 Tencent. All rights reserved.
#
# tRPC-Agent-Python is licensed under Apache-2.0.
-"""Unit tests for SkillsRequestProcessor and module-level helpers."""
-from __future__ import annotations
-
-import asyncio
import json
-from typing import List
-from unittest.mock import AsyncMock, MagicMock, Mock, patch
-
-import pytest
-
-from trpc_agent_sdk.agents.core._skill_processor import (
- SKILL_LOAD_MODE_ONCE,
- SKILL_LOAD_MODE_SESSION,
- SKILL_LOAD_MODE_TURN,
- SkillsRequestProcessor,
- _default_knowledge_only_guidance,
- _default_full_tooling_and_workspace_guidance,
- _default_tooling_and_workspace_guidance,
- _is_knowledge_only,
- _normalize_custom_guidance,
- _normalize_load_mode,
- _SKILLS_LOADED_ORDER_STATE_KEY,
- _SKILLS_OVERVIEW_HEADER,
-)
-from trpc_agent_sdk.context import InvocationContext, create_agent_context
-from trpc_agent_sdk.events import EventActions
-from trpc_agent_sdk.models import LLMModel, LlmRequest, LlmResponse, ModelRegistry
-from trpc_agent_sdk.sessions import InMemorySessionService
-from trpc_agent_sdk.skills import (
- SKILL_DOCS_STATE_KEY_PREFIX,
- SKILL_LOADED_STATE_KEY_PREFIX,
- SKILL_TOOLS_STATE_KEY_PREFIX,
- BaseSkillRepository,
- Skill,
- SkillResource,
- SkillSummary,
-)
-from trpc_agent_sdk.types import GenerateContentConfig
-
-
-# ---------------------------------------------------------------------------
-# Helpers and fixtures
-# ---------------------------------------------------------------------------
-
-class _MockLLMModel(LLMModel):
- @classmethod
- def supported_models(cls) -> List[str]:
- return [r"test-sp-.*"]
-
- async def _generate_async_impl(self, request, stream=False, ctx=None):
- yield LlmResponse(content=None)
-
- def validate_request(self, request):
- pass
-
-
-class _StubRepo(BaseSkillRepository):
- """In-memory stub for BaseSkillRepository."""
-
- def __init__(self, skills=None, summaries_list=None, user_prompt_text=""):
- super().__init__(workspace_runtime=MagicMock())
- self._skills = skills or {}
- self._summaries = summaries_list or []
- self._user_prompt_text = user_prompt_text
-
- def summaries(self) -> list[SkillSummary]:
- return self._summaries
-
- def get(self, name: str) -> Skill:
- if name not in self._skills:
- raise ValueError(f"Skill not found: {name}")
- return self._skills[name]
-
- def user_prompt(self) -> str:
- return self._user_prompt_text
-
- def skill_list(self) -> list[str]:
- return list(self._skills.keys())
-
- def path(self, name: str) -> str:
- if name not in self._skills:
- raise ValueError(f"Skill not found: {name}")
- return f"/fake/skills/{name}"
-
- def refresh(self) -> None:
- pass
-
-
-@pytest.fixture(scope="module", autouse=True)
-def register_test_model():
- original_registry = ModelRegistry._registry.copy()
- ModelRegistry.register(_MockLLMModel)
- yield
- ModelRegistry._registry = original_registry
-
-
-@pytest.fixture
-def session_service():
- return InMemorySessionService()
-
-
-@pytest.fixture
-def session(session_service):
- return asyncio.run(
- session_service.create_session(app_name="test", user_id="u1", session_id="sp_sess")
- )
-
-
-@pytest.fixture
-def ctx(session_service, session):
- from trpc_agent_sdk.agents._llm_agent import LlmAgent
- agent = LlmAgent(name="skill_agent", model="test-sp-model")
- return InvocationContext(
- session_service=session_service,
- invocation_id="inv-sp-1",
- agent=agent,
- agent_context=create_agent_context(),
- session=session,
- branch="branch_sp",
- )
-
-
-@pytest.fixture
-def sample_skill():
+from copy import deepcopy
+from unittest.mock import Mock
+
+import trpc_agent_sdk.skills as _skills_pkg
+
+if not hasattr(_skills_pkg, "get_skill_processor_parameters"):
+
+ def _compat_get_skill_processor_parameters(agent_context):
+ from trpc_agent_sdk.agents.core._skill_processor import get_skill_processor_parameters as _impl
+ return _impl(agent_context)
+
+ _skills_pkg.get_skill_processor_parameters = _compat_get_skill_processor_parameters
+
+from trpc_agent_sdk.agents.core._skill_processor import SkillsRequestProcessor
+from trpc_agent_sdk.agents.core._skill_processor import get_skill_processor_parameters
+from trpc_agent_sdk.agents.core._skill_processor import set_skill_processor_parameters
+from trpc_agent_sdk.context import AgentContext
+from trpc_agent_sdk.context import InvocationContext
+from trpc_agent_sdk.models import LlmRequest
+from trpc_agent_sdk.skills import Skill
+from trpc_agent_sdk.skills import SkillResource
+from trpc_agent_sdk.skills import SkillSummary
+from trpc_agent_sdk.skills import docs_key
+from trpc_agent_sdk.skills import docs_session_key
+from trpc_agent_sdk.skills import loaded_key
+from trpc_agent_sdk.skills import loaded_order_key
+from trpc_agent_sdk.skills import loaded_session_key
+from trpc_agent_sdk.skills import loaded_session_order_key
+from trpc_agent_sdk.skills import tool_key
+from trpc_agent_sdk.skills import tool_session_key
+from trpc_agent_sdk.skills._skill_config import DEFAULT_SKILL_CONFIG
+from trpc_agent_sdk.skills._skill_config import set_skill_config
+
+
+def _build_skill(name: str) -> Skill:
return Skill(
- summary=SkillSummary(name="code_review", description="Reviews code"),
- body="# Code Review\nReview PR diffs.",
- resources=[
- SkillResource(path="guide.md", content="Review guide content"),
- SkillResource(path="checklist.md", content="Checklist content"),
- ],
- tools=["lint_check", "format_code"],
- )
-
-
-@pytest.fixture
-def sample_repo(sample_skill):
- return _StubRepo(
- skills={"code_review": sample_skill},
- summaries_list=[SkillSummary(name="code_review", description="Reviews code")],
+ summary=SkillSummary(name=name, description=f"{name} description"),
+ body=f"{name} body",
+ resources=[SkillResource(path="docs/guide.md", content=f"{name} guide")],
+ tools=["tool-a", "tool-b"],
)
-# ---------------------------------------------------------------------------
-# Module-level helpers
-# ---------------------------------------------------------------------------
-
-
-class TestNormalizeLoadMode:
- def test_valid_modes(self):
- """Valid mode strings are returned as-is (lowered)."""
- assert _normalize_load_mode("once") == SKILL_LOAD_MODE_ONCE
- assert _normalize_load_mode("TURN") == SKILL_LOAD_MODE_TURN
- assert _normalize_load_mode("Session") == SKILL_LOAD_MODE_SESSION
-
- def test_invalid_mode_defaults_to_turn(self):
- """Invalid mode falls back to turn."""
- assert _normalize_load_mode("bogus") == SKILL_LOAD_MODE_TURN
-
- def test_empty_mode_defaults_to_turn(self):
- """Empty/None mode falls back to turn."""
- assert _normalize_load_mode("") == SKILL_LOAD_MODE_TURN
- assert _normalize_load_mode(None) == SKILL_LOAD_MODE_TURN
-
-
-class TestIsKnowledgeOnly:
- def test_knowledge_only_profiles(self):
- """Recognized knowledge-only profile strings return True."""
- assert _is_knowledge_only("knowledge_only") is True
- assert _is_knowledge_only("knowledge") is True
- assert _is_knowledge_only("Knowledge-Only") is True
-
- def test_non_knowledge_profiles(self):
- """Non-matching profiles return False."""
- assert _is_knowledge_only("full") is False
- assert _is_knowledge_only("") is False
- assert _is_knowledge_only(None) is False
-
-
-class TestNormalizeCustomGuidance:
- def test_empty_returns_empty(self):
- """Empty string passes through."""
- assert _normalize_custom_guidance("") == ""
-
- def test_adds_leading_newline(self):
- """Leading newline is added if missing."""
- result = _normalize_custom_guidance("hello")
- assert result.startswith("\n")
-
- def test_adds_trailing_newline(self):
- """Trailing newline is added if missing."""
- result = _normalize_custom_guidance("hello")
- assert result.endswith("\n")
-
- def test_existing_newlines_preserved(self):
- """Existing leading/trailing newlines are not doubled."""
- result = _normalize_custom_guidance("\nhello\n")
- assert result == "\nhello\n"
-
-
-class TestGuidanceTextBuilders:
- def test_knowledge_only_guidance_contains_header(self):
- """Knowledge-only guidance includes the tooling guidance header."""
- text = _default_knowledge_only_guidance()
- assert "Tooling and workspace guidance" in text
-
- def test_full_tooling_guidance_exec_enabled(self):
- """Full tooling guidance with exec tools enabled mentions skill_exec."""
- text = _default_full_tooling_and_workspace_guidance(exec_tools_disabled=False)
- assert "skill_exec" in text
-
- def test_full_tooling_guidance_exec_disabled(self):
- """Full tooling guidance with exec tools disabled omits skill_exec mention."""
- text = _default_full_tooling_and_workspace_guidance(exec_tools_disabled=True)
- assert "interactive execution is available" in text
-
- def test_default_routing_knowledge_only(self):
- """Dispatcher routes to knowledge-only guidance for matching profile."""
- text = _default_tooling_and_workspace_guidance("knowledge_only", False)
- assert "progressive disclosure" in text
-
- def test_default_routing_full(self):
- """Dispatcher routes to full guidance for non-matching profile."""
- text = _default_tooling_and_workspace_guidance("", False)
- assert "skill_run" in text
-
-
-# ---------------------------------------------------------------------------
-# SkillsRequestProcessor.__init__
-# ---------------------------------------------------------------------------
-
-
-class TestSkillsRequestProcessorInit:
- def test_defaults(self, sample_repo):
- """Default parameters are applied correctly."""
- proc = SkillsRequestProcessor(sample_repo)
- assert proc._load_mode == SKILL_LOAD_MODE_TURN
- assert proc._tool_result_mode is False
- assert proc._max_loaded_skills == 0
-
- def test_custom_parameters(self, sample_repo):
- """Custom init parameters are stored."""
- proc = SkillsRequestProcessor(
- sample_repo,
- load_mode="once",
- tooling_guidance="custom",
- tool_result_mode=True,
- tool_profile="knowledge_only",
- exec_tools_disabled=True,
- max_loaded_skills=5,
- )
- assert proc._load_mode == SKILL_LOAD_MODE_ONCE
- assert proc._tooling_guidance == "custom"
- assert proc._tool_result_mode is True
- assert proc._tool_profile == "knowledge_only"
- assert proc._exec_tools_disabled is True
- assert proc._max_loaded_skills == 5
-
-
-# ---------------------------------------------------------------------------
-# _get_repository
-# ---------------------------------------------------------------------------
-
-
-class TestGetRepository:
- def test_returns_default_repo(self, sample_repo):
- """Returns default repository when no resolver is set."""
- proc = SkillsRequestProcessor(sample_repo)
- ctx_mock = MagicMock()
- assert proc._get_repository(ctx_mock) is sample_repo
-
- def test_resolver_overrides(self, sample_repo):
- """Repo resolver takes precedence over the default repository."""
- other_repo = _StubRepo()
- proc = SkillsRequestProcessor(sample_repo, repo_resolver=lambda c: other_repo)
- ctx_mock = MagicMock()
- assert proc._get_repository(ctx_mock) is other_repo
-
-
-# ---------------------------------------------------------------------------
-# _snapshot_state / _read_state
-# ---------------------------------------------------------------------------
-
-
-class TestStateHelpers:
- def test_snapshot_merges_delta(self, ctx, sample_repo):
- """Snapshot merges session state with pending delta."""
- proc = SkillsRequestProcessor(sample_repo)
- ctx.session.state["key_a"] = "from_session"
- ctx.actions.state_delta["key_b"] = "from_delta"
- snap = proc._snapshot_state(ctx)
- assert snap["key_a"] == "from_session"
- assert snap["key_b"] == "from_delta"
-
- def test_snapshot_delta_none_removes_key(self, ctx, sample_repo):
- """Delta value of None removes the key from snapshot."""
- proc = SkillsRequestProcessor(sample_repo)
- ctx.session.state["key_a"] = "exists"
- ctx.actions.state_delta["key_a"] = None
- snap = proc._snapshot_state(ctx)
- assert "key_a" not in snap
-
- def test_read_state_from_delta(self, ctx, sample_repo):
- """read_state prefers delta over session state."""
- proc = SkillsRequestProcessor(sample_repo)
- ctx.session.state["k"] = "old"
- ctx.actions.state_delta["k"] = "new"
- assert proc._read_state(ctx, "k") == "new"
-
- def test_read_state_from_session(self, ctx, sample_repo):
- """read_state falls back to session state when delta has no key."""
- proc = SkillsRequestProcessor(sample_repo)
- ctx.session.state["k"] = "val"
- assert proc._read_state(ctx, "k") == "val"
-
- def test_read_state_default(self, ctx, sample_repo):
- """read_state returns default when key not in delta or session."""
- proc = SkillsRequestProcessor(sample_repo)
- assert proc._read_state(ctx, "missing", "fallback") == "fallback"
-
-
-# ---------------------------------------------------------------------------
-# _get_loaded_skills
-# ---------------------------------------------------------------------------
-
-
-class TestGetLoadedSkills:
- def test_no_loaded_skills(self, ctx, sample_repo):
- """Returns empty list when no skills are loaded."""
- proc = SkillsRequestProcessor(sample_repo)
- assert proc._get_loaded_skills(ctx) == []
-
- def test_loaded_skills_detected(self, ctx, sample_repo):
- """Skills with state key prefix are detected."""
- proc = SkillsRequestProcessor(sample_repo)
- ctx.session.state[SKILL_LOADED_STATE_KEY_PREFIX + "code_review"] = "1"
- result = proc._get_loaded_skills(ctx)
- assert "code_review" in result
-
- def test_falsy_value_ignored(self, ctx, sample_repo):
- """Skills with falsy state values are not returned."""
- proc = SkillsRequestProcessor(sample_repo)
- ctx.session.state[SKILL_LOADED_STATE_KEY_PREFIX + "empty_skill"] = ""
- assert proc._get_loaded_skills(ctx) == []
-
-
-# ---------------------------------------------------------------------------
-# _inject_overview
-# ---------------------------------------------------------------------------
-
-
-class TestInjectOverview:
- def test_overview_injected(self, sample_repo):
- """Overview is injected into system instruction."""
- proc = SkillsRequestProcessor(sample_repo)
- request = LlmRequest(model="test-sp-model")
- proc._inject_overview(request, sample_repo)
- sys_instr = str(request.config.system_instruction)
- assert "code_review" in sys_instr
- assert "Reviews code" in sys_instr
-
- def test_no_summaries_no_injection(self):
- """No injection when repo has no summaries."""
- repo = _StubRepo(summaries_list=[])
- proc = SkillsRequestProcessor(repo)
- request = LlmRequest(model="test-sp-model")
- proc._inject_overview(request, repo)
- assert request.config is None or request.config.system_instruction is None
-
- def test_double_injection_guard(self, sample_repo):
- """Overview is not injected twice."""
- proc = SkillsRequestProcessor(sample_repo)
- request = LlmRequest(model="test-sp-model")
- proc._inject_overview(request, sample_repo)
- first_instr = str(request.config.system_instruction)
- proc._inject_overview(request, sample_repo)
- second_instr = str(request.config.system_instruction)
- assert first_instr == second_instr
-
- def test_user_prompt_prepended(self):
- """Repository user_prompt is prepended to overview."""
- repo = _StubRepo(
- summaries_list=[SkillSummary(name="s1", description="desc")],
- user_prompt_text="Custom prompt",
- )
- proc = SkillsRequestProcessor(repo)
- request = LlmRequest(model="test-sp-model")
- proc._inject_overview(request, repo)
- sys_instr = str(request.config.system_instruction)
- assert sys_instr.index("Custom prompt") < sys_instr.index("s1")
-
-
-# ---------------------------------------------------------------------------
-# _maybe_clear_skill_state_for_turn
-# ---------------------------------------------------------------------------
-
-
-class TestMaybeClearSkillStateForTurn:
- def test_clears_on_first_invocation(self, ctx, sample_repo):
- """State is cleared on first invocation in turn mode."""
- proc = SkillsRequestProcessor(sample_repo, load_mode="turn")
- ctx.session.state[SKILL_LOADED_STATE_KEY_PREFIX + "sk1"] = "1"
- proc._maybe_clear_skill_state_for_turn(ctx)
- assert (SKILL_LOADED_STATE_KEY_PREFIX + "sk1") in ctx.actions.state_delta
- assert ctx.actions.state_delta[SKILL_LOADED_STATE_KEY_PREFIX + "sk1"] is None
-
- def test_no_clear_on_second_call_same_invocation(self, ctx, sample_repo):
- """State is NOT cleared on second call within same invocation."""
- proc = SkillsRequestProcessor(sample_repo, load_mode="turn")
- proc._maybe_clear_skill_state_for_turn(ctx)
- ctx.actions.state_delta.clear()
- ctx.session.state[SKILL_LOADED_STATE_KEY_PREFIX + "sk2"] = "1"
- proc._maybe_clear_skill_state_for_turn(ctx)
- assert (SKILL_LOADED_STATE_KEY_PREFIX + "sk2") not in ctx.actions.state_delta
-
- def test_no_clear_in_session_mode(self, ctx, sample_repo):
- """No clearing in session mode."""
- proc = SkillsRequestProcessor(sample_repo, load_mode="session")
- ctx.session.state[SKILL_LOADED_STATE_KEY_PREFIX + "sk1"] = "1"
- proc._maybe_clear_skill_state_for_turn(ctx)
- assert (SKILL_LOADED_STATE_KEY_PREFIX + "sk1") not in ctx.actions.state_delta
-
-
-# ---------------------------------------------------------------------------
-# _maybe_offload_loaded_skills
-# ---------------------------------------------------------------------------
-
-
-class TestMaybeOffloadLoadedSkills:
- def test_offloads_in_once_mode(self, ctx, sample_repo):
- """Skill state is cleared after injection in once mode."""
- proc = SkillsRequestProcessor(sample_repo, load_mode="once")
- proc._maybe_offload_loaded_skills(ctx, ["code_review"])
- assert ctx.actions.state_delta[SKILL_LOADED_STATE_KEY_PREFIX + "code_review"] is None
- assert ctx.actions.state_delta[SKILL_DOCS_STATE_KEY_PREFIX + "code_review"] is None
- assert ctx.actions.state_delta[SKILL_TOOLS_STATE_KEY_PREFIX + "code_review"] is None
- assert ctx.actions.state_delta[_SKILLS_LOADED_ORDER_STATE_KEY] is None
-
- def test_no_offload_in_turn_mode(self, ctx, sample_repo):
- """No offloading in turn mode."""
- proc = SkillsRequestProcessor(sample_repo, load_mode="turn")
- proc._maybe_offload_loaded_skills(ctx, ["code_review"])
- assert SKILL_LOADED_STATE_KEY_PREFIX + "code_review" not in ctx.actions.state_delta
-
- def test_no_offload_when_empty(self, ctx, sample_repo):
- """No offloading when loaded list is empty."""
- proc = SkillsRequestProcessor(sample_repo, load_mode="once")
- proc._maybe_offload_loaded_skills(ctx, [])
- assert len(ctx.actions.state_delta) == 0
-
-
-# ---------------------------------------------------------------------------
-# _maybe_cap_loaded_skills
-# ---------------------------------------------------------------------------
-
-
-class TestMaybeCapLoadedSkills:
- def test_no_cap_returns_all(self, ctx, sample_repo):
- """All skills returned when cap is 0 (disabled)."""
- proc = SkillsRequestProcessor(sample_repo, max_loaded_skills=0)
- result = proc._maybe_cap_loaded_skills(ctx, ["a", "b", "c"])
- assert result == ["a", "b", "c"]
-
- def test_under_cap_returns_all(self, ctx, sample_repo):
- """All skills returned when count is at or under cap."""
- proc = SkillsRequestProcessor(sample_repo, max_loaded_skills=5)
- result = proc._maybe_cap_loaded_skills(ctx, ["a", "b"])
- assert result == ["a", "b"]
-
- def test_over_cap_evicts_oldest(self, ctx, sample_repo):
- """Excess skills are evicted, keeping most recent."""
- proc = SkillsRequestProcessor(sample_repo, max_loaded_skills=2)
- ctx.actions.state_delta[_SKILLS_LOADED_ORDER_STATE_KEY] = json.dumps(["a", "b", "c"])
- result = proc._maybe_cap_loaded_skills(ctx, ["a", "b", "c"])
- assert len(result) == 2
- assert "a" not in result
- assert "b" in result
- assert "c" in result
-
-
-# ---------------------------------------------------------------------------
-# _build_docs_text
-# ---------------------------------------------------------------------------
-
-
-class TestBuildDocsText:
- def test_selected_docs_included(self, sample_repo, sample_skill):
- """Only selected docs are included in output."""
- proc = SkillsRequestProcessor(sample_repo)
- text = proc._build_docs_text(sample_skill, ["guide.md"])
- assert "Review guide content" in text
- assert "Checklist content" not in text
-
- def test_no_docs_returns_empty(self, sample_repo):
- """Empty string for skill with no resources."""
- proc = SkillsRequestProcessor(sample_repo)
- empty_skill = Skill(summary=SkillSummary(name="empty", description=""))
- assert proc._build_docs_text(empty_skill, ["any.md"]) == ""
-
- def test_none_skill_returns_empty(self, sample_repo):
- """None skill returns empty string."""
- proc = SkillsRequestProcessor(sample_repo)
- assert proc._build_docs_text(None, ["any.md"]) == ""
-
-
-# ---------------------------------------------------------------------------
-# _merge_into_system
-# ---------------------------------------------------------------------------
-
-
-class TestMergeIntoSystem:
- def test_appends_to_system(self, sample_repo):
- """Content is appended to system instruction."""
- proc = SkillsRequestProcessor(sample_repo)
- request = LlmRequest(model="test-sp-model")
- proc._merge_into_system(request, "extra guidance")
- assert "extra guidance" in str(request.config.system_instruction)
-
- def test_empty_content_no_op(self, sample_repo):
- """Empty string is a no-op."""
- proc = SkillsRequestProcessor(sample_repo)
- request = LlmRequest(model="test-sp-model")
- proc._merge_into_system(request, "")
- assert request.config is None or request.config.system_instruction is None
-
-
-# ---------------------------------------------------------------------------
-# _capability_guidance_text
-# ---------------------------------------------------------------------------
-
-
-class TestCapabilityGuidanceText:
- def test_non_knowledge_profile_empty(self, sample_repo):
- """Non-knowledge profile returns empty string."""
- proc = SkillsRequestProcessor(sample_repo, tool_profile="full")
- assert proc._capability_guidance_text() == ""
-
- def test_knowledge_profile_returns_guidance(self, sample_repo):
- """Knowledge-only profile returns capability guidance."""
- proc = SkillsRequestProcessor(sample_repo, tool_profile="knowledge_only")
- text = proc._capability_guidance_text()
- assert "knowledge loading only" in text
-
- def test_knowledge_profile_with_empty_guidance_suppressed(self, sample_repo):
- """Knowledge profile with explicit empty guidance suppresses capability block."""
- proc = SkillsRequestProcessor(sample_repo, tool_profile="knowledge_only", tooling_guidance="")
- assert proc._capability_guidance_text() == ""
-
-
-# ---------------------------------------------------------------------------
-# process_llm_request (integration)
-# ---------------------------------------------------------------------------
-
-
-class TestProcessLlmRequest:
- @pytest.mark.asyncio
- async def test_none_request_returns_empty(self, ctx, sample_repo):
- """Returns empty list for None request."""
- proc = SkillsRequestProcessor(sample_repo)
- result = await proc.process_llm_request(ctx, None)
- assert result == []
-
- @pytest.mark.asyncio
- async def test_none_ctx_returns_empty(self, sample_repo):
- """Returns empty list for None ctx."""
- proc = SkillsRequestProcessor(sample_repo)
- request = LlmRequest(model="test-sp-model")
- result = await proc.process_llm_request(None, request)
- assert result == []
-
- @pytest.mark.asyncio
- async def test_overview_injected_no_loaded(self, ctx, sample_repo):
- """Overview is injected even when no skills are loaded."""
- proc = SkillsRequestProcessor(sample_repo)
- request = LlmRequest(model="test-sp-model")
- result = await proc.process_llm_request(ctx, request)
- assert result == []
- assert "code_review" in str(request.config.system_instruction)
-
- @pytest.mark.asyncio
- async def test_loaded_skill_body_injected(self, ctx, sample_repo):
- """Loaded skill body is injected into system instruction."""
- proc = SkillsRequestProcessor(sample_repo, load_mode="session")
- ctx.session.state[SKILL_LOADED_STATE_KEY_PREFIX + "code_review"] = "1"
- request = LlmRequest(model="test-sp-model")
- result = await proc.process_llm_request(ctx, request)
- assert "code_review" in result
- sys_instr = str(request.config.system_instruction)
- assert "Review PR diffs" in sys_instr
-
- @pytest.mark.asyncio
- async def test_tool_result_mode_skips_body_injection(self, ctx, sample_repo):
- """In tool_result_mode, loaded skill bodies are NOT injected."""
- proc = SkillsRequestProcessor(sample_repo, tool_result_mode=True, load_mode="session")
- ctx.session.state[SKILL_LOADED_STATE_KEY_PREFIX + "code_review"] = "1"
- request = LlmRequest(model="test-sp-model")
- result = await proc.process_llm_request(ctx, request)
- assert "code_review" in result
- sys_instr = str(request.config.system_instruction) if request.config and request.config.system_instruction else ""
- assert "Review PR diffs" not in sys_instr
-
- @pytest.mark.asyncio
- async def test_once_mode_offloads_after_injection(self, ctx, sample_repo):
- """In once mode, skill state is cleared after injection."""
- proc = SkillsRequestProcessor(sample_repo, load_mode="once")
- ctx.session.state[SKILL_LOADED_STATE_KEY_PREFIX + "code_review"] = "1"
- request = LlmRequest(model="test-sp-model")
- await proc.process_llm_request(ctx, request)
- assert ctx.actions.state_delta.get(SKILL_LOADED_STATE_KEY_PREFIX + "code_review") is None
-
- @pytest.mark.asyncio
- async def test_none_repo_returns_empty(self, ctx):
- """Returns empty list when repository resolves to None."""
- proc = SkillsRequestProcessor(
- MagicMock(),
- repo_resolver=lambda c: None,
- )
- request = LlmRequest(model="test-sp-model")
- result = await proc.process_llm_request(ctx, request)
- assert result == []
-
-
-# ---------------------------------------------------------------------------
-# _get_loaded_skill_order
-# ---------------------------------------------------------------------------
-
-
-class TestGetLoadedSkillOrder:
- def test_no_persisted_order(self, ctx, sample_repo):
- """Missing order key returns alphabetical order."""
- proc = SkillsRequestProcessor(sample_repo)
- order = proc._get_loaded_skill_order(ctx, ["beta", "alpha"])
- assert order == ["alpha", "beta"]
-
- def test_persisted_order_respected(self, ctx, sample_repo):
- """Persisted order is respected for known skills."""
- proc = SkillsRequestProcessor(sample_repo)
- ctx.session.state[_SKILLS_LOADED_ORDER_STATE_KEY] = json.dumps(["beta", "alpha"])
- order = proc._get_loaded_skill_order(ctx, ["alpha", "beta"])
- assert order == ["beta", "alpha"]
-
- def test_new_skills_appended_alphabetically(self, ctx, sample_repo):
- """Skills not in persisted order are appended alphabetically."""
- proc = SkillsRequestProcessor(sample_repo)
- ctx.session.state[_SKILLS_LOADED_ORDER_STATE_KEY] = json.dumps(["beta"])
- order = proc._get_loaded_skill_order(ctx, ["alpha", "beta", "gamma"])
- assert order == ["beta", "alpha", "gamma"]
-
- def test_invalid_json_falls_back(self, ctx, sample_repo):
- """Invalid JSON in order key falls back to alphabetical."""
- proc = SkillsRequestProcessor(sample_repo)
- ctx.session.state[_SKILLS_LOADED_ORDER_STATE_KEY] = "not-json"
- order = proc._get_loaded_skill_order(ctx, ["c", "a", "b"])
- assert order == ["a", "b", "c"]
+def _build_repo(skills: dict[str, Skill]) -> Mock:
+ repo = Mock()
+ repo.summaries.return_value = [skill.summary for skill in skills.values()]
+ repo.user_prompt.return_value = "repo user prompt"
+ repo.get.side_effect = lambda name: skills.get(name)
+ return repo
+
+
+def _build_ctx(state: dict, *, agent_name: str = "demo-agent", load_mode: str = "turn") -> Mock:
+ ctx = Mock(spec=InvocationContext)
+ ctx.agent_name = agent_name
+ ctx.session_state = state
+ ctx.actions = Mock()
+ ctx.actions.state_delta = {}
+ ctx.agent_context = AgentContext()
+ config = deepcopy(DEFAULT_SKILL_CONFIG)
+ config["skill_processor"]["load_mode"] = load_mode
+ set_skill_config(ctx.agent_context, config)
+ ctx.session = Mock()
+ ctx.session.state = state
+ ctx.session.events = []
+ return ctx
+
+
+class TestSkillsRequestProcessor:
+
+ async def test_injects_overview_and_loaded_content(self):
+ skill = _build_skill("demo-skill")
+ repo = _build_repo({"demo-skill": skill})
+ loaded_state_key = loaded_session_key(loaded_key("demo-agent", "demo-skill"))
+ docs_state_key = docs_session_key(docs_key("demo-agent", "demo-skill"))
+ tools_state_key = tool_session_key(tool_key("demo-agent", "demo-skill"))
+ ctx = _build_ctx({
+ loaded_state_key: True,
+ docs_state_key: json.dumps(["docs/guide.md"]),
+ tools_state_key: json.dumps(["tool-a"]),
+ }, load_mode="session")
+ request = LlmRequest(contents=[], tools_dict={})
+ processor = SkillsRequestProcessor(repo, load_mode="session")
+
+ loaded = await processor.process_llm_request(ctx, request)
+
+ assert loaded == ["demo-skill"]
+ assert request.config is not None
+ text = request.config.system_instruction
+ assert "repo user prompt" in text
+ assert "Available skills:" in text
+ assert "- demo-skill: demo-skill description" in text
+ assert "[Loaded] demo-skill" in text
+ assert "Docs loaded: docs/guide.md" in text
+ assert "[Doc] docs/guide.md" in text
+ assert "Tools selected: tool-a" in text
+
+ async def test_tool_result_mode_skips_loaded_materialization(self):
+ skill = _build_skill("demo-skill")
+ repo = _build_repo({"demo-skill": skill})
+ loaded_state_key = loaded_session_key(loaded_key("demo-agent", "demo-skill"))
+ ctx = _build_ctx({loaded_state_key: True}, load_mode="session")
+ request = LlmRequest(contents=[], tools_dict={})
+ processor = SkillsRequestProcessor(repo, tool_result_mode=True, load_mode="session")
+
+ loaded = await processor.process_llm_request(ctx, request)
+
+ assert loaded == ["demo-skill"]
+ assert request.config is not None
+ text = request.config.system_instruction
+ assert "Available skills:" in text
+ assert "[Loaded] demo-skill" not in text
+
+ async def test_once_mode_offloads_loaded_state(self):
+ skill = _build_skill("demo-skill")
+ repo = _build_repo({"demo-skill": skill})
+ ctx = _build_ctx({
+ loaded_key("demo-agent", "demo-skill"): True,
+ docs_key("demo-agent", "demo-skill"): json.dumps(["docs/guide.md"]),
+ tool_key("demo-agent", "demo-skill"): json.dumps(["tool-a"]),
+ })
+ request = LlmRequest(contents=[], tools_dict={})
+ processor = SkillsRequestProcessor(repo, load_mode="once")
+
+ await processor.process_llm_request(ctx, request)
+
+ assert ctx.actions.state_delta[loaded_key("demo-agent", "demo-skill")] is None
+ assert ctx.actions.state_delta[docs_key("demo-agent", "demo-skill")] is None
+ assert ctx.actions.state_delta[tool_key("demo-agent", "demo-skill")] is None
+ assert ctx.actions.state_delta[loaded_order_key("demo-agent")] is None
+
+ async def test_turn_mode_clears_previous_skill_state(self):
+ skill = _build_skill("demo-skill")
+ repo = _build_repo({"demo-skill": skill})
+ ctx = _build_ctx({
+ loaded_key("demo-agent", "demo-skill"): True,
+ docs_key("demo-agent", "demo-skill"): json.dumps(["docs/guide.md"]),
+ tool_key("demo-agent", "demo-skill"): json.dumps(["tool-a"]),
+ loaded_order_key("demo-agent"): json.dumps(["demo-skill"]),
+ })
+ request = LlmRequest(contents=[], tools_dict={})
+ processor = SkillsRequestProcessor(repo, load_mode="turn")
+
+ loaded = await processor.process_llm_request(ctx, request)
+
+ assert loaded == []
+ assert ctx.actions.state_delta[loaded_key("demo-agent", "demo-skill")] is None
+ assert ctx.actions.state_delta[docs_key("demo-agent", "demo-skill")] is None
+ assert ctx.actions.state_delta[tool_key("demo-agent", "demo-skill")] is None
+ assert ctx.actions.state_delta[loaded_order_key("demo-agent")] is None
+ assert ctx.agent_context.get_metadata("processor:skills:turn_init") is True
+
+ async def test_session_mode_uses_temp_only_state(self):
+ skill = _build_skill("demo-skill")
+ repo = _build_repo({"demo-skill": skill})
+ loaded_state_key = loaded_session_key(loaded_key("demo-agent", "demo-skill"))
+ docs_state_key = docs_session_key(docs_key("demo-agent", "demo-skill"))
+ tools_state_key = tool_session_key(tool_key("demo-agent", "demo-skill"))
+ session_order_key = loaded_session_order_key(loaded_order_key("demo-agent"))
+ ctx = _build_ctx({
+ loaded_state_key: True,
+ docs_state_key: json.dumps(["docs/guide.md"]),
+ tools_state_key: json.dumps(["tool-a"]),
+ session_order_key: json.dumps(["demo-skill"]),
+ }, load_mode="session")
+ request = LlmRequest(contents=[], tools_dict={})
+ processor = SkillsRequestProcessor(repo, load_mode="session")
+
+ loaded = await processor.process_llm_request(ctx, request)
+
+ assert loaded == ["demo-skill"]
+ assert loaded_key("demo-agent", "demo-skill") not in ctx.actions.state_delta
+ assert docs_key("demo-agent", "demo-skill") not in ctx.actions.state_delta
+ assert tool_key("demo-agent", "demo-skill") not in ctx.actions.state_delta
+
+ async def test_max_loaded_skills_evicts_lru_skills_in_session_state(self):
+ skill_a = _build_skill("skill-a")
+ skill_b = _build_skill("skill-b")
+ repo = _build_repo({"skill-a": skill_a, "skill-b": skill_b})
+ loaded_a = loaded_session_key(loaded_key("demo-agent", "skill-a"))
+ loaded_b = loaded_session_key(loaded_key("demo-agent", "skill-b"))
+ docs_a = docs_session_key(docs_key("demo-agent", "skill-a"))
+ docs_b = docs_session_key(docs_key("demo-agent", "skill-b"))
+ tools_a = tool_session_key(tool_key("demo-agent", "skill-a"))
+ tools_b = tool_session_key(tool_key("demo-agent", "skill-b"))
+ order_key = loaded_session_order_key(loaded_order_key("demo-agent"))
+ ctx = _build_ctx({
+ loaded_a: True,
+ loaded_b: True,
+ docs_a: json.dumps(["docs/guide.md"]),
+ docs_b: json.dumps(["docs/guide.md"]),
+ tools_a: json.dumps(["tool-a"]),
+ tools_b: json.dumps(["tool-a"]),
+ order_key: json.dumps(["skill-a", "skill-b"]),
+ }, load_mode="session")
+ request = LlmRequest(contents=[], tools_dict={})
+ processor = SkillsRequestProcessor(repo, max_loaded_skills=1, load_mode="session")
+
+ loaded = await processor.process_llm_request(ctx, request)
+
+ assert loaded == ["skill-b"]
+ assert ctx.actions.state_delta[loaded_a] is None
+ assert ctx.actions.state_delta[docs_a] is None
+ assert ctx.actions.state_delta[tools_a] is None
+ assert ctx.actions.state_delta[order_key] == json.dumps(["skill-b"])
+
+
+class TestSkillProcessorParameters:
+
+ def test_set_and_get_skill_processor_parameters(self):
+ agent_context = AgentContext()
+ parameters = {"load_mode": "session", "max_loaded_skills": 3}
+
+ set_skill_processor_parameters(agent_context, parameters)
+ got = get_skill_processor_parameters(agent_context)
+
+ assert got["load_mode"] == "session"
+ assert got["max_loaded_skills"] == 3
diff --git a/tests/agents/core/test_skill_tool_result_processor.py b/tests/agents/core/test_skill_tool_result_processor.py
new file mode 100644
index 0000000..b635677
--- /dev/null
+++ b/tests/agents/core/test_skill_tool_result_processor.py
@@ -0,0 +1,136 @@
+# Tencent is pleased to support the open source community by making tRPC-Agent-Python available.
+#
+# Copyright (C) 2026 Tencent. All rights reserved.
+#
+# tRPC-Agent-Python is licensed under Apache-2.0.
+
+import json
+from copy import deepcopy
+from unittest.mock import Mock
+
+import trpc_agent_sdk.skills as _skills_pkg
+
+if not hasattr(_skills_pkg, "get_skill_processor_parameters"):
+
+ def _compat_get_skill_processor_parameters(agent_context):
+ from trpc_agent_sdk.agents.core._skill_processor import get_skill_processor_parameters as _impl
+ return _impl(agent_context)
+
+ _skills_pkg.get_skill_processor_parameters = _compat_get_skill_processor_parameters
+
+from trpc_agent_sdk.agents.core._skills_tool_result_processor import SkillsToolResultRequestProcessor
+from trpc_agent_sdk.context import AgentContext
+from trpc_agent_sdk.context import InvocationContext
+from trpc_agent_sdk.models import LlmRequest
+from trpc_agent_sdk.skills import Skill
+from trpc_agent_sdk.skills import SkillResource
+from trpc_agent_sdk.skills import SkillSummary
+from trpc_agent_sdk.skills import SkillToolsNames
+from trpc_agent_sdk.skills import docs_key
+from trpc_agent_sdk.skills import loaded_key
+from trpc_agent_sdk.skills import loaded_order_key
+from trpc_agent_sdk.skills._skill_config import DEFAULT_SKILL_CONFIG
+from trpc_agent_sdk.skills._skill_config import set_skill_config
+from trpc_agent_sdk.types import Content
+from trpc_agent_sdk.types import Part
+
+
+def _build_context(agent_name: str, state: dict, *, load_mode: str = "turn") -> Mock:
+ ctx = Mock(spec=InvocationContext)
+ ctx.agent_name = agent_name
+ ctx.session_state = state
+ ctx.actions = Mock()
+ ctx.actions.state_delta = {}
+ ctx.agent_context = AgentContext()
+ config = deepcopy(DEFAULT_SKILL_CONFIG)
+ config["skill_processor"]["load_mode"] = load_mode
+ set_skill_config(ctx.agent_context, config)
+ return ctx
+
+
+def _build_skill() -> Skill:
+ return Skill(
+ summary=SkillSummary(name="demo-skill", description="demo"),
+ body="Use this skill body.",
+ resources=[SkillResource(path="docs/guide.md", content="Guide content.")],
+ )
+
+
+class TestSkillsToolResultRequestProcessor:
+
+ async def test_materialize_tool_result_from_skill_load(self):
+ repo = Mock()
+ repo.get.return_value = _build_skill()
+
+ ctx = _build_context(
+ "demo-agent",
+ {
+ loaded_key("demo-agent", "demo-skill"): True,
+ docs_key("demo-agent", "demo-skill"): json.dumps(["docs/guide.md"]),
+ },
+ )
+ processor = SkillsToolResultRequestProcessor(repo)
+
+ call_part = Part.from_function_call(name=SkillToolsNames.LOAD, args={"skill": "demo-skill"})
+ call_part.function_call.id = "call_1"
+ response_part = Part.from_function_response(name=SkillToolsNames.LOAD, response={"result": "skill 'demo-skill' loaded"})
+ response_part.function_response.id = "call_1"
+ request = LlmRequest(
+ contents=[
+ Content(role="model", parts=[call_part]),
+ Content(role="user", parts=[response_part]),
+ ],
+ tools_dict={},
+ )
+
+ loaded = await processor.process_llm_request(ctx, request)
+
+ assert loaded == ["demo-skill"]
+ result = response_part.function_response.response["result"]
+ assert "[Loaded] demo-skill" in result
+ assert "Docs loaded: docs/guide.md" in result
+ assert "[Doc] docs/guide.md" in result
+
+ async def test_fallback_to_system_instruction_when_tool_result_missing(self):
+ repo = Mock()
+ repo.get.return_value = _build_skill()
+ ctx = _build_context(
+ "demo-agent",
+ {loaded_key("demo-agent", "demo-skill"): True},
+ )
+ processor = SkillsToolResultRequestProcessor(repo)
+ request = LlmRequest(contents=[Content(role="user", parts=[Part.from_text(text="hello")])], tools_dict={})
+
+ await processor.process_llm_request(ctx, request)
+
+ assert request.config is not None
+ assert "Loaded skill context:" in request.config.system_instruction
+ assert "[Loaded] demo-skill" in request.config.system_instruction
+
+ async def test_once_mode_offloads_loaded_skill_state(self):
+ repo = Mock()
+ repo.get.return_value = _build_skill()
+ state = {
+ loaded_key("demo-agent", "demo-skill"): True,
+ docs_key("demo-agent", "demo-skill"): json.dumps(["docs/guide.md"]),
+ }
+ ctx = _build_context("demo-agent", state, load_mode="once")
+ processor = SkillsToolResultRequestProcessor(repo)
+
+ call_part = Part.from_function_call(name=SkillToolsNames.LOAD, args={"skill": "demo-skill"})
+ call_part.function_call.id = "call_1"
+ response_part = Part.from_function_response(name=SkillToolsNames.LOAD, response={"result": "skill 'demo-skill' loaded"})
+ response_part.function_response.id = "call_1"
+ request = LlmRequest(
+ contents=[
+ Content(role="model", parts=[call_part]),
+ Content(role="user", parts=[response_part]),
+ ],
+ tools_dict={},
+ )
+
+ await processor.process_llm_request(ctx, request)
+
+ assert ctx.actions.state_delta[loaded_key("demo-agent", "demo-skill")] is None
+ assert ctx.actions.state_delta[docs_key("demo-agent", "demo-skill")] is None
+ assert ctx.actions.state_delta[loaded_order_key("demo-agent")] is None
diff --git a/tests/agents/core/test_tools_processor.py b/tests/agents/core/test_tools_processor.py
index a7c49d5..5ebe512 100644
--- a/tests/agents/core/test_tools_processor.py
+++ b/tests/agents/core/test_tools_processor.py
@@ -12,6 +12,15 @@
from unittest.mock import AsyncMock, Mock, patch
import pytest
+import trpc_agent_sdk.skills as _skills_pkg
+
+if not hasattr(_skills_pkg, "get_skill_processor_parameters"):
+
+ def _compat_get_skill_processor_parameters(agent_context):
+ from trpc_agent_sdk.agents.core._skill_processor import get_skill_processor_parameters as _impl
+ return _impl(agent_context)
+
+ _skills_pkg.get_skill_processor_parameters = _compat_get_skill_processor_parameters
from trpc_agent_sdk.agents._base_agent import BaseAgent
from trpc_agent_sdk.agents.core._tools_processor import ToolsProcessor
diff --git a/tests/agents/core/test_workspace_exec_processor.py b/tests/agents/core/test_workspace_exec_processor.py
new file mode 100644
index 0000000..4fa100e
--- /dev/null
+++ b/tests/agents/core/test_workspace_exec_processor.py
@@ -0,0 +1,127 @@
+# Tencent is pleased to support the open source community by making tRPC-Agent-Python available.
+#
+# Copyright (C) 2026 Tencent. All rights reserved.
+#
+# tRPC-Agent-Python is licensed under Apache-2.0.
+
+from unittest.mock import Mock
+
+import trpc_agent_sdk.skills as _skills_pkg
+
+if not hasattr(_skills_pkg, "get_skill_processor_parameters"):
+
+ def _compat_get_skill_processor_parameters(agent_context):
+ from trpc_agent_sdk.agents.core._skill_processor import get_skill_processor_parameters as _impl
+ return _impl(agent_context)
+
+ _skills_pkg.get_skill_processor_parameters = _compat_get_skill_processor_parameters
+
+from trpc_agent_sdk.agents.core._workspace_exec_processor import WorkspaceExecRequestProcessor
+from trpc_agent_sdk.agents.core._workspace_exec_processor import get_workspace_exec_processor_parameters
+from trpc_agent_sdk.agents.core._workspace_exec_processor import set_workspace_exec_processor_parameters
+from trpc_agent_sdk.context import AgentContext
+from trpc_agent_sdk.context import InvocationContext
+from trpc_agent_sdk.models import LlmRequest
+
+
+def _build_ctx() -> Mock:
+ ctx = Mock(spec=InvocationContext)
+ ctx.agent_name = "demo-agent"
+ return ctx
+
+
+def _build_request_with_tools(*tool_names: str) -> LlmRequest:
+ request = LlmRequest(contents=[], tools_dict={})
+ config = Mock()
+ config.system_instruction = ""
+ config.tools = []
+ for name in tool_names:
+ declaration = Mock()
+ declaration.name = name
+ tool = Mock()
+ tool.function_declarations = [declaration]
+ config.tools.append(tool)
+ request.config = config
+ return request
+
+
+class TestWorkspaceExecRequestProcessor:
+
+ async def test_injects_guidance_when_workspace_exec_present(self):
+ processor = WorkspaceExecRequestProcessor()
+ ctx = _build_ctx()
+ request = _build_request_with_tools("workspace_exec")
+
+ await processor.process_llm_request(ctx, request)
+
+ assert "Executor workspace guidance:" in request.config.system_instruction
+ assert "workspace_exec" in request.config.system_instruction
+
+ async def test_does_not_duplicate_guidance_header(self):
+ processor = WorkspaceExecRequestProcessor()
+ ctx = _build_ctx()
+ request = _build_request_with_tools("workspace_exec")
+ request.config.system_instruction = "Executor workspace guidance:\nexisting"
+
+ await processor.process_llm_request(ctx, request)
+
+ assert request.config.system_instruction == "Executor workspace guidance:\nexisting"
+
+ async def test_includes_artifact_and_session_hints_when_tools_available(self):
+ processor = WorkspaceExecRequestProcessor()
+ ctx = _build_ctx()
+ request = _build_request_with_tools(
+ "workspace_exec",
+ "workspace_save_artifact",
+ "workspace_write_stdin",
+ "workspace_kill_session",
+ )
+
+ await processor.process_llm_request(ctx, request)
+
+ text = request.config.system_instruction
+ assert "workspace_save_artifact" in text
+ assert "workspace_write_stdin" in text
+ assert "workspace_kill_session" in text
+
+ async def test_includes_skills_repo_hint_via_repo_resolver(self):
+ processor = WorkspaceExecRequestProcessor(repo_resolver=lambda _ctx: object())
+ ctx = _build_ctx()
+ request = _build_request_with_tools("workspace_exec")
+
+ await processor.process_llm_request(ctx, request)
+
+ assert "Paths under skills/" in request.config.system_instruction
+
+ async def test_respects_enabled_resolver(self):
+ processor = WorkspaceExecRequestProcessor(enabled_resolver=lambda _ctx: False)
+ ctx = _build_ctx()
+ request = _build_request_with_tools("workspace_exec")
+
+ await processor.process_llm_request(ctx, request)
+
+ assert request.config.system_instruction == ""
+
+ async def test_sessions_resolver_can_enable_session_guidance_without_session_tools(self):
+ processor = WorkspaceExecRequestProcessor(sessions_resolver=lambda _ctx: True)
+ ctx = _build_ctx()
+ request = _build_request_with_tools("workspace_exec")
+
+ await processor.process_llm_request(ctx, request)
+
+ text = request.config.system_instruction
+ assert "workspace_write_stdin" in text
+ assert "workspace_kill_session" in text
+
+
+class TestWorkspaceExecProcessorParameters:
+
+ def test_set_and_get_workspace_exec_processor_parameters(self):
+ agent_context = AgentContext()
+ parameters = {"session_tools": True, "has_skills_repo": True}
+
+ set_workspace_exec_processor_parameters(agent_context, parameters)
+ got = get_workspace_exec_processor_parameters(agent_context)
+
+ assert got["session_tools"] is True
+ assert got["has_skills_repo"] is True
diff --git a/tests/skills/__init__.py b/tests/skills/__init__.py
index e69de29..8b13789 100644
--- a/tests/skills/__init__.py
+++ b/tests/skills/__init__.py
@@ -0,0 +1 @@
+
diff --git a/tests/skills/stager/__init__.py b/tests/skills/stager/__init__.py
index e69de29..8b13789 100644
--- a/tests/skills/stager/__init__.py
+++ b/tests/skills/stager/__init__.py
@@ -0,0 +1 @@
+
diff --git a/tests/skills/stager/test_base_stager.py b/tests/skills/stager/test_base_stager.py
index c5a8a00..aecc774 100644
--- a/tests/skills/stager/test_base_stager.py
+++ b/tests/skills/stager/test_base_stager.py
@@ -67,14 +67,13 @@ def _make_repository(path="/skills/test-skill"):
return repo
-def _make_request(skill_name="test-skill", runtime=None, repo=None, ws=None, ctx=None):
+def _make_request(skill_name="test-skill", repo=None, ws=None, ctx=None):
from trpc_agent_sdk.skills.stager._types import SkillStageRequest
return SkillStageRequest(
skill_name=skill_name,
repository=repo or _make_repository(),
workspace=ws or _make_workspace(),
ctx=ctx or _make_ctx(),
- engine=runtime,
)
@@ -262,7 +261,7 @@ class TestStageSkill:
async def test_fresh_staging(self, mock_digest):
stager = Stager()
request = _make_request()
- runtime = request.engine or request.repository.workspace_runtime
+ runtime = request.repository.workspace_runtime
mock_file = MagicMock()
mock_file.content = json.dumps({"version": 1, "skills": {}})
@@ -275,7 +274,7 @@ async def test_fresh_staging(self, mock_digest):
async def test_cached_staging_with_links(self, mock_digest):
stager = Stager()
request = _make_request()
- runtime = request.engine or request.repository.workspace_runtime
+ runtime = request.repository.workspace_runtime
md_data = {
"version": 1,
diff --git a/tests/skills/stager/test_types.py b/tests/skills/stager/test_types.py
index efef3ee..b269c9c 100644
--- a/tests/skills/stager/test_types.py
+++ b/tests/skills/stager/test_types.py
@@ -26,7 +26,6 @@ def test_creation(self):
ctx=MagicMock(),
)
assert req.skill_name == "test"
- assert req.engine is None
assert req.timeout == 300.0
def test_custom_timeout(self):
diff --git a/tests/skills/test_common.py b/tests/skills/test_common.py
index a8f7d85..221c7f6 100644
--- a/tests/skills/test_common.py
+++ b/tests/skills/test_common.py
@@ -3,308 +3,530 @@
# Copyright (C) 2026 Tencent. All rights reserved.
#
# tRPC-Agent-Python is licensed under Apache-2.0.
-"""Unit tests for trpc_agent_sdk.skills._common.
-
-Covers:
-- SelectionMode enum
-- BaseSelectionResult model
-- get_state_delta_value
-- get_previous_selection
-- clear_selection, add_selection, replace_selection
-- set_state_delta_for_selection
-- generic_select_items (all modes, edge cases)
-- generic_get_selection (all branches)
-"""
-
-from __future__ import annotations
-
-import json
-from unittest.mock import MagicMock, Mock
import pytest
-from pydantic import BaseModel, Field
+import json
+from unittest.mock import Mock
+from pydantic import Field
from trpc_agent_sdk.skills._common import (
- BaseSelectionResult,
+ get_state_delta_value,
SelectionMode,
- add_selection,
- clear_selection,
- generic_get_selection,
- generic_select_items,
+ BaseSelectionResult,
+ append_loaded_order_state_delta,
get_previous_selection,
- get_state_delta_value,
+ clear_selection,
+ add_selection,
replace_selection,
set_state_delta_for_selection,
+ generic_select_items,
+ generic_get_selection,
)
+from trpc_agent_sdk.context import InvocationContext
+from trpc_agent_sdk.skills._state_keys import loaded_order_key
-class _TestSelectionResult(BaseSelectionResult):
- """Concrete test subclass for BaseSelectionResult."""
+class MockSelectionResult(BaseSelectionResult):
+ """Mock selection result class for testing."""
selected_items: list[str] = Field(default_factory=list)
include_all: bool = Field(default=False)
-def _make_ctx(state_delta=None, session_state=None):
- ctx = MagicMock()
- ctx.actions.state_delta = state_delta or {}
- ctx.session_state = session_state or {}
- return ctx
+@pytest.fixture(autouse=True)
+def _mock_turn_load_mode(monkeypatch):
+ monkeypatch.setattr("trpc_agent_sdk.skills._common.get_skill_load_mode", lambda _ctx: "turn")
+
+
+class TestGetStateDeltaValue:
+ """Test suite for get_state_delta_value function."""
+
+ def test_get_from_state_delta(self):
+ """Test getting value from state_delta."""
+ mock_ctx = Mock(spec=InvocationContext)
+ mock_ctx.actions = Mock()
+ mock_ctx.actions.state_delta = {"key1": "value1"}
+ mock_ctx.session_state = {"key1": "old_value"}
+ result = get_state_delta_value(mock_ctx, "key1")
+
+ assert result == "value1"
+
+ def test_get_from_session_state(self):
+ """Test getting value from session_state when not in state_delta."""
+ mock_ctx = Mock(spec=InvocationContext)
+ mock_ctx.actions = Mock()
+ mock_ctx.actions.state_delta = {}
+ mock_ctx.session_state = {"key1": "value1"}
+
+ result = get_state_delta_value(mock_ctx, "key1")
+
+ assert result == "value1"
+
+ def test_get_nonexistent_key(self):
+ """Test getting nonexistent key returns None."""
+ mock_ctx = Mock(spec=InvocationContext)
+ mock_ctx.actions = Mock()
+ mock_ctx.actions.state_delta = {}
+ mock_ctx.session_state = {}
+
+ result = get_state_delta_value(mock_ctx, "nonexistent")
+
+ assert result is None
-# ---------------------------------------------------------------------------
-# SelectionMode
-# ---------------------------------------------------------------------------
class TestSelectionMode:
- def test_values(self):
- assert SelectionMode.ADD == "add"
- assert SelectionMode.REPLACE == "replace"
- assert SelectionMode.CLEAR == "clear"
+ """Test suite for SelectionMode enum."""
- def test_from_string(self):
+ def test_selection_mode_values(self):
+ """Test SelectionMode enum values."""
+ assert SelectionMode.ADD.value == "add"
+ assert SelectionMode.REPLACE.value == "replace"
+ assert SelectionMode.CLEAR.value == "clear"
+
+ def test_selection_mode_from_string(self):
+ """Test creating SelectionMode from string."""
assert SelectionMode("add") == SelectionMode.ADD
+ assert SelectionMode("replace") == SelectionMode.REPLACE
+ assert SelectionMode("clear") == SelectionMode.CLEAR
+ def test_selection_mode_invalid_string(self):
+ """Test invalid string raises ValueError."""
+ with pytest.raises(ValueError):
+ SelectionMode("invalid")
-# ---------------------------------------------------------------------------
-# get_state_delta_value
-# ---------------------------------------------------------------------------
-class TestGetStateDeltaValue:
- def test_from_state_delta(self):
- ctx = _make_ctx(state_delta={"key": "delta_value"})
- assert get_state_delta_value(ctx, "key") == "delta_value"
+class TestBaseSelectionResult:
+ """Test suite for BaseSelectionResult class."""
- def test_from_session_state(self):
- ctx = _make_ctx(session_state={"key": "session_value"})
- assert get_state_delta_value(ctx, "key") == "session_value"
+ def test_create_base_selection_result(self):
+ """Test creating BaseSelectionResult."""
+ result = BaseSelectionResult(skill="test-skill", mode="replace")
- def test_delta_takes_precedence(self):
- ctx = _make_ctx(state_delta={"k": "delta"}, session_state={"k": "session"})
- assert get_state_delta_value(ctx, "k") == "delta"
+ assert result.skill == "test-skill"
+ assert result.mode == "replace"
- def test_missing_returns_none(self):
- ctx = _make_ctx()
- assert get_state_delta_value(ctx, "missing") is None
+ def test_default_mode(self):
+ """Test default mode is empty string."""
+ result = BaseSelectionResult(skill="test-skill")
+ assert result.mode == ""
-# ---------------------------------------------------------------------------
-# get_previous_selection
-# ---------------------------------------------------------------------------
class TestGetPreviousSelection:
- def test_no_value_returns_empty_list(self):
- ctx = _make_ctx()
- result = get_previous_selection(ctx, "prefix:", "skill")
- assert result == []
+ """Test suite for get_previous_selection function."""
+
+ def test_get_previous_selection_json(self):
+ """Test getting previous selection from JSON."""
+ mock_ctx = Mock(spec=InvocationContext)
+ mock_ctx.session_state = {"temp:skill:docs:test-skill": json.dumps(["doc1.md", "doc2.md"])}
+
+ result = get_previous_selection(mock_ctx, "temp:skill:docs:", "test-skill")
+
+ assert result == ["doc1.md", "doc2.md"]
+
+ def test_get_previous_selection_all(self):
+ """Test getting previous selection when all items selected."""
+ mock_ctx = Mock(spec=InvocationContext)
+ mock_ctx.session_state = {"temp:skill:docs:test-skill": '*'}
+
+ result = get_previous_selection(mock_ctx, "temp:skill:docs:", "test-skill")
- def test_star_returns_none(self):
- ctx = _make_ctx(session_state={"prefix:skill": "*"})
- result = get_previous_selection(ctx, "prefix:", "skill")
assert result is None
- def test_json_array(self):
- ctx = _make_ctx(session_state={"prefix:skill": json.dumps(["a", "b"])})
- result = get_previous_selection(ctx, "prefix:", "skill")
- assert result == ["a", "b"]
+ def test_get_previous_selection_not_found(self):
+ """Test getting previous selection when not found."""
+ mock_ctx = Mock(spec=InvocationContext)
+ mock_ctx.session_state = {}
- def test_invalid_json_returns_empty(self):
- ctx = _make_ctx(session_state={"prefix:skill": "not json"})
- result = get_previous_selection(ctx, "prefix:", "skill")
- assert result == []
+ result = get_previous_selection(mock_ctx, "temp:skill:docs:", "test-skill")
- def test_empty_string_returns_empty(self):
- ctx = _make_ctx(session_state={"prefix:skill": ""})
- result = get_previous_selection(ctx, "prefix:", "skill")
assert result == []
-# ---------------------------------------------------------------------------
-# clear_selection
-# ---------------------------------------------------------------------------
+class TestAppendLoadedOrderStateDelta:
+ """Test suite for append_loaded_order_state_delta function."""
+
+ def test_appends_to_temp_order_when_temp_exists(self):
+ mock_ctx = Mock(spec=InvocationContext)
+ temp_key = loaded_order_key("demo-agent")
+ mock_ctx.session_state = {temp_key: json.dumps(["skill-a"])}
+ mock_ctx.actions = Mock()
+ mock_ctx.actions.state_delta = {}
+ mock_ctx.agent_context = Mock()
+ mock_ctx.agent_context.get_metadata = Mock(side_effect=lambda _key, default=None: default)
+
+ append_loaded_order_state_delta(mock_ctx, "demo-agent", "skill-b")
+
+ assert mock_ctx.actions.state_delta[temp_key] == json.dumps(["skill-a", "skill-b"])
+
+ def test_initializes_order_when_missing(self):
+ mock_ctx = Mock(spec=InvocationContext)
+ temp_key = loaded_order_key("demo-agent")
+ mock_ctx.session_state = {}
+ mock_ctx.actions = Mock()
+ mock_ctx.actions.state_delta = {}
+ mock_ctx.agent_context = Mock()
+ mock_ctx.agent_context.get_metadata = Mock(side_effect=lambda _key, default=None: default)
+
+ append_loaded_order_state_delta(mock_ctx, "demo-agent", "skill-b")
+
+ assert mock_ctx.actions.state_delta[temp_key] == json.dumps(["skill-b"])
+
+ def test_get_previous_selection_invalid_json(self):
+ """Test getting previous selection with invalid JSON."""
+ mock_ctx = Mock(spec=InvocationContext)
+ mock_ctx.session_state = {"temp:skill:docs:test-skill": "invalid json"}
+
+ result = get_previous_selection(mock_ctx, "temp:skill:docs:", "test-skill")
+
+ assert result == []
+
class TestClearSelection:
- def test_clear(self):
- result = clear_selection("skill", ["a", "b"], True, ["old"], _TestSelectionResult)
- assert result.skill == "skill"
+ """Test suite for clear_selection function."""
+
+ def test_clear_selection(self):
+ """Test clearing selection."""
+ result = clear_selection(skill_name="test-skill",
+ items=["item1"],
+ include_all=False,
+ previous_items=["item1", "item2"],
+ result_class=MockSelectionResult)
+
+ assert isinstance(result, MockSelectionResult)
+ assert result.skill == "test-skill"
+ assert result.mode == "clear"
assert result.selected_items == []
assert result.include_all is False
- assert result.mode == "clear"
-
-# ---------------------------------------------------------------------------
-# add_selection
-# ---------------------------------------------------------------------------
class TestAddSelection:
- def test_add_to_empty(self):
- result = add_selection("skill", ["a", "b"], False, [], _TestSelectionResult)
- assert set(result.selected_items) == {"a", "b"}
- assert result.include_all is False
+ """Test suite for add_selection function."""
+
+ def test_add_selection(self):
+ """Test adding to selection."""
+ result = add_selection(skill_name="test-skill",
+ items=["item2", "item3"],
+ include_all=False,
+ previous_items=["item1"],
+ result_class=MockSelectionResult)
+
+ assert isinstance(result, MockSelectionResult)
+ assert result.skill == "test-skill"
assert result.mode == "add"
+ assert set(result.selected_items) == {"item1", "item2", "item3"}
+ assert result.include_all is False
- def test_add_to_existing(self):
- result = add_selection("skill", ["c"], False, ["a", "b"], _TestSelectionResult)
- assert set(result.selected_items) == {"a", "b", "c"}
-
- def test_add_deduplicate(self):
- result = add_selection("skill", ["a"], False, ["a"], _TestSelectionResult)
- assert result.selected_items == ["a"]
+ def test_add_selection_duplicates(self):
+ """Test adding selection removes duplicates."""
+ result = add_selection(skill_name="test-skill",
+ items=["item1", "item2"],
+ include_all=False,
+ previous_items=["item1"],
+ result_class=MockSelectionResult)
+
+ assert len(result.selected_items) == 2
+ assert result.selected_items.count("item1") == 1
+
+ def test_add_selection_include_all(self):
+ """Test adding selection with include_all."""
+ result = add_selection(skill_name="test-skill",
+ items=["item2"],
+ include_all=True,
+ previous_items=["item1"],
+ result_class=MockSelectionResult)
- def test_add_include_all(self):
- result = add_selection("skill", ["a"], True, ["b"], _TestSelectionResult)
assert result.selected_items == []
assert result.include_all is True
-# ---------------------------------------------------------------------------
-# replace_selection
-# ---------------------------------------------------------------------------
-
class TestReplaceSelection:
- def test_replace(self):
- result = replace_selection("skill", ["x", "y"], False, ["old"], _TestSelectionResult)
- assert result.selected_items == ["x", "y"]
+ """Test suite for replace_selection function."""
+
+ def test_replace_selection(self):
+ """Test replacing selection."""
+ result = replace_selection(skill_name="test-skill",
+ items=["item2", "item3"],
+ include_all=False,
+ previous_items=["item1"],
+ result_class=MockSelectionResult)
+
+ assert isinstance(result, MockSelectionResult)
+ assert result.skill == "test-skill"
assert result.mode == "replace"
+ assert result.selected_items == ["item2", "item3"]
+ assert result.include_all is False
+
+ def test_replace_selection_include_all(self):
+ """Test replacing selection with include_all."""
+ result = replace_selection(skill_name="test-skill",
+ items=["item2"],
+ include_all=True,
+ previous_items=["item1"],
+ result_class=MockSelectionResult)
- def test_replace_include_all(self):
- result = replace_selection("skill", ["x"], True, [], _TestSelectionResult)
assert result.selected_items == []
assert result.include_all is True
-# ---------------------------------------------------------------------------
-# set_state_delta_for_selection
-# ---------------------------------------------------------------------------
-
class TestSetStateDeltaForSelection:
- def test_sets_json_array(self):
- ctx = _make_ctx()
- result = _TestSelectionResult(skill="s", selected_items=["a", "b"], include_all=False)
- set_state_delta_for_selection(ctx, "prefix:", result)
- assert json.loads(ctx.actions.state_delta["prefix:s"]) == ["a", "b"]
+ """Test suite for set_state_delta_for_selection function."""
+
+ def test_set_state_delta_with_items(self):
+ """Test setting state delta with items."""
+ mock_ctx = Mock(spec=InvocationContext)
+ mock_ctx.actions = Mock()
+ mock_ctx.actions.state_delta = {}
+
+ result = MockSelectionResult(skill="test-skill",
+ selected_items=["item1", "item2"],
+ include_all=False,
+ mode="replace")
+
+ set_state_delta_for_selection(mock_ctx, "temp:skill:test:", result)
+
+ key = "temp:skill:test:test-skill"
+ assert key in mock_ctx.actions.state_delta
+ assert json.loads(mock_ctx.actions.state_delta[key]) == ["item1", "item2"]
+
+ def test_set_state_delta_include_all(self):
+ """Test setting state delta with include_all."""
+ mock_ctx = Mock(spec=InvocationContext)
+ mock_ctx.actions = Mock()
+ mock_ctx.actions.state_delta = {}
+
+ result = MockSelectionResult(skill="test-skill", selected_items=[], include_all=True, mode="replace")
+
+ set_state_delta_for_selection(mock_ctx, "temp:skill:test:", result)
- def test_sets_star_for_include_all(self):
- ctx = _make_ctx()
- result = _TestSelectionResult(skill="s", selected_items=[], include_all=True)
- set_state_delta_for_selection(ctx, "prefix:", result)
- assert ctx.actions.state_delta["prefix:s"] == "*"
+ key = "temp:skill:test:test-skill"
+ assert mock_ctx.actions.state_delta[key] == '*'
- def test_no_skill_is_noop(self):
- ctx = _make_ctx()
- result = _TestSelectionResult(skill="", selected_items=[])
- set_state_delta_for_selection(ctx, "prefix:", result)
- assert len(ctx.actions.state_delta) == 0
+ def test_set_state_delta_empty_skill(self):
+ """Test setting state delta with empty skill name does nothing."""
+ mock_ctx = Mock(spec=InvocationContext)
+ mock_ctx.actions = Mock()
+ mock_ctx.actions.state_delta = {}
+ result = MockSelectionResult(skill="", selected_items=["item1"], include_all=False, mode="replace")
+
+ set_state_delta_for_selection(mock_ctx, "temp:skill:test:", result)
+
+ assert len(mock_ctx.actions.state_delta) == 0
-# ---------------------------------------------------------------------------
-# generic_select_items
-# ---------------------------------------------------------------------------
class TestGenericSelectItems:
- def test_replace_mode(self):
- ctx = _make_ctx()
- result = generic_select_items(
- ctx, "skill", ["a", "b"], False, "replace", "prefix:", _TestSelectionResult
- )
- assert result.selected_items == ["a", "b"]
+ """Test suite for generic_select_items function."""
+
+ def test_generic_select_items_replace(self):
+ """Test generic select items with replace mode."""
+ mock_ctx = Mock(spec=InvocationContext)
+ mock_ctx.session_state = {}
+ mock_ctx.actions = Mock()
+ mock_ctx.actions.state_delta = {}
+
+ result = generic_select_items(tool_context=mock_ctx,
+ skill_name="test-skill",
+ items=["item1", "item2"],
+ include_all=False,
+ mode="replace",
+ state_key_prefix="temp:skill:test:",
+ result_class=MockSelectionResult)
+
+ assert isinstance(result, MockSelectionResult)
+ assert result.skill == "test-skill"
assert result.mode == "replace"
+ assert result.selected_items == ["item1", "item2"]
+
+ def test_generic_select_items_add(self):
+ """Test generic select items with add mode."""
+ mock_ctx = Mock(spec=InvocationContext)
+ mock_ctx.session_state = {"temp:skill:test:test-skill": json.dumps(["item1"])}
+ mock_ctx.actions = Mock()
+ mock_ctx.actions.state_delta = {}
+
+ result = generic_select_items(tool_context=mock_ctx,
+ skill_name="test-skill",
+ items=["item2"],
+ include_all=False,
+ mode="add",
+ state_key_prefix="temp:skill:test:",
+ result_class=MockSelectionResult)
- def test_add_mode(self):
- ctx = _make_ctx(session_state={"prefix:skill": json.dumps(["a"])})
- result = generic_select_items(
- ctx, "skill", ["b"], False, "add", "prefix:", _TestSelectionResult
- )
- assert set(result.selected_items) == {"a", "b"}
assert result.mode == "add"
+ assert set(result.selected_items) == {"item1", "item2"}
+
+ def test_generic_select_items_clear(self):
+ """Test generic select items with clear mode."""
+ mock_ctx = Mock(spec=InvocationContext)
+ mock_ctx.session_state = {"temp:skill:test:test-skill": json.dumps(["item1"])}
+ mock_ctx.actions = Mock()
+ mock_ctx.actions.state_delta = {}
+
+ result = generic_select_items(tool_context=mock_ctx,
+ skill_name="test-skill",
+ items=None,
+ include_all=False,
+ mode="clear",
+ state_key_prefix="temp:skill:test:",
+ result_class=MockSelectionResult)
- def test_clear_mode(self):
- ctx = _make_ctx(session_state={"prefix:skill": json.dumps(["a"])})
- result = generic_select_items(
- ctx, "skill", [], False, "clear", "prefix:", _TestSelectionResult
- )
- assert result.selected_items == []
assert result.mode == "clear"
+ assert result.selected_items == []
+
+ def test_generic_select_items_invalid_mode(self):
+ """Test generic select items with invalid mode defaults to replace."""
+ mock_ctx = Mock(spec=InvocationContext)
+ mock_ctx.session_state = {}
+ mock_ctx.actions = Mock()
+ mock_ctx.actions.state_delta = {}
+
+ result = generic_select_items(tool_context=mock_ctx,
+ skill_name="test-skill",
+ items=["item1"],
+ include_all=False,
+ mode="invalid",
+ state_key_prefix="temp:skill:test:",
+ result_class=MockSelectionResult)
- def test_invalid_mode_defaults_to_replace(self):
- ctx = _make_ctx()
- result = generic_select_items(
- ctx, "skill", ["x"], False, "invalid_mode", "prefix:", _TestSelectionResult
- )
assert result.mode == "replace"
- def test_previous_star_and_not_clearing_keeps_include_all(self):
- ctx = _make_ctx(session_state={"prefix:skill": "*"})
- result = generic_select_items(
- ctx, "skill", ["a"], False, "add", "prefix:", _TestSelectionResult
- )
+ def test_generic_select_items_previous_all(self):
+ """Test generic select items when previous was all."""
+ mock_ctx = Mock(spec=InvocationContext)
+ mock_ctx.session_state = {"temp:skill:test:test-skill": '*'}
+ mock_ctx.actions = Mock()
+ mock_ctx.actions.state_delta = {}
+
+ result = generic_select_items(tool_context=mock_ctx,
+ skill_name="test-skill",
+ items=["item1"],
+ include_all=False,
+ mode="add",
+ state_key_prefix="temp:skill:test:",
+ result_class=MockSelectionResult)
+
+ # Should maintain include_all=True
assert result.include_all is True
- def test_previous_star_and_clear(self):
- ctx = _make_ctx(session_state={"prefix:skill": "*"})
- result = generic_select_items(
- ctx, "skill", [], False, "clear", "prefix:", _TestSelectionResult
- )
- assert result.include_all is False
- assert result.selected_items == []
+ def test_generic_select_items_updates_state(self):
+ """Test generic select items updates state delta."""
+ mock_ctx = Mock(spec=InvocationContext)
+ mock_ctx.session_state = {}
+ mock_ctx.actions = Mock()
+ mock_ctx.actions.state_delta = {}
- def test_none_items_treated_as_empty(self):
- ctx = _make_ctx()
- result = generic_select_items(
- ctx, "skill", None, False, "replace", "prefix:", _TestSelectionResult
- )
- assert result.selected_items == []
-
- def test_updates_state_delta(self):
- ctx = _make_ctx()
- generic_select_items(
- ctx, "skill", ["a"], False, "replace", "prefix:", _TestSelectionResult
- )
- assert "prefix:skill" in ctx.actions.state_delta
+ result = generic_select_items(tool_context=mock_ctx,
+ skill_name="test-skill",
+ items=["item1"],
+ include_all=False,
+ mode="replace",
+ state_key_prefix="temp:skill:test:",
+ result_class=MockSelectionResult)
+ key = "temp:skill:test:test-skill"
+ assert key in mock_ctx.actions.state_delta
-# ---------------------------------------------------------------------------
-# generic_get_selection
-# ---------------------------------------------------------------------------
class TestGenericGetSelection:
- def test_no_value_returns_empty(self):
- ctx = _make_ctx()
- assert generic_get_selection(ctx, "skill", "prefix:") == []
-
- def test_json_array(self):
- ctx = _make_ctx(state_delta={"prefix:skill": json.dumps(["a", "b"])})
- assert generic_get_selection(ctx, "skill", "prefix:") == ["a", "b"]
-
- def test_star_with_callback(self):
- ctx = _make_ctx(state_delta={"prefix:skill": "*"})
- callback = Mock(return_value=["all_a", "all_b"])
- result = generic_get_selection(ctx, "skill", "prefix:", callback)
- assert result == ["all_a", "all_b"]
- callback.assert_called_once_with("skill")
-
- def test_star_without_callback(self):
- ctx = _make_ctx(state_delta={"prefix:skill": "*"})
- assert generic_get_selection(ctx, "skill", "prefix:") == []
-
- def test_star_callback_exception_returns_empty(self):
- ctx = _make_ctx(state_delta={"prefix:skill": "*"})
- callback = Mock(side_effect=RuntimeError("boom"))
- assert generic_get_selection(ctx, "skill", "prefix:", callback) == []
-
- def test_invalid_json_returns_empty(self):
- ctx = _make_ctx(state_delta={"prefix:skill": "not_json"})
- assert generic_get_selection(ctx, "skill", "prefix:") == []
-
- def test_bytes_value_decoded(self):
- ctx = _make_ctx(state_delta={"prefix:skill": json.dumps(["x"]).encode("utf-8")})
- assert generic_get_selection(ctx, "skill", "prefix:") == ["x"]
-
- def test_bytes_star_with_callback(self):
- ctx = _make_ctx(state_delta={"prefix:skill": b"*"})
- callback = Mock(return_value=["all"])
- result = generic_get_selection(ctx, "skill", "prefix:", callback)
- assert result == ["all"]
-
- def test_non_list_json_returns_empty(self):
- ctx = _make_ctx(state_delta={"prefix:skill": json.dumps({"not": "list"})})
- assert generic_get_selection(ctx, "skill", "prefix:") == []
+ """Test suite for generic_get_selection function."""
+
+ def test_generic_get_selection_json_array(self):
+ """Test getting selection from JSON array."""
+ mock_ctx = Mock(spec=InvocationContext)
+ mock_ctx.actions = Mock()
+ mock_ctx.actions.state_delta = {}
+ mock_ctx.session_state = {"temp:skill:test:test-skill": json.dumps(["item1", "item2"])}
+
+ result = generic_get_selection(ctx=mock_ctx, skill_name="test-skill", state_key_prefix="temp:skill:test:")
+
+ assert result == ["item1", "item2"]
+
+ def test_generic_get_selection_all_with_callback(self):
+ """Test getting selection with '*' and callback."""
+ mock_ctx = Mock(spec=InvocationContext)
+ mock_ctx.actions = Mock()
+ mock_ctx.actions.state_delta = {}
+ mock_ctx.session_state = {"temp:skill:test:test-skill": '*'}
+
+ def get_all_items(skill_name):
+ return ["item1", "item2", "item3"]
+
+ result = generic_get_selection(ctx=mock_ctx,
+ skill_name="test-skill",
+ state_key_prefix="temp:skill:test:",
+ get_all_items_callback=get_all_items)
+
+ assert result == ["item1", "item2", "item3"]
+
+ def test_generic_get_selection_all_without_callback(self):
+ """Test getting selection with '*' but no callback."""
+ mock_ctx = Mock(spec=InvocationContext)
+ mock_ctx.actions = Mock()
+ mock_ctx.actions.state_delta = {}
+ mock_ctx.session_state = {"temp:skill:test:test-skill": '*'}
+
+ result = generic_get_selection(ctx=mock_ctx, skill_name="test-skill", state_key_prefix="temp:skill:test:")
+
+ assert result == []
+
+ def test_generic_get_selection_not_found(self):
+ """Test getting selection when not found."""
+ mock_ctx = Mock(spec=InvocationContext)
+ mock_ctx.actions = Mock()
+ mock_ctx.actions.state_delta = {}
+ mock_ctx.session_state = {}
+
+ result = generic_get_selection(ctx=mock_ctx, skill_name="test-skill", state_key_prefix="temp:skill:test:")
+
+ assert result == []
+
+ def test_generic_get_selection_invalid_json(self):
+ """Test getting selection with invalid JSON."""
+ mock_ctx = Mock(spec=InvocationContext)
+ mock_ctx.actions = Mock()
+ mock_ctx.actions.state_delta = {}
+ mock_ctx.session_state = {"temp:skill:test:test-skill": "invalid json"}
+
+ result = generic_get_selection(ctx=mock_ctx, skill_name="test-skill", state_key_prefix="temp:skill:test:")
+
+ assert result == []
+
+ def test_generic_get_selection_bytes_value(self):
+ """Test getting selection when value is bytes."""
+ mock_ctx = Mock(spec=InvocationContext)
+ mock_ctx.actions = Mock()
+ mock_ctx.actions.state_delta = {}
+ mock_ctx.session_state = {"temp:skill:test:test-skill": json.dumps(["item1"]).encode('utf-8')}
+
+ result = generic_get_selection(ctx=mock_ctx, skill_name="test-skill", state_key_prefix="temp:skill:test:")
+
+ assert result == ["item1"]
+
+ def test_generic_get_selection_callback_exception(self):
+ """Test getting selection when callback raises exception."""
+ mock_ctx = Mock(spec=InvocationContext)
+ mock_ctx.actions = Mock()
+ mock_ctx.actions.state_delta = {}
+ mock_ctx.session_state = {"temp:skill:test:test-skill": '*'}
+
+ def get_all_items_error(skill_name):
+ raise Exception("Test error")
+
+ result = generic_get_selection(ctx=mock_ctx,
+ skill_name="test-skill",
+ state_key_prefix="temp:skill:test:",
+ get_all_items_callback=get_all_items_error)
+
+ assert result == []
+
+ def test_generic_get_selection_from_state_delta(self):
+ """Test getting selection prefers state_delta over session_state."""
+ mock_ctx = Mock(spec=InvocationContext)
+ mock_ctx.actions = Mock()
+ mock_ctx.actions.state_delta = {"temp:skill:test:test-skill": json.dumps(["item_new"])}
+ mock_ctx.session_state = {"temp:skill:test:test-skill": json.dumps(["item_old"])}
+
+ result = generic_get_selection(ctx=mock_ctx, skill_name="test-skill", state_key_prefix="temp:skill:test:")
+
+ assert result == ["item_new"]
diff --git a/tests/skills/test_constants.py b/tests/skills/test_constants.py
new file mode 100644
index 0000000..26b80fa
--- /dev/null
+++ b/tests/skills/test_constants.py
@@ -0,0 +1,23 @@
+# Tencent is pleased to support the open source community by making tRPC-Agent-Python available.
+#
+# Copyright (C) 2026 Tencent. All rights reserved.
+#
+# tRPC-Agent-Python is licensed under Apache-2.0.
+
+from trpc_agent_sdk.skills._constants import SKILL_FILE
+from trpc_agent_sdk.skills._constants import SKILL_LOAD_MODE_VALUES
+from trpc_agent_sdk.skills._constants import SKILL_TOOLS_NAMES
+from trpc_agent_sdk.skills._constants import SkillLoadModeNames
+from trpc_agent_sdk.skills._constants import SkillToolsNames
+
+
+def test_skill_file_constant():
+ assert SKILL_FILE == "SKILL.md"
+
+
+def test_skill_tools_names_matches_enum():
+ assert SKILL_TOOLS_NAMES == [item.value for item in SkillToolsNames]
+
+
+def test_skill_load_mode_values_matches_enum():
+ assert SKILL_LOAD_MODE_VALUES == [item.value for item in SkillLoadModeNames]
diff --git a/tests/skills/test_dynamic_toolset.py b/tests/skills/test_dynamic_toolset.py
index 26cd3b8..df65fb8 100644
--- a/tests/skills/test_dynamic_toolset.py
+++ b/tests/skills/test_dynamic_toolset.py
@@ -23,12 +23,19 @@
import pytest
from trpc_agent_sdk.skills._dynamic_toolset import DynamicSkillToolSet
+from trpc_agent_sdk.skills._common import loaded_state_key
+from trpc_agent_sdk.skills._common import tool_state_key
+from trpc_agent_sdk.skills._constants import SKILL_CONFIG_KEY
+from trpc_agent_sdk.skills._skill_config import DEFAULT_SKILL_CONFIG
def _make_ctx(state_delta=None, session_state=None):
ctx = MagicMock()
ctx.actions.state_delta = state_delta or {}
ctx.session_state = session_state or {}
+ ctx.agent_name = ""
+ ctx.agent_context.get_metadata = MagicMock(
+ side_effect=lambda key, default=None: DEFAULT_SKILL_CONFIG if key == SKILL_CONFIG_KEY else default)
return ctx
@@ -113,10 +120,9 @@ def test_no_loaded_skills(self, *_):
def test_loaded_skills_from_session_and_delta(self, *_):
repo = _make_mock_repository()
ts = DynamicSkillToolSet(skill_repository=repo)
- ctx = _make_ctx(
- session_state={"temp:skill:loaded:skill-a": True},
- state_delta={"temp:skill:loaded:skill-b": True},
- )
+ ctx = _make_ctx()
+ ctx.session_state = {loaded_state_key(ctx, "skill-a"): True}
+ ctx.actions.state_delta = {loaded_state_key(ctx, "skill-b"): True}
result = ts._get_loaded_skills_from_state(ctx)
assert set(result) == {"skill-a", "skill-b"}
@@ -139,7 +145,7 @@ def test_no_active_skills(self, *_):
def test_active_from_loaded(self, *_):
repo = _make_mock_repository()
ts = DynamicSkillToolSet(skill_repository=repo)
- ctx = _make_ctx(state_delta={"temp:skill:loaded:s1": True})
+ ctx = _make_ctx(state_delta={loaded_state_key(_make_ctx(), "s1"): True})
result = ts._get_active_skills_from_delta(ctx)
assert "s1" in result
@@ -148,7 +154,7 @@ def test_active_from_loaded(self, *_):
def test_active_from_tools_modified(self, *_):
repo = _make_mock_repository()
ts = DynamicSkillToolSet(skill_repository=repo)
- ctx = _make_ctx(state_delta={"temp:skill:tools:s2": json.dumps(["t1"])})
+ ctx = _make_ctx(state_delta={tool_state_key(_make_ctx(), "s2"): json.dumps(["t1"])})
result = ts._get_active_skills_from_delta(ctx)
assert "s2" in result
@@ -157,7 +163,7 @@ def test_active_from_tools_modified(self, *_):
def test_falsy_loaded_value_ignored(self, *_):
repo = _make_mock_repository()
ts = DynamicSkillToolSet(skill_repository=repo)
- ctx = _make_ctx(state_delta={"temp:skill:loaded:s1": False})
+ ctx = _make_ctx(state_delta={loaded_state_key(_make_ctx(), "s1"): False})
assert ts._get_active_skills_from_delta(ctx) == []
@@ -173,7 +179,7 @@ def test_json_array(self, *_):
skill.tools = ["default_tool"]
repo = _make_mock_repository({"s1": skill})
ts = DynamicSkillToolSet(skill_repository=repo)
- ctx = _make_ctx(state_delta={"temp:skill:tools:s1": json.dumps(["tool_a", "tool_b"])})
+ ctx = _make_ctx(state_delta={tool_state_key(_make_ctx(), "s1"): json.dumps(["tool_a", "tool_b"])})
result = ts._get_tools_selection(ctx, "s1")
assert result == ["tool_a", "tool_b"]
@@ -184,7 +190,7 @@ def test_star_returns_defaults(self, *_):
skill.tools = ["default_tool"]
repo = _make_mock_repository({"s1": skill})
ts = DynamicSkillToolSet(skill_repository=repo)
- ctx = _make_ctx(state_delta={"temp:skill:tools:s1": "*"})
+ ctx = _make_ctx(state_delta={tool_state_key(_make_ctx(), "s1"): "*"})
result = ts._get_tools_selection(ctx, "s1")
assert result == ["default_tool"]
@@ -206,7 +212,7 @@ def test_invalid_json_falls_back(self, *_):
skill.tools = ["fallback"]
repo = _make_mock_repository({"s1": skill})
ts = DynamicSkillToolSet(skill_repository=repo)
- ctx = _make_ctx(state_delta={"temp:skill:tools:s1": "not_json"})
+ ctx = _make_ctx(state_delta={tool_state_key(_make_ctx(), "s1"): "not_json"})
result = ts._get_tools_selection(ctx, "s1")
assert result == ["fallback"]
@@ -301,7 +307,7 @@ async def test_active_skills_with_tools(self, *_):
mock_tool = _make_mock_tool("my_tool")
ts._available_tools["my_tool"] = mock_tool
- ctx = _make_ctx(state_delta={"temp:skill:loaded:s1": True})
+ ctx = _make_ctx(state_delta={loaded_state_key(_make_ctx(), "s1"): True})
result = await ts.get_tools(ctx)
assert len(result) == 1
assert result[0] is mock_tool
@@ -316,7 +322,7 @@ async def test_only_active_false_uses_all_loaded(self, *_):
mock_tool = _make_mock_tool("tool_a")
ts._available_tools["tool_a"] = mock_tool
- ctx = _make_ctx(session_state={"temp:skill:loaded:s1": True})
+ ctx = _make_ctx(session_state={loaded_state_key(_make_ctx(), "s1"): True})
result = await ts.get_tools(ctx)
assert len(result) == 1
@@ -332,9 +338,10 @@ async def test_deduplicates_tools(self, *_):
mock_tool = _make_mock_tool("shared_tool")
ts._available_tools["shared_tool"] = mock_tool
+ key_ctx = _make_ctx()
ctx = _make_ctx(session_state={
- "temp:skill:loaded:s1": True,
- "temp:skill:loaded:s2": True,
+ loaded_state_key(key_ctx, "s1"): True,
+ loaded_state_key(key_ctx, "s2"): True,
})
result = await ts.get_tools(ctx)
assert len(result) == 1
@@ -349,7 +356,7 @@ async def test_fallback_to_loaded_when_no_active(self, *_):
mock_tool = _make_mock_tool("fallback_tool")
ts._available_tools["fallback_tool"] = mock_tool
- ctx = _make_ctx(session_state={"temp:skill:loaded:s1": True})
+ ctx = _make_ctx(session_state={loaded_state_key(_make_ctx(), "s1"): True})
result = await ts.get_tools(ctx)
assert len(result) == 1
@@ -360,7 +367,7 @@ async def test_unresolvable_tool_skipped(self, *_):
skill.tools = ["nonexistent_tool"]
repo = _make_mock_repository({"s1": skill})
ts = DynamicSkillToolSet(skill_repository=repo)
- ctx = _make_ctx(state_delta={"temp:skill:loaded:s1": True})
+ ctx = _make_ctx(state_delta={loaded_state_key(_make_ctx(), "s1"): True})
result = await ts.get_tools(ctx)
assert len(result) == 0
@@ -371,7 +378,7 @@ async def test_no_tools_for_skill(self, *_):
skill.tools = []
repo = _make_mock_repository({"s1": skill})
ts = DynamicSkillToolSet(skill_repository=repo)
- ctx = _make_ctx(state_delta={"temp:skill:loaded:s1": True})
+ ctx = _make_ctx(state_delta={loaded_state_key(_make_ctx(), "s1"): True})
result = await ts.get_tools(ctx)
assert result == []
@@ -439,7 +446,7 @@ def test_bytes_json_value(self, *_):
skill.tools = ["default"]
repo = _make_mock_repository({"s1": skill})
ts = DynamicSkillToolSet(skill_repository=repo)
- ctx = _make_ctx(state_delta={"temp:skill:tools:s1": json.dumps(["t1"]).encode()})
+ ctx = _make_ctx(state_delta={tool_state_key(_make_ctx(), "s1"): json.dumps(["t1"]).encode()})
result = ts._get_tools_selection(ctx, "s1")
assert result == ["t1"]
@@ -450,7 +457,7 @@ def test_bytes_star_value(self, *_):
skill.tools = ["default"]
repo = _make_mock_repository({"s1": skill})
ts = DynamicSkillToolSet(skill_repository=repo)
- ctx = _make_ctx(state_delta={"temp:skill:tools:s1": b"*"})
+ ctx = _make_ctx(state_delta={tool_state_key(_make_ctx(), "s1"): b"*"})
result = ts._get_tools_selection(ctx, "s1")
assert result == ["default"]
@@ -461,6 +468,6 @@ def test_non_list_json_falls_back(self, *_):
skill.tools = ["default"]
repo = _make_mock_repository({"s1": skill})
ts = DynamicSkillToolSet(skill_repository=repo)
- ctx = _make_ctx(state_delta={"temp:skill:tools:s1": json.dumps({"not": "list"})})
+ ctx = _make_ctx(state_delta={tool_state_key(_make_ctx(), "s1"): json.dumps({"not": "list"})})
result = ts._get_tools_selection(ctx, "s1")
assert result == ["default"]
diff --git a/tests/skills/test_hot_reload.py b/tests/skills/test_hot_reload.py
new file mode 100644
index 0000000..bc8411b
--- /dev/null
+++ b/tests/skills/test_hot_reload.py
@@ -0,0 +1,35 @@
+# Tencent is pleased to support the open source community by making tRPC-Agent-Python available.
+#
+# Copyright (C) 2026 Tencent. All rights reserved.
+#
+# tRPC-Agent-Python is licensed under Apache-2.0.
+
+from pathlib import Path
+
+from trpc_agent_sdk.skills._hot_reload import SkillHotReloadTracker
+
+
+def test_mark_changed_path_only_tracks_skill_file(tmp_path: Path):
+ root = tmp_path / "skills"
+ skill_dir = root / "demo"
+ skill_dir.mkdir(parents=True)
+ tracker = SkillHotReloadTracker("SKILL.md")
+
+ tracker.mark_changed_path(str(skill_dir / "notes.txt"), is_directory=False, skill_roots=[str(root)])
+ assert tracker.pop_changed_dirs(str(root.resolve())) == []
+
+ tracker.mark_changed_path(str(skill_dir / "SKILL.md"), is_directory=False, skill_roots=[str(root)])
+ changed = tracker.pop_changed_dirs(str(root.resolve()))
+ assert changed == [skill_dir.resolve()]
+
+
+def test_resolve_root_key_and_normalize_targets(tmp_path: Path):
+ root = tmp_path / "skills"
+ nested = root / "a" / "b"
+ nested.mkdir(parents=True)
+
+ key = SkillHotReloadTracker.resolve_root_key(nested, [str(root)])
+ assert key == str(root.resolve())
+
+ deduped = SkillHotReloadTracker.normalize_scan_targets([root / "a", nested, root / "c"])
+ assert deduped == [root / "a", root / "c"]
diff --git a/tests/skills/test_repository.py b/tests/skills/test_repository.py
index 22a6ad7..6f24581 100644
--- a/tests/skills/test_repository.py
+++ b/tests/skills/test_repository.py
@@ -27,11 +27,9 @@
BASE_DIR_PLACEHOLDER,
BaseSkillRepository,
FsSkillRepository,
- _is_doc_file,
- _parse_tools_from_body,
- _split_front_matter,
create_default_skill_repository,
)
+from trpc_agent_sdk.skills._utils import is_doc_file
# ---------------------------------------------------------------------------
@@ -40,54 +38,54 @@
class TestSplitFrontMatter:
def test_no_front_matter(self):
- fm, body = _split_front_matter("# Hello\nworld")
+ fm, body = FsSkillRepository.from_markdown("# Hello\nworld")
assert fm == {}
assert body == "# Hello\nworld"
def test_with_front_matter(self):
content = "---\nname: test\ndescription: Test skill\n---\n# Body"
- fm, body = _split_front_matter(content)
+ fm, body = FsSkillRepository.from_markdown(content)
assert fm["name"] == "test"
assert fm["description"] == "Test skill"
assert body == "# Body"
def test_crlf_normalization(self):
content = "---\r\nname: test\r\n---\r\nbody"
- fm, body = _split_front_matter(content)
+ fm, body = FsSkillRepository.from_markdown(content)
assert fm["name"] == "test"
assert body == "body"
def test_invalid_yaml_returns_empty_dict(self):
content = "---\n: : : invalid\n---\nbody"
- fm, body = _split_front_matter(content)
+ fm, body = FsSkillRepository.from_markdown(content)
assert body == "body"
def test_non_dict_yaml_returns_empty_dict(self):
content = "---\n- item1\n- item2\n---\nbody"
- fm, body = _split_front_matter(content)
+ fm, body = FsSkillRepository.from_markdown(content)
assert fm == {}
assert body == "body"
def test_unclosed_front_matter(self):
content = "---\nname: test\nno closing"
- fm, body = _split_front_matter(content)
+ fm, body = FsSkillRepository.from_markdown(content)
assert fm == {}
assert body == content
def test_none_values_become_empty_string(self):
content = "---\nname:\n---\nbody"
- fm, body = _split_front_matter(content)
+ fm, body = FsSkillRepository.from_markdown(content)
assert fm["name"] == ""
def test_none_key_converted_to_string(self):
content = "---\nname: test\n---\nbody"
- fm, body = _split_front_matter(content)
+ fm, body = FsSkillRepository.from_markdown(content)
assert fm["name"] == "test"
assert body == "body"
def test_no_dash_prefix(self):
content = "no front matter at all"
- fm, body = _split_front_matter(content)
+ fm, body = FsSkillRepository.from_markdown(content)
assert fm == {}
assert body == content
@@ -99,29 +97,29 @@ def test_no_dash_prefix(self):
class TestParseToolsFromBody:
def test_basic_tools_section(self):
body = "Tools:\n- tool_a\n- tool_b\n\nOverview"
- tools = _parse_tools_from_body(body)
+ tools = FsSkillRepository._parse_tools_from_body(body)
assert tools == ["tool_a", "tool_b"]
def test_no_tools_section(self):
body = "# Just markdown\nNo tools here"
- assert _parse_tools_from_body(body) == []
+ assert FsSkillRepository._parse_tools_from_body(body) == []
def test_tools_section_stops_at_next_section(self):
body = "Tools:\n- tool_a\nOverview\nMore content"
- tools = _parse_tools_from_body(body)
+ tools = FsSkillRepository._parse_tools_from_body(body)
assert tools == ["tool_a"]
def test_tools_section_skips_headings(self):
body = "Tools:\n# Comment\n- tool_a\n"
- tools = _parse_tools_from_body(body)
+ tools = FsSkillRepository._parse_tools_from_body(body)
assert tools == ["tool_a"]
def test_empty_body(self):
- assert _parse_tools_from_body("") == []
+ assert FsSkillRepository._parse_tools_from_body("") == []
def test_tools_with_description_colon(self):
body = "Tools:\n- tool_a\nDescription: something\n"
- tools = _parse_tools_from_body(body)
+ tools = FsSkillRepository._parse_tools_from_body(body)
assert tools == ["tool_a"]
@@ -131,16 +129,16 @@ def test_tools_with_description_colon(self):
class TestIsDocFile:
def test_markdown(self):
- assert _is_doc_file("readme.md") is True
- assert _is_doc_file("README.MD") is True
+ assert is_doc_file("readme.md") is True
+ assert is_doc_file("README.MD") is True
def test_text(self):
- assert _is_doc_file("notes.txt") is True
- assert _is_doc_file("NOTES.TXT") is True
+ assert is_doc_file("notes.txt") is True
+ assert is_doc_file("NOTES.TXT") is True
def test_non_doc(self):
- assert _is_doc_file("script.py") is False
- assert _is_doc_file("data.json") is False
+ assert is_doc_file("script.py") is False
+ assert is_doc_file("data.json") is False
# ---------------------------------------------------------------------------
@@ -374,10 +372,42 @@ def test_read_docs_error_handling(self, tmp_path):
class TestBaseSkillRepositoryAbstract:
def test_user_prompt_default(self):
- repo = MagicMock(spec=BaseSkillRepository)
- BaseSkillRepository.user_prompt(repo)
+ class _Repo(BaseSkillRepository):
+ def summaries(self):
+ return []
+
+ def get(self, name: str):
+ raise ValueError(name)
+
+ def skill_list(self, mode: str = "all"):
+ return []
+
+ def path(self, name: str) -> str:
+ return ""
+
+ def refresh(self) -> None:
+ return None
+
+ repo = _Repo(workspace_runtime=MagicMock())
+ assert BaseSkillRepository.user_prompt(repo) == ""
def test_skill_run_env_default(self):
- repo = MagicMock(spec=BaseSkillRepository)
+ class _Repo(BaseSkillRepository):
+ def summaries(self):
+ return []
+
+ def get(self, name: str):
+ raise ValueError(name)
+
+ def skill_list(self, mode: str = "all"):
+ return []
+
+ def path(self, name: str) -> str:
+ return ""
+
+ def refresh(self) -> None:
+ return None
+
+ repo = _Repo(workspace_runtime=MagicMock())
result = BaseSkillRepository.skill_run_env(repo, "skill")
assert result == {}
diff --git a/tests/skills/test_run_tool.py b/tests/skills/test_run_tool.py
deleted file mode 100644
index 9e4f577..0000000
--- a/tests/skills/test_run_tool.py
+++ /dev/null
@@ -1,364 +0,0 @@
-# Tencent is pleased to support the open source community by making tRPC-Agent-Python available.
-#
-# Copyright (C) 2026 Tencent. All rights reserved.
-#
-# tRPC-Agent-Python is licensed under Apache-2.0.
-
-from unittest.mock import AsyncMock
-from unittest.mock import Mock
-from unittest.mock import patch
-
-import pytest
-from trpc_agent_sdk.code_executors import BaseProgramRunner
-from trpc_agent_sdk.code_executors import BaseWorkspaceFS
-from trpc_agent_sdk.code_executors import BaseWorkspaceManager
-from trpc_agent_sdk.code_executors import BaseWorkspaceRuntime
-from trpc_agent_sdk.code_executors import CodeFile
-from trpc_agent_sdk.code_executors import WorkspaceInfo
-from trpc_agent_sdk.code_executors import WorkspaceRunResult
-from trpc_agent_sdk.context import InvocationContext
-from trpc_agent_sdk.skills import BaseSkillRepository
-from trpc_agent_sdk.skills.tools import ArtifactInfo
-from trpc_agent_sdk.skills.tools import SkillRunFile
-from trpc_agent_sdk.skills.tools import SkillRunInput
-from trpc_agent_sdk.skills.tools import SkillRunOutput
-from trpc_agent_sdk.skills.tools import SkillRunTool
-from trpc_agent_sdk.skills.tools._skill_run import _inline_json_schema_refs
-
-
-class TestInlineJsonSchemaRefs:
- """Test suite for _inline_json_schema_refs function."""
-
- def test_inline_json_schema_refs_no_refs(self):
- """Test inlining schema with no $ref references."""
- schema = {
- "type": "object",
- "properties": {
- "name": {"type": "string"}
- }
- }
-
- result = _inline_json_schema_refs(schema)
-
- assert result == schema
-
- def test_inline_json_schema_refs_with_refs(self):
- """Test inlining schema with $ref references."""
- schema = {
- "type": "object",
- "properties": {
- "item": {"$ref": "#/$defs/Item"}
- },
- "$defs": {
- "Item": {
- "type": "object",
- "properties": {
- "name": {"type": "string"}
- }
- }
- }
- }
-
- result = _inline_json_schema_refs(schema)
-
- assert "$defs" not in result
- assert "$ref" not in str(result)
- assert "name" in str(result)
-
- def test_inline_json_schema_refs_nested_refs(self):
- """Test inlining schema with nested $ref references."""
- schema = {
- "type": "object",
- "properties": {
- "item": {"$ref": "#/$defs/Item"}
- },
- "$defs": {
- "Item": {
- "type": "object",
- "properties": {
- "nested": {"$ref": "#/$defs/Nested"}
- }
- },
- "Nested": {
- "type": "string"
- }
- }
- }
-
- result = _inline_json_schema_refs(schema)
-
- assert "$defs" not in result
- assert "$ref" not in str(result)
-
-
-class TestSkillRunInput:
- """Test suite for SkillRunInput class."""
-
- def test_create_skill_run_input(self):
- """Test creating skill run input."""
- input_data = SkillRunInput(
- skill="test-skill",
- command="python script.py",
- cwd="work",
- env={"VAR": "value"},
- output_files=["out/*.txt"],
- timeout=30,
- save_as_artifacts=True,
- )
-
- assert input_data.skill == "test-skill"
- assert input_data.command == "python script.py"
- assert input_data.cwd == "work"
- assert input_data.env == {"VAR": "value"}
- assert input_data.output_files == ["out/*.txt"]
- assert input_data.timeout == 30
- assert input_data.save_as_artifacts is True
-
- def test_create_skill_run_input_defaults(self):
- """Test creating skill run input with defaults."""
- input_data = SkillRunInput(skill="test-skill", command="echo hello")
-
- assert input_data.skill == "test-skill"
- assert input_data.command == "echo hello"
- assert input_data.cwd == ""
- assert input_data.env == {}
- assert input_data.output_files == []
- assert input_data.timeout == 0
- assert input_data.save_as_artifacts is False
-
-
-class TestSkillRunOutput:
- """Test suite for SkillRunOutput class."""
-
- def test_create_skill_run_output(self):
- """Test creating skill run output."""
- output_files = [SkillRunFile(name="output.txt", content="content", mime_type="text/plain")]
- artifact_files = [ArtifactInfo(name="artifact.txt", version=1)]
-
- output = SkillRunOutput(
- stdout="output",
- stderr="error",
- exit_code=0,
- timed_out=False,
- duration_ms=1000,
- output_files=output_files,
- artifact_files=artifact_files,
- )
-
- assert output.stdout == "output"
- assert output.stderr == "error"
- assert output.exit_code == 0
- assert output.timed_out is False
- assert output.duration_ms == 1000
- assert len(output.output_files) == 1
- assert len(output.artifact_files) == 1
-
- def test_create_skill_run_output_defaults(self):
- """Test creating skill run output with defaults."""
- output = SkillRunOutput()
-
- assert output.stdout == ""
- assert output.stderr == ""
- assert output.exit_code == 0
- assert output.timed_out is False
- assert output.duration_ms == 0
- assert output.output_files == []
- assert output.artifact_files == []
-
-
-class TestArtifactInfo:
- """Test suite for ArtifactInfo class."""
-
- def test_create_artifact_info(self):
- """Test creating artifact info."""
- info = ArtifactInfo(name="artifact.txt", version=1)
-
- assert info.name == "artifact.txt"
- assert info.version == 1
-
- def test_create_artifact_info_defaults(self):
- """Test creating artifact info with defaults."""
- info = ArtifactInfo()
-
- assert info.name == ""
- assert info.version == 0
-
-
-class TestSkillRunTool:
- """Test suite for SkillRunTool class."""
-
- def setup_method(self):
- """Set up test fixtures before each test."""
- self.mock_repository = Mock(spec=BaseSkillRepository)
- self.mock_runtime = Mock(spec=BaseWorkspaceRuntime)
- self.mock_manager = Mock(spec=BaseWorkspaceManager)
- self.mock_fs = Mock(spec=BaseWorkspaceFS)
- self.mock_runner = Mock(spec=BaseProgramRunner)
- self.mock_repository.workspace_runtime = self.mock_runtime
- self.mock_runtime.manager = Mock(return_value=self.mock_manager)
- self.mock_runtime.fs = Mock(return_value=self.mock_fs)
- self.mock_runtime.runner = Mock(return_value=self.mock_runner)
-
- self.mock_ctx = Mock(spec=InvocationContext)
- self.mock_ctx.agent_context = Mock()
- self.mock_ctx.agent_context.get_metadata = Mock(return_value=None)
- self.mock_ctx.session = Mock()
- self.mock_ctx.session.id = "session-123"
- self.mock_ctx.actions = Mock()
- self.mock_ctx.actions.state_delta = {}
-
- def test_init(self):
- """Test SkillRunTool initialization."""
- tool = SkillRunTool(repository=self.mock_repository)
-
- assert tool.name == "skill_run"
- assert tool._repository == self.mock_repository
-
- def test_get_declaration(self):
- """Test getting function declaration."""
- tool = SkillRunTool(repository=self.mock_repository)
-
- declaration = tool._get_declaration()
-
- assert declaration.name == "skill_run"
- assert declaration.parameters is not None
- assert declaration.response is not None
-
- def test_get_repository_from_instance(self):
- """Test getting repository from instance."""
- tool = SkillRunTool(repository=self.mock_repository)
-
- result = tool._get_repository(self.mock_ctx)
-
- assert result == self.mock_repository
-
- def test_get_repository_from_context(self):
- """Test getting repository from context."""
- tool = SkillRunTool(repository=None)
- self.mock_ctx.agent_context.get_metadata = Mock(return_value=self.mock_repository)
-
- result = tool._get_repository(self.mock_ctx)
-
- assert result == self.mock_repository
-
- @pytest.mark.asyncio
- async def test_run_async_impl_success(self):
- """Test running skill_run tool successfully."""
- from trpc_agent_sdk.skills.stager import SkillStageResult
-
- mock_stager = AsyncMock()
- mock_stager.stage_skill = AsyncMock(return_value=SkillStageResult(workspace_skill_dir="skills/test-skill"))
- tool = SkillRunTool(repository=self.mock_repository, skill_stager=mock_stager)
-
- workspace = WorkspaceInfo(id="ws-123", path="/tmp/workspace")
- self.mock_repository.path = Mock(return_value="/path/to/skill")
- self.mock_repository.skill_run_env = Mock(return_value={})
- self.mock_manager.create_workspace = AsyncMock(return_value=workspace)
- self.mock_fs.stage_directory = AsyncMock()
- self.mock_fs.stage_inputs = AsyncMock()
- self.mock_fs.collect_outputs = AsyncMock(return_value=Mock(files=[]))
- self.mock_runner.run_program = AsyncMock(return_value=WorkspaceRunResult(
- stdout="output",
- stderr="",
- exit_code=0,
- duration=1.0,
- timed_out=False
- ))
-
- args = {
- "skill": "test-skill",
- "command": "echo hello"
- }
- result = await tool._run_async_impl(tool_context=self.mock_ctx, args=args)
-
- assert isinstance(result, dict)
- assert result["stdout"] == "output"
- assert result["exit_code"] == 0
-
- @pytest.mark.asyncio
- async def test_run_async_impl_with_output_files(self):
- """Test running skill_run tool with output files."""
- from trpc_agent_sdk.skills.stager import SkillStageResult
-
- mock_stager = AsyncMock()
- mock_stager.stage_skill = AsyncMock(return_value=SkillStageResult(workspace_skill_dir="skills/test-skill"))
- tool = SkillRunTool(repository=self.mock_repository, skill_stager=mock_stager)
-
- workspace = WorkspaceInfo(id="ws-123", path="/tmp/workspace")
- self.mock_repository.path = Mock(return_value="/path/to/skill")
- self.mock_repository.skill_run_env = Mock(return_value={})
- self.mock_manager.create_workspace = AsyncMock(return_value=workspace)
- self.mock_fs.stage_directory = AsyncMock()
- self.mock_fs.stage_inputs = AsyncMock()
- self.mock_fs.collect = AsyncMock(return_value=[CodeFile(name="output.txt", content="content", mime_type="text/plain")])
- self.mock_runner.run_program = AsyncMock(return_value=WorkspaceRunResult(
- stdout="output",
- stderr="",
- exit_code=0,
- duration=1.0,
- timed_out=False
- ))
-
- from trpc_agent_sdk.code_executors._types import ManifestOutput, ManifestFileRef
- mock_output = ManifestOutput(files=[
- ManifestFileRef(name="output.txt", content="content", mime_type="text/plain")
- ])
- self.mock_fs.collect_outputs = AsyncMock(return_value=mock_output)
-
- args = {
- "skill": "test-skill",
- "command": "echo hello",
- "output_files": ["out/*.txt"]
- }
- result = await tool._run_async_impl(tool_context=self.mock_ctx, args=args)
-
- assert len(result["output_files"]) == 1
-
- @pytest.mark.asyncio
- async def test_run_async_impl_invalid_args(self):
- """Test running skill_run tool with invalid arguments."""
- tool = SkillRunTool(repository=self.mock_repository)
-
- args = {
- "skill": "test-skill",
- # Missing required 'command' field
- }
-
- with pytest.raises(ValueError, match="Invalid skill_run arguments"):
- await tool._run_async_impl(tool_context=self.mock_ctx, args=args)
-
- @pytest.mark.asyncio
- async def test_run_async_impl_with_kwargs(self):
- """Test running skill_run tool with kwargs."""
- from trpc_agent_sdk.skills.stager import SkillStageResult
-
- mock_stager = AsyncMock()
- mock_stager.stage_skill = AsyncMock(return_value=SkillStageResult(workspace_skill_dir="skills/test-skill"))
- tool = SkillRunTool(repository=self.mock_repository, timeout=30, skill_stager=mock_stager)
-
- workspace = WorkspaceInfo(id="ws-123", path="/tmp/workspace")
- self.mock_repository.path = Mock(return_value="/path/to/skill")
- self.mock_repository.skill_run_env = Mock(return_value={})
- self.mock_manager.create_workspace = AsyncMock(return_value=workspace)
- self.mock_fs.stage_directory = AsyncMock()
- self.mock_fs.stage_inputs = AsyncMock()
- self.mock_fs.collect_outputs = AsyncMock(return_value=Mock(files=[]))
- self.mock_runner.run_program = AsyncMock(return_value=WorkspaceRunResult(
- stdout="output",
- stderr="",
- exit_code=0,
- duration=1.0,
- timed_out=False
- ))
-
- args = {
- "skill": "test-skill",
- "command": "echo hello"
- }
-
- result = await tool._run_async_impl(tool_context=self.mock_ctx, args=args)
-
- assert isinstance(result, dict)
- assert result["stdout"] == "output"
- assert result["exit_code"] == 0
-
diff --git a/tests/skills/test_skill_config.py b/tests/skills/test_skill_config.py
new file mode 100644
index 0000000..b8284dd
--- /dev/null
+++ b/tests/skills/test_skill_config.py
@@ -0,0 +1,42 @@
+# Tencent is pleased to support the open source community by making tRPC-Agent-Python available.
+#
+# Copyright (C) 2026 Tencent. All rights reserved.
+#
+# tRPC-Agent-Python is licensed under Apache-2.0.
+
+from unittest.mock import MagicMock
+
+from trpc_agent_sdk.skills._constants import SKILL_CONFIG_KEY
+from trpc_agent_sdk.skills._constants import SkillLoadModeNames
+from trpc_agent_sdk.skills._skill_config import DEFAULT_SKILL_CONFIG
+from trpc_agent_sdk.skills._skill_config import get_skill_config
+from trpc_agent_sdk.skills._skill_config import get_skill_load_mode
+from trpc_agent_sdk.skills._skill_config import is_exist_skill_config
+from trpc_agent_sdk.skills._skill_config import set_skill_config
+
+
+def test_get_skill_config_uses_metadata_default():
+ agent_ctx = MagicMock()
+ agent_ctx.get_metadata = MagicMock(return_value=DEFAULT_SKILL_CONFIG)
+ assert get_skill_config(agent_ctx) == DEFAULT_SKILL_CONFIG
+
+
+def test_set_skill_config_writes_metadata():
+ agent_ctx = MagicMock()
+ config = {"skill_processor": {"load_mode": "session"}}
+ set_skill_config(agent_ctx, config)
+ agent_ctx.with_metadata.assert_called_once_with(SKILL_CONFIG_KEY, config)
+
+
+def test_get_skill_load_mode_fallback_turn_on_invalid():
+ agent_ctx = MagicMock()
+ agent_ctx.get_metadata = MagicMock(return_value={"skill_processor": {"load_mode": "bad"}})
+ ctx = MagicMock()
+ ctx.agent_context = agent_ctx
+ assert get_skill_load_mode(ctx) == SkillLoadModeNames.TURN.value
+
+
+def test_is_exist_skill_config_checks_key():
+ agent_ctx = MagicMock()
+ agent_ctx.metadata = {SKILL_CONFIG_KEY: {}}
+ assert is_exist_skill_config(agent_ctx) is True
diff --git a/tests/skills/test_skill_profile.py b/tests/skills/test_skill_profile.py
new file mode 100644
index 0000000..b091c15
--- /dev/null
+++ b/tests/skills/test_skill_profile.py
@@ -0,0 +1,41 @@
+# Tencent is pleased to support the open source community by making tRPC-Agent-Python available.
+#
+# Copyright (C) 2026 Tencent. All rights reserved.
+#
+# tRPC-Agent-Python is licensed under Apache-2.0.
+
+import pytest
+
+from trpc_agent_sdk.skills._constants import SkillProfileNames
+from trpc_agent_sdk.skills._skill_profile import SkillProfileFlags
+
+
+def test_normalize_profile():
+ assert SkillProfileFlags.normalize_profile("knowledge_only") == SkillProfileNames.KNOWLEDGE_ONLY.value
+ assert SkillProfileFlags.normalize_profile("unknown") == SkillProfileNames.FULL.value
+
+
+def test_preset_flags_knowledge_only():
+ flags = SkillProfileFlags.preset_flags("knowledge_only")
+ assert flags.has_knowledge_tools() is True
+ assert flags.requires_execution_tools() is False
+
+
+def test_resolve_flags_with_forbidden_tool():
+ flags = SkillProfileFlags.resolve_flags("full", forbidden_tools=["skill_write_stdin"])
+ assert flags.exec is True
+ assert flags.write_stdin is False
+
+
+def test_validate_dependency_error():
+ flags = SkillProfileFlags(run=False, exec=True)
+ with pytest.raises(ValueError, match="requires"):
+ flags.validate()
+
+
+def test_without_interactive_execution():
+ flags = SkillProfileFlags.resolve_flags("full")
+ narrowed = flags.without_interactive_execution()
+ assert narrowed.run is True
+ assert narrowed.exec is False
+ assert narrowed.poll_session is False
diff --git a/tests/skills/test_state_keys.py b/tests/skills/test_state_keys.py
new file mode 100644
index 0000000..e8df27f
--- /dev/null
+++ b/tests/skills/test_state_keys.py
@@ -0,0 +1,35 @@
+# Tencent is pleased to support the open source community by making tRPC-Agent-Python available.
+#
+# Copyright (C) 2026 Tencent. All rights reserved.
+#
+# tRPC-Agent-Python is licensed under Apache-2.0.
+
+from trpc_agent_sdk.skills._state_keys import docs_key
+from trpc_agent_sdk.skills._state_keys import docs_prefix
+from trpc_agent_sdk.skills._state_keys import loaded_key
+from trpc_agent_sdk.skills._state_keys import loaded_order_key
+from trpc_agent_sdk.skills._state_keys import loaded_prefix
+from trpc_agent_sdk.skills._state_keys import to_persistent_prefix
+from trpc_agent_sdk.skills._state_keys import tool_key
+from trpc_agent_sdk.skills._state_keys import tool_prefix
+
+
+def test_loaded_key_legacy_fallback():
+ assert loaded_key("", "demo") == "temp:skill:loaded:demo"
+
+
+def test_scoped_keys_escape_agent_name():
+ assert loaded_key("agent/a", "demo") == "temp:skill:loaded_by_agent:agent%2Fa/demo"
+ assert docs_key("agent/a", "demo") == "temp:skill:docs_by_agent:agent%2Fa/demo"
+ assert tool_key("agent/a", "demo") == "temp:skill:tools_by_agent:agent%2Fa/demo"
+
+
+def test_prefix_helpers():
+ assert loaded_prefix("") == "temp:skill:loaded:"
+ assert docs_prefix("") == "temp:skill:docs:"
+ assert tool_prefix("") == "temp:skill:tools:"
+ assert loaded_order_key("agent/a") == "temp:skill:loaded_order_by_agent:agent%2Fa"
+
+
+def test_to_persistent_prefix():
+ assert to_persistent_prefix("temp:skill:loaded:demo") == "skill:loaded:demo"
diff --git a/tests/skills/test_state_migration.py b/tests/skills/test_state_migration.py
new file mode 100644
index 0000000..02495a3
--- /dev/null
+++ b/tests/skills/test_state_migration.py
@@ -0,0 +1,71 @@
+# Tencent is pleased to support the open source community by making tRPC-Agent-Python available.
+#
+# Copyright (C) 2026 Tencent. All rights reserved.
+#
+# tRPC-Agent-Python is licensed under Apache-2.0.
+"""Tests for legacy skill state migration."""
+
+from unittest.mock import Mock
+
+from trpc_agent_sdk.skills._constants import SKILL_DOCS_STATE_KEY_PREFIX
+from trpc_agent_sdk.skills._constants import SKILL_LOADED_STATE_KEY_PREFIX
+from trpc_agent_sdk.skills._state_keys import docs_key
+from trpc_agent_sdk.skills._state_keys import loaded_key
+from trpc_agent_sdk.skills._state_migration import SKILLS_LEGACY_MIGRATION_STATE_KEY
+from trpc_agent_sdk.skills._state_migration import maybe_migrate_legacy_skill_state
+
+
+def _build_ctx(*, state=None, delta=None, agent_name: str = "agent-a") -> Mock:
+ ctx = Mock()
+ ctx.session = Mock()
+ ctx.session.state = dict(state or {})
+ ctx.session.events = []
+ ctx.actions = Mock()
+ ctx.actions.state_delta = dict(delta or {})
+ ctx.agent = Mock()
+ ctx.agent.name = agent_name
+ return ctx
+
+
+class TestMaybeMigrateLegacySkillState:
+ def test_migrates_loaded_legacy_key_to_temp(self):
+ legacy_key = f"{SKILL_LOADED_STATE_KEY_PREFIX}demo-skill"
+ temp_key = loaded_key("agent-a", "demo-skill")
+ ctx = _build_ctx(state={legacy_key: "1"})
+
+ maybe_migrate_legacy_skill_state(ctx)
+
+ assert ctx.actions.state_delta[SKILLS_LEGACY_MIGRATION_STATE_KEY] is True
+ assert ctx.actions.state_delta[temp_key] == "1"
+ assert ctx.actions.state_delta[legacy_key] is None
+
+ def test_migrates_docs_legacy_key_to_temp(self):
+ legacy_key = f"{SKILL_DOCS_STATE_KEY_PREFIX}demo-skill"
+ temp_key = docs_key("agent-a", "demo-skill")
+ value = '["README.md"]'
+ ctx = _build_ctx(state={legacy_key: value})
+
+ maybe_migrate_legacy_skill_state(ctx)
+
+ assert ctx.actions.state_delta[SKILLS_LEGACY_MIGRATION_STATE_KEY] is True
+ assert ctx.actions.state_delta[temp_key] == value
+ assert ctx.actions.state_delta[legacy_key] is None
+
+ def test_existing_scoped_key_skips_copy_and_only_clears_legacy(self):
+ legacy_key = f"{SKILL_LOADED_STATE_KEY_PREFIX}demo-skill"
+ temp_key = loaded_key("agent-a", "demo-skill")
+ ctx = _build_ctx(state={legacy_key: "legacy", temp_key: "existing"})
+
+ maybe_migrate_legacy_skill_state(ctx)
+
+ assert ctx.actions.state_delta[SKILLS_LEGACY_MIGRATION_STATE_KEY] is True
+ assert ctx.actions.state_delta[legacy_key] is None
+ assert temp_key not in ctx.actions.state_delta
+
+ def test_migration_is_idempotent_when_marker_exists(self):
+ legacy_key = f"{SKILL_LOADED_STATE_KEY_PREFIX}demo-skill"
+ ctx = _build_ctx(state={legacy_key: "1", SKILLS_LEGACY_MIGRATION_STATE_KEY: True})
+
+ maybe_migrate_legacy_skill_state(ctx)
+
+ assert ctx.actions.state_delta == {}
diff --git a/tests/skills/test_state_order.py b/tests/skills/test_state_order.py
new file mode 100644
index 0000000..360d58b
--- /dev/null
+++ b/tests/skills/test_state_order.py
@@ -0,0 +1,24 @@
+# Tencent is pleased to support the open source community by making tRPC-Agent-Python available.
+#
+# Copyright (C) 2026 Tencent. All rights reserved.
+#
+# tRPC-Agent-Python is licensed under Apache-2.0.
+
+from trpc_agent_sdk.skills._state_order import marshal_loaded_order
+from trpc_agent_sdk.skills._state_order import parse_loaded_order
+from trpc_agent_sdk.skills._state_order import touch_loaded_order
+
+
+def test_parse_loaded_order_from_json_and_bytes():
+ assert parse_loaded_order('["a","b","a",""]') == ["a", "b"]
+ assert parse_loaded_order(b'["x","y"]') == ["x", "y"]
+ assert parse_loaded_order(b"\xff") == []
+
+
+def test_marshal_loaded_order_normalizes():
+ assert marshal_loaded_order(["a", "a", " ", "b"]) == '["a", "b"]'
+ assert marshal_loaded_order([]) == ""
+
+
+def test_touch_loaded_order_moves_items_to_tail():
+ assert touch_loaded_order(["a", "b", "c"], "b", "a") == ["c", "b", "a"]
diff --git a/tests/skills/test_tools.py b/tests/skills/test_tools.py
deleted file mode 100644
index 6c49fdd..0000000
--- a/tests/skills/test_tools.py
+++ /dev/null
@@ -1,553 +0,0 @@
-# Tencent is pleased to support the open source community by making tRPC-Agent-Python available.
-#
-# Copyright (C) 2026 Tencent. All rights reserved.
-#
-# tRPC-Agent-Python is licensed under Apache-2.0.
-
-import json
-from unittest.mock import Mock
-
-import pytest
-from trpc_agent_sdk.context import InvocationContext
-from trpc_agent_sdk.skills.tools import SkillSelectDocsResult
-from trpc_agent_sdk.skills.tools import SkillSelectToolsResult
-from trpc_agent_sdk.skills import skill_list
-from trpc_agent_sdk.skills import skill_list_docs
-from trpc_agent_sdk.skills import skill_list_tools
-from trpc_agent_sdk.skills import skill_load
-from trpc_agent_sdk.skills import skill_select_docs
-from trpc_agent_sdk.skills import skill_select_tools
-from trpc_agent_sdk.skills.tools._skill_load import _set_state_delta_for_skill_load
-from trpc_agent_sdk.skills.tools._skill_load import _set_state_delta_for_skill_tools
-from trpc_agent_sdk.skills import Skill
-from trpc_agent_sdk.skills import SkillResource
-
-
-class TestSkillList:
- """Test suite for skill_list function."""
-
- def test_skill_list_success(self):
- """Test listing all skills."""
- mock_repository = Mock()
- mock_repository.skill_list.return_value = ["skill1", "skill2", "skill3"]
-
- mock_ctx = Mock(spec=InvocationContext)
- mock_ctx.agent_context = Mock()
- mock_ctx.agent_context.get_metadata = Mock(return_value=mock_repository)
-
- result = skill_list(mock_ctx)
-
- assert len(result) == 3
- assert "skill1" in result
- assert "skill2" in result
- assert "skill3" in result
-
- def test_skill_list_empty(self):
- """Test listing skills when none exist."""
- mock_repository = Mock()
- mock_repository.skill_list.return_value = []
-
- mock_ctx = Mock(spec=InvocationContext)
- mock_ctx.agent_context = Mock()
- mock_ctx.agent_context.get_metadata = Mock(return_value=mock_repository)
-
- result = skill_list(mock_ctx)
-
- assert result == []
-
- def test_skill_list_repository_not_found(self):
- """Test listing skills when repository not found."""
- mock_ctx = Mock(spec=InvocationContext)
- mock_ctx.agent_context = Mock()
- mock_ctx.agent_context.get_metadata = Mock(return_value=None)
-
- with pytest.raises(ValueError, match="repository not found"):
- skill_list(mock_ctx)
-
-
-class TestSkillListDocs:
- """Test suite for skill_list_docs function."""
-
- def test_skill_list_docs_success(self):
- """Test listing docs for a skill."""
- mock_repository = Mock()
- skill = Skill(
- resources=[
- SkillResource(path="doc1.md", content="content1"),
- SkillResource(path="doc2.md", content="content2"),
- ]
- )
- mock_repository.get.return_value = skill
-
- mock_ctx = Mock(spec=InvocationContext)
- mock_ctx.agent_context = Mock()
- mock_ctx.agent_context.get_metadata = Mock(return_value=mock_repository)
-
- result = skill_list_docs(mock_ctx, "test-skill")
-
- assert len(result["docs"]) == 2
- assert "doc1.md" in result["docs"]
- assert "doc2.md" in result["docs"]
-
- def test_skill_list_docs_no_resources(self):
- """Test listing docs for skill with no resources."""
- mock_repository = Mock()
- skill = Skill(resources=[])
- mock_repository.get.return_value = skill
-
- mock_ctx = Mock(spec=InvocationContext)
- mock_ctx.agent_context = Mock()
- mock_ctx.agent_context.get_metadata = Mock(return_value=mock_repository)
-
- result = skill_list_docs(mock_ctx, "test-skill")
-
- assert result["docs"] == []
-
- def test_skill_list_docs_skill_not_found(self):
- """Test listing docs for nonexistent skill."""
- mock_repository = Mock()
- mock_repository.get.return_value = None
-
- mock_ctx = Mock(spec=InvocationContext)
- mock_ctx.agent_context = Mock()
- mock_ctx.agent_context.get_metadata = Mock(return_value=mock_repository)
-
- result = skill_list_docs(mock_ctx, "nonexistent-skill")
-
- assert result["docs"] == []
-
- def test_skill_list_docs_repository_not_found(self):
- """Test listing docs when repository not found."""
- mock_ctx = Mock(spec=InvocationContext)
- mock_ctx.agent_context = Mock()
- mock_ctx.agent_context.get_metadata = Mock(return_value=None)
-
- with pytest.raises(ValueError, match="repository not found"):
- skill_list_docs(mock_ctx, "test-skill")
-
-
-class TestSkillListTools:
- """Test suite for skill_list_tools function."""
-
- def test_skill_list_tools_success(self):
- """Test listing tools for a skill."""
- mock_repository = Mock()
- skill = Skill(tools=["tool1", "tool2", "tool3"])
- mock_repository.get.return_value = skill
-
- mock_ctx = Mock(spec=InvocationContext)
- mock_ctx.agent_context = Mock()
- mock_ctx.agent_context.get_metadata = Mock(return_value=mock_repository)
-
- result = skill_list_tools(mock_ctx, "test-skill")
-
- assert len(result["tools"]) == 3
- assert "tool1" in result["tools"]
- assert "tool2" in result["tools"]
- assert "tool3" in result["tools"]
-
- def test_skill_list_tools_no_tools(self):
- """Test listing tools for skill with no tools."""
- mock_repository = Mock()
- skill = Skill(tools=[])
- mock_repository.get.return_value = skill
-
- mock_ctx = Mock(spec=InvocationContext)
- mock_ctx.agent_context = Mock()
- mock_ctx.agent_context.get_metadata = Mock(return_value=mock_repository)
-
- result = skill_list_tools(mock_ctx, "test-skill")
-
- assert result["tools"] == []
-
- def test_skill_list_tools_skill_not_found(self):
- """Test listing tools for nonexistent skill."""
- mock_repository = Mock()
- mock_repository.get.return_value = None
-
- mock_ctx = Mock(spec=InvocationContext)
- mock_ctx.agent_context = Mock()
- mock_ctx.agent_context.get_metadata = Mock(return_value=mock_repository)
-
- result = skill_list_tools(mock_ctx, "nonexistent-skill")
-
- assert result["tools"] == []
-
- def test_skill_list_tools_repository_not_found(self):
- """Test listing tools when repository not found."""
- mock_ctx = Mock(spec=InvocationContext)
- mock_ctx.agent_context = Mock()
- mock_ctx.agent_context.get_metadata = Mock(return_value=None)
-
- with pytest.raises(ValueError, match="repository not found"):
- skill_list_tools(mock_ctx, "test-skill")
-
-
-class TestSkillLoad:
- """Test suite for skill_load function."""
-
- def test_skill_load_success(self):
- """Test loading a skill."""
- mock_repository = Mock()
- skill = Skill(body="skill body", tools=[])
- mock_repository.get.return_value = skill
-
- mock_ctx = Mock(spec=InvocationContext)
- mock_ctx.agent_context = Mock()
- mock_ctx.agent_context.get_metadata = Mock(return_value=mock_repository)
- mock_ctx.actions = Mock()
- mock_ctx.actions.state_delta = {}
-
- result = skill_load(mock_ctx, "test-skill")
-
- assert "loaded" in result
- assert "test-skill" in result
-
- def test_skill_load_with_tools(self):
- """Test loading a skill with tools."""
- mock_repository = Mock()
- skill = Skill(body="skill body", tools=["tool1", "tool2"])
- mock_repository.get.return_value = skill
-
- mock_ctx = Mock(spec=InvocationContext)
- mock_ctx.agent_context = Mock()
- mock_ctx.agent_context.get_metadata = Mock(return_value=mock_repository)
- mock_ctx.actions = Mock()
- mock_ctx.actions.state_delta = {}
-
- result = skill_load(mock_ctx, "test-skill")
-
- assert "loaded" in result
- # Check that tools state was set
- tools_key = "temp:skill:tools:test-skill"
- assert tools_key in mock_ctx.actions.state_delta
- assert json.loads(mock_ctx.actions.state_delta[tools_key]) == ["tool1", "tool2"]
-
- def test_skill_load_with_docs(self):
- """Test loading a skill with specific docs."""
- mock_repository = Mock()
- skill = Skill(body="skill body", tools=[])
- mock_repository.get.return_value = skill
-
- mock_ctx = Mock(spec=InvocationContext)
- mock_ctx.agent_context = Mock()
- mock_ctx.agent_context.get_metadata = Mock(return_value=mock_repository)
- mock_ctx.actions = Mock()
- mock_ctx.actions.state_delta = {}
-
- result = skill_load(mock_ctx, "test-skill", docs=["doc1.md"])
-
- assert "loaded" in result
- docs_key = "temp:skill:docs:test-skill"
- assert docs_key in mock_ctx.actions.state_delta
-
- def test_skill_load_with_include_all_docs(self):
- """Test loading a skill with include_all_docs=True."""
- mock_repository = Mock()
- skill = Skill(body="skill body", tools=[])
- mock_repository.get.return_value = skill
-
- mock_ctx = Mock(spec=InvocationContext)
- mock_ctx.agent_context = Mock()
- mock_ctx.agent_context.get_metadata = Mock(return_value=mock_repository)
- mock_ctx.actions = Mock()
- mock_ctx.actions.state_delta = {}
-
- result = skill_load(mock_ctx, "test-skill", include_all_docs=True)
-
- assert "loaded" in result
- assert mock_ctx.actions.state_delta.get("temp:skill:docs:test-skill") == '*'
-
- def test_skill_load_skill_not_found(self):
- """Test loading nonexistent skill."""
- mock_repository = Mock()
- mock_repository.get.return_value = None
-
- mock_ctx = Mock(spec=InvocationContext)
- mock_ctx.agent_context = Mock()
- mock_ctx.agent_context.get_metadata = Mock(return_value=mock_repository)
-
- result = skill_load(mock_ctx, "nonexistent-skill")
-
- assert "not found" in result
-
- def test_skill_load_repository_not_found(self):
- """Test loading skill when repository not found."""
- mock_ctx = Mock(spec=InvocationContext)
- mock_ctx.agent_context = Mock()
- mock_ctx.agent_context.get_metadata = Mock(return_value=None)
-
- with pytest.raises(ValueError, match="repository not found"):
- skill_load(mock_ctx, "test-skill")
-
-
-class TestSkillSelectDocs:
- """Test suite for skill_select_docs function."""
-
- def test_skill_select_docs_replace_mode(self):
- """Test selecting docs with replace mode."""
- mock_ctx = Mock(spec=InvocationContext)
- mock_ctx.session_state = {}
- mock_ctx.actions = Mock()
- mock_ctx.actions.state_delta = {}
-
- result = skill_select_docs(mock_ctx, "test-skill", docs=["doc1.md"], mode="replace")
-
- assert isinstance(result, SkillSelectDocsResult)
- assert result.skill == "test-skill"
- assert result.mode == "replace"
- assert "doc1.md" in result.selected_docs
-
- def test_skill_select_docs_add_mode(self):
- """Test selecting docs with add mode."""
- mock_ctx = Mock(spec=InvocationContext)
- mock_ctx.session_state = {
- "temp:skill:docs:test-skill": json.dumps(["doc1.md"])
- }
- mock_ctx.actions = Mock()
- mock_ctx.actions.state_delta = {}
-
- result = skill_select_docs(mock_ctx, "test-skill", docs=["doc2.md"], mode="add")
-
- assert result.mode == "add"
- assert "doc1.md" in result.selected_docs
- assert "doc2.md" in result.selected_docs
-
- def test_skill_select_docs_clear_mode(self):
- """Test selecting docs with clear mode."""
- mock_ctx = Mock(spec=InvocationContext)
- mock_ctx.session_state = {
- "temp:skill:docs:test-skill": json.dumps(["doc1.md"])
- }
- mock_ctx.actions = Mock()
- mock_ctx.actions.state_delta = {}
-
- result = skill_select_docs(mock_ctx, "test-skill", mode="clear")
-
- assert result.mode == "clear"
- assert result.include_all_docs is False
- assert result.selected_docs == []
-
- def test_skill_select_docs_with_include_all(self):
- """Test selecting docs with include_all_docs=True."""
- mock_ctx = Mock(spec=InvocationContext)
- mock_ctx.session_state = {}
- mock_ctx.actions = Mock()
- mock_ctx.actions.state_delta = {}
-
- result = skill_select_docs(mock_ctx, "test-skill", include_all_docs=True, mode="replace")
-
- assert result.include_all_docs is True
- assert result.selected_docs == []
-
- def test_skill_select_docs_invalid_mode(self):
- """Test selecting docs with invalid mode defaults to replace."""
- mock_ctx = Mock(spec=InvocationContext)
- mock_ctx.session_state = {}
- mock_ctx.actions = Mock()
- mock_ctx.actions.state_delta = {}
-
- result = skill_select_docs(mock_ctx, "test-skill", docs=["doc1.md"], mode="invalid")
-
- assert result.mode == "replace"
-
- def test_skill_select_docs_previous_all_docs(self):
- """Test selecting docs when previous state has all docs."""
- mock_ctx = Mock(spec=InvocationContext)
- mock_ctx.session_state = {
- "temp:skill:docs:test-skill": '*'
- }
- mock_ctx.actions = Mock()
- mock_ctx.actions.state_delta = {}
-
- result = skill_select_docs(mock_ctx, "test-skill", docs=["doc1.md"], mode="add")
-
- assert result.include_all_docs is True
-
-
-class TestSkillSelectTools:
- """Test suite for skill_select_tools function."""
-
- def test_skill_select_tools_replace_mode(self):
- """Test selecting tools with replace mode."""
- mock_ctx = Mock(spec=InvocationContext)
- mock_ctx.session_state = {}
- mock_ctx.actions = Mock()
- mock_ctx.actions.state_delta = {}
-
- result = skill_select_tools(mock_ctx, "test-skill", tools=["tool1"], mode="replace")
-
- assert isinstance(result, SkillSelectToolsResult)
- assert result.skill == "test-skill"
- assert result.mode == "replace"
- assert "tool1" in result.selected_tools
-
- def test_skill_select_tools_add_mode(self):
- """Test selecting tools with add mode."""
- mock_ctx = Mock(spec=InvocationContext)
- mock_ctx.session_state = {
- "temp:skill:tools:test-skill": json.dumps(["tool1"])
- }
- mock_ctx.actions = Mock()
- mock_ctx.actions.state_delta = {}
-
- result = skill_select_tools(mock_ctx, "test-skill", tools=["tool2"], mode="add")
-
- assert result.mode == "add"
- assert "tool1" in result.selected_tools
- assert "tool2" in result.selected_tools
-
- def test_skill_select_tools_clear_mode(self):
- """Test selecting tools with clear mode."""
- mock_ctx = Mock(spec=InvocationContext)
- mock_ctx.session_state = {
- "temp:skill:tools:test-skill": json.dumps(["tool1"])
- }
- mock_ctx.actions = Mock()
- mock_ctx.actions.state_delta = {}
-
- result = skill_select_tools(mock_ctx, "test-skill", mode="clear")
-
- assert result.mode == "clear"
- assert result.include_all_tools is False
- assert result.selected_tools == []
-
- def test_skill_select_tools_with_include_all(self):
- """Test selecting tools with include_all_tools=True."""
- mock_ctx = Mock(spec=InvocationContext)
- mock_ctx.session_state = {}
- mock_ctx.actions = Mock()
- mock_ctx.actions.state_delta = {}
-
- result = skill_select_tools(mock_ctx, "test-skill", include_all_tools=True, mode="replace")
-
- assert result.include_all_tools is True
- assert result.selected_tools == []
-
- def test_skill_select_tools_invalid_mode(self):
- """Test selecting tools with invalid mode defaults to replace."""
- mock_ctx = Mock(spec=InvocationContext)
- mock_ctx.session_state = {}
- mock_ctx.actions = Mock()
- mock_ctx.actions.state_delta = {}
-
- result = skill_select_tools(mock_ctx, "test-skill", tools=["tool1"], mode="invalid")
-
- assert result.mode == "replace"
-
- def test_skill_select_tools_previous_all_tools(self):
- """Test selecting tools when previous state has all tools."""
- mock_ctx = Mock(spec=InvocationContext)
- mock_ctx.session_state = {
- "temp:skill:tools:test-skill": '*'
- }
- mock_ctx.actions = Mock()
- mock_ctx.actions.state_delta = {}
-
- result = skill_select_tools(mock_ctx, "test-skill", tools=["tool1"], mode="add")
-
- assert result.include_all_tools is True
-
-
-class TestSkillSelectDocsResult:
- """Test suite for SkillSelectDocsResult class."""
-
- def test_create_result(self):
- """Test creating SkillSelectDocsResult."""
- result = SkillSelectDocsResult(
- skill="test-skill",
- selected_docs=["doc1.md"],
- include_all_docs=True,
- mode="replace"
- )
-
- assert result.skill == "test-skill"
- assert result.selected_docs == ["doc1.md"]
- assert result.include_all_docs is True
- assert result.mode == "replace"
-
- def test_create_result_with_alias_fields(self):
- """Test creating SkillSelectDocsResult with alias fields."""
- result = SkillSelectDocsResult(
- skill="test-skill",
- selected_items=["doc1.md", "doc2.md"],
- include_all=True,
- mode="replace"
- )
-
- # Alias fields should be mapped to actual fields
- assert result.selected_docs == ["doc1.md", "doc2.md"]
- assert result.include_all_docs is True
-
-
-class TestSkillSelectToolsResult:
- """Test suite for SkillSelectToolsResult class."""
-
- def test_create_result(self):
- """Test creating SkillSelectToolsResult."""
- result = SkillSelectToolsResult(
- skill="test-skill",
- selected_tools=["tool1"],
- include_all_tools=True,
- mode="replace"
- )
-
- assert result.skill == "test-skill"
- assert result.selected_tools == ["tool1"]
- assert result.include_all_tools is True
- assert result.mode == "replace"
-
- def test_create_result_with_alias_fields(self):
- """Test creating SkillSelectToolsResult with alias fields."""
- result = SkillSelectToolsResult(
- skill="test-skill",
- selected_items=["tool1", "tool2"],
- include_all=True,
- mode="replace"
- )
-
- # Alias fields should be mapped to actual fields
- assert result.selected_tools == ["tool1", "tool2"]
- assert result.include_all_tools is True
-
-
-class TestStateDeltaHelpers:
- """Test suite for state delta helper functions."""
-
- def test_set_state_delta_for_skill_load(self):
- """Test setting state delta for skill load."""
- mock_ctx = Mock(spec=InvocationContext)
- mock_ctx.actions = Mock()
- mock_ctx.actions.state_delta = {}
-
- _set_state_delta_for_skill_load(mock_ctx, "test-skill", ["doc1.md"], False)
-
- loaded_key = "temp:skill:loaded:test-skill"
- docs_key = "temp:skill:docs:test-skill"
-
- assert loaded_key in mock_ctx.actions.state_delta
- assert docs_key in mock_ctx.actions.state_delta
- assert json.loads(mock_ctx.actions.state_delta[docs_key]) == ["doc1.md"]
-
- def test_set_state_delta_for_skill_load_include_all(self):
- """Test setting state delta with include_all_docs."""
- mock_ctx = Mock(spec=InvocationContext)
- mock_ctx.actions = Mock()
- mock_ctx.actions.state_delta = {}
-
- _set_state_delta_for_skill_load(mock_ctx, "test-skill", [], True)
-
- docs_key = "temp:skill:docs:test-skill"
- assert mock_ctx.actions.state_delta[docs_key] == '*'
-
- def test_set_state_delta_for_skill_tools(self):
- """Test setting state delta for skill tools."""
- mock_ctx = Mock(spec=InvocationContext)
- mock_ctx.actions = Mock()
- mock_ctx.actions.state_delta = {}
-
- _set_state_delta_for_skill_tools(mock_ctx, "test-skill", ["tool1", "tool2"])
-
- tools_key = "temp:skill:tools:test-skill"
- assert tools_key in mock_ctx.actions.state_delta
- assert json.loads(mock_ctx.actions.state_delta[tools_key]) == ["tool1", "tool2"]
diff --git a/tests/skills/tools/__init__.py b/tests/skills/tools/__init__.py
index e69de29..8b13789 100644
--- a/tests/skills/tools/__init__.py
+++ b/tests/skills/tools/__init__.py
@@ -0,0 +1 @@
+
diff --git a/tests/skills/tools/test_common.py b/tests/skills/tools/test_common.py
new file mode 100644
index 0000000..52d712f
--- /dev/null
+++ b/tests/skills/tools/test_common.py
@@ -0,0 +1,31 @@
+from unittest.mock import MagicMock
+
+import pytest
+
+from trpc_agent_sdk.skills.tools._common import get_staged_workspace_dir
+from trpc_agent_sdk.skills.tools._common import inline_json_schema_refs
+from trpc_agent_sdk.skills.tools._common import require_non_empty
+from trpc_agent_sdk.skills.tools._common import set_staged_workspace_dir
+
+
+def test_require_non_empty():
+ assert require_non_empty(" ok ", field_name="x") == "ok"
+ with pytest.raises(ValueError, match="x is required"):
+ require_non_empty(" ", field_name="x")
+
+
+def test_inline_json_schema_refs():
+ schema = {"$defs": {"S": {"type": "string"}}, "properties": {"name": {"$ref": "#/$defs/S"}}}
+ out = inline_json_schema_refs(schema)
+ assert "$defs" not in out
+ assert out["properties"]["name"]["type"] == "string"
+
+
+def test_staged_workspace_dir_round_trip():
+ metadata = {}
+ ctx = MagicMock()
+ ctx.agent_context.get_metadata = MagicMock(side_effect=lambda key, default=None: metadata.get(key, default))
+ ctx.agent_context.with_metadata = MagicMock(side_effect=lambda key, value: metadata.__setitem__(key, value))
+
+ set_staged_workspace_dir(ctx, "skill-a", "skills/skill-a")
+ assert get_staged_workspace_dir(ctx, "skill-a") == "skills/skill-a"
diff --git a/tests/skills/tools/test_save_artifact.py b/tests/skills/tools/test_save_artifact.py
new file mode 100644
index 0000000..837b771
--- /dev/null
+++ b/tests/skills/tools/test_save_artifact.py
@@ -0,0 +1,49 @@
+from pathlib import Path
+from unittest.mock import MagicMock
+
+import pytest
+
+from trpc_agent_sdk.skills._constants import SKILL_ARTIFACTS_STATE_KEY
+from trpc_agent_sdk.skills.tools._save_artifact import SaveArtifactTool
+from trpc_agent_sdk.skills.tools._save_artifact import _apply_artifact_state_delta
+from trpc_agent_sdk.skills.tools._save_artifact import _artifact_save_skip_reason
+from trpc_agent_sdk.skills.tools._save_artifact import _normalize_artifact_path
+from trpc_agent_sdk.skills.tools._save_artifact import _normalize_workspace_prefix
+
+
+def test_normalize_workspace_prefix():
+ assert _normalize_workspace_prefix("workspace://work/a.txt") == "work/a.txt"
+ assert _normalize_workspace_prefix("$WORK_DIR/a.txt") == "work/a.txt"
+
+
+def test_normalize_artifact_path_valid_and_invalid(tmp_path: Path):
+ workspace_root = str(tmp_path)
+ rel, abs_path = _normalize_artifact_path("work/a.txt", workspace_root)
+ assert rel == "work/a.txt"
+ assert abs_path.endswith("work/a.txt")
+
+ with pytest.raises(ValueError, match="stay within the workspace"):
+ _normalize_artifact_path("../a.txt", workspace_root)
+
+
+def test_artifact_save_skip_reason_and_state_delta():
+ ctx = MagicMock()
+ ctx.artifact_service = object()
+ ctx.session = object()
+ ctx.app_name = "app"
+ ctx.user_id = "u"
+ ctx.session_id = "s"
+ ctx.function_call_id = "fc-1"
+ ctx.actions.state_delta = {}
+ assert _artifact_save_skip_reason(ctx) == ""
+
+ _apply_artifact_state_delta(ctx, "work/a.txt", 2, "artifact://work/a.txt@2")
+ value = ctx.actions.state_delta[SKILL_ARTIFACTS_STATE_KEY]
+ assert value["tool_call_id"] == "fc-1"
+ assert value["artifacts"][0]["version"] == 2
+
+
+def test_save_artifact_declaration_name():
+ declaration = SaveArtifactTool()._get_declaration()
+ assert declaration is not None
+ assert declaration.name == "workspace_save_artifact"
diff --git a/tests/skills/tools/test_skill_exec.py b/tests/skills/tools/test_skill_exec.py
index 82a0ecb..1593e5d 100644
--- a/tests/skills/tools/test_skill_exec.py
+++ b/tests/skills/tools/test_skill_exec.py
@@ -3,887 +3,107 @@
# Copyright (C) 2026 Tencent. All rights reserved.
#
# tRPC-Agent-Python is licensed under Apache-2.0.
-"""Unit tests for trpc_agent_sdk.skills.tools._skill_exec.
-Covers:
-- Pydantic I/O models: ExecInput, WriteStdinInput, PollSessionInput,
- KillSessionInput, SessionInteraction, ExecOutput, SessionKillOutput
-- Helper functions: _last_non_empty_line, _has_selection_items,
- _detect_interaction, _build_exec_env, _resolve_abs_cwd
-- SkillExecTool: session management, declarations
-- create_exec_tools factory
-- _close_session
-"""
-
-from __future__ import annotations
-
-import asyncio
-import os
-from pathlib import Path
-from unittest.mock import AsyncMock, MagicMock, patch
+from unittest.mock import AsyncMock
+from unittest.mock import MagicMock
import pytest
-
-from trpc_agent_sdk.skills.tools._skill_exec import (
- DEFAULT_EXEC_YIELD_MS,
- DEFAULT_IO_YIELD_MS,
- DEFAULT_POLL_LINES,
- DEFAULT_SESSION_TTL,
- ExecInput,
- ExecOutput,
- KillSessionInput,
- KillSessionTool,
- PollSessionInput,
- PollSessionTool,
- SessionInteraction,
- SessionKillOutput,
- SkillExecTool,
- WriteStdinInput,
- WriteStdinTool,
- _close_session,
- _detect_interaction,
- _has_selection_items,
- _last_non_empty_line,
- _resolve_abs_cwd,
- create_exec_tools,
-)
-
-
-# ---------------------------------------------------------------------------
-# _last_non_empty_line
-# ---------------------------------------------------------------------------
-
-class TestLastNonEmptyLine:
- def test_normal(self):
- assert _last_non_empty_line("line1\nline2\nline3") == "line3"
-
- def test_trailing_empty_lines(self):
- assert _last_non_empty_line("line1\nline2\n\n\n") == "line2"
-
- def test_all_empty(self):
- assert _last_non_empty_line("\n\n") == ""
-
- def test_empty_string(self):
- assert _last_non_empty_line("") == ""
-
- def test_single_line(self):
- assert _last_non_empty_line("hello") == "hello"
-
- def test_whitespace_lines(self):
- assert _last_non_empty_line(" \n \n content \n ") == "content"
-
-
-# ---------------------------------------------------------------------------
-# _has_selection_items
-# ---------------------------------------------------------------------------
-
-class TestHasSelectionItems:
- def test_numbered_list(self):
- text = "Choose an option:\n1. Option A\n2. Option B\n3. Option C"
- assert _has_selection_items(text) is True
-
- def test_numbered_paren(self):
- text = "1) Option A\n2) Option B"
- assert _has_selection_items(text) is True
-
- def test_no_numbers(self):
- text = "No numbered items here"
- assert _has_selection_items(text) is False
-
- def test_single_number(self):
- text = "1. Only one item"
- assert _has_selection_items(text) is False
-
- def test_empty(self):
- assert _has_selection_items("") is False
-
-
-# ---------------------------------------------------------------------------
-# _detect_interaction
-# ---------------------------------------------------------------------------
-
-class TestDetectInteraction:
- def test_exited_returns_none(self):
- assert _detect_interaction("exited", "some output") is None
-
- def test_empty_output_returns_none(self):
- assert _detect_interaction("running", "") is None
-
- def test_colon_prompt(self):
- result = _detect_interaction("running", "Enter your name:")
- assert result is not None
- assert result.needs_input is True
- assert result.kind == "prompt"
-
- def test_question_mark_prompt(self):
- result = _detect_interaction("running", "Continue?")
- assert result is not None
- assert result.needs_input is True
-
- def test_press_enter(self):
- result = _detect_interaction("running", "Press Enter to continue")
- assert result is not None
- assert result.needs_input is True
-
- def test_selection_detection(self):
- text = "Choose a number:\n1. Option A\n2. Option B\nEnter the number:"
- result = _detect_interaction("running", text)
- assert result is not None
- assert result.kind == "selection"
-
- def test_normal_output(self):
- result = _detect_interaction("running", "Processing data...\nDone.")
- assert result is None
-
- def test_type_your_prompt(self):
- result = _detect_interaction("running", "Type your answer here")
- assert result is not None
- assert result.needs_input is True
-
-
-# ---------------------------------------------------------------------------
-# _resolve_abs_cwd
-# ---------------------------------------------------------------------------
-
-class TestResolveAbsCwd:
- def test_relative_cwd(self, tmp_path):
- result = _resolve_abs_cwd(str(tmp_path), "sub")
- assert os.path.isabs(result)
- assert os.path.isdir(result)
-
- def test_absolute_cwd(self, tmp_path):
- result = _resolve_abs_cwd(str(tmp_path), str(tmp_path / "abs"))
- assert result == str(tmp_path / "abs")
-
- def test_empty_cwd(self, tmp_path):
- result = _resolve_abs_cwd(str(tmp_path), "")
- assert os.path.isabs(result)
-
-
-# ---------------------------------------------------------------------------
-# Pydantic I/O models
-# ---------------------------------------------------------------------------
-
-class TestExecModels:
- def test_exec_input_required(self):
- inp = ExecInput(skill="test", command="ls")
- assert inp.skill == "test"
- assert inp.command == "ls"
+from trpc_agent_sdk.code_executors import DEFAULT_EXEC_YIELD_MS
+from trpc_agent_sdk.code_executors import DEFAULT_IO_YIELD_MS
+from trpc_agent_sdk.code_executors import DEFAULT_POLL_LINES
+from trpc_agent_sdk.code_executors import DEFAULT_SESSION_TTL_SEC
+from trpc_agent_sdk.skills.tools._skill_exec import ExecInput
+from trpc_agent_sdk.skills.tools._skill_exec import PollSessionTool
+from trpc_agent_sdk.skills.tools._skill_exec import SkillExecTool
+from trpc_agent_sdk.skills.tools._skill_exec import WriteStdinTool
+from trpc_agent_sdk.skills.tools._skill_exec import _close_session
+from trpc_agent_sdk.skills.tools._skill_exec import _detect_interaction
+from trpc_agent_sdk.skills.tools._skill_exec import _has_selection_items
+from trpc_agent_sdk.skills.tools._skill_exec import _last_non_empty_line
+from trpc_agent_sdk.skills.tools._skill_exec import create_exec_tools
+
+
+def _make_exec_tool() -> SkillExecTool:
+ run_tool = MagicMock()
+ run_tool._repository = MagicMock()
+ run_tool._timeout = 300.0
+ run_tool._resolve_cwd = MagicMock(return_value="skills/test")
+ run_tool._build_command = MagicMock(return_value=("bash", ["-lc", "echo hello"]))
+ run_tool._prepare_outputs = AsyncMock(return_value=([], None))
+ run_tool._attach_artifacts_if_requested = AsyncMock()
+ run_tool._merge_manifest_artifact_refs = MagicMock()
+ return SkillExecTool(run_tool)
+
+
+class TestHelpers:
+ def test_last_non_empty_line(self):
+ assert _last_non_empty_line("a\n\nb\n") == "b"
+
+ def test_has_selection_items(self):
+ assert _has_selection_items("1. a\n2. b") is True
+ assert _has_selection_items("1. a") is False
+
+ def test_detect_interaction_prompt(self):
+ ret = _detect_interaction("running", "Enter your name:")
+ assert ret is not None
+ assert ret.needs_input is True
+
+ def test_detect_interaction_selection(self):
+ ret = _detect_interaction("running", "Choose:\n1. A\n2. B\nEnter the number:")
+ assert ret is not None
+ assert ret.kind == "selection"
+
+
+class TestModelsAndConstants:
+ def test_exec_input_defaults(self):
+ inp = ExecInput(skill="s", command="echo hi")
+ assert inp.yield_time_ms == 0
+ assert inp.poll_lines == 0
assert inp.tty is False
- assert inp.yield_ms == 0
-
- def test_write_stdin_input(self):
- inp = WriteStdinInput(session_id="abc")
- assert inp.session_id == "abc"
- assert inp.chars == ""
- assert inp.submit is False
-
- def test_poll_session_input(self):
- inp = PollSessionInput(session_id="abc")
- assert inp.session_id == "abc"
-
- def test_kill_session_input(self):
- inp = KillSessionInput(session_id="abc")
- assert inp.session_id == "abc"
-
- def test_session_interaction(self):
- si = SessionInteraction(needs_input=True, kind="prompt", hint="Enter:")
- assert si.needs_input is True
- assert si.kind == "prompt"
-
- def test_exec_output_defaults(self):
- out = ExecOutput()
- assert out.status == "running"
- assert out.session_id == ""
- assert out.output == ""
- assert out.exit_code is None
-
- def test_session_kill_output(self):
- out = SessionKillOutput(ok=True, session_id="abc", status="killed")
- assert out.ok is True
-
-
-# ---------------------------------------------------------------------------
-# Constants
-# ---------------------------------------------------------------------------
-
-class TestExecConstants:
- def test_defaults(self):
- assert DEFAULT_EXEC_YIELD_MS == 300
- assert DEFAULT_IO_YIELD_MS == 100
- assert DEFAULT_POLL_LINES == 50
- assert DEFAULT_SESSION_TTL == 300.0
-
-
-# ---------------------------------------------------------------------------
-# SkillExecTool — session management
-# ---------------------------------------------------------------------------
-
-class TestSkillExecToolSessions:
- def _make_exec_tool(self):
- run_tool = MagicMock()
- run_tool._repository = MagicMock()
- run_tool._timeout = 300.0
- return SkillExecTool(run_tool)
-
- async def test_put_and_get_session(self):
- tool = self._make_exec_tool()
- mock_session = MagicMock()
- mock_session.exited_at = None
- await tool._put_session("s1", mock_session)
- result = await tool.get_session("s1")
- assert result is mock_session
-
- async def test_get_unknown_session_raises(self):
- tool = self._make_exec_tool()
- with pytest.raises(ValueError, match="unknown session_id"):
- await tool.get_session("nonexistent")
-
- async def test_remove_session(self):
- tool = self._make_exec_tool()
- mock_session = MagicMock()
- mock_session.exited_at = None
- await tool._put_session("s1", mock_session)
- result = await tool.remove_session("s1")
- assert result is mock_session
-
- async def test_remove_unknown_session_raises(self):
- tool = self._make_exec_tool()
- with pytest.raises(ValueError, match="unknown session_id"):
- await tool.remove_session("nonexistent")
-
- def test_declaration(self):
- tool = self._make_exec_tool()
- decl = tool._get_declaration()
- assert decl.name == "skill_exec"
-
-
-# ---------------------------------------------------------------------------
-# WriteStdinTool, PollSessionTool, KillSessionTool — declarations
-# ---------------------------------------------------------------------------
-
-class TestSubToolDeclarations:
- def _make_exec_tool(self):
- run_tool = MagicMock()
- run_tool._repository = MagicMock()
- run_tool._timeout = 300.0
- return SkillExecTool(run_tool)
- def test_write_stdin_declaration(self):
- exec_tool = self._make_exec_tool()
- tool = WriteStdinTool(exec_tool)
- decl = tool._get_declaration()
- assert decl.name == "skill_write_stdin"
+ def test_default_constants(self):
+ assert DEFAULT_EXEC_YIELD_MS > 0
+ assert DEFAULT_IO_YIELD_MS > 0
+ assert DEFAULT_POLL_LINES > 0
+ assert DEFAULT_SESSION_TTL_SEC > 0
- def test_poll_session_declaration(self):
- exec_tool = self._make_exec_tool()
- tool = PollSessionTool(exec_tool)
- decl = tool._get_declaration()
- assert decl.name == "skill_poll_session"
- def test_kill_session_declaration(self):
- exec_tool = self._make_exec_tool()
- tool = KillSessionTool(exec_tool)
- decl = tool._get_declaration()
- assert decl.name == "skill_kill_session"
-
-
-# ---------------------------------------------------------------------------
-# create_exec_tools
-# ---------------------------------------------------------------------------
+class TestSessionStore:
+ @pytest.mark.asyncio
+ async def test_put_get_remove(self):
+ tool = _make_exec_tool()
+ sess = MagicMock()
+ sess.exited_at = None
+ sess.proc.state = AsyncMock(return_value=MagicMock(status="running", exit_code=None))
+ await tool._put_session("s1", sess)
+ got = await tool._get_session("s1")
+ assert got is sess
+ removed = await tool._remove_session("s1")
+ assert removed is sess
-class TestCreateExecTools:
- def test_creates_four_tools(self):
- run_tool = MagicMock()
- run_tool._repository = MagicMock()
- run_tool._timeout = 300.0
- result = create_exec_tools(run_tool)
- assert len(result) == 4
- exec_tool, write_tool, poll_tool, kill_tool = result
- assert isinstance(exec_tool, SkillExecTool)
- assert isinstance(write_tool, WriteStdinTool)
- assert isinstance(poll_tool, PollSessionTool)
- assert isinstance(kill_tool, KillSessionTool)
- def test_custom_ttl(self):
+class TestFactoryAndDeclarations:
+ def test_create_exec_tools(self):
run_tool = MagicMock()
run_tool._repository = MagicMock()
run_tool._timeout = 300.0
- exec_tool, *_ = create_exec_tools(run_tool, session_ttl=60.0)
- assert exec_tool._ttl == 60.0
+ tools = create_exec_tools(run_tool)
+ assert len(tools) == 4
+ assert isinstance(tools[0], SkillExecTool)
+ assert isinstance(tools[1], WriteStdinTool)
+ assert isinstance(tools[2], PollSessionTool)
+ def test_declaration_names(self):
+ exec_tool = _make_exec_tool()
+ assert exec_tool._get_declaration().name == "skill_exec"
+ assert WriteStdinTool(exec_tool)._get_declaration().name == "skill_write_stdin"
+ assert PollSessionTool(exec_tool)._get_declaration().name == "skill_poll_session"
-# ---------------------------------------------------------------------------
-# _close_session
-# ---------------------------------------------------------------------------
class TestCloseSession:
- def test_close_with_reader_task(self):
- sess = MagicMock()
- sess.reader_task = MagicMock()
- sess.reader_task.done = MagicMock(return_value=False)
- sess.master_fd = None
- sess.proc.stdin = None
- _close_session(sess)
- sess.reader_task.cancel.assert_called_once()
-
- def test_close_with_master_fd(self):
- sess = MagicMock()
- sess.reader_task = None
- sess.master_fd = 42
- sess.proc.stdin = None
- with patch("os.close") as mock_close:
- _close_session(sess)
- mock_close.assert_called_once_with(42)
- assert sess.master_fd is None
-
- def test_close_with_stdin(self):
- sess = MagicMock()
- sess.reader_task = None
- sess.master_fd = None
- sess.proc.stdin = MagicMock()
- sess.proc.stdin.is_closing = MagicMock(return_value=False)
- _close_session(sess)
- sess.proc.stdin.close.assert_called_once()
-
- def test_close_already_done(self):
+ @pytest.mark.asyncio
+ async def test_close_session(self):
sess = MagicMock()
- sess.reader_task = MagicMock()
- sess.reader_task.done = MagicMock(return_value=True)
- sess.master_fd = None
- sess.proc.stdin = None
- _close_session(sess)
- sess.reader_task.cancel.assert_not_called()
-
- def test_close_no_reader_task(self):
- sess = MagicMock()
- sess.reader_task = None
- sess.master_fd = None
- sess.proc.stdin = None
- _close_session(sess)
-
- def test_close_master_fd_oserror(self):
- sess = MagicMock()
- sess.reader_task = None
- sess.master_fd = 42
- sess.proc.stdin = None
- with patch("os.close", side_effect=OSError("test")):
- _close_session(sess)
- assert sess.master_fd is None
-
- def test_close_stdin_closing(self):
- sess = MagicMock()
- sess.reader_task = None
- sess.master_fd = None
- sess.proc.stdin = MagicMock()
- sess.proc.stdin.is_closing = MagicMock(return_value=True)
- _close_session(sess)
- sess.proc.stdin.close.assert_not_called()
-
-
-# ---------------------------------------------------------------------------
-# _ExecSession
-# ---------------------------------------------------------------------------
-
-class TestExecSession:
- async def test_append_and_total_output(self):
- from trpc_agent_sdk.skills.tools._skill_exec import _ExecSession
- proc = MagicMock()
- proc.returncode = None
- ws = MagicMock()
- in_data = ExecInput(skill="test", command="echo hello")
- sess = _ExecSession(proc=proc, ws=ws, in_data=in_data)
- await sess.append_output("hello ")
- await sess.append_output("world")
- total = await sess.total_output()
- assert total == "hello world"
-
- async def test_yield_output_exited(self):
- from trpc_agent_sdk.skills.tools._skill_exec import _ExecSession
- proc = MagicMock()
- proc.returncode = 0
- ws = MagicMock()
- in_data = ExecInput(skill="test", command="echo hello")
- sess = _ExecSession(proc=proc, ws=ws, in_data=in_data)
- await sess.append_output("output text")
- status, chunk, offset, next_offset = await sess.yield_output(50, 0)
- assert status == "exited"
- assert "output text" in chunk
- assert sess.exit_code == 0
-
-
-# ---------------------------------------------------------------------------
-# _build_exec_env
-# ---------------------------------------------------------------------------
-
-class TestBuildExecEnv:
- def test_builds_env(self, tmp_path):
- from trpc_agent_sdk.skills.tools._skill_exec import _build_exec_env
- ws = MagicMock()
- ws.path = str(tmp_path)
- env = _build_exec_env(ws, {"MY_VAR": "test"})
- assert "MY_VAR" in env
- assert env["MY_VAR"] == "test"
-
- def test_sets_workspace_dirs(self, tmp_path):
- from trpc_agent_sdk.skills.tools._skill_exec import _build_exec_env
- ws = MagicMock()
- ws.path = str(tmp_path)
- env = _build_exec_env(ws, {})
- assert any("skills" in v.lower() or "SKILLS" in k for k, v in env.items())
-
-
-# ---------------------------------------------------------------------------
-# GC expired sessions
-# ---------------------------------------------------------------------------
-
-class TestGcExpired:
- async def test_gc_expired_sessions(self):
- import time
- run_tool = MagicMock()
- run_tool._repository = MagicMock()
- run_tool._timeout = 300.0
- tool = SkillExecTool(run_tool, session_ttl=1.0)
-
- mock_session = MagicMock()
- mock_session.exited_at = time.time() - 100
- mock_session.reader_task = None
- mock_session.master_fd = None
- mock_session.proc.stdin = None
- tool._sessions["expired_session"] = mock_session
-
- # gc is called within put_session
- new_session = MagicMock()
- new_session.exited_at = None
- await tool._put_session("new", new_session)
-
- assert "expired_session" not in tool._sessions
- assert "new" in tool._sessions
-
- async def test_gc_ttl_zero_skips(self):
- run_tool = MagicMock()
- run_tool._repository = MagicMock()
- run_tool._timeout = 300.0
- tool = SkillExecTool(run_tool, session_ttl=0)
-
- mock_session = MagicMock()
- mock_session.exited_at = 1.0
- tool._sessions["s1"] = mock_session
-
- new_session = MagicMock()
- new_session.exited_at = None
- await tool._put_session("s2", new_session)
-
- assert "s1" in tool._sessions
-
-
-# ---------------------------------------------------------------------------
-# _write_stdin helper
-# ---------------------------------------------------------------------------
-
-class TestWriteStdin:
- async def test_write_pipe(self):
- from trpc_agent_sdk.skills.tools._skill_exec import _write_stdin, _ExecSession
- proc = MagicMock()
- proc.returncode = None
- proc.stdin = MagicMock()
- proc.stdin.write = MagicMock()
- proc.stdin.drain = AsyncMock()
- ws = MagicMock()
- in_data = ExecInput(skill="test", command="cat")
- sess = _ExecSession(proc=proc, ws=ws, in_data=in_data)
-
- await _write_stdin(sess, "hello", submit=True)
- proc.stdin.write.assert_called_once()
- written = proc.stdin.write.call_args[0][0]
- assert b"hello\n" == written
-
- async def test_write_pipe_no_submit(self):
- from trpc_agent_sdk.skills.tools._skill_exec import _write_stdin, _ExecSession
- proc = MagicMock()
- proc.returncode = None
- proc.stdin = MagicMock()
- proc.stdin.write = MagicMock()
- proc.stdin.drain = AsyncMock()
- ws = MagicMock()
- in_data = ExecInput(skill="test", command="cat")
- sess = _ExecSession(proc=proc, ws=ws, in_data=in_data)
-
- await _write_stdin(sess, "hello", submit=False)
- written = proc.stdin.write.call_args[0][0]
- assert b"hello" == written
-
- async def test_write_pipe_error_handled(self):
- from trpc_agent_sdk.skills.tools._skill_exec import _write_stdin, _ExecSession
- proc = MagicMock()
- proc.returncode = None
- proc.stdin = MagicMock()
- proc.stdin.write = MagicMock(side_effect=RuntimeError("broken"))
- proc.stdin.drain = AsyncMock()
- ws = MagicMock()
- in_data = ExecInput(skill="test", command="cat")
- sess = _ExecSession(proc=proc, ws=ws, in_data=in_data)
-
- await _write_stdin(sess, "test", submit=False)
-
- async def test_write_pty(self):
- from trpc_agent_sdk.skills.tools._skill_exec import _write_stdin, _ExecSession
- proc = MagicMock()
- proc.returncode = None
- proc.stdin = None
- ws = MagicMock()
- in_data = ExecInput(skill="test", command="cat")
- sess = _ExecSession(proc=proc, ws=ws, in_data=in_data)
- sess.master_fd = 42
-
- with patch("os.write") as mock_write:
- await _write_stdin(sess, "hello", submit=True)
- mock_write.assert_called_once_with(42, b"hello\n")
-
- async def test_write_pty_error(self):
- from trpc_agent_sdk.skills.tools._skill_exec import _write_stdin, _ExecSession
- proc = MagicMock()
- proc.returncode = None
- proc.stdin = None
- ws = MagicMock()
- in_data = ExecInput(skill="test", command="cat")
- sess = _ExecSession(proc=proc, ws=ws, in_data=in_data)
- sess.master_fd = 42
-
- with patch("os.write", side_effect=OSError("broken")):
- await _write_stdin(sess, "hello", submit=False)
-
- async def test_write_no_stdin_no_pty(self):
- from trpc_agent_sdk.skills.tools._skill_exec import _write_stdin, _ExecSession
- proc = MagicMock()
- proc.returncode = None
- proc.stdin = None
- ws = MagicMock()
- in_data = ExecInput(skill="test", command="cat")
- sess = _ExecSession(proc=proc, ws=ws, in_data=in_data)
- await _write_stdin(sess, "hello", submit=True)
-
-
-# ---------------------------------------------------------------------------
-# _collect_final_result
-# ---------------------------------------------------------------------------
-
-class TestCollectFinalResult:
- async def test_already_finalized(self):
- from trpc_agent_sdk.skills.tools._skill_exec import _collect_final_result, _ExecSession
- from trpc_agent_sdk.skills.tools._skill_run import SkillRunOutput
- proc = MagicMock()
- proc.returncode = 0
- ws = MagicMock()
- in_data = ExecInput(skill="test", command="echo")
- sess = _ExecSession(proc=proc, ws=ws, in_data=in_data)
- sess.finalized = True
- expected = SkillRunOutput(stdout="cached")
- sess.final_result = expected
-
- ctx = MagicMock()
- run_tool = MagicMock()
- result = await _collect_final_result(ctx, sess, run_tool)
- assert result is expected
-
- async def test_collect_with_outputs(self):
- from trpc_agent_sdk.skills.tools._skill_exec import _collect_final_result, _ExecSession
- from trpc_agent_sdk.skills.tools._skill_run import SkillRunFile
- proc = MagicMock()
- proc.returncode = 0
- ws = MagicMock()
- in_data = ExecInput(skill="test", command="echo", output_files=["out/*.txt"])
- sess = _ExecSession(proc=proc, ws=ws, in_data=in_data)
- sess.exit_code = 0
- await sess.append_output("test output")
-
- ctx = MagicMock()
- ctx.artifact_service = None
- run_tool = MagicMock()
- run_tool._prepare_outputs = AsyncMock(return_value=([], None))
- run_tool._attach_artifacts_if_requested = AsyncMock()
- run_tool._merge_manifest_artifact_refs = MagicMock()
-
- result = await _collect_final_result(ctx, sess, run_tool)
- assert result is not None
- assert result.stdout == "test output"
- assert sess.finalized is True
-
- async def test_collect_prepare_outputs_error(self):
- from trpc_agent_sdk.skills.tools._skill_exec import _collect_final_result, _ExecSession
- proc = MagicMock()
- proc.returncode = 0
- ws = MagicMock()
- in_data = ExecInput(skill="test", command="echo")
- sess = _ExecSession(proc=proc, ws=ws, in_data=in_data)
- sess.exit_code = 0
- await sess.append_output("output")
-
- ctx = MagicMock()
- ctx.artifact_service = None
- run_tool = MagicMock()
- run_tool._prepare_outputs = AsyncMock(side_effect=RuntimeError("fail"))
- run_tool._attach_artifacts_if_requested = AsyncMock()
- run_tool._merge_manifest_artifact_refs = MagicMock()
-
- result = await _collect_final_result(ctx, sess, run_tool)
- assert result is not None
- assert sess.finalized is True
-
-
-# ---------------------------------------------------------------------------
-# _ExecSession — yield_output more tests
-# ---------------------------------------------------------------------------
-
-class TestExecSessionYieldOutput:
- async def test_yield_with_poll_lines_limit(self):
- from trpc_agent_sdk.skills.tools._skill_exec import _ExecSession
- proc = MagicMock()
- proc.returncode = 0
- ws = MagicMock()
- in_data = ExecInput(skill="test", command="echo")
- sess = _ExecSession(proc=proc, ws=ws, in_data=in_data)
-
- long_output = "\n".join([f"line {i}" for i in range(100)])
- await sess.append_output(long_output)
-
- status, chunk, offset, next_offset = await sess.yield_output(10, 5)
- lines = chunk.strip().split("\n")
- assert len(lines) <= 5
-
- async def test_yield_incremental(self):
- from trpc_agent_sdk.skills.tools._skill_exec import _ExecSession
- proc = MagicMock()
- proc.returncode = 0
- ws = MagicMock()
- in_data = ExecInput(skill="test", command="echo")
- sess = _ExecSession(proc=proc, ws=ws, in_data=in_data)
-
- await sess.append_output("first")
- _, chunk1, _, _ = await sess.yield_output(10, 0)
- assert "first" in chunk1
-
- await sess.append_output("second")
- _, chunk2, _, _ = await sess.yield_output(10, 0)
- assert "second" in chunk2
- assert "first" not in chunk2
-
-
-# ---------------------------------------------------------------------------
-# _read_pipe
-# ---------------------------------------------------------------------------
-
-class TestReadPipe:
- async def test_read_pipe_empty(self):
- from trpc_agent_sdk.skills.tools._skill_exec import _read_pipe, _ExecSession
- proc = MagicMock()
- proc.returncode = 0
- proc.wait = AsyncMock()
- ws = MagicMock()
- in_data = ExecInput(skill="test", command="echo")
- sess = _ExecSession(proc=proc, ws=ws, in_data=in_data)
-
- stream = AsyncMock()
- stream.read = AsyncMock(return_value=b"")
-
- await _read_pipe(sess, stream)
-
- async def test_read_pipe_with_data(self):
- from trpc_agent_sdk.skills.tools._skill_exec import _read_pipe, _ExecSession
- proc = MagicMock()
- proc.returncode = 0
- proc.wait = AsyncMock()
- ws = MagicMock()
- in_data = ExecInput(skill="test", command="echo")
- sess = _ExecSession(proc=proc, ws=ws, in_data=in_data)
-
- stream = AsyncMock()
- stream.read = AsyncMock(side_effect=[b"hello", b""])
-
- await _read_pipe(sess, stream)
- total = await sess.total_output()
- assert "hello" in total
-
-
-# ---------------------------------------------------------------------------
-# KillSessionTool._run_async_impl
-# ---------------------------------------------------------------------------
-
-class TestKillSessionToolRun:
- def _make_exec_tool(self):
- run_tool = MagicMock()
- run_tool._repository = MagicMock()
- run_tool._timeout = 300.0
- return SkillExecTool(run_tool)
-
- async def test_kill_exited_session(self):
- exec_tool = self._make_exec_tool()
- kill_tool = KillSessionTool(exec_tool)
-
- mock_session = MagicMock()
- mock_session.proc.returncode = 0
- mock_session.exited_at = None
- mock_session.reader_task = None
- mock_session.master_fd = None
- mock_session.proc.stdin = None
-
- await exec_tool._put_session("s1", mock_session)
-
- ctx = MagicMock()
- result = await kill_tool._run_async_impl(
- tool_context=ctx,
- args={"session_id": "s1"},
- )
- assert result["ok"] is True
- assert result["status"] == "exited"
-
- async def test_kill_running_session(self):
- exec_tool = self._make_exec_tool()
- kill_tool = KillSessionTool(exec_tool)
-
- mock_session = MagicMock()
- mock_session.proc.returncode = None
- mock_session.proc.kill = MagicMock()
- mock_session.proc.wait = AsyncMock()
- mock_session.exited_at = None
- mock_session.reader_task = None
- mock_session.master_fd = None
- mock_session.proc.stdin = None
-
- await exec_tool._put_session("s1", mock_session)
-
- ctx = MagicMock()
- result = await kill_tool._run_async_impl(
- tool_context=ctx,
- args={"session_id": "s1"},
- )
- assert result["ok"] is True
- assert result["status"] == "killed"
- mock_session.proc.kill.assert_called_once()
-
- async def test_kill_invalid_args_raises(self):
- exec_tool = self._make_exec_tool()
- kill_tool = KillSessionTool(exec_tool)
- ctx = MagicMock()
- with pytest.raises(ValueError, match="Invalid"):
- await kill_tool._run_async_impl(tool_context=ctx, args={})
-
- async def test_kill_unknown_session_raises(self):
- exec_tool = self._make_exec_tool()
- kill_tool = KillSessionTool(exec_tool)
- ctx = MagicMock()
- with pytest.raises(ValueError, match="unknown session_id"):
- await kill_tool._run_async_impl(
- tool_context=ctx,
- args={"session_id": "nonexistent"},
- )
-
-
-# ---------------------------------------------------------------------------
-# PollSessionTool._run_async_impl
-# ---------------------------------------------------------------------------
-
-class TestPollSessionToolRun:
- async def test_poll_running_session(self):
- run_tool = MagicMock()
- run_tool._repository = MagicMock()
- run_tool._timeout = 300.0
- exec_tool = SkillExecTool(run_tool)
- poll_tool = PollSessionTool(exec_tool)
-
- from trpc_agent_sdk.skills.tools._skill_exec import _ExecSession
- proc = MagicMock()
- proc.returncode = 0
- ws = MagicMock()
- in_data = ExecInput(skill="test", command="echo")
- sess = _ExecSession(proc=proc, ws=ws, in_data=in_data)
- await sess.append_output("poll output")
-
- await exec_tool._put_session("s1", sess)
-
- ctx = MagicMock()
- ctx.artifact_service = None
- result = await poll_tool._run_async_impl(
- tool_context=ctx,
- args={"session_id": "s1", "yield_ms": 10},
- )
- assert result["status"] == "exited"
- assert "poll output" in result["output"]
-
- async def test_poll_invalid_args_raises(self):
- run_tool = MagicMock()
- run_tool._repository = MagicMock()
- run_tool._timeout = 300.0
- exec_tool = SkillExecTool(run_tool)
- poll_tool = PollSessionTool(exec_tool)
- ctx = MagicMock()
- with pytest.raises(ValueError, match="Invalid"):
- await poll_tool._run_async_impl(tool_context=ctx, args={})
-
-
-# ---------------------------------------------------------------------------
-# WriteStdinTool._run_async_impl
-# ---------------------------------------------------------------------------
-
-class TestWriteStdinToolRun:
- async def test_write_and_poll(self):
- run_tool = MagicMock()
- run_tool._repository = MagicMock()
- run_tool._timeout = 300.0
- exec_tool = SkillExecTool(run_tool)
- write_tool = WriteStdinTool(exec_tool)
-
- from trpc_agent_sdk.skills.tools._skill_exec import _ExecSession
- proc = MagicMock()
- proc.returncode = 0
- proc.stdin = MagicMock()
- proc.stdin.write = MagicMock()
- proc.stdin.drain = AsyncMock()
- ws = MagicMock()
- in_data = ExecInput(skill="test", command="cat")
- sess = _ExecSession(proc=proc, ws=ws, in_data=in_data)
- await sess.append_output("response")
-
- await exec_tool._put_session("s1", sess)
-
- ctx = MagicMock()
- ctx.artifact_service = None
- result = await write_tool._run_async_impl(
- tool_context=ctx,
- args={"session_id": "s1", "chars": "input", "submit": True, "yield_ms": 10},
- )
- assert result["status"] == "exited"
-
- async def test_write_empty_chars(self):
- run_tool = MagicMock()
- run_tool._repository = MagicMock()
- run_tool._timeout = 300.0
- exec_tool = SkillExecTool(run_tool)
- write_tool = WriteStdinTool(exec_tool)
-
- from trpc_agent_sdk.skills.tools._skill_exec import _ExecSession
- proc = MagicMock()
- proc.returncode = 0
- ws = MagicMock()
- in_data = ExecInput(skill="test", command="cat")
- sess = _ExecSession(proc=proc, ws=ws, in_data=in_data)
-
- await exec_tool._put_session("s1", sess)
-
- ctx = MagicMock()
- ctx.artifact_service = None
- result = await write_tool._run_async_impl(
- tool_context=ctx,
- args={"session_id": "s1", "yield_ms": 10},
- )
- assert "status" in result
-
- async def test_write_invalid_args_raises(self):
- run_tool = MagicMock()
- run_tool._repository = MagicMock()
- run_tool._timeout = 300.0
- exec_tool = SkillExecTool(run_tool)
- write_tool = WriteStdinTool(exec_tool)
- ctx = MagicMock()
- with pytest.raises(ValueError, match="Invalid"):
- await write_tool._run_async_impl(tool_context=ctx, args={})
+ sess.proc.close = AsyncMock()
+ await _close_session(sess)
+ sess.proc.close.assert_awaited_once()
diff --git a/tests/skills/tools/test_skill_list_docs.py b/tests/skills/tools/test_skill_list_docs.py
index 33e6a76..26999d9 100644
--- a/tests/skills/tools/test_skill_list_docs.py
+++ b/tests/skills/tools/test_skill_list_docs.py
@@ -42,8 +42,7 @@ def test_returns_docs(self):
ctx = _make_ctx(repository=repo)
result = skill_list_docs(ctx, "test")
- assert result["docs"] == ["guide.md", "api.md"]
- assert result["body_loaded"] is True
+ assert result == ["guide.md", "api.md"]
def test_no_body(self):
skill = Skill(summary=SkillSummary(name="test"), body="")
@@ -52,17 +51,16 @@ def test_no_body(self):
ctx = _make_ctx(repository=repo)
result = skill_list_docs(ctx, "test")
- assert result["body_loaded"] is False
+ assert result == []
def test_skill_not_found(self):
repo = MagicMock()
- repo.get = MagicMock(return_value=None)
+ repo.get = MagicMock(side_effect=ValueError("not found"))
ctx = _make_ctx(repository=repo)
- result = skill_list_docs(ctx, "nonexistent")
- assert result == {"docs": [], "body_loaded": False}
+ with pytest.raises(ValueError, match="unknown skill"):
+ skill_list_docs(ctx, "nonexistent")
def test_no_repository_raises(self):
ctx = _make_ctx(repository=None)
- with pytest.raises(ValueError, match="repository not found"):
- skill_list_docs(ctx, "test")
+ assert skill_list_docs(ctx, "test") == []
diff --git a/tests/skills/tools/test_skill_list_tool.py b/tests/skills/tools/test_skill_list_tool.py
index 38769d4..220524a 100644
--- a/tests/skills/tools/test_skill_list_tool.py
+++ b/tests/skills/tools/test_skill_list_tool.py
@@ -3,13 +3,7 @@
# Copyright (C) 2026 Tencent. All rights reserved.
#
# tRPC-Agent-Python is licensed under Apache-2.0.
-"""Unit tests for trpc_agent_sdk.skills.tools._skill_list_tool.
-
-Covers:
-- _extract_shell_examples_from_skill_body: Command section parsing
-- skill_list_tools: returns tools and command examples
-- skill_list_tools: handles missing skill / repository
-"""
+"""Unit tests for trpc_agent_sdk.skills.tools._skill_list_tool."""
from __future__ import annotations
@@ -19,79 +13,10 @@
from trpc_agent_sdk.skills._types import Skill, SkillSummary
from trpc_agent_sdk.skills.tools._skill_list_tool import (
- _extract_shell_examples_from_skill_body,
skill_list_tools,
)
-# ---------------------------------------------------------------------------
-# _extract_shell_examples_from_skill_body
-# ---------------------------------------------------------------------------
-
-class TestExtractShellExamples:
- def test_empty_body(self):
- assert _extract_shell_examples_from_skill_body("") == []
-
- def test_command_section(self):
- body = "Command:\n python scripts/run.py --input data.csv\n\nOverview"
- result = _extract_shell_examples_from_skill_body(body)
- assert len(result) >= 1
- assert "python scripts/run.py" in result[0]
-
- def test_limit(self):
- body = ""
- for i in range(10):
- body += f"Command:\n cmd_{i} --arg\n\n"
- result = _extract_shell_examples_from_skill_body(body, limit=3)
- assert len(result) <= 3
-
- def test_stops_at_section_break(self):
- body = "Command:\n python run.py\n\nOutput files\nMore content"
- result = _extract_shell_examples_from_skill_body(body)
- assert len(result) == 1
-
- def test_multiline_command(self):
- body = "Command:\n python scripts/long.py \\\n --arg1 val1 \\\n --arg2 val2\n\n"
- result = _extract_shell_examples_from_skill_body(body)
- assert len(result) >= 1
-
- def test_no_command_section(self):
- body = "# Overview\nJust a description.\n"
- result = _extract_shell_examples_from_skill_body(body)
- assert result == []
-
- def test_deduplication(self):
- body = "Command:\n python run.py\n\nCommand:\n python run.py\n"
- result = _extract_shell_examples_from_skill_body(body)
- assert len(result) == 1
-
- def test_rejects_non_command_starting_chars(self):
- body = "Command:\n !not_a_command\n\n"
- result = _extract_shell_examples_from_skill_body(body)
- assert result == []
-
- def test_command_with_numbered_break(self):
- body = "Command:\n python run.py\n\n1) Next section\n"
- result = _extract_shell_examples_from_skill_body(body)
- assert len(result) == 1
-
- def test_skips_empty_lines_before_command(self):
- body = "Command:\n\n\n python run.py\n\nEnd"
- result = _extract_shell_examples_from_skill_body(body)
- assert len(result) >= 1
-
- def test_stops_at_tools_section(self):
- body = "Command:\n python run.py\n\ntools:\n- tool1"
- result = _extract_shell_examples_from_skill_body(body)
- assert len(result) == 1
-
- def test_whitespace_normalization(self):
- body = "Command:\n python run.py --arg val\n\n"
- result = _extract_shell_examples_from_skill_body(body)
- assert len(result) == 1
- assert " " not in result[0]
-
-
# ---------------------------------------------------------------------------
# skill_list_tools
# ---------------------------------------------------------------------------
@@ -103,7 +28,7 @@ def _make_ctx(repository=None):
class TestSkillListTools:
- def test_returns_tools_and_examples(self):
+ def test_returns_tools(self):
skill = Skill(
summary=SkillSummary(name="test"),
body="Command:\n python run.py\n\nOverview",
@@ -114,8 +39,7 @@ def test_returns_tools_and_examples(self):
ctx = _make_ctx(repository=repo)
result = skill_list_tools(ctx, "test")
- assert result["tools"] == ["get_weather", "get_data"]
- assert len(result["command_examples"]) >= 1
+ assert result["available_tools"] == ["get_weather", "get_data"]
def test_skill_not_found(self):
repo = MagicMock()
@@ -123,7 +47,7 @@ def test_skill_not_found(self):
ctx = _make_ctx(repository=repo)
result = skill_list_tools(ctx, "nonexistent")
- assert result == {"tools": [], "command_examples": []}
+ assert result == {"available_tools": []}
def test_no_repository_raises(self):
ctx = _make_ctx(repository=None)
@@ -137,4 +61,4 @@ def test_no_tools_or_examples(self):
ctx = _make_ctx(repository=repo)
result = skill_list_tools(ctx, "test")
- assert result["tools"] == []
+ assert result["available_tools"] == []
diff --git a/tests/skills/tools/test_skill_load.py b/tests/skills/tools/test_skill_load.py
index 3bb87f5..7c4987b 100644
--- a/tests/skills/tools/test_skill_load.py
+++ b/tests/skills/tools/test_skill_load.py
@@ -14,17 +14,22 @@
from __future__ import annotations
+import asyncio
import json
+from unittest.mock import AsyncMock
from unittest.mock import MagicMock
+from unittest.mock import patch
import pytest
+from trpc_agent_sdk.skills._common import docs_state_key
+from trpc_agent_sdk.skills._common import loaded_state_key
+from trpc_agent_sdk.skills._common import set_state_delta
+from trpc_agent_sdk.skills._common import tool_state_key
+from trpc_agent_sdk.skills._constants import SKILL_REPOSITORY_KEY
from trpc_agent_sdk.skills._types import Skill, SkillSummary
from trpc_agent_sdk.skills.tools._skill_load import (
- _set_state_delta,
- _set_state_delta_for_skill_load,
- _set_state_delta_for_skill_tools,
- skill_load,
+ SkillLoadTool,
)
@@ -32,9 +37,44 @@ def _make_ctx(repository=None):
ctx = MagicMock()
ctx.actions.state_delta = {}
ctx.agent_context.get_metadata = MagicMock(return_value=repository)
+ ctx.agent_name = ""
return ctx
+def _set_state_delta_for_skill_load(ctx, skill_name: str, docs: list[str], include_all_docs: bool = False):
+ set_state_delta(ctx, loaded_state_key(ctx, skill_name), True)
+ set_state_delta(
+ ctx,
+ docs_state_key(ctx, skill_name),
+ "*" if include_all_docs else json.dumps(docs or []),
+ )
+
+
+def _set_state_delta_for_skill_tools(ctx, skill_name: str, tools: list[str]):
+ set_state_delta(ctx, tool_state_key(ctx, skill_name), json.dumps(tools or []))
+
+
+def _set_state_delta(ctx, key: str, value: str):
+ set_state_delta(ctx, key, value)
+
+
+def skill_load(ctx, skill_name: str, docs: list[str] | None = None, include_all_docs: bool = False) -> str:
+ repository = ctx.agent_context.get_metadata(SKILL_REPOSITORY_KEY)
+ if repository is None:
+ raise ValueError("repository not found")
+ tool = SkillLoadTool(repository=repository)
+ with patch.object(SkillLoadTool, "_ensure_staged", new=AsyncMock(return_value=None)):
+ return asyncio.run(
+ tool._run_async_impl(
+ tool_context=ctx,
+ args={
+ "skill_name": skill_name,
+ "docs": docs or [],
+ "include_all_docs": include_all_docs,
+ },
+ ))
+
+
# ---------------------------------------------------------------------------
# _set_state_delta
# ---------------------------------------------------------------------------
@@ -55,21 +95,24 @@ class TestSetStateDeltaForSkillLoad:
def test_sets_loaded_flag(self):
ctx = MagicMock()
ctx.actions.state_delta = {}
+ ctx.agent_name = ""
_set_state_delta_for_skill_load(ctx, "test-skill", [])
- assert ctx.actions.state_delta["temp:skill:loaded:test-skill"] is True
+ assert ctx.actions.state_delta[loaded_state_key(ctx, "test-skill")] is True
def test_sets_docs_as_json(self):
ctx = MagicMock()
ctx.actions.state_delta = {}
+ ctx.agent_name = ""
_set_state_delta_for_skill_load(ctx, "test-skill", ["doc1.md", "doc2.md"])
- docs_value = ctx.actions.state_delta["temp:skill:docs:test-skill"]
+ docs_value = ctx.actions.state_delta[docs_state_key(ctx, "test-skill")]
assert json.loads(docs_value) == ["doc1.md", "doc2.md"]
def test_include_all_docs_sets_star(self):
ctx = MagicMock()
ctx.actions.state_delta = {}
+ ctx.agent_name = ""
_set_state_delta_for_skill_load(ctx, "test-skill", [], include_all_docs=True)
- assert ctx.actions.state_delta["temp:skill:docs:test-skill"] == "*"
+ assert ctx.actions.state_delta[docs_state_key(ctx, "test-skill")] == "*"
# ---------------------------------------------------------------------------
@@ -80,15 +123,17 @@ class TestSetStateDeltaForSkillTools:
def test_sets_tools_as_json(self):
ctx = MagicMock()
ctx.actions.state_delta = {}
+ ctx.agent_name = ""
_set_state_delta_for_skill_tools(ctx, "test-skill", ["tool_a", "tool_b"])
- tools_value = ctx.actions.state_delta["temp:skill:tools:test-skill"]
+ tools_value = ctx.actions.state_delta[tool_state_key(ctx, "test-skill")]
assert json.loads(tools_value) == ["tool_a", "tool_b"]
def test_empty_tools(self):
ctx = MagicMock()
ctx.actions.state_delta = {}
+ ctx.agent_name = ""
_set_state_delta_for_skill_tools(ctx, "test-skill", [])
- tools_value = ctx.actions.state_delta["temp:skill:tools:test-skill"]
+ tools_value = ctx.actions.state_delta[tool_state_key(ctx, "test-skill")]
assert json.loads(tools_value) == []
@@ -105,15 +150,15 @@ def test_load_success(self):
result = skill_load(ctx, "test")
assert "loaded" in result
- assert ctx.actions.state_delta["temp:skill:loaded:test"] is True
+ assert ctx.actions.state_delta[loaded_state_key(ctx, "test")] is True
def test_load_not_found(self):
repo = MagicMock()
- repo.get = MagicMock(return_value=None)
+ repo.get = MagicMock(side_effect=ValueError("not found"))
ctx = _make_ctx(repository=repo)
- result = skill_load(ctx, "nonexistent")
- assert "not found" in result
+ with pytest.raises(ValueError, match="not found"):
+ skill_load(ctx, "nonexistent")
def test_load_no_repository_raises(self):
ctx = _make_ctx(repository=None)
@@ -131,7 +176,7 @@ def test_load_with_tools_sets_tools_state(self):
ctx = _make_ctx(repository=repo)
skill_load(ctx, "test")
- tools_key = "temp:skill:tools:test"
+ tools_key = tool_state_key(ctx, "test")
assert tools_key in ctx.actions.state_delta
assert json.loads(ctx.actions.state_delta[tools_key]) == ["get_weather", "get_data"]
@@ -142,7 +187,7 @@ def test_load_without_tools_does_not_set_tools_state(self):
ctx = _make_ctx(repository=repo)
skill_load(ctx, "test")
- assert "temp:skill:tools:test" not in ctx.actions.state_delta
+ assert tool_state_key(ctx, "test") not in ctx.actions.state_delta
def test_load_with_docs(self):
skill = Skill(summary=SkillSummary(name="test"), body="# Body")
@@ -151,7 +196,7 @@ def test_load_with_docs(self):
ctx = _make_ctx(repository=repo)
skill_load(ctx, "test", docs=["doc1.md"])
- docs_key = "temp:skill:docs:test"
+ docs_key = docs_state_key(ctx, "test")
assert json.loads(ctx.actions.state_delta[docs_key]) == ["doc1.md"]
def test_load_include_all_docs(self):
@@ -161,5 +206,5 @@ def test_load_include_all_docs(self):
ctx = _make_ctx(repository=repo)
skill_load(ctx, "test", include_all_docs=True)
- docs_key = "temp:skill:docs:test"
+ docs_key = docs_state_key(ctx, "test")
assert ctx.actions.state_delta[docs_key] == "*"
diff --git a/tests/skills/tools/test_skill_run.py b/tests/skills/tools/test_skill_run.py
index c7489c2..b1e43ae 100644
--- a/tests/skills/tools/test_skill_run.py
+++ b/tests/skills/tools/test_skill_run.py
@@ -3,1050 +3,122 @@
# Copyright (C) 2026 Tencent. All rights reserved.
#
# tRPC-Agent-Python is licensed under Apache-2.0.
-"""Unit tests for trpc_agent_sdk.skills.tools._skill_run.
-
-Covers:
-- Module-level helpers:
- _inline_json_schema_refs, _is_text_mime, _should_inline_file_content,
- _truncate_output, _workspace_ref, _filter_failed_empty_outputs,
- _select_primary_output, _split_command_line, _build_editor_wrapper_script
-- Pydantic models: SkillRunFile, SkillRunInput, SkillRunOutput, ArtifactInfo
-- SkillRunTool:
- _resolve_cwd, _build_command, _wrap_with_venv, _is_skill_loaded,
- _get_repository, _is_missing_command_result, _extract_command_path_candidates,
- _extract_shell_examples_from_skill_body, _with_missing_command_hint
-"""
-
-from __future__ import annotations
import os
-from pathlib import Path
-from unittest.mock import AsyncMock, MagicMock, patch
+from unittest.mock import MagicMock
import pytest
-
-from trpc_agent_sdk.skills.tools._skill_run import (
- ArtifactInfo,
- SkillRunFile,
- SkillRunInput,
- SkillRunOutput,
- SkillRunTool,
- _build_editor_wrapper_script,
- _filter_failed_empty_outputs,
- _inline_json_schema_refs,
- _is_text_mime,
- _select_primary_output,
- _should_inline_file_content,
- _split_command_line,
- _truncate_output,
- _workspace_ref,
-)
-
-
-# ---------------------------------------------------------------------------
-# _inline_json_schema_refs
-# ---------------------------------------------------------------------------
-
-class TestInlineJsonSchemaRefs:
- def test_no_refs(self):
- schema = {"type": "object", "properties": {"name": {"type": "string"}}}
- result = _inline_json_schema_refs(schema)
- assert result == schema
-
- def test_with_refs(self):
- schema = {
- "$defs": {"Foo": {"type": "object", "properties": {"x": {"type": "integer"}}}},
- "properties": {"foo": {"$ref": "#/$defs/Foo"}},
- }
- result = _inline_json_schema_refs(schema)
- assert "$defs" not in result
- assert result["properties"]["foo"]["type"] == "object"
-
- def test_nested_refs(self):
- schema = {
- "$defs": {"Bar": {"type": "string"}},
- "properties": {"items": {"type": "array", "items": {"$ref": "#/$defs/Bar"}}},
- }
- result = _inline_json_schema_refs(schema)
- assert result["properties"]["items"]["items"]["type"] == "string"
-
-
-# ---------------------------------------------------------------------------
-# _is_text_mime
-# ---------------------------------------------------------------------------
-
-class TestIsTextMime:
- def test_text_plain(self):
+from trpc_agent_sdk.skills._common import loaded_state_key
+from trpc_agent_sdk.skills.tools._common import inline_json_schema_refs
+from trpc_agent_sdk.skills.tools._skill_run import ArtifactInfo
+from trpc_agent_sdk.skills.tools._skill_run import SkillRunFile
+from trpc_agent_sdk.skills.tools._skill_run import SkillRunInput
+from trpc_agent_sdk.skills.tools._skill_run import SkillRunOutput
+from trpc_agent_sdk.skills.tools._skill_run import SkillRunTool
+from trpc_agent_sdk.skills.tools._skill_run import _build_editor_wrapper_script
+from trpc_agent_sdk.skills.tools._skill_run import _filter_failed_empty_outputs
+from trpc_agent_sdk.skills.tools._skill_run import _is_text_mime
+from trpc_agent_sdk.skills.tools._skill_run import _select_primary_output
+from trpc_agent_sdk.skills.tools._skill_run import _should_inline_file_content
+from trpc_agent_sdk.skills.tools._skill_run import _split_command_line
+from trpc_agent_sdk.skills.tools._skill_run import _truncate_output
+from trpc_agent_sdk.skills.tools._skill_run import _workspace_ref
+
+
+def _make_tool() -> SkillRunTool:
+ repo = MagicMock()
+ repo.workspace_runtime = MagicMock()
+ return SkillRunTool(repository=repo)
+
+
+class TestSchemaHelpers:
+ def test_inline_json_schema_refs(self):
+ schema = {"$defs": {"X": {"type": "string"}}, "properties": {"x": {"$ref": "#/$defs/X"}}}
+ out = inline_json_schema_refs(schema)
+ assert "$defs" not in out
+ assert out["properties"]["x"]["type"] == "string"
+
+
+class TestModuleHelpers:
+ def test_is_text_mime(self):
assert _is_text_mime("text/plain") is True
-
- def test_text_html(self):
- assert _is_text_mime("text/html") is True
-
- def test_application_json(self):
assert _is_text_mime("application/json") is True
-
- def test_application_yaml(self):
- assert _is_text_mime("application/yaml") is True
-
- def test_application_xml(self):
- assert _is_text_mime("application/xml") is True
-
- def test_image_png(self):
assert _is_text_mime("image/png") is False
- def test_empty_string_is_text(self):
- assert _is_text_mime("") is True
-
- def test_with_charset(self):
- assert _is_text_mime("application/json; charset=utf-8") is True
-
- def test_octet_stream(self):
- assert _is_text_mime("application/octet-stream") is False
-
-
-# ---------------------------------------------------------------------------
-# _should_inline_file_content
-# ---------------------------------------------------------------------------
-
-class TestShouldInlineFileContent:
- def test_text_file(self):
- from trpc_agent_sdk.code_executors import CodeFile
- f = CodeFile(name="test.txt", content="hello", mime_type="text/plain", size_bytes=5)
- assert _should_inline_file_content(f) is True
-
- def test_binary_file(self):
+ def test_should_inline_file_content(self):
from trpc_agent_sdk.code_executors import CodeFile
- f = CodeFile(name="test.png", content="data", mime_type="image/png", size_bytes=4)
- assert _should_inline_file_content(f) is False
-
- def test_null_bytes_rejected(self):
- from trpc_agent_sdk.code_executors import CodeFile
- f = CodeFile(name="test.txt", content="hello\x00world", mime_type="text/plain", size_bytes=11)
- assert _should_inline_file_content(f) is False
-
- def test_empty_content(self):
- from trpc_agent_sdk.code_executors import CodeFile
- f = CodeFile(name="empty.txt", content="", mime_type="text/plain", size_bytes=0)
+ f = CodeFile(name="a.txt", content="ok", mime_type="text/plain", size_bytes=2)
assert _should_inline_file_content(f) is True
-
-# ---------------------------------------------------------------------------
-# _truncate_output
-# ---------------------------------------------------------------------------
-
-class TestTruncateOutput:
- def test_short_string(self):
- s, truncated = _truncate_output("hello")
- assert s == "hello"
- assert truncated is False
-
- def test_long_string(self):
+ def test_truncate_output(self):
s, truncated = _truncate_output("x" * 20000)
assert truncated is True
assert len(s) <= 16 * 1024
- def test_exact_limit(self):
- s, truncated = _truncate_output("x" * (16 * 1024))
- assert truncated is False
-
+ def test_workspace_ref(self):
+ assert _workspace_ref("a.txt") == "workspace://a.txt"
-# ---------------------------------------------------------------------------
-# _workspace_ref
-# ---------------------------------------------------------------------------
+ def test_filter_failed_empty_outputs(self):
+ files = [SkillRunFile(name="a.txt", content="", size_bytes=0)]
+ kept, warns = _filter_failed_empty_outputs(1, False, files)
+ assert kept == []
+ assert warns
-class TestWorkspaceRef:
- def test_with_name(self):
- assert _workspace_ref("out/file.txt") == "workspace://out/file.txt"
-
- def test_empty_name(self):
- assert _workspace_ref("") == ""
-
-
-# ---------------------------------------------------------------------------
-# _filter_failed_empty_outputs
-# ---------------------------------------------------------------------------
-
-class TestFilterFailedEmptyOutputs:
- def test_success_no_filter(self):
- files = [SkillRunFile(name="a.txt", content="data", size_bytes=4)]
- result, warns = _filter_failed_empty_outputs(0, False, files)
- assert len(result) == 1
- assert warns == []
-
- def test_failure_removes_empty(self):
- files = [
- SkillRunFile(name="a.txt", content="data", size_bytes=4),
- SkillRunFile(name="empty.txt", content="", size_bytes=0),
- ]
- result, warns = _filter_failed_empty_outputs(1, False, files)
- assert len(result) == 1
- assert result[0].name == "a.txt"
- assert len(warns) == 1
-
- def test_failure_all_have_content(self):
- files = [SkillRunFile(name="a.txt", content="data", size_bytes=4)]
- result, warns = _filter_failed_empty_outputs(1, False, files)
- assert len(result) == 1
- assert warns == []
-
- def test_timeout_removes_empty(self):
- files = [SkillRunFile(name="e.txt", content="", size_bytes=0)]
- result, warns = _filter_failed_empty_outputs(0, True, files)
- assert len(result) == 0
- assert len(warns) == 1
-
-
-# ---------------------------------------------------------------------------
-# _select_primary_output
-# ---------------------------------------------------------------------------
-
-class TestSelectPrimaryOutput:
- def test_picks_smallest_text_file(self):
+ def test_select_primary_output(self):
files = [
- SkillRunFile(name="b.txt", content="bbb", mime_type="text/plain"),
- SkillRunFile(name="a.txt", content="aaa", mime_type="text/plain"),
+ SkillRunFile(name="b.txt", content="2", mime_type="text/plain"),
+ SkillRunFile(name="a.txt", content="1", mime_type="text/plain"),
]
- result = _select_primary_output(files)
- assert result.name == "a.txt"
-
- def test_skips_empty_content(self):
- files = [SkillRunFile(name="a.txt", content="", mime_type="text/plain")]
- result = _select_primary_output(files)
- assert result is None
-
- def test_skips_binary(self):
- files = [SkillRunFile(name="a.png", content="data", mime_type="image/png")]
- result = _select_primary_output(files)
- assert result is None
-
- def test_empty_files(self):
- assert _select_primary_output([]) is None
-
- def test_skips_too_large(self):
- content = "x" * (32 * 1024 + 1)
- files = [SkillRunFile(name="big.txt", content=content, mime_type="text/plain")]
- assert _select_primary_output(files) is None
+ best = _select_primary_output(files)
+ assert best is not None
+ assert best.name == "a.txt"
-
-# ---------------------------------------------------------------------------
-# _split_command_line
-# ---------------------------------------------------------------------------
-
-class TestSplitCommandLine:
- def test_simple_command(self):
+ def test_split_command_line(self):
assert _split_command_line("python run.py") == ["python", "run.py"]
+ with pytest.raises(ValueError):
+ _split_command_line("a | b")
- def test_quoted_args(self):
- result = _split_command_line("echo 'hello world'")
- assert result == ["echo", "hello world"]
-
- def test_double_quoted(self):
- result = _split_command_line('echo "hello world"')
- assert result == ["echo", "hello world"]
-
- def test_empty_raises(self):
- with pytest.raises(ValueError, match="empty"):
- _split_command_line("")
-
- def test_whitespace_only_raises(self):
- with pytest.raises(ValueError, match="empty"):
- _split_command_line(" ")
-
- def test_shell_metachar_rejected(self):
- with pytest.raises(ValueError, match="metacharacter"):
- _split_command_line("cmd1 | cmd2")
-
- def test_semicolon_rejected(self):
- with pytest.raises(ValueError, match="metacharacter"):
- _split_command_line("cmd1; cmd2")
-
- def test_redirect_rejected(self):
- with pytest.raises(ValueError, match="metacharacter"):
- _split_command_line("cmd > file")
-
- def test_unterminated_quote_raises(self):
- with pytest.raises(ValueError, match="unterminated"):
- _split_command_line("echo 'hello")
-
- def test_trailing_escape_raises(self):
- with pytest.raises(ValueError, match="trailing"):
- _split_command_line("echo hello\\")
-
- def test_escaped_char(self):
- result = _split_command_line("echo hello\\ world")
- assert result == ["echo", "hello world"]
-
- def test_newline_rejected(self):
- with pytest.raises(ValueError, match="metacharacter"):
- _split_command_line("cmd1\ncmd2")
-
-
-# ---------------------------------------------------------------------------
-# _build_editor_wrapper_script
-# ---------------------------------------------------------------------------
-
-class TestBuildEditorWrapperScript:
- def test_contains_shebang(self):
- script = _build_editor_wrapper_script("/tmp/content.txt")
+ def test_build_editor_wrapper_script(self):
+ script = _build_editor_wrapper_script("/tmp/file")
assert script.startswith("#!/bin/sh")
+ assert "/tmp/file" in script
- def test_contains_content_path(self):
- script = _build_editor_wrapper_script("/tmp/content.txt")
- assert "/tmp/content.txt" in script
-
-# ---------------------------------------------------------------------------
-# Pydantic models
-# ---------------------------------------------------------------------------
-
-class TestSkillRunModels:
- def test_skill_run_file_defaults(self):
- f = SkillRunFile()
- assert f.name == ""
- assert f.content == ""
- assert f.size_bytes == 0
-
- def test_skill_run_input_required_fields(self):
- inp = SkillRunInput(skill="test", command="python run.py")
- assert inp.skill == "test"
- assert inp.command == "python run.py"
- assert inp.env == {}
- assert inp.output_files == []
-
- def test_skill_run_output_defaults(self):
+class TestModels:
+ def test_run_models(self):
+ inp = SkillRunInput(skill="s", command="echo hi")
out = SkillRunOutput()
- assert out.stdout == ""
+ art = ArtifactInfo(name="a.txt", version=1)
+ assert inp.skill == "s"
assert out.exit_code == 0
- assert out.timed_out is False
- assert out.output_files == []
-
- def test_artifact_info(self):
- a = ArtifactInfo(name="test.txt", version=1)
- assert a.name == "test.txt"
- assert a.version == 1
-
-
-# ---------------------------------------------------------------------------
-# SkillRunTool — helpers
-# ---------------------------------------------------------------------------
-
-class TestSkillRunToolHelpers:
- def _make_run_tool(self, **kwargs):
- repo = MagicMock()
- repo.workspace_runtime = MagicMock()
- return SkillRunTool(repository=repo, **kwargs)
-
- def test_resolve_cwd_empty(self):
- tool = self._make_run_tool()
- assert tool._resolve_cwd("", "skills/test") == "skills/test"
-
- def test_resolve_cwd_relative(self):
- tool = self._make_run_tool()
- result = tool._resolve_cwd("sub/dir", "skills/test")
- assert result == os.path.join("skills/test", "sub/dir")
+ assert art.version == 1
- def test_resolve_cwd_absolute(self):
- tool = self._make_run_tool()
- assert tool._resolve_cwd("/abs/path", "skills/test") == "/abs/path"
- def test_resolve_cwd_skills_dir_env(self):
- tool = self._make_run_tool()
- result = tool._resolve_cwd("$SKILLS_DIR/test", "skills/test")
- assert "skills" in result
+class TestSkillRunToolBasics:
+ def test_resolve_cwd(self):
+ tool = _make_tool()
+ assert tool._resolve_cwd("", "skills/x") == "skills/x"
+ assert tool._resolve_cwd("sub", "skills/x") == os.path.join("skills/x", "sub")
- def test_build_command_no_restrictions(self):
- tool = self._make_run_tool()
- cmd, args = tool._build_command("python run.py", "/ws", "skills/test")
+ def test_build_command(self):
+ tool = _make_tool()
+ cmd, args = tool._build_command("python run.py", "/tmp/ws", "skills/x")
assert cmd == "bash"
assert "-c" in args
- def test_build_command_with_allowed_cmds(self):
- tool = self._make_run_tool(allowed_cmds=["python"])
- cmd, args = tool._build_command("python run.py", "/ws", "skills/test")
- assert cmd == "python"
- assert args == ["run.py"]
-
- def test_build_command_denied_cmd_raises(self):
- tool = self._make_run_tool(denied_cmds=["rm"])
- with pytest.raises(ValueError, match="denied"):
- tool._build_command("rm -rf /", "/ws", "skills/test")
-
- def test_build_command_not_in_allowed_raises(self):
- tool = self._make_run_tool(allowed_cmds=["python"])
- with pytest.raises(ValueError, match="not in allowed"):
- tool._build_command("bash script.sh", "/ws", "skills/test")
-
- def test_is_skill_loaded_true(self):
- tool = self._make_run_tool()
- ctx = MagicMock()
- ctx.session_state = {"temp:skill:loaded:test": True}
- assert tool._is_skill_loaded(ctx, "test") is True
-
- def test_is_skill_loaded_false(self):
- tool = self._make_run_tool()
- ctx = MagicMock()
- ctx.session_state = {}
- assert tool._is_skill_loaded(ctx, "test") is False
-
- def test_is_skill_loaded_exception_defaults_true(self):
- tool = self._make_run_tool()
- ctx = MagicMock()
- ctx.session_state = MagicMock()
- ctx.session_state.get = MagicMock(side_effect=RuntimeError("oops"))
- assert tool._is_skill_loaded(ctx, "test") is True
-
- def test_get_repository_from_tool(self):
+ def test_get_repository(self):
repo = MagicMock()
repo.workspace_runtime = MagicMock()
tool = SkillRunTool(repository=repo)
ctx = MagicMock()
assert tool._get_repository(ctx) is repo
- def test_is_missing_command_result_true(self):
- from trpc_agent_sdk.code_executors import WorkspaceRunResult
- ret = WorkspaceRunResult(
- stdout="", stderr="bash: foo: command not found",
- exit_code=127, duration=0, timed_out=False,
- )
- assert SkillRunTool._is_missing_command_result(ret) is True
-
- def test_is_missing_command_result_false(self):
- from trpc_agent_sdk.code_executors import WorkspaceRunResult
- ret = WorkspaceRunResult(
- stdout="", stderr="error", exit_code=1, duration=0, timed_out=False,
- )
- assert SkillRunTool._is_missing_command_result(ret) is False
-
-
-# ---------------------------------------------------------------------------
-# SkillRunTool — _extract_shell_examples_from_skill_body
-# ---------------------------------------------------------------------------
-
-class TestExtractShellExamplesFromBody:
- def test_fenced_code_block(self):
- body = "# Usage\n```\npython scripts/run.py --input data.csv\n```\n"
- result = SkillRunTool._extract_shell_examples_from_skill_body(body)
- assert any("python" in r for r in result)
-
- def test_command_section(self):
- body = "Command:\n python scripts/analyze.py\n\nOverview"
- result = SkillRunTool._extract_shell_examples_from_skill_body(body)
- assert len(result) >= 1
-
- def test_empty_body(self):
- assert SkillRunTool._extract_shell_examples_from_skill_body("") == []
-
- def test_limit(self):
- body = ""
- for i in range(10):
- body += f"```\ncmd_{i}\n```\n"
- result = SkillRunTool._extract_shell_examples_from_skill_body(body, limit=3)
- assert len(result) <= 3
-
- def test_skips_function_calls(self):
- body = "```\nmy_function(arg='value')\n```\n"
- result = SkillRunTool._extract_shell_examples_from_skill_body(body)
- assert not any("my_function" in r for r in result)
-
-
-# ---------------------------------------------------------------------------
-# SkillRunTool — _extract_command_path_candidates
-# ---------------------------------------------------------------------------
-
-class TestExtractCommandPathCandidates:
- def test_relative_path(self):
- result = SkillRunTool._extract_command_path_candidates("python scripts/run.py")
- assert "scripts/run.py" in result
-
- def test_no_path_like_tokens(self):
- result = SkillRunTool._extract_command_path_candidates("ls -la")
- assert result == []
-
- def test_absolute_path_excluded(self):
- result = SkillRunTool._extract_command_path_candidates("python /abs/path.py")
- assert result == []
-
- def test_flags_excluded(self):
- result = SkillRunTool._extract_command_path_candidates("python --version")
- assert result == []
-
- def test_script_extensions(self):
- result = SkillRunTool._extract_command_path_candidates("bash setup.sh")
- assert "setup.sh" in result
-
-
-# ---------------------------------------------------------------------------
-# SkillRunTool — _with_missing_command_hint
-# ---------------------------------------------------------------------------
-
-class TestWithMissingCommandHint:
- def test_adds_hint_for_missing_command(self):
- from trpc_agent_sdk.code_executors import WorkspaceRunResult
- ret = WorkspaceRunResult(
- stdout="", stderr="bash: nonexist: command not found",
- exit_code=127, duration=0, timed_out=False,
- )
- inp = SkillRunInput(skill="test", command="nonexist")
- updated = SkillRunTool._with_missing_command_hint(ret, inp)
- assert "hint" in updated.stderr.lower()
-
- def test_no_hint_for_success(self):
- from trpc_agent_sdk.code_executors import WorkspaceRunResult
- ret = WorkspaceRunResult(
- stdout="ok", stderr="", exit_code=0, duration=0, timed_out=False,
- )
- inp = SkillRunInput(skill="test", command="python run.py")
- updated = SkillRunTool._with_missing_command_hint(ret, inp)
- assert updated.stderr == ""
-
-
-# ---------------------------------------------------------------------------
-# SkillRunTool — skill_stager property
-# ---------------------------------------------------------------------------
-
-class TestSkillRunToolProperties:
- def test_skill_stager_property(self):
- repo = MagicMock()
- repo.workspace_runtime = MagicMock()
- tool = SkillRunTool(repository=repo)
- assert tool.skill_stager is not None
-
- def test_custom_stager(self):
- repo = MagicMock()
- repo.workspace_runtime = MagicMock()
- custom_stager = MagicMock()
- tool = SkillRunTool(repository=repo, skill_stager=custom_stager)
- assert tool.skill_stager is custom_stager
-
- def test_declaration(self):
- repo = MagicMock()
- repo.workspace_runtime = MagicMock()
- tool = SkillRunTool(repository=repo)
- decl = tool._get_declaration()
- assert decl.name == "skill_run"
-
- def test_declaration_with_allowed_cmds(self):
- repo = MagicMock()
- repo.workspace_runtime = MagicMock()
- tool = SkillRunTool(repository=repo, allowed_cmds=["python", "bash"])
- decl = tool._get_declaration()
- assert "python" in decl.description
-
- def test_declaration_with_denied_cmds(self):
- repo = MagicMock()
- repo.workspace_runtime = MagicMock()
- tool = SkillRunTool(repository=repo, denied_cmds=["rm"])
- decl = tool._get_declaration()
- assert "Restrictions" in decl.description
-
-
-# ---------------------------------------------------------------------------
-# SkillRunTool — _wrap_with_venv
-# ---------------------------------------------------------------------------
-
-class TestWrapWithVenv:
- def _make_tool(self):
- repo = MagicMock()
- repo.workspace_runtime = MagicMock()
- return SkillRunTool(repository=repo)
-
- def test_wraps_command(self):
- tool = self._make_tool()
- result = tool._wrap_with_venv("python run.py", "/ws", "skills/test")
- assert "VIRTUAL_ENV" in result
- assert "python run.py" in result
- assert ".venv" in result
-
- def test_non_skills_cwd(self):
- tool = self._make_tool()
- result = tool._wrap_with_venv("echo hi", "/ws", "work/custom")
- assert "echo hi" in result
-
-
-# ---------------------------------------------------------------------------
-# SkillRunTool — _with_skill_doc_command_hint
-# ---------------------------------------------------------------------------
-
-class TestWithSkillDocCommandHint:
- def _make_tool(self):
- repo = MagicMock()
- repo.workspace_runtime = MagicMock()
- return SkillRunTool(repository=repo)
-
- def test_adds_hint_for_missing_command(self):
- from trpc_agent_sdk.code_executors import WorkspaceRunResult
- tool = self._make_tool()
- ret = WorkspaceRunResult(
- stdout="", stderr="bash: nonexist: command not found",
- exit_code=127, duration=0, timed_out=False,
- )
- skill = MagicMock()
- skill.body = "```\npython scripts/run.py\n```\n"
- skill.tools = ["get_weather"]
- repo = MagicMock()
- repo.get = MagicMock(return_value=skill)
- inp = SkillRunInput(skill="test", command="nonexist")
- updated = tool._with_skill_doc_command_hint(ret, repo, inp)
- assert "SKILL.md" in updated.stderr
-
- def test_no_hint_for_success(self):
- from trpc_agent_sdk.code_executors import WorkspaceRunResult
- tool = self._make_tool()
- ret = WorkspaceRunResult(
- stdout="ok", stderr="", exit_code=0, duration=0, timed_out=False,
- )
- repo = MagicMock()
- inp = SkillRunInput(skill="test", command="python run.py")
- updated = tool._with_skill_doc_command_hint(ret, repo, inp)
- assert updated.stderr == ""
-
- def test_repo_exception_returns_unchanged(self):
- from trpc_agent_sdk.code_executors import WorkspaceRunResult
- tool = self._make_tool()
- ret = WorkspaceRunResult(
- stdout="", stderr="bash: x: command not found",
- exit_code=127, duration=0, timed_out=False,
- )
- repo = MagicMock()
- repo.get = MagicMock(side_effect=RuntimeError("fail"))
- inp = SkillRunInput(skill="test", command="x")
- updated = tool._with_skill_doc_command_hint(ret, repo, inp)
- assert updated is ret
-
-
-# ---------------------------------------------------------------------------
-# SkillRunTool — _suggest_commands/_suggest_tools
-# ---------------------------------------------------------------------------
-
-class TestSuggestMethods:
- def _make_tool(self):
- repo = MagicMock()
- repo.workspace_runtime = MagicMock()
- return SkillRunTool(repository=repo)
-
- def test_suggest_commands_for_missing(self):
- from trpc_agent_sdk.code_executors import WorkspaceRunResult
- tool = self._make_tool()
- ret = WorkspaceRunResult(
- stdout="", stderr="command not found",
- exit_code=127, duration=0, timed_out=False,
- )
- skill = MagicMock()
- skill.body = "```\npython run.py\n```\n"
- repo = MagicMock()
- repo.get = MagicMock(return_value=skill)
- result = tool._suggest_commands_for_missing_command(ret, repo, "test")
- assert result is not None
-
- def test_suggest_commands_no_missing(self):
- from trpc_agent_sdk.code_executors import WorkspaceRunResult
- tool = self._make_tool()
- ret = WorkspaceRunResult(
- stdout="ok", stderr="", exit_code=0, duration=0, timed_out=False,
- )
- repo = MagicMock()
- result = tool._suggest_commands_for_missing_command(ret, repo, "test")
- assert result is None
-
- def test_suggest_tools_for_missing(self):
- from trpc_agent_sdk.code_executors import WorkspaceRunResult
- tool = self._make_tool()
- ret = WorkspaceRunResult(
- stdout="", stderr="command not found",
- exit_code=127, duration=0, timed_out=False,
- )
- skill = MagicMock()
- skill.tools = ["get_data"]
- repo = MagicMock()
- repo.get = MagicMock(return_value=skill)
- result = tool._suggest_tools_for_missing_command(ret, repo, "test")
- assert result == ["get_data"]
-
- def test_suggest_tools_no_tools(self):
- from trpc_agent_sdk.code_executors import WorkspaceRunResult
- tool = self._make_tool()
- ret = WorkspaceRunResult(
- stdout="", stderr="command not found",
- exit_code=127, duration=0, timed_out=False,
- )
- skill = MagicMock()
- skill.tools = []
- repo = MagicMock()
- repo.get = MagicMock(return_value=skill)
- result = tool._suggest_tools_for_missing_command(ret, repo, "test")
- assert result is None
-
- def test_suggest_commands_repo_error(self):
- from trpc_agent_sdk.code_executors import WorkspaceRunResult
- tool = self._make_tool()
- ret = WorkspaceRunResult(
- stdout="", stderr="command not found",
- exit_code=127, duration=0, timed_out=False,
- )
- repo = MagicMock()
- repo.get = MagicMock(side_effect=RuntimeError("fail"))
- result = tool._suggest_commands_for_missing_command(ret, repo, "test")
- assert result is None
-
-
-# ---------------------------------------------------------------------------
-# SkillRunTool — _precheck_inline_python_rewrite
-# ---------------------------------------------------------------------------
-
-class TestPrecheckInlinePythonRewrite:
- def test_not_blocked_when_disabled(self):
- repo = MagicMock()
- repo.workspace_runtime = MagicMock()
- tool = SkillRunTool(repository=repo, block_inline_python_rewrite=False)
- inp = SkillRunInput(skill="test", command="python -c 'print(1)'")
- assert tool._precheck_inline_python_rewrite(repo, inp) is None
-
- def test_not_python_c(self):
- repo = MagicMock()
- repo.workspace_runtime = MagicMock()
- tool = SkillRunTool(repository=repo, block_inline_python_rewrite=True)
- inp = SkillRunInput(skill="test", command="python run.py")
- assert tool._precheck_inline_python_rewrite(repo, inp) is None
-
- def test_blocked_with_script_examples(self):
- repo = MagicMock()
- repo.workspace_runtime = MagicMock()
- skill = MagicMock()
- skill.body = "```\npython3 scripts/analyze.py --input data\n```\n"
- repo.get = MagicMock(return_value=skill)
- tool = SkillRunTool(repository=repo, block_inline_python_rewrite=True)
- inp = SkillRunInput(skill="test", command="python -c 'import sys; print(sys.argv)'")
- result = tool._precheck_inline_python_rewrite(repo, inp)
- assert result is not None
- assert result.exit_code == 2
-
- def test_not_blocked_without_script_examples(self):
- repo = MagicMock()
- repo.workspace_runtime = MagicMock()
- skill = MagicMock()
- skill.body = "```\necho hello\n```\n"
- repo.get = MagicMock(return_value=skill)
- tool = SkillRunTool(repository=repo, block_inline_python_rewrite=True)
- inp = SkillRunInput(skill="test", command="python -c 'print(1)'")
- result = tool._precheck_inline_python_rewrite(repo, inp)
- assert result is None
-
- def test_repo_error_returns_none(self):
- repo = MagicMock()
- repo.workspace_runtime = MagicMock()
- repo.get = MagicMock(side_effect=RuntimeError("fail"))
- tool = SkillRunTool(repository=repo, block_inline_python_rewrite=True)
- inp = SkillRunInput(skill="test", command="python -c 'print(1)'")
- result = tool._precheck_inline_python_rewrite(repo, inp)
- assert result is None
-
-
-# ---------------------------------------------------------------------------
-# SkillRunTool — _list_entrypoint_suggestions
-# ---------------------------------------------------------------------------
-
-class TestListEntrypointSuggestions:
- def test_finds_scripts(self, tmp_path):
- scripts = tmp_path / "scripts"
- scripts.mkdir()
- (scripts / "run.py").write_text("#!/usr/bin/env python")
- result = SkillRunTool._list_entrypoint_suggestions(tmp_path)
- assert any("run.py" in r for r in result)
-
- def test_empty_dir(self, tmp_path):
- assert SkillRunTool._list_entrypoint_suggestions(tmp_path) == []
-
- def test_limit(self, tmp_path):
- scripts = tmp_path / "scripts"
- scripts.mkdir()
- for i in range(30):
- (scripts / f"script_{i}.py").write_text(f"# script {i}")
- result = SkillRunTool._list_entrypoint_suggestions(tmp_path, limit=5)
- assert len(result) <= 5
-
-
-# ---------------------------------------------------------------------------
-# SkillRunTool — _with_missing_entrypoint_hint
-# ---------------------------------------------------------------------------
-
-class TestWithMissingEntrypointHint:
- def test_adds_hint_for_missing_file(self, tmp_path):
- from trpc_agent_sdk.code_executors import WorkspaceRunResult
- scripts = tmp_path / "skills" / "test" / "scripts"
- scripts.mkdir(parents=True)
- (scripts / "real.py").write_text("# real")
- ws = MagicMock()
- ws.path = str(tmp_path)
- ret = WorkspaceRunResult(
- stdout="", stderr="No such file or directory",
- exit_code=1, duration=0, timed_out=False,
- )
- inp = SkillRunInput(skill="test", command="python scripts/missing.py")
- updated = SkillRunTool._with_missing_entrypoint_hint(ret, inp, ws, "skills/test")
- assert "hint" in updated.stderr.lower() or "entrypoint" in updated.stderr.lower()
-
- def test_no_hint_for_success(self):
- from trpc_agent_sdk.code_executors import WorkspaceRunResult
- ret = WorkspaceRunResult(
- stdout="ok", stderr="", exit_code=0, duration=0, timed_out=False,
- )
- ws = MagicMock()
- ws.path = "/tmp/ws"
- inp = SkillRunInput(skill="test", command="python run.py")
- updated = SkillRunTool._with_missing_entrypoint_hint(ret, inp, ws, "skills/test")
- assert updated is ret
-
-
-# ---------------------------------------------------------------------------
-# SkillRunTool — _merge_manifest_artifact_refs
-# ---------------------------------------------------------------------------
-
-class TestMergeManifestArtifactRefs:
- def _make_tool(self):
- repo = MagicMock()
- repo.workspace_runtime = MagicMock()
- return SkillRunTool(repository=repo)
-
- def test_none_manifest_noop(self):
- tool = self._make_tool()
- output = SkillRunOutput()
- tool._merge_manifest_artifact_refs(None, output)
- assert output.artifact_files == []
-
- def test_already_has_artifacts_noop(self):
- tool = self._make_tool()
- output = SkillRunOutput(artifact_files=[ArtifactInfo(name="x", version=1)])
- manifest = MagicMock()
- tool._merge_manifest_artifact_refs(manifest, output)
- assert len(output.artifact_files) == 1
-
- def test_merges_from_manifest(self):
- tool = self._make_tool()
- output = SkillRunOutput()
- manifest = MagicMock()
- fr = MagicMock()
- fr.saved_as = "artifact.txt"
- fr.version = 2
- manifest.files = [fr]
- tool._merge_manifest_artifact_refs(manifest, output)
- assert len(output.artifact_files) == 1
- assert output.artifact_files[0].name == "artifact.txt"
-
- def test_skips_unsaved(self):
- tool = self._make_tool()
- output = SkillRunOutput()
- manifest = MagicMock()
- fr = MagicMock()
- fr.saved_as = ""
- manifest.files = [fr]
- tool._merge_manifest_artifact_refs(manifest, output)
- assert output.artifact_files == []
-
-
-# ---------------------------------------------------------------------------
-# SkillRunTool — _to_run_file / _to_run_files
-# ---------------------------------------------------------------------------
-
-class TestToRunFile:
- def _make_tool(self):
- repo = MagicMock()
- repo.workspace_runtime = MagicMock()
- return SkillRunTool(repository=repo)
-
- def test_text_file(self):
- from trpc_agent_sdk.code_executors import CodeFile
- tool = self._make_tool()
- cf = CodeFile(name="test.txt", content="hello", mime_type="text/plain", size_bytes=5)
- rf = tool._to_run_file(cf)
- assert rf.name == "test.txt"
- assert rf.content == "hello"
- assert rf.ref == "workspace://test.txt"
-
- def test_binary_file_omits_content(self):
- from trpc_agent_sdk.code_executors import CodeFile
- tool = self._make_tool()
- cf = CodeFile(name="img.png", content="binary", mime_type="image/png", size_bytes=100)
- rf = tool._to_run_file(cf)
- assert rf.content == ""
-
- def test_to_run_files(self):
- from trpc_agent_sdk.code_executors import CodeFile
- tool = self._make_tool()
- files = [
- CodeFile(name="a.txt", content="a", mime_type="text/plain", size_bytes=1),
- CodeFile(name="b.txt", content="b", mime_type="text/plain", size_bytes=1),
- ]
- result = tool._to_run_files(files)
- assert len(result) == 2
-
-
-# ---------------------------------------------------------------------------
-# SkillRunTool — _prepare_editor_env
-# ---------------------------------------------------------------------------
-
-class TestPrepareEditorEnv:
- def _make_tool(self):
- repo = MagicMock()
- repo.workspace_runtime = MagicMock()
- return SkillRunTool(repository=repo)
-
- async def test_empty_editor_text_noop(self):
- tool = self._make_tool()
- ctx = MagicMock()
- ws = MagicMock()
- ws.path = "/tmp/ws"
- env = {}
- await tool._prepare_editor_env(ctx, ws, env, "")
- assert "EDITOR" not in env
-
- async def test_editor_env_conflict_raises(self):
- tool = self._make_tool()
- ctx = MagicMock()
- ws = MagicMock()
- ws.path = "/tmp/ws"
- env = {"EDITOR": "/usr/bin/vim"}
- with pytest.raises(ValueError, match="editor_text cannot be combined"):
- await tool._prepare_editor_env(ctx, ws, env, "some text")
-
- async def test_visual_env_conflict_raises(self):
- tool = self._make_tool()
- ctx = MagicMock()
- ws = MagicMock()
- ws.path = "/tmp/ws"
- env = {"VISUAL": "/usr/bin/vim"}
- with pytest.raises(ValueError, match="editor_text cannot be combined"):
- await tool._prepare_editor_env(ctx, ws, env, "some text")
-
- async def test_stages_editor_files(self, tmp_path):
- repo = MagicMock()
- fs = MagicMock()
- fs.put_files = AsyncMock()
- runtime = MagicMock()
- runtime.fs = MagicMock(return_value=fs)
- repo.workspace_runtime = runtime
- tool = SkillRunTool(repository=repo)
-
+ def test_is_skill_loaded(self):
+ tool = _make_tool()
ctx = MagicMock()
- ws = MagicMock()
- ws.path = str(tmp_path)
- env = {}
- await tool._prepare_editor_env(ctx, ws, env, "editor content")
- assert "EDITOR" in env
- assert "VISUAL" in env
-
- async def test_fallback_to_local_write(self, tmp_path):
- repo = MagicMock()
- fs = MagicMock()
- fs.put_files = AsyncMock(side_effect=RuntimeError("workspace unavailable"))
- runtime = MagicMock()
- runtime.fs = MagicMock(return_value=fs)
- repo.workspace_runtime = runtime
- tool = SkillRunTool(repository=repo)
-
- ctx = MagicMock()
- ws = MagicMock()
- ws.path = str(tmp_path)
- env = {}
- await tool._prepare_editor_env(ctx, ws, env, "editor content")
- assert "EDITOR" in env
-
-
-# ---------------------------------------------------------------------------
-# SkillRunTool — _attach_artifacts_if_requested
-# ---------------------------------------------------------------------------
-
-class TestAttachArtifacts:
- def _make_tool(self):
- repo = MagicMock()
- repo.workspace_runtime = MagicMock()
- return SkillRunTool(repository=repo)
-
- async def test_no_files_noop(self):
- tool = self._make_tool()
- ctx = MagicMock()
- ws = MagicMock()
- inp = SkillRunInput(skill="test", command="echo", save_as_artifacts=True)
- output = SkillRunOutput()
- await tool._attach_artifacts_if_requested(ctx, ws, inp, output, [])
-
- async def test_not_requested_noop(self):
- tool = self._make_tool()
- ctx = MagicMock()
- ws = MagicMock()
- inp = SkillRunInput(skill="test", command="echo", save_as_artifacts=False)
- output = SkillRunOutput()
- files = [SkillRunFile(name="a.txt", content="hello")]
- await tool._attach_artifacts_if_requested(ctx, ws, inp, output, files)
- assert output.artifact_files == []
-
- async def test_no_artifact_service_warns(self):
- tool = self._make_tool()
- ctx = MagicMock()
- ctx.artifact_service = None
- ws = MagicMock()
- inp = SkillRunInput(skill="test", command="echo", save_as_artifacts=True)
- output = SkillRunOutput()
- files = [SkillRunFile(name="a.txt", content="hello")]
- await tool._attach_artifacts_if_requested(ctx, ws, inp, output, files)
- assert any("not configured" in w for w in output.warnings)
-
- async def test_save_artifacts(self):
- tool = self._make_tool()
- ctx = MagicMock()
- ctx.artifact_service = MagicMock()
- ctx.save_artifact = AsyncMock(return_value=1)
- ws = MagicMock()
- inp = SkillRunInput(skill="test", command="echo", save_as_artifacts=True)
- output = SkillRunOutput()
- files = [SkillRunFile(name="a.txt", content="hello", mime_type="text/plain")]
- await tool._attach_artifacts_if_requested(ctx, ws, inp, output, files)
- assert len(output.artifact_files) == 1
- assert output.artifact_files[0].name == "a.txt"
-
- async def test_save_artifacts_with_prefix(self):
- tool = self._make_tool()
- ctx = MagicMock()
- ctx.artifact_service = MagicMock()
- ctx.save_artifact = AsyncMock(return_value=1)
- ws = MagicMock()
- inp = SkillRunInput(skill="test", command="echo", save_as_artifacts=True, artifact_prefix="out/")
- output = SkillRunOutput()
- files = [SkillRunFile(name="a.txt", content="hello", mime_type="text/plain")]
- await tool._attach_artifacts_if_requested(ctx, ws, inp, output, files)
- assert output.artifact_files[0].name == "out/a.txt"
-
-
-# ---------------------------------------------------------------------------
-# SkillRunTool — _prepare_outputs
-# ---------------------------------------------------------------------------
-
-class TestPrepareOutputs:
- def _make_tool(self):
- repo = MagicMock()
- fs = MagicMock()
- runtime = MagicMock()
- runtime.fs = MagicMock(return_value=fs)
- repo.workspace_runtime = runtime
- return SkillRunTool(repository=repo), fs
-
- async def test_no_outputs(self):
- tool, fs = self._make_tool()
- ctx = MagicMock()
- ws = MagicMock()
- inp = SkillRunInput(skill="test", command="echo")
- files, manifest = await tool._prepare_outputs(ctx, ws, inp)
- assert files == []
- assert manifest is None
-
- async def test_output_files_patterns(self):
- tool, fs = self._make_tool()
- from trpc_agent_sdk.code_executors import CodeFile
- fs.collect = AsyncMock(return_value=[
- CodeFile(name="out/a.txt", content="hello", mime_type="text/plain", size_bytes=5)
- ])
- ctx = MagicMock()
- ws = MagicMock()
- inp = SkillRunInput(skill="test", command="echo", output_files=["out/*.txt"])
- files, manifest = await tool._prepare_outputs(ctx, ws, inp)
- assert len(files) == 1
- assert manifest is None
+ ctx.agent_name = ""
+ ctx.actions = MagicMock()
+ key = loaded_state_key(ctx, "test")
+ ctx.actions.state_delta = {key: True}
+ ctx.session_state = {}
+ assert tool._is_skill_loaded(ctx, "test") is True
diff --git a/tests/skills/tools/test_skill_select_docs.py b/tests/skills/tools/test_skill_select_docs.py
index f52660b..1d8577e 100644
--- a/tests/skills/tools/test_skill_select_docs.py
+++ b/tests/skills/tools/test_skill_select_docs.py
@@ -17,16 +17,22 @@
import pytest
+from trpc_agent_sdk.skills._constants import SKILL_CONFIG_KEY
from trpc_agent_sdk.skills.tools._skill_select_docs import (
SkillSelectDocsResult,
skill_select_docs,
)
+from trpc_agent_sdk.skills._common import docs_state_key
+from trpc_agent_sdk.skills._skill_config import DEFAULT_SKILL_CONFIG
def _make_ctx(state_delta=None, session_state=None):
ctx = MagicMock()
ctx.actions.state_delta = state_delta or {}
ctx.session_state = session_state or {}
+ ctx.agent_name = ""
+ ctx.agent_context.get_metadata = MagicMock(
+ side_effect=lambda key, default=None: DEFAULT_SKILL_CONFIG if key == SKILL_CONFIG_KEY else default)
return ctx
@@ -76,18 +82,16 @@ def test_replace_mode(self):
assert result.selected_docs == ["doc1.md", "doc2.md"]
def test_add_mode(self):
- ctx = _make_ctx(session_state={
- "temp:skill:docs:test-skill": json.dumps(["existing.md"]),
- })
+ ctx = _make_ctx()
+ ctx.session_state = {docs_state_key(ctx, "test-skill"): json.dumps(["existing.md"])}
result = skill_select_docs(ctx, "test-skill", docs=["new.md"], mode="add")
assert result.mode == "add"
assert "existing.md" in result.selected_docs
assert "new.md" in result.selected_docs
def test_clear_mode(self):
- ctx = _make_ctx(session_state={
- "temp:skill:docs:test-skill": json.dumps(["some.md"]),
- })
+ ctx = _make_ctx()
+ ctx.session_state = {docs_state_key(ctx, "test-skill"): json.dumps(["some.md"])}
result = skill_select_docs(ctx, "test-skill", mode="clear")
assert result.mode == "clear"
assert result.selected_docs == []
@@ -100,5 +104,5 @@ def test_include_all_docs(self):
def test_updates_state_delta(self):
ctx = _make_ctx()
skill_select_docs(ctx, "test-skill", docs=["a.md"], mode="replace")
- key = "temp:skill:docs:test-skill"
+ key = docs_state_key(ctx, "test-skill")
assert key in ctx.actions.state_delta
diff --git a/tests/skills/tools/test_skill_select_tools.py b/tests/skills/tools/test_skill_select_tools.py
index 6bd2489..9b9fe98 100644
--- a/tests/skills/tools/test_skill_select_tools.py
+++ b/tests/skills/tools/test_skill_select_tools.py
@@ -17,16 +17,22 @@
import pytest
+from trpc_agent_sdk.skills._constants import SKILL_CONFIG_KEY
from trpc_agent_sdk.skills.tools._skill_select_tools import (
SkillSelectToolsResult,
skill_select_tools,
)
+from trpc_agent_sdk.skills._common import tool_state_key
+from trpc_agent_sdk.skills._skill_config import DEFAULT_SKILL_CONFIG
def _make_ctx(state_delta=None, session_state=None):
ctx = MagicMock()
ctx.actions.state_delta = state_delta or {}
ctx.session_state = session_state or {}
+ ctx.agent_name = ""
+ ctx.agent_context.get_metadata = MagicMock(
+ side_effect=lambda key, default=None: DEFAULT_SKILL_CONFIG if key == SKILL_CONFIG_KEY else default)
return ctx
@@ -74,18 +80,16 @@ def test_replace_mode(self):
assert result.selected_tools == ["tool_a", "tool_b"]
def test_add_mode(self):
- ctx = _make_ctx(session_state={
- "temp:skill:tools:test-skill": json.dumps(["existing_tool"]),
- })
+ ctx = _make_ctx()
+ ctx.session_state = {tool_state_key(ctx, "test-skill"): json.dumps(["existing_tool"])}
result = skill_select_tools(ctx, "test-skill", tools=["new_tool"], mode="add")
assert result.mode == "add"
assert "existing_tool" in result.selected_tools
assert "new_tool" in result.selected_tools
def test_clear_mode(self):
- ctx = _make_ctx(session_state={
- "temp:skill:tools:test-skill": json.dumps(["tool"]),
- })
+ ctx = _make_ctx()
+ ctx.session_state = {tool_state_key(ctx, "test-skill"): json.dumps(["tool"])}
result = skill_select_tools(ctx, "test-skill", mode="clear")
assert result.mode == "clear"
assert result.selected_tools == []
@@ -98,5 +102,5 @@ def test_include_all_tools(self):
def test_updates_state_delta(self):
ctx = _make_ctx()
skill_select_tools(ctx, "test-skill", tools=["t1"], mode="replace")
- key = "temp:skill:tools:test-skill"
+ key = tool_state_key(ctx, "test-skill")
assert key in ctx.actions.state_delta
diff --git a/tests/skills/tools/test_workspace_exec.py b/tests/skills/tools/test_workspace_exec.py
new file mode 100644
index 0000000..467e6a9
--- /dev/null
+++ b/tests/skills/tools/test_workspace_exec.py
@@ -0,0 +1,34 @@
+import pytest
+
+from trpc_agent_sdk.code_executors import PROGRAM_STATUS_RUNNING
+from trpc_agent_sdk.code_executors import ProgramPoll
+from trpc_agent_sdk.skills.tools._workspace_exec import _combine_output
+from trpc_agent_sdk.skills.tools._workspace_exec import _exec_timeout_seconds
+from trpc_agent_sdk.skills.tools._workspace_exec import _exec_yield_seconds
+from trpc_agent_sdk.skills.tools._workspace_exec import _normalize_cwd
+from trpc_agent_sdk.skills.tools._workspace_exec import _poll_output
+from trpc_agent_sdk.skills.tools._workspace_exec import _write_yield_seconds
+
+
+def test_normalize_cwd():
+ assert _normalize_cwd("") == "."
+ assert _normalize_cwd("work/demo") == "work/demo"
+ with pytest.raises(ValueError, match="within the workspace"):
+ _normalize_cwd("../demo")
+
+
+def test_timeout_and_yield_helpers():
+ assert _exec_timeout_seconds(0) > 0
+ assert _exec_timeout_seconds(3) == 3.0
+ assert _exec_yield_seconds(background=True, raw_ms=None) == 0.0
+ assert _exec_yield_seconds(background=False, raw_ms=100) == 0.1
+ assert _write_yield_seconds(None) > 0.0
+ assert _write_yield_seconds(-1) == 0.0
+
+
+def test_poll_output_and_combine_output():
+ poll = ProgramPoll(status=PROGRAM_STATUS_RUNNING, output="ok", offset=1, next_offset=2)
+ out = _poll_output("sid-1", poll)
+ assert out["session_id"] == "sid-1"
+ assert out["output"] == "ok"
+ assert _combine_output("a", "b") == "ab"
diff --git a/trpc_agent_sdk/agents/core/README.md b/trpc_agent_sdk/agents/core/README.md
new file mode 100644
index 0000000..4dbbaa2
--- /dev/null
+++ b/trpc_agent_sdk/agents/core/README.md
@@ -0,0 +1,249 @@
+# trpc_agent/agents/core 说明
+
+本文档面向读者解释 [`trpc_agent/agents/core`](./) 中与 Skills 相关的请求处理逻辑,重点覆盖三个处理器:
+
+- `SkillsRequestProcessor`([`_skill_processor.py`](./_skill_processor.py))
+- `WorkspaceExecRequestProcessor`([`_workspace_exec_processor.py`](./_workspace_exec_processor.py))
+- `SkillsToolResultRequestProcessor`([`_skills_tool_result_processor.py`](./_skills_tool_result_processor.py))
+
+文档目标是回答三个问题:
+
+1. 每个处理器解决什么问题
+2. 处理器在请求流水线中的位置与协作关系
+3. 如何在运行时判断“功能已生效”
+
+## 1. 请求流水线中的职责划分
+
+在 `RequestProcessor` 的技能相关路径中(简化):
+
+1. 组装基础 instruction
+2. 注入 tools
+3. `SkillsRequestProcessor`:注入 skills 总览与(可选)已加载内容
+4. `WorkspaceExecRequestProcessor`:注入 `workspace_exec` guidance
+5. 注入会话历史
+6. `SkillsToolResultRequestProcessor`:在 post-history 阶段做 tool result 物化
+7. 其他能力(planning/output schema 等)
+
+可理解为:
+
+- `SkillsRequestProcessor` 负责“技能上下文主编排”
+- `WorkspaceExecRequestProcessor` 负责“执行器工具选择引导”
+- `SkillsToolResultRequestProcessor` 负责“tool_result_mode 下的内容物化补强”
+
+## 2. SkillsRequestProcessor(核心入口)
+
+主入口:
+
+- `SkillsRequestProcessor.process_llm_request(ctx, request)`
+
+### 2.1 解决的问题
+
+模型在多技能场景需要两类信息:
+
+- 可用 skill 总览(有哪些技能、各自做什么)
+- 已加载 skill 的正文/文档/工具选择(当前上下文真正可用什么)
+
+`SkillsRequestProcessor` 提供统一策略来管理这些信息,并在 `turn/once/session` 三种加载模式下保持行为一致。
+
+### 2.2 核心行为
+
+一次调用中,典型流程如下:
+
+1. 获取 skill repo(支持 `repo_resolver(ctx)` 动态仓库)
+2. 执行旧状态迁移(兼容历史 key)与 turn 模式清理
+3. 注入 skill 概览(始终执行)
+4. 读取 loaded skills 并按 `max_loaded_skills` 裁剪
+5. `load_mode=session` 在 temp-only 设计下与 `turn` 读取语义一致
+6. 根据 `tool_result_mode` 分流:
+ - `False`:直接向 system instruction 注入 `[Loaded]`、`Docs loaded`、`[Doc]`
+ - `True`:跳过注入,交给 `SkillsToolResultRequestProcessor` 处理
+7. `load_mode=once` 时清理 loaded/docs/tools/order 状态(offload)
+
+### 2.3 状态语义
+
+- `turn`
+ - 每次 invocation 开始清理一次技能状态
+ - 对应:`_maybe_clear_skill_state_for_turn`
+- `once`
+ - 本轮用完后清理,避免持续占用上下文
+ - 对应:`_maybe_offload_loaded_skills`
+- `session`
+ - 在 temp-only 状态模型下,不再维护 `user:skill:*` 双键
+ - 对应:`_maybe_promote_skill_state_for_session`(当前为 no-op)
+
+读取策略为单一 temp key 读取(`session_state + state_delta` 视图)。
+
+### 2.4 关键参数
+
+- `load_mode`: `turn` / `once` / `session`
+- `tool_result_mode`: 是否改为 tool result materialization 路径
+- `tool_profile` / `allowed_skill_tools` / `tool_flags`: 限制可用技能工具能力面
+- `exec_tools_disabled`: 关闭交互执行 guidance
+- `repo_resolver`: invocation 级仓库解析
+- `max_loaded_skills`: loaded 上限(超限按顺序淘汰)
+
+参数入口:
+
+- `set_skill_processor_parameters(agent_context, parameters)`
+
+## 3. WorkspaceExecRequestProcessor(workspace_exec guidance)
+
+对应实现:
+
+- [`trpc_agent/agents/core/_workspace_exec_processor.py`](./_workspace_exec_processor.py)
+
+### 3.1 解决的问题
+
+在多工具场景下,模型容易混淆:
+
+- 什么时候使用 `workspace_exec`(通用 shell)
+- 什么时候使用 `skill_run`(技能内部执行)
+- `workspace_exec` 的路径边界、会话工具、artifact 保存边界
+
+处理器通过注入统一 guidance,降低误用和误判。
+
+### 3.2 主要行为
+
+`process_llm_request(ctx, request)` 典型步骤:
+
+1. 判断是否启用 guidance
+ - 默认按 request tools 是否包含 `workspace_exec`
+ - 支持 `enabled_resolver` 动态开关
+2. 生成 guidance 主体
+ - 通用 `workspace_exec` 使用建议
+ - `work/out/runs` 路径建议
+ - “先用小命令验证环境限制”的原则
+3. 按能力追加段落
+ - 有 `workspace_save_artifact`:追加 artifact 保存边界说明
+ - 有 skills repo:提示 `skills/` 目录并非自动 stage
+ - 有会话工具:追加 `workspace_write_stdin` / `workspace_kill_session` 生命周期提示
+4. 去重注入
+ - 若已存在 `Executor workspace guidance:` header,则不重复追加
+
+### 3.3 行为示例
+
+当工具列表包含:
+
+- `workspace_exec`
+- `workspace_write_stdin`
+- `workspace_kill_session`
+- `workspace_save_artifact`
+
+且 agent 绑定 skill repository 时,system instruction 会引导模型:
+
+- 通用 shell 优先走 `workspace_exec`
+- 路径优先 `work/`、`out/`、`runs/`
+- 限制不先假设,先验证
+- 仅在需要稳定引用时再调用 `workspace_save_artifact`
+
+### 3.4 常见误区
+
+- 误区:`workspace_exec` 会自动准备 `skills/` 内容
+ 实际:是否存在 `skills/...` 取决于是否有其他工具先 stage
+
+- 误区:遇到限制直接下结论“环境不支持”
+ 实际:应先做有界验证
+
+- 误区:所有输出都必须保存 artifact
+ 实际:应按稳定引用需求再保存
+
+### 3.5 如何验证生效
+
+优先看“发给模型前的请求”而非仅看终端事件:
+
+1. `request.config.system_instruction` 是否包含 `Executor workspace guidance:`
+2. 是否只注入一次(无重复 header)
+3. 工具选择行为是否符合预期(通用 shell 走 `workspace_exec`)
+
+## 4. SkillsToolResultRequestProcessor(tool result 物化)
+
+对应实现:
+
+- [`trpc_agent/agents/core/_skills_tool_result_processor.py`](./_skills_tool_result_processor.py)
+
+### 4.1 解决的问题
+
+仅靠 `skill_load` 的短回包(例如 `"skill 'python-math' loaded"`),模型往往拿不到可执行细节。
+这个处理器负责把“已加载 skill 的实质内容”物化到模型当前请求上下文。
+
+### 4.2 主要行为
+
+处理器会:
+
+1. 从 `session_state + state_delta` 读取已加载 skill
+2. 在 `LlmRequest.contents` 中定位最近的 `skill_load` / `skill_select_docs` response
+3. 条件满足时改写 response,注入:
+ - `[Loaded] `
+ - `Docs loaded: ...`
+ - `[Doc] ...`
+4. 若本轮没有可改写 response,fallback 到 system instruction 追加 `Loaded skill context:`
+5. `load_mode=once` 时按策略清理 loaded/docs 状态
+
+### 4.3 与 SkillsRequestProcessor 的分工
+
+- `tool_result_mode=False`
+ - 由 `SkillsRequestProcessor` 直接注入 loaded 内容
+- `tool_result_mode=True`
+ - `SkillsRequestProcessor` 不注入 loaded 内容
+ - `SkillsToolResultRequestProcessor` 在 post-history 做物化
+
+### 4.4 最小示例
+
+进入处理器前:
+
+- function call: `skill_load(demo-skill)`
+- function response: `{"result":"skill 'demo-skill' loaded"}`
+
+处理器后(示意):
+
+```text
+{
+ "result": "[Loaded] demo-skill\n\n\n\nDocs loaded: docs/guide.md\n\n[Doc] docs/guide.md\n\n"
+}
+```
+
+即使没有对应 tool response 可改写,也会通过 system instruction fallback 注入已加载上下文。
+
+### 4.5 什么时候会被误判为“没生效”
+
+最常见误区:只看终端工具即时回包。
+该处理器真实生效点是“发给模型前的请求内容”,与外层事件流并不总是 1:1。
+
+建议观察:
+
+- `request.config.system_instruction`
+- `request.contents` 里的 function response 是否已被改写
+
+### 4.6 参数入口
+
+- `load_mode`
+- `skip_fallback_on_session_summary`
+- `repo_resolver`
+
+通过:
+
+- `set_skill_tool_result_processor_parameters(agent_context, {...})`
+
+注入请求构建链路。
+
+## 5. 测试语义映射(与 examples 对齐)
+
+可配合 [`examples/skills/run_agent.py`](../../../examples/skills/run_agent.py) 与
+[`examples/skills/README.md`](../../../examples/skills/README.md) 观察实际行为。
+
+- `workspace_exec_guidance` 类测试
+ - 关注工具选择行为是否被 guidance 纠偏
+ - 核心断言是 `workspace_exec` 与 `skill_run` 调用分布
+
+- `skills_tool_result_mode` 类测试
+ - 关注 materialization 信号是否出现(`[Loaded]` / `Docs loaded` / `[Doc]`)
+ - 允许“请求层可见但终端不完全回显”的情况
+
+## 6. 给读者的排障建议
+
+1. 先确认模式参数:`load_mode`、`tool_result_mode`
+2. 再确认状态读写:`state_delta` 与 `session_state` 是否符合预期
+3. 最后看请求最终形态:
+ - 是否注入了 guidance
+ - 是否注入了 loaded context
+ - 是否发生了 offload/clear
diff --git a/trpc_agent_sdk/agents/core/__init__.py b/trpc_agent_sdk/agents/core/__init__.py
index 5797916..c8b5922 100644
--- a/trpc_agent_sdk/agents/core/__init__.py
+++ b/trpc_agent_sdk/agents/core/__init__.py
@@ -23,7 +23,13 @@
from ._request_processor import RequestProcessor
from ._request_processor import default_request_processor
from ._skill_processor import SkillsRequestProcessor
+from ._skill_processor import get_skill_processor_parameters
+from ._skill_processor import set_skill_processor_parameters
+from ._skills_tool_result_processor import get_skill_tool_result_processor_parameters
+from ._skills_tool_result_processor import set_skill_tool_result_processor_parameters
from ._tools_processor import ToolsProcessor
+from ._workspace_exec_processor import get_workspace_exec_processor_parameters
+from ._workspace_exec_processor import set_workspace_exec_processor_parameters
__all__ = [
"AgentTransferProcessor",
@@ -40,5 +46,11 @@
"RequestProcessor",
"default_request_processor",
"SkillsRequestProcessor",
+ "get_skill_processor_parameters",
+ "set_skill_processor_parameters",
+ "get_skill_tool_result_processor_parameters",
+ "set_skill_tool_result_processor_parameters",
"ToolsProcessor",
+ "get_workspace_exec_processor_parameters",
+ "set_workspace_exec_processor_parameters",
]
diff --git a/trpc_agent_sdk/agents/core/_code_execution_processor.py b/trpc_agent_sdk/agents/core/_code_execution_processor.py
index 976712e..e164135 100644
--- a/trpc_agent_sdk/agents/core/_code_execution_processor.py
+++ b/trpc_agent_sdk/agents/core/_code_execution_processor.py
@@ -1,8 +1,6 @@
-# Tencent is pleased to support the open source community by making tRPC-Agent-Python available.
+# -*- coding: utf-8 -*-
#
-# Copyright (C) 2026 Tencent. All rights reserved.
-#
-# tRPC-Agent-Python is licensed under Apache-2.0.
+# Copyright @ 2025 Tencent.com
"""Code execution processor for TRPC Agent framework.
This module provides code execution processing capabilities for LLM agents,
@@ -264,7 +262,8 @@ async def _run_post_processor(
# content to the part with the first code block.
response_content = llm_response.content
code_blocks = CodeExecutionUtils.extract_code_and_truncate_content(response_content,
- code_executor.code_block_delimiters)
+ code_executor.code_block_delimiters,
+ code_executor.ignore_codes)
# Terminal state: no code to execute.
if not code_blocks:
return
diff --git a/trpc_agent_sdk/agents/core/_request_processor.py b/trpc_agent_sdk/agents/core/_request_processor.py
index 4640f43..0e414ae 100644
--- a/trpc_agent_sdk/agents/core/_request_processor.py
+++ b/trpc_agent_sdk/agents/core/_request_processor.py
@@ -1,8 +1,6 @@
# Tencent is pleased to support the open source community by making tRPC-Agent-Python available.
#
-# Copyright (C) 2026 Tencent. All rights reserved.
-#
-# tRPC-Agent-Python is licensed under Apache-2.0.
+# Copyright @ 2025 Tencent.com
"""Request Processor implementation for TRPC Agent framework.
This module provides the RequestProcessor class which handles building LlmRequest
@@ -24,6 +22,7 @@
import copy
import inspect
import re
+from typing import Any
from typing import List
from typing import Optional
@@ -32,6 +31,7 @@
from trpc_agent_sdk.log import logger
from trpc_agent_sdk.models import LlmRequest
from trpc_agent_sdk.planners import default_planning_processor
+from trpc_agent_sdk.skills import get_skill_processor_parameters
from trpc_agent_sdk.tools import FunctionTool
from trpc_agent_sdk.tools import transfer_to_agent
from trpc_agent_sdk.types import Content
@@ -43,7 +43,12 @@
from ._history_processor import HistoryProcessor
from ._history_processor import TimelineFilterMode
from ._skill_processor import SkillsRequestProcessor
+from ._skill_processor import get_skill_processor_parameters
+from ._skills_tool_result_processor import SkillsToolResultRequestProcessor
+from ._skills_tool_result_processor import get_skill_tool_result_processor_parameters
from ._tools_processor import ToolsProcessor
+from ._workspace_exec_processor import WorkspaceExecRequestProcessor
+from ._workspace_exec_processor import get_workspace_exec_processor_parameters
class RequestProcessor:
@@ -124,11 +129,18 @@ async def build_request(
return error_event
# 5. Add skills to the request
- error_event = await self._add_skills_to_request(agent, ctx, request)
+ skill_parameters = get_skill_processor_parameters(ctx.agent_context)
+ error_event = await self._add_skills_to_request(agent, ctx, request, skill_parameters)
+ if error_event:
+ return error_event
+
+ # 6. Add workspace_exec guidance (after skills, before history)
+ workspace_exec_parameters = get_workspace_exec_processor_parameters(ctx.agent_context)
+ error_event = await self._add_workspace_exec_guidance(agent, ctx, request, workspace_exec_parameters)
if error_event:
return error_event
- # 6. Add conversation history (includes current user message in correct order)
+ # 7. Add conversation history (includes current user message in correct order)
if override_messages is not None:
# Use provided messages directly (for TeamAgent member control)
for content in override_messages:
@@ -150,12 +162,18 @@ async def build_request(
if error_event:
return error_event
- # 7. Process planning if planner is available
+ # 8. Materialize loaded-skill content into tool results (post-history).
+ skill_tool_result_parameters = get_skill_tool_result_processor_parameters(ctx.agent_context)
+ error_event = await self._add_skills_tool_results(agent, ctx, request, skill_tool_result_parameters)
+ if error_event:
+ return error_event
+
+ # 9. Process planning if planner is available
error_event = await self._add_planning_capabilities(agent, ctx, request)
if error_event:
return error_event
- # 8. Process output schema if needed (when tools are also present)
+ # 10. Process output schema if needed (when tools are also present)
error_event = await self._add_output_schema_capabilities(agent, ctx, request)
if error_event:
return error_event
@@ -303,8 +321,8 @@ async def _add_tools_to_request(self, agent: BaseAgent, ctx: InvocationContext,
return None # Success
- async def _add_skills_to_request(self, agent: BaseAgent, ctx: InvocationContext,
- request: LlmRequest) -> Optional[Event]:
+ async def _add_skills_to_request(self, agent: BaseAgent, ctx: InvocationContext, request: LlmRequest,
+ parameters: dict[str, Any]) -> Optional[Event]:
"""Add skills to the model request.
Args:
@@ -318,13 +336,58 @@ async def _add_skills_to_request(self, agent: BaseAgent, ctx: InvocationContext,
skill_repository = getattr(agent, 'skill_repository', None)
if skill_repository:
try:
- skills_processor = SkillsRequestProcessor(skill_repository)
+ skills_processor = SkillsRequestProcessor(skill_repository, **parameters)
skill_names = await skills_processor.process_llm_request(ctx, request)
logger.debug("Processed %s skills for agent: %s", len(skill_names), agent.name)
return None # Success
except Exception as ex: # pylint: disable=broad-except
logger.error("Error processing skills for agent %s: %s", agent.name, ex)
return self._create_error_event(ctx, "skill_processing_error", f"Failed to process skills: {str(ex)}")
+ return None
+
+ async def _add_workspace_exec_guidance(self, agent: BaseAgent, ctx: InvocationContext, request: LlmRequest,
+ parameters: dict[str, Any]) -> Optional[Event]:
+ """Inject workspace_exec guidance after skills and before history."""
+ try:
+ skill_repository = getattr(agent, "skill_repository", None)
+ repo_resolver = parameters.get("repo_resolver")
+ processor = WorkspaceExecRequestProcessor(
+ has_skills_repo=bool(skill_repository),
+ repo_resolver=repo_resolver if callable(repo_resolver) else None,
+ )
+ await processor.process_llm_request(ctx, request)
+ return None
+ except Exception as ex: # pylint: disable=broad-except
+ logger.error("Error injecting workspace_exec guidance for agent %s: %s", agent.name, ex)
+ return self._create_error_event(
+ ctx,
+ "workspace_exec_guidance_error",
+ f"Failed to inject workspace_exec guidance: {str(ex)}",
+ )
+
+ async def _add_skills_tool_results(self, agent: BaseAgent, ctx: InvocationContext, request: LlmRequest,
+ parameters: dict[str, Any]) -> Optional[Event]:
+ """Materialize loaded skills into tool results after history is attached."""
+ if not parameters.get("tool_result_mode"):
+ return None
+ skill_repository = getattr(agent, "skill_repository", None)
+ if skill_repository is None:
+ return None
+ try:
+ processor = SkillsToolResultRequestProcessor(
+ skill_repository,
+ skip_fallback_on_session_summary=parameters.get("skip_fallback_on_session_summary", True),
+ repo_resolver=parameters.get("repo_resolver"),
+ )
+ await processor.process_llm_request(ctx, request)
+ return None
+ except Exception as ex: # pylint: disable=broad-except
+ logger.error("Error processing skill tool results for agent %s: %s", agent.name, ex)
+ return self._create_error_event(
+ ctx,
+ "skill_tool_result_processing_error",
+ f"Failed to process skill tool results: {str(ex)}",
+ )
async def _add_agent_transfer_capabilities(self, agent: BaseAgent, ctx: InvocationContext,
request: LlmRequest) -> Optional[Event]:
@@ -879,7 +942,7 @@ def _apply_template_substitution(self, instruction: str, ctx: InvocationContext)
{session_key}, etc. with actual values from the session state.
This implementation is inspired by adk-python's inject_session_state but
- adapted for trpc_agent_sdk's architecture.
+ adapted for trpc_agent's architecture.
Args:
instruction: The instruction string containing template placeholders
@@ -920,7 +983,6 @@ def replace_placeholder(match):
# This follows the behavior of the original SafeFormatter approach
return match.group()
- # Use regex pattern similar to adk-python but simpler for trpc_agent_sdk
# This matches {variable_name} patterns including optional ones with ?
pattern = r'\{[^{}]*\}'
result = re.sub(pattern, replace_placeholder, instruction)
diff --git a/trpc_agent_sdk/agents/core/_skill_processor.py b/trpc_agent_sdk/agents/core/_skill_processor.py
index c34f6d8..1e1fbc2 100644
--- a/trpc_agent_sdk/agents/core/_skill_processor.py
+++ b/trpc_agent_sdk/agents/core/_skill_processor.py
@@ -4,55 +4,40 @@
#
# tRPC-Agent-Python is licensed under Apache-2.0.
"""SkillsRequestProcessor — injects skill overviews and loaded contents.
-
-Mirrors ``internal/flow/processor/skills.go`` from trpc-agent-go, while
-retaining Python-unique features (user_prompt, tools-selection summary).
-
-Behavior
---------
-- Overview : always injected (skill names + descriptions).
-- Loaded skills: full SKILL.md body injected into system prompt (or
- deferred to tool-result mode).
-- Docs : doc texts selected via session state keys.
-- Tools : (Python-unique) tool selection summary for each loaded skill.
-
-Skill load modes
-----------------
-- ``turn`` (default) – loaded skill content is available for all LLM
- calls within the current invocation, then cleared at the start of the
- next invocation.
-- ``once`` – loaded skill content is injected once, then offloaded
- (state keys cleared) immediately after injection.
-- ``session`` – loaded skill content persists across invocations until
- the session expires or state is cleared explicitly.
"""
from __future__ import annotations
import json
+from typing import Any
from typing import Callable
from typing import List
from typing import Optional
+from trpc_agent_sdk.context import AgentContext
from trpc_agent_sdk.context import InvocationContext
from trpc_agent_sdk.log import logger
from trpc_agent_sdk.models import LlmRequest
from trpc_agent_sdk.skills import BaseSkillRepository
-from trpc_agent_sdk.skills import SKILL_DOCS_STATE_KEY_PREFIX
-from trpc_agent_sdk.skills import SKILL_LOADED_STATE_KEY_PREFIX
-from trpc_agent_sdk.skills import SKILL_TOOLS_STATE_KEY_PREFIX
from trpc_agent_sdk.skills import Skill
-from trpc_agent_sdk.skills import generic_get_selection
-
-# ---------------------------------------------------------------------------
-# Load mode constants (mirrors Go SkillLoadModeXxx)
-# ---------------------------------------------------------------------------
-
-SKILL_LOAD_MODE_ONCE = "once"
-SKILL_LOAD_MODE_TURN = "turn"
-SKILL_LOAD_MODE_SESSION = "session"
-
-_DEFAULT_SKILL_LOAD_MODE = SKILL_LOAD_MODE_TURN
+from trpc_agent_sdk.skills import SkillLoadModeNames
+from trpc_agent_sdk.skills import SkillProfileFlags
+from trpc_agent_sdk.skills import SkillProfileNames
+from trpc_agent_sdk.skills import SkillToolsNames
+from trpc_agent_sdk.skills import docs_scan_prefix
+from trpc_agent_sdk.skills import docs_state_key
+from trpc_agent_sdk.skills import get_skill_config
+from trpc_agent_sdk.skills import loaded_order_state_key
+from trpc_agent_sdk.skills import loaded_scan_prefix
+from trpc_agent_sdk.skills import loaded_state_key
+from trpc_agent_sdk.skills import marshal_loaded_order
+from trpc_agent_sdk.skills import parse_loaded_order
+from trpc_agent_sdk.skills import set_skill_config
+from trpc_agent_sdk.skills import tool_scan_prefix
+from trpc_agent_sdk.skills import tool_state_key
+from trpc_agent_sdk.skills import touch_loaded_order
+
+from ._skills_tool_result_processor import SKILL_LOADED_RE
# ---------------------------------------------------------------------------
# Prompt section headers (mirrors Go const block)
@@ -62,49 +47,93 @@
_SKILLS_CAPABILITY_HEADER = "Skill tool availability:"
_SKILLS_TOOLING_GUIDANCE_HEADER = "Tooling and workspace guidance:"
-# ---------------------------------------------------------------------------
-# Internal state keys
-# ---------------------------------------------------------------------------
+_SKILLS_TURN_INIT_STATE_KEY = "processor:skills:turn_init"
+
+
+def normalize_load_mode(mode: str) -> str:
+ value = (mode or "").strip().lower()
+ if value in (SkillLoadModeNames.ONCE, SkillLoadModeNames.TURN, SkillLoadModeNames.SESSION):
+ return value
+ return SkillLoadModeNames.TURN
+
+
+def _append_knowledge_guidance(lines: list[str], flags: SkillProfileFlags) -> None:
+ """Append docs-loading guidance mirroring Go's appendKnowledgeGuidance."""
+ has_list_docs = flags.list_docs
+ has_select_docs = flags.select_docs
+ if has_list_docs and has_select_docs:
+ lines.append("- Use the available doc listing and selection helpers to keep"
+ " documentation loads targeted.\n")
+ elif has_list_docs:
+ lines.append("- Use the available doc listing helper to discover doc names,"
+ " then load only the docs you need.\n")
+ elif has_select_docs:
+ lines.append("- If doc names are already known, use the available doc"
+ " selection helper to keep loaded docs targeted.\n")
+ else:
+ lines.append("- If you need docs, request them directly with skill_load.docs"
+ " or include_all_docs.\n")
+ lines.append("- Avoid include_all_docs unless the user asks or the task genuinely"
+ " needs the full doc set.\n")
-# JSON array of skill names in load order — used by the max-cap eviction.
-_SKILLS_LOADED_ORDER_STATE_KEY = "temp:skill:loaded_order"
# ---------------------------------------------------------------------------
-# Normalization helpers
+# Guidance text builders (mirrors Go defaultXxxGuidance functions)
# ---------------------------------------------------------------------------
-def _normalize_load_mode(mode: str) -> str:
- m = (mode or "").strip().lower()
- if m in (SKILL_LOAD_MODE_ONCE, SKILL_LOAD_MODE_TURN, SKILL_LOAD_MODE_SESSION):
- return m
- return _DEFAULT_SKILL_LOAD_MODE
-
-
-def _is_knowledge_only(profile: str) -> bool:
- """Return True for profiles that support knowledge lookup only."""
- p = (profile or "").strip().lower().replace("-", "_")
- return p in ("knowledge_only", "knowledge")
+def _default_catalog_only_guidance() -> str:
+ return ("\n" + _SKILLS_TOOLING_GUIDANCE_HEADER + "\n" +
+ "- Use the skill overview as a catalog only. Built-in skill tools are"
+ " unavailable in this configuration; if a task depends on loading or"
+ " executing a skill, use other registered tools or explain the"
+ " limitation clearly.\n")
-# ---------------------------------------------------------------------------
-# Guidance text builders (mirrors Go defaultXxxGuidance functions)
-# ---------------------------------------------------------------------------
+def _default_doc_helpers_only_guidance(flags: SkillProfileFlags) -> str:
+ lines = [
+ "\n",
+ _SKILLS_TOOLING_GUIDANCE_HEADER,
+ "\n",
+ ]
+ has_list_docs = flags.list_docs
+ has_select_docs = flags.select_docs
+ if has_list_docs and has_select_docs:
+ lines.append("- Use skills only to inspect available doc names or adjust"
+ " doc selection state.\n")
+ elif has_list_docs:
+ lines.append("- Use skills only to inspect available doc names.\n")
+ elif has_select_docs:
+ lines.append("- Use skills only to adjust doc selection when doc names are"
+ " already known.\n")
+ lines.append("- Built-in skill loading is unavailable, so doc helpers do not"
+ " inject SKILL.md or doc contents into context; if the task needs"
+ " loaded content or execution, use other registered tools or"
+ " explain the limitation clearly.\n")
+ return "".join(lines)
-def _default_knowledge_only_guidance() -> str:
- return ("\n" + _SKILLS_TOOLING_GUIDANCE_HEADER + "\n" +
- "- Use skills for progressive disclosure only: load SKILL.md first,"
- " then inspect only the documentation needed for the current task.\n" +
- "- Avoid include_all_docs unless the user asks or the task genuinely"
- " needs the full doc set.\n" + "- Treat loaded skill content as domain guidance. Do not claim you"
- " executed scripts, shell commands, or interactive flows described by"
- " the skill.\n" + "- If a skill depends on execution to complete the task, switch to"
- " other registered tools (for example, MCP tools) or explain the"
- " limitation clearly.\n")
+def _default_knowledge_only_guidance(flags: SkillProfileFlags) -> str:
+ lines = [
+ "\n",
+ _SKILLS_TOOLING_GUIDANCE_HEADER,
+ "\n",
+ "- Use skills for progressive disclosure only: load SKILL.md first,"
+ " then inspect only the documentation needed for the current task.\n",
+ ]
+ _append_knowledge_guidance(lines, flags)
+ lines += [
+ "- Treat loaded skill content as domain guidance. Do not claim you"
+ " executed scripts, shell commands, or interactive flows described by"
+ " the skill.\n",
+ "- If a skill depends on execution to complete the task, switch to"
+ " other registered tools (for example, MCP tools) or explain the"
+ " limitation clearly.\n",
+ ]
+ return "".join(lines)
-def _default_full_tooling_and_workspace_guidance(exec_tools_disabled: bool) -> str:
+def _default_full_tooling_and_workspace_guidance(flags: SkillProfileFlags) -> str:
lines: list[str] = [
"\n",
_SKILLS_TOOLING_GUIDANCE_HEADER,
@@ -132,6 +161,9 @@ def _default_full_tooling_and_workspace_guidance(exec_tools_disabled: bool) -> s
"- Prefer writing new files under $OUTPUT_DIR or a skill's out/"
" directory and include output_files globs (or an outputs spec) so"
" files can be collected or saved as artifacts.\n",
+ "- Use stdout/stderr for logs or short status text. If the model needs"
+ " large or structured text, write it to files under $OUTPUT_DIR and"
+ " return it via output_files or outputs.\n",
"- For Python skills that need third-party packages, create a virtualenv"
" under the skill's .venv/ directory (it is writable inside the"
" workspace).\n",
@@ -153,69 +185,78 @@ def _default_full_tooling_and_workspace_guidance(exec_tools_disabled: bool) -> s
"- When chaining multiple skills, read previous results from $OUTPUT_DIR"
" (or a skill's out/ directory) instead of copying them back into inputs"
" directories.\n",
- "- Treat loaded skill docs as guidance, not perfect truth; when runtime"
- " help or stderr disagrees, trust observed runtime behavior.\n",
- "- Loading a skill gives you instructions and bundled resources; it does"
- " not execute the skill by itself.\n",
- "- The skill summaries above are routing summaries only; they do not"
- " replace SKILL.md or other loaded docs.\n",
- "- If the loaded content already provides enough guidance to answer or"
- " produce the requested result, respond directly.\n",
- "- A skill can still be executable even when it has no extra docs"
- " or no custom tools. If SKILL.md provides runnable commands,"
- " proceed with skill_run using those commands.\n",
- "- If a skill is not loaded, call skill_load; you may pass docs or"
- " include_all_docs.\n",
- "- If the body is loaded and docs are missing, treat docs as optional"
- " unless the task explicitly requires extra references; then call"
- " skill_select_docs or skill_load again to add docs.\n",
- "- If the skill defines tools in its SKILL.md, they will be"
- " automatically selected when you load the skill. You can refine tool"
- " selection with skill_select_tools.\n",
- "- If you decide to use a skill, load SKILL.md before",
]
- if exec_tools_disabled:
- lines.append(" the first skill_run for that skill, then load only"
- " the docs you still need.\n")
+ if flags.load:
+ lines += [
+ "- Treat loaded skill docs as guidance, not perfect truth; when runtime"
+ " help or stderr disagrees, trust observed runtime behavior.\n",
+ "- Loading a skill gives you instructions and bundled resources; it does"
+ " not execute the skill by itself.\n",
+ "- The skill summaries above are routing summaries only; they do not"
+ " replace SKILL.md or other loaded docs.\n",
+ "- If the loaded content already provides enough guidance to answer or"
+ " produce the requested result, respond directly.\n",
+ "- If you decide to use a skill, load SKILL.md before",
+ ]
+ if flags.requires_exec_session_tools():
+ lines.append(" the first skill_run or skill_exec for that skill, then load"
+ " only the docs you still need.\n")
+ else:
+ lines.append(" the first skill_run for that skill, then load only the docs"
+ " you still need.\n")
+ lines += [
+ "- Do not infer commands, script entrypoints, or resource layouts from"
+ " the short summary alone.\n",
+ ]
+ _append_knowledge_guidance(lines, flags)
+ elif flags.has_doc_helpers():
+ lines += [
+ "- Built-in skill loading is unavailable in this configuration. Doc"
+ " listing or selection helpers can inspect doc names or selection"
+ " state, but they do not inject SKILL.md or doc contents into"
+ " context.\n",
+ ]
else:
- lines.append(" the first skill_run or skill_exec for that skill,"
- " then load only the docs you still need.\n")
+ lines += [
+ "- Built-in skill loading is unavailable in this configuration; do not"
+ " assume SKILL.md or doc contents are in context.\n",
+ ]
+
lines += [
- "- Do not infer commands, script entrypoints, or resource layouts"
- " from the short summary alone.\n",
- "- For docs, prefer skill_list_docs + skill_select_docs to load only"
- " what you need.\n",
- "- Avoid include_all_docs unless you need every doc or the user asks.\n",
- "- Use execution tools only when running a command will reveal or"
- " produce information or files you still need.\n",
+ "- Use execution tools only when running a command will reveal or produce"
+ " information or files you still need.\n",
]
- if not exec_tools_disabled:
+ if flags.requires_exec_session_tools():
lines.append("- Use skill_exec only when a command needs incremental stdin or"
" TTY-style interaction; otherwise prefer one-shot execution.\n")
else:
lines.append("- Do not assume interactive execution is available when only"
" one-shot execution tools are present.\n")
lines += [
- "- Prefer script-based commands from SKILL.md examples (for example,"
- " python3 scripts/foo.py) instead of ad-hoc python -c rewrites"
- " unless the skill explicitly recommends inline execution.\n",
- "- skill_run is a command runner inside the skill workspace, not a"
- " magic capability. It does not automatically add the skill directory"
- " to PATH or install dependencies; invoke scripts via an explicit"
- " interpreter and path (e.g., python3 scripts/foo.py).\n",
- "- When you execute, follow the tool description, loaded skill docs,"
- " bundled scripts, and observed runtime behavior rather than inventing"
- " shell syntax or command arguments.\n",
- "- If skill_list_tools returns command_examples, execute one of those"
- " commands directly before trying ad-hoc shell alternatives.\n",
+ "- skill_run is a command runner inside the skill workspace, not a magic"
+ " capability. It does not automatically add the skill directory to PATH"
+ " or install dependencies; invoke scripts via an explicit interpreter and"
+ " path (e.g., python3 scripts/foo.py).\n",
+ "- When you execute, follow the tool description, ",
]
+ if flags.load:
+ lines[-1] += "loaded skill docs, "
+ lines[
+ -1] += "bundled scripts, and observed runtime behavior rather than inventing shell syntax or command arguments.\n"
return "".join(lines)
-def _default_tooling_and_workspace_guidance(profile: str, exec_tools_disabled: bool) -> str:
- if _is_knowledge_only(profile):
- return _default_knowledge_only_guidance()
- return _default_full_tooling_and_workspace_guidance(exec_tools_disabled)
+def _default_tooling_and_workspace_guidance(flags: SkillProfileFlags) -> str:
+ if not flags.is_any():
+ return _default_catalog_only_guidance()
+
+ if not flags.run:
+ if flags.load:
+ return _default_knowledge_only_guidance(flags)
+ if flags.has_doc_helpers():
+ return _default_doc_helpers_only_guidance(flags)
+ return _default_catalog_only_guidance()
+ return _default_full_tooling_and_workspace_guidance(flags)
def _normalize_custom_guidance(guidance: str) -> str:
@@ -236,8 +277,6 @@ def _normalize_custom_guidance(guidance: str) -> str:
class SkillsRequestProcessor:
"""Injects skill overviews and loaded contents into LLM requests.
- Mirrors Go's ``SkillsRequestProcessor``.
-
Args:
skill_repository: Default skill repository.
load_mode: ``"turn"`` (default), ``"once"``, or ``"session"``.
@@ -247,6 +286,9 @@ class SkillsRequestProcessor:
tool_result_mode: When ``True``, skip loaded-skill injection here
(content is materialized into tool results instead).
tool_profile: Profile string (e.g. ``"knowledge_only"``).
+ forbidden_tools: Optional explicit blacklist of built-in skill tools.
+ tool_flags: Optional resolved flags; when set, takes precedence
+ over ``tool_profile``/``forbidden_tools``.
exec_tools_disabled: When ``True``, omit skill_exec guidance lines.
repo_resolver: Optional ``(ctx) -> BaseSkillRepository`` callable
that returns an invocation-specific repository.
@@ -257,25 +299,30 @@ def __init__(
self,
skill_repository: BaseSkillRepository,
*,
- load_mode: str = SKILL_LOAD_MODE_TURN,
+ load_mode: str = str(SkillLoadModeNames.TURN),
tooling_guidance: Optional[str] = None,
tool_result_mode: bool = False,
- tool_profile: str = "",
+ tool_profile: str = str(SkillProfileNames.FULL),
+ forbidden_tools: Optional[list[str]] = None,
+ tool_flags: Optional[SkillProfileFlags] = None,
exec_tools_disabled: bool = False,
repo_resolver: Optional[Callable[[InvocationContext], BaseSkillRepository]] = None,
max_loaded_skills: int = 0,
) -> None:
self._skill_repository = skill_repository
- self._load_mode = _normalize_load_mode(load_mode)
+ self._load_mode = normalize_load_mode(load_mode)
self._tooling_guidance = tooling_guidance
self._tool_result_mode = tool_result_mode
- self._tool_profile = (tool_profile or "").strip()
- self._exec_tools_disabled = exec_tools_disabled
+ try:
+ resolved_flags = tool_flags or SkillProfileFlags.resolve_flags(tool_profile, forbidden_tools)
+ except ValueError as ex:
+ logger.warning("skills: invalid skill tool flags config, fallback to full profile: %s", ex)
+ resolved_flags = SkillProfileFlags.preset_flags(tool_profile, forbidden_tools)
+ if exec_tools_disabled:
+ resolved_flags = resolved_flags.without_interactive_execution()
+ self._tool_flags = resolved_flags
self._repo_resolver = repo_resolver
self._max_loaded_skills = max_loaded_skills
- # Tracks which invocation IDs have already had their turn-init clearing
- # applied. This is ephemeral instance state — not persisted to session.
- self._initialized_invocations: set[str] = set()
# ------------------------------------------------------------------
# Public entry point
@@ -301,17 +348,17 @@ async def process_llm_request(
self._maybe_clear_skill_state_for_turn(ctx)
# 1) Always inject overview (names + descriptions).
- self._inject_overview(request, repo)
+ self._inject_overview(ctx, request, repo)
loaded = self._get_loaded_skills(ctx)
loaded = self._maybe_cap_loaded_skills(ctx, loaded)
if self._tool_result_mode:
- # Loaded skill bodies/docs are injected into tool results by a
- # separate post-content processor — skip injection here.
+ # Materialization is handled by a dedicated post-history processor
+ # in request pipeline (Go-aligned ordering).
return loaded
- # 2) Loaded skills: full body + docs + tools (sorted for stable prompts).
+ # 2) Loaded skills: full body + docs (sorted for stable prompts).
loaded.sort()
parts: list[str] = []
@@ -338,8 +385,6 @@ async def process_llm_request(
doc_text = self._build_docs_text(sk, sel)
if doc_text:
parts.append(doc_text)
-
- # Tools (Python-unique: skill_select_tools integration)
tool_sel = self._get_tools_selection(ctx, name)
parts.append("Tools selected: ")
if not tool_sel:
@@ -373,38 +418,36 @@ def _maybe_clear_skill_state_for_turn(self, ctx: InvocationContext) -> None:
Uses ``ctx.invocation_id`` to detect when a new invocation has started
without persisting an extra key to session state.
"""
- if self._load_mode != SKILL_LOAD_MODE_TURN:
+ if self._load_mode != SkillLoadModeNames.TURN:
return
- inv_id = ctx.invocation_id
- if inv_id in self._initialized_invocations:
+ if ctx.agent_context.get_metadata(_SKILLS_TURN_INIT_STATE_KEY):
return
- self._initialized_invocations.add(inv_id)
- # Bound the set size to avoid unbounded growth across long-running servers.
- if len(self._initialized_invocations) > 2000:
- oldest = list(self._initialized_invocations)[:1000]
- for old_id in oldest:
- self._initialized_invocations.discard(old_id)
+ ctx.agent_context.with_metadata(_SKILLS_TURN_INIT_STATE_KEY, True)
self._clear_skill_state(ctx)
def _clear_skill_state(self, ctx: InvocationContext) -> None:
"""Clear all loaded-skill state keys from the session."""
+ loaded_state_prefix = loaded_scan_prefix(ctx)
+ docs_state_prefix = docs_scan_prefix(ctx)
+ tools_state_prefix = tool_scan_prefix(ctx)
+ order_state_key = loaded_order_state_key(ctx)
state = self._snapshot_state(ctx)
for k, v in state.items():
if not v:
continue
- if (k.startswith(SKILL_LOADED_STATE_KEY_PREFIX) or k.startswith(SKILL_DOCS_STATE_KEY_PREFIX)
- or k.startswith(SKILL_TOOLS_STATE_KEY_PREFIX) or k == _SKILLS_LOADED_ORDER_STATE_KEY):
+ if (k.startswith(loaded_state_prefix) or k.startswith(docs_state_prefix) or k.startswith(tools_state_prefix)
+ or k == order_state_key):
ctx.actions.state_delta[k] = None
def _maybe_offload_loaded_skills(self, ctx: InvocationContext, loaded: list[str]) -> None:
"""After injection, clear skill state for once mode."""
- if self._load_mode != SKILL_LOAD_MODE_ONCE or not loaded:
+ if self._load_mode != SkillLoadModeNames.ONCE or not loaded:
return
for name in loaded:
- ctx.actions.state_delta[SKILL_LOADED_STATE_KEY_PREFIX + name] = None
- ctx.actions.state_delta[SKILL_DOCS_STATE_KEY_PREFIX + name] = None
- ctx.actions.state_delta[SKILL_TOOLS_STATE_KEY_PREFIX + name] = None
- ctx.actions.state_delta[_SKILLS_LOADED_ORDER_STATE_KEY] = None
+ ctx.actions.state_delta[loaded_state_key(ctx, name)] = None
+ ctx.actions.state_delta[docs_state_key(ctx, name)] = None
+ ctx.actions.state_delta[tool_state_key(ctx, name)] = None
+ ctx.actions.state_delta[loaded_order_state_key(ctx)] = None
# ------------------------------------------------------------------
# Max-loaded-skills cap
@@ -416,40 +459,31 @@ def _maybe_cap_loaded_skills(self, ctx: InvocationContext, loaded: list[str]) ->
return loaded
order = self._get_loaded_skill_order(ctx, loaded)
- # Keep the most recently touched skills (tail of the order list).
+ if not order:
+ return loaded
keep_count = self._max_loaded_skills
- keep_set = set(order[-keep_count:]) if len(order) >= keep_count else set(order)
+ keep_set = set(order[-keep_count:])
kept: list[str] = []
for name in loaded:
if name in keep_set:
kept.append(name)
else:
- ctx.actions.state_delta[SKILL_LOADED_STATE_KEY_PREFIX + name] = None
- ctx.actions.state_delta[SKILL_DOCS_STATE_KEY_PREFIX + name] = None
- ctx.actions.state_delta[SKILL_TOOLS_STATE_KEY_PREFIX + name] = None
-
+ ctx.actions.state_delta[loaded_state_key(ctx, name)] = None
+ ctx.actions.state_delta[docs_state_key(ctx, name)] = None
+ ctx.actions.state_delta[tool_state_key(ctx, name)] = None
new_order = [n for n in order if n in keep_set]
- ctx.actions.state_delta[_SKILLS_LOADED_ORDER_STATE_KEY] = json.dumps(new_order)
+ encoded_order = marshal_loaded_order(new_order)
+ ctx.actions.state_delta[loaded_order_state_key(ctx)] = encoded_order
return kept
def _get_loaded_skill_order(self, ctx: InvocationContext, loaded: list[str]) -> list[str]:
- """Return skill names in load order (oldest first, most-recent last).
-
- Reads the persisted order key; fills in any missing names
- alphabetically (mirrors Go's fillLoadedSkillOrderAlphabetically).
- """
- loaded_set = set(loaded)
- raw = self._read_state(ctx, _SKILLS_LOADED_ORDER_STATE_KEY)
- order: list[str] = []
- if raw:
- try:
- parsed = json.loads(raw) if isinstance(raw, str) else raw
- if isinstance(parsed, list):
- order = [n for n in parsed if n in loaded_set]
- except (json.JSONDecodeError, TypeError, ValueError):
- pass
-
+ loaded_set = self._loaded_skill_set(loaded)
+ if not loaded_set:
+ return []
+ order = self._loaded_skill_order_from_state(ctx, loaded_set)
+ if len(order) < len(loaded_set):
+ order = self._append_skills_to_order_from_events(ctx, order, loaded_set)
seen = set(order)
for name in sorted(n for n in loaded_set if n not in seen):
order.append(name)
@@ -459,7 +493,7 @@ def _get_loaded_skill_order(self, ctx: InvocationContext, loaded: list[str]) ->
# Overview injection
# ------------------------------------------------------------------
- def _inject_overview(self, request: LlmRequest, repo: BaseSkillRepository) -> None:
+ def _inject_overview(self, ctx: InvocationContext, request: LlmRequest, repo: BaseSkillRepository) -> None:
sums = repo.summaries()
if not sums:
return
@@ -501,23 +535,44 @@ def _inject_overview(self, request: LlmRequest, repo: BaseSkillRepository) -> No
def _tooling_guidance_text(self) -> str:
if self._tooling_guidance is None:
- return _default_tooling_and_workspace_guidance(self._tool_profile, self._exec_tools_disabled)
- return _normalize_custom_guidance(self._tooling_guidance)
+ tool_prompt = _default_tooling_and_workspace_guidance(self._tool_flags)
+ else:
+ tool_prompt = _normalize_custom_guidance(self._tooling_guidance)
+ if self._tool_flags.has_select_tools():
+ tool_prompt += """
+ - Use the skill_select_tools tool to select tools for the current task only when user asks for it."
+ """
+ if self._tool_flags.list_skills:
+ tool_prompt += """
+ - Use the skill_list_skills tool to list skills for the current task only when user asks for it."
+ """
+ return tool_prompt
def _capability_guidance_text(self) -> str:
- """Inject capability block for knowledge-only profiles."""
- if not _is_knowledge_only(self._tool_profile):
- return ""
+ """Inject capability block for constrained skill-tool profiles."""
# Omit when caller explicitly cleared guidance.
- if self._tooling_guidance is not None and self._tooling_guidance == "":
+ if self._tooling_guidance == "" or self._tool_flags.run:
return ""
+ if self._tool_flags.load:
+ return ("\n" + _SKILLS_CAPABILITY_HEADER + "\n" +
+ "- This configuration supports skill discovery and knowledge loading only.\n" +
+ "- Built-in skill execution tools are unavailable in the current mode.\n" +
+ "- If a loaded skill describes scripts, shell commands, workspace paths,"
+ " generated files, or interactive flows, treat that content as reference"
+ " only. Use other registered tools for real actions, or explain that"
+ " execution is unavailable in the current mode.\n")
+ if self._tool_flags.has_doc_helpers():
+ return ("\n" + _SKILLS_CAPABILITY_HEADER + "\n" +
+ "- This configuration supports skill discovery and skill doc inspection only.\n" +
+ "- Built-in skill loading and execution tools are unavailable in the"
+ " current mode.\n- Listing or selecting docs does not inject SKILL.md or doc contents"
+ " into model context by itself.\n")
return ("\n" + _SKILLS_CAPABILITY_HEADER + "\n" +
- "- This profile supports skill discovery and knowledge loading only.\n" +
- "- Execution-oriented skill tools are unavailable in the current mode.\n" +
- "- If a loaded skill describes scripts, shell commands, workspace paths,"
- " generated files, or interactive flows, treat that content as reference"
- " only. Use other registered tools for real actions, or explain that"
- " execution is unavailable in the current mode.\n")
+ "- This configuration exposes skill summaries only. Built-in skill tools"
+ " are unavailable in the current mode.\n" +
+ "- Treat the skill overview as a catalog of possible capabilities. Use"
+ " other registered tools, or explain the limitation clearly when the task"
+ " depends on skill loading or execution.\n")
# ------------------------------------------------------------------
# State helpers
@@ -547,57 +602,150 @@ def _get_loaded_skills(self, ctx: InvocationContext) -> list[str]:
"""Return names of all currently loaded skills."""
names: list[str] = []
state = self._snapshot_state(ctx)
+ scan_prefix = loaded_scan_prefix(ctx)
for k, v in state.items():
- if not k.startswith(SKILL_LOADED_STATE_KEY_PREFIX) or not v:
+ if not k.startswith(scan_prefix) or not v:
continue
- name = k[len(SKILL_LOADED_STATE_KEY_PREFIX):]
- names.append(name)
- return names
+ name = k[len(scan_prefix):].strip()
+ if name:
+ names.append(name)
+ if names:
+ return sorted(set(names))
+ return []
# ------------------------------------------------------------------
- # Docs and tools selection
+ # Docs / tools selection
# ------------------------------------------------------------------
def _get_docs_selection(self, ctx: InvocationContext, name: str) -> list[str]:
-
- def get_all_docs(skill_name: str) -> list[str]:
+ value = self._read_state(ctx, docs_state_key(ctx, name), default=None)
+ if not value:
+ return []
+ if isinstance(value, bytes):
try:
- repo = self._get_repository(ctx)
- sk = repo.get(skill_name) if repo else None
- if sk is None:
- return []
- return [d.path for d in sk.resources]
+ value = value.decode("utf-8")
+ except UnicodeDecodeError:
+ return []
+ if value == "*":
+ repo = self._get_repository(ctx)
+ if repo is None:
+ return []
+ try:
+ sk = repo.get(name)
+ return [doc.path for doc in sk.resources]
except Exception as ex: # pylint: disable=broad-except
- logger.warning("Failed to get docs for skill '%s': %s", skill_name, ex)
+ logger.warning("Failed to get docs for skill '%s': %s", name, ex)
return []
-
- return generic_get_selection(
- ctx=ctx,
- skill_name=name,
- state_key_prefix=SKILL_DOCS_STATE_KEY_PREFIX,
- get_all_items_callback=get_all_docs,
- )
+ if not isinstance(value, str):
+ return []
+ try:
+ arr = json.loads(value)
+ except json.JSONDecodeError:
+ return []
+ if not isinstance(arr, list):
+ return []
+ return [doc for doc in arr if isinstance(doc, str) and doc.strip()]
def _get_tools_selection(self, ctx: InvocationContext, name: str) -> list[str]:
- """Python-unique: return selected tool names for *name*."""
-
- def get_all_tools(skill_name: str) -> list[str]:
+ value = self._read_state(ctx, tool_state_key(ctx, name), default=None)
+ if not value:
+ return []
+ if isinstance(value, bytes):
try:
- repo = self._get_repository(ctx)
- sk = repo.get(skill_name) if repo else None
- if sk is None:
- return []
+ value = value.decode("utf-8")
+ except UnicodeDecodeError:
+ return []
+ if value == "*":
+ repo = self._get_repository(ctx)
+ if repo is None:
+ return []
+ try:
+ sk = repo.get(name)
return sk.tools
except Exception as ex: # pylint: disable=broad-except
- logger.warning("Failed to get tools for skill '%s': %s", skill_name, ex)
+ logger.warning("Failed to get tools for skill '%s': %s", name, ex)
return []
+ if not isinstance(value, str):
+ return []
+ try:
+ arr = json.loads(value)
+ except json.JSONDecodeError:
+ return []
+ if not isinstance(arr, list):
+ return []
+ return [tool for tool in arr if isinstance(tool, str) and tool.strip()]
+
+ def _loaded_skill_set(self, loaded: list[str]) -> set[str]:
+ out: set[str] = set()
+ for name in loaded:
+ candidate = (name or "").strip()
+ if candidate:
+ out.add(candidate)
+ return out
+
+ def _loaded_skill_order_from_state(self, ctx: InvocationContext, loaded_set: set[str]) -> list[str]:
+ order = parse_loaded_order(self._read_state(ctx, loaded_order_state_key(ctx)))
+ if not order:
+ return []
+ out: list[str] = []
+ seen: set[str] = set()
+ for name in order:
+ if name not in loaded_set or name in seen:
+ continue
+ out.append(name)
+ seen.add(name)
+ return out
- return generic_get_selection(
- ctx=ctx,
- skill_name=name,
- state_key_prefix=SKILL_TOOLS_STATE_KEY_PREFIX,
- get_all_items_callback=get_all_tools,
- )
+ def _append_skills_to_order_from_events(
+ self,
+ ctx: InvocationContext,
+ order: list[str],
+ loaded_set: set[str],
+ ) -> list[str]:
+ events = list(getattr(ctx.session, "events", []) or [])
+ if not events:
+ return order
+ for event in events:
+ if ctx.agent_name and getattr(event, "author", "") != ctx.agent_name:
+ continue
+ content = getattr(event, "content", None)
+ if content is None or not getattr(content, "parts", None):
+ continue
+ for part in content.parts:
+ response = getattr(part, "function_response", None)
+ if response is None:
+ continue
+ tool_name = (getattr(response, "name", "") or "").strip()
+ if tool_name not in (SkillToolsNames.LOAD, SkillToolsNames.SELECT_DOCS):
+ continue
+ skill_name = self._skill_name_from_tool_response(tool_name, getattr(response, "response", None))
+ if not skill_name or skill_name not in loaded_set:
+ continue
+ order = touch_loaded_order(order, skill_name)
+ return order
+
+ def _skill_name_from_tool_response(self, tool_name: str, response: Any) -> str:
+ if tool_name == str(SkillToolsNames.SELECT_DOCS) and isinstance(response, dict):
+ for key in ("skill", "skill_name", "name"):
+ value = response.get(key)
+ if isinstance(value, str) and value.strip():
+ return value.strip()
+ return ""
+ if tool_name == SkillToolsNames.LOAD:
+ if isinstance(response, dict):
+ for key in ("skill", "skill_name", "name", "result"):
+ value = response.get(key)
+ if isinstance(value, str) and value.strip():
+ match = SKILL_LOADED_RE.search(value)
+ if match:
+ return match.group(1).strip()
+ if key in ("skill", "skill_name", "name"):
+ return value.strip()
+ if isinstance(response, str):
+ match = SKILL_LOADED_RE.search(response)
+ if match:
+ return match.group(1).strip()
+ return ""
# ------------------------------------------------------------------
# Doc text assembly
@@ -623,3 +771,28 @@ def _merge_into_system(self, request: LlmRequest, content: str) -> None:
if not content:
return
request.append_instructions([content])
+
+
+def set_skill_processor_parameters(agent_context: AgentContext, parameters: dict[str, Any]) -> None:
+ """Set the parameters of a skill processor by agent context.
+
+ Args:
+ agent_context: AgentContext object
+ parameters: Parameters to set
+ """
+ skill_config = get_skill_config(agent_context)
+ skill_config["skill_processor"].update(parameters)
+ set_skill_config(agent_context, skill_config)
+
+
+def get_skill_processor_parameters(agent_context: AgentContext) -> dict[str, Any]:
+ """Get the parameters of a skill processor.
+
+ Args:
+ invocation_context: InvocationContext object
+
+ Returns:
+ Parameters of the skill processor
+ """
+ skill_config = get_skill_config(agent_context)
+ return skill_config["skill_processor"]
diff --git a/trpc_agent_sdk/agents/core/_skills_tool_result_processor.py b/trpc_agent_sdk/agents/core/_skills_tool_result_processor.py
new file mode 100644
index 0000000..fcd38fe
--- /dev/null
+++ b/trpc_agent_sdk/agents/core/_skills_tool_result_processor.py
@@ -0,0 +1,375 @@
+# Tencent is pleased to support the open source community by making tRPC-Agent-Python available.
+#
+# Copyright (C) 2026 Tencent. All rights reserved.
+#
+# tRPC-Agent-Python is licensed under Apache-2.0.
+"""Materialize loaded skill context into skill tool results."""
+
+from __future__ import annotations
+
+import json
+import re
+from typing import Any
+from typing import Callable
+from typing import Optional
+from typing import Tuple
+
+from trpc_agent_sdk.context import AgentContext
+from trpc_agent_sdk.context import InvocationContext
+from trpc_agent_sdk.log import logger
+from trpc_agent_sdk.models import LlmRequest
+from trpc_agent_sdk.skills import BaseSkillRepository
+from trpc_agent_sdk.skills import Skill
+from trpc_agent_sdk.skills import SkillLoadModeNames
+from trpc_agent_sdk.skills import SkillToolsNames
+from trpc_agent_sdk.skills import docs_state_key
+from trpc_agent_sdk.skills import get_skill_config
+from trpc_agent_sdk.skills import get_skill_load_mode
+from trpc_agent_sdk.skills import loaded_order_state_key
+from trpc_agent_sdk.skills import loaded_scan_prefix
+from trpc_agent_sdk.skills import loaded_state_key
+from trpc_agent_sdk.skills import set_skill_config
+
+_SKILLS_LOADED_CONTEXT_HEADER = "Loaded skill context:"
+_SESSION_SUMMARY_PREFIX = "Here is a brief summary of your previous interactions:"
+SKILL_LOADED_RE = re.compile(r"skill\s+'([^']+)'\s+loaded", re.IGNORECASE)
+
+
+class SkillsToolResultRequestProcessor:
+ """Materialize loaded skill content into skill tool results."""
+
+ def __init__(
+ self,
+ skill_repository: BaseSkillRepository,
+ *,
+ skip_fallback_on_session_summary: bool = True,
+ repo_resolver: Optional[Callable[[InvocationContext], BaseSkillRepository]] = None,
+ ) -> None:
+ self._skill_repository = skill_repository
+ self._repo_resolver = repo_resolver
+ self._skip_fallback_on_session_summary = skip_fallback_on_session_summary
+
+ async def process_llm_request(self, ctx: InvocationContext, request: LlmRequest) -> list[str]:
+ """Apply loaded-skill materialization to tool results and fallback prompt."""
+ if request is None or ctx is None:
+ return []
+ repo = self._get_repository(ctx)
+ if repo is None:
+ return []
+
+ loaded = self._get_loaded_skills(ctx)
+ if not loaded:
+ return []
+ loaded.sort()
+
+ tool_calls = self._index_tool_calls(request)
+ last_tool_parts = self._last_skill_tool_parts(request, tool_calls)
+
+ materialized: set[str] = set()
+ for skill_name, (content_idx, part_idx) in last_tool_parts.items():
+ content = request.contents[content_idx]
+ part = content.parts[part_idx]
+ function_response = part.function_response
+ if function_response is None:
+ continue
+ base = self._response_to_text(getattr(function_response, "response", None))
+ rendered = self._build_tool_result_content(ctx, repo, skill_name, base)
+ if not rendered:
+ continue
+ function_response.response = {"result": rendered}
+ materialized.add(skill_name)
+
+ fallback = self._build_fallback_system_content(ctx, repo, loaded, materialized)
+ if fallback:
+ if not (self._skip_fallback_on_session_summary and self._has_session_summary(request)
+ and not last_tool_parts):
+ request.append_instructions([fallback])
+
+ self._maybe_offload_loaded_skills(ctx, loaded)
+ return loaded
+
+ def _get_repository(self, ctx: InvocationContext) -> Optional[BaseSkillRepository]:
+ if self._repo_resolver is not None:
+ return self._repo_resolver(ctx)
+ return self._skill_repository
+
+ def _snapshot_state(self, ctx: InvocationContext) -> dict[str, Any]:
+ state = dict(ctx.session_state)
+ for key, value in ctx.actions.state_delta.items():
+ if value is None:
+ state.pop(key, None)
+ else:
+ state[key] = value
+ return state
+
+ def _read_state(self, ctx: InvocationContext, key: str, default=None):
+ if key in ctx.actions.state_delta:
+ return ctx.actions.state_delta[key]
+ return ctx.session_state.get(key, default)
+
+ def _get_loaded_skills(self, ctx: InvocationContext) -> list[str]:
+ names_set: set[str] = set()
+ state = self._snapshot_state(ctx)
+ scan_prefix = loaded_scan_prefix(ctx)
+ for key, value in state.items():
+ if not key.startswith(scan_prefix) or not value:
+ continue
+ name = key[len(scan_prefix):].strip()
+ if name:
+ names_set.add(name)
+ return sorted(names_set)
+
+ def _index_tool_calls(self, request: LlmRequest) -> dict[str, Any]:
+ out: dict[str, Any] = {}
+ for content in request.contents:
+ if content.role not in ("model", "assistant") or not content.parts:
+ continue
+ for part in content.parts:
+ function_call = getattr(part, "function_call", None)
+ if function_call is None:
+ continue
+ call_id = (getattr(function_call, "id", "") or "").strip()
+ if not call_id:
+ continue
+ out[call_id] = function_call
+ return out
+
+ def _last_skill_tool_parts(
+ self,
+ request: LlmRequest,
+ tool_calls: dict[str, Any],
+ ) -> dict[str, Tuple[int, int]]:
+ out: dict[str, Tuple[int, int]] = {}
+ for content_idx, content in enumerate(request.contents):
+ if content.role != "user" or not content.parts:
+ continue
+ for part_idx, part in enumerate(content.parts):
+ function_response = getattr(part, "function_response", None)
+ if function_response is None:
+ continue
+ tool_name = (getattr(function_response, "name", "") or "").strip()
+ if tool_name not in (SkillToolsNames.LOAD, SkillToolsNames.SELECT_DOCS):
+ continue
+ skill_name = self._skill_name_from_tool_response(function_response, tool_calls)
+ if not skill_name:
+ continue
+ out[skill_name] = (content_idx, part_idx)
+ return out
+
+ def _skill_name_from_tool_response(self, function_response: Any, tool_calls: dict[str, Any]) -> str:
+ response = getattr(function_response, "response", None)
+ if isinstance(response, dict):
+ for key in ("skill", "skill_name", "name"):
+ value = response.get(key)
+ if isinstance(value, str) and value.strip():
+ return value.strip()
+
+ call_id = (getattr(function_response, "id", "") or "").strip()
+ if call_id and call_id in tool_calls:
+ function_call = tool_calls[call_id]
+ args = getattr(function_call, "args", None)
+ for key in ("skill", "skill_name", "name"):
+ value = self._get_arg_value(args, key)
+ if value:
+ return value
+
+ return self._parse_loaded_skill_from_text(self._response_to_text(response))
+
+ def _get_arg_value(self, args: Any, key: str) -> str:
+ if isinstance(args, str):
+ try:
+ args = json.loads(args)
+ except json.JSONDecodeError:
+ return ""
+ if isinstance(args, dict):
+ value = args.get(key)
+ if isinstance(value, str):
+ return value.strip()
+ return ""
+
+ def _parse_loaded_skill_from_text(self, content: str) -> str:
+ text = (content or "").strip()
+ if not text:
+ return ""
+ match = SKILL_LOADED_RE.search(text)
+ if match:
+ return match.group(1).strip()
+ lower = text.lower()
+ if lower.startswith("loaded:"):
+ return text[len("loaded:"):].strip()
+ return ""
+
+ def _response_to_text(self, response: Any) -> str:
+ if response is None:
+ return ""
+ if isinstance(response, str):
+ return response.strip()
+ if isinstance(response, dict):
+ result = response.get("result")
+ if isinstance(result, str):
+ return result.strip()
+ if result is not None:
+ return str(result).strip()
+ return json.dumps(response, ensure_ascii=False).strip()
+ return str(response).strip()
+
+ def _is_loaded_tool_stub(self, tool_output: str, skill_name: str) -> bool:
+ loaded = self._parse_loaded_skill_from_text(tool_output)
+ if not loaded:
+ return False
+ return loaded.lower() == skill_name.lower()
+
+ def _build_tool_result_content(
+ self,
+ ctx: InvocationContext,
+ repo: BaseSkillRepository,
+ skill_name: str,
+ tool_output: str,
+ ) -> str:
+ try:
+ sk = repo.get(skill_name)
+ except Exception as ex: # pylint: disable=broad-except
+ logger.warning("skills: get %s failed: %s", skill_name, ex)
+ return ""
+ if sk is None:
+ logger.warning("skills: get %s failed: skill not found", skill_name)
+ return ""
+
+ parts: list[str] = []
+ base = tool_output.strip()
+ if base and self._is_loaded_tool_stub(base, skill_name):
+ base = ""
+ if base:
+ parts.append(base)
+ parts.append("\n\n")
+
+ if sk.body.strip():
+ parts.append(f"[Loaded] {skill_name}\n\n{sk.body}\n")
+
+ selected_docs = self._get_docs_selection(ctx, skill_name, repo)
+ parts.append("Docs loaded: ")
+ if not selected_docs:
+ parts.append("none\n")
+ else:
+ parts.append(", ".join(selected_docs) + "\n")
+ docs_text = self._build_docs_text(sk, selected_docs)
+ if docs_text:
+ parts.append(docs_text)
+ return "".join(parts).strip()
+
+ def _build_fallback_system_content(
+ self,
+ ctx: InvocationContext,
+ repo: BaseSkillRepository,
+ loaded: list[str],
+ materialized: set[str],
+ ) -> str:
+ missing = [name for name in loaded if name not in materialized]
+ if not missing:
+ return ""
+
+ parts: list[str] = [_SKILLS_LOADED_CONTEXT_HEADER, "\n"]
+ appended = False
+ for name in missing:
+ try:
+ sk = repo.get(name)
+ except Exception as ex: # pylint: disable=broad-except
+ logger.warning("skills: get %s failed: %s", name, ex)
+ continue
+ if sk is None:
+ logger.warning("skills: get %s failed: skill not found", name)
+ continue
+ if sk.body.strip():
+ parts.append(f"\n[Loaded] {name}\n\n{sk.body}\n")
+ appended = True
+ selected_docs = self._get_docs_selection(ctx, name, repo)
+ parts.append("Docs loaded: ")
+ if not selected_docs:
+ parts.append("none\n")
+ else:
+ parts.append(", ".join(selected_docs) + "\n")
+ docs_text = self._build_docs_text(sk, selected_docs)
+ if docs_text:
+ parts.append(docs_text)
+ appended = True
+ if not appended:
+ return ""
+ return "".join(parts).strip()
+
+ def _has_session_summary(self, request: LlmRequest) -> bool:
+ if request is None or request.config is None:
+ return False
+ system_instruction = str(request.config.system_instruction or "")
+ return _SESSION_SUMMARY_PREFIX in system_instruction
+
+ def _get_docs_selection(self, ctx: InvocationContext, skill_name: str, repo: BaseSkillRepository) -> list[str]:
+ value = self._read_state(ctx, docs_state_key(ctx, skill_name), default=None)
+ if not value:
+ return []
+ if isinstance(value, bytes):
+ try:
+ value = value.decode("utf-8")
+ except UnicodeDecodeError:
+ return []
+ if value == "*":
+ try:
+ sk = repo.get(skill_name)
+ except Exception as ex: # pylint: disable=broad-except
+ logger.warning("skills: get %s failed: %s", skill_name, ex)
+ return []
+ if sk is None:
+ return []
+ return [doc.path for doc in sk.resources]
+ if not isinstance(value, str):
+ return []
+ try:
+ arr = json.loads(value)
+ except json.JSONDecodeError:
+ return []
+ if not isinstance(arr, list):
+ return []
+ return [doc for doc in arr if isinstance(doc, str) and doc.strip()]
+
+ def _build_docs_text(self, sk: Skill, wanted: list[str]) -> str:
+ if sk is None or not sk.resources:
+ return ""
+ want = set(wanted)
+ parts: list[str] = []
+ for resource in sk.resources:
+ if resource.path not in want or not resource.content:
+ continue
+ parts.append(f"\n[Doc] {resource.path}\n\n{resource.content}\n")
+ return "".join(parts)
+
+ def _maybe_offload_loaded_skills(self, ctx: InvocationContext, loaded: list[str]) -> None:
+ if get_skill_load_mode(ctx) != SkillLoadModeNames.ONCE or not loaded:
+ return
+ for skill_name in loaded:
+ ctx.actions.state_delta[loaded_state_key(ctx, skill_name)] = None
+ ctx.actions.state_delta[docs_state_key(ctx, skill_name)] = None
+ ctx.actions.state_delta[loaded_order_state_key(ctx)] = None
+
+
+def set_skill_tool_result_processor_parameters(agent_context: AgentContext, parameters: dict[str, Any]) -> None:
+ """Set the parameters of a skill tool result processor by agent context.
+
+ Args:
+ agent_context: AgentContext object
+ parameters: Parameters to set
+ """
+ skill_config = get_skill_config(agent_context)
+ skill_config["skills_tool_result_processor"].update(parameters)
+ set_skill_config(agent_context, skill_config)
+
+
+def get_skill_tool_result_processor_parameters(agent_context: AgentContext) -> dict[str, Any]:
+ """Get the parameters of a skill tool result processor.
+
+ Args:
+ agent_context: AgentContext object
+
+ Returns:
+ Parameters of the skill tool result processor
+ """
+ skill_config = get_skill_config(agent_context)
+ return skill_config["skills_tool_result_processor"]
diff --git a/trpc_agent_sdk/agents/core/_workspace_exec_processor.py b/trpc_agent_sdk/agents/core/_workspace_exec_processor.py
new file mode 100644
index 0000000..fc2d4d9
--- /dev/null
+++ b/trpc_agent_sdk/agents/core/_workspace_exec_processor.py
@@ -0,0 +1,153 @@
+# Tencent is pleased to support the open source community by making tRPC-Agent-Python available.
+#
+# Copyright (C) 2026 Tencent. All rights reserved.
+#
+# tRPC-Agent-Python is licensed under Apache-2.0.
+"""Inject workspace_exec guidance into request system instructions.
+"""
+
+from __future__ import annotations
+
+from typing import Any
+from typing import Callable
+from typing import Optional
+
+from trpc_agent_sdk.context import AgentContext
+from trpc_agent_sdk.context import InvocationContext
+from trpc_agent_sdk.models import LlmRequest
+from trpc_agent_sdk.skills import BaseSkillRepository
+from trpc_agent_sdk.skills import get_skill_config
+from trpc_agent_sdk.skills import set_skill_config
+
+_WORKSPACE_EXEC_GUIDANCE_HEADER = "Executor workspace guidance:"
+
+
+class WorkspaceExecRequestProcessor:
+ """Request processor for workspace_exec guidance injection."""
+
+ def __init__(
+ self,
+ *,
+ session_tools: bool = False,
+ has_skills_repo: bool = False,
+ repo_resolver: Optional[Callable[[InvocationContext], Optional[BaseSkillRepository]]] = None,
+ enabled_resolver: Optional[Callable[[InvocationContext], bool]] = None,
+ sessions_resolver: Optional[Callable[[InvocationContext], bool]] = None,
+ ) -> None:
+ self._session_tools = session_tools
+ self._static_skills_repo = has_skills_repo
+ self._repo_resolver = repo_resolver
+ self._enabled_resolver = enabled_resolver
+ self._sessions_resolver = sessions_resolver
+
+ async def process_llm_request(self, ctx: InvocationContext, request: LlmRequest) -> None:
+ """Inject workspace guidance into request.config.system_instruction."""
+ if ctx is None or request is None:
+ return
+ guidance = self._guidance_text(ctx, request)
+ if not guidance:
+ return
+
+ existing = ""
+ if request.config and request.config.system_instruction:
+ existing = str(request.config.system_instruction)
+ if _WORKSPACE_EXEC_GUIDANCE_HEADER in existing:
+ return
+ request.append_instructions([guidance])
+
+ def _guidance_text(self, ctx: InvocationContext, request: LlmRequest) -> str:
+ if not self._enabled_for_invocation(ctx, request):
+ return ""
+ lines: list[str] = [
+ _WORKSPACE_EXEC_GUIDANCE_HEADER,
+ "- Treat workspace_exec as the default general shell runner for shared "
+ "executor-side work. It runs inside the current executor workspace, not "
+ "on the agent host; workspace is its scope, not its capability limit.",
+ "- workspace_exec starts at the workspace root by default. Prefer work/, "
+ "out/, and runs/ for shared executor-side work, and treat cwd as a "
+ "workspace-relative path.",
+ "- Network access depends on the current executor environment. If you "
+ "need a network command such as curl, use a small bounded command to "
+ "verify whether that environment allows it.",
+ "- When a limitation depends on the executor environment and a small "
+ "bounded command can verify it, verify first before claiming the "
+ "limitation. This applies to checks such as command availability, file "
+ "presence, or access to a known URL.",
+ ]
+ if self._supports_artifact_save(request):
+ lines.append("- Use workspace_save_artifact only when you need a stable artifact "
+ "reference for an already existing file in work/, out/, or runs/. "
+ "Intermediate files usually stay in the workspace.")
+ if self._has_skills_repo(ctx):
+ lines.append("- Paths under skills/ are only useful when some other tool has "
+ "already placed content there. workspace_exec does not stage skills "
+ "automatically.")
+ if self._session_tools_for_invocation(ctx, request):
+ lines.append("- When workspace_exec starts a command that keeps running or waits "
+ "for stdin, continue with workspace_write_stdin. When chars is empty, "
+ "workspace_write_stdin acts like a poll. Use workspace_kill_session "
+ "to stop a running workspace_exec session.")
+ lines.append("- Interactive workspace_exec sessions are only guaranteed within the "
+ "current invocation. Do not assume a later user message can resume "
+ "the same session.")
+ return "\n".join(lines).strip()
+
+ def _enabled_for_invocation(self, ctx: InvocationContext, request: LlmRequest) -> bool:
+ if self._enabled_resolver is not None:
+ return bool(self._enabled_resolver(ctx))
+ return self._has_tool(request, "workspace_exec")
+
+ def _session_tools_for_invocation(self, ctx: InvocationContext, request: LlmRequest) -> bool:
+ if self._sessions_resolver is not None:
+ return bool(self._sessions_resolver(ctx))
+ if self._session_tools:
+ return True
+ return self._has_tool(request, "workspace_write_stdin") and self._has_tool(request, "workspace_kill_session")
+
+ def _has_skills_repo(self, ctx: InvocationContext) -> bool:
+ if self._repo_resolver is not None:
+ return self._repo_resolver(ctx) is not None
+ return self._static_skills_repo
+
+ def _supports_artifact_save(self, request: LlmRequest) -> bool:
+ return self._has_tool(request, "workspace_save_artifact")
+
+ @staticmethod
+ def _has_tool(request: LlmRequest, tool_name: str) -> bool:
+ if request is None or request.config is None or not request.config.tools:
+ return False
+ target = (tool_name or "").strip()
+ if not target:
+ return False
+ for tool in request.config.tools:
+ declarations = getattr(tool, "function_declarations", None) or []
+ for declaration in declarations:
+ name = getattr(declaration, "name", "")
+ if (name or "").strip() == target:
+ return True
+ return False
+
+
+def set_workspace_exec_processor_parameters(agent_context: AgentContext, parameters: dict[str, Any]) -> None:
+ """Set the parameters of a workspace exec processor by agent context.
+
+ Args:
+ agent_context: AgentContext object
+ parameters: Parameters to set
+ """
+ skill_config = get_skill_config(agent_context)
+ skill_config["workspace_exec_processor"].update(parameters)
+ set_skill_config(agent_context, skill_config)
+
+
+def get_workspace_exec_processor_parameters(agent_context: AgentContext) -> dict[str, Any]:
+ """Get the parameters of a workspace exec processor.
+
+ Args:
+ agent_context: AgentContext object
+
+ Returns:
+ Parameters of the workspace exec processor
+ """
+ skill_config = get_skill_config(agent_context)
+ return skill_config["workspace_exec_processor"]
diff --git a/trpc_agent_sdk/artifacts/_in_memory_artifact_service.py b/trpc_agent_sdk/artifacts/_in_memory_artifact_service.py
index add9962..ab6c5aa 100644
--- a/trpc_agent_sdk/artifacts/_in_memory_artifact_service.py
+++ b/trpc_agent_sdk/artifacts/_in_memory_artifact_service.py
@@ -1,8 +1,6 @@
-# Tencent is pleased to support the open source community by making tRPC-Agent-Python available.
+# -*- coding: utf-8 -*-
#
-# Copyright (C) 2026 Tencent. All rights reserved.
-#
-# tRPC-Agent-Python is licensed under Apache-2.0.
+# Copyright @ 2025 Tencent.com
#
# Directly reuse the types from adk-python
# Below code are copy and modified from https://github.com/google/adk-python.git
@@ -27,7 +25,6 @@
from pydantic import BaseModel
from pydantic import Field
-
from trpc_agent_sdk.abc import ArtifactEntry
from trpc_agent_sdk.abc import ArtifactId
from trpc_agent_sdk.abc import ArtifactServiceABC
@@ -80,7 +77,7 @@ async def save_artifact(
else:
raise ValueError("Not supported artifact type.")
- self.artifacts[path].append(ArtifactEntry(data=artifact, artifact_version=artifact_version))
+ self.artifacts[path].append(ArtifactEntry(data=artifact, version=artifact_version))
return version
@override
@@ -125,6 +122,8 @@ async def load_artifact(
if (artifact_data == Part() or artifact_data == Part(text="")
or (artifact_data.inline_data and not artifact_data.inline_data.data)):
return None
+ if artifact_entry.version.mime_type is None:
+ artifact_entry.version.mime_type = "application/octet-stream"
return artifact_entry
@override
@@ -175,7 +174,7 @@ async def list_artifact_versions(
entries = self.artifacts.get(path)
if not entries:
return []
- return [entry.artifact_version for entry in entries]
+ return [entry.version for entry in entries]
@override
async def get_artifact_version(
@@ -192,6 +191,6 @@ async def get_artifact_version(
if version is None:
version = -1
try:
- return entries[version].artifact_version
+ return entries[version].version
except IndexError:
return None
diff --git a/trpc_agent_sdk/artifacts/_utils.py b/trpc_agent_sdk/artifacts/_utils.py
index b9f0656..ee1444e 100644
--- a/trpc_agent_sdk/artifacts/_utils.py
+++ b/trpc_agent_sdk/artifacts/_utils.py
@@ -29,7 +29,6 @@
from typing import Optional
from google.genai import types
-
from trpc_agent_sdk.abc import ArtifactId
diff --git a/trpc_agent_sdk/code_executors/__init__.py b/trpc_agent_sdk/code_executors/__init__.py
index d16645e..276f4db 100644
--- a/trpc_agent_sdk/code_executors/__init__.py
+++ b/trpc_agent_sdk/code_executors/__init__.py
@@ -9,13 +9,9 @@
including base classes and implementations.
"""
-from ._artifacts import artifact_service_from_context
-from ._artifacts import artifact_session_from_context
from ._artifacts import load_artifact_helper
from ._artifacts import parse_artifact_ref
from ._artifacts import save_artifact_helper
-from ._artifacts import with_artifact_service
-from ._artifacts import with_artifact_session
from ._base_code_executor import BaseCodeExecutor
from ._base_workspace_runtime import BaseProgramRunner
from ._base_workspace_runtime import BaseWorkspaceFS
@@ -47,6 +43,20 @@
from ._constants import META_FILE_NAME
from ._constants import TMP_FILE_NAME
from ._constants import WORKSPACE_ENV_DIR_KEY
+from ._program_session import BaseProgramSession
+from ._program_session import DEFAULT_EXEC_YIELD_MS
+from ._program_session import DEFAULT_IO_YIELD_MS
+from ._program_session import DEFAULT_POLL_LINES
+from ._program_session import DEFAULT_SESSION_KILL_SEC
+from ._program_session import DEFAULT_SESSION_TTL_SEC
+from ._program_session import PROGRAM_STATUS_EXITED
+from ._program_session import PROGRAM_STATUS_RUNNING
+from ._program_session import ProgramLog
+from ._program_session import ProgramPoll
+from ._program_session import ProgramState
+from ._program_session import poll_line_limit
+from ._program_session import wait_for_program_output
+from ._program_session import yield_duration_ms
from ._types import CodeBlock
from ._types import CodeBlockDelimiter
from ._types import CodeExecutionInput
@@ -82,13 +92,9 @@
from .utils import CodeExecutionUtils
__all__ = [
- "artifact_service_from_context",
- "artifact_session_from_context",
"load_artifact_helper",
"parse_artifact_ref",
"save_artifact_helper",
- "with_artifact_service",
- "with_artifact_session",
"BaseCodeExecutor",
"BaseProgramRunner",
"BaseWorkspaceFS",
@@ -120,6 +126,20 @@
"META_FILE_NAME",
"TMP_FILE_NAME",
"WORKSPACE_ENV_DIR_KEY",
+ "BaseProgramSession",
+ "DEFAULT_EXEC_YIELD_MS",
+ "DEFAULT_IO_YIELD_MS",
+ "DEFAULT_POLL_LINES",
+ "DEFAULT_SESSION_KILL_SEC",
+ "DEFAULT_SESSION_TTL_SEC",
+ "PROGRAM_STATUS_EXITED",
+ "PROGRAM_STATUS_RUNNING",
+ "ProgramLog",
+ "ProgramPoll",
+ "ProgramState",
+ "poll_line_limit",
+ "wait_for_program_output",
+ "yield_duration_ms",
"CodeBlock",
"CodeBlockDelimiter",
"CodeExecutionInput",
diff --git a/trpc_agent_sdk/code_executors/_artifacts.py b/trpc_agent_sdk/code_executors/_artifacts.py
index 08e0cda..3c1087b 100644
--- a/trpc_agent_sdk/code_executors/_artifacts.py
+++ b/trpc_agent_sdk/code_executors/_artifacts.py
@@ -10,11 +10,9 @@
through context, enabling artifact resolution without importing higher-level packages.
"""
-from typing import Any
from typing import Optional
from typing import Tuple
-from trpc_agent_sdk.abc import ArtifactServiceABC
from trpc_agent_sdk.context import InvocationContext
from trpc_agent_sdk.types import Blob
from trpc_agent_sdk.types import Part
@@ -36,6 +34,8 @@ async def load_artifact_helper(ctx: InvocationContext,
Returns:
The artifact.
"""
+ if not ctx:
+ raise ValueError("ctx is required")
artifact_entry = await ctx.load_artifact(name, version)
if artifact_entry is None:
return None
@@ -55,17 +55,20 @@ def parse_artifact_ref(ref: str) -> Tuple[str, Optional[int]]:
The artifact name and version.
"""
parts = ref.split("@")
+ name = parts[0]
+ if not name:
+ raise ValueError(f"invalid ref: {ref}")
if len(parts) == 1:
- return parts[0], None
+ return name, None
- if len(parts) == 2:
+ if len(parts) >= 2:
# Try to parse version as integer
- version_str = parts[1]
+ version_str = "".join(parts[1:])
if not version_str.isdigit():
- raise ValueError(f"invalid version: {version_str}")
+ return name, None
- return parts[0], int(version_str)
+ return name, int(version_str)
raise ValueError(f"invalid ref: {ref}")
@@ -85,51 +88,3 @@ async def save_artifact_helper(ctx: InvocationContext, filename: str, data: byte
"""
artifact = Part(inline_data=Blob(data=data, mime_type=mime), )
return await ctx.save_artifact(filename, artifact)
-
-
-def with_artifact_service(ctx: InvocationContext, svc: ArtifactServiceABC) -> InvocationContext:
- """
- Store an artifact.Service in the context.
-
- Callers retrieve it in lower layers to load/save artifacts
- without importing higher-level packages.
-
- Args:
- ctx: The context to store the service in
- svc: The artifact service
-
- Returns:
- Updated context with artifact service
- """
- ctx.artifact_service = svc
- return ctx
-
-
-def artifact_service_from_context(ctx: InvocationContext) -> Optional[ArtifactServiceABC]:
- """
- Fetch the artifact.Service previously stored by with_artifact_service.
-
- Args:
- ctx: The context to retrieve the service from
-
- Returns:
- Tuple of (service, ok) where ok indicates presence
- """
- return ctx.artifact_service
-
-
-def with_artifact_session(ctx: InvocationContext, info: Any) -> InvocationContext:
- assert False, "Not implemented"
-
-
-def artifact_session_from_context(ctx: InvocationContext) -> Any:
- """
- Retrieve artifact session info from context.
-
- Args:
- ctx: The context to retrieve the session info from
-
- Returns:
- SessionInfo object (empty if not found)
- """
- assert False, "Not implemented"
diff --git a/trpc_agent_sdk/code_executors/_base_code_executor.py b/trpc_agent_sdk/code_executors/_base_code_executor.py
index 17ea06d..f7c6b06 100644
--- a/trpc_agent_sdk/code_executors/_base_code_executor.py
+++ b/trpc_agent_sdk/code_executors/_base_code_executor.py
@@ -16,7 +16,6 @@
from typing import Optional
from pydantic import BaseModel
-
from trpc_agent_sdk.context import InvocationContext
from ._base_workspace_runtime import BaseWorkspaceRuntime
@@ -82,6 +81,9 @@ class BaseCodeExecutor(BaseModel):
workspace_runtime: Optional[BaseWorkspaceRuntime] = None
"""The workspace runtime for the code execution."""
+ ignore_codes: list[str] = []
+ """The list of codes to ignore in the code execution."""
+
@abc.abstractmethod
async def execute_code(
self,
diff --git a/trpc_agent_sdk/code_executors/_base_workspace_runtime.py b/trpc_agent_sdk/code_executors/_base_workspace_runtime.py
index b57b6a9..3ee1e1e 100644
--- a/trpc_agent_sdk/code_executors/_base_workspace_runtime.py
+++ b/trpc_agent_sdk/code_executors/_base_workspace_runtime.py
@@ -12,10 +12,12 @@
from abc import ABC
from abc import abstractmethod
+from typing import Callable
from typing import List
from typing import Optional
from trpc_agent_sdk.context import InvocationContext
+from trpc_agent_sdk.log import logger
from ._types import CodeFile
from ._types import ManifestOutput
@@ -28,6 +30,8 @@
from ._types import WorkspaceRunResult
from ._types import WorkspaceStageOptions
+RunEnvProvider = Callable[[Optional[InvocationContext]], dict[str, str]]
+
class BaseWorkspaceManager(ABC):
"""
@@ -130,6 +134,40 @@ class BaseProgramRunner(ABC):
Executes programs within a workspace.
"""
+ def __init__(
+ self,
+ provider: Optional[RunEnvProvider] = None,
+ enable_provider_env: bool = False,
+ ) -> None:
+ self._run_env_provider = provider
+ self._enable_provider_env = bool(enable_provider_env and provider)
+
+ def _apply_provider_env(
+ self,
+ spec: WorkspaceRunProgramSpec,
+ ctx: Optional[InvocationContext] = None,
+ ) -> WorkspaceRunProgramSpec:
+ """Return spec with provider env merged when enabled.
+
+ Provider values never override keys already present in ``spec.env``.
+ The input ``spec`` is not mutated.
+ """
+ provider = getattr(self, "_run_env_provider", None)
+ if not getattr(self, "_enable_provider_env", False) or provider is None:
+ return spec
+ try:
+ extra = provider(ctx) or {}
+ except Exception as ex: # pylint: disable=broad-except
+ logger.warning("run env provider failed: %s", ex)
+ return spec
+ if not extra:
+ return spec
+ merged = dict(spec.env or {})
+ for key, value in extra.items():
+ if key not in merged:
+ merged[key] = value
+ return spec.model_copy(update={"env": merged}, deep=True)
+
@abstractmethod
async def run_program(
self,
diff --git a/trpc_agent_sdk/code_executors/_program_session.py b/trpc_agent_sdk/code_executors/_program_session.py
new file mode 100644
index 0000000..177a00a
--- /dev/null
+++ b/trpc_agent_sdk/code_executors/_program_session.py
@@ -0,0 +1,167 @@
+# -*- coding: utf-8 -*-
+#
+# Copyright @ 2026 Tencent.com
+"""Program-session helpers.
+"""
+
+from __future__ import annotations
+
+import asyncio
+import time
+from abc import ABC
+from abc import abstractmethod
+from dataclasses import dataclass
+from typing import Optional
+
+from ._types import WorkspaceRunResult
+
+# Default wait windows (milliseconds).
+DEFAULT_EXEC_YIELD_MS = 1_000
+DEFAULT_IO_YIELD_MS = 400
+DEFAULT_POLL_LINES = 40
+
+# Poll pacing / settle windows (seconds).
+DEFAULT_POLL_WAIT_SEC = 0.05
+DEFAULT_POLL_SETTLE_SEC = 0.075
+
+# Session lifecycle defaults (seconds).
+DEFAULT_SESSION_TTL_SEC = 30 * 60
+DEFAULT_SESSION_KILL_SEC = 2.0
+
+PROGRAM_STATUS_RUNNING = "running"
+PROGRAM_STATUS_EXITED = "exited"
+
+
+@dataclass
+class ProgramPoll:
+ """Incremental output chunk for a running or exited session."""
+
+ status: str = PROGRAM_STATUS_RUNNING
+ output: str = ""
+ offset: int = 0
+ next_offset: int = 0
+ exit_code: Optional[int] = None
+
+
+@dataclass
+class ProgramLog:
+ """Non-destructive output window from a specific offset."""
+
+ output: str = ""
+ offset: int = 0
+ next_offset: int = 0
+
+
+@dataclass
+class ProgramState:
+ """Non-streaming session status without cursor mutation."""
+
+ status: str = PROGRAM_STATUS_RUNNING
+ exit_code: Optional[int] = None
+
+
+class BaseProgramSession(ABC):
+ """Base class for program sessions."""
+
+ @abstractmethod
+ def id(self) -> str:
+ """Return stable session id."""
+
+ @abstractmethod
+ async def poll(self, limit: Optional[int] = None) -> ProgramPoll:
+ """Advance cursor and return incremental output."""
+
+ @abstractmethod
+ async def log(self, offset: Optional[int] = None, limit: Optional[int] = None) -> ProgramLog:
+ """Read output from offset without advancing cursor."""
+
+ @abstractmethod
+ async def write(self, data: str, newline: bool) -> None:
+ """Write input to session."""
+
+ @abstractmethod
+ async def kill(self, grace_seconds: float) -> None:
+ """Terminate session, escalating after grace period."""
+
+ @abstractmethod
+ async def close(self) -> None:
+ """Release resources and stop background routines."""
+
+ @abstractmethod
+ async def state(self) -> ProgramState:
+ """Return current state snapshot."""
+
+ @abstractmethod
+ async def run_result(self) -> WorkspaceRunResult:
+ """Return final run result after session exits."""
+
+
+def yield_duration_ms(ms: int, fallback_ms: int) -> float:
+ """Normalize milliseconds into seconds with fallback and clamping."""
+ if ms < 0:
+ ms = 0
+ if ms == 0:
+ ms = fallback_ms
+ return ms / 1000.0
+
+
+def poll_line_limit(lines: int) -> int:
+ """Return a positive poll-line limit with default fallback."""
+ if lines <= 0:
+ lines = DEFAULT_POLL_LINES
+ return lines
+
+
+async def wait_for_program_output(
+ proc: BaseProgramSession,
+ yield_seconds: float,
+ limit: Optional[int],
+) -> ProgramPoll:
+ """Poll until session exits or output settles within the yield window."""
+ deadline = time.monotonic()
+ if yield_seconds > 0:
+ deadline += yield_seconds
+
+ out_parts: list[str] = []
+ offset = 0
+ next_offset = 0
+ have_chunk = False
+ settle_deadline = 0.0
+
+ while True:
+ poll = await proc.poll(limit)
+ if poll.output:
+ if not have_chunk:
+ offset = poll.offset
+ have_chunk = True
+ out_parts.append(poll.output)
+ next_offset = poll.next_offset
+ settle_deadline = time.monotonic() + DEFAULT_POLL_SETTLE_SEC
+ if yield_seconds <= 0:
+ deadline = settle_deadline
+ elif not have_chunk:
+ offset = poll.offset
+ next_offset = poll.next_offset
+ else:
+ next_offset = poll.next_offset
+
+ if poll.status == PROGRAM_STATUS_EXITED:
+ poll.output = "".join(out_parts)
+ poll.offset = offset
+ poll.next_offset = next_offset
+ return poll
+
+ now = time.monotonic()
+ if settle_deadline and now > settle_deadline:
+ poll.output = "".join(out_parts)
+ poll.offset = offset
+ poll.next_offset = next_offset
+ return poll
+
+ if yield_seconds > 0 and now > deadline:
+ poll.output = "".join(out_parts)
+ poll.offset = offset
+ poll.next_offset = next_offset
+ return poll
+
+ await asyncio.sleep(DEFAULT_POLL_WAIT_SEC)
diff --git a/trpc_agent_sdk/code_executors/_types.py b/trpc_agent_sdk/code_executors/_types.py
index 3323201..5130d48 100644
--- a/trpc_agent_sdk/code_executors/_types.py
+++ b/trpc_agent_sdk/code_executors/_types.py
@@ -11,7 +11,6 @@
from pydantic import BaseModel
from pydantic import Field
-
from trpc_agent_sdk.types import CodeExecutionResult
from trpc_agent_sdk.types import Outcome
@@ -149,6 +148,9 @@ class WorkspaceRunProgramSpec(BaseModel):
limits: WorkspaceResourceLimits = Field(default_factory=WorkspaceResourceLimits)
""" resource limits"""
+ tty: bool = Field(default=False, description="Allocate pseudo-TTY")
+ """ whether to allocate pseudo-TTY"""
+
class WorkspaceRunResult(BaseModel):
"""
diff --git a/trpc_agent_sdk/code_executors/container/_container_cli.py b/trpc_agent_sdk/code_executors/container/_container_cli.py
index 2e06d0c..39447ad 100644
--- a/trpc_agent_sdk/code_executors/container/_container_cli.py
+++ b/trpc_agent_sdk/code_executors/container/_container_cli.py
@@ -1,8 +1,6 @@
-# Tencent is pleased to support the open source community by making tRPC-Agent-Python available.
+# -*- coding: utf-8 -*-
#
-# Copyright (C) 2026 Tencent. All rights reserved.
-#
-# tRPC-Agent-Python is licensed under Apache-2.0.
+# Copyright @ 2025 Tencent.com
"""Container code executor for TRPC Agent framework.
This module provides a code executor that uses a custom container to execute code.
@@ -14,12 +12,15 @@
import asyncio
import atexit
import os
+import socket as pysocket
from dataclasses import dataclass
from typing import Optional
import docker
from docker.models.containers import Container
-
+from docker.utils.socket import consume_socket_output
+from docker.utils.socket import demux_adaptor
+from docker.utils.socket import frames_iter
from trpc_agent_sdk.log import logger
from trpc_agent_sdk.utils import CommandExecResult
@@ -48,6 +49,8 @@ class CommandArgs:
"""The environment variables for the command execution."""
timeout: Optional[float] = None
"""The timeout for the command execution in seconds."""
+ stdin: Optional[str] = None
+ """Optional stdin content to write once before reading output."""
class ContainerClient:
@@ -151,6 +154,16 @@ def _init_container(self):
# docker SDK `run` supports bind specs via `volumes`.
run_kwargs["volumes"] = binds
logger.info("Container bind mounts enabled: %s", binds)
+ command = self.host_config.get("command", ["tail", "-f", "/dev/null"])
+ stdin = self.host_config.get("stdin", True)
+ working_dir = self.host_config.get("working_dir", "/")
+ network_mode = self.host_config.get("network_mode", "none")
+ auto_remove = self.host_config.get("auto_remove", True)
+ run_kwargs.setdefault("command", command)
+ run_kwargs.setdefault("stdin_open", stdin)
+ run_kwargs.setdefault("working_dir", working_dir)
+ run_kwargs.setdefault("network_mode", network_mode)
+ run_kwargs.setdefault("auto_remove", auto_remove)
self._container = self._client.containers.run(
image=self.image,
detach=True,
@@ -189,23 +202,100 @@ def _cleanup_container(self):
return
logger.info("[Cleanup] Stopping the container...")
- self._container.stop()
- self._container.remove()
+ try:
+ self._container.stop()
+ except Exception: # pylint: disable=broad-except
+ pass
+ try:
+ self._container.remove()
+ except Exception: # pylint: disable=broad-except
+ pass
logger.info("Container %s stopped and removed.", self._container.id)
# self._container = None
+ def _exec_run_with_stdin(
+ self,
+ cmd: list[str],
+ environment: dict[str, str],
+ stdin: str,
+ ) -> CommandExecResult:
+ """Execute command with attached stdin, similar to docker exec attach."""
+ resp = self.container.client.api.exec_create(
+ self.container.container.id,
+ cmd=cmd[:],
+ stdout=True,
+ stderr=True,
+ stdin=True,
+ tty=False,
+ environment=environment,
+ )
+ exec_id = resp["Id"]
+ sock = self.container.client.api.exec_start(
+ exec_id,
+ detach=False,
+ tty=False,
+ stream=False,
+ socket=True,
+ demux=False,
+ )
+ try:
+ data = (stdin or "").encode("utf-8")
+ if data:
+ try:
+ sock.sendall(data)
+ except Exception: # pylint: disable=broad-except
+ # Some transports expose the real socket as _sock.
+ sock._sock.sendall(data) # pylint: disable=protected-access
+
+ try:
+ sock.shutdown(pysocket.SHUT_WR)
+ except Exception: # pylint: disable=broad-except
+ close_write = getattr(sock, "close_write", None)
+ if callable(close_write):
+ close_write()
+
+ frames = frames_iter(sock, tty=False)
+ demux_frames = (demux_adaptor(*frame) for frame in frames)
+ output = consume_socket_output(demux_frames, demux=True)
+ stdout = output[0].decode("utf-8") if output and output[0] else ""
+ stderr = output[1].decode("utf-8") if output and output[1] else ""
+ finally:
+ try:
+ sock.close()
+ except Exception: # pylint: disable=broad-except
+ pass
+
+ inspect = self.container.client.api.exec_inspect(exec_id)
+ exit_code = int(inspect.get("ExitCode", -1))
+ return CommandExecResult(stdout=stdout, stderr=stderr, exit_code=exit_code, is_timeout=False)
+
async def exec_run(self, cmd: list[str], command_args: CommandArgs) -> CommandExecResult:
"""Execute command in container."""
timeout = command_args.timeout
try:
loop = asyncio.get_event_loop()
- co = loop.run_in_executor(
- None,
- lambda: self.container.exec_run(cmd=cmd[:], demux=True, environment=command_args.environment or {}))
+ if command_args.stdin:
+ co = loop.run_in_executor(
+ None,
+ lambda: self._exec_run_with_stdin(
+ cmd,
+ command_args.environment or {},
+ command_args.stdin or "",
+ ),
+ )
+ else:
+ co = loop.run_in_executor(
+ None,
+ lambda: self.container.exec_run(cmd=cmd[:], demux=True, environment=command_args.environment or {}))
if command_args.timeout:
- exit_code, output = await asyncio.wait_for(co, timeout=command_args.timeout)
+ result = await asyncio.wait_for(co, timeout=command_args.timeout)
else:
- exit_code, output = await co
+ result = await co
+
+ if command_args.stdin:
+ return result
+
+ exit_code, output = result
stdout = output[0].decode('utf-8') if output[0] else ""
stderr = output[1].decode('utf-8') if output[1] else ""
except asyncio.TimeoutError:
diff --git a/trpc_agent_sdk/code_executors/container/_container_code_executor.py b/trpc_agent_sdk/code_executors/container/_container_code_executor.py
index 13ba145..56eec06 100644
--- a/trpc_agent_sdk/code_executors/container/_container_code_executor.py
+++ b/trpc_agent_sdk/code_executors/container/_container_code_executor.py
@@ -15,7 +15,6 @@
from typing_extensions import override
from pydantic import Field
-
from trpc_agent_sdk.context import InvocationContext
from .._base_code_executor import BaseCodeExecutor
diff --git a/trpc_agent_sdk/code_executors/container/_container_ws_runtime.py b/trpc_agent_sdk/code_executors/container/_container_ws_runtime.py
index 7bf5b34..ce9fb0f 100644
--- a/trpc_agent_sdk/code_executors/container/_container_ws_runtime.py
+++ b/trpc_agent_sdk/code_executors/container/_container_ws_runtime.py
@@ -15,12 +15,15 @@
"""
import io
+import json
import os
import tarfile
import time
from dataclasses import dataclass
from dataclasses import field
+from datetime import datetime
from pathlib import Path
+from typing import Any
from typing import Dict
from typing import List
from typing import Optional
@@ -37,9 +40,13 @@
from .._base_workspace_runtime import BaseWorkspaceFS
from .._base_workspace_runtime import BaseWorkspaceManager
from .._base_workspace_runtime import BaseWorkspaceRuntime
+from .._base_workspace_runtime import RunEnvProvider
from .._constants import DEFAULT_INPUTS_CONTAINER
+from .._constants import DEFAULT_MAX_FILES
+from .._constants import DEFAULT_MAX_TOTAL_BYTES
from .._constants import DEFAULT_RUN_CONTAINER_BASE
from .._constants import DEFAULT_SKILLS_CONTAINER
+from .._constants import DEFAULT_TIMEOUT_SEC
from .._constants import DIR_OUT
from .._constants import DIR_RUNS
from .._constants import DIR_SKILLS
@@ -62,7 +69,10 @@
from .._types import WorkspaceRunProgramSpec
from .._types import WorkspaceRunResult
from .._types import WorkspaceStageOptions
+from ..utils import InputRecordMeta
+from ..utils import WorkspaceMetadata
from ..utils import get_rel_path
+from ..utils import normalize_globs
from ._container_cli import CommandArgs
from ._container_cli import ContainerClient
from ._container_cli import ContainerConfig
@@ -78,10 +88,16 @@ class RuntimeConfig:
run_container_base: str = DEFAULT_RUN_CONTAINER_BASE
inputs_host_base: str = ""
inputs_container_base: str = DEFAULT_INPUTS_CONTAINER
- auto_map_inputs: bool = False
+ auto_map_inputs: bool = True
command_args: CommandArgs = field(default_factory=CommandArgs)
+def _shell_quote(s: str) -> str:
+ if not s:
+ return "''"
+ return "'" + s.replace("'", "'\\''") + "'"
+
+
class ContainerWorkspaceManager(BaseWorkspaceManager):
"""
Docker container-based workspace manager implementation.
@@ -124,12 +140,15 @@ async def create_workspace(self, exec_id: str, ctx: Optional[InvocationContext]
ws_path = str(Path(self.config.run_container_base) / f"ws_{safe_id}_{suffix}")
# Create standard directory layout
- cmd_parts = [
- f"mkdir -p '{ws_path}'", f"'{ws_path}/{DIR_SKILLS}'", f"'{ws_path}/{DIR_WORK}'", f"'{ws_path}/{DIR_RUNS}'",
- f"'{ws_path}/{DIR_OUT}'",
- f"&& [ -f '{ws_path}/{META_FILE_NAME}' ] || echo '{{}}' > '{ws_path}/{META_FILE_NAME}'"
- ]
- cmd = ["/bin/bash", "-lc", " ".join(cmd_parts)]
+ cmd_str = ("set -e; "
+ f"mkdir -p {_shell_quote(ws_path)} "
+ f"{_shell_quote(str(Path(ws_path) / DIR_SKILLS))} "
+ f"{_shell_quote(str(Path(ws_path) / DIR_WORK))} "
+ f"{_shell_quote(str(Path(ws_path) / DIR_RUNS))} "
+ f"{_shell_quote(str(Path(ws_path) / DIR_OUT))}; "
+ f"[ -f {_shell_quote(str(Path(ws_path) / META_FILE_NAME))} ] || "
+ f"echo '{{}}' > {_shell_quote(str(Path(ws_path) / META_FILE_NAME))}")
+ cmd = ["/bin/bash", "-lc", cmd_str]
result = await self.container.exec_run(cmd=cmd, command_args=self.config.command_args)
if result.exit_code != 0:
@@ -268,8 +287,8 @@ async def stage_directory(self,
cmd = ["/bin/bash", "-lc", cmd_str]
result = await self.container.exec_run(cmd=cmd, command_args=self.config.command_args)
if result.exit_code != 0:
- logger.debug("Failed to stage directory: %s", result.stderr)
- logger.debug("Staged directory using mount: %s -> %s", container_src, container_dst)
+ raise RuntimeError(f"Failed to stage directory: {result.stderr}")
+ return
# Fallback: tar copy
await self._put_directory(ws, src_abs_path, dst)
@@ -327,8 +346,15 @@ async def collect(self,
continue
seen.add(rel_path)
- data, mime = self._copy_file_out(line)
- files.append(CodeFile(name=rel_path, content=data.decode('utf-8', errors='replace'), mime_type=mime))
+ data, size_bytes, mime = self._copy_file_out(line)
+ files.append(
+ CodeFile(
+ name=rel_path,
+ content=data.decode('utf-8', errors='replace'),
+ mime_type=mime,
+ size_bytes=size_bytes,
+ truncated=size_bytes > len(data),
+ ))
logger.info("Collected %s files from workspace", len(files))
return files
@@ -349,32 +375,56 @@ async def stage_inputs(self,
Raises:
RuntimeError: If staging fails
"""
+ md = await self._load_workspace_metadata(ws)
for spec in specs:
- mode = spec.mode.lower().strip() or "copy"
- dst = spec.dst.strip() or str(Path(DIR_WORK) / "inputs" / self._input_base(spec.src))
- dst = os.path.join(ws.path, dst)
+ mode = (spec.mode or "").lower().strip() or "copy"
+ dst_rel = (spec.dst or "").strip() or str(Path(DIR_WORK) / "inputs" / self._input_base(spec.src))
+ dst_abs = str(Path(ws.path) / dst_rel)
+
+ resolved = ""
+ version: Optional[int] = None
if spec.src.startswith("artifact://"):
- name = spec.src.removeprefix("artifact://")
- resolved, ver = parse_artifact_ref(name)
if not ctx:
raise ValueError("Context is required to load artifacts")
- content, ver = await load_artifact_helper(ctx, resolved, ver)
- await self._put_bytes_tar(content, dst)
+ name = spec.src.removeprefix("artifact://")
+ artifact_name, requested_ver = parse_artifact_ref(name)
+ use_ver = requested_ver
+ if use_ver is None and spec.pin:
+ use_ver = self._pinned_artifact_version(md, artifact_name, dst_rel)
+ content, actual_ver = await load_artifact_helper(ctx, artifact_name, use_ver)
+ await self._put_bytes_tar(content, dst_abs)
+ resolved = artifact_name
+ version = use_ver if use_ver is not None else actual_ver
elif spec.src.startswith("host://"):
host_path = spec.src.removeprefix("host://")
- await self._stage_host_input(ws, host_path, dst, mode)
+ await self._stage_host_input(ws, host_path, dst_abs, mode, dst_rel)
+ resolved = host_path
elif spec.src.startswith("workspace://"):
rel = spec.src.removeprefix("workspace://")
src = str(Path(ws.path) / rel)
- await self._stage_workspace_input(src, dst, mode)
+ await self._stage_workspace_input(src, dst_abs, mode)
+ resolved = rel
elif spec.src.startswith("skill://"):
rest = spec.src.removeprefix("skill://")
src = str(Path(ws.path) / DIR_SKILLS / rest)
- await self._stage_workspace_input(src, dst, mode)
+ await self._stage_workspace_input(src, dst_abs, mode)
+ resolved = src
else:
raise RuntimeError(f"Unsupported input: {spec.src}")
+ md.inputs.append(
+ InputRecordMeta(
+ src=spec.src,
+ dst=dst_rel,
+ resolved=resolved,
+ version=version,
+ mode=mode,
+ timestamp=datetime.now(),
+ ))
+
+ await self._save_workspace_metadata(ws, md)
+
logger.info("Staged %s inputs into workspace", len(specs))
@override
@@ -409,13 +459,15 @@ async def collect_outputs(self,
raise RuntimeError(f"Failed to collect outputs: {result.stderr}")
stdout = result.stdout
- max_files = spec.max_files or 100
+ max_files = spec.max_files or DEFAULT_MAX_FILES
max_file_bytes = spec.max_file_bytes or MAX_READ_SIZE_BYTES
- max_total = spec.max_total_bytes or 64 * 1024 * 1024
+ max_total = spec.max_total_bytes or DEFAULT_MAX_TOTAL_BYTES
manifest = ManifestOutput()
total_bytes = 0
count = 0
+ saved_names: list[str] = []
+ saved_versions: list[int] = []
for line in stdout.strip().split('\n'):
line = line.strip()
if not line:
@@ -425,14 +477,25 @@ async def collect_outputs(self,
manifest.limits_hit = True
break
- data, mime = self._copy_file_out(line)
+ data, raw_size, mime = self._copy_file_out(line)
if len(data) > max_file_bytes:
data = data[:max_file_bytes]
manifest.limits_hit = True
+ if total_bytes + len(data) > max_total:
+ remain = max_total - total_bytes
+ if remain <= 0:
+ manifest.limits_hit = True
+ break
+ data = data[:remain]
+ manifest.limits_hit = True
+
total_bytes += len(data)
rel_path = line.removeprefix(f"{ws.path}/")
+ truncated = raw_size > len(data)
+ if truncated and spec.save:
+ raise RuntimeError(f"cannot save truncated output file: {rel_path}")
file_ref = ManifestFileRef(name=rel_path, mime_type=mime)
if spec.inline:
@@ -446,6 +509,8 @@ async def collect_outputs(self,
version = await save_artifact_helper(ctx, save_name, data, mime)
file_ref.saved_as = save_name
file_ref.version = version
+ saved_names.append(save_name)
+ saved_versions.append(version)
manifest.files.append(file_ref)
count += 1
@@ -455,6 +520,8 @@ async def collect_outputs(self,
async def _put_directory(self, ws: WorkspaceInfo, src: str, dst: str) -> None:
"""Copy directory to container using tar."""
+ if not src or not str(src).strip():
+ raise ValueError("source path is empty")
abs_src = os.path.abspath(src)
container_dst = str(Path(ws.path) / dst) if dst else ws.path
if self.config.skills_host_base:
@@ -464,8 +531,9 @@ async def _put_directory(self, ws: WorkspaceInfo, src: str, dst: str) -> None:
# Create destination directory
cmd = ["/bin/bash", "-lc", f"mkdir -p '{container_dst}' && cp -a '{container_src}/.' '{container_dst}'"]
result = await self.container.exec_run(cmd=cmd, command_args=self.config.command_args)
- if result.exit_code:
- logger.debug("Failed to stage directory: %s", result.stderr)
+ if result.exit_code == 0:
+ return None
+ logger.debug("Failed to stage directory via mount copy, fallback to tar: %s", result.stderr)
cmd = ["/bin/bash", "-lc", f"[ -e '{container_dst}' ] || mkdir -p '{container_dst}'"]
result = await self.container.exec_run(cmd=cmd, command_args=self.config.command_args)
@@ -507,7 +575,7 @@ async def _put_bytes_tar(self, data: bytes, dest: str, mode: int = 0o644) -> Non
if not success:
raise RuntimeError(f"Failed to copy bytes to {dest}")
- async def _stage_host_input(self, ws: WorkspaceInfo, host: str, dst: str, mode: str) -> None:
+ async def _stage_host_input(self, ws: WorkspaceInfo, host: str, dst: str, mode: str, dst_rel: str) -> None:
"""Stage input from host path."""
if self.config.inputs_host_base:
rel_path = get_rel_path(self.config.inputs_host_base, host)
@@ -525,11 +593,11 @@ async def _stage_host_input(self, ws: WorkspaceInfo, host: str, dst: str, mode:
cmd = ["/bin/bash", "-lc", cmd_str]
result = await self.container.exec_run(cmd=cmd, command_args=self.config.command_args)
- if result.exit_code:
- logger.debug("Failed to stage input: %s", result.stderr)
+ if result.exit_code != 0:
+ raise RuntimeError(f"Failed to stage host input: {result.stderr}")
return
# Fallback to tar copy
- await self._put_directory(ws, host, str(Path(dst).parent))
+ await self._put_directory(ws, host, str(Path(dst_rel).parent))
async def _stage_workspace_input(self, src: str, dst: str, mode: str) -> None:
"""Stage input from workspace path."""
@@ -547,7 +615,7 @@ async def _stage_workspace_input(self, src: str, dst: str, mode: str) -> None:
if result.exit_code:
raise RuntimeError(f"Failed to stage input: {result.stderr}")
- def _copy_file_out(self, full_path: str) -> Tuple[bytes, str]:
+ def _copy_file_out(self, full_path: str) -> Tuple[bytes, int, str]:
"""
Copy file out of container.
@@ -555,7 +623,7 @@ def _copy_file_out(self, full_path: str) -> Tuple[bytes, str]:
full_path: Full path to file in container
Returns:
- Tuple of (file_data, mime_type)
+ Tuple of (file_data, size_bytes, mime_type)
Raises:
RuntimeError: If copy fails
@@ -570,7 +638,7 @@ def _copy_file_out(self, full_path: str) -> Tuple[bytes, str]:
f = tar.extractfile(member)
data = f.read(MAX_READ_SIZE_BYTES)
mime = self._detect_mime_type(data)
- return data, mime
+ return data, member.size, mime
raise RuntimeError(f"No file found in archive: {full_path}")
@@ -597,24 +665,75 @@ def _create_tar_from_files(files: List[WorkspacePutFileInfo]) -> io.BytesIO:
@staticmethod
def _normalize_globs(patterns: List[str]) -> List[str]:
"""Normalize glob patterns."""
- normalized = []
- for p in patterns:
- p = p.strip()
- if p:
- # Simple normalization - replace environment variables
- p = p.replace("$OUTPUT_DIR", DIR_OUT)
- p = p.replace("${OUTPUT_DIR}", DIR_OUT)
- p = p.replace("$WORK_DIR", DIR_WORK)
- p = p.replace("${WORK_DIR}", DIR_WORK)
- p = p.replace("$WORKSPACE_DIR", ".")
- p = p.replace("${WORKSPACE_DIR}", ".")
- normalized.append(p)
- return normalized
+ return normalize_globs(patterns)
@staticmethod
- def _input_base(path: str) -> str:
+ def _input_base(src: str) -> str:
"""Extract base name from input path."""
- return Path(path).name
+ s = (src or "").strip()
+ if s.startswith("artifact://"):
+ ref = s.removeprefix("artifact://")
+ try:
+ name, _ = parse_artifact_ref(ref)
+ base = Path(name.strip()).name
+ if base and base not in (".", "..", "/"):
+ return base
+ except Exception: # pylint: disable=broad-except
+ pass
+ return Path(s).name
+
+ @staticmethod
+ def _pinned_artifact_version(md: Any, artifact_name: str, dst: str) -> Optional[int]:
+ for record in reversed(md.inputs or []):
+ if (record.dst or "") != dst:
+ continue
+ if record.version is None:
+ continue
+ if (record.resolved or "") == artifact_name:
+ return record.version
+ src = record.src or ""
+ if not src.startswith("artifact://"):
+ continue
+ try:
+ name, _ = parse_artifact_ref(src.removeprefix("artifact://"))
+ except Exception: # pylint: disable=broad-except
+ continue
+ if name == artifact_name:
+ return record.version
+ return None
+
+ async def _load_workspace_metadata(self, ws: WorkspaceInfo):
+ now = datetime.now()
+ cmd = ["/bin/bash", "-lc", f"cat {_shell_quote(str(Path(ws.path) / META_FILE_NAME))}"]
+ result = await self.container.exec_run(cmd=cmd, command_args=self.config.command_args)
+ if result.exit_code != 0 or not result.stdout.strip():
+ return WorkspaceMetadata(version=1, created_at=now, updated_at=now, last_access=now, skills={})
+ try:
+ data = json.loads(result.stdout)
+ md = WorkspaceMetadata(**data)
+ except Exception as ex: # pylint: disable=broad-except
+ raise RuntimeError(f"Failed to parse workspace metadata: {ex}") from ex
+ if not md.version:
+ md.version = 1
+ if md.created_at is None:
+ md.created_at = now
+ md.last_access = now
+ if md.skills is None:
+ md.skills = {}
+ return md
+
+ async def _save_workspace_metadata(self, ws: WorkspaceInfo, md: Any) -> None:
+ now = datetime.now()
+ if not md.version:
+ md.version = 1
+ if md.created_at is None:
+ md.created_at = now
+ md.updated_at = now
+ md.last_access = now
+ if md.skills is None:
+ md.skills = {}
+ payload = json.dumps(md.model_dump(exclude_none=True, by_alias=True, mode="json"), ensure_ascii=False, indent=2)
+ await self._put_bytes_tar(payload.encode("utf-8"), str(Path(ws.path) / META_FILE_NAME), mode=0o600)
@staticmethod
def _detect_mime_type(data: bytes) -> str:
@@ -637,7 +756,13 @@ class ContainerProgramRunner(BaseProgramRunner):
Docker container-based program runner implementation.
"""
- def __init__(self, container: ContainerClient, config: RuntimeConfig):
+ def __init__(
+ self,
+ container: ContainerClient,
+ config: RuntimeConfig,
+ provider: Optional[RunEnvProvider] = None,
+ enable_provider_env: bool = False,
+ ):
"""
Initialize container program runner.
@@ -646,6 +771,7 @@ def __init__(self, container: ContainerClient, config: RuntimeConfig):
container: Docker container to use
config: Runtime configuration
"""
+ super().__init__(provider=provider, enable_provider_env=enable_provider_env)
self.container = container
self.config = config
@@ -667,6 +793,7 @@ async def run_program(self,
Raises:
RuntimeError: If execution fails
"""
+ spec = self._apply_provider_env(spec, ctx)
cwd = f"{ws.path}/{spec.cwd}" if spec.cwd else ws.path
# Prepare directories
@@ -685,35 +812,41 @@ async def run_program(self,
}
env_parts = []
+ user_env = dict(spec.env or {})
for k, v in base_env.items():
- if k not in spec.env:
- env_parts.append(f"{k}={self._shell_quote(v)}")
+ if k not in user_env:
+ env_parts.append(f"{k}={_shell_quote(v)}")
- for k, v in spec.env.items():
- env_parts.append(f"{k}={self._shell_quote(v)}")
+ for k, v in user_env.items():
+ env_parts.append(f"{k}={_shell_quote(v)}")
env_str = " ".join(env_parts)
# Build command line
cmd_parts = [
- f"mkdir -p {self._shell_quote(run_dir)} {self._shell_quote(out_dir)}", f"&& cd {self._shell_quote(cwd)}",
+ f"mkdir -p {_shell_quote(run_dir)} {_shell_quote(out_dir)}", f"&& cd {_shell_quote(cwd)}",
"&& env" if env_str else "", env_str,
- self._shell_quote(spec.cmd)
+ _shell_quote(spec.cmd)
]
for arg in spec.args:
- cmd_parts.append(self._shell_quote(arg))
+ cmd_parts.append(_shell_quote(arg))
cmd_str = " ".join(filter(None, cmd_parts))
cmd = ["/bin/bash", "-lc", cmd_str]
start_time = time.time()
- timeout = spec.timeout or self.config.command_args.timeout
- if timeout is None:
+ if spec.timeout and spec.timeout > 0:
timeout = spec.timeout
+ elif self.config.command_args.timeout and self.config.command_args.timeout > 0:
+ timeout = self.config.command_args.timeout
else:
- timeout = min(timeout, spec.timeout)
- command_args = CommandArgs(environment=None, timeout=timeout)
+ timeout = float(DEFAULT_TIMEOUT_SEC)
+ command_args = CommandArgs(
+ environment=None,
+ timeout=timeout,
+ stdin=spec.stdin or None,
+ )
result = await self.container.exec_run(cmd=cmd, command_args=command_args)
return WorkspaceRunResult(stdout=result.stdout,
stderr=result.stderr,
@@ -721,28 +854,20 @@ async def run_program(self,
duration=time.time() - start_time,
timed_out=result.is_timeout)
- @staticmethod
- def _shell_quote(s: str) -> str:
- """
- Quote string for safe shell usage.
-
- Args:
- s: String to quote
-
- Returns:
- Quoted string
- """
- if not s:
- return "''"
- return "'" + s.replace("'", "'\\''") + "'"
-
class ContainerWorkspaceRuntime(BaseWorkspaceRuntime):
"""
Docker container-based execution engine.
"""
- def __init__(self, container: ContainerClient, host_config: Optional[Dict] = None, auto_inputs: bool = False):
+ def __init__(
+ self,
+ container: ContainerClient,
+ host_config: Optional[Dict] = None,
+ auto_inputs: bool = True,
+ provider: Optional[RunEnvProvider] = None,
+ enable_provider_env: bool = False,
+ ):
"""
Initialize container engine.
@@ -763,7 +888,12 @@ def __init__(self, container: ContainerClient, host_config: Optional[Dict] = Non
self._fs = ContainerWorkspaceFS(self.container, config)
self._manager = ContainerWorkspaceManager(self.container, config, self._fs)
- self._runner = ContainerProgramRunner(self.container, config)
+ self._runner = ContainerProgramRunner(
+ self.container,
+ config,
+ provider=provider,
+ enable_provider_env=enable_provider_env,
+ )
@override
def manager(self, ctx: Optional[InvocationContext] = None) -> ContainerWorkspaceManager:
@@ -797,8 +927,8 @@ def _find_bind_source(binds: List[str], dest: str) -> str:
if len(parts) < 2:
continue
- # Handle format: source:dest[:mode]
- bind_dest = parts[-2] if len(parts) >= 2 else ""
+ # Handle format: source:dest[:mode], parse from right.
+ bind_dest = parts[-2]
if bind_dest == dest:
source = ':'.join(parts[:-2]) if len(parts) > 2 else parts[0]
if Path(source).is_dir():
@@ -820,7 +950,9 @@ def describe(self, ctx: Optional[InvocationContext] = None) -> WorkspaceCapabili
def create_container_workspace_runtime(
container_config: Optional[ContainerConfig] = None,
host_config: Optional[Dict] = None,
- auto_inputs: bool = False,
+ auto_inputs: bool = True,
+ provider: Optional[RunEnvProvider] = None,
+ enable_provider_env: bool = False,
) -> ContainerWorkspaceRuntime:
"""Create a new container workspace runtime.
Args:
@@ -838,4 +970,10 @@ def create_container_workspace_runtime(
container = ContainerClient(config=cfg)
else:
container = ContainerClient(config=ContainerConfig(host_config=host_config))
- return ContainerWorkspaceRuntime(container=container, host_config=host_config, auto_inputs=auto_inputs)
+ return ContainerWorkspaceRuntime(
+ container=container,
+ host_config=host_config,
+ auto_inputs=auto_inputs,
+ provider=provider,
+ enable_provider_env=enable_provider_env,
+ )
diff --git a/trpc_agent_sdk/code_executors/local/_local_program_session.py b/trpc_agent_sdk/code_executors/local/_local_program_session.py
new file mode 100644
index 0000000..fb1be8a
--- /dev/null
+++ b/trpc_agent_sdk/code_executors/local/_local_program_session.py
@@ -0,0 +1,299 @@
+# Tencent is pleased to support the open source community by making tRPC-Agent-Python available.
+#
+# Copyright (C) 2026 Tencent. All rights reserved.
+#
+# tRPC-Agent-Python is licensed under Apache-2.0.
+"""WorkspaceInfo runtime for local code execution.
+
+This module provides the WorkspaceRuntime class which allows local code execution.
+It provides methods for staging directories and inputs into the workspace.
+"""
+
+from __future__ import annotations
+
+import asyncio
+import os
+import sys
+import time
+import uuid
+from typing import Optional
+from typing_extensions import override
+
+from .._program_session import BaseProgramSession
+from .._program_session import PROGRAM_STATUS_EXITED
+from .._program_session import PROGRAM_STATUS_RUNNING
+from .._program_session import ProgramLog
+from .._program_session import ProgramPoll
+from .._program_session import ProgramState
+from .._types import WorkspaceRunResult
+
+_DEFAULT_INTERACTIVE_MAX_LINES = 20_000
+if sys.platform != "win32":
+ import fcntl
+
+
+def _split_lines_with_partial(text: str) -> tuple[list[str], str]:
+ normalized = text.replace("\r\n", "\n")
+ parts = normalized.split("\n")
+ if len(parts) == 1:
+ return [], parts[0]
+ return parts[:-1], parts[-1]
+
+
+class LocalProgramSession(BaseProgramSession):
+ """Local interactive subprocess session."""
+
+ def __init__(
+ self,
+ process: asyncio.subprocess.Process,
+ *,
+ max_lines: int = _DEFAULT_INTERACTIVE_MAX_LINES,
+ master_fd: Optional[int] = None,
+ ) -> None:
+ self._id = uuid.uuid4().hex
+ self._process = process
+ self._max_lines = max_lines
+ self._master_fd = master_fd
+ self._lock = asyncio.Lock()
+ self._closed = False
+
+ self._started_at = time.time()
+ self._finished_at: Optional[float] = None
+ self._exit_code: Optional[int] = None
+ self._timed_out = False
+
+ self._line_base = 0
+ self._lines: list[str] = []
+ self._partial = ""
+ self._poll_cursor = 0
+
+ self._stdout = ""
+ self._stderr = ""
+ if self._master_fd is not None:
+ self._stdout_task = asyncio.create_task(self._read_pty(self._master_fd))
+ self._stderr_task = asyncio.create_task(asyncio.sleep(0))
+ else:
+ self._stdout_task = asyncio.create_task(self._read_stream(self._process.stdout, stream_name="stdout"))
+ self._stderr_task = asyncio.create_task(self._read_stream(self._process.stderr, stream_name="stderr"))
+ self._wait_task = asyncio.create_task(self._watch_process_exit())
+
+ @override
+ def id(self) -> str:
+ return self._id
+
+ async def _read_stream(self, reader: Optional[asyncio.StreamReader], *, stream_name: str) -> None:
+ if reader is None:
+ return
+ while True:
+ chunk = await reader.read(4096)
+ if not chunk:
+ return
+ await self._append_output(chunk.decode("utf-8", errors="replace"), stream=stream_name)
+
+ async def _read_pty(self, master_fd: int) -> None:
+ loop = asyncio.get_running_loop()
+ flags = fcntl.fcntl(master_fd, fcntl.F_GETFL)
+ fcntl.fcntl(master_fd, fcntl.F_SETFL, flags | os.O_NONBLOCK)
+
+ read_event = asyncio.Event()
+
+ def _on_readable() -> None:
+ read_event.set()
+
+ loop.add_reader(master_fd, _on_readable)
+ try:
+ while True:
+ read_event.clear()
+ if self._process.returncode is not None:
+ while True:
+ try:
+ data = os.read(master_fd, 4096)
+ if not data:
+ break
+ await self._append_output(data.decode("utf-8", errors="replace"), stream="stdout")
+ except BlockingIOError:
+ break
+ except OSError:
+ break
+ return
+
+ try:
+ await asyncio.wait_for(read_event.wait(), timeout=0.05)
+ except asyncio.TimeoutError:
+ pass
+ try:
+ data = os.read(master_fd, 4096)
+ if data:
+ await self._append_output(data.decode("utf-8", errors="replace"), stream="stdout")
+ except BlockingIOError:
+ pass
+ except OSError:
+ return
+ finally:
+ loop.remove_reader(master_fd)
+
+ async def _append_output(self, chunk: str, *, stream: str) -> None:
+ normalized = chunk.replace("\r\n", "\n")
+ async with self._lock:
+ if stream == "stderr":
+ self._stderr += normalized
+ else:
+ self._stdout += normalized
+
+ merged = self._partial + normalized
+ lines, self._partial = _split_lines_with_partial(merged)
+ self._lines.extend(lines)
+ self._trim_lines_locked()
+
+ async def _watch_process_exit(self) -> None:
+ code = await self._process.wait()
+ await asyncio.gather(self._stdout_task, self._stderr_task, return_exceptions=True)
+ async with self._lock:
+ if self._finished_at is not None:
+ return
+ if self._partial:
+ self._lines.append(self._partial)
+ self._partial = ""
+ self._trim_lines_locked()
+ self._exit_code = code
+ self._finished_at = time.time()
+
+ def _trim_lines_locked(self) -> None:
+ if self._max_lines <= 0:
+ return
+ if len(self._lines) <= self._max_lines:
+ return
+ drop = len(self._lines) - self._max_lines
+ self._lines = self._lines[drop:]
+ self._line_base += drop
+ if self._poll_cursor < self._line_base:
+ self._poll_cursor = self._line_base
+
+ @override
+ async def poll(self, limit: Optional[int] = None) -> ProgramPoll:
+ async with self._lock:
+ start = self._poll_cursor
+ if start < self._line_base:
+ start = self._line_base
+ self._poll_cursor = start
+ end = self._line_base + len(self._lines)
+ if limit is not None and limit > 0:
+ end = min(end, start + limit)
+
+ out = ""
+ if end > start:
+ out = "\n".join(self._lines[start - self._line_base:end - self._line_base])
+ if end == self._line_base + len(self._lines) and self._partial:
+ out = f"{out}\n{self._partial}" if out else self._partial
+
+ self._poll_cursor = end
+ status = PROGRAM_STATUS_RUNNING if self._finished_at is None else PROGRAM_STATUS_EXITED
+ return ProgramPoll(
+ status=status,
+ output=out,
+ offset=start,
+ next_offset=end,
+ exit_code=self._exit_code,
+ )
+
+ @override
+ async def log(self, offset: Optional[int] = None, limit: Optional[int] = None) -> ProgramLog:
+ async with self._lock:
+ start = self._line_base if offset is None else offset
+ end = self._line_base + len(self._lines)
+
+ if start < self._line_base:
+ start = self._line_base
+ if start > end:
+ start = end
+ if limit is not None and limit > 0:
+ end = min(end, start + limit)
+
+ out = ""
+ if end > start:
+ out = "\n".join(self._lines[start - self._line_base:end - self._line_base])
+ if end == self._line_base + len(self._lines) and self._partial:
+ out = f"{out}\n{self._partial}" if out else self._partial
+ return ProgramLog(output=out, offset=start, next_offset=end)
+
+ @override
+ async def write(self, data: str, newline: bool) -> None:
+ if not data and not newline:
+ return
+ if self._process.returncode is not None:
+ raise ValueError("session is not running")
+ text = data
+ if newline:
+ text += "\n"
+ if self._master_fd is not None:
+ try:
+ os.write(self._master_fd, text.encode("utf-8"))
+ return
+ except OSError as ex:
+ raise ValueError("stdin is not available") from ex
+ stdin = self._process.stdin
+ if stdin is None:
+ raise ValueError("stdin is not available")
+ stdin.write(text.encode("utf-8"))
+ await stdin.drain()
+
+ @override
+ async def kill(self, grace_seconds: float) -> None:
+ if self._process.returncode is not None:
+ return
+ self._process.terminate()
+ try:
+ await asyncio.wait_for(self._process.wait(), timeout=max(0.0, grace_seconds))
+ return
+ except asyncio.TimeoutError:
+ pass
+ if self._process.returncode is None:
+ self._process.kill()
+ await self._process.wait()
+
+ @override
+ async def close(self) -> None:
+ async with self._lock:
+ if self._closed:
+ return
+ self._closed = True
+ if self._process.returncode is None:
+ await self.kill(0.5)
+ if self._process.stdin is not None:
+ try:
+ self._process.stdin.close()
+ except Exception: # pylint: disable=broad-except
+ pass
+ await asyncio.gather(self._stdout_task, self._stderr_task, self._wait_task, return_exceptions=True)
+ if self._master_fd is not None:
+ try:
+ os.close(self._master_fd)
+ except OSError:
+ pass
+
+ async def enforce_timeout(self, timeout_sec: float) -> None:
+ if timeout_sec <= 0:
+ return
+ await asyncio.sleep(timeout_sec)
+ if self._process.returncode is None:
+ self._timed_out = True
+ await self.kill(0.5)
+
+ @override
+ async def state(self) -> ProgramState:
+ if self._finished_at is None:
+ return ProgramState(status=PROGRAM_STATUS_RUNNING)
+ return ProgramState(status=PROGRAM_STATUS_EXITED, exit_code=self._exit_code)
+
+ @override
+ async def run_result(self) -> WorkspaceRunResult:
+ duration = 0.0
+ if self._finished_at is not None:
+ duration = self._finished_at - self._started_at
+ return WorkspaceRunResult(
+ stdout=self._stdout,
+ stderr=self._stderr,
+ exit_code=self._exit_code or 0,
+ duration=duration,
+ timed_out=self._timed_out,
+ )
diff --git a/trpc_agent_sdk/code_executors/local/_local_ws_runtime.py b/trpc_agent_sdk/code_executors/local/_local_ws_runtime.py
index 87151c6..1b245ee 100644
--- a/trpc_agent_sdk/code_executors/local/_local_ws_runtime.py
+++ b/trpc_agent_sdk/code_executors/local/_local_ws_runtime.py
@@ -11,11 +11,14 @@
from __future__ import annotations
+import asyncio
import os
import re
import shutil
+import sys
import tempfile
import time
+import uuid
from datetime import datetime
from pathlib import Path
from typing import List
@@ -32,6 +35,7 @@
from .._base_workspace_runtime import BaseWorkspaceFS
from .._base_workspace_runtime import BaseWorkspaceManager
from .._base_workspace_runtime import BaseWorkspaceRuntime
+from .._base_workspace_runtime import RunEnvProvider
from .._constants import DEFAULT_FILE_MODE
from .._constants import DEFAULT_MAX_FILES
from .._constants import DEFAULT_MAX_TOTAL_BYTES
@@ -46,6 +50,7 @@
from .._constants import ENV_WORK_DIR
from .._constants import MAX_READ_SIZE_BYTES
from .._constants import WORKSPACE_ENV_DIR_KEY
+from .._program_session import BaseProgramSession
from .._types import CodeFile
from .._types import ManifestFileRef
from .._types import ManifestOutput
@@ -69,6 +74,11 @@
from ..utils import path_join
from ..utils import save_metadata
+if sys.platform != "win32":
+ import pty
+
+from ._local_program_session import LocalProgramSession
+
class LocalWorkspaceManager(BaseWorkspaceManager):
"""Local workspace manager for executing commands in skill workspaces."""
@@ -380,7 +390,10 @@ async def stage_inputs(
if mode == "link":
make_symlink(ws.path, dst.as_posix(), host_path)
else:
- self._put_directory(ws, host_path, dst.parent.as_posix())
+ # Preserve caller-provided destination path exactly.
+ # For files, copy to the exact dst filename.
+ # For directories, copy tree into dst directory.
+ copy_path(host_path, path_join(ws.path, dst.as_posix()))
elif spec.src.startswith("workspace://"):
# Handle workspace inputs
rel = spec.src[len("workspace://"):]
@@ -574,6 +587,35 @@ def _read_limited_with_cap(
class LocalProgramRunner(BaseProgramRunner):
"""Local program runner for executing commands in skill workspaces."""
+ def __init__(
+ self,
+ provider: Optional[RunEnvProvider] = None,
+ enable_provider_env: bool = False,
+ ):
+ super().__init__(provider=provider, enable_provider_env=enable_provider_env)
+
+ def _build_program_env(self, ws: WorkspaceInfo, spec: WorkspaceRunProgramSpec) -> dict[str, str]:
+ env = os.environ.copy()
+ user_env = dict(spec.env or {})
+ wr_path = Path(ws.path)
+ ensure_layout(wr_path)
+ run_dir = wr_path / DIR_RUNS / f"run_{datetime.now().strftime('%Y%m%dT%H%M%S.%f')}"
+ run_dir.mkdir(parents=True, exist_ok=True)
+
+ base_env = {
+ WORKSPACE_ENV_DIR_KEY: ws.path,
+ ENV_SKILLS_DIR: str(Path(ws.path) / DIR_SKILLS),
+ ENV_WORK_DIR: str(Path(ws.path) / DIR_WORK),
+ ENV_OUTPUT_DIR: str(Path(ws.path) / DIR_OUT),
+ ENV_RUN_DIR: str(run_dir),
+ }
+ for key, value in base_env.items():
+ if key not in user_env:
+ env[key] = value
+ if user_env:
+ env.update(user_env)
+ return env
+
@override
async def run_program(self,
ws: WorkspaceInfo,
@@ -591,37 +633,13 @@ async def run_program(self,
Returns:
Execution result
"""
+ spec = self._apply_provider_env(spec, ctx)
# Resolve cwd under workspace
cwd = Path(path_join(ws.path, spec.cwd))
cwd.mkdir(parents=True, exist_ok=True)
timeout = spec.timeout or float(DEFAULT_TIMEOUT_SEC)
-
- # Build environment
- env = os.environ.copy()
-
- wr_path = Path(ws.path)
- # Ensure layout exists and compute run dir
- ensure_layout(wr_path)
- run_dir = wr_path / DIR_RUNS / f"run_{datetime.now().strftime('%Y%m%dT%H%M%S.%f')}"
- run_dir.mkdir(parents=True, exist_ok=True)
-
- # Inject well-known variables if not set
- base_env = {
- WORKSPACE_ENV_DIR_KEY: ws.path,
- ENV_SKILLS_DIR: str(Path(ws.path) / DIR_SKILLS),
- ENV_WORK_DIR: str(Path(ws.path) / DIR_WORK),
- ENV_OUTPUT_DIR: str(Path(ws.path) / DIR_OUT),
- ENV_RUN_DIR: str(run_dir),
- }
-
- for key, value in base_env.items():
- if key not in spec.env:
- env[key] = value
-
- # Add user-provided environment variables
- if spec.env:
- env.update(spec.env)
+ env = self._build_program_env(ws, spec)
# Prepare command
cmd_args = [spec.cmd] + (spec.args or [])
@@ -639,6 +657,62 @@ async def run_program(self,
duration=time.time() - start_time,
timed_out=result.is_timeout)
+ async def start_program(
+ self,
+ ctx: Optional[InvocationContext],
+ ws: WorkspaceInfo,
+ spec: WorkspaceRunProgramSpec,
+ ) -> BaseProgramSession:
+ """Start an interactive program session in workspace."""
+ if spec.tty and sys.platform == "win32":
+ raise ValueError("interactive tty is not supported on windows")
+
+ spec = self._apply_provider_env(spec, ctx)
+ cwd = Path(path_join(ws.path, spec.cwd))
+ cwd.mkdir(parents=True, exist_ok=True)
+ env = self._build_program_env(ws, spec)
+ timeout = spec.timeout or float(DEFAULT_TIMEOUT_SEC)
+
+ cmd_args = [spec.cmd] + (spec.args or [])
+ if spec.tty:
+ master_fd, slave_fd = pty.openpty()
+ try:
+ process = await asyncio.create_subprocess_exec(
+ *cmd_args,
+ cwd=str(cwd),
+ env=env,
+ stdin=slave_fd,
+ stdout=slave_fd,
+ stderr=slave_fd,
+ close_fds=True,
+ preexec_fn=os.setsid,
+ )
+ except Exception:
+ os.close(master_fd)
+ os.close(slave_fd)
+ raise
+ finally:
+ try:
+ os.close(slave_fd)
+ except OSError:
+ pass
+ session = LocalProgramSession(process, master_fd=master_fd)
+ else:
+ process = await asyncio.create_subprocess_exec(
+ *cmd_args,
+ cwd=str(cwd),
+ env=env,
+ stdin=asyncio.subprocess.PIPE,
+ stdout=asyncio.subprocess.PIPE,
+ stderr=asyncio.subprocess.PIPE,
+ )
+ session = LocalProgramSession(process)
+ if timeout > 0:
+ asyncio.create_task(session.enforce_timeout(float(timeout)))
+ if spec.stdin:
+ await session.write(spec.stdin, newline=False)
+ return session
+
class LocalWorkspaceRuntime(BaseWorkspaceRuntime):
"""Local workspace for executing commands in skill workspaces."""
@@ -647,9 +721,11 @@ def __init__(self,
work_root: str = '',
read_only_staged_skill: bool = False,
auto_inputs: bool = True,
- inputs_host_base: str = ""):
+ inputs_host_base: str = "",
+ provider: Optional[RunEnvProvider] = None,
+ enable_provider_env: bool = False):
self._fs = LocalWorkspaceFS(read_only_staged_skill)
- self._runner = LocalProgramRunner()
+ self._runner = LocalProgramRunner(provider=provider, enable_provider_env=enable_provider_env)
self._manager = LocalWorkspaceManager(work_root, auto_inputs, inputs_host_base, self._fs)
@override
@@ -681,6 +757,9 @@ def describe(self, ctx: Optional[InvocationContext] = None) -> WorkspaceCapabili
def create_local_workspace_runtime(work_root: str = '',
read_only_staged_skill: bool = False,
auto_inputs: bool = True,
- inputs_host_base: str = "") -> LocalWorkspaceRuntime:
+ inputs_host_base: str = "",
+ provider: Optional[RunEnvProvider] = None,
+ enable_provider_env: bool = False) -> LocalWorkspaceRuntime:
"""Create a new local workspace runtime."""
- return LocalWorkspaceRuntime(work_root, read_only_staged_skill, auto_inputs, inputs_host_base)
+ return LocalWorkspaceRuntime(work_root, read_only_staged_skill, auto_inputs, inputs_host_base, provider,
+ enable_provider_env)
diff --git a/trpc_agent_sdk/code_executors/local/_unsafe_local_code_executor.py b/trpc_agent_sdk/code_executors/local/_unsafe_local_code_executor.py
index 9593ca6..3b3aa19 100644
--- a/trpc_agent_sdk/code_executors/local/_unsafe_local_code_executor.py
+++ b/trpc_agent_sdk/code_executors/local/_unsafe_local_code_executor.py
@@ -17,7 +17,6 @@
from typing_extensions import override
from pydantic import Field
-
from trpc_agent_sdk.context import InvocationContext
from trpc_agent_sdk.utils import async_execute_command
@@ -44,7 +43,7 @@ class UnsafeLocalCodeExecutor(BaseCodeExecutor):
work_dir: str = Field(default="", description="The working directory for the code execution.")
- timeout: float = Field(default=0, description="The timeout for the code execution.")
+ timeout: float = Field(default=0, description="The timeout seconds for the code execution.")
clean_temp_files: bool = Field(default=True,
description="Whether to clean temporary files after the code execution.")
diff --git a/trpc_agent_sdk/code_executors/utils/_code_execution.py b/trpc_agent_sdk/code_executors/utils/_code_execution.py
index bae808a..8635015 100644
--- a/trpc_agent_sdk/code_executors/utils/_code_execution.py
+++ b/trpc_agent_sdk/code_executors/utils/_code_execution.py
@@ -24,6 +24,20 @@
class CodeExecutionUtils:
"""Utility functions for code execution."""
+ @classmethod
+ def _is_ignored_code_block(cls, code: str, ignore_codes: list[str]) -> bool:
+ """Return True when code block first line is in ignore list."""
+ if not code:
+ return False
+ lines = code.splitlines()
+ if not lines:
+ return False
+ first_line = lines[0].strip()
+ if not first_line:
+ return False
+ ignore_set = {item.strip() for item in ignore_codes if item and item.strip()}
+ return first_line in ignore_set
+
@classmethod
def prepare_globals(cls, code: str, globals_: dict[str, Any]) -> None:
"""Prepare globals for code execution, injecting __name__ if needed."""
@@ -61,6 +75,7 @@ def extract_code_and_truncate_content(
cls,
content: Content,
code_block_delimiters: list[CodeBlockDelimiter],
+ ignore_codes: list[str] = None,
) -> list[CodeBlock]:
"""Extracts all code blocks from the content and reconstructs content.parts.
@@ -71,10 +86,12 @@ def extract_code_and_truncate_content(
content: The mutable content to extract the code from.
code_block_delimiters: The list of the enclosing delimiters to identify
the code blocks.
+ ignore_codes: The list of codes to ignore.
Returns:
The first code block if found; otherwise, None.
"""
+ ignore_codes = ignore_codes or []
code_blocks = []
if not content or not content.parts:
return code_blocks
@@ -84,12 +101,13 @@ def extract_code_and_truncate_content(
total_len = len(content.parts)
for idx, part in enumerate(content.parts):
if part.executable_code:
+ code_str = part.executable_code.code or ""
+ if cls._is_ignored_code_block(code_str, ignore_codes):
+ continue
if idx < total_len - 1 and not content.parts[idx + 1].code_execution_result:
- code_blocks.append(CodeBlock(code=part.executable_code.code,
- language=part.executable_code.language))
+ code_blocks.append(CodeBlock(code=code_str, language=part.executable_code.language))
if idx == total_len - 1:
- code_blocks.append(CodeBlock(code=part.executable_code.code,
- language=part.executable_code.language))
+ code_blocks.append(CodeBlock(code=code_str, language=part.executable_code.language))
# If there are code blocks, return them.
if code_blocks:
return code_blocks
@@ -140,6 +158,9 @@ def extract_code_and_truncate_content(
# Extract code content, removing leading/trailing whitespace and newlines
code_str = code_content.strip()
+ if cls._is_ignored_code_block(code_str, ignore_codes):
+ last_end = match.end()
+ continue
# Store first code block for return value
if first_code is None:
diff --git a/trpc_agent_sdk/server/langfuse/tracing/opentelemetry.py b/trpc_agent_sdk/server/langfuse/tracing/opentelemetry.py
index 166f20a..5dab42d 100644
--- a/trpc_agent_sdk/server/langfuse/tracing/opentelemetry.py
+++ b/trpc_agent_sdk/server/langfuse/tracing/opentelemetry.py
@@ -58,7 +58,6 @@ def _should_skip_span(self, span: ReadableSpan) -> bool:
Returns:
True if the span should be skipped, False otherwise.
"""
- global _langfuse_config # pylint: disable=invalid-name
# If enable_a2a_trace is True, don't filter out any spans
if _langfuse_config.enable_a2a_trace:
return False
@@ -99,7 +98,6 @@ def _should_skip_span(self, span: ReadableSpan) -> bool:
def _transform_span_for_langfuse(self, span: ReadableSpan) -> ReadableSpan:
"""Transform TRPC agent span attributes to Langfuse format."""
- global _langfuse_config # pylint: disable=invalid-name
trpc_span_name = get_trpc_agent_span_name()
if span.name == "invocation":
span_name = span.attributes.get(f"{trpc_span_name}.runner.name", "unknown")
@@ -262,6 +260,8 @@ def _map_generation_attributes(self, attributes: Dict[str, Any]) -> Dict[str, An
# Map usage details
gen_attrs["gen_ai.usage.input_tokens"] = attributes.get("gen_ai.usage.input_tokens", "0")
gen_attrs["gen_ai.usage.output_tokens"] = attributes.get("gen_ai.usage.output_tokens", "0")
+ if "gen_ai.request.model" in attributes:
+ gen_attrs["gen_ai.request.model"] = attributes["gen_ai.request.model"]
# Map generation metadata
gen_metadata = {}
@@ -603,7 +603,6 @@ def setup(config: Optional[LangfuseConfig] = None) -> TracerProvider:
Raises:
ValueError: If required configuration is missing.
"""
- global _langfuse_config # pylint: disable=invalid-name
if config is None:
config = LangfuseConfig()
@@ -619,6 +618,7 @@ def setup(config: Optional[LangfuseConfig] = None) -> TracerProvider:
if config.host is None:
config.host = os.getenv("LANGFUSE_HOST")
+ global _langfuse_config # pylint: disable=invalid-name
# Set the global config and use it in span processor(Only be setted once)
_langfuse_config = config
diff --git a/trpc_agent_sdk/sessions/_base_session_service.py b/trpc_agent_sdk/sessions/_base_session_service.py
index 0d984c1..dd41fb7 100644
--- a/trpc_agent_sdk/sessions/_base_session_service.py
+++ b/trpc_agent_sdk/sessions/_base_session_service.py
@@ -84,6 +84,9 @@ async def append_event(self, session: Session, event: Event) -> Event:
"""Appends an event to a session object."""
if event.partial:
return event
+ # Apply temp-scoped state to in-memory session before trimming event delta,
+ # so same-invocation consumers can still read temp values.
+ self._apply_temp_state(session, event)
event = self._trim_temp_delta_state(event)
self.__update_session_state(session, event)
session.add_event(event,
@@ -91,6 +94,18 @@ async def append_event(self, session: Session, event: Event) -> Event:
max_events=self._session_config.max_events)
return event
+ def _apply_temp_state(self, session: Session, event: Event) -> None:
+ """Apply temp-scoped state delta to in-memory session state only.
+
+ Temp state is intentionally ephemeral: it should be visible within
+ current invocation memory but not persisted into stored event deltas.
+ """
+ if not event.actions or not event.actions.state_delta:
+ return
+ for key, value in event.actions.state_delta.items():
+ if key.startswith(State.TEMP_PREFIX):
+ session.state[key] = value
+
def _trim_temp_delta_state(self, event: Event) -> Event:
"""Removes temporary state delta keys from the event."""
if not event.actions or not event.actions.state_delta:
diff --git a/trpc_agent_sdk/sessions/_history_record.py b/trpc_agent_sdk/sessions/_history_record.py
index 62086a6..1950e09 100644
--- a/trpc_agent_sdk/sessions/_history_record.py
+++ b/trpc_agent_sdk/sessions/_history_record.py
@@ -23,6 +23,7 @@ class HistoryRecord(BaseModel):
user_texts: list[str] = Field(default_factory=list, description="List of user text")
# The text of the user
assistant_texts: list[str] = Field(default_factory=list, description="List of assistant text")
+
# The text of the assistant
def add_record(self, user_text: str, assistant_text: str | None = ""):
diff --git a/trpc_agent_sdk/sessions/_session.py b/trpc_agent_sdk/sessions/_session.py
index 7598a47..28c0b91 100644
--- a/trpc_agent_sdk/sessions/_session.py
+++ b/trpc_agent_sdk/sessions/_session.py
@@ -11,7 +11,6 @@
from typing import List
from pydantic import Field
-
from trpc_agent_sdk.abc import SessionABC
from trpc_agent_sdk.events import Event
@@ -54,10 +53,7 @@ def apply_event_filtering(self, event_ttl_seconds: float = 0.0, max_events: int
2. Count filtering: Keep only the most recent max_events events
If both filters result in removing all events, the method attempts to
- preserve the first user message and all events after it from the original events.
-
- The filtering logic is inspired by the ApplyEventFiltering function
- from trpc-agent-go: https://github.com/trpc-group/trpc-agent-go
+ preserve the first user message and all events after it from the original events.
Args:
event_ttl_seconds: Time-to-live in seconds for events. If 0, no TTL filtering is applied.
diff --git a/trpc_agent_sdk/skills/__init__.py b/trpc_agent_sdk/skills/__init__.py
index b8afbb3..d457bd6 100644
--- a/trpc_agent_sdk/skills/__init__.py
+++ b/trpc_agent_sdk/skills/__init__.py
@@ -1,3 +1,7 @@
+# -*- coding: utf-8 -*-
+#
+# Copyright @ 2026 Tencent.com
+
# Tencent is pleased to support the open source community by making tRPC-Agent-Python available.
#
# Copyright (C) 2026 Tencent. All rights reserved.
@@ -21,120 +25,156 @@
>>> tools = await toolset.get_tools()
"""
-from ._common import BaseSelectionResult
from ._common import SelectionMode
-from ._common import add_selection
-from ._common import clear_selection
+from ._common import docs_scan_prefix
+from ._common import docs_state_key
from ._common import generic_get_selection
-from ._common import generic_select_items
-from ._common import get_previous_selection
-from ._common import get_state_delta_value
-from ._common import replace_selection
-from ._common import set_state_delta_for_selection
+from ._common import loaded_order_state_key
+from ._common import loaded_scan_prefix
+from ._common import loaded_state_key
+from ._common import tool_scan_prefix
+from ._common import tool_state_key
+from ._common import use_session_skill_state
from ._constants import ENV_SKILLS_ROOT
+from ._constants import SKILL_ARTIFACTS_STATE_KEY
+from ._constants import SKILL_DOCS_BY_AGENT_STATE_KEY_PREFIX
from ._constants import SKILL_DOCS_STATE_KEY_PREFIX
from ._constants import SKILL_FILE
+from ._constants import SKILL_LOADED_BY_AGENT_STATE_KEY_PREFIX
+from ._constants import SKILL_LOADED_ORDER_BY_AGENT_STATE_KEY_PREFIX
+from ._constants import SKILL_LOADED_ORDER_STATE_KEY_PREFIX
from ._constants import SKILL_LOADED_STATE_KEY_PREFIX
+from ._constants import SKILL_LOAD_MODE_VALUES
from ._constants import SKILL_REGISTRY_KEY
from ._constants import SKILL_REPOSITORY_KEY
+from ._constants import SKILL_TOOLS_BY_AGENT_STATE_KEY_PREFIX
+from ._constants import SKILL_TOOLS_NAMES
from ._constants import SKILL_TOOLS_STATE_KEY_PREFIX
+from ._constants import SkillLoadModeNames
+from ._constants import SkillProfileNames
+from ._constants import SkillToolsNames
from ._dynamic_toolset import DynamicSkillToolSet
from ._registry import SkillRegistry
from ._repository import BaseSkillRepository
from ._repository import FsSkillRepository
+from ._repository import VisibilityFilter
from ._repository import create_default_skill_repository
+from ._skill_config import get_skill_config
+from ._skill_config import get_skill_load_mode
+from ._skill_config import set_skill_config
+from ._skill_profile import SkillProfileFlags
+from ._state_keys import docs_key
+from ._state_keys import docs_prefix
+from ._state_keys import docs_session_key
+from ._state_keys import docs_session_prefix
+from ._state_keys import loaded_key
+from ._state_keys import loaded_order_key
+from ._state_keys import loaded_prefix
+from ._state_keys import loaded_session_key
+from ._state_keys import loaded_session_order_key
+from ._state_keys import loaded_session_prefix
+from ._state_keys import tool_key
+from ._state_keys import tool_prefix
+from ._state_keys import tool_session_key
+from ._state_keys import tool_session_prefix
+from ._state_migration import SKILLS_LEGACY_MIGRATION_STATE_KEY
+from ._state_migration import maybe_migrate_legacy_skill_state
+from ._state_order import marshal_loaded_order
+from ._state_order import parse_loaded_order
+from ._state_order import touch_loaded_order
from ._toolset import SkillToolSet
from ._types import Skill
from ._types import SkillConfig
from ._types import SkillFrontMatter
-from ._types import SkillMetadata
from ._types import SkillRequires
from ._types import SkillResource
from ._types import SkillSummary
-from ._types import SkillWorkspaceInputRecord
-from ._types import SkillWorkspaceMetadata
-from ._types import SkillWorkspaceOutputRecord
-from ._types import format_datetime
-from ._types import parse_datetime
-from ._url_root import ArchiveExt
from ._url_root import ArchiveExtractor
-from ._url_root import ArchiveKind
-from ._url_root import CacheConfig
-from ._url_root import FilePerm
-from ._url_root import SizeLimit
from ._url_root import SkillRootResolver
-from ._url_root import TarPerm
-from ._utils import compute_dir_digest
-from ._utils import ensure_layout
from ._utils import get_state_delta
-from ._utils import load_metadata
-from ._utils import save_metadata
from ._utils import set_state_delta
-from ._utils import shell_quote
+from .tools import SkillLoadTool
from .tools import SkillRunTool
from .tools import skill_list
from .tools import skill_list_docs
from .tools import skill_list_tools
-from .tools import skill_load
from .tools import skill_select_docs
from .tools import skill_select_tools
__all__ = [
- "BaseSelectionResult",
"SelectionMode",
- "add_selection",
- "clear_selection",
+ "docs_scan_prefix",
+ "docs_state_key",
"generic_get_selection",
- "generic_select_items",
- "get_previous_selection",
- "get_state_delta_value",
- "replace_selection",
- "set_state_delta_for_selection",
+ "loaded_order_state_key",
+ "loaded_scan_prefix",
+ "loaded_state_key",
+ "tool_scan_prefix",
+ "tool_state_key",
+ "use_session_skill_state",
"ENV_SKILLS_ROOT",
+ "SKILL_ARTIFACTS_STATE_KEY",
+ "SKILL_DOCS_BY_AGENT_STATE_KEY_PREFIX",
"SKILL_DOCS_STATE_KEY_PREFIX",
"SKILL_FILE",
+ "SKILL_LOADED_BY_AGENT_STATE_KEY_PREFIX",
+ "SKILL_LOADED_ORDER_BY_AGENT_STATE_KEY_PREFIX",
+ "SKILL_LOADED_ORDER_STATE_KEY_PREFIX",
"SKILL_LOADED_STATE_KEY_PREFIX",
+ "SKILL_LOAD_MODE_VALUES",
"SKILL_REGISTRY_KEY",
"SKILL_REPOSITORY_KEY",
+ "SKILL_TOOLS_BY_AGENT_STATE_KEY_PREFIX",
+ "SKILL_TOOLS_NAMES",
"SKILL_TOOLS_STATE_KEY_PREFIX",
+ "SkillLoadModeNames",
+ "SkillProfileNames",
+ "SkillToolsNames",
"DynamicSkillToolSet",
"SkillRegistry",
"BaseSkillRepository",
"FsSkillRepository",
+ "VisibilityFilter",
"create_default_skill_repository",
+ "get_skill_config",
+ "get_skill_load_mode",
+ "set_skill_config",
+ "SkillProfileFlags",
+ "docs_key",
+ "docs_prefix",
+ "docs_session_key",
+ "docs_session_prefix",
+ "loaded_key",
+ "loaded_order_key",
+ "loaded_prefix",
+ "loaded_session_key",
+ "loaded_session_order_key",
+ "loaded_session_prefix",
+ "tool_key",
+ "tool_prefix",
+ "tool_session_key",
+ "tool_session_prefix",
+ "SKILLS_LEGACY_MIGRATION_STATE_KEY",
+ "maybe_migrate_legacy_skill_state",
+ "marshal_loaded_order",
+ "parse_loaded_order",
+ "touch_loaded_order",
"SkillToolSet",
"Skill",
"SkillConfig",
"SkillFrontMatter",
- "SkillMetadata",
"SkillRequires",
"SkillResource",
"SkillSummary",
- "SkillWorkspaceInputRecord",
- "SkillWorkspaceMetadata",
- "SkillWorkspaceOutputRecord",
- "format_datetime",
- "parse_datetime",
- "ArchiveExt",
"ArchiveExtractor",
- "ArchiveKind",
- "CacheConfig",
- "FilePerm",
- "SizeLimit",
"SkillRootResolver",
- "TarPerm",
- "compute_dir_digest",
- "ensure_layout",
"get_state_delta",
- "load_metadata",
- "save_metadata",
"set_state_delta",
- "shell_quote",
+ "SkillLoadTool",
"SkillRunTool",
"skill_list",
"skill_list_docs",
"skill_list_tools",
- "skill_load",
"skill_select_docs",
"skill_select_tools",
]
diff --git a/trpc_agent_sdk/skills/_common.py b/trpc_agent_sdk/skills/_common.py
index a50f5a7..0eefc2c 100644
--- a/trpc_agent_sdk/skills/_common.py
+++ b/trpc_agent_sdk/skills/_common.py
@@ -24,6 +24,26 @@
from trpc_agent_sdk.context import InvocationContext
from trpc_agent_sdk.log import logger
+from ._constants import SkillLoadModeNames
+from ._skill_config import get_skill_load_mode
+from ._state_keys import docs_key
+from ._state_keys import docs_prefix
+from ._state_keys import docs_session_key
+from ._state_keys import docs_session_prefix
+from ._state_keys import loaded_key
+from ._state_keys import loaded_order_key
+from ._state_keys import loaded_prefix
+from ._state_keys import loaded_session_key
+from ._state_keys import loaded_session_order_key
+from ._state_keys import loaded_session_prefix
+from ._state_keys import tool_key
+from ._state_keys import tool_prefix
+from ._state_keys import tool_session_key
+from ._state_keys import tool_session_prefix
+from ._state_order import marshal_loaded_order
+from ._state_order import parse_loaded_order
+from ._state_order import touch_loaded_order
+
# Generic type for selection results
T = TypeVar('T', bound=BaseModel) # pylint: disable=invalid-name
@@ -59,8 +79,45 @@ class BaseSelectionResult(BaseModel):
mode: str = Field(default="", description="The mode used for selecting")
-def _set_state_delta(invocation_context: InvocationContext, key: str, value: Any) -> None:
- """Set the state delta for a key.
+def get_agent_name(invocation_context: InvocationContext) -> str:
+ """Get normalized agent name from invocation context."""
+ name = getattr(invocation_context, "agent_name", "")
+ return name.strip() if isinstance(name, str) else ""
+
+
+def use_session_skill_state(invocation_context: InvocationContext) -> bool:
+ """Whether skill-related state should use persistent (non-temp) keys."""
+ return get_skill_load_mode(invocation_context) == SkillLoadModeNames.SESSION
+
+
+def normalize_selection_mode(mode: str) -> str:
+ """Normalize selection mode; defaults to replace."""
+ m = (mode or "").strip().lower()
+ if m in ("add", "replace", "clear"):
+ return m
+ return "replace"
+
+
+def get_previous_selection_by_key(invocation_context: InvocationContext, key: str) -> tuple[list[str], bool]:
+ """Get previous selection by exact state key.
+
+ Returns:
+ (selected_items, had_all)
+ """
+ value = get_state_delta_value(invocation_context, key)
+ if not value:
+ return [], False
+ if value == "*":
+ return [], True
+ try:
+ parsed = json.loads(value)
+ except (json.JSONDecodeError, TypeError):
+ return [], False
+ return parsed if isinstance(parsed, list) else [], False
+
+
+def set_state_delta(invocation_context: InvocationContext, key: str, value: Any) -> None:
+ """Set the state delta of a skill workspace.
Args:
invocation_context: Invocation context
@@ -70,6 +127,35 @@ def _set_state_delta(invocation_context: InvocationContext, key: str, value: Any
invocation_context.actions.state_delta[key] = value
+def set_selection_state_delta_by_key(
+ invocation_context: InvocationContext,
+ key: str,
+ selected_items: list[str],
+ include_all: bool,
+) -> None:
+ """Set selection state delta by exact state key."""
+ if include_all:
+ set_state_delta(invocation_context, key, "*")
+ else:
+ set_state_delta(invocation_context, key, json.dumps(selected_items or []))
+
+
+def append_loaded_order_state_delta(
+ invocation_context: InvocationContext,
+ agent_name: str,
+ skill_name: str,
+) -> None:
+ """Append loaded-order state delta for skill touch order."""
+ key = loaded_order_key(agent_name)
+ if use_session_skill_state(invocation_context):
+ key = loaded_session_order_key(key)
+ current = parse_loaded_order(get_state_delta_value(invocation_context, key))
+ next_order = touch_loaded_order(current, skill_name)
+ encoded = marshal_loaded_order(next_order)
+ if encoded:
+ set_state_delta(invocation_context, key, encoded)
+
+
def get_previous_selection(invocation_context: InvocationContext, state_key_prefix: str,
skill_name: str) -> Optional[list[str]]:
"""Get the previous selection for a skill from session state.
@@ -191,12 +277,12 @@ def set_state_delta_for_selection(invocation_context: InvocationContext, state_k
include_all = getattr(result, 'include_all', False)
if include_all:
- _set_state_delta(invocation_context, key, '*')
+ set_state_delta(invocation_context, key, '*')
return
selected_items = getattr(result, 'selected_items', [])
selected_json = json.dumps(selected_items)
- _set_state_delta(invocation_context, key, selected_json)
+ set_state_delta(invocation_context, key, selected_json)
def generic_select_items(tool_context: InvocationContext, skill_name: str, items: Optional[list[str]],
@@ -354,3 +440,52 @@ def get_all_tools(skill_name):
except json.JSONDecodeError:
logger.warning("Failed to parse selection for skill '%s' with key '%s': %s", skill_name, key, v_str)
return []
+
+
+def loaded_state_key(ctx: InvocationContext, skill_name: str) -> str:
+ """Return the loaded-state key for a skill."""
+ agent_name = ctx.agent_name.strip()
+ load_keys = loaded_key(agent_name, skill_name)
+ return loaded_session_key(load_keys) if use_session_skill_state(ctx) else load_keys
+
+
+def docs_state_key(ctx: InvocationContext, skill_name: str) -> str:
+ """Return the docs-state key for a skill."""
+ agent_name = ctx.agent_name.strip()
+ key = docs_key(agent_name, skill_name)
+ return docs_session_key(key) if use_session_skill_state(ctx) else key
+
+
+def tool_state_key(ctx: InvocationContext, skill_name: str) -> str:
+ """Return the tools-state key for a skill."""
+ agent_name = ctx.agent_name.strip()
+ key = tool_key(agent_name, skill_name)
+ return tool_session_key(key) if use_session_skill_state(ctx) else key
+
+
+def loaded_scan_prefix(ctx: InvocationContext) -> str:
+ """Return the loaded-state scan prefix for an agent."""
+ agent_name = ctx.agent_name.strip()
+ key = loaded_prefix(agent_name)
+ return loaded_session_prefix(key) if use_session_skill_state(ctx) else key
+
+
+def docs_scan_prefix(ctx: InvocationContext) -> str:
+ """Return the docs-state scan prefix for an agent."""
+ agent_name = ctx.agent_name.strip()
+ key = docs_prefix(agent_name)
+ return docs_session_prefix(key) if use_session_skill_state(ctx) else key
+
+
+def tool_scan_prefix(ctx: InvocationContext) -> str:
+ """Return the tools-state scan prefix for an agent."""
+ agent_name = ctx.agent_name.strip()
+ key = tool_prefix(agent_name)
+ return tool_session_prefix(key) if use_session_skill_state(ctx) else key
+
+
+def loaded_order_state_key(ctx: InvocationContext) -> str:
+ """Return the loaded-order state key for an agent."""
+ agent_name = ctx.agent_name.strip()
+ key = loaded_order_key(agent_name)
+ return loaded_session_order_key(key) if use_session_skill_state(ctx) else key
diff --git a/trpc_agent_sdk/skills/_constants.py b/trpc_agent_sdk/skills/_constants.py
index 292fba8..9a36626 100644
--- a/trpc_agent_sdk/skills/_constants.py
+++ b/trpc_agent_sdk/skills/_constants.py
@@ -8,6 +8,8 @@
This module defines constants for the skills system.
"""
+from enum import Enum
+
SKILL_FILE = "SKILL.md"
# Environment variable name for skills root directory
@@ -24,15 +26,80 @@
SKILL_LOADED_STATE_KEY_PREFIX = "temp:skill:loaded:"
"""State key for loaded skills."""
+# State key for loaded skills scoped by agent
+SKILL_LOADED_BY_AGENT_STATE_KEY_PREFIX = "temp:skill:loaded_by_agent:"
+"""State key prefix for loaded skills scoped by agent."""
+
# State key for docs of skills
SKILL_DOCS_STATE_KEY_PREFIX = "temp:skill:docs:"
"""State key for docs of skills."""
+# State key for docs of skills scoped by agent
+SKILL_DOCS_BY_AGENT_STATE_KEY_PREFIX = "temp:skill:docs_by_agent:"
+"""State key prefix for docs scoped by agent."""
+
+# State key for loaded skill order
+SKILL_LOADED_ORDER_STATE_KEY_PREFIX = "temp:skill:loaded_order:"
+"""State key prefix for loaded skill touch order."""
+
+# State key for loaded skill order scoped by agent
+SKILL_LOADED_ORDER_BY_AGENT_STATE_KEY_PREFIX = "temp:skill:loaded_order_by_agent:"
+"""State key prefix for loaded skill touch order scoped by agent."""
+
# State key for tools of skills
SKILL_TOOLS_STATE_KEY_PREFIX = "temp:skill:tools:"
"""State key for tools of skills."""
+# State key for tools of skills scoped by agent
+SKILL_TOOLS_BY_AGENT_STATE_KEY_PREFIX = "temp:skill:tools_by_agent:"
+"""State key prefix for tools scoped by agent."""
+
+# State key for per-tool-call artifact refs (replay support)
+SKILL_ARTIFACTS_STATE_KEY = "temp:skill:artifacts"
+"""State key for skill tool-call artifact references."""
+
# EnvSkillsCacheDir overrides where URL-based skills roots are cached.
# When empty, the user cache directory is used.
ENV_SKILLS_CACHE_DIR = "SKILLS_CACHE_DIR"
"""Environment variable name for skills cache directory."""
+
+
+class SkillProfileNames(str, Enum):
+ FULL = "full"
+ KNOWLEDGE_ONLY = "knowledge_only"
+
+ def __str__(self) -> str:
+ return self.value
+
+
+class SkillToolsNames(str, Enum):
+ LOAD = "skill_load"
+ SELECT_DOCS = "skill_select_docs"
+ LIST_DOCS = "skill_list_docs"
+ RUN = "skill_run"
+ SELECT_TOOLS = "skill_select_tools"
+ LIST_SKILLS = "skill_list_skills"
+ EXEC = "skill_exec"
+ WRITE_STDIN = "skill_write_stdin"
+ POLL_SESSION = "skill_poll_session"
+ KILL_SESSION = "skill_kill_session"
+
+ def __str__(self) -> str:
+ return self.value
+
+
+SKILL_TOOLS_NAMES = [tool.value for tool in SkillToolsNames.__members__.values()]
+
+
+class SkillLoadModeNames(str, Enum):
+ ONCE = "once"
+ TURN = "turn"
+ SESSION = "session"
+
+ def __str__(self) -> str:
+ return self.value
+
+
+SKILL_LOAD_MODE_VALUES = [mode.value for mode in SkillLoadModeNames.__members__.values()]
+
+SKILL_CONFIG_KEY = "__trpc_agent_skills_config"
diff --git a/trpc_agent_sdk/skills/_dynamic_toolset.py b/trpc_agent_sdk/skills/_dynamic_toolset.py
index ba648cf..7d11ca5 100644
--- a/trpc_agent_sdk/skills/_dynamic_toolset.py
+++ b/trpc_agent_sdk/skills/_dynamic_toolset.py
@@ -25,8 +25,9 @@
from trpc_agent_sdk.tools import get_tool
from trpc_agent_sdk.tools import get_tool_set
-from ._constants import SKILL_LOADED_STATE_KEY_PREFIX
-from ._constants import SKILL_TOOLS_STATE_KEY_PREFIX
+from ._common import loaded_scan_prefix
+from ._common import tool_scan_prefix
+from ._common import tool_state_key
from ._repository import BaseSkillRepository
from ._utils import get_state_delta
@@ -130,6 +131,7 @@ def __init__(self,
self._find_tool_by_name(tool)
else:
self._find_tool_by_type(tool)
+
logger.info("DynamicSkillToolSet initialized: %s tools, %s toolsets, only_active_skills=%s",
len(self._available_tools), len(self._available_toolsets), only_active_skills)
@@ -266,9 +268,8 @@ async def get_tools(self, ctx: InvocationContext) -> List[BaseTool]:
tool = await self._resolve_tool(tool_name, ctx)
if tool is None:
logger.warning(
- "Tool '%s' required by skill '%s' could not be resolved. Checked: available_tools (%s), "
- "available_toolsets (%s), global registry", tool_name, skill_name, len(self._available_tools),
- len(self._available_toolsets))
+ "Tool '%s' required by skill '%s' could not be resolved. Checked: available_tools (%s), available_toolsets (%s), global registry",
+ tool_name, skill_name, len(self._available_tools), len(self._available_toolsets))
continue
selected_tools.append(tool)
@@ -293,13 +294,14 @@ def _get_loaded_skills_from_state(self, ctx: InvocationContext) -> List[str]:
List of loaded skill names
"""
loaded_skills: List[str] = []
- # Combine session state and current state delta
state = dict(ctx.session_state.copy())
state.update(ctx.actions.state_delta)
-
+ prefix = loaded_scan_prefix(ctx)
for key, value in state.items():
- if key.startswith(SKILL_LOADED_STATE_KEY_PREFIX) and value:
- skill_name = key[len(SKILL_LOADED_STATE_KEY_PREFIX):]
+ if not value or not key.startswith(prefix):
+ continue
+ skill_name = key[len(prefix):].strip()
+ if skill_name:
loaded_skills.append(skill_name)
return loaded_skills
@@ -321,18 +323,20 @@ def _get_active_skills_from_delta(self, ctx: InvocationContext) -> List[str]:
"""
active_skills: set[str] = set()
- # Check state_delta for skill-related changes
+ loaded_state_prefix = loaded_scan_prefix(ctx)
+ tools_state_prefix = tool_scan_prefix(ctx)
+
for key, value in ctx.actions.state_delta.items():
- if key.startswith(SKILL_LOADED_STATE_KEY_PREFIX) and value:
- # Skill was just loaded
- skill_name = key[len(SKILL_LOADED_STATE_KEY_PREFIX):]
- active_skills.add(skill_name)
- logger.debug("Skill '%s' is active (just loaded)", skill_name)
- elif key.startswith(SKILL_TOOLS_STATE_KEY_PREFIX):
- # Skill's tools were just selected/modified
- skill_name = key[len(SKILL_TOOLS_STATE_KEY_PREFIX):]
- active_skills.add(skill_name)
- logger.debug("Skill '%s' is active (tools modified)", skill_name)
+ if key.startswith(loaded_state_prefix) and value:
+ skill_name = key[len(loaded_state_prefix):].strip()
+ if skill_name:
+ active_skills.add(skill_name)
+ logger.debug("Skill '%s' is active (just loaded)", skill_name)
+ if key.startswith(tools_state_prefix):
+ skill_name = key[len(tools_state_prefix):].strip()
+ if skill_name:
+ active_skills.add(skill_name)
+ logger.debug("Skill '%s' is active (tools modified)", skill_name)
return list(active_skills)
@@ -346,7 +350,7 @@ def _get_tools_selection(self, ctx: InvocationContext, skill_name: str) -> list[
Returns:
List of selected tool names
"""
- key = SKILL_TOOLS_STATE_KEY_PREFIX + skill_name
+ key = tool_state_key(ctx, skill_name)
v = get_state_delta(ctx, key)
if not v:
# Fallback to SKILL.md defaults when explicit selection state is absent.
diff --git a/trpc_agent_sdk/skills/_hot_reload.py b/trpc_agent_sdk/skills/_hot_reload.py
new file mode 100644
index 0000000..270c76c
--- /dev/null
+++ b/trpc_agent_sdk/skills/_hot_reload.py
@@ -0,0 +1,163 @@
+# Tencent is pleased to support the open source community by making tRPC-Agent-Python available.
+#
+# Copyright (C) 2026 Tencent. All rights reserved.
+#
+# tRPC-Agent-Python is licensed under Apache-2.0.
+"""Skill hot-reload helpers.
+
+This module encapsulates filesystem event handling and dirty-directory queues
+for skill repositories. Repository implementations can stay focused on indexing
+while delegating watcher integration here.
+"""
+
+from __future__ import annotations
+
+import importlib
+import threading
+from pathlib import Path
+from typing import Callable
+
+from trpc_agent_sdk.log import logger
+
+
+class SkillHotReloadTracker:
+ """Track changed skill directories and optional watchdog observer state."""
+
+ def __init__(self, skill_file_name: str):
+ self._skill_file_name = skill_file_name
+ self._watchdog_init_attempted = False
+ self._watchdog_observer: object | None = None
+ self._changed_dirs_by_root: dict[str, set[str]] = {}
+ self._changed_dirs_lock = threading.Lock()
+
+ def clear(self) -> None:
+ """Clear queued directory changes."""
+ with self._changed_dirs_lock:
+ self._changed_dirs_by_root = {}
+
+ def mark_changed_path(self, raw_path: str, is_directory: bool, skill_roots: list[str]) -> None:
+ """Mark a changed path as dirty for the next incremental reload."""
+ path = Path(raw_path)
+ if not is_directory and path.name.lower() != self._skill_file_name.lower():
+ return
+ target_dir = path if is_directory else path.parent
+ self.mark_changed_dir(target_dir, skill_roots)
+
+ def mark_changed_dir(self, path: Path, skill_roots: list[str]) -> None:
+ """Queue a changed directory; consumed by repository incremental scans."""
+ root_key = self.resolve_root_key(path, skill_roots)
+ if root_key is None:
+ return
+ with self._changed_dirs_lock:
+ self._changed_dirs_by_root.setdefault(root_key, set()).add(str(path.resolve(strict=False)))
+
+ def pop_changed_dirs(self, root_key: str) -> list[Path]:
+ """Pop and return queued changed directories for a root."""
+ with self._changed_dirs_lock:
+ raw_dirs = self._changed_dirs_by_root.pop(root_key, set())
+ return [Path(raw) for raw in sorted(raw_dirs)]
+
+ def collect_changed_dirs(
+ self,
+ root_key: str,
+ tracked_dirs: set[str],
+ dir_mtime_ns: dict[str, int],
+ mtime_reader: Callable[[Path], int],
+ ) -> list[Path]:
+ """Collect changed directories via event queue or mtime probing."""
+ changed_dirs = self.pop_changed_dirs(root_key)
+ if not changed_dirs:
+ for dir_key in sorted(tracked_dirs):
+ path = Path(dir_key)
+ if not path.exists():
+ continue
+ current_mtime = mtime_reader(path)
+ if dir_mtime_ns.get(dir_key) != current_mtime:
+ changed_dirs.append(path)
+ dir_mtime_ns[dir_key] = current_mtime
+ return self.normalize_scan_targets(changed_dirs)
+
+ @staticmethod
+ def resolve_root_key(path: Path, skill_roots: list[str]) -> str | None:
+ """Find matching root key for a path."""
+ resolved = path.resolve(strict=False)
+ seen_roots: set[str] = set()
+ for root in skill_roots:
+ if not root:
+ continue
+ root_path = Path(root).resolve()
+ root_key = str(root_path)
+ if root_key in seen_roots:
+ continue
+ seen_roots.add(root_key)
+ if resolved.is_relative_to(root_path):
+ return root_key
+ return None
+
+ @staticmethod
+ def normalize_scan_targets(changed_dirs: list[Path]) -> list[Path]:
+ """Drop nested paths when their parent is already queued."""
+ result: list[Path] = []
+ for candidate in sorted(changed_dirs, key=lambda p: len(p.parts)):
+ if any(candidate.is_relative_to(parent) for parent in result):
+ continue
+ result.append(candidate)
+ return result
+
+ def start_watcher_if_possible(self, skill_roots: list[str]) -> None:
+ """Start filesystem watcher for near real-time hot reload if available."""
+ if self._watchdog_init_attempted:
+ return
+ self._watchdog_init_attempted = True
+ try:
+ events_module = importlib.import_module("watchdog.events")
+ observers_module = importlib.import_module("watchdog.observers")
+ except ImportError:
+ logger.debug("watchdog is unavailable; skill hot reload falls back to mtime probing")
+ return
+ except Exception as ex: # pylint: disable=broad-except
+ logger.warning("Failed to initialize watchdog imports: %s", ex)
+ return
+
+ file_system_event_handler_cls = getattr(events_module, "FileSystemEventHandler", None)
+ observer_cls = getattr(observers_module, "Observer", None)
+ if file_system_event_handler_cls is None or observer_cls is None:
+ logger.warning("watchdog is installed but required symbols are missing")
+ return
+
+ tracker = self
+
+ class _SkillDirChangeHandler(file_system_event_handler_cls): # type: ignore[misc,valid-type]
+
+ def on_any_event(self, event) -> None: # type: ignore[no-untyped-def]
+ src_path = getattr(event, "src_path", None)
+ if src_path:
+ tracker.mark_changed_path(src_path, bool(event.is_directory), skill_roots)
+ dest_path = getattr(event, "dest_path", None)
+ if dest_path:
+ tracker.mark_changed_path(dest_path, bool(event.is_directory), skill_roots)
+
+ observer = observer_cls()
+ handler = _SkillDirChangeHandler()
+ seen_roots: set[str] = set()
+ for root in skill_roots:
+ if not root:
+ continue
+ root_path = Path(root).resolve()
+ root_key = str(root_path)
+ if root_key in seen_roots:
+ continue
+ seen_roots.add(root_key)
+ if not root_path.is_dir():
+ continue
+ try:
+ observer.schedule(handler, path=root_key, recursive=True)
+ except Exception as ex: # pylint: disable=broad-except
+ logger.warning("Failed to watch skill root %s: %s", root_key, ex)
+
+ if not observer.emitters:
+ logger.debug("No valid skill roots to watch; hot reload watcher not started")
+ return
+ observer.start()
+ self._watchdog_observer = observer
+ logger.debug("Skill hot reload watcher started with %d root(s)", len(observer.emitters))
diff --git a/trpc_agent_sdk/skills/_repository.py b/trpc_agent_sdk/skills/_repository.py
index 0877f4c..08189bf 100644
--- a/trpc_agent_sdk/skills/_repository.py
+++ b/trpc_agent_sdk/skills/_repository.py
@@ -18,9 +18,9 @@
import abc
import os
from pathlib import Path
+from typing import Callable
from typing import List
from typing import Optional
-from typing import Literal
from typing_extensions import override
import yaml
@@ -29,74 +29,16 @@
from trpc_agent_sdk.log import logger
from ._constants import SKILL_FILE
+from ._hot_reload import SkillHotReloadTracker
from ._types import Skill
from ._types import SkillResource
from ._types import SkillSummary
from ._url_root import SkillRootResolver
+from ._utils import is_doc_file
+from ._utils import is_script_file
BASE_DIR_PLACEHOLDER = "__BASE_DIR__"
-
-
-def _split_front_matter(content: str) -> tuple[dict[str, str], str]:
- """Split markdown into (front matter dict, body) with optional YAML front matter."""
- text = content.replace("\r\n", "\n")
- if not text.startswith("---\n"):
- return {}, text
- idx = text.find("\n---\n", 4)
- if idx < 0:
- return {}, text
- raw_yaml = text[4:idx]
- body = text[idx + 5:]
- try:
- parsed = yaml.safe_load(raw_yaml) or {}
- if not isinstance(parsed, dict):
- return {}, body
- except Exception:
- return {}, body
- out: dict[str, str] = {}
- for k, v in parsed.items():
- key = str(k).strip()
- if not key:
- continue
- if v is None:
- out[key] = ""
- else:
- out[key] = str(v)
- return out, body
-
-
-def _parse_tools_from_body(body: str) -> list[str]:
- """Parse tool names from the Tools section in body text."""
- tool_names: list[str] = []
- in_tools_section = False
- for line in body.split("\n"):
- stripped = line.strip()
- if stripped.lower().startswith("tools:"):
- in_tools_section = True
- continue
- if not in_tools_section:
- continue
- if stripped and not stripped.startswith("-") and not stripped.startswith("#"):
- if ":" in stripped or (stripped[0].isupper() and any(
- stripped.startswith(s) for s in ["Overview", "Examples", "Usage", "Description", "Installation"])):
- break
- if stripped.startswith("#"):
- continue
- if stripped.startswith("-"):
- tool_name = stripped[1:].strip()
- if tool_name and not tool_name.startswith("#"):
- tool_names.append(tool_name)
- return tool_names
-
-
-def _is_doc_file(name: str) -> bool:
- name_lower = name.lower()
- return name_lower.endswith(".md") or name_lower.endswith(".txt")
-
-
-def _read_skill_file(path: Path) -> tuple[dict[str, str], str]:
- content = path.read_text(encoding="utf-8")
- return _split_front_matter(content)
+VisibilityFilter = Callable[[SkillSummary], bool]
class BaseSkillRepository(abc.ABC):
@@ -107,13 +49,19 @@ class BaseSkillRepository(abc.ABC):
must satisfy. Parsing internals are left entirely to subclasses.
"""
- def __init__(self, workspace_runtime: BaseWorkspaceRuntime):
+ def __init__(self, workspace_runtime: BaseWorkspaceRuntime, visibility_filter: VisibilityFilter | None = None):
self._workspace_runtime = workspace_runtime
+ self._visibility_filter = visibility_filter
@property
def workspace_runtime(self) -> BaseWorkspaceRuntime:
return self._workspace_runtime
+ @property
+ def visibility_filter(self) -> VisibilityFilter | None:
+ """Return the filter function."""
+ return self._visibility_filter
+
def user_prompt(self) -> str:
return ""
@@ -132,18 +80,8 @@ def get(self, name: str) -> Skill:
raise NotImplementedError
@abc.abstractmethod
- def skill_list(self, mode: Literal["all", "enabled", "disabled"] = "all") -> list[str]:
- """Return the names of all indexed skills.
-
- Args:
- mode: The mode to list the skills.
- - all: List all skills.
- - enabled: List enabled skills.
- - disabled: List disabled skills.
-
- Returns:
- A list of skill names.
- """
+ def skill_list(self, mode: str = 'all') -> list[str]:
+ """Return the names of all indexed skills."""
raise NotImplementedError
@abc.abstractmethod
@@ -163,8 +101,24 @@ def refresh(self) -> None:
def skill_run_env(self, skill_name: str) -> dict[str, str]:
"""Return the environment variables for the given skill.
"""
+ if self._visibility_filter is not None:
+ if not self._skill_visible_by_name(skill_name, self.summaries()):
+ raise ValueError(f"skill '{skill_name}' not found")
return {}
+ def _filter_summaries(
+ self,
+ summaries: list[SkillSummary],
+ ) -> list[SkillSummary]:
+ if not summaries:
+ return []
+ if self._visibility_filter is None:
+ return [SkillSummary(name=s.name, description=s.description) for s in summaries]
+ return [s for s in summaries if self._visibility_filter(s)]
+
+ def _skill_visible_by_name(self, name: str, summaries: list[SkillSummary]) -> bool:
+ return any(summary.name == name for summary in summaries)
+
class FsSkillRepository(BaseSkillRepository):
"""
@@ -181,6 +135,7 @@ def __init__(
*roots: str,
workspace_runtime: Optional[BaseWorkspaceRuntime] = None,
resolver: Optional[SkillRootResolver] = None,
+ enable_hot_reload: bool = False,
):
"""
Create a FsSkillRepository scanning the given roots.
@@ -190,6 +145,7 @@ def __init__(
``file://`` URLs, or ``http(s)://`` archive URLs).
workspace_runtime: Optional workspace runtime to use.
resolver: Optional skill root resolver to use.
+ enable_hot_reload: Whether to enable skill hot reload checks.
"""
if workspace_runtime is None:
workspace_runtime = create_local_workspace_runtime()
@@ -197,6 +153,11 @@ def __init__(
self._resolver = resolver or SkillRootResolver()
self._skill_paths: dict[str, str] = {} # name -> base dir
self._all_descriptions: dict[str, str] = {} # name -> description
+ self._discovered_skill_files: set[str] = set()
+ self._tracked_dirs_by_root: dict[str, set[str]] = {}
+ self._dir_mtime_ns: dict[str, int] = {}
+ self._hot_reload_tracker = SkillHotReloadTracker(skill_file_name=SKILL_FILE, )
+ self._enable_hot_reload = enable_hot_reload
self._skill_roots: list[str] = []
flat_roots: list[str] = []
@@ -212,6 +173,11 @@ def __init__(
# Root resolution
# ------------------------------------------------------------------
+ @property
+ def hot_reload_enabled(self) -> bool:
+ """Whether hot reload checks are enabled."""
+ return self._enable_hot_reload
+
def _resolve_skill_roots(self, roots: list[str]) -> None:
"""
Resolve a skill root string to a local directory path.
@@ -238,44 +204,140 @@ def _index(self) -> None:
"""Scan all roots and index available skills."""
self._skill_paths = {}
self._all_descriptions = {}
+ self._discovered_skill_files = set()
+ self._tracked_dirs_by_root = {}
+ self._dir_mtime_ns = {}
+ self._hot_reload_tracker.clear()
seen: set[str] = set()
for root in self._skill_roots:
if not root:
continue
root_path = Path(root).resolve()
- if str(root_path) in seen:
+ root_key = str(root_path)
+ if root_key in seen:
continue
- seen.add(str(root_path))
+ seen.add(root_key)
try:
- for dirpath, _dirs, _files in os.walk(root_path):
- skill_file_path = Path(dirpath) / SKILL_FILE
- if not skill_file_path.is_file():
- continue
- try:
- self._index_one(dirpath, skill_file_path)
- except Exception as ex: # pylint: disable=broad-except
- logger.debug("Failed to index skill at %s: %s", skill_file_path, ex)
-
+ self._scan_root(root_path, root_key=root_key)
except Exception as ex: # pylint: disable=broad-except
logger.warning("Error scanning root %s: %s", root_path, ex)
+ if self._enable_hot_reload:
+ self._hot_reload_tracker.start_watcher_if_possible(self._skill_roots)
+
+ def _scan_root(self, root_path: Path, root_key: str, start_path: Optional[Path] = None) -> None:
+ """Scan a full root or one of its changed subtrees."""
+ target = start_path or root_path
+ for dirpath, _dirs, _files in os.walk(target):
+ self._track_dir(root_key, Path(dirpath))
+ skill_file_path = Path(dirpath) / SKILL_FILE
+ if not skill_file_path.is_file():
+ continue
+ try:
+ self._index_one(dirpath, skill_file_path)
+ except Exception as ex: # pylint: disable=broad-except
+ logger.debug("Failed to index skill at %s: %s dirpath: %s, _files: %s", skill_file_path, ex, _dirs,
+ _files)
+
+ def _track_dir(self, root_key: str, path: Path) -> None:
+ """Store directory metadata used by incremental hot reload."""
+ dir_key = str(path.resolve())
+ self._tracked_dirs_by_root.setdefault(root_key, set()).add(dir_key)
+ self._dir_mtime_ns[dir_key] = self._safe_mtime_ns(path)
- def _index_one(self, dirpath: str, skill_file_path: Path) -> None:
+ @staticmethod
+ def _safe_mtime_ns(path: Path) -> int:
+ try:
+ return path.stat().st_mtime_ns
+ except OSError:
+ return -1
+
+ def _mark_changed_dir_for_hot_reload(self, path: Path) -> None:
+ """Queue a changed directory; consumed by _scan_changed_dirs."""
+ if not self._enable_hot_reload:
+ return
+ self._hot_reload_tracker.mark_changed_dir(path, self._skill_roots)
+
+ def _index_one(self, dirpath: str, skill_file_path: Path) -> bool:
"""Index a single skill directory found at *dirpath*."""
- front_matter, _ = _read_skill_file(skill_file_path)
+ skill_file_key = str(skill_file_path.resolve())
+ if skill_file_key in self._discovered_skill_files:
+ return False
+ self._discovered_skill_files.add(skill_file_key)
+
+ front_matter, _ = self._read_skill_file(skill_file_path)
name = front_matter.get("name", "").strip()
if not name:
name = Path(dirpath).name.strip()
if not name:
- return
+ return False
# First occurrence wins.
if name in self._skill_paths:
- return
+ return False
self._all_descriptions[name] = front_matter.get("description", "").strip()
self._skill_paths[name] = dirpath
logger.debug("Found skill '%s' at %s", name, dirpath)
+ return True
+
+ def _scan_changed_dirs(self) -> None:
+ """Fast probe + incremental scan for newly added SKILL.md files."""
+ if not self._enable_hot_reload:
+ return
+ seen_roots: set[str] = set()
+ for root in self._skill_roots:
+ if not root:
+ continue
+ root_path = Path(root).resolve()
+ root_key = str(root_path)
+ if root_key in seen_roots:
+ continue
+ seen_roots.add(root_key)
+
+ if not root_path.is_dir():
+ continue
+
+ tracked_dirs = self._tracked_dirs_by_root.get(root_key, {root_key})
+ changed_dirs = self._hot_reload_tracker.collect_changed_dirs(
+ root_key=root_key,
+ tracked_dirs=tracked_dirs,
+ dir_mtime_ns=self._dir_mtime_ns,
+ mtime_reader=self._safe_mtime_ns,
+ )
+ if not changed_dirs:
+ continue
+
+ for target in changed_dirs:
+ self._scan_root(root_path=root_path, root_key=root_key, start_path=target)
+ self._prune_deleted_skills()
+
+ def _prune_deleted_skills(self) -> None:
+ """Remove indexed skills whose directory or SKILL.md no longer exists."""
+ removed_names: list[str] = []
+ for name, dirpath in list(self._skill_paths.items()):
+ skill_file_path = Path(dirpath) / SKILL_FILE
+ if skill_file_path.is_file():
+ continue
+ removed_names.append(name)
+
+ for name in removed_names:
+ dirpath = self._skill_paths.pop(name, "")
+ self._all_descriptions.pop(name, None)
+ if dirpath:
+ removed_file_key = str((Path(dirpath) / SKILL_FILE).resolve(strict=False))
+ self._discovered_skill_files.discard(removed_file_key)
+ logger.debug("Pruned deleted skill '%s' from repository index", name)
+
+ stale_files = {path for path in self._discovered_skill_files if not Path(path).is_file()}
+ if stale_files:
+ self._discovered_skill_files.difference_update(stale_files)
+
+ @classmethod
+ def _read_skill_file(cls, path: Path) -> tuple[dict[str, str], str]:
+ """Read the skill file and return the front matter and body."""
+ content = path.read_text(encoding="utf-8")
+ return cls.from_markdown(content)
# ------------------------------------------------------------------
# Public API
@@ -290,17 +352,24 @@ def path(self, name: str) -> str:
"""
key = name.strip()
if key not in self._skill_paths:
- raise ValueError(f"skill '{name}' not found")
+ logger.warning("skill '%s' not found, refreshing repository", name)
+ self.refresh()
+ if key not in self._skill_paths:
+ raise ValueError(f"skill '{name}' not found")
+ if self._visibility_filter is not None:
+ if not self._skill_visible_by_name(key, self.summaries()):
+ raise ValueError(f"skill '{key}' not found")
return self._skill_paths[key]
@override
def summaries(self) -> List[SkillSummary]:
"""Return summaries for all indexed skills, sorted by name."""
+ self._scan_changed_dirs()
out: list[SkillSummary] = []
for name in sorted(self._skill_paths):
skill_file_path = Path(self._skill_paths[name]) / SKILL_FILE
try:
- front_matter, _ = _read_skill_file(skill_file_path)
+ front_matter, _ = self._read_skill_file(skill_file_path)
summary = SkillSummary(
name=front_matter.get("name", "").strip(),
description=front_matter.get("description", "").strip(),
@@ -310,6 +379,8 @@ def summaries(self) -> List[SkillSummary]:
out.append(summary)
except Exception as ex: # pylint: disable=broad-except
logger.warning("Failed to parse summary for skill '%s': %s", name, ex)
+ if self._visibility_filter is not None:
+ out = self._filter_summaries(out)
return out
@override
@@ -323,13 +394,13 @@ def get(self, name: str) -> Skill:
ValueError: If the skill is not found.
"""
dir_path = Path(self.path(name))
- front_matter, body = _read_skill_file(dir_path / SKILL_FILE)
+ front_matter, body = self._read_skill_file(dir_path / SKILL_FILE)
skill = Skill()
skill.base_dir = str(dir_path)
skill.summary.name = front_matter.get("name", "").strip() or name
skill.summary.description = front_matter.get("description", "").strip()
skill.body = body
- skill.tools = _parse_tools_from_body(skill.body)
+ skill.tools = self._parse_tools_from_body(skill.body)
if skill.base_dir:
skill.body = skill.body.replace(BASE_DIR_PLACEHOLDER, skill.base_dir)
@@ -338,12 +409,20 @@ def get(self, name: str) -> Skill:
return skill
@override
- def skill_list(self, mode: Literal["all", "enabled", "disabled"] = "all") -> list[str]:
- """Return the names of all indexed skills, sorted."""
+ def skill_list(self, mode: str = 'all') -> list[str]:
+ """Return the names of all indexed skills, sorted.
+
+ Args:
+ mode: The mode to list skills.
+ Returns:
+ A list of skill names.
+ """
+ self._scan_changed_dirs()
return sorted(self._skill_paths)
@override
def refresh(self) -> None:
+ """Refresh the skill repository."""
self._index()
def _read_docs(self, dir_path: Path, base_dir: str) -> list[SkillResource]:
@@ -356,7 +435,7 @@ def _read_docs(self, dir_path: Path, base_dir: str) -> list[SkillResource]:
continue
if entry.name.lower() == SKILL_FILE.lower():
continue
- if not _is_doc_file(entry.name):
+ if not is_doc_file(entry.name) and not is_script_file(entry.name):
continue
try:
content = entry.read_text(encoding="utf-8")
@@ -380,7 +459,30 @@ def from_markdown(cls, content: str) -> tuple[dict[str, str], str]:
.. deprecated::
Prefer repository-native front matter splitting directly.
"""
- return _split_front_matter(content)
+ text = content.replace("\r\n", "\n")
+ if not text.startswith("---\n"):
+ return {}, text
+ idx = text.find("\n---\n", 4)
+ if idx < 0:
+ return {}, text
+ raw_yaml = text[4:idx]
+ body = text[idx + 5:]
+ try:
+ parsed = yaml.safe_load(raw_yaml) or {}
+ if not isinstance(parsed, dict):
+ return {}, body
+ except Exception:
+ return {}, body
+ out: dict[str, str] = {}
+ for k, v in parsed.items():
+ key = str(k).strip()
+ if not key:
+ continue
+ if v is None:
+ out[key] = ""
+ else:
+ out[key] = str(v)
+ return out, body
@staticmethod
def _parse_tools_from_body(body: str) -> list[str]:
@@ -389,18 +491,40 @@ def _parse_tools_from_body(body: str) -> list[str]:
.. deprecated::
Prefer repository-native tool parser directly.
"""
- return _parse_tools_from_body(body)
+ tool_names: list[str] = []
+ in_tools_section = False
+ for line in body.split("\n"):
+ stripped = line.strip()
+ if stripped.lower().startswith("tools:"):
+ in_tools_section = True
+ continue
+ if not in_tools_section:
+ continue
+ if stripped and not stripped.startswith("-") and not stripped.startswith("#"):
+ if ":" in stripped or (stripped[0].isupper() and any(
+ stripped.startswith(s)
+ for s in ["Overview", "Examples", "Usage", "Description", "Installation"])):
+ break
+ if stripped.startswith("#"):
+ continue
+ if stripped.startswith("-"):
+ tool_name = stripped[1:].strip()
+ if tool_name and not tool_name.startswith("#"):
+ tool_names.append(tool_name)
+ return tool_names
def create_default_skill_repository(
*roots: str,
workspace_runtime: Optional[BaseWorkspaceRuntime] = None,
+ enable_hot_reload: bool = True,
) -> FsSkillRepository:
"""Create a new filesystem skill repository.
Args:
roots: Root directories (or URLs) to scan for skills.
workspace_runtime: Optional workspace runtime.
+ enable_hot_reload: Whether to enable skill hot reload checks.
Returns:
A configured :class:`FsSkillRepository`.
"""
@@ -409,4 +533,5 @@ def create_default_skill_repository(
return FsSkillRepository(
*roots,
workspace_runtime=workspace_runtime,
+ enable_hot_reload=enable_hot_reload,
)
diff --git a/trpc_agent_sdk/skills/_skill_config.py b/trpc_agent_sdk/skills/_skill_config.py
new file mode 100644
index 0000000..00f15fe
--- /dev/null
+++ b/trpc_agent_sdk/skills/_skill_config.py
@@ -0,0 +1,60 @@
+# Tencent is pleased to support the open source community by making tRPC-Agent-Python available.
+#
+# Copyright (C) 2026 Tencent. All rights reserved.
+#
+# tRPC-Agent-Python is licensed under Apache-2.0.
+
+from typing import Any
+
+from trpc_agent_sdk.context import AgentContext
+from trpc_agent_sdk.context import InvocationContext
+
+from ._constants import SKILL_CONFIG_KEY
+from ._constants import SKILL_LOAD_MODE_VALUES
+from ._constants import SkillLoadModeNames
+
+DEFAULT_SKILL_CONFIG = {
+ "skill_processor": {
+ "load_mode": "turn",
+ "tooling_guidance": "",
+ "tool_result_mode": False,
+ "tool_profile": "full",
+ "forbidden_tools": [],
+ "tool_flags": None,
+ "exec_tools_disabled": False,
+ "repo_resolver": None,
+ "max_loaded_skills": 0,
+ },
+ "workspace_exec_processor": {
+ "session_tools": False,
+ "has_skills_repo": False,
+ "repo_resolver": None,
+ "enabled_resolver": None,
+ "sessions_resolver": None,
+ },
+ "skills_tool_result_processor": {
+ "skip_fallback_on_session_summary": True,
+ "repo_resolver": None,
+ "tool_result_mode": False,
+ },
+}
+
+
+def get_skill_config(agent_context: AgentContext) -> dict[str, Any]:
+ return agent_context.get_metadata(SKILL_CONFIG_KEY, DEFAULT_SKILL_CONFIG)
+
+
+def set_skill_config(agent_context: AgentContext, config: dict[str, Any] = DEFAULT_SKILL_CONFIG) -> None:
+ agent_context.with_metadata(SKILL_CONFIG_KEY, config)
+
+
+def get_skill_load_mode(ctx: InvocationContext) -> str:
+ skill_config = get_skill_config(ctx.agent_context)
+ load_mode = skill_config["skill_processor"].get("load_mode", SkillLoadModeNames.TURN.value)
+ if load_mode not in SKILL_LOAD_MODE_VALUES:
+ load_mode = SkillLoadModeNames.TURN.value
+ return str(load_mode)
+
+
+def is_exist_skill_config(agent_context: AgentContext) -> bool:
+ return SKILL_CONFIG_KEY in agent_context.metadata
diff --git a/trpc_agent_sdk/skills/_skill_profile.py b/trpc_agent_sdk/skills/_skill_profile.py
new file mode 100644
index 0000000..e1f07da
--- /dev/null
+++ b/trpc_agent_sdk/skills/_skill_profile.py
@@ -0,0 +1,120 @@
+# Tencent is pleased to support the open source community by making tRPC-Agent-Python available.
+#
+# Copyright (C) 2026 Tencent. All rights reserved.
+#
+# tRPC-Agent-Python is licensed under Apache-2.0.
+"""Skill tool profile options and flags."""
+
+from __future__ import annotations
+
+from dataclasses import dataclass
+from typing import Optional
+
+from ._constants import SKILL_TOOLS_NAMES
+from ._constants import SkillProfileNames
+from ._constants import SkillToolsNames
+
+
+@dataclass
+class SkillProfileFlags:
+ """Built-in skill tool flags"""
+ load: bool = False
+ select_docs: bool = False
+ list_docs: bool = False
+ run: bool = False
+ select_tools: bool = False
+ list_skills: bool = False
+ exec: bool = False
+ write_stdin: bool = False
+ poll_session: bool = False
+ kill_session: bool = False
+
+ @classmethod
+ def normalize_profile(cls, profile: str) -> str:
+ p = (profile or "").strip().lower()
+ if p == str(SkillProfileNames.KNOWLEDGE_ONLY):
+ return str(SkillProfileNames.KNOWLEDGE_ONLY)
+ return str(SkillProfileNames.FULL)
+
+ @classmethod
+ def normalize_tool(cls, name: str) -> str:
+ return (name or "").strip().lower()
+
+ @classmethod
+ def preset_flags(cls, profile: str, forbidden_tools: Optional[list[str]] = None) -> "SkillProfileFlags":
+ normalized = cls.normalize_profile(profile)
+ if normalized == SkillProfileNames.KNOWLEDGE_ONLY:
+ return cls(load=True, select_docs=True, list_docs=True)
+
+ flags = cls(
+ load=True,
+ select_docs=True,
+ list_docs=True,
+ run=True,
+ exec=True,
+ write_stdin=True,
+ poll_session=True,
+ kill_session=True,
+ select_tools=True,
+ list_skills=True,
+ )
+ if forbidden_tools:
+ flags = flags.flags_from_forbidden_tools(forbidden_tools, flags)
+ return flags
+
+ @classmethod
+ def flags_from_forbidden_tools(cls, forbidden_tools: list[str], flags: "SkillProfileFlags") -> "SkillProfileFlags":
+ for raw in forbidden_tools:
+ name = cls.normalize_tool(raw)
+ if name in SKILL_TOOLS_NAMES:
+ flags_mem = name[name.index("_") + 1:]
+ setattr(flags, flags_mem, False)
+ flags.validate()
+ return flags
+
+ @classmethod
+ def resolve_flags(cls, profile: str, forbidden_tools: Optional[list[str]] = None) -> "SkillProfileFlags":
+ flags = cls.preset_flags(profile, forbidden_tools)
+ flags.validate()
+ return flags
+
+ def validate(self) -> None:
+ if self.exec and not self.run:
+ raise ValueError(f"{SkillToolsNames.EXEC} requires {SkillToolsNames.RUN}")
+ if self.write_stdin and not self.exec:
+ raise ValueError(f"{SkillToolsNames.WRITE_STDIN} requires {SkillToolsNames.EXEC}")
+ if self.poll_session and not self.exec:
+ raise ValueError(f"{SkillToolsNames.POLL_SESSION} requires {SkillToolsNames.EXEC}")
+ if self.kill_session and not self.exec:
+ raise ValueError(f"{SkillToolsNames.KILL_SESSION} requires {SkillToolsNames.EXEC}")
+
+ def is_any(self) -> bool:
+ return (self.load or self.select_docs or self.list_docs or self.run or self.exec or self.write_stdin
+ or self.poll_session or self.kill_session or self.select_tools or self.list_skills)
+
+ def has_knowledge_tools(self) -> bool:
+ return self.load or self.select_docs or self.list_docs
+
+ def has_doc_helpers(self) -> bool:
+ return self.select_docs or self.list_docs
+
+ def has_select_tools(self) -> bool:
+ return self.select_tools
+
+ def requires_execution_tools(self) -> bool:
+ return self.run or self.exec or self.write_stdin or self.poll_session or self.kill_session
+
+ def requires_exec_session_tools(self) -> bool:
+ return self.exec or self.write_stdin or self.poll_session or self.kill_session
+
+ def without_interactive_execution(self) -> "SkillProfileFlags":
+ return SkillProfileFlags(
+ load=self.load,
+ select_docs=self.select_docs,
+ list_docs=self.list_docs,
+ run=self.run,
+ exec=False,
+ write_stdin=False,
+ poll_session=False,
+ kill_session=False,
+ )
diff --git a/trpc_agent_sdk/skills/_state_keys.py b/trpc_agent_sdk/skills/_state_keys.py
new file mode 100644
index 0000000..586f530
--- /dev/null
+++ b/trpc_agent_sdk/skills/_state_keys.py
@@ -0,0 +1,151 @@
+# Tencent is pleased to support the open source community by making tRPC-Agent-Python available.
+#
+# Copyright (C) 2026 Tencent. All rights reserved.
+#
+# tRPC-Agent-Python is licensed under Apache-2.0.
+"""State key builders for skills."""
+
+from __future__ import annotations
+
+from urllib.parse import quote
+
+from ._constants import SKILL_DOCS_BY_AGENT_STATE_KEY_PREFIX
+from ._constants import SKILL_DOCS_STATE_KEY_PREFIX
+from ._constants import SKILL_LOADED_BY_AGENT_STATE_KEY_PREFIX
+from ._constants import SKILL_LOADED_ORDER_BY_AGENT_STATE_KEY_PREFIX
+from ._constants import SKILL_LOADED_ORDER_STATE_KEY_PREFIX
+from ._constants import SKILL_LOADED_STATE_KEY_PREFIX
+from ._constants import SKILL_TOOLS_BY_AGENT_STATE_KEY_PREFIX
+from ._constants import SKILL_TOOLS_STATE_KEY_PREFIX
+
+_STATE_KEY_SCOPE_DELIMITER = "/"
+_STATE_KEY_TEMP_PREFIX = "temp:"
+
+
+def _escape_scope_segment(value: str) -> str:
+ """Escape a scoped-key segment when it contains the delimiter."""
+ if _STATE_KEY_SCOPE_DELIMITER in value:
+ return quote(value, safe="")
+ return value
+
+
+def to_persistent_prefix(prefix: str) -> str:
+ """Convert temp-prefixed state prefix to persistent prefix."""
+ if prefix.startswith(_STATE_KEY_TEMP_PREFIX):
+ return prefix[len(_STATE_KEY_TEMP_PREFIX):]
+ return prefix
+
+
+def loaded_key(agent_name: str, skill_name: str) -> str:
+ """Return the loaded-state key for a skill.
+
+ When ``agent_name`` is empty, fallback to the legacy unscoped key.
+ """
+ agent_name = agent_name.strip()
+ skill_name = skill_name.strip()
+ if not agent_name:
+ return f"{SKILL_LOADED_STATE_KEY_PREFIX}{skill_name}"
+ return (f"{SKILL_LOADED_BY_AGENT_STATE_KEY_PREFIX}"
+ f"{_escape_scope_segment(agent_name)}"
+ f"{_STATE_KEY_SCOPE_DELIMITER}{skill_name}")
+
+
+def loaded_session_key(keys: str) -> str:
+ """Return persistent loaded-state key for session mode."""
+ return to_persistent_prefix(keys)
+
+
+def docs_key(agent_name: str, skill_name: str) -> str:
+ """Return the docs-state key for a skill.
+
+ When ``agent_name`` is empty, fallback to the legacy unscoped key.
+ """
+ agent_name = agent_name.strip()
+ skill_name = skill_name.strip()
+ if not agent_name:
+ return f"{SKILL_DOCS_STATE_KEY_PREFIX}{skill_name}"
+ return (f"{SKILL_DOCS_BY_AGENT_STATE_KEY_PREFIX}"
+ f"{_escape_scope_segment(agent_name)}"
+ f"{_STATE_KEY_SCOPE_DELIMITER}{skill_name}")
+
+
+def docs_session_key(keys: str) -> str:
+ """Return persistent docs-state key for session mode."""
+ return to_persistent_prefix(keys)
+
+
+def tool_key(agent_name: str, skill_name: str) -> str:
+ """Return the tools-state key for a skill.
+
+ When ``agent_name`` is empty, fallback to the legacy unscoped key.
+ """
+ agent_name = agent_name.strip()
+ skill_name = skill_name.strip()
+ if not agent_name:
+ return f"{SKILL_TOOLS_STATE_KEY_PREFIX}{skill_name}"
+ return (f"{SKILL_TOOLS_BY_AGENT_STATE_KEY_PREFIX}"
+ f"{_escape_scope_segment(agent_name)}"
+ f"{_STATE_KEY_SCOPE_DELIMITER}{skill_name}")
+
+
+def tool_session_key(keys: str) -> str:
+ """Return persistent tools-state key for session mode."""
+ return to_persistent_prefix(keys)
+
+
+def loaded_prefix(agent_name: str) -> str:
+ """Return the loaded-state scan prefix for an agent."""
+ agent_name = agent_name.strip()
+ if not agent_name:
+ return SKILL_LOADED_STATE_KEY_PREFIX
+ return (f"{SKILL_LOADED_BY_AGENT_STATE_KEY_PREFIX}"
+ f"{_escape_scope_segment(agent_name)}"
+ f"{_STATE_KEY_SCOPE_DELIMITER}")
+
+
+def loaded_session_prefix(keys: str) -> str:
+ """Return persistent loaded-state scan prefix for session mode."""
+ return to_persistent_prefix(keys)
+
+
+def docs_prefix(agent_name: str) -> str:
+ """Return the docs-state scan prefix for an agent."""
+ agent_name = agent_name.strip()
+ if not agent_name:
+ return SKILL_DOCS_STATE_KEY_PREFIX
+ return (f"{SKILL_DOCS_BY_AGENT_STATE_KEY_PREFIX}"
+ f"{_escape_scope_segment(agent_name)}"
+ f"{_STATE_KEY_SCOPE_DELIMITER}")
+
+
+def docs_session_prefix(keys: str) -> str:
+ """Return persistent docs-state scan prefix for session mode."""
+ return to_persistent_prefix(keys)
+
+
+def tool_prefix(agent_name: str) -> str:
+ """Return the tools-state scan prefix for an agent."""
+ agent_name = agent_name.strip()
+ if not agent_name:
+ return SKILL_TOOLS_STATE_KEY_PREFIX
+ return (f"{SKILL_TOOLS_BY_AGENT_STATE_KEY_PREFIX}"
+ f"{_escape_scope_segment(agent_name)}"
+ f"{_STATE_KEY_SCOPE_DELIMITER}")
+
+
+def tool_session_prefix(keys: str) -> str:
+ """Return persistent tools-state scan prefix for session mode."""
+ return to_persistent_prefix(keys)
+
+
+def loaded_order_key(agent_name: str) -> str:
+ """Return the loaded-order key for an agent."""
+ agent_name = agent_name.strip()
+ if not agent_name:
+ return SKILL_LOADED_ORDER_STATE_KEY_PREFIX
+ return f"{SKILL_LOADED_ORDER_BY_AGENT_STATE_KEY_PREFIX}{_escape_scope_segment(agent_name)}"
+
+
+def loaded_session_order_key(keys: str) -> str:
+ """Return persistent loaded-order key for session mode."""
+ return to_persistent_prefix(keys)
diff --git a/trpc_agent_sdk/skills/_state_migration.py b/trpc_agent_sdk/skills/_state_migration.py
new file mode 100644
index 0000000..90feb15
--- /dev/null
+++ b/trpc_agent_sdk/skills/_state_migration.py
@@ -0,0 +1,165 @@
+# Tencent is pleased to support the open source community by making tRPC-Agent-Python available.
+#
+# Copyright (C) 2026 Tencent. All rights reserved.
+#
+# tRPC-Agent-Python is licensed under Apache-2.0.
+"""Legacy skill state migration utilities.
+
+- migrate legacy unscoped skill state keys once per session
+- infer skill owners from historical tool responses
+- write new scoped keys and clear old legacy keys
+"""
+
+from __future__ import annotations
+
+import re
+from typing import Any
+from typing import Callable
+
+from trpc_agent_sdk.context import InvocationContext
+from trpc_agent_sdk.events import Event
+
+from ._constants import SKILL_DOCS_STATE_KEY_PREFIX
+from ._constants import SKILL_LOADED_STATE_KEY_PREFIX
+from ._constants import SkillToolsNames
+from ._state_keys import docs_key
+from ._state_keys import loaded_key
+
+SKILLS_LEGACY_MIGRATION_STATE_KEY = "processor:skills:legacy_migrated"
+
+ScopedKeysBuilder = Callable[[str, str], str]
+
+
+def _state_has_key(ctx: InvocationContext, key: str) -> bool:
+ if key in ctx.actions.state_delta:
+ return True
+ return key in ctx.session.state
+
+
+def _snapshot_state(ctx: InvocationContext) -> dict[str, Any]:
+ state = dict(ctx.session.state or {})
+ for k, v in ctx.actions.state_delta.items():
+ if v is None:
+ state.pop(k, None)
+ else:
+ state[k] = v
+ return state
+
+
+def _migrate_legacy_state_key(
+ ctx: InvocationContext,
+ state: dict[str, Any],
+ delta: dict[str, Any],
+ legacy_key: str,
+ legacy_val: Any,
+ skill_name: str,
+ owners: dict[str, str],
+ build_keys: ScopedKeysBuilder,
+) -> None:
+ name = (skill_name or "").strip()
+ if not name:
+ return
+ # Skip already scoped entries.
+ if ":" in name:
+ return
+
+ owner = (owners.get(name, "") or "").strip()
+ if not owner:
+ owner = (getattr(ctx.agent, "name", "") or "").strip()
+ if not owner:
+ return
+
+ temp_key = build_keys(owner, name)
+ temp_existing = state.get(temp_key, None)
+ if temp_existing:
+ delta[legacy_key] = None
+ return
+
+ delta[temp_key] = legacy_val
+ delta[legacy_key] = None
+
+
+def _legacy_skill_owners(events: list[Event]) -> dict[str, str]:
+ owners: dict[str, str] = {}
+ for ev in reversed(events or []):
+ _add_owners_from_event(ev, owners)
+ return owners
+
+
+def _add_owners_from_event(ev: Event, owners: dict[str, str]) -> None:
+ if not ev or not ev.content or not ev.content.parts:
+ return
+ author = (ev.author or "").strip()
+ if not author:
+ return
+ for part in reversed(ev.content.parts):
+ fr = part.function_response
+ if not fr:
+ continue
+ tool_name = (fr.name or "").strip()
+ if tool_name not in (SkillToolsNames.LOAD, SkillToolsNames.SELECT_DOCS):
+ continue
+ skill_name = _skill_name_from_tool_response(fr.response)
+ if not skill_name or skill_name in owners:
+ continue
+ owners[skill_name] = author
+
+
+def _skill_name_from_tool_response(response: Any) -> str:
+ """Extract skill name from tool response payload."""
+ if isinstance(response, dict):
+ for key in ("skill", "skill_name", "name"):
+ value = response.get(key)
+ if isinstance(value, str) and value.strip():
+ return value.strip()
+ result = response.get("result")
+ if isinstance(result, str):
+ matched = re.search(r"skill\s+'([^']+)'\s+loaded", result)
+ if matched:
+ return matched.group(1).strip()
+ elif isinstance(response, str):
+ matched = re.search(r"skill\s+'([^']+)'\s+loaded", response)
+ if matched:
+ return matched.group(1).strip()
+ return ""
+
+
+def maybe_migrate_legacy_skill_state(ctx: InvocationContext) -> None:
+ """Migrate legacy skill state keys into scoped keys once.
+
+ This function is idempotent per session via
+ ``SKILLS_LEGACY_MIGRATION_STATE_KEY``.
+ """
+ if ctx is None or ctx.session is None:
+ return
+ if _state_has_key(ctx, SKILLS_LEGACY_MIGRATION_STATE_KEY):
+ return
+ ctx.actions.state_delta[SKILLS_LEGACY_MIGRATION_STATE_KEY] = True
+
+ state = _snapshot_state(ctx)
+ if not state:
+ return
+ has_loaded = any(k.startswith(SKILL_LOADED_STATE_KEY_PREFIX) for k in state.keys())
+ has_docs = any(k.startswith(SKILL_DOCS_STATE_KEY_PREFIX) for k in state.keys())
+ if not has_loaded and not has_docs:
+ return
+
+ owners: dict[str, str] | None = None
+ delta: dict[str, Any] = {}
+
+ for key, value in state.items():
+ if value is None or value == "":
+ continue
+ if key.startswith(SKILL_LOADED_STATE_KEY_PREFIX):
+ if owners is None:
+ owners = _legacy_skill_owners(getattr(ctx.session, "events", []))
+ name = key[len(SKILL_LOADED_STATE_KEY_PREFIX):].strip()
+ _migrate_legacy_state_key(ctx, state, delta, key, value, name, owners, loaded_key)
+ elif key.startswith(SKILL_DOCS_STATE_KEY_PREFIX):
+ if owners is None:
+ owners = _legacy_skill_owners(getattr(ctx.session, "events", []))
+ name = key[len(SKILL_DOCS_STATE_KEY_PREFIX):].strip()
+ _migrate_legacy_state_key(ctx, state, delta, key, value, name, owners, docs_key)
+
+ if delta:
+ ctx.actions.state_delta.update(delta)
diff --git a/trpc_agent_sdk/skills/_state_order.py b/trpc_agent_sdk/skills/_state_order.py
new file mode 100644
index 0000000..8cae773
--- /dev/null
+++ b/trpc_agent_sdk/skills/_state_order.py
@@ -0,0 +1,82 @@
+# Tencent is pleased to support the open source community by making tRPC-Agent-Python available.
+#
+# Copyright (C) 2026 Tencent. All rights reserved.
+#
+# tRPC-Agent-Python is licensed under Apache-2.0.
+"""Loaded-skill order helpers.
+"""
+
+from __future__ import annotations
+
+import json
+from typing import Any
+
+
+def parse_loaded_order(raw: Any) -> list[str]:
+ """Parse a stored loaded-order payload.
+
+ Returns a normalized list of unique non-empty skill names.
+ """
+ if raw is None:
+ return []
+ if isinstance(raw, bytes):
+ if not raw:
+ return []
+ try:
+ raw = raw.decode("utf-8")
+ except UnicodeDecodeError:
+ return []
+ if isinstance(raw, str):
+ if not raw:
+ return []
+ try:
+ raw = json.loads(raw)
+ except json.JSONDecodeError:
+ return []
+ if not isinstance(raw, list):
+ return []
+ return _normalize_loaded_order(raw)
+
+
+def marshal_loaded_order(names: list[str]) -> str:
+ """Serialize a normalized loaded-order payload."""
+ normalized = _normalize_loaded_order(names)
+ if not normalized:
+ return ""
+ return json.dumps(normalized, ensure_ascii=False)
+
+
+def touch_loaded_order(names: list[str], *touched: str) -> list[str]:
+ """Move touched skills to the tail of the loaded order."""
+ order = _normalize_loaded_order(names)
+ for name in touched:
+ candidate = (name or "").strip()
+ if not candidate:
+ continue
+ order = _remove_loaded_order_name(order, candidate)
+ order.append(candidate)
+ return order
+
+
+def _normalize_loaded_order(names: list[Any]) -> list[str]:
+ if not names:
+ return []
+ out: list[str] = []
+ seen: set[str] = set()
+ for name in names:
+ if not isinstance(name, str):
+ continue
+ candidate = name.strip()
+ if not candidate or candidate in seen:
+ continue
+ out.append(candidate)
+ seen.add(candidate)
+ return out
+
+
+def _remove_loaded_order_name(order: list[str], target: str) -> list[str]:
+ for i, name in enumerate(order):
+ if name != target:
+ continue
+ return order[:i] + order[i + 1:]
+ return order
diff --git a/trpc_agent_sdk/skills/_toolset.py b/trpc_agent_sdk/skills/_toolset.py
index 024a1db..f276322 100644
--- a/trpc_agent_sdk/skills/_toolset.py
+++ b/trpc_agent_sdk/skills/_toolset.py
@@ -12,33 +12,44 @@
from __future__ import annotations
from typing import Any
-from typing import Callable
from typing import List
from typing import Optional
from typing import Union
from typing_extensions import override
-from trpc_agent_sdk.abc import ToolABC
-from trpc_agent_sdk.abc import ToolPredicate
from trpc_agent_sdk.abc import ToolSetABC
+from trpc_agent_sdk.abc import ToolPredicate
+from trpc_agent_sdk.abc import ToolABC
from trpc_agent_sdk.context import InvocationContext
-from trpc_agent_sdk.log import logger
+from trpc_agent_sdk.context import get_invocation_ctx
from trpc_agent_sdk.tools import FunctionTool
+from trpc_agent_sdk.log import logger
from ._constants import SKILL_REGISTRY_KEY
from ._constants import SKILL_REPOSITORY_KEY
+from ._repository import FsSkillRepository
+from ._repository import BaseSkillRepository
from ._registry import SKILL_REGISTRY
from ._registry import SkillToolFunction
-from ._repository import BaseSkillRepository
-from ._repository import FsSkillRepository
-from .tools import SkillExecTool
-from .tools import SkillRunTool
-from .tools import skill_list
+from ._skill_config import DEFAULT_SKILL_CONFIG
+from ._skill_config import set_skill_config
+from ._skill_config import is_exist_skill_config
from .tools import skill_list_docs
from .tools import skill_list_tools
-from .tools import skill_load
+from .tools import SkillLoadTool
from .tools import skill_select_docs
from .tools import skill_select_tools
+from .tools import skill_list
+from .tools import SkillExecTool
+from .tools import SkillRunTool
+from .tools import SaveArtifactTool
+from .tools import WorkspaceExecTool
+from .tools import WorkspaceWriteStdinTool
+from .tools import WorkspaceKillSessionTool
+from .tools import CreateWorkspaceNameCallback
+from .tools import default_create_ws_name_callback
+from .tools import CopySkillStager
+from .stager import Stager
class SkillToolSet(ToolSetABC):
@@ -58,31 +69,63 @@ class SkillToolSet(ToolSetABC):
def __init__(self,
paths: Optional[List[str]] = None,
repository: BaseSkillRepository = None,
+ enable_hot_reload: bool = False,
tool_filter: Optional[Union[ToolPredicate, List[str]]] = None,
is_include_all_tools: bool = True,
+ create_ws_name_cb: Optional[CreateWorkspaceNameCallback] = None,
+ runtime_tools: Optional[List[ToolABC]] = None,
+ skill_stager: Optional[Stager] = None,
+ skill_config: Optional[dict[str, Any]] = None,
**run_tool_kwargs: dict[str, Any]):
"""Initialize the skill toolset.
Args:
paths: Optional list of skill paths. If None, will create a new one.
repository: Skill repository. If None, will be retrieved from context metadata.
+ enable_hot_reload: Whether to enable skill hot reload checks for
+ auto-created repositories.
tool_filter: Optional tool filter. If None, will include all tools.
is_include_all_tools: Optional flag to include all tools. If True, will include all tools.
+ user_tools: Optional list of user tools. If None, will not include any user tools.
run_tool_kwargs: Optional keyword arguments for skill run tool. If None, will use default values.
"""
super().__init__(tool_filter=tool_filter, is_include_all_tools=is_include_all_tools)
self.name = "skill_toolset"
- self._repository = repository or FsSkillRepository(*(paths or []))
- self._run_tool = SkillRunTool(repository=self._repository, **run_tool_kwargs)
- self._exec_tool = SkillExecTool(run_tool=self._run_tool)
- self._tools: List[Callable] = [
- skill_load,
+
+ self._repository = repository or FsSkillRepository(
+ *(paths or []),
+ enable_hot_reload=enable_hot_reload,
+ )
+ self._skill_config = skill_config or DEFAULT_SKILL_CONFIG
+ self._create_ws_name_cb = create_ws_name_cb or default_create_ws_name_callback
+ self._skill_stager = skill_stager or CopySkillStager()
+ self._load_tool = SkillLoadTool(repository=self._repository,
+ skill_stager=self._skill_stager,
+ create_ws_name_cb=self._create_ws_name_cb)
+ self._run_tool = SkillRunTool(repository=self._repository,
+ create_ws_name_cb=self._create_ws_name_cb,
+ skill_stager=self._skill_stager,
+ **run_tool_kwargs)
+ self._exec_tool = SkillExecTool(run_tool=self._run_tool, create_ws_name_cb=self._create_ws_name_cb)
+ self._function_tools: List[SkillToolFunction] = [
skill_list,
skill_list_docs,
skill_list_tools,
skill_select_docs,
skill_select_tools,
]
+ if runtime_tools:
+ self._runtime_tools = runtime_tools
+ else:
+ workspace_exec_tool = WorkspaceExecTool(workspace_runtime=self._repository.workspace_runtime,
+ create_ws_name_cb=self._create_ws_name_cb)
+ self._runtime_tools: List[ToolABC] = [
+ SaveArtifactTool(workspace_runtime=self._repository.workspace_runtime,
+ create_ws_name_cb=self._create_ws_name_cb),
+ workspace_exec_tool,
+ WorkspaceWriteStdinTool(workspace_exec_tool),
+ WorkspaceKillSessionTool(workspace_exec_tool),
+ ]
@property
def repository(self) -> BaseSkillRepository:
@@ -99,14 +142,21 @@ async def get_tools(self, invocation_context: Optional[InvocationContext] = None
Returns:
List of tools from all registered skills
"""
- tools: List[FunctionTool] = []
+ tools: List[ToolABC] = []
skill_functions: List[SkillToolFunction] = SKILL_REGISTRY.get_all()
- skill_functions.extend(self._tools)
- agent_context = invocation_context.agent_context
- agent_context.with_metadata(SKILL_REGISTRY_KEY, SKILL_REGISTRY)
- agent_context.with_metadata(SKILL_REPOSITORY_KEY, self._repository)
+ skill_functions.extend(self._function_tools)
+ if not invocation_context:
+ invocation_context = get_invocation_ctx()
+ if invocation_context:
+ agent_context = invocation_context.agent_context
+ agent_context.with_metadata(SKILL_REGISTRY_KEY, SKILL_REGISTRY)
+ agent_context.with_metadata(SKILL_REPOSITORY_KEY, self._repository)
+ if not is_exist_skill_config(agent_context):
+ set_skill_config(agent_context, self._skill_config)
+ tools.append(self._load_tool)
tools.append(self._run_tool)
tools.append(self._exec_tool)
+ tools.extend(self._runtime_tools)
for skill_function in skill_functions:
try:
tools.append(FunctionTool(func=skill_function))
diff --git a/trpc_agent_sdk/skills/_utils.py b/trpc_agent_sdk/skills/_utils.py
index c9434d7..61bccc0 100644
--- a/trpc_agent_sdk/skills/_utils.py
+++ b/trpc_agent_sdk/skills/_utils.py
@@ -194,3 +194,60 @@ def get_state_delta(invocation_context: InvocationContext, key: str) -> Optional
state = dict(invocation_context.session_state.copy())
state.update(invocation_context.actions.state_delta)
return state.get(key, None)
+
+
+FILE_EXTENSIONS_DOC = (".md", ".txt")
+
+
+def is_doc_file(name: str) -> bool:
+ """Check if a file is a document file.
+
+ Args:
+ name: The name of the file
+
+ Returns:
+ True if the file is a document file, False otherwise
+ """
+ name_lower = name.lower()
+ return name_lower.endswith(FILE_EXTENSIONS_DOC)
+
+
+_SCRIPT_FILE_EXTENSIONS = (
+ # Shell
+ ".sh",
+ ".bash",
+ ".zsh",
+ ".fish",
+ # Python / Node
+ ".py",
+ ".pyw",
+ ".js",
+ ".mjs",
+ ".cjs",
+ ".ts",
+ # Other interpreted / script languages
+ ".rb",
+ ".pl",
+ ".php",
+ ".lua",
+ ".r",
+ ".ps1",
+ ".awk",
+ ".tcl",
+ ".groovy",
+ ".kts",
+ ".jl",
+)
+
+
+def is_script_file(name: str) -> bool:
+ """Check if a file is a script file.
+
+ Args:
+ name: The name of the file
+
+ Returns:
+ True if the file is a script file, False otherwise
+ """
+ name_lower = name.lower()
+ return name_lower.endswith(_SCRIPT_FILE_EXTENSIONS)
diff --git a/trpc_agent_sdk/skills/stager/_base_stager.py b/trpc_agent_sdk/skills/stager/_base_stager.py
index 8bfffe3..3119193 100644
--- a/trpc_agent_sdk/skills/stager/_base_stager.py
+++ b/trpc_agent_sdk/skills/stager/_base_stager.py
@@ -27,6 +27,7 @@
from trpc_agent_sdk.code_executors import WorkspaceRunProgramSpec
from trpc_agent_sdk.code_executors import WorkspaceStageOptions
from trpc_agent_sdk.context import InvocationContext
+from trpc_agent_sdk.log import logger
from .._types import SkillMetadata
from .._types import SkillWorkspaceMetadata
@@ -46,12 +47,14 @@
class Stager:
"""Materializes skill package contents into a workspace.
-
- Mirrors Go's ``internal/skillstage.Stager``. Create an instance with
- ``Stager()`` (or the module-level helper :func:`new`) and call
- :meth:`stage_skill` from async tool code.
+ Stager is responsible for staging the skill package contents into a workspace.
"""
+ def __init__(self) -> None:
+ # Deduplicate noisy link warnings within one process lifetime.
+ # Key format: "||".
+ self._link_error_warned_keys: set[str] = set()
+
# ------------------------------------------------------------------
# Public API
# ------------------------------------------------------------------
@@ -76,7 +79,7 @@ async def stage_skill(self, request: SkillStageRequest) -> SkillStageResult:
ctx = request.ctx
ws = request.workspace
root = request.repository.path(request.skill_name)
- runtime = request.engine or request.repository.workspace_runtime
+ runtime = request.repository.workspace_runtime
name = request.skill_name
digest = compute_dir_digest(root)
md = await self.load_workspace_metadata(ctx, runtime, ws)
@@ -277,7 +280,17 @@ async def _link_workspace_dirs(
cmd = (f"set -e; cd {shell_quote(skill_root)}"
f"; rm -rf out work {_SKILL_DIR_INPUTS} {shell_quote(_SKILL_DIR_VENV)}"
- f"; mkdir -p {shell_quote(to_inputs)} {shell_quote(_SKILL_DIR_VENV)}"
+ f"; if [ -L {shell_quote(to_work)} ]; then"
+ f" rm -rf {shell_quote(to_work)}; fi"
+ f"; if [ -e {shell_quote(to_work)} ] && [ ! -d {shell_quote(to_work)} ]; then"
+ f" rm -rf {shell_quote(to_work)}; fi"
+ f"; if [ ! -d {shell_quote(to_work)} ]; then mkdir -p {shell_quote(to_work)}; fi"
+ f"; if [ -L {shell_quote(to_inputs)} ]; then"
+ f" rm -rf {shell_quote(to_inputs)}; fi"
+ f"; if [ -e {shell_quote(to_inputs)} ] && [ ! -d {shell_quote(to_inputs)} ]; then"
+ f" rm -rf {shell_quote(to_inputs)}; fi"
+ f"; if [ ! -d {shell_quote(to_inputs)} ]; then mkdir -p {shell_quote(to_inputs)}; fi"
+ f"; mkdir -p {shell_quote(_SKILL_DIR_VENV)}"
f"; ln -sfn {shell_quote(to_out)} out"
f"; ln -sfn {shell_quote(to_work)} work"
f"; ln -sfn {shell_quote(to_inputs)} {_SKILL_DIR_INPUTS}")
@@ -294,8 +307,18 @@ async def _link_workspace_dirs(
ctx,
)
if ret.exit_code != 0:
- from trpc_agent_sdk.log import logger # noqa: PLC0415
- logger.info("Stager._link_workspace_dirs failed for %r: %s", name, ret.stderr)
+ inv_id = ctx.invocation_id
+ err = (ret.stderr or "").strip()
+ dedupe_key = f"{inv_id}|{name}|{err}"
+ if dedupe_key in self._link_error_warned_keys:
+ logger.debug("Stager._link_workspace_dirs retry failed for %r: %s", name, ret.stderr)
+ return
+
+ self._link_error_warned_keys.add(dedupe_key)
+ # Keep the set bounded for long-lived processes.
+ if len(self._link_error_warned_keys) > 2000:
+ self._link_error_warned_keys.clear()
+ logger.warning("Stager._link_workspace_dirs failed for %r: %s", name, ret.stderr)
async def _read_only_except_symlinks(
self,
@@ -327,7 +350,6 @@ async def _read_only_except_symlinks(
ctx,
)
if ret.exit_code != 0:
- from trpc_agent_sdk.log import logger # noqa: PLC0415
logger.info("Stager._read_only_except_symlinks failed for %r: %s", dest, ret.stderr)
@classmethod
diff --git a/trpc_agent_sdk/skills/stager/_types.py b/trpc_agent_sdk/skills/stager/_types.py
index fec0411..0d599cc 100644
--- a/trpc_agent_sdk/skills/stager/_types.py
+++ b/trpc_agent_sdk/skills/stager/_types.py
@@ -6,7 +6,6 @@
"""Skill staging types."""
from dataclasses import dataclass
-from typing import Optional
from trpc_agent_sdk.code_executors import BaseWorkspaceRuntime
from trpc_agent_sdk.code_executors import WorkspaceInfo
@@ -17,10 +16,7 @@
@dataclass
class SkillStageRequest:
- """Describes the skill staging context for one run.
-
- Mirrors Go's ``SkillStageRequest`` in ``tool/skill/stager.go``.
- """
+ """Describes the skill staging context for one run."""
skill_name: str
"""Name of the skill to stage."""
@@ -34,11 +30,6 @@ class SkillStageRequest:
ctx: InvocationContext
"""Invocation context (used for workspace FS/runner access)."""
- engine: Optional[BaseWorkspaceRuntime] = None
- """Explicit workspace runtime (Go: ``Engine``).
- When ``None`` the runtime is obtained from ``repository.workspace_runtime``.
- """
-
timeout: float = 300.0
"""Timeout in seconds for internal staging helpers."""
diff --git a/trpc_agent_sdk/skills/tools/__init__.py b/trpc_agent_sdk/skills/tools/__init__.py
index 9ecf8a5..a6fc68d 100644
--- a/trpc_agent_sdk/skills/tools/__init__.py
+++ b/trpc_agent_sdk/skills/tools/__init__.py
@@ -1,64 +1,40 @@
-# Tencent is pleased to support the open source community by making tRPC-Agent-Python available.
+# -*- coding: utf-8 -*-
#
-# Copyright (C) 2026 Tencent. All rights reserved.
-#
-# tRPC-Agent-Python is licensed under Apache-2.0.
+# Copyright @ 2025 Tencent.com
"""Skill tools package."""
+from ._common import CreateWorkspaceNameCallback
+from ._common import default_create_ws_name_callback
from ._copy_stager import CopySkillStager
-from ._copy_stager import normalize_workspace_skill_dir
-from ._skill_exec import ExecInput
-from ._skill_exec import ExecOutput
-from ._skill_exec import KillSessionInput
-from ._skill_exec import KillSessionTool
-from ._skill_exec import PollSessionInput
-from ._skill_exec import PollSessionTool
-from ._skill_exec import SessionInteraction
-from ._skill_exec import SessionKillOutput
+from ._save_artifact import SaveArtifactTool
from ._skill_exec import SkillExecTool
-from ._skill_exec import WriteStdinInput
-from ._skill_exec import WriteStdinTool
-from ._skill_exec import create_exec_tools
from ._skill_list import skill_list
from ._skill_list_docs import skill_list_docs
from ._skill_list_tool import skill_list_tools
-from ._skill_load import skill_load
-from ._skill_run import ArtifactInfo
-from ._skill_run import SkillRunFile
-from ._skill_run import SkillRunInput
-from ._skill_run import SkillRunOutput
+from ._skill_load import SkillLoadTool
from ._skill_run import SkillRunTool
-from ._skill_select_docs import SkillSelectDocsResult
from ._skill_select_docs import skill_select_docs
-from ._skill_select_tools import SkillSelectToolsResult
from ._skill_select_tools import skill_select_tools
+from ._workspace_exec import WorkspaceExecTool
+from ._workspace_exec import WorkspaceKillSessionTool
+from ._workspace_exec import WorkspaceWriteStdinTool
+from ._workspace_exec import create_workspace_exec_tools
__all__ = [
+ "CreateWorkspaceNameCallback",
+ "default_create_ws_name_callback",
"CopySkillStager",
- "normalize_workspace_skill_dir",
- "ExecInput",
- "ExecOutput",
- "KillSessionInput",
- "KillSessionTool",
- "PollSessionInput",
- "PollSessionTool",
- "SessionInteraction",
- "SessionKillOutput",
+ "SaveArtifactTool",
"SkillExecTool",
- "WriteStdinInput",
- "WriteStdinTool",
- "create_exec_tools",
"skill_list",
"skill_list_docs",
"skill_list_tools",
- "skill_load",
- "ArtifactInfo",
- "SkillRunFile",
- "SkillRunInput",
- "SkillRunOutput",
+ "SkillLoadTool",
"SkillRunTool",
- "SkillSelectDocsResult",
"skill_select_docs",
- "SkillSelectToolsResult",
"skill_select_tools",
+ "WorkspaceExecTool",
+ "WorkspaceKillSessionTool",
+ "WorkspaceWriteStdinTool",
+ "create_workspace_exec_tools",
]
diff --git a/trpc_agent_sdk/skills/tools/_common.py b/trpc_agent_sdk/skills/tools/_common.py
new file mode 100644
index 0000000..2cd6bf3
--- /dev/null
+++ b/trpc_agent_sdk/skills/tools/_common.py
@@ -0,0 +1,159 @@
+# Tencent is pleased to support the open source community by making tRPC-Agent-Python available.
+#
+# Copyright (C) 2026 Tencent. All rights reserved.
+#
+# tRPC-Agent-Python is licensed under Apache-2.0.
+"""Shared helpers for skill/workspace tools."""
+
+from __future__ import annotations
+
+import asyncio
+import inspect
+import time
+from typing import Any
+from typing import Awaitable
+from typing import Callable
+from typing import TypeVar
+
+from trpc_agent_sdk.context import InvocationContext
+
+T = TypeVar("T")
+
+CreateWorkspaceNameCallback = Callable[[InvocationContext], str]
+"""Callback to create a workspace name."""
+
+
+def default_create_ws_name_callback(ctx: InvocationContext) -> str:
+ """Default callback to create a workspace name."""
+ return ctx.session.id
+
+
+def require_non_empty(value: str, *, field_name: str) -> str:
+ """Validate a required string field and return trimmed value."""
+ normalized = (value or "").strip()
+ if not normalized:
+ raise ValueError(f"{field_name} is required")
+ return normalized
+
+
+async def put_session(
+ sessions: dict[str, T],
+ lock: asyncio.Lock,
+ sid: str,
+ session: T,
+ gc: Callable[[], Awaitable[None]],
+) -> None:
+ """Insert session after running gc in lock."""
+ async with lock:
+ await gc()
+ sessions[sid] = session
+
+
+async def get_session(
+ sessions: dict[str, T],
+ lock: asyncio.Lock,
+ sid: str,
+ gc: Callable[[], Awaitable[None]],
+) -> T:
+ """Lookup session after running gc in lock."""
+ async with lock:
+ await gc()
+ session = sessions.get(sid)
+ if session is None:
+ raise ValueError(f"unknown session_id: {sid}")
+ return session
+
+
+async def remove_session(
+ sessions: dict[str, T],
+ lock: asyncio.Lock,
+ sid: str,
+ gc: Callable[[], Awaitable[None]],
+) -> T:
+ """Remove session after running gc in lock."""
+ async with lock:
+ await gc()
+ session = sessions.pop(sid, None)
+ if session is None:
+ raise ValueError(f"unknown session_id: {sid}")
+ return session
+
+
+async def _await_if_needed(value: object) -> None:
+ if inspect.isawaitable(value):
+ await value
+
+
+async def cleanup_expired_sessions(
+ sessions: dict[str, T],
+ *,
+ ttl: float,
+ refresh_exit_state: Callable[[T, float], object],
+ close_session: Callable[[T], object],
+) -> None:
+ """Refresh exit state and evict expired sessions in-place."""
+ if ttl <= 0:
+ return
+ now = time.time()
+ expired: list[str] = []
+ for sid, session in sessions.items():
+ await _await_if_needed(refresh_exit_state(session, now))
+ exited_at = getattr(session, "exited_at", None)
+ if exited_at is not None and (now - exited_at) >= ttl:
+ expired.append(sid)
+
+ for sid in expired:
+ session = sessions.pop(sid, None)
+ if session is None:
+ continue
+ try:
+ await _await_if_needed(close_session(session))
+ except Exception: # pylint: disable=broad-except
+ # Best-effort cleanup: mirror Go behavior and keep evicting others.
+ pass
+
+
+def inline_json_schema_refs(schema: dict[str, Any]) -> dict[str, Any]:
+ """Inline $ref references in JSON Schema by replacing them with actual definitions."""
+ defs = schema.get('$defs', {})
+ if not defs:
+ return schema
+
+ def resolve_ref(obj: Any) -> Any:
+ if isinstance(obj, dict):
+ if '$ref' in obj:
+ ref_path = obj['$ref']
+ if ref_path.startswith('#/$defs/'):
+ ref_name = ref_path.replace('#/$defs/', '')
+ if ref_name in defs:
+ resolved = resolve_ref(defs[ref_name])
+ merged = {**resolved, **{k: v for k, v in obj.items() if k != '$ref'}}
+ return merged
+ return obj
+ else:
+ return {k: resolve_ref(v) for k, v in obj.items()}
+ elif isinstance(obj, list):
+ return [resolve_ref(item) for item in obj]
+ else:
+ return obj
+
+ result = {k: v for k, v in schema.items() if k != '$defs'}
+ result = resolve_ref(result)
+ return result
+
+
+SKILL_STAGED_WORKSPACE_DIR_KEY = "__trpc_agent_skills_staged_workspace_dir"
+"""Key for staged workspace directory."""
+
+
+def get_staged_workspace_dir(ctx: InvocationContext, skill_name: str) -> str:
+ """Get the staged workspace directory."""
+ dir_map = ctx.agent_context.get_metadata(SKILL_STAGED_WORKSPACE_DIR_KEY, {})
+ return dir_map.get(skill_name, "")
+
+
+def set_staged_workspace_dir(ctx: InvocationContext, skill_name: str, dir: str) -> None:
+ """Set the staged workspace directory."""
+ dir_map = ctx.agent_context.get_metadata(SKILL_STAGED_WORKSPACE_DIR_KEY, {})
+ dir_map[skill_name] = dir
+ ctx.agent_context.with_metadata(SKILL_STAGED_WORKSPACE_DIR_KEY, dir_map)
diff --git a/trpc_agent_sdk/skills/tools/_copy_stager.py b/trpc_agent_sdk/skills/tools/_copy_stager.py
index 534f616..0ca5288 100644
--- a/trpc_agent_sdk/skills/tools/_copy_stager.py
+++ b/trpc_agent_sdk/skills/tools/_copy_stager.py
@@ -22,17 +22,8 @@
from ..stager import SkillStageResult
from ..stager import Stager
-# ---------------------------------------------------------------------------
-# Error messages (mirrors Go const block in tool/skill/stager.go)
-# ---------------------------------------------------------------------------
-
-_ERR_STAGER_NOT_CONFIGURED = "skill stager is not configured"
_ERR_REPO_NOT_CONFIGURED = "skill repository is not configured"
-# ---------------------------------------------------------------------------
-# Allowed workspace roots (mirrors Go isAllowedWorkspacePath)
-# ---------------------------------------------------------------------------
-
_ALLOWED_WS_ROOTS = (
DIR_SKILLS, # "skills"
DIR_WORK, # "work"
@@ -40,17 +31,12 @@
DIR_RUNS, # "runs"
)
-# ---------------------------------------------------------------------------
-# Path normalization helpers
-# (mirrors normalizeWorkspaceSkillDir / normalizeSkillStageResult)
-# ---------------------------------------------------------------------------
-
def normalize_workspace_skill_dir(raw: str) -> str:
"""Normalize and validate a workspace-relative skill directory.
Mirrors Go's ``normalizeWorkspaceSkillDir``. Strips leading slashes,
- normalizes path separators, and ensures the result stays within a
+ normalises path separators, and ensures the result stays within a
known workspace root.
Raises:
@@ -74,7 +60,7 @@ def normalize_workspace_skill_dir(raw: str) -> str:
def _normalize_skill_stage_result(result: SkillStageResult) -> SkillStageResult:
- """Return *result* with :attr:`workspace_skill_dir` normalized.
+ """Return *result* with :attr:`workspace_skill_dir` normalised.
Mirrors Go's ``normalizeSkillStageResult``.
@@ -84,39 +70,15 @@ def _normalize_skill_stage_result(result: SkillStageResult) -> SkillStageResult:
return SkillStageResult(workspace_skill_dir=normalize_workspace_skill_dir(result.workspace_skill_dir))
-# ---------------------------------------------------------------------------
-# Default copy-based stager (mirrors Go copySkillStager)
-# ---------------------------------------------------------------------------
-
-
class CopySkillStager(Stager):
"""Default stager: copies the skill directory into ``skills/``.
- Mirrors Go's ``copySkillStager`` struct. Holds an optional
- back-reference to the owning ``SkillRunTool`` (``run_tool``), matching
- the Go pattern of ``copySkillStager{tool: tool}``.
-
The actual filesystem work — digest check, ``stage_directory``, symlink
- creation, read-only chmod — is delegated to
- :class:`~trpc_agent_sdk.skills.stager.Stager`, which mirrors
- :class:`~trpc_agent_sdk.skills.stager.Stager`.
-
- Construct via :func:`new_copy_skill_stager` or
- :class:`~trpc_agent_sdk.skills.tools.CopySkillStager`.
+ creation, read-only chmod — is delegated to :class:`~trpc_agent.skills.stager.Stager`.
"""
async def stage_skill(self, request: SkillStageRequest) -> SkillStageResult:
- """Stage the skill and return the normalized workspace skill dir.
-
- Mirrors Go's ``copySkillStager.StageSkill``:
-
- 1. Validate repository is present.
- 2. Resolve the on-disk skill root via the repository.
- 3. Resolve the runtime (``request.engine`` takes priority, falls back
- to ``repository.workspace_runtime``).
- 4. Delegate copy / link / chmod to :class:`~trpc_agent_sdk.skills.stager.Stager`.
- 5. Return a normalized :class:`SkillStageResult`.
- """
+ """Stage the skill and return the normalised workspace skill dir."""
if request.repository is None:
raise ValueError(_ERR_REPO_NOT_CONFIGURED)
diff --git a/trpc_agent_sdk/skills/tools/_save_artifact.py b/trpc_agent_sdk/skills/tools/_save_artifact.py
new file mode 100644
index 0000000..17cb70c
--- /dev/null
+++ b/trpc_agent_sdk/skills/tools/_save_artifact.py
@@ -0,0 +1,262 @@
+# -*- coding: utf-8 -*-
+#
+# Copyright @ 2025 Tencent.com
+"""Workspace artifact save tool.
+
+This tool persists an existing file from the current workspace as an artifact.
+"""
+
+from __future__ import annotations
+
+import mimetypes
+import os
+import posixpath
+from typing import Any
+from typing import Optional
+
+from trpc_agent_sdk.code_executors import BaseWorkspaceRuntime
+from trpc_agent_sdk.code_executors import DIR_OUT
+from trpc_agent_sdk.code_executors import DIR_RUNS
+from trpc_agent_sdk.code_executors import DIR_WORK
+from trpc_agent_sdk.code_executors import WORKSPACE_ENV_DIR_KEY
+from trpc_agent_sdk.code_executors.utils import normalize_globs
+from trpc_agent_sdk.context import InvocationContext
+from trpc_agent_sdk.filter import BaseFilter
+from trpc_agent_sdk.tools import BaseTool
+from trpc_agent_sdk.types import FunctionDeclaration
+from trpc_agent_sdk.types import Part
+from trpc_agent_sdk.types import Schema
+from trpc_agent_sdk.types import Type
+
+from .._constants import SKILL_ARTIFACTS_STATE_KEY
+from ._common import CreateWorkspaceNameCallback
+from ._common import default_create_ws_name_callback
+
+_DEFAULT_MAX_BYTES = 64 * 1024 * 1024
+_ALLOWED_ROOTS = (DIR_WORK, DIR_OUT, DIR_RUNS)
+_SAVE_REASON_NO_SERVICE = "artifact service is not configured"
+_SAVE_REASON_NO_SESSION = "session is missing from invocation context"
+_SAVE_REASON_NO_SESSION_IDS = "session app/user/session IDs are missing"
+
+
+def _has_glob_meta(s: str) -> bool:
+ return any(ch in s for ch in ("*", "?", "["))
+
+
+def _normalize_workspace_prefix(path: str) -> str:
+ """Normalize workspace env-style prefixes in path."""
+ s = path.strip().replace("\\", "/")
+ if s.startswith("workspace://"):
+ s = s[len("workspace://"):]
+ replacements = (
+ ("${WORKSPACE_DIR}/", ""),
+ ("$WORKSPACE_DIR/", ""),
+ ("${WORK_DIR}/", f"{DIR_WORK}/"),
+ ("$WORK_DIR/", f"{DIR_WORK}/"),
+ ("${OUTPUT_DIR}/", f"{DIR_OUT}/"),
+ ("$OUTPUT_DIR/", f"{DIR_OUT}/"),
+ ("${RUN_DIR}/", f"{DIR_RUNS}/"),
+ ("$RUN_DIR/", f"{DIR_RUNS}/"),
+ )
+ for src, dst in replacements:
+ if s.startswith(src):
+ return dst + s[len(src):]
+ if s in ("$WORKSPACE_DIR", "${WORKSPACE_DIR}"):
+ return ""
+ if s in ("$WORK_DIR", "${WORK_DIR}"):
+ return DIR_WORK
+ if s in ("$OUTPUT_DIR", "${OUTPUT_DIR}"):
+ return DIR_OUT
+ if s in ("$RUN_DIR", "${RUN_DIR}"):
+ return DIR_RUNS
+ return s
+
+
+def _is_workspace_env_path(path: str) -> bool:
+ s = path.strip()
+ if not s:
+ return False
+ return s.startswith("$") or s.startswith("${")
+
+
+def _is_allowed_publish_path(rel: str) -> bool:
+ return any(rel == root or rel.startswith(f"{root}/") for root in _ALLOWED_ROOTS)
+
+
+def _artifact_save_skip_reason(ctx: InvocationContext) -> str:
+ if ctx.artifact_service is None:
+ return _SAVE_REASON_NO_SERVICE
+ if ctx.session is None:
+ return _SAVE_REASON_NO_SESSION
+ if not (ctx.app_name and ctx.user_id and ctx.session_id):
+ return _SAVE_REASON_NO_SESSION_IDS
+ return ""
+
+
+def _apply_artifact_state_delta(ctx: InvocationContext, saved_as: str, version: int, ref: str) -> None:
+ tool_call_id = (ctx.function_call_id or "").strip()
+ if not tool_call_id or not saved_as or version < 0:
+ return
+ artifact_ref = ref.strip() or f"artifact://{saved_as}@{version}"
+ ctx.actions.state_delta[SKILL_ARTIFACTS_STATE_KEY] = {
+ "tool_call_id": tool_call_id,
+ "artifacts": [{
+ "name": saved_as,
+ "version": version,
+ "ref": artifact_ref,
+ }],
+ }
+
+
+def _resolve_workspace_root_from_config(ctx: InvocationContext) -> str:
+ """Resolve workspace root from run_config/env; fallback to cwd."""
+ if ctx.run_config and isinstance(ctx.run_config.custom_data, dict):
+ for k in ("workspace_dir", "workspace_root", "workspace_path"):
+ v = ctx.run_config.custom_data.get(k)
+ if isinstance(v, str) and v.strip():
+ return os.path.abspath(v.strip())
+ env_root = os.environ.get(WORKSPACE_ENV_DIR_KEY, "").strip()
+ if env_root:
+ return os.path.abspath(env_root)
+ return os.path.abspath(os.getcwd())
+
+
+def _normalize_artifact_path(raw: str, workspace_root: str) -> tuple[str, str]:
+ """Return (workspace-relative path, absolute file path)."""
+ s = _normalize_workspace_prefix(raw)
+ if not s:
+ raise ValueError("path is required")
+ if _has_glob_meta(s):
+ raise ValueError("path must not contain glob patterns")
+ if _is_workspace_env_path(s):
+ out = normalize_globs([s.replace("${RUN_DIR}", DIR_RUNS).replace("$RUN_DIR", DIR_RUNS)])
+ if not out:
+ raise ValueError("invalid path")
+ s = out[0]
+
+ cleaned = posixpath.normpath(s)
+ if posixpath.isabs(cleaned):
+ rel = cleaned.lstrip("/")
+ rel = posixpath.normpath(rel)
+ if rel in ("", ".", ".."):
+ raise ValueError("path must point to a file inside the workspace")
+ else:
+ rel = cleaned
+ if rel in (".", "..") or rel.startswith("../"):
+ raise ValueError("path must stay within the workspace")
+
+ if not _is_allowed_publish_path(rel):
+ raise ValueError("path must stay under work/, out/, or runs/")
+
+ abs_path = os.path.abspath(os.path.join(workspace_root, rel))
+ rel_check = os.path.relpath(abs_path, workspace_root).replace("\\", "/")
+ if rel_check in (".", "..") or rel_check.startswith("../"):
+ raise ValueError("path must stay within the workspace")
+
+ rel = posixpath.normpath(rel)
+ return rel, abs_path
+
+
+class SaveArtifactTool(BaseTool):
+ """Persist an existing workspace file as an artifact."""
+
+ def __init__(
+ self,
+ max_file_bytes: int = _DEFAULT_MAX_BYTES,
+ workspace_runtime: Optional[BaseWorkspaceRuntime] = None,
+ create_ws_name_cb: Optional[CreateWorkspaceNameCallback] = None,
+ filters_name: Optional[list[str]] = None,
+ filters: Optional[list[BaseFilter]] = None,
+ ):
+ super().__init__(
+ name="workspace_save_artifact",
+ description=("Save an existing file from the current workspace as an artifact. "
+ "Path must be under work/, out/, or runs/."),
+ filters_name=filters_name,
+ filters=filters,
+ )
+ self._max_file_bytes = max_file_bytes
+ self._workspace_runtime = workspace_runtime
+ self._create_ws_name_cb = create_ws_name_cb or default_create_ws_name_callback
+
+ async def _resolve_workspace_root(self, ctx: InvocationContext) -> str:
+ """Resolve workspace root, preferring the shared workspace_exec workspace."""
+ runtime = self._workspace_runtime
+ if runtime is not None:
+ workspace_id = self._create_ws_name_cb(ctx)
+ ws = await runtime.manager(ctx).create_workspace(workspace_id, ctx)
+ if ws.path:
+ return os.path.abspath(ws.path)
+ return _resolve_workspace_root_from_config(ctx)
+
+ def _get_declaration(self) -> Optional[FunctionDeclaration]:
+ return FunctionDeclaration(
+ name="workspace_save_artifact",
+ description=("Save an existing file from the current workspace as an artifact. "
+ "Use this to get a stable artifact:// reference for files under "
+ "work/, out/, or runs/."),
+ parameters=Schema(
+ type=Type.OBJECT,
+ required=["path"],
+ properties={
+ "path":
+ Schema(
+ type=Type.STRING,
+ description=("Workspace-relative file path to save. "
+ "Supports prefixes like $WORK_DIR/, $OUTPUT_DIR/, "
+ "$RUN_DIR/, and workspace://."),
+ ),
+ },
+ ),
+ response=Schema(
+ type=Type.OBJECT,
+ required=["path", "saved_as", "version", "ref", "size_bytes"],
+ properties={
+ "path": Schema(type=Type.STRING, description="Workspace-relative source path."),
+ "saved_as": Schema(type=Type.STRING, description="Artifact name used when saving."),
+ "version": Schema(type=Type.INTEGER, description="Artifact version."),
+ "ref": Schema(type=Type.STRING, description="artifact:// reference for saved artifact."),
+ "mime_type": Schema(type=Type.STRING, description="Detected MIME type."),
+ "size_bytes": Schema(type=Type.INTEGER, description="File size in bytes."),
+ },
+ ),
+ )
+
+ async def _run_async_impl(self, *, tool_context: InvocationContext, args: dict[str, Any]) -> Any:
+ raw_path = str(args.get("path", "")).strip()
+ if not raw_path:
+ return {"error": "INVALID_PARAMETER: path is required"}
+ reason = _artifact_save_skip_reason(tool_context)
+ if reason:
+ return {"error": f"ARTIFACT_SAVE_UNAVAILABLE: {reason}"}
+
+ try:
+ workspace_root = await self._resolve_workspace_root(tool_context)
+ rel, abs_path = _normalize_artifact_path(raw_path, workspace_root)
+ except ValueError as ex:
+ return {"error": f"INVALID_PARAMETER: {str(ex)}"}
+
+ if not os.path.exists(abs_path):
+ return {"error": f"FILE_NOT_FOUND: workspace artifact file not found: {rel}"}
+ if not os.path.isfile(abs_path):
+ return {"error": f"INVALID_PATH: path is not a file: {rel}"}
+
+ size_bytes = os.path.getsize(abs_path)
+ if size_bytes > self._max_file_bytes:
+ return {"error": f"FILE_TOO_LARGE: file exceeds {self._max_file_bytes} bytes limit"}
+
+ with open(abs_path, "rb") as f:
+ data = f.read()
+
+ mime_type = mimetypes.guess_type(abs_path)[0] or "application/octet-stream"
+ version = await tool_context.save_artifact(rel, Part.from_bytes(data=data, mime_type=mime_type))
+ ref = f"artifact://{rel}@{version}"
+ _apply_artifact_state_delta(tool_context, rel, version, ref)
+ return {
+ "path": rel,
+ "saved_as": rel,
+ "version": version,
+ "ref": ref,
+ "mime_type": mime_type,
+ "size_bytes": size_bytes,
+ }
diff --git a/trpc_agent_sdk/skills/tools/_skill_exec.py b/trpc_agent_sdk/skills/tools/_skill_exec.py
index 7e4cfe6..6a0e511 100644
--- a/trpc_agent_sdk/skills/tools/_skill_exec.py
+++ b/trpc_agent_sdk/skills/tools/_skill_exec.py
@@ -5,15 +5,13 @@
# tRPC-Agent-Python is licensed under Apache-2.0.
"""Interactive skill execution tools.
-Provides four tools that mirror the Go :mod:`~trpc_agent_sdk.skills.tools.SkillExecTool` implementation:
-
-* :class:`~trpc_agent_sdk.skills.tools.SkillExecTool` — start an interactive session
-* :class:`~trpc_agent_sdk.skills.tools.WriteStdinTool` — write stdin to a running session
-* :class:`~trpc_agent_sdk.skills.tools.PollSessionTool` — poll a session for new output
-* :class:`~trpc_agent_sdk.skills.tools.KillSessionTool` — terminate and remove a session
+* ``skill_exec`` — start an interactive session (SkillExecTool)
+* ``skill_write_stdin`` — write stdin to a running session (WriteStdinTool)
+* ``skill_poll_session`` — poll a session for new output (PollSessionTool)
+* ``skill_kill_session`` — terminate and remove a session (KillSessionTool)
Sessions run real sub-processes inside the staged skill workspace. When
-:attr:`~trpc_agent_sdk.skills.tools.ExecInput.tty` is ``True`` a POSIX pseudo-terminal is allocated so TTY-aware programs work
+``tty=True`` a POSIX pseudo-terminal is allocated so TTY-aware programs work
correctly (e.g. interactive shells, ncurses UIs).
Usage example::
@@ -22,7 +20,7 @@
result = await skill_exec_tool.run(ctx, {
"skill": "my_skill",
"command": "python interactive.py",
- "yield_ms": 500,
+ "yield_time_ms": 500,
})
sid = result["session_id"]
@@ -46,6 +44,7 @@
import uuid
from dataclasses import dataclass
from dataclasses import field
+from datetime import datetime
from pathlib import Path
from typing import Any
from typing import Dict
@@ -55,13 +54,29 @@
from pydantic import BaseModel
from pydantic import Field
+from trpc_agent_sdk.code_executors import BaseProgramRunner
+from trpc_agent_sdk.code_executors import BaseProgramSession
+from trpc_agent_sdk.code_executors import DEFAULT_EXEC_YIELD_MS
+from trpc_agent_sdk.code_executors import DEFAULT_IO_YIELD_MS
+from trpc_agent_sdk.code_executors import DEFAULT_SESSION_KILL_SEC
+from trpc_agent_sdk.code_executors import DEFAULT_SESSION_TTL_SEC
from trpc_agent_sdk.code_executors import DIR_OUT
+from trpc_agent_sdk.code_executors import DIR_RUNS
from trpc_agent_sdk.code_executors import DIR_SKILLS
from trpc_agent_sdk.code_executors import DIR_WORK
+from trpc_agent_sdk.code_executors import ENV_OUTPUT_DIR
+from trpc_agent_sdk.code_executors import ENV_SKILLS_DIR
from trpc_agent_sdk.code_executors import ENV_SKILL_NAME
+from trpc_agent_sdk.code_executors import ENV_WORK_DIR
+from trpc_agent_sdk.code_executors import PROGRAM_STATUS_EXITED
+from trpc_agent_sdk.code_executors import WORKSPACE_ENV_DIR_KEY
from trpc_agent_sdk.code_executors import WorkspaceInfo
from trpc_agent_sdk.code_executors import WorkspaceInputSpec
from trpc_agent_sdk.code_executors import WorkspaceOutputSpec
+from trpc_agent_sdk.code_executors import WorkspaceRunProgramSpec
+from trpc_agent_sdk.code_executors import poll_line_limit
+from trpc_agent_sdk.code_executors import wait_for_program_output
+from trpc_agent_sdk.code_executors import yield_duration_ms
from trpc_agent_sdk.context import InvocationContext
from trpc_agent_sdk.filter import BaseFilter
from trpc_agent_sdk.log import logger
@@ -69,24 +84,20 @@
from trpc_agent_sdk.types import FunctionDeclaration
from trpc_agent_sdk.types import Schema
+from .._constants import SKILL_ARTIFACTS_STATE_KEY
+from ._common import CreateWorkspaceNameCallback
+from ._common import cleanup_expired_sessions
+from ._common import default_create_ws_name_callback
+from ._common import inline_json_schema_refs
+from ._common import require_non_empty
from ._copy_stager import SkillStageRequest
from ._skill_run import SkillRunInput
from ._skill_run import SkillRunOutput
from ._skill_run import SkillRunTool
from ._skill_run import _filter_failed_empty_outputs
-from ._skill_run import _inline_json_schema_refs
from ._skill_run import _select_primary_output
from ._skill_run import _truncate_output
-# ---------------------------------------------------------------------------
-# Defaults (mirrors Go program's session defaults)
-# ---------------------------------------------------------------------------
-
-DEFAULT_EXEC_YIELD_MS: int = 300 # wait time on skill_exec
-DEFAULT_IO_YIELD_MS: int = 100 # wait time on write/poll
-DEFAULT_POLL_LINES: int = 50 # max lines returned per call
-DEFAULT_SESSION_TTL: float = 300.0 # seconds after exit before GC
-
# Status strings
_STATUS_RUNNING = "running"
_STATUS_EXITED = "exited"
@@ -109,7 +120,7 @@ class ExecInput(BaseModel):
env: dict[str, str] = Field(default_factory=dict, description="Extra environment variables")
stdin: str = Field(default="", description="Optional initial stdin written before yielding")
tty: bool = Field(default=False, description="Allocate a pseudo-TTY")
- yield_ms: int = Field(default=0, description="Milliseconds to wait for initial output before returning")
+ yield_time_ms: int = Field(default=0, description="Milliseconds to wait for initial output before returning")
poll_lines: int = Field(default=0, description="Maximum output lines to return per call")
output_files: list[str] = Field(default_factory=list, description="Glob patterns to collect on exit")
timeout: int = Field(default=0, description="Timeout in seconds (0 = no timeout)")
@@ -126,7 +137,7 @@ class WriteStdinInput(BaseModel):
session_id: str = Field(..., description="Session id returned by skill_exec")
chars: str = Field(default="", description="Text to write to stdin")
submit: bool = Field(default=False, description="Append a newline after chars")
- yield_ms: int = Field(default=0, description="Milliseconds to wait for new output")
+ yield_time_ms: int = Field(default=0, description="Milliseconds to wait for new output")
poll_lines: int = Field(default=0, description="Maximum output lines to return")
@@ -134,7 +145,7 @@ class PollSessionInput(BaseModel):
"""Input for skill_poll_session."""
session_id: str = Field(..., description="Session id returned by skill_exec")
- yield_ms: int = Field(default=0, description="Milliseconds to wait for new output")
+ yield_time_ms: int = Field(default=0, description="Milliseconds to wait for new output")
poll_lines: int = Field(default=0, description="Maximum output lines to return")
@@ -173,6 +184,31 @@ class SessionKillOutput(BaseModel):
status: str = Field(default="", description="Final status after kill")
+def _apply_artifacts_state_delta(ctx: InvocationContext, output: ExecOutput) -> None:
+ """Store replayable artifact refs in state delta (Go execArtifactsStateDelta parity)."""
+ tool_call_id = (ctx.function_call_id or "").strip()
+ if not tool_call_id or output.result is None or not output.result.artifact_files:
+ return
+
+ artifacts: list[dict[str, Any]] = []
+ for item in output.result.artifact_files:
+ name = (item.name or "").strip()
+ version = int(item.version)
+ if not name or version < 0:
+ continue
+ artifacts.append({
+ "name": name,
+ "version": version,
+ "ref": f"artifact://{name}@{version}",
+ })
+ if not artifacts:
+ return
+ ctx.actions.state_delta[SKILL_ARTIFACTS_STATE_KEY] = {
+ "tool_call_id": tool_call_id,
+ "artifacts": artifacts,
+ }
+
+
# ---------------------------------------------------------------------------
# Internal session state
# ---------------------------------------------------------------------------
@@ -182,169 +218,31 @@ class SessionKillOutput(BaseModel):
class _ExecSession:
"""Holds state for one running interactive skill session."""
- proc: asyncio.subprocess.Process
+ proc: BaseProgramSession
ws: WorkspaceInfo
in_data: ExecInput
- # Output buffer (all output since start as raw bytes → decoded text)
- _output_buf: list[str] = field(default_factory=list)
- _output_lock: asyncio.Lock = field(default_factory=asyncio.Lock)
- _output_event: asyncio.Event = field(default_factory=asyncio.Event)
-
- # Current byte offset for incremental reads
- _read_offset: int = 0
-
- # Background reader task
- reader_task: Optional[asyncio.Task] = None
-
- # PTY master fd (None for non-TTY)
- master_fd: Optional[int] = None
-
# Final state
exit_code: Optional[int] = None
exited_at: Optional[float] = None
final_result: Optional[SkillRunOutput] = None
finalized: bool = False
- async def append_output(self, chunk: str) -> None:
- async with self._output_lock:
- self._output_buf.append(chunk)
- self._output_event.set()
-
- async def total_output(self) -> str:
- async with self._output_lock:
- return "".join(self._output_buf)
-
- async def yield_output(self, yield_ms: int, poll_lines: int) -> tuple[str, str, int, int]:
- """Wait *yield_ms* ms for new output then return a chunk.
+ async def yield_output(self, yield_time_ms: int, poll_lines: int) -> tuple[str, str, int, int]:
+ """Wait *yield_time_ms* ms for new output then return a chunk.
Returns ``(status, output_chunk, offset, next_offset)``.
"""
- yield_sec = (yield_ms or DEFAULT_EXEC_YIELD_MS) / 1000.0
- deadline = asyncio.get_event_loop().time() + yield_sec
-
- while True:
- remaining = deadline - asyncio.get_event_loop().time()
- if remaining <= 0:
- break
- self._output_event.clear()
- # Break early if process has already exited and we have all output
- if self.proc.returncode is not None:
- break
- try:
- await asyncio.wait_for(asyncio.shield(self._output_event.wait()), timeout=remaining)
- except asyncio.TimeoutError:
- break
-
- # Determine status
- rc = self.proc.returncode
- if rc is not None:
- self.exit_code = rc
- if self.exited_at is None:
- self.exited_at = time.time()
- status = _STATUS_RUNNING if rc is None else _STATUS_EXITED
-
- # Slice the output since last read
- async with self._output_lock:
- full = "".join(self._output_buf)
-
- chunk = full[self._read_offset:]
-
- # Apply poll_lines limit
- limit = poll_lines or DEFAULT_POLL_LINES
- if limit > 0 and chunk:
- lines = chunk.split("\n")
- if len(lines) > limit:
- chunk = "\n".join(lines[:limit])
- if not chunk.endswith("\n"):
- chunk += "\n"
-
- offset = self._read_offset
- self._read_offset += len(chunk)
- next_offset = self._read_offset
-
- return status, chunk, offset, next_offset
-
-
-# ---------------------------------------------------------------------------
-# Background output readers
-# ---------------------------------------------------------------------------
-
-
-async def _read_pipe(session: _ExecSession, stream: asyncio.StreamReader) -> None:
- """Continuously read from *stream* and append to session output."""
- try:
- while True:
- data = await stream.read(4096)
- if not data:
- break
- await session.append_output(data.decode("utf-8", errors="replace"))
- except Exception: # pylint: disable=broad-except
- pass
- finally:
- # Ensure exit is captured
- try:
- await session.proc.wait()
- except Exception: # pylint: disable=broad-except
- pass
- session._output_event.set() # unblock any waiting yield
-
-
-async def _read_pty(session: _ExecSession, master_fd: int) -> None:
- """Continuously read from a PTY master fd and append to session output."""
- loop = asyncio.get_event_loop()
- try:
- # Set master_fd non-blocking
- flags = fcntl.fcntl(master_fd, fcntl.F_GETFL)
- fcntl.fcntl(master_fd, fcntl.F_SETFL, flags | os.O_NONBLOCK)
-
- # Use loop.add_reader for non-blocking reads
- read_event = asyncio.Event()
-
- def _on_readable() -> None:
- read_event.set()
-
- loop.add_reader(master_fd, _on_readable)
- try:
- while True:
- read_event.clear()
- # Check if process has exited
- rc = session.proc.returncode
- if rc is not None:
- # Drain remaining data
- while True:
- try:
- data = os.read(master_fd, 4096)
- if data:
- await session.append_output(data.decode("utf-8", errors="replace"))
- else:
- break
- except OSError:
- break
- break
- # Wait for readable or timeout
- try:
- await asyncio.wait_for(read_event.wait(), timeout=0.05)
- except asyncio.TimeoutError:
- pass
- try:
- data = os.read(master_fd, 4096)
- if data:
- await session.append_output(data.decode("utf-8", errors="replace"))
- except BlockingIOError:
- pass
- except OSError:
- break
- finally:
- loop.remove_reader(master_fd)
- except Exception: # pylint: disable=broad-except
- pass
- finally:
- try:
- await session.proc.wait()
- except Exception: # pylint: disable=broad-except
- pass
- session._output_event.set()
+ poll = await wait_for_program_output(
+ self.proc,
+ yield_duration_ms(yield_time_ms, DEFAULT_EXEC_YIELD_MS),
+ poll_line_limit(poll_lines),
+ )
+ if poll.exit_code is not None:
+ self.exit_code = poll.exit_code
+ if poll.status == _STATUS_EXITED and self.exited_at is None:
+ self.exited_at = time.time()
+ return poll.status, poll.output, poll.offset, poll.next_offset
# ---------------------------------------------------------------------------
@@ -395,10 +293,6 @@ def _detect_interaction(status: str, output: str) -> Optional[SessionInteraction
def _build_exec_env(ws: WorkspaceInfo, extra: dict[str, str]) -> dict[str, str]:
"""Build the merged environment for a subprocess in *ws*."""
- from trpc_agent_sdk.code_executors._constants import ( # lazy import to avoid circular
- WORKSPACE_ENV_DIR_KEY, ENV_SKILLS_DIR, ENV_WORK_DIR, ENV_OUTPUT_DIR, DIR_RUNS,
- )
- from datetime import datetime
env = os.environ.copy()
run_dir = str(Path(ws.path) / DIR_RUNS / f"run_{datetime.now().strftime('%Y%m%dT%H%M%S_%f')}")
@@ -416,15 +310,6 @@ def _build_exec_env(ws: WorkspaceInfo, extra: dict[str, str]) -> dict[str, str]:
return env
-def _resolve_abs_cwd(ws_path: str, rel_cwd: str) -> str:
- """Return the absolute cwd by joining *ws_path* and *rel_cwd*."""
- if rel_cwd and os.path.isabs(rel_cwd):
- return rel_cwd
- resolved = os.path.normpath(os.path.join(ws_path, rel_cwd or "."))
- os.makedirs(resolved, exist_ok=True)
- return resolved
-
-
# ---------------------------------------------------------------------------
# SkillExecTool (skill_exec)
# ---------------------------------------------------------------------------
@@ -442,7 +327,8 @@ def __init__(
self,
run_tool: SkillRunTool,
filters: Optional[List[BaseFilter]] = None,
- session_ttl: float = DEFAULT_SESSION_TTL,
+ session_ttl: float = DEFAULT_SESSION_TTL_SEC,
+ create_ws_name_cb: Optional[CreateWorkspaceNameCallback] = None,
):
super().__init__(name="skill_exec",
description=("Start an interactive command inside a skill workspace. "
@@ -452,8 +338,8 @@ def __init__(
filters=filters)
self._run_tool = run_tool
self._ttl = session_ttl
+ self._create_ws_name_cb = create_ws_name_cb or default_create_ws_name_callback
self._sessions: dict[str, _ExecSession] = {}
- self._sessions_lock = asyncio.Lock()
# ------------------------------------------------------------------
# Declaration
@@ -461,8 +347,8 @@ def __init__(
@override
def _get_declaration(self) -> FunctionDeclaration:
- params_schema = _inline_json_schema_refs(ExecInput.model_json_schema())
- response_schema = _inline_json_schema_refs(ExecOutput.model_json_schema())
+ params_schema = inline_json_schema_refs(ExecInput.model_json_schema())
+ response_schema = inline_json_schema_refs(ExecOutput.model_json_schema())
return FunctionDeclaration(
name="skill_exec",
description=("Start an interactive command inside a skill workspace. "
@@ -476,37 +362,39 @@ def _get_declaration(self) -> FunctionDeclaration:
# Session management
# ------------------------------------------------------------------
- async def _put_session(self, sid: str, sess: _ExecSession) -> None:
- async with self._sessions_lock:
- await self._gc_expired_locked()
- self._sessions[sid] = sess
+ async def _put_session(self, sid: str, exec_session: _ExecSession) -> None:
+ await self._gc_expired_sessions()
+ self._sessions[sid] = exec_session
- async def get_session(self, sid: str) -> _ExecSession:
- async with self._sessions_lock:
- await self._gc_expired_locked()
- sess = self._sessions.get(sid)
- if sess is None:
+ async def _get_session(self, sid: str) -> _ExecSession:
+ await self._gc_expired_sessions()
+ session = self._sessions.get(sid)
+ if session is None:
raise ValueError(f"unknown session_id: {sid}")
- return sess
+ return session
- async def remove_session(self, sid: str) -> _ExecSession:
- async with self._sessions_lock:
- await self._gc_expired_locked()
- sess = self._sessions.pop(sid, None)
- if sess is None:
+ async def _remove_session(self, sid: str) -> _ExecSession:
+ await self._gc_expired_sessions()
+ session = self._sessions.pop(sid, None)
+ if session is None:
raise ValueError(f"unknown session_id: {sid}")
- return sess
-
- async def _gc_expired_locked(self) -> None:
- if self._ttl <= 0:
- return
- now = time.time()
- expired = [
- sid for sid, s in self._sessions.items() if s.exited_at is not None and (now - s.exited_at) >= self._ttl
- ]
- for sid in expired:
- s = self._sessions.pop(sid)
- _close_session(s)
+ return session
+
+ async def _gc_expired_sessions(self) -> None:
+
+ async def _refresh_exit_state(session: _ExecSession, now: float) -> None:
+ if session.exited_at is not None:
+ return
+ session_state = await session.proc.state()
+ if session_state.status == PROGRAM_STATUS_EXITED:
+ session.exited_at = now
+
+ await cleanup_expired_sessions(
+ self._sessions,
+ ttl=self._ttl,
+ refresh_exit_state=_refresh_exit_state,
+ close_session=_close_session,
+ )
# ------------------------------------------------------------------
# Main execution
@@ -523,26 +411,30 @@ async def _run_async_impl(
inputs = ExecInput.model_validate(args)
except Exception as ex: # pylint: disable=broad-except
raise ValueError(f"Invalid skill_exec arguments: {ex}") from ex
+ normalized_skill = inputs.skill.strip()
+ normalized_command = inputs.command.strip()
+ if not normalized_skill or not normalized_command:
+ raise ValueError("skill and command are required")
+ inputs = inputs.model_copy(update={"skill": normalized_skill, "command": normalized_command})
+
+ if self._run_tool.require_skill_loaded and not self._run_tool._is_skill_loaded(tool_context, normalized_skill):
+ raise ValueError(f"skill_exec requires skill_load first for {normalized_skill!r}")
repository = self._run_tool._get_repository(tool_context)
# Workspace creation
- session_id_ws = inputs.skill
- if tool_context.session and tool_context.session.id:
- session_id_ws = tool_context.session.id
-
workspace_runtime = repository.workspace_runtime
manager = workspace_runtime.manager(tool_context)
- ws = await manager.create_workspace(session_id_ws, tool_context)
+ workspace_id = self._create_ws_name_cb(tool_context)
+ ws = await manager.create_workspace(workspace_id, tool_context)
# Stage skill via the same pluggable stager used by SkillRunTool
stage_result = await self._run_tool.skill_stager.stage_skill(
SkillStageRequest(
- skill_name=inputs.skill,
+ skill_name=normalized_skill,
repository=repository,
workspace=ws,
ctx=tool_context,
- engine=workspace_runtime,
timeout=self._run_tool._timeout,
))
workspace_skill_dir = stage_result.workspace_skill_dir
@@ -553,29 +445,35 @@ async def _run_async_impl(
# Resolve cwd and env
rel_cwd = self._run_tool._resolve_cwd(inputs.cwd, workspace_skill_dir)
- abs_cwd = _resolve_abs_cwd(ws.path, rel_cwd)
extra_env: dict[str, str] = dict(inputs.env)
if ENV_SKILL_NAME not in extra_env:
- extra_env[ENV_SKILL_NAME] = inputs.skill
+ extra_env[ENV_SKILL_NAME] = normalized_skill
merged_env = _build_exec_env(ws, extra_env)
- # Start subprocess
+ # Start interactive program session via runtime runner.
+ runner = workspace_runtime.runner(tool_context)
+ start_program = getattr(runner, "start_program", None)
+ if start_program is None:
+ raise ValueError("skill_exec is not supported by the current executor")
sid = str(uuid.uuid4())
- sess = await _start_session(inputs, ws, abs_cwd, merged_env)
- await self._put_session(sid, sess)
-
- # Write initial stdin if provided
- if inputs.stdin:
- await _write_stdin(sess, inputs.stdin, submit=False)
+ exec_session = await _start_session(
+ runner=runner,
+ tool_context=tool_context,
+ inputs=inputs,
+ ws=ws,
+ rel_cwd=rel_cwd,
+ env=merged_env,
+ )
+ await self._put_session(sid, exec_session)
- yield_ms = inputs.yield_ms or DEFAULT_EXEC_YIELD_MS
- status, chunk, offset, next_offset = await sess.yield_output(yield_ms, inputs.poll_lines)
+ yield_time_ms = inputs.yield_time_ms or DEFAULT_EXEC_YIELD_MS
+ status, chunk, offset, next_offset = await exec_session.yield_output(yield_time_ms, inputs.poll_lines)
# Attempt to collect final result if already exited
final_result = None
- if status == _STATUS_EXITED and not sess.finalized:
- final_result = await _collect_final_result(tool_context, sess, self._run_tool)
+ if status == _STATUS_EXITED and not exec_session.finalized:
+ final_result = await _collect_final_result(tool_context, exec_session, self._run_tool)
out = ExecOutput(
status=status,
@@ -583,10 +481,11 @@ async def _run_async_impl(
output=chunk,
offset=offset,
next_offset=next_offset,
- exit_code=sess.exit_code,
+ exit_code=exec_session.exit_code,
interaction=_detect_interaction(status, chunk),
result=final_result,
)
+ _apply_artifacts_state_delta(tool_context, out)
return out.model_dump(exclude_none=True)
@@ -613,8 +512,8 @@ def __init__(self, exec_tool: SkillExecTool, filters: Optional[List[BaseFilter]]
@override
def _get_declaration(self) -> FunctionDeclaration:
- params_schema = _inline_json_schema_refs(WriteStdinInput.model_json_schema())
- response_schema = _inline_json_schema_refs(ExecOutput.model_json_schema())
+ params_schema = inline_json_schema_refs(WriteStdinInput.model_json_schema())
+ response_schema = inline_json_schema_refs(ExecOutput.model_json_schema())
return FunctionDeclaration(
name="skill_write_stdin",
description=("Write to a running skill_exec session. Set submit=true to "
@@ -635,29 +534,32 @@ async def _run_async_impl(
inputs = WriteStdinInput.model_validate(args)
except Exception as ex: # pylint: disable=broad-except
raise ValueError(f"Invalid skill_write_stdin arguments: {ex}") from ex
+ normalized_session_id = require_non_empty(inputs.session_id, field_name="session_id")
+ inputs = inputs.model_copy(update={"session_id": normalized_session_id})
- sess = await self._exec.get_session(inputs.session_id)
+ exec_session = await self._exec._get_session(inputs.session_id)
if inputs.chars or inputs.submit:
- await _write_stdin(sess, inputs.chars, submit=inputs.submit)
+ await _write_stdin(exec_session, inputs.chars, submit=inputs.submit)
- yield_ms = inputs.yield_ms or DEFAULT_IO_YIELD_MS
- status, chunk, offset, next_offset = await sess.yield_output(yield_ms, inputs.poll_lines)
+ yield_time_ms = inputs.yield_time_ms or DEFAULT_IO_YIELD_MS
+ status, chunk, offset, next_offset = await exec_session.yield_output(yield_time_ms, inputs.poll_lines)
final_result = None
- if status == _STATUS_EXITED and not sess.finalized:
- final_result = await _collect_final_result(tool_context, sess, self._exec._run_tool)
+ if status == _STATUS_EXITED and not exec_session.finalized:
+ final_result = await _collect_final_result(tool_context, exec_session, self._exec._run_tool)
out = ExecOutput(
status=status,
- session_id=inputs.session_id,
+ session_id=normalized_session_id,
output=chunk,
offset=offset,
next_offset=next_offset,
- exit_code=sess.exit_code,
+ exit_code=exec_session.exit_code,
interaction=_detect_interaction(status, chunk),
result=final_result,
)
+ _apply_artifacts_state_delta(tool_context, out)
return out.model_dump(exclude_none=True)
@@ -678,8 +580,8 @@ def __init__(self, exec_tool: SkillExecTool, filters: Optional[List[BaseFilter]]
@override
def _get_declaration(self) -> FunctionDeclaration:
- params_schema = _inline_json_schema_refs(PollSessionInput.model_json_schema())
- response_schema = _inline_json_schema_refs(ExecOutput.model_json_schema())
+ params_schema = inline_json_schema_refs(PollSessionInput.model_json_schema())
+ response_schema = inline_json_schema_refs(ExecOutput.model_json_schema())
return FunctionDeclaration(
name="skill_poll_session",
description=("Poll a running or recently exited skill_exec session for "
@@ -699,26 +601,29 @@ async def _run_async_impl(
inputs = PollSessionInput.model_validate(args)
except Exception as ex: # pylint: disable=broad-except
raise ValueError(f"Invalid skill_poll_session arguments: {ex}") from ex
+ normalized_session_id = require_non_empty(inputs.session_id, field_name="session_id")
+ inputs = inputs.model_copy(update={"session_id": normalized_session_id})
- sess = await self._exec.get_session(inputs.session_id)
+ exec_session = await self._exec._get_session(inputs.session_id)
- yield_ms = inputs.yield_ms or DEFAULT_IO_YIELD_MS
- status, chunk, offset, next_offset = await sess.yield_output(yield_ms, inputs.poll_lines)
+ yield_time_ms = inputs.yield_time_ms or DEFAULT_IO_YIELD_MS
+ status, chunk, offset, next_offset = await exec_session.yield_output(yield_time_ms, inputs.poll_lines)
final_result = None
- if status == _STATUS_EXITED and not sess.finalized:
- final_result = await _collect_final_result(tool_context, sess, self._exec._run_tool)
+ if status == _STATUS_EXITED and not exec_session.finalized:
+ final_result = await _collect_final_result(tool_context, exec_session, self._exec._run_tool)
out = ExecOutput(
status=status,
- session_id=inputs.session_id,
+ session_id=normalized_session_id,
output=chunk,
offset=offset,
next_offset=next_offset,
- exit_code=sess.exit_code,
+ exit_code=exec_session.exit_code,
interaction=_detect_interaction(status, chunk),
result=final_result,
)
+ _apply_artifacts_state_delta(tool_context, out)
return out.model_dump(exclude_none=True)
@@ -738,8 +643,8 @@ def __init__(self, exec_tool: SkillExecTool, filters: Optional[List[BaseFilter]]
@override
def _get_declaration(self) -> FunctionDeclaration:
- params_schema = _inline_json_schema_refs(KillSessionInput.model_json_schema())
- response_schema = _inline_json_schema_refs(SessionKillOutput.model_json_schema())
+ params_schema = inline_json_schema_refs(KillSessionInput.model_json_schema())
+ response_schema = inline_json_schema_refs(SessionKillOutput.model_json_schema())
return FunctionDeclaration(
name="skill_kill_session",
description="Terminate and remove a skill_exec session.",
@@ -758,25 +663,26 @@ async def _run_async_impl(
inputs = KillSessionInput.model_validate(args)
except Exception as ex: # pylint: disable=broad-except
raise ValueError(f"Invalid skill_kill_session arguments: {ex}") from ex
+ normalized_session_id = require_non_empty(inputs.session_id, field_name="session_id")
+ inputs = inputs.model_copy(update={"session_id": normalized_session_id})
+ exec_session = await self._exec._get_session(normalized_session_id)
- sess = await self._exec.remove_session(inputs.session_id)
-
- rc = sess.proc.returncode
final_status = _STATUS_EXITED
- if rc is None:
+ poll = await exec_session.proc.poll(None)
+ if poll.status == _STATUS_RUNNING:
try:
- sess.proc.kill()
- await asyncio.wait_for(sess.proc.wait(), timeout=5.0)
+ await exec_session.proc.kill(DEFAULT_SESSION_KILL_SEC)
except Exception: # pylint: disable=broad-except
pass
final_status = "killed"
- _close_session(sess)
+ await self._exec._remove_session(normalized_session_id)
+ await _close_session(exec_session)
out = SessionKillOutput(
ok=True,
- session_id=inputs.session_id,
+ session_id=normalized_session_id,
status=final_status,
)
return out.model_dump()
@@ -790,7 +696,7 @@ async def _run_async_impl(
def create_exec_tools(
run_tool: SkillRunTool,
filters: Optional[List[BaseFilter]] = None,
- session_ttl: float = DEFAULT_SESSION_TTL,
+ session_ttl: float = DEFAULT_SESSION_TTL_SEC,
) -> tuple[SkillExecTool, WriteStdinTool, PollSessionTool, KillSessionTool]:
"""Create the full set of interactive exec tools sharing one session store.
@@ -823,85 +729,46 @@ def create_exec_tools(
async def _start_session(
+ *,
+ runner: BaseProgramRunner,
+ tool_context: InvocationContext,
inputs: ExecInput,
ws: WorkspaceInfo,
- abs_cwd: str,
+ rel_cwd: str,
env: dict[str, str],
) -> _ExecSession:
- """Spawn a subprocess and return an initialized :class:`_ExecSession`."""
- command = inputs.command
- master_fd: Optional[int] = None
-
- if inputs.tty:
- # Allocate a pseudo-TTY.
- master_fd, slave_fd = pty.openpty()
- try:
- proc = await asyncio.create_subprocess_exec(
- "bash",
- "-c",
- command,
- stdin=slave_fd,
- stdout=slave_fd,
- stderr=slave_fd,
- cwd=abs_cwd,
- env=env,
- close_fds=True,
- preexec_fn=os.setsid,
- )
- finally:
- os.close(slave_fd) # parent doesn't need the slave end
-
- sess = _ExecSession(proc=proc, ws=ws, in_data=inputs, master_fd=master_fd)
- sess.reader_task = asyncio.create_task(_read_pty(sess, master_fd))
- else:
- proc = await asyncio.create_subprocess_exec(
- "bash",
- "-c",
- command,
- stdin=asyncio.subprocess.PIPE,
- stdout=asyncio.subprocess.PIPE,
- stderr=asyncio.subprocess.STDOUT,
- cwd=abs_cwd,
- env=env,
- )
- sess = _ExecSession(proc=proc, ws=ws, in_data=inputs)
- if proc.stdout:
- sess.reader_task = asyncio.create_task(_read_pipe(sess, proc.stdout))
-
- return sess
+ """Start a ProgramSession via runtime runner."""
+ spec = WorkspaceRunProgramSpec(
+ cmd="bash",
+ args=["-c", inputs.command],
+ env=env,
+ cwd=rel_cwd,
+ stdin=inputs.stdin,
+ timeout=float(inputs.timeout or 0),
+ tty=inputs.tty,
+ )
+ proc = await runner.start_program(tool_context, ws, spec)
+ return _ExecSession(proc=proc, ws=ws, in_data=inputs)
-async def _write_stdin(sess: _ExecSession, chars: str, submit: bool) -> None:
+async def _write_stdin(exec_session: _ExecSession, chars: str, submit: bool) -> None:
"""Write *chars* (and optionally a newline) to the session's stdin."""
- if sess.master_fd is not None:
- # PTY path: write to master fd
- data = (chars + ("\n" if submit else "")).encode("utf-8")
- if data:
- try:
- os.write(sess.master_fd, data)
- except OSError as ex:
- logger.debug("skill_exec: write to pty failed: %s", ex)
- elif sess.proc.stdin:
- # Pipe path: use asyncio StreamWriter
- data = (chars + ("\n" if submit else "")).encode("utf-8")
- if data:
- try:
- sess.proc.stdin.write(data)
- await sess.proc.stdin.drain()
- except Exception as ex: # pylint: disable=broad-except
- logger.debug("skill_exec: write to stdin failed: %s", ex)
+ try:
+ await exec_session.proc.write(chars, submit)
+ except Exception as ex: # pylint: disable=broad-except
+ logger.debug("skill_exec: write to stdin failed: %s", ex)
async def _collect_final_result(
ctx: InvocationContext,
- sess: _ExecSession,
+ exec_session: _ExecSession,
run_tool: SkillRunTool,
) -> Optional[SkillRunOutput]:
"""Collect output files and build the final :class:`SkillRunOutput`."""
- if sess.finalized:
- return sess.final_result
+ if exec_session.finalized:
+ return exec_session.final_result
- in_data = sess.in_data
+ in_data = exec_session.in_data
fake_run_input = SkillRunInput(
skill=in_data.skill,
command=in_data.command,
@@ -916,13 +783,19 @@ async def _collect_final_result(
outputs=in_data.outputs,
)
try:
- files, manifest = await run_tool._prepare_outputs(ctx, sess.ws, fake_run_input)
+ files, manifest = await run_tool._prepare_outputs(ctx, exec_session.ws, fake_run_input)
except Exception as ex: # pylint: disable=broad-except
logger.warning("skill_exec: collect outputs failed: %s", ex)
files, manifest = [], None
- total_out = await sess.total_output()
- exit_code = sess.exit_code or 0
+ try:
+ run_result = await exec_session.proc.run_result()
+ total_out = (run_result.stdout or "") + (run_result.stderr or "")
+ exit_code = run_result.exit_code
+ except Exception: # pylint: disable=broad-except
+ run_log = await exec_session.proc.log(None, None)
+ total_out = run_log.output or ""
+ exit_code = exec_session.exit_code or 0
# Reuse the same output-quality helpers as skill_run
warnings: list[str] = []
@@ -944,30 +817,21 @@ async def _collect_final_result(
)
try:
- await run_tool._attach_artifacts_if_requested(ctx, sess.ws, fake_run_input, result, files)
+ await run_tool._attach_artifacts_if_requested(ctx, exec_session.ws, fake_run_input, result, files)
except Exception as ex: # pylint: disable=broad-except
logger.warning("skill_exec: attach artifacts failed: %s", ex)
if manifest:
run_tool._merge_manifest_artifact_refs(manifest, result)
- sess.final_result = result
- sess.finalized = True
+ exec_session.final_result = result
+ exec_session.finalized = True
return result
-def _close_session(sess: _ExecSession) -> None:
- """Cancel background tasks and close any open fds."""
- if sess.reader_task and not sess.reader_task.done():
- sess.reader_task.cancel()
- if sess.master_fd is not None:
- try:
- os.close(sess.master_fd)
- except OSError:
- pass
- sess.master_fd = None
- if sess.proc.stdin and not sess.proc.stdin.is_closing():
- try:
- sess.proc.stdin.close()
- except Exception: # pylint: disable=broad-except
- pass
+async def _close_session(exec_session: _ExecSession) -> None:
+ """Release program session resources."""
+ try:
+ await exec_session.proc.close()
+ except Exception: # pylint: disable=broad-except
+ pass
diff --git a/trpc_agent_sdk/skills/tools/_skill_list.py b/trpc_agent_sdk/skills/tools/_skill_list.py
index 83d0903..1bfa72e 100644
--- a/trpc_agent_sdk/skills/tools/_skill_list.py
+++ b/trpc_agent_sdk/skills/tools/_skill_list.py
@@ -8,8 +8,8 @@
from __future__ import annotations
-from typing import Optional
from typing import Literal
+from typing import Optional
from trpc_agent_sdk.context import InvocationContext
@@ -17,16 +17,16 @@
from .._repository import BaseSkillRepository
-def skill_list(tool_context: InvocationContext, mode: Literal["all", "enabled", "disabled"] = "all") -> list[str]:
+def skill_list(tool_context: InvocationContext, mode: Literal['all', 'available'] = 'all') -> list[str]:
"""List all discovered skills.
Args:
tool_context: The tool context.
-
+ mode: The mode to list skills.
Returns:
A list of skill names.
"""
repository: Optional[BaseSkillRepository] = tool_context.agent_context.get_metadata(SKILL_REPOSITORY_KEY)
if repository is None:
raise ValueError("repository not found")
- return repository.skill_list()
+ return repository.skill_list(mode)
diff --git a/trpc_agent_sdk/skills/tools/_skill_list_docs.py b/trpc_agent_sdk/skills/tools/_skill_list_docs.py
index f1e82e9..02f4abb 100644
--- a/trpc_agent_sdk/skills/tools/_skill_list_docs.py
+++ b/trpc_agent_sdk/skills/tools/_skill_list_docs.py
@@ -8,32 +8,33 @@
from __future__ import annotations
-from typing import Any
from typing import Optional
from trpc_agent_sdk.context import InvocationContext
-from trpc_agent_sdk.log import logger
from .._constants import SKILL_REPOSITORY_KEY
from .._repository import BaseSkillRepository
-def skill_list_docs(tool_context: InvocationContext, skill_name: str) -> dict[str, Any]:
+def skill_list_docs(tool_context: InvocationContext, skill_name: str) -> list[str]:
"""List doc filenames for a skill.
+
Args:
- skill_name: The name of the skill to load.
+ skill_name: The name of the skill.
Returns:
- Object containing docs list and whether SKILL.md body is loaded.
+ Array of doc filenames.
"""
+ normalized_skill = (skill_name or "").strip()
+ if not normalized_skill:
+ raise ValueError("skill is required")
+
repository: Optional[BaseSkillRepository] = tool_context.agent_context.get_metadata(SKILL_REPOSITORY_KEY)
if repository is None:
- raise ValueError("repository not found")
- skill = repository.get(skill_name)
- if skill is None:
- logger.error("Skill %s not found", repr(skill_name))
- return {"docs": [], "body_loaded": False}
- return {
- "docs": [resource.path for resource in skill.resources],
- "body_loaded": bool(skill.body),
- }
+ return []
+
+ try:
+ skill = repository.get(normalized_skill)
+ except ValueError as ex:
+ raise ValueError(f"unknown skill: {normalized_skill}") from ex
+ return [resource.path for resource in (skill.resources or [])]
diff --git a/trpc_agent_sdk/skills/tools/_skill_list_tool.py b/trpc_agent_sdk/skills/tools/_skill_list_tool.py
index 4ecfbb2..942d127 100644
--- a/trpc_agent_sdk/skills/tools/_skill_list_tool.py
+++ b/trpc_agent_sdk/skills/tools/_skill_list_tool.py
@@ -1,4 +1,4 @@
-# Tencent is pleased to support the open source community by making tRPC-Agent-Python available.
+# -*- coding: utf-8 -*-
#
# Copyright (C) 2026 Tencent. All rights reserved.
#
@@ -9,7 +9,6 @@
from __future__ import annotations
-import re
from typing import Any
from typing import Optional
@@ -20,63 +19,13 @@
from .._repository import BaseSkillRepository
-def _extract_shell_examples_from_skill_body(body: str, limit: int = 5) -> list[str]:
- """Extract likely runnable command examples from SKILL.md body."""
- if not body:
- return []
- out: list[str] = []
- seen: set[str] = set()
- lines = body.splitlines()
-
- def maybe_add(cmd: str) -> None:
- cmd = re.sub(r"\s+", " ", (cmd or "").strip())
- if not cmd or cmd in seen:
- return
- if not re.match(r"^[A-Za-z0-9_./$\"'`-]", cmd):
- return
- seen.add(cmd)
- out.append(cmd)
-
- i = 0
- while i < len(lines) and len(out) < limit:
- cur = lines[i].strip()
- if cur.lower() != "command:":
- i += 1
- continue
- i += 1
- block: list[str] = []
- while i < len(lines):
- raw = lines[i]
- s = raw.strip()
- if not s:
- if block:
- break
- i += 1
- continue
- if s.lower() in ("command:", "output files", "overview", "examples", "tools:"):
- break
- if re.match(r"^\d+\)", s):
- break
- if raw.startswith(" ") or raw.startswith("\t"):
- block.append(s)
- i += 1
- continue
- if block:
- break
- i += 1
- if block:
- merged = " ".join(part.rstrip("\\").strip() for part in block)
- maybe_add(merged)
- return out
-
-
def skill_list_tools(tool_context: InvocationContext, skill_name: str) -> dict[str, Any]:
- """List executable guidance for a skill.
+ """List callable tools declared for a skill.
Args:
skill_name: The name of the skill to load.
Returns:
- Object containing declared tools and command examples from SKILL.md.
+ Object containing available tools.
"""
repository: Optional[BaseSkillRepository] = tool_context.agent_context.get_metadata(SKILL_REPOSITORY_KEY)
if repository is None:
@@ -84,8 +33,5 @@ def skill_list_tools(tool_context: InvocationContext, skill_name: str) -> dict[s
skill = repository.get(skill_name)
if skill is None:
logger.error("Skill %s not found", repr(skill_name))
- return {"tools": [], "command_examples": []}
- return {
- "tools": list(skill.tools or []),
- "command_examples": _extract_shell_examples_from_skill_body(skill.body, limit=5),
- }
+ return {"available_tools": []}
+ return {"available_tools": list(skill.tools or [])}
diff --git a/trpc_agent_sdk/skills/tools/_skill_load.py b/trpc_agent_sdk/skills/tools/_skill_load.py
index ec2c62c..3579d31 100644
--- a/trpc_agent_sdk/skills/tools/_skill_load.py
+++ b/trpc_agent_sdk/skills/tools/_skill_load.py
@@ -1,4 +1,4 @@
-# Tencent is pleased to support the open source community by making tRPC-Agent-Python available.
+# -*- coding: utf-8 -*-
#
# Copyright (C) 2026 Tencent. All rights reserved.
#
@@ -12,65 +12,138 @@
import json
from typing import Any
+from typing import List
from typing import Optional
+from typing_extensions import override
from trpc_agent_sdk.context import InvocationContext
+from trpc_agent_sdk.filter import BaseFilter
+from trpc_agent_sdk.tools import BaseTool
+from trpc_agent_sdk.types import FunctionDeclaration
+from trpc_agent_sdk.types import Schema
+from trpc_agent_sdk.types import Type
-from .._constants import SKILL_DOCS_STATE_KEY_PREFIX
-from .._constants import SKILL_LOADED_STATE_KEY_PREFIX
-from .._constants import SKILL_REPOSITORY_KEY
-from .._constants import SKILL_TOOLS_STATE_KEY_PREFIX
+from .._common import append_loaded_order_state_delta
+from .._common import docs_state_key
+from .._common import loaded_state_key
+from .._common import set_state_delta
+from .._common import tool_state_key
from .._repository import BaseSkillRepository
+from .._types import Skill
+from ..stager import SkillStageRequest
+from ..stager import Stager
+from ._common import CreateWorkspaceNameCallback
+from ._common import default_create_ws_name_callback
+from ._common import set_staged_workspace_dir
+from ._copy_stager import CopySkillStager
-def _set_state_delta(invocation_context: InvocationContext, key: str, value: Any) -> None:
- """Set the state delta of a skill loaded."""
- invocation_context.actions.state_delta[key] = value
+class SkillLoadTool(BaseTool):
+ """Tool for loading a skill."""
+ def __init__(
+ self,
+ repository: BaseSkillRepository,
+ skill_stager: Optional[Stager] = None,
+ create_ws_name_cb: Optional[CreateWorkspaceNameCallback] = None,
+ filters: Optional[List[BaseFilter]] = None,
+ ):
+ super().__init__(name="skill_load", description="Load a skill.", filters=filters)
+ self._repository = repository
+ self._skill_stager: Stager = skill_stager or CopySkillStager()
+ self._create_ws_name_cb: Optional[
+ CreateWorkspaceNameCallback] = create_ws_name_cb or default_create_ws_name_callback
-def _set_state_delta_for_skill_load(invocation_context: InvocationContext,
- skill_name: str,
- docs: list[str],
- include_all_docs: bool = False) -> None:
- """Set the state delta of a skill loaded."""
- key = f"{SKILL_LOADED_STATE_KEY_PREFIX}{skill_name}"
- _set_state_delta(invocation_context, key, True)
- key = f"{SKILL_DOCS_STATE_KEY_PREFIX}{skill_name}"
- if include_all_docs:
- _set_state_delta(invocation_context, key, '*')
- else:
- _set_state_delta(invocation_context, key, json.dumps(docs or []))
+ @override
+ def _get_declaration(self) -> Optional[FunctionDeclaration]:
+ return FunctionDeclaration(
+ name="skill_load",
+ description=("Load a skill body and optional docs. Safe to call multiple times to add or replace docs. "
+ "Do not call this to list skills; names and descriptions are already in context. "
+ "Use when a task needs a skill's SKILL.md body and selected docs in context."),
+ parameters=Schema(
+ type=Type.OBJECT,
+ required=["skill_name"],
+ properties={
+ "skill_name":
+ Schema(type=Type.STRING, description="The name of the skill to load."),
+ "docs":
+ Schema(type=Type.ARRAY,
+ default=None,
+ items=Schema(type=Type.STRING),
+ description="The docs of the skill to load."),
+ "include_all_docs":
+ Schema(type=Type.BOOLEAN, default=False, description="Whether to include all docs of the skill."),
+ },
+ ),
+ response=Schema(type=Type.STRING,
+ description="Result of skill_load. message is a string indicating the skill was loaded."),
+ )
+ @override
+ async def _run_async_impl(self, *, tool_context: InvocationContext, args: dict[str, Any]) -> str:
+ if not (args["skill_name"] or "").strip():
+ raise ValueError("skill_name is required")
+ skill_name = args["skill_name"]
+ docs = args.get("docs", [])
+ include_all_docs = args.get("include_all_docs", False)
+ normalized_skill = skill_name.strip()
+ skill = self._repository.get(normalized_skill)
+ await self._ensure_staged(ctx=tool_context, skill_name=skill_name)
+ clean_docs = [doc.strip() for doc in (docs or []) if isinstance(doc, str) and doc.strip()]
+ self.__set_state_delta_for_skill_load(tool_context, skill_name, clean_docs, include_all_docs)
+ if skill.tools:
+ self.__set_state_delta_for_skill_tools(tool_context, skill)
+ return f"skill {skill_name!r} loaded"
-def _set_state_delta_for_skill_tools(invocation_context: InvocationContext, skill_name: str, tools: list[str]) -> None:
- """Set the state delta of a skill tools."""
- key = f"{SKILL_TOOLS_STATE_KEY_PREFIX}{skill_name}"
- _set_state_delta(invocation_context, key, json.dumps(tools or []))
+ async def _ensure_staged(self, *, ctx: InvocationContext, skill_name: str) -> None:
+ runtime = self._repository.workspace_runtime
+ manager = runtime.manager(ctx)
+ ws_id = self._create_ws_name_cb(ctx)
+ ws = await manager.create_workspace(ws_id, ctx)
+ result = await self._skill_stager.stage_skill(
+ SkillStageRequest(skill_name=skill_name, repository=self._repository, workspace=ws, ctx=ctx))
+ set_staged_workspace_dir(ctx, skill_name, result.workspace_skill_dir)
+ def __set_state_delta_for_skill_load(self,
+ invocation_context: InvocationContext,
+ skill_name: str,
+ docs: list[str],
+ include_all_docs: bool = False) -> None:
+ """Set state delta for skill_load, aligned with Go StateDeltaForInvocation."""
+ agent_name = invocation_context.agent_name.strip()
+ delta, normalized_skill = self.__build_state_delta_for_skill_load(
+ invocation_context=invocation_context,
+ skill_name=skill_name,
+ docs=docs,
+ include_all_docs=include_all_docs,
+ )
+ invocation_context.actions.state_delta.update(delta)
+ append_loaded_order_state_delta(invocation_context, agent_name, normalized_skill)
-def skill_load(tool_context: InvocationContext,
- skill_name: str,
- docs: Optional[list[str]] = None,
- include_all_docs: bool = False) -> str:
- """Load a skill body and optional docs. Safe to call multiple times to add or replace docs.
- Do not call this to list skills; names and descriptions are already in context.
- Use when a task needs a skill's SKILL.md body and selected docs in context.
- Args:
- skill_name: The name of the skill to load.
- docs: The docs of the skill to load.
- include_all_docs: Whether to include all docs of the skill.
+ def __build_state_delta_for_skill_load(
+ self,
+ invocation_context: InvocationContext,
+ skill_name: str,
+ docs: list[str],
+ include_all_docs: bool = False,
+ ) -> tuple[dict[str, Any], str]:
+ """Build skill_load state delta."""
+ normalized_skill = skill_name.strip()
+ if not normalized_skill:
+ return {}, ""
+ delta: dict[str, Any] = {}
+ delta[loaded_state_key(invocation_context, normalized_skill)] = True
+ if include_all_docs:
+ delta[docs_state_key(invocation_context, normalized_skill)] = '*'
+ else:
+ delta[docs_state_key(invocation_context, normalized_skill)] = json.dumps(docs or [])
+ return delta, normalized_skill
- Returns:
- A message indicating the skill was loaded.
- """
-
- repository: Optional[BaseSkillRepository] = tool_context.agent_context.get_metadata(SKILL_REPOSITORY_KEY)
- if repository is None:
- raise ValueError("repository not found")
- skill = repository.get(skill_name)
- if skill is None:
- return f"skill {skill_name!r} not found"
- _set_state_delta_for_skill_load(tool_context, skill_name, docs or [], include_all_docs)
- if skill.tools:
- _set_state_delta_for_skill_tools(tool_context, skill_name, skill.tools)
- return f"skill {skill_name!r} loaded"
+ def __set_state_delta_for_skill_tools(self, invocation_context: InvocationContext, skill: Skill) -> None:
+ """Set the state delta of a skill tools."""
+ normalized_skill = skill.summary.name.strip()
+ if not normalized_skill:
+ return
+ key = tool_state_key(invocation_context, normalized_skill)
+ set_state_delta(invocation_context, key, json.dumps(skill.tools))
diff --git a/trpc_agent_sdk/skills/tools/_skill_run.py b/trpc_agent_sdk/skills/tools/_skill_run.py
index 23ac4fa..01dcdc8 100644
--- a/trpc_agent_sdk/skills/tools/_skill_run.py
+++ b/trpc_agent_sdk/skills/tools/_skill_run.py
@@ -3,19 +3,12 @@
# Copyright (C) 2026 Tencent. All rights reserved.
#
# tRPC-Agent-Python is licensed under Apache-2.0.
-"""Skill run tool for executing commands in skill workspaces.
-
-This module provides the SkillRunTool class which allows LLM to execute commands
-inside a skill workspace. It stages the entire skill directory and runs commands,
-aligned with the Go implementation at:
-https://github.com/trpc-group/trpc-agent-go/blob/main/tool/skill/run.go
-"""
+"""Skill run tool for executing commands in skill workspaces."""
from __future__ import annotations
import os
-import re
-import shlex
+import posixpath
from pathlib import Path
from typing import Any
from typing import Dict
@@ -45,12 +38,19 @@
from trpc_agent_sdk.types import Part
from trpc_agent_sdk.types import Schema
-from .._constants import SKILL_LOADED_STATE_KEY_PREFIX
+from .._common import get_state_delta_value
+from .._common import loaded_state_key
+from .._constants import SKILL_ARTIFACTS_STATE_KEY
from .._constants import SKILL_REPOSITORY_KEY
from .._repository import BaseSkillRepository
from .._utils import shell_quote
from ..stager import SkillStageRequest
from ..stager import Stager
+from ..stager import default_workspace_skill_dir
+from ._common import CreateWorkspaceNameCallback
+from ._common import default_create_ws_name_callback
+from ._common import get_staged_workspace_dir
+from ._common import inline_json_schema_refs
from ._copy_stager import CopySkillStager
# ---------------------------------------------------------------------------
@@ -68,7 +68,7 @@
_ENV_EDITOR = "EDITOR"
_ENV_VISUAL = "VISUAL"
-_EDITOR_HELPER_DIR = ".trpc_agent_sdk"
+_EDITOR_HELPER_DIR = ".trpc_agent"
_EDITOR_CONTENT_FILE = "editor_input.txt"
_EDITOR_SCRIPT_FILE = "editor_write.sh"
@@ -115,35 +115,6 @@
# ---------------------------------------------------------------------------
-def _inline_json_schema_refs(schema: dict[str, Any]) -> dict[str, Any]:
- """Inline $ref references in JSON Schema by replacing them with actual definitions."""
- defs = schema.get('$defs', {})
- if not defs:
- return schema
-
- def resolve_ref(obj: Any) -> Any:
- if isinstance(obj, dict):
- if '$ref' in obj:
- ref_path = obj['$ref']
- if ref_path.startswith('#/$defs/'):
- ref_name = ref_path.replace('#/$defs/', '')
- if ref_name in defs:
- resolved = resolve_ref(defs[ref_name])
- merged = {**resolved, **{k: v for k, v in obj.items() if k != '$ref'}}
- return merged
- return obj
- else:
- return {k: resolve_ref(v) for k, v in obj.items()}
- elif isinstance(obj, list):
- return [resolve_ref(item) for item in obj]
- else:
- return obj
-
- result = {k: v for k, v in schema.items() if k != '$defs'}
- result = resolve_ref(result)
- return result
-
-
def _is_text_mime(mime: str) -> bool:
"""Return True when *mime* is a text-like content type."""
if not mime:
@@ -176,6 +147,51 @@ def _workspace_ref(name: str) -> str:
return f"workspace://{name}" if name else ""
+def _normalize_input_dst(dst: str) -> str:
+ """Normalize declarative input destination like Go normalizeInputTo."""
+ s = (dst or "").strip().replace("\\", "/")
+ if not s:
+ return ""
+ cleaned = posixpath.normpath(s)
+ if cleaned in (".", "inputs"):
+ return ""
+ prefix = "inputs/"
+ if cleaned.startswith(prefix):
+ rest = cleaned[len(prefix):]
+ return posixpath.join(DIR_WORK, "inputs", rest)
+ return cleaned
+
+
+def _normalize_run_input(input_data: "SkillRunInput") -> "SkillRunInput":
+ """Normalize run input fields like Go normalizeRunInput."""
+ if not input_data.inputs:
+ return input_data
+ normalized_inputs: list[WorkspaceInputSpec] = []
+ for spec in input_data.inputs:
+ normalized_inputs.append(spec.model_copy(update={"dst": _normalize_input_dst(spec.dst)}))
+ return input_data.model_copy(update={"inputs": normalized_inputs})
+
+
+def _apply_run_artifacts_state_delta(ctx: InvocationContext, output: "SkillRunOutput") -> None:
+ """Write run artifact refs to state delta (Go RunTool.StateDelta parity)."""
+ tool_call_id = (ctx.function_call_id or "").strip()
+ if not tool_call_id or not output.artifact_files:
+ return
+ artifacts: list[dict[str, Any]] = []
+ for item in output.artifact_files:
+ name = (item.name or "").strip()
+ version = int(item.version)
+ if not name or version < 0:
+ continue
+ artifacts.append({"name": name, "version": version, "ref": f"artifact://{name}@{version}"})
+ if not artifacts:
+ return
+ ctx.actions.state_delta[SKILL_ARTIFACTS_STATE_KEY] = {
+ "tool_call_id": tool_call_id,
+ "artifacts": artifacts,
+ }
+
+
def _filter_failed_empty_outputs(
exit_code: int,
timed_out: bool,
@@ -215,7 +231,7 @@ def _split_command_line(cmd: str) -> list[str]:
raise ValueError("skill_run: command is empty")
for ch in _DISALLOWED_SHELL_META:
if ch in cmd:
- raise ValueError(f"skill_run: shell metacharacter {ch!r} is not allowed when "
+ raise ValueError(f"skill_run: shell meta character {ch!r} is not allowed when "
"command restrictions are enabled. Provide a single executable "
"with args only (no redirects/pipes/chaining).")
args: list[str] = []
@@ -345,16 +361,6 @@ class SkillRunOutput(BaseModel):
default_factory=list,
description="Non-fatal warnings about truncation, persistence, or empty outputs",
)
- suggested_commands: Optional[list[str]] = Field(
- default=None,
- description=("Suggested runnable commands extracted from SKILL.md when the provided "
- "command is not found."),
- )
- suggested_tools: Optional[list[str]] = Field(
- default=None,
- description=("Suggested tool names from SKILL.md Tools section when command-not-found "
- "indicates the request should use tool calls instead of shell execution."),
- )
# ---------------------------------------------------------------------------
@@ -375,11 +381,11 @@ def __init__(
filters: Optional[List[BaseFilter]] = None,
*,
require_skill_loaded: bool = False,
- block_inline_python_rewrite: bool = False,
force_save_artifacts: bool = False,
allowed_cmds: Optional[List[str]] = None,
denied_cmds: Optional[List[str]] = None,
skill_stager: Optional[Stager] = None,
+ create_ws_name_cb: Optional[CreateWorkspaceNameCallback] = None,
**kwargs,
):
"""Initialize SkillRunTool.
@@ -389,9 +395,6 @@ def __init__(
filters: Optional tool filters.
require_skill_loaded: When True, skill_run raises unless skill_load was called first
for this skill in the current session.
- block_inline_python_rewrite: When True, reject ad-hoc ``python -c`` commands
- if SKILL.md already provides script-based python
- command examples (for example ``python3 scripts/foo.py``).
force_save_artifacts: When True, always attempt to persist collected output files
via the artifact service (if available).
allowed_cmds: When set, only these command names (first token) are allowed.
@@ -414,7 +417,6 @@ def __init__(
)
self._repository = repository
self._require_skill_loaded = require_skill_loaded
- self._block_inline_python_rewrite = block_inline_python_rewrite
self._force_save_artifacts = force_save_artifacts
self._allowed_cmds: frozenset[str] = frozenset(c.strip() for c in (allowed_cmds or []) if c.strip())
self._denied_cmds: frozenset[str] = frozenset(c.strip() for c in (denied_cmds or []) if c.strip())
@@ -429,10 +431,17 @@ def __init__(
self._kwargs = kwargs
self._run_tool_kwargs: dict = kwargs.pop("run_tool_kwargs", {})
self._timeout = self._run_tool_kwargs.pop("timeout", 300.0)
+ self._create_ws_name_cb: Optional[CreateWorkspaceNameCallback] = \
+ create_ws_name_cb or default_create_ws_name_callback
# Staging strategy: default is copy-based stager (mirrors Go newCopySkillStager)
self._skill_stager: Stager = skill_stager or CopySkillStager()
+ @property
+ def require_skill_loaded(self) -> bool:
+ """Get the require_skill_loaded flag."""
+ return self._require_skill_loaded
+
@property
def skill_stager(self) -> Stager:
"""Get the skill stager."""
@@ -444,8 +453,8 @@ def skill_stager(self) -> Stager:
@override
def _get_declaration(self) -> FunctionDeclaration:
- params_schema = _inline_json_schema_refs(SkillRunInput.model_json_schema())
- response_schema = _inline_json_schema_refs(SkillRunOutput.model_json_schema())
+ params_schema = inline_json_schema_refs(SkillRunInput.model_json_schema())
+ response_schema = inline_json_schema_refs(SkillRunOutput.model_json_schema())
desc = ("Run a command inside a skill workspace. "
"Use it only for commands required by the skill docs (not for generic shell tasks). "
"User-uploaded file inputs are staged under $WORK_DIR/inputs (also visible as inputs/). "
@@ -482,11 +491,10 @@ def _get_repository(self, context: InvocationContext) -> BaseSkillRepository:
def _is_skill_loaded(self, ctx: InvocationContext, skill_name: str) -> bool:
"""Return True when the skill was loaded in the current session."""
try:
- key = f"{SKILL_LOADED_STATE_KEY_PREFIX}{skill_name.strip()}"
- val = ctx.session_state.get(key)
- return bool(val)
+ key = loaded_state_key(ctx, skill_name.strip())
+ return bool(get_state_delta_value(ctx, key))
except Exception: # pylint: disable=broad-except
- return True # default to allowed when state is not accessible
+ return False
# ------------------------------------------------------------------
# Editor helper
@@ -639,6 +647,10 @@ async def _run_async_impl(
inputs = SkillRunInput.model_validate(args)
except Exception as ex: # pylint: disable=broad-except
raise ValueError(f"Invalid skill_run arguments: {ex}") from ex
+ if not (inputs.skill or "").strip() or not (inputs.command or "").strip():
+ raise ValueError("skill and command are required")
+ inputs = inputs.model_copy(update={"skill": inputs.skill.strip(), "command": inputs.command.strip()})
+ inputs = _normalize_run_input(inputs)
# require_skill_loaded gate
if self._require_skill_loaded and not self._is_skill_loaded(tool_context, inputs.skill):
@@ -653,25 +665,27 @@ async def _run_async_impl(
repository = self._get_repository(tool_context)
- session_id = inputs.skill
- if tool_context.session and tool_context.session.id:
- session_id = tool_context.session.id
-
+ workspace_id = self._create_ws_name_cb(tool_context)
workspace_runtime = repository.workspace_runtime
manager = workspace_runtime.manager(tool_context)
- ws = await manager.create_workspace(session_id, tool_context)
-
- # Stage skill via the pluggable stager strategy
- stage_result = await self._skill_stager.stage_skill(
- SkillStageRequest(
- skill_name=inputs.skill,
- repository=repository,
- workspace=ws,
- ctx=tool_context,
- engine=workspace_runtime,
- timeout=self._timeout,
- ))
- workspace_skill_dir = stage_result.workspace_skill_dir
+ ws = await manager.create_workspace(workspace_id, tool_context)
+
+ # Static stage is handled by skill_load. Keep fallback staging here for
+ # backward compatibility when callers skip skill_load.
+ if self._is_skill_loaded(tool_context, inputs.skill):
+ workspace_skill_dir = get_staged_workspace_dir(tool_context, inputs.skill)
+ if not workspace_skill_dir:
+ workspace_skill_dir = default_workspace_skill_dir(inputs.skill)
+ else:
+ stage_result = await self._skill_stager.stage_skill(
+ SkillStageRequest(
+ skill_name=inputs.skill,
+ repository=repository,
+ workspace=ws,
+ ctx=tool_context,
+ timeout=self._timeout,
+ ))
+ workspace_skill_dir = stage_result.workspace_skill_dir
if inputs.inputs:
fs = workspace_runtime.fs(tool_context)
@@ -705,16 +719,6 @@ async def _run_async_impl(
# Select primary output
primary = _select_primary_output(files)
- suggested_commands = self._suggest_commands_for_missing_command(
- result,
- repository,
- inputs.skill,
- )
- suggested_tools = self._suggest_tools_for_missing_command(
- result,
- repository,
- inputs.skill,
- )
output = SkillRunOutput(
stdout=stdout,
@@ -725,20 +729,20 @@ async def _run_async_impl(
output_files=files,
primary_output=primary,
warnings=warnings,
- suggested_commands=suggested_commands,
- suggested_tools=suggested_tools,
)
await self._attach_artifacts_if_requested(tool_context, ws, inputs, output, files)
self._merge_manifest_artifact_refs(manifest, output)
# omit_inline_content
- if inputs.omit_inline_content and output.artifact_files:
+ if inputs.omit_inline_content:
for f in output.output_files:
f.content = ""
if output.primary_output:
output.primary_output.content = ""
+ _apply_run_artifacts_state_delta(tool_context, output)
+
return output.model_dump(exclude_none=True)
# ------------------------------------------------------------------
@@ -761,13 +765,15 @@ async def _run_program(
# Inject skill-specific env from repository (e.g. api_key → primary_env)
repository = self._get_repository(ctx)
try:
- skill_env: dict[str, str] = repository.skill_run_env(input_data.skill)
+ skill_env: dict[str, str] = repository.skill_run_env(ctx, input_data.skill)
for k, v in skill_env.items():
k = k.strip()
if not k or not v.strip():
continue
if k in env: # don't override explicit tool-call env
continue
+ if os.environ.get(k, "").strip(): # don't override host env
+ continue
if k.upper() in _BLOCKED_SKILL_ENV_KEYS:
continue
env[k] = v
@@ -778,9 +784,6 @@ async def _run_program(
await self._prepare_editor_env(ctx, ws, env, input_data.editor_text)
# Build command (with venv activation or command restrictions)
- blocked = self._precheck_inline_python_rewrite(repository, input_data)
- if blocked is not None:
- return blocked
cmd, cmd_args = self._build_command(input_data.command, ws.path, cwd)
workspace_runtime = repository.workspace_runtime
@@ -798,300 +801,10 @@ async def _run_program(
ctx,
)
if ret.exit_code != 0:
- ret = self._with_missing_command_hint(ret, input_data)
- ret = self._with_skill_doc_command_hint(ret, repository, input_data)
- ret = self._with_missing_entrypoint_hint(ret, input_data, ws, cwd)
- logger.info("Failed to run program: %s", ret.stderr)
+ raw_stderr = ret.stderr or ""
+ logger.info("Failed to run program: exit_code=%s, stderr=%s", ret.exit_code, raw_stderr.strip())
return ret
- def _precheck_inline_python_rewrite(
- self,
- repository: BaseSkillRepository,
- input_data: SkillRunInput,
- ) -> Optional[WorkspaceRunResult]:
- """Optionally block ad-hoc python -c when SKILL.md has script examples."""
- if not self._block_inline_python_rewrite:
- return None
- cmd = (input_data.command or "").strip()
- if not re.match(r"^python(?:\d+(?:\.\d+)?)?\s+-c\b", cmd):
- return None
- try:
- sk = repository.get(input_data.skill)
- except Exception: # pylint: disable=broad-except
- return None
- if sk is None:
- return None
- examples = self._extract_shell_examples_from_skill_body(sk.body, limit=8)
- script_examples = [
- e for e in examples if re.match(r"^python(?:\d+(?:\.\d+)?)?\s+scripts/[^\s]+(?:\s|$)", e.strip())
- ]
- if not script_examples:
- return None
- stderr = ("skill_run rejected this ad-hoc inline python command.\n"
- "This skill already provides script-based Python command examples in SKILL.md.\n"
- "Use one of these commands exactly instead of `python -c` rewrites:\n" +
- "\n".join(f"- {c}" for c in script_examples) + "\n")
- return WorkspaceRunResult(
- stdout="",
- stderr=stderr,
- exit_code=2,
- duration=0,
- timed_out=False,
- )
-
- @staticmethod
- def _with_missing_command_hint(ret: WorkspaceRunResult, input_data: SkillRunInput) -> WorkspaceRunResult:
- """Add a targeted hint when the shell command does not exist."""
- stderr_lower = (ret.stderr or "").lower()
- is_missing_cmd = ret.exit_code == 127 and ("command not found" in stderr_lower or
- "not recognized as an internal or external command" in stderr_lower)
- if not is_missing_cmd:
- return ret
- hint = ("\n\nSkill command hint:\n"
- f"- The command `{input_data.command}` was not found in this skill workspace.\n"
- "- Do not invent command names.\n"
- "- Read the loaded `SKILL.md` and execute one of its exact shell examples.\n"
- "- If needed, call `skill_load` first so the full skill body is injected "
- "before calling `skill_run`.\n")
- return ret.model_copy(update={"stderr": f"{ret.stderr}{hint}"})
-
- @staticmethod
- def _extract_shell_examples_from_skill_body(body: str, limit: int = 5) -> list[str]:
- """Extract likely runnable shell command lines from SKILL.md body."""
- if not body:
- return []
- out: list[str] = []
- seen: set[str] = set()
- lines = body.splitlines()
-
- def maybe_add(cmd: str) -> None:
- cmd = re.sub(r"\s+", " ", (cmd or "").strip())
- if not cmd:
- return
- if cmd in seen:
- return
- # Keep only command-like lines.
- if not re.match(r"^[A-Za-z0-9_./$\"'`-]", cmd):
- return
- # Skip function-style examples like tool_name(arg="..."), which are
- # LLM tool calls rather than shell commands.
- if re.match(r"^[A-Za-z_][A-Za-z0-9_]*\s*\(.*\)\s*$", cmd):
- return
- seen.add(cmd)
- out.append(cmd)
-
- # 1) Parse markdown fenced code blocks.
- in_fence = False
- for raw in lines:
- line = raw.strip()
- if line.startswith("```"):
- in_fence = not in_fence
- continue
- if not in_fence:
- continue
- if not line or line.startswith("#"):
- continue
- if line in ("PY", "EOF") or line.startswith(("-", "*")):
- continue
- maybe_add(line)
- if len(out) >= limit:
- return out
-
- # 2) Parse "Command:" sections with indented multi-line commands.
- i = 0
- while i < len(lines) and len(out) < limit:
- cur = lines[i].strip()
- if cur.lower() != "command:":
- i += 1
- continue
- i += 1
- block: list[str] = []
- while i < len(lines):
- raw = lines[i]
- s = raw.strip()
- if not s:
- if block:
- break
- i += 1
- continue
- # Next section marker
- if s.lower() in ("command:", "output files", "overview", "examples", "tools:"):
- break
- if re.match(r"^\d+\)", s):
- break
- # Command content is commonly indented in SKILL.md examples.
- if raw.startswith(" ") or raw.startswith("\t"):
- block.append(s)
- i += 1
- continue
- if block:
- break
- i += 1
- if block:
- merged = " ".join(part.rstrip("\\").strip() for part in block)
- maybe_add(merged)
- return out[:limit]
-
- def _with_skill_doc_command_hint(
- self,
- ret: WorkspaceRunResult,
- repository: BaseSkillRepository,
- input_data: SkillRunInput,
- ) -> WorkspaceRunResult:
- """Append SKILL.md command examples when command is not found."""
- stderr = ret.stderr or ""
- stderr_lower = stderr.lower()
- is_missing_cmd = ret.exit_code == 127 and ("command not found" in stderr_lower or
- "not recognized as an internal or external command" in stderr_lower)
- if not is_missing_cmd:
- return ret
- try:
- sk = repository.get(input_data.skill)
- except Exception: # pylint: disable=broad-except
- return ret
- examples = self._extract_shell_examples_from_skill_body(sk.body, limit=5)
- tools = list(getattr(sk, "tools", []) or [])
- if not examples and not tools:
- return ret
- hint_parts: list[str] = []
- if examples:
- hint_parts.append("\n\nSKILL.md command examples:\n" + "\n".join(f"- {cmd}" for cmd in examples) + "\n")
- if tools:
- hint_parts.append("\nSKILL.md tools suggest this is a tool-call workflow:\n" +
- "\n".join(f"- {name}" for name in tools) + "\n" +
- "- Do not run these tool names via `skill_run` shell commands.\n" +
- "- Call those tools directly (e.g. function/tool call) after `skill_load`.\n")
- hint = "".join(hint_parts)
- return ret.model_copy(update={"stderr": f"{stderr}{hint}"})
-
- @staticmethod
- def _is_missing_command_result(ret: WorkspaceRunResult) -> bool:
- """Return True when stderr indicates command-not-found failure."""
- stderr_lower = (ret.stderr or "").lower()
- return ret.exit_code == 127 and ("command not found" in stderr_lower
- or "not recognized as an internal or external command" in stderr_lower)
-
- def _suggest_commands_for_missing_command(
- self,
- ret: WorkspaceRunResult,
- repository: BaseSkillRepository,
- skill_name: str,
- ) -> Optional[list[str]]:
- """Return SKILL.md command suggestions for missing-command failures."""
- if not self._is_missing_command_result(ret):
- return None
- try:
- sk = repository.get(skill_name)
- except Exception: # pylint: disable=broad-except
- return None
- suggestions = self._extract_shell_examples_from_skill_body(sk.body, limit=5)
- return suggestions or None
-
- def _suggest_tools_for_missing_command(
- self,
- ret: WorkspaceRunResult,
- repository: BaseSkillRepository,
- skill_name: str,
- ) -> Optional[list[str]]:
- """Return SKILL.md tool names when command-not-found implies tool calls."""
- if not self._is_missing_command_result(ret):
- return None
- try:
- sk = repository.get(skill_name)
- except Exception: # pylint: disable=broad-except
- return None
- tools = list(getattr(sk, "tools", []) or [])
- return tools or None
-
- @staticmethod
- def _extract_command_path_candidates(command: str) -> list[str]:
- """Extract likely relative path tokens from a shell command."""
- try:
- tokens = shlex.split(command, posix=True)
- except ValueError:
- tokens = command.split()
- out: list[str] = []
- seen: set[str] = set()
- for tok in tokens:
- s = (tok or "").strip()
- if not s or s.startswith("-"):
- continue
- # Keep relative path-like tokens only.
- is_path_like = "/" in s or s.endswith((".py", ".sh", ".pl", ".rb", ".js", ".ts"))
- if not is_path_like or os.path.isabs(s):
- continue
- if s in seen:
- continue
- seen.add(s)
- out.append(s)
- return out
-
- @staticmethod
- def _list_entrypoint_suggestions(run_dir: Path, limit: int = 20) -> list[str]:
- """List likely runnable files from common locations under the current cwd."""
- suggestions: list[str] = []
- seen: set[str] = set()
- roots = [
- run_dir,
- run_dir / "scripts",
- run_dir / "bin",
- run_dir / "tools",
- ]
- for root in roots:
- try:
- if not root.is_dir():
- continue
- for p in sorted(root.rglob("*")):
- if len(suggestions) >= limit:
- return suggestions
- if not p.is_file():
- continue
- rel = p.relative_to(run_dir).as_posix()
- # Prefer obviously runnable entries.
- is_runnable = os.access(p.as_posix(), os.X_OK) or rel.endswith(
- (".py", ".sh", ".pl", ".rb", ".js", ".ts"))
- if not is_runnable or rel in seen:
- continue
- seen.add(rel)
- suggestions.append(rel)
- except Exception: # pylint: disable=broad-except
- continue
- return suggestions
-
- @staticmethod
- def _with_missing_entrypoint_hint(
- ret: WorkspaceRunResult,
- input_data: SkillRunInput,
- ws: WorkspaceInfo,
- cwd: str,
- ) -> WorkspaceRunResult:
- """Add a generic hint when command references missing relative entrypoints."""
- stderr = ret.stderr or ""
- stderr_lower = stderr.lower()
- if ret.exit_code == 0:
- return ret
- if "no such file or directory" not in stderr_lower and "can't open file" not in stderr_lower:
- return ret
- missing_candidates = SkillRunTool._extract_command_path_candidates(input_data.command)
- if not missing_candidates:
- return ret
-
- run_dir = Path(ws.path) / cwd
- # Show only candidates that are truly missing under cwd.
- missing = [p for p in missing_candidates if not (run_dir / p).exists()]
- if not missing:
- return ret
- available = SkillRunTool._list_entrypoint_suggestions(run_dir, limit=20)
- if not available:
- return ret
-
- hint = ("\n\nSkill entrypoint hint:\n"
- f"- Missing relative path(s) in command: {', '.join(f'`{m}`' for m in missing)}\n"
- f"- Runnable file candidates from current skill cwd: {', '.join(available)}\n"
- "- Do not invent file names or paths.\n"
- "- Read the loaded `SKILL.md` and run one of those commands exactly.\n")
- return ret.model_copy(update={"stderr": f"{stderr}{hint}"})
-
def _resolve_cwd(self, cwd: str, skill_dir: str) -> str:
"""Resolve the working directory relative to the workspace root.
diff --git a/trpc_agent_sdk/skills/tools/_skill_select_docs.py b/trpc_agent_sdk/skills/tools/_skill_select_docs.py
index ad5c756..1e07dc4 100644
--- a/trpc_agent_sdk/skills/tools/_skill_select_docs.py
+++ b/trpc_agent_sdk/skills/tools/_skill_select_docs.py
@@ -15,8 +15,14 @@
from trpc_agent_sdk.context import InvocationContext
from .._common import BaseSelectionResult
-from .._common import generic_select_items
-from .._constants import SKILL_DOCS_STATE_KEY_PREFIX
+from .._common import append_loaded_order_state_delta
+from .._common import docs_state_key
+from .._common import get_agent_name
+from .._common import get_previous_selection_by_key
+from .._common import normalize_selection_mode
+from .._common import set_selection_state_delta_by_key
+from .._constants import SKILL_REPOSITORY_KEY
+from .._repository import BaseSkillRepository
class SkillSelectDocsResult(BaseSelectionResult):
@@ -53,11 +59,58 @@ def skill_select_docs(tool_context: InvocationContext,
Returns:
A message indicating the docs were selected.
"""
- result = generic_select_items(tool_context=tool_context,
- skill_name=skill_name,
- items=docs,
- include_all=include_all_docs,
- mode=mode,
- state_key_prefix=SKILL_DOCS_STATE_KEY_PREFIX,
- result_class=SkillSelectDocsResult)
+ normalized_skill = (skill_name or "").strip()
+ if not normalized_skill:
+ raise ValueError("skill is required")
+ normalized_mode = normalize_selection_mode(mode)
+ agent_name = get_agent_name(tool_context)
+
+ repository: Optional[BaseSkillRepository] = tool_context.agent_context.get_metadata(SKILL_REPOSITORY_KEY)
+ if repository is not None:
+ try:
+ _ = repository.get(normalized_skill)
+ except ValueError as ex:
+ raise ValueError(f"unknown skill: {normalized_skill}") from ex
+
+ docs_selection_key = docs_state_key(tool_context, normalized_skill)
+ previous_items, had_all = get_previous_selection_by_key(tool_context, docs_selection_key)
+ if had_all and normalized_mode != "clear":
+ result = SkillSelectDocsResult(
+ skill=normalized_skill,
+ selected_items=[],
+ include_all=True,
+ mode=normalized_mode,
+ )
+ elif normalized_mode == "clear":
+ result = SkillSelectDocsResult(
+ skill=normalized_skill,
+ selected_items=[],
+ include_all=False,
+ mode="clear",
+ )
+ elif normalized_mode == "add":
+ selected = set(previous_items)
+ for item in docs or []:
+ selected.add(item)
+ result = SkillSelectDocsResult(
+ skill=normalized_skill,
+ selected_items=[] if include_all_docs else list(selected),
+ include_all=include_all_docs,
+ mode="add",
+ )
+ else:
+ result = SkillSelectDocsResult(
+ skill=normalized_skill,
+ selected_items=[] if include_all_docs else list(docs or []),
+ include_all=include_all_docs,
+ mode="replace",
+ )
+
+ set_selection_state_delta_by_key(
+ tool_context,
+ docs_selection_key,
+ result.selected_docs,
+ result.include_all_docs,
+ )
+ append_loaded_order_state_delta(tool_context, agent_name, normalized_skill)
return result
diff --git a/trpc_agent_sdk/skills/tools/_skill_select_tools.py b/trpc_agent_sdk/skills/tools/_skill_select_tools.py
index 0c44b0f..f2e9745 100644
--- a/trpc_agent_sdk/skills/tools/_skill_select_tools.py
+++ b/trpc_agent_sdk/skills/tools/_skill_select_tools.py
@@ -15,8 +15,14 @@
from trpc_agent_sdk.context import InvocationContext
from .._common import BaseSelectionResult
-from .._common import generic_select_items
-from .._constants import SKILL_TOOLS_STATE_KEY_PREFIX
+from .._common import append_loaded_order_state_delta
+from .._common import get_agent_name
+from .._common import get_previous_selection_by_key
+from .._common import normalize_selection_mode
+from .._common import set_selection_state_delta_by_key
+from .._common import tool_state_key
+from .._constants import SKILL_REPOSITORY_KEY
+from .._repository import BaseSkillRepository
class SkillSelectToolsResult(BaseSelectionResult):
@@ -73,11 +79,58 @@ def skill_select_tools(tool_context: InvocationContext,
tools=["get_current_weather"],
mode="replace")
"""
- result = generic_select_items(tool_context=tool_context,
- skill_name=skill_name,
- items=tools,
- include_all=include_all_tools,
- mode=mode,
- state_key_prefix=SKILL_TOOLS_STATE_KEY_PREFIX,
- result_class=SkillSelectToolsResult)
+ normalized_skill = (skill_name or "").strip()
+ if not normalized_skill:
+ raise ValueError("skill is required")
+ normalized_mode = normalize_selection_mode(mode)
+ agent_name = get_agent_name(tool_context)
+
+ repository: Optional[BaseSkillRepository] = tool_context.agent_context.get_metadata(SKILL_REPOSITORY_KEY)
+ if repository is not None:
+ try:
+ _ = repository.get(normalized_skill)
+ except ValueError as ex:
+ raise ValueError(f"unknown skill: {normalized_skill}") from ex
+
+ tools_selection_key = tool_state_key(tool_context, normalized_skill)
+ previous_items, had_all = get_previous_selection_by_key(tool_context, tools_selection_key)
+ if had_all and normalized_mode != "clear":
+ result = SkillSelectToolsResult(
+ skill=normalized_skill,
+ selected_items=[],
+ include_all=True,
+ mode=normalized_mode,
+ )
+ elif normalized_mode == "clear":
+ result = SkillSelectToolsResult(
+ skill=normalized_skill,
+ selected_items=[],
+ include_all=False,
+ mode="clear",
+ )
+ elif normalized_mode == "add":
+ selected = set(previous_items)
+ for item in tools or []:
+ selected.add(item)
+ result = SkillSelectToolsResult(
+ skill=normalized_skill,
+ selected_items=[] if include_all_tools else list(selected),
+ include_all=include_all_tools,
+ mode="add",
+ )
+ else:
+ result = SkillSelectToolsResult(
+ skill=normalized_skill,
+ selected_items=[] if include_all_tools else list(tools or []),
+ include_all=include_all_tools,
+ mode="replace",
+ )
+
+ set_selection_state_delta_by_key(
+ tool_context,
+ tools_selection_key,
+ result.selected_tools,
+ result.include_all_tools,
+ )
+ append_loaded_order_state_delta(tool_context, agent_name, normalized_skill)
return result
diff --git a/trpc_agent_sdk/skills/tools/_workspace_exec.py b/trpc_agent_sdk/skills/tools/_workspace_exec.py
new file mode 100644
index 0000000..395457f
--- /dev/null
+++ b/trpc_agent_sdk/skills/tools/_workspace_exec.py
@@ -0,0 +1,518 @@
+# Tencent is pleased to support the open source community by making tRPC-Agent-Python available.
+#
+# Copyright (C) 2026 Tencent. All rights reserved.
+#
+# tRPC-Agent-Python is licensed under Apache-2.0.
+"""Shared executor workspace tools."""
+
+from __future__ import annotations
+
+import posixpath
+import time
+from dataclasses import dataclass
+from typing import Any
+from typing import Optional
+from typing_extensions import override
+
+from pydantic import BaseModel
+from pydantic import Field
+from trpc_agent_sdk.code_executors import BaseCodeExecutor
+from trpc_agent_sdk.code_executors import BaseProgramSession
+from trpc_agent_sdk.code_executors import BaseWorkspaceRuntime
+from trpc_agent_sdk.code_executors import DEFAULT_EXEC_YIELD_MS
+from trpc_agent_sdk.code_executors import DEFAULT_SESSION_KILL_SEC
+from trpc_agent_sdk.code_executors import DEFAULT_SESSION_TTL_SEC
+from trpc_agent_sdk.code_executors import DIR_OUT
+from trpc_agent_sdk.code_executors import DIR_RUNS
+from trpc_agent_sdk.code_executors import DIR_SKILLS
+from trpc_agent_sdk.code_executors import DIR_WORK
+from trpc_agent_sdk.code_executors import PROGRAM_STATUS_EXITED
+from trpc_agent_sdk.code_executors import PROGRAM_STATUS_RUNNING
+from trpc_agent_sdk.code_executors import ProgramPoll
+from trpc_agent_sdk.code_executors import WorkspaceRunProgramSpec
+from trpc_agent_sdk.code_executors import poll_line_limit
+from trpc_agent_sdk.code_executors import wait_for_program_output
+from trpc_agent_sdk.code_executors import yield_duration_ms
+from trpc_agent_sdk.code_executors.utils import normalize_globs
+from trpc_agent_sdk.context import InvocationContext
+from trpc_agent_sdk.filter import BaseFilter
+from trpc_agent_sdk.tools import BaseTool
+from trpc_agent_sdk.types import FunctionDeclaration
+from trpc_agent_sdk.types import Schema
+from trpc_agent_sdk.types import Type
+
+from ._common import CreateWorkspaceNameCallback
+from ._common import cleanup_expired_sessions
+from ._common import default_create_ws_name_callback
+from ._common import require_non_empty
+
+_DEFAULT_WORKSPACE_EXEC_TIMEOUT_SEC = 5 * 60
+_DEFAULT_WORKSPACE_WRITE_YIELD_MS = 200
+
+
+def _combine_output(stdout: str, stderr: str) -> str:
+ if not stdout:
+ return stderr
+ if not stderr:
+ return stdout
+ return stdout + stderr
+
+
+def _has_glob_meta(s: str) -> bool:
+ return any(ch in s for ch in ("*", "?", "["))
+
+
+def _has_env_prefix(s: str, name: str) -> bool:
+ if s.startswith(f"${name}"):
+ tail = s[len(name) + 1:]
+ return tail == "" or tail.startswith("/") or tail.startswith("\\")
+ brace = f"${{{name}}}"
+ if s.startswith(brace):
+ tail = s[len(brace):]
+ return tail == "" or tail.startswith("/") or tail.startswith("\\")
+ return False
+
+
+def _is_workspace_env_path(s: str) -> bool:
+ return (_has_env_prefix(s, "WORKSPACE_DIR") or _has_env_prefix(s, "SKILLS_DIR") or _has_env_prefix(s, "WORK_DIR")
+ or _has_env_prefix(s, "OUTPUT_DIR") or _has_env_prefix(s, "RUN_DIR"))
+
+
+def _is_allowed_workspace_path(rel: str) -> bool:
+ return any(rel == root or rel.startswith(f"{root}/") for root in (DIR_SKILLS, DIR_WORK, DIR_OUT, DIR_RUNS))
+
+
+def _normalize_cwd(raw: str) -> str:
+ s = (raw or "").strip().replace("\\", "/")
+ if not s:
+ return "."
+ if _has_glob_meta(s):
+ raise ValueError("cwd must not contain glob patterns")
+ if _is_workspace_env_path(s):
+ out = normalize_globs([s])
+ if not out:
+ raise ValueError("invalid cwd")
+ s = out[0]
+ if s.startswith("/"):
+ rel = posixpath.normpath(s).lstrip("/")
+ if rel in ("", "."):
+ return "."
+ if not _is_allowed_workspace_path(rel):
+ raise ValueError(f"cwd must stay under workspace roots: {raw!r}")
+ return rel
+ rel = posixpath.normpath(s)
+ if rel == ".":
+ return "."
+ if rel == ".." or rel.startswith("../"):
+ raise ValueError("cwd must stay within the workspace")
+ if not _is_allowed_workspace_path(rel):
+ raise ValueError(
+ f"cwd must stay under supported workspace roots such as skills/, work/, out/, or runs/: {raw!r}")
+ return rel
+
+
+def _exec_timeout_seconds(raw: int) -> float:
+ if raw <= 0:
+ return float(_DEFAULT_WORKSPACE_EXEC_TIMEOUT_SEC)
+ return float(raw)
+
+
+def _exec_yield_seconds(background: bool, raw_ms: Optional[int]) -> float:
+ if background:
+ if raw_ms is not None and raw_ms > 0:
+ return raw_ms / 1000.0
+ return 0.0
+ return yield_duration_ms(raw_ms or 0, DEFAULT_EXEC_YIELD_MS)
+
+
+def _write_yield_seconds(raw_ms: Optional[int]) -> float:
+ if raw_ms is None:
+ return _DEFAULT_WORKSPACE_WRITE_YIELD_MS / 1000.0
+ if raw_ms < 0:
+ return 0.0
+ return raw_ms / 1000.0
+
+
+class _ExecInput(BaseModel):
+ command: str = Field(default="")
+ cwd: str = Field(default="")
+ env: dict[str, str] = Field(default_factory=dict)
+ stdin: str = Field(default="")
+ yield_time_ms: int = Field(default=0)
+ background: bool = Field(default=False)
+ timeout_sec: int = Field(default=0)
+ tty: bool = Field(default=False)
+
+
+class _WriteInput(BaseModel):
+ session_id: str = Field(default="")
+ chars: str = Field(default="")
+ yield_time_ms: Optional[int] = Field(default=None)
+ append_newline: bool = Field(default=False)
+
+
+class _KillInput(BaseModel):
+ session_id: str = Field(default="")
+
+
+@dataclass
+class _ExecSession:
+ proc: BaseProgramSession
+ exited_at: Optional[float] = None
+ finalized: bool = False
+ finalized_at: Optional[float] = None
+
+
+class WorkspaceExecTool(BaseTool):
+ """Execute shell commands in shared executor workspace."""
+
+ def __init__(
+ self,
+ workspace_runtime: BaseWorkspaceRuntime,
+ create_ws_name_cb: Optional[CreateWorkspaceNameCallback] = None,
+ session_ttl: float = DEFAULT_SESSION_TTL_SEC,
+ filters_name: Optional[list[str]] = None,
+ filters: Optional[list[BaseFilter]] = None,
+ ):
+ super().__init__(
+ name="workspace_exec",
+ description=("Execute a shell command inside the current executor workspace. "
+ "This is the default shell runner for executor-side work not bound to a specific skill."),
+ filters_name=filters_name,
+ filters=filters,
+ )
+ self._workspace_runtime = workspace_runtime
+ self._create_ws_name_cb = create_ws_name_cb or default_create_ws_name_callback
+ self._ttl = session_ttl
+ self._sessions: dict[str, _ExecSession] = {}
+
+ def _runtime(self) -> BaseWorkspaceRuntime:
+ runtime = self._workspace_runtime
+ if runtime is None:
+ raise ValueError("workspace_exec requires an executor with live workspace support")
+ return runtime
+
+ async def _workspace(self, ctx: InvocationContext):
+ runtime = self._runtime()
+ manager = runtime.manager(ctx)
+ workspace_id = self._create_ws_name_cb(ctx)
+ ws = await manager.create_workspace(workspace_id, ctx)
+ return runtime, ws
+
+ def _supports_interactive(self, ctx: InvocationContext) -> bool:
+ runner = self._runtime().runner(ctx)
+ start_program = getattr(runner, "start_program", None)
+ return start_program is not None
+
+ async def _put_session(self, sid: str, session: _ExecSession) -> None:
+ await self._cleanup_expired_locked()
+ self._sessions[sid] = session
+
+ async def _get_session(self, sid: str) -> _ExecSession:
+ await self._cleanup_expired_locked()
+ session = self._sessions.get(sid)
+ if session is None:
+ raise ValueError(f"unknown session_id: {sid}")
+ return session
+
+ async def _remove_session(self, sid: str) -> _ExecSession:
+ await self._cleanup_expired_locked()
+ session = self._sessions.pop(sid, None)
+ if session is None:
+ raise ValueError(f"unknown session_id: {sid}")
+ return session
+
+ async def _finalize_and_remove_session(self, sid: str) -> None:
+ session = await self._get_session(sid)
+ now = time.time()
+ session.finalized = True
+ session.finalized_at = now
+ session.exited_at = now
+ await session.proc.close()
+ await self._remove_session(sid)
+
+ async def _cleanup_expired_locked(self) -> None:
+
+ async def _refresh_exit_state(session: _ExecSession, now: float) -> None:
+ if session.exited_at is not None:
+ return
+ session_state = await session.proc.state()
+ if session_state.status == PROGRAM_STATUS_EXITED:
+ session.exited_at = now
+
+ await cleanup_expired_sessions(
+ self._sessions,
+ ttl=self._ttl,
+ refresh_exit_state=_refresh_exit_state,
+ close_session=lambda s: s.proc.close(),
+ )
+
+ @override
+ def _get_declaration(self) -> FunctionDeclaration:
+ return FunctionDeclaration(
+ name="workspace_exec",
+ description=("Execute a shell command in the shared executor workspace. "
+ "Use for workspace-level file operations and validation commands."),
+ parameters=Schema(
+ type=Type.OBJECT,
+ required=["command"],
+ properties={
+ "command": Schema(type=Type.STRING, description="Shell command to execute."),
+ "cwd": Schema(type=Type.STRING, description="Optional workspace-relative cwd."),
+ "env": Schema(type=Type.OBJECT, description="Optional environment overrides."),
+ "stdin": Schema(type=Type.STRING, description="Optional initial stdin text."),
+ "timeout_sec": Schema(type=Type.INTEGER, description="Maximum command runtime in seconds."),
+ "yield_time_ms": Schema(type=Type.INTEGER, description="How long to wait before returning."),
+ "background": Schema(type=Type.BOOLEAN, description="Start command in background session."),
+ "tty": Schema(type=Type.BOOLEAN, description="Allocate TTY for interactive commands."),
+ },
+ ),
+ response=_exec_output_schema(
+ "Result of workspace_exec. output is aggregated terminal text and may combine stdout/stderr."),
+ )
+
+ @override
+ async def _run_async_impl(self, *, tool_context: InvocationContext, args: dict[str, Any]) -> Any:
+ inputs = _ExecInput.model_validate(args)
+ command = (inputs.command or "").strip()
+ if not command:
+ raise ValueError("command is required")
+ cwd = _normalize_cwd(inputs.cwd)
+ timeout_raw = inputs.timeout_sec
+ tty = inputs.tty
+ yield_ms = inputs.yield_time_ms
+
+ runtime, ws = await self._workspace(tool_context)
+ spec = WorkspaceRunProgramSpec(
+ cmd="sh",
+ args=["-lc", command],
+ env=inputs.env,
+ cwd=cwd,
+ stdin=inputs.stdin,
+ timeout=_exec_timeout_seconds(timeout_raw),
+ )
+
+ if not self._supports_interactive(tool_context):
+ if inputs.background or tty:
+ raise ValueError("workspace_exec interactive sessions are not supported by the current executor")
+ return await _run_one_shot(runtime, ws, spec, tool_context)
+
+ if (not inputs.background) and (not tty) and yield_ms <= 0:
+ return await _run_one_shot(runtime, ws, spec, tool_context)
+
+ runner = runtime.runner(tool_context)
+ interactive_spec = WorkspaceRunProgramSpec(
+ cmd=spec.cmd,
+ args=spec.args,
+ env=spec.env,
+ cwd=spec.cwd,
+ stdin=spec.stdin,
+ timeout=spec.timeout,
+ limits=spec.limits,
+ tty=tty,
+ )
+ proc = await runner.start_program(tool_context, ws, interactive_spec) # type: ignore[attr-defined]
+ sid = proc.id()
+ await self._put_session(sid, _ExecSession(proc=proc))
+
+ if inputs.background and yield_ms <= 0:
+ poll = await proc.poll(poll_line_limit(0))
+ else:
+ poll = await wait_for_program_output(
+ proc,
+ _exec_yield_seconds(inputs.background, yield_ms if yield_ms > 0 else None),
+ poll_line_limit(0),
+ )
+
+ out = _poll_output(sid, poll)
+ if poll.status == PROGRAM_STATUS_EXITED:
+ try:
+ await self._finalize_and_remove_session(sid)
+ except Exception: # pylint: disable=broad-except
+ out["session_id"] = sid
+ return out
+
+
+class WorkspaceWriteStdinTool(BaseTool):
+ """Write stdin to a running workspace_exec session or poll it."""
+
+ def __init__(
+ self,
+ exec_tool: WorkspaceExecTool,
+ *,
+ filters_name: Optional[list[str]] = None,
+ filters: Optional[list[BaseFilter]] = None,
+ ):
+ super().__init__(
+ name="workspace_write_stdin",
+ description="Write to a running workspace_exec session. Empty chars acts like poll.",
+ filters_name=filters_name,
+ filters=filters,
+ )
+ self._exec = exec_tool
+
+ @override
+ def _get_declaration(self) -> FunctionDeclaration:
+ return FunctionDeclaration(
+ name="workspace_write_stdin",
+ description="Write to a running workspace_exec session. Empty chars acts like poll.",
+ parameters=Schema(
+ type=Type.OBJECT,
+ required=["session_id"],
+ properties={
+ "session_id": Schema(type=Type.STRING, description="Session id returned by workspace_exec."),
+ "chars": Schema(type=Type.STRING, description="Characters to write."),
+ "yield_time_ms": Schema(type=Type.INTEGER, description="Optional wait before polling output."),
+ "append_newline": Schema(type=Type.BOOLEAN, description="Append newline after chars."),
+ },
+ ),
+ response=_exec_output_schema("Result of stdin write or follow-up poll."),
+ )
+
+ @override
+ async def _run_async_impl(self, *, tool_context: InvocationContext, args: dict[str, Any]) -> Any:
+ inputs = _WriteInput.model_validate(args)
+ session_id = require_non_empty(inputs.session_id, field_name="session_id")
+ session = await self._exec._get_session(session_id)
+
+ append_newline = inputs.append_newline
+ if inputs.chars or append_newline:
+ await session.proc.write(inputs.chars or "", append_newline)
+
+ yield_ms = inputs.yield_time_ms or 0
+ user_set_yield = inputs.yield_time_ms is not None
+ poll = await wait_for_program_output(
+ session.proc,
+ _write_yield_seconds(yield_ms if user_set_yield else None),
+ poll_line_limit(0),
+ )
+ out = _poll_output(session_id, poll)
+ if poll.status == PROGRAM_STATUS_EXITED:
+ try:
+ await self._exec._finalize_and_remove_session(session_id)
+ except Exception: # pylint: disable=broad-except
+ out["session_id"] = session_id
+ return out
+
+
+class WorkspaceKillSessionTool(BaseTool):
+ """Terminate a running workspace_exec session."""
+
+ def __init__(
+ self,
+ exec_tool: WorkspaceExecTool,
+ *,
+ filters_name: Optional[list[str]] = None,
+ filters: Optional[list[BaseFilter]] = None,
+ ):
+ super().__init__(
+ name="workspace_kill_session",
+ description="Terminate a running workspace_exec session.",
+ filters_name=filters_name,
+ filters=filters,
+ )
+ self._exec = exec_tool
+
+ @override
+ def _get_declaration(self) -> FunctionDeclaration:
+ return FunctionDeclaration(
+ name="workspace_kill_session",
+ description="Terminate a running workspace_exec session.",
+ parameters=Schema(
+ type=Type.OBJECT,
+ required=["session_id"],
+ properties={
+ "session_id": Schema(type=Type.STRING, description="Session id returned by workspace_exec."),
+ },
+ ),
+ response=Schema(
+ type=Type.OBJECT,
+ required=["ok", "session_id", "status"],
+ properties={
+ "ok": Schema(type=Type.BOOLEAN, description="True when session was removed."),
+ "session_id": Schema(type=Type.STRING, description="Session id."),
+ "status": Schema(type=Type.STRING, description="Final status."),
+ },
+ ),
+ )
+
+ @override
+ async def _run_async_impl(self, *, tool_context: InvocationContext, args: dict[str, Any]) -> Any:
+ if "session_id" not in args:
+ raise ValueError("session_id is required")
+ session_id = args["session_id"]
+ session = await self._exec._get_session(session_id)
+ status = PROGRAM_STATUS_EXITED
+ poll = await session.proc.poll(None)
+ if poll.status == PROGRAM_STATUS_RUNNING:
+ await session.proc.kill(DEFAULT_SESSION_KILL_SEC)
+ status = "killed"
+ await self._exec._finalize_and_remove_session(session_id)
+ return {"ok": True, "session_id": session_id, "status": status}
+
+
+def _exec_output_schema(description: str) -> Schema:
+ return Schema(
+ type=Type.OBJECT,
+ description=description,
+ required=["status", "offset", "next_offset"],
+ properties={
+ "status": Schema(type=Type.STRING, description="running or exited"),
+ "output": Schema(type=Type.STRING, description="Aggregated terminal text observed for this call."),
+ "exit_code": Schema(type=Type.INTEGER, description="Exit code when session has exited."),
+ "session_id": Schema(type=Type.STRING, description="Interactive session id when still running."),
+ "offset": Schema(type=Type.INTEGER, description="Start offset of returned output."),
+ "next_offset": Schema(type=Type.INTEGER, description="Next output offset."),
+ },
+ )
+
+
+def _poll_output(session_id: str, poll: ProgramPoll) -> dict[str, Any]:
+ out: dict[str, Any] = {
+ "status": poll.status,
+ "output": poll.output,
+ "offset": poll.offset,
+ "next_offset": poll.next_offset,
+ }
+ if poll.exit_code is not None:
+ out["exit_code"] = poll.exit_code
+ if poll.status == PROGRAM_STATUS_RUNNING:
+ out["session_id"] = session_id
+ return out
+
+
+async def _run_one_shot(
+ runtime: BaseWorkspaceRuntime,
+ workspace: Any,
+ spec: WorkspaceRunProgramSpec,
+ ctx: InvocationContext,
+) -> dict[str, Any]:
+ rr = await runtime.runner(ctx).run_program(workspace, spec, ctx)
+ return {
+ "status": PROGRAM_STATUS_EXITED,
+ "output": _combine_output(rr.stdout, rr.stderr),
+ "exit_code": rr.exit_code,
+ "offset": 0,
+ "next_offset": 0,
+ }
+
+
+def create_workspace_exec_tools(
+ code_executor: BaseCodeExecutor,
+ *,
+ workspace_runtime: Optional[BaseWorkspaceRuntime] = None,
+ session_ttl: float = DEFAULT_SESSION_TTL_SEC,
+ filters_name: Optional[list[str]] = None,
+ filters: Optional[list[BaseFilter]] = None,
+) -> tuple[WorkspaceExecTool, WorkspaceWriteStdinTool, WorkspaceKillSessionTool]:
+ """Create workspace_exec tool trio."""
+ exec_tool = WorkspaceExecTool(
+ code_executor=code_executor,
+ workspace_runtime=workspace_runtime,
+ session_ttl=session_ttl,
+ filters_name=filters_name,
+ filters=filters,
+ )
+ write_tool = WorkspaceWriteStdinTool(exec_tool, filters_name=filters_name, filters=filters)
+ kill_tool = WorkspaceKillSessionTool(exec_tool, filters_name=filters_name, filters=filters)
+ return exec_tool, write_tool, kill_tool
diff --git a/trpc_agent_sdk/storage/_sql_common.py b/trpc_agent_sdk/storage/_sql_common.py
index d9d4c04..8f5ef17 100644
--- a/trpc_agent_sdk/storage/_sql_common.py
+++ b/trpc_agent_sdk/storage/_sql_common.py
@@ -172,6 +172,8 @@ class DynamicPickleType(TypeDecorator):
def load_dialect_impl(self, dialect: Dialect) -> TypeDecorator:
if dialect.name == "spanner+spanner":
return dialect.type_descriptor(SpannerPickleType) # type: ignore
+ if dialect.name == "mysql":
+ return dialect.type_descriptor(mysql.LONGBLOB) # type: ignore
return self.impl
def process_bind_param(self, value: Any, dialect: Dialect) -> Any: