diff --git a/astrbot/core/star/star_tools.py b/astrbot/core/star/star_tools.py index fe5563b7dd..0a390c4bfe 100644 --- a/astrbot/core/star/star_tools.py +++ b/astrbot/core/star/star_tools.py @@ -20,8 +20,9 @@ import inspect import os import uuid -from collections.abc import Awaitable, Callable +from collections.abc import Awaitable, Callable, Mapping from pathlib import Path +from types import ModuleType from typing import Any, ClassVar from astrbot.api.platform import AstrBotMessage, MessageMember, MessageType @@ -35,10 +36,132 @@ AiocqhttpAdapter, ) from astrbot.core.star.context import Context -from astrbot.core.star.star import star_map -from astrbot.core.utils.astrbot_path import get_astrbot_data_path +from astrbot.core.star.star import StarMetadata, star_map +from astrbot.core.utils.astrbot_path import ( + get_astrbot_data_path, + get_astrbot_path, + get_astrbot_plugin_path, +) from astrbot.core.utils.io import ensure_dir +_PLUGIN_MODULE_FLAGS = {"plugins", "builtin_stars"} + + +def _split_module_path(module_path: str | None) -> list[str]: + if not module_path: + return [] + return module_path.split(".") + + +def _plugin_root_from_module_path(module_path: str | None) -> tuple[str, str] | None: + parts = _split_module_path(module_path) + for index, part in enumerate(parts): + if part in _PLUGIN_MODULE_FLAGS and index + 1 < len(parts): + return part, parts[index + 1] + return None + + +def _metadata_root_dir_name( + metadata: StarMetadata, + module_path: str | None, +) -> str | None: + if metadata.root_dir_name: + return metadata.root_dir_name + + root_info = _plugin_root_from_module_path(metadata.module_path or module_path) + return root_info[1] if root_info else None + + +def _iter_star_metadata( + stars: Mapping[str, StarMetadata], +) -> list[tuple[str, StarMetadata]]: + seen: set[int] = set() + metadata_items: list[tuple[str, StarMetadata]] = [] + for module_path, metadata in reversed(tuple(stars.items())): + metadata_id = id(metadata) + if metadata_id in seen: + continue + seen.add(metadata_id) + metadata_items.append((module_path, metadata)) + return metadata_items + + +def _resolve_plugin_from_root_dir( + root_dir_name: str, + stars: Mapping[str, StarMetadata], + module_flag: str | None = None, +) -> StarMetadata | None: + for module_path, metadata in _iter_star_metadata(stars): + registered_module_path = metadata.module_path or module_path + registered_root = _plugin_root_from_module_path(registered_module_path) + if module_flag and registered_root and registered_root[0] != module_flag: + continue + if _metadata_root_dir_name(metadata, module_path) == root_dir_name: + return metadata + return None + + +def _resolve_plugin_from_registered_package( + module_path: str, + stars: Mapping[str, StarMetadata], +) -> StarMetadata | None: + root_info = _plugin_root_from_module_path(module_path) + if not root_info: + return None + + module_flag, root_dir_name = root_info + return _resolve_plugin_from_root_dir(root_dir_name, stars, module_flag) + + +def _plugin_search_roots() -> tuple[tuple[str, Path], ...]: + return ( + ("plugins", Path(get_astrbot_plugin_path()).resolve()), + ( + "builtin_stars", + Path(get_astrbot_path()).resolve() / "astrbot" / "builtin_stars", + ), + ) + + +def _resolve_plugin_from_file_path( + module: ModuleType, + stars: Mapping[str, StarMetadata], +) -> StarMetadata | None: + module_file = getattr(module, "__file__", None) + if not module_file: + return None + + try: + module_path = Path(module_file).resolve() + except Exception: + return None + + for module_flag, plugin_root in _plugin_search_roots(): + try: + relative_parts = module_path.relative_to(plugin_root).parts + except ValueError: + continue + + if relative_parts: + return _resolve_plugin_from_root_dir( + relative_parts[0], + stars, + module_flag, + ) + + return None + + +def _resolve_plugin_metadata( + module: ModuleType, + stars: Mapping[str, StarMetadata], +) -> StarMetadata | None: + return ( + stars.get(module.__name__) + or _resolve_plugin_from_registered_package(module.__name__, stars) + or _resolve_plugin_from_file_path(module, stars) + ) + class StarTools: """提供给插件使用的便捷工具函数集合 @@ -291,7 +414,7 @@ def get_data_dir(cls, plugin_name: str | None = None) -> Path: if not module: raise RuntimeError("无法获取调用者模块信息") - metadata = star_map.get(module.__name__, None) + metadata = _resolve_plugin_metadata(module, star_map) if not metadata: raise RuntimeError(f"无法获取模块 {module.__name__} 的元数据信息") diff --git a/tests/unit/test_star_tools.py b/tests/unit/test_star_tools.py new file mode 100644 index 0000000000..8b4578ea7f --- /dev/null +++ b/tests/unit/test_star_tools.py @@ -0,0 +1,79 @@ +from types import ModuleType + +import pytest + +from astrbot.core.star import star_tools +from astrbot.core.star.star import StarMetadata, star_map +from astrbot.core.star.star_tools import StarTools + + +@pytest.fixture(autouse=True) +def restore_star_map(): + original_map = dict(star_map) + star_map.clear() + try: + yield + finally: + star_map.clear() + star_map.update(original_map) + + +def make_module(name: str, file_path: str | None = None) -> ModuleType: + module = ModuleType(name) + if file_path: + module.__file__ = file_path + return module + + +def set_caller_module(monkeypatch: pytest.MonkeyPatch, module: ModuleType) -> None: + monkeypatch.setattr(star_tools.inspect, "getmodule", lambda _frame: module) + + +def test_get_data_dir_resolves_registered_plugin_submodule(monkeypatch, tmp_path): + data_path = tmp_path / "data" + monkeypatch.setattr(star_tools, "get_astrbot_data_path", lambda: str(data_path)) + set_caller_module( + monkeypatch, + make_module("data.plugins.demo_plugin.services.cache"), + ) + star_map["data.plugins.demo_plugin.main"] = StarMetadata( + name="demo", + module_path="data.plugins.demo_plugin.main", + root_dir_name="demo_plugin", + ) + + data_dir = StarTools.get_data_dir() + + assert data_dir == (data_path / "plugin_data" / "demo").resolve() + + +def test_get_data_dir_resolves_debug_module_from_plugin_path(monkeypatch, tmp_path): + data_path = tmp_path / "data" + plugin_root = tmp_path / "plugins" + debug_file = plugin_root / "demo_plugin" / "scripts" / "debug.py" + debug_file.parent.mkdir(parents=True) + debug_file.write_text("", encoding="utf-8") + monkeypatch.setattr(star_tools, "get_astrbot_data_path", lambda: str(data_path)) + monkeypatch.setattr(star_tools, "get_astrbot_plugin_path", lambda: str(plugin_root)) + monkeypatch.setattr(star_tools, "get_astrbot_path", lambda: str(tmp_path / "src")) + set_caller_module(monkeypatch, make_module("__main__", str(debug_file))) + star_map["data.plugins.demo_plugin.main"] = StarMetadata( + name="demo", + module_path="data.plugins.demo_plugin.main", + root_dir_name="demo_plugin", + ) + + data_dir = StarTools.get_data_dir() + + assert data_dir == (data_path / "plugin_data" / "demo").resolve() + + +def test_get_data_dir_keeps_unknown_module_failure(monkeypatch, tmp_path): + data_path = tmp_path / "data" + monkeypatch.setattr(star_tools, "get_astrbot_data_path", lambda: str(data_path)) + set_caller_module(monkeypatch, make_module("external.module")) + + with pytest.raises(RuntimeError, match="无法获取模块 external.module 的元数据信息"): + StarTools.get_data_dir() + + assert not (data_path / "plugin_data" / "unknown_plugin").exists()