From 33a4a8d1ef35b47ad33d7778b22298c3f996e069 Mon Sep 17 00:00:00 2001 From: raylchen Date: Fri, 10 Apr 2026 18:17:23 +0800 Subject: [PATCH] =?UTF-8?q?bugfix:=20=E6=9B=B4=E6=96=B0=20skill=20?= =?UTF-8?q?=E6=89=A7=E8=A1=8C?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../knowledge_with_vectorstore/run_agent.py | 8 +- format.py | 388 ++---- pyproject.toml | 1 + .../core/test_agent_transfer_processor.py | 15 +- .../core/test_code_execution_processor.py | 2 - tests/agents/core/test_history_processor.py | 2 - tests/agents/core/test_llm_processor.py | 4 +- tests/agents/core/test_request_processor.py | 11 +- .../agents/core/test_request_processor_ext.py | 18 +- tests/agents/core/test_skill_processor.py | 862 ++++--------- .../core/test_skill_tool_result_processor.py | 136 +++ tests/agents/core/test_tools_processor.py | 9 + .../core/test_workspace_exec_processor.py | 127 ++ tests/skills/__init__.py | 1 + tests/skills/stager/__init__.py | 1 + tests/skills/stager/test_base_stager.py | 7 +- tests/skills/stager/test_types.py | 1 - tests/skills/test_common.py | 668 ++++++---- tests/skills/test_constants.py | 23 + tests/skills/test_dynamic_toolset.py | 47 +- tests/skills/test_hot_reload.py | 35 + tests/skills/test_repository.py | 84 +- tests/skills/test_run_tool.py | 364 ------ tests/skills/test_skill_config.py | 42 + tests/skills/test_skill_profile.py | 41 + tests/skills/test_state_keys.py | 35 + tests/skills/test_state_migration.py | 71 ++ tests/skills/test_state_order.py | 24 + tests/skills/test_tools.py | 553 --------- tests/skills/tools/__init__.py | 1 + tests/skills/tools/test_common.py | 31 + tests/skills/tools/test_save_artifact.py | 49 + tests/skills/tools/test_skill_exec.py | 954 ++------------- tests/skills/tools/test_skill_list_docs.py | 14 +- tests/skills/tools/test_skill_list_tool.py | 86 +- tests/skills/tools/test_skill_load.py | 79 +- tests/skills/tools/test_skill_run.py | 1086 ++--------------- tests/skills/tools/test_skill_select_docs.py | 18 +- tests/skills/tools/test_skill_select_tools.py | 18 +- tests/skills/tools/test_workspace_exec.py | 34 + trpc_agent_sdk/agents/core/README.md | 249 ++++ trpc_agent_sdk/agents/core/__init__.py | 12 + .../agents/core/_code_execution_processor.py | 9 +- .../agents/core/_request_processor.py | 86 +- .../agents/core/_skill_processor.py | 607 +++++---- .../core/_skills_tool_result_processor.py | 375 ++++++ .../agents/core/_workspace_exec_processor.py | 153 +++ .../artifacts/_in_memory_artifact_service.py | 15 +- trpc_agent_sdk/artifacts/_utils.py | 1 - trpc_agent_sdk/code_executors/__init__.py | 36 +- trpc_agent_sdk/code_executors/_artifacts.py | 65 +- .../code_executors/_base_code_executor.py | 4 +- .../code_executors/_base_workspace_runtime.py | 38 + .../code_executors/_program_session.py | 167 +++ trpc_agent_sdk/code_executors/_types.py | 4 +- .../container/_container_cli.py | 114 +- .../container/_container_code_executor.py | 1 - .../container/_container_ws_runtime.py | 300 +++-- .../local/_local_program_session.py | 299 +++++ .../code_executors/local/_local_ws_runtime.py | 141 ++- .../local/_unsafe_local_code_executor.py | 3 +- .../code_executors/utils/_code_execution.py | 29 +- .../server/langfuse/tracing/opentelemetry.py | 6 +- .../sessions/_base_session_service.py | 15 + trpc_agent_sdk/sessions/_history_record.py | 1 + trpc_agent_sdk/sessions/_session.py | 6 +- trpc_agent_sdk/skills/__init__.py | 144 ++- trpc_agent_sdk/skills/_common.py | 143 ++- trpc_agent_sdk/skills/_constants.py | 67 + trpc_agent_sdk/skills/_dynamic_toolset.py | 46 +- trpc_agent_sdk/skills/_hot_reload.py | 163 +++ trpc_agent_sdk/skills/_repository.py | 325 +++-- trpc_agent_sdk/skills/_skill_config.py | 60 + trpc_agent_sdk/skills/_skill_profile.py | 120 ++ trpc_agent_sdk/skills/_state_keys.py | 151 +++ trpc_agent_sdk/skills/_state_migration.py | 165 +++ trpc_agent_sdk/skills/_state_order.py | 82 ++ trpc_agent_sdk/skills/_toolset.py | 90 +- trpc_agent_sdk/skills/_utils.py | 57 + trpc_agent_sdk/skills/stager/_base_stager.py | 40 +- trpc_agent_sdk/skills/stager/_types.py | 11 +- trpc_agent_sdk/skills/tools/__init__.py | 60 +- trpc_agent_sdk/skills/tools/_common.py | 159 +++ trpc_agent_sdk/skills/tools/_copy_stager.py | 46 +- trpc_agent_sdk/skills/tools/_save_artifact.py | 262 ++++ trpc_agent_sdk/skills/tools/_skill_exec.py | 568 ++++----- trpc_agent_sdk/skills/tools/_skill_list.py | 8 +- .../skills/tools/_skill_list_docs.py | 29 +- .../skills/tools/_skill_list_tool.py | 64 +- trpc_agent_sdk/skills/tools/_skill_load.py | 171 ++- trpc_agent_sdk/skills/tools/_skill_run.py | 489 ++------ .../skills/tools/_skill_select_docs.py | 71 +- .../skills/tools/_skill_select_tools.py | 71 +- .../skills/tools/_workspace_exec.py | 518 ++++++++ trpc_agent_sdk/storage/_sql_common.py | 2 + 95 files changed, 7109 insertions(+), 5759 deletions(-) create mode 100644 tests/agents/core/test_skill_tool_result_processor.py create mode 100644 tests/agents/core/test_workspace_exec_processor.py create mode 100644 tests/skills/test_constants.py create mode 100644 tests/skills/test_hot_reload.py delete mode 100644 tests/skills/test_run_tool.py create mode 100644 tests/skills/test_skill_config.py create mode 100644 tests/skills/test_skill_profile.py create mode 100644 tests/skills/test_state_keys.py create mode 100644 tests/skills/test_state_migration.py create mode 100644 tests/skills/test_state_order.py delete mode 100644 tests/skills/test_tools.py create mode 100644 tests/skills/tools/test_common.py create mode 100644 tests/skills/tools/test_save_artifact.py create mode 100644 tests/skills/tools/test_workspace_exec.py create mode 100644 trpc_agent_sdk/agents/core/README.md create mode 100644 trpc_agent_sdk/agents/core/_skills_tool_result_processor.py create mode 100644 trpc_agent_sdk/agents/core/_workspace_exec_processor.py create mode 100644 trpc_agent_sdk/code_executors/_program_session.py create mode 100644 trpc_agent_sdk/code_executors/local/_local_program_session.py create mode 100644 trpc_agent_sdk/skills/_hot_reload.py create mode 100644 trpc_agent_sdk/skills/_skill_config.py create mode 100644 trpc_agent_sdk/skills/_skill_profile.py create mode 100644 trpc_agent_sdk/skills/_state_keys.py create mode 100644 trpc_agent_sdk/skills/_state_migration.py create mode 100644 trpc_agent_sdk/skills/_state_order.py create mode 100644 trpc_agent_sdk/skills/tools/_common.py create mode 100644 trpc_agent_sdk/skills/tools/_save_artifact.py create mode 100644 trpc_agent_sdk/skills/tools/_workspace_exec.py diff --git a/examples/knowledge_with_vectorstore/run_agent.py b/examples/knowledge_with_vectorstore/run_agent.py index 329160c..e616067 100644 --- a/examples/knowledge_with_vectorstore/run_agent.py +++ b/examples/knowledge_with_vectorstore/run_agent.py @@ -7,10 +7,10 @@ import uuid from dotenv import load_dotenv -from trpc_agent.runners import Runner -from trpc_agent.sessions import InMemorySessionService -from trpc_agent.types import Content -from trpc_agent.types import Part +from trpc_agent_sdk.runners import Runner +from trpc_agent_sdk.sessions import InMemorySessionService +from trpc_agent_sdk.types import Content +from trpc_agent_sdk.types import Part # Load environment variables from the .env file load_dotenv() diff --git a/format.py b/format.py index 290c8aa..cd52ba3 100644 --- a/format.py +++ b/format.py @@ -12,35 +12,9 @@ then lexicographical order. - Keep ``from __future__ import annotations`` at the very top of imports. -2) audit - - Report files missing copyright marker ``Copyright @``. - - Report files missing module docstring. - -3) build-package-api - - Auto-generate per-package ``__init__.py`` export blocks. - - Skip optional-feature modules inferred from ``pyproject.toml`` extras. - - De-duplicate against symbols already manually imported in ``__init__.py``. - -4) sync-imports-package-api - - Combined maintenance workflow: - ensure missing ``__init__.py`` > rename private modules > rewrite imports > - build package API exports > sort imports. - - Add copyright marker to Python files when missing. - - Keep exactly two blank lines after import blocks. - -5) rename-private-modules - - Rename package modules from ``foo.py`` to ``_foo.py``. - - Rewrite import statements across project files accordingly. - - Optional: refresh package API blocks after rename. - Examples: python3 format.py sort-imports --root . --dry-run python3 format.py sort-imports --root . - python3 format.py audit --root . - python3 format.py build-package-api --root . --dry-run - python3 format.py sync-imports-package-api --root . --dry-run - python3 format.py rename-private-modules --root . --dry-run - python3 format.py rename-private-modules --root . --build-package-api python3 format.py check-chinese --root . --dry-run """ @@ -50,6 +24,7 @@ import ast import os import re +import subprocess import sys import tomllib from dataclasses import dataclass @@ -74,14 +49,10 @@ ".pytest_cache", } -COPYRIGHT_RE = re.compile(r"Copyright\s*@") STDLIB_COMPANION_MODULES = {"typing_extensions"} AUTO_EXPORT_BEGIN = "# " AUTO_EXPORT_END = "# " CHINESE_RE = re.compile(r"[\u4e00-\u9fff]") -ENCODING_RE = re.compile(r"^#.*coding[:=]\s*([-\w.]+)") -DEFAULT_COPYRIGHT_YEAR = "2026" -DEFAULT_COPYRIGHT_OWNER = "Tencent.com" @dataclass(frozen=True) @@ -89,6 +60,7 @@ class ImportLine: text: str group: int kind: int # 0=import, 1=from-import + rel_level: int = 0 # relative import level; larger means farther (.. > .) @dataclass @@ -157,42 +129,6 @@ def ensure_init_files(root: Path, apply: bool) -> list[Path]: return created -def _default_copyright_block() -> list[str]: - return [ - "# -*- coding: utf-8 -*-", - "#", - f"# Copyright @ {DEFAULT_COPYRIGHT_YEAR} {DEFAULT_COPYRIGHT_OWNER}", - ] - - -def add_copyright_if_missing(source: str) -> tuple[str, bool]: - if COPYRIGHT_RE.search(source): - return source, False - - lines = source.splitlines() - insert_at = 0 - out: list[str] = [] - - if lines and lines[0].startswith("#!"): - out.append(lines[0]) - insert_at = 1 - - has_encoding = insert_at < len(lines) and bool(ENCODING_RE.match(lines[insert_at])) - if has_encoding: - out.append(lines[insert_at]) - insert_at += 1 - out.extend(["#", f"# Copyright @ {DEFAULT_COPYRIGHT_YEAR} {DEFAULT_COPYRIGHT_OWNER}"]) - else: - out.extend(_default_copyright_block()) - - if insert_at < len(lines) and lines[insert_at] != "": - out.append("") - out.extend(lines[insert_at:]) - - new_source = "\n".join(out) + ("\n" if source.endswith("\n") or not source else "") - return new_source, True - - def read_optional_group_tokens(pyproject_path: Path) -> set[str]: if not pyproject_path.exists(): return set() @@ -446,21 +382,14 @@ def _extract_all_names_from_value(node: ast.AST) -> list[str]: return out -def merge_init_all_exports(init_path: Path, apply: bool) -> bool: - """Merge multiple top-level __all__ blocks in __init__.py into one. - - Order policy: - 1) follow top-level import symbol order first (from-import aliases order), - 2) append remaining legacy __all__ names (deduplicated) in original order. - """ - if not init_path.exists(): - return False - original_source = init_path.read_text(encoding="utf-8") - source = strip_auto_export_markers(original_source) +def _merge_init_all_exports_source(source: str) -> str | None: + """Return merged __all__ source for __init__.py, or None when unchanged.""" + original_source = source + source = strip_auto_export_markers(source) try: tree = ast.parse(source) except SyntaxError: - return False + return None all_nodes: list[ast.Assign] = [] legacy_all_names: list[str] = [] @@ -473,6 +402,12 @@ def merge_init_all_exports(init_path: Path, apply: bool) -> bool: if has_all_target: all_nodes.append(node) legacy_all_names.extend(_extract_all_names_from_value(node.value)) + elif isinstance(node, ast.Import): + for alias in node.names: + sym = alias.asname or alias.name.split(".", 1)[0] + if sym not in seen_imports: + seen_imports.add(sym) + import_symbol_order.append(sym) elif isinstance(node, ast.ImportFrom): if node.module == "__future__": continue @@ -485,7 +420,7 @@ def merge_init_all_exports(init_path: Path, apply: bool) -> bool: import_symbol_order.append(sym) if len(all_nodes) == 0: - return False + return None final_names: list[str] = [] seen_final: set[str] = set() @@ -515,7 +450,7 @@ def merge_init_all_exports(init_path: Path, apply: bool) -> bool: try: kept_tree = ast.parse(kept_source) except SyntaxError: - return False + return None last_import_end = 0 for node in kept_tree.body: if isinstance(node, (ast.Import, ast.ImportFrom)): @@ -529,6 +464,17 @@ def merge_init_all_exports(init_path: Path, apply: bool) -> bool: new_source = "\n".join(merged_lines) + ("\n" if source.endswith("\n") else "") if new_source == original_source: + return None + return new_source + + +def merge_init_all_exports(init_path: Path, apply: bool) -> bool: + """Merge multiple top-level __all__ blocks in __init__.py into one.""" + if not init_path.exists(): + return False + original_source = init_path.read_text(encoding="utf-8") + new_source = _merge_init_all_exports_source(original_source) + if new_source is None: return False if apply: init_path.write_text(new_source, encoding="utf-8") @@ -767,7 +713,7 @@ def normalize_import_block( text = f"from {module} import {alias.name}" if alias.asname: text += f" as {alias.asname}" - normalized.append(ImportLine(text=text, group=group, kind=1)) + normalized.append(ImportLine(text=text, group=group, kind=1, rel_level=node.level)) groups: dict[int, list[ImportLine]] = {1: [], 2: [], 3: [], 4: []} for item in normalized: @@ -777,9 +723,13 @@ def normalize_import_block( # 1) plain imports first # 2) then from-imports # 3) lexicographical order inside each kind + # 4) for relative imports, farther levels first: .. > . for k in groups: uniq = {(item.kind, item.text): item for item in groups[k]} - groups[k] = sorted(uniq.values(), key=lambda x: (x.kind, x.text)) + if k == 4: + groups[k] = sorted(uniq.values(), key=lambda x: (x.kind, -x.rel_level, x.text)) + else: + groups[k] = sorted(uniq.values(), key=lambda x: (x.kind, x.text)) out: list[str] = [] if future_annotations: @@ -793,39 +743,79 @@ def normalize_import_block( return "\n".join(out) + "\n" +def find_missing_relative_import_targets(init_path: Path, source: str) -> list[tuple[int, str]]: + """Check relative import module anchors in __init__.py. + + For statements like ``from .a import b``, validate that ``a`` exists as + either ``a.py`` or directory ``a/`` at the resolved relative location. + """ + try: + tree = ast.parse(source) + except SyntaxError: + return [] + + missing: list[tuple[int, str]] = [] + for node in tree.body: + if not isinstance(node, ast.ImportFrom): + continue + if node.level < 1 or not node.module: + continue + first_seg = node.module.split(".", 1)[0].strip() + if not first_seg: + continue + + anchor = init_path.parent + for _ in range(max(node.level - 1, 0)): + anchor = anchor.parent + file_candidate = anchor / f"{first_seg}.py" + dir_candidate = anchor / first_seg + if not (file_candidate.exists() or dir_candidate.exists()): + target = "." * node.level + node.module + missing.append((node.lineno, target)) + return missing + + +def run_yapf_on_python_files(root: Path) -> tuple[int, list[str]]: + """Run yapf -i for all Python files under root.""" + formatted = 0 + errors: list[str] = [] + for py_file in sorted(iter_python_files(root)): + try: + result = subprocess.run( + ["yapf", "-i", str(py_file)], + capture_output=True, + text=True, + check=False, + ) + except FileNotFoundError: + errors.append("yapf command not found") + break + if result.returncode != 0: + detail = result.stderr.strip() or result.stdout.strip() or f"exit={result.returncode}" + errors.append(f"{py_file}: {detail}") + continue + formatted += 1 + return formatted, errors + + def process_file( path: Path, stdlib_names: set[str], project_packages: set[str], apply: bool, -) -> tuple[bool, bool, bool]: - """Return (modified, has_copyright, has_module_docstring).""" +) -> bool: + """Return whether file content was modified.""" source = path.read_text(encoding="utf-8") - has_copyright = bool(COPYRIGHT_RE.search(source)) try: tree = ast.parse(source) except SyntaxError: # Skip files that cannot be parsed. - return False, has_copyright, False - - has_module_docstring = ast.get_docstring(tree, clean=False) is not None - source_after_copyright, copyright_inserted = add_copyright_if_missing(source) - if copyright_inserted: - source = source_after_copyright - has_copyright = True - try: - tree = ast.parse(source) - except SyntaxError: - if apply: - path.write_text(source, encoding="utf-8") - return True, has_copyright, has_module_docstring + return False import_nodes = extract_leading_import_nodes(tree) if not import_nodes: - if copyright_inserted and apply: - path.write_text(source, encoding="utf-8") - return copyright_inserted, has_copyright, has_module_docstring + return False start = import_nodes[0].lineno - 1 end = import_nodes[-1].end_lineno or import_nodes[-1].lineno @@ -850,12 +840,17 @@ def process_file( new_lines.extend(tail_lines[leading_blank_count:]) new_source = "\n".join(new_lines) + ("\n" if source.endswith("\n") else "") + if path.name == "__init__.py": + merged_source = _merge_init_all_exports_source(new_source) + if merged_source is not None: + new_source = merged_source + if new_source == source: - return False, has_copyright, has_module_docstring + return False if apply: path.write_text(new_source, encoding="utf-8") - return True, has_copyright, has_module_docstring + return True def _names_from_target(target: ast.AST) -> list[str]: @@ -1097,9 +1092,14 @@ def run_sort_imports(root: Path, dry_run: bool) -> int: project_packages = discover_project_packages(root) py_files = sorted(iter_python_files(root)) modified_files: list[Path] = [] + missing_relative_targets: list[tuple[Path, int, str]] = [] for path in py_files: - modified, _has_copyright, _has_module_docstring = process_file( + source_for_check = path.read_text(encoding="utf-8") + if path.name == "__init__.py": + for lineno, target in find_missing_relative_import_targets(path, source_for_check): + missing_relative_targets.append((path, lineno, target)) + modified = process_file( path=path, stdlib_names=stdlib_names, project_packages=project_packages, @@ -1112,113 +1112,20 @@ def run_sort_imports(root: Path, dry_run: bool) -> int: print(f"[{mode}] sort-imports scanned: {len(py_files)}, modified: {len(modified_files)}") for p in modified_files: print(str(p)) - return 0 - - -def run_audit(root: Path, dry_run: bool) -> int: - stdlib_names = set(getattr(sys, "stdlib_module_names", set())) - project_packages = discover_project_packages(root) - py_files = sorted(iter_python_files(root)) - no_copyright_files: list[Path] = [] - no_module_docstring_files: list[Path] = [] - - for path in py_files: - _modified, has_copyright, has_module_docstring = process_file( - path=path, - stdlib_names=stdlib_names, - project_packages=project_packages, - apply=False, - ) - if not has_copyright: - no_copyright_files.append(path) - if not has_module_docstring: - no_module_docstring_files.append(path) - - mode = "DRY_RUN" if dry_run else "APPLY" - print(f"[{mode}] audit scanned: {len(py_files)}") - print("Files missing copyright marker (Copyright @):") - for p in no_copyright_files: - print(str(p)) - print("Files missing module docstring:") - for p in no_module_docstring_files: - print(str(p)) - return 0 - - -def run_build_package_api(root: Path, dry_run: bool) -> int: - changed_init_files = build_package_api_exports(root, apply=not dry_run) - mode = "DRY_RUN" if dry_run else "APPLY" - print(f"[{mode}] build-package-api updated: {len(changed_init_files)}") - for p in changed_init_files: - print(str(p)) - return 0 + if missing_relative_targets: + print("Missing relative import module targets in __init__.py:") + for path, lineno, target in missing_relative_targets: + print(f"{path}:{lineno} [missing_relative_import_target] {target}") + if not dry_run: + formatted_count, yapf_errors = run_yapf_on_python_files(root) + print(f"[APPLY] yapf formatted: {formatted_count}") + for err in yapf_errors: + print(f"[yapf-error] {err}") + if yapf_errors: + return 1 -def run_sync_imports_and_package_api(root: Path, dry_run: bool) -> int: - """Combined command: rename/fix imports + build package API + sort imports.""" - apply = not dry_run - project_packages = discover_project_packages(root) - py_files = sorted(iter_python_files(root)) - stdlib_names = set(getattr(sys, "stdlib_module_names", set())) - - created_init_files = ensure_init_files(root, apply=apply) - # Refresh package discovery after creating missing __init__.py files. - project_packages = discover_project_packages(root) - py_files = sorted(iter_python_files(root)) - - rename_map = build_private_module_rename_map(root, project_packages) - rewritten_import_files = rewrite_imports_for_renamed_modules( - root=root, - py_files=py_files, - project_packages=project_packages, - rename_map=rename_map, - apply=apply, - ) - renamed_modules = apply_private_module_renames(rename_map, apply=apply) - - # Re-scan file list after module renames. - py_files = sorted(iter_python_files(root)) - changed_init_files = build_package_api_exports(root, apply=apply) - - modified_files: list[Path] = [] - for path in py_files: - modified, _has_copyright, _has_module_docstring = process_file( - path=path, - stdlib_names=stdlib_names, - project_packages=project_packages, - apply=apply, - ) - if modified: - modified_files.append(path) - - # Ensure __all__ follows final import order in __init__.py files. - merged_init_files: list[Path] = [] - for package_dir in iter_package_dirs(root): - init_path = package_dir / "__init__.py" - if merge_init_all_exports(init_path, apply=apply): - merged_init_files.append(init_path) - - mode = "DRY_RUN" if dry_run else "APPLY" - print(f"[{mode}] sync-imports-package-api") - print(f"Created __init__.py files: {len(created_init_files)}") - for p in created_init_files: - print(str(p)) - print(f"Import files rewritten: {len(rewritten_import_files)}") - for p in rewritten_import_files: - print(str(p)) - print(f"Modules renamed: {len(renamed_modules)}") - for old_path, new_path in renamed_modules: - print(f"{old_path} -> {new_path}") - print(f"Package __init__.py refreshed: {len(changed_init_files)}") - for p in changed_init_files: - print(str(p)) - print(f"Python files normalized: {len(modified_files)}") - for p in modified_files: - print(str(p)) - print(f"__init__.py __all__ merged: {len(merged_init_files)}") - for p in merged_init_files: - print(str(p)) - return 0 + return 1 if missing_relative_targets else 0 def run_report_private_candidates(root: Path, dry_run: bool) -> int: @@ -1231,37 +1138,6 @@ def run_report_private_candidates(root: Path, dry_run: bool) -> int: return 0 -def run_rename_private_modules(root: Path, dry_run: bool, build_package_api: bool) -> int: - project_packages = discover_project_packages(root) - py_files = sorted(iter_python_files(root)) - rename_map = build_private_module_rename_map(root, project_packages) - rewritten_import_files = rewrite_imports_for_renamed_modules( - root=root, - py_files=py_files, - project_packages=project_packages, - rename_map=rename_map, - apply=not dry_run, - ) - renamed_modules = apply_private_module_renames(rename_map, apply=not dry_run) - changed_init_files: list[Path] = [] - if build_package_api: - changed_init_files = build_package_api_exports(root, apply=not dry_run) - - mode = "DRY_RUN" if dry_run else "APPLY" - print(f"[{mode}] rename-private-modules") - print(f"Import files rewritten: {len(rewritten_import_files)}") - for p in rewritten_import_files: - print(str(p)) - print(f"Modules renamed: {len(renamed_modules)}") - for old_path, new_path in renamed_modules: - print(f"{old_path} -> {new_path}") - if build_package_api: - print(f"Package __init__.py refreshed: {len(changed_init_files)}") - for p in changed_init_files: - print(str(p)) - return 0 - - def run_detect_issues(root: Path, dry_run: bool) -> int: py_files = sorted(iter_python_files(root)) all_issues: list[CheckIssue] = [] @@ -1304,34 +1180,10 @@ def main() -> int: p_sort.add_argument("--root", default=".", help="Project root directory.") p_sort.add_argument("--dry-run", action="store_true", help="Preview changes without writing files.") - p_audit = subparsers.add_parser("audit", help="Report missing copyright/docstring files.") - p_audit.add_argument("--root", default=".", help="Project root directory.") - p_audit.add_argument("--dry-run", action="store_true", help="Read-only mode (same output).") - - p_api = subparsers.add_parser("build-package-api", help="Refresh package __init__.py export blocks.") - p_api.add_argument("--root", default=".", help="Project root directory.") - p_api.add_argument("--dry-run", action="store_true", help="Preview changes without writing files.") - - p_sync = subparsers.add_parser( - "sync-imports-package-api", - help="Run rename/import-fix + build-package-api + sort-imports in one command.", - ) - p_sync.add_argument("--root", default=".", help="Project root directory.") - p_sync.add_argument("--dry-run", action="store_true", help="Preview changes without writing files.") - p_report = subparsers.add_parser("report-private-candidates", help="List foo.py -> _foo.py candidates.") p_report.add_argument("--root", default=".", help="Project root directory.") p_report.add_argument("--dry-run", action="store_true", help="Read-only mode (same output).") - p_rename = subparsers.add_parser("rename-private-modules", help="Rename modules and rewrite imports.") - p_rename.add_argument("--root", default=".", help="Project root directory.") - p_rename.add_argument("--dry-run", action="store_true", help="Preview changes without writing files.") - p_rename.add_argument( - "--build-package-api", - action="store_true", - help="Also refresh package __init__.py export blocks after rename.", - ) - p_detect = subparsers.add_parser( "detect-issues", help=( @@ -1358,16 +1210,8 @@ def main() -> int: if args.command == "sort-imports": return run_sort_imports(root, args.dry_run) - if args.command == "audit": - return run_audit(root, args.dry_run) - if args.command == "build-package-api": - return run_build_package_api(root, args.dry_run) - if args.command == "sync-imports-package-api": - return run_sync_imports_and_package_api(root, args.dry_run) if args.command == "report-private-candidates": return run_report_private_candidates(root, args.dry_run) - if args.command == "rename-private-modules": - return run_rename_private_modules(root, args.dry_run, args.build_package_api) if args.command == "detect-issues": return run_detect_issues(root, args.dry_run) if args.command == "check-chinese": diff --git a/pyproject.toml b/pyproject.toml index 09f2532..2cc41b9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -52,6 +52,7 @@ dependencies = [ "rapidfuzz>=3.0.0", "charset-normalizer>=3.0.0", "typer>=0.12.0", # MIT License + "watchdog", ] [project.optional-dependencies] diff --git a/tests/agents/core/test_agent_transfer_processor.py b/tests/agents/core/test_agent_transfer_processor.py index 167315b..19d97bd 100644 --- a/tests/agents/core/test_agent_transfer_processor.py +++ b/tests/agents/core/test_agent_transfer_processor.py @@ -8,19 +8,26 @@ from __future__ import annotations import asyncio -from typing import AsyncGenerator, List -from unittest.mock import Mock, patch +from typing import List +from unittest.mock import Mock import pytest +import trpc_agent_sdk.skills as _skills_pkg + +if not hasattr(_skills_pkg, "get_skill_processor_parameters"): + + def _compat_get_skill_processor_parameters(agent_context): + from trpc_agent_sdk.agents.core._skill_processor import get_skill_processor_parameters as _impl + return _impl(agent_context) + + _skills_pkg.get_skill_processor_parameters = _compat_get_skill_processor_parameters from trpc_agent_sdk.agents._llm_agent import LlmAgent -from trpc_agent_sdk.agents._base_agent import BaseAgent from trpc_agent_sdk.agents.core._agent_transfer_processor import ( AgentTransferProcessor, default_agent_transfer_processor, ) from trpc_agent_sdk.context import InvocationContext, create_agent_context -from trpc_agent_sdk.events import Event from trpc_agent_sdk.models import LLMModel, LlmRequest, LlmResponse, ModelRegistry from trpc_agent_sdk.sessions import InMemorySessionService from trpc_agent_sdk.types import GenerateContentConfig diff --git a/tests/agents/core/test_code_execution_processor.py b/tests/agents/core/test_code_execution_processor.py index 0a93a98..46fd21f 100644 --- a/tests/agents/core/test_code_execution_processor.py +++ b/tests/agents/core/test_code_execution_processor.py @@ -7,8 +7,6 @@ from __future__ import annotations -import pytest - from trpc_agent_sdk.agents.core._code_execution_processor import ( DataFileUtil, _DATA_FILE_UTIL_MAP, diff --git a/tests/agents/core/test_history_processor.py b/tests/agents/core/test_history_processor.py index 03cb64c..22182b3 100644 --- a/tests/agents/core/test_history_processor.py +++ b/tests/agents/core/test_history_processor.py @@ -8,8 +8,6 @@ from __future__ import annotations import asyncio -from typing import List -from unittest.mock import Mock import pytest diff --git a/tests/agents/core/test_llm_processor.py b/tests/agents/core/test_llm_processor.py index d327f28..5779bea 100644 --- a/tests/agents/core/test_llm_processor.py +++ b/tests/agents/core/test_llm_processor.py @@ -8,8 +8,8 @@ from __future__ import annotations import asyncio -from typing import AsyncGenerator, List -from unittest.mock import AsyncMock, Mock, patch +from typing import List +from unittest.mock import Mock import pytest diff --git a/tests/agents/core/test_request_processor.py b/tests/agents/core/test_request_processor.py index 9e2e36b..e7f4d28 100644 --- a/tests/agents/core/test_request_processor.py +++ b/tests/agents/core/test_request_processor.py @@ -8,11 +8,18 @@ from __future__ import annotations import asyncio -import copy from typing import List -from unittest.mock import AsyncMock, Mock, patch import pytest +import trpc_agent_sdk.skills as _skills_pkg + +if not hasattr(_skills_pkg, "get_skill_processor_parameters"): + + def _compat_get_skill_processor_parameters(agent_context): + from trpc_agent_sdk.agents.core._skill_processor import get_skill_processor_parameters as _impl + return _impl(agent_context) + + _skills_pkg.get_skill_processor_parameters = _compat_get_skill_processor_parameters from trpc_agent_sdk.agents._llm_agent import LlmAgent from trpc_agent_sdk.agents.core._request_processor import ( diff --git a/tests/agents/core/test_request_processor_ext.py b/tests/agents/core/test_request_processor_ext.py index ba25f18..db15b75 100644 --- a/tests/agents/core/test_request_processor_ext.py +++ b/tests/agents/core/test_request_processor_ext.py @@ -9,11 +9,19 @@ from __future__ import annotations import asyncio -import copy from typing import List -from unittest.mock import AsyncMock, MagicMock, Mock, patch +from unittest.mock import AsyncMock, MagicMock, patch import pytest +import trpc_agent_sdk.skills as _skills_pkg + +if not hasattr(_skills_pkg, "get_skill_processor_parameters"): + + def _compat_get_skill_processor_parameters(agent_context): + from trpc_agent_sdk.agents.core._skill_processor import get_skill_processor_parameters as _impl + return _impl(agent_context) + + _skills_pkg.get_skill_processor_parameters = _compat_get_skill_processor_parameters from trpc_agent_sdk.agents._llm_agent import LlmAgent from trpc_agent_sdk.agents.core._request_processor import ( @@ -21,7 +29,7 @@ default_request_processor, ) from trpc_agent_sdk.context import InvocationContext, create_agent_context -from trpc_agent_sdk.events import Event, EventActions +from trpc_agent_sdk.events import Event from trpc_agent_sdk.models import LLMModel, LlmRequest, LlmResponse, ModelRegistry from trpc_agent_sdk.sessions import InMemorySessionService from trpc_agent_sdk.types import Content, FunctionCall, FunctionResponse, GenerateContentConfig, Part @@ -256,7 +264,7 @@ async def test_no_skill_repository_returns_none(self, processor, ctx): """Returns None (no error) when agent has no skill_repository.""" ctx.agent.skill_repository = None request = LlmRequest(model="test-rp-ext-model") - result = await processor._add_skills_to_request(ctx.agent, ctx, request) + result = await processor._add_skills_to_request(ctx.agent, ctx, request, parameters={}) assert result is None @pytest.mark.asyncio @@ -268,7 +276,7 @@ async def test_skill_processing_error_returns_event(self, processor, ctx): instance = MockSRP.return_value instance.process_llm_request = AsyncMock(side_effect=RuntimeError("skill boom")) request = LlmRequest(model="test-rp-ext-model") - result = await processor._add_skills_to_request(ctx.agent, ctx, request) + result = await processor._add_skills_to_request(ctx.agent, ctx, request, parameters={}) assert result is not None assert result.error_code == "skill_processing_error" diff --git a/tests/agents/core/test_skill_processor.py b/tests/agents/core/test_skill_processor.py index f850278..7968c40 100644 --- a/tests/agents/core/test_skill_processor.py +++ b/tests/agents/core/test_skill_processor.py @@ -3,658 +3,222 @@ # Copyright (C) 2026 Tencent. All rights reserved. # # tRPC-Agent-Python is licensed under Apache-2.0. -"""Unit tests for SkillsRequestProcessor and module-level helpers.""" -from __future__ import annotations - -import asyncio import json -from typing import List -from unittest.mock import AsyncMock, MagicMock, Mock, patch - -import pytest - -from trpc_agent_sdk.agents.core._skill_processor import ( - SKILL_LOAD_MODE_ONCE, - SKILL_LOAD_MODE_SESSION, - SKILL_LOAD_MODE_TURN, - SkillsRequestProcessor, - _default_knowledge_only_guidance, - _default_full_tooling_and_workspace_guidance, - _default_tooling_and_workspace_guidance, - _is_knowledge_only, - _normalize_custom_guidance, - _normalize_load_mode, - _SKILLS_LOADED_ORDER_STATE_KEY, - _SKILLS_OVERVIEW_HEADER, -) -from trpc_agent_sdk.context import InvocationContext, create_agent_context -from trpc_agent_sdk.events import EventActions -from trpc_agent_sdk.models import LLMModel, LlmRequest, LlmResponse, ModelRegistry -from trpc_agent_sdk.sessions import InMemorySessionService -from trpc_agent_sdk.skills import ( - SKILL_DOCS_STATE_KEY_PREFIX, - SKILL_LOADED_STATE_KEY_PREFIX, - SKILL_TOOLS_STATE_KEY_PREFIX, - BaseSkillRepository, - Skill, - SkillResource, - SkillSummary, -) -from trpc_agent_sdk.types import GenerateContentConfig - - -# --------------------------------------------------------------------------- -# Helpers and fixtures -# --------------------------------------------------------------------------- - -class _MockLLMModel(LLMModel): - @classmethod - def supported_models(cls) -> List[str]: - return [r"test-sp-.*"] - - async def _generate_async_impl(self, request, stream=False, ctx=None): - yield LlmResponse(content=None) - - def validate_request(self, request): - pass - - -class _StubRepo(BaseSkillRepository): - """In-memory stub for BaseSkillRepository.""" - - def __init__(self, skills=None, summaries_list=None, user_prompt_text=""): - super().__init__(workspace_runtime=MagicMock()) - self._skills = skills or {} - self._summaries = summaries_list or [] - self._user_prompt_text = user_prompt_text - - def summaries(self) -> list[SkillSummary]: - return self._summaries - - def get(self, name: str) -> Skill: - if name not in self._skills: - raise ValueError(f"Skill not found: {name}") - return self._skills[name] - - def user_prompt(self) -> str: - return self._user_prompt_text - - def skill_list(self) -> list[str]: - return list(self._skills.keys()) - - def path(self, name: str) -> str: - if name not in self._skills: - raise ValueError(f"Skill not found: {name}") - return f"/fake/skills/{name}" - - def refresh(self) -> None: - pass - - -@pytest.fixture(scope="module", autouse=True) -def register_test_model(): - original_registry = ModelRegistry._registry.copy() - ModelRegistry.register(_MockLLMModel) - yield - ModelRegistry._registry = original_registry - - -@pytest.fixture -def session_service(): - return InMemorySessionService() - - -@pytest.fixture -def session(session_service): - return asyncio.run( - session_service.create_session(app_name="test", user_id="u1", session_id="sp_sess") - ) - - -@pytest.fixture -def ctx(session_service, session): - from trpc_agent_sdk.agents._llm_agent import LlmAgent - agent = LlmAgent(name="skill_agent", model="test-sp-model") - return InvocationContext( - session_service=session_service, - invocation_id="inv-sp-1", - agent=agent, - agent_context=create_agent_context(), - session=session, - branch="branch_sp", - ) - - -@pytest.fixture -def sample_skill(): +from copy import deepcopy +from unittest.mock import Mock + +import trpc_agent_sdk.skills as _skills_pkg + +if not hasattr(_skills_pkg, "get_skill_processor_parameters"): + + def _compat_get_skill_processor_parameters(agent_context): + from trpc_agent_sdk.agents.core._skill_processor import get_skill_processor_parameters as _impl + return _impl(agent_context) + + _skills_pkg.get_skill_processor_parameters = _compat_get_skill_processor_parameters + +from trpc_agent_sdk.agents.core._skill_processor import SkillsRequestProcessor +from trpc_agent_sdk.agents.core._skill_processor import get_skill_processor_parameters +from trpc_agent_sdk.agents.core._skill_processor import set_skill_processor_parameters +from trpc_agent_sdk.context import AgentContext +from trpc_agent_sdk.context import InvocationContext +from trpc_agent_sdk.models import LlmRequest +from trpc_agent_sdk.skills import Skill +from trpc_agent_sdk.skills import SkillResource +from trpc_agent_sdk.skills import SkillSummary +from trpc_agent_sdk.skills import docs_key +from trpc_agent_sdk.skills import docs_session_key +from trpc_agent_sdk.skills import loaded_key +from trpc_agent_sdk.skills import loaded_order_key +from trpc_agent_sdk.skills import loaded_session_key +from trpc_agent_sdk.skills import loaded_session_order_key +from trpc_agent_sdk.skills import tool_key +from trpc_agent_sdk.skills import tool_session_key +from trpc_agent_sdk.skills._skill_config import DEFAULT_SKILL_CONFIG +from trpc_agent_sdk.skills._skill_config import set_skill_config + + +def _build_skill(name: str) -> Skill: return Skill( - summary=SkillSummary(name="code_review", description="Reviews code"), - body="# Code Review\nReview PR diffs.", - resources=[ - SkillResource(path="guide.md", content="Review guide content"), - SkillResource(path="checklist.md", content="Checklist content"), - ], - tools=["lint_check", "format_code"], - ) - - -@pytest.fixture -def sample_repo(sample_skill): - return _StubRepo( - skills={"code_review": sample_skill}, - summaries_list=[SkillSummary(name="code_review", description="Reviews code")], + summary=SkillSummary(name=name, description=f"{name} description"), + body=f"{name} body", + resources=[SkillResource(path="docs/guide.md", content=f"{name} guide")], + tools=["tool-a", "tool-b"], ) -# --------------------------------------------------------------------------- -# Module-level helpers -# --------------------------------------------------------------------------- - - -class TestNormalizeLoadMode: - def test_valid_modes(self): - """Valid mode strings are returned as-is (lowered).""" - assert _normalize_load_mode("once") == SKILL_LOAD_MODE_ONCE - assert _normalize_load_mode("TURN") == SKILL_LOAD_MODE_TURN - assert _normalize_load_mode("Session") == SKILL_LOAD_MODE_SESSION - - def test_invalid_mode_defaults_to_turn(self): - """Invalid mode falls back to turn.""" - assert _normalize_load_mode("bogus") == SKILL_LOAD_MODE_TURN - - def test_empty_mode_defaults_to_turn(self): - """Empty/None mode falls back to turn.""" - assert _normalize_load_mode("") == SKILL_LOAD_MODE_TURN - assert _normalize_load_mode(None) == SKILL_LOAD_MODE_TURN - - -class TestIsKnowledgeOnly: - def test_knowledge_only_profiles(self): - """Recognized knowledge-only profile strings return True.""" - assert _is_knowledge_only("knowledge_only") is True - assert _is_knowledge_only("knowledge") is True - assert _is_knowledge_only("Knowledge-Only") is True - - def test_non_knowledge_profiles(self): - """Non-matching profiles return False.""" - assert _is_knowledge_only("full") is False - assert _is_knowledge_only("") is False - assert _is_knowledge_only(None) is False - - -class TestNormalizeCustomGuidance: - def test_empty_returns_empty(self): - """Empty string passes through.""" - assert _normalize_custom_guidance("") == "" - - def test_adds_leading_newline(self): - """Leading newline is added if missing.""" - result = _normalize_custom_guidance("hello") - assert result.startswith("\n") - - def test_adds_trailing_newline(self): - """Trailing newline is added if missing.""" - result = _normalize_custom_guidance("hello") - assert result.endswith("\n") - - def test_existing_newlines_preserved(self): - """Existing leading/trailing newlines are not doubled.""" - result = _normalize_custom_guidance("\nhello\n") - assert result == "\nhello\n" - - -class TestGuidanceTextBuilders: - def test_knowledge_only_guidance_contains_header(self): - """Knowledge-only guidance includes the tooling guidance header.""" - text = _default_knowledge_only_guidance() - assert "Tooling and workspace guidance" in text - - def test_full_tooling_guidance_exec_enabled(self): - """Full tooling guidance with exec tools enabled mentions skill_exec.""" - text = _default_full_tooling_and_workspace_guidance(exec_tools_disabled=False) - assert "skill_exec" in text - - def test_full_tooling_guidance_exec_disabled(self): - """Full tooling guidance with exec tools disabled omits skill_exec mention.""" - text = _default_full_tooling_and_workspace_guidance(exec_tools_disabled=True) - assert "interactive execution is available" in text - - def test_default_routing_knowledge_only(self): - """Dispatcher routes to knowledge-only guidance for matching profile.""" - text = _default_tooling_and_workspace_guidance("knowledge_only", False) - assert "progressive disclosure" in text - - def test_default_routing_full(self): - """Dispatcher routes to full guidance for non-matching profile.""" - text = _default_tooling_and_workspace_guidance("", False) - assert "skill_run" in text - - -# --------------------------------------------------------------------------- -# SkillsRequestProcessor.__init__ -# --------------------------------------------------------------------------- - - -class TestSkillsRequestProcessorInit: - def test_defaults(self, sample_repo): - """Default parameters are applied correctly.""" - proc = SkillsRequestProcessor(sample_repo) - assert proc._load_mode == SKILL_LOAD_MODE_TURN - assert proc._tool_result_mode is False - assert proc._max_loaded_skills == 0 - - def test_custom_parameters(self, sample_repo): - """Custom init parameters are stored.""" - proc = SkillsRequestProcessor( - sample_repo, - load_mode="once", - tooling_guidance="custom", - tool_result_mode=True, - tool_profile="knowledge_only", - exec_tools_disabled=True, - max_loaded_skills=5, - ) - assert proc._load_mode == SKILL_LOAD_MODE_ONCE - assert proc._tooling_guidance == "custom" - assert proc._tool_result_mode is True - assert proc._tool_profile == "knowledge_only" - assert proc._exec_tools_disabled is True - assert proc._max_loaded_skills == 5 - - -# --------------------------------------------------------------------------- -# _get_repository -# --------------------------------------------------------------------------- - - -class TestGetRepository: - def test_returns_default_repo(self, sample_repo): - """Returns default repository when no resolver is set.""" - proc = SkillsRequestProcessor(sample_repo) - ctx_mock = MagicMock() - assert proc._get_repository(ctx_mock) is sample_repo - - def test_resolver_overrides(self, sample_repo): - """Repo resolver takes precedence over the default repository.""" - other_repo = _StubRepo() - proc = SkillsRequestProcessor(sample_repo, repo_resolver=lambda c: other_repo) - ctx_mock = MagicMock() - assert proc._get_repository(ctx_mock) is other_repo - - -# --------------------------------------------------------------------------- -# _snapshot_state / _read_state -# --------------------------------------------------------------------------- - - -class TestStateHelpers: - def test_snapshot_merges_delta(self, ctx, sample_repo): - """Snapshot merges session state with pending delta.""" - proc = SkillsRequestProcessor(sample_repo) - ctx.session.state["key_a"] = "from_session" - ctx.actions.state_delta["key_b"] = "from_delta" - snap = proc._snapshot_state(ctx) - assert snap["key_a"] == "from_session" - assert snap["key_b"] == "from_delta" - - def test_snapshot_delta_none_removes_key(self, ctx, sample_repo): - """Delta value of None removes the key from snapshot.""" - proc = SkillsRequestProcessor(sample_repo) - ctx.session.state["key_a"] = "exists" - ctx.actions.state_delta["key_a"] = None - snap = proc._snapshot_state(ctx) - assert "key_a" not in snap - - def test_read_state_from_delta(self, ctx, sample_repo): - """read_state prefers delta over session state.""" - proc = SkillsRequestProcessor(sample_repo) - ctx.session.state["k"] = "old" - ctx.actions.state_delta["k"] = "new" - assert proc._read_state(ctx, "k") == "new" - - def test_read_state_from_session(self, ctx, sample_repo): - """read_state falls back to session state when delta has no key.""" - proc = SkillsRequestProcessor(sample_repo) - ctx.session.state["k"] = "val" - assert proc._read_state(ctx, "k") == "val" - - def test_read_state_default(self, ctx, sample_repo): - """read_state returns default when key not in delta or session.""" - proc = SkillsRequestProcessor(sample_repo) - assert proc._read_state(ctx, "missing", "fallback") == "fallback" - - -# --------------------------------------------------------------------------- -# _get_loaded_skills -# --------------------------------------------------------------------------- - - -class TestGetLoadedSkills: - def test_no_loaded_skills(self, ctx, sample_repo): - """Returns empty list when no skills are loaded.""" - proc = SkillsRequestProcessor(sample_repo) - assert proc._get_loaded_skills(ctx) == [] - - def test_loaded_skills_detected(self, ctx, sample_repo): - """Skills with state key prefix are detected.""" - proc = SkillsRequestProcessor(sample_repo) - ctx.session.state[SKILL_LOADED_STATE_KEY_PREFIX + "code_review"] = "1" - result = proc._get_loaded_skills(ctx) - assert "code_review" in result - - def test_falsy_value_ignored(self, ctx, sample_repo): - """Skills with falsy state values are not returned.""" - proc = SkillsRequestProcessor(sample_repo) - ctx.session.state[SKILL_LOADED_STATE_KEY_PREFIX + "empty_skill"] = "" - assert proc._get_loaded_skills(ctx) == [] - - -# --------------------------------------------------------------------------- -# _inject_overview -# --------------------------------------------------------------------------- - - -class TestInjectOverview: - def test_overview_injected(self, sample_repo): - """Overview is injected into system instruction.""" - proc = SkillsRequestProcessor(sample_repo) - request = LlmRequest(model="test-sp-model") - proc._inject_overview(request, sample_repo) - sys_instr = str(request.config.system_instruction) - assert "code_review" in sys_instr - assert "Reviews code" in sys_instr - - def test_no_summaries_no_injection(self): - """No injection when repo has no summaries.""" - repo = _StubRepo(summaries_list=[]) - proc = SkillsRequestProcessor(repo) - request = LlmRequest(model="test-sp-model") - proc._inject_overview(request, repo) - assert request.config is None or request.config.system_instruction is None - - def test_double_injection_guard(self, sample_repo): - """Overview is not injected twice.""" - proc = SkillsRequestProcessor(sample_repo) - request = LlmRequest(model="test-sp-model") - proc._inject_overview(request, sample_repo) - first_instr = str(request.config.system_instruction) - proc._inject_overview(request, sample_repo) - second_instr = str(request.config.system_instruction) - assert first_instr == second_instr - - def test_user_prompt_prepended(self): - """Repository user_prompt is prepended to overview.""" - repo = _StubRepo( - summaries_list=[SkillSummary(name="s1", description="desc")], - user_prompt_text="Custom prompt", - ) - proc = SkillsRequestProcessor(repo) - request = LlmRequest(model="test-sp-model") - proc._inject_overview(request, repo) - sys_instr = str(request.config.system_instruction) - assert sys_instr.index("Custom prompt") < sys_instr.index("s1") - - -# --------------------------------------------------------------------------- -# _maybe_clear_skill_state_for_turn -# --------------------------------------------------------------------------- - - -class TestMaybeClearSkillStateForTurn: - def test_clears_on_first_invocation(self, ctx, sample_repo): - """State is cleared on first invocation in turn mode.""" - proc = SkillsRequestProcessor(sample_repo, load_mode="turn") - ctx.session.state[SKILL_LOADED_STATE_KEY_PREFIX + "sk1"] = "1" - proc._maybe_clear_skill_state_for_turn(ctx) - assert (SKILL_LOADED_STATE_KEY_PREFIX + "sk1") in ctx.actions.state_delta - assert ctx.actions.state_delta[SKILL_LOADED_STATE_KEY_PREFIX + "sk1"] is None - - def test_no_clear_on_second_call_same_invocation(self, ctx, sample_repo): - """State is NOT cleared on second call within same invocation.""" - proc = SkillsRequestProcessor(sample_repo, load_mode="turn") - proc._maybe_clear_skill_state_for_turn(ctx) - ctx.actions.state_delta.clear() - ctx.session.state[SKILL_LOADED_STATE_KEY_PREFIX + "sk2"] = "1" - proc._maybe_clear_skill_state_for_turn(ctx) - assert (SKILL_LOADED_STATE_KEY_PREFIX + "sk2") not in ctx.actions.state_delta - - def test_no_clear_in_session_mode(self, ctx, sample_repo): - """No clearing in session mode.""" - proc = SkillsRequestProcessor(sample_repo, load_mode="session") - ctx.session.state[SKILL_LOADED_STATE_KEY_PREFIX + "sk1"] = "1" - proc._maybe_clear_skill_state_for_turn(ctx) - assert (SKILL_LOADED_STATE_KEY_PREFIX + "sk1") not in ctx.actions.state_delta - - -# --------------------------------------------------------------------------- -# _maybe_offload_loaded_skills -# --------------------------------------------------------------------------- - - -class TestMaybeOffloadLoadedSkills: - def test_offloads_in_once_mode(self, ctx, sample_repo): - """Skill state is cleared after injection in once mode.""" - proc = SkillsRequestProcessor(sample_repo, load_mode="once") - proc._maybe_offload_loaded_skills(ctx, ["code_review"]) - assert ctx.actions.state_delta[SKILL_LOADED_STATE_KEY_PREFIX + "code_review"] is None - assert ctx.actions.state_delta[SKILL_DOCS_STATE_KEY_PREFIX + "code_review"] is None - assert ctx.actions.state_delta[SKILL_TOOLS_STATE_KEY_PREFIX + "code_review"] is None - assert ctx.actions.state_delta[_SKILLS_LOADED_ORDER_STATE_KEY] is None - - def test_no_offload_in_turn_mode(self, ctx, sample_repo): - """No offloading in turn mode.""" - proc = SkillsRequestProcessor(sample_repo, load_mode="turn") - proc._maybe_offload_loaded_skills(ctx, ["code_review"]) - assert SKILL_LOADED_STATE_KEY_PREFIX + "code_review" not in ctx.actions.state_delta - - def test_no_offload_when_empty(self, ctx, sample_repo): - """No offloading when loaded list is empty.""" - proc = SkillsRequestProcessor(sample_repo, load_mode="once") - proc._maybe_offload_loaded_skills(ctx, []) - assert len(ctx.actions.state_delta) == 0 - - -# --------------------------------------------------------------------------- -# _maybe_cap_loaded_skills -# --------------------------------------------------------------------------- - - -class TestMaybeCapLoadedSkills: - def test_no_cap_returns_all(self, ctx, sample_repo): - """All skills returned when cap is 0 (disabled).""" - proc = SkillsRequestProcessor(sample_repo, max_loaded_skills=0) - result = proc._maybe_cap_loaded_skills(ctx, ["a", "b", "c"]) - assert result == ["a", "b", "c"] - - def test_under_cap_returns_all(self, ctx, sample_repo): - """All skills returned when count is at or under cap.""" - proc = SkillsRequestProcessor(sample_repo, max_loaded_skills=5) - result = proc._maybe_cap_loaded_skills(ctx, ["a", "b"]) - assert result == ["a", "b"] - - def test_over_cap_evicts_oldest(self, ctx, sample_repo): - """Excess skills are evicted, keeping most recent.""" - proc = SkillsRequestProcessor(sample_repo, max_loaded_skills=2) - ctx.actions.state_delta[_SKILLS_LOADED_ORDER_STATE_KEY] = json.dumps(["a", "b", "c"]) - result = proc._maybe_cap_loaded_skills(ctx, ["a", "b", "c"]) - assert len(result) == 2 - assert "a" not in result - assert "b" in result - assert "c" in result - - -# --------------------------------------------------------------------------- -# _build_docs_text -# --------------------------------------------------------------------------- - - -class TestBuildDocsText: - def test_selected_docs_included(self, sample_repo, sample_skill): - """Only selected docs are included in output.""" - proc = SkillsRequestProcessor(sample_repo) - text = proc._build_docs_text(sample_skill, ["guide.md"]) - assert "Review guide content" in text - assert "Checklist content" not in text - - def test_no_docs_returns_empty(self, sample_repo): - """Empty string for skill with no resources.""" - proc = SkillsRequestProcessor(sample_repo) - empty_skill = Skill(summary=SkillSummary(name="empty", description="")) - assert proc._build_docs_text(empty_skill, ["any.md"]) == "" - - def test_none_skill_returns_empty(self, sample_repo): - """None skill returns empty string.""" - proc = SkillsRequestProcessor(sample_repo) - assert proc._build_docs_text(None, ["any.md"]) == "" - - -# --------------------------------------------------------------------------- -# _merge_into_system -# --------------------------------------------------------------------------- - - -class TestMergeIntoSystem: - def test_appends_to_system(self, sample_repo): - """Content is appended to system instruction.""" - proc = SkillsRequestProcessor(sample_repo) - request = LlmRequest(model="test-sp-model") - proc._merge_into_system(request, "extra guidance") - assert "extra guidance" in str(request.config.system_instruction) - - def test_empty_content_no_op(self, sample_repo): - """Empty string is a no-op.""" - proc = SkillsRequestProcessor(sample_repo) - request = LlmRequest(model="test-sp-model") - proc._merge_into_system(request, "") - assert request.config is None or request.config.system_instruction is None - - -# --------------------------------------------------------------------------- -# _capability_guidance_text -# --------------------------------------------------------------------------- - - -class TestCapabilityGuidanceText: - def test_non_knowledge_profile_empty(self, sample_repo): - """Non-knowledge profile returns empty string.""" - proc = SkillsRequestProcessor(sample_repo, tool_profile="full") - assert proc._capability_guidance_text() == "" - - def test_knowledge_profile_returns_guidance(self, sample_repo): - """Knowledge-only profile returns capability guidance.""" - proc = SkillsRequestProcessor(sample_repo, tool_profile="knowledge_only") - text = proc._capability_guidance_text() - assert "knowledge loading only" in text - - def test_knowledge_profile_with_empty_guidance_suppressed(self, sample_repo): - """Knowledge profile with explicit empty guidance suppresses capability block.""" - proc = SkillsRequestProcessor(sample_repo, tool_profile="knowledge_only", tooling_guidance="") - assert proc._capability_guidance_text() == "" - - -# --------------------------------------------------------------------------- -# process_llm_request (integration) -# --------------------------------------------------------------------------- - - -class TestProcessLlmRequest: - @pytest.mark.asyncio - async def test_none_request_returns_empty(self, ctx, sample_repo): - """Returns empty list for None request.""" - proc = SkillsRequestProcessor(sample_repo) - result = await proc.process_llm_request(ctx, None) - assert result == [] - - @pytest.mark.asyncio - async def test_none_ctx_returns_empty(self, sample_repo): - """Returns empty list for None ctx.""" - proc = SkillsRequestProcessor(sample_repo) - request = LlmRequest(model="test-sp-model") - result = await proc.process_llm_request(None, request) - assert result == [] - - @pytest.mark.asyncio - async def test_overview_injected_no_loaded(self, ctx, sample_repo): - """Overview is injected even when no skills are loaded.""" - proc = SkillsRequestProcessor(sample_repo) - request = LlmRequest(model="test-sp-model") - result = await proc.process_llm_request(ctx, request) - assert result == [] - assert "code_review" in str(request.config.system_instruction) - - @pytest.mark.asyncio - async def test_loaded_skill_body_injected(self, ctx, sample_repo): - """Loaded skill body is injected into system instruction.""" - proc = SkillsRequestProcessor(sample_repo, load_mode="session") - ctx.session.state[SKILL_LOADED_STATE_KEY_PREFIX + "code_review"] = "1" - request = LlmRequest(model="test-sp-model") - result = await proc.process_llm_request(ctx, request) - assert "code_review" in result - sys_instr = str(request.config.system_instruction) - assert "Review PR diffs" in sys_instr - - @pytest.mark.asyncio - async def test_tool_result_mode_skips_body_injection(self, ctx, sample_repo): - """In tool_result_mode, loaded skill bodies are NOT injected.""" - proc = SkillsRequestProcessor(sample_repo, tool_result_mode=True, load_mode="session") - ctx.session.state[SKILL_LOADED_STATE_KEY_PREFIX + "code_review"] = "1" - request = LlmRequest(model="test-sp-model") - result = await proc.process_llm_request(ctx, request) - assert "code_review" in result - sys_instr = str(request.config.system_instruction) if request.config and request.config.system_instruction else "" - assert "Review PR diffs" not in sys_instr - - @pytest.mark.asyncio - async def test_once_mode_offloads_after_injection(self, ctx, sample_repo): - """In once mode, skill state is cleared after injection.""" - proc = SkillsRequestProcessor(sample_repo, load_mode="once") - ctx.session.state[SKILL_LOADED_STATE_KEY_PREFIX + "code_review"] = "1" - request = LlmRequest(model="test-sp-model") - await proc.process_llm_request(ctx, request) - assert ctx.actions.state_delta.get(SKILL_LOADED_STATE_KEY_PREFIX + "code_review") is None - - @pytest.mark.asyncio - async def test_none_repo_returns_empty(self, ctx): - """Returns empty list when repository resolves to None.""" - proc = SkillsRequestProcessor( - MagicMock(), - repo_resolver=lambda c: None, - ) - request = LlmRequest(model="test-sp-model") - result = await proc.process_llm_request(ctx, request) - assert result == [] - - -# --------------------------------------------------------------------------- -# _get_loaded_skill_order -# --------------------------------------------------------------------------- - - -class TestGetLoadedSkillOrder: - def test_no_persisted_order(self, ctx, sample_repo): - """Missing order key returns alphabetical order.""" - proc = SkillsRequestProcessor(sample_repo) - order = proc._get_loaded_skill_order(ctx, ["beta", "alpha"]) - assert order == ["alpha", "beta"] - - def test_persisted_order_respected(self, ctx, sample_repo): - """Persisted order is respected for known skills.""" - proc = SkillsRequestProcessor(sample_repo) - ctx.session.state[_SKILLS_LOADED_ORDER_STATE_KEY] = json.dumps(["beta", "alpha"]) - order = proc._get_loaded_skill_order(ctx, ["alpha", "beta"]) - assert order == ["beta", "alpha"] - - def test_new_skills_appended_alphabetically(self, ctx, sample_repo): - """Skills not in persisted order are appended alphabetically.""" - proc = SkillsRequestProcessor(sample_repo) - ctx.session.state[_SKILLS_LOADED_ORDER_STATE_KEY] = json.dumps(["beta"]) - order = proc._get_loaded_skill_order(ctx, ["alpha", "beta", "gamma"]) - assert order == ["beta", "alpha", "gamma"] - - def test_invalid_json_falls_back(self, ctx, sample_repo): - """Invalid JSON in order key falls back to alphabetical.""" - proc = SkillsRequestProcessor(sample_repo) - ctx.session.state[_SKILLS_LOADED_ORDER_STATE_KEY] = "not-json" - order = proc._get_loaded_skill_order(ctx, ["c", "a", "b"]) - assert order == ["a", "b", "c"] +def _build_repo(skills: dict[str, Skill]) -> Mock: + repo = Mock() + repo.summaries.return_value = [skill.summary for skill in skills.values()] + repo.user_prompt.return_value = "repo user prompt" + repo.get.side_effect = lambda name: skills.get(name) + return repo + + +def _build_ctx(state: dict, *, agent_name: str = "demo-agent", load_mode: str = "turn") -> Mock: + ctx = Mock(spec=InvocationContext) + ctx.agent_name = agent_name + ctx.session_state = state + ctx.actions = Mock() + ctx.actions.state_delta = {} + ctx.agent_context = AgentContext() + config = deepcopy(DEFAULT_SKILL_CONFIG) + config["skill_processor"]["load_mode"] = load_mode + set_skill_config(ctx.agent_context, config) + ctx.session = Mock() + ctx.session.state = state + ctx.session.events = [] + return ctx + + +class TestSkillsRequestProcessor: + + async def test_injects_overview_and_loaded_content(self): + skill = _build_skill("demo-skill") + repo = _build_repo({"demo-skill": skill}) + loaded_state_key = loaded_session_key(loaded_key("demo-agent", "demo-skill")) + docs_state_key = docs_session_key(docs_key("demo-agent", "demo-skill")) + tools_state_key = tool_session_key(tool_key("demo-agent", "demo-skill")) + ctx = _build_ctx({ + loaded_state_key: True, + docs_state_key: json.dumps(["docs/guide.md"]), + tools_state_key: json.dumps(["tool-a"]), + }, load_mode="session") + request = LlmRequest(contents=[], tools_dict={}) + processor = SkillsRequestProcessor(repo, load_mode="session") + + loaded = await processor.process_llm_request(ctx, request) + + assert loaded == ["demo-skill"] + assert request.config is not None + text = request.config.system_instruction + assert "repo user prompt" in text + assert "Available skills:" in text + assert "- demo-skill: demo-skill description" in text + assert "[Loaded] demo-skill" in text + assert "Docs loaded: docs/guide.md" in text + assert "[Doc] docs/guide.md" in text + assert "Tools selected: tool-a" in text + + async def test_tool_result_mode_skips_loaded_materialization(self): + skill = _build_skill("demo-skill") + repo = _build_repo({"demo-skill": skill}) + loaded_state_key = loaded_session_key(loaded_key("demo-agent", "demo-skill")) + ctx = _build_ctx({loaded_state_key: True}, load_mode="session") + request = LlmRequest(contents=[], tools_dict={}) + processor = SkillsRequestProcessor(repo, tool_result_mode=True, load_mode="session") + + loaded = await processor.process_llm_request(ctx, request) + + assert loaded == ["demo-skill"] + assert request.config is not None + text = request.config.system_instruction + assert "Available skills:" in text + assert "[Loaded] demo-skill" not in text + + async def test_once_mode_offloads_loaded_state(self): + skill = _build_skill("demo-skill") + repo = _build_repo({"demo-skill": skill}) + ctx = _build_ctx({ + loaded_key("demo-agent", "demo-skill"): True, + docs_key("demo-agent", "demo-skill"): json.dumps(["docs/guide.md"]), + tool_key("demo-agent", "demo-skill"): json.dumps(["tool-a"]), + }) + request = LlmRequest(contents=[], tools_dict={}) + processor = SkillsRequestProcessor(repo, load_mode="once") + + await processor.process_llm_request(ctx, request) + + assert ctx.actions.state_delta[loaded_key("demo-agent", "demo-skill")] is None + assert ctx.actions.state_delta[docs_key("demo-agent", "demo-skill")] is None + assert ctx.actions.state_delta[tool_key("demo-agent", "demo-skill")] is None + assert ctx.actions.state_delta[loaded_order_key("demo-agent")] is None + + async def test_turn_mode_clears_previous_skill_state(self): + skill = _build_skill("demo-skill") + repo = _build_repo({"demo-skill": skill}) + ctx = _build_ctx({ + loaded_key("demo-agent", "demo-skill"): True, + docs_key("demo-agent", "demo-skill"): json.dumps(["docs/guide.md"]), + tool_key("demo-agent", "demo-skill"): json.dumps(["tool-a"]), + loaded_order_key("demo-agent"): json.dumps(["demo-skill"]), + }) + request = LlmRequest(contents=[], tools_dict={}) + processor = SkillsRequestProcessor(repo, load_mode="turn") + + loaded = await processor.process_llm_request(ctx, request) + + assert loaded == [] + assert ctx.actions.state_delta[loaded_key("demo-agent", "demo-skill")] is None + assert ctx.actions.state_delta[docs_key("demo-agent", "demo-skill")] is None + assert ctx.actions.state_delta[tool_key("demo-agent", "demo-skill")] is None + assert ctx.actions.state_delta[loaded_order_key("demo-agent")] is None + assert ctx.agent_context.get_metadata("processor:skills:turn_init") is True + + async def test_session_mode_uses_temp_only_state(self): + skill = _build_skill("demo-skill") + repo = _build_repo({"demo-skill": skill}) + loaded_state_key = loaded_session_key(loaded_key("demo-agent", "demo-skill")) + docs_state_key = docs_session_key(docs_key("demo-agent", "demo-skill")) + tools_state_key = tool_session_key(tool_key("demo-agent", "demo-skill")) + session_order_key = loaded_session_order_key(loaded_order_key("demo-agent")) + ctx = _build_ctx({ + loaded_state_key: True, + docs_state_key: json.dumps(["docs/guide.md"]), + tools_state_key: json.dumps(["tool-a"]), + session_order_key: json.dumps(["demo-skill"]), + }, load_mode="session") + request = LlmRequest(contents=[], tools_dict={}) + processor = SkillsRequestProcessor(repo, load_mode="session") + + loaded = await processor.process_llm_request(ctx, request) + + assert loaded == ["demo-skill"] + assert loaded_key("demo-agent", "demo-skill") not in ctx.actions.state_delta + assert docs_key("demo-agent", "demo-skill") not in ctx.actions.state_delta + assert tool_key("demo-agent", "demo-skill") not in ctx.actions.state_delta + + async def test_max_loaded_skills_evicts_lru_skills_in_session_state(self): + skill_a = _build_skill("skill-a") + skill_b = _build_skill("skill-b") + repo = _build_repo({"skill-a": skill_a, "skill-b": skill_b}) + loaded_a = loaded_session_key(loaded_key("demo-agent", "skill-a")) + loaded_b = loaded_session_key(loaded_key("demo-agent", "skill-b")) + docs_a = docs_session_key(docs_key("demo-agent", "skill-a")) + docs_b = docs_session_key(docs_key("demo-agent", "skill-b")) + tools_a = tool_session_key(tool_key("demo-agent", "skill-a")) + tools_b = tool_session_key(tool_key("demo-agent", "skill-b")) + order_key = loaded_session_order_key(loaded_order_key("demo-agent")) + ctx = _build_ctx({ + loaded_a: True, + loaded_b: True, + docs_a: json.dumps(["docs/guide.md"]), + docs_b: json.dumps(["docs/guide.md"]), + tools_a: json.dumps(["tool-a"]), + tools_b: json.dumps(["tool-a"]), + order_key: json.dumps(["skill-a", "skill-b"]), + }, load_mode="session") + request = LlmRequest(contents=[], tools_dict={}) + processor = SkillsRequestProcessor(repo, max_loaded_skills=1, load_mode="session") + + loaded = await processor.process_llm_request(ctx, request) + + assert loaded == ["skill-b"] + assert ctx.actions.state_delta[loaded_a] is None + assert ctx.actions.state_delta[docs_a] is None + assert ctx.actions.state_delta[tools_a] is None + assert ctx.actions.state_delta[order_key] == json.dumps(["skill-b"]) + + +class TestSkillProcessorParameters: + + def test_set_and_get_skill_processor_parameters(self): + agent_context = AgentContext() + parameters = {"load_mode": "session", "max_loaded_skills": 3} + + set_skill_processor_parameters(agent_context, parameters) + got = get_skill_processor_parameters(agent_context) + + assert got["load_mode"] == "session" + assert got["max_loaded_skills"] == 3 diff --git a/tests/agents/core/test_skill_tool_result_processor.py b/tests/agents/core/test_skill_tool_result_processor.py new file mode 100644 index 0000000..b635677 --- /dev/null +++ b/tests/agents/core/test_skill_tool_result_processor.py @@ -0,0 +1,136 @@ +# Tencent is pleased to support the open source community by making tRPC-Agent-Python available. +# +# Copyright (C) 2026 Tencent. All rights reserved. +# +# tRPC-Agent-Python is licensed under Apache-2.0. + +import json +from copy import deepcopy +from unittest.mock import Mock + +import trpc_agent_sdk.skills as _skills_pkg + +if not hasattr(_skills_pkg, "get_skill_processor_parameters"): + + def _compat_get_skill_processor_parameters(agent_context): + from trpc_agent_sdk.agents.core._skill_processor import get_skill_processor_parameters as _impl + return _impl(agent_context) + + _skills_pkg.get_skill_processor_parameters = _compat_get_skill_processor_parameters + +from trpc_agent_sdk.agents.core._skills_tool_result_processor import SkillsToolResultRequestProcessor +from trpc_agent_sdk.context import AgentContext +from trpc_agent_sdk.context import InvocationContext +from trpc_agent_sdk.models import LlmRequest +from trpc_agent_sdk.skills import Skill +from trpc_agent_sdk.skills import SkillResource +from trpc_agent_sdk.skills import SkillSummary +from trpc_agent_sdk.skills import SkillToolsNames +from trpc_agent_sdk.skills import docs_key +from trpc_agent_sdk.skills import loaded_key +from trpc_agent_sdk.skills import loaded_order_key +from trpc_agent_sdk.skills._skill_config import DEFAULT_SKILL_CONFIG +from trpc_agent_sdk.skills._skill_config import set_skill_config +from trpc_agent_sdk.types import Content +from trpc_agent_sdk.types import Part + + +def _build_context(agent_name: str, state: dict, *, load_mode: str = "turn") -> Mock: + ctx = Mock(spec=InvocationContext) + ctx.agent_name = agent_name + ctx.session_state = state + ctx.actions = Mock() + ctx.actions.state_delta = {} + ctx.agent_context = AgentContext() + config = deepcopy(DEFAULT_SKILL_CONFIG) + config["skill_processor"]["load_mode"] = load_mode + set_skill_config(ctx.agent_context, config) + return ctx + + +def _build_skill() -> Skill: + return Skill( + summary=SkillSummary(name="demo-skill", description="demo"), + body="Use this skill body.", + resources=[SkillResource(path="docs/guide.md", content="Guide content.")], + ) + + +class TestSkillsToolResultRequestProcessor: + + async def test_materialize_tool_result_from_skill_load(self): + repo = Mock() + repo.get.return_value = _build_skill() + + ctx = _build_context( + "demo-agent", + { + loaded_key("demo-agent", "demo-skill"): True, + docs_key("demo-agent", "demo-skill"): json.dumps(["docs/guide.md"]), + }, + ) + processor = SkillsToolResultRequestProcessor(repo) + + call_part = Part.from_function_call(name=SkillToolsNames.LOAD, args={"skill": "demo-skill"}) + call_part.function_call.id = "call_1" + response_part = Part.from_function_response(name=SkillToolsNames.LOAD, response={"result": "skill 'demo-skill' loaded"}) + response_part.function_response.id = "call_1" + request = LlmRequest( + contents=[ + Content(role="model", parts=[call_part]), + Content(role="user", parts=[response_part]), + ], + tools_dict={}, + ) + + loaded = await processor.process_llm_request(ctx, request) + + assert loaded == ["demo-skill"] + result = response_part.function_response.response["result"] + assert "[Loaded] demo-skill" in result + assert "Docs loaded: docs/guide.md" in result + assert "[Doc] docs/guide.md" in result + + async def test_fallback_to_system_instruction_when_tool_result_missing(self): + repo = Mock() + repo.get.return_value = _build_skill() + ctx = _build_context( + "demo-agent", + {loaded_key("demo-agent", "demo-skill"): True}, + ) + processor = SkillsToolResultRequestProcessor(repo) + request = LlmRequest(contents=[Content(role="user", parts=[Part.from_text(text="hello")])], tools_dict={}) + + await processor.process_llm_request(ctx, request) + + assert request.config is not None + assert "Loaded skill context:" in request.config.system_instruction + assert "[Loaded] demo-skill" in request.config.system_instruction + + async def test_once_mode_offloads_loaded_skill_state(self): + repo = Mock() + repo.get.return_value = _build_skill() + state = { + loaded_key("demo-agent", "demo-skill"): True, + docs_key("demo-agent", "demo-skill"): json.dumps(["docs/guide.md"]), + } + ctx = _build_context("demo-agent", state, load_mode="once") + processor = SkillsToolResultRequestProcessor(repo) + + call_part = Part.from_function_call(name=SkillToolsNames.LOAD, args={"skill": "demo-skill"}) + call_part.function_call.id = "call_1" + response_part = Part.from_function_response(name=SkillToolsNames.LOAD, response={"result": "skill 'demo-skill' loaded"}) + response_part.function_response.id = "call_1" + request = LlmRequest( + contents=[ + Content(role="model", parts=[call_part]), + Content(role="user", parts=[response_part]), + ], + tools_dict={}, + ) + + await processor.process_llm_request(ctx, request) + + assert ctx.actions.state_delta[loaded_key("demo-agent", "demo-skill")] is None + assert ctx.actions.state_delta[docs_key("demo-agent", "demo-skill")] is None + assert ctx.actions.state_delta[loaded_order_key("demo-agent")] is None diff --git a/tests/agents/core/test_tools_processor.py b/tests/agents/core/test_tools_processor.py index a7c49d5..5ebe512 100644 --- a/tests/agents/core/test_tools_processor.py +++ b/tests/agents/core/test_tools_processor.py @@ -12,6 +12,15 @@ from unittest.mock import AsyncMock, Mock, patch import pytest +import trpc_agent_sdk.skills as _skills_pkg + +if not hasattr(_skills_pkg, "get_skill_processor_parameters"): + + def _compat_get_skill_processor_parameters(agent_context): + from trpc_agent_sdk.agents.core._skill_processor import get_skill_processor_parameters as _impl + return _impl(agent_context) + + _skills_pkg.get_skill_processor_parameters = _compat_get_skill_processor_parameters from trpc_agent_sdk.agents._base_agent import BaseAgent from trpc_agent_sdk.agents.core._tools_processor import ToolsProcessor diff --git a/tests/agents/core/test_workspace_exec_processor.py b/tests/agents/core/test_workspace_exec_processor.py new file mode 100644 index 0000000..4fa100e --- /dev/null +++ b/tests/agents/core/test_workspace_exec_processor.py @@ -0,0 +1,127 @@ +# Tencent is pleased to support the open source community by making tRPC-Agent-Python available. +# +# Copyright (C) 2026 Tencent. All rights reserved. +# +# tRPC-Agent-Python is licensed under Apache-2.0. + +from unittest.mock import Mock + +import trpc_agent_sdk.skills as _skills_pkg + +if not hasattr(_skills_pkg, "get_skill_processor_parameters"): + + def _compat_get_skill_processor_parameters(agent_context): + from trpc_agent_sdk.agents.core._skill_processor import get_skill_processor_parameters as _impl + return _impl(agent_context) + + _skills_pkg.get_skill_processor_parameters = _compat_get_skill_processor_parameters + +from trpc_agent_sdk.agents.core._workspace_exec_processor import WorkspaceExecRequestProcessor +from trpc_agent_sdk.agents.core._workspace_exec_processor import get_workspace_exec_processor_parameters +from trpc_agent_sdk.agents.core._workspace_exec_processor import set_workspace_exec_processor_parameters +from trpc_agent_sdk.context import AgentContext +from trpc_agent_sdk.context import InvocationContext +from trpc_agent_sdk.models import LlmRequest + + +def _build_ctx() -> Mock: + ctx = Mock(spec=InvocationContext) + ctx.agent_name = "demo-agent" + return ctx + + +def _build_request_with_tools(*tool_names: str) -> LlmRequest: + request = LlmRequest(contents=[], tools_dict={}) + config = Mock() + config.system_instruction = "" + config.tools = [] + for name in tool_names: + declaration = Mock() + declaration.name = name + tool = Mock() + tool.function_declarations = [declaration] + config.tools.append(tool) + request.config = config + return request + + +class TestWorkspaceExecRequestProcessor: + + async def test_injects_guidance_when_workspace_exec_present(self): + processor = WorkspaceExecRequestProcessor() + ctx = _build_ctx() + request = _build_request_with_tools("workspace_exec") + + await processor.process_llm_request(ctx, request) + + assert "Executor workspace guidance:" in request.config.system_instruction + assert "workspace_exec" in request.config.system_instruction + + async def test_does_not_duplicate_guidance_header(self): + processor = WorkspaceExecRequestProcessor() + ctx = _build_ctx() + request = _build_request_with_tools("workspace_exec") + request.config.system_instruction = "Executor workspace guidance:\nexisting" + + await processor.process_llm_request(ctx, request) + + assert request.config.system_instruction == "Executor workspace guidance:\nexisting" + + async def test_includes_artifact_and_session_hints_when_tools_available(self): + processor = WorkspaceExecRequestProcessor() + ctx = _build_ctx() + request = _build_request_with_tools( + "workspace_exec", + "workspace_save_artifact", + "workspace_write_stdin", + "workspace_kill_session", + ) + + await processor.process_llm_request(ctx, request) + + text = request.config.system_instruction + assert "workspace_save_artifact" in text + assert "workspace_write_stdin" in text + assert "workspace_kill_session" in text + + async def test_includes_skills_repo_hint_via_repo_resolver(self): + processor = WorkspaceExecRequestProcessor(repo_resolver=lambda _ctx: object()) + ctx = _build_ctx() + request = _build_request_with_tools("workspace_exec") + + await processor.process_llm_request(ctx, request) + + assert "Paths under skills/" in request.config.system_instruction + + async def test_respects_enabled_resolver(self): + processor = WorkspaceExecRequestProcessor(enabled_resolver=lambda _ctx: False) + ctx = _build_ctx() + request = _build_request_with_tools("workspace_exec") + + await processor.process_llm_request(ctx, request) + + assert request.config.system_instruction == "" + + async def test_sessions_resolver_can_enable_session_guidance_without_session_tools(self): + processor = WorkspaceExecRequestProcessor(sessions_resolver=lambda _ctx: True) + ctx = _build_ctx() + request = _build_request_with_tools("workspace_exec") + + await processor.process_llm_request(ctx, request) + + text = request.config.system_instruction + assert "workspace_write_stdin" in text + assert "workspace_kill_session" in text + + +class TestWorkspaceExecProcessorParameters: + + def test_set_and_get_workspace_exec_processor_parameters(self): + agent_context = AgentContext() + parameters = {"session_tools": True, "has_skills_repo": True} + + set_workspace_exec_processor_parameters(agent_context, parameters) + got = get_workspace_exec_processor_parameters(agent_context) + + assert got["session_tools"] is True + assert got["has_skills_repo"] is True diff --git a/tests/skills/__init__.py b/tests/skills/__init__.py index e69de29..8b13789 100644 --- a/tests/skills/__init__.py +++ b/tests/skills/__init__.py @@ -0,0 +1 @@ + diff --git a/tests/skills/stager/__init__.py b/tests/skills/stager/__init__.py index e69de29..8b13789 100644 --- a/tests/skills/stager/__init__.py +++ b/tests/skills/stager/__init__.py @@ -0,0 +1 @@ + diff --git a/tests/skills/stager/test_base_stager.py b/tests/skills/stager/test_base_stager.py index c5a8a00..aecc774 100644 --- a/tests/skills/stager/test_base_stager.py +++ b/tests/skills/stager/test_base_stager.py @@ -67,14 +67,13 @@ def _make_repository(path="/skills/test-skill"): return repo -def _make_request(skill_name="test-skill", runtime=None, repo=None, ws=None, ctx=None): +def _make_request(skill_name="test-skill", repo=None, ws=None, ctx=None): from trpc_agent_sdk.skills.stager._types import SkillStageRequest return SkillStageRequest( skill_name=skill_name, repository=repo or _make_repository(), workspace=ws or _make_workspace(), ctx=ctx or _make_ctx(), - engine=runtime, ) @@ -262,7 +261,7 @@ class TestStageSkill: async def test_fresh_staging(self, mock_digest): stager = Stager() request = _make_request() - runtime = request.engine or request.repository.workspace_runtime + runtime = request.repository.workspace_runtime mock_file = MagicMock() mock_file.content = json.dumps({"version": 1, "skills": {}}) @@ -275,7 +274,7 @@ async def test_fresh_staging(self, mock_digest): async def test_cached_staging_with_links(self, mock_digest): stager = Stager() request = _make_request() - runtime = request.engine or request.repository.workspace_runtime + runtime = request.repository.workspace_runtime md_data = { "version": 1, diff --git a/tests/skills/stager/test_types.py b/tests/skills/stager/test_types.py index efef3ee..b269c9c 100644 --- a/tests/skills/stager/test_types.py +++ b/tests/skills/stager/test_types.py @@ -26,7 +26,6 @@ def test_creation(self): ctx=MagicMock(), ) assert req.skill_name == "test" - assert req.engine is None assert req.timeout == 300.0 def test_custom_timeout(self): diff --git a/tests/skills/test_common.py b/tests/skills/test_common.py index a8f7d85..221c7f6 100644 --- a/tests/skills/test_common.py +++ b/tests/skills/test_common.py @@ -3,308 +3,530 @@ # Copyright (C) 2026 Tencent. All rights reserved. # # tRPC-Agent-Python is licensed under Apache-2.0. -"""Unit tests for trpc_agent_sdk.skills._common. - -Covers: -- SelectionMode enum -- BaseSelectionResult model -- get_state_delta_value -- get_previous_selection -- clear_selection, add_selection, replace_selection -- set_state_delta_for_selection -- generic_select_items (all modes, edge cases) -- generic_get_selection (all branches) -""" - -from __future__ import annotations - -import json -from unittest.mock import MagicMock, Mock import pytest -from pydantic import BaseModel, Field +import json +from unittest.mock import Mock +from pydantic import Field from trpc_agent_sdk.skills._common import ( - BaseSelectionResult, + get_state_delta_value, SelectionMode, - add_selection, - clear_selection, - generic_get_selection, - generic_select_items, + BaseSelectionResult, + append_loaded_order_state_delta, get_previous_selection, - get_state_delta_value, + clear_selection, + add_selection, replace_selection, set_state_delta_for_selection, + generic_select_items, + generic_get_selection, ) +from trpc_agent_sdk.context import InvocationContext +from trpc_agent_sdk.skills._state_keys import loaded_order_key -class _TestSelectionResult(BaseSelectionResult): - """Concrete test subclass for BaseSelectionResult.""" +class MockSelectionResult(BaseSelectionResult): + """Mock selection result class for testing.""" selected_items: list[str] = Field(default_factory=list) include_all: bool = Field(default=False) -def _make_ctx(state_delta=None, session_state=None): - ctx = MagicMock() - ctx.actions.state_delta = state_delta or {} - ctx.session_state = session_state or {} - return ctx +@pytest.fixture(autouse=True) +def _mock_turn_load_mode(monkeypatch): + monkeypatch.setattr("trpc_agent_sdk.skills._common.get_skill_load_mode", lambda _ctx: "turn") + + +class TestGetStateDeltaValue: + """Test suite for get_state_delta_value function.""" + + def test_get_from_state_delta(self): + """Test getting value from state_delta.""" + mock_ctx = Mock(spec=InvocationContext) + mock_ctx.actions = Mock() + mock_ctx.actions.state_delta = {"key1": "value1"} + mock_ctx.session_state = {"key1": "old_value"} + result = get_state_delta_value(mock_ctx, "key1") + + assert result == "value1" + + def test_get_from_session_state(self): + """Test getting value from session_state when not in state_delta.""" + mock_ctx = Mock(spec=InvocationContext) + mock_ctx.actions = Mock() + mock_ctx.actions.state_delta = {} + mock_ctx.session_state = {"key1": "value1"} + + result = get_state_delta_value(mock_ctx, "key1") + + assert result == "value1" + + def test_get_nonexistent_key(self): + """Test getting nonexistent key returns None.""" + mock_ctx = Mock(spec=InvocationContext) + mock_ctx.actions = Mock() + mock_ctx.actions.state_delta = {} + mock_ctx.session_state = {} + + result = get_state_delta_value(mock_ctx, "nonexistent") + + assert result is None -# --------------------------------------------------------------------------- -# SelectionMode -# --------------------------------------------------------------------------- class TestSelectionMode: - def test_values(self): - assert SelectionMode.ADD == "add" - assert SelectionMode.REPLACE == "replace" - assert SelectionMode.CLEAR == "clear" + """Test suite for SelectionMode enum.""" - def test_from_string(self): + def test_selection_mode_values(self): + """Test SelectionMode enum values.""" + assert SelectionMode.ADD.value == "add" + assert SelectionMode.REPLACE.value == "replace" + assert SelectionMode.CLEAR.value == "clear" + + def test_selection_mode_from_string(self): + """Test creating SelectionMode from string.""" assert SelectionMode("add") == SelectionMode.ADD + assert SelectionMode("replace") == SelectionMode.REPLACE + assert SelectionMode("clear") == SelectionMode.CLEAR + def test_selection_mode_invalid_string(self): + """Test invalid string raises ValueError.""" + with pytest.raises(ValueError): + SelectionMode("invalid") -# --------------------------------------------------------------------------- -# get_state_delta_value -# --------------------------------------------------------------------------- -class TestGetStateDeltaValue: - def test_from_state_delta(self): - ctx = _make_ctx(state_delta={"key": "delta_value"}) - assert get_state_delta_value(ctx, "key") == "delta_value" +class TestBaseSelectionResult: + """Test suite for BaseSelectionResult class.""" - def test_from_session_state(self): - ctx = _make_ctx(session_state={"key": "session_value"}) - assert get_state_delta_value(ctx, "key") == "session_value" + def test_create_base_selection_result(self): + """Test creating BaseSelectionResult.""" + result = BaseSelectionResult(skill="test-skill", mode="replace") - def test_delta_takes_precedence(self): - ctx = _make_ctx(state_delta={"k": "delta"}, session_state={"k": "session"}) - assert get_state_delta_value(ctx, "k") == "delta" + assert result.skill == "test-skill" + assert result.mode == "replace" - def test_missing_returns_none(self): - ctx = _make_ctx() - assert get_state_delta_value(ctx, "missing") is None + def test_default_mode(self): + """Test default mode is empty string.""" + result = BaseSelectionResult(skill="test-skill") + assert result.mode == "" -# --------------------------------------------------------------------------- -# get_previous_selection -# --------------------------------------------------------------------------- class TestGetPreviousSelection: - def test_no_value_returns_empty_list(self): - ctx = _make_ctx() - result = get_previous_selection(ctx, "prefix:", "skill") - assert result == [] + """Test suite for get_previous_selection function.""" + + def test_get_previous_selection_json(self): + """Test getting previous selection from JSON.""" + mock_ctx = Mock(spec=InvocationContext) + mock_ctx.session_state = {"temp:skill:docs:test-skill": json.dumps(["doc1.md", "doc2.md"])} + + result = get_previous_selection(mock_ctx, "temp:skill:docs:", "test-skill") + + assert result == ["doc1.md", "doc2.md"] + + def test_get_previous_selection_all(self): + """Test getting previous selection when all items selected.""" + mock_ctx = Mock(spec=InvocationContext) + mock_ctx.session_state = {"temp:skill:docs:test-skill": '*'} + + result = get_previous_selection(mock_ctx, "temp:skill:docs:", "test-skill") - def test_star_returns_none(self): - ctx = _make_ctx(session_state={"prefix:skill": "*"}) - result = get_previous_selection(ctx, "prefix:", "skill") assert result is None - def test_json_array(self): - ctx = _make_ctx(session_state={"prefix:skill": json.dumps(["a", "b"])}) - result = get_previous_selection(ctx, "prefix:", "skill") - assert result == ["a", "b"] + def test_get_previous_selection_not_found(self): + """Test getting previous selection when not found.""" + mock_ctx = Mock(spec=InvocationContext) + mock_ctx.session_state = {} - def test_invalid_json_returns_empty(self): - ctx = _make_ctx(session_state={"prefix:skill": "not json"}) - result = get_previous_selection(ctx, "prefix:", "skill") - assert result == [] + result = get_previous_selection(mock_ctx, "temp:skill:docs:", "test-skill") - def test_empty_string_returns_empty(self): - ctx = _make_ctx(session_state={"prefix:skill": ""}) - result = get_previous_selection(ctx, "prefix:", "skill") assert result == [] -# --------------------------------------------------------------------------- -# clear_selection -# --------------------------------------------------------------------------- +class TestAppendLoadedOrderStateDelta: + """Test suite for append_loaded_order_state_delta function.""" + + def test_appends_to_temp_order_when_temp_exists(self): + mock_ctx = Mock(spec=InvocationContext) + temp_key = loaded_order_key("demo-agent") + mock_ctx.session_state = {temp_key: json.dumps(["skill-a"])} + mock_ctx.actions = Mock() + mock_ctx.actions.state_delta = {} + mock_ctx.agent_context = Mock() + mock_ctx.agent_context.get_metadata = Mock(side_effect=lambda _key, default=None: default) + + append_loaded_order_state_delta(mock_ctx, "demo-agent", "skill-b") + + assert mock_ctx.actions.state_delta[temp_key] == json.dumps(["skill-a", "skill-b"]) + + def test_initializes_order_when_missing(self): + mock_ctx = Mock(spec=InvocationContext) + temp_key = loaded_order_key("demo-agent") + mock_ctx.session_state = {} + mock_ctx.actions = Mock() + mock_ctx.actions.state_delta = {} + mock_ctx.agent_context = Mock() + mock_ctx.agent_context.get_metadata = Mock(side_effect=lambda _key, default=None: default) + + append_loaded_order_state_delta(mock_ctx, "demo-agent", "skill-b") + + assert mock_ctx.actions.state_delta[temp_key] == json.dumps(["skill-b"]) + + def test_get_previous_selection_invalid_json(self): + """Test getting previous selection with invalid JSON.""" + mock_ctx = Mock(spec=InvocationContext) + mock_ctx.session_state = {"temp:skill:docs:test-skill": "invalid json"} + + result = get_previous_selection(mock_ctx, "temp:skill:docs:", "test-skill") + + assert result == [] + class TestClearSelection: - def test_clear(self): - result = clear_selection("skill", ["a", "b"], True, ["old"], _TestSelectionResult) - assert result.skill == "skill" + """Test suite for clear_selection function.""" + + def test_clear_selection(self): + """Test clearing selection.""" + result = clear_selection(skill_name="test-skill", + items=["item1"], + include_all=False, + previous_items=["item1", "item2"], + result_class=MockSelectionResult) + + assert isinstance(result, MockSelectionResult) + assert result.skill == "test-skill" + assert result.mode == "clear" assert result.selected_items == [] assert result.include_all is False - assert result.mode == "clear" - -# --------------------------------------------------------------------------- -# add_selection -# --------------------------------------------------------------------------- class TestAddSelection: - def test_add_to_empty(self): - result = add_selection("skill", ["a", "b"], False, [], _TestSelectionResult) - assert set(result.selected_items) == {"a", "b"} - assert result.include_all is False + """Test suite for add_selection function.""" + + def test_add_selection(self): + """Test adding to selection.""" + result = add_selection(skill_name="test-skill", + items=["item2", "item3"], + include_all=False, + previous_items=["item1"], + result_class=MockSelectionResult) + + assert isinstance(result, MockSelectionResult) + assert result.skill == "test-skill" assert result.mode == "add" + assert set(result.selected_items) == {"item1", "item2", "item3"} + assert result.include_all is False - def test_add_to_existing(self): - result = add_selection("skill", ["c"], False, ["a", "b"], _TestSelectionResult) - assert set(result.selected_items) == {"a", "b", "c"} - - def test_add_deduplicate(self): - result = add_selection("skill", ["a"], False, ["a"], _TestSelectionResult) - assert result.selected_items == ["a"] + def test_add_selection_duplicates(self): + """Test adding selection removes duplicates.""" + result = add_selection(skill_name="test-skill", + items=["item1", "item2"], + include_all=False, + previous_items=["item1"], + result_class=MockSelectionResult) + + assert len(result.selected_items) == 2 + assert result.selected_items.count("item1") == 1 + + def test_add_selection_include_all(self): + """Test adding selection with include_all.""" + result = add_selection(skill_name="test-skill", + items=["item2"], + include_all=True, + previous_items=["item1"], + result_class=MockSelectionResult) - def test_add_include_all(self): - result = add_selection("skill", ["a"], True, ["b"], _TestSelectionResult) assert result.selected_items == [] assert result.include_all is True -# --------------------------------------------------------------------------- -# replace_selection -# --------------------------------------------------------------------------- - class TestReplaceSelection: - def test_replace(self): - result = replace_selection("skill", ["x", "y"], False, ["old"], _TestSelectionResult) - assert result.selected_items == ["x", "y"] + """Test suite for replace_selection function.""" + + def test_replace_selection(self): + """Test replacing selection.""" + result = replace_selection(skill_name="test-skill", + items=["item2", "item3"], + include_all=False, + previous_items=["item1"], + result_class=MockSelectionResult) + + assert isinstance(result, MockSelectionResult) + assert result.skill == "test-skill" assert result.mode == "replace" + assert result.selected_items == ["item2", "item3"] + assert result.include_all is False + + def test_replace_selection_include_all(self): + """Test replacing selection with include_all.""" + result = replace_selection(skill_name="test-skill", + items=["item2"], + include_all=True, + previous_items=["item1"], + result_class=MockSelectionResult) - def test_replace_include_all(self): - result = replace_selection("skill", ["x"], True, [], _TestSelectionResult) assert result.selected_items == [] assert result.include_all is True -# --------------------------------------------------------------------------- -# set_state_delta_for_selection -# --------------------------------------------------------------------------- - class TestSetStateDeltaForSelection: - def test_sets_json_array(self): - ctx = _make_ctx() - result = _TestSelectionResult(skill="s", selected_items=["a", "b"], include_all=False) - set_state_delta_for_selection(ctx, "prefix:", result) - assert json.loads(ctx.actions.state_delta["prefix:s"]) == ["a", "b"] + """Test suite for set_state_delta_for_selection function.""" + + def test_set_state_delta_with_items(self): + """Test setting state delta with items.""" + mock_ctx = Mock(spec=InvocationContext) + mock_ctx.actions = Mock() + mock_ctx.actions.state_delta = {} + + result = MockSelectionResult(skill="test-skill", + selected_items=["item1", "item2"], + include_all=False, + mode="replace") + + set_state_delta_for_selection(mock_ctx, "temp:skill:test:", result) + + key = "temp:skill:test:test-skill" + assert key in mock_ctx.actions.state_delta + assert json.loads(mock_ctx.actions.state_delta[key]) == ["item1", "item2"] + + def test_set_state_delta_include_all(self): + """Test setting state delta with include_all.""" + mock_ctx = Mock(spec=InvocationContext) + mock_ctx.actions = Mock() + mock_ctx.actions.state_delta = {} + + result = MockSelectionResult(skill="test-skill", selected_items=[], include_all=True, mode="replace") + + set_state_delta_for_selection(mock_ctx, "temp:skill:test:", result) - def test_sets_star_for_include_all(self): - ctx = _make_ctx() - result = _TestSelectionResult(skill="s", selected_items=[], include_all=True) - set_state_delta_for_selection(ctx, "prefix:", result) - assert ctx.actions.state_delta["prefix:s"] == "*" + key = "temp:skill:test:test-skill" + assert mock_ctx.actions.state_delta[key] == '*' - def test_no_skill_is_noop(self): - ctx = _make_ctx() - result = _TestSelectionResult(skill="", selected_items=[]) - set_state_delta_for_selection(ctx, "prefix:", result) - assert len(ctx.actions.state_delta) == 0 + def test_set_state_delta_empty_skill(self): + """Test setting state delta with empty skill name does nothing.""" + mock_ctx = Mock(spec=InvocationContext) + mock_ctx.actions = Mock() + mock_ctx.actions.state_delta = {} + result = MockSelectionResult(skill="", selected_items=["item1"], include_all=False, mode="replace") + + set_state_delta_for_selection(mock_ctx, "temp:skill:test:", result) + + assert len(mock_ctx.actions.state_delta) == 0 -# --------------------------------------------------------------------------- -# generic_select_items -# --------------------------------------------------------------------------- class TestGenericSelectItems: - def test_replace_mode(self): - ctx = _make_ctx() - result = generic_select_items( - ctx, "skill", ["a", "b"], False, "replace", "prefix:", _TestSelectionResult - ) - assert result.selected_items == ["a", "b"] + """Test suite for generic_select_items function.""" + + def test_generic_select_items_replace(self): + """Test generic select items with replace mode.""" + mock_ctx = Mock(spec=InvocationContext) + mock_ctx.session_state = {} + mock_ctx.actions = Mock() + mock_ctx.actions.state_delta = {} + + result = generic_select_items(tool_context=mock_ctx, + skill_name="test-skill", + items=["item1", "item2"], + include_all=False, + mode="replace", + state_key_prefix="temp:skill:test:", + result_class=MockSelectionResult) + + assert isinstance(result, MockSelectionResult) + assert result.skill == "test-skill" assert result.mode == "replace" + assert result.selected_items == ["item1", "item2"] + + def test_generic_select_items_add(self): + """Test generic select items with add mode.""" + mock_ctx = Mock(spec=InvocationContext) + mock_ctx.session_state = {"temp:skill:test:test-skill": json.dumps(["item1"])} + mock_ctx.actions = Mock() + mock_ctx.actions.state_delta = {} + + result = generic_select_items(tool_context=mock_ctx, + skill_name="test-skill", + items=["item2"], + include_all=False, + mode="add", + state_key_prefix="temp:skill:test:", + result_class=MockSelectionResult) - def test_add_mode(self): - ctx = _make_ctx(session_state={"prefix:skill": json.dumps(["a"])}) - result = generic_select_items( - ctx, "skill", ["b"], False, "add", "prefix:", _TestSelectionResult - ) - assert set(result.selected_items) == {"a", "b"} assert result.mode == "add" + assert set(result.selected_items) == {"item1", "item2"} + + def test_generic_select_items_clear(self): + """Test generic select items with clear mode.""" + mock_ctx = Mock(spec=InvocationContext) + mock_ctx.session_state = {"temp:skill:test:test-skill": json.dumps(["item1"])} + mock_ctx.actions = Mock() + mock_ctx.actions.state_delta = {} + + result = generic_select_items(tool_context=mock_ctx, + skill_name="test-skill", + items=None, + include_all=False, + mode="clear", + state_key_prefix="temp:skill:test:", + result_class=MockSelectionResult) - def test_clear_mode(self): - ctx = _make_ctx(session_state={"prefix:skill": json.dumps(["a"])}) - result = generic_select_items( - ctx, "skill", [], False, "clear", "prefix:", _TestSelectionResult - ) - assert result.selected_items == [] assert result.mode == "clear" + assert result.selected_items == [] + + def test_generic_select_items_invalid_mode(self): + """Test generic select items with invalid mode defaults to replace.""" + mock_ctx = Mock(spec=InvocationContext) + mock_ctx.session_state = {} + mock_ctx.actions = Mock() + mock_ctx.actions.state_delta = {} + + result = generic_select_items(tool_context=mock_ctx, + skill_name="test-skill", + items=["item1"], + include_all=False, + mode="invalid", + state_key_prefix="temp:skill:test:", + result_class=MockSelectionResult) - def test_invalid_mode_defaults_to_replace(self): - ctx = _make_ctx() - result = generic_select_items( - ctx, "skill", ["x"], False, "invalid_mode", "prefix:", _TestSelectionResult - ) assert result.mode == "replace" - def test_previous_star_and_not_clearing_keeps_include_all(self): - ctx = _make_ctx(session_state={"prefix:skill": "*"}) - result = generic_select_items( - ctx, "skill", ["a"], False, "add", "prefix:", _TestSelectionResult - ) + def test_generic_select_items_previous_all(self): + """Test generic select items when previous was all.""" + mock_ctx = Mock(spec=InvocationContext) + mock_ctx.session_state = {"temp:skill:test:test-skill": '*'} + mock_ctx.actions = Mock() + mock_ctx.actions.state_delta = {} + + result = generic_select_items(tool_context=mock_ctx, + skill_name="test-skill", + items=["item1"], + include_all=False, + mode="add", + state_key_prefix="temp:skill:test:", + result_class=MockSelectionResult) + + # Should maintain include_all=True assert result.include_all is True - def test_previous_star_and_clear(self): - ctx = _make_ctx(session_state={"prefix:skill": "*"}) - result = generic_select_items( - ctx, "skill", [], False, "clear", "prefix:", _TestSelectionResult - ) - assert result.include_all is False - assert result.selected_items == [] + def test_generic_select_items_updates_state(self): + """Test generic select items updates state delta.""" + mock_ctx = Mock(spec=InvocationContext) + mock_ctx.session_state = {} + mock_ctx.actions = Mock() + mock_ctx.actions.state_delta = {} - def test_none_items_treated_as_empty(self): - ctx = _make_ctx() - result = generic_select_items( - ctx, "skill", None, False, "replace", "prefix:", _TestSelectionResult - ) - assert result.selected_items == [] - - def test_updates_state_delta(self): - ctx = _make_ctx() - generic_select_items( - ctx, "skill", ["a"], False, "replace", "prefix:", _TestSelectionResult - ) - assert "prefix:skill" in ctx.actions.state_delta + result = generic_select_items(tool_context=mock_ctx, + skill_name="test-skill", + items=["item1"], + include_all=False, + mode="replace", + state_key_prefix="temp:skill:test:", + result_class=MockSelectionResult) + key = "temp:skill:test:test-skill" + assert key in mock_ctx.actions.state_delta -# --------------------------------------------------------------------------- -# generic_get_selection -# --------------------------------------------------------------------------- class TestGenericGetSelection: - def test_no_value_returns_empty(self): - ctx = _make_ctx() - assert generic_get_selection(ctx, "skill", "prefix:") == [] - - def test_json_array(self): - ctx = _make_ctx(state_delta={"prefix:skill": json.dumps(["a", "b"])}) - assert generic_get_selection(ctx, "skill", "prefix:") == ["a", "b"] - - def test_star_with_callback(self): - ctx = _make_ctx(state_delta={"prefix:skill": "*"}) - callback = Mock(return_value=["all_a", "all_b"]) - result = generic_get_selection(ctx, "skill", "prefix:", callback) - assert result == ["all_a", "all_b"] - callback.assert_called_once_with("skill") - - def test_star_without_callback(self): - ctx = _make_ctx(state_delta={"prefix:skill": "*"}) - assert generic_get_selection(ctx, "skill", "prefix:") == [] - - def test_star_callback_exception_returns_empty(self): - ctx = _make_ctx(state_delta={"prefix:skill": "*"}) - callback = Mock(side_effect=RuntimeError("boom")) - assert generic_get_selection(ctx, "skill", "prefix:", callback) == [] - - def test_invalid_json_returns_empty(self): - ctx = _make_ctx(state_delta={"prefix:skill": "not_json"}) - assert generic_get_selection(ctx, "skill", "prefix:") == [] - - def test_bytes_value_decoded(self): - ctx = _make_ctx(state_delta={"prefix:skill": json.dumps(["x"]).encode("utf-8")}) - assert generic_get_selection(ctx, "skill", "prefix:") == ["x"] - - def test_bytes_star_with_callback(self): - ctx = _make_ctx(state_delta={"prefix:skill": b"*"}) - callback = Mock(return_value=["all"]) - result = generic_get_selection(ctx, "skill", "prefix:", callback) - assert result == ["all"] - - def test_non_list_json_returns_empty(self): - ctx = _make_ctx(state_delta={"prefix:skill": json.dumps({"not": "list"})}) - assert generic_get_selection(ctx, "skill", "prefix:") == [] + """Test suite for generic_get_selection function.""" + + def test_generic_get_selection_json_array(self): + """Test getting selection from JSON array.""" + mock_ctx = Mock(spec=InvocationContext) + mock_ctx.actions = Mock() + mock_ctx.actions.state_delta = {} + mock_ctx.session_state = {"temp:skill:test:test-skill": json.dumps(["item1", "item2"])} + + result = generic_get_selection(ctx=mock_ctx, skill_name="test-skill", state_key_prefix="temp:skill:test:") + + assert result == ["item1", "item2"] + + def test_generic_get_selection_all_with_callback(self): + """Test getting selection with '*' and callback.""" + mock_ctx = Mock(spec=InvocationContext) + mock_ctx.actions = Mock() + mock_ctx.actions.state_delta = {} + mock_ctx.session_state = {"temp:skill:test:test-skill": '*'} + + def get_all_items(skill_name): + return ["item1", "item2", "item3"] + + result = generic_get_selection(ctx=mock_ctx, + skill_name="test-skill", + state_key_prefix="temp:skill:test:", + get_all_items_callback=get_all_items) + + assert result == ["item1", "item2", "item3"] + + def test_generic_get_selection_all_without_callback(self): + """Test getting selection with '*' but no callback.""" + mock_ctx = Mock(spec=InvocationContext) + mock_ctx.actions = Mock() + mock_ctx.actions.state_delta = {} + mock_ctx.session_state = {"temp:skill:test:test-skill": '*'} + + result = generic_get_selection(ctx=mock_ctx, skill_name="test-skill", state_key_prefix="temp:skill:test:") + + assert result == [] + + def test_generic_get_selection_not_found(self): + """Test getting selection when not found.""" + mock_ctx = Mock(spec=InvocationContext) + mock_ctx.actions = Mock() + mock_ctx.actions.state_delta = {} + mock_ctx.session_state = {} + + result = generic_get_selection(ctx=mock_ctx, skill_name="test-skill", state_key_prefix="temp:skill:test:") + + assert result == [] + + def test_generic_get_selection_invalid_json(self): + """Test getting selection with invalid JSON.""" + mock_ctx = Mock(spec=InvocationContext) + mock_ctx.actions = Mock() + mock_ctx.actions.state_delta = {} + mock_ctx.session_state = {"temp:skill:test:test-skill": "invalid json"} + + result = generic_get_selection(ctx=mock_ctx, skill_name="test-skill", state_key_prefix="temp:skill:test:") + + assert result == [] + + def test_generic_get_selection_bytes_value(self): + """Test getting selection when value is bytes.""" + mock_ctx = Mock(spec=InvocationContext) + mock_ctx.actions = Mock() + mock_ctx.actions.state_delta = {} + mock_ctx.session_state = {"temp:skill:test:test-skill": json.dumps(["item1"]).encode('utf-8')} + + result = generic_get_selection(ctx=mock_ctx, skill_name="test-skill", state_key_prefix="temp:skill:test:") + + assert result == ["item1"] + + def test_generic_get_selection_callback_exception(self): + """Test getting selection when callback raises exception.""" + mock_ctx = Mock(spec=InvocationContext) + mock_ctx.actions = Mock() + mock_ctx.actions.state_delta = {} + mock_ctx.session_state = {"temp:skill:test:test-skill": '*'} + + def get_all_items_error(skill_name): + raise Exception("Test error") + + result = generic_get_selection(ctx=mock_ctx, + skill_name="test-skill", + state_key_prefix="temp:skill:test:", + get_all_items_callback=get_all_items_error) + + assert result == [] + + def test_generic_get_selection_from_state_delta(self): + """Test getting selection prefers state_delta over session_state.""" + mock_ctx = Mock(spec=InvocationContext) + mock_ctx.actions = Mock() + mock_ctx.actions.state_delta = {"temp:skill:test:test-skill": json.dumps(["item_new"])} + mock_ctx.session_state = {"temp:skill:test:test-skill": json.dumps(["item_old"])} + + result = generic_get_selection(ctx=mock_ctx, skill_name="test-skill", state_key_prefix="temp:skill:test:") + + assert result == ["item_new"] diff --git a/tests/skills/test_constants.py b/tests/skills/test_constants.py new file mode 100644 index 0000000..26b80fa --- /dev/null +++ b/tests/skills/test_constants.py @@ -0,0 +1,23 @@ +# Tencent is pleased to support the open source community by making tRPC-Agent-Python available. +# +# Copyright (C) 2026 Tencent. All rights reserved. +# +# tRPC-Agent-Python is licensed under Apache-2.0. + +from trpc_agent_sdk.skills._constants import SKILL_FILE +from trpc_agent_sdk.skills._constants import SKILL_LOAD_MODE_VALUES +from trpc_agent_sdk.skills._constants import SKILL_TOOLS_NAMES +from trpc_agent_sdk.skills._constants import SkillLoadModeNames +from trpc_agent_sdk.skills._constants import SkillToolsNames + + +def test_skill_file_constant(): + assert SKILL_FILE == "SKILL.md" + + +def test_skill_tools_names_matches_enum(): + assert SKILL_TOOLS_NAMES == [item.value for item in SkillToolsNames] + + +def test_skill_load_mode_values_matches_enum(): + assert SKILL_LOAD_MODE_VALUES == [item.value for item in SkillLoadModeNames] diff --git a/tests/skills/test_dynamic_toolset.py b/tests/skills/test_dynamic_toolset.py index 26cd3b8..df65fb8 100644 --- a/tests/skills/test_dynamic_toolset.py +++ b/tests/skills/test_dynamic_toolset.py @@ -23,12 +23,19 @@ import pytest from trpc_agent_sdk.skills._dynamic_toolset import DynamicSkillToolSet +from trpc_agent_sdk.skills._common import loaded_state_key +from trpc_agent_sdk.skills._common import tool_state_key +from trpc_agent_sdk.skills._constants import SKILL_CONFIG_KEY +from trpc_agent_sdk.skills._skill_config import DEFAULT_SKILL_CONFIG def _make_ctx(state_delta=None, session_state=None): ctx = MagicMock() ctx.actions.state_delta = state_delta or {} ctx.session_state = session_state or {} + ctx.agent_name = "" + ctx.agent_context.get_metadata = MagicMock( + side_effect=lambda key, default=None: DEFAULT_SKILL_CONFIG if key == SKILL_CONFIG_KEY else default) return ctx @@ -113,10 +120,9 @@ def test_no_loaded_skills(self, *_): def test_loaded_skills_from_session_and_delta(self, *_): repo = _make_mock_repository() ts = DynamicSkillToolSet(skill_repository=repo) - ctx = _make_ctx( - session_state={"temp:skill:loaded:skill-a": True}, - state_delta={"temp:skill:loaded:skill-b": True}, - ) + ctx = _make_ctx() + ctx.session_state = {loaded_state_key(ctx, "skill-a"): True} + ctx.actions.state_delta = {loaded_state_key(ctx, "skill-b"): True} result = ts._get_loaded_skills_from_state(ctx) assert set(result) == {"skill-a", "skill-b"} @@ -139,7 +145,7 @@ def test_no_active_skills(self, *_): def test_active_from_loaded(self, *_): repo = _make_mock_repository() ts = DynamicSkillToolSet(skill_repository=repo) - ctx = _make_ctx(state_delta={"temp:skill:loaded:s1": True}) + ctx = _make_ctx(state_delta={loaded_state_key(_make_ctx(), "s1"): True}) result = ts._get_active_skills_from_delta(ctx) assert "s1" in result @@ -148,7 +154,7 @@ def test_active_from_loaded(self, *_): def test_active_from_tools_modified(self, *_): repo = _make_mock_repository() ts = DynamicSkillToolSet(skill_repository=repo) - ctx = _make_ctx(state_delta={"temp:skill:tools:s2": json.dumps(["t1"])}) + ctx = _make_ctx(state_delta={tool_state_key(_make_ctx(), "s2"): json.dumps(["t1"])}) result = ts._get_active_skills_from_delta(ctx) assert "s2" in result @@ -157,7 +163,7 @@ def test_active_from_tools_modified(self, *_): def test_falsy_loaded_value_ignored(self, *_): repo = _make_mock_repository() ts = DynamicSkillToolSet(skill_repository=repo) - ctx = _make_ctx(state_delta={"temp:skill:loaded:s1": False}) + ctx = _make_ctx(state_delta={loaded_state_key(_make_ctx(), "s1"): False}) assert ts._get_active_skills_from_delta(ctx) == [] @@ -173,7 +179,7 @@ def test_json_array(self, *_): skill.tools = ["default_tool"] repo = _make_mock_repository({"s1": skill}) ts = DynamicSkillToolSet(skill_repository=repo) - ctx = _make_ctx(state_delta={"temp:skill:tools:s1": json.dumps(["tool_a", "tool_b"])}) + ctx = _make_ctx(state_delta={tool_state_key(_make_ctx(), "s1"): json.dumps(["tool_a", "tool_b"])}) result = ts._get_tools_selection(ctx, "s1") assert result == ["tool_a", "tool_b"] @@ -184,7 +190,7 @@ def test_star_returns_defaults(self, *_): skill.tools = ["default_tool"] repo = _make_mock_repository({"s1": skill}) ts = DynamicSkillToolSet(skill_repository=repo) - ctx = _make_ctx(state_delta={"temp:skill:tools:s1": "*"}) + ctx = _make_ctx(state_delta={tool_state_key(_make_ctx(), "s1"): "*"}) result = ts._get_tools_selection(ctx, "s1") assert result == ["default_tool"] @@ -206,7 +212,7 @@ def test_invalid_json_falls_back(self, *_): skill.tools = ["fallback"] repo = _make_mock_repository({"s1": skill}) ts = DynamicSkillToolSet(skill_repository=repo) - ctx = _make_ctx(state_delta={"temp:skill:tools:s1": "not_json"}) + ctx = _make_ctx(state_delta={tool_state_key(_make_ctx(), "s1"): "not_json"}) result = ts._get_tools_selection(ctx, "s1") assert result == ["fallback"] @@ -301,7 +307,7 @@ async def test_active_skills_with_tools(self, *_): mock_tool = _make_mock_tool("my_tool") ts._available_tools["my_tool"] = mock_tool - ctx = _make_ctx(state_delta={"temp:skill:loaded:s1": True}) + ctx = _make_ctx(state_delta={loaded_state_key(_make_ctx(), "s1"): True}) result = await ts.get_tools(ctx) assert len(result) == 1 assert result[0] is mock_tool @@ -316,7 +322,7 @@ async def test_only_active_false_uses_all_loaded(self, *_): mock_tool = _make_mock_tool("tool_a") ts._available_tools["tool_a"] = mock_tool - ctx = _make_ctx(session_state={"temp:skill:loaded:s1": True}) + ctx = _make_ctx(session_state={loaded_state_key(_make_ctx(), "s1"): True}) result = await ts.get_tools(ctx) assert len(result) == 1 @@ -332,9 +338,10 @@ async def test_deduplicates_tools(self, *_): mock_tool = _make_mock_tool("shared_tool") ts._available_tools["shared_tool"] = mock_tool + key_ctx = _make_ctx() ctx = _make_ctx(session_state={ - "temp:skill:loaded:s1": True, - "temp:skill:loaded:s2": True, + loaded_state_key(key_ctx, "s1"): True, + loaded_state_key(key_ctx, "s2"): True, }) result = await ts.get_tools(ctx) assert len(result) == 1 @@ -349,7 +356,7 @@ async def test_fallback_to_loaded_when_no_active(self, *_): mock_tool = _make_mock_tool("fallback_tool") ts._available_tools["fallback_tool"] = mock_tool - ctx = _make_ctx(session_state={"temp:skill:loaded:s1": True}) + ctx = _make_ctx(session_state={loaded_state_key(_make_ctx(), "s1"): True}) result = await ts.get_tools(ctx) assert len(result) == 1 @@ -360,7 +367,7 @@ async def test_unresolvable_tool_skipped(self, *_): skill.tools = ["nonexistent_tool"] repo = _make_mock_repository({"s1": skill}) ts = DynamicSkillToolSet(skill_repository=repo) - ctx = _make_ctx(state_delta={"temp:skill:loaded:s1": True}) + ctx = _make_ctx(state_delta={loaded_state_key(_make_ctx(), "s1"): True}) result = await ts.get_tools(ctx) assert len(result) == 0 @@ -371,7 +378,7 @@ async def test_no_tools_for_skill(self, *_): skill.tools = [] repo = _make_mock_repository({"s1": skill}) ts = DynamicSkillToolSet(skill_repository=repo) - ctx = _make_ctx(state_delta={"temp:skill:loaded:s1": True}) + ctx = _make_ctx(state_delta={loaded_state_key(_make_ctx(), "s1"): True}) result = await ts.get_tools(ctx) assert result == [] @@ -439,7 +446,7 @@ def test_bytes_json_value(self, *_): skill.tools = ["default"] repo = _make_mock_repository({"s1": skill}) ts = DynamicSkillToolSet(skill_repository=repo) - ctx = _make_ctx(state_delta={"temp:skill:tools:s1": json.dumps(["t1"]).encode()}) + ctx = _make_ctx(state_delta={tool_state_key(_make_ctx(), "s1"): json.dumps(["t1"]).encode()}) result = ts._get_tools_selection(ctx, "s1") assert result == ["t1"] @@ -450,7 +457,7 @@ def test_bytes_star_value(self, *_): skill.tools = ["default"] repo = _make_mock_repository({"s1": skill}) ts = DynamicSkillToolSet(skill_repository=repo) - ctx = _make_ctx(state_delta={"temp:skill:tools:s1": b"*"}) + ctx = _make_ctx(state_delta={tool_state_key(_make_ctx(), "s1"): b"*"}) result = ts._get_tools_selection(ctx, "s1") assert result == ["default"] @@ -461,6 +468,6 @@ def test_non_list_json_falls_back(self, *_): skill.tools = ["default"] repo = _make_mock_repository({"s1": skill}) ts = DynamicSkillToolSet(skill_repository=repo) - ctx = _make_ctx(state_delta={"temp:skill:tools:s1": json.dumps({"not": "list"})}) + ctx = _make_ctx(state_delta={tool_state_key(_make_ctx(), "s1"): json.dumps({"not": "list"})}) result = ts._get_tools_selection(ctx, "s1") assert result == ["default"] diff --git a/tests/skills/test_hot_reload.py b/tests/skills/test_hot_reload.py new file mode 100644 index 0000000..bc8411b --- /dev/null +++ b/tests/skills/test_hot_reload.py @@ -0,0 +1,35 @@ +# Tencent is pleased to support the open source community by making tRPC-Agent-Python available. +# +# Copyright (C) 2026 Tencent. All rights reserved. +# +# tRPC-Agent-Python is licensed under Apache-2.0. + +from pathlib import Path + +from trpc_agent_sdk.skills._hot_reload import SkillHotReloadTracker + + +def test_mark_changed_path_only_tracks_skill_file(tmp_path: Path): + root = tmp_path / "skills" + skill_dir = root / "demo" + skill_dir.mkdir(parents=True) + tracker = SkillHotReloadTracker("SKILL.md") + + tracker.mark_changed_path(str(skill_dir / "notes.txt"), is_directory=False, skill_roots=[str(root)]) + assert tracker.pop_changed_dirs(str(root.resolve())) == [] + + tracker.mark_changed_path(str(skill_dir / "SKILL.md"), is_directory=False, skill_roots=[str(root)]) + changed = tracker.pop_changed_dirs(str(root.resolve())) + assert changed == [skill_dir.resolve()] + + +def test_resolve_root_key_and_normalize_targets(tmp_path: Path): + root = tmp_path / "skills" + nested = root / "a" / "b" + nested.mkdir(parents=True) + + key = SkillHotReloadTracker.resolve_root_key(nested, [str(root)]) + assert key == str(root.resolve()) + + deduped = SkillHotReloadTracker.normalize_scan_targets([root / "a", nested, root / "c"]) + assert deduped == [root / "a", root / "c"] diff --git a/tests/skills/test_repository.py b/tests/skills/test_repository.py index 22a6ad7..6f24581 100644 --- a/tests/skills/test_repository.py +++ b/tests/skills/test_repository.py @@ -27,11 +27,9 @@ BASE_DIR_PLACEHOLDER, BaseSkillRepository, FsSkillRepository, - _is_doc_file, - _parse_tools_from_body, - _split_front_matter, create_default_skill_repository, ) +from trpc_agent_sdk.skills._utils import is_doc_file # --------------------------------------------------------------------------- @@ -40,54 +38,54 @@ class TestSplitFrontMatter: def test_no_front_matter(self): - fm, body = _split_front_matter("# Hello\nworld") + fm, body = FsSkillRepository.from_markdown("# Hello\nworld") assert fm == {} assert body == "# Hello\nworld" def test_with_front_matter(self): content = "---\nname: test\ndescription: Test skill\n---\n# Body" - fm, body = _split_front_matter(content) + fm, body = FsSkillRepository.from_markdown(content) assert fm["name"] == "test" assert fm["description"] == "Test skill" assert body == "# Body" def test_crlf_normalization(self): content = "---\r\nname: test\r\n---\r\nbody" - fm, body = _split_front_matter(content) + fm, body = FsSkillRepository.from_markdown(content) assert fm["name"] == "test" assert body == "body" def test_invalid_yaml_returns_empty_dict(self): content = "---\n: : : invalid\n---\nbody" - fm, body = _split_front_matter(content) + fm, body = FsSkillRepository.from_markdown(content) assert body == "body" def test_non_dict_yaml_returns_empty_dict(self): content = "---\n- item1\n- item2\n---\nbody" - fm, body = _split_front_matter(content) + fm, body = FsSkillRepository.from_markdown(content) assert fm == {} assert body == "body" def test_unclosed_front_matter(self): content = "---\nname: test\nno closing" - fm, body = _split_front_matter(content) + fm, body = FsSkillRepository.from_markdown(content) assert fm == {} assert body == content def test_none_values_become_empty_string(self): content = "---\nname:\n---\nbody" - fm, body = _split_front_matter(content) + fm, body = FsSkillRepository.from_markdown(content) assert fm["name"] == "" def test_none_key_converted_to_string(self): content = "---\nname: test\n---\nbody" - fm, body = _split_front_matter(content) + fm, body = FsSkillRepository.from_markdown(content) assert fm["name"] == "test" assert body == "body" def test_no_dash_prefix(self): content = "no front matter at all" - fm, body = _split_front_matter(content) + fm, body = FsSkillRepository.from_markdown(content) assert fm == {} assert body == content @@ -99,29 +97,29 @@ def test_no_dash_prefix(self): class TestParseToolsFromBody: def test_basic_tools_section(self): body = "Tools:\n- tool_a\n- tool_b\n\nOverview" - tools = _parse_tools_from_body(body) + tools = FsSkillRepository._parse_tools_from_body(body) assert tools == ["tool_a", "tool_b"] def test_no_tools_section(self): body = "# Just markdown\nNo tools here" - assert _parse_tools_from_body(body) == [] + assert FsSkillRepository._parse_tools_from_body(body) == [] def test_tools_section_stops_at_next_section(self): body = "Tools:\n- tool_a\nOverview\nMore content" - tools = _parse_tools_from_body(body) + tools = FsSkillRepository._parse_tools_from_body(body) assert tools == ["tool_a"] def test_tools_section_skips_headings(self): body = "Tools:\n# Comment\n- tool_a\n" - tools = _parse_tools_from_body(body) + tools = FsSkillRepository._parse_tools_from_body(body) assert tools == ["tool_a"] def test_empty_body(self): - assert _parse_tools_from_body("") == [] + assert FsSkillRepository._parse_tools_from_body("") == [] def test_tools_with_description_colon(self): body = "Tools:\n- tool_a\nDescription: something\n" - tools = _parse_tools_from_body(body) + tools = FsSkillRepository._parse_tools_from_body(body) assert tools == ["tool_a"] @@ -131,16 +129,16 @@ def test_tools_with_description_colon(self): class TestIsDocFile: def test_markdown(self): - assert _is_doc_file("readme.md") is True - assert _is_doc_file("README.MD") is True + assert is_doc_file("readme.md") is True + assert is_doc_file("README.MD") is True def test_text(self): - assert _is_doc_file("notes.txt") is True - assert _is_doc_file("NOTES.TXT") is True + assert is_doc_file("notes.txt") is True + assert is_doc_file("NOTES.TXT") is True def test_non_doc(self): - assert _is_doc_file("script.py") is False - assert _is_doc_file("data.json") is False + assert is_doc_file("script.py") is False + assert is_doc_file("data.json") is False # --------------------------------------------------------------------------- @@ -374,10 +372,42 @@ def test_read_docs_error_handling(self, tmp_path): class TestBaseSkillRepositoryAbstract: def test_user_prompt_default(self): - repo = MagicMock(spec=BaseSkillRepository) - BaseSkillRepository.user_prompt(repo) + class _Repo(BaseSkillRepository): + def summaries(self): + return [] + + def get(self, name: str): + raise ValueError(name) + + def skill_list(self, mode: str = "all"): + return [] + + def path(self, name: str) -> str: + return "" + + def refresh(self) -> None: + return None + + repo = _Repo(workspace_runtime=MagicMock()) + assert BaseSkillRepository.user_prompt(repo) == "" def test_skill_run_env_default(self): - repo = MagicMock(spec=BaseSkillRepository) + class _Repo(BaseSkillRepository): + def summaries(self): + return [] + + def get(self, name: str): + raise ValueError(name) + + def skill_list(self, mode: str = "all"): + return [] + + def path(self, name: str) -> str: + return "" + + def refresh(self) -> None: + return None + + repo = _Repo(workspace_runtime=MagicMock()) result = BaseSkillRepository.skill_run_env(repo, "skill") assert result == {} diff --git a/tests/skills/test_run_tool.py b/tests/skills/test_run_tool.py deleted file mode 100644 index 9e4f577..0000000 --- a/tests/skills/test_run_tool.py +++ /dev/null @@ -1,364 +0,0 @@ -# Tencent is pleased to support the open source community by making tRPC-Agent-Python available. -# -# Copyright (C) 2026 Tencent. All rights reserved. -# -# tRPC-Agent-Python is licensed under Apache-2.0. - -from unittest.mock import AsyncMock -from unittest.mock import Mock -from unittest.mock import patch - -import pytest -from trpc_agent_sdk.code_executors import BaseProgramRunner -from trpc_agent_sdk.code_executors import BaseWorkspaceFS -from trpc_agent_sdk.code_executors import BaseWorkspaceManager -from trpc_agent_sdk.code_executors import BaseWorkspaceRuntime -from trpc_agent_sdk.code_executors import CodeFile -from trpc_agent_sdk.code_executors import WorkspaceInfo -from trpc_agent_sdk.code_executors import WorkspaceRunResult -from trpc_agent_sdk.context import InvocationContext -from trpc_agent_sdk.skills import BaseSkillRepository -from trpc_agent_sdk.skills.tools import ArtifactInfo -from trpc_agent_sdk.skills.tools import SkillRunFile -from trpc_agent_sdk.skills.tools import SkillRunInput -from trpc_agent_sdk.skills.tools import SkillRunOutput -from trpc_agent_sdk.skills.tools import SkillRunTool -from trpc_agent_sdk.skills.tools._skill_run import _inline_json_schema_refs - - -class TestInlineJsonSchemaRefs: - """Test suite for _inline_json_schema_refs function.""" - - def test_inline_json_schema_refs_no_refs(self): - """Test inlining schema with no $ref references.""" - schema = { - "type": "object", - "properties": { - "name": {"type": "string"} - } - } - - result = _inline_json_schema_refs(schema) - - assert result == schema - - def test_inline_json_schema_refs_with_refs(self): - """Test inlining schema with $ref references.""" - schema = { - "type": "object", - "properties": { - "item": {"$ref": "#/$defs/Item"} - }, - "$defs": { - "Item": { - "type": "object", - "properties": { - "name": {"type": "string"} - } - } - } - } - - result = _inline_json_schema_refs(schema) - - assert "$defs" not in result - assert "$ref" not in str(result) - assert "name" in str(result) - - def test_inline_json_schema_refs_nested_refs(self): - """Test inlining schema with nested $ref references.""" - schema = { - "type": "object", - "properties": { - "item": {"$ref": "#/$defs/Item"} - }, - "$defs": { - "Item": { - "type": "object", - "properties": { - "nested": {"$ref": "#/$defs/Nested"} - } - }, - "Nested": { - "type": "string" - } - } - } - - result = _inline_json_schema_refs(schema) - - assert "$defs" not in result - assert "$ref" not in str(result) - - -class TestSkillRunInput: - """Test suite for SkillRunInput class.""" - - def test_create_skill_run_input(self): - """Test creating skill run input.""" - input_data = SkillRunInput( - skill="test-skill", - command="python script.py", - cwd="work", - env={"VAR": "value"}, - output_files=["out/*.txt"], - timeout=30, - save_as_artifacts=True, - ) - - assert input_data.skill == "test-skill" - assert input_data.command == "python script.py" - assert input_data.cwd == "work" - assert input_data.env == {"VAR": "value"} - assert input_data.output_files == ["out/*.txt"] - assert input_data.timeout == 30 - assert input_data.save_as_artifacts is True - - def test_create_skill_run_input_defaults(self): - """Test creating skill run input with defaults.""" - input_data = SkillRunInput(skill="test-skill", command="echo hello") - - assert input_data.skill == "test-skill" - assert input_data.command == "echo hello" - assert input_data.cwd == "" - assert input_data.env == {} - assert input_data.output_files == [] - assert input_data.timeout == 0 - assert input_data.save_as_artifacts is False - - -class TestSkillRunOutput: - """Test suite for SkillRunOutput class.""" - - def test_create_skill_run_output(self): - """Test creating skill run output.""" - output_files = [SkillRunFile(name="output.txt", content="content", mime_type="text/plain")] - artifact_files = [ArtifactInfo(name="artifact.txt", version=1)] - - output = SkillRunOutput( - stdout="output", - stderr="error", - exit_code=0, - timed_out=False, - duration_ms=1000, - output_files=output_files, - artifact_files=artifact_files, - ) - - assert output.stdout == "output" - assert output.stderr == "error" - assert output.exit_code == 0 - assert output.timed_out is False - assert output.duration_ms == 1000 - assert len(output.output_files) == 1 - assert len(output.artifact_files) == 1 - - def test_create_skill_run_output_defaults(self): - """Test creating skill run output with defaults.""" - output = SkillRunOutput() - - assert output.stdout == "" - assert output.stderr == "" - assert output.exit_code == 0 - assert output.timed_out is False - assert output.duration_ms == 0 - assert output.output_files == [] - assert output.artifact_files == [] - - -class TestArtifactInfo: - """Test suite for ArtifactInfo class.""" - - def test_create_artifact_info(self): - """Test creating artifact info.""" - info = ArtifactInfo(name="artifact.txt", version=1) - - assert info.name == "artifact.txt" - assert info.version == 1 - - def test_create_artifact_info_defaults(self): - """Test creating artifact info with defaults.""" - info = ArtifactInfo() - - assert info.name == "" - assert info.version == 0 - - -class TestSkillRunTool: - """Test suite for SkillRunTool class.""" - - def setup_method(self): - """Set up test fixtures before each test.""" - self.mock_repository = Mock(spec=BaseSkillRepository) - self.mock_runtime = Mock(spec=BaseWorkspaceRuntime) - self.mock_manager = Mock(spec=BaseWorkspaceManager) - self.mock_fs = Mock(spec=BaseWorkspaceFS) - self.mock_runner = Mock(spec=BaseProgramRunner) - self.mock_repository.workspace_runtime = self.mock_runtime - self.mock_runtime.manager = Mock(return_value=self.mock_manager) - self.mock_runtime.fs = Mock(return_value=self.mock_fs) - self.mock_runtime.runner = Mock(return_value=self.mock_runner) - - self.mock_ctx = Mock(spec=InvocationContext) - self.mock_ctx.agent_context = Mock() - self.mock_ctx.agent_context.get_metadata = Mock(return_value=None) - self.mock_ctx.session = Mock() - self.mock_ctx.session.id = "session-123" - self.mock_ctx.actions = Mock() - self.mock_ctx.actions.state_delta = {} - - def test_init(self): - """Test SkillRunTool initialization.""" - tool = SkillRunTool(repository=self.mock_repository) - - assert tool.name == "skill_run" - assert tool._repository == self.mock_repository - - def test_get_declaration(self): - """Test getting function declaration.""" - tool = SkillRunTool(repository=self.mock_repository) - - declaration = tool._get_declaration() - - assert declaration.name == "skill_run" - assert declaration.parameters is not None - assert declaration.response is not None - - def test_get_repository_from_instance(self): - """Test getting repository from instance.""" - tool = SkillRunTool(repository=self.mock_repository) - - result = tool._get_repository(self.mock_ctx) - - assert result == self.mock_repository - - def test_get_repository_from_context(self): - """Test getting repository from context.""" - tool = SkillRunTool(repository=None) - self.mock_ctx.agent_context.get_metadata = Mock(return_value=self.mock_repository) - - result = tool._get_repository(self.mock_ctx) - - assert result == self.mock_repository - - @pytest.mark.asyncio - async def test_run_async_impl_success(self): - """Test running skill_run tool successfully.""" - from trpc_agent_sdk.skills.stager import SkillStageResult - - mock_stager = AsyncMock() - mock_stager.stage_skill = AsyncMock(return_value=SkillStageResult(workspace_skill_dir="skills/test-skill")) - tool = SkillRunTool(repository=self.mock_repository, skill_stager=mock_stager) - - workspace = WorkspaceInfo(id="ws-123", path="/tmp/workspace") - self.mock_repository.path = Mock(return_value="/path/to/skill") - self.mock_repository.skill_run_env = Mock(return_value={}) - self.mock_manager.create_workspace = AsyncMock(return_value=workspace) - self.mock_fs.stage_directory = AsyncMock() - self.mock_fs.stage_inputs = AsyncMock() - self.mock_fs.collect_outputs = AsyncMock(return_value=Mock(files=[])) - self.mock_runner.run_program = AsyncMock(return_value=WorkspaceRunResult( - stdout="output", - stderr="", - exit_code=0, - duration=1.0, - timed_out=False - )) - - args = { - "skill": "test-skill", - "command": "echo hello" - } - result = await tool._run_async_impl(tool_context=self.mock_ctx, args=args) - - assert isinstance(result, dict) - assert result["stdout"] == "output" - assert result["exit_code"] == 0 - - @pytest.mark.asyncio - async def test_run_async_impl_with_output_files(self): - """Test running skill_run tool with output files.""" - from trpc_agent_sdk.skills.stager import SkillStageResult - - mock_stager = AsyncMock() - mock_stager.stage_skill = AsyncMock(return_value=SkillStageResult(workspace_skill_dir="skills/test-skill")) - tool = SkillRunTool(repository=self.mock_repository, skill_stager=mock_stager) - - workspace = WorkspaceInfo(id="ws-123", path="/tmp/workspace") - self.mock_repository.path = Mock(return_value="/path/to/skill") - self.mock_repository.skill_run_env = Mock(return_value={}) - self.mock_manager.create_workspace = AsyncMock(return_value=workspace) - self.mock_fs.stage_directory = AsyncMock() - self.mock_fs.stage_inputs = AsyncMock() - self.mock_fs.collect = AsyncMock(return_value=[CodeFile(name="output.txt", content="content", mime_type="text/plain")]) - self.mock_runner.run_program = AsyncMock(return_value=WorkspaceRunResult( - stdout="output", - stderr="", - exit_code=0, - duration=1.0, - timed_out=False - )) - - from trpc_agent_sdk.code_executors._types import ManifestOutput, ManifestFileRef - mock_output = ManifestOutput(files=[ - ManifestFileRef(name="output.txt", content="content", mime_type="text/plain") - ]) - self.mock_fs.collect_outputs = AsyncMock(return_value=mock_output) - - args = { - "skill": "test-skill", - "command": "echo hello", - "output_files": ["out/*.txt"] - } - result = await tool._run_async_impl(tool_context=self.mock_ctx, args=args) - - assert len(result["output_files"]) == 1 - - @pytest.mark.asyncio - async def test_run_async_impl_invalid_args(self): - """Test running skill_run tool with invalid arguments.""" - tool = SkillRunTool(repository=self.mock_repository) - - args = { - "skill": "test-skill", - # Missing required 'command' field - } - - with pytest.raises(ValueError, match="Invalid skill_run arguments"): - await tool._run_async_impl(tool_context=self.mock_ctx, args=args) - - @pytest.mark.asyncio - async def test_run_async_impl_with_kwargs(self): - """Test running skill_run tool with kwargs.""" - from trpc_agent_sdk.skills.stager import SkillStageResult - - mock_stager = AsyncMock() - mock_stager.stage_skill = AsyncMock(return_value=SkillStageResult(workspace_skill_dir="skills/test-skill")) - tool = SkillRunTool(repository=self.mock_repository, timeout=30, skill_stager=mock_stager) - - workspace = WorkspaceInfo(id="ws-123", path="/tmp/workspace") - self.mock_repository.path = Mock(return_value="/path/to/skill") - self.mock_repository.skill_run_env = Mock(return_value={}) - self.mock_manager.create_workspace = AsyncMock(return_value=workspace) - self.mock_fs.stage_directory = AsyncMock() - self.mock_fs.stage_inputs = AsyncMock() - self.mock_fs.collect_outputs = AsyncMock(return_value=Mock(files=[])) - self.mock_runner.run_program = AsyncMock(return_value=WorkspaceRunResult( - stdout="output", - stderr="", - exit_code=0, - duration=1.0, - timed_out=False - )) - - args = { - "skill": "test-skill", - "command": "echo hello" - } - - result = await tool._run_async_impl(tool_context=self.mock_ctx, args=args) - - assert isinstance(result, dict) - assert result["stdout"] == "output" - assert result["exit_code"] == 0 - diff --git a/tests/skills/test_skill_config.py b/tests/skills/test_skill_config.py new file mode 100644 index 0000000..b8284dd --- /dev/null +++ b/tests/skills/test_skill_config.py @@ -0,0 +1,42 @@ +# Tencent is pleased to support the open source community by making tRPC-Agent-Python available. +# +# Copyright (C) 2026 Tencent. All rights reserved. +# +# tRPC-Agent-Python is licensed under Apache-2.0. + +from unittest.mock import MagicMock + +from trpc_agent_sdk.skills._constants import SKILL_CONFIG_KEY +from trpc_agent_sdk.skills._constants import SkillLoadModeNames +from trpc_agent_sdk.skills._skill_config import DEFAULT_SKILL_CONFIG +from trpc_agent_sdk.skills._skill_config import get_skill_config +from trpc_agent_sdk.skills._skill_config import get_skill_load_mode +from trpc_agent_sdk.skills._skill_config import is_exist_skill_config +from trpc_agent_sdk.skills._skill_config import set_skill_config + + +def test_get_skill_config_uses_metadata_default(): + agent_ctx = MagicMock() + agent_ctx.get_metadata = MagicMock(return_value=DEFAULT_SKILL_CONFIG) + assert get_skill_config(agent_ctx) == DEFAULT_SKILL_CONFIG + + +def test_set_skill_config_writes_metadata(): + agent_ctx = MagicMock() + config = {"skill_processor": {"load_mode": "session"}} + set_skill_config(agent_ctx, config) + agent_ctx.with_metadata.assert_called_once_with(SKILL_CONFIG_KEY, config) + + +def test_get_skill_load_mode_fallback_turn_on_invalid(): + agent_ctx = MagicMock() + agent_ctx.get_metadata = MagicMock(return_value={"skill_processor": {"load_mode": "bad"}}) + ctx = MagicMock() + ctx.agent_context = agent_ctx + assert get_skill_load_mode(ctx) == SkillLoadModeNames.TURN.value + + +def test_is_exist_skill_config_checks_key(): + agent_ctx = MagicMock() + agent_ctx.metadata = {SKILL_CONFIG_KEY: {}} + assert is_exist_skill_config(agent_ctx) is True diff --git a/tests/skills/test_skill_profile.py b/tests/skills/test_skill_profile.py new file mode 100644 index 0000000..b091c15 --- /dev/null +++ b/tests/skills/test_skill_profile.py @@ -0,0 +1,41 @@ +# Tencent is pleased to support the open source community by making tRPC-Agent-Python available. +# +# Copyright (C) 2026 Tencent. All rights reserved. +# +# tRPC-Agent-Python is licensed under Apache-2.0. + +import pytest + +from trpc_agent_sdk.skills._constants import SkillProfileNames +from trpc_agent_sdk.skills._skill_profile import SkillProfileFlags + + +def test_normalize_profile(): + assert SkillProfileFlags.normalize_profile("knowledge_only") == SkillProfileNames.KNOWLEDGE_ONLY.value + assert SkillProfileFlags.normalize_profile("unknown") == SkillProfileNames.FULL.value + + +def test_preset_flags_knowledge_only(): + flags = SkillProfileFlags.preset_flags("knowledge_only") + assert flags.has_knowledge_tools() is True + assert flags.requires_execution_tools() is False + + +def test_resolve_flags_with_forbidden_tool(): + flags = SkillProfileFlags.resolve_flags("full", forbidden_tools=["skill_write_stdin"]) + assert flags.exec is True + assert flags.write_stdin is False + + +def test_validate_dependency_error(): + flags = SkillProfileFlags(run=False, exec=True) + with pytest.raises(ValueError, match="requires"): + flags.validate() + + +def test_without_interactive_execution(): + flags = SkillProfileFlags.resolve_flags("full") + narrowed = flags.without_interactive_execution() + assert narrowed.run is True + assert narrowed.exec is False + assert narrowed.poll_session is False diff --git a/tests/skills/test_state_keys.py b/tests/skills/test_state_keys.py new file mode 100644 index 0000000..e8df27f --- /dev/null +++ b/tests/skills/test_state_keys.py @@ -0,0 +1,35 @@ +# Tencent is pleased to support the open source community by making tRPC-Agent-Python available. +# +# Copyright (C) 2026 Tencent. All rights reserved. +# +# tRPC-Agent-Python is licensed under Apache-2.0. + +from trpc_agent_sdk.skills._state_keys import docs_key +from trpc_agent_sdk.skills._state_keys import docs_prefix +from trpc_agent_sdk.skills._state_keys import loaded_key +from trpc_agent_sdk.skills._state_keys import loaded_order_key +from trpc_agent_sdk.skills._state_keys import loaded_prefix +from trpc_agent_sdk.skills._state_keys import to_persistent_prefix +from trpc_agent_sdk.skills._state_keys import tool_key +from trpc_agent_sdk.skills._state_keys import tool_prefix + + +def test_loaded_key_legacy_fallback(): + assert loaded_key("", "demo") == "temp:skill:loaded:demo" + + +def test_scoped_keys_escape_agent_name(): + assert loaded_key("agent/a", "demo") == "temp:skill:loaded_by_agent:agent%2Fa/demo" + assert docs_key("agent/a", "demo") == "temp:skill:docs_by_agent:agent%2Fa/demo" + assert tool_key("agent/a", "demo") == "temp:skill:tools_by_agent:agent%2Fa/demo" + + +def test_prefix_helpers(): + assert loaded_prefix("") == "temp:skill:loaded:" + assert docs_prefix("") == "temp:skill:docs:" + assert tool_prefix("") == "temp:skill:tools:" + assert loaded_order_key("agent/a") == "temp:skill:loaded_order_by_agent:agent%2Fa" + + +def test_to_persistent_prefix(): + assert to_persistent_prefix("temp:skill:loaded:demo") == "skill:loaded:demo" diff --git a/tests/skills/test_state_migration.py b/tests/skills/test_state_migration.py new file mode 100644 index 0000000..02495a3 --- /dev/null +++ b/tests/skills/test_state_migration.py @@ -0,0 +1,71 @@ +# Tencent is pleased to support the open source community by making tRPC-Agent-Python available. +# +# Copyright (C) 2026 Tencent. All rights reserved. +# +# tRPC-Agent-Python is licensed under Apache-2.0. +"""Tests for legacy skill state migration.""" + +from unittest.mock import Mock + +from trpc_agent_sdk.skills._constants import SKILL_DOCS_STATE_KEY_PREFIX +from trpc_agent_sdk.skills._constants import SKILL_LOADED_STATE_KEY_PREFIX +from trpc_agent_sdk.skills._state_keys import docs_key +from trpc_agent_sdk.skills._state_keys import loaded_key +from trpc_agent_sdk.skills._state_migration import SKILLS_LEGACY_MIGRATION_STATE_KEY +from trpc_agent_sdk.skills._state_migration import maybe_migrate_legacy_skill_state + + +def _build_ctx(*, state=None, delta=None, agent_name: str = "agent-a") -> Mock: + ctx = Mock() + ctx.session = Mock() + ctx.session.state = dict(state or {}) + ctx.session.events = [] + ctx.actions = Mock() + ctx.actions.state_delta = dict(delta or {}) + ctx.agent = Mock() + ctx.agent.name = agent_name + return ctx + + +class TestMaybeMigrateLegacySkillState: + def test_migrates_loaded_legacy_key_to_temp(self): + legacy_key = f"{SKILL_LOADED_STATE_KEY_PREFIX}demo-skill" + temp_key = loaded_key("agent-a", "demo-skill") + ctx = _build_ctx(state={legacy_key: "1"}) + + maybe_migrate_legacy_skill_state(ctx) + + assert ctx.actions.state_delta[SKILLS_LEGACY_MIGRATION_STATE_KEY] is True + assert ctx.actions.state_delta[temp_key] == "1" + assert ctx.actions.state_delta[legacy_key] is None + + def test_migrates_docs_legacy_key_to_temp(self): + legacy_key = f"{SKILL_DOCS_STATE_KEY_PREFIX}demo-skill" + temp_key = docs_key("agent-a", "demo-skill") + value = '["README.md"]' + ctx = _build_ctx(state={legacy_key: value}) + + maybe_migrate_legacy_skill_state(ctx) + + assert ctx.actions.state_delta[SKILLS_LEGACY_MIGRATION_STATE_KEY] is True + assert ctx.actions.state_delta[temp_key] == value + assert ctx.actions.state_delta[legacy_key] is None + + def test_existing_scoped_key_skips_copy_and_only_clears_legacy(self): + legacy_key = f"{SKILL_LOADED_STATE_KEY_PREFIX}demo-skill" + temp_key = loaded_key("agent-a", "demo-skill") + ctx = _build_ctx(state={legacy_key: "legacy", temp_key: "existing"}) + + maybe_migrate_legacy_skill_state(ctx) + + assert ctx.actions.state_delta[SKILLS_LEGACY_MIGRATION_STATE_KEY] is True + assert ctx.actions.state_delta[legacy_key] is None + assert temp_key not in ctx.actions.state_delta + + def test_migration_is_idempotent_when_marker_exists(self): + legacy_key = f"{SKILL_LOADED_STATE_KEY_PREFIX}demo-skill" + ctx = _build_ctx(state={legacy_key: "1", SKILLS_LEGACY_MIGRATION_STATE_KEY: True}) + + maybe_migrate_legacy_skill_state(ctx) + + assert ctx.actions.state_delta == {} diff --git a/tests/skills/test_state_order.py b/tests/skills/test_state_order.py new file mode 100644 index 0000000..360d58b --- /dev/null +++ b/tests/skills/test_state_order.py @@ -0,0 +1,24 @@ +# Tencent is pleased to support the open source community by making tRPC-Agent-Python available. +# +# Copyright (C) 2026 Tencent. All rights reserved. +# +# tRPC-Agent-Python is licensed under Apache-2.0. + +from trpc_agent_sdk.skills._state_order import marshal_loaded_order +from trpc_agent_sdk.skills._state_order import parse_loaded_order +from trpc_agent_sdk.skills._state_order import touch_loaded_order + + +def test_parse_loaded_order_from_json_and_bytes(): + assert parse_loaded_order('["a","b","a",""]') == ["a", "b"] + assert parse_loaded_order(b'["x","y"]') == ["x", "y"] + assert parse_loaded_order(b"\xff") == [] + + +def test_marshal_loaded_order_normalizes(): + assert marshal_loaded_order(["a", "a", " ", "b"]) == '["a", "b"]' + assert marshal_loaded_order([]) == "" + + +def test_touch_loaded_order_moves_items_to_tail(): + assert touch_loaded_order(["a", "b", "c"], "b", "a") == ["c", "b", "a"] diff --git a/tests/skills/test_tools.py b/tests/skills/test_tools.py deleted file mode 100644 index 6c49fdd..0000000 --- a/tests/skills/test_tools.py +++ /dev/null @@ -1,553 +0,0 @@ -# Tencent is pleased to support the open source community by making tRPC-Agent-Python available. -# -# Copyright (C) 2026 Tencent. All rights reserved. -# -# tRPC-Agent-Python is licensed under Apache-2.0. - -import json -from unittest.mock import Mock - -import pytest -from trpc_agent_sdk.context import InvocationContext -from trpc_agent_sdk.skills.tools import SkillSelectDocsResult -from trpc_agent_sdk.skills.tools import SkillSelectToolsResult -from trpc_agent_sdk.skills import skill_list -from trpc_agent_sdk.skills import skill_list_docs -from trpc_agent_sdk.skills import skill_list_tools -from trpc_agent_sdk.skills import skill_load -from trpc_agent_sdk.skills import skill_select_docs -from trpc_agent_sdk.skills import skill_select_tools -from trpc_agent_sdk.skills.tools._skill_load import _set_state_delta_for_skill_load -from trpc_agent_sdk.skills.tools._skill_load import _set_state_delta_for_skill_tools -from trpc_agent_sdk.skills import Skill -from trpc_agent_sdk.skills import SkillResource - - -class TestSkillList: - """Test suite for skill_list function.""" - - def test_skill_list_success(self): - """Test listing all skills.""" - mock_repository = Mock() - mock_repository.skill_list.return_value = ["skill1", "skill2", "skill3"] - - mock_ctx = Mock(spec=InvocationContext) - mock_ctx.agent_context = Mock() - mock_ctx.agent_context.get_metadata = Mock(return_value=mock_repository) - - result = skill_list(mock_ctx) - - assert len(result) == 3 - assert "skill1" in result - assert "skill2" in result - assert "skill3" in result - - def test_skill_list_empty(self): - """Test listing skills when none exist.""" - mock_repository = Mock() - mock_repository.skill_list.return_value = [] - - mock_ctx = Mock(spec=InvocationContext) - mock_ctx.agent_context = Mock() - mock_ctx.agent_context.get_metadata = Mock(return_value=mock_repository) - - result = skill_list(mock_ctx) - - assert result == [] - - def test_skill_list_repository_not_found(self): - """Test listing skills when repository not found.""" - mock_ctx = Mock(spec=InvocationContext) - mock_ctx.agent_context = Mock() - mock_ctx.agent_context.get_metadata = Mock(return_value=None) - - with pytest.raises(ValueError, match="repository not found"): - skill_list(mock_ctx) - - -class TestSkillListDocs: - """Test suite for skill_list_docs function.""" - - def test_skill_list_docs_success(self): - """Test listing docs for a skill.""" - mock_repository = Mock() - skill = Skill( - resources=[ - SkillResource(path="doc1.md", content="content1"), - SkillResource(path="doc2.md", content="content2"), - ] - ) - mock_repository.get.return_value = skill - - mock_ctx = Mock(spec=InvocationContext) - mock_ctx.agent_context = Mock() - mock_ctx.agent_context.get_metadata = Mock(return_value=mock_repository) - - result = skill_list_docs(mock_ctx, "test-skill") - - assert len(result["docs"]) == 2 - assert "doc1.md" in result["docs"] - assert "doc2.md" in result["docs"] - - def test_skill_list_docs_no_resources(self): - """Test listing docs for skill with no resources.""" - mock_repository = Mock() - skill = Skill(resources=[]) - mock_repository.get.return_value = skill - - mock_ctx = Mock(spec=InvocationContext) - mock_ctx.agent_context = Mock() - mock_ctx.agent_context.get_metadata = Mock(return_value=mock_repository) - - result = skill_list_docs(mock_ctx, "test-skill") - - assert result["docs"] == [] - - def test_skill_list_docs_skill_not_found(self): - """Test listing docs for nonexistent skill.""" - mock_repository = Mock() - mock_repository.get.return_value = None - - mock_ctx = Mock(spec=InvocationContext) - mock_ctx.agent_context = Mock() - mock_ctx.agent_context.get_metadata = Mock(return_value=mock_repository) - - result = skill_list_docs(mock_ctx, "nonexistent-skill") - - assert result["docs"] == [] - - def test_skill_list_docs_repository_not_found(self): - """Test listing docs when repository not found.""" - mock_ctx = Mock(spec=InvocationContext) - mock_ctx.agent_context = Mock() - mock_ctx.agent_context.get_metadata = Mock(return_value=None) - - with pytest.raises(ValueError, match="repository not found"): - skill_list_docs(mock_ctx, "test-skill") - - -class TestSkillListTools: - """Test suite for skill_list_tools function.""" - - def test_skill_list_tools_success(self): - """Test listing tools for a skill.""" - mock_repository = Mock() - skill = Skill(tools=["tool1", "tool2", "tool3"]) - mock_repository.get.return_value = skill - - mock_ctx = Mock(spec=InvocationContext) - mock_ctx.agent_context = Mock() - mock_ctx.agent_context.get_metadata = Mock(return_value=mock_repository) - - result = skill_list_tools(mock_ctx, "test-skill") - - assert len(result["tools"]) == 3 - assert "tool1" in result["tools"] - assert "tool2" in result["tools"] - assert "tool3" in result["tools"] - - def test_skill_list_tools_no_tools(self): - """Test listing tools for skill with no tools.""" - mock_repository = Mock() - skill = Skill(tools=[]) - mock_repository.get.return_value = skill - - mock_ctx = Mock(spec=InvocationContext) - mock_ctx.agent_context = Mock() - mock_ctx.agent_context.get_metadata = Mock(return_value=mock_repository) - - result = skill_list_tools(mock_ctx, "test-skill") - - assert result["tools"] == [] - - def test_skill_list_tools_skill_not_found(self): - """Test listing tools for nonexistent skill.""" - mock_repository = Mock() - mock_repository.get.return_value = None - - mock_ctx = Mock(spec=InvocationContext) - mock_ctx.agent_context = Mock() - mock_ctx.agent_context.get_metadata = Mock(return_value=mock_repository) - - result = skill_list_tools(mock_ctx, "nonexistent-skill") - - assert result["tools"] == [] - - def test_skill_list_tools_repository_not_found(self): - """Test listing tools when repository not found.""" - mock_ctx = Mock(spec=InvocationContext) - mock_ctx.agent_context = Mock() - mock_ctx.agent_context.get_metadata = Mock(return_value=None) - - with pytest.raises(ValueError, match="repository not found"): - skill_list_tools(mock_ctx, "test-skill") - - -class TestSkillLoad: - """Test suite for skill_load function.""" - - def test_skill_load_success(self): - """Test loading a skill.""" - mock_repository = Mock() - skill = Skill(body="skill body", tools=[]) - mock_repository.get.return_value = skill - - mock_ctx = Mock(spec=InvocationContext) - mock_ctx.agent_context = Mock() - mock_ctx.agent_context.get_metadata = Mock(return_value=mock_repository) - mock_ctx.actions = Mock() - mock_ctx.actions.state_delta = {} - - result = skill_load(mock_ctx, "test-skill") - - assert "loaded" in result - assert "test-skill" in result - - def test_skill_load_with_tools(self): - """Test loading a skill with tools.""" - mock_repository = Mock() - skill = Skill(body="skill body", tools=["tool1", "tool2"]) - mock_repository.get.return_value = skill - - mock_ctx = Mock(spec=InvocationContext) - mock_ctx.agent_context = Mock() - mock_ctx.agent_context.get_metadata = Mock(return_value=mock_repository) - mock_ctx.actions = Mock() - mock_ctx.actions.state_delta = {} - - result = skill_load(mock_ctx, "test-skill") - - assert "loaded" in result - # Check that tools state was set - tools_key = "temp:skill:tools:test-skill" - assert tools_key in mock_ctx.actions.state_delta - assert json.loads(mock_ctx.actions.state_delta[tools_key]) == ["tool1", "tool2"] - - def test_skill_load_with_docs(self): - """Test loading a skill with specific docs.""" - mock_repository = Mock() - skill = Skill(body="skill body", tools=[]) - mock_repository.get.return_value = skill - - mock_ctx = Mock(spec=InvocationContext) - mock_ctx.agent_context = Mock() - mock_ctx.agent_context.get_metadata = Mock(return_value=mock_repository) - mock_ctx.actions = Mock() - mock_ctx.actions.state_delta = {} - - result = skill_load(mock_ctx, "test-skill", docs=["doc1.md"]) - - assert "loaded" in result - docs_key = "temp:skill:docs:test-skill" - assert docs_key in mock_ctx.actions.state_delta - - def test_skill_load_with_include_all_docs(self): - """Test loading a skill with include_all_docs=True.""" - mock_repository = Mock() - skill = Skill(body="skill body", tools=[]) - mock_repository.get.return_value = skill - - mock_ctx = Mock(spec=InvocationContext) - mock_ctx.agent_context = Mock() - mock_ctx.agent_context.get_metadata = Mock(return_value=mock_repository) - mock_ctx.actions = Mock() - mock_ctx.actions.state_delta = {} - - result = skill_load(mock_ctx, "test-skill", include_all_docs=True) - - assert "loaded" in result - assert mock_ctx.actions.state_delta.get("temp:skill:docs:test-skill") == '*' - - def test_skill_load_skill_not_found(self): - """Test loading nonexistent skill.""" - mock_repository = Mock() - mock_repository.get.return_value = None - - mock_ctx = Mock(spec=InvocationContext) - mock_ctx.agent_context = Mock() - mock_ctx.agent_context.get_metadata = Mock(return_value=mock_repository) - - result = skill_load(mock_ctx, "nonexistent-skill") - - assert "not found" in result - - def test_skill_load_repository_not_found(self): - """Test loading skill when repository not found.""" - mock_ctx = Mock(spec=InvocationContext) - mock_ctx.agent_context = Mock() - mock_ctx.agent_context.get_metadata = Mock(return_value=None) - - with pytest.raises(ValueError, match="repository not found"): - skill_load(mock_ctx, "test-skill") - - -class TestSkillSelectDocs: - """Test suite for skill_select_docs function.""" - - def test_skill_select_docs_replace_mode(self): - """Test selecting docs with replace mode.""" - mock_ctx = Mock(spec=InvocationContext) - mock_ctx.session_state = {} - mock_ctx.actions = Mock() - mock_ctx.actions.state_delta = {} - - result = skill_select_docs(mock_ctx, "test-skill", docs=["doc1.md"], mode="replace") - - assert isinstance(result, SkillSelectDocsResult) - assert result.skill == "test-skill" - assert result.mode == "replace" - assert "doc1.md" in result.selected_docs - - def test_skill_select_docs_add_mode(self): - """Test selecting docs with add mode.""" - mock_ctx = Mock(spec=InvocationContext) - mock_ctx.session_state = { - "temp:skill:docs:test-skill": json.dumps(["doc1.md"]) - } - mock_ctx.actions = Mock() - mock_ctx.actions.state_delta = {} - - result = skill_select_docs(mock_ctx, "test-skill", docs=["doc2.md"], mode="add") - - assert result.mode == "add" - assert "doc1.md" in result.selected_docs - assert "doc2.md" in result.selected_docs - - def test_skill_select_docs_clear_mode(self): - """Test selecting docs with clear mode.""" - mock_ctx = Mock(spec=InvocationContext) - mock_ctx.session_state = { - "temp:skill:docs:test-skill": json.dumps(["doc1.md"]) - } - mock_ctx.actions = Mock() - mock_ctx.actions.state_delta = {} - - result = skill_select_docs(mock_ctx, "test-skill", mode="clear") - - assert result.mode == "clear" - assert result.include_all_docs is False - assert result.selected_docs == [] - - def test_skill_select_docs_with_include_all(self): - """Test selecting docs with include_all_docs=True.""" - mock_ctx = Mock(spec=InvocationContext) - mock_ctx.session_state = {} - mock_ctx.actions = Mock() - mock_ctx.actions.state_delta = {} - - result = skill_select_docs(mock_ctx, "test-skill", include_all_docs=True, mode="replace") - - assert result.include_all_docs is True - assert result.selected_docs == [] - - def test_skill_select_docs_invalid_mode(self): - """Test selecting docs with invalid mode defaults to replace.""" - mock_ctx = Mock(spec=InvocationContext) - mock_ctx.session_state = {} - mock_ctx.actions = Mock() - mock_ctx.actions.state_delta = {} - - result = skill_select_docs(mock_ctx, "test-skill", docs=["doc1.md"], mode="invalid") - - assert result.mode == "replace" - - def test_skill_select_docs_previous_all_docs(self): - """Test selecting docs when previous state has all docs.""" - mock_ctx = Mock(spec=InvocationContext) - mock_ctx.session_state = { - "temp:skill:docs:test-skill": '*' - } - mock_ctx.actions = Mock() - mock_ctx.actions.state_delta = {} - - result = skill_select_docs(mock_ctx, "test-skill", docs=["doc1.md"], mode="add") - - assert result.include_all_docs is True - - -class TestSkillSelectTools: - """Test suite for skill_select_tools function.""" - - def test_skill_select_tools_replace_mode(self): - """Test selecting tools with replace mode.""" - mock_ctx = Mock(spec=InvocationContext) - mock_ctx.session_state = {} - mock_ctx.actions = Mock() - mock_ctx.actions.state_delta = {} - - result = skill_select_tools(mock_ctx, "test-skill", tools=["tool1"], mode="replace") - - assert isinstance(result, SkillSelectToolsResult) - assert result.skill == "test-skill" - assert result.mode == "replace" - assert "tool1" in result.selected_tools - - def test_skill_select_tools_add_mode(self): - """Test selecting tools with add mode.""" - mock_ctx = Mock(spec=InvocationContext) - mock_ctx.session_state = { - "temp:skill:tools:test-skill": json.dumps(["tool1"]) - } - mock_ctx.actions = Mock() - mock_ctx.actions.state_delta = {} - - result = skill_select_tools(mock_ctx, "test-skill", tools=["tool2"], mode="add") - - assert result.mode == "add" - assert "tool1" in result.selected_tools - assert "tool2" in result.selected_tools - - def test_skill_select_tools_clear_mode(self): - """Test selecting tools with clear mode.""" - mock_ctx = Mock(spec=InvocationContext) - mock_ctx.session_state = { - "temp:skill:tools:test-skill": json.dumps(["tool1"]) - } - mock_ctx.actions = Mock() - mock_ctx.actions.state_delta = {} - - result = skill_select_tools(mock_ctx, "test-skill", mode="clear") - - assert result.mode == "clear" - assert result.include_all_tools is False - assert result.selected_tools == [] - - def test_skill_select_tools_with_include_all(self): - """Test selecting tools with include_all_tools=True.""" - mock_ctx = Mock(spec=InvocationContext) - mock_ctx.session_state = {} - mock_ctx.actions = Mock() - mock_ctx.actions.state_delta = {} - - result = skill_select_tools(mock_ctx, "test-skill", include_all_tools=True, mode="replace") - - assert result.include_all_tools is True - assert result.selected_tools == [] - - def test_skill_select_tools_invalid_mode(self): - """Test selecting tools with invalid mode defaults to replace.""" - mock_ctx = Mock(spec=InvocationContext) - mock_ctx.session_state = {} - mock_ctx.actions = Mock() - mock_ctx.actions.state_delta = {} - - result = skill_select_tools(mock_ctx, "test-skill", tools=["tool1"], mode="invalid") - - assert result.mode == "replace" - - def test_skill_select_tools_previous_all_tools(self): - """Test selecting tools when previous state has all tools.""" - mock_ctx = Mock(spec=InvocationContext) - mock_ctx.session_state = { - "temp:skill:tools:test-skill": '*' - } - mock_ctx.actions = Mock() - mock_ctx.actions.state_delta = {} - - result = skill_select_tools(mock_ctx, "test-skill", tools=["tool1"], mode="add") - - assert result.include_all_tools is True - - -class TestSkillSelectDocsResult: - """Test suite for SkillSelectDocsResult class.""" - - def test_create_result(self): - """Test creating SkillSelectDocsResult.""" - result = SkillSelectDocsResult( - skill="test-skill", - selected_docs=["doc1.md"], - include_all_docs=True, - mode="replace" - ) - - assert result.skill == "test-skill" - assert result.selected_docs == ["doc1.md"] - assert result.include_all_docs is True - assert result.mode == "replace" - - def test_create_result_with_alias_fields(self): - """Test creating SkillSelectDocsResult with alias fields.""" - result = SkillSelectDocsResult( - skill="test-skill", - selected_items=["doc1.md", "doc2.md"], - include_all=True, - mode="replace" - ) - - # Alias fields should be mapped to actual fields - assert result.selected_docs == ["doc1.md", "doc2.md"] - assert result.include_all_docs is True - - -class TestSkillSelectToolsResult: - """Test suite for SkillSelectToolsResult class.""" - - def test_create_result(self): - """Test creating SkillSelectToolsResult.""" - result = SkillSelectToolsResult( - skill="test-skill", - selected_tools=["tool1"], - include_all_tools=True, - mode="replace" - ) - - assert result.skill == "test-skill" - assert result.selected_tools == ["tool1"] - assert result.include_all_tools is True - assert result.mode == "replace" - - def test_create_result_with_alias_fields(self): - """Test creating SkillSelectToolsResult with alias fields.""" - result = SkillSelectToolsResult( - skill="test-skill", - selected_items=["tool1", "tool2"], - include_all=True, - mode="replace" - ) - - # Alias fields should be mapped to actual fields - assert result.selected_tools == ["tool1", "tool2"] - assert result.include_all_tools is True - - -class TestStateDeltaHelpers: - """Test suite for state delta helper functions.""" - - def test_set_state_delta_for_skill_load(self): - """Test setting state delta for skill load.""" - mock_ctx = Mock(spec=InvocationContext) - mock_ctx.actions = Mock() - mock_ctx.actions.state_delta = {} - - _set_state_delta_for_skill_load(mock_ctx, "test-skill", ["doc1.md"], False) - - loaded_key = "temp:skill:loaded:test-skill" - docs_key = "temp:skill:docs:test-skill" - - assert loaded_key in mock_ctx.actions.state_delta - assert docs_key in mock_ctx.actions.state_delta - assert json.loads(mock_ctx.actions.state_delta[docs_key]) == ["doc1.md"] - - def test_set_state_delta_for_skill_load_include_all(self): - """Test setting state delta with include_all_docs.""" - mock_ctx = Mock(spec=InvocationContext) - mock_ctx.actions = Mock() - mock_ctx.actions.state_delta = {} - - _set_state_delta_for_skill_load(mock_ctx, "test-skill", [], True) - - docs_key = "temp:skill:docs:test-skill" - assert mock_ctx.actions.state_delta[docs_key] == '*' - - def test_set_state_delta_for_skill_tools(self): - """Test setting state delta for skill tools.""" - mock_ctx = Mock(spec=InvocationContext) - mock_ctx.actions = Mock() - mock_ctx.actions.state_delta = {} - - _set_state_delta_for_skill_tools(mock_ctx, "test-skill", ["tool1", "tool2"]) - - tools_key = "temp:skill:tools:test-skill" - assert tools_key in mock_ctx.actions.state_delta - assert json.loads(mock_ctx.actions.state_delta[tools_key]) == ["tool1", "tool2"] diff --git a/tests/skills/tools/__init__.py b/tests/skills/tools/__init__.py index e69de29..8b13789 100644 --- a/tests/skills/tools/__init__.py +++ b/tests/skills/tools/__init__.py @@ -0,0 +1 @@ + diff --git a/tests/skills/tools/test_common.py b/tests/skills/tools/test_common.py new file mode 100644 index 0000000..52d712f --- /dev/null +++ b/tests/skills/tools/test_common.py @@ -0,0 +1,31 @@ +from unittest.mock import MagicMock + +import pytest + +from trpc_agent_sdk.skills.tools._common import get_staged_workspace_dir +from trpc_agent_sdk.skills.tools._common import inline_json_schema_refs +from trpc_agent_sdk.skills.tools._common import require_non_empty +from trpc_agent_sdk.skills.tools._common import set_staged_workspace_dir + + +def test_require_non_empty(): + assert require_non_empty(" ok ", field_name="x") == "ok" + with pytest.raises(ValueError, match="x is required"): + require_non_empty(" ", field_name="x") + + +def test_inline_json_schema_refs(): + schema = {"$defs": {"S": {"type": "string"}}, "properties": {"name": {"$ref": "#/$defs/S"}}} + out = inline_json_schema_refs(schema) + assert "$defs" not in out + assert out["properties"]["name"]["type"] == "string" + + +def test_staged_workspace_dir_round_trip(): + metadata = {} + ctx = MagicMock() + ctx.agent_context.get_metadata = MagicMock(side_effect=lambda key, default=None: metadata.get(key, default)) + ctx.agent_context.with_metadata = MagicMock(side_effect=lambda key, value: metadata.__setitem__(key, value)) + + set_staged_workspace_dir(ctx, "skill-a", "skills/skill-a") + assert get_staged_workspace_dir(ctx, "skill-a") == "skills/skill-a" diff --git a/tests/skills/tools/test_save_artifact.py b/tests/skills/tools/test_save_artifact.py new file mode 100644 index 0000000..837b771 --- /dev/null +++ b/tests/skills/tools/test_save_artifact.py @@ -0,0 +1,49 @@ +from pathlib import Path +from unittest.mock import MagicMock + +import pytest + +from trpc_agent_sdk.skills._constants import SKILL_ARTIFACTS_STATE_KEY +from trpc_agent_sdk.skills.tools._save_artifact import SaveArtifactTool +from trpc_agent_sdk.skills.tools._save_artifact import _apply_artifact_state_delta +from trpc_agent_sdk.skills.tools._save_artifact import _artifact_save_skip_reason +from trpc_agent_sdk.skills.tools._save_artifact import _normalize_artifact_path +from trpc_agent_sdk.skills.tools._save_artifact import _normalize_workspace_prefix + + +def test_normalize_workspace_prefix(): + assert _normalize_workspace_prefix("workspace://work/a.txt") == "work/a.txt" + assert _normalize_workspace_prefix("$WORK_DIR/a.txt") == "work/a.txt" + + +def test_normalize_artifact_path_valid_and_invalid(tmp_path: Path): + workspace_root = str(tmp_path) + rel, abs_path = _normalize_artifact_path("work/a.txt", workspace_root) + assert rel == "work/a.txt" + assert abs_path.endswith("work/a.txt") + + with pytest.raises(ValueError, match="stay within the workspace"): + _normalize_artifact_path("../a.txt", workspace_root) + + +def test_artifact_save_skip_reason_and_state_delta(): + ctx = MagicMock() + ctx.artifact_service = object() + ctx.session = object() + ctx.app_name = "app" + ctx.user_id = "u" + ctx.session_id = "s" + ctx.function_call_id = "fc-1" + ctx.actions.state_delta = {} + assert _artifact_save_skip_reason(ctx) == "" + + _apply_artifact_state_delta(ctx, "work/a.txt", 2, "artifact://work/a.txt@2") + value = ctx.actions.state_delta[SKILL_ARTIFACTS_STATE_KEY] + assert value["tool_call_id"] == "fc-1" + assert value["artifacts"][0]["version"] == 2 + + +def test_save_artifact_declaration_name(): + declaration = SaveArtifactTool()._get_declaration() + assert declaration is not None + assert declaration.name == "workspace_save_artifact" diff --git a/tests/skills/tools/test_skill_exec.py b/tests/skills/tools/test_skill_exec.py index 82a0ecb..1593e5d 100644 --- a/tests/skills/tools/test_skill_exec.py +++ b/tests/skills/tools/test_skill_exec.py @@ -3,887 +3,107 @@ # Copyright (C) 2026 Tencent. All rights reserved. # # tRPC-Agent-Python is licensed under Apache-2.0. -"""Unit tests for trpc_agent_sdk.skills.tools._skill_exec. -Covers: -- Pydantic I/O models: ExecInput, WriteStdinInput, PollSessionInput, - KillSessionInput, SessionInteraction, ExecOutput, SessionKillOutput -- Helper functions: _last_non_empty_line, _has_selection_items, - _detect_interaction, _build_exec_env, _resolve_abs_cwd -- SkillExecTool: session management, declarations -- create_exec_tools factory -- _close_session -""" - -from __future__ import annotations - -import asyncio -import os -from pathlib import Path -from unittest.mock import AsyncMock, MagicMock, patch +from unittest.mock import AsyncMock +from unittest.mock import MagicMock import pytest - -from trpc_agent_sdk.skills.tools._skill_exec import ( - DEFAULT_EXEC_YIELD_MS, - DEFAULT_IO_YIELD_MS, - DEFAULT_POLL_LINES, - DEFAULT_SESSION_TTL, - ExecInput, - ExecOutput, - KillSessionInput, - KillSessionTool, - PollSessionInput, - PollSessionTool, - SessionInteraction, - SessionKillOutput, - SkillExecTool, - WriteStdinInput, - WriteStdinTool, - _close_session, - _detect_interaction, - _has_selection_items, - _last_non_empty_line, - _resolve_abs_cwd, - create_exec_tools, -) - - -# --------------------------------------------------------------------------- -# _last_non_empty_line -# --------------------------------------------------------------------------- - -class TestLastNonEmptyLine: - def test_normal(self): - assert _last_non_empty_line("line1\nline2\nline3") == "line3" - - def test_trailing_empty_lines(self): - assert _last_non_empty_line("line1\nline2\n\n\n") == "line2" - - def test_all_empty(self): - assert _last_non_empty_line("\n\n") == "" - - def test_empty_string(self): - assert _last_non_empty_line("") == "" - - def test_single_line(self): - assert _last_non_empty_line("hello") == "hello" - - def test_whitespace_lines(self): - assert _last_non_empty_line(" \n \n content \n ") == "content" - - -# --------------------------------------------------------------------------- -# _has_selection_items -# --------------------------------------------------------------------------- - -class TestHasSelectionItems: - def test_numbered_list(self): - text = "Choose an option:\n1. Option A\n2. Option B\n3. Option C" - assert _has_selection_items(text) is True - - def test_numbered_paren(self): - text = "1) Option A\n2) Option B" - assert _has_selection_items(text) is True - - def test_no_numbers(self): - text = "No numbered items here" - assert _has_selection_items(text) is False - - def test_single_number(self): - text = "1. Only one item" - assert _has_selection_items(text) is False - - def test_empty(self): - assert _has_selection_items("") is False - - -# --------------------------------------------------------------------------- -# _detect_interaction -# --------------------------------------------------------------------------- - -class TestDetectInteraction: - def test_exited_returns_none(self): - assert _detect_interaction("exited", "some output") is None - - def test_empty_output_returns_none(self): - assert _detect_interaction("running", "") is None - - def test_colon_prompt(self): - result = _detect_interaction("running", "Enter your name:") - assert result is not None - assert result.needs_input is True - assert result.kind == "prompt" - - def test_question_mark_prompt(self): - result = _detect_interaction("running", "Continue?") - assert result is not None - assert result.needs_input is True - - def test_press_enter(self): - result = _detect_interaction("running", "Press Enter to continue") - assert result is not None - assert result.needs_input is True - - def test_selection_detection(self): - text = "Choose a number:\n1. Option A\n2. Option B\nEnter the number:" - result = _detect_interaction("running", text) - assert result is not None - assert result.kind == "selection" - - def test_normal_output(self): - result = _detect_interaction("running", "Processing data...\nDone.") - assert result is None - - def test_type_your_prompt(self): - result = _detect_interaction("running", "Type your answer here") - assert result is not None - assert result.needs_input is True - - -# --------------------------------------------------------------------------- -# _resolve_abs_cwd -# --------------------------------------------------------------------------- - -class TestResolveAbsCwd: - def test_relative_cwd(self, tmp_path): - result = _resolve_abs_cwd(str(tmp_path), "sub") - assert os.path.isabs(result) - assert os.path.isdir(result) - - def test_absolute_cwd(self, tmp_path): - result = _resolve_abs_cwd(str(tmp_path), str(tmp_path / "abs")) - assert result == str(tmp_path / "abs") - - def test_empty_cwd(self, tmp_path): - result = _resolve_abs_cwd(str(tmp_path), "") - assert os.path.isabs(result) - - -# --------------------------------------------------------------------------- -# Pydantic I/O models -# --------------------------------------------------------------------------- - -class TestExecModels: - def test_exec_input_required(self): - inp = ExecInput(skill="test", command="ls") - assert inp.skill == "test" - assert inp.command == "ls" +from trpc_agent_sdk.code_executors import DEFAULT_EXEC_YIELD_MS +from trpc_agent_sdk.code_executors import DEFAULT_IO_YIELD_MS +from trpc_agent_sdk.code_executors import DEFAULT_POLL_LINES +from trpc_agent_sdk.code_executors import DEFAULT_SESSION_TTL_SEC +from trpc_agent_sdk.skills.tools._skill_exec import ExecInput +from trpc_agent_sdk.skills.tools._skill_exec import PollSessionTool +from trpc_agent_sdk.skills.tools._skill_exec import SkillExecTool +from trpc_agent_sdk.skills.tools._skill_exec import WriteStdinTool +from trpc_agent_sdk.skills.tools._skill_exec import _close_session +from trpc_agent_sdk.skills.tools._skill_exec import _detect_interaction +from trpc_agent_sdk.skills.tools._skill_exec import _has_selection_items +from trpc_agent_sdk.skills.tools._skill_exec import _last_non_empty_line +from trpc_agent_sdk.skills.tools._skill_exec import create_exec_tools + + +def _make_exec_tool() -> SkillExecTool: + run_tool = MagicMock() + run_tool._repository = MagicMock() + run_tool._timeout = 300.0 + run_tool._resolve_cwd = MagicMock(return_value="skills/test") + run_tool._build_command = MagicMock(return_value=("bash", ["-lc", "echo hello"])) + run_tool._prepare_outputs = AsyncMock(return_value=([], None)) + run_tool._attach_artifacts_if_requested = AsyncMock() + run_tool._merge_manifest_artifact_refs = MagicMock() + return SkillExecTool(run_tool) + + +class TestHelpers: + def test_last_non_empty_line(self): + assert _last_non_empty_line("a\n\nb\n") == "b" + + def test_has_selection_items(self): + assert _has_selection_items("1. a\n2. b") is True + assert _has_selection_items("1. a") is False + + def test_detect_interaction_prompt(self): + ret = _detect_interaction("running", "Enter your name:") + assert ret is not None + assert ret.needs_input is True + + def test_detect_interaction_selection(self): + ret = _detect_interaction("running", "Choose:\n1. A\n2. B\nEnter the number:") + assert ret is not None + assert ret.kind == "selection" + + +class TestModelsAndConstants: + def test_exec_input_defaults(self): + inp = ExecInput(skill="s", command="echo hi") + assert inp.yield_time_ms == 0 + assert inp.poll_lines == 0 assert inp.tty is False - assert inp.yield_ms == 0 - - def test_write_stdin_input(self): - inp = WriteStdinInput(session_id="abc") - assert inp.session_id == "abc" - assert inp.chars == "" - assert inp.submit is False - - def test_poll_session_input(self): - inp = PollSessionInput(session_id="abc") - assert inp.session_id == "abc" - - def test_kill_session_input(self): - inp = KillSessionInput(session_id="abc") - assert inp.session_id == "abc" - - def test_session_interaction(self): - si = SessionInteraction(needs_input=True, kind="prompt", hint="Enter:") - assert si.needs_input is True - assert si.kind == "prompt" - - def test_exec_output_defaults(self): - out = ExecOutput() - assert out.status == "running" - assert out.session_id == "" - assert out.output == "" - assert out.exit_code is None - - def test_session_kill_output(self): - out = SessionKillOutput(ok=True, session_id="abc", status="killed") - assert out.ok is True - - -# --------------------------------------------------------------------------- -# Constants -# --------------------------------------------------------------------------- - -class TestExecConstants: - def test_defaults(self): - assert DEFAULT_EXEC_YIELD_MS == 300 - assert DEFAULT_IO_YIELD_MS == 100 - assert DEFAULT_POLL_LINES == 50 - assert DEFAULT_SESSION_TTL == 300.0 - - -# --------------------------------------------------------------------------- -# SkillExecTool — session management -# --------------------------------------------------------------------------- - -class TestSkillExecToolSessions: - def _make_exec_tool(self): - run_tool = MagicMock() - run_tool._repository = MagicMock() - run_tool._timeout = 300.0 - return SkillExecTool(run_tool) - - async def test_put_and_get_session(self): - tool = self._make_exec_tool() - mock_session = MagicMock() - mock_session.exited_at = None - await tool._put_session("s1", mock_session) - result = await tool.get_session("s1") - assert result is mock_session - - async def test_get_unknown_session_raises(self): - tool = self._make_exec_tool() - with pytest.raises(ValueError, match="unknown session_id"): - await tool.get_session("nonexistent") - - async def test_remove_session(self): - tool = self._make_exec_tool() - mock_session = MagicMock() - mock_session.exited_at = None - await tool._put_session("s1", mock_session) - result = await tool.remove_session("s1") - assert result is mock_session - - async def test_remove_unknown_session_raises(self): - tool = self._make_exec_tool() - with pytest.raises(ValueError, match="unknown session_id"): - await tool.remove_session("nonexistent") - - def test_declaration(self): - tool = self._make_exec_tool() - decl = tool._get_declaration() - assert decl.name == "skill_exec" - - -# --------------------------------------------------------------------------- -# WriteStdinTool, PollSessionTool, KillSessionTool — declarations -# --------------------------------------------------------------------------- - -class TestSubToolDeclarations: - def _make_exec_tool(self): - run_tool = MagicMock() - run_tool._repository = MagicMock() - run_tool._timeout = 300.0 - return SkillExecTool(run_tool) - def test_write_stdin_declaration(self): - exec_tool = self._make_exec_tool() - tool = WriteStdinTool(exec_tool) - decl = tool._get_declaration() - assert decl.name == "skill_write_stdin" + def test_default_constants(self): + assert DEFAULT_EXEC_YIELD_MS > 0 + assert DEFAULT_IO_YIELD_MS > 0 + assert DEFAULT_POLL_LINES > 0 + assert DEFAULT_SESSION_TTL_SEC > 0 - def test_poll_session_declaration(self): - exec_tool = self._make_exec_tool() - tool = PollSessionTool(exec_tool) - decl = tool._get_declaration() - assert decl.name == "skill_poll_session" - def test_kill_session_declaration(self): - exec_tool = self._make_exec_tool() - tool = KillSessionTool(exec_tool) - decl = tool._get_declaration() - assert decl.name == "skill_kill_session" - - -# --------------------------------------------------------------------------- -# create_exec_tools -# --------------------------------------------------------------------------- +class TestSessionStore: + @pytest.mark.asyncio + async def test_put_get_remove(self): + tool = _make_exec_tool() + sess = MagicMock() + sess.exited_at = None + sess.proc.state = AsyncMock(return_value=MagicMock(status="running", exit_code=None)) + await tool._put_session("s1", sess) + got = await tool._get_session("s1") + assert got is sess + removed = await tool._remove_session("s1") + assert removed is sess -class TestCreateExecTools: - def test_creates_four_tools(self): - run_tool = MagicMock() - run_tool._repository = MagicMock() - run_tool._timeout = 300.0 - result = create_exec_tools(run_tool) - assert len(result) == 4 - exec_tool, write_tool, poll_tool, kill_tool = result - assert isinstance(exec_tool, SkillExecTool) - assert isinstance(write_tool, WriteStdinTool) - assert isinstance(poll_tool, PollSessionTool) - assert isinstance(kill_tool, KillSessionTool) - def test_custom_ttl(self): +class TestFactoryAndDeclarations: + def test_create_exec_tools(self): run_tool = MagicMock() run_tool._repository = MagicMock() run_tool._timeout = 300.0 - exec_tool, *_ = create_exec_tools(run_tool, session_ttl=60.0) - assert exec_tool._ttl == 60.0 + tools = create_exec_tools(run_tool) + assert len(tools) == 4 + assert isinstance(tools[0], SkillExecTool) + assert isinstance(tools[1], WriteStdinTool) + assert isinstance(tools[2], PollSessionTool) + def test_declaration_names(self): + exec_tool = _make_exec_tool() + assert exec_tool._get_declaration().name == "skill_exec" + assert WriteStdinTool(exec_tool)._get_declaration().name == "skill_write_stdin" + assert PollSessionTool(exec_tool)._get_declaration().name == "skill_poll_session" -# --------------------------------------------------------------------------- -# _close_session -# --------------------------------------------------------------------------- class TestCloseSession: - def test_close_with_reader_task(self): - sess = MagicMock() - sess.reader_task = MagicMock() - sess.reader_task.done = MagicMock(return_value=False) - sess.master_fd = None - sess.proc.stdin = None - _close_session(sess) - sess.reader_task.cancel.assert_called_once() - - def test_close_with_master_fd(self): - sess = MagicMock() - sess.reader_task = None - sess.master_fd = 42 - sess.proc.stdin = None - with patch("os.close") as mock_close: - _close_session(sess) - mock_close.assert_called_once_with(42) - assert sess.master_fd is None - - def test_close_with_stdin(self): - sess = MagicMock() - sess.reader_task = None - sess.master_fd = None - sess.proc.stdin = MagicMock() - sess.proc.stdin.is_closing = MagicMock(return_value=False) - _close_session(sess) - sess.proc.stdin.close.assert_called_once() - - def test_close_already_done(self): + @pytest.mark.asyncio + async def test_close_session(self): sess = MagicMock() - sess.reader_task = MagicMock() - sess.reader_task.done = MagicMock(return_value=True) - sess.master_fd = None - sess.proc.stdin = None - _close_session(sess) - sess.reader_task.cancel.assert_not_called() - - def test_close_no_reader_task(self): - sess = MagicMock() - sess.reader_task = None - sess.master_fd = None - sess.proc.stdin = None - _close_session(sess) - - def test_close_master_fd_oserror(self): - sess = MagicMock() - sess.reader_task = None - sess.master_fd = 42 - sess.proc.stdin = None - with patch("os.close", side_effect=OSError("test")): - _close_session(sess) - assert sess.master_fd is None - - def test_close_stdin_closing(self): - sess = MagicMock() - sess.reader_task = None - sess.master_fd = None - sess.proc.stdin = MagicMock() - sess.proc.stdin.is_closing = MagicMock(return_value=True) - _close_session(sess) - sess.proc.stdin.close.assert_not_called() - - -# --------------------------------------------------------------------------- -# _ExecSession -# --------------------------------------------------------------------------- - -class TestExecSession: - async def test_append_and_total_output(self): - from trpc_agent_sdk.skills.tools._skill_exec import _ExecSession - proc = MagicMock() - proc.returncode = None - ws = MagicMock() - in_data = ExecInput(skill="test", command="echo hello") - sess = _ExecSession(proc=proc, ws=ws, in_data=in_data) - await sess.append_output("hello ") - await sess.append_output("world") - total = await sess.total_output() - assert total == "hello world" - - async def test_yield_output_exited(self): - from trpc_agent_sdk.skills.tools._skill_exec import _ExecSession - proc = MagicMock() - proc.returncode = 0 - ws = MagicMock() - in_data = ExecInput(skill="test", command="echo hello") - sess = _ExecSession(proc=proc, ws=ws, in_data=in_data) - await sess.append_output("output text") - status, chunk, offset, next_offset = await sess.yield_output(50, 0) - assert status == "exited" - assert "output text" in chunk - assert sess.exit_code == 0 - - -# --------------------------------------------------------------------------- -# _build_exec_env -# --------------------------------------------------------------------------- - -class TestBuildExecEnv: - def test_builds_env(self, tmp_path): - from trpc_agent_sdk.skills.tools._skill_exec import _build_exec_env - ws = MagicMock() - ws.path = str(tmp_path) - env = _build_exec_env(ws, {"MY_VAR": "test"}) - assert "MY_VAR" in env - assert env["MY_VAR"] == "test" - - def test_sets_workspace_dirs(self, tmp_path): - from trpc_agent_sdk.skills.tools._skill_exec import _build_exec_env - ws = MagicMock() - ws.path = str(tmp_path) - env = _build_exec_env(ws, {}) - assert any("skills" in v.lower() or "SKILLS" in k for k, v in env.items()) - - -# --------------------------------------------------------------------------- -# GC expired sessions -# --------------------------------------------------------------------------- - -class TestGcExpired: - async def test_gc_expired_sessions(self): - import time - run_tool = MagicMock() - run_tool._repository = MagicMock() - run_tool._timeout = 300.0 - tool = SkillExecTool(run_tool, session_ttl=1.0) - - mock_session = MagicMock() - mock_session.exited_at = time.time() - 100 - mock_session.reader_task = None - mock_session.master_fd = None - mock_session.proc.stdin = None - tool._sessions["expired_session"] = mock_session - - # gc is called within put_session - new_session = MagicMock() - new_session.exited_at = None - await tool._put_session("new", new_session) - - assert "expired_session" not in tool._sessions - assert "new" in tool._sessions - - async def test_gc_ttl_zero_skips(self): - run_tool = MagicMock() - run_tool._repository = MagicMock() - run_tool._timeout = 300.0 - tool = SkillExecTool(run_tool, session_ttl=0) - - mock_session = MagicMock() - mock_session.exited_at = 1.0 - tool._sessions["s1"] = mock_session - - new_session = MagicMock() - new_session.exited_at = None - await tool._put_session("s2", new_session) - - assert "s1" in tool._sessions - - -# --------------------------------------------------------------------------- -# _write_stdin helper -# --------------------------------------------------------------------------- - -class TestWriteStdin: - async def test_write_pipe(self): - from trpc_agent_sdk.skills.tools._skill_exec import _write_stdin, _ExecSession - proc = MagicMock() - proc.returncode = None - proc.stdin = MagicMock() - proc.stdin.write = MagicMock() - proc.stdin.drain = AsyncMock() - ws = MagicMock() - in_data = ExecInput(skill="test", command="cat") - sess = _ExecSession(proc=proc, ws=ws, in_data=in_data) - - await _write_stdin(sess, "hello", submit=True) - proc.stdin.write.assert_called_once() - written = proc.stdin.write.call_args[0][0] - assert b"hello\n" == written - - async def test_write_pipe_no_submit(self): - from trpc_agent_sdk.skills.tools._skill_exec import _write_stdin, _ExecSession - proc = MagicMock() - proc.returncode = None - proc.stdin = MagicMock() - proc.stdin.write = MagicMock() - proc.stdin.drain = AsyncMock() - ws = MagicMock() - in_data = ExecInput(skill="test", command="cat") - sess = _ExecSession(proc=proc, ws=ws, in_data=in_data) - - await _write_stdin(sess, "hello", submit=False) - written = proc.stdin.write.call_args[0][0] - assert b"hello" == written - - async def test_write_pipe_error_handled(self): - from trpc_agent_sdk.skills.tools._skill_exec import _write_stdin, _ExecSession - proc = MagicMock() - proc.returncode = None - proc.stdin = MagicMock() - proc.stdin.write = MagicMock(side_effect=RuntimeError("broken")) - proc.stdin.drain = AsyncMock() - ws = MagicMock() - in_data = ExecInput(skill="test", command="cat") - sess = _ExecSession(proc=proc, ws=ws, in_data=in_data) - - await _write_stdin(sess, "test", submit=False) - - async def test_write_pty(self): - from trpc_agent_sdk.skills.tools._skill_exec import _write_stdin, _ExecSession - proc = MagicMock() - proc.returncode = None - proc.stdin = None - ws = MagicMock() - in_data = ExecInput(skill="test", command="cat") - sess = _ExecSession(proc=proc, ws=ws, in_data=in_data) - sess.master_fd = 42 - - with patch("os.write") as mock_write: - await _write_stdin(sess, "hello", submit=True) - mock_write.assert_called_once_with(42, b"hello\n") - - async def test_write_pty_error(self): - from trpc_agent_sdk.skills.tools._skill_exec import _write_stdin, _ExecSession - proc = MagicMock() - proc.returncode = None - proc.stdin = None - ws = MagicMock() - in_data = ExecInput(skill="test", command="cat") - sess = _ExecSession(proc=proc, ws=ws, in_data=in_data) - sess.master_fd = 42 - - with patch("os.write", side_effect=OSError("broken")): - await _write_stdin(sess, "hello", submit=False) - - async def test_write_no_stdin_no_pty(self): - from trpc_agent_sdk.skills.tools._skill_exec import _write_stdin, _ExecSession - proc = MagicMock() - proc.returncode = None - proc.stdin = None - ws = MagicMock() - in_data = ExecInput(skill="test", command="cat") - sess = _ExecSession(proc=proc, ws=ws, in_data=in_data) - await _write_stdin(sess, "hello", submit=True) - - -# --------------------------------------------------------------------------- -# _collect_final_result -# --------------------------------------------------------------------------- - -class TestCollectFinalResult: - async def test_already_finalized(self): - from trpc_agent_sdk.skills.tools._skill_exec import _collect_final_result, _ExecSession - from trpc_agent_sdk.skills.tools._skill_run import SkillRunOutput - proc = MagicMock() - proc.returncode = 0 - ws = MagicMock() - in_data = ExecInput(skill="test", command="echo") - sess = _ExecSession(proc=proc, ws=ws, in_data=in_data) - sess.finalized = True - expected = SkillRunOutput(stdout="cached") - sess.final_result = expected - - ctx = MagicMock() - run_tool = MagicMock() - result = await _collect_final_result(ctx, sess, run_tool) - assert result is expected - - async def test_collect_with_outputs(self): - from trpc_agent_sdk.skills.tools._skill_exec import _collect_final_result, _ExecSession - from trpc_agent_sdk.skills.tools._skill_run import SkillRunFile - proc = MagicMock() - proc.returncode = 0 - ws = MagicMock() - in_data = ExecInput(skill="test", command="echo", output_files=["out/*.txt"]) - sess = _ExecSession(proc=proc, ws=ws, in_data=in_data) - sess.exit_code = 0 - await sess.append_output("test output") - - ctx = MagicMock() - ctx.artifact_service = None - run_tool = MagicMock() - run_tool._prepare_outputs = AsyncMock(return_value=([], None)) - run_tool._attach_artifacts_if_requested = AsyncMock() - run_tool._merge_manifest_artifact_refs = MagicMock() - - result = await _collect_final_result(ctx, sess, run_tool) - assert result is not None - assert result.stdout == "test output" - assert sess.finalized is True - - async def test_collect_prepare_outputs_error(self): - from trpc_agent_sdk.skills.tools._skill_exec import _collect_final_result, _ExecSession - proc = MagicMock() - proc.returncode = 0 - ws = MagicMock() - in_data = ExecInput(skill="test", command="echo") - sess = _ExecSession(proc=proc, ws=ws, in_data=in_data) - sess.exit_code = 0 - await sess.append_output("output") - - ctx = MagicMock() - ctx.artifact_service = None - run_tool = MagicMock() - run_tool._prepare_outputs = AsyncMock(side_effect=RuntimeError("fail")) - run_tool._attach_artifacts_if_requested = AsyncMock() - run_tool._merge_manifest_artifact_refs = MagicMock() - - result = await _collect_final_result(ctx, sess, run_tool) - assert result is not None - assert sess.finalized is True - - -# --------------------------------------------------------------------------- -# _ExecSession — yield_output more tests -# --------------------------------------------------------------------------- - -class TestExecSessionYieldOutput: - async def test_yield_with_poll_lines_limit(self): - from trpc_agent_sdk.skills.tools._skill_exec import _ExecSession - proc = MagicMock() - proc.returncode = 0 - ws = MagicMock() - in_data = ExecInput(skill="test", command="echo") - sess = _ExecSession(proc=proc, ws=ws, in_data=in_data) - - long_output = "\n".join([f"line {i}" for i in range(100)]) - await sess.append_output(long_output) - - status, chunk, offset, next_offset = await sess.yield_output(10, 5) - lines = chunk.strip().split("\n") - assert len(lines) <= 5 - - async def test_yield_incremental(self): - from trpc_agent_sdk.skills.tools._skill_exec import _ExecSession - proc = MagicMock() - proc.returncode = 0 - ws = MagicMock() - in_data = ExecInput(skill="test", command="echo") - sess = _ExecSession(proc=proc, ws=ws, in_data=in_data) - - await sess.append_output("first") - _, chunk1, _, _ = await sess.yield_output(10, 0) - assert "first" in chunk1 - - await sess.append_output("second") - _, chunk2, _, _ = await sess.yield_output(10, 0) - assert "second" in chunk2 - assert "first" not in chunk2 - - -# --------------------------------------------------------------------------- -# _read_pipe -# --------------------------------------------------------------------------- - -class TestReadPipe: - async def test_read_pipe_empty(self): - from trpc_agent_sdk.skills.tools._skill_exec import _read_pipe, _ExecSession - proc = MagicMock() - proc.returncode = 0 - proc.wait = AsyncMock() - ws = MagicMock() - in_data = ExecInput(skill="test", command="echo") - sess = _ExecSession(proc=proc, ws=ws, in_data=in_data) - - stream = AsyncMock() - stream.read = AsyncMock(return_value=b"") - - await _read_pipe(sess, stream) - - async def test_read_pipe_with_data(self): - from trpc_agent_sdk.skills.tools._skill_exec import _read_pipe, _ExecSession - proc = MagicMock() - proc.returncode = 0 - proc.wait = AsyncMock() - ws = MagicMock() - in_data = ExecInput(skill="test", command="echo") - sess = _ExecSession(proc=proc, ws=ws, in_data=in_data) - - stream = AsyncMock() - stream.read = AsyncMock(side_effect=[b"hello", b""]) - - await _read_pipe(sess, stream) - total = await sess.total_output() - assert "hello" in total - - -# --------------------------------------------------------------------------- -# KillSessionTool._run_async_impl -# --------------------------------------------------------------------------- - -class TestKillSessionToolRun: - def _make_exec_tool(self): - run_tool = MagicMock() - run_tool._repository = MagicMock() - run_tool._timeout = 300.0 - return SkillExecTool(run_tool) - - async def test_kill_exited_session(self): - exec_tool = self._make_exec_tool() - kill_tool = KillSessionTool(exec_tool) - - mock_session = MagicMock() - mock_session.proc.returncode = 0 - mock_session.exited_at = None - mock_session.reader_task = None - mock_session.master_fd = None - mock_session.proc.stdin = None - - await exec_tool._put_session("s1", mock_session) - - ctx = MagicMock() - result = await kill_tool._run_async_impl( - tool_context=ctx, - args={"session_id": "s1"}, - ) - assert result["ok"] is True - assert result["status"] == "exited" - - async def test_kill_running_session(self): - exec_tool = self._make_exec_tool() - kill_tool = KillSessionTool(exec_tool) - - mock_session = MagicMock() - mock_session.proc.returncode = None - mock_session.proc.kill = MagicMock() - mock_session.proc.wait = AsyncMock() - mock_session.exited_at = None - mock_session.reader_task = None - mock_session.master_fd = None - mock_session.proc.stdin = None - - await exec_tool._put_session("s1", mock_session) - - ctx = MagicMock() - result = await kill_tool._run_async_impl( - tool_context=ctx, - args={"session_id": "s1"}, - ) - assert result["ok"] is True - assert result["status"] == "killed" - mock_session.proc.kill.assert_called_once() - - async def test_kill_invalid_args_raises(self): - exec_tool = self._make_exec_tool() - kill_tool = KillSessionTool(exec_tool) - ctx = MagicMock() - with pytest.raises(ValueError, match="Invalid"): - await kill_tool._run_async_impl(tool_context=ctx, args={}) - - async def test_kill_unknown_session_raises(self): - exec_tool = self._make_exec_tool() - kill_tool = KillSessionTool(exec_tool) - ctx = MagicMock() - with pytest.raises(ValueError, match="unknown session_id"): - await kill_tool._run_async_impl( - tool_context=ctx, - args={"session_id": "nonexistent"}, - ) - - -# --------------------------------------------------------------------------- -# PollSessionTool._run_async_impl -# --------------------------------------------------------------------------- - -class TestPollSessionToolRun: - async def test_poll_running_session(self): - run_tool = MagicMock() - run_tool._repository = MagicMock() - run_tool._timeout = 300.0 - exec_tool = SkillExecTool(run_tool) - poll_tool = PollSessionTool(exec_tool) - - from trpc_agent_sdk.skills.tools._skill_exec import _ExecSession - proc = MagicMock() - proc.returncode = 0 - ws = MagicMock() - in_data = ExecInput(skill="test", command="echo") - sess = _ExecSession(proc=proc, ws=ws, in_data=in_data) - await sess.append_output("poll output") - - await exec_tool._put_session("s1", sess) - - ctx = MagicMock() - ctx.artifact_service = None - result = await poll_tool._run_async_impl( - tool_context=ctx, - args={"session_id": "s1", "yield_ms": 10}, - ) - assert result["status"] == "exited" - assert "poll output" in result["output"] - - async def test_poll_invalid_args_raises(self): - run_tool = MagicMock() - run_tool._repository = MagicMock() - run_tool._timeout = 300.0 - exec_tool = SkillExecTool(run_tool) - poll_tool = PollSessionTool(exec_tool) - ctx = MagicMock() - with pytest.raises(ValueError, match="Invalid"): - await poll_tool._run_async_impl(tool_context=ctx, args={}) - - -# --------------------------------------------------------------------------- -# WriteStdinTool._run_async_impl -# --------------------------------------------------------------------------- - -class TestWriteStdinToolRun: - async def test_write_and_poll(self): - run_tool = MagicMock() - run_tool._repository = MagicMock() - run_tool._timeout = 300.0 - exec_tool = SkillExecTool(run_tool) - write_tool = WriteStdinTool(exec_tool) - - from trpc_agent_sdk.skills.tools._skill_exec import _ExecSession - proc = MagicMock() - proc.returncode = 0 - proc.stdin = MagicMock() - proc.stdin.write = MagicMock() - proc.stdin.drain = AsyncMock() - ws = MagicMock() - in_data = ExecInput(skill="test", command="cat") - sess = _ExecSession(proc=proc, ws=ws, in_data=in_data) - await sess.append_output("response") - - await exec_tool._put_session("s1", sess) - - ctx = MagicMock() - ctx.artifact_service = None - result = await write_tool._run_async_impl( - tool_context=ctx, - args={"session_id": "s1", "chars": "input", "submit": True, "yield_ms": 10}, - ) - assert result["status"] == "exited" - - async def test_write_empty_chars(self): - run_tool = MagicMock() - run_tool._repository = MagicMock() - run_tool._timeout = 300.0 - exec_tool = SkillExecTool(run_tool) - write_tool = WriteStdinTool(exec_tool) - - from trpc_agent_sdk.skills.tools._skill_exec import _ExecSession - proc = MagicMock() - proc.returncode = 0 - ws = MagicMock() - in_data = ExecInput(skill="test", command="cat") - sess = _ExecSession(proc=proc, ws=ws, in_data=in_data) - - await exec_tool._put_session("s1", sess) - - ctx = MagicMock() - ctx.artifact_service = None - result = await write_tool._run_async_impl( - tool_context=ctx, - args={"session_id": "s1", "yield_ms": 10}, - ) - assert "status" in result - - async def test_write_invalid_args_raises(self): - run_tool = MagicMock() - run_tool._repository = MagicMock() - run_tool._timeout = 300.0 - exec_tool = SkillExecTool(run_tool) - write_tool = WriteStdinTool(exec_tool) - ctx = MagicMock() - with pytest.raises(ValueError, match="Invalid"): - await write_tool._run_async_impl(tool_context=ctx, args={}) + sess.proc.close = AsyncMock() + await _close_session(sess) + sess.proc.close.assert_awaited_once() diff --git a/tests/skills/tools/test_skill_list_docs.py b/tests/skills/tools/test_skill_list_docs.py index 33e6a76..26999d9 100644 --- a/tests/skills/tools/test_skill_list_docs.py +++ b/tests/skills/tools/test_skill_list_docs.py @@ -42,8 +42,7 @@ def test_returns_docs(self): ctx = _make_ctx(repository=repo) result = skill_list_docs(ctx, "test") - assert result["docs"] == ["guide.md", "api.md"] - assert result["body_loaded"] is True + assert result == ["guide.md", "api.md"] def test_no_body(self): skill = Skill(summary=SkillSummary(name="test"), body="") @@ -52,17 +51,16 @@ def test_no_body(self): ctx = _make_ctx(repository=repo) result = skill_list_docs(ctx, "test") - assert result["body_loaded"] is False + assert result == [] def test_skill_not_found(self): repo = MagicMock() - repo.get = MagicMock(return_value=None) + repo.get = MagicMock(side_effect=ValueError("not found")) ctx = _make_ctx(repository=repo) - result = skill_list_docs(ctx, "nonexistent") - assert result == {"docs": [], "body_loaded": False} + with pytest.raises(ValueError, match="unknown skill"): + skill_list_docs(ctx, "nonexistent") def test_no_repository_raises(self): ctx = _make_ctx(repository=None) - with pytest.raises(ValueError, match="repository not found"): - skill_list_docs(ctx, "test") + assert skill_list_docs(ctx, "test") == [] diff --git a/tests/skills/tools/test_skill_list_tool.py b/tests/skills/tools/test_skill_list_tool.py index 38769d4..220524a 100644 --- a/tests/skills/tools/test_skill_list_tool.py +++ b/tests/skills/tools/test_skill_list_tool.py @@ -3,13 +3,7 @@ # Copyright (C) 2026 Tencent. All rights reserved. # # tRPC-Agent-Python is licensed under Apache-2.0. -"""Unit tests for trpc_agent_sdk.skills.tools._skill_list_tool. - -Covers: -- _extract_shell_examples_from_skill_body: Command section parsing -- skill_list_tools: returns tools and command examples -- skill_list_tools: handles missing skill / repository -""" +"""Unit tests for trpc_agent_sdk.skills.tools._skill_list_tool.""" from __future__ import annotations @@ -19,79 +13,10 @@ from trpc_agent_sdk.skills._types import Skill, SkillSummary from trpc_agent_sdk.skills.tools._skill_list_tool import ( - _extract_shell_examples_from_skill_body, skill_list_tools, ) -# --------------------------------------------------------------------------- -# _extract_shell_examples_from_skill_body -# --------------------------------------------------------------------------- - -class TestExtractShellExamples: - def test_empty_body(self): - assert _extract_shell_examples_from_skill_body("") == [] - - def test_command_section(self): - body = "Command:\n python scripts/run.py --input data.csv\n\nOverview" - result = _extract_shell_examples_from_skill_body(body) - assert len(result) >= 1 - assert "python scripts/run.py" in result[0] - - def test_limit(self): - body = "" - for i in range(10): - body += f"Command:\n cmd_{i} --arg\n\n" - result = _extract_shell_examples_from_skill_body(body, limit=3) - assert len(result) <= 3 - - def test_stops_at_section_break(self): - body = "Command:\n python run.py\n\nOutput files\nMore content" - result = _extract_shell_examples_from_skill_body(body) - assert len(result) == 1 - - def test_multiline_command(self): - body = "Command:\n python scripts/long.py \\\n --arg1 val1 \\\n --arg2 val2\n\n" - result = _extract_shell_examples_from_skill_body(body) - assert len(result) >= 1 - - def test_no_command_section(self): - body = "# Overview\nJust a description.\n" - result = _extract_shell_examples_from_skill_body(body) - assert result == [] - - def test_deduplication(self): - body = "Command:\n python run.py\n\nCommand:\n python run.py\n" - result = _extract_shell_examples_from_skill_body(body) - assert len(result) == 1 - - def test_rejects_non_command_starting_chars(self): - body = "Command:\n !not_a_command\n\n" - result = _extract_shell_examples_from_skill_body(body) - assert result == [] - - def test_command_with_numbered_break(self): - body = "Command:\n python run.py\n\n1) Next section\n" - result = _extract_shell_examples_from_skill_body(body) - assert len(result) == 1 - - def test_skips_empty_lines_before_command(self): - body = "Command:\n\n\n python run.py\n\nEnd" - result = _extract_shell_examples_from_skill_body(body) - assert len(result) >= 1 - - def test_stops_at_tools_section(self): - body = "Command:\n python run.py\n\ntools:\n- tool1" - result = _extract_shell_examples_from_skill_body(body) - assert len(result) == 1 - - def test_whitespace_normalization(self): - body = "Command:\n python run.py --arg val\n\n" - result = _extract_shell_examples_from_skill_body(body) - assert len(result) == 1 - assert " " not in result[0] - - # --------------------------------------------------------------------------- # skill_list_tools # --------------------------------------------------------------------------- @@ -103,7 +28,7 @@ def _make_ctx(repository=None): class TestSkillListTools: - def test_returns_tools_and_examples(self): + def test_returns_tools(self): skill = Skill( summary=SkillSummary(name="test"), body="Command:\n python run.py\n\nOverview", @@ -114,8 +39,7 @@ def test_returns_tools_and_examples(self): ctx = _make_ctx(repository=repo) result = skill_list_tools(ctx, "test") - assert result["tools"] == ["get_weather", "get_data"] - assert len(result["command_examples"]) >= 1 + assert result["available_tools"] == ["get_weather", "get_data"] def test_skill_not_found(self): repo = MagicMock() @@ -123,7 +47,7 @@ def test_skill_not_found(self): ctx = _make_ctx(repository=repo) result = skill_list_tools(ctx, "nonexistent") - assert result == {"tools": [], "command_examples": []} + assert result == {"available_tools": []} def test_no_repository_raises(self): ctx = _make_ctx(repository=None) @@ -137,4 +61,4 @@ def test_no_tools_or_examples(self): ctx = _make_ctx(repository=repo) result = skill_list_tools(ctx, "test") - assert result["tools"] == [] + assert result["available_tools"] == [] diff --git a/tests/skills/tools/test_skill_load.py b/tests/skills/tools/test_skill_load.py index 3bb87f5..7c4987b 100644 --- a/tests/skills/tools/test_skill_load.py +++ b/tests/skills/tools/test_skill_load.py @@ -14,17 +14,22 @@ from __future__ import annotations +import asyncio import json +from unittest.mock import AsyncMock from unittest.mock import MagicMock +from unittest.mock import patch import pytest +from trpc_agent_sdk.skills._common import docs_state_key +from trpc_agent_sdk.skills._common import loaded_state_key +from trpc_agent_sdk.skills._common import set_state_delta +from trpc_agent_sdk.skills._common import tool_state_key +from trpc_agent_sdk.skills._constants import SKILL_REPOSITORY_KEY from trpc_agent_sdk.skills._types import Skill, SkillSummary from trpc_agent_sdk.skills.tools._skill_load import ( - _set_state_delta, - _set_state_delta_for_skill_load, - _set_state_delta_for_skill_tools, - skill_load, + SkillLoadTool, ) @@ -32,9 +37,44 @@ def _make_ctx(repository=None): ctx = MagicMock() ctx.actions.state_delta = {} ctx.agent_context.get_metadata = MagicMock(return_value=repository) + ctx.agent_name = "" return ctx +def _set_state_delta_for_skill_load(ctx, skill_name: str, docs: list[str], include_all_docs: bool = False): + set_state_delta(ctx, loaded_state_key(ctx, skill_name), True) + set_state_delta( + ctx, + docs_state_key(ctx, skill_name), + "*" if include_all_docs else json.dumps(docs or []), + ) + + +def _set_state_delta_for_skill_tools(ctx, skill_name: str, tools: list[str]): + set_state_delta(ctx, tool_state_key(ctx, skill_name), json.dumps(tools or [])) + + +def _set_state_delta(ctx, key: str, value: str): + set_state_delta(ctx, key, value) + + +def skill_load(ctx, skill_name: str, docs: list[str] | None = None, include_all_docs: bool = False) -> str: + repository = ctx.agent_context.get_metadata(SKILL_REPOSITORY_KEY) + if repository is None: + raise ValueError("repository not found") + tool = SkillLoadTool(repository=repository) + with patch.object(SkillLoadTool, "_ensure_staged", new=AsyncMock(return_value=None)): + return asyncio.run( + tool._run_async_impl( + tool_context=ctx, + args={ + "skill_name": skill_name, + "docs": docs or [], + "include_all_docs": include_all_docs, + }, + )) + + # --------------------------------------------------------------------------- # _set_state_delta # --------------------------------------------------------------------------- @@ -55,21 +95,24 @@ class TestSetStateDeltaForSkillLoad: def test_sets_loaded_flag(self): ctx = MagicMock() ctx.actions.state_delta = {} + ctx.agent_name = "" _set_state_delta_for_skill_load(ctx, "test-skill", []) - assert ctx.actions.state_delta["temp:skill:loaded:test-skill"] is True + assert ctx.actions.state_delta[loaded_state_key(ctx, "test-skill")] is True def test_sets_docs_as_json(self): ctx = MagicMock() ctx.actions.state_delta = {} + ctx.agent_name = "" _set_state_delta_for_skill_load(ctx, "test-skill", ["doc1.md", "doc2.md"]) - docs_value = ctx.actions.state_delta["temp:skill:docs:test-skill"] + docs_value = ctx.actions.state_delta[docs_state_key(ctx, "test-skill")] assert json.loads(docs_value) == ["doc1.md", "doc2.md"] def test_include_all_docs_sets_star(self): ctx = MagicMock() ctx.actions.state_delta = {} + ctx.agent_name = "" _set_state_delta_for_skill_load(ctx, "test-skill", [], include_all_docs=True) - assert ctx.actions.state_delta["temp:skill:docs:test-skill"] == "*" + assert ctx.actions.state_delta[docs_state_key(ctx, "test-skill")] == "*" # --------------------------------------------------------------------------- @@ -80,15 +123,17 @@ class TestSetStateDeltaForSkillTools: def test_sets_tools_as_json(self): ctx = MagicMock() ctx.actions.state_delta = {} + ctx.agent_name = "" _set_state_delta_for_skill_tools(ctx, "test-skill", ["tool_a", "tool_b"]) - tools_value = ctx.actions.state_delta["temp:skill:tools:test-skill"] + tools_value = ctx.actions.state_delta[tool_state_key(ctx, "test-skill")] assert json.loads(tools_value) == ["tool_a", "tool_b"] def test_empty_tools(self): ctx = MagicMock() ctx.actions.state_delta = {} + ctx.agent_name = "" _set_state_delta_for_skill_tools(ctx, "test-skill", []) - tools_value = ctx.actions.state_delta["temp:skill:tools:test-skill"] + tools_value = ctx.actions.state_delta[tool_state_key(ctx, "test-skill")] assert json.loads(tools_value) == [] @@ -105,15 +150,15 @@ def test_load_success(self): result = skill_load(ctx, "test") assert "loaded" in result - assert ctx.actions.state_delta["temp:skill:loaded:test"] is True + assert ctx.actions.state_delta[loaded_state_key(ctx, "test")] is True def test_load_not_found(self): repo = MagicMock() - repo.get = MagicMock(return_value=None) + repo.get = MagicMock(side_effect=ValueError("not found")) ctx = _make_ctx(repository=repo) - result = skill_load(ctx, "nonexistent") - assert "not found" in result + with pytest.raises(ValueError, match="not found"): + skill_load(ctx, "nonexistent") def test_load_no_repository_raises(self): ctx = _make_ctx(repository=None) @@ -131,7 +176,7 @@ def test_load_with_tools_sets_tools_state(self): ctx = _make_ctx(repository=repo) skill_load(ctx, "test") - tools_key = "temp:skill:tools:test" + tools_key = tool_state_key(ctx, "test") assert tools_key in ctx.actions.state_delta assert json.loads(ctx.actions.state_delta[tools_key]) == ["get_weather", "get_data"] @@ -142,7 +187,7 @@ def test_load_without_tools_does_not_set_tools_state(self): ctx = _make_ctx(repository=repo) skill_load(ctx, "test") - assert "temp:skill:tools:test" not in ctx.actions.state_delta + assert tool_state_key(ctx, "test") not in ctx.actions.state_delta def test_load_with_docs(self): skill = Skill(summary=SkillSummary(name="test"), body="# Body") @@ -151,7 +196,7 @@ def test_load_with_docs(self): ctx = _make_ctx(repository=repo) skill_load(ctx, "test", docs=["doc1.md"]) - docs_key = "temp:skill:docs:test" + docs_key = docs_state_key(ctx, "test") assert json.loads(ctx.actions.state_delta[docs_key]) == ["doc1.md"] def test_load_include_all_docs(self): @@ -161,5 +206,5 @@ def test_load_include_all_docs(self): ctx = _make_ctx(repository=repo) skill_load(ctx, "test", include_all_docs=True) - docs_key = "temp:skill:docs:test" + docs_key = docs_state_key(ctx, "test") assert ctx.actions.state_delta[docs_key] == "*" diff --git a/tests/skills/tools/test_skill_run.py b/tests/skills/tools/test_skill_run.py index c7489c2..b1e43ae 100644 --- a/tests/skills/tools/test_skill_run.py +++ b/tests/skills/tools/test_skill_run.py @@ -3,1050 +3,122 @@ # Copyright (C) 2026 Tencent. All rights reserved. # # tRPC-Agent-Python is licensed under Apache-2.0. -"""Unit tests for trpc_agent_sdk.skills.tools._skill_run. - -Covers: -- Module-level helpers: - _inline_json_schema_refs, _is_text_mime, _should_inline_file_content, - _truncate_output, _workspace_ref, _filter_failed_empty_outputs, - _select_primary_output, _split_command_line, _build_editor_wrapper_script -- Pydantic models: SkillRunFile, SkillRunInput, SkillRunOutput, ArtifactInfo -- SkillRunTool: - _resolve_cwd, _build_command, _wrap_with_venv, _is_skill_loaded, - _get_repository, _is_missing_command_result, _extract_command_path_candidates, - _extract_shell_examples_from_skill_body, _with_missing_command_hint -""" - -from __future__ import annotations import os -from pathlib import Path -from unittest.mock import AsyncMock, MagicMock, patch +from unittest.mock import MagicMock import pytest - -from trpc_agent_sdk.skills.tools._skill_run import ( - ArtifactInfo, - SkillRunFile, - SkillRunInput, - SkillRunOutput, - SkillRunTool, - _build_editor_wrapper_script, - _filter_failed_empty_outputs, - _inline_json_schema_refs, - _is_text_mime, - _select_primary_output, - _should_inline_file_content, - _split_command_line, - _truncate_output, - _workspace_ref, -) - - -# --------------------------------------------------------------------------- -# _inline_json_schema_refs -# --------------------------------------------------------------------------- - -class TestInlineJsonSchemaRefs: - def test_no_refs(self): - schema = {"type": "object", "properties": {"name": {"type": "string"}}} - result = _inline_json_schema_refs(schema) - assert result == schema - - def test_with_refs(self): - schema = { - "$defs": {"Foo": {"type": "object", "properties": {"x": {"type": "integer"}}}}, - "properties": {"foo": {"$ref": "#/$defs/Foo"}}, - } - result = _inline_json_schema_refs(schema) - assert "$defs" not in result - assert result["properties"]["foo"]["type"] == "object" - - def test_nested_refs(self): - schema = { - "$defs": {"Bar": {"type": "string"}}, - "properties": {"items": {"type": "array", "items": {"$ref": "#/$defs/Bar"}}}, - } - result = _inline_json_schema_refs(schema) - assert result["properties"]["items"]["items"]["type"] == "string" - - -# --------------------------------------------------------------------------- -# _is_text_mime -# --------------------------------------------------------------------------- - -class TestIsTextMime: - def test_text_plain(self): +from trpc_agent_sdk.skills._common import loaded_state_key +from trpc_agent_sdk.skills.tools._common import inline_json_schema_refs +from trpc_agent_sdk.skills.tools._skill_run import ArtifactInfo +from trpc_agent_sdk.skills.tools._skill_run import SkillRunFile +from trpc_agent_sdk.skills.tools._skill_run import SkillRunInput +from trpc_agent_sdk.skills.tools._skill_run import SkillRunOutput +from trpc_agent_sdk.skills.tools._skill_run import SkillRunTool +from trpc_agent_sdk.skills.tools._skill_run import _build_editor_wrapper_script +from trpc_agent_sdk.skills.tools._skill_run import _filter_failed_empty_outputs +from trpc_agent_sdk.skills.tools._skill_run import _is_text_mime +from trpc_agent_sdk.skills.tools._skill_run import _select_primary_output +from trpc_agent_sdk.skills.tools._skill_run import _should_inline_file_content +from trpc_agent_sdk.skills.tools._skill_run import _split_command_line +from trpc_agent_sdk.skills.tools._skill_run import _truncate_output +from trpc_agent_sdk.skills.tools._skill_run import _workspace_ref + + +def _make_tool() -> SkillRunTool: + repo = MagicMock() + repo.workspace_runtime = MagicMock() + return SkillRunTool(repository=repo) + + +class TestSchemaHelpers: + def test_inline_json_schema_refs(self): + schema = {"$defs": {"X": {"type": "string"}}, "properties": {"x": {"$ref": "#/$defs/X"}}} + out = inline_json_schema_refs(schema) + assert "$defs" not in out + assert out["properties"]["x"]["type"] == "string" + + +class TestModuleHelpers: + def test_is_text_mime(self): assert _is_text_mime("text/plain") is True - - def test_text_html(self): - assert _is_text_mime("text/html") is True - - def test_application_json(self): assert _is_text_mime("application/json") is True - - def test_application_yaml(self): - assert _is_text_mime("application/yaml") is True - - def test_application_xml(self): - assert _is_text_mime("application/xml") is True - - def test_image_png(self): assert _is_text_mime("image/png") is False - def test_empty_string_is_text(self): - assert _is_text_mime("") is True - - def test_with_charset(self): - assert _is_text_mime("application/json; charset=utf-8") is True - - def test_octet_stream(self): - assert _is_text_mime("application/octet-stream") is False - - -# --------------------------------------------------------------------------- -# _should_inline_file_content -# --------------------------------------------------------------------------- - -class TestShouldInlineFileContent: - def test_text_file(self): - from trpc_agent_sdk.code_executors import CodeFile - f = CodeFile(name="test.txt", content="hello", mime_type="text/plain", size_bytes=5) - assert _should_inline_file_content(f) is True - - def test_binary_file(self): + def test_should_inline_file_content(self): from trpc_agent_sdk.code_executors import CodeFile - f = CodeFile(name="test.png", content="data", mime_type="image/png", size_bytes=4) - assert _should_inline_file_content(f) is False - - def test_null_bytes_rejected(self): - from trpc_agent_sdk.code_executors import CodeFile - f = CodeFile(name="test.txt", content="hello\x00world", mime_type="text/plain", size_bytes=11) - assert _should_inline_file_content(f) is False - - def test_empty_content(self): - from trpc_agent_sdk.code_executors import CodeFile - f = CodeFile(name="empty.txt", content="", mime_type="text/plain", size_bytes=0) + f = CodeFile(name="a.txt", content="ok", mime_type="text/plain", size_bytes=2) assert _should_inline_file_content(f) is True - -# --------------------------------------------------------------------------- -# _truncate_output -# --------------------------------------------------------------------------- - -class TestTruncateOutput: - def test_short_string(self): - s, truncated = _truncate_output("hello") - assert s == "hello" - assert truncated is False - - def test_long_string(self): + def test_truncate_output(self): s, truncated = _truncate_output("x" * 20000) assert truncated is True assert len(s) <= 16 * 1024 - def test_exact_limit(self): - s, truncated = _truncate_output("x" * (16 * 1024)) - assert truncated is False - + def test_workspace_ref(self): + assert _workspace_ref("a.txt") == "workspace://a.txt" -# --------------------------------------------------------------------------- -# _workspace_ref -# --------------------------------------------------------------------------- + def test_filter_failed_empty_outputs(self): + files = [SkillRunFile(name="a.txt", content="", size_bytes=0)] + kept, warns = _filter_failed_empty_outputs(1, False, files) + assert kept == [] + assert warns -class TestWorkspaceRef: - def test_with_name(self): - assert _workspace_ref("out/file.txt") == "workspace://out/file.txt" - - def test_empty_name(self): - assert _workspace_ref("") == "" - - -# --------------------------------------------------------------------------- -# _filter_failed_empty_outputs -# --------------------------------------------------------------------------- - -class TestFilterFailedEmptyOutputs: - def test_success_no_filter(self): - files = [SkillRunFile(name="a.txt", content="data", size_bytes=4)] - result, warns = _filter_failed_empty_outputs(0, False, files) - assert len(result) == 1 - assert warns == [] - - def test_failure_removes_empty(self): - files = [ - SkillRunFile(name="a.txt", content="data", size_bytes=4), - SkillRunFile(name="empty.txt", content="", size_bytes=0), - ] - result, warns = _filter_failed_empty_outputs(1, False, files) - assert len(result) == 1 - assert result[0].name == "a.txt" - assert len(warns) == 1 - - def test_failure_all_have_content(self): - files = [SkillRunFile(name="a.txt", content="data", size_bytes=4)] - result, warns = _filter_failed_empty_outputs(1, False, files) - assert len(result) == 1 - assert warns == [] - - def test_timeout_removes_empty(self): - files = [SkillRunFile(name="e.txt", content="", size_bytes=0)] - result, warns = _filter_failed_empty_outputs(0, True, files) - assert len(result) == 0 - assert len(warns) == 1 - - -# --------------------------------------------------------------------------- -# _select_primary_output -# --------------------------------------------------------------------------- - -class TestSelectPrimaryOutput: - def test_picks_smallest_text_file(self): + def test_select_primary_output(self): files = [ - SkillRunFile(name="b.txt", content="bbb", mime_type="text/plain"), - SkillRunFile(name="a.txt", content="aaa", mime_type="text/plain"), + SkillRunFile(name="b.txt", content="2", mime_type="text/plain"), + SkillRunFile(name="a.txt", content="1", mime_type="text/plain"), ] - result = _select_primary_output(files) - assert result.name == "a.txt" - - def test_skips_empty_content(self): - files = [SkillRunFile(name="a.txt", content="", mime_type="text/plain")] - result = _select_primary_output(files) - assert result is None - - def test_skips_binary(self): - files = [SkillRunFile(name="a.png", content="data", mime_type="image/png")] - result = _select_primary_output(files) - assert result is None - - def test_empty_files(self): - assert _select_primary_output([]) is None - - def test_skips_too_large(self): - content = "x" * (32 * 1024 + 1) - files = [SkillRunFile(name="big.txt", content=content, mime_type="text/plain")] - assert _select_primary_output(files) is None + best = _select_primary_output(files) + assert best is not None + assert best.name == "a.txt" - -# --------------------------------------------------------------------------- -# _split_command_line -# --------------------------------------------------------------------------- - -class TestSplitCommandLine: - def test_simple_command(self): + def test_split_command_line(self): assert _split_command_line("python run.py") == ["python", "run.py"] + with pytest.raises(ValueError): + _split_command_line("a | b") - def test_quoted_args(self): - result = _split_command_line("echo 'hello world'") - assert result == ["echo", "hello world"] - - def test_double_quoted(self): - result = _split_command_line('echo "hello world"') - assert result == ["echo", "hello world"] - - def test_empty_raises(self): - with pytest.raises(ValueError, match="empty"): - _split_command_line("") - - def test_whitespace_only_raises(self): - with pytest.raises(ValueError, match="empty"): - _split_command_line(" ") - - def test_shell_metachar_rejected(self): - with pytest.raises(ValueError, match="metacharacter"): - _split_command_line("cmd1 | cmd2") - - def test_semicolon_rejected(self): - with pytest.raises(ValueError, match="metacharacter"): - _split_command_line("cmd1; cmd2") - - def test_redirect_rejected(self): - with pytest.raises(ValueError, match="metacharacter"): - _split_command_line("cmd > file") - - def test_unterminated_quote_raises(self): - with pytest.raises(ValueError, match="unterminated"): - _split_command_line("echo 'hello") - - def test_trailing_escape_raises(self): - with pytest.raises(ValueError, match="trailing"): - _split_command_line("echo hello\\") - - def test_escaped_char(self): - result = _split_command_line("echo hello\\ world") - assert result == ["echo", "hello world"] - - def test_newline_rejected(self): - with pytest.raises(ValueError, match="metacharacter"): - _split_command_line("cmd1\ncmd2") - - -# --------------------------------------------------------------------------- -# _build_editor_wrapper_script -# --------------------------------------------------------------------------- - -class TestBuildEditorWrapperScript: - def test_contains_shebang(self): - script = _build_editor_wrapper_script("/tmp/content.txt") + def test_build_editor_wrapper_script(self): + script = _build_editor_wrapper_script("/tmp/file") assert script.startswith("#!/bin/sh") + assert "/tmp/file" in script - def test_contains_content_path(self): - script = _build_editor_wrapper_script("/tmp/content.txt") - assert "/tmp/content.txt" in script - -# --------------------------------------------------------------------------- -# Pydantic models -# --------------------------------------------------------------------------- - -class TestSkillRunModels: - def test_skill_run_file_defaults(self): - f = SkillRunFile() - assert f.name == "" - assert f.content == "" - assert f.size_bytes == 0 - - def test_skill_run_input_required_fields(self): - inp = SkillRunInput(skill="test", command="python run.py") - assert inp.skill == "test" - assert inp.command == "python run.py" - assert inp.env == {} - assert inp.output_files == [] - - def test_skill_run_output_defaults(self): +class TestModels: + def test_run_models(self): + inp = SkillRunInput(skill="s", command="echo hi") out = SkillRunOutput() - assert out.stdout == "" + art = ArtifactInfo(name="a.txt", version=1) + assert inp.skill == "s" assert out.exit_code == 0 - assert out.timed_out is False - assert out.output_files == [] - - def test_artifact_info(self): - a = ArtifactInfo(name="test.txt", version=1) - assert a.name == "test.txt" - assert a.version == 1 - - -# --------------------------------------------------------------------------- -# SkillRunTool — helpers -# --------------------------------------------------------------------------- - -class TestSkillRunToolHelpers: - def _make_run_tool(self, **kwargs): - repo = MagicMock() - repo.workspace_runtime = MagicMock() - return SkillRunTool(repository=repo, **kwargs) - - def test_resolve_cwd_empty(self): - tool = self._make_run_tool() - assert tool._resolve_cwd("", "skills/test") == "skills/test" - - def test_resolve_cwd_relative(self): - tool = self._make_run_tool() - result = tool._resolve_cwd("sub/dir", "skills/test") - assert result == os.path.join("skills/test", "sub/dir") + assert art.version == 1 - def test_resolve_cwd_absolute(self): - tool = self._make_run_tool() - assert tool._resolve_cwd("/abs/path", "skills/test") == "/abs/path" - def test_resolve_cwd_skills_dir_env(self): - tool = self._make_run_tool() - result = tool._resolve_cwd("$SKILLS_DIR/test", "skills/test") - assert "skills" in result +class TestSkillRunToolBasics: + def test_resolve_cwd(self): + tool = _make_tool() + assert tool._resolve_cwd("", "skills/x") == "skills/x" + assert tool._resolve_cwd("sub", "skills/x") == os.path.join("skills/x", "sub") - def test_build_command_no_restrictions(self): - tool = self._make_run_tool() - cmd, args = tool._build_command("python run.py", "/ws", "skills/test") + def test_build_command(self): + tool = _make_tool() + cmd, args = tool._build_command("python run.py", "/tmp/ws", "skills/x") assert cmd == "bash" assert "-c" in args - def test_build_command_with_allowed_cmds(self): - tool = self._make_run_tool(allowed_cmds=["python"]) - cmd, args = tool._build_command("python run.py", "/ws", "skills/test") - assert cmd == "python" - assert args == ["run.py"] - - def test_build_command_denied_cmd_raises(self): - tool = self._make_run_tool(denied_cmds=["rm"]) - with pytest.raises(ValueError, match="denied"): - tool._build_command("rm -rf /", "/ws", "skills/test") - - def test_build_command_not_in_allowed_raises(self): - tool = self._make_run_tool(allowed_cmds=["python"]) - with pytest.raises(ValueError, match="not in allowed"): - tool._build_command("bash script.sh", "/ws", "skills/test") - - def test_is_skill_loaded_true(self): - tool = self._make_run_tool() - ctx = MagicMock() - ctx.session_state = {"temp:skill:loaded:test": True} - assert tool._is_skill_loaded(ctx, "test") is True - - def test_is_skill_loaded_false(self): - tool = self._make_run_tool() - ctx = MagicMock() - ctx.session_state = {} - assert tool._is_skill_loaded(ctx, "test") is False - - def test_is_skill_loaded_exception_defaults_true(self): - tool = self._make_run_tool() - ctx = MagicMock() - ctx.session_state = MagicMock() - ctx.session_state.get = MagicMock(side_effect=RuntimeError("oops")) - assert tool._is_skill_loaded(ctx, "test") is True - - def test_get_repository_from_tool(self): + def test_get_repository(self): repo = MagicMock() repo.workspace_runtime = MagicMock() tool = SkillRunTool(repository=repo) ctx = MagicMock() assert tool._get_repository(ctx) is repo - def test_is_missing_command_result_true(self): - from trpc_agent_sdk.code_executors import WorkspaceRunResult - ret = WorkspaceRunResult( - stdout="", stderr="bash: foo: command not found", - exit_code=127, duration=0, timed_out=False, - ) - assert SkillRunTool._is_missing_command_result(ret) is True - - def test_is_missing_command_result_false(self): - from trpc_agent_sdk.code_executors import WorkspaceRunResult - ret = WorkspaceRunResult( - stdout="", stderr="error", exit_code=1, duration=0, timed_out=False, - ) - assert SkillRunTool._is_missing_command_result(ret) is False - - -# --------------------------------------------------------------------------- -# SkillRunTool — _extract_shell_examples_from_skill_body -# --------------------------------------------------------------------------- - -class TestExtractShellExamplesFromBody: - def test_fenced_code_block(self): - body = "# Usage\n```\npython scripts/run.py --input data.csv\n```\n" - result = SkillRunTool._extract_shell_examples_from_skill_body(body) - assert any("python" in r for r in result) - - def test_command_section(self): - body = "Command:\n python scripts/analyze.py\n\nOverview" - result = SkillRunTool._extract_shell_examples_from_skill_body(body) - assert len(result) >= 1 - - def test_empty_body(self): - assert SkillRunTool._extract_shell_examples_from_skill_body("") == [] - - def test_limit(self): - body = "" - for i in range(10): - body += f"```\ncmd_{i}\n```\n" - result = SkillRunTool._extract_shell_examples_from_skill_body(body, limit=3) - assert len(result) <= 3 - - def test_skips_function_calls(self): - body = "```\nmy_function(arg='value')\n```\n" - result = SkillRunTool._extract_shell_examples_from_skill_body(body) - assert not any("my_function" in r for r in result) - - -# --------------------------------------------------------------------------- -# SkillRunTool — _extract_command_path_candidates -# --------------------------------------------------------------------------- - -class TestExtractCommandPathCandidates: - def test_relative_path(self): - result = SkillRunTool._extract_command_path_candidates("python scripts/run.py") - assert "scripts/run.py" in result - - def test_no_path_like_tokens(self): - result = SkillRunTool._extract_command_path_candidates("ls -la") - assert result == [] - - def test_absolute_path_excluded(self): - result = SkillRunTool._extract_command_path_candidates("python /abs/path.py") - assert result == [] - - def test_flags_excluded(self): - result = SkillRunTool._extract_command_path_candidates("python --version") - assert result == [] - - def test_script_extensions(self): - result = SkillRunTool._extract_command_path_candidates("bash setup.sh") - assert "setup.sh" in result - - -# --------------------------------------------------------------------------- -# SkillRunTool — _with_missing_command_hint -# --------------------------------------------------------------------------- - -class TestWithMissingCommandHint: - def test_adds_hint_for_missing_command(self): - from trpc_agent_sdk.code_executors import WorkspaceRunResult - ret = WorkspaceRunResult( - stdout="", stderr="bash: nonexist: command not found", - exit_code=127, duration=0, timed_out=False, - ) - inp = SkillRunInput(skill="test", command="nonexist") - updated = SkillRunTool._with_missing_command_hint(ret, inp) - assert "hint" in updated.stderr.lower() - - def test_no_hint_for_success(self): - from trpc_agent_sdk.code_executors import WorkspaceRunResult - ret = WorkspaceRunResult( - stdout="ok", stderr="", exit_code=0, duration=0, timed_out=False, - ) - inp = SkillRunInput(skill="test", command="python run.py") - updated = SkillRunTool._with_missing_command_hint(ret, inp) - assert updated.stderr == "" - - -# --------------------------------------------------------------------------- -# SkillRunTool — skill_stager property -# --------------------------------------------------------------------------- - -class TestSkillRunToolProperties: - def test_skill_stager_property(self): - repo = MagicMock() - repo.workspace_runtime = MagicMock() - tool = SkillRunTool(repository=repo) - assert tool.skill_stager is not None - - def test_custom_stager(self): - repo = MagicMock() - repo.workspace_runtime = MagicMock() - custom_stager = MagicMock() - tool = SkillRunTool(repository=repo, skill_stager=custom_stager) - assert tool.skill_stager is custom_stager - - def test_declaration(self): - repo = MagicMock() - repo.workspace_runtime = MagicMock() - tool = SkillRunTool(repository=repo) - decl = tool._get_declaration() - assert decl.name == "skill_run" - - def test_declaration_with_allowed_cmds(self): - repo = MagicMock() - repo.workspace_runtime = MagicMock() - tool = SkillRunTool(repository=repo, allowed_cmds=["python", "bash"]) - decl = tool._get_declaration() - assert "python" in decl.description - - def test_declaration_with_denied_cmds(self): - repo = MagicMock() - repo.workspace_runtime = MagicMock() - tool = SkillRunTool(repository=repo, denied_cmds=["rm"]) - decl = tool._get_declaration() - assert "Restrictions" in decl.description - - -# --------------------------------------------------------------------------- -# SkillRunTool — _wrap_with_venv -# --------------------------------------------------------------------------- - -class TestWrapWithVenv: - def _make_tool(self): - repo = MagicMock() - repo.workspace_runtime = MagicMock() - return SkillRunTool(repository=repo) - - def test_wraps_command(self): - tool = self._make_tool() - result = tool._wrap_with_venv("python run.py", "/ws", "skills/test") - assert "VIRTUAL_ENV" in result - assert "python run.py" in result - assert ".venv" in result - - def test_non_skills_cwd(self): - tool = self._make_tool() - result = tool._wrap_with_venv("echo hi", "/ws", "work/custom") - assert "echo hi" in result - - -# --------------------------------------------------------------------------- -# SkillRunTool — _with_skill_doc_command_hint -# --------------------------------------------------------------------------- - -class TestWithSkillDocCommandHint: - def _make_tool(self): - repo = MagicMock() - repo.workspace_runtime = MagicMock() - return SkillRunTool(repository=repo) - - def test_adds_hint_for_missing_command(self): - from trpc_agent_sdk.code_executors import WorkspaceRunResult - tool = self._make_tool() - ret = WorkspaceRunResult( - stdout="", stderr="bash: nonexist: command not found", - exit_code=127, duration=0, timed_out=False, - ) - skill = MagicMock() - skill.body = "```\npython scripts/run.py\n```\n" - skill.tools = ["get_weather"] - repo = MagicMock() - repo.get = MagicMock(return_value=skill) - inp = SkillRunInput(skill="test", command="nonexist") - updated = tool._with_skill_doc_command_hint(ret, repo, inp) - assert "SKILL.md" in updated.stderr - - def test_no_hint_for_success(self): - from trpc_agent_sdk.code_executors import WorkspaceRunResult - tool = self._make_tool() - ret = WorkspaceRunResult( - stdout="ok", stderr="", exit_code=0, duration=0, timed_out=False, - ) - repo = MagicMock() - inp = SkillRunInput(skill="test", command="python run.py") - updated = tool._with_skill_doc_command_hint(ret, repo, inp) - assert updated.stderr == "" - - def test_repo_exception_returns_unchanged(self): - from trpc_agent_sdk.code_executors import WorkspaceRunResult - tool = self._make_tool() - ret = WorkspaceRunResult( - stdout="", stderr="bash: x: command not found", - exit_code=127, duration=0, timed_out=False, - ) - repo = MagicMock() - repo.get = MagicMock(side_effect=RuntimeError("fail")) - inp = SkillRunInput(skill="test", command="x") - updated = tool._with_skill_doc_command_hint(ret, repo, inp) - assert updated is ret - - -# --------------------------------------------------------------------------- -# SkillRunTool — _suggest_commands/_suggest_tools -# --------------------------------------------------------------------------- - -class TestSuggestMethods: - def _make_tool(self): - repo = MagicMock() - repo.workspace_runtime = MagicMock() - return SkillRunTool(repository=repo) - - def test_suggest_commands_for_missing(self): - from trpc_agent_sdk.code_executors import WorkspaceRunResult - tool = self._make_tool() - ret = WorkspaceRunResult( - stdout="", stderr="command not found", - exit_code=127, duration=0, timed_out=False, - ) - skill = MagicMock() - skill.body = "```\npython run.py\n```\n" - repo = MagicMock() - repo.get = MagicMock(return_value=skill) - result = tool._suggest_commands_for_missing_command(ret, repo, "test") - assert result is not None - - def test_suggest_commands_no_missing(self): - from trpc_agent_sdk.code_executors import WorkspaceRunResult - tool = self._make_tool() - ret = WorkspaceRunResult( - stdout="ok", stderr="", exit_code=0, duration=0, timed_out=False, - ) - repo = MagicMock() - result = tool._suggest_commands_for_missing_command(ret, repo, "test") - assert result is None - - def test_suggest_tools_for_missing(self): - from trpc_agent_sdk.code_executors import WorkspaceRunResult - tool = self._make_tool() - ret = WorkspaceRunResult( - stdout="", stderr="command not found", - exit_code=127, duration=0, timed_out=False, - ) - skill = MagicMock() - skill.tools = ["get_data"] - repo = MagicMock() - repo.get = MagicMock(return_value=skill) - result = tool._suggest_tools_for_missing_command(ret, repo, "test") - assert result == ["get_data"] - - def test_suggest_tools_no_tools(self): - from trpc_agent_sdk.code_executors import WorkspaceRunResult - tool = self._make_tool() - ret = WorkspaceRunResult( - stdout="", stderr="command not found", - exit_code=127, duration=0, timed_out=False, - ) - skill = MagicMock() - skill.tools = [] - repo = MagicMock() - repo.get = MagicMock(return_value=skill) - result = tool._suggest_tools_for_missing_command(ret, repo, "test") - assert result is None - - def test_suggest_commands_repo_error(self): - from trpc_agent_sdk.code_executors import WorkspaceRunResult - tool = self._make_tool() - ret = WorkspaceRunResult( - stdout="", stderr="command not found", - exit_code=127, duration=0, timed_out=False, - ) - repo = MagicMock() - repo.get = MagicMock(side_effect=RuntimeError("fail")) - result = tool._suggest_commands_for_missing_command(ret, repo, "test") - assert result is None - - -# --------------------------------------------------------------------------- -# SkillRunTool — _precheck_inline_python_rewrite -# --------------------------------------------------------------------------- - -class TestPrecheckInlinePythonRewrite: - def test_not_blocked_when_disabled(self): - repo = MagicMock() - repo.workspace_runtime = MagicMock() - tool = SkillRunTool(repository=repo, block_inline_python_rewrite=False) - inp = SkillRunInput(skill="test", command="python -c 'print(1)'") - assert tool._precheck_inline_python_rewrite(repo, inp) is None - - def test_not_python_c(self): - repo = MagicMock() - repo.workspace_runtime = MagicMock() - tool = SkillRunTool(repository=repo, block_inline_python_rewrite=True) - inp = SkillRunInput(skill="test", command="python run.py") - assert tool._precheck_inline_python_rewrite(repo, inp) is None - - def test_blocked_with_script_examples(self): - repo = MagicMock() - repo.workspace_runtime = MagicMock() - skill = MagicMock() - skill.body = "```\npython3 scripts/analyze.py --input data\n```\n" - repo.get = MagicMock(return_value=skill) - tool = SkillRunTool(repository=repo, block_inline_python_rewrite=True) - inp = SkillRunInput(skill="test", command="python -c 'import sys; print(sys.argv)'") - result = tool._precheck_inline_python_rewrite(repo, inp) - assert result is not None - assert result.exit_code == 2 - - def test_not_blocked_without_script_examples(self): - repo = MagicMock() - repo.workspace_runtime = MagicMock() - skill = MagicMock() - skill.body = "```\necho hello\n```\n" - repo.get = MagicMock(return_value=skill) - tool = SkillRunTool(repository=repo, block_inline_python_rewrite=True) - inp = SkillRunInput(skill="test", command="python -c 'print(1)'") - result = tool._precheck_inline_python_rewrite(repo, inp) - assert result is None - - def test_repo_error_returns_none(self): - repo = MagicMock() - repo.workspace_runtime = MagicMock() - repo.get = MagicMock(side_effect=RuntimeError("fail")) - tool = SkillRunTool(repository=repo, block_inline_python_rewrite=True) - inp = SkillRunInput(skill="test", command="python -c 'print(1)'") - result = tool._precheck_inline_python_rewrite(repo, inp) - assert result is None - - -# --------------------------------------------------------------------------- -# SkillRunTool — _list_entrypoint_suggestions -# --------------------------------------------------------------------------- - -class TestListEntrypointSuggestions: - def test_finds_scripts(self, tmp_path): - scripts = tmp_path / "scripts" - scripts.mkdir() - (scripts / "run.py").write_text("#!/usr/bin/env python") - result = SkillRunTool._list_entrypoint_suggestions(tmp_path) - assert any("run.py" in r for r in result) - - def test_empty_dir(self, tmp_path): - assert SkillRunTool._list_entrypoint_suggestions(tmp_path) == [] - - def test_limit(self, tmp_path): - scripts = tmp_path / "scripts" - scripts.mkdir() - for i in range(30): - (scripts / f"script_{i}.py").write_text(f"# script {i}") - result = SkillRunTool._list_entrypoint_suggestions(tmp_path, limit=5) - assert len(result) <= 5 - - -# --------------------------------------------------------------------------- -# SkillRunTool — _with_missing_entrypoint_hint -# --------------------------------------------------------------------------- - -class TestWithMissingEntrypointHint: - def test_adds_hint_for_missing_file(self, tmp_path): - from trpc_agent_sdk.code_executors import WorkspaceRunResult - scripts = tmp_path / "skills" / "test" / "scripts" - scripts.mkdir(parents=True) - (scripts / "real.py").write_text("# real") - ws = MagicMock() - ws.path = str(tmp_path) - ret = WorkspaceRunResult( - stdout="", stderr="No such file or directory", - exit_code=1, duration=0, timed_out=False, - ) - inp = SkillRunInput(skill="test", command="python scripts/missing.py") - updated = SkillRunTool._with_missing_entrypoint_hint(ret, inp, ws, "skills/test") - assert "hint" in updated.stderr.lower() or "entrypoint" in updated.stderr.lower() - - def test_no_hint_for_success(self): - from trpc_agent_sdk.code_executors import WorkspaceRunResult - ret = WorkspaceRunResult( - stdout="ok", stderr="", exit_code=0, duration=0, timed_out=False, - ) - ws = MagicMock() - ws.path = "/tmp/ws" - inp = SkillRunInput(skill="test", command="python run.py") - updated = SkillRunTool._with_missing_entrypoint_hint(ret, inp, ws, "skills/test") - assert updated is ret - - -# --------------------------------------------------------------------------- -# SkillRunTool — _merge_manifest_artifact_refs -# --------------------------------------------------------------------------- - -class TestMergeManifestArtifactRefs: - def _make_tool(self): - repo = MagicMock() - repo.workspace_runtime = MagicMock() - return SkillRunTool(repository=repo) - - def test_none_manifest_noop(self): - tool = self._make_tool() - output = SkillRunOutput() - tool._merge_manifest_artifact_refs(None, output) - assert output.artifact_files == [] - - def test_already_has_artifacts_noop(self): - tool = self._make_tool() - output = SkillRunOutput(artifact_files=[ArtifactInfo(name="x", version=1)]) - manifest = MagicMock() - tool._merge_manifest_artifact_refs(manifest, output) - assert len(output.artifact_files) == 1 - - def test_merges_from_manifest(self): - tool = self._make_tool() - output = SkillRunOutput() - manifest = MagicMock() - fr = MagicMock() - fr.saved_as = "artifact.txt" - fr.version = 2 - manifest.files = [fr] - tool._merge_manifest_artifact_refs(manifest, output) - assert len(output.artifact_files) == 1 - assert output.artifact_files[0].name == "artifact.txt" - - def test_skips_unsaved(self): - tool = self._make_tool() - output = SkillRunOutput() - manifest = MagicMock() - fr = MagicMock() - fr.saved_as = "" - manifest.files = [fr] - tool._merge_manifest_artifact_refs(manifest, output) - assert output.artifact_files == [] - - -# --------------------------------------------------------------------------- -# SkillRunTool — _to_run_file / _to_run_files -# --------------------------------------------------------------------------- - -class TestToRunFile: - def _make_tool(self): - repo = MagicMock() - repo.workspace_runtime = MagicMock() - return SkillRunTool(repository=repo) - - def test_text_file(self): - from trpc_agent_sdk.code_executors import CodeFile - tool = self._make_tool() - cf = CodeFile(name="test.txt", content="hello", mime_type="text/plain", size_bytes=5) - rf = tool._to_run_file(cf) - assert rf.name == "test.txt" - assert rf.content == "hello" - assert rf.ref == "workspace://test.txt" - - def test_binary_file_omits_content(self): - from trpc_agent_sdk.code_executors import CodeFile - tool = self._make_tool() - cf = CodeFile(name="img.png", content="binary", mime_type="image/png", size_bytes=100) - rf = tool._to_run_file(cf) - assert rf.content == "" - - def test_to_run_files(self): - from trpc_agent_sdk.code_executors import CodeFile - tool = self._make_tool() - files = [ - CodeFile(name="a.txt", content="a", mime_type="text/plain", size_bytes=1), - CodeFile(name="b.txt", content="b", mime_type="text/plain", size_bytes=1), - ] - result = tool._to_run_files(files) - assert len(result) == 2 - - -# --------------------------------------------------------------------------- -# SkillRunTool — _prepare_editor_env -# --------------------------------------------------------------------------- - -class TestPrepareEditorEnv: - def _make_tool(self): - repo = MagicMock() - repo.workspace_runtime = MagicMock() - return SkillRunTool(repository=repo) - - async def test_empty_editor_text_noop(self): - tool = self._make_tool() - ctx = MagicMock() - ws = MagicMock() - ws.path = "/tmp/ws" - env = {} - await tool._prepare_editor_env(ctx, ws, env, "") - assert "EDITOR" not in env - - async def test_editor_env_conflict_raises(self): - tool = self._make_tool() - ctx = MagicMock() - ws = MagicMock() - ws.path = "/tmp/ws" - env = {"EDITOR": "/usr/bin/vim"} - with pytest.raises(ValueError, match="editor_text cannot be combined"): - await tool._prepare_editor_env(ctx, ws, env, "some text") - - async def test_visual_env_conflict_raises(self): - tool = self._make_tool() - ctx = MagicMock() - ws = MagicMock() - ws.path = "/tmp/ws" - env = {"VISUAL": "/usr/bin/vim"} - with pytest.raises(ValueError, match="editor_text cannot be combined"): - await tool._prepare_editor_env(ctx, ws, env, "some text") - - async def test_stages_editor_files(self, tmp_path): - repo = MagicMock() - fs = MagicMock() - fs.put_files = AsyncMock() - runtime = MagicMock() - runtime.fs = MagicMock(return_value=fs) - repo.workspace_runtime = runtime - tool = SkillRunTool(repository=repo) - + def test_is_skill_loaded(self): + tool = _make_tool() ctx = MagicMock() - ws = MagicMock() - ws.path = str(tmp_path) - env = {} - await tool._prepare_editor_env(ctx, ws, env, "editor content") - assert "EDITOR" in env - assert "VISUAL" in env - - async def test_fallback_to_local_write(self, tmp_path): - repo = MagicMock() - fs = MagicMock() - fs.put_files = AsyncMock(side_effect=RuntimeError("workspace unavailable")) - runtime = MagicMock() - runtime.fs = MagicMock(return_value=fs) - repo.workspace_runtime = runtime - tool = SkillRunTool(repository=repo) - - ctx = MagicMock() - ws = MagicMock() - ws.path = str(tmp_path) - env = {} - await tool._prepare_editor_env(ctx, ws, env, "editor content") - assert "EDITOR" in env - - -# --------------------------------------------------------------------------- -# SkillRunTool — _attach_artifacts_if_requested -# --------------------------------------------------------------------------- - -class TestAttachArtifacts: - def _make_tool(self): - repo = MagicMock() - repo.workspace_runtime = MagicMock() - return SkillRunTool(repository=repo) - - async def test_no_files_noop(self): - tool = self._make_tool() - ctx = MagicMock() - ws = MagicMock() - inp = SkillRunInput(skill="test", command="echo", save_as_artifacts=True) - output = SkillRunOutput() - await tool._attach_artifacts_if_requested(ctx, ws, inp, output, []) - - async def test_not_requested_noop(self): - tool = self._make_tool() - ctx = MagicMock() - ws = MagicMock() - inp = SkillRunInput(skill="test", command="echo", save_as_artifacts=False) - output = SkillRunOutput() - files = [SkillRunFile(name="a.txt", content="hello")] - await tool._attach_artifacts_if_requested(ctx, ws, inp, output, files) - assert output.artifact_files == [] - - async def test_no_artifact_service_warns(self): - tool = self._make_tool() - ctx = MagicMock() - ctx.artifact_service = None - ws = MagicMock() - inp = SkillRunInput(skill="test", command="echo", save_as_artifacts=True) - output = SkillRunOutput() - files = [SkillRunFile(name="a.txt", content="hello")] - await tool._attach_artifacts_if_requested(ctx, ws, inp, output, files) - assert any("not configured" in w for w in output.warnings) - - async def test_save_artifacts(self): - tool = self._make_tool() - ctx = MagicMock() - ctx.artifact_service = MagicMock() - ctx.save_artifact = AsyncMock(return_value=1) - ws = MagicMock() - inp = SkillRunInput(skill="test", command="echo", save_as_artifacts=True) - output = SkillRunOutput() - files = [SkillRunFile(name="a.txt", content="hello", mime_type="text/plain")] - await tool._attach_artifacts_if_requested(ctx, ws, inp, output, files) - assert len(output.artifact_files) == 1 - assert output.artifact_files[0].name == "a.txt" - - async def test_save_artifacts_with_prefix(self): - tool = self._make_tool() - ctx = MagicMock() - ctx.artifact_service = MagicMock() - ctx.save_artifact = AsyncMock(return_value=1) - ws = MagicMock() - inp = SkillRunInput(skill="test", command="echo", save_as_artifacts=True, artifact_prefix="out/") - output = SkillRunOutput() - files = [SkillRunFile(name="a.txt", content="hello", mime_type="text/plain")] - await tool._attach_artifacts_if_requested(ctx, ws, inp, output, files) - assert output.artifact_files[0].name == "out/a.txt" - - -# --------------------------------------------------------------------------- -# SkillRunTool — _prepare_outputs -# --------------------------------------------------------------------------- - -class TestPrepareOutputs: - def _make_tool(self): - repo = MagicMock() - fs = MagicMock() - runtime = MagicMock() - runtime.fs = MagicMock(return_value=fs) - repo.workspace_runtime = runtime - return SkillRunTool(repository=repo), fs - - async def test_no_outputs(self): - tool, fs = self._make_tool() - ctx = MagicMock() - ws = MagicMock() - inp = SkillRunInput(skill="test", command="echo") - files, manifest = await tool._prepare_outputs(ctx, ws, inp) - assert files == [] - assert manifest is None - - async def test_output_files_patterns(self): - tool, fs = self._make_tool() - from trpc_agent_sdk.code_executors import CodeFile - fs.collect = AsyncMock(return_value=[ - CodeFile(name="out/a.txt", content="hello", mime_type="text/plain", size_bytes=5) - ]) - ctx = MagicMock() - ws = MagicMock() - inp = SkillRunInput(skill="test", command="echo", output_files=["out/*.txt"]) - files, manifest = await tool._prepare_outputs(ctx, ws, inp) - assert len(files) == 1 - assert manifest is None + ctx.agent_name = "" + ctx.actions = MagicMock() + key = loaded_state_key(ctx, "test") + ctx.actions.state_delta = {key: True} + ctx.session_state = {} + assert tool._is_skill_loaded(ctx, "test") is True diff --git a/tests/skills/tools/test_skill_select_docs.py b/tests/skills/tools/test_skill_select_docs.py index f52660b..1d8577e 100644 --- a/tests/skills/tools/test_skill_select_docs.py +++ b/tests/skills/tools/test_skill_select_docs.py @@ -17,16 +17,22 @@ import pytest +from trpc_agent_sdk.skills._constants import SKILL_CONFIG_KEY from trpc_agent_sdk.skills.tools._skill_select_docs import ( SkillSelectDocsResult, skill_select_docs, ) +from trpc_agent_sdk.skills._common import docs_state_key +from trpc_agent_sdk.skills._skill_config import DEFAULT_SKILL_CONFIG def _make_ctx(state_delta=None, session_state=None): ctx = MagicMock() ctx.actions.state_delta = state_delta or {} ctx.session_state = session_state or {} + ctx.agent_name = "" + ctx.agent_context.get_metadata = MagicMock( + side_effect=lambda key, default=None: DEFAULT_SKILL_CONFIG if key == SKILL_CONFIG_KEY else default) return ctx @@ -76,18 +82,16 @@ def test_replace_mode(self): assert result.selected_docs == ["doc1.md", "doc2.md"] def test_add_mode(self): - ctx = _make_ctx(session_state={ - "temp:skill:docs:test-skill": json.dumps(["existing.md"]), - }) + ctx = _make_ctx() + ctx.session_state = {docs_state_key(ctx, "test-skill"): json.dumps(["existing.md"])} result = skill_select_docs(ctx, "test-skill", docs=["new.md"], mode="add") assert result.mode == "add" assert "existing.md" in result.selected_docs assert "new.md" in result.selected_docs def test_clear_mode(self): - ctx = _make_ctx(session_state={ - "temp:skill:docs:test-skill": json.dumps(["some.md"]), - }) + ctx = _make_ctx() + ctx.session_state = {docs_state_key(ctx, "test-skill"): json.dumps(["some.md"])} result = skill_select_docs(ctx, "test-skill", mode="clear") assert result.mode == "clear" assert result.selected_docs == [] @@ -100,5 +104,5 @@ def test_include_all_docs(self): def test_updates_state_delta(self): ctx = _make_ctx() skill_select_docs(ctx, "test-skill", docs=["a.md"], mode="replace") - key = "temp:skill:docs:test-skill" + key = docs_state_key(ctx, "test-skill") assert key in ctx.actions.state_delta diff --git a/tests/skills/tools/test_skill_select_tools.py b/tests/skills/tools/test_skill_select_tools.py index 6bd2489..9b9fe98 100644 --- a/tests/skills/tools/test_skill_select_tools.py +++ b/tests/skills/tools/test_skill_select_tools.py @@ -17,16 +17,22 @@ import pytest +from trpc_agent_sdk.skills._constants import SKILL_CONFIG_KEY from trpc_agent_sdk.skills.tools._skill_select_tools import ( SkillSelectToolsResult, skill_select_tools, ) +from trpc_agent_sdk.skills._common import tool_state_key +from trpc_agent_sdk.skills._skill_config import DEFAULT_SKILL_CONFIG def _make_ctx(state_delta=None, session_state=None): ctx = MagicMock() ctx.actions.state_delta = state_delta or {} ctx.session_state = session_state or {} + ctx.agent_name = "" + ctx.agent_context.get_metadata = MagicMock( + side_effect=lambda key, default=None: DEFAULT_SKILL_CONFIG if key == SKILL_CONFIG_KEY else default) return ctx @@ -74,18 +80,16 @@ def test_replace_mode(self): assert result.selected_tools == ["tool_a", "tool_b"] def test_add_mode(self): - ctx = _make_ctx(session_state={ - "temp:skill:tools:test-skill": json.dumps(["existing_tool"]), - }) + ctx = _make_ctx() + ctx.session_state = {tool_state_key(ctx, "test-skill"): json.dumps(["existing_tool"])} result = skill_select_tools(ctx, "test-skill", tools=["new_tool"], mode="add") assert result.mode == "add" assert "existing_tool" in result.selected_tools assert "new_tool" in result.selected_tools def test_clear_mode(self): - ctx = _make_ctx(session_state={ - "temp:skill:tools:test-skill": json.dumps(["tool"]), - }) + ctx = _make_ctx() + ctx.session_state = {tool_state_key(ctx, "test-skill"): json.dumps(["tool"])} result = skill_select_tools(ctx, "test-skill", mode="clear") assert result.mode == "clear" assert result.selected_tools == [] @@ -98,5 +102,5 @@ def test_include_all_tools(self): def test_updates_state_delta(self): ctx = _make_ctx() skill_select_tools(ctx, "test-skill", tools=["t1"], mode="replace") - key = "temp:skill:tools:test-skill" + key = tool_state_key(ctx, "test-skill") assert key in ctx.actions.state_delta diff --git a/tests/skills/tools/test_workspace_exec.py b/tests/skills/tools/test_workspace_exec.py new file mode 100644 index 0000000..467e6a9 --- /dev/null +++ b/tests/skills/tools/test_workspace_exec.py @@ -0,0 +1,34 @@ +import pytest + +from trpc_agent_sdk.code_executors import PROGRAM_STATUS_RUNNING +from trpc_agent_sdk.code_executors import ProgramPoll +from trpc_agent_sdk.skills.tools._workspace_exec import _combine_output +from trpc_agent_sdk.skills.tools._workspace_exec import _exec_timeout_seconds +from trpc_agent_sdk.skills.tools._workspace_exec import _exec_yield_seconds +from trpc_agent_sdk.skills.tools._workspace_exec import _normalize_cwd +from trpc_agent_sdk.skills.tools._workspace_exec import _poll_output +from trpc_agent_sdk.skills.tools._workspace_exec import _write_yield_seconds + + +def test_normalize_cwd(): + assert _normalize_cwd("") == "." + assert _normalize_cwd("work/demo") == "work/demo" + with pytest.raises(ValueError, match="within the workspace"): + _normalize_cwd("../demo") + + +def test_timeout_and_yield_helpers(): + assert _exec_timeout_seconds(0) > 0 + assert _exec_timeout_seconds(3) == 3.0 + assert _exec_yield_seconds(background=True, raw_ms=None) == 0.0 + assert _exec_yield_seconds(background=False, raw_ms=100) == 0.1 + assert _write_yield_seconds(None) > 0.0 + assert _write_yield_seconds(-1) == 0.0 + + +def test_poll_output_and_combine_output(): + poll = ProgramPoll(status=PROGRAM_STATUS_RUNNING, output="ok", offset=1, next_offset=2) + out = _poll_output("sid-1", poll) + assert out["session_id"] == "sid-1" + assert out["output"] == "ok" + assert _combine_output("a", "b") == "ab" diff --git a/trpc_agent_sdk/agents/core/README.md b/trpc_agent_sdk/agents/core/README.md new file mode 100644 index 0000000..4dbbaa2 --- /dev/null +++ b/trpc_agent_sdk/agents/core/README.md @@ -0,0 +1,249 @@ +# trpc_agent/agents/core 说明 + +本文档面向读者解释 [`trpc_agent/agents/core`](./) 中与 Skills 相关的请求处理逻辑,重点覆盖三个处理器: + +- `SkillsRequestProcessor`([`_skill_processor.py`](./_skill_processor.py)) +- `WorkspaceExecRequestProcessor`([`_workspace_exec_processor.py`](./_workspace_exec_processor.py)) +- `SkillsToolResultRequestProcessor`([`_skills_tool_result_processor.py`](./_skills_tool_result_processor.py)) + +文档目标是回答三个问题: + +1. 每个处理器解决什么问题 +2. 处理器在请求流水线中的位置与协作关系 +3. 如何在运行时判断“功能已生效” + +## 1. 请求流水线中的职责划分 + +在 `RequestProcessor` 的技能相关路径中(简化): + +1. 组装基础 instruction +2. 注入 tools +3. `SkillsRequestProcessor`:注入 skills 总览与(可选)已加载内容 +4. `WorkspaceExecRequestProcessor`:注入 `workspace_exec` guidance +5. 注入会话历史 +6. `SkillsToolResultRequestProcessor`:在 post-history 阶段做 tool result 物化 +7. 其他能力(planning/output schema 等) + +可理解为: + +- `SkillsRequestProcessor` 负责“技能上下文主编排” +- `WorkspaceExecRequestProcessor` 负责“执行器工具选择引导” +- `SkillsToolResultRequestProcessor` 负责“tool_result_mode 下的内容物化补强” + +## 2. SkillsRequestProcessor(核心入口) + +主入口: + +- `SkillsRequestProcessor.process_llm_request(ctx, request)` + +### 2.1 解决的问题 + +模型在多技能场景需要两类信息: + +- 可用 skill 总览(有哪些技能、各自做什么) +- 已加载 skill 的正文/文档/工具选择(当前上下文真正可用什么) + +`SkillsRequestProcessor` 提供统一策略来管理这些信息,并在 `turn/once/session` 三种加载模式下保持行为一致。 + +### 2.2 核心行为 + +一次调用中,典型流程如下: + +1. 获取 skill repo(支持 `repo_resolver(ctx)` 动态仓库) +2. 执行旧状态迁移(兼容历史 key)与 turn 模式清理 +3. 注入 skill 概览(始终执行) +4. 读取 loaded skills 并按 `max_loaded_skills` 裁剪 +5. `load_mode=session` 在 temp-only 设计下与 `turn` 读取语义一致 +6. 根据 `tool_result_mode` 分流: + - `False`:直接向 system instruction 注入 `[Loaded]`、`Docs loaded`、`[Doc]` + - `True`:跳过注入,交给 `SkillsToolResultRequestProcessor` 处理 +7. `load_mode=once` 时清理 loaded/docs/tools/order 状态(offload) + +### 2.3 状态语义 + +- `turn` + - 每次 invocation 开始清理一次技能状态 + - 对应:`_maybe_clear_skill_state_for_turn` +- `once` + - 本轮用完后清理,避免持续占用上下文 + - 对应:`_maybe_offload_loaded_skills` +- `session` + - 在 temp-only 状态模型下,不再维护 `user:skill:*` 双键 + - 对应:`_maybe_promote_skill_state_for_session`(当前为 no-op) + +读取策略为单一 temp key 读取(`session_state + state_delta` 视图)。 + +### 2.4 关键参数 + +- `load_mode`: `turn` / `once` / `session` +- `tool_result_mode`: 是否改为 tool result materialization 路径 +- `tool_profile` / `allowed_skill_tools` / `tool_flags`: 限制可用技能工具能力面 +- `exec_tools_disabled`: 关闭交互执行 guidance +- `repo_resolver`: invocation 级仓库解析 +- `max_loaded_skills`: loaded 上限(超限按顺序淘汰) + +参数入口: + +- `set_skill_processor_parameters(agent_context, parameters)` + +## 3. WorkspaceExecRequestProcessor(workspace_exec guidance) + +对应实现: + +- [`trpc_agent/agents/core/_workspace_exec_processor.py`](./_workspace_exec_processor.py) + +### 3.1 解决的问题 + +在多工具场景下,模型容易混淆: + +- 什么时候使用 `workspace_exec`(通用 shell) +- 什么时候使用 `skill_run`(技能内部执行) +- `workspace_exec` 的路径边界、会话工具、artifact 保存边界 + +处理器通过注入统一 guidance,降低误用和误判。 + +### 3.2 主要行为 + +`process_llm_request(ctx, request)` 典型步骤: + +1. 判断是否启用 guidance + - 默认按 request tools 是否包含 `workspace_exec` + - 支持 `enabled_resolver` 动态开关 +2. 生成 guidance 主体 + - 通用 `workspace_exec` 使用建议 + - `work/out/runs` 路径建议 + - “先用小命令验证环境限制”的原则 +3. 按能力追加段落 + - 有 `workspace_save_artifact`:追加 artifact 保存边界说明 + - 有 skills repo:提示 `skills/` 目录并非自动 stage + - 有会话工具:追加 `workspace_write_stdin` / `workspace_kill_session` 生命周期提示 +4. 去重注入 + - 若已存在 `Executor workspace guidance:` header,则不重复追加 + +### 3.3 行为示例 + +当工具列表包含: + +- `workspace_exec` +- `workspace_write_stdin` +- `workspace_kill_session` +- `workspace_save_artifact` + +且 agent 绑定 skill repository 时,system instruction 会引导模型: + +- 通用 shell 优先走 `workspace_exec` +- 路径优先 `work/`、`out/`、`runs/` +- 限制不先假设,先验证 +- 仅在需要稳定引用时再调用 `workspace_save_artifact` + +### 3.4 常见误区 + +- 误区:`workspace_exec` 会自动准备 `skills/` 内容 + 实际:是否存在 `skills/...` 取决于是否有其他工具先 stage + +- 误区:遇到限制直接下结论“环境不支持” + 实际:应先做有界验证 + +- 误区:所有输出都必须保存 artifact + 实际:应按稳定引用需求再保存 + +### 3.5 如何验证生效 + +优先看“发给模型前的请求”而非仅看终端事件: + +1. `request.config.system_instruction` 是否包含 `Executor workspace guidance:` +2. 是否只注入一次(无重复 header) +3. 工具选择行为是否符合预期(通用 shell 走 `workspace_exec`) + +## 4. SkillsToolResultRequestProcessor(tool result 物化) + +对应实现: + +- [`trpc_agent/agents/core/_skills_tool_result_processor.py`](./_skills_tool_result_processor.py) + +### 4.1 解决的问题 + +仅靠 `skill_load` 的短回包(例如 `"skill 'python-math' loaded"`),模型往往拿不到可执行细节。 +这个处理器负责把“已加载 skill 的实质内容”物化到模型当前请求上下文。 + +### 4.2 主要行为 + +处理器会: + +1. 从 `session_state + state_delta` 读取已加载 skill +2. 在 `LlmRequest.contents` 中定位最近的 `skill_load` / `skill_select_docs` response +3. 条件满足时改写 response,注入: + - `[Loaded] ` + - `Docs loaded: ...` + - `[Doc] ...` +4. 若本轮没有可改写 response,fallback 到 system instruction 追加 `Loaded skill context:` +5. `load_mode=once` 时按策略清理 loaded/docs 状态 + +### 4.3 与 SkillsRequestProcessor 的分工 + +- `tool_result_mode=False` + - 由 `SkillsRequestProcessor` 直接注入 loaded 内容 +- `tool_result_mode=True` + - `SkillsRequestProcessor` 不注入 loaded 内容 + - `SkillsToolResultRequestProcessor` 在 post-history 做物化 + +### 4.4 最小示例 + +进入处理器前: + +- function call: `skill_load(demo-skill)` +- function response: `{"result":"skill 'demo-skill' loaded"}` + +处理器后(示意): + +```text +{ + "result": "[Loaded] demo-skill\n\n\n\nDocs loaded: docs/guide.md\n\n[Doc] docs/guide.md\n\n" +} +``` + +即使没有对应 tool response 可改写,也会通过 system instruction fallback 注入已加载上下文。 + +### 4.5 什么时候会被误判为“没生效” + +最常见误区:只看终端工具即时回包。 +该处理器真实生效点是“发给模型前的请求内容”,与外层事件流并不总是 1:1。 + +建议观察: + +- `request.config.system_instruction` +- `request.contents` 里的 function response 是否已被改写 + +### 4.6 参数入口 + +- `load_mode` +- `skip_fallback_on_session_summary` +- `repo_resolver` + +通过: + +- `set_skill_tool_result_processor_parameters(agent_context, {...})` + +注入请求构建链路。 + +## 5. 测试语义映射(与 examples 对齐) + +可配合 [`examples/skills/run_agent.py`](../../../examples/skills/run_agent.py) 与 +[`examples/skills/README.md`](../../../examples/skills/README.md) 观察实际行为。 + +- `workspace_exec_guidance` 类测试 + - 关注工具选择行为是否被 guidance 纠偏 + - 核心断言是 `workspace_exec` 与 `skill_run` 调用分布 + +- `skills_tool_result_mode` 类测试 + - 关注 materialization 信号是否出现(`[Loaded]` / `Docs loaded` / `[Doc]`) + - 允许“请求层可见但终端不完全回显”的情况 + +## 6. 给读者的排障建议 + +1. 先确认模式参数:`load_mode`、`tool_result_mode` +2. 再确认状态读写:`state_delta` 与 `session_state` 是否符合预期 +3. 最后看请求最终形态: + - 是否注入了 guidance + - 是否注入了 loaded context + - 是否发生了 offload/clear diff --git a/trpc_agent_sdk/agents/core/__init__.py b/trpc_agent_sdk/agents/core/__init__.py index 5797916..c8b5922 100644 --- a/trpc_agent_sdk/agents/core/__init__.py +++ b/trpc_agent_sdk/agents/core/__init__.py @@ -23,7 +23,13 @@ from ._request_processor import RequestProcessor from ._request_processor import default_request_processor from ._skill_processor import SkillsRequestProcessor +from ._skill_processor import get_skill_processor_parameters +from ._skill_processor import set_skill_processor_parameters +from ._skills_tool_result_processor import get_skill_tool_result_processor_parameters +from ._skills_tool_result_processor import set_skill_tool_result_processor_parameters from ._tools_processor import ToolsProcessor +from ._workspace_exec_processor import get_workspace_exec_processor_parameters +from ._workspace_exec_processor import set_workspace_exec_processor_parameters __all__ = [ "AgentTransferProcessor", @@ -40,5 +46,11 @@ "RequestProcessor", "default_request_processor", "SkillsRequestProcessor", + "get_skill_processor_parameters", + "set_skill_processor_parameters", + "get_skill_tool_result_processor_parameters", + "set_skill_tool_result_processor_parameters", "ToolsProcessor", + "get_workspace_exec_processor_parameters", + "set_workspace_exec_processor_parameters", ] diff --git a/trpc_agent_sdk/agents/core/_code_execution_processor.py b/trpc_agent_sdk/agents/core/_code_execution_processor.py index 976712e..e164135 100644 --- a/trpc_agent_sdk/agents/core/_code_execution_processor.py +++ b/trpc_agent_sdk/agents/core/_code_execution_processor.py @@ -1,8 +1,6 @@ -# Tencent is pleased to support the open source community by making tRPC-Agent-Python available. +# -*- coding: utf-8 -*- # -# Copyright (C) 2026 Tencent. All rights reserved. -# -# tRPC-Agent-Python is licensed under Apache-2.0. +# Copyright @ 2025 Tencent.com """Code execution processor for TRPC Agent framework. This module provides code execution processing capabilities for LLM agents, @@ -264,7 +262,8 @@ async def _run_post_processor( # content to the part with the first code block. response_content = llm_response.content code_blocks = CodeExecutionUtils.extract_code_and_truncate_content(response_content, - code_executor.code_block_delimiters) + code_executor.code_block_delimiters, + code_executor.ignore_codes) # Terminal state: no code to execute. if not code_blocks: return diff --git a/trpc_agent_sdk/agents/core/_request_processor.py b/trpc_agent_sdk/agents/core/_request_processor.py index 4640f43..0e414ae 100644 --- a/trpc_agent_sdk/agents/core/_request_processor.py +++ b/trpc_agent_sdk/agents/core/_request_processor.py @@ -1,8 +1,6 @@ # Tencent is pleased to support the open source community by making tRPC-Agent-Python available. # -# Copyright (C) 2026 Tencent. All rights reserved. -# -# tRPC-Agent-Python is licensed under Apache-2.0. +# Copyright @ 2025 Tencent.com """Request Processor implementation for TRPC Agent framework. This module provides the RequestProcessor class which handles building LlmRequest @@ -24,6 +22,7 @@ import copy import inspect import re +from typing import Any from typing import List from typing import Optional @@ -32,6 +31,7 @@ from trpc_agent_sdk.log import logger from trpc_agent_sdk.models import LlmRequest from trpc_agent_sdk.planners import default_planning_processor +from trpc_agent_sdk.skills import get_skill_processor_parameters from trpc_agent_sdk.tools import FunctionTool from trpc_agent_sdk.tools import transfer_to_agent from trpc_agent_sdk.types import Content @@ -43,7 +43,12 @@ from ._history_processor import HistoryProcessor from ._history_processor import TimelineFilterMode from ._skill_processor import SkillsRequestProcessor +from ._skill_processor import get_skill_processor_parameters +from ._skills_tool_result_processor import SkillsToolResultRequestProcessor +from ._skills_tool_result_processor import get_skill_tool_result_processor_parameters from ._tools_processor import ToolsProcessor +from ._workspace_exec_processor import WorkspaceExecRequestProcessor +from ._workspace_exec_processor import get_workspace_exec_processor_parameters class RequestProcessor: @@ -124,11 +129,18 @@ async def build_request( return error_event # 5. Add skills to the request - error_event = await self._add_skills_to_request(agent, ctx, request) + skill_parameters = get_skill_processor_parameters(ctx.agent_context) + error_event = await self._add_skills_to_request(agent, ctx, request, skill_parameters) + if error_event: + return error_event + + # 6. Add workspace_exec guidance (after skills, before history) + workspace_exec_parameters = get_workspace_exec_processor_parameters(ctx.agent_context) + error_event = await self._add_workspace_exec_guidance(agent, ctx, request, workspace_exec_parameters) if error_event: return error_event - # 6. Add conversation history (includes current user message in correct order) + # 7. Add conversation history (includes current user message in correct order) if override_messages is not None: # Use provided messages directly (for TeamAgent member control) for content in override_messages: @@ -150,12 +162,18 @@ async def build_request( if error_event: return error_event - # 7. Process planning if planner is available + # 8. Materialize loaded-skill content into tool results (post-history). + skill_tool_result_parameters = get_skill_tool_result_processor_parameters(ctx.agent_context) + error_event = await self._add_skills_tool_results(agent, ctx, request, skill_tool_result_parameters) + if error_event: + return error_event + + # 9. Process planning if planner is available error_event = await self._add_planning_capabilities(agent, ctx, request) if error_event: return error_event - # 8. Process output schema if needed (when tools are also present) + # 10. Process output schema if needed (when tools are also present) error_event = await self._add_output_schema_capabilities(agent, ctx, request) if error_event: return error_event @@ -303,8 +321,8 @@ async def _add_tools_to_request(self, agent: BaseAgent, ctx: InvocationContext, return None # Success - async def _add_skills_to_request(self, agent: BaseAgent, ctx: InvocationContext, - request: LlmRequest) -> Optional[Event]: + async def _add_skills_to_request(self, agent: BaseAgent, ctx: InvocationContext, request: LlmRequest, + parameters: dict[str, Any]) -> Optional[Event]: """Add skills to the model request. Args: @@ -318,13 +336,58 @@ async def _add_skills_to_request(self, agent: BaseAgent, ctx: InvocationContext, skill_repository = getattr(agent, 'skill_repository', None) if skill_repository: try: - skills_processor = SkillsRequestProcessor(skill_repository) + skills_processor = SkillsRequestProcessor(skill_repository, **parameters) skill_names = await skills_processor.process_llm_request(ctx, request) logger.debug("Processed %s skills for agent: %s", len(skill_names), agent.name) return None # Success except Exception as ex: # pylint: disable=broad-except logger.error("Error processing skills for agent %s: %s", agent.name, ex) return self._create_error_event(ctx, "skill_processing_error", f"Failed to process skills: {str(ex)}") + return None + + async def _add_workspace_exec_guidance(self, agent: BaseAgent, ctx: InvocationContext, request: LlmRequest, + parameters: dict[str, Any]) -> Optional[Event]: + """Inject workspace_exec guidance after skills and before history.""" + try: + skill_repository = getattr(agent, "skill_repository", None) + repo_resolver = parameters.get("repo_resolver") + processor = WorkspaceExecRequestProcessor( + has_skills_repo=bool(skill_repository), + repo_resolver=repo_resolver if callable(repo_resolver) else None, + ) + await processor.process_llm_request(ctx, request) + return None + except Exception as ex: # pylint: disable=broad-except + logger.error("Error injecting workspace_exec guidance for agent %s: %s", agent.name, ex) + return self._create_error_event( + ctx, + "workspace_exec_guidance_error", + f"Failed to inject workspace_exec guidance: {str(ex)}", + ) + + async def _add_skills_tool_results(self, agent: BaseAgent, ctx: InvocationContext, request: LlmRequest, + parameters: dict[str, Any]) -> Optional[Event]: + """Materialize loaded skills into tool results after history is attached.""" + if not parameters.get("tool_result_mode"): + return None + skill_repository = getattr(agent, "skill_repository", None) + if skill_repository is None: + return None + try: + processor = SkillsToolResultRequestProcessor( + skill_repository, + skip_fallback_on_session_summary=parameters.get("skip_fallback_on_session_summary", True), + repo_resolver=parameters.get("repo_resolver"), + ) + await processor.process_llm_request(ctx, request) + return None + except Exception as ex: # pylint: disable=broad-except + logger.error("Error processing skill tool results for agent %s: %s", agent.name, ex) + return self._create_error_event( + ctx, + "skill_tool_result_processing_error", + f"Failed to process skill tool results: {str(ex)}", + ) async def _add_agent_transfer_capabilities(self, agent: BaseAgent, ctx: InvocationContext, request: LlmRequest) -> Optional[Event]: @@ -879,7 +942,7 @@ def _apply_template_substitution(self, instruction: str, ctx: InvocationContext) {session_key}, etc. with actual values from the session state. This implementation is inspired by adk-python's inject_session_state but - adapted for trpc_agent_sdk's architecture. + adapted for trpc_agent's architecture. Args: instruction: The instruction string containing template placeholders @@ -920,7 +983,6 @@ def replace_placeholder(match): # This follows the behavior of the original SafeFormatter approach return match.group() - # Use regex pattern similar to adk-python but simpler for trpc_agent_sdk # This matches {variable_name} patterns including optional ones with ? pattern = r'\{[^{}]*\}' result = re.sub(pattern, replace_placeholder, instruction) diff --git a/trpc_agent_sdk/agents/core/_skill_processor.py b/trpc_agent_sdk/agents/core/_skill_processor.py index c34f6d8..1e1fbc2 100644 --- a/trpc_agent_sdk/agents/core/_skill_processor.py +++ b/trpc_agent_sdk/agents/core/_skill_processor.py @@ -4,55 +4,40 @@ # # tRPC-Agent-Python is licensed under Apache-2.0. """SkillsRequestProcessor — injects skill overviews and loaded contents. - -Mirrors ``internal/flow/processor/skills.go`` from trpc-agent-go, while -retaining Python-unique features (user_prompt, tools-selection summary). - -Behavior --------- -- Overview : always injected (skill names + descriptions). -- Loaded skills: full SKILL.md body injected into system prompt (or - deferred to tool-result mode). -- Docs : doc texts selected via session state keys. -- Tools : (Python-unique) tool selection summary for each loaded skill. - -Skill load modes ----------------- -- ``turn`` (default) – loaded skill content is available for all LLM - calls within the current invocation, then cleared at the start of the - next invocation. -- ``once`` – loaded skill content is injected once, then offloaded - (state keys cleared) immediately after injection. -- ``session`` – loaded skill content persists across invocations until - the session expires or state is cleared explicitly. """ from __future__ import annotations import json +from typing import Any from typing import Callable from typing import List from typing import Optional +from trpc_agent_sdk.context import AgentContext from trpc_agent_sdk.context import InvocationContext from trpc_agent_sdk.log import logger from trpc_agent_sdk.models import LlmRequest from trpc_agent_sdk.skills import BaseSkillRepository -from trpc_agent_sdk.skills import SKILL_DOCS_STATE_KEY_PREFIX -from trpc_agent_sdk.skills import SKILL_LOADED_STATE_KEY_PREFIX -from trpc_agent_sdk.skills import SKILL_TOOLS_STATE_KEY_PREFIX from trpc_agent_sdk.skills import Skill -from trpc_agent_sdk.skills import generic_get_selection - -# --------------------------------------------------------------------------- -# Load mode constants (mirrors Go SkillLoadModeXxx) -# --------------------------------------------------------------------------- - -SKILL_LOAD_MODE_ONCE = "once" -SKILL_LOAD_MODE_TURN = "turn" -SKILL_LOAD_MODE_SESSION = "session" - -_DEFAULT_SKILL_LOAD_MODE = SKILL_LOAD_MODE_TURN +from trpc_agent_sdk.skills import SkillLoadModeNames +from trpc_agent_sdk.skills import SkillProfileFlags +from trpc_agent_sdk.skills import SkillProfileNames +from trpc_agent_sdk.skills import SkillToolsNames +from trpc_agent_sdk.skills import docs_scan_prefix +from trpc_agent_sdk.skills import docs_state_key +from trpc_agent_sdk.skills import get_skill_config +from trpc_agent_sdk.skills import loaded_order_state_key +from trpc_agent_sdk.skills import loaded_scan_prefix +from trpc_agent_sdk.skills import loaded_state_key +from trpc_agent_sdk.skills import marshal_loaded_order +from trpc_agent_sdk.skills import parse_loaded_order +from trpc_agent_sdk.skills import set_skill_config +from trpc_agent_sdk.skills import tool_scan_prefix +from trpc_agent_sdk.skills import tool_state_key +from trpc_agent_sdk.skills import touch_loaded_order + +from ._skills_tool_result_processor import SKILL_LOADED_RE # --------------------------------------------------------------------------- # Prompt section headers (mirrors Go const block) @@ -62,49 +47,93 @@ _SKILLS_CAPABILITY_HEADER = "Skill tool availability:" _SKILLS_TOOLING_GUIDANCE_HEADER = "Tooling and workspace guidance:" -# --------------------------------------------------------------------------- -# Internal state keys -# --------------------------------------------------------------------------- +_SKILLS_TURN_INIT_STATE_KEY = "processor:skills:turn_init" + + +def normalize_load_mode(mode: str) -> str: + value = (mode or "").strip().lower() + if value in (SkillLoadModeNames.ONCE, SkillLoadModeNames.TURN, SkillLoadModeNames.SESSION): + return value + return SkillLoadModeNames.TURN + + +def _append_knowledge_guidance(lines: list[str], flags: SkillProfileFlags) -> None: + """Append docs-loading guidance mirroring Go's appendKnowledgeGuidance.""" + has_list_docs = flags.list_docs + has_select_docs = flags.select_docs + if has_list_docs and has_select_docs: + lines.append("- Use the available doc listing and selection helpers to keep" + " documentation loads targeted.\n") + elif has_list_docs: + lines.append("- Use the available doc listing helper to discover doc names," + " then load only the docs you need.\n") + elif has_select_docs: + lines.append("- If doc names are already known, use the available doc" + " selection helper to keep loaded docs targeted.\n") + else: + lines.append("- If you need docs, request them directly with skill_load.docs" + " or include_all_docs.\n") + lines.append("- Avoid include_all_docs unless the user asks or the task genuinely" + " needs the full doc set.\n") -# JSON array of skill names in load order — used by the max-cap eviction. -_SKILLS_LOADED_ORDER_STATE_KEY = "temp:skill:loaded_order" # --------------------------------------------------------------------------- -# Normalization helpers +# Guidance text builders (mirrors Go defaultXxxGuidance functions) # --------------------------------------------------------------------------- -def _normalize_load_mode(mode: str) -> str: - m = (mode or "").strip().lower() - if m in (SKILL_LOAD_MODE_ONCE, SKILL_LOAD_MODE_TURN, SKILL_LOAD_MODE_SESSION): - return m - return _DEFAULT_SKILL_LOAD_MODE - - -def _is_knowledge_only(profile: str) -> bool: - """Return True for profiles that support knowledge lookup only.""" - p = (profile or "").strip().lower().replace("-", "_") - return p in ("knowledge_only", "knowledge") +def _default_catalog_only_guidance() -> str: + return ("\n" + _SKILLS_TOOLING_GUIDANCE_HEADER + "\n" + + "- Use the skill overview as a catalog only. Built-in skill tools are" + " unavailable in this configuration; if a task depends on loading or" + " executing a skill, use other registered tools or explain the" + " limitation clearly.\n") -# --------------------------------------------------------------------------- -# Guidance text builders (mirrors Go defaultXxxGuidance functions) -# --------------------------------------------------------------------------- +def _default_doc_helpers_only_guidance(flags: SkillProfileFlags) -> str: + lines = [ + "\n", + _SKILLS_TOOLING_GUIDANCE_HEADER, + "\n", + ] + has_list_docs = flags.list_docs + has_select_docs = flags.select_docs + if has_list_docs and has_select_docs: + lines.append("- Use skills only to inspect available doc names or adjust" + " doc selection state.\n") + elif has_list_docs: + lines.append("- Use skills only to inspect available doc names.\n") + elif has_select_docs: + lines.append("- Use skills only to adjust doc selection when doc names are" + " already known.\n") + lines.append("- Built-in skill loading is unavailable, so doc helpers do not" + " inject SKILL.md or doc contents into context; if the task needs" + " loaded content or execution, use other registered tools or" + " explain the limitation clearly.\n") + return "".join(lines) -def _default_knowledge_only_guidance() -> str: - return ("\n" + _SKILLS_TOOLING_GUIDANCE_HEADER + "\n" + - "- Use skills for progressive disclosure only: load SKILL.md first," - " then inspect only the documentation needed for the current task.\n" + - "- Avoid include_all_docs unless the user asks or the task genuinely" - " needs the full doc set.\n" + "- Treat loaded skill content as domain guidance. Do not claim you" - " executed scripts, shell commands, or interactive flows described by" - " the skill.\n" + "- If a skill depends on execution to complete the task, switch to" - " other registered tools (for example, MCP tools) or explain the" - " limitation clearly.\n") +def _default_knowledge_only_guidance(flags: SkillProfileFlags) -> str: + lines = [ + "\n", + _SKILLS_TOOLING_GUIDANCE_HEADER, + "\n", + "- Use skills for progressive disclosure only: load SKILL.md first," + " then inspect only the documentation needed for the current task.\n", + ] + _append_knowledge_guidance(lines, flags) + lines += [ + "- Treat loaded skill content as domain guidance. Do not claim you" + " executed scripts, shell commands, or interactive flows described by" + " the skill.\n", + "- If a skill depends on execution to complete the task, switch to" + " other registered tools (for example, MCP tools) or explain the" + " limitation clearly.\n", + ] + return "".join(lines) -def _default_full_tooling_and_workspace_guidance(exec_tools_disabled: bool) -> str: +def _default_full_tooling_and_workspace_guidance(flags: SkillProfileFlags) -> str: lines: list[str] = [ "\n", _SKILLS_TOOLING_GUIDANCE_HEADER, @@ -132,6 +161,9 @@ def _default_full_tooling_and_workspace_guidance(exec_tools_disabled: bool) -> s "- Prefer writing new files under $OUTPUT_DIR or a skill's out/" " directory and include output_files globs (or an outputs spec) so" " files can be collected or saved as artifacts.\n", + "- Use stdout/stderr for logs or short status text. If the model needs" + " large or structured text, write it to files under $OUTPUT_DIR and" + " return it via output_files or outputs.\n", "- For Python skills that need third-party packages, create a virtualenv" " under the skill's .venv/ directory (it is writable inside the" " workspace).\n", @@ -153,69 +185,78 @@ def _default_full_tooling_and_workspace_guidance(exec_tools_disabled: bool) -> s "- When chaining multiple skills, read previous results from $OUTPUT_DIR" " (or a skill's out/ directory) instead of copying them back into inputs" " directories.\n", - "- Treat loaded skill docs as guidance, not perfect truth; when runtime" - " help or stderr disagrees, trust observed runtime behavior.\n", - "- Loading a skill gives you instructions and bundled resources; it does" - " not execute the skill by itself.\n", - "- The skill summaries above are routing summaries only; they do not" - " replace SKILL.md or other loaded docs.\n", - "- If the loaded content already provides enough guidance to answer or" - " produce the requested result, respond directly.\n", - "- A skill can still be executable even when it has no extra docs" - " or no custom tools. If SKILL.md provides runnable commands," - " proceed with skill_run using those commands.\n", - "- If a skill is not loaded, call skill_load; you may pass docs or" - " include_all_docs.\n", - "- If the body is loaded and docs are missing, treat docs as optional" - " unless the task explicitly requires extra references; then call" - " skill_select_docs or skill_load again to add docs.\n", - "- If the skill defines tools in its SKILL.md, they will be" - " automatically selected when you load the skill. You can refine tool" - " selection with skill_select_tools.\n", - "- If you decide to use a skill, load SKILL.md before", ] - if exec_tools_disabled: - lines.append(" the first skill_run for that skill, then load only" - " the docs you still need.\n") + if flags.load: + lines += [ + "- Treat loaded skill docs as guidance, not perfect truth; when runtime" + " help or stderr disagrees, trust observed runtime behavior.\n", + "- Loading a skill gives you instructions and bundled resources; it does" + " not execute the skill by itself.\n", + "- The skill summaries above are routing summaries only; they do not" + " replace SKILL.md or other loaded docs.\n", + "- If the loaded content already provides enough guidance to answer or" + " produce the requested result, respond directly.\n", + "- If you decide to use a skill, load SKILL.md before", + ] + if flags.requires_exec_session_tools(): + lines.append(" the first skill_run or skill_exec for that skill, then load" + " only the docs you still need.\n") + else: + lines.append(" the first skill_run for that skill, then load only the docs" + " you still need.\n") + lines += [ + "- Do not infer commands, script entrypoints, or resource layouts from" + " the short summary alone.\n", + ] + _append_knowledge_guidance(lines, flags) + elif flags.has_doc_helpers(): + lines += [ + "- Built-in skill loading is unavailable in this configuration. Doc" + " listing or selection helpers can inspect doc names or selection" + " state, but they do not inject SKILL.md or doc contents into" + " context.\n", + ] else: - lines.append(" the first skill_run or skill_exec for that skill," - " then load only the docs you still need.\n") + lines += [ + "- Built-in skill loading is unavailable in this configuration; do not" + " assume SKILL.md or doc contents are in context.\n", + ] + lines += [ - "- Do not infer commands, script entrypoints, or resource layouts" - " from the short summary alone.\n", - "- For docs, prefer skill_list_docs + skill_select_docs to load only" - " what you need.\n", - "- Avoid include_all_docs unless you need every doc or the user asks.\n", - "- Use execution tools only when running a command will reveal or" - " produce information or files you still need.\n", + "- Use execution tools only when running a command will reveal or produce" + " information or files you still need.\n", ] - if not exec_tools_disabled: + if flags.requires_exec_session_tools(): lines.append("- Use skill_exec only when a command needs incremental stdin or" " TTY-style interaction; otherwise prefer one-shot execution.\n") else: lines.append("- Do not assume interactive execution is available when only" " one-shot execution tools are present.\n") lines += [ - "- Prefer script-based commands from SKILL.md examples (for example," - " python3 scripts/foo.py) instead of ad-hoc python -c rewrites" - " unless the skill explicitly recommends inline execution.\n", - "- skill_run is a command runner inside the skill workspace, not a" - " magic capability. It does not automatically add the skill directory" - " to PATH or install dependencies; invoke scripts via an explicit" - " interpreter and path (e.g., python3 scripts/foo.py).\n", - "- When you execute, follow the tool description, loaded skill docs," - " bundled scripts, and observed runtime behavior rather than inventing" - " shell syntax or command arguments.\n", - "- If skill_list_tools returns command_examples, execute one of those" - " commands directly before trying ad-hoc shell alternatives.\n", + "- skill_run is a command runner inside the skill workspace, not a magic" + " capability. It does not automatically add the skill directory to PATH" + " or install dependencies; invoke scripts via an explicit interpreter and" + " path (e.g., python3 scripts/foo.py).\n", + "- When you execute, follow the tool description, ", ] + if flags.load: + lines[-1] += "loaded skill docs, " + lines[ + -1] += "bundled scripts, and observed runtime behavior rather than inventing shell syntax or command arguments.\n" return "".join(lines) -def _default_tooling_and_workspace_guidance(profile: str, exec_tools_disabled: bool) -> str: - if _is_knowledge_only(profile): - return _default_knowledge_only_guidance() - return _default_full_tooling_and_workspace_guidance(exec_tools_disabled) +def _default_tooling_and_workspace_guidance(flags: SkillProfileFlags) -> str: + if not flags.is_any(): + return _default_catalog_only_guidance() + + if not flags.run: + if flags.load: + return _default_knowledge_only_guidance(flags) + if flags.has_doc_helpers(): + return _default_doc_helpers_only_guidance(flags) + return _default_catalog_only_guidance() + return _default_full_tooling_and_workspace_guidance(flags) def _normalize_custom_guidance(guidance: str) -> str: @@ -236,8 +277,6 @@ def _normalize_custom_guidance(guidance: str) -> str: class SkillsRequestProcessor: """Injects skill overviews and loaded contents into LLM requests. - Mirrors Go's ``SkillsRequestProcessor``. - Args: skill_repository: Default skill repository. load_mode: ``"turn"`` (default), ``"once"``, or ``"session"``. @@ -247,6 +286,9 @@ class SkillsRequestProcessor: tool_result_mode: When ``True``, skip loaded-skill injection here (content is materialized into tool results instead). tool_profile: Profile string (e.g. ``"knowledge_only"``). + forbidden_tools: Optional explicit blacklist of built-in skill tools. + tool_flags: Optional resolved flags; when set, takes precedence + over ``tool_profile``/``forbidden_tools``. exec_tools_disabled: When ``True``, omit skill_exec guidance lines. repo_resolver: Optional ``(ctx) -> BaseSkillRepository`` callable that returns an invocation-specific repository. @@ -257,25 +299,30 @@ def __init__( self, skill_repository: BaseSkillRepository, *, - load_mode: str = SKILL_LOAD_MODE_TURN, + load_mode: str = str(SkillLoadModeNames.TURN), tooling_guidance: Optional[str] = None, tool_result_mode: bool = False, - tool_profile: str = "", + tool_profile: str = str(SkillProfileNames.FULL), + forbidden_tools: Optional[list[str]] = None, + tool_flags: Optional[SkillProfileFlags] = None, exec_tools_disabled: bool = False, repo_resolver: Optional[Callable[[InvocationContext], BaseSkillRepository]] = None, max_loaded_skills: int = 0, ) -> None: self._skill_repository = skill_repository - self._load_mode = _normalize_load_mode(load_mode) + self._load_mode = normalize_load_mode(load_mode) self._tooling_guidance = tooling_guidance self._tool_result_mode = tool_result_mode - self._tool_profile = (tool_profile or "").strip() - self._exec_tools_disabled = exec_tools_disabled + try: + resolved_flags = tool_flags or SkillProfileFlags.resolve_flags(tool_profile, forbidden_tools) + except ValueError as ex: + logger.warning("skills: invalid skill tool flags config, fallback to full profile: %s", ex) + resolved_flags = SkillProfileFlags.preset_flags(tool_profile, forbidden_tools) + if exec_tools_disabled: + resolved_flags = resolved_flags.without_interactive_execution() + self._tool_flags = resolved_flags self._repo_resolver = repo_resolver self._max_loaded_skills = max_loaded_skills - # Tracks which invocation IDs have already had their turn-init clearing - # applied. This is ephemeral instance state — not persisted to session. - self._initialized_invocations: set[str] = set() # ------------------------------------------------------------------ # Public entry point @@ -301,17 +348,17 @@ async def process_llm_request( self._maybe_clear_skill_state_for_turn(ctx) # 1) Always inject overview (names + descriptions). - self._inject_overview(request, repo) + self._inject_overview(ctx, request, repo) loaded = self._get_loaded_skills(ctx) loaded = self._maybe_cap_loaded_skills(ctx, loaded) if self._tool_result_mode: - # Loaded skill bodies/docs are injected into tool results by a - # separate post-content processor — skip injection here. + # Materialization is handled by a dedicated post-history processor + # in request pipeline (Go-aligned ordering). return loaded - # 2) Loaded skills: full body + docs + tools (sorted for stable prompts). + # 2) Loaded skills: full body + docs (sorted for stable prompts). loaded.sort() parts: list[str] = [] @@ -338,8 +385,6 @@ async def process_llm_request( doc_text = self._build_docs_text(sk, sel) if doc_text: parts.append(doc_text) - - # Tools (Python-unique: skill_select_tools integration) tool_sel = self._get_tools_selection(ctx, name) parts.append("Tools selected: ") if not tool_sel: @@ -373,38 +418,36 @@ def _maybe_clear_skill_state_for_turn(self, ctx: InvocationContext) -> None: Uses ``ctx.invocation_id`` to detect when a new invocation has started without persisting an extra key to session state. """ - if self._load_mode != SKILL_LOAD_MODE_TURN: + if self._load_mode != SkillLoadModeNames.TURN: return - inv_id = ctx.invocation_id - if inv_id in self._initialized_invocations: + if ctx.agent_context.get_metadata(_SKILLS_TURN_INIT_STATE_KEY): return - self._initialized_invocations.add(inv_id) - # Bound the set size to avoid unbounded growth across long-running servers. - if len(self._initialized_invocations) > 2000: - oldest = list(self._initialized_invocations)[:1000] - for old_id in oldest: - self._initialized_invocations.discard(old_id) + ctx.agent_context.with_metadata(_SKILLS_TURN_INIT_STATE_KEY, True) self._clear_skill_state(ctx) def _clear_skill_state(self, ctx: InvocationContext) -> None: """Clear all loaded-skill state keys from the session.""" + loaded_state_prefix = loaded_scan_prefix(ctx) + docs_state_prefix = docs_scan_prefix(ctx) + tools_state_prefix = tool_scan_prefix(ctx) + order_state_key = loaded_order_state_key(ctx) state = self._snapshot_state(ctx) for k, v in state.items(): if not v: continue - if (k.startswith(SKILL_LOADED_STATE_KEY_PREFIX) or k.startswith(SKILL_DOCS_STATE_KEY_PREFIX) - or k.startswith(SKILL_TOOLS_STATE_KEY_PREFIX) or k == _SKILLS_LOADED_ORDER_STATE_KEY): + if (k.startswith(loaded_state_prefix) or k.startswith(docs_state_prefix) or k.startswith(tools_state_prefix) + or k == order_state_key): ctx.actions.state_delta[k] = None def _maybe_offload_loaded_skills(self, ctx: InvocationContext, loaded: list[str]) -> None: """After injection, clear skill state for once mode.""" - if self._load_mode != SKILL_LOAD_MODE_ONCE or not loaded: + if self._load_mode != SkillLoadModeNames.ONCE or not loaded: return for name in loaded: - ctx.actions.state_delta[SKILL_LOADED_STATE_KEY_PREFIX + name] = None - ctx.actions.state_delta[SKILL_DOCS_STATE_KEY_PREFIX + name] = None - ctx.actions.state_delta[SKILL_TOOLS_STATE_KEY_PREFIX + name] = None - ctx.actions.state_delta[_SKILLS_LOADED_ORDER_STATE_KEY] = None + ctx.actions.state_delta[loaded_state_key(ctx, name)] = None + ctx.actions.state_delta[docs_state_key(ctx, name)] = None + ctx.actions.state_delta[tool_state_key(ctx, name)] = None + ctx.actions.state_delta[loaded_order_state_key(ctx)] = None # ------------------------------------------------------------------ # Max-loaded-skills cap @@ -416,40 +459,31 @@ def _maybe_cap_loaded_skills(self, ctx: InvocationContext, loaded: list[str]) -> return loaded order = self._get_loaded_skill_order(ctx, loaded) - # Keep the most recently touched skills (tail of the order list). + if not order: + return loaded keep_count = self._max_loaded_skills - keep_set = set(order[-keep_count:]) if len(order) >= keep_count else set(order) + keep_set = set(order[-keep_count:]) kept: list[str] = [] for name in loaded: if name in keep_set: kept.append(name) else: - ctx.actions.state_delta[SKILL_LOADED_STATE_KEY_PREFIX + name] = None - ctx.actions.state_delta[SKILL_DOCS_STATE_KEY_PREFIX + name] = None - ctx.actions.state_delta[SKILL_TOOLS_STATE_KEY_PREFIX + name] = None - + ctx.actions.state_delta[loaded_state_key(ctx, name)] = None + ctx.actions.state_delta[docs_state_key(ctx, name)] = None + ctx.actions.state_delta[tool_state_key(ctx, name)] = None new_order = [n for n in order if n in keep_set] - ctx.actions.state_delta[_SKILLS_LOADED_ORDER_STATE_KEY] = json.dumps(new_order) + encoded_order = marshal_loaded_order(new_order) + ctx.actions.state_delta[loaded_order_state_key(ctx)] = encoded_order return kept def _get_loaded_skill_order(self, ctx: InvocationContext, loaded: list[str]) -> list[str]: - """Return skill names in load order (oldest first, most-recent last). - - Reads the persisted order key; fills in any missing names - alphabetically (mirrors Go's fillLoadedSkillOrderAlphabetically). - """ - loaded_set = set(loaded) - raw = self._read_state(ctx, _SKILLS_LOADED_ORDER_STATE_KEY) - order: list[str] = [] - if raw: - try: - parsed = json.loads(raw) if isinstance(raw, str) else raw - if isinstance(parsed, list): - order = [n for n in parsed if n in loaded_set] - except (json.JSONDecodeError, TypeError, ValueError): - pass - + loaded_set = self._loaded_skill_set(loaded) + if not loaded_set: + return [] + order = self._loaded_skill_order_from_state(ctx, loaded_set) + if len(order) < len(loaded_set): + order = self._append_skills_to_order_from_events(ctx, order, loaded_set) seen = set(order) for name in sorted(n for n in loaded_set if n not in seen): order.append(name) @@ -459,7 +493,7 @@ def _get_loaded_skill_order(self, ctx: InvocationContext, loaded: list[str]) -> # Overview injection # ------------------------------------------------------------------ - def _inject_overview(self, request: LlmRequest, repo: BaseSkillRepository) -> None: + def _inject_overview(self, ctx: InvocationContext, request: LlmRequest, repo: BaseSkillRepository) -> None: sums = repo.summaries() if not sums: return @@ -501,23 +535,44 @@ def _inject_overview(self, request: LlmRequest, repo: BaseSkillRepository) -> No def _tooling_guidance_text(self) -> str: if self._tooling_guidance is None: - return _default_tooling_and_workspace_guidance(self._tool_profile, self._exec_tools_disabled) - return _normalize_custom_guidance(self._tooling_guidance) + tool_prompt = _default_tooling_and_workspace_guidance(self._tool_flags) + else: + tool_prompt = _normalize_custom_guidance(self._tooling_guidance) + if self._tool_flags.has_select_tools(): + tool_prompt += """ + - Use the skill_select_tools tool to select tools for the current task only when user asks for it." + """ + if self._tool_flags.list_skills: + tool_prompt += """ + - Use the skill_list_skills tool to list skills for the current task only when user asks for it." + """ + return tool_prompt def _capability_guidance_text(self) -> str: - """Inject capability block for knowledge-only profiles.""" - if not _is_knowledge_only(self._tool_profile): - return "" + """Inject capability block for constrained skill-tool profiles.""" # Omit when caller explicitly cleared guidance. - if self._tooling_guidance is not None and self._tooling_guidance == "": + if self._tooling_guidance == "" or self._tool_flags.run: return "" + if self._tool_flags.load: + return ("\n" + _SKILLS_CAPABILITY_HEADER + "\n" + + "- This configuration supports skill discovery and knowledge loading only.\n" + + "- Built-in skill execution tools are unavailable in the current mode.\n" + + "- If a loaded skill describes scripts, shell commands, workspace paths," + " generated files, or interactive flows, treat that content as reference" + " only. Use other registered tools for real actions, or explain that" + " execution is unavailable in the current mode.\n") + if self._tool_flags.has_doc_helpers(): + return ("\n" + _SKILLS_CAPABILITY_HEADER + "\n" + + "- This configuration supports skill discovery and skill doc inspection only.\n" + + "- Built-in skill loading and execution tools are unavailable in the" + " current mode.\n- Listing or selecting docs does not inject SKILL.md or doc contents" + " into model context by itself.\n") return ("\n" + _SKILLS_CAPABILITY_HEADER + "\n" + - "- This profile supports skill discovery and knowledge loading only.\n" + - "- Execution-oriented skill tools are unavailable in the current mode.\n" + - "- If a loaded skill describes scripts, shell commands, workspace paths," - " generated files, or interactive flows, treat that content as reference" - " only. Use other registered tools for real actions, or explain that" - " execution is unavailable in the current mode.\n") + "- This configuration exposes skill summaries only. Built-in skill tools" + " are unavailable in the current mode.\n" + + "- Treat the skill overview as a catalog of possible capabilities. Use" + " other registered tools, or explain the limitation clearly when the task" + " depends on skill loading or execution.\n") # ------------------------------------------------------------------ # State helpers @@ -547,57 +602,150 @@ def _get_loaded_skills(self, ctx: InvocationContext) -> list[str]: """Return names of all currently loaded skills.""" names: list[str] = [] state = self._snapshot_state(ctx) + scan_prefix = loaded_scan_prefix(ctx) for k, v in state.items(): - if not k.startswith(SKILL_LOADED_STATE_KEY_PREFIX) or not v: + if not k.startswith(scan_prefix) or not v: continue - name = k[len(SKILL_LOADED_STATE_KEY_PREFIX):] - names.append(name) - return names + name = k[len(scan_prefix):].strip() + if name: + names.append(name) + if names: + return sorted(set(names)) + return [] # ------------------------------------------------------------------ - # Docs and tools selection + # Docs / tools selection # ------------------------------------------------------------------ def _get_docs_selection(self, ctx: InvocationContext, name: str) -> list[str]: - - def get_all_docs(skill_name: str) -> list[str]: + value = self._read_state(ctx, docs_state_key(ctx, name), default=None) + if not value: + return [] + if isinstance(value, bytes): try: - repo = self._get_repository(ctx) - sk = repo.get(skill_name) if repo else None - if sk is None: - return [] - return [d.path for d in sk.resources] + value = value.decode("utf-8") + except UnicodeDecodeError: + return [] + if value == "*": + repo = self._get_repository(ctx) + if repo is None: + return [] + try: + sk = repo.get(name) + return [doc.path for doc in sk.resources] except Exception as ex: # pylint: disable=broad-except - logger.warning("Failed to get docs for skill '%s': %s", skill_name, ex) + logger.warning("Failed to get docs for skill '%s': %s", name, ex) return [] - - return generic_get_selection( - ctx=ctx, - skill_name=name, - state_key_prefix=SKILL_DOCS_STATE_KEY_PREFIX, - get_all_items_callback=get_all_docs, - ) + if not isinstance(value, str): + return [] + try: + arr = json.loads(value) + except json.JSONDecodeError: + return [] + if not isinstance(arr, list): + return [] + return [doc for doc in arr if isinstance(doc, str) and doc.strip()] def _get_tools_selection(self, ctx: InvocationContext, name: str) -> list[str]: - """Python-unique: return selected tool names for *name*.""" - - def get_all_tools(skill_name: str) -> list[str]: + value = self._read_state(ctx, tool_state_key(ctx, name), default=None) + if not value: + return [] + if isinstance(value, bytes): try: - repo = self._get_repository(ctx) - sk = repo.get(skill_name) if repo else None - if sk is None: - return [] + value = value.decode("utf-8") + except UnicodeDecodeError: + return [] + if value == "*": + repo = self._get_repository(ctx) + if repo is None: + return [] + try: + sk = repo.get(name) return sk.tools except Exception as ex: # pylint: disable=broad-except - logger.warning("Failed to get tools for skill '%s': %s", skill_name, ex) + logger.warning("Failed to get tools for skill '%s': %s", name, ex) return [] + if not isinstance(value, str): + return [] + try: + arr = json.loads(value) + except json.JSONDecodeError: + return [] + if not isinstance(arr, list): + return [] + return [tool for tool in arr if isinstance(tool, str) and tool.strip()] + + def _loaded_skill_set(self, loaded: list[str]) -> set[str]: + out: set[str] = set() + for name in loaded: + candidate = (name or "").strip() + if candidate: + out.add(candidate) + return out + + def _loaded_skill_order_from_state(self, ctx: InvocationContext, loaded_set: set[str]) -> list[str]: + order = parse_loaded_order(self._read_state(ctx, loaded_order_state_key(ctx))) + if not order: + return [] + out: list[str] = [] + seen: set[str] = set() + for name in order: + if name not in loaded_set or name in seen: + continue + out.append(name) + seen.add(name) + return out - return generic_get_selection( - ctx=ctx, - skill_name=name, - state_key_prefix=SKILL_TOOLS_STATE_KEY_PREFIX, - get_all_items_callback=get_all_tools, - ) + def _append_skills_to_order_from_events( + self, + ctx: InvocationContext, + order: list[str], + loaded_set: set[str], + ) -> list[str]: + events = list(getattr(ctx.session, "events", []) or []) + if not events: + return order + for event in events: + if ctx.agent_name and getattr(event, "author", "") != ctx.agent_name: + continue + content = getattr(event, "content", None) + if content is None or not getattr(content, "parts", None): + continue + for part in content.parts: + response = getattr(part, "function_response", None) + if response is None: + continue + tool_name = (getattr(response, "name", "") or "").strip() + if tool_name not in (SkillToolsNames.LOAD, SkillToolsNames.SELECT_DOCS): + continue + skill_name = self._skill_name_from_tool_response(tool_name, getattr(response, "response", None)) + if not skill_name or skill_name not in loaded_set: + continue + order = touch_loaded_order(order, skill_name) + return order + + def _skill_name_from_tool_response(self, tool_name: str, response: Any) -> str: + if tool_name == str(SkillToolsNames.SELECT_DOCS) and isinstance(response, dict): + for key in ("skill", "skill_name", "name"): + value = response.get(key) + if isinstance(value, str) and value.strip(): + return value.strip() + return "" + if tool_name == SkillToolsNames.LOAD: + if isinstance(response, dict): + for key in ("skill", "skill_name", "name", "result"): + value = response.get(key) + if isinstance(value, str) and value.strip(): + match = SKILL_LOADED_RE.search(value) + if match: + return match.group(1).strip() + if key in ("skill", "skill_name", "name"): + return value.strip() + if isinstance(response, str): + match = SKILL_LOADED_RE.search(response) + if match: + return match.group(1).strip() + return "" # ------------------------------------------------------------------ # Doc text assembly @@ -623,3 +771,28 @@ def _merge_into_system(self, request: LlmRequest, content: str) -> None: if not content: return request.append_instructions([content]) + + +def set_skill_processor_parameters(agent_context: AgentContext, parameters: dict[str, Any]) -> None: + """Set the parameters of a skill processor by agent context. + + Args: + agent_context: AgentContext object + parameters: Parameters to set + """ + skill_config = get_skill_config(agent_context) + skill_config["skill_processor"].update(parameters) + set_skill_config(agent_context, skill_config) + + +def get_skill_processor_parameters(agent_context: AgentContext) -> dict[str, Any]: + """Get the parameters of a skill processor. + + Args: + invocation_context: InvocationContext object + + Returns: + Parameters of the skill processor + """ + skill_config = get_skill_config(agent_context) + return skill_config["skill_processor"] diff --git a/trpc_agent_sdk/agents/core/_skills_tool_result_processor.py b/trpc_agent_sdk/agents/core/_skills_tool_result_processor.py new file mode 100644 index 0000000..fcd38fe --- /dev/null +++ b/trpc_agent_sdk/agents/core/_skills_tool_result_processor.py @@ -0,0 +1,375 @@ +# Tencent is pleased to support the open source community by making tRPC-Agent-Python available. +# +# Copyright (C) 2026 Tencent. All rights reserved. +# +# tRPC-Agent-Python is licensed under Apache-2.0. +"""Materialize loaded skill context into skill tool results.""" + +from __future__ import annotations + +import json +import re +from typing import Any +from typing import Callable +from typing import Optional +from typing import Tuple + +from trpc_agent_sdk.context import AgentContext +from trpc_agent_sdk.context import InvocationContext +from trpc_agent_sdk.log import logger +from trpc_agent_sdk.models import LlmRequest +from trpc_agent_sdk.skills import BaseSkillRepository +from trpc_agent_sdk.skills import Skill +from trpc_agent_sdk.skills import SkillLoadModeNames +from trpc_agent_sdk.skills import SkillToolsNames +from trpc_agent_sdk.skills import docs_state_key +from trpc_agent_sdk.skills import get_skill_config +from trpc_agent_sdk.skills import get_skill_load_mode +from trpc_agent_sdk.skills import loaded_order_state_key +from trpc_agent_sdk.skills import loaded_scan_prefix +from trpc_agent_sdk.skills import loaded_state_key +from trpc_agent_sdk.skills import set_skill_config + +_SKILLS_LOADED_CONTEXT_HEADER = "Loaded skill context:" +_SESSION_SUMMARY_PREFIX = "Here is a brief summary of your previous interactions:" +SKILL_LOADED_RE = re.compile(r"skill\s+'([^']+)'\s+loaded", re.IGNORECASE) + + +class SkillsToolResultRequestProcessor: + """Materialize loaded skill content into skill tool results.""" + + def __init__( + self, + skill_repository: BaseSkillRepository, + *, + skip_fallback_on_session_summary: bool = True, + repo_resolver: Optional[Callable[[InvocationContext], BaseSkillRepository]] = None, + ) -> None: + self._skill_repository = skill_repository + self._repo_resolver = repo_resolver + self._skip_fallback_on_session_summary = skip_fallback_on_session_summary + + async def process_llm_request(self, ctx: InvocationContext, request: LlmRequest) -> list[str]: + """Apply loaded-skill materialization to tool results and fallback prompt.""" + if request is None or ctx is None: + return [] + repo = self._get_repository(ctx) + if repo is None: + return [] + + loaded = self._get_loaded_skills(ctx) + if not loaded: + return [] + loaded.sort() + + tool_calls = self._index_tool_calls(request) + last_tool_parts = self._last_skill_tool_parts(request, tool_calls) + + materialized: set[str] = set() + for skill_name, (content_idx, part_idx) in last_tool_parts.items(): + content = request.contents[content_idx] + part = content.parts[part_idx] + function_response = part.function_response + if function_response is None: + continue + base = self._response_to_text(getattr(function_response, "response", None)) + rendered = self._build_tool_result_content(ctx, repo, skill_name, base) + if not rendered: + continue + function_response.response = {"result": rendered} + materialized.add(skill_name) + + fallback = self._build_fallback_system_content(ctx, repo, loaded, materialized) + if fallback: + if not (self._skip_fallback_on_session_summary and self._has_session_summary(request) + and not last_tool_parts): + request.append_instructions([fallback]) + + self._maybe_offload_loaded_skills(ctx, loaded) + return loaded + + def _get_repository(self, ctx: InvocationContext) -> Optional[BaseSkillRepository]: + if self._repo_resolver is not None: + return self._repo_resolver(ctx) + return self._skill_repository + + def _snapshot_state(self, ctx: InvocationContext) -> dict[str, Any]: + state = dict(ctx.session_state) + for key, value in ctx.actions.state_delta.items(): + if value is None: + state.pop(key, None) + else: + state[key] = value + return state + + def _read_state(self, ctx: InvocationContext, key: str, default=None): + if key in ctx.actions.state_delta: + return ctx.actions.state_delta[key] + return ctx.session_state.get(key, default) + + def _get_loaded_skills(self, ctx: InvocationContext) -> list[str]: + names_set: set[str] = set() + state = self._snapshot_state(ctx) + scan_prefix = loaded_scan_prefix(ctx) + for key, value in state.items(): + if not key.startswith(scan_prefix) or not value: + continue + name = key[len(scan_prefix):].strip() + if name: + names_set.add(name) + return sorted(names_set) + + def _index_tool_calls(self, request: LlmRequest) -> dict[str, Any]: + out: dict[str, Any] = {} + for content in request.contents: + if content.role not in ("model", "assistant") or not content.parts: + continue + for part in content.parts: + function_call = getattr(part, "function_call", None) + if function_call is None: + continue + call_id = (getattr(function_call, "id", "") or "").strip() + if not call_id: + continue + out[call_id] = function_call + return out + + def _last_skill_tool_parts( + self, + request: LlmRequest, + tool_calls: dict[str, Any], + ) -> dict[str, Tuple[int, int]]: + out: dict[str, Tuple[int, int]] = {} + for content_idx, content in enumerate(request.contents): + if content.role != "user" or not content.parts: + continue + for part_idx, part in enumerate(content.parts): + function_response = getattr(part, "function_response", None) + if function_response is None: + continue + tool_name = (getattr(function_response, "name", "") or "").strip() + if tool_name not in (SkillToolsNames.LOAD, SkillToolsNames.SELECT_DOCS): + continue + skill_name = self._skill_name_from_tool_response(function_response, tool_calls) + if not skill_name: + continue + out[skill_name] = (content_idx, part_idx) + return out + + def _skill_name_from_tool_response(self, function_response: Any, tool_calls: dict[str, Any]) -> str: + response = getattr(function_response, "response", None) + if isinstance(response, dict): + for key in ("skill", "skill_name", "name"): + value = response.get(key) + if isinstance(value, str) and value.strip(): + return value.strip() + + call_id = (getattr(function_response, "id", "") or "").strip() + if call_id and call_id in tool_calls: + function_call = tool_calls[call_id] + args = getattr(function_call, "args", None) + for key in ("skill", "skill_name", "name"): + value = self._get_arg_value(args, key) + if value: + return value + + return self._parse_loaded_skill_from_text(self._response_to_text(response)) + + def _get_arg_value(self, args: Any, key: str) -> str: + if isinstance(args, str): + try: + args = json.loads(args) + except json.JSONDecodeError: + return "" + if isinstance(args, dict): + value = args.get(key) + if isinstance(value, str): + return value.strip() + return "" + + def _parse_loaded_skill_from_text(self, content: str) -> str: + text = (content or "").strip() + if not text: + return "" + match = SKILL_LOADED_RE.search(text) + if match: + return match.group(1).strip() + lower = text.lower() + if lower.startswith("loaded:"): + return text[len("loaded:"):].strip() + return "" + + def _response_to_text(self, response: Any) -> str: + if response is None: + return "" + if isinstance(response, str): + return response.strip() + if isinstance(response, dict): + result = response.get("result") + if isinstance(result, str): + return result.strip() + if result is not None: + return str(result).strip() + return json.dumps(response, ensure_ascii=False).strip() + return str(response).strip() + + def _is_loaded_tool_stub(self, tool_output: str, skill_name: str) -> bool: + loaded = self._parse_loaded_skill_from_text(tool_output) + if not loaded: + return False + return loaded.lower() == skill_name.lower() + + def _build_tool_result_content( + self, + ctx: InvocationContext, + repo: BaseSkillRepository, + skill_name: str, + tool_output: str, + ) -> str: + try: + sk = repo.get(skill_name) + except Exception as ex: # pylint: disable=broad-except + logger.warning("skills: get %s failed: %s", skill_name, ex) + return "" + if sk is None: + logger.warning("skills: get %s failed: skill not found", skill_name) + return "" + + parts: list[str] = [] + base = tool_output.strip() + if base and self._is_loaded_tool_stub(base, skill_name): + base = "" + if base: + parts.append(base) + parts.append("\n\n") + + if sk.body.strip(): + parts.append(f"[Loaded] {skill_name}\n\n{sk.body}\n") + + selected_docs = self._get_docs_selection(ctx, skill_name, repo) + parts.append("Docs loaded: ") + if not selected_docs: + parts.append("none\n") + else: + parts.append(", ".join(selected_docs) + "\n") + docs_text = self._build_docs_text(sk, selected_docs) + if docs_text: + parts.append(docs_text) + return "".join(parts).strip() + + def _build_fallback_system_content( + self, + ctx: InvocationContext, + repo: BaseSkillRepository, + loaded: list[str], + materialized: set[str], + ) -> str: + missing = [name for name in loaded if name not in materialized] + if not missing: + return "" + + parts: list[str] = [_SKILLS_LOADED_CONTEXT_HEADER, "\n"] + appended = False + for name in missing: + try: + sk = repo.get(name) + except Exception as ex: # pylint: disable=broad-except + logger.warning("skills: get %s failed: %s", name, ex) + continue + if sk is None: + logger.warning("skills: get %s failed: skill not found", name) + continue + if sk.body.strip(): + parts.append(f"\n[Loaded] {name}\n\n{sk.body}\n") + appended = True + selected_docs = self._get_docs_selection(ctx, name, repo) + parts.append("Docs loaded: ") + if not selected_docs: + parts.append("none\n") + else: + parts.append(", ".join(selected_docs) + "\n") + docs_text = self._build_docs_text(sk, selected_docs) + if docs_text: + parts.append(docs_text) + appended = True + if not appended: + return "" + return "".join(parts).strip() + + def _has_session_summary(self, request: LlmRequest) -> bool: + if request is None or request.config is None: + return False + system_instruction = str(request.config.system_instruction or "") + return _SESSION_SUMMARY_PREFIX in system_instruction + + def _get_docs_selection(self, ctx: InvocationContext, skill_name: str, repo: BaseSkillRepository) -> list[str]: + value = self._read_state(ctx, docs_state_key(ctx, skill_name), default=None) + if not value: + return [] + if isinstance(value, bytes): + try: + value = value.decode("utf-8") + except UnicodeDecodeError: + return [] + if value == "*": + try: + sk = repo.get(skill_name) + except Exception as ex: # pylint: disable=broad-except + logger.warning("skills: get %s failed: %s", skill_name, ex) + return [] + if sk is None: + return [] + return [doc.path for doc in sk.resources] + if not isinstance(value, str): + return [] + try: + arr = json.loads(value) + except json.JSONDecodeError: + return [] + if not isinstance(arr, list): + return [] + return [doc for doc in arr if isinstance(doc, str) and doc.strip()] + + def _build_docs_text(self, sk: Skill, wanted: list[str]) -> str: + if sk is None or not sk.resources: + return "" + want = set(wanted) + parts: list[str] = [] + for resource in sk.resources: + if resource.path not in want or not resource.content: + continue + parts.append(f"\n[Doc] {resource.path}\n\n{resource.content}\n") + return "".join(parts) + + def _maybe_offload_loaded_skills(self, ctx: InvocationContext, loaded: list[str]) -> None: + if get_skill_load_mode(ctx) != SkillLoadModeNames.ONCE or not loaded: + return + for skill_name in loaded: + ctx.actions.state_delta[loaded_state_key(ctx, skill_name)] = None + ctx.actions.state_delta[docs_state_key(ctx, skill_name)] = None + ctx.actions.state_delta[loaded_order_state_key(ctx)] = None + + +def set_skill_tool_result_processor_parameters(agent_context: AgentContext, parameters: dict[str, Any]) -> None: + """Set the parameters of a skill tool result processor by agent context. + + Args: + agent_context: AgentContext object + parameters: Parameters to set + """ + skill_config = get_skill_config(agent_context) + skill_config["skills_tool_result_processor"].update(parameters) + set_skill_config(agent_context, skill_config) + + +def get_skill_tool_result_processor_parameters(agent_context: AgentContext) -> dict[str, Any]: + """Get the parameters of a skill tool result processor. + + Args: + agent_context: AgentContext object + + Returns: + Parameters of the skill tool result processor + """ + skill_config = get_skill_config(agent_context) + return skill_config["skills_tool_result_processor"] diff --git a/trpc_agent_sdk/agents/core/_workspace_exec_processor.py b/trpc_agent_sdk/agents/core/_workspace_exec_processor.py new file mode 100644 index 0000000..fc2d4d9 --- /dev/null +++ b/trpc_agent_sdk/agents/core/_workspace_exec_processor.py @@ -0,0 +1,153 @@ +# Tencent is pleased to support the open source community by making tRPC-Agent-Python available. +# +# Copyright (C) 2026 Tencent. All rights reserved. +# +# tRPC-Agent-Python is licensed under Apache-2.0. +"""Inject workspace_exec guidance into request system instructions. +""" + +from __future__ import annotations + +from typing import Any +from typing import Callable +from typing import Optional + +from trpc_agent_sdk.context import AgentContext +from trpc_agent_sdk.context import InvocationContext +from trpc_agent_sdk.models import LlmRequest +from trpc_agent_sdk.skills import BaseSkillRepository +from trpc_agent_sdk.skills import get_skill_config +from trpc_agent_sdk.skills import set_skill_config + +_WORKSPACE_EXEC_GUIDANCE_HEADER = "Executor workspace guidance:" + + +class WorkspaceExecRequestProcessor: + """Request processor for workspace_exec guidance injection.""" + + def __init__( + self, + *, + session_tools: bool = False, + has_skills_repo: bool = False, + repo_resolver: Optional[Callable[[InvocationContext], Optional[BaseSkillRepository]]] = None, + enabled_resolver: Optional[Callable[[InvocationContext], bool]] = None, + sessions_resolver: Optional[Callable[[InvocationContext], bool]] = None, + ) -> None: + self._session_tools = session_tools + self._static_skills_repo = has_skills_repo + self._repo_resolver = repo_resolver + self._enabled_resolver = enabled_resolver + self._sessions_resolver = sessions_resolver + + async def process_llm_request(self, ctx: InvocationContext, request: LlmRequest) -> None: + """Inject workspace guidance into request.config.system_instruction.""" + if ctx is None or request is None: + return + guidance = self._guidance_text(ctx, request) + if not guidance: + return + + existing = "" + if request.config and request.config.system_instruction: + existing = str(request.config.system_instruction) + if _WORKSPACE_EXEC_GUIDANCE_HEADER in existing: + return + request.append_instructions([guidance]) + + def _guidance_text(self, ctx: InvocationContext, request: LlmRequest) -> str: + if not self._enabled_for_invocation(ctx, request): + return "" + lines: list[str] = [ + _WORKSPACE_EXEC_GUIDANCE_HEADER, + "- Treat workspace_exec as the default general shell runner for shared " + "executor-side work. It runs inside the current executor workspace, not " + "on the agent host; workspace is its scope, not its capability limit.", + "- workspace_exec starts at the workspace root by default. Prefer work/, " + "out/, and runs/ for shared executor-side work, and treat cwd as a " + "workspace-relative path.", + "- Network access depends on the current executor environment. If you " + "need a network command such as curl, use a small bounded command to " + "verify whether that environment allows it.", + "- When a limitation depends on the executor environment and a small " + "bounded command can verify it, verify first before claiming the " + "limitation. This applies to checks such as command availability, file " + "presence, or access to a known URL.", + ] + if self._supports_artifact_save(request): + lines.append("- Use workspace_save_artifact only when you need a stable artifact " + "reference for an already existing file in work/, out/, or runs/. " + "Intermediate files usually stay in the workspace.") + if self._has_skills_repo(ctx): + lines.append("- Paths under skills/ are only useful when some other tool has " + "already placed content there. workspace_exec does not stage skills " + "automatically.") + if self._session_tools_for_invocation(ctx, request): + lines.append("- When workspace_exec starts a command that keeps running or waits " + "for stdin, continue with workspace_write_stdin. When chars is empty, " + "workspace_write_stdin acts like a poll. Use workspace_kill_session " + "to stop a running workspace_exec session.") + lines.append("- Interactive workspace_exec sessions are only guaranteed within the " + "current invocation. Do not assume a later user message can resume " + "the same session.") + return "\n".join(lines).strip() + + def _enabled_for_invocation(self, ctx: InvocationContext, request: LlmRequest) -> bool: + if self._enabled_resolver is not None: + return bool(self._enabled_resolver(ctx)) + return self._has_tool(request, "workspace_exec") + + def _session_tools_for_invocation(self, ctx: InvocationContext, request: LlmRequest) -> bool: + if self._sessions_resolver is not None: + return bool(self._sessions_resolver(ctx)) + if self._session_tools: + return True + return self._has_tool(request, "workspace_write_stdin") and self._has_tool(request, "workspace_kill_session") + + def _has_skills_repo(self, ctx: InvocationContext) -> bool: + if self._repo_resolver is not None: + return self._repo_resolver(ctx) is not None + return self._static_skills_repo + + def _supports_artifact_save(self, request: LlmRequest) -> bool: + return self._has_tool(request, "workspace_save_artifact") + + @staticmethod + def _has_tool(request: LlmRequest, tool_name: str) -> bool: + if request is None or request.config is None or not request.config.tools: + return False + target = (tool_name or "").strip() + if not target: + return False + for tool in request.config.tools: + declarations = getattr(tool, "function_declarations", None) or [] + for declaration in declarations: + name = getattr(declaration, "name", "") + if (name or "").strip() == target: + return True + return False + + +def set_workspace_exec_processor_parameters(agent_context: AgentContext, parameters: dict[str, Any]) -> None: + """Set the parameters of a workspace exec processor by agent context. + + Args: + agent_context: AgentContext object + parameters: Parameters to set + """ + skill_config = get_skill_config(agent_context) + skill_config["workspace_exec_processor"].update(parameters) + set_skill_config(agent_context, skill_config) + + +def get_workspace_exec_processor_parameters(agent_context: AgentContext) -> dict[str, Any]: + """Get the parameters of a workspace exec processor. + + Args: + agent_context: AgentContext object + + Returns: + Parameters of the workspace exec processor + """ + skill_config = get_skill_config(agent_context) + return skill_config["workspace_exec_processor"] diff --git a/trpc_agent_sdk/artifacts/_in_memory_artifact_service.py b/trpc_agent_sdk/artifacts/_in_memory_artifact_service.py index add9962..ab6c5aa 100644 --- a/trpc_agent_sdk/artifacts/_in_memory_artifact_service.py +++ b/trpc_agent_sdk/artifacts/_in_memory_artifact_service.py @@ -1,8 +1,6 @@ -# Tencent is pleased to support the open source community by making tRPC-Agent-Python available. +# -*- coding: utf-8 -*- # -# Copyright (C) 2026 Tencent. All rights reserved. -# -# tRPC-Agent-Python is licensed under Apache-2.0. +# Copyright @ 2025 Tencent.com # # Directly reuse the types from adk-python # Below code are copy and modified from https://github.com/google/adk-python.git @@ -27,7 +25,6 @@ from pydantic import BaseModel from pydantic import Field - from trpc_agent_sdk.abc import ArtifactEntry from trpc_agent_sdk.abc import ArtifactId from trpc_agent_sdk.abc import ArtifactServiceABC @@ -80,7 +77,7 @@ async def save_artifact( else: raise ValueError("Not supported artifact type.") - self.artifacts[path].append(ArtifactEntry(data=artifact, artifact_version=artifact_version)) + self.artifacts[path].append(ArtifactEntry(data=artifact, version=artifact_version)) return version @override @@ -125,6 +122,8 @@ async def load_artifact( if (artifact_data == Part() or artifact_data == Part(text="") or (artifact_data.inline_data and not artifact_data.inline_data.data)): return None + if artifact_entry.version.mime_type is None: + artifact_entry.version.mime_type = "application/octet-stream" return artifact_entry @override @@ -175,7 +174,7 @@ async def list_artifact_versions( entries = self.artifacts.get(path) if not entries: return [] - return [entry.artifact_version for entry in entries] + return [entry.version for entry in entries] @override async def get_artifact_version( @@ -192,6 +191,6 @@ async def get_artifact_version( if version is None: version = -1 try: - return entries[version].artifact_version + return entries[version].version except IndexError: return None diff --git a/trpc_agent_sdk/artifacts/_utils.py b/trpc_agent_sdk/artifacts/_utils.py index b9f0656..ee1444e 100644 --- a/trpc_agent_sdk/artifacts/_utils.py +++ b/trpc_agent_sdk/artifacts/_utils.py @@ -29,7 +29,6 @@ from typing import Optional from google.genai import types - from trpc_agent_sdk.abc import ArtifactId diff --git a/trpc_agent_sdk/code_executors/__init__.py b/trpc_agent_sdk/code_executors/__init__.py index d16645e..276f4db 100644 --- a/trpc_agent_sdk/code_executors/__init__.py +++ b/trpc_agent_sdk/code_executors/__init__.py @@ -9,13 +9,9 @@ including base classes and implementations. """ -from ._artifacts import artifact_service_from_context -from ._artifacts import artifact_session_from_context from ._artifacts import load_artifact_helper from ._artifacts import parse_artifact_ref from ._artifacts import save_artifact_helper -from ._artifacts import with_artifact_service -from ._artifacts import with_artifact_session from ._base_code_executor import BaseCodeExecutor from ._base_workspace_runtime import BaseProgramRunner from ._base_workspace_runtime import BaseWorkspaceFS @@ -47,6 +43,20 @@ from ._constants import META_FILE_NAME from ._constants import TMP_FILE_NAME from ._constants import WORKSPACE_ENV_DIR_KEY +from ._program_session import BaseProgramSession +from ._program_session import DEFAULT_EXEC_YIELD_MS +from ._program_session import DEFAULT_IO_YIELD_MS +from ._program_session import DEFAULT_POLL_LINES +from ._program_session import DEFAULT_SESSION_KILL_SEC +from ._program_session import DEFAULT_SESSION_TTL_SEC +from ._program_session import PROGRAM_STATUS_EXITED +from ._program_session import PROGRAM_STATUS_RUNNING +from ._program_session import ProgramLog +from ._program_session import ProgramPoll +from ._program_session import ProgramState +from ._program_session import poll_line_limit +from ._program_session import wait_for_program_output +from ._program_session import yield_duration_ms from ._types import CodeBlock from ._types import CodeBlockDelimiter from ._types import CodeExecutionInput @@ -82,13 +92,9 @@ from .utils import CodeExecutionUtils __all__ = [ - "artifact_service_from_context", - "artifact_session_from_context", "load_artifact_helper", "parse_artifact_ref", "save_artifact_helper", - "with_artifact_service", - "with_artifact_session", "BaseCodeExecutor", "BaseProgramRunner", "BaseWorkspaceFS", @@ -120,6 +126,20 @@ "META_FILE_NAME", "TMP_FILE_NAME", "WORKSPACE_ENV_DIR_KEY", + "BaseProgramSession", + "DEFAULT_EXEC_YIELD_MS", + "DEFAULT_IO_YIELD_MS", + "DEFAULT_POLL_LINES", + "DEFAULT_SESSION_KILL_SEC", + "DEFAULT_SESSION_TTL_SEC", + "PROGRAM_STATUS_EXITED", + "PROGRAM_STATUS_RUNNING", + "ProgramLog", + "ProgramPoll", + "ProgramState", + "poll_line_limit", + "wait_for_program_output", + "yield_duration_ms", "CodeBlock", "CodeBlockDelimiter", "CodeExecutionInput", diff --git a/trpc_agent_sdk/code_executors/_artifacts.py b/trpc_agent_sdk/code_executors/_artifacts.py index 08e0cda..3c1087b 100644 --- a/trpc_agent_sdk/code_executors/_artifacts.py +++ b/trpc_agent_sdk/code_executors/_artifacts.py @@ -10,11 +10,9 @@ through context, enabling artifact resolution without importing higher-level packages. """ -from typing import Any from typing import Optional from typing import Tuple -from trpc_agent_sdk.abc import ArtifactServiceABC from trpc_agent_sdk.context import InvocationContext from trpc_agent_sdk.types import Blob from trpc_agent_sdk.types import Part @@ -36,6 +34,8 @@ async def load_artifact_helper(ctx: InvocationContext, Returns: The artifact. """ + if not ctx: + raise ValueError("ctx is required") artifact_entry = await ctx.load_artifact(name, version) if artifact_entry is None: return None @@ -55,17 +55,20 @@ def parse_artifact_ref(ref: str) -> Tuple[str, Optional[int]]: The artifact name and version. """ parts = ref.split("@") + name = parts[0] + if not name: + raise ValueError(f"invalid ref: {ref}") if len(parts) == 1: - return parts[0], None + return name, None - if len(parts) == 2: + if len(parts) >= 2: # Try to parse version as integer - version_str = parts[1] + version_str = "".join(parts[1:]) if not version_str.isdigit(): - raise ValueError(f"invalid version: {version_str}") + return name, None - return parts[0], int(version_str) + return name, int(version_str) raise ValueError(f"invalid ref: {ref}") @@ -85,51 +88,3 @@ async def save_artifact_helper(ctx: InvocationContext, filename: str, data: byte """ artifact = Part(inline_data=Blob(data=data, mime_type=mime), ) return await ctx.save_artifact(filename, artifact) - - -def with_artifact_service(ctx: InvocationContext, svc: ArtifactServiceABC) -> InvocationContext: - """ - Store an artifact.Service in the context. - - Callers retrieve it in lower layers to load/save artifacts - without importing higher-level packages. - - Args: - ctx: The context to store the service in - svc: The artifact service - - Returns: - Updated context with artifact service - """ - ctx.artifact_service = svc - return ctx - - -def artifact_service_from_context(ctx: InvocationContext) -> Optional[ArtifactServiceABC]: - """ - Fetch the artifact.Service previously stored by with_artifact_service. - - Args: - ctx: The context to retrieve the service from - - Returns: - Tuple of (service, ok) where ok indicates presence - """ - return ctx.artifact_service - - -def with_artifact_session(ctx: InvocationContext, info: Any) -> InvocationContext: - assert False, "Not implemented" - - -def artifact_session_from_context(ctx: InvocationContext) -> Any: - """ - Retrieve artifact session info from context. - - Args: - ctx: The context to retrieve the session info from - - Returns: - SessionInfo object (empty if not found) - """ - assert False, "Not implemented" diff --git a/trpc_agent_sdk/code_executors/_base_code_executor.py b/trpc_agent_sdk/code_executors/_base_code_executor.py index 17ea06d..f7c6b06 100644 --- a/trpc_agent_sdk/code_executors/_base_code_executor.py +++ b/trpc_agent_sdk/code_executors/_base_code_executor.py @@ -16,7 +16,6 @@ from typing import Optional from pydantic import BaseModel - from trpc_agent_sdk.context import InvocationContext from ._base_workspace_runtime import BaseWorkspaceRuntime @@ -82,6 +81,9 @@ class BaseCodeExecutor(BaseModel): workspace_runtime: Optional[BaseWorkspaceRuntime] = None """The workspace runtime for the code execution.""" + ignore_codes: list[str] = [] + """The list of codes to ignore in the code execution.""" + @abc.abstractmethod async def execute_code( self, diff --git a/trpc_agent_sdk/code_executors/_base_workspace_runtime.py b/trpc_agent_sdk/code_executors/_base_workspace_runtime.py index b57b6a9..3ee1e1e 100644 --- a/trpc_agent_sdk/code_executors/_base_workspace_runtime.py +++ b/trpc_agent_sdk/code_executors/_base_workspace_runtime.py @@ -12,10 +12,12 @@ from abc import ABC from abc import abstractmethod +from typing import Callable from typing import List from typing import Optional from trpc_agent_sdk.context import InvocationContext +from trpc_agent_sdk.log import logger from ._types import CodeFile from ._types import ManifestOutput @@ -28,6 +30,8 @@ from ._types import WorkspaceRunResult from ._types import WorkspaceStageOptions +RunEnvProvider = Callable[[Optional[InvocationContext]], dict[str, str]] + class BaseWorkspaceManager(ABC): """ @@ -130,6 +134,40 @@ class BaseProgramRunner(ABC): Executes programs within a workspace. """ + def __init__( + self, + provider: Optional[RunEnvProvider] = None, + enable_provider_env: bool = False, + ) -> None: + self._run_env_provider = provider + self._enable_provider_env = bool(enable_provider_env and provider) + + def _apply_provider_env( + self, + spec: WorkspaceRunProgramSpec, + ctx: Optional[InvocationContext] = None, + ) -> WorkspaceRunProgramSpec: + """Return spec with provider env merged when enabled. + + Provider values never override keys already present in ``spec.env``. + The input ``spec`` is not mutated. + """ + provider = getattr(self, "_run_env_provider", None) + if not getattr(self, "_enable_provider_env", False) or provider is None: + return spec + try: + extra = provider(ctx) or {} + except Exception as ex: # pylint: disable=broad-except + logger.warning("run env provider failed: %s", ex) + return spec + if not extra: + return spec + merged = dict(spec.env or {}) + for key, value in extra.items(): + if key not in merged: + merged[key] = value + return spec.model_copy(update={"env": merged}, deep=True) + @abstractmethod async def run_program( self, diff --git a/trpc_agent_sdk/code_executors/_program_session.py b/trpc_agent_sdk/code_executors/_program_session.py new file mode 100644 index 0000000..177a00a --- /dev/null +++ b/trpc_agent_sdk/code_executors/_program_session.py @@ -0,0 +1,167 @@ +# -*- coding: utf-8 -*- +# +# Copyright @ 2026 Tencent.com +"""Program-session helpers. +""" + +from __future__ import annotations + +import asyncio +import time +from abc import ABC +from abc import abstractmethod +from dataclasses import dataclass +from typing import Optional + +from ._types import WorkspaceRunResult + +# Default wait windows (milliseconds). +DEFAULT_EXEC_YIELD_MS = 1_000 +DEFAULT_IO_YIELD_MS = 400 +DEFAULT_POLL_LINES = 40 + +# Poll pacing / settle windows (seconds). +DEFAULT_POLL_WAIT_SEC = 0.05 +DEFAULT_POLL_SETTLE_SEC = 0.075 + +# Session lifecycle defaults (seconds). +DEFAULT_SESSION_TTL_SEC = 30 * 60 +DEFAULT_SESSION_KILL_SEC = 2.0 + +PROGRAM_STATUS_RUNNING = "running" +PROGRAM_STATUS_EXITED = "exited" + + +@dataclass +class ProgramPoll: + """Incremental output chunk for a running or exited session.""" + + status: str = PROGRAM_STATUS_RUNNING + output: str = "" + offset: int = 0 + next_offset: int = 0 + exit_code: Optional[int] = None + + +@dataclass +class ProgramLog: + """Non-destructive output window from a specific offset.""" + + output: str = "" + offset: int = 0 + next_offset: int = 0 + + +@dataclass +class ProgramState: + """Non-streaming session status without cursor mutation.""" + + status: str = PROGRAM_STATUS_RUNNING + exit_code: Optional[int] = None + + +class BaseProgramSession(ABC): + """Base class for program sessions.""" + + @abstractmethod + def id(self) -> str: + """Return stable session id.""" + + @abstractmethod + async def poll(self, limit: Optional[int] = None) -> ProgramPoll: + """Advance cursor and return incremental output.""" + + @abstractmethod + async def log(self, offset: Optional[int] = None, limit: Optional[int] = None) -> ProgramLog: + """Read output from offset without advancing cursor.""" + + @abstractmethod + async def write(self, data: str, newline: bool) -> None: + """Write input to session.""" + + @abstractmethod + async def kill(self, grace_seconds: float) -> None: + """Terminate session, escalating after grace period.""" + + @abstractmethod + async def close(self) -> None: + """Release resources and stop background routines.""" + + @abstractmethod + async def state(self) -> ProgramState: + """Return current state snapshot.""" + + @abstractmethod + async def run_result(self) -> WorkspaceRunResult: + """Return final run result after session exits.""" + + +def yield_duration_ms(ms: int, fallback_ms: int) -> float: + """Normalize milliseconds into seconds with fallback and clamping.""" + if ms < 0: + ms = 0 + if ms == 0: + ms = fallback_ms + return ms / 1000.0 + + +def poll_line_limit(lines: int) -> int: + """Return a positive poll-line limit with default fallback.""" + if lines <= 0: + lines = DEFAULT_POLL_LINES + return lines + + +async def wait_for_program_output( + proc: BaseProgramSession, + yield_seconds: float, + limit: Optional[int], +) -> ProgramPoll: + """Poll until session exits or output settles within the yield window.""" + deadline = time.monotonic() + if yield_seconds > 0: + deadline += yield_seconds + + out_parts: list[str] = [] + offset = 0 + next_offset = 0 + have_chunk = False + settle_deadline = 0.0 + + while True: + poll = await proc.poll(limit) + if poll.output: + if not have_chunk: + offset = poll.offset + have_chunk = True + out_parts.append(poll.output) + next_offset = poll.next_offset + settle_deadline = time.monotonic() + DEFAULT_POLL_SETTLE_SEC + if yield_seconds <= 0: + deadline = settle_deadline + elif not have_chunk: + offset = poll.offset + next_offset = poll.next_offset + else: + next_offset = poll.next_offset + + if poll.status == PROGRAM_STATUS_EXITED: + poll.output = "".join(out_parts) + poll.offset = offset + poll.next_offset = next_offset + return poll + + now = time.monotonic() + if settle_deadline and now > settle_deadline: + poll.output = "".join(out_parts) + poll.offset = offset + poll.next_offset = next_offset + return poll + + if yield_seconds > 0 and now > deadline: + poll.output = "".join(out_parts) + poll.offset = offset + poll.next_offset = next_offset + return poll + + await asyncio.sleep(DEFAULT_POLL_WAIT_SEC) diff --git a/trpc_agent_sdk/code_executors/_types.py b/trpc_agent_sdk/code_executors/_types.py index 3323201..5130d48 100644 --- a/trpc_agent_sdk/code_executors/_types.py +++ b/trpc_agent_sdk/code_executors/_types.py @@ -11,7 +11,6 @@ from pydantic import BaseModel from pydantic import Field - from trpc_agent_sdk.types import CodeExecutionResult from trpc_agent_sdk.types import Outcome @@ -149,6 +148,9 @@ class WorkspaceRunProgramSpec(BaseModel): limits: WorkspaceResourceLimits = Field(default_factory=WorkspaceResourceLimits) """ resource limits""" + tty: bool = Field(default=False, description="Allocate pseudo-TTY") + """ whether to allocate pseudo-TTY""" + class WorkspaceRunResult(BaseModel): """ diff --git a/trpc_agent_sdk/code_executors/container/_container_cli.py b/trpc_agent_sdk/code_executors/container/_container_cli.py index 2e06d0c..39447ad 100644 --- a/trpc_agent_sdk/code_executors/container/_container_cli.py +++ b/trpc_agent_sdk/code_executors/container/_container_cli.py @@ -1,8 +1,6 @@ -# Tencent is pleased to support the open source community by making tRPC-Agent-Python available. +# -*- coding: utf-8 -*- # -# Copyright (C) 2026 Tencent. All rights reserved. -# -# tRPC-Agent-Python is licensed under Apache-2.0. +# Copyright @ 2025 Tencent.com """Container code executor for TRPC Agent framework. This module provides a code executor that uses a custom container to execute code. @@ -14,12 +12,15 @@ import asyncio import atexit import os +import socket as pysocket from dataclasses import dataclass from typing import Optional import docker from docker.models.containers import Container - +from docker.utils.socket import consume_socket_output +from docker.utils.socket import demux_adaptor +from docker.utils.socket import frames_iter from trpc_agent_sdk.log import logger from trpc_agent_sdk.utils import CommandExecResult @@ -48,6 +49,8 @@ class CommandArgs: """The environment variables for the command execution.""" timeout: Optional[float] = None """The timeout for the command execution in seconds.""" + stdin: Optional[str] = None + """Optional stdin content to write once before reading output.""" class ContainerClient: @@ -151,6 +154,16 @@ def _init_container(self): # docker SDK `run` supports bind specs via `volumes`. run_kwargs["volumes"] = binds logger.info("Container bind mounts enabled: %s", binds) + command = self.host_config.get("command", ["tail", "-f", "/dev/null"]) + stdin = self.host_config.get("stdin", True) + working_dir = self.host_config.get("working_dir", "/") + network_mode = self.host_config.get("network_mode", "none") + auto_remove = self.host_config.get("auto_remove", True) + run_kwargs.setdefault("command", command) + run_kwargs.setdefault("stdin_open", stdin) + run_kwargs.setdefault("working_dir", working_dir) + run_kwargs.setdefault("network_mode", network_mode) + run_kwargs.setdefault("auto_remove", auto_remove) self._container = self._client.containers.run( image=self.image, detach=True, @@ -189,23 +202,100 @@ def _cleanup_container(self): return logger.info("[Cleanup] Stopping the container...") - self._container.stop() - self._container.remove() + try: + self._container.stop() + except Exception: # pylint: disable=broad-except + pass + try: + self._container.remove() + except Exception: # pylint: disable=broad-except + pass logger.info("Container %s stopped and removed.", self._container.id) # self._container = None + def _exec_run_with_stdin( + self, + cmd: list[str], + environment: dict[str, str], + stdin: str, + ) -> CommandExecResult: + """Execute command with attached stdin, similar to docker exec attach.""" + resp = self.container.client.api.exec_create( + self.container.container.id, + cmd=cmd[:], + stdout=True, + stderr=True, + stdin=True, + tty=False, + environment=environment, + ) + exec_id = resp["Id"] + sock = self.container.client.api.exec_start( + exec_id, + detach=False, + tty=False, + stream=False, + socket=True, + demux=False, + ) + try: + data = (stdin or "").encode("utf-8") + if data: + try: + sock.sendall(data) + except Exception: # pylint: disable=broad-except + # Some transports expose the real socket as _sock. + sock._sock.sendall(data) # pylint: disable=protected-access + + try: + sock.shutdown(pysocket.SHUT_WR) + except Exception: # pylint: disable=broad-except + close_write = getattr(sock, "close_write", None) + if callable(close_write): + close_write() + + frames = frames_iter(sock, tty=False) + demux_frames = (demux_adaptor(*frame) for frame in frames) + output = consume_socket_output(demux_frames, demux=True) + stdout = output[0].decode("utf-8") if output and output[0] else "" + stderr = output[1].decode("utf-8") if output and output[1] else "" + finally: + try: + sock.close() + except Exception: # pylint: disable=broad-except + pass + + inspect = self.container.client.api.exec_inspect(exec_id) + exit_code = int(inspect.get("ExitCode", -1)) + return CommandExecResult(stdout=stdout, stderr=stderr, exit_code=exit_code, is_timeout=False) + async def exec_run(self, cmd: list[str], command_args: CommandArgs) -> CommandExecResult: """Execute command in container.""" timeout = command_args.timeout try: loop = asyncio.get_event_loop() - co = loop.run_in_executor( - None, - lambda: self.container.exec_run(cmd=cmd[:], demux=True, environment=command_args.environment or {})) + if command_args.stdin: + co = loop.run_in_executor( + None, + lambda: self._exec_run_with_stdin( + cmd, + command_args.environment or {}, + command_args.stdin or "", + ), + ) + else: + co = loop.run_in_executor( + None, + lambda: self.container.exec_run(cmd=cmd[:], demux=True, environment=command_args.environment or {})) if command_args.timeout: - exit_code, output = await asyncio.wait_for(co, timeout=command_args.timeout) + result = await asyncio.wait_for(co, timeout=command_args.timeout) else: - exit_code, output = await co + result = await co + + if command_args.stdin: + return result + + exit_code, output = result stdout = output[0].decode('utf-8') if output[0] else "" stderr = output[1].decode('utf-8') if output[1] else "" except asyncio.TimeoutError: diff --git a/trpc_agent_sdk/code_executors/container/_container_code_executor.py b/trpc_agent_sdk/code_executors/container/_container_code_executor.py index 13ba145..56eec06 100644 --- a/trpc_agent_sdk/code_executors/container/_container_code_executor.py +++ b/trpc_agent_sdk/code_executors/container/_container_code_executor.py @@ -15,7 +15,6 @@ from typing_extensions import override from pydantic import Field - from trpc_agent_sdk.context import InvocationContext from .._base_code_executor import BaseCodeExecutor diff --git a/trpc_agent_sdk/code_executors/container/_container_ws_runtime.py b/trpc_agent_sdk/code_executors/container/_container_ws_runtime.py index 7bf5b34..ce9fb0f 100644 --- a/trpc_agent_sdk/code_executors/container/_container_ws_runtime.py +++ b/trpc_agent_sdk/code_executors/container/_container_ws_runtime.py @@ -15,12 +15,15 @@ """ import io +import json import os import tarfile import time from dataclasses import dataclass from dataclasses import field +from datetime import datetime from pathlib import Path +from typing import Any from typing import Dict from typing import List from typing import Optional @@ -37,9 +40,13 @@ from .._base_workspace_runtime import BaseWorkspaceFS from .._base_workspace_runtime import BaseWorkspaceManager from .._base_workspace_runtime import BaseWorkspaceRuntime +from .._base_workspace_runtime import RunEnvProvider from .._constants import DEFAULT_INPUTS_CONTAINER +from .._constants import DEFAULT_MAX_FILES +from .._constants import DEFAULT_MAX_TOTAL_BYTES from .._constants import DEFAULT_RUN_CONTAINER_BASE from .._constants import DEFAULT_SKILLS_CONTAINER +from .._constants import DEFAULT_TIMEOUT_SEC from .._constants import DIR_OUT from .._constants import DIR_RUNS from .._constants import DIR_SKILLS @@ -62,7 +69,10 @@ from .._types import WorkspaceRunProgramSpec from .._types import WorkspaceRunResult from .._types import WorkspaceStageOptions +from ..utils import InputRecordMeta +from ..utils import WorkspaceMetadata from ..utils import get_rel_path +from ..utils import normalize_globs from ._container_cli import CommandArgs from ._container_cli import ContainerClient from ._container_cli import ContainerConfig @@ -78,10 +88,16 @@ class RuntimeConfig: run_container_base: str = DEFAULT_RUN_CONTAINER_BASE inputs_host_base: str = "" inputs_container_base: str = DEFAULT_INPUTS_CONTAINER - auto_map_inputs: bool = False + auto_map_inputs: bool = True command_args: CommandArgs = field(default_factory=CommandArgs) +def _shell_quote(s: str) -> str: + if not s: + return "''" + return "'" + s.replace("'", "'\\''") + "'" + + class ContainerWorkspaceManager(BaseWorkspaceManager): """ Docker container-based workspace manager implementation. @@ -124,12 +140,15 @@ async def create_workspace(self, exec_id: str, ctx: Optional[InvocationContext] ws_path = str(Path(self.config.run_container_base) / f"ws_{safe_id}_{suffix}") # Create standard directory layout - cmd_parts = [ - f"mkdir -p '{ws_path}'", f"'{ws_path}/{DIR_SKILLS}'", f"'{ws_path}/{DIR_WORK}'", f"'{ws_path}/{DIR_RUNS}'", - f"'{ws_path}/{DIR_OUT}'", - f"&& [ -f '{ws_path}/{META_FILE_NAME}' ] || echo '{{}}' > '{ws_path}/{META_FILE_NAME}'" - ] - cmd = ["/bin/bash", "-lc", " ".join(cmd_parts)] + cmd_str = ("set -e; " + f"mkdir -p {_shell_quote(ws_path)} " + f"{_shell_quote(str(Path(ws_path) / DIR_SKILLS))} " + f"{_shell_quote(str(Path(ws_path) / DIR_WORK))} " + f"{_shell_quote(str(Path(ws_path) / DIR_RUNS))} " + f"{_shell_quote(str(Path(ws_path) / DIR_OUT))}; " + f"[ -f {_shell_quote(str(Path(ws_path) / META_FILE_NAME))} ] || " + f"echo '{{}}' > {_shell_quote(str(Path(ws_path) / META_FILE_NAME))}") + cmd = ["/bin/bash", "-lc", cmd_str] result = await self.container.exec_run(cmd=cmd, command_args=self.config.command_args) if result.exit_code != 0: @@ -268,8 +287,8 @@ async def stage_directory(self, cmd = ["/bin/bash", "-lc", cmd_str] result = await self.container.exec_run(cmd=cmd, command_args=self.config.command_args) if result.exit_code != 0: - logger.debug("Failed to stage directory: %s", result.stderr) - logger.debug("Staged directory using mount: %s -> %s", container_src, container_dst) + raise RuntimeError(f"Failed to stage directory: {result.stderr}") + return # Fallback: tar copy await self._put_directory(ws, src_abs_path, dst) @@ -327,8 +346,15 @@ async def collect(self, continue seen.add(rel_path) - data, mime = self._copy_file_out(line) - files.append(CodeFile(name=rel_path, content=data.decode('utf-8', errors='replace'), mime_type=mime)) + data, size_bytes, mime = self._copy_file_out(line) + files.append( + CodeFile( + name=rel_path, + content=data.decode('utf-8', errors='replace'), + mime_type=mime, + size_bytes=size_bytes, + truncated=size_bytes > len(data), + )) logger.info("Collected %s files from workspace", len(files)) return files @@ -349,32 +375,56 @@ async def stage_inputs(self, Raises: RuntimeError: If staging fails """ + md = await self._load_workspace_metadata(ws) for spec in specs: - mode = spec.mode.lower().strip() or "copy" - dst = spec.dst.strip() or str(Path(DIR_WORK) / "inputs" / self._input_base(spec.src)) - dst = os.path.join(ws.path, dst) + mode = (spec.mode or "").lower().strip() or "copy" + dst_rel = (spec.dst or "").strip() or str(Path(DIR_WORK) / "inputs" / self._input_base(spec.src)) + dst_abs = str(Path(ws.path) / dst_rel) + + resolved = "" + version: Optional[int] = None if spec.src.startswith("artifact://"): - name = spec.src.removeprefix("artifact://") - resolved, ver = parse_artifact_ref(name) if not ctx: raise ValueError("Context is required to load artifacts") - content, ver = await load_artifact_helper(ctx, resolved, ver) - await self._put_bytes_tar(content, dst) + name = spec.src.removeprefix("artifact://") + artifact_name, requested_ver = parse_artifact_ref(name) + use_ver = requested_ver + if use_ver is None and spec.pin: + use_ver = self._pinned_artifact_version(md, artifact_name, dst_rel) + content, actual_ver = await load_artifact_helper(ctx, artifact_name, use_ver) + await self._put_bytes_tar(content, dst_abs) + resolved = artifact_name + version = use_ver if use_ver is not None else actual_ver elif spec.src.startswith("host://"): host_path = spec.src.removeprefix("host://") - await self._stage_host_input(ws, host_path, dst, mode) + await self._stage_host_input(ws, host_path, dst_abs, mode, dst_rel) + resolved = host_path elif spec.src.startswith("workspace://"): rel = spec.src.removeprefix("workspace://") src = str(Path(ws.path) / rel) - await self._stage_workspace_input(src, dst, mode) + await self._stage_workspace_input(src, dst_abs, mode) + resolved = rel elif spec.src.startswith("skill://"): rest = spec.src.removeprefix("skill://") src = str(Path(ws.path) / DIR_SKILLS / rest) - await self._stage_workspace_input(src, dst, mode) + await self._stage_workspace_input(src, dst_abs, mode) + resolved = src else: raise RuntimeError(f"Unsupported input: {spec.src}") + md.inputs.append( + InputRecordMeta( + src=spec.src, + dst=dst_rel, + resolved=resolved, + version=version, + mode=mode, + timestamp=datetime.now(), + )) + + await self._save_workspace_metadata(ws, md) + logger.info("Staged %s inputs into workspace", len(specs)) @override @@ -409,13 +459,15 @@ async def collect_outputs(self, raise RuntimeError(f"Failed to collect outputs: {result.stderr}") stdout = result.stdout - max_files = spec.max_files or 100 + max_files = spec.max_files or DEFAULT_MAX_FILES max_file_bytes = spec.max_file_bytes or MAX_READ_SIZE_BYTES - max_total = spec.max_total_bytes or 64 * 1024 * 1024 + max_total = spec.max_total_bytes or DEFAULT_MAX_TOTAL_BYTES manifest = ManifestOutput() total_bytes = 0 count = 0 + saved_names: list[str] = [] + saved_versions: list[int] = [] for line in stdout.strip().split('\n'): line = line.strip() if not line: @@ -425,14 +477,25 @@ async def collect_outputs(self, manifest.limits_hit = True break - data, mime = self._copy_file_out(line) + data, raw_size, mime = self._copy_file_out(line) if len(data) > max_file_bytes: data = data[:max_file_bytes] manifest.limits_hit = True + if total_bytes + len(data) > max_total: + remain = max_total - total_bytes + if remain <= 0: + manifest.limits_hit = True + break + data = data[:remain] + manifest.limits_hit = True + total_bytes += len(data) rel_path = line.removeprefix(f"{ws.path}/") + truncated = raw_size > len(data) + if truncated and spec.save: + raise RuntimeError(f"cannot save truncated output file: {rel_path}") file_ref = ManifestFileRef(name=rel_path, mime_type=mime) if spec.inline: @@ -446,6 +509,8 @@ async def collect_outputs(self, version = await save_artifact_helper(ctx, save_name, data, mime) file_ref.saved_as = save_name file_ref.version = version + saved_names.append(save_name) + saved_versions.append(version) manifest.files.append(file_ref) count += 1 @@ -455,6 +520,8 @@ async def collect_outputs(self, async def _put_directory(self, ws: WorkspaceInfo, src: str, dst: str) -> None: """Copy directory to container using tar.""" + if not src or not str(src).strip(): + raise ValueError("source path is empty") abs_src = os.path.abspath(src) container_dst = str(Path(ws.path) / dst) if dst else ws.path if self.config.skills_host_base: @@ -464,8 +531,9 @@ async def _put_directory(self, ws: WorkspaceInfo, src: str, dst: str) -> None: # Create destination directory cmd = ["/bin/bash", "-lc", f"mkdir -p '{container_dst}' && cp -a '{container_src}/.' '{container_dst}'"] result = await self.container.exec_run(cmd=cmd, command_args=self.config.command_args) - if result.exit_code: - logger.debug("Failed to stage directory: %s", result.stderr) + if result.exit_code == 0: + return None + logger.debug("Failed to stage directory via mount copy, fallback to tar: %s", result.stderr) cmd = ["/bin/bash", "-lc", f"[ -e '{container_dst}' ] || mkdir -p '{container_dst}'"] result = await self.container.exec_run(cmd=cmd, command_args=self.config.command_args) @@ -507,7 +575,7 @@ async def _put_bytes_tar(self, data: bytes, dest: str, mode: int = 0o644) -> Non if not success: raise RuntimeError(f"Failed to copy bytes to {dest}") - async def _stage_host_input(self, ws: WorkspaceInfo, host: str, dst: str, mode: str) -> None: + async def _stage_host_input(self, ws: WorkspaceInfo, host: str, dst: str, mode: str, dst_rel: str) -> None: """Stage input from host path.""" if self.config.inputs_host_base: rel_path = get_rel_path(self.config.inputs_host_base, host) @@ -525,11 +593,11 @@ async def _stage_host_input(self, ws: WorkspaceInfo, host: str, dst: str, mode: cmd = ["/bin/bash", "-lc", cmd_str] result = await self.container.exec_run(cmd=cmd, command_args=self.config.command_args) - if result.exit_code: - logger.debug("Failed to stage input: %s", result.stderr) + if result.exit_code != 0: + raise RuntimeError(f"Failed to stage host input: {result.stderr}") return # Fallback to tar copy - await self._put_directory(ws, host, str(Path(dst).parent)) + await self._put_directory(ws, host, str(Path(dst_rel).parent)) async def _stage_workspace_input(self, src: str, dst: str, mode: str) -> None: """Stage input from workspace path.""" @@ -547,7 +615,7 @@ async def _stage_workspace_input(self, src: str, dst: str, mode: str) -> None: if result.exit_code: raise RuntimeError(f"Failed to stage input: {result.stderr}") - def _copy_file_out(self, full_path: str) -> Tuple[bytes, str]: + def _copy_file_out(self, full_path: str) -> Tuple[bytes, int, str]: """ Copy file out of container. @@ -555,7 +623,7 @@ def _copy_file_out(self, full_path: str) -> Tuple[bytes, str]: full_path: Full path to file in container Returns: - Tuple of (file_data, mime_type) + Tuple of (file_data, size_bytes, mime_type) Raises: RuntimeError: If copy fails @@ -570,7 +638,7 @@ def _copy_file_out(self, full_path: str) -> Tuple[bytes, str]: f = tar.extractfile(member) data = f.read(MAX_READ_SIZE_BYTES) mime = self._detect_mime_type(data) - return data, mime + return data, member.size, mime raise RuntimeError(f"No file found in archive: {full_path}") @@ -597,24 +665,75 @@ def _create_tar_from_files(files: List[WorkspacePutFileInfo]) -> io.BytesIO: @staticmethod def _normalize_globs(patterns: List[str]) -> List[str]: """Normalize glob patterns.""" - normalized = [] - for p in patterns: - p = p.strip() - if p: - # Simple normalization - replace environment variables - p = p.replace("$OUTPUT_DIR", DIR_OUT) - p = p.replace("${OUTPUT_DIR}", DIR_OUT) - p = p.replace("$WORK_DIR", DIR_WORK) - p = p.replace("${WORK_DIR}", DIR_WORK) - p = p.replace("$WORKSPACE_DIR", ".") - p = p.replace("${WORKSPACE_DIR}", ".") - normalized.append(p) - return normalized + return normalize_globs(patterns) @staticmethod - def _input_base(path: str) -> str: + def _input_base(src: str) -> str: """Extract base name from input path.""" - return Path(path).name + s = (src or "").strip() + if s.startswith("artifact://"): + ref = s.removeprefix("artifact://") + try: + name, _ = parse_artifact_ref(ref) + base = Path(name.strip()).name + if base and base not in (".", "..", "/"): + return base + except Exception: # pylint: disable=broad-except + pass + return Path(s).name + + @staticmethod + def _pinned_artifact_version(md: Any, artifact_name: str, dst: str) -> Optional[int]: + for record in reversed(md.inputs or []): + if (record.dst or "") != dst: + continue + if record.version is None: + continue + if (record.resolved or "") == artifact_name: + return record.version + src = record.src or "" + if not src.startswith("artifact://"): + continue + try: + name, _ = parse_artifact_ref(src.removeprefix("artifact://")) + except Exception: # pylint: disable=broad-except + continue + if name == artifact_name: + return record.version + return None + + async def _load_workspace_metadata(self, ws: WorkspaceInfo): + now = datetime.now() + cmd = ["/bin/bash", "-lc", f"cat {_shell_quote(str(Path(ws.path) / META_FILE_NAME))}"] + result = await self.container.exec_run(cmd=cmd, command_args=self.config.command_args) + if result.exit_code != 0 or not result.stdout.strip(): + return WorkspaceMetadata(version=1, created_at=now, updated_at=now, last_access=now, skills={}) + try: + data = json.loads(result.stdout) + md = WorkspaceMetadata(**data) + except Exception as ex: # pylint: disable=broad-except + raise RuntimeError(f"Failed to parse workspace metadata: {ex}") from ex + if not md.version: + md.version = 1 + if md.created_at is None: + md.created_at = now + md.last_access = now + if md.skills is None: + md.skills = {} + return md + + async def _save_workspace_metadata(self, ws: WorkspaceInfo, md: Any) -> None: + now = datetime.now() + if not md.version: + md.version = 1 + if md.created_at is None: + md.created_at = now + md.updated_at = now + md.last_access = now + if md.skills is None: + md.skills = {} + payload = json.dumps(md.model_dump(exclude_none=True, by_alias=True, mode="json"), ensure_ascii=False, indent=2) + await self._put_bytes_tar(payload.encode("utf-8"), str(Path(ws.path) / META_FILE_NAME), mode=0o600) @staticmethod def _detect_mime_type(data: bytes) -> str: @@ -637,7 +756,13 @@ class ContainerProgramRunner(BaseProgramRunner): Docker container-based program runner implementation. """ - def __init__(self, container: ContainerClient, config: RuntimeConfig): + def __init__( + self, + container: ContainerClient, + config: RuntimeConfig, + provider: Optional[RunEnvProvider] = None, + enable_provider_env: bool = False, + ): """ Initialize container program runner. @@ -646,6 +771,7 @@ def __init__(self, container: ContainerClient, config: RuntimeConfig): container: Docker container to use config: Runtime configuration """ + super().__init__(provider=provider, enable_provider_env=enable_provider_env) self.container = container self.config = config @@ -667,6 +793,7 @@ async def run_program(self, Raises: RuntimeError: If execution fails """ + spec = self._apply_provider_env(spec, ctx) cwd = f"{ws.path}/{spec.cwd}" if spec.cwd else ws.path # Prepare directories @@ -685,35 +812,41 @@ async def run_program(self, } env_parts = [] + user_env = dict(spec.env or {}) for k, v in base_env.items(): - if k not in spec.env: - env_parts.append(f"{k}={self._shell_quote(v)}") + if k not in user_env: + env_parts.append(f"{k}={_shell_quote(v)}") - for k, v in spec.env.items(): - env_parts.append(f"{k}={self._shell_quote(v)}") + for k, v in user_env.items(): + env_parts.append(f"{k}={_shell_quote(v)}") env_str = " ".join(env_parts) # Build command line cmd_parts = [ - f"mkdir -p {self._shell_quote(run_dir)} {self._shell_quote(out_dir)}", f"&& cd {self._shell_quote(cwd)}", + f"mkdir -p {_shell_quote(run_dir)} {_shell_quote(out_dir)}", f"&& cd {_shell_quote(cwd)}", "&& env" if env_str else "", env_str, - self._shell_quote(spec.cmd) + _shell_quote(spec.cmd) ] for arg in spec.args: - cmd_parts.append(self._shell_quote(arg)) + cmd_parts.append(_shell_quote(arg)) cmd_str = " ".join(filter(None, cmd_parts)) cmd = ["/bin/bash", "-lc", cmd_str] start_time = time.time() - timeout = spec.timeout or self.config.command_args.timeout - if timeout is None: + if spec.timeout and spec.timeout > 0: timeout = spec.timeout + elif self.config.command_args.timeout and self.config.command_args.timeout > 0: + timeout = self.config.command_args.timeout else: - timeout = min(timeout, spec.timeout) - command_args = CommandArgs(environment=None, timeout=timeout) + timeout = float(DEFAULT_TIMEOUT_SEC) + command_args = CommandArgs( + environment=None, + timeout=timeout, + stdin=spec.stdin or None, + ) result = await self.container.exec_run(cmd=cmd, command_args=command_args) return WorkspaceRunResult(stdout=result.stdout, stderr=result.stderr, @@ -721,28 +854,20 @@ async def run_program(self, duration=time.time() - start_time, timed_out=result.is_timeout) - @staticmethod - def _shell_quote(s: str) -> str: - """ - Quote string for safe shell usage. - - Args: - s: String to quote - - Returns: - Quoted string - """ - if not s: - return "''" - return "'" + s.replace("'", "'\\''") + "'" - class ContainerWorkspaceRuntime(BaseWorkspaceRuntime): """ Docker container-based execution engine. """ - def __init__(self, container: ContainerClient, host_config: Optional[Dict] = None, auto_inputs: bool = False): + def __init__( + self, + container: ContainerClient, + host_config: Optional[Dict] = None, + auto_inputs: bool = True, + provider: Optional[RunEnvProvider] = None, + enable_provider_env: bool = False, + ): """ Initialize container engine. @@ -763,7 +888,12 @@ def __init__(self, container: ContainerClient, host_config: Optional[Dict] = Non self._fs = ContainerWorkspaceFS(self.container, config) self._manager = ContainerWorkspaceManager(self.container, config, self._fs) - self._runner = ContainerProgramRunner(self.container, config) + self._runner = ContainerProgramRunner( + self.container, + config, + provider=provider, + enable_provider_env=enable_provider_env, + ) @override def manager(self, ctx: Optional[InvocationContext] = None) -> ContainerWorkspaceManager: @@ -797,8 +927,8 @@ def _find_bind_source(binds: List[str], dest: str) -> str: if len(parts) < 2: continue - # Handle format: source:dest[:mode] - bind_dest = parts[-2] if len(parts) >= 2 else "" + # Handle format: source:dest[:mode], parse from right. + bind_dest = parts[-2] if bind_dest == dest: source = ':'.join(parts[:-2]) if len(parts) > 2 else parts[0] if Path(source).is_dir(): @@ -820,7 +950,9 @@ def describe(self, ctx: Optional[InvocationContext] = None) -> WorkspaceCapabili def create_container_workspace_runtime( container_config: Optional[ContainerConfig] = None, host_config: Optional[Dict] = None, - auto_inputs: bool = False, + auto_inputs: bool = True, + provider: Optional[RunEnvProvider] = None, + enable_provider_env: bool = False, ) -> ContainerWorkspaceRuntime: """Create a new container workspace runtime. Args: @@ -838,4 +970,10 @@ def create_container_workspace_runtime( container = ContainerClient(config=cfg) else: container = ContainerClient(config=ContainerConfig(host_config=host_config)) - return ContainerWorkspaceRuntime(container=container, host_config=host_config, auto_inputs=auto_inputs) + return ContainerWorkspaceRuntime( + container=container, + host_config=host_config, + auto_inputs=auto_inputs, + provider=provider, + enable_provider_env=enable_provider_env, + ) diff --git a/trpc_agent_sdk/code_executors/local/_local_program_session.py b/trpc_agent_sdk/code_executors/local/_local_program_session.py new file mode 100644 index 0000000..fb1be8a --- /dev/null +++ b/trpc_agent_sdk/code_executors/local/_local_program_session.py @@ -0,0 +1,299 @@ +# Tencent is pleased to support the open source community by making tRPC-Agent-Python available. +# +# Copyright (C) 2026 Tencent. All rights reserved. +# +# tRPC-Agent-Python is licensed under Apache-2.0. +"""WorkspaceInfo runtime for local code execution. + +This module provides the WorkspaceRuntime class which allows local code execution. +It provides methods for staging directories and inputs into the workspace. +""" + +from __future__ import annotations + +import asyncio +import os +import sys +import time +import uuid +from typing import Optional +from typing_extensions import override + +from .._program_session import BaseProgramSession +from .._program_session import PROGRAM_STATUS_EXITED +from .._program_session import PROGRAM_STATUS_RUNNING +from .._program_session import ProgramLog +from .._program_session import ProgramPoll +from .._program_session import ProgramState +from .._types import WorkspaceRunResult + +_DEFAULT_INTERACTIVE_MAX_LINES = 20_000 +if sys.platform != "win32": + import fcntl + + +def _split_lines_with_partial(text: str) -> tuple[list[str], str]: + normalized = text.replace("\r\n", "\n") + parts = normalized.split("\n") + if len(parts) == 1: + return [], parts[0] + return parts[:-1], parts[-1] + + +class LocalProgramSession(BaseProgramSession): + """Local interactive subprocess session.""" + + def __init__( + self, + process: asyncio.subprocess.Process, + *, + max_lines: int = _DEFAULT_INTERACTIVE_MAX_LINES, + master_fd: Optional[int] = None, + ) -> None: + self._id = uuid.uuid4().hex + self._process = process + self._max_lines = max_lines + self._master_fd = master_fd + self._lock = asyncio.Lock() + self._closed = False + + self._started_at = time.time() + self._finished_at: Optional[float] = None + self._exit_code: Optional[int] = None + self._timed_out = False + + self._line_base = 0 + self._lines: list[str] = [] + self._partial = "" + self._poll_cursor = 0 + + self._stdout = "" + self._stderr = "" + if self._master_fd is not None: + self._stdout_task = asyncio.create_task(self._read_pty(self._master_fd)) + self._stderr_task = asyncio.create_task(asyncio.sleep(0)) + else: + self._stdout_task = asyncio.create_task(self._read_stream(self._process.stdout, stream_name="stdout")) + self._stderr_task = asyncio.create_task(self._read_stream(self._process.stderr, stream_name="stderr")) + self._wait_task = asyncio.create_task(self._watch_process_exit()) + + @override + def id(self) -> str: + return self._id + + async def _read_stream(self, reader: Optional[asyncio.StreamReader], *, stream_name: str) -> None: + if reader is None: + return + while True: + chunk = await reader.read(4096) + if not chunk: + return + await self._append_output(chunk.decode("utf-8", errors="replace"), stream=stream_name) + + async def _read_pty(self, master_fd: int) -> None: + loop = asyncio.get_running_loop() + flags = fcntl.fcntl(master_fd, fcntl.F_GETFL) + fcntl.fcntl(master_fd, fcntl.F_SETFL, flags | os.O_NONBLOCK) + + read_event = asyncio.Event() + + def _on_readable() -> None: + read_event.set() + + loop.add_reader(master_fd, _on_readable) + try: + while True: + read_event.clear() + if self._process.returncode is not None: + while True: + try: + data = os.read(master_fd, 4096) + if not data: + break + await self._append_output(data.decode("utf-8", errors="replace"), stream="stdout") + except BlockingIOError: + break + except OSError: + break + return + + try: + await asyncio.wait_for(read_event.wait(), timeout=0.05) + except asyncio.TimeoutError: + pass + try: + data = os.read(master_fd, 4096) + if data: + await self._append_output(data.decode("utf-8", errors="replace"), stream="stdout") + except BlockingIOError: + pass + except OSError: + return + finally: + loop.remove_reader(master_fd) + + async def _append_output(self, chunk: str, *, stream: str) -> None: + normalized = chunk.replace("\r\n", "\n") + async with self._lock: + if stream == "stderr": + self._stderr += normalized + else: + self._stdout += normalized + + merged = self._partial + normalized + lines, self._partial = _split_lines_with_partial(merged) + self._lines.extend(lines) + self._trim_lines_locked() + + async def _watch_process_exit(self) -> None: + code = await self._process.wait() + await asyncio.gather(self._stdout_task, self._stderr_task, return_exceptions=True) + async with self._lock: + if self._finished_at is not None: + return + if self._partial: + self._lines.append(self._partial) + self._partial = "" + self._trim_lines_locked() + self._exit_code = code + self._finished_at = time.time() + + def _trim_lines_locked(self) -> None: + if self._max_lines <= 0: + return + if len(self._lines) <= self._max_lines: + return + drop = len(self._lines) - self._max_lines + self._lines = self._lines[drop:] + self._line_base += drop + if self._poll_cursor < self._line_base: + self._poll_cursor = self._line_base + + @override + async def poll(self, limit: Optional[int] = None) -> ProgramPoll: + async with self._lock: + start = self._poll_cursor + if start < self._line_base: + start = self._line_base + self._poll_cursor = start + end = self._line_base + len(self._lines) + if limit is not None and limit > 0: + end = min(end, start + limit) + + out = "" + if end > start: + out = "\n".join(self._lines[start - self._line_base:end - self._line_base]) + if end == self._line_base + len(self._lines) and self._partial: + out = f"{out}\n{self._partial}" if out else self._partial + + self._poll_cursor = end + status = PROGRAM_STATUS_RUNNING if self._finished_at is None else PROGRAM_STATUS_EXITED + return ProgramPoll( + status=status, + output=out, + offset=start, + next_offset=end, + exit_code=self._exit_code, + ) + + @override + async def log(self, offset: Optional[int] = None, limit: Optional[int] = None) -> ProgramLog: + async with self._lock: + start = self._line_base if offset is None else offset + end = self._line_base + len(self._lines) + + if start < self._line_base: + start = self._line_base + if start > end: + start = end + if limit is not None and limit > 0: + end = min(end, start + limit) + + out = "" + if end > start: + out = "\n".join(self._lines[start - self._line_base:end - self._line_base]) + if end == self._line_base + len(self._lines) and self._partial: + out = f"{out}\n{self._partial}" if out else self._partial + return ProgramLog(output=out, offset=start, next_offset=end) + + @override + async def write(self, data: str, newline: bool) -> None: + if not data and not newline: + return + if self._process.returncode is not None: + raise ValueError("session is not running") + text = data + if newline: + text += "\n" + if self._master_fd is not None: + try: + os.write(self._master_fd, text.encode("utf-8")) + return + except OSError as ex: + raise ValueError("stdin is not available") from ex + stdin = self._process.stdin + if stdin is None: + raise ValueError("stdin is not available") + stdin.write(text.encode("utf-8")) + await stdin.drain() + + @override + async def kill(self, grace_seconds: float) -> None: + if self._process.returncode is not None: + return + self._process.terminate() + try: + await asyncio.wait_for(self._process.wait(), timeout=max(0.0, grace_seconds)) + return + except asyncio.TimeoutError: + pass + if self._process.returncode is None: + self._process.kill() + await self._process.wait() + + @override + async def close(self) -> None: + async with self._lock: + if self._closed: + return + self._closed = True + if self._process.returncode is None: + await self.kill(0.5) + if self._process.stdin is not None: + try: + self._process.stdin.close() + except Exception: # pylint: disable=broad-except + pass + await asyncio.gather(self._stdout_task, self._stderr_task, self._wait_task, return_exceptions=True) + if self._master_fd is not None: + try: + os.close(self._master_fd) + except OSError: + pass + + async def enforce_timeout(self, timeout_sec: float) -> None: + if timeout_sec <= 0: + return + await asyncio.sleep(timeout_sec) + if self._process.returncode is None: + self._timed_out = True + await self.kill(0.5) + + @override + async def state(self) -> ProgramState: + if self._finished_at is None: + return ProgramState(status=PROGRAM_STATUS_RUNNING) + return ProgramState(status=PROGRAM_STATUS_EXITED, exit_code=self._exit_code) + + @override + async def run_result(self) -> WorkspaceRunResult: + duration = 0.0 + if self._finished_at is not None: + duration = self._finished_at - self._started_at + return WorkspaceRunResult( + stdout=self._stdout, + stderr=self._stderr, + exit_code=self._exit_code or 0, + duration=duration, + timed_out=self._timed_out, + ) diff --git a/trpc_agent_sdk/code_executors/local/_local_ws_runtime.py b/trpc_agent_sdk/code_executors/local/_local_ws_runtime.py index 87151c6..1b245ee 100644 --- a/trpc_agent_sdk/code_executors/local/_local_ws_runtime.py +++ b/trpc_agent_sdk/code_executors/local/_local_ws_runtime.py @@ -11,11 +11,14 @@ from __future__ import annotations +import asyncio import os import re import shutil +import sys import tempfile import time +import uuid from datetime import datetime from pathlib import Path from typing import List @@ -32,6 +35,7 @@ from .._base_workspace_runtime import BaseWorkspaceFS from .._base_workspace_runtime import BaseWorkspaceManager from .._base_workspace_runtime import BaseWorkspaceRuntime +from .._base_workspace_runtime import RunEnvProvider from .._constants import DEFAULT_FILE_MODE from .._constants import DEFAULT_MAX_FILES from .._constants import DEFAULT_MAX_TOTAL_BYTES @@ -46,6 +50,7 @@ from .._constants import ENV_WORK_DIR from .._constants import MAX_READ_SIZE_BYTES from .._constants import WORKSPACE_ENV_DIR_KEY +from .._program_session import BaseProgramSession from .._types import CodeFile from .._types import ManifestFileRef from .._types import ManifestOutput @@ -69,6 +74,11 @@ from ..utils import path_join from ..utils import save_metadata +if sys.platform != "win32": + import pty + +from ._local_program_session import LocalProgramSession + class LocalWorkspaceManager(BaseWorkspaceManager): """Local workspace manager for executing commands in skill workspaces.""" @@ -380,7 +390,10 @@ async def stage_inputs( if mode == "link": make_symlink(ws.path, dst.as_posix(), host_path) else: - self._put_directory(ws, host_path, dst.parent.as_posix()) + # Preserve caller-provided destination path exactly. + # For files, copy to the exact dst filename. + # For directories, copy tree into dst directory. + copy_path(host_path, path_join(ws.path, dst.as_posix())) elif spec.src.startswith("workspace://"): # Handle workspace inputs rel = spec.src[len("workspace://"):] @@ -574,6 +587,35 @@ def _read_limited_with_cap( class LocalProgramRunner(BaseProgramRunner): """Local program runner for executing commands in skill workspaces.""" + def __init__( + self, + provider: Optional[RunEnvProvider] = None, + enable_provider_env: bool = False, + ): + super().__init__(provider=provider, enable_provider_env=enable_provider_env) + + def _build_program_env(self, ws: WorkspaceInfo, spec: WorkspaceRunProgramSpec) -> dict[str, str]: + env = os.environ.copy() + user_env = dict(spec.env or {}) + wr_path = Path(ws.path) + ensure_layout(wr_path) + run_dir = wr_path / DIR_RUNS / f"run_{datetime.now().strftime('%Y%m%dT%H%M%S.%f')}" + run_dir.mkdir(parents=True, exist_ok=True) + + base_env = { + WORKSPACE_ENV_DIR_KEY: ws.path, + ENV_SKILLS_DIR: str(Path(ws.path) / DIR_SKILLS), + ENV_WORK_DIR: str(Path(ws.path) / DIR_WORK), + ENV_OUTPUT_DIR: str(Path(ws.path) / DIR_OUT), + ENV_RUN_DIR: str(run_dir), + } + for key, value in base_env.items(): + if key not in user_env: + env[key] = value + if user_env: + env.update(user_env) + return env + @override async def run_program(self, ws: WorkspaceInfo, @@ -591,37 +633,13 @@ async def run_program(self, Returns: Execution result """ + spec = self._apply_provider_env(spec, ctx) # Resolve cwd under workspace cwd = Path(path_join(ws.path, spec.cwd)) cwd.mkdir(parents=True, exist_ok=True) timeout = spec.timeout or float(DEFAULT_TIMEOUT_SEC) - - # Build environment - env = os.environ.copy() - - wr_path = Path(ws.path) - # Ensure layout exists and compute run dir - ensure_layout(wr_path) - run_dir = wr_path / DIR_RUNS / f"run_{datetime.now().strftime('%Y%m%dT%H%M%S.%f')}" - run_dir.mkdir(parents=True, exist_ok=True) - - # Inject well-known variables if not set - base_env = { - WORKSPACE_ENV_DIR_KEY: ws.path, - ENV_SKILLS_DIR: str(Path(ws.path) / DIR_SKILLS), - ENV_WORK_DIR: str(Path(ws.path) / DIR_WORK), - ENV_OUTPUT_DIR: str(Path(ws.path) / DIR_OUT), - ENV_RUN_DIR: str(run_dir), - } - - for key, value in base_env.items(): - if key not in spec.env: - env[key] = value - - # Add user-provided environment variables - if spec.env: - env.update(spec.env) + env = self._build_program_env(ws, spec) # Prepare command cmd_args = [spec.cmd] + (spec.args or []) @@ -639,6 +657,62 @@ async def run_program(self, duration=time.time() - start_time, timed_out=result.is_timeout) + async def start_program( + self, + ctx: Optional[InvocationContext], + ws: WorkspaceInfo, + spec: WorkspaceRunProgramSpec, + ) -> BaseProgramSession: + """Start an interactive program session in workspace.""" + if spec.tty and sys.platform == "win32": + raise ValueError("interactive tty is not supported on windows") + + spec = self._apply_provider_env(spec, ctx) + cwd = Path(path_join(ws.path, spec.cwd)) + cwd.mkdir(parents=True, exist_ok=True) + env = self._build_program_env(ws, spec) + timeout = spec.timeout or float(DEFAULT_TIMEOUT_SEC) + + cmd_args = [spec.cmd] + (spec.args or []) + if spec.tty: + master_fd, slave_fd = pty.openpty() + try: + process = await asyncio.create_subprocess_exec( + *cmd_args, + cwd=str(cwd), + env=env, + stdin=slave_fd, + stdout=slave_fd, + stderr=slave_fd, + close_fds=True, + preexec_fn=os.setsid, + ) + except Exception: + os.close(master_fd) + os.close(slave_fd) + raise + finally: + try: + os.close(slave_fd) + except OSError: + pass + session = LocalProgramSession(process, master_fd=master_fd) + else: + process = await asyncio.create_subprocess_exec( + *cmd_args, + cwd=str(cwd), + env=env, + stdin=asyncio.subprocess.PIPE, + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE, + ) + session = LocalProgramSession(process) + if timeout > 0: + asyncio.create_task(session.enforce_timeout(float(timeout))) + if spec.stdin: + await session.write(spec.stdin, newline=False) + return session + class LocalWorkspaceRuntime(BaseWorkspaceRuntime): """Local workspace for executing commands in skill workspaces.""" @@ -647,9 +721,11 @@ def __init__(self, work_root: str = '', read_only_staged_skill: bool = False, auto_inputs: bool = True, - inputs_host_base: str = ""): + inputs_host_base: str = "", + provider: Optional[RunEnvProvider] = None, + enable_provider_env: bool = False): self._fs = LocalWorkspaceFS(read_only_staged_skill) - self._runner = LocalProgramRunner() + self._runner = LocalProgramRunner(provider=provider, enable_provider_env=enable_provider_env) self._manager = LocalWorkspaceManager(work_root, auto_inputs, inputs_host_base, self._fs) @override @@ -681,6 +757,9 @@ def describe(self, ctx: Optional[InvocationContext] = None) -> WorkspaceCapabili def create_local_workspace_runtime(work_root: str = '', read_only_staged_skill: bool = False, auto_inputs: bool = True, - inputs_host_base: str = "") -> LocalWorkspaceRuntime: + inputs_host_base: str = "", + provider: Optional[RunEnvProvider] = None, + enable_provider_env: bool = False) -> LocalWorkspaceRuntime: """Create a new local workspace runtime.""" - return LocalWorkspaceRuntime(work_root, read_only_staged_skill, auto_inputs, inputs_host_base) + return LocalWorkspaceRuntime(work_root, read_only_staged_skill, auto_inputs, inputs_host_base, provider, + enable_provider_env) diff --git a/trpc_agent_sdk/code_executors/local/_unsafe_local_code_executor.py b/trpc_agent_sdk/code_executors/local/_unsafe_local_code_executor.py index 9593ca6..3b3aa19 100644 --- a/trpc_agent_sdk/code_executors/local/_unsafe_local_code_executor.py +++ b/trpc_agent_sdk/code_executors/local/_unsafe_local_code_executor.py @@ -17,7 +17,6 @@ from typing_extensions import override from pydantic import Field - from trpc_agent_sdk.context import InvocationContext from trpc_agent_sdk.utils import async_execute_command @@ -44,7 +43,7 @@ class UnsafeLocalCodeExecutor(BaseCodeExecutor): work_dir: str = Field(default="", description="The working directory for the code execution.") - timeout: float = Field(default=0, description="The timeout for the code execution.") + timeout: float = Field(default=0, description="The timeout seconds for the code execution.") clean_temp_files: bool = Field(default=True, description="Whether to clean temporary files after the code execution.") diff --git a/trpc_agent_sdk/code_executors/utils/_code_execution.py b/trpc_agent_sdk/code_executors/utils/_code_execution.py index bae808a..8635015 100644 --- a/trpc_agent_sdk/code_executors/utils/_code_execution.py +++ b/trpc_agent_sdk/code_executors/utils/_code_execution.py @@ -24,6 +24,20 @@ class CodeExecutionUtils: """Utility functions for code execution.""" + @classmethod + def _is_ignored_code_block(cls, code: str, ignore_codes: list[str]) -> bool: + """Return True when code block first line is in ignore list.""" + if not code: + return False + lines = code.splitlines() + if not lines: + return False + first_line = lines[0].strip() + if not first_line: + return False + ignore_set = {item.strip() for item in ignore_codes if item and item.strip()} + return first_line in ignore_set + @classmethod def prepare_globals(cls, code: str, globals_: dict[str, Any]) -> None: """Prepare globals for code execution, injecting __name__ if needed.""" @@ -61,6 +75,7 @@ def extract_code_and_truncate_content( cls, content: Content, code_block_delimiters: list[CodeBlockDelimiter], + ignore_codes: list[str] = None, ) -> list[CodeBlock]: """Extracts all code blocks from the content and reconstructs content.parts. @@ -71,10 +86,12 @@ def extract_code_and_truncate_content( content: The mutable content to extract the code from. code_block_delimiters: The list of the enclosing delimiters to identify the code blocks. + ignore_codes: The list of codes to ignore. Returns: The first code block if found; otherwise, None. """ + ignore_codes = ignore_codes or [] code_blocks = [] if not content or not content.parts: return code_blocks @@ -84,12 +101,13 @@ def extract_code_and_truncate_content( total_len = len(content.parts) for idx, part in enumerate(content.parts): if part.executable_code: + code_str = part.executable_code.code or "" + if cls._is_ignored_code_block(code_str, ignore_codes): + continue if idx < total_len - 1 and not content.parts[idx + 1].code_execution_result: - code_blocks.append(CodeBlock(code=part.executable_code.code, - language=part.executable_code.language)) + code_blocks.append(CodeBlock(code=code_str, language=part.executable_code.language)) if idx == total_len - 1: - code_blocks.append(CodeBlock(code=part.executable_code.code, - language=part.executable_code.language)) + code_blocks.append(CodeBlock(code=code_str, language=part.executable_code.language)) # If there are code blocks, return them. if code_blocks: return code_blocks @@ -140,6 +158,9 @@ def extract_code_and_truncate_content( # Extract code content, removing leading/trailing whitespace and newlines code_str = code_content.strip() + if cls._is_ignored_code_block(code_str, ignore_codes): + last_end = match.end() + continue # Store first code block for return value if first_code is None: diff --git a/trpc_agent_sdk/server/langfuse/tracing/opentelemetry.py b/trpc_agent_sdk/server/langfuse/tracing/opentelemetry.py index 166f20a..5dab42d 100644 --- a/trpc_agent_sdk/server/langfuse/tracing/opentelemetry.py +++ b/trpc_agent_sdk/server/langfuse/tracing/opentelemetry.py @@ -58,7 +58,6 @@ def _should_skip_span(self, span: ReadableSpan) -> bool: Returns: True if the span should be skipped, False otherwise. """ - global _langfuse_config # pylint: disable=invalid-name # If enable_a2a_trace is True, don't filter out any spans if _langfuse_config.enable_a2a_trace: return False @@ -99,7 +98,6 @@ def _should_skip_span(self, span: ReadableSpan) -> bool: def _transform_span_for_langfuse(self, span: ReadableSpan) -> ReadableSpan: """Transform TRPC agent span attributes to Langfuse format.""" - global _langfuse_config # pylint: disable=invalid-name trpc_span_name = get_trpc_agent_span_name() if span.name == "invocation": span_name = span.attributes.get(f"{trpc_span_name}.runner.name", "unknown") @@ -262,6 +260,8 @@ def _map_generation_attributes(self, attributes: Dict[str, Any]) -> Dict[str, An # Map usage details gen_attrs["gen_ai.usage.input_tokens"] = attributes.get("gen_ai.usage.input_tokens", "0") gen_attrs["gen_ai.usage.output_tokens"] = attributes.get("gen_ai.usage.output_tokens", "0") + if "gen_ai.request.model" in attributes: + gen_attrs["gen_ai.request.model"] = attributes["gen_ai.request.model"] # Map generation metadata gen_metadata = {} @@ -603,7 +603,6 @@ def setup(config: Optional[LangfuseConfig] = None) -> TracerProvider: Raises: ValueError: If required configuration is missing. """ - global _langfuse_config # pylint: disable=invalid-name if config is None: config = LangfuseConfig() @@ -619,6 +618,7 @@ def setup(config: Optional[LangfuseConfig] = None) -> TracerProvider: if config.host is None: config.host = os.getenv("LANGFUSE_HOST") + global _langfuse_config # pylint: disable=invalid-name # Set the global config and use it in span processor(Only be setted once) _langfuse_config = config diff --git a/trpc_agent_sdk/sessions/_base_session_service.py b/trpc_agent_sdk/sessions/_base_session_service.py index 0d984c1..dd41fb7 100644 --- a/trpc_agent_sdk/sessions/_base_session_service.py +++ b/trpc_agent_sdk/sessions/_base_session_service.py @@ -84,6 +84,9 @@ async def append_event(self, session: Session, event: Event) -> Event: """Appends an event to a session object.""" if event.partial: return event + # Apply temp-scoped state to in-memory session before trimming event delta, + # so same-invocation consumers can still read temp values. + self._apply_temp_state(session, event) event = self._trim_temp_delta_state(event) self.__update_session_state(session, event) session.add_event(event, @@ -91,6 +94,18 @@ async def append_event(self, session: Session, event: Event) -> Event: max_events=self._session_config.max_events) return event + def _apply_temp_state(self, session: Session, event: Event) -> None: + """Apply temp-scoped state delta to in-memory session state only. + + Temp state is intentionally ephemeral: it should be visible within + current invocation memory but not persisted into stored event deltas. + """ + if not event.actions or not event.actions.state_delta: + return + for key, value in event.actions.state_delta.items(): + if key.startswith(State.TEMP_PREFIX): + session.state[key] = value + def _trim_temp_delta_state(self, event: Event) -> Event: """Removes temporary state delta keys from the event.""" if not event.actions or not event.actions.state_delta: diff --git a/trpc_agent_sdk/sessions/_history_record.py b/trpc_agent_sdk/sessions/_history_record.py index 62086a6..1950e09 100644 --- a/trpc_agent_sdk/sessions/_history_record.py +++ b/trpc_agent_sdk/sessions/_history_record.py @@ -23,6 +23,7 @@ class HistoryRecord(BaseModel): user_texts: list[str] = Field(default_factory=list, description="List of user text") # The text of the user assistant_texts: list[str] = Field(default_factory=list, description="List of assistant text") + # The text of the assistant def add_record(self, user_text: str, assistant_text: str | None = ""): diff --git a/trpc_agent_sdk/sessions/_session.py b/trpc_agent_sdk/sessions/_session.py index 7598a47..28c0b91 100644 --- a/trpc_agent_sdk/sessions/_session.py +++ b/trpc_agent_sdk/sessions/_session.py @@ -11,7 +11,6 @@ from typing import List from pydantic import Field - from trpc_agent_sdk.abc import SessionABC from trpc_agent_sdk.events import Event @@ -54,10 +53,7 @@ def apply_event_filtering(self, event_ttl_seconds: float = 0.0, max_events: int 2. Count filtering: Keep only the most recent max_events events If both filters result in removing all events, the method attempts to - preserve the first user message and all events after it from the original events. - - The filtering logic is inspired by the ApplyEventFiltering function - from trpc-agent-go: https://github.com/trpc-group/trpc-agent-go + preserve the first user message and all events after it from the original events. Args: event_ttl_seconds: Time-to-live in seconds for events. If 0, no TTL filtering is applied. diff --git a/trpc_agent_sdk/skills/__init__.py b/trpc_agent_sdk/skills/__init__.py index b8afbb3..d457bd6 100644 --- a/trpc_agent_sdk/skills/__init__.py +++ b/trpc_agent_sdk/skills/__init__.py @@ -1,3 +1,7 @@ +# -*- coding: utf-8 -*- +# +# Copyright @ 2026 Tencent.com + # Tencent is pleased to support the open source community by making tRPC-Agent-Python available. # # Copyright (C) 2026 Tencent. All rights reserved. @@ -21,120 +25,156 @@ >>> tools = await toolset.get_tools() """ -from ._common import BaseSelectionResult from ._common import SelectionMode -from ._common import add_selection -from ._common import clear_selection +from ._common import docs_scan_prefix +from ._common import docs_state_key from ._common import generic_get_selection -from ._common import generic_select_items -from ._common import get_previous_selection -from ._common import get_state_delta_value -from ._common import replace_selection -from ._common import set_state_delta_for_selection +from ._common import loaded_order_state_key +from ._common import loaded_scan_prefix +from ._common import loaded_state_key +from ._common import tool_scan_prefix +from ._common import tool_state_key +from ._common import use_session_skill_state from ._constants import ENV_SKILLS_ROOT +from ._constants import SKILL_ARTIFACTS_STATE_KEY +from ._constants import SKILL_DOCS_BY_AGENT_STATE_KEY_PREFIX from ._constants import SKILL_DOCS_STATE_KEY_PREFIX from ._constants import SKILL_FILE +from ._constants import SKILL_LOADED_BY_AGENT_STATE_KEY_PREFIX +from ._constants import SKILL_LOADED_ORDER_BY_AGENT_STATE_KEY_PREFIX +from ._constants import SKILL_LOADED_ORDER_STATE_KEY_PREFIX from ._constants import SKILL_LOADED_STATE_KEY_PREFIX +from ._constants import SKILL_LOAD_MODE_VALUES from ._constants import SKILL_REGISTRY_KEY from ._constants import SKILL_REPOSITORY_KEY +from ._constants import SKILL_TOOLS_BY_AGENT_STATE_KEY_PREFIX +from ._constants import SKILL_TOOLS_NAMES from ._constants import SKILL_TOOLS_STATE_KEY_PREFIX +from ._constants import SkillLoadModeNames +from ._constants import SkillProfileNames +from ._constants import SkillToolsNames from ._dynamic_toolset import DynamicSkillToolSet from ._registry import SkillRegistry from ._repository import BaseSkillRepository from ._repository import FsSkillRepository +from ._repository import VisibilityFilter from ._repository import create_default_skill_repository +from ._skill_config import get_skill_config +from ._skill_config import get_skill_load_mode +from ._skill_config import set_skill_config +from ._skill_profile import SkillProfileFlags +from ._state_keys import docs_key +from ._state_keys import docs_prefix +from ._state_keys import docs_session_key +from ._state_keys import docs_session_prefix +from ._state_keys import loaded_key +from ._state_keys import loaded_order_key +from ._state_keys import loaded_prefix +from ._state_keys import loaded_session_key +from ._state_keys import loaded_session_order_key +from ._state_keys import loaded_session_prefix +from ._state_keys import tool_key +from ._state_keys import tool_prefix +from ._state_keys import tool_session_key +from ._state_keys import tool_session_prefix +from ._state_migration import SKILLS_LEGACY_MIGRATION_STATE_KEY +from ._state_migration import maybe_migrate_legacy_skill_state +from ._state_order import marshal_loaded_order +from ._state_order import parse_loaded_order +from ._state_order import touch_loaded_order from ._toolset import SkillToolSet from ._types import Skill from ._types import SkillConfig from ._types import SkillFrontMatter -from ._types import SkillMetadata from ._types import SkillRequires from ._types import SkillResource from ._types import SkillSummary -from ._types import SkillWorkspaceInputRecord -from ._types import SkillWorkspaceMetadata -from ._types import SkillWorkspaceOutputRecord -from ._types import format_datetime -from ._types import parse_datetime -from ._url_root import ArchiveExt from ._url_root import ArchiveExtractor -from ._url_root import ArchiveKind -from ._url_root import CacheConfig -from ._url_root import FilePerm -from ._url_root import SizeLimit from ._url_root import SkillRootResolver -from ._url_root import TarPerm -from ._utils import compute_dir_digest -from ._utils import ensure_layout from ._utils import get_state_delta -from ._utils import load_metadata -from ._utils import save_metadata from ._utils import set_state_delta -from ._utils import shell_quote +from .tools import SkillLoadTool from .tools import SkillRunTool from .tools import skill_list from .tools import skill_list_docs from .tools import skill_list_tools -from .tools import skill_load from .tools import skill_select_docs from .tools import skill_select_tools __all__ = [ - "BaseSelectionResult", "SelectionMode", - "add_selection", - "clear_selection", + "docs_scan_prefix", + "docs_state_key", "generic_get_selection", - "generic_select_items", - "get_previous_selection", - "get_state_delta_value", - "replace_selection", - "set_state_delta_for_selection", + "loaded_order_state_key", + "loaded_scan_prefix", + "loaded_state_key", + "tool_scan_prefix", + "tool_state_key", + "use_session_skill_state", "ENV_SKILLS_ROOT", + "SKILL_ARTIFACTS_STATE_KEY", + "SKILL_DOCS_BY_AGENT_STATE_KEY_PREFIX", "SKILL_DOCS_STATE_KEY_PREFIX", "SKILL_FILE", + "SKILL_LOADED_BY_AGENT_STATE_KEY_PREFIX", + "SKILL_LOADED_ORDER_BY_AGENT_STATE_KEY_PREFIX", + "SKILL_LOADED_ORDER_STATE_KEY_PREFIX", "SKILL_LOADED_STATE_KEY_PREFIX", + "SKILL_LOAD_MODE_VALUES", "SKILL_REGISTRY_KEY", "SKILL_REPOSITORY_KEY", + "SKILL_TOOLS_BY_AGENT_STATE_KEY_PREFIX", + "SKILL_TOOLS_NAMES", "SKILL_TOOLS_STATE_KEY_PREFIX", + "SkillLoadModeNames", + "SkillProfileNames", + "SkillToolsNames", "DynamicSkillToolSet", "SkillRegistry", "BaseSkillRepository", "FsSkillRepository", + "VisibilityFilter", "create_default_skill_repository", + "get_skill_config", + "get_skill_load_mode", + "set_skill_config", + "SkillProfileFlags", + "docs_key", + "docs_prefix", + "docs_session_key", + "docs_session_prefix", + "loaded_key", + "loaded_order_key", + "loaded_prefix", + "loaded_session_key", + "loaded_session_order_key", + "loaded_session_prefix", + "tool_key", + "tool_prefix", + "tool_session_key", + "tool_session_prefix", + "SKILLS_LEGACY_MIGRATION_STATE_KEY", + "maybe_migrate_legacy_skill_state", + "marshal_loaded_order", + "parse_loaded_order", + "touch_loaded_order", "SkillToolSet", "Skill", "SkillConfig", "SkillFrontMatter", - "SkillMetadata", "SkillRequires", "SkillResource", "SkillSummary", - "SkillWorkspaceInputRecord", - "SkillWorkspaceMetadata", - "SkillWorkspaceOutputRecord", - "format_datetime", - "parse_datetime", - "ArchiveExt", "ArchiveExtractor", - "ArchiveKind", - "CacheConfig", - "FilePerm", - "SizeLimit", "SkillRootResolver", - "TarPerm", - "compute_dir_digest", - "ensure_layout", "get_state_delta", - "load_metadata", - "save_metadata", "set_state_delta", - "shell_quote", + "SkillLoadTool", "SkillRunTool", "skill_list", "skill_list_docs", "skill_list_tools", - "skill_load", "skill_select_docs", "skill_select_tools", ] diff --git a/trpc_agent_sdk/skills/_common.py b/trpc_agent_sdk/skills/_common.py index a50f5a7..0eefc2c 100644 --- a/trpc_agent_sdk/skills/_common.py +++ b/trpc_agent_sdk/skills/_common.py @@ -24,6 +24,26 @@ from trpc_agent_sdk.context import InvocationContext from trpc_agent_sdk.log import logger +from ._constants import SkillLoadModeNames +from ._skill_config import get_skill_load_mode +from ._state_keys import docs_key +from ._state_keys import docs_prefix +from ._state_keys import docs_session_key +from ._state_keys import docs_session_prefix +from ._state_keys import loaded_key +from ._state_keys import loaded_order_key +from ._state_keys import loaded_prefix +from ._state_keys import loaded_session_key +from ._state_keys import loaded_session_order_key +from ._state_keys import loaded_session_prefix +from ._state_keys import tool_key +from ._state_keys import tool_prefix +from ._state_keys import tool_session_key +from ._state_keys import tool_session_prefix +from ._state_order import marshal_loaded_order +from ._state_order import parse_loaded_order +from ._state_order import touch_loaded_order + # Generic type for selection results T = TypeVar('T', bound=BaseModel) # pylint: disable=invalid-name @@ -59,8 +79,45 @@ class BaseSelectionResult(BaseModel): mode: str = Field(default="", description="The mode used for selecting") -def _set_state_delta(invocation_context: InvocationContext, key: str, value: Any) -> None: - """Set the state delta for a key. +def get_agent_name(invocation_context: InvocationContext) -> str: + """Get normalized agent name from invocation context.""" + name = getattr(invocation_context, "agent_name", "") + return name.strip() if isinstance(name, str) else "" + + +def use_session_skill_state(invocation_context: InvocationContext) -> bool: + """Whether skill-related state should use persistent (non-temp) keys.""" + return get_skill_load_mode(invocation_context) == SkillLoadModeNames.SESSION + + +def normalize_selection_mode(mode: str) -> str: + """Normalize selection mode; defaults to replace.""" + m = (mode or "").strip().lower() + if m in ("add", "replace", "clear"): + return m + return "replace" + + +def get_previous_selection_by_key(invocation_context: InvocationContext, key: str) -> tuple[list[str], bool]: + """Get previous selection by exact state key. + + Returns: + (selected_items, had_all) + """ + value = get_state_delta_value(invocation_context, key) + if not value: + return [], False + if value == "*": + return [], True + try: + parsed = json.loads(value) + except (json.JSONDecodeError, TypeError): + return [], False + return parsed if isinstance(parsed, list) else [], False + + +def set_state_delta(invocation_context: InvocationContext, key: str, value: Any) -> None: + """Set the state delta of a skill workspace. Args: invocation_context: Invocation context @@ -70,6 +127,35 @@ def _set_state_delta(invocation_context: InvocationContext, key: str, value: Any invocation_context.actions.state_delta[key] = value +def set_selection_state_delta_by_key( + invocation_context: InvocationContext, + key: str, + selected_items: list[str], + include_all: bool, +) -> None: + """Set selection state delta by exact state key.""" + if include_all: + set_state_delta(invocation_context, key, "*") + else: + set_state_delta(invocation_context, key, json.dumps(selected_items or [])) + + +def append_loaded_order_state_delta( + invocation_context: InvocationContext, + agent_name: str, + skill_name: str, +) -> None: + """Append loaded-order state delta for skill touch order.""" + key = loaded_order_key(agent_name) + if use_session_skill_state(invocation_context): + key = loaded_session_order_key(key) + current = parse_loaded_order(get_state_delta_value(invocation_context, key)) + next_order = touch_loaded_order(current, skill_name) + encoded = marshal_loaded_order(next_order) + if encoded: + set_state_delta(invocation_context, key, encoded) + + def get_previous_selection(invocation_context: InvocationContext, state_key_prefix: str, skill_name: str) -> Optional[list[str]]: """Get the previous selection for a skill from session state. @@ -191,12 +277,12 @@ def set_state_delta_for_selection(invocation_context: InvocationContext, state_k include_all = getattr(result, 'include_all', False) if include_all: - _set_state_delta(invocation_context, key, '*') + set_state_delta(invocation_context, key, '*') return selected_items = getattr(result, 'selected_items', []) selected_json = json.dumps(selected_items) - _set_state_delta(invocation_context, key, selected_json) + set_state_delta(invocation_context, key, selected_json) def generic_select_items(tool_context: InvocationContext, skill_name: str, items: Optional[list[str]], @@ -354,3 +440,52 @@ def get_all_tools(skill_name): except json.JSONDecodeError: logger.warning("Failed to parse selection for skill '%s' with key '%s': %s", skill_name, key, v_str) return [] + + +def loaded_state_key(ctx: InvocationContext, skill_name: str) -> str: + """Return the loaded-state key for a skill.""" + agent_name = ctx.agent_name.strip() + load_keys = loaded_key(agent_name, skill_name) + return loaded_session_key(load_keys) if use_session_skill_state(ctx) else load_keys + + +def docs_state_key(ctx: InvocationContext, skill_name: str) -> str: + """Return the docs-state key for a skill.""" + agent_name = ctx.agent_name.strip() + key = docs_key(agent_name, skill_name) + return docs_session_key(key) if use_session_skill_state(ctx) else key + + +def tool_state_key(ctx: InvocationContext, skill_name: str) -> str: + """Return the tools-state key for a skill.""" + agent_name = ctx.agent_name.strip() + key = tool_key(agent_name, skill_name) + return tool_session_key(key) if use_session_skill_state(ctx) else key + + +def loaded_scan_prefix(ctx: InvocationContext) -> str: + """Return the loaded-state scan prefix for an agent.""" + agent_name = ctx.agent_name.strip() + key = loaded_prefix(agent_name) + return loaded_session_prefix(key) if use_session_skill_state(ctx) else key + + +def docs_scan_prefix(ctx: InvocationContext) -> str: + """Return the docs-state scan prefix for an agent.""" + agent_name = ctx.agent_name.strip() + key = docs_prefix(agent_name) + return docs_session_prefix(key) if use_session_skill_state(ctx) else key + + +def tool_scan_prefix(ctx: InvocationContext) -> str: + """Return the tools-state scan prefix for an agent.""" + agent_name = ctx.agent_name.strip() + key = tool_prefix(agent_name) + return tool_session_prefix(key) if use_session_skill_state(ctx) else key + + +def loaded_order_state_key(ctx: InvocationContext) -> str: + """Return the loaded-order state key for an agent.""" + agent_name = ctx.agent_name.strip() + key = loaded_order_key(agent_name) + return loaded_session_order_key(key) if use_session_skill_state(ctx) else key diff --git a/trpc_agent_sdk/skills/_constants.py b/trpc_agent_sdk/skills/_constants.py index 292fba8..9a36626 100644 --- a/trpc_agent_sdk/skills/_constants.py +++ b/trpc_agent_sdk/skills/_constants.py @@ -8,6 +8,8 @@ This module defines constants for the skills system. """ +from enum import Enum + SKILL_FILE = "SKILL.md" # Environment variable name for skills root directory @@ -24,15 +26,80 @@ SKILL_LOADED_STATE_KEY_PREFIX = "temp:skill:loaded:" """State key for loaded skills.""" +# State key for loaded skills scoped by agent +SKILL_LOADED_BY_AGENT_STATE_KEY_PREFIX = "temp:skill:loaded_by_agent:" +"""State key prefix for loaded skills scoped by agent.""" + # State key for docs of skills SKILL_DOCS_STATE_KEY_PREFIX = "temp:skill:docs:" """State key for docs of skills.""" +# State key for docs of skills scoped by agent +SKILL_DOCS_BY_AGENT_STATE_KEY_PREFIX = "temp:skill:docs_by_agent:" +"""State key prefix for docs scoped by agent.""" + +# State key for loaded skill order +SKILL_LOADED_ORDER_STATE_KEY_PREFIX = "temp:skill:loaded_order:" +"""State key prefix for loaded skill touch order.""" + +# State key for loaded skill order scoped by agent +SKILL_LOADED_ORDER_BY_AGENT_STATE_KEY_PREFIX = "temp:skill:loaded_order_by_agent:" +"""State key prefix for loaded skill touch order scoped by agent.""" + # State key for tools of skills SKILL_TOOLS_STATE_KEY_PREFIX = "temp:skill:tools:" """State key for tools of skills.""" +# State key for tools of skills scoped by agent +SKILL_TOOLS_BY_AGENT_STATE_KEY_PREFIX = "temp:skill:tools_by_agent:" +"""State key prefix for tools scoped by agent.""" + +# State key for per-tool-call artifact refs (replay support) +SKILL_ARTIFACTS_STATE_KEY = "temp:skill:artifacts" +"""State key for skill tool-call artifact references.""" + # EnvSkillsCacheDir overrides where URL-based skills roots are cached. # When empty, the user cache directory is used. ENV_SKILLS_CACHE_DIR = "SKILLS_CACHE_DIR" """Environment variable name for skills cache directory.""" + + +class SkillProfileNames(str, Enum): + FULL = "full" + KNOWLEDGE_ONLY = "knowledge_only" + + def __str__(self) -> str: + return self.value + + +class SkillToolsNames(str, Enum): + LOAD = "skill_load" + SELECT_DOCS = "skill_select_docs" + LIST_DOCS = "skill_list_docs" + RUN = "skill_run" + SELECT_TOOLS = "skill_select_tools" + LIST_SKILLS = "skill_list_skills" + EXEC = "skill_exec" + WRITE_STDIN = "skill_write_stdin" + POLL_SESSION = "skill_poll_session" + KILL_SESSION = "skill_kill_session" + + def __str__(self) -> str: + return self.value + + +SKILL_TOOLS_NAMES = [tool.value for tool in SkillToolsNames.__members__.values()] + + +class SkillLoadModeNames(str, Enum): + ONCE = "once" + TURN = "turn" + SESSION = "session" + + def __str__(self) -> str: + return self.value + + +SKILL_LOAD_MODE_VALUES = [mode.value for mode in SkillLoadModeNames.__members__.values()] + +SKILL_CONFIG_KEY = "__trpc_agent_skills_config" diff --git a/trpc_agent_sdk/skills/_dynamic_toolset.py b/trpc_agent_sdk/skills/_dynamic_toolset.py index ba648cf..7d11ca5 100644 --- a/trpc_agent_sdk/skills/_dynamic_toolset.py +++ b/trpc_agent_sdk/skills/_dynamic_toolset.py @@ -25,8 +25,9 @@ from trpc_agent_sdk.tools import get_tool from trpc_agent_sdk.tools import get_tool_set -from ._constants import SKILL_LOADED_STATE_KEY_PREFIX -from ._constants import SKILL_TOOLS_STATE_KEY_PREFIX +from ._common import loaded_scan_prefix +from ._common import tool_scan_prefix +from ._common import tool_state_key from ._repository import BaseSkillRepository from ._utils import get_state_delta @@ -130,6 +131,7 @@ def __init__(self, self._find_tool_by_name(tool) else: self._find_tool_by_type(tool) + logger.info("DynamicSkillToolSet initialized: %s tools, %s toolsets, only_active_skills=%s", len(self._available_tools), len(self._available_toolsets), only_active_skills) @@ -266,9 +268,8 @@ async def get_tools(self, ctx: InvocationContext) -> List[BaseTool]: tool = await self._resolve_tool(tool_name, ctx) if tool is None: logger.warning( - "Tool '%s' required by skill '%s' could not be resolved. Checked: available_tools (%s), " - "available_toolsets (%s), global registry", tool_name, skill_name, len(self._available_tools), - len(self._available_toolsets)) + "Tool '%s' required by skill '%s' could not be resolved. Checked: available_tools (%s), available_toolsets (%s), global registry", + tool_name, skill_name, len(self._available_tools), len(self._available_toolsets)) continue selected_tools.append(tool) @@ -293,13 +294,14 @@ def _get_loaded_skills_from_state(self, ctx: InvocationContext) -> List[str]: List of loaded skill names """ loaded_skills: List[str] = [] - # Combine session state and current state delta state = dict(ctx.session_state.copy()) state.update(ctx.actions.state_delta) - + prefix = loaded_scan_prefix(ctx) for key, value in state.items(): - if key.startswith(SKILL_LOADED_STATE_KEY_PREFIX) and value: - skill_name = key[len(SKILL_LOADED_STATE_KEY_PREFIX):] + if not value or not key.startswith(prefix): + continue + skill_name = key[len(prefix):].strip() + if skill_name: loaded_skills.append(skill_name) return loaded_skills @@ -321,18 +323,20 @@ def _get_active_skills_from_delta(self, ctx: InvocationContext) -> List[str]: """ active_skills: set[str] = set() - # Check state_delta for skill-related changes + loaded_state_prefix = loaded_scan_prefix(ctx) + tools_state_prefix = tool_scan_prefix(ctx) + for key, value in ctx.actions.state_delta.items(): - if key.startswith(SKILL_LOADED_STATE_KEY_PREFIX) and value: - # Skill was just loaded - skill_name = key[len(SKILL_LOADED_STATE_KEY_PREFIX):] - active_skills.add(skill_name) - logger.debug("Skill '%s' is active (just loaded)", skill_name) - elif key.startswith(SKILL_TOOLS_STATE_KEY_PREFIX): - # Skill's tools were just selected/modified - skill_name = key[len(SKILL_TOOLS_STATE_KEY_PREFIX):] - active_skills.add(skill_name) - logger.debug("Skill '%s' is active (tools modified)", skill_name) + if key.startswith(loaded_state_prefix) and value: + skill_name = key[len(loaded_state_prefix):].strip() + if skill_name: + active_skills.add(skill_name) + logger.debug("Skill '%s' is active (just loaded)", skill_name) + if key.startswith(tools_state_prefix): + skill_name = key[len(tools_state_prefix):].strip() + if skill_name: + active_skills.add(skill_name) + logger.debug("Skill '%s' is active (tools modified)", skill_name) return list(active_skills) @@ -346,7 +350,7 @@ def _get_tools_selection(self, ctx: InvocationContext, skill_name: str) -> list[ Returns: List of selected tool names """ - key = SKILL_TOOLS_STATE_KEY_PREFIX + skill_name + key = tool_state_key(ctx, skill_name) v = get_state_delta(ctx, key) if not v: # Fallback to SKILL.md defaults when explicit selection state is absent. diff --git a/trpc_agent_sdk/skills/_hot_reload.py b/trpc_agent_sdk/skills/_hot_reload.py new file mode 100644 index 0000000..270c76c --- /dev/null +++ b/trpc_agent_sdk/skills/_hot_reload.py @@ -0,0 +1,163 @@ +# Tencent is pleased to support the open source community by making tRPC-Agent-Python available. +# +# Copyright (C) 2026 Tencent. All rights reserved. +# +# tRPC-Agent-Python is licensed under Apache-2.0. +"""Skill hot-reload helpers. + +This module encapsulates filesystem event handling and dirty-directory queues +for skill repositories. Repository implementations can stay focused on indexing +while delegating watcher integration here. +""" + +from __future__ import annotations + +import importlib +import threading +from pathlib import Path +from typing import Callable + +from trpc_agent_sdk.log import logger + + +class SkillHotReloadTracker: + """Track changed skill directories and optional watchdog observer state.""" + + def __init__(self, skill_file_name: str): + self._skill_file_name = skill_file_name + self._watchdog_init_attempted = False + self._watchdog_observer: object | None = None + self._changed_dirs_by_root: dict[str, set[str]] = {} + self._changed_dirs_lock = threading.Lock() + + def clear(self) -> None: + """Clear queued directory changes.""" + with self._changed_dirs_lock: + self._changed_dirs_by_root = {} + + def mark_changed_path(self, raw_path: str, is_directory: bool, skill_roots: list[str]) -> None: + """Mark a changed path as dirty for the next incremental reload.""" + path = Path(raw_path) + if not is_directory and path.name.lower() != self._skill_file_name.lower(): + return + target_dir = path if is_directory else path.parent + self.mark_changed_dir(target_dir, skill_roots) + + def mark_changed_dir(self, path: Path, skill_roots: list[str]) -> None: + """Queue a changed directory; consumed by repository incremental scans.""" + root_key = self.resolve_root_key(path, skill_roots) + if root_key is None: + return + with self._changed_dirs_lock: + self._changed_dirs_by_root.setdefault(root_key, set()).add(str(path.resolve(strict=False))) + + def pop_changed_dirs(self, root_key: str) -> list[Path]: + """Pop and return queued changed directories for a root.""" + with self._changed_dirs_lock: + raw_dirs = self._changed_dirs_by_root.pop(root_key, set()) + return [Path(raw) for raw in sorted(raw_dirs)] + + def collect_changed_dirs( + self, + root_key: str, + tracked_dirs: set[str], + dir_mtime_ns: dict[str, int], + mtime_reader: Callable[[Path], int], + ) -> list[Path]: + """Collect changed directories via event queue or mtime probing.""" + changed_dirs = self.pop_changed_dirs(root_key) + if not changed_dirs: + for dir_key in sorted(tracked_dirs): + path = Path(dir_key) + if not path.exists(): + continue + current_mtime = mtime_reader(path) + if dir_mtime_ns.get(dir_key) != current_mtime: + changed_dirs.append(path) + dir_mtime_ns[dir_key] = current_mtime + return self.normalize_scan_targets(changed_dirs) + + @staticmethod + def resolve_root_key(path: Path, skill_roots: list[str]) -> str | None: + """Find matching root key for a path.""" + resolved = path.resolve(strict=False) + seen_roots: set[str] = set() + for root in skill_roots: + if not root: + continue + root_path = Path(root).resolve() + root_key = str(root_path) + if root_key in seen_roots: + continue + seen_roots.add(root_key) + if resolved.is_relative_to(root_path): + return root_key + return None + + @staticmethod + def normalize_scan_targets(changed_dirs: list[Path]) -> list[Path]: + """Drop nested paths when their parent is already queued.""" + result: list[Path] = [] + for candidate in sorted(changed_dirs, key=lambda p: len(p.parts)): + if any(candidate.is_relative_to(parent) for parent in result): + continue + result.append(candidate) + return result + + def start_watcher_if_possible(self, skill_roots: list[str]) -> None: + """Start filesystem watcher for near real-time hot reload if available.""" + if self._watchdog_init_attempted: + return + self._watchdog_init_attempted = True + try: + events_module = importlib.import_module("watchdog.events") + observers_module = importlib.import_module("watchdog.observers") + except ImportError: + logger.debug("watchdog is unavailable; skill hot reload falls back to mtime probing") + return + except Exception as ex: # pylint: disable=broad-except + logger.warning("Failed to initialize watchdog imports: %s", ex) + return + + file_system_event_handler_cls = getattr(events_module, "FileSystemEventHandler", None) + observer_cls = getattr(observers_module, "Observer", None) + if file_system_event_handler_cls is None or observer_cls is None: + logger.warning("watchdog is installed but required symbols are missing") + return + + tracker = self + + class _SkillDirChangeHandler(file_system_event_handler_cls): # type: ignore[misc,valid-type] + + def on_any_event(self, event) -> None: # type: ignore[no-untyped-def] + src_path = getattr(event, "src_path", None) + if src_path: + tracker.mark_changed_path(src_path, bool(event.is_directory), skill_roots) + dest_path = getattr(event, "dest_path", None) + if dest_path: + tracker.mark_changed_path(dest_path, bool(event.is_directory), skill_roots) + + observer = observer_cls() + handler = _SkillDirChangeHandler() + seen_roots: set[str] = set() + for root in skill_roots: + if not root: + continue + root_path = Path(root).resolve() + root_key = str(root_path) + if root_key in seen_roots: + continue + seen_roots.add(root_key) + if not root_path.is_dir(): + continue + try: + observer.schedule(handler, path=root_key, recursive=True) + except Exception as ex: # pylint: disable=broad-except + logger.warning("Failed to watch skill root %s: %s", root_key, ex) + + if not observer.emitters: + logger.debug("No valid skill roots to watch; hot reload watcher not started") + return + observer.start() + self._watchdog_observer = observer + logger.debug("Skill hot reload watcher started with %d root(s)", len(observer.emitters)) diff --git a/trpc_agent_sdk/skills/_repository.py b/trpc_agent_sdk/skills/_repository.py index 0877f4c..08189bf 100644 --- a/trpc_agent_sdk/skills/_repository.py +++ b/trpc_agent_sdk/skills/_repository.py @@ -18,9 +18,9 @@ import abc import os from pathlib import Path +from typing import Callable from typing import List from typing import Optional -from typing import Literal from typing_extensions import override import yaml @@ -29,74 +29,16 @@ from trpc_agent_sdk.log import logger from ._constants import SKILL_FILE +from ._hot_reload import SkillHotReloadTracker from ._types import Skill from ._types import SkillResource from ._types import SkillSummary from ._url_root import SkillRootResolver +from ._utils import is_doc_file +from ._utils import is_script_file BASE_DIR_PLACEHOLDER = "__BASE_DIR__" - - -def _split_front_matter(content: str) -> tuple[dict[str, str], str]: - """Split markdown into (front matter dict, body) with optional YAML front matter.""" - text = content.replace("\r\n", "\n") - if not text.startswith("---\n"): - return {}, text - idx = text.find("\n---\n", 4) - if idx < 0: - return {}, text - raw_yaml = text[4:idx] - body = text[idx + 5:] - try: - parsed = yaml.safe_load(raw_yaml) or {} - if not isinstance(parsed, dict): - return {}, body - except Exception: - return {}, body - out: dict[str, str] = {} - for k, v in parsed.items(): - key = str(k).strip() - if not key: - continue - if v is None: - out[key] = "" - else: - out[key] = str(v) - return out, body - - -def _parse_tools_from_body(body: str) -> list[str]: - """Parse tool names from the Tools section in body text.""" - tool_names: list[str] = [] - in_tools_section = False - for line in body.split("\n"): - stripped = line.strip() - if stripped.lower().startswith("tools:"): - in_tools_section = True - continue - if not in_tools_section: - continue - if stripped and not stripped.startswith("-") and not stripped.startswith("#"): - if ":" in stripped or (stripped[0].isupper() and any( - stripped.startswith(s) for s in ["Overview", "Examples", "Usage", "Description", "Installation"])): - break - if stripped.startswith("#"): - continue - if stripped.startswith("-"): - tool_name = stripped[1:].strip() - if tool_name and not tool_name.startswith("#"): - tool_names.append(tool_name) - return tool_names - - -def _is_doc_file(name: str) -> bool: - name_lower = name.lower() - return name_lower.endswith(".md") or name_lower.endswith(".txt") - - -def _read_skill_file(path: Path) -> tuple[dict[str, str], str]: - content = path.read_text(encoding="utf-8") - return _split_front_matter(content) +VisibilityFilter = Callable[[SkillSummary], bool] class BaseSkillRepository(abc.ABC): @@ -107,13 +49,19 @@ class BaseSkillRepository(abc.ABC): must satisfy. Parsing internals are left entirely to subclasses. """ - def __init__(self, workspace_runtime: BaseWorkspaceRuntime): + def __init__(self, workspace_runtime: BaseWorkspaceRuntime, visibility_filter: VisibilityFilter | None = None): self._workspace_runtime = workspace_runtime + self._visibility_filter = visibility_filter @property def workspace_runtime(self) -> BaseWorkspaceRuntime: return self._workspace_runtime + @property + def visibility_filter(self) -> VisibilityFilter | None: + """Return the filter function.""" + return self._visibility_filter + def user_prompt(self) -> str: return "" @@ -132,18 +80,8 @@ def get(self, name: str) -> Skill: raise NotImplementedError @abc.abstractmethod - def skill_list(self, mode: Literal["all", "enabled", "disabled"] = "all") -> list[str]: - """Return the names of all indexed skills. - - Args: - mode: The mode to list the skills. - - all: List all skills. - - enabled: List enabled skills. - - disabled: List disabled skills. - - Returns: - A list of skill names. - """ + def skill_list(self, mode: str = 'all') -> list[str]: + """Return the names of all indexed skills.""" raise NotImplementedError @abc.abstractmethod @@ -163,8 +101,24 @@ def refresh(self) -> None: def skill_run_env(self, skill_name: str) -> dict[str, str]: """Return the environment variables for the given skill. """ + if self._visibility_filter is not None: + if not self._skill_visible_by_name(skill_name, self.summaries()): + raise ValueError(f"skill '{skill_name}' not found") return {} + def _filter_summaries( + self, + summaries: list[SkillSummary], + ) -> list[SkillSummary]: + if not summaries: + return [] + if self._visibility_filter is None: + return [SkillSummary(name=s.name, description=s.description) for s in summaries] + return [s for s in summaries if self._visibility_filter(s)] + + def _skill_visible_by_name(self, name: str, summaries: list[SkillSummary]) -> bool: + return any(summary.name == name for summary in summaries) + class FsSkillRepository(BaseSkillRepository): """ @@ -181,6 +135,7 @@ def __init__( *roots: str, workspace_runtime: Optional[BaseWorkspaceRuntime] = None, resolver: Optional[SkillRootResolver] = None, + enable_hot_reload: bool = False, ): """ Create a FsSkillRepository scanning the given roots. @@ -190,6 +145,7 @@ def __init__( ``file://`` URLs, or ``http(s)://`` archive URLs). workspace_runtime: Optional workspace runtime to use. resolver: Optional skill root resolver to use. + enable_hot_reload: Whether to enable skill hot reload checks. """ if workspace_runtime is None: workspace_runtime = create_local_workspace_runtime() @@ -197,6 +153,11 @@ def __init__( self._resolver = resolver or SkillRootResolver() self._skill_paths: dict[str, str] = {} # name -> base dir self._all_descriptions: dict[str, str] = {} # name -> description + self._discovered_skill_files: set[str] = set() + self._tracked_dirs_by_root: dict[str, set[str]] = {} + self._dir_mtime_ns: dict[str, int] = {} + self._hot_reload_tracker = SkillHotReloadTracker(skill_file_name=SKILL_FILE, ) + self._enable_hot_reload = enable_hot_reload self._skill_roots: list[str] = [] flat_roots: list[str] = [] @@ -212,6 +173,11 @@ def __init__( # Root resolution # ------------------------------------------------------------------ + @property + def hot_reload_enabled(self) -> bool: + """Whether hot reload checks are enabled.""" + return self._enable_hot_reload + def _resolve_skill_roots(self, roots: list[str]) -> None: """ Resolve a skill root string to a local directory path. @@ -238,44 +204,140 @@ def _index(self) -> None: """Scan all roots and index available skills.""" self._skill_paths = {} self._all_descriptions = {} + self._discovered_skill_files = set() + self._tracked_dirs_by_root = {} + self._dir_mtime_ns = {} + self._hot_reload_tracker.clear() seen: set[str] = set() for root in self._skill_roots: if not root: continue root_path = Path(root).resolve() - if str(root_path) in seen: + root_key = str(root_path) + if root_key in seen: continue - seen.add(str(root_path)) + seen.add(root_key) try: - for dirpath, _dirs, _files in os.walk(root_path): - skill_file_path = Path(dirpath) / SKILL_FILE - if not skill_file_path.is_file(): - continue - try: - self._index_one(dirpath, skill_file_path) - except Exception as ex: # pylint: disable=broad-except - logger.debug("Failed to index skill at %s: %s", skill_file_path, ex) - + self._scan_root(root_path, root_key=root_key) except Exception as ex: # pylint: disable=broad-except logger.warning("Error scanning root %s: %s", root_path, ex) + if self._enable_hot_reload: + self._hot_reload_tracker.start_watcher_if_possible(self._skill_roots) + + def _scan_root(self, root_path: Path, root_key: str, start_path: Optional[Path] = None) -> None: + """Scan a full root or one of its changed subtrees.""" + target = start_path or root_path + for dirpath, _dirs, _files in os.walk(target): + self._track_dir(root_key, Path(dirpath)) + skill_file_path = Path(dirpath) / SKILL_FILE + if not skill_file_path.is_file(): + continue + try: + self._index_one(dirpath, skill_file_path) + except Exception as ex: # pylint: disable=broad-except + logger.debug("Failed to index skill at %s: %s dirpath: %s, _files: %s", skill_file_path, ex, _dirs, + _files) + + def _track_dir(self, root_key: str, path: Path) -> None: + """Store directory metadata used by incremental hot reload.""" + dir_key = str(path.resolve()) + self._tracked_dirs_by_root.setdefault(root_key, set()).add(dir_key) + self._dir_mtime_ns[dir_key] = self._safe_mtime_ns(path) - def _index_one(self, dirpath: str, skill_file_path: Path) -> None: + @staticmethod + def _safe_mtime_ns(path: Path) -> int: + try: + return path.stat().st_mtime_ns + except OSError: + return -1 + + def _mark_changed_dir_for_hot_reload(self, path: Path) -> None: + """Queue a changed directory; consumed by _scan_changed_dirs.""" + if not self._enable_hot_reload: + return + self._hot_reload_tracker.mark_changed_dir(path, self._skill_roots) + + def _index_one(self, dirpath: str, skill_file_path: Path) -> bool: """Index a single skill directory found at *dirpath*.""" - front_matter, _ = _read_skill_file(skill_file_path) + skill_file_key = str(skill_file_path.resolve()) + if skill_file_key in self._discovered_skill_files: + return False + self._discovered_skill_files.add(skill_file_key) + + front_matter, _ = self._read_skill_file(skill_file_path) name = front_matter.get("name", "").strip() if not name: name = Path(dirpath).name.strip() if not name: - return + return False # First occurrence wins. if name in self._skill_paths: - return + return False self._all_descriptions[name] = front_matter.get("description", "").strip() self._skill_paths[name] = dirpath logger.debug("Found skill '%s' at %s", name, dirpath) + return True + + def _scan_changed_dirs(self) -> None: + """Fast probe + incremental scan for newly added SKILL.md files.""" + if not self._enable_hot_reload: + return + seen_roots: set[str] = set() + for root in self._skill_roots: + if not root: + continue + root_path = Path(root).resolve() + root_key = str(root_path) + if root_key in seen_roots: + continue + seen_roots.add(root_key) + + if not root_path.is_dir(): + continue + + tracked_dirs = self._tracked_dirs_by_root.get(root_key, {root_key}) + changed_dirs = self._hot_reload_tracker.collect_changed_dirs( + root_key=root_key, + tracked_dirs=tracked_dirs, + dir_mtime_ns=self._dir_mtime_ns, + mtime_reader=self._safe_mtime_ns, + ) + if not changed_dirs: + continue + + for target in changed_dirs: + self._scan_root(root_path=root_path, root_key=root_key, start_path=target) + self._prune_deleted_skills() + + def _prune_deleted_skills(self) -> None: + """Remove indexed skills whose directory or SKILL.md no longer exists.""" + removed_names: list[str] = [] + for name, dirpath in list(self._skill_paths.items()): + skill_file_path = Path(dirpath) / SKILL_FILE + if skill_file_path.is_file(): + continue + removed_names.append(name) + + for name in removed_names: + dirpath = self._skill_paths.pop(name, "") + self._all_descriptions.pop(name, None) + if dirpath: + removed_file_key = str((Path(dirpath) / SKILL_FILE).resolve(strict=False)) + self._discovered_skill_files.discard(removed_file_key) + logger.debug("Pruned deleted skill '%s' from repository index", name) + + stale_files = {path for path in self._discovered_skill_files if not Path(path).is_file()} + if stale_files: + self._discovered_skill_files.difference_update(stale_files) + + @classmethod + def _read_skill_file(cls, path: Path) -> tuple[dict[str, str], str]: + """Read the skill file and return the front matter and body.""" + content = path.read_text(encoding="utf-8") + return cls.from_markdown(content) # ------------------------------------------------------------------ # Public API @@ -290,17 +352,24 @@ def path(self, name: str) -> str: """ key = name.strip() if key not in self._skill_paths: - raise ValueError(f"skill '{name}' not found") + logger.warning("skill '%s' not found, refreshing repository", name) + self.refresh() + if key not in self._skill_paths: + raise ValueError(f"skill '{name}' not found") + if self._visibility_filter is not None: + if not self._skill_visible_by_name(key, self.summaries()): + raise ValueError(f"skill '{key}' not found") return self._skill_paths[key] @override def summaries(self) -> List[SkillSummary]: """Return summaries for all indexed skills, sorted by name.""" + self._scan_changed_dirs() out: list[SkillSummary] = [] for name in sorted(self._skill_paths): skill_file_path = Path(self._skill_paths[name]) / SKILL_FILE try: - front_matter, _ = _read_skill_file(skill_file_path) + front_matter, _ = self._read_skill_file(skill_file_path) summary = SkillSummary( name=front_matter.get("name", "").strip(), description=front_matter.get("description", "").strip(), @@ -310,6 +379,8 @@ def summaries(self) -> List[SkillSummary]: out.append(summary) except Exception as ex: # pylint: disable=broad-except logger.warning("Failed to parse summary for skill '%s': %s", name, ex) + if self._visibility_filter is not None: + out = self._filter_summaries(out) return out @override @@ -323,13 +394,13 @@ def get(self, name: str) -> Skill: ValueError: If the skill is not found. """ dir_path = Path(self.path(name)) - front_matter, body = _read_skill_file(dir_path / SKILL_FILE) + front_matter, body = self._read_skill_file(dir_path / SKILL_FILE) skill = Skill() skill.base_dir = str(dir_path) skill.summary.name = front_matter.get("name", "").strip() or name skill.summary.description = front_matter.get("description", "").strip() skill.body = body - skill.tools = _parse_tools_from_body(skill.body) + skill.tools = self._parse_tools_from_body(skill.body) if skill.base_dir: skill.body = skill.body.replace(BASE_DIR_PLACEHOLDER, skill.base_dir) @@ -338,12 +409,20 @@ def get(self, name: str) -> Skill: return skill @override - def skill_list(self, mode: Literal["all", "enabled", "disabled"] = "all") -> list[str]: - """Return the names of all indexed skills, sorted.""" + def skill_list(self, mode: str = 'all') -> list[str]: + """Return the names of all indexed skills, sorted. + + Args: + mode: The mode to list skills. + Returns: + A list of skill names. + """ + self._scan_changed_dirs() return sorted(self._skill_paths) @override def refresh(self) -> None: + """Refresh the skill repository.""" self._index() def _read_docs(self, dir_path: Path, base_dir: str) -> list[SkillResource]: @@ -356,7 +435,7 @@ def _read_docs(self, dir_path: Path, base_dir: str) -> list[SkillResource]: continue if entry.name.lower() == SKILL_FILE.lower(): continue - if not _is_doc_file(entry.name): + if not is_doc_file(entry.name) and not is_script_file(entry.name): continue try: content = entry.read_text(encoding="utf-8") @@ -380,7 +459,30 @@ def from_markdown(cls, content: str) -> tuple[dict[str, str], str]: .. deprecated:: Prefer repository-native front matter splitting directly. """ - return _split_front_matter(content) + text = content.replace("\r\n", "\n") + if not text.startswith("---\n"): + return {}, text + idx = text.find("\n---\n", 4) + if idx < 0: + return {}, text + raw_yaml = text[4:idx] + body = text[idx + 5:] + try: + parsed = yaml.safe_load(raw_yaml) or {} + if not isinstance(parsed, dict): + return {}, body + except Exception: + return {}, body + out: dict[str, str] = {} + for k, v in parsed.items(): + key = str(k).strip() + if not key: + continue + if v is None: + out[key] = "" + else: + out[key] = str(v) + return out, body @staticmethod def _parse_tools_from_body(body: str) -> list[str]: @@ -389,18 +491,40 @@ def _parse_tools_from_body(body: str) -> list[str]: .. deprecated:: Prefer repository-native tool parser directly. """ - return _parse_tools_from_body(body) + tool_names: list[str] = [] + in_tools_section = False + for line in body.split("\n"): + stripped = line.strip() + if stripped.lower().startswith("tools:"): + in_tools_section = True + continue + if not in_tools_section: + continue + if stripped and not stripped.startswith("-") and not stripped.startswith("#"): + if ":" in stripped or (stripped[0].isupper() and any( + stripped.startswith(s) + for s in ["Overview", "Examples", "Usage", "Description", "Installation"])): + break + if stripped.startswith("#"): + continue + if stripped.startswith("-"): + tool_name = stripped[1:].strip() + if tool_name and not tool_name.startswith("#"): + tool_names.append(tool_name) + return tool_names def create_default_skill_repository( *roots: str, workspace_runtime: Optional[BaseWorkspaceRuntime] = None, + enable_hot_reload: bool = True, ) -> FsSkillRepository: """Create a new filesystem skill repository. Args: roots: Root directories (or URLs) to scan for skills. workspace_runtime: Optional workspace runtime. + enable_hot_reload: Whether to enable skill hot reload checks. Returns: A configured :class:`FsSkillRepository`. """ @@ -409,4 +533,5 @@ def create_default_skill_repository( return FsSkillRepository( *roots, workspace_runtime=workspace_runtime, + enable_hot_reload=enable_hot_reload, ) diff --git a/trpc_agent_sdk/skills/_skill_config.py b/trpc_agent_sdk/skills/_skill_config.py new file mode 100644 index 0000000..00f15fe --- /dev/null +++ b/trpc_agent_sdk/skills/_skill_config.py @@ -0,0 +1,60 @@ +# Tencent is pleased to support the open source community by making tRPC-Agent-Python available. +# +# Copyright (C) 2026 Tencent. All rights reserved. +# +# tRPC-Agent-Python is licensed under Apache-2.0. + +from typing import Any + +from trpc_agent_sdk.context import AgentContext +from trpc_agent_sdk.context import InvocationContext + +from ._constants import SKILL_CONFIG_KEY +from ._constants import SKILL_LOAD_MODE_VALUES +from ._constants import SkillLoadModeNames + +DEFAULT_SKILL_CONFIG = { + "skill_processor": { + "load_mode": "turn", + "tooling_guidance": "", + "tool_result_mode": False, + "tool_profile": "full", + "forbidden_tools": [], + "tool_flags": None, + "exec_tools_disabled": False, + "repo_resolver": None, + "max_loaded_skills": 0, + }, + "workspace_exec_processor": { + "session_tools": False, + "has_skills_repo": False, + "repo_resolver": None, + "enabled_resolver": None, + "sessions_resolver": None, + }, + "skills_tool_result_processor": { + "skip_fallback_on_session_summary": True, + "repo_resolver": None, + "tool_result_mode": False, + }, +} + + +def get_skill_config(agent_context: AgentContext) -> dict[str, Any]: + return agent_context.get_metadata(SKILL_CONFIG_KEY, DEFAULT_SKILL_CONFIG) + + +def set_skill_config(agent_context: AgentContext, config: dict[str, Any] = DEFAULT_SKILL_CONFIG) -> None: + agent_context.with_metadata(SKILL_CONFIG_KEY, config) + + +def get_skill_load_mode(ctx: InvocationContext) -> str: + skill_config = get_skill_config(ctx.agent_context) + load_mode = skill_config["skill_processor"].get("load_mode", SkillLoadModeNames.TURN.value) + if load_mode not in SKILL_LOAD_MODE_VALUES: + load_mode = SkillLoadModeNames.TURN.value + return str(load_mode) + + +def is_exist_skill_config(agent_context: AgentContext) -> bool: + return SKILL_CONFIG_KEY in agent_context.metadata diff --git a/trpc_agent_sdk/skills/_skill_profile.py b/trpc_agent_sdk/skills/_skill_profile.py new file mode 100644 index 0000000..e1f07da --- /dev/null +++ b/trpc_agent_sdk/skills/_skill_profile.py @@ -0,0 +1,120 @@ +# Tencent is pleased to support the open source community by making tRPC-Agent-Python available. +# +# Copyright (C) 2026 Tencent. All rights reserved. +# +# tRPC-Agent-Python is licensed under Apache-2.0. +"""Skill tool profile options and flags.""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Optional + +from ._constants import SKILL_TOOLS_NAMES +from ._constants import SkillProfileNames +from ._constants import SkillToolsNames + + +@dataclass +class SkillProfileFlags: + """Built-in skill tool flags""" + load: bool = False + select_docs: bool = False + list_docs: bool = False + run: bool = False + select_tools: bool = False + list_skills: bool = False + exec: bool = False + write_stdin: bool = False + poll_session: bool = False + kill_session: bool = False + + @classmethod + def normalize_profile(cls, profile: str) -> str: + p = (profile or "").strip().lower() + if p == str(SkillProfileNames.KNOWLEDGE_ONLY): + return str(SkillProfileNames.KNOWLEDGE_ONLY) + return str(SkillProfileNames.FULL) + + @classmethod + def normalize_tool(cls, name: str) -> str: + return (name or "").strip().lower() + + @classmethod + def preset_flags(cls, profile: str, forbidden_tools: Optional[list[str]] = None) -> "SkillProfileFlags": + normalized = cls.normalize_profile(profile) + if normalized == SkillProfileNames.KNOWLEDGE_ONLY: + return cls(load=True, select_docs=True, list_docs=True) + + flags = cls( + load=True, + select_docs=True, + list_docs=True, + run=True, + exec=True, + write_stdin=True, + poll_session=True, + kill_session=True, + select_tools=True, + list_skills=True, + ) + if forbidden_tools: + flags = flags.flags_from_forbidden_tools(forbidden_tools, flags) + return flags + + @classmethod + def flags_from_forbidden_tools(cls, forbidden_tools: list[str], flags: "SkillProfileFlags") -> "SkillProfileFlags": + for raw in forbidden_tools: + name = cls.normalize_tool(raw) + if name in SKILL_TOOLS_NAMES: + flags_mem = name[name.index("_") + 1:] + setattr(flags, flags_mem, False) + flags.validate() + return flags + + @classmethod + def resolve_flags(cls, profile: str, forbidden_tools: Optional[list[str]] = None) -> "SkillProfileFlags": + flags = cls.preset_flags(profile, forbidden_tools) + flags.validate() + return flags + + def validate(self) -> None: + if self.exec and not self.run: + raise ValueError(f"{SkillToolsNames.EXEC} requires {SkillToolsNames.RUN}") + if self.write_stdin and not self.exec: + raise ValueError(f"{SkillToolsNames.WRITE_STDIN} requires {SkillToolsNames.EXEC}") + if self.poll_session and not self.exec: + raise ValueError(f"{SkillToolsNames.POLL_SESSION} requires {SkillToolsNames.EXEC}") + if self.kill_session and not self.exec: + raise ValueError(f"{SkillToolsNames.KILL_SESSION} requires {SkillToolsNames.EXEC}") + + def is_any(self) -> bool: + return (self.load or self.select_docs or self.list_docs or self.run or self.exec or self.write_stdin + or self.poll_session or self.kill_session or self.select_tools or self.list_skills) + + def has_knowledge_tools(self) -> bool: + return self.load or self.select_docs or self.list_docs + + def has_doc_helpers(self) -> bool: + return self.select_docs or self.list_docs + + def has_select_tools(self) -> bool: + return self.select_tools + + def requires_execution_tools(self) -> bool: + return self.run or self.exec or self.write_stdin or self.poll_session or self.kill_session + + def requires_exec_session_tools(self) -> bool: + return self.exec or self.write_stdin or self.poll_session or self.kill_session + + def without_interactive_execution(self) -> "SkillProfileFlags": + return SkillProfileFlags( + load=self.load, + select_docs=self.select_docs, + list_docs=self.list_docs, + run=self.run, + exec=False, + write_stdin=False, + poll_session=False, + kill_session=False, + ) diff --git a/trpc_agent_sdk/skills/_state_keys.py b/trpc_agent_sdk/skills/_state_keys.py new file mode 100644 index 0000000..586f530 --- /dev/null +++ b/trpc_agent_sdk/skills/_state_keys.py @@ -0,0 +1,151 @@ +# Tencent is pleased to support the open source community by making tRPC-Agent-Python available. +# +# Copyright (C) 2026 Tencent. All rights reserved. +# +# tRPC-Agent-Python is licensed under Apache-2.0. +"""State key builders for skills.""" + +from __future__ import annotations + +from urllib.parse import quote + +from ._constants import SKILL_DOCS_BY_AGENT_STATE_KEY_PREFIX +from ._constants import SKILL_DOCS_STATE_KEY_PREFIX +from ._constants import SKILL_LOADED_BY_AGENT_STATE_KEY_PREFIX +from ._constants import SKILL_LOADED_ORDER_BY_AGENT_STATE_KEY_PREFIX +from ._constants import SKILL_LOADED_ORDER_STATE_KEY_PREFIX +from ._constants import SKILL_LOADED_STATE_KEY_PREFIX +from ._constants import SKILL_TOOLS_BY_AGENT_STATE_KEY_PREFIX +from ._constants import SKILL_TOOLS_STATE_KEY_PREFIX + +_STATE_KEY_SCOPE_DELIMITER = "/" +_STATE_KEY_TEMP_PREFIX = "temp:" + + +def _escape_scope_segment(value: str) -> str: + """Escape a scoped-key segment when it contains the delimiter.""" + if _STATE_KEY_SCOPE_DELIMITER in value: + return quote(value, safe="") + return value + + +def to_persistent_prefix(prefix: str) -> str: + """Convert temp-prefixed state prefix to persistent prefix.""" + if prefix.startswith(_STATE_KEY_TEMP_PREFIX): + return prefix[len(_STATE_KEY_TEMP_PREFIX):] + return prefix + + +def loaded_key(agent_name: str, skill_name: str) -> str: + """Return the loaded-state key for a skill. + + When ``agent_name`` is empty, fallback to the legacy unscoped key. + """ + agent_name = agent_name.strip() + skill_name = skill_name.strip() + if not agent_name: + return f"{SKILL_LOADED_STATE_KEY_PREFIX}{skill_name}" + return (f"{SKILL_LOADED_BY_AGENT_STATE_KEY_PREFIX}" + f"{_escape_scope_segment(agent_name)}" + f"{_STATE_KEY_SCOPE_DELIMITER}{skill_name}") + + +def loaded_session_key(keys: str) -> str: + """Return persistent loaded-state key for session mode.""" + return to_persistent_prefix(keys) + + +def docs_key(agent_name: str, skill_name: str) -> str: + """Return the docs-state key for a skill. + + When ``agent_name`` is empty, fallback to the legacy unscoped key. + """ + agent_name = agent_name.strip() + skill_name = skill_name.strip() + if not agent_name: + return f"{SKILL_DOCS_STATE_KEY_PREFIX}{skill_name}" + return (f"{SKILL_DOCS_BY_AGENT_STATE_KEY_PREFIX}" + f"{_escape_scope_segment(agent_name)}" + f"{_STATE_KEY_SCOPE_DELIMITER}{skill_name}") + + +def docs_session_key(keys: str) -> str: + """Return persistent docs-state key for session mode.""" + return to_persistent_prefix(keys) + + +def tool_key(agent_name: str, skill_name: str) -> str: + """Return the tools-state key for a skill. + + When ``agent_name`` is empty, fallback to the legacy unscoped key. + """ + agent_name = agent_name.strip() + skill_name = skill_name.strip() + if not agent_name: + return f"{SKILL_TOOLS_STATE_KEY_PREFIX}{skill_name}" + return (f"{SKILL_TOOLS_BY_AGENT_STATE_KEY_PREFIX}" + f"{_escape_scope_segment(agent_name)}" + f"{_STATE_KEY_SCOPE_DELIMITER}{skill_name}") + + +def tool_session_key(keys: str) -> str: + """Return persistent tools-state key for session mode.""" + return to_persistent_prefix(keys) + + +def loaded_prefix(agent_name: str) -> str: + """Return the loaded-state scan prefix for an agent.""" + agent_name = agent_name.strip() + if not agent_name: + return SKILL_LOADED_STATE_KEY_PREFIX + return (f"{SKILL_LOADED_BY_AGENT_STATE_KEY_PREFIX}" + f"{_escape_scope_segment(agent_name)}" + f"{_STATE_KEY_SCOPE_DELIMITER}") + + +def loaded_session_prefix(keys: str) -> str: + """Return persistent loaded-state scan prefix for session mode.""" + return to_persistent_prefix(keys) + + +def docs_prefix(agent_name: str) -> str: + """Return the docs-state scan prefix for an agent.""" + agent_name = agent_name.strip() + if not agent_name: + return SKILL_DOCS_STATE_KEY_PREFIX + return (f"{SKILL_DOCS_BY_AGENT_STATE_KEY_PREFIX}" + f"{_escape_scope_segment(agent_name)}" + f"{_STATE_KEY_SCOPE_DELIMITER}") + + +def docs_session_prefix(keys: str) -> str: + """Return persistent docs-state scan prefix for session mode.""" + return to_persistent_prefix(keys) + + +def tool_prefix(agent_name: str) -> str: + """Return the tools-state scan prefix for an agent.""" + agent_name = agent_name.strip() + if not agent_name: + return SKILL_TOOLS_STATE_KEY_PREFIX + return (f"{SKILL_TOOLS_BY_AGENT_STATE_KEY_PREFIX}" + f"{_escape_scope_segment(agent_name)}" + f"{_STATE_KEY_SCOPE_DELIMITER}") + + +def tool_session_prefix(keys: str) -> str: + """Return persistent tools-state scan prefix for session mode.""" + return to_persistent_prefix(keys) + + +def loaded_order_key(agent_name: str) -> str: + """Return the loaded-order key for an agent.""" + agent_name = agent_name.strip() + if not agent_name: + return SKILL_LOADED_ORDER_STATE_KEY_PREFIX + return f"{SKILL_LOADED_ORDER_BY_AGENT_STATE_KEY_PREFIX}{_escape_scope_segment(agent_name)}" + + +def loaded_session_order_key(keys: str) -> str: + """Return persistent loaded-order key for session mode.""" + return to_persistent_prefix(keys) diff --git a/trpc_agent_sdk/skills/_state_migration.py b/trpc_agent_sdk/skills/_state_migration.py new file mode 100644 index 0000000..90feb15 --- /dev/null +++ b/trpc_agent_sdk/skills/_state_migration.py @@ -0,0 +1,165 @@ +# Tencent is pleased to support the open source community by making tRPC-Agent-Python available. +# +# Copyright (C) 2026 Tencent. All rights reserved. +# +# tRPC-Agent-Python is licensed under Apache-2.0. +"""Legacy skill state migration utilities. + +- migrate legacy unscoped skill state keys once per session +- infer skill owners from historical tool responses +- write new scoped keys and clear old legacy keys +""" + +from __future__ import annotations + +import re +from typing import Any +from typing import Callable + +from trpc_agent_sdk.context import InvocationContext +from trpc_agent_sdk.events import Event + +from ._constants import SKILL_DOCS_STATE_KEY_PREFIX +from ._constants import SKILL_LOADED_STATE_KEY_PREFIX +from ._constants import SkillToolsNames +from ._state_keys import docs_key +from ._state_keys import loaded_key + +SKILLS_LEGACY_MIGRATION_STATE_KEY = "processor:skills:legacy_migrated" + +ScopedKeysBuilder = Callable[[str, str], str] + + +def _state_has_key(ctx: InvocationContext, key: str) -> bool: + if key in ctx.actions.state_delta: + return True + return key in ctx.session.state + + +def _snapshot_state(ctx: InvocationContext) -> dict[str, Any]: + state = dict(ctx.session.state or {}) + for k, v in ctx.actions.state_delta.items(): + if v is None: + state.pop(k, None) + else: + state[k] = v + return state + + +def _migrate_legacy_state_key( + ctx: InvocationContext, + state: dict[str, Any], + delta: dict[str, Any], + legacy_key: str, + legacy_val: Any, + skill_name: str, + owners: dict[str, str], + build_keys: ScopedKeysBuilder, +) -> None: + name = (skill_name or "").strip() + if not name: + return + # Skip already scoped entries. + if ":" in name: + return + + owner = (owners.get(name, "") or "").strip() + if not owner: + owner = (getattr(ctx.agent, "name", "") or "").strip() + if not owner: + return + + temp_key = build_keys(owner, name) + temp_existing = state.get(temp_key, None) + if temp_existing: + delta[legacy_key] = None + return + + delta[temp_key] = legacy_val + delta[legacy_key] = None + + +def _legacy_skill_owners(events: list[Event]) -> dict[str, str]: + owners: dict[str, str] = {} + for ev in reversed(events or []): + _add_owners_from_event(ev, owners) + return owners + + +def _add_owners_from_event(ev: Event, owners: dict[str, str]) -> None: + if not ev or not ev.content or not ev.content.parts: + return + author = (ev.author or "").strip() + if not author: + return + for part in reversed(ev.content.parts): + fr = part.function_response + if not fr: + continue + tool_name = (fr.name or "").strip() + if tool_name not in (SkillToolsNames.LOAD, SkillToolsNames.SELECT_DOCS): + continue + skill_name = _skill_name_from_tool_response(fr.response) + if not skill_name or skill_name in owners: + continue + owners[skill_name] = author + + +def _skill_name_from_tool_response(response: Any) -> str: + """Extract skill name from tool response payload.""" + if isinstance(response, dict): + for key in ("skill", "skill_name", "name"): + value = response.get(key) + if isinstance(value, str) and value.strip(): + return value.strip() + result = response.get("result") + if isinstance(result, str): + matched = re.search(r"skill\s+'([^']+)'\s+loaded", result) + if matched: + return matched.group(1).strip() + elif isinstance(response, str): + matched = re.search(r"skill\s+'([^']+)'\s+loaded", response) + if matched: + return matched.group(1).strip() + return "" + + +def maybe_migrate_legacy_skill_state(ctx: InvocationContext) -> None: + """Migrate legacy skill state keys into scoped keys once. + + This function is idempotent per session via + ``SKILLS_LEGACY_MIGRATION_STATE_KEY``. + """ + if ctx is None or ctx.session is None: + return + if _state_has_key(ctx, SKILLS_LEGACY_MIGRATION_STATE_KEY): + return + ctx.actions.state_delta[SKILLS_LEGACY_MIGRATION_STATE_KEY] = True + + state = _snapshot_state(ctx) + if not state: + return + has_loaded = any(k.startswith(SKILL_LOADED_STATE_KEY_PREFIX) for k in state.keys()) + has_docs = any(k.startswith(SKILL_DOCS_STATE_KEY_PREFIX) for k in state.keys()) + if not has_loaded and not has_docs: + return + + owners: dict[str, str] | None = None + delta: dict[str, Any] = {} + + for key, value in state.items(): + if value is None or value == "": + continue + if key.startswith(SKILL_LOADED_STATE_KEY_PREFIX): + if owners is None: + owners = _legacy_skill_owners(getattr(ctx.session, "events", [])) + name = key[len(SKILL_LOADED_STATE_KEY_PREFIX):].strip() + _migrate_legacy_state_key(ctx, state, delta, key, value, name, owners, loaded_key) + elif key.startswith(SKILL_DOCS_STATE_KEY_PREFIX): + if owners is None: + owners = _legacy_skill_owners(getattr(ctx.session, "events", [])) + name = key[len(SKILL_DOCS_STATE_KEY_PREFIX):].strip() + _migrate_legacy_state_key(ctx, state, delta, key, value, name, owners, docs_key) + + if delta: + ctx.actions.state_delta.update(delta) diff --git a/trpc_agent_sdk/skills/_state_order.py b/trpc_agent_sdk/skills/_state_order.py new file mode 100644 index 0000000..8cae773 --- /dev/null +++ b/trpc_agent_sdk/skills/_state_order.py @@ -0,0 +1,82 @@ +# Tencent is pleased to support the open source community by making tRPC-Agent-Python available. +# +# Copyright (C) 2026 Tencent. All rights reserved. +# +# tRPC-Agent-Python is licensed under Apache-2.0. +"""Loaded-skill order helpers. +""" + +from __future__ import annotations + +import json +from typing import Any + + +def parse_loaded_order(raw: Any) -> list[str]: + """Parse a stored loaded-order payload. + + Returns a normalized list of unique non-empty skill names. + """ + if raw is None: + return [] + if isinstance(raw, bytes): + if not raw: + return [] + try: + raw = raw.decode("utf-8") + except UnicodeDecodeError: + return [] + if isinstance(raw, str): + if not raw: + return [] + try: + raw = json.loads(raw) + except json.JSONDecodeError: + return [] + if not isinstance(raw, list): + return [] + return _normalize_loaded_order(raw) + + +def marshal_loaded_order(names: list[str]) -> str: + """Serialize a normalized loaded-order payload.""" + normalized = _normalize_loaded_order(names) + if not normalized: + return "" + return json.dumps(normalized, ensure_ascii=False) + + +def touch_loaded_order(names: list[str], *touched: str) -> list[str]: + """Move touched skills to the tail of the loaded order.""" + order = _normalize_loaded_order(names) + for name in touched: + candidate = (name or "").strip() + if not candidate: + continue + order = _remove_loaded_order_name(order, candidate) + order.append(candidate) + return order + + +def _normalize_loaded_order(names: list[Any]) -> list[str]: + if not names: + return [] + out: list[str] = [] + seen: set[str] = set() + for name in names: + if not isinstance(name, str): + continue + candidate = name.strip() + if not candidate or candidate in seen: + continue + out.append(candidate) + seen.add(candidate) + return out + + +def _remove_loaded_order_name(order: list[str], target: str) -> list[str]: + for i, name in enumerate(order): + if name != target: + continue + return order[:i] + order[i + 1:] + return order diff --git a/trpc_agent_sdk/skills/_toolset.py b/trpc_agent_sdk/skills/_toolset.py index 024a1db..f276322 100644 --- a/trpc_agent_sdk/skills/_toolset.py +++ b/trpc_agent_sdk/skills/_toolset.py @@ -12,33 +12,44 @@ from __future__ import annotations from typing import Any -from typing import Callable from typing import List from typing import Optional from typing import Union from typing_extensions import override -from trpc_agent_sdk.abc import ToolABC -from trpc_agent_sdk.abc import ToolPredicate from trpc_agent_sdk.abc import ToolSetABC +from trpc_agent_sdk.abc import ToolPredicate +from trpc_agent_sdk.abc import ToolABC from trpc_agent_sdk.context import InvocationContext -from trpc_agent_sdk.log import logger +from trpc_agent_sdk.context import get_invocation_ctx from trpc_agent_sdk.tools import FunctionTool +from trpc_agent_sdk.log import logger from ._constants import SKILL_REGISTRY_KEY from ._constants import SKILL_REPOSITORY_KEY +from ._repository import FsSkillRepository +from ._repository import BaseSkillRepository from ._registry import SKILL_REGISTRY from ._registry import SkillToolFunction -from ._repository import BaseSkillRepository -from ._repository import FsSkillRepository -from .tools import SkillExecTool -from .tools import SkillRunTool -from .tools import skill_list +from ._skill_config import DEFAULT_SKILL_CONFIG +from ._skill_config import set_skill_config +from ._skill_config import is_exist_skill_config from .tools import skill_list_docs from .tools import skill_list_tools -from .tools import skill_load +from .tools import SkillLoadTool from .tools import skill_select_docs from .tools import skill_select_tools +from .tools import skill_list +from .tools import SkillExecTool +from .tools import SkillRunTool +from .tools import SaveArtifactTool +from .tools import WorkspaceExecTool +from .tools import WorkspaceWriteStdinTool +from .tools import WorkspaceKillSessionTool +from .tools import CreateWorkspaceNameCallback +from .tools import default_create_ws_name_callback +from .tools import CopySkillStager +from .stager import Stager class SkillToolSet(ToolSetABC): @@ -58,31 +69,63 @@ class SkillToolSet(ToolSetABC): def __init__(self, paths: Optional[List[str]] = None, repository: BaseSkillRepository = None, + enable_hot_reload: bool = False, tool_filter: Optional[Union[ToolPredicate, List[str]]] = None, is_include_all_tools: bool = True, + create_ws_name_cb: Optional[CreateWorkspaceNameCallback] = None, + runtime_tools: Optional[List[ToolABC]] = None, + skill_stager: Optional[Stager] = None, + skill_config: Optional[dict[str, Any]] = None, **run_tool_kwargs: dict[str, Any]): """Initialize the skill toolset. Args: paths: Optional list of skill paths. If None, will create a new one. repository: Skill repository. If None, will be retrieved from context metadata. + enable_hot_reload: Whether to enable skill hot reload checks for + auto-created repositories. tool_filter: Optional tool filter. If None, will include all tools. is_include_all_tools: Optional flag to include all tools. If True, will include all tools. + user_tools: Optional list of user tools. If None, will not include any user tools. run_tool_kwargs: Optional keyword arguments for skill run tool. If None, will use default values. """ super().__init__(tool_filter=tool_filter, is_include_all_tools=is_include_all_tools) self.name = "skill_toolset" - self._repository = repository or FsSkillRepository(*(paths or [])) - self._run_tool = SkillRunTool(repository=self._repository, **run_tool_kwargs) - self._exec_tool = SkillExecTool(run_tool=self._run_tool) - self._tools: List[Callable] = [ - skill_load, + + self._repository = repository or FsSkillRepository( + *(paths or []), + enable_hot_reload=enable_hot_reload, + ) + self._skill_config = skill_config or DEFAULT_SKILL_CONFIG + self._create_ws_name_cb = create_ws_name_cb or default_create_ws_name_callback + self._skill_stager = skill_stager or CopySkillStager() + self._load_tool = SkillLoadTool(repository=self._repository, + skill_stager=self._skill_stager, + create_ws_name_cb=self._create_ws_name_cb) + self._run_tool = SkillRunTool(repository=self._repository, + create_ws_name_cb=self._create_ws_name_cb, + skill_stager=self._skill_stager, + **run_tool_kwargs) + self._exec_tool = SkillExecTool(run_tool=self._run_tool, create_ws_name_cb=self._create_ws_name_cb) + self._function_tools: List[SkillToolFunction] = [ skill_list, skill_list_docs, skill_list_tools, skill_select_docs, skill_select_tools, ] + if runtime_tools: + self._runtime_tools = runtime_tools + else: + workspace_exec_tool = WorkspaceExecTool(workspace_runtime=self._repository.workspace_runtime, + create_ws_name_cb=self._create_ws_name_cb) + self._runtime_tools: List[ToolABC] = [ + SaveArtifactTool(workspace_runtime=self._repository.workspace_runtime, + create_ws_name_cb=self._create_ws_name_cb), + workspace_exec_tool, + WorkspaceWriteStdinTool(workspace_exec_tool), + WorkspaceKillSessionTool(workspace_exec_tool), + ] @property def repository(self) -> BaseSkillRepository: @@ -99,14 +142,21 @@ async def get_tools(self, invocation_context: Optional[InvocationContext] = None Returns: List of tools from all registered skills """ - tools: List[FunctionTool] = [] + tools: List[ToolABC] = [] skill_functions: List[SkillToolFunction] = SKILL_REGISTRY.get_all() - skill_functions.extend(self._tools) - agent_context = invocation_context.agent_context - agent_context.with_metadata(SKILL_REGISTRY_KEY, SKILL_REGISTRY) - agent_context.with_metadata(SKILL_REPOSITORY_KEY, self._repository) + skill_functions.extend(self._function_tools) + if not invocation_context: + invocation_context = get_invocation_ctx() + if invocation_context: + agent_context = invocation_context.agent_context + agent_context.with_metadata(SKILL_REGISTRY_KEY, SKILL_REGISTRY) + agent_context.with_metadata(SKILL_REPOSITORY_KEY, self._repository) + if not is_exist_skill_config(agent_context): + set_skill_config(agent_context, self._skill_config) + tools.append(self._load_tool) tools.append(self._run_tool) tools.append(self._exec_tool) + tools.extend(self._runtime_tools) for skill_function in skill_functions: try: tools.append(FunctionTool(func=skill_function)) diff --git a/trpc_agent_sdk/skills/_utils.py b/trpc_agent_sdk/skills/_utils.py index c9434d7..61bccc0 100644 --- a/trpc_agent_sdk/skills/_utils.py +++ b/trpc_agent_sdk/skills/_utils.py @@ -194,3 +194,60 @@ def get_state_delta(invocation_context: InvocationContext, key: str) -> Optional state = dict(invocation_context.session_state.copy()) state.update(invocation_context.actions.state_delta) return state.get(key, None) + + +FILE_EXTENSIONS_DOC = (".md", ".txt") + + +def is_doc_file(name: str) -> bool: + """Check if a file is a document file. + + Args: + name: The name of the file + + Returns: + True if the file is a document file, False otherwise + """ + name_lower = name.lower() + return name_lower.endswith(FILE_EXTENSIONS_DOC) + + +_SCRIPT_FILE_EXTENSIONS = ( + # Shell + ".sh", + ".bash", + ".zsh", + ".fish", + # Python / Node + ".py", + ".pyw", + ".js", + ".mjs", + ".cjs", + ".ts", + # Other interpreted / script languages + ".rb", + ".pl", + ".php", + ".lua", + ".r", + ".ps1", + ".awk", + ".tcl", + ".groovy", + ".kts", + ".jl", +) + + +def is_script_file(name: str) -> bool: + """Check if a file is a script file. + + Args: + name: The name of the file + + Returns: + True if the file is a script file, False otherwise + """ + name_lower = name.lower() + return name_lower.endswith(_SCRIPT_FILE_EXTENSIONS) diff --git a/trpc_agent_sdk/skills/stager/_base_stager.py b/trpc_agent_sdk/skills/stager/_base_stager.py index 8bfffe3..3119193 100644 --- a/trpc_agent_sdk/skills/stager/_base_stager.py +++ b/trpc_agent_sdk/skills/stager/_base_stager.py @@ -27,6 +27,7 @@ from trpc_agent_sdk.code_executors import WorkspaceRunProgramSpec from trpc_agent_sdk.code_executors import WorkspaceStageOptions from trpc_agent_sdk.context import InvocationContext +from trpc_agent_sdk.log import logger from .._types import SkillMetadata from .._types import SkillWorkspaceMetadata @@ -46,12 +47,14 @@ class Stager: """Materializes skill package contents into a workspace. - - Mirrors Go's ``internal/skillstage.Stager``. Create an instance with - ``Stager()`` (or the module-level helper :func:`new`) and call - :meth:`stage_skill` from async tool code. + Stager is responsible for staging the skill package contents into a workspace. """ + def __init__(self) -> None: + # Deduplicate noisy link warnings within one process lifetime. + # Key format: "||". + self._link_error_warned_keys: set[str] = set() + # ------------------------------------------------------------------ # Public API # ------------------------------------------------------------------ @@ -76,7 +79,7 @@ async def stage_skill(self, request: SkillStageRequest) -> SkillStageResult: ctx = request.ctx ws = request.workspace root = request.repository.path(request.skill_name) - runtime = request.engine or request.repository.workspace_runtime + runtime = request.repository.workspace_runtime name = request.skill_name digest = compute_dir_digest(root) md = await self.load_workspace_metadata(ctx, runtime, ws) @@ -277,7 +280,17 @@ async def _link_workspace_dirs( cmd = (f"set -e; cd {shell_quote(skill_root)}" f"; rm -rf out work {_SKILL_DIR_INPUTS} {shell_quote(_SKILL_DIR_VENV)}" - f"; mkdir -p {shell_quote(to_inputs)} {shell_quote(_SKILL_DIR_VENV)}" + f"; if [ -L {shell_quote(to_work)} ]; then" + f" rm -rf {shell_quote(to_work)}; fi" + f"; if [ -e {shell_quote(to_work)} ] && [ ! -d {shell_quote(to_work)} ]; then" + f" rm -rf {shell_quote(to_work)}; fi" + f"; if [ ! -d {shell_quote(to_work)} ]; then mkdir -p {shell_quote(to_work)}; fi" + f"; if [ -L {shell_quote(to_inputs)} ]; then" + f" rm -rf {shell_quote(to_inputs)}; fi" + f"; if [ -e {shell_quote(to_inputs)} ] && [ ! -d {shell_quote(to_inputs)} ]; then" + f" rm -rf {shell_quote(to_inputs)}; fi" + f"; if [ ! -d {shell_quote(to_inputs)} ]; then mkdir -p {shell_quote(to_inputs)}; fi" + f"; mkdir -p {shell_quote(_SKILL_DIR_VENV)}" f"; ln -sfn {shell_quote(to_out)} out" f"; ln -sfn {shell_quote(to_work)} work" f"; ln -sfn {shell_quote(to_inputs)} {_SKILL_DIR_INPUTS}") @@ -294,8 +307,18 @@ async def _link_workspace_dirs( ctx, ) if ret.exit_code != 0: - from trpc_agent_sdk.log import logger # noqa: PLC0415 - logger.info("Stager._link_workspace_dirs failed for %r: %s", name, ret.stderr) + inv_id = ctx.invocation_id + err = (ret.stderr or "").strip() + dedupe_key = f"{inv_id}|{name}|{err}" + if dedupe_key in self._link_error_warned_keys: + logger.debug("Stager._link_workspace_dirs retry failed for %r: %s", name, ret.stderr) + return + + self._link_error_warned_keys.add(dedupe_key) + # Keep the set bounded for long-lived processes. + if len(self._link_error_warned_keys) > 2000: + self._link_error_warned_keys.clear() + logger.warning("Stager._link_workspace_dirs failed for %r: %s", name, ret.stderr) async def _read_only_except_symlinks( self, @@ -327,7 +350,6 @@ async def _read_only_except_symlinks( ctx, ) if ret.exit_code != 0: - from trpc_agent_sdk.log import logger # noqa: PLC0415 logger.info("Stager._read_only_except_symlinks failed for %r: %s", dest, ret.stderr) @classmethod diff --git a/trpc_agent_sdk/skills/stager/_types.py b/trpc_agent_sdk/skills/stager/_types.py index fec0411..0d599cc 100644 --- a/trpc_agent_sdk/skills/stager/_types.py +++ b/trpc_agent_sdk/skills/stager/_types.py @@ -6,7 +6,6 @@ """Skill staging types.""" from dataclasses import dataclass -from typing import Optional from trpc_agent_sdk.code_executors import BaseWorkspaceRuntime from trpc_agent_sdk.code_executors import WorkspaceInfo @@ -17,10 +16,7 @@ @dataclass class SkillStageRequest: - """Describes the skill staging context for one run. - - Mirrors Go's ``SkillStageRequest`` in ``tool/skill/stager.go``. - """ + """Describes the skill staging context for one run.""" skill_name: str """Name of the skill to stage.""" @@ -34,11 +30,6 @@ class SkillStageRequest: ctx: InvocationContext """Invocation context (used for workspace FS/runner access).""" - engine: Optional[BaseWorkspaceRuntime] = None - """Explicit workspace runtime (Go: ``Engine``). - When ``None`` the runtime is obtained from ``repository.workspace_runtime``. - """ - timeout: float = 300.0 """Timeout in seconds for internal staging helpers.""" diff --git a/trpc_agent_sdk/skills/tools/__init__.py b/trpc_agent_sdk/skills/tools/__init__.py index 9ecf8a5..a6fc68d 100644 --- a/trpc_agent_sdk/skills/tools/__init__.py +++ b/trpc_agent_sdk/skills/tools/__init__.py @@ -1,64 +1,40 @@ -# Tencent is pleased to support the open source community by making tRPC-Agent-Python available. +# -*- coding: utf-8 -*- # -# Copyright (C) 2026 Tencent. All rights reserved. -# -# tRPC-Agent-Python is licensed under Apache-2.0. +# Copyright @ 2025 Tencent.com """Skill tools package.""" +from ._common import CreateWorkspaceNameCallback +from ._common import default_create_ws_name_callback from ._copy_stager import CopySkillStager -from ._copy_stager import normalize_workspace_skill_dir -from ._skill_exec import ExecInput -from ._skill_exec import ExecOutput -from ._skill_exec import KillSessionInput -from ._skill_exec import KillSessionTool -from ._skill_exec import PollSessionInput -from ._skill_exec import PollSessionTool -from ._skill_exec import SessionInteraction -from ._skill_exec import SessionKillOutput +from ._save_artifact import SaveArtifactTool from ._skill_exec import SkillExecTool -from ._skill_exec import WriteStdinInput -from ._skill_exec import WriteStdinTool -from ._skill_exec import create_exec_tools from ._skill_list import skill_list from ._skill_list_docs import skill_list_docs from ._skill_list_tool import skill_list_tools -from ._skill_load import skill_load -from ._skill_run import ArtifactInfo -from ._skill_run import SkillRunFile -from ._skill_run import SkillRunInput -from ._skill_run import SkillRunOutput +from ._skill_load import SkillLoadTool from ._skill_run import SkillRunTool -from ._skill_select_docs import SkillSelectDocsResult from ._skill_select_docs import skill_select_docs -from ._skill_select_tools import SkillSelectToolsResult from ._skill_select_tools import skill_select_tools +from ._workspace_exec import WorkspaceExecTool +from ._workspace_exec import WorkspaceKillSessionTool +from ._workspace_exec import WorkspaceWriteStdinTool +from ._workspace_exec import create_workspace_exec_tools __all__ = [ + "CreateWorkspaceNameCallback", + "default_create_ws_name_callback", "CopySkillStager", - "normalize_workspace_skill_dir", - "ExecInput", - "ExecOutput", - "KillSessionInput", - "KillSessionTool", - "PollSessionInput", - "PollSessionTool", - "SessionInteraction", - "SessionKillOutput", + "SaveArtifactTool", "SkillExecTool", - "WriteStdinInput", - "WriteStdinTool", - "create_exec_tools", "skill_list", "skill_list_docs", "skill_list_tools", - "skill_load", - "ArtifactInfo", - "SkillRunFile", - "SkillRunInput", - "SkillRunOutput", + "SkillLoadTool", "SkillRunTool", - "SkillSelectDocsResult", "skill_select_docs", - "SkillSelectToolsResult", "skill_select_tools", + "WorkspaceExecTool", + "WorkspaceKillSessionTool", + "WorkspaceWriteStdinTool", + "create_workspace_exec_tools", ] diff --git a/trpc_agent_sdk/skills/tools/_common.py b/trpc_agent_sdk/skills/tools/_common.py new file mode 100644 index 0000000..2cd6bf3 --- /dev/null +++ b/trpc_agent_sdk/skills/tools/_common.py @@ -0,0 +1,159 @@ +# Tencent is pleased to support the open source community by making tRPC-Agent-Python available. +# +# Copyright (C) 2026 Tencent. All rights reserved. +# +# tRPC-Agent-Python is licensed under Apache-2.0. +"""Shared helpers for skill/workspace tools.""" + +from __future__ import annotations + +import asyncio +import inspect +import time +from typing import Any +from typing import Awaitable +from typing import Callable +from typing import TypeVar + +from trpc_agent_sdk.context import InvocationContext + +T = TypeVar("T") + +CreateWorkspaceNameCallback = Callable[[InvocationContext], str] +"""Callback to create a workspace name.""" + + +def default_create_ws_name_callback(ctx: InvocationContext) -> str: + """Default callback to create a workspace name.""" + return ctx.session.id + + +def require_non_empty(value: str, *, field_name: str) -> str: + """Validate a required string field and return trimmed value.""" + normalized = (value or "").strip() + if not normalized: + raise ValueError(f"{field_name} is required") + return normalized + + +async def put_session( + sessions: dict[str, T], + lock: asyncio.Lock, + sid: str, + session: T, + gc: Callable[[], Awaitable[None]], +) -> None: + """Insert session after running gc in lock.""" + async with lock: + await gc() + sessions[sid] = session + + +async def get_session( + sessions: dict[str, T], + lock: asyncio.Lock, + sid: str, + gc: Callable[[], Awaitable[None]], +) -> T: + """Lookup session after running gc in lock.""" + async with lock: + await gc() + session = sessions.get(sid) + if session is None: + raise ValueError(f"unknown session_id: {sid}") + return session + + +async def remove_session( + sessions: dict[str, T], + lock: asyncio.Lock, + sid: str, + gc: Callable[[], Awaitable[None]], +) -> T: + """Remove session after running gc in lock.""" + async with lock: + await gc() + session = sessions.pop(sid, None) + if session is None: + raise ValueError(f"unknown session_id: {sid}") + return session + + +async def _await_if_needed(value: object) -> None: + if inspect.isawaitable(value): + await value + + +async def cleanup_expired_sessions( + sessions: dict[str, T], + *, + ttl: float, + refresh_exit_state: Callable[[T, float], object], + close_session: Callable[[T], object], +) -> None: + """Refresh exit state and evict expired sessions in-place.""" + if ttl <= 0: + return + now = time.time() + expired: list[str] = [] + for sid, session in sessions.items(): + await _await_if_needed(refresh_exit_state(session, now)) + exited_at = getattr(session, "exited_at", None) + if exited_at is not None and (now - exited_at) >= ttl: + expired.append(sid) + + for sid in expired: + session = sessions.pop(sid, None) + if session is None: + continue + try: + await _await_if_needed(close_session(session)) + except Exception: # pylint: disable=broad-except + # Best-effort cleanup: mirror Go behavior and keep evicting others. + pass + + +def inline_json_schema_refs(schema: dict[str, Any]) -> dict[str, Any]: + """Inline $ref references in JSON Schema by replacing them with actual definitions.""" + defs = schema.get('$defs', {}) + if not defs: + return schema + + def resolve_ref(obj: Any) -> Any: + if isinstance(obj, dict): + if '$ref' in obj: + ref_path = obj['$ref'] + if ref_path.startswith('#/$defs/'): + ref_name = ref_path.replace('#/$defs/', '') + if ref_name in defs: + resolved = resolve_ref(defs[ref_name]) + merged = {**resolved, **{k: v for k, v in obj.items() if k != '$ref'}} + return merged + return obj + else: + return {k: resolve_ref(v) for k, v in obj.items()} + elif isinstance(obj, list): + return [resolve_ref(item) for item in obj] + else: + return obj + + result = {k: v for k, v in schema.items() if k != '$defs'} + result = resolve_ref(result) + return result + + +SKILL_STAGED_WORKSPACE_DIR_KEY = "__trpc_agent_skills_staged_workspace_dir" +"""Key for staged workspace directory.""" + + +def get_staged_workspace_dir(ctx: InvocationContext, skill_name: str) -> str: + """Get the staged workspace directory.""" + dir_map = ctx.agent_context.get_metadata(SKILL_STAGED_WORKSPACE_DIR_KEY, {}) + return dir_map.get(skill_name, "") + + +def set_staged_workspace_dir(ctx: InvocationContext, skill_name: str, dir: str) -> None: + """Set the staged workspace directory.""" + dir_map = ctx.agent_context.get_metadata(SKILL_STAGED_WORKSPACE_DIR_KEY, {}) + dir_map[skill_name] = dir + ctx.agent_context.with_metadata(SKILL_STAGED_WORKSPACE_DIR_KEY, dir_map) diff --git a/trpc_agent_sdk/skills/tools/_copy_stager.py b/trpc_agent_sdk/skills/tools/_copy_stager.py index 534f616..0ca5288 100644 --- a/trpc_agent_sdk/skills/tools/_copy_stager.py +++ b/trpc_agent_sdk/skills/tools/_copy_stager.py @@ -22,17 +22,8 @@ from ..stager import SkillStageResult from ..stager import Stager -# --------------------------------------------------------------------------- -# Error messages (mirrors Go const block in tool/skill/stager.go) -# --------------------------------------------------------------------------- - -_ERR_STAGER_NOT_CONFIGURED = "skill stager is not configured" _ERR_REPO_NOT_CONFIGURED = "skill repository is not configured" -# --------------------------------------------------------------------------- -# Allowed workspace roots (mirrors Go isAllowedWorkspacePath) -# --------------------------------------------------------------------------- - _ALLOWED_WS_ROOTS = ( DIR_SKILLS, # "skills" DIR_WORK, # "work" @@ -40,17 +31,12 @@ DIR_RUNS, # "runs" ) -# --------------------------------------------------------------------------- -# Path normalization helpers -# (mirrors normalizeWorkspaceSkillDir / normalizeSkillStageResult) -# --------------------------------------------------------------------------- - def normalize_workspace_skill_dir(raw: str) -> str: """Normalize and validate a workspace-relative skill directory. Mirrors Go's ``normalizeWorkspaceSkillDir``. Strips leading slashes, - normalizes path separators, and ensures the result stays within a + normalises path separators, and ensures the result stays within a known workspace root. Raises: @@ -74,7 +60,7 @@ def normalize_workspace_skill_dir(raw: str) -> str: def _normalize_skill_stage_result(result: SkillStageResult) -> SkillStageResult: - """Return *result* with :attr:`workspace_skill_dir` normalized. + """Return *result* with :attr:`workspace_skill_dir` normalised. Mirrors Go's ``normalizeSkillStageResult``. @@ -84,39 +70,15 @@ def _normalize_skill_stage_result(result: SkillStageResult) -> SkillStageResult: return SkillStageResult(workspace_skill_dir=normalize_workspace_skill_dir(result.workspace_skill_dir)) -# --------------------------------------------------------------------------- -# Default copy-based stager (mirrors Go copySkillStager) -# --------------------------------------------------------------------------- - - class CopySkillStager(Stager): """Default stager: copies the skill directory into ``skills/``. - Mirrors Go's ``copySkillStager`` struct. Holds an optional - back-reference to the owning ``SkillRunTool`` (``run_tool``), matching - the Go pattern of ``copySkillStager{tool: tool}``. - The actual filesystem work — digest check, ``stage_directory``, symlink - creation, read-only chmod — is delegated to - :class:`~trpc_agent_sdk.skills.stager.Stager`, which mirrors - :class:`~trpc_agent_sdk.skills.stager.Stager`. - - Construct via :func:`new_copy_skill_stager` or - :class:`~trpc_agent_sdk.skills.tools.CopySkillStager`. + creation, read-only chmod — is delegated to :class:`~trpc_agent.skills.stager.Stager`. """ async def stage_skill(self, request: SkillStageRequest) -> SkillStageResult: - """Stage the skill and return the normalized workspace skill dir. - - Mirrors Go's ``copySkillStager.StageSkill``: - - 1. Validate repository is present. - 2. Resolve the on-disk skill root via the repository. - 3. Resolve the runtime (``request.engine`` takes priority, falls back - to ``repository.workspace_runtime``). - 4. Delegate copy / link / chmod to :class:`~trpc_agent_sdk.skills.stager.Stager`. - 5. Return a normalized :class:`SkillStageResult`. - """ + """Stage the skill and return the normalised workspace skill dir.""" if request.repository is None: raise ValueError(_ERR_REPO_NOT_CONFIGURED) diff --git a/trpc_agent_sdk/skills/tools/_save_artifact.py b/trpc_agent_sdk/skills/tools/_save_artifact.py new file mode 100644 index 0000000..17cb70c --- /dev/null +++ b/trpc_agent_sdk/skills/tools/_save_artifact.py @@ -0,0 +1,262 @@ +# -*- coding: utf-8 -*- +# +# Copyright @ 2025 Tencent.com +"""Workspace artifact save tool. + +This tool persists an existing file from the current workspace as an artifact. +""" + +from __future__ import annotations + +import mimetypes +import os +import posixpath +from typing import Any +from typing import Optional + +from trpc_agent_sdk.code_executors import BaseWorkspaceRuntime +from trpc_agent_sdk.code_executors import DIR_OUT +from trpc_agent_sdk.code_executors import DIR_RUNS +from trpc_agent_sdk.code_executors import DIR_WORK +from trpc_agent_sdk.code_executors import WORKSPACE_ENV_DIR_KEY +from trpc_agent_sdk.code_executors.utils import normalize_globs +from trpc_agent_sdk.context import InvocationContext +from trpc_agent_sdk.filter import BaseFilter +from trpc_agent_sdk.tools import BaseTool +from trpc_agent_sdk.types import FunctionDeclaration +from trpc_agent_sdk.types import Part +from trpc_agent_sdk.types import Schema +from trpc_agent_sdk.types import Type + +from .._constants import SKILL_ARTIFACTS_STATE_KEY +from ._common import CreateWorkspaceNameCallback +from ._common import default_create_ws_name_callback + +_DEFAULT_MAX_BYTES = 64 * 1024 * 1024 +_ALLOWED_ROOTS = (DIR_WORK, DIR_OUT, DIR_RUNS) +_SAVE_REASON_NO_SERVICE = "artifact service is not configured" +_SAVE_REASON_NO_SESSION = "session is missing from invocation context" +_SAVE_REASON_NO_SESSION_IDS = "session app/user/session IDs are missing" + + +def _has_glob_meta(s: str) -> bool: + return any(ch in s for ch in ("*", "?", "[")) + + +def _normalize_workspace_prefix(path: str) -> str: + """Normalize workspace env-style prefixes in path.""" + s = path.strip().replace("\\", "/") + if s.startswith("workspace://"): + s = s[len("workspace://"):] + replacements = ( + ("${WORKSPACE_DIR}/", ""), + ("$WORKSPACE_DIR/", ""), + ("${WORK_DIR}/", f"{DIR_WORK}/"), + ("$WORK_DIR/", f"{DIR_WORK}/"), + ("${OUTPUT_DIR}/", f"{DIR_OUT}/"), + ("$OUTPUT_DIR/", f"{DIR_OUT}/"), + ("${RUN_DIR}/", f"{DIR_RUNS}/"), + ("$RUN_DIR/", f"{DIR_RUNS}/"), + ) + for src, dst in replacements: + if s.startswith(src): + return dst + s[len(src):] + if s in ("$WORKSPACE_DIR", "${WORKSPACE_DIR}"): + return "" + if s in ("$WORK_DIR", "${WORK_DIR}"): + return DIR_WORK + if s in ("$OUTPUT_DIR", "${OUTPUT_DIR}"): + return DIR_OUT + if s in ("$RUN_DIR", "${RUN_DIR}"): + return DIR_RUNS + return s + + +def _is_workspace_env_path(path: str) -> bool: + s = path.strip() + if not s: + return False + return s.startswith("$") or s.startswith("${") + + +def _is_allowed_publish_path(rel: str) -> bool: + return any(rel == root or rel.startswith(f"{root}/") for root in _ALLOWED_ROOTS) + + +def _artifact_save_skip_reason(ctx: InvocationContext) -> str: + if ctx.artifact_service is None: + return _SAVE_REASON_NO_SERVICE + if ctx.session is None: + return _SAVE_REASON_NO_SESSION + if not (ctx.app_name and ctx.user_id and ctx.session_id): + return _SAVE_REASON_NO_SESSION_IDS + return "" + + +def _apply_artifact_state_delta(ctx: InvocationContext, saved_as: str, version: int, ref: str) -> None: + tool_call_id = (ctx.function_call_id or "").strip() + if not tool_call_id or not saved_as or version < 0: + return + artifact_ref = ref.strip() or f"artifact://{saved_as}@{version}" + ctx.actions.state_delta[SKILL_ARTIFACTS_STATE_KEY] = { + "tool_call_id": tool_call_id, + "artifacts": [{ + "name": saved_as, + "version": version, + "ref": artifact_ref, + }], + } + + +def _resolve_workspace_root_from_config(ctx: InvocationContext) -> str: + """Resolve workspace root from run_config/env; fallback to cwd.""" + if ctx.run_config and isinstance(ctx.run_config.custom_data, dict): + for k in ("workspace_dir", "workspace_root", "workspace_path"): + v = ctx.run_config.custom_data.get(k) + if isinstance(v, str) and v.strip(): + return os.path.abspath(v.strip()) + env_root = os.environ.get(WORKSPACE_ENV_DIR_KEY, "").strip() + if env_root: + return os.path.abspath(env_root) + return os.path.abspath(os.getcwd()) + + +def _normalize_artifact_path(raw: str, workspace_root: str) -> tuple[str, str]: + """Return (workspace-relative path, absolute file path).""" + s = _normalize_workspace_prefix(raw) + if not s: + raise ValueError("path is required") + if _has_glob_meta(s): + raise ValueError("path must not contain glob patterns") + if _is_workspace_env_path(s): + out = normalize_globs([s.replace("${RUN_DIR}", DIR_RUNS).replace("$RUN_DIR", DIR_RUNS)]) + if not out: + raise ValueError("invalid path") + s = out[0] + + cleaned = posixpath.normpath(s) + if posixpath.isabs(cleaned): + rel = cleaned.lstrip("/") + rel = posixpath.normpath(rel) + if rel in ("", ".", ".."): + raise ValueError("path must point to a file inside the workspace") + else: + rel = cleaned + if rel in (".", "..") or rel.startswith("../"): + raise ValueError("path must stay within the workspace") + + if not _is_allowed_publish_path(rel): + raise ValueError("path must stay under work/, out/, or runs/") + + abs_path = os.path.abspath(os.path.join(workspace_root, rel)) + rel_check = os.path.relpath(abs_path, workspace_root).replace("\\", "/") + if rel_check in (".", "..") or rel_check.startswith("../"): + raise ValueError("path must stay within the workspace") + + rel = posixpath.normpath(rel) + return rel, abs_path + + +class SaveArtifactTool(BaseTool): + """Persist an existing workspace file as an artifact.""" + + def __init__( + self, + max_file_bytes: int = _DEFAULT_MAX_BYTES, + workspace_runtime: Optional[BaseWorkspaceRuntime] = None, + create_ws_name_cb: Optional[CreateWorkspaceNameCallback] = None, + filters_name: Optional[list[str]] = None, + filters: Optional[list[BaseFilter]] = None, + ): + super().__init__( + name="workspace_save_artifact", + description=("Save an existing file from the current workspace as an artifact. " + "Path must be under work/, out/, or runs/."), + filters_name=filters_name, + filters=filters, + ) + self._max_file_bytes = max_file_bytes + self._workspace_runtime = workspace_runtime + self._create_ws_name_cb = create_ws_name_cb or default_create_ws_name_callback + + async def _resolve_workspace_root(self, ctx: InvocationContext) -> str: + """Resolve workspace root, preferring the shared workspace_exec workspace.""" + runtime = self._workspace_runtime + if runtime is not None: + workspace_id = self._create_ws_name_cb(ctx) + ws = await runtime.manager(ctx).create_workspace(workspace_id, ctx) + if ws.path: + return os.path.abspath(ws.path) + return _resolve_workspace_root_from_config(ctx) + + def _get_declaration(self) -> Optional[FunctionDeclaration]: + return FunctionDeclaration( + name="workspace_save_artifact", + description=("Save an existing file from the current workspace as an artifact. " + "Use this to get a stable artifact:// reference for files under " + "work/, out/, or runs/."), + parameters=Schema( + type=Type.OBJECT, + required=["path"], + properties={ + "path": + Schema( + type=Type.STRING, + description=("Workspace-relative file path to save. " + "Supports prefixes like $WORK_DIR/, $OUTPUT_DIR/, " + "$RUN_DIR/, and workspace://."), + ), + }, + ), + response=Schema( + type=Type.OBJECT, + required=["path", "saved_as", "version", "ref", "size_bytes"], + properties={ + "path": Schema(type=Type.STRING, description="Workspace-relative source path."), + "saved_as": Schema(type=Type.STRING, description="Artifact name used when saving."), + "version": Schema(type=Type.INTEGER, description="Artifact version."), + "ref": Schema(type=Type.STRING, description="artifact:// reference for saved artifact."), + "mime_type": Schema(type=Type.STRING, description="Detected MIME type."), + "size_bytes": Schema(type=Type.INTEGER, description="File size in bytes."), + }, + ), + ) + + async def _run_async_impl(self, *, tool_context: InvocationContext, args: dict[str, Any]) -> Any: + raw_path = str(args.get("path", "")).strip() + if not raw_path: + return {"error": "INVALID_PARAMETER: path is required"} + reason = _artifact_save_skip_reason(tool_context) + if reason: + return {"error": f"ARTIFACT_SAVE_UNAVAILABLE: {reason}"} + + try: + workspace_root = await self._resolve_workspace_root(tool_context) + rel, abs_path = _normalize_artifact_path(raw_path, workspace_root) + except ValueError as ex: + return {"error": f"INVALID_PARAMETER: {str(ex)}"} + + if not os.path.exists(abs_path): + return {"error": f"FILE_NOT_FOUND: workspace artifact file not found: {rel}"} + if not os.path.isfile(abs_path): + return {"error": f"INVALID_PATH: path is not a file: {rel}"} + + size_bytes = os.path.getsize(abs_path) + if size_bytes > self._max_file_bytes: + return {"error": f"FILE_TOO_LARGE: file exceeds {self._max_file_bytes} bytes limit"} + + with open(abs_path, "rb") as f: + data = f.read() + + mime_type = mimetypes.guess_type(abs_path)[0] or "application/octet-stream" + version = await tool_context.save_artifact(rel, Part.from_bytes(data=data, mime_type=mime_type)) + ref = f"artifact://{rel}@{version}" + _apply_artifact_state_delta(tool_context, rel, version, ref) + return { + "path": rel, + "saved_as": rel, + "version": version, + "ref": ref, + "mime_type": mime_type, + "size_bytes": size_bytes, + } diff --git a/trpc_agent_sdk/skills/tools/_skill_exec.py b/trpc_agent_sdk/skills/tools/_skill_exec.py index 7e4cfe6..6a0e511 100644 --- a/trpc_agent_sdk/skills/tools/_skill_exec.py +++ b/trpc_agent_sdk/skills/tools/_skill_exec.py @@ -5,15 +5,13 @@ # tRPC-Agent-Python is licensed under Apache-2.0. """Interactive skill execution tools. -Provides four tools that mirror the Go :mod:`~trpc_agent_sdk.skills.tools.SkillExecTool` implementation: - -* :class:`~trpc_agent_sdk.skills.tools.SkillExecTool` — start an interactive session -* :class:`~trpc_agent_sdk.skills.tools.WriteStdinTool` — write stdin to a running session -* :class:`~trpc_agent_sdk.skills.tools.PollSessionTool` — poll a session for new output -* :class:`~trpc_agent_sdk.skills.tools.KillSessionTool` — terminate and remove a session +* ``skill_exec`` — start an interactive session (SkillExecTool) +* ``skill_write_stdin`` — write stdin to a running session (WriteStdinTool) +* ``skill_poll_session`` — poll a session for new output (PollSessionTool) +* ``skill_kill_session`` — terminate and remove a session (KillSessionTool) Sessions run real sub-processes inside the staged skill workspace. When -:attr:`~trpc_agent_sdk.skills.tools.ExecInput.tty` is ``True`` a POSIX pseudo-terminal is allocated so TTY-aware programs work +``tty=True`` a POSIX pseudo-terminal is allocated so TTY-aware programs work correctly (e.g. interactive shells, ncurses UIs). Usage example:: @@ -22,7 +20,7 @@ result = await skill_exec_tool.run(ctx, { "skill": "my_skill", "command": "python interactive.py", - "yield_ms": 500, + "yield_time_ms": 500, }) sid = result["session_id"] @@ -46,6 +44,7 @@ import uuid from dataclasses import dataclass from dataclasses import field +from datetime import datetime from pathlib import Path from typing import Any from typing import Dict @@ -55,13 +54,29 @@ from pydantic import BaseModel from pydantic import Field +from trpc_agent_sdk.code_executors import BaseProgramRunner +from trpc_agent_sdk.code_executors import BaseProgramSession +from trpc_agent_sdk.code_executors import DEFAULT_EXEC_YIELD_MS +from trpc_agent_sdk.code_executors import DEFAULT_IO_YIELD_MS +from trpc_agent_sdk.code_executors import DEFAULT_SESSION_KILL_SEC +from trpc_agent_sdk.code_executors import DEFAULT_SESSION_TTL_SEC from trpc_agent_sdk.code_executors import DIR_OUT +from trpc_agent_sdk.code_executors import DIR_RUNS from trpc_agent_sdk.code_executors import DIR_SKILLS from trpc_agent_sdk.code_executors import DIR_WORK +from trpc_agent_sdk.code_executors import ENV_OUTPUT_DIR +from trpc_agent_sdk.code_executors import ENV_SKILLS_DIR from trpc_agent_sdk.code_executors import ENV_SKILL_NAME +from trpc_agent_sdk.code_executors import ENV_WORK_DIR +from trpc_agent_sdk.code_executors import PROGRAM_STATUS_EXITED +from trpc_agent_sdk.code_executors import WORKSPACE_ENV_DIR_KEY from trpc_agent_sdk.code_executors import WorkspaceInfo from trpc_agent_sdk.code_executors import WorkspaceInputSpec from trpc_agent_sdk.code_executors import WorkspaceOutputSpec +from trpc_agent_sdk.code_executors import WorkspaceRunProgramSpec +from trpc_agent_sdk.code_executors import poll_line_limit +from trpc_agent_sdk.code_executors import wait_for_program_output +from trpc_agent_sdk.code_executors import yield_duration_ms from trpc_agent_sdk.context import InvocationContext from trpc_agent_sdk.filter import BaseFilter from trpc_agent_sdk.log import logger @@ -69,24 +84,20 @@ from trpc_agent_sdk.types import FunctionDeclaration from trpc_agent_sdk.types import Schema +from .._constants import SKILL_ARTIFACTS_STATE_KEY +from ._common import CreateWorkspaceNameCallback +from ._common import cleanup_expired_sessions +from ._common import default_create_ws_name_callback +from ._common import inline_json_schema_refs +from ._common import require_non_empty from ._copy_stager import SkillStageRequest from ._skill_run import SkillRunInput from ._skill_run import SkillRunOutput from ._skill_run import SkillRunTool from ._skill_run import _filter_failed_empty_outputs -from ._skill_run import _inline_json_schema_refs from ._skill_run import _select_primary_output from ._skill_run import _truncate_output -# --------------------------------------------------------------------------- -# Defaults (mirrors Go program's session defaults) -# --------------------------------------------------------------------------- - -DEFAULT_EXEC_YIELD_MS: int = 300 # wait time on skill_exec -DEFAULT_IO_YIELD_MS: int = 100 # wait time on write/poll -DEFAULT_POLL_LINES: int = 50 # max lines returned per call -DEFAULT_SESSION_TTL: float = 300.0 # seconds after exit before GC - # Status strings _STATUS_RUNNING = "running" _STATUS_EXITED = "exited" @@ -109,7 +120,7 @@ class ExecInput(BaseModel): env: dict[str, str] = Field(default_factory=dict, description="Extra environment variables") stdin: str = Field(default="", description="Optional initial stdin written before yielding") tty: bool = Field(default=False, description="Allocate a pseudo-TTY") - yield_ms: int = Field(default=0, description="Milliseconds to wait for initial output before returning") + yield_time_ms: int = Field(default=0, description="Milliseconds to wait for initial output before returning") poll_lines: int = Field(default=0, description="Maximum output lines to return per call") output_files: list[str] = Field(default_factory=list, description="Glob patterns to collect on exit") timeout: int = Field(default=0, description="Timeout in seconds (0 = no timeout)") @@ -126,7 +137,7 @@ class WriteStdinInput(BaseModel): session_id: str = Field(..., description="Session id returned by skill_exec") chars: str = Field(default="", description="Text to write to stdin") submit: bool = Field(default=False, description="Append a newline after chars") - yield_ms: int = Field(default=0, description="Milliseconds to wait for new output") + yield_time_ms: int = Field(default=0, description="Milliseconds to wait for new output") poll_lines: int = Field(default=0, description="Maximum output lines to return") @@ -134,7 +145,7 @@ class PollSessionInput(BaseModel): """Input for skill_poll_session.""" session_id: str = Field(..., description="Session id returned by skill_exec") - yield_ms: int = Field(default=0, description="Milliseconds to wait for new output") + yield_time_ms: int = Field(default=0, description="Milliseconds to wait for new output") poll_lines: int = Field(default=0, description="Maximum output lines to return") @@ -173,6 +184,31 @@ class SessionKillOutput(BaseModel): status: str = Field(default="", description="Final status after kill") +def _apply_artifacts_state_delta(ctx: InvocationContext, output: ExecOutput) -> None: + """Store replayable artifact refs in state delta (Go execArtifactsStateDelta parity).""" + tool_call_id = (ctx.function_call_id or "").strip() + if not tool_call_id or output.result is None or not output.result.artifact_files: + return + + artifacts: list[dict[str, Any]] = [] + for item in output.result.artifact_files: + name = (item.name or "").strip() + version = int(item.version) + if not name or version < 0: + continue + artifacts.append({ + "name": name, + "version": version, + "ref": f"artifact://{name}@{version}", + }) + if not artifacts: + return + ctx.actions.state_delta[SKILL_ARTIFACTS_STATE_KEY] = { + "tool_call_id": tool_call_id, + "artifacts": artifacts, + } + + # --------------------------------------------------------------------------- # Internal session state # --------------------------------------------------------------------------- @@ -182,169 +218,31 @@ class SessionKillOutput(BaseModel): class _ExecSession: """Holds state for one running interactive skill session.""" - proc: asyncio.subprocess.Process + proc: BaseProgramSession ws: WorkspaceInfo in_data: ExecInput - # Output buffer (all output since start as raw bytes → decoded text) - _output_buf: list[str] = field(default_factory=list) - _output_lock: asyncio.Lock = field(default_factory=asyncio.Lock) - _output_event: asyncio.Event = field(default_factory=asyncio.Event) - - # Current byte offset for incremental reads - _read_offset: int = 0 - - # Background reader task - reader_task: Optional[asyncio.Task] = None - - # PTY master fd (None for non-TTY) - master_fd: Optional[int] = None - # Final state exit_code: Optional[int] = None exited_at: Optional[float] = None final_result: Optional[SkillRunOutput] = None finalized: bool = False - async def append_output(self, chunk: str) -> None: - async with self._output_lock: - self._output_buf.append(chunk) - self._output_event.set() - - async def total_output(self) -> str: - async with self._output_lock: - return "".join(self._output_buf) - - async def yield_output(self, yield_ms: int, poll_lines: int) -> tuple[str, str, int, int]: - """Wait *yield_ms* ms for new output then return a chunk. + async def yield_output(self, yield_time_ms: int, poll_lines: int) -> tuple[str, str, int, int]: + """Wait *yield_time_ms* ms for new output then return a chunk. Returns ``(status, output_chunk, offset, next_offset)``. """ - yield_sec = (yield_ms or DEFAULT_EXEC_YIELD_MS) / 1000.0 - deadline = asyncio.get_event_loop().time() + yield_sec - - while True: - remaining = deadline - asyncio.get_event_loop().time() - if remaining <= 0: - break - self._output_event.clear() - # Break early if process has already exited and we have all output - if self.proc.returncode is not None: - break - try: - await asyncio.wait_for(asyncio.shield(self._output_event.wait()), timeout=remaining) - except asyncio.TimeoutError: - break - - # Determine status - rc = self.proc.returncode - if rc is not None: - self.exit_code = rc - if self.exited_at is None: - self.exited_at = time.time() - status = _STATUS_RUNNING if rc is None else _STATUS_EXITED - - # Slice the output since last read - async with self._output_lock: - full = "".join(self._output_buf) - - chunk = full[self._read_offset:] - - # Apply poll_lines limit - limit = poll_lines or DEFAULT_POLL_LINES - if limit > 0 and chunk: - lines = chunk.split("\n") - if len(lines) > limit: - chunk = "\n".join(lines[:limit]) - if not chunk.endswith("\n"): - chunk += "\n" - - offset = self._read_offset - self._read_offset += len(chunk) - next_offset = self._read_offset - - return status, chunk, offset, next_offset - - -# --------------------------------------------------------------------------- -# Background output readers -# --------------------------------------------------------------------------- - - -async def _read_pipe(session: _ExecSession, stream: asyncio.StreamReader) -> None: - """Continuously read from *stream* and append to session output.""" - try: - while True: - data = await stream.read(4096) - if not data: - break - await session.append_output(data.decode("utf-8", errors="replace")) - except Exception: # pylint: disable=broad-except - pass - finally: - # Ensure exit is captured - try: - await session.proc.wait() - except Exception: # pylint: disable=broad-except - pass - session._output_event.set() # unblock any waiting yield - - -async def _read_pty(session: _ExecSession, master_fd: int) -> None: - """Continuously read from a PTY master fd and append to session output.""" - loop = asyncio.get_event_loop() - try: - # Set master_fd non-blocking - flags = fcntl.fcntl(master_fd, fcntl.F_GETFL) - fcntl.fcntl(master_fd, fcntl.F_SETFL, flags | os.O_NONBLOCK) - - # Use loop.add_reader for non-blocking reads - read_event = asyncio.Event() - - def _on_readable() -> None: - read_event.set() - - loop.add_reader(master_fd, _on_readable) - try: - while True: - read_event.clear() - # Check if process has exited - rc = session.proc.returncode - if rc is not None: - # Drain remaining data - while True: - try: - data = os.read(master_fd, 4096) - if data: - await session.append_output(data.decode("utf-8", errors="replace")) - else: - break - except OSError: - break - break - # Wait for readable or timeout - try: - await asyncio.wait_for(read_event.wait(), timeout=0.05) - except asyncio.TimeoutError: - pass - try: - data = os.read(master_fd, 4096) - if data: - await session.append_output(data.decode("utf-8", errors="replace")) - except BlockingIOError: - pass - except OSError: - break - finally: - loop.remove_reader(master_fd) - except Exception: # pylint: disable=broad-except - pass - finally: - try: - await session.proc.wait() - except Exception: # pylint: disable=broad-except - pass - session._output_event.set() + poll = await wait_for_program_output( + self.proc, + yield_duration_ms(yield_time_ms, DEFAULT_EXEC_YIELD_MS), + poll_line_limit(poll_lines), + ) + if poll.exit_code is not None: + self.exit_code = poll.exit_code + if poll.status == _STATUS_EXITED and self.exited_at is None: + self.exited_at = time.time() + return poll.status, poll.output, poll.offset, poll.next_offset # --------------------------------------------------------------------------- @@ -395,10 +293,6 @@ def _detect_interaction(status: str, output: str) -> Optional[SessionInteraction def _build_exec_env(ws: WorkspaceInfo, extra: dict[str, str]) -> dict[str, str]: """Build the merged environment for a subprocess in *ws*.""" - from trpc_agent_sdk.code_executors._constants import ( # lazy import to avoid circular - WORKSPACE_ENV_DIR_KEY, ENV_SKILLS_DIR, ENV_WORK_DIR, ENV_OUTPUT_DIR, DIR_RUNS, - ) - from datetime import datetime env = os.environ.copy() run_dir = str(Path(ws.path) / DIR_RUNS / f"run_{datetime.now().strftime('%Y%m%dT%H%M%S_%f')}") @@ -416,15 +310,6 @@ def _build_exec_env(ws: WorkspaceInfo, extra: dict[str, str]) -> dict[str, str]: return env -def _resolve_abs_cwd(ws_path: str, rel_cwd: str) -> str: - """Return the absolute cwd by joining *ws_path* and *rel_cwd*.""" - if rel_cwd and os.path.isabs(rel_cwd): - return rel_cwd - resolved = os.path.normpath(os.path.join(ws_path, rel_cwd or ".")) - os.makedirs(resolved, exist_ok=True) - return resolved - - # --------------------------------------------------------------------------- # SkillExecTool (skill_exec) # --------------------------------------------------------------------------- @@ -442,7 +327,8 @@ def __init__( self, run_tool: SkillRunTool, filters: Optional[List[BaseFilter]] = None, - session_ttl: float = DEFAULT_SESSION_TTL, + session_ttl: float = DEFAULT_SESSION_TTL_SEC, + create_ws_name_cb: Optional[CreateWorkspaceNameCallback] = None, ): super().__init__(name="skill_exec", description=("Start an interactive command inside a skill workspace. " @@ -452,8 +338,8 @@ def __init__( filters=filters) self._run_tool = run_tool self._ttl = session_ttl + self._create_ws_name_cb = create_ws_name_cb or default_create_ws_name_callback self._sessions: dict[str, _ExecSession] = {} - self._sessions_lock = asyncio.Lock() # ------------------------------------------------------------------ # Declaration @@ -461,8 +347,8 @@ def __init__( @override def _get_declaration(self) -> FunctionDeclaration: - params_schema = _inline_json_schema_refs(ExecInput.model_json_schema()) - response_schema = _inline_json_schema_refs(ExecOutput.model_json_schema()) + params_schema = inline_json_schema_refs(ExecInput.model_json_schema()) + response_schema = inline_json_schema_refs(ExecOutput.model_json_schema()) return FunctionDeclaration( name="skill_exec", description=("Start an interactive command inside a skill workspace. " @@ -476,37 +362,39 @@ def _get_declaration(self) -> FunctionDeclaration: # Session management # ------------------------------------------------------------------ - async def _put_session(self, sid: str, sess: _ExecSession) -> None: - async with self._sessions_lock: - await self._gc_expired_locked() - self._sessions[sid] = sess + async def _put_session(self, sid: str, exec_session: _ExecSession) -> None: + await self._gc_expired_sessions() + self._sessions[sid] = exec_session - async def get_session(self, sid: str) -> _ExecSession: - async with self._sessions_lock: - await self._gc_expired_locked() - sess = self._sessions.get(sid) - if sess is None: + async def _get_session(self, sid: str) -> _ExecSession: + await self._gc_expired_sessions() + session = self._sessions.get(sid) + if session is None: raise ValueError(f"unknown session_id: {sid}") - return sess + return session - async def remove_session(self, sid: str) -> _ExecSession: - async with self._sessions_lock: - await self._gc_expired_locked() - sess = self._sessions.pop(sid, None) - if sess is None: + async def _remove_session(self, sid: str) -> _ExecSession: + await self._gc_expired_sessions() + session = self._sessions.pop(sid, None) + if session is None: raise ValueError(f"unknown session_id: {sid}") - return sess - - async def _gc_expired_locked(self) -> None: - if self._ttl <= 0: - return - now = time.time() - expired = [ - sid for sid, s in self._sessions.items() if s.exited_at is not None and (now - s.exited_at) >= self._ttl - ] - for sid in expired: - s = self._sessions.pop(sid) - _close_session(s) + return session + + async def _gc_expired_sessions(self) -> None: + + async def _refresh_exit_state(session: _ExecSession, now: float) -> None: + if session.exited_at is not None: + return + session_state = await session.proc.state() + if session_state.status == PROGRAM_STATUS_EXITED: + session.exited_at = now + + await cleanup_expired_sessions( + self._sessions, + ttl=self._ttl, + refresh_exit_state=_refresh_exit_state, + close_session=_close_session, + ) # ------------------------------------------------------------------ # Main execution @@ -523,26 +411,30 @@ async def _run_async_impl( inputs = ExecInput.model_validate(args) except Exception as ex: # pylint: disable=broad-except raise ValueError(f"Invalid skill_exec arguments: {ex}") from ex + normalized_skill = inputs.skill.strip() + normalized_command = inputs.command.strip() + if not normalized_skill or not normalized_command: + raise ValueError("skill and command are required") + inputs = inputs.model_copy(update={"skill": normalized_skill, "command": normalized_command}) + + if self._run_tool.require_skill_loaded and not self._run_tool._is_skill_loaded(tool_context, normalized_skill): + raise ValueError(f"skill_exec requires skill_load first for {normalized_skill!r}") repository = self._run_tool._get_repository(tool_context) # Workspace creation - session_id_ws = inputs.skill - if tool_context.session and tool_context.session.id: - session_id_ws = tool_context.session.id - workspace_runtime = repository.workspace_runtime manager = workspace_runtime.manager(tool_context) - ws = await manager.create_workspace(session_id_ws, tool_context) + workspace_id = self._create_ws_name_cb(tool_context) + ws = await manager.create_workspace(workspace_id, tool_context) # Stage skill via the same pluggable stager used by SkillRunTool stage_result = await self._run_tool.skill_stager.stage_skill( SkillStageRequest( - skill_name=inputs.skill, + skill_name=normalized_skill, repository=repository, workspace=ws, ctx=tool_context, - engine=workspace_runtime, timeout=self._run_tool._timeout, )) workspace_skill_dir = stage_result.workspace_skill_dir @@ -553,29 +445,35 @@ async def _run_async_impl( # Resolve cwd and env rel_cwd = self._run_tool._resolve_cwd(inputs.cwd, workspace_skill_dir) - abs_cwd = _resolve_abs_cwd(ws.path, rel_cwd) extra_env: dict[str, str] = dict(inputs.env) if ENV_SKILL_NAME not in extra_env: - extra_env[ENV_SKILL_NAME] = inputs.skill + extra_env[ENV_SKILL_NAME] = normalized_skill merged_env = _build_exec_env(ws, extra_env) - # Start subprocess + # Start interactive program session via runtime runner. + runner = workspace_runtime.runner(tool_context) + start_program = getattr(runner, "start_program", None) + if start_program is None: + raise ValueError("skill_exec is not supported by the current executor") sid = str(uuid.uuid4()) - sess = await _start_session(inputs, ws, abs_cwd, merged_env) - await self._put_session(sid, sess) - - # Write initial stdin if provided - if inputs.stdin: - await _write_stdin(sess, inputs.stdin, submit=False) + exec_session = await _start_session( + runner=runner, + tool_context=tool_context, + inputs=inputs, + ws=ws, + rel_cwd=rel_cwd, + env=merged_env, + ) + await self._put_session(sid, exec_session) - yield_ms = inputs.yield_ms or DEFAULT_EXEC_YIELD_MS - status, chunk, offset, next_offset = await sess.yield_output(yield_ms, inputs.poll_lines) + yield_time_ms = inputs.yield_time_ms or DEFAULT_EXEC_YIELD_MS + status, chunk, offset, next_offset = await exec_session.yield_output(yield_time_ms, inputs.poll_lines) # Attempt to collect final result if already exited final_result = None - if status == _STATUS_EXITED and not sess.finalized: - final_result = await _collect_final_result(tool_context, sess, self._run_tool) + if status == _STATUS_EXITED and not exec_session.finalized: + final_result = await _collect_final_result(tool_context, exec_session, self._run_tool) out = ExecOutput( status=status, @@ -583,10 +481,11 @@ async def _run_async_impl( output=chunk, offset=offset, next_offset=next_offset, - exit_code=sess.exit_code, + exit_code=exec_session.exit_code, interaction=_detect_interaction(status, chunk), result=final_result, ) + _apply_artifacts_state_delta(tool_context, out) return out.model_dump(exclude_none=True) @@ -613,8 +512,8 @@ def __init__(self, exec_tool: SkillExecTool, filters: Optional[List[BaseFilter]] @override def _get_declaration(self) -> FunctionDeclaration: - params_schema = _inline_json_schema_refs(WriteStdinInput.model_json_schema()) - response_schema = _inline_json_schema_refs(ExecOutput.model_json_schema()) + params_schema = inline_json_schema_refs(WriteStdinInput.model_json_schema()) + response_schema = inline_json_schema_refs(ExecOutput.model_json_schema()) return FunctionDeclaration( name="skill_write_stdin", description=("Write to a running skill_exec session. Set submit=true to " @@ -635,29 +534,32 @@ async def _run_async_impl( inputs = WriteStdinInput.model_validate(args) except Exception as ex: # pylint: disable=broad-except raise ValueError(f"Invalid skill_write_stdin arguments: {ex}") from ex + normalized_session_id = require_non_empty(inputs.session_id, field_name="session_id") + inputs = inputs.model_copy(update={"session_id": normalized_session_id}) - sess = await self._exec.get_session(inputs.session_id) + exec_session = await self._exec._get_session(inputs.session_id) if inputs.chars or inputs.submit: - await _write_stdin(sess, inputs.chars, submit=inputs.submit) + await _write_stdin(exec_session, inputs.chars, submit=inputs.submit) - yield_ms = inputs.yield_ms or DEFAULT_IO_YIELD_MS - status, chunk, offset, next_offset = await sess.yield_output(yield_ms, inputs.poll_lines) + yield_time_ms = inputs.yield_time_ms or DEFAULT_IO_YIELD_MS + status, chunk, offset, next_offset = await exec_session.yield_output(yield_time_ms, inputs.poll_lines) final_result = None - if status == _STATUS_EXITED and not sess.finalized: - final_result = await _collect_final_result(tool_context, sess, self._exec._run_tool) + if status == _STATUS_EXITED and not exec_session.finalized: + final_result = await _collect_final_result(tool_context, exec_session, self._exec._run_tool) out = ExecOutput( status=status, - session_id=inputs.session_id, + session_id=normalized_session_id, output=chunk, offset=offset, next_offset=next_offset, - exit_code=sess.exit_code, + exit_code=exec_session.exit_code, interaction=_detect_interaction(status, chunk), result=final_result, ) + _apply_artifacts_state_delta(tool_context, out) return out.model_dump(exclude_none=True) @@ -678,8 +580,8 @@ def __init__(self, exec_tool: SkillExecTool, filters: Optional[List[BaseFilter]] @override def _get_declaration(self) -> FunctionDeclaration: - params_schema = _inline_json_schema_refs(PollSessionInput.model_json_schema()) - response_schema = _inline_json_schema_refs(ExecOutput.model_json_schema()) + params_schema = inline_json_schema_refs(PollSessionInput.model_json_schema()) + response_schema = inline_json_schema_refs(ExecOutput.model_json_schema()) return FunctionDeclaration( name="skill_poll_session", description=("Poll a running or recently exited skill_exec session for " @@ -699,26 +601,29 @@ async def _run_async_impl( inputs = PollSessionInput.model_validate(args) except Exception as ex: # pylint: disable=broad-except raise ValueError(f"Invalid skill_poll_session arguments: {ex}") from ex + normalized_session_id = require_non_empty(inputs.session_id, field_name="session_id") + inputs = inputs.model_copy(update={"session_id": normalized_session_id}) - sess = await self._exec.get_session(inputs.session_id) + exec_session = await self._exec._get_session(inputs.session_id) - yield_ms = inputs.yield_ms or DEFAULT_IO_YIELD_MS - status, chunk, offset, next_offset = await sess.yield_output(yield_ms, inputs.poll_lines) + yield_time_ms = inputs.yield_time_ms or DEFAULT_IO_YIELD_MS + status, chunk, offset, next_offset = await exec_session.yield_output(yield_time_ms, inputs.poll_lines) final_result = None - if status == _STATUS_EXITED and not sess.finalized: - final_result = await _collect_final_result(tool_context, sess, self._exec._run_tool) + if status == _STATUS_EXITED and not exec_session.finalized: + final_result = await _collect_final_result(tool_context, exec_session, self._exec._run_tool) out = ExecOutput( status=status, - session_id=inputs.session_id, + session_id=normalized_session_id, output=chunk, offset=offset, next_offset=next_offset, - exit_code=sess.exit_code, + exit_code=exec_session.exit_code, interaction=_detect_interaction(status, chunk), result=final_result, ) + _apply_artifacts_state_delta(tool_context, out) return out.model_dump(exclude_none=True) @@ -738,8 +643,8 @@ def __init__(self, exec_tool: SkillExecTool, filters: Optional[List[BaseFilter]] @override def _get_declaration(self) -> FunctionDeclaration: - params_schema = _inline_json_schema_refs(KillSessionInput.model_json_schema()) - response_schema = _inline_json_schema_refs(SessionKillOutput.model_json_schema()) + params_schema = inline_json_schema_refs(KillSessionInput.model_json_schema()) + response_schema = inline_json_schema_refs(SessionKillOutput.model_json_schema()) return FunctionDeclaration( name="skill_kill_session", description="Terminate and remove a skill_exec session.", @@ -758,25 +663,26 @@ async def _run_async_impl( inputs = KillSessionInput.model_validate(args) except Exception as ex: # pylint: disable=broad-except raise ValueError(f"Invalid skill_kill_session arguments: {ex}") from ex + normalized_session_id = require_non_empty(inputs.session_id, field_name="session_id") + inputs = inputs.model_copy(update={"session_id": normalized_session_id}) + exec_session = await self._exec._get_session(normalized_session_id) - sess = await self._exec.remove_session(inputs.session_id) - - rc = sess.proc.returncode final_status = _STATUS_EXITED - if rc is None: + poll = await exec_session.proc.poll(None) + if poll.status == _STATUS_RUNNING: try: - sess.proc.kill() - await asyncio.wait_for(sess.proc.wait(), timeout=5.0) + await exec_session.proc.kill(DEFAULT_SESSION_KILL_SEC) except Exception: # pylint: disable=broad-except pass final_status = "killed" - _close_session(sess) + await self._exec._remove_session(normalized_session_id) + await _close_session(exec_session) out = SessionKillOutput( ok=True, - session_id=inputs.session_id, + session_id=normalized_session_id, status=final_status, ) return out.model_dump() @@ -790,7 +696,7 @@ async def _run_async_impl( def create_exec_tools( run_tool: SkillRunTool, filters: Optional[List[BaseFilter]] = None, - session_ttl: float = DEFAULT_SESSION_TTL, + session_ttl: float = DEFAULT_SESSION_TTL_SEC, ) -> tuple[SkillExecTool, WriteStdinTool, PollSessionTool, KillSessionTool]: """Create the full set of interactive exec tools sharing one session store. @@ -823,85 +729,46 @@ def create_exec_tools( async def _start_session( + *, + runner: BaseProgramRunner, + tool_context: InvocationContext, inputs: ExecInput, ws: WorkspaceInfo, - abs_cwd: str, + rel_cwd: str, env: dict[str, str], ) -> _ExecSession: - """Spawn a subprocess and return an initialized :class:`_ExecSession`.""" - command = inputs.command - master_fd: Optional[int] = None - - if inputs.tty: - # Allocate a pseudo-TTY. - master_fd, slave_fd = pty.openpty() - try: - proc = await asyncio.create_subprocess_exec( - "bash", - "-c", - command, - stdin=slave_fd, - stdout=slave_fd, - stderr=slave_fd, - cwd=abs_cwd, - env=env, - close_fds=True, - preexec_fn=os.setsid, - ) - finally: - os.close(slave_fd) # parent doesn't need the slave end - - sess = _ExecSession(proc=proc, ws=ws, in_data=inputs, master_fd=master_fd) - sess.reader_task = asyncio.create_task(_read_pty(sess, master_fd)) - else: - proc = await asyncio.create_subprocess_exec( - "bash", - "-c", - command, - stdin=asyncio.subprocess.PIPE, - stdout=asyncio.subprocess.PIPE, - stderr=asyncio.subprocess.STDOUT, - cwd=abs_cwd, - env=env, - ) - sess = _ExecSession(proc=proc, ws=ws, in_data=inputs) - if proc.stdout: - sess.reader_task = asyncio.create_task(_read_pipe(sess, proc.stdout)) - - return sess + """Start a ProgramSession via runtime runner.""" + spec = WorkspaceRunProgramSpec( + cmd="bash", + args=["-c", inputs.command], + env=env, + cwd=rel_cwd, + stdin=inputs.stdin, + timeout=float(inputs.timeout or 0), + tty=inputs.tty, + ) + proc = await runner.start_program(tool_context, ws, spec) + return _ExecSession(proc=proc, ws=ws, in_data=inputs) -async def _write_stdin(sess: _ExecSession, chars: str, submit: bool) -> None: +async def _write_stdin(exec_session: _ExecSession, chars: str, submit: bool) -> None: """Write *chars* (and optionally a newline) to the session's stdin.""" - if sess.master_fd is not None: - # PTY path: write to master fd - data = (chars + ("\n" if submit else "")).encode("utf-8") - if data: - try: - os.write(sess.master_fd, data) - except OSError as ex: - logger.debug("skill_exec: write to pty failed: %s", ex) - elif sess.proc.stdin: - # Pipe path: use asyncio StreamWriter - data = (chars + ("\n" if submit else "")).encode("utf-8") - if data: - try: - sess.proc.stdin.write(data) - await sess.proc.stdin.drain() - except Exception as ex: # pylint: disable=broad-except - logger.debug("skill_exec: write to stdin failed: %s", ex) + try: + await exec_session.proc.write(chars, submit) + except Exception as ex: # pylint: disable=broad-except + logger.debug("skill_exec: write to stdin failed: %s", ex) async def _collect_final_result( ctx: InvocationContext, - sess: _ExecSession, + exec_session: _ExecSession, run_tool: SkillRunTool, ) -> Optional[SkillRunOutput]: """Collect output files and build the final :class:`SkillRunOutput`.""" - if sess.finalized: - return sess.final_result + if exec_session.finalized: + return exec_session.final_result - in_data = sess.in_data + in_data = exec_session.in_data fake_run_input = SkillRunInput( skill=in_data.skill, command=in_data.command, @@ -916,13 +783,19 @@ async def _collect_final_result( outputs=in_data.outputs, ) try: - files, manifest = await run_tool._prepare_outputs(ctx, sess.ws, fake_run_input) + files, manifest = await run_tool._prepare_outputs(ctx, exec_session.ws, fake_run_input) except Exception as ex: # pylint: disable=broad-except logger.warning("skill_exec: collect outputs failed: %s", ex) files, manifest = [], None - total_out = await sess.total_output() - exit_code = sess.exit_code or 0 + try: + run_result = await exec_session.proc.run_result() + total_out = (run_result.stdout or "") + (run_result.stderr or "") + exit_code = run_result.exit_code + except Exception: # pylint: disable=broad-except + run_log = await exec_session.proc.log(None, None) + total_out = run_log.output or "" + exit_code = exec_session.exit_code or 0 # Reuse the same output-quality helpers as skill_run warnings: list[str] = [] @@ -944,30 +817,21 @@ async def _collect_final_result( ) try: - await run_tool._attach_artifacts_if_requested(ctx, sess.ws, fake_run_input, result, files) + await run_tool._attach_artifacts_if_requested(ctx, exec_session.ws, fake_run_input, result, files) except Exception as ex: # pylint: disable=broad-except logger.warning("skill_exec: attach artifacts failed: %s", ex) if manifest: run_tool._merge_manifest_artifact_refs(manifest, result) - sess.final_result = result - sess.finalized = True + exec_session.final_result = result + exec_session.finalized = True return result -def _close_session(sess: _ExecSession) -> None: - """Cancel background tasks and close any open fds.""" - if sess.reader_task and not sess.reader_task.done(): - sess.reader_task.cancel() - if sess.master_fd is not None: - try: - os.close(sess.master_fd) - except OSError: - pass - sess.master_fd = None - if sess.proc.stdin and not sess.proc.stdin.is_closing(): - try: - sess.proc.stdin.close() - except Exception: # pylint: disable=broad-except - pass +async def _close_session(exec_session: _ExecSession) -> None: + """Release program session resources.""" + try: + await exec_session.proc.close() + except Exception: # pylint: disable=broad-except + pass diff --git a/trpc_agent_sdk/skills/tools/_skill_list.py b/trpc_agent_sdk/skills/tools/_skill_list.py index 83d0903..1bfa72e 100644 --- a/trpc_agent_sdk/skills/tools/_skill_list.py +++ b/trpc_agent_sdk/skills/tools/_skill_list.py @@ -8,8 +8,8 @@ from __future__ import annotations -from typing import Optional from typing import Literal +from typing import Optional from trpc_agent_sdk.context import InvocationContext @@ -17,16 +17,16 @@ from .._repository import BaseSkillRepository -def skill_list(tool_context: InvocationContext, mode: Literal["all", "enabled", "disabled"] = "all") -> list[str]: +def skill_list(tool_context: InvocationContext, mode: Literal['all', 'available'] = 'all') -> list[str]: """List all discovered skills. Args: tool_context: The tool context. - + mode: The mode to list skills. Returns: A list of skill names. """ repository: Optional[BaseSkillRepository] = tool_context.agent_context.get_metadata(SKILL_REPOSITORY_KEY) if repository is None: raise ValueError("repository not found") - return repository.skill_list() + return repository.skill_list(mode) diff --git a/trpc_agent_sdk/skills/tools/_skill_list_docs.py b/trpc_agent_sdk/skills/tools/_skill_list_docs.py index f1e82e9..02f4abb 100644 --- a/trpc_agent_sdk/skills/tools/_skill_list_docs.py +++ b/trpc_agent_sdk/skills/tools/_skill_list_docs.py @@ -8,32 +8,33 @@ from __future__ import annotations -from typing import Any from typing import Optional from trpc_agent_sdk.context import InvocationContext -from trpc_agent_sdk.log import logger from .._constants import SKILL_REPOSITORY_KEY from .._repository import BaseSkillRepository -def skill_list_docs(tool_context: InvocationContext, skill_name: str) -> dict[str, Any]: +def skill_list_docs(tool_context: InvocationContext, skill_name: str) -> list[str]: """List doc filenames for a skill. + Args: - skill_name: The name of the skill to load. + skill_name: The name of the skill. Returns: - Object containing docs list and whether SKILL.md body is loaded. + Array of doc filenames. """ + normalized_skill = (skill_name or "").strip() + if not normalized_skill: + raise ValueError("skill is required") + repository: Optional[BaseSkillRepository] = tool_context.agent_context.get_metadata(SKILL_REPOSITORY_KEY) if repository is None: - raise ValueError("repository not found") - skill = repository.get(skill_name) - if skill is None: - logger.error("Skill %s not found", repr(skill_name)) - return {"docs": [], "body_loaded": False} - return { - "docs": [resource.path for resource in skill.resources], - "body_loaded": bool(skill.body), - } + return [] + + try: + skill = repository.get(normalized_skill) + except ValueError as ex: + raise ValueError(f"unknown skill: {normalized_skill}") from ex + return [resource.path for resource in (skill.resources or [])] diff --git a/trpc_agent_sdk/skills/tools/_skill_list_tool.py b/trpc_agent_sdk/skills/tools/_skill_list_tool.py index 4ecfbb2..942d127 100644 --- a/trpc_agent_sdk/skills/tools/_skill_list_tool.py +++ b/trpc_agent_sdk/skills/tools/_skill_list_tool.py @@ -1,4 +1,4 @@ -# Tencent is pleased to support the open source community by making tRPC-Agent-Python available. +# -*- coding: utf-8 -*- # # Copyright (C) 2026 Tencent. All rights reserved. # @@ -9,7 +9,6 @@ from __future__ import annotations -import re from typing import Any from typing import Optional @@ -20,63 +19,13 @@ from .._repository import BaseSkillRepository -def _extract_shell_examples_from_skill_body(body: str, limit: int = 5) -> list[str]: - """Extract likely runnable command examples from SKILL.md body.""" - if not body: - return [] - out: list[str] = [] - seen: set[str] = set() - lines = body.splitlines() - - def maybe_add(cmd: str) -> None: - cmd = re.sub(r"\s+", " ", (cmd or "").strip()) - if not cmd or cmd in seen: - return - if not re.match(r"^[A-Za-z0-9_./$\"'`-]", cmd): - return - seen.add(cmd) - out.append(cmd) - - i = 0 - while i < len(lines) and len(out) < limit: - cur = lines[i].strip() - if cur.lower() != "command:": - i += 1 - continue - i += 1 - block: list[str] = [] - while i < len(lines): - raw = lines[i] - s = raw.strip() - if not s: - if block: - break - i += 1 - continue - if s.lower() in ("command:", "output files", "overview", "examples", "tools:"): - break - if re.match(r"^\d+\)", s): - break - if raw.startswith(" ") or raw.startswith("\t"): - block.append(s) - i += 1 - continue - if block: - break - i += 1 - if block: - merged = " ".join(part.rstrip("\\").strip() for part in block) - maybe_add(merged) - return out - - def skill_list_tools(tool_context: InvocationContext, skill_name: str) -> dict[str, Any]: - """List executable guidance for a skill. + """List callable tools declared for a skill. Args: skill_name: The name of the skill to load. Returns: - Object containing declared tools and command examples from SKILL.md. + Object containing available tools. """ repository: Optional[BaseSkillRepository] = tool_context.agent_context.get_metadata(SKILL_REPOSITORY_KEY) if repository is None: @@ -84,8 +33,5 @@ def skill_list_tools(tool_context: InvocationContext, skill_name: str) -> dict[s skill = repository.get(skill_name) if skill is None: logger.error("Skill %s not found", repr(skill_name)) - return {"tools": [], "command_examples": []} - return { - "tools": list(skill.tools or []), - "command_examples": _extract_shell_examples_from_skill_body(skill.body, limit=5), - } + return {"available_tools": []} + return {"available_tools": list(skill.tools or [])} diff --git a/trpc_agent_sdk/skills/tools/_skill_load.py b/trpc_agent_sdk/skills/tools/_skill_load.py index ec2c62c..3579d31 100644 --- a/trpc_agent_sdk/skills/tools/_skill_load.py +++ b/trpc_agent_sdk/skills/tools/_skill_load.py @@ -1,4 +1,4 @@ -# Tencent is pleased to support the open source community by making tRPC-Agent-Python available. +# -*- coding: utf-8 -*- # # Copyright (C) 2026 Tencent. All rights reserved. # @@ -12,65 +12,138 @@ import json from typing import Any +from typing import List from typing import Optional +from typing_extensions import override from trpc_agent_sdk.context import InvocationContext +from trpc_agent_sdk.filter import BaseFilter +from trpc_agent_sdk.tools import BaseTool +from trpc_agent_sdk.types import FunctionDeclaration +from trpc_agent_sdk.types import Schema +from trpc_agent_sdk.types import Type -from .._constants import SKILL_DOCS_STATE_KEY_PREFIX -from .._constants import SKILL_LOADED_STATE_KEY_PREFIX -from .._constants import SKILL_REPOSITORY_KEY -from .._constants import SKILL_TOOLS_STATE_KEY_PREFIX +from .._common import append_loaded_order_state_delta +from .._common import docs_state_key +from .._common import loaded_state_key +from .._common import set_state_delta +from .._common import tool_state_key from .._repository import BaseSkillRepository +from .._types import Skill +from ..stager import SkillStageRequest +from ..stager import Stager +from ._common import CreateWorkspaceNameCallback +from ._common import default_create_ws_name_callback +from ._common import set_staged_workspace_dir +from ._copy_stager import CopySkillStager -def _set_state_delta(invocation_context: InvocationContext, key: str, value: Any) -> None: - """Set the state delta of a skill loaded.""" - invocation_context.actions.state_delta[key] = value +class SkillLoadTool(BaseTool): + """Tool for loading a skill.""" + def __init__( + self, + repository: BaseSkillRepository, + skill_stager: Optional[Stager] = None, + create_ws_name_cb: Optional[CreateWorkspaceNameCallback] = None, + filters: Optional[List[BaseFilter]] = None, + ): + super().__init__(name="skill_load", description="Load a skill.", filters=filters) + self._repository = repository + self._skill_stager: Stager = skill_stager or CopySkillStager() + self._create_ws_name_cb: Optional[ + CreateWorkspaceNameCallback] = create_ws_name_cb or default_create_ws_name_callback -def _set_state_delta_for_skill_load(invocation_context: InvocationContext, - skill_name: str, - docs: list[str], - include_all_docs: bool = False) -> None: - """Set the state delta of a skill loaded.""" - key = f"{SKILL_LOADED_STATE_KEY_PREFIX}{skill_name}" - _set_state_delta(invocation_context, key, True) - key = f"{SKILL_DOCS_STATE_KEY_PREFIX}{skill_name}" - if include_all_docs: - _set_state_delta(invocation_context, key, '*') - else: - _set_state_delta(invocation_context, key, json.dumps(docs or [])) + @override + def _get_declaration(self) -> Optional[FunctionDeclaration]: + return FunctionDeclaration( + name="skill_load", + description=("Load a skill body and optional docs. Safe to call multiple times to add or replace docs. " + "Do not call this to list skills; names and descriptions are already in context. " + "Use when a task needs a skill's SKILL.md body and selected docs in context."), + parameters=Schema( + type=Type.OBJECT, + required=["skill_name"], + properties={ + "skill_name": + Schema(type=Type.STRING, description="The name of the skill to load."), + "docs": + Schema(type=Type.ARRAY, + default=None, + items=Schema(type=Type.STRING), + description="The docs of the skill to load."), + "include_all_docs": + Schema(type=Type.BOOLEAN, default=False, description="Whether to include all docs of the skill."), + }, + ), + response=Schema(type=Type.STRING, + description="Result of skill_load. message is a string indicating the skill was loaded."), + ) + @override + async def _run_async_impl(self, *, tool_context: InvocationContext, args: dict[str, Any]) -> str: + if not (args["skill_name"] or "").strip(): + raise ValueError("skill_name is required") + skill_name = args["skill_name"] + docs = args.get("docs", []) + include_all_docs = args.get("include_all_docs", False) + normalized_skill = skill_name.strip() + skill = self._repository.get(normalized_skill) + await self._ensure_staged(ctx=tool_context, skill_name=skill_name) + clean_docs = [doc.strip() for doc in (docs or []) if isinstance(doc, str) and doc.strip()] + self.__set_state_delta_for_skill_load(tool_context, skill_name, clean_docs, include_all_docs) + if skill.tools: + self.__set_state_delta_for_skill_tools(tool_context, skill) + return f"skill {skill_name!r} loaded" -def _set_state_delta_for_skill_tools(invocation_context: InvocationContext, skill_name: str, tools: list[str]) -> None: - """Set the state delta of a skill tools.""" - key = f"{SKILL_TOOLS_STATE_KEY_PREFIX}{skill_name}" - _set_state_delta(invocation_context, key, json.dumps(tools or [])) + async def _ensure_staged(self, *, ctx: InvocationContext, skill_name: str) -> None: + runtime = self._repository.workspace_runtime + manager = runtime.manager(ctx) + ws_id = self._create_ws_name_cb(ctx) + ws = await manager.create_workspace(ws_id, ctx) + result = await self._skill_stager.stage_skill( + SkillStageRequest(skill_name=skill_name, repository=self._repository, workspace=ws, ctx=ctx)) + set_staged_workspace_dir(ctx, skill_name, result.workspace_skill_dir) + def __set_state_delta_for_skill_load(self, + invocation_context: InvocationContext, + skill_name: str, + docs: list[str], + include_all_docs: bool = False) -> None: + """Set state delta for skill_load, aligned with Go StateDeltaForInvocation.""" + agent_name = invocation_context.agent_name.strip() + delta, normalized_skill = self.__build_state_delta_for_skill_load( + invocation_context=invocation_context, + skill_name=skill_name, + docs=docs, + include_all_docs=include_all_docs, + ) + invocation_context.actions.state_delta.update(delta) + append_loaded_order_state_delta(invocation_context, agent_name, normalized_skill) -def skill_load(tool_context: InvocationContext, - skill_name: str, - docs: Optional[list[str]] = None, - include_all_docs: bool = False) -> str: - """Load a skill body and optional docs. Safe to call multiple times to add or replace docs. - Do not call this to list skills; names and descriptions are already in context. - Use when a task needs a skill's SKILL.md body and selected docs in context. - Args: - skill_name: The name of the skill to load. - docs: The docs of the skill to load. - include_all_docs: Whether to include all docs of the skill. + def __build_state_delta_for_skill_load( + self, + invocation_context: InvocationContext, + skill_name: str, + docs: list[str], + include_all_docs: bool = False, + ) -> tuple[dict[str, Any], str]: + """Build skill_load state delta.""" + normalized_skill = skill_name.strip() + if not normalized_skill: + return {}, "" + delta: dict[str, Any] = {} + delta[loaded_state_key(invocation_context, normalized_skill)] = True + if include_all_docs: + delta[docs_state_key(invocation_context, normalized_skill)] = '*' + else: + delta[docs_state_key(invocation_context, normalized_skill)] = json.dumps(docs or []) + return delta, normalized_skill - Returns: - A message indicating the skill was loaded. - """ - - repository: Optional[BaseSkillRepository] = tool_context.agent_context.get_metadata(SKILL_REPOSITORY_KEY) - if repository is None: - raise ValueError("repository not found") - skill = repository.get(skill_name) - if skill is None: - return f"skill {skill_name!r} not found" - _set_state_delta_for_skill_load(tool_context, skill_name, docs or [], include_all_docs) - if skill.tools: - _set_state_delta_for_skill_tools(tool_context, skill_name, skill.tools) - return f"skill {skill_name!r} loaded" + def __set_state_delta_for_skill_tools(self, invocation_context: InvocationContext, skill: Skill) -> None: + """Set the state delta of a skill tools.""" + normalized_skill = skill.summary.name.strip() + if not normalized_skill: + return + key = tool_state_key(invocation_context, normalized_skill) + set_state_delta(invocation_context, key, json.dumps(skill.tools)) diff --git a/trpc_agent_sdk/skills/tools/_skill_run.py b/trpc_agent_sdk/skills/tools/_skill_run.py index 23ac4fa..01dcdc8 100644 --- a/trpc_agent_sdk/skills/tools/_skill_run.py +++ b/trpc_agent_sdk/skills/tools/_skill_run.py @@ -3,19 +3,12 @@ # Copyright (C) 2026 Tencent. All rights reserved. # # tRPC-Agent-Python is licensed under Apache-2.0. -"""Skill run tool for executing commands in skill workspaces. - -This module provides the SkillRunTool class which allows LLM to execute commands -inside a skill workspace. It stages the entire skill directory and runs commands, -aligned with the Go implementation at: -https://github.com/trpc-group/trpc-agent-go/blob/main/tool/skill/run.go -""" +"""Skill run tool for executing commands in skill workspaces.""" from __future__ import annotations import os -import re -import shlex +import posixpath from pathlib import Path from typing import Any from typing import Dict @@ -45,12 +38,19 @@ from trpc_agent_sdk.types import Part from trpc_agent_sdk.types import Schema -from .._constants import SKILL_LOADED_STATE_KEY_PREFIX +from .._common import get_state_delta_value +from .._common import loaded_state_key +from .._constants import SKILL_ARTIFACTS_STATE_KEY from .._constants import SKILL_REPOSITORY_KEY from .._repository import BaseSkillRepository from .._utils import shell_quote from ..stager import SkillStageRequest from ..stager import Stager +from ..stager import default_workspace_skill_dir +from ._common import CreateWorkspaceNameCallback +from ._common import default_create_ws_name_callback +from ._common import get_staged_workspace_dir +from ._common import inline_json_schema_refs from ._copy_stager import CopySkillStager # --------------------------------------------------------------------------- @@ -68,7 +68,7 @@ _ENV_EDITOR = "EDITOR" _ENV_VISUAL = "VISUAL" -_EDITOR_HELPER_DIR = ".trpc_agent_sdk" +_EDITOR_HELPER_DIR = ".trpc_agent" _EDITOR_CONTENT_FILE = "editor_input.txt" _EDITOR_SCRIPT_FILE = "editor_write.sh" @@ -115,35 +115,6 @@ # --------------------------------------------------------------------------- -def _inline_json_schema_refs(schema: dict[str, Any]) -> dict[str, Any]: - """Inline $ref references in JSON Schema by replacing them with actual definitions.""" - defs = schema.get('$defs', {}) - if not defs: - return schema - - def resolve_ref(obj: Any) -> Any: - if isinstance(obj, dict): - if '$ref' in obj: - ref_path = obj['$ref'] - if ref_path.startswith('#/$defs/'): - ref_name = ref_path.replace('#/$defs/', '') - if ref_name in defs: - resolved = resolve_ref(defs[ref_name]) - merged = {**resolved, **{k: v for k, v in obj.items() if k != '$ref'}} - return merged - return obj - else: - return {k: resolve_ref(v) for k, v in obj.items()} - elif isinstance(obj, list): - return [resolve_ref(item) for item in obj] - else: - return obj - - result = {k: v for k, v in schema.items() if k != '$defs'} - result = resolve_ref(result) - return result - - def _is_text_mime(mime: str) -> bool: """Return True when *mime* is a text-like content type.""" if not mime: @@ -176,6 +147,51 @@ def _workspace_ref(name: str) -> str: return f"workspace://{name}" if name else "" +def _normalize_input_dst(dst: str) -> str: + """Normalize declarative input destination like Go normalizeInputTo.""" + s = (dst or "").strip().replace("\\", "/") + if not s: + return "" + cleaned = posixpath.normpath(s) + if cleaned in (".", "inputs"): + return "" + prefix = "inputs/" + if cleaned.startswith(prefix): + rest = cleaned[len(prefix):] + return posixpath.join(DIR_WORK, "inputs", rest) + return cleaned + + +def _normalize_run_input(input_data: "SkillRunInput") -> "SkillRunInput": + """Normalize run input fields like Go normalizeRunInput.""" + if not input_data.inputs: + return input_data + normalized_inputs: list[WorkspaceInputSpec] = [] + for spec in input_data.inputs: + normalized_inputs.append(spec.model_copy(update={"dst": _normalize_input_dst(spec.dst)})) + return input_data.model_copy(update={"inputs": normalized_inputs}) + + +def _apply_run_artifacts_state_delta(ctx: InvocationContext, output: "SkillRunOutput") -> None: + """Write run artifact refs to state delta (Go RunTool.StateDelta parity).""" + tool_call_id = (ctx.function_call_id or "").strip() + if not tool_call_id or not output.artifact_files: + return + artifacts: list[dict[str, Any]] = [] + for item in output.artifact_files: + name = (item.name or "").strip() + version = int(item.version) + if not name or version < 0: + continue + artifacts.append({"name": name, "version": version, "ref": f"artifact://{name}@{version}"}) + if not artifacts: + return + ctx.actions.state_delta[SKILL_ARTIFACTS_STATE_KEY] = { + "tool_call_id": tool_call_id, + "artifacts": artifacts, + } + + def _filter_failed_empty_outputs( exit_code: int, timed_out: bool, @@ -215,7 +231,7 @@ def _split_command_line(cmd: str) -> list[str]: raise ValueError("skill_run: command is empty") for ch in _DISALLOWED_SHELL_META: if ch in cmd: - raise ValueError(f"skill_run: shell metacharacter {ch!r} is not allowed when " + raise ValueError(f"skill_run: shell meta character {ch!r} is not allowed when " "command restrictions are enabled. Provide a single executable " "with args only (no redirects/pipes/chaining).") args: list[str] = [] @@ -345,16 +361,6 @@ class SkillRunOutput(BaseModel): default_factory=list, description="Non-fatal warnings about truncation, persistence, or empty outputs", ) - suggested_commands: Optional[list[str]] = Field( - default=None, - description=("Suggested runnable commands extracted from SKILL.md when the provided " - "command is not found."), - ) - suggested_tools: Optional[list[str]] = Field( - default=None, - description=("Suggested tool names from SKILL.md Tools section when command-not-found " - "indicates the request should use tool calls instead of shell execution."), - ) # --------------------------------------------------------------------------- @@ -375,11 +381,11 @@ def __init__( filters: Optional[List[BaseFilter]] = None, *, require_skill_loaded: bool = False, - block_inline_python_rewrite: bool = False, force_save_artifacts: bool = False, allowed_cmds: Optional[List[str]] = None, denied_cmds: Optional[List[str]] = None, skill_stager: Optional[Stager] = None, + create_ws_name_cb: Optional[CreateWorkspaceNameCallback] = None, **kwargs, ): """Initialize SkillRunTool. @@ -389,9 +395,6 @@ def __init__( filters: Optional tool filters. require_skill_loaded: When True, skill_run raises unless skill_load was called first for this skill in the current session. - block_inline_python_rewrite: When True, reject ad-hoc ``python -c`` commands - if SKILL.md already provides script-based python - command examples (for example ``python3 scripts/foo.py``). force_save_artifacts: When True, always attempt to persist collected output files via the artifact service (if available). allowed_cmds: When set, only these command names (first token) are allowed. @@ -414,7 +417,6 @@ def __init__( ) self._repository = repository self._require_skill_loaded = require_skill_loaded - self._block_inline_python_rewrite = block_inline_python_rewrite self._force_save_artifacts = force_save_artifacts self._allowed_cmds: frozenset[str] = frozenset(c.strip() for c in (allowed_cmds or []) if c.strip()) self._denied_cmds: frozenset[str] = frozenset(c.strip() for c in (denied_cmds or []) if c.strip()) @@ -429,10 +431,17 @@ def __init__( self._kwargs = kwargs self._run_tool_kwargs: dict = kwargs.pop("run_tool_kwargs", {}) self._timeout = self._run_tool_kwargs.pop("timeout", 300.0) + self._create_ws_name_cb: Optional[CreateWorkspaceNameCallback] = \ + create_ws_name_cb or default_create_ws_name_callback # Staging strategy: default is copy-based stager (mirrors Go newCopySkillStager) self._skill_stager: Stager = skill_stager or CopySkillStager() + @property + def require_skill_loaded(self) -> bool: + """Get the require_skill_loaded flag.""" + return self._require_skill_loaded + @property def skill_stager(self) -> Stager: """Get the skill stager.""" @@ -444,8 +453,8 @@ def skill_stager(self) -> Stager: @override def _get_declaration(self) -> FunctionDeclaration: - params_schema = _inline_json_schema_refs(SkillRunInput.model_json_schema()) - response_schema = _inline_json_schema_refs(SkillRunOutput.model_json_schema()) + params_schema = inline_json_schema_refs(SkillRunInput.model_json_schema()) + response_schema = inline_json_schema_refs(SkillRunOutput.model_json_schema()) desc = ("Run a command inside a skill workspace. " "Use it only for commands required by the skill docs (not for generic shell tasks). " "User-uploaded file inputs are staged under $WORK_DIR/inputs (also visible as inputs/). " @@ -482,11 +491,10 @@ def _get_repository(self, context: InvocationContext) -> BaseSkillRepository: def _is_skill_loaded(self, ctx: InvocationContext, skill_name: str) -> bool: """Return True when the skill was loaded in the current session.""" try: - key = f"{SKILL_LOADED_STATE_KEY_PREFIX}{skill_name.strip()}" - val = ctx.session_state.get(key) - return bool(val) + key = loaded_state_key(ctx, skill_name.strip()) + return bool(get_state_delta_value(ctx, key)) except Exception: # pylint: disable=broad-except - return True # default to allowed when state is not accessible + return False # ------------------------------------------------------------------ # Editor helper @@ -639,6 +647,10 @@ async def _run_async_impl( inputs = SkillRunInput.model_validate(args) except Exception as ex: # pylint: disable=broad-except raise ValueError(f"Invalid skill_run arguments: {ex}") from ex + if not (inputs.skill or "").strip() or not (inputs.command or "").strip(): + raise ValueError("skill and command are required") + inputs = inputs.model_copy(update={"skill": inputs.skill.strip(), "command": inputs.command.strip()}) + inputs = _normalize_run_input(inputs) # require_skill_loaded gate if self._require_skill_loaded and not self._is_skill_loaded(tool_context, inputs.skill): @@ -653,25 +665,27 @@ async def _run_async_impl( repository = self._get_repository(tool_context) - session_id = inputs.skill - if tool_context.session and tool_context.session.id: - session_id = tool_context.session.id - + workspace_id = self._create_ws_name_cb(tool_context) workspace_runtime = repository.workspace_runtime manager = workspace_runtime.manager(tool_context) - ws = await manager.create_workspace(session_id, tool_context) - - # Stage skill via the pluggable stager strategy - stage_result = await self._skill_stager.stage_skill( - SkillStageRequest( - skill_name=inputs.skill, - repository=repository, - workspace=ws, - ctx=tool_context, - engine=workspace_runtime, - timeout=self._timeout, - )) - workspace_skill_dir = stage_result.workspace_skill_dir + ws = await manager.create_workspace(workspace_id, tool_context) + + # Static stage is handled by skill_load. Keep fallback staging here for + # backward compatibility when callers skip skill_load. + if self._is_skill_loaded(tool_context, inputs.skill): + workspace_skill_dir = get_staged_workspace_dir(tool_context, inputs.skill) + if not workspace_skill_dir: + workspace_skill_dir = default_workspace_skill_dir(inputs.skill) + else: + stage_result = await self._skill_stager.stage_skill( + SkillStageRequest( + skill_name=inputs.skill, + repository=repository, + workspace=ws, + ctx=tool_context, + timeout=self._timeout, + )) + workspace_skill_dir = stage_result.workspace_skill_dir if inputs.inputs: fs = workspace_runtime.fs(tool_context) @@ -705,16 +719,6 @@ async def _run_async_impl( # Select primary output primary = _select_primary_output(files) - suggested_commands = self._suggest_commands_for_missing_command( - result, - repository, - inputs.skill, - ) - suggested_tools = self._suggest_tools_for_missing_command( - result, - repository, - inputs.skill, - ) output = SkillRunOutput( stdout=stdout, @@ -725,20 +729,20 @@ async def _run_async_impl( output_files=files, primary_output=primary, warnings=warnings, - suggested_commands=suggested_commands, - suggested_tools=suggested_tools, ) await self._attach_artifacts_if_requested(tool_context, ws, inputs, output, files) self._merge_manifest_artifact_refs(manifest, output) # omit_inline_content - if inputs.omit_inline_content and output.artifact_files: + if inputs.omit_inline_content: for f in output.output_files: f.content = "" if output.primary_output: output.primary_output.content = "" + _apply_run_artifacts_state_delta(tool_context, output) + return output.model_dump(exclude_none=True) # ------------------------------------------------------------------ @@ -761,13 +765,15 @@ async def _run_program( # Inject skill-specific env from repository (e.g. api_key → primary_env) repository = self._get_repository(ctx) try: - skill_env: dict[str, str] = repository.skill_run_env(input_data.skill) + skill_env: dict[str, str] = repository.skill_run_env(ctx, input_data.skill) for k, v in skill_env.items(): k = k.strip() if not k or not v.strip(): continue if k in env: # don't override explicit tool-call env continue + if os.environ.get(k, "").strip(): # don't override host env + continue if k.upper() in _BLOCKED_SKILL_ENV_KEYS: continue env[k] = v @@ -778,9 +784,6 @@ async def _run_program( await self._prepare_editor_env(ctx, ws, env, input_data.editor_text) # Build command (with venv activation or command restrictions) - blocked = self._precheck_inline_python_rewrite(repository, input_data) - if blocked is not None: - return blocked cmd, cmd_args = self._build_command(input_data.command, ws.path, cwd) workspace_runtime = repository.workspace_runtime @@ -798,300 +801,10 @@ async def _run_program( ctx, ) if ret.exit_code != 0: - ret = self._with_missing_command_hint(ret, input_data) - ret = self._with_skill_doc_command_hint(ret, repository, input_data) - ret = self._with_missing_entrypoint_hint(ret, input_data, ws, cwd) - logger.info("Failed to run program: %s", ret.stderr) + raw_stderr = ret.stderr or "" + logger.info("Failed to run program: exit_code=%s, stderr=%s", ret.exit_code, raw_stderr.strip()) return ret - def _precheck_inline_python_rewrite( - self, - repository: BaseSkillRepository, - input_data: SkillRunInput, - ) -> Optional[WorkspaceRunResult]: - """Optionally block ad-hoc python -c when SKILL.md has script examples.""" - if not self._block_inline_python_rewrite: - return None - cmd = (input_data.command or "").strip() - if not re.match(r"^python(?:\d+(?:\.\d+)?)?\s+-c\b", cmd): - return None - try: - sk = repository.get(input_data.skill) - except Exception: # pylint: disable=broad-except - return None - if sk is None: - return None - examples = self._extract_shell_examples_from_skill_body(sk.body, limit=8) - script_examples = [ - e for e in examples if re.match(r"^python(?:\d+(?:\.\d+)?)?\s+scripts/[^\s]+(?:\s|$)", e.strip()) - ] - if not script_examples: - return None - stderr = ("skill_run rejected this ad-hoc inline python command.\n" - "This skill already provides script-based Python command examples in SKILL.md.\n" - "Use one of these commands exactly instead of `python -c` rewrites:\n" + - "\n".join(f"- {c}" for c in script_examples) + "\n") - return WorkspaceRunResult( - stdout="", - stderr=stderr, - exit_code=2, - duration=0, - timed_out=False, - ) - - @staticmethod - def _with_missing_command_hint(ret: WorkspaceRunResult, input_data: SkillRunInput) -> WorkspaceRunResult: - """Add a targeted hint when the shell command does not exist.""" - stderr_lower = (ret.stderr or "").lower() - is_missing_cmd = ret.exit_code == 127 and ("command not found" in stderr_lower or - "not recognized as an internal or external command" in stderr_lower) - if not is_missing_cmd: - return ret - hint = ("\n\nSkill command hint:\n" - f"- The command `{input_data.command}` was not found in this skill workspace.\n" - "- Do not invent command names.\n" - "- Read the loaded `SKILL.md` and execute one of its exact shell examples.\n" - "- If needed, call `skill_load` first so the full skill body is injected " - "before calling `skill_run`.\n") - return ret.model_copy(update={"stderr": f"{ret.stderr}{hint}"}) - - @staticmethod - def _extract_shell_examples_from_skill_body(body: str, limit: int = 5) -> list[str]: - """Extract likely runnable shell command lines from SKILL.md body.""" - if not body: - return [] - out: list[str] = [] - seen: set[str] = set() - lines = body.splitlines() - - def maybe_add(cmd: str) -> None: - cmd = re.sub(r"\s+", " ", (cmd or "").strip()) - if not cmd: - return - if cmd in seen: - return - # Keep only command-like lines. - if not re.match(r"^[A-Za-z0-9_./$\"'`-]", cmd): - return - # Skip function-style examples like tool_name(arg="..."), which are - # LLM tool calls rather than shell commands. - if re.match(r"^[A-Za-z_][A-Za-z0-9_]*\s*\(.*\)\s*$", cmd): - return - seen.add(cmd) - out.append(cmd) - - # 1) Parse markdown fenced code blocks. - in_fence = False - for raw in lines: - line = raw.strip() - if line.startswith("```"): - in_fence = not in_fence - continue - if not in_fence: - continue - if not line or line.startswith("#"): - continue - if line in ("PY", "EOF") or line.startswith(("-", "*")): - continue - maybe_add(line) - if len(out) >= limit: - return out - - # 2) Parse "Command:" sections with indented multi-line commands. - i = 0 - while i < len(lines) and len(out) < limit: - cur = lines[i].strip() - if cur.lower() != "command:": - i += 1 - continue - i += 1 - block: list[str] = [] - while i < len(lines): - raw = lines[i] - s = raw.strip() - if not s: - if block: - break - i += 1 - continue - # Next section marker - if s.lower() in ("command:", "output files", "overview", "examples", "tools:"): - break - if re.match(r"^\d+\)", s): - break - # Command content is commonly indented in SKILL.md examples. - if raw.startswith(" ") or raw.startswith("\t"): - block.append(s) - i += 1 - continue - if block: - break - i += 1 - if block: - merged = " ".join(part.rstrip("\\").strip() for part in block) - maybe_add(merged) - return out[:limit] - - def _with_skill_doc_command_hint( - self, - ret: WorkspaceRunResult, - repository: BaseSkillRepository, - input_data: SkillRunInput, - ) -> WorkspaceRunResult: - """Append SKILL.md command examples when command is not found.""" - stderr = ret.stderr or "" - stderr_lower = stderr.lower() - is_missing_cmd = ret.exit_code == 127 and ("command not found" in stderr_lower or - "not recognized as an internal or external command" in stderr_lower) - if not is_missing_cmd: - return ret - try: - sk = repository.get(input_data.skill) - except Exception: # pylint: disable=broad-except - return ret - examples = self._extract_shell_examples_from_skill_body(sk.body, limit=5) - tools = list(getattr(sk, "tools", []) or []) - if not examples and not tools: - return ret - hint_parts: list[str] = [] - if examples: - hint_parts.append("\n\nSKILL.md command examples:\n" + "\n".join(f"- {cmd}" for cmd in examples) + "\n") - if tools: - hint_parts.append("\nSKILL.md tools suggest this is a tool-call workflow:\n" + - "\n".join(f"- {name}" for name in tools) + "\n" + - "- Do not run these tool names via `skill_run` shell commands.\n" + - "- Call those tools directly (e.g. function/tool call) after `skill_load`.\n") - hint = "".join(hint_parts) - return ret.model_copy(update={"stderr": f"{stderr}{hint}"}) - - @staticmethod - def _is_missing_command_result(ret: WorkspaceRunResult) -> bool: - """Return True when stderr indicates command-not-found failure.""" - stderr_lower = (ret.stderr or "").lower() - return ret.exit_code == 127 and ("command not found" in stderr_lower - or "not recognized as an internal or external command" in stderr_lower) - - def _suggest_commands_for_missing_command( - self, - ret: WorkspaceRunResult, - repository: BaseSkillRepository, - skill_name: str, - ) -> Optional[list[str]]: - """Return SKILL.md command suggestions for missing-command failures.""" - if not self._is_missing_command_result(ret): - return None - try: - sk = repository.get(skill_name) - except Exception: # pylint: disable=broad-except - return None - suggestions = self._extract_shell_examples_from_skill_body(sk.body, limit=5) - return suggestions or None - - def _suggest_tools_for_missing_command( - self, - ret: WorkspaceRunResult, - repository: BaseSkillRepository, - skill_name: str, - ) -> Optional[list[str]]: - """Return SKILL.md tool names when command-not-found implies tool calls.""" - if not self._is_missing_command_result(ret): - return None - try: - sk = repository.get(skill_name) - except Exception: # pylint: disable=broad-except - return None - tools = list(getattr(sk, "tools", []) or []) - return tools or None - - @staticmethod - def _extract_command_path_candidates(command: str) -> list[str]: - """Extract likely relative path tokens from a shell command.""" - try: - tokens = shlex.split(command, posix=True) - except ValueError: - tokens = command.split() - out: list[str] = [] - seen: set[str] = set() - for tok in tokens: - s = (tok or "").strip() - if not s or s.startswith("-"): - continue - # Keep relative path-like tokens only. - is_path_like = "/" in s or s.endswith((".py", ".sh", ".pl", ".rb", ".js", ".ts")) - if not is_path_like or os.path.isabs(s): - continue - if s in seen: - continue - seen.add(s) - out.append(s) - return out - - @staticmethod - def _list_entrypoint_suggestions(run_dir: Path, limit: int = 20) -> list[str]: - """List likely runnable files from common locations under the current cwd.""" - suggestions: list[str] = [] - seen: set[str] = set() - roots = [ - run_dir, - run_dir / "scripts", - run_dir / "bin", - run_dir / "tools", - ] - for root in roots: - try: - if not root.is_dir(): - continue - for p in sorted(root.rglob("*")): - if len(suggestions) >= limit: - return suggestions - if not p.is_file(): - continue - rel = p.relative_to(run_dir).as_posix() - # Prefer obviously runnable entries. - is_runnable = os.access(p.as_posix(), os.X_OK) or rel.endswith( - (".py", ".sh", ".pl", ".rb", ".js", ".ts")) - if not is_runnable or rel in seen: - continue - seen.add(rel) - suggestions.append(rel) - except Exception: # pylint: disable=broad-except - continue - return suggestions - - @staticmethod - def _with_missing_entrypoint_hint( - ret: WorkspaceRunResult, - input_data: SkillRunInput, - ws: WorkspaceInfo, - cwd: str, - ) -> WorkspaceRunResult: - """Add a generic hint when command references missing relative entrypoints.""" - stderr = ret.stderr or "" - stderr_lower = stderr.lower() - if ret.exit_code == 0: - return ret - if "no such file or directory" not in stderr_lower and "can't open file" not in stderr_lower: - return ret - missing_candidates = SkillRunTool._extract_command_path_candidates(input_data.command) - if not missing_candidates: - return ret - - run_dir = Path(ws.path) / cwd - # Show only candidates that are truly missing under cwd. - missing = [p for p in missing_candidates if not (run_dir / p).exists()] - if not missing: - return ret - available = SkillRunTool._list_entrypoint_suggestions(run_dir, limit=20) - if not available: - return ret - - hint = ("\n\nSkill entrypoint hint:\n" - f"- Missing relative path(s) in command: {', '.join(f'`{m}`' for m in missing)}\n" - f"- Runnable file candidates from current skill cwd: {', '.join(available)}\n" - "- Do not invent file names or paths.\n" - "- Read the loaded `SKILL.md` and run one of those commands exactly.\n") - return ret.model_copy(update={"stderr": f"{stderr}{hint}"}) - def _resolve_cwd(self, cwd: str, skill_dir: str) -> str: """Resolve the working directory relative to the workspace root. diff --git a/trpc_agent_sdk/skills/tools/_skill_select_docs.py b/trpc_agent_sdk/skills/tools/_skill_select_docs.py index ad5c756..1e07dc4 100644 --- a/trpc_agent_sdk/skills/tools/_skill_select_docs.py +++ b/trpc_agent_sdk/skills/tools/_skill_select_docs.py @@ -15,8 +15,14 @@ from trpc_agent_sdk.context import InvocationContext from .._common import BaseSelectionResult -from .._common import generic_select_items -from .._constants import SKILL_DOCS_STATE_KEY_PREFIX +from .._common import append_loaded_order_state_delta +from .._common import docs_state_key +from .._common import get_agent_name +from .._common import get_previous_selection_by_key +from .._common import normalize_selection_mode +from .._common import set_selection_state_delta_by_key +from .._constants import SKILL_REPOSITORY_KEY +from .._repository import BaseSkillRepository class SkillSelectDocsResult(BaseSelectionResult): @@ -53,11 +59,58 @@ def skill_select_docs(tool_context: InvocationContext, Returns: A message indicating the docs were selected. """ - result = generic_select_items(tool_context=tool_context, - skill_name=skill_name, - items=docs, - include_all=include_all_docs, - mode=mode, - state_key_prefix=SKILL_DOCS_STATE_KEY_PREFIX, - result_class=SkillSelectDocsResult) + normalized_skill = (skill_name or "").strip() + if not normalized_skill: + raise ValueError("skill is required") + normalized_mode = normalize_selection_mode(mode) + agent_name = get_agent_name(tool_context) + + repository: Optional[BaseSkillRepository] = tool_context.agent_context.get_metadata(SKILL_REPOSITORY_KEY) + if repository is not None: + try: + _ = repository.get(normalized_skill) + except ValueError as ex: + raise ValueError(f"unknown skill: {normalized_skill}") from ex + + docs_selection_key = docs_state_key(tool_context, normalized_skill) + previous_items, had_all = get_previous_selection_by_key(tool_context, docs_selection_key) + if had_all and normalized_mode != "clear": + result = SkillSelectDocsResult( + skill=normalized_skill, + selected_items=[], + include_all=True, + mode=normalized_mode, + ) + elif normalized_mode == "clear": + result = SkillSelectDocsResult( + skill=normalized_skill, + selected_items=[], + include_all=False, + mode="clear", + ) + elif normalized_mode == "add": + selected = set(previous_items) + for item in docs or []: + selected.add(item) + result = SkillSelectDocsResult( + skill=normalized_skill, + selected_items=[] if include_all_docs else list(selected), + include_all=include_all_docs, + mode="add", + ) + else: + result = SkillSelectDocsResult( + skill=normalized_skill, + selected_items=[] if include_all_docs else list(docs or []), + include_all=include_all_docs, + mode="replace", + ) + + set_selection_state_delta_by_key( + tool_context, + docs_selection_key, + result.selected_docs, + result.include_all_docs, + ) + append_loaded_order_state_delta(tool_context, agent_name, normalized_skill) return result diff --git a/trpc_agent_sdk/skills/tools/_skill_select_tools.py b/trpc_agent_sdk/skills/tools/_skill_select_tools.py index 0c44b0f..f2e9745 100644 --- a/trpc_agent_sdk/skills/tools/_skill_select_tools.py +++ b/trpc_agent_sdk/skills/tools/_skill_select_tools.py @@ -15,8 +15,14 @@ from trpc_agent_sdk.context import InvocationContext from .._common import BaseSelectionResult -from .._common import generic_select_items -from .._constants import SKILL_TOOLS_STATE_KEY_PREFIX +from .._common import append_loaded_order_state_delta +from .._common import get_agent_name +from .._common import get_previous_selection_by_key +from .._common import normalize_selection_mode +from .._common import set_selection_state_delta_by_key +from .._common import tool_state_key +from .._constants import SKILL_REPOSITORY_KEY +from .._repository import BaseSkillRepository class SkillSelectToolsResult(BaseSelectionResult): @@ -73,11 +79,58 @@ def skill_select_tools(tool_context: InvocationContext, tools=["get_current_weather"], mode="replace") """ - result = generic_select_items(tool_context=tool_context, - skill_name=skill_name, - items=tools, - include_all=include_all_tools, - mode=mode, - state_key_prefix=SKILL_TOOLS_STATE_KEY_PREFIX, - result_class=SkillSelectToolsResult) + normalized_skill = (skill_name or "").strip() + if not normalized_skill: + raise ValueError("skill is required") + normalized_mode = normalize_selection_mode(mode) + agent_name = get_agent_name(tool_context) + + repository: Optional[BaseSkillRepository] = tool_context.agent_context.get_metadata(SKILL_REPOSITORY_KEY) + if repository is not None: + try: + _ = repository.get(normalized_skill) + except ValueError as ex: + raise ValueError(f"unknown skill: {normalized_skill}") from ex + + tools_selection_key = tool_state_key(tool_context, normalized_skill) + previous_items, had_all = get_previous_selection_by_key(tool_context, tools_selection_key) + if had_all and normalized_mode != "clear": + result = SkillSelectToolsResult( + skill=normalized_skill, + selected_items=[], + include_all=True, + mode=normalized_mode, + ) + elif normalized_mode == "clear": + result = SkillSelectToolsResult( + skill=normalized_skill, + selected_items=[], + include_all=False, + mode="clear", + ) + elif normalized_mode == "add": + selected = set(previous_items) + for item in tools or []: + selected.add(item) + result = SkillSelectToolsResult( + skill=normalized_skill, + selected_items=[] if include_all_tools else list(selected), + include_all=include_all_tools, + mode="add", + ) + else: + result = SkillSelectToolsResult( + skill=normalized_skill, + selected_items=[] if include_all_tools else list(tools or []), + include_all=include_all_tools, + mode="replace", + ) + + set_selection_state_delta_by_key( + tool_context, + tools_selection_key, + result.selected_tools, + result.include_all_tools, + ) + append_loaded_order_state_delta(tool_context, agent_name, normalized_skill) return result diff --git a/trpc_agent_sdk/skills/tools/_workspace_exec.py b/trpc_agent_sdk/skills/tools/_workspace_exec.py new file mode 100644 index 0000000..395457f --- /dev/null +++ b/trpc_agent_sdk/skills/tools/_workspace_exec.py @@ -0,0 +1,518 @@ +# Tencent is pleased to support the open source community by making tRPC-Agent-Python available. +# +# Copyright (C) 2026 Tencent. All rights reserved. +# +# tRPC-Agent-Python is licensed under Apache-2.0. +"""Shared executor workspace tools.""" + +from __future__ import annotations + +import posixpath +import time +from dataclasses import dataclass +from typing import Any +from typing import Optional +from typing_extensions import override + +from pydantic import BaseModel +from pydantic import Field +from trpc_agent_sdk.code_executors import BaseCodeExecutor +from trpc_agent_sdk.code_executors import BaseProgramSession +from trpc_agent_sdk.code_executors import BaseWorkspaceRuntime +from trpc_agent_sdk.code_executors import DEFAULT_EXEC_YIELD_MS +from trpc_agent_sdk.code_executors import DEFAULT_SESSION_KILL_SEC +from trpc_agent_sdk.code_executors import DEFAULT_SESSION_TTL_SEC +from trpc_agent_sdk.code_executors import DIR_OUT +from trpc_agent_sdk.code_executors import DIR_RUNS +from trpc_agent_sdk.code_executors import DIR_SKILLS +from trpc_agent_sdk.code_executors import DIR_WORK +from trpc_agent_sdk.code_executors import PROGRAM_STATUS_EXITED +from trpc_agent_sdk.code_executors import PROGRAM_STATUS_RUNNING +from trpc_agent_sdk.code_executors import ProgramPoll +from trpc_agent_sdk.code_executors import WorkspaceRunProgramSpec +from trpc_agent_sdk.code_executors import poll_line_limit +from trpc_agent_sdk.code_executors import wait_for_program_output +from trpc_agent_sdk.code_executors import yield_duration_ms +from trpc_agent_sdk.code_executors.utils import normalize_globs +from trpc_agent_sdk.context import InvocationContext +from trpc_agent_sdk.filter import BaseFilter +from trpc_agent_sdk.tools import BaseTool +from trpc_agent_sdk.types import FunctionDeclaration +from trpc_agent_sdk.types import Schema +from trpc_agent_sdk.types import Type + +from ._common import CreateWorkspaceNameCallback +from ._common import cleanup_expired_sessions +from ._common import default_create_ws_name_callback +from ._common import require_non_empty + +_DEFAULT_WORKSPACE_EXEC_TIMEOUT_SEC = 5 * 60 +_DEFAULT_WORKSPACE_WRITE_YIELD_MS = 200 + + +def _combine_output(stdout: str, stderr: str) -> str: + if not stdout: + return stderr + if not stderr: + return stdout + return stdout + stderr + + +def _has_glob_meta(s: str) -> bool: + return any(ch in s for ch in ("*", "?", "[")) + + +def _has_env_prefix(s: str, name: str) -> bool: + if s.startswith(f"${name}"): + tail = s[len(name) + 1:] + return tail == "" or tail.startswith("/") or tail.startswith("\\") + brace = f"${{{name}}}" + if s.startswith(brace): + tail = s[len(brace):] + return tail == "" or tail.startswith("/") or tail.startswith("\\") + return False + + +def _is_workspace_env_path(s: str) -> bool: + return (_has_env_prefix(s, "WORKSPACE_DIR") or _has_env_prefix(s, "SKILLS_DIR") or _has_env_prefix(s, "WORK_DIR") + or _has_env_prefix(s, "OUTPUT_DIR") or _has_env_prefix(s, "RUN_DIR")) + + +def _is_allowed_workspace_path(rel: str) -> bool: + return any(rel == root or rel.startswith(f"{root}/") for root in (DIR_SKILLS, DIR_WORK, DIR_OUT, DIR_RUNS)) + + +def _normalize_cwd(raw: str) -> str: + s = (raw or "").strip().replace("\\", "/") + if not s: + return "." + if _has_glob_meta(s): + raise ValueError("cwd must not contain glob patterns") + if _is_workspace_env_path(s): + out = normalize_globs([s]) + if not out: + raise ValueError("invalid cwd") + s = out[0] + if s.startswith("/"): + rel = posixpath.normpath(s).lstrip("/") + if rel in ("", "."): + return "." + if not _is_allowed_workspace_path(rel): + raise ValueError(f"cwd must stay under workspace roots: {raw!r}") + return rel + rel = posixpath.normpath(s) + if rel == ".": + return "." + if rel == ".." or rel.startswith("../"): + raise ValueError("cwd must stay within the workspace") + if not _is_allowed_workspace_path(rel): + raise ValueError( + f"cwd must stay under supported workspace roots such as skills/, work/, out/, or runs/: {raw!r}") + return rel + + +def _exec_timeout_seconds(raw: int) -> float: + if raw <= 0: + return float(_DEFAULT_WORKSPACE_EXEC_TIMEOUT_SEC) + return float(raw) + + +def _exec_yield_seconds(background: bool, raw_ms: Optional[int]) -> float: + if background: + if raw_ms is not None and raw_ms > 0: + return raw_ms / 1000.0 + return 0.0 + return yield_duration_ms(raw_ms or 0, DEFAULT_EXEC_YIELD_MS) + + +def _write_yield_seconds(raw_ms: Optional[int]) -> float: + if raw_ms is None: + return _DEFAULT_WORKSPACE_WRITE_YIELD_MS / 1000.0 + if raw_ms < 0: + return 0.0 + return raw_ms / 1000.0 + + +class _ExecInput(BaseModel): + command: str = Field(default="") + cwd: str = Field(default="") + env: dict[str, str] = Field(default_factory=dict) + stdin: str = Field(default="") + yield_time_ms: int = Field(default=0) + background: bool = Field(default=False) + timeout_sec: int = Field(default=0) + tty: bool = Field(default=False) + + +class _WriteInput(BaseModel): + session_id: str = Field(default="") + chars: str = Field(default="") + yield_time_ms: Optional[int] = Field(default=None) + append_newline: bool = Field(default=False) + + +class _KillInput(BaseModel): + session_id: str = Field(default="") + + +@dataclass +class _ExecSession: + proc: BaseProgramSession + exited_at: Optional[float] = None + finalized: bool = False + finalized_at: Optional[float] = None + + +class WorkspaceExecTool(BaseTool): + """Execute shell commands in shared executor workspace.""" + + def __init__( + self, + workspace_runtime: BaseWorkspaceRuntime, + create_ws_name_cb: Optional[CreateWorkspaceNameCallback] = None, + session_ttl: float = DEFAULT_SESSION_TTL_SEC, + filters_name: Optional[list[str]] = None, + filters: Optional[list[BaseFilter]] = None, + ): + super().__init__( + name="workspace_exec", + description=("Execute a shell command inside the current executor workspace. " + "This is the default shell runner for executor-side work not bound to a specific skill."), + filters_name=filters_name, + filters=filters, + ) + self._workspace_runtime = workspace_runtime + self._create_ws_name_cb = create_ws_name_cb or default_create_ws_name_callback + self._ttl = session_ttl + self._sessions: dict[str, _ExecSession] = {} + + def _runtime(self) -> BaseWorkspaceRuntime: + runtime = self._workspace_runtime + if runtime is None: + raise ValueError("workspace_exec requires an executor with live workspace support") + return runtime + + async def _workspace(self, ctx: InvocationContext): + runtime = self._runtime() + manager = runtime.manager(ctx) + workspace_id = self._create_ws_name_cb(ctx) + ws = await manager.create_workspace(workspace_id, ctx) + return runtime, ws + + def _supports_interactive(self, ctx: InvocationContext) -> bool: + runner = self._runtime().runner(ctx) + start_program = getattr(runner, "start_program", None) + return start_program is not None + + async def _put_session(self, sid: str, session: _ExecSession) -> None: + await self._cleanup_expired_locked() + self._sessions[sid] = session + + async def _get_session(self, sid: str) -> _ExecSession: + await self._cleanup_expired_locked() + session = self._sessions.get(sid) + if session is None: + raise ValueError(f"unknown session_id: {sid}") + return session + + async def _remove_session(self, sid: str) -> _ExecSession: + await self._cleanup_expired_locked() + session = self._sessions.pop(sid, None) + if session is None: + raise ValueError(f"unknown session_id: {sid}") + return session + + async def _finalize_and_remove_session(self, sid: str) -> None: + session = await self._get_session(sid) + now = time.time() + session.finalized = True + session.finalized_at = now + session.exited_at = now + await session.proc.close() + await self._remove_session(sid) + + async def _cleanup_expired_locked(self) -> None: + + async def _refresh_exit_state(session: _ExecSession, now: float) -> None: + if session.exited_at is not None: + return + session_state = await session.proc.state() + if session_state.status == PROGRAM_STATUS_EXITED: + session.exited_at = now + + await cleanup_expired_sessions( + self._sessions, + ttl=self._ttl, + refresh_exit_state=_refresh_exit_state, + close_session=lambda s: s.proc.close(), + ) + + @override + def _get_declaration(self) -> FunctionDeclaration: + return FunctionDeclaration( + name="workspace_exec", + description=("Execute a shell command in the shared executor workspace. " + "Use for workspace-level file operations and validation commands."), + parameters=Schema( + type=Type.OBJECT, + required=["command"], + properties={ + "command": Schema(type=Type.STRING, description="Shell command to execute."), + "cwd": Schema(type=Type.STRING, description="Optional workspace-relative cwd."), + "env": Schema(type=Type.OBJECT, description="Optional environment overrides."), + "stdin": Schema(type=Type.STRING, description="Optional initial stdin text."), + "timeout_sec": Schema(type=Type.INTEGER, description="Maximum command runtime in seconds."), + "yield_time_ms": Schema(type=Type.INTEGER, description="How long to wait before returning."), + "background": Schema(type=Type.BOOLEAN, description="Start command in background session."), + "tty": Schema(type=Type.BOOLEAN, description="Allocate TTY for interactive commands."), + }, + ), + response=_exec_output_schema( + "Result of workspace_exec. output is aggregated terminal text and may combine stdout/stderr."), + ) + + @override + async def _run_async_impl(self, *, tool_context: InvocationContext, args: dict[str, Any]) -> Any: + inputs = _ExecInput.model_validate(args) + command = (inputs.command or "").strip() + if not command: + raise ValueError("command is required") + cwd = _normalize_cwd(inputs.cwd) + timeout_raw = inputs.timeout_sec + tty = inputs.tty + yield_ms = inputs.yield_time_ms + + runtime, ws = await self._workspace(tool_context) + spec = WorkspaceRunProgramSpec( + cmd="sh", + args=["-lc", command], + env=inputs.env, + cwd=cwd, + stdin=inputs.stdin, + timeout=_exec_timeout_seconds(timeout_raw), + ) + + if not self._supports_interactive(tool_context): + if inputs.background or tty: + raise ValueError("workspace_exec interactive sessions are not supported by the current executor") + return await _run_one_shot(runtime, ws, spec, tool_context) + + if (not inputs.background) and (not tty) and yield_ms <= 0: + return await _run_one_shot(runtime, ws, spec, tool_context) + + runner = runtime.runner(tool_context) + interactive_spec = WorkspaceRunProgramSpec( + cmd=spec.cmd, + args=spec.args, + env=spec.env, + cwd=spec.cwd, + stdin=spec.stdin, + timeout=spec.timeout, + limits=spec.limits, + tty=tty, + ) + proc = await runner.start_program(tool_context, ws, interactive_spec) # type: ignore[attr-defined] + sid = proc.id() + await self._put_session(sid, _ExecSession(proc=proc)) + + if inputs.background and yield_ms <= 0: + poll = await proc.poll(poll_line_limit(0)) + else: + poll = await wait_for_program_output( + proc, + _exec_yield_seconds(inputs.background, yield_ms if yield_ms > 0 else None), + poll_line_limit(0), + ) + + out = _poll_output(sid, poll) + if poll.status == PROGRAM_STATUS_EXITED: + try: + await self._finalize_and_remove_session(sid) + except Exception: # pylint: disable=broad-except + out["session_id"] = sid + return out + + +class WorkspaceWriteStdinTool(BaseTool): + """Write stdin to a running workspace_exec session or poll it.""" + + def __init__( + self, + exec_tool: WorkspaceExecTool, + *, + filters_name: Optional[list[str]] = None, + filters: Optional[list[BaseFilter]] = None, + ): + super().__init__( + name="workspace_write_stdin", + description="Write to a running workspace_exec session. Empty chars acts like poll.", + filters_name=filters_name, + filters=filters, + ) + self._exec = exec_tool + + @override + def _get_declaration(self) -> FunctionDeclaration: + return FunctionDeclaration( + name="workspace_write_stdin", + description="Write to a running workspace_exec session. Empty chars acts like poll.", + parameters=Schema( + type=Type.OBJECT, + required=["session_id"], + properties={ + "session_id": Schema(type=Type.STRING, description="Session id returned by workspace_exec."), + "chars": Schema(type=Type.STRING, description="Characters to write."), + "yield_time_ms": Schema(type=Type.INTEGER, description="Optional wait before polling output."), + "append_newline": Schema(type=Type.BOOLEAN, description="Append newline after chars."), + }, + ), + response=_exec_output_schema("Result of stdin write or follow-up poll."), + ) + + @override + async def _run_async_impl(self, *, tool_context: InvocationContext, args: dict[str, Any]) -> Any: + inputs = _WriteInput.model_validate(args) + session_id = require_non_empty(inputs.session_id, field_name="session_id") + session = await self._exec._get_session(session_id) + + append_newline = inputs.append_newline + if inputs.chars or append_newline: + await session.proc.write(inputs.chars or "", append_newline) + + yield_ms = inputs.yield_time_ms or 0 + user_set_yield = inputs.yield_time_ms is not None + poll = await wait_for_program_output( + session.proc, + _write_yield_seconds(yield_ms if user_set_yield else None), + poll_line_limit(0), + ) + out = _poll_output(session_id, poll) + if poll.status == PROGRAM_STATUS_EXITED: + try: + await self._exec._finalize_and_remove_session(session_id) + except Exception: # pylint: disable=broad-except + out["session_id"] = session_id + return out + + +class WorkspaceKillSessionTool(BaseTool): + """Terminate a running workspace_exec session.""" + + def __init__( + self, + exec_tool: WorkspaceExecTool, + *, + filters_name: Optional[list[str]] = None, + filters: Optional[list[BaseFilter]] = None, + ): + super().__init__( + name="workspace_kill_session", + description="Terminate a running workspace_exec session.", + filters_name=filters_name, + filters=filters, + ) + self._exec = exec_tool + + @override + def _get_declaration(self) -> FunctionDeclaration: + return FunctionDeclaration( + name="workspace_kill_session", + description="Terminate a running workspace_exec session.", + parameters=Schema( + type=Type.OBJECT, + required=["session_id"], + properties={ + "session_id": Schema(type=Type.STRING, description="Session id returned by workspace_exec."), + }, + ), + response=Schema( + type=Type.OBJECT, + required=["ok", "session_id", "status"], + properties={ + "ok": Schema(type=Type.BOOLEAN, description="True when session was removed."), + "session_id": Schema(type=Type.STRING, description="Session id."), + "status": Schema(type=Type.STRING, description="Final status."), + }, + ), + ) + + @override + async def _run_async_impl(self, *, tool_context: InvocationContext, args: dict[str, Any]) -> Any: + if "session_id" not in args: + raise ValueError("session_id is required") + session_id = args["session_id"] + session = await self._exec._get_session(session_id) + status = PROGRAM_STATUS_EXITED + poll = await session.proc.poll(None) + if poll.status == PROGRAM_STATUS_RUNNING: + await session.proc.kill(DEFAULT_SESSION_KILL_SEC) + status = "killed" + await self._exec._finalize_and_remove_session(session_id) + return {"ok": True, "session_id": session_id, "status": status} + + +def _exec_output_schema(description: str) -> Schema: + return Schema( + type=Type.OBJECT, + description=description, + required=["status", "offset", "next_offset"], + properties={ + "status": Schema(type=Type.STRING, description="running or exited"), + "output": Schema(type=Type.STRING, description="Aggregated terminal text observed for this call."), + "exit_code": Schema(type=Type.INTEGER, description="Exit code when session has exited."), + "session_id": Schema(type=Type.STRING, description="Interactive session id when still running."), + "offset": Schema(type=Type.INTEGER, description="Start offset of returned output."), + "next_offset": Schema(type=Type.INTEGER, description="Next output offset."), + }, + ) + + +def _poll_output(session_id: str, poll: ProgramPoll) -> dict[str, Any]: + out: dict[str, Any] = { + "status": poll.status, + "output": poll.output, + "offset": poll.offset, + "next_offset": poll.next_offset, + } + if poll.exit_code is not None: + out["exit_code"] = poll.exit_code + if poll.status == PROGRAM_STATUS_RUNNING: + out["session_id"] = session_id + return out + + +async def _run_one_shot( + runtime: BaseWorkspaceRuntime, + workspace: Any, + spec: WorkspaceRunProgramSpec, + ctx: InvocationContext, +) -> dict[str, Any]: + rr = await runtime.runner(ctx).run_program(workspace, spec, ctx) + return { + "status": PROGRAM_STATUS_EXITED, + "output": _combine_output(rr.stdout, rr.stderr), + "exit_code": rr.exit_code, + "offset": 0, + "next_offset": 0, + } + + +def create_workspace_exec_tools( + code_executor: BaseCodeExecutor, + *, + workspace_runtime: Optional[BaseWorkspaceRuntime] = None, + session_ttl: float = DEFAULT_SESSION_TTL_SEC, + filters_name: Optional[list[str]] = None, + filters: Optional[list[BaseFilter]] = None, +) -> tuple[WorkspaceExecTool, WorkspaceWriteStdinTool, WorkspaceKillSessionTool]: + """Create workspace_exec tool trio.""" + exec_tool = WorkspaceExecTool( + code_executor=code_executor, + workspace_runtime=workspace_runtime, + session_ttl=session_ttl, + filters_name=filters_name, + filters=filters, + ) + write_tool = WorkspaceWriteStdinTool(exec_tool, filters_name=filters_name, filters=filters) + kill_tool = WorkspaceKillSessionTool(exec_tool, filters_name=filters_name, filters=filters) + return exec_tool, write_tool, kill_tool diff --git a/trpc_agent_sdk/storage/_sql_common.py b/trpc_agent_sdk/storage/_sql_common.py index d9d4c04..8f5ef17 100644 --- a/trpc_agent_sdk/storage/_sql_common.py +++ b/trpc_agent_sdk/storage/_sql_common.py @@ -172,6 +172,8 @@ class DynamicPickleType(TypeDecorator): def load_dialect_impl(self, dialect: Dialect) -> TypeDecorator: if dialect.name == "spanner+spanner": return dialect.type_descriptor(SpannerPickleType) # type: ignore + if dialect.name == "mysql": + return dialect.type_descriptor(mysql.LONGBLOB) # type: ignore return self.impl def process_bind_param(self, value: Any, dialect: Dialect) -> Any: