Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
131 changes: 127 additions & 4 deletions astrbot/core/star/star_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,9 @@
import inspect
import os
import uuid
from collections.abc import Awaitable, Callable
from collections.abc import Awaitable, Callable, Mapping
from pathlib import Path
from types import ModuleType
from typing import Any, ClassVar

from astrbot.api.platform import AstrBotMessage, MessageMember, MessageType
Expand All @@ -35,10 +36,132 @@
AiocqhttpAdapter,
)
from astrbot.core.star.context import Context
from astrbot.core.star.star import star_map
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
from astrbot.core.star.star import StarMetadata, star_map
from astrbot.core.utils.astrbot_path import (
get_astrbot_data_path,
get_astrbot_path,
get_astrbot_plugin_path,
)
from astrbot.core.utils.io import ensure_dir

_PLUGIN_MODULE_FLAGS = {"plugins", "builtin_stars"}


def _split_module_path(module_path: str | None) -> list[str]:
if not module_path:
return []
return module_path.split(".")


def _plugin_root_from_module_path(module_path: str | None) -> tuple[str, str] | None:
parts = _split_module_path(module_path)
for index, part in enumerate(parts):
if part in _PLUGIN_MODULE_FLAGS and index + 1 < len(parts):
return part, parts[index + 1]
return None


def _metadata_root_dir_name(
metadata: StarMetadata,
module_path: str | None,
) -> str | None:
if metadata.root_dir_name:
return metadata.root_dir_name

root_info = _plugin_root_from_module_path(metadata.module_path or module_path)
return root_info[1] if root_info else None


def _iter_star_metadata(
stars: Mapping[str, StarMetadata],
) -> list[tuple[str, StarMetadata]]:
seen: set[int] = set()
metadata_items: list[tuple[str, StarMetadata]] = []
for module_path, metadata in reversed(tuple(stars.items())):
metadata_id = id(metadata)
if metadata_id in seen:
continue
seen.add(metadata_id)
metadata_items.append((module_path, metadata))
return metadata_items


def _resolve_plugin_from_root_dir(
root_dir_name: str,
stars: Mapping[str, StarMetadata],
module_flag: str | None = None,
) -> StarMetadata | None:
for module_path, metadata in _iter_star_metadata(stars):
registered_module_path = metadata.module_path or module_path
registered_root = _plugin_root_from_module_path(registered_module_path)
if module_flag and registered_root and registered_root[0] != module_flag:
continue
if _metadata_root_dir_name(metadata, module_path) == root_dir_name:
return metadata
return None


def _resolve_plugin_from_registered_package(
module_path: str,
stars: Mapping[str, StarMetadata],
) -> StarMetadata | None:
root_info = _plugin_root_from_module_path(module_path)
if not root_info:
return None

module_flag, root_dir_name = root_info
return _resolve_plugin_from_root_dir(root_dir_name, stars, module_flag)


def _plugin_search_roots() -> tuple[tuple[str, Path], ...]:
return (
("plugins", Path(get_astrbot_plugin_path()).resolve()),
(
"builtin_stars",
Path(get_astrbot_path()).resolve() / "astrbot" / "builtin_stars",
),
)


def _resolve_plugin_from_file_path(
module: ModuleType,
stars: Mapping[str, StarMetadata],
) -> StarMetadata | None:
module_file = getattr(module, "__file__", None)
if not module_file:
return None

try:
module_path = Path(module_file).resolve()
except Exception:
return None

for module_flag, plugin_root in _plugin_search_roots():
try:
relative_parts = module_path.relative_to(plugin_root).parts
except ValueError:
continue

if relative_parts:
return _resolve_plugin_from_root_dir(
relative_parts[0],
stars,
module_flag,
)

return None


def _resolve_plugin_metadata(
module: ModuleType,
stars: Mapping[str, StarMetadata],
) -> StarMetadata | None:
return (
stars.get(module.__name__)
or _resolve_plugin_from_registered_package(module.__name__, stars)
or _resolve_plugin_from_file_path(module, stars)
)


class StarTools:
"""提供给插件使用的便捷工具函数集合
Expand Down Expand Up @@ -291,7 +414,7 @@ def get_data_dir(cls, plugin_name: str | None = None) -> Path:
if not module:
raise RuntimeError("无法获取调用者模块信息")

metadata = star_map.get(module.__name__, None)
metadata = _resolve_plugin_metadata(module, star_map)

if not metadata:
raise RuntimeError(f"无法获取模块 {module.__name__} 的元数据信息")
Expand Down
79 changes: 79 additions & 0 deletions tests/unit/test_star_tools.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
from types import ModuleType

import pytest

from astrbot.core.star import star_tools
from astrbot.core.star.star import StarMetadata, star_map
from astrbot.core.star.star_tools import StarTools


@pytest.fixture(autouse=True)
def restore_star_map():
original_map = dict(star_map)
star_map.clear()
try:
yield
finally:
star_map.clear()
star_map.update(original_map)


def make_module(name: str, file_path: str | None = None) -> ModuleType:
module = ModuleType(name)
if file_path:
module.__file__ = file_path
return module


def set_caller_module(monkeypatch: pytest.MonkeyPatch, module: ModuleType) -> None:
monkeypatch.setattr(star_tools.inspect, "getmodule", lambda _frame: module)


def test_get_data_dir_resolves_registered_plugin_submodule(monkeypatch, tmp_path):
data_path = tmp_path / "data"
monkeypatch.setattr(star_tools, "get_astrbot_data_path", lambda: str(data_path))
set_caller_module(
monkeypatch,
make_module("data.plugins.demo_plugin.services.cache"),
)
star_map["data.plugins.demo_plugin.main"] = StarMetadata(
name="demo",
module_path="data.plugins.demo_plugin.main",
root_dir_name="demo_plugin",
)

data_dir = StarTools.get_data_dir()

assert data_dir == (data_path / "plugin_data" / "demo").resolve()


def test_get_data_dir_resolves_debug_module_from_plugin_path(monkeypatch, tmp_path):
data_path = tmp_path / "data"
plugin_root = tmp_path / "plugins"
debug_file = plugin_root / "demo_plugin" / "scripts" / "debug.py"
debug_file.parent.mkdir(parents=True)
debug_file.write_text("", encoding="utf-8")
monkeypatch.setattr(star_tools, "get_astrbot_data_path", lambda: str(data_path))
monkeypatch.setattr(star_tools, "get_astrbot_plugin_path", lambda: str(plugin_root))
monkeypatch.setattr(star_tools, "get_astrbot_path", lambda: str(tmp_path / "src"))
set_caller_module(monkeypatch, make_module("__main__", str(debug_file)))
star_map["data.plugins.demo_plugin.main"] = StarMetadata(
name="demo",
module_path="data.plugins.demo_plugin.main",
root_dir_name="demo_plugin",
)

data_dir = StarTools.get_data_dir()

assert data_dir == (data_path / "plugin_data" / "demo").resolve()


def test_get_data_dir_keeps_unknown_module_failure(monkeypatch, tmp_path):
data_path = tmp_path / "data"
monkeypatch.setattr(star_tools, "get_astrbot_data_path", lambda: str(data_path))
set_caller_module(monkeypatch, make_module("external.module"))

with pytest.raises(RuntimeError, match="无法获取模块 external.module 的元数据信息"):
StarTools.get_data_dir()

assert not (data_path / "plugin_data" / "unknown_plugin").exists()
Loading