From 750f13b6bce5458d2728e0d109abf5448c904178 Mon Sep 17 00:00:00 2001 From: axisrow Date: Sun, 10 May 2026 23:57:20 +0700 Subject: [PATCH 1/2] feat(search): add chat filters to search queries --- docs/features/search.md | 3 + src/agent/tools/search_queries.py | 25 ++- src/cli/commands/search_query.py | 42 ++++ src/cli/parser_domains/search_query.py | 3 + src/database/bundles.py | 8 +- src/database/migrations.py | 1 + src/database/repositories/messages.py | 18 ++ src/database/repositories/search_queries.py | 9 +- src/database/schema.py | 1 + src/models.py | 1 + src/services/notification_matcher.py | 8 +- src/services/search_query_service.py | 23 ++ src/telegram/collector.py | 11 +- src/utils/search_query_chat_filter.py | 196 ++++++++++++++++++ src/web/bootstrap.py | 2 +- src/web/routes/search_queries.py | 18 +- src/web/static/js/app.js | 11 +- src/web/templates/search_queries.html | 33 ++- .../repositories/test_messages_repository.py | 33 +++ .../test_search_queries_repository.py | 5 +- tests/routes/test_search_queries_routes.py | 16 ++ tests/test_agent_tools_search_queries.py | 20 ++ ...t_cli_process_database_repository_paths.py | 1 + .../test_migrations_worker_bootstrap_paths.py | 1 + tests/test_notification_matcher.py | 33 +++ tests/test_search_queries.py | 25 +++ tests/test_search_query_service.py | 4 + 27 files changed, 532 insertions(+), 19 deletions(-) create mode 100644 src/utils/search_query_chat_filter.py diff --git a/docs/features/search.md b/docs/features/search.md index 5ee48f06..3e47fa21 100644 --- a/docs/features/search.md +++ b/docs/features/search.md @@ -42,6 +42,7 @@ python -m src.main agent chat "найди сообщения про блокче ```bash python -m src.main search-query list python -m src.main search-query add "ключевое слово" + python -m src.main search-query add "ключевое слово" --chats "@chat1, -1001234567890" python -m src.main search-query run 1 # разовый запуск python -m src.main search-query stats 1 # статистика совпадений ``` @@ -49,4 +50,6 @@ python -m src.main agent chat "найди сообщения про блокче === "Web" `GET /search-queries/` · `POST /search-queries/add` +Поле `Чаты` у сохранённого запроса необязательно. Пустое значение ищет по всем чатам; непустое ограничивает поиск списком `channel_id`, `@username`, `username` или ссылок `t.me`, разделённых пробелами или запятыми. + Уведомления содержат ссылку на оригинальное сообщение (`t.me/channel/message_id`). diff --git a/src/agent/tools/search_queries.py b/src/agent/tools/search_queries.py index 1c52c6ef..27ee853d 100644 --- a/src/agent/tools/search_queries.py +++ b/src/agent/tools/search_queries.py @@ -42,9 +42,10 @@ async def list_search_queries(args): if getattr(sq, "notify_on_collect", False): flags.append("notify") flags_str = f" [{', '.join(flags)}]" if flags else "" + chats_str = f" chats={sq.chat_filter}" if getattr(sq, "chat_filter", "") else " chats=all" lines.append( f"- id={sq.id}: '{sq.query}' interval={sq.interval_minutes}m " - f"{status}{flags_str}" + f"{status}{flags_str}{chats_str}" ) return _text_response("\n".join(lines)) except Exception as e: @@ -76,7 +77,11 @@ async def get_search_query(args): f"is_regex: {sq.is_regex}", f"is_fts: {sq.is_fts}", f"notify_on_collect: {sq.notify_on_collect}", + f"chat_filter: {sq.chat_filter or 'all'}", ] + warning = (await svc.validate_chat_filter(sq.chat_filter)).warning_text() + if warning: + lines.append(f"warning: {warning}") return _text_response("\n".join(lines)) except Exception as e: return _text_response(f"Ошибка получения поискового запроса: {e}") @@ -101,6 +106,7 @@ async def get_search_query(args): "track_stats": Annotated[bool, "Записывать ежедневную статистику совпадений"], "exclude_patterns": Annotated[str, "Паттерны исключения через запятую"], "max_length": Annotated[int, "Максимальная длина сообщения для совпадения"], + "chat_filter": Annotated[str, "Список чатов через запятую или пробел"], "confirm": Annotated[bool, "Установите true для подтверждения действия"], }, ) @@ -122,6 +128,7 @@ async def add_search_query(args): track_stats = bool(args.get("track_stats", True)) exclude_patterns = args.get("exclude_patterns", "") max_length = args.get("max_length") + chat_filter = args.get("chat_filter", args.get("chats", "")) sq_id = await svc.add( query, interval_minutes=interval, @@ -131,8 +138,13 @@ async def add_search_query(args): track_stats=track_stats, exclude_patterns=exclude_patterns or "", max_length=int(max_length) if max_length is not None else None, + chat_filter=chat_filter or "", ) - return _text_response(f"Поисковый запрос создан (id={sq_id}).") + warning = (await svc.validate_chat_filter(chat_filter or "")).warning_text() + text = f"Поисковый запрос создан (id={sq_id})." + if warning: + text += f"\nПредупреждение: {warning}" + return _text_response(text) except Exception as e: return _text_response(f"Ошибка добавления поискового запроса: {e}") @@ -151,6 +163,7 @@ async def add_search_query(args): "track_stats": Annotated[bool, "Записывать ежедневную статистику совпадений"], "exclude_patterns": Annotated[str, "Паттерны исключения через запятую"], "max_length": Annotated[int, "Максимальная длина сообщения для совпадения"], + "chat_filter": Annotated[str, "Список чатов через запятую или пробел"], "confirm": Annotated[bool, "Установите true для подтверждения действия"], }, ) @@ -176,6 +189,7 @@ async def edit_search_query(args): track_stats = bool(args.get("track_stats", getattr(existing, "track_stats", True))) exclude_patterns = args.get("exclude_patterns", getattr(existing, "exclude_patterns", "")) max_length_raw = args.get("max_length", getattr(existing, "max_length", None)) + chat_filter = args.get("chat_filter", args.get("chats", getattr(existing, "chat_filter", ""))) ok = await svc.update( int(sq_id), query, @@ -186,9 +200,14 @@ async def edit_search_query(args): track_stats=track_stats, exclude_patterns=exclude_patterns or "", max_length=int(max_length_raw) if max_length_raw is not None else None, + chat_filter=chat_filter or "", ) if ok: - return _text_response(f"Поисковый запрос id={sq_id} обновлён.") + warning = (await svc.validate_chat_filter(chat_filter or "")).warning_text() + text = f"Поисковый запрос id={sq_id} обновлён." + if warning: + text += f"\nПредупреждение: {warning}" + return _text_response(text) return _text_response(f"Не удалось обновить запрос id={sq_id}.") except Exception as e: return _text_response(f"Ошибка редактирования поискового запроса: {e}") diff --git a/src/cli/commands/search_query.py b/src/cli/commands/search_query.py index a101cf80..d6e9bf49 100644 --- a/src/cli/commands/search_query.py +++ b/src/cli/commands/search_query.py @@ -2,6 +2,7 @@ import argparse import asyncio +import inspect from pydantic import ValidationError @@ -10,6 +11,20 @@ from src.services.search_query_service import SearchQueryService +async def _chat_filter_warning(svc: SearchQueryService, chat_filter: str) -> str: + validator = getattr(svc, "validate_chat_filter", None) + if validator is None: + return "" + try: + result = validator(chat_filter) + validation = await result if inspect.isawaitable(result) else result + warning_text = getattr(validation, "warning_text", None) + text = warning_text() if callable(warning_text) else "" + return text if isinstance(text, str) else "" + except Exception: + return "" + + def run(args: argparse.Namespace) -> None: async def _run() -> None: _, db = await runtime.init_db(args.config) @@ -35,6 +50,11 @@ async def _run() -> None: (item["last_run"] or "—")[:20], ) ) + chat_filter = getattr(sq, "chat_filter", "") + if chat_filter: + print(f" chats: {chat_filter}") + if item.get("chat_filter_warnings"): + print(f" warning: {item['chat_filter_warnings']}") elif args.search_query_action == "get": sq = await svc.get(args.id) @@ -51,6 +71,11 @@ async def _run() -> None: print(f"Track stats: {sq.track_stats}") print(f"Max length: {sq.max_length if sq.max_length is not None else '—'}") print(f"Exclude patterns: {sq.exclude_patterns or '—'}") + chat_filter = getattr(sq, "chat_filter", "") + print(f"Chats: {chat_filter or 'all'}") + warning = await _chat_filter_warning(svc, chat_filter) + if warning: + print(f"Warning: {warning}") elif args.search_query_action == "add": exclude = ( @@ -66,11 +91,15 @@ async def _run() -> None: track_stats=args.track_stats, exclude_patterns=exclude, max_length=args.max_length, + chat_filter=getattr(args, "chats", ""), ) except ValidationError as e: print(f"Error: {e.errors()[0]['msg']}") return print(f"Added search query id={sq_id}: {args.query}") + warning = await _chat_filter_warning(svc, getattr(args, "chats", "")) + if warning: + print(f"Warning: {warning}") elif args.search_query_action == "edit": sq = await svc.get(args.id) @@ -101,11 +130,24 @@ async def _run() -> None: track_stats=tstats, exclude_patterns=exclude, max_length=max_len, + chat_filter=( + args.chats + if getattr(args, "chats", None) is not None + else getattr(sq, "chat_filter", "") + ), ) except ValidationError as e: print(f"Error: {e.errors()[0]['msg']}") return print(f"Updated search query id={args.id}") + warning = await _chat_filter_warning( + svc, + args.chats + if getattr(args, "chats", None) is not None + else getattr(sq, "chat_filter", ""), + ) + if warning: + print(f"Warning: {warning}") elif args.search_query_action == "delete": await svc.delete(args.id) diff --git a/src/cli/parser_domains/search_query.py b/src/cli/parser_domains/search_query.py index 316bd8d8..b56031ea 100644 --- a/src/cli/parser_domains/search_query.py +++ b/src/cli/parser_domains/search_query.py @@ -22,6 +22,7 @@ def register(subparsers: argparse._SubParsersAction) -> argparse.ArgumentParser "--exclude-patterns", default="", help="Exclude patterns, one per line (use \\n)" ) sq_add.add_argument("--max-length", type=int, default=None, help="Max message text length") + sq_add.add_argument("--chats", default="", help="Chat filter: IDs, usernames or t.me links") sq_edit = sq_sub.add_parser("edit", help="Edit search query") sq_edit.add_argument("id", type=int, help="Search query id") @@ -38,6 +39,8 @@ def register(subparsers: argparse._SubParsersAction) -> argparse.ArgumentParser sq_edit.add_argument("--exclude-patterns", default=None, help="Exclude patterns (use \\n)") sq_edit.add_argument("--max-length", type=int, default=None, help="Max message text length") sq_edit.add_argument("--no-max-length", dest="max_length", action="store_const", const=-1) + sq_edit.add_argument("--chats", default=None, help="Chat filter: IDs, usernames or t.me links") + sq_edit.add_argument("--clear-chats", dest="chats", action="store_const", const="") sq_del = sq_sub.add_parser("delete", help="Delete search query") sq_del.add_argument("id", type=int, help="Search query id") diff --git a/src/database/bundles.py b/src/database/bundles.py index fbe7baf8..8f153fdd 100644 --- a/src/database/bundles.py +++ b/src/database/bundles.py @@ -757,11 +757,12 @@ async def get_recent_searches(self, limit: int = 20) -> list[dict]: class SearchQueryBundle: search_queries: SearchQueriesRepository messages: MessagesRepository + channels: ChannelsRepository | None = None @classmethod def from_database(cls, db: "Database") -> "SearchQueryBundle": repos = db.repos - return cls(repos.search_queries, repos.messages) + return cls(repos.search_queries, repos.messages, repos.channels) async def add(self, sq: SearchQuery) -> int: return await self.search_queries.add(sq) @@ -809,6 +810,11 @@ async def get_last_recorded_at(self, query_id: int) -> str | None: async def get_last_recorded_at_all(self) -> dict[int, str]: return await self.search_queries.get_last_recorded_at_all() + async def get_channels(self) -> list[Channel]: + if self.channels is None: + return [] + return await self.channels.get_channels() + @dataclass(frozen=True) class PipelineBundle: diff --git a/src/database/migrations.py b/src/database/migrations.py index 1f315b97..9f0f8516 100644 --- a/src/database/migrations.py +++ b/src/database/migrations.py @@ -64,6 +64,7 @@ "track_stats": "track_stats INTEGER DEFAULT 1", "exclude_patterns": "exclude_patterns TEXT DEFAULT ''", "max_length": "max_length INTEGER DEFAULT NULL", + "chat_filter": "chat_filter TEXT DEFAULT ''", }, "notification_bots": { "tg_username": "tg_username TEXT", diff --git a/src/database/repositories/messages.py b/src/database/repositories/messages.py index adb368af..00173594 100644 --- a/src/database/repositories/messages.py +++ b/src/database/repositories/messages.py @@ -10,6 +10,7 @@ from src.models import Message, SearchQuery from src.utils.datetime import parse_datetime, parse_required_datetime +from src.utils.search_query_chat_filter import parse_chat_filter logger = logging.getLogger(__name__) _DATE_ONLY_RE = re.compile(r"^\d{4}-\d{2}-\d{2}$") @@ -767,6 +768,23 @@ def _build_extra_conditions(sq: SearchQuery) -> tuple[list[str], list]: continue conditions.append("m.text NOT LIKE ?") params.append(f"%{stripped}%") + chat_filter = parse_chat_filter(sq.chat_filter) + if chat_filter.has_filter: + chat_parts = [] + if chat_filter.numeric_values: + placeholders = ", ".join("?" for _ in chat_filter.numeric_values) + chat_parts.append(f"m.channel_id IN ({placeholders})") + params.extend(chat_filter.numeric_values) + chat_parts.append(f"c.id IN ({placeholders})") + params.extend(chat_filter.numeric_values) + if chat_filter.usernames: + placeholders = ", ".join("?" for _ in chat_filter.usernames) + chat_parts.append(f"LOWER(c.username) IN ({placeholders})") + params.extend(chat_filter.usernames) + if chat_parts: + conditions.append("(" + " OR ".join(chat_parts) + ")") + else: + conditions.append("0 = 1") return conditions, params def _build_sq_parts(self, sq: SearchQuery) -> tuple[str, list[str], list]: diff --git a/src/database/repositories/search_queries.py b/src/database/repositories/search_queries.py index 5df72678..1bf99e1a 100644 --- a/src/database/repositories/search_queries.py +++ b/src/database/repositories/search_queries.py @@ -14,8 +14,8 @@ async def add(self, sq: SearchQuery) -> int: cur = await self._db.execute( "INSERT INTO search_queries " "(name, query, is_regex, is_fts, is_active, notify_on_collect, " - "track_stats, interval_minutes, exclude_patterns, max_length) " - "VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)", + "track_stats, interval_minutes, exclude_patterns, max_length, chat_filter) " + "VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)", ( sq.name, sq.query, @@ -27,6 +27,7 @@ async def add(self, sq: SearchQuery) -> int: sq.interval_minutes, sq.exclude_patterns, sq.max_length, + sq.chat_filter, ), ) await self._db.commit() @@ -56,7 +57,7 @@ async def update(self, sq_id: int, sq: SearchQuery) -> None: await self._db.execute( "UPDATE search_queries SET name = ?, query = ?, is_regex = ?, is_fts = ?, " "notify_on_collect = ?, track_stats = ?, interval_minutes = ?, " - "exclude_patterns = ?, max_length = ? " + "exclude_patterns = ?, max_length = ?, chat_filter = ? " "WHERE id = ?", ( sq.name, @@ -68,6 +69,7 @@ async def update(self, sq_id: int, sq: SearchQuery) -> None: sq.interval_minutes, sq.exclude_patterns, sq.max_length, + sq.chat_filter, sq_id, ), ) @@ -163,5 +165,6 @@ def _row_to_model(row) -> SearchQuery: interval_minutes=row["interval_minutes"], exclude_patterns=row["exclude_patterns"] or "", max_length=row["max_length"], + chat_filter=row["chat_filter"] if "chat_filter" in row.keys() and row["chat_filter"] else "", created_at=parse_datetime(row["created_at"]), ) diff --git a/src/database/schema.py b/src/database/schema.py index e691b087..30dcdefc 100644 --- a/src/database/schema.py +++ b/src/database/schema.py @@ -161,6 +161,7 @@ interval_minutes INTEGER NOT NULL DEFAULT 60, exclude_patterns TEXT DEFAULT '', max_length INTEGER DEFAULT NULL, + chat_filter TEXT DEFAULT '', created_at TEXT DEFAULT (datetime('now')) ); diff --git a/src/models.py b/src/models.py index d8982d6c..2c7e0d96 100644 --- a/src/models.py +++ b/src/models.py @@ -363,6 +363,7 @@ class SearchQuery(BaseModel): interval_minutes: int = Field(60, ge=1) exclude_patterns: str = "" max_length: int | None = None + chat_filter: str = "" created_at: datetime | None = None @model_validator(mode="after") diff --git a/src/services/notification_matcher.py b/src/services/notification_matcher.py index b39fa7fb..18c4779b 100644 --- a/src/services/notification_matcher.py +++ b/src/services/notification_matcher.py @@ -3,8 +3,9 @@ import logging import re -from src.models import Message, SearchQuery +from src.models import Channel, Message, SearchQuery from src.telegram.notifier import Notifier +from src.utils.search_query_chat_filter import chat_filter_matches_message logger = logging.getLogger(__name__) @@ -12,8 +13,9 @@ class NotificationMatcher: """Match messages against notification queries and send batched notifications.""" - def __init__(self, notifier: Notifier): + def __init__(self, notifier: Notifier, *, channels: list[Channel] | None = None): self._notifier = notifier + self._channels = channels or [] async def match_and_notify( self, @@ -30,6 +32,8 @@ async def match_and_notify( if not msg.text: continue for sq in queries: + if not chat_filter_matches_message(sq.chat_filter, msg, channels=self._channels): + continue if sq.max_length is not None and len(msg.text) >= sq.max_length: continue if any(p.lower() in msg.text.lower() for p in sq.exclude_patterns_list): diff --git a/src/services/search_query_service.py b/src/services/search_query_service.py index 7c505707..4c9c314c 100644 --- a/src/services/search_query_service.py +++ b/src/services/search_query_service.py @@ -7,6 +7,11 @@ from src.database import Database from src.database.bundles import SearchQueryBundle from src.models import SearchQuery, SearchQueryDailyStat +from src.utils.search_query_chat_filter import ( + ChatFilterValidation, + single_resolved_channel_id, + validate_chat_filter, +) logger = logging.getLogger(__name__) @@ -28,6 +33,7 @@ async def add( track_stats: bool = True, exclude_patterns: str = "", max_length: int | None = None, + chat_filter: str = "", ) -> int: sq = SearchQuery( query=query, @@ -38,6 +44,7 @@ async def add( track_stats=track_stats, exclude_patterns=exclude_patterns, max_length=max_length, + chat_filter=chat_filter, ) return await self._bundle.add(sq) @@ -64,6 +71,7 @@ async def update( track_stats: bool = True, exclude_patterns: str = "", max_length: int | None = None, + chat_filter: str | None = None, ) -> bool: existing = await self._bundle.get_by_id(sq_id) if not existing: @@ -78,6 +86,7 @@ async def update( track_stats=track_stats, exclude_patterns=exclude_patterns, max_length=max_length, + chat_filter=chat_filter if chat_filter is not None else existing.chat_filter, ) await self._bundle.update(sq_id, sq) return True @@ -103,9 +112,14 @@ async def run_once(self, sq_id: int) -> int: async def get_daily_stats(self, sq_id: int, days: int = 30) -> list[SearchQueryDailyStat]: return await self._bundle.get_daily_stats(sq_id, days) + async def validate_chat_filter(self, chat_filter: str) -> ChatFilterValidation: + channels = await self._get_channels() + return validate_chat_filter(chat_filter, channels) + async def get_with_stats(self, days: int = 30) -> list[dict]: queries = await self._bundle.get_all() last_runs = await self._bundle.get_last_recorded_at_all() + channels = await self._get_channels() # Regex queries can't be counted via FTS5; exclude them from FTS stats batch tracked = [sq for sq in queries if sq.track_stats and not sq.is_regex] tracked_ids = {sq.id for sq in tracked} @@ -116,12 +130,15 @@ async def get_with_stats(self, days: int = 30) -> list[dict]: raw = stats_map.get(sq.id, []) if sq.id in tracked_ids else None daily = self._fill_missing_days(raw, days) total = sum(s.count for s in daily) + chat_validation = validate_chat_filter(sq.chat_filter, channels) result.append( { "query": sq, "total_30d": total, "last_run": last_runs.get(sq.id), "daily_stats": daily, + "chat_filter_warnings": chat_validation.warning_text(), + "chat_filter_channel_id": single_resolved_channel_id(sq.chat_filter, channels), } ) return result @@ -147,3 +164,9 @@ def _fill_missing_days( day_str = today.isoformat() filled.append(existing.get(day_str, SearchQueryDailyStat(day=day_str, count=0))) return filled + + async def _get_channels(self): + get_channels = getattr(self._bundle, "get_channels", None) + if get_channels is None: + return [] + return await get_channels() diff --git a/src/telegram/collector.py b/src/telegram/collector.py index a3e94886..842e0543 100644 --- a/src/telegram/collector.py +++ b/src/telegram/collector.py @@ -1136,7 +1136,16 @@ async def _check_notification_queries(self, messages: list[Message]) -> None: from src.services.notification_matcher import NotificationMatcher - matcher = NotificationMatcher(self._notifier) + get_channels = getattr(self._db, "get_channels", None) + channels = [] + if get_channels: + import inspect + + maybe_channels = get_channels() + channels = await maybe_channels if inspect.isawaitable(maybe_channels) else maybe_channels + if not isinstance(channels, list): + channels = [] + matcher = NotificationMatcher(self._notifier, channels=channels) await matcher.match_and_notify(messages, queries) async def _channel_still_exists(self, channel_id: int) -> bool: diff --git a/src/utils/search_query_chat_filter.py b/src/utils/search_query_chat_filter.py new file mode 100644 index 00000000..dcb6a9a5 --- /dev/null +++ b/src/utils/search_query_chat_filter.py @@ -0,0 +1,196 @@ +from __future__ import annotations + +import re +from dataclasses import dataclass +from typing import Literal +from urllib.parse import parse_qs, urlsplit + +from src.models import Channel, Message + +ChatFilterTokenKind = Literal["numeric", "username", "invalid"] + +_USERNAME_RE = re.compile(r"^[A-Za-z0-9_]+$") +_SPLIT_RE = re.compile(r"[\s,]+") + + +@dataclass(frozen=True) +class ChatFilterToken: + raw: str + kind: ChatFilterTokenKind + value: int | str | None = None + + +@dataclass(frozen=True) +class ParsedChatFilter: + entries: tuple[ChatFilterToken, ...] + + @property + def has_filter(self) -> bool: + return bool(self.entries) + + @property + def numeric_values(self) -> tuple[int, ...]: + return tuple(entry.value for entry in self.entries if entry.kind == "numeric" and isinstance(entry.value, int)) + + @property + def usernames(self) -> tuple[str, ...]: + return tuple(entry.value for entry in self.entries if entry.kind == "username" and isinstance(entry.value, str)) + + @property + def invalid_tokens(self) -> tuple[str, ...]: + return tuple(entry.raw for entry in self.entries if entry.kind == "invalid") + + @property + def has_valid_tokens(self) -> bool: + return bool(self.numeric_values or self.usernames) + + +@dataclass(frozen=True) +class ChatFilterValidation: + invalid_tokens: tuple[str, ...] = () + unknown_tokens: tuple[str, ...] = () + matched_channel_ids: tuple[int, ...] = () + + @property + def has_warnings(self) -> bool: + return bool(self.invalid_tokens or self.unknown_tokens) + + def warning_text(self) -> str: + parts = [] + if self.invalid_tokens: + parts.append("некорректные: " + ", ".join(self.invalid_tokens)) + if self.unknown_tokens: + parts.append("не найдены: " + ", ".join(self.unknown_tokens)) + return "Чаты в фильтре сохранены, но есть предупреждения: " + "; ".join(parts) if parts else "" + + +def parse_chat_filter(raw_filter: str | None) -> ParsedChatFilter: + entries: list[ChatFilterToken] = [] + seen: set[tuple[ChatFilterTokenKind, int | str | None]] = set() + for raw in _SPLIT_RE.split((raw_filter or "").strip()): + token = raw.strip().strip(",;") + if not token: + continue + entry = _parse_token(token) + key = (entry.kind, entry.value if entry.kind != "invalid" else entry.raw) + if key in seen: + continue + seen.add(key) + entries.append(entry) + return ParsedChatFilter(tuple(entries)) + + +def validate_chat_filter(raw_filter: str | None, channels: list[Channel]) -> ChatFilterValidation: + parsed = parse_chat_filter(raw_filter) + if not parsed.has_filter: + return ChatFilterValidation() + + matched_ids: set[int] = set() + unknown: list[str] = [] + invalid = list(parsed.invalid_tokens) + + channels_by_username = { + (ch.username or "").lower(): ch + for ch in channels + if ch.username + } + for entry in parsed.entries: + if entry.kind == "invalid": + continue + if entry.kind == "numeric" and isinstance(entry.value, int): + matches = [ + ch + for ch in channels + if ch.channel_id == entry.value or ch.id == entry.value + ] + elif entry.kind == "username" and isinstance(entry.value, str): + match = channels_by_username.get(entry.value) + matches = [match] if match else [] + else: + matches = [] + + if matches: + matched_ids.update(ch.channel_id for ch in matches) + else: + unknown.append(entry.raw) + + return ChatFilterValidation( + invalid_tokens=tuple(invalid), + unknown_tokens=tuple(unknown), + matched_channel_ids=tuple(sorted(matched_ids)), + ) + + +def chat_filter_matches_message( + raw_filter: str | None, + msg: Message, + *, + channels: list[Channel] | None = None, +) -> bool: + parsed = parse_chat_filter(raw_filter) + if not parsed.has_filter: + return True + if not parsed.has_valid_tokens: + return False + + numeric_values = set(parsed.numeric_values) + usernames = set(parsed.usernames) + if msg.channel_id in numeric_values: + return True + if msg.channel_username and msg.channel_username.lower() in usernames: + return True + + for ch in channels or []: + if ch.channel_id != msg.channel_id: + continue + if ch.id in numeric_values or ch.channel_id in numeric_values: + return True + if ch.username and ch.username.lower() in usernames: + return True + return False + + +def single_resolved_channel_id(raw_filter: str | None, channels: list[Channel]) -> int | None: + validation = validate_chat_filter(raw_filter, channels) + if validation.invalid_tokens or validation.unknown_tokens: + return None + if len(validation.matched_channel_ids) == 1: + return validation.matched_channel_ids[0] + return None + + +def _parse_token(token: str) -> ChatFilterToken: + normalized = _normalize_token(token) + if not normalized: + return ChatFilterToken(raw=token, kind="invalid") + try: + return ChatFilterToken(raw=token, kind="numeric", value=int(normalized)) + except ValueError: + pass + if _USERNAME_RE.match(normalized): + return ChatFilterToken(raw=token, kind="username", value=normalized.lower()) + return ChatFilterToken(raw=token, kind="invalid") + + +def _normalize_token(token: str) -> str | None: + token = token.strip() + if not token: + return None + if token.startswith("@"): + return token[1:].strip() + if token.startswith(("https://", "http://", "t.me/", "telegram.me/")): + url = token if token.startswith(("https://", "http://")) else f"https://{token}" + parsed = urlsplit(url) + host = parsed.netloc.lower() + if host not in {"t.me", "telegram.me", "www.t.me", "www.telegram.me"}: + return None + query_domain = parse_qs(parsed.query).get("domain") + if query_domain: + return query_domain[0].strip().lstrip("@") + parts = [part for part in parsed.path.split("/") if part] + if not parts: + return None + if parts[0] == "c" and len(parts) > 1 and parts[1].isdigit(): + return f"-100{parts[1]}" + return parts[0].strip().lstrip("@") + return token diff --git a/src/web/bootstrap.py b/src/web/bootstrap.py index 332946fb..d18bb688 100644 --- a/src/web/bootstrap.py +++ b/src/web/bootstrap.py @@ -254,7 +254,7 @@ async def build_container_with_templates( search_pool = None search_engine = SearchEngine(search_bundle, search_pool, config=config) ai_search = AISearchEngine(config.llm, search_bundle) - search_query_bundle = SearchQueryBundle(repos.search_queries, repos.messages) + search_query_bundle = SearchQueryBundle(repos.search_queries, repos.messages, repos.channels) from src.services.collection_service import CollectionService diff --git a/src/web/routes/search_queries.py b/src/web/routes/search_queries.py index e441bf56..d6820a6b 100644 --- a/src/web/routes/search_queries.py +++ b/src/web/routes/search_queries.py @@ -28,6 +28,7 @@ async def add_search_query( track_stats: bool = Form(False), exclude_patterns: str = Form(""), max_length: int | None = Form(None), + chat_filter: str = Form(""), ): if not query.strip(): return flash_redirect("/search-queries", error="invalid_value") @@ -42,13 +43,19 @@ async def add_search_query( track_stats=track_stats, exclude_patterns=exclude_patterns, max_length=max_length, + chat_filter=chat_filter, ) except ValidationError: return flash_redirect("/search-queries", error="invalid_value") + chat_validation = await svc.validate_chat_filter(chat_filter) scheduler = deps.get_scheduler(request) if scheduler.is_running: await scheduler.sync_search_query_jobs() - return flash_redirect("/search-queries", msg="sq_added") + return flash_redirect( + "/search-queries", + msg="sq_added", + extra={"warning": chat_validation.warning_text() or None}, + ) @router.post("/{sq_id}/toggle") @@ -73,6 +80,7 @@ async def edit_search_query( track_stats: bool = Form(False), exclude_patterns: str = Form(""), max_length: int | None = Form(None), + chat_filter: str = Form(""), ): if not query.strip(): return flash_redirect("/search-queries", error="invalid_value") @@ -88,13 +96,19 @@ async def edit_search_query( track_stats=track_stats, exclude_patterns=exclude_patterns, max_length=max_length, + chat_filter=chat_filter, ) except ValidationError: return flash_redirect("/search-queries", error="invalid_value") + chat_validation = await svc.validate_chat_filter(chat_filter) scheduler = deps.get_scheduler(request) if scheduler.is_running: await scheduler.sync_search_query_jobs() - return flash_redirect("/search-queries", msg="sq_edited") + return flash_redirect( + "/search-queries", + msg="sq_edited", + extra={"warning": chat_validation.warning_text() or None}, + ) @router.post("/{sq_id}/delete") diff --git a/src/web/static/js/app.js b/src/web/static/js/app.js index a81f4390..31756c7e 100644 --- a/src/web/static/js/app.js +++ b/src/web/static/js/app.js @@ -4,6 +4,7 @@ var params = new URLSearchParams(window.location.search); var code = params.get("msg"); var errCode = params.get("error"); + var warningText = params.get("warning"); if (code) window.__flashMsg = code; if (errCode) window.__flashError = errCode; var container = document.getElementById("flash-container"); @@ -22,9 +23,17 @@ } container.appendChild(errDiv); } - if (code || errCode) { + if (warningText) { + var warnDiv = document.createElement("div"); + warnDiv.className = "alert alert-warning"; + warnDiv.setAttribute("role", "alert"); + warnDiv.textContent = warningText; + container.appendChild(warnDiv); + } + if (code || errCode || warningText) { params.delete("msg"); params.delete("error"); + params.delete("warning"); var qs = params.toString(); var url = window.location.pathname + (qs ? "?" + qs : "") + window.location.hash; window.history.replaceState(null, "", url); diff --git a/src/web/templates/search_queries.html b/src/web/templates/search_queries.html index cbec7c5c..459b0613 100644 --- a/src/web/templates/search_queries.html +++ b/src/web/templates/search_queries.html @@ -21,6 +21,10 @@

Поисковые запросы

+
+ + +
@@ -61,6 +65,7 @@

Поисковые запросы

Запрос + Чаты Режим Интервал Активен @@ -74,6 +79,14 @@

Поисковые запросы

{% set sq = item.query %} {{ sq.query }}{% if sq.is_regex %} (regex){% endif %}{% if sq.is_fts %} (fts){% endif %} + + {% if sq.chat_filter %} + {{ sq.chat_filter }} + {% if item.chat_filter_warnings %}
{{ item.chat_filter_warnings }}
{% endif %} + {% else %} + Все + {% endif %} + {% if sq.notify_on_collect %}{{ icon("notification") }}{% endif %} {% if sq.track_stats %}{{ icon("stats") }}{% endif %} @@ -86,7 +99,7 @@

Поисковые запросы

{% if sq.is_regex %} {{ icon("search") }} {% else %} - {{ icon("search") }} {% endif %} @@ -104,7 +117,7 @@

Поисковые запросы

- +
@@ -116,6 +129,10 @@

Поисковые запросы

+
+ + +
@@ -151,7 +168,7 @@

Поисковые запросы

{% if item.daily_stats %} - +
График по дням ({{ item.daily_stats | length }} дн.)
@@ -207,9 +224,17 @@

Поисковые запросы

· {{ item.total_30d }} за 30д {% if item.last_run %}· {{ item.last_run[:16] }}{% endif %} +
+ {% if sq.chat_filter %} + Чаты: {{ sq.chat_filter }} + {% if item.chat_filter_warnings %}
{{ item.chat_filter_warnings }}
{% endif %} + {% else %} + Все чаты + {% endif %} +
{% if not sq.is_regex %} - {{ icon("search") }} {% endif %} diff --git a/tests/repositories/test_messages_repository.py b/tests/repositories/test_messages_repository.py index d98e5c8c..d89b07df 100644 --- a/tests/repositories/test_messages_repository.py +++ b/tests/repositories/test_messages_repository.py @@ -525,6 +525,21 @@ async def test_count_fts_matches_for_query_with_exclude_patterns(messages_repo): assert count == 1 +async def test_count_fts_matches_for_query_with_chat_filter(messages_repo): + """Test counting FTS matches restricted to selected chats.""" + sq = SearchQuery(query="hello", is_fts=True, chat_filter="1") + + await messages_repo.insert_messages_batch( + [ + make_message(1, 100, "Hello world"), + make_message(2, 101, "Hello there"), + ] + ) + + count = await messages_repo.count_fts_matches_for_query(sq) + assert count == 1 + + # get_fts_daily_stats_for_query tests @@ -569,6 +584,24 @@ async def test_get_fts_daily_stats_batch(messages_repo): assert 2 in result +async def test_get_fts_daily_stats_batch_with_chat_filter(messages_repo): + """Test batch FTS daily stats respects chat filters per query.""" + sq1 = SearchQuery(id=1, query="hello", is_fts=True, chat_filter="10") + sq2 = SearchQuery(id=2, query="hello", is_fts=True, chat_filter="20") + + await messages_repo.insert_messages_batch( + [ + make_message(10, 100, "Hello one"), + make_message(20, 101, "Hello two"), + ] + ) + + result = await messages_repo.get_fts_daily_stats_batch([sq1, sq2], days=7) + + assert sum(s.count for s in result[1]) == 1 + assert sum(s.count for s in result[2]) == 1 + + async def test_get_fts_daily_stats_batch_empty(messages_repo): """Test batch FTS stats with empty list.""" result = await messages_repo.get_fts_daily_stats_batch([], days=7) diff --git a/tests/repositories/test_search_queries_repository.py b/tests/repositories/test_search_queries_repository.py index 7e29c29d..86717614 100644 --- a/tests/repositories/test_search_queries_repository.py +++ b/tests/repositories/test_search_queries_repository.py @@ -47,6 +47,7 @@ async def test_add_with_all_fields(search_queries_repo): interval_minutes=30, exclude_patterns="spam\njunk", max_length=500, + chat_filter="@chat_one, -1001", ) pk = await search_queries_repo.add(sq) @@ -58,6 +59,7 @@ async def test_add_with_all_fields(search_queries_repo): assert result.interval_minutes == 30 assert result.exclude_patterns == "spam\njunk" assert result.max_length == 500 + assert result.chat_filter == "@chat_one, -1001" # get_all tests @@ -151,13 +153,14 @@ async def test_update_query(search_queries_repo): """Test updating a query.""" pk = await search_queries_repo.add(make_query("old query")) - updated = make_query("new query", is_regex=True, interval_minutes=120) + updated = make_query("new query", is_regex=True, interval_minutes=120, chat_filter="@updated") await search_queries_repo.update(pk, updated) result = await search_queries_repo.get_by_id(pk) assert result.query == "new query" assert result.is_regex is True assert result.interval_minutes == 120 + assert result.chat_filter == "@updated" async def test_update_preserves_id(search_queries_repo): diff --git a/tests/routes/test_search_queries_routes.py b/tests/routes/test_search_queries_routes.py index 63c6e83d..c2ed414a 100644 --- a/tests/routes/test_search_queries_routes.py +++ b/tests/routes/test_search_queries_routes.py @@ -66,6 +66,22 @@ async def test_add_search_query_with_all_fields(route_client): assert "msg=sq_added" in resp.headers["location"] +@pytest.mark.anyio +async def test_add_search_query_with_chat_filter_saves_and_warns(route_client, db): + """Chat filter is saved even when it cannot be resolved locally.""" + resp = await route_client.post( + "/search-queries/add", + data={"query": "scoped query", "interval_minutes": "60", "chat_filter": "@missing_chat"}, + follow_redirects=False, + ) + + assert resp.status_code == 303 + assert "msg=sq_added" in resp.headers["location"] + assert "warning=" in resp.headers["location"] + queries = await db.repos.search_queries.get_all() + assert queries[0].chat_filter == "@missing_chat" + + @pytest.mark.anyio async def test_toggle_search_query(route_client): """Test toggle search query.""" diff --git a/tests/test_agent_tools_search_queries.py b/tests/test_agent_tools_search_queries.py index a4a0d73c..40464b34 100644 --- a/tests/test_agent_tools_search_queries.py +++ b/tests/test_agent_tools_search_queries.py @@ -77,6 +77,7 @@ async def test_found_shows_fields(self, db, sq_handlers): assert f"id: {sq_id}" in text assert "is_active" in text assert "interval_minutes" in text + assert "chat_filter" in text class TestAddSearchQueryTool: @@ -104,6 +105,15 @@ async def test_with_custom_interval_creates(self, sq_handlers): ) assert "создан" in _text(result) + @pytest.mark.anyio + async def test_with_chat_filter_creates_and_warns(self, sq_handlers): + result = await sq_handlers["add_search_query"]( + {"query": "chat scoped", "chat_filter": "@missing_chat", "confirm": True} + ) + text = _text(result) + assert "создан" in text + assert "Предупреждение" in text + class TestEditSearchQueryTool: @pytest.mark.anyio @@ -137,6 +147,16 @@ async def test_updates_interval(self, db, sq_handlers): ) assert "обновлён" in _text(result) + @pytest.mark.anyio + async def test_updates_chat_filter(self, db, sq_handlers): + sq_id = await _add_query(db, "some query") + result = await sq_handlers["edit_search_query"]( + {"sq_id": sq_id, "chat_filter": "@missing_chat", "confirm": True} + ) + text = _text(result) + assert "обновлён" in text + assert "Предупреждение" in text + class TestDeleteSearchQueryTool: @pytest.mark.anyio diff --git a/tests/test_cli_process_database_repository_paths.py b/tests/test_cli_process_database_repository_paths.py index bae3e3ee..786d5af3 100644 --- a/tests/test_cli_process_database_repository_paths.py +++ b/tests/test_cli_process_database_repository_paths.py @@ -1117,6 +1117,7 @@ async def test_migrations_search_queries_columns(): assert "is_regex" in cols assert "notify_on_collect" in cols assert "is_fts" in cols + assert "chat_filter" in cols @pytest.mark.anyio diff --git a/tests/test_migrations_worker_bootstrap_paths.py b/tests/test_migrations_worker_bootstrap_paths.py index 65b2a4da..437113be 100644 --- a/tests/test_migrations_worker_bootstrap_paths.py +++ b/tests/test_migrations_worker_bootstrap_paths.py @@ -288,6 +288,7 @@ async def test_run_migrations_adds_search_query_columns(fresh_db): assert "is_fts" in cols assert "exclude_patterns" in cols assert "max_length" in cols + assert "chat_filter" in cols @pytest.mark.anyio diff --git a/tests/test_notification_matcher.py b/tests/test_notification_matcher.py index 20a8c0c7..dc64b63c 100644 --- a/tests/test_notification_matcher.py +++ b/tests/test_notification_matcher.py @@ -36,6 +36,7 @@ def make_query( is_fts: bool = False, exclude_patterns: str = "", max_length: int | None = None, + chat_filter: str = "", ) -> SearchQuery: return SearchQuery( id=sq_id, @@ -44,6 +45,7 @@ def make_query( is_fts=is_fts, exclude_patterns=exclude_patterns, max_length=max_length, + chat_filter=chat_filter, ) @@ -209,6 +211,37 @@ async def test_match_and_notify_max_length_filter(): assert result == {1: 1} +@pytest.mark.anyio +async def test_match_and_notify_respects_chat_filter(): + """Chat filters limit live notification matches.""" + notifier = AsyncMock() + matcher = NotificationMatcher(notifier) + + messages = [ + make_message("hello one", channel_id=1, message_id=1), + make_message("hello two", channel_id=2, message_id=2), + ] + queries = [make_query("hello", chat_filter="2")] + + result = await matcher.match_and_notify(messages, queries) + + assert result == {1: 1} + + +@pytest.mark.anyio +async def test_match_and_notify_unknown_chat_filter_matches_nothing(): + """A non-empty unknown chat filter must not fall back to all chats.""" + notifier = AsyncMock() + matcher = NotificationMatcher(notifier) + + messages = [make_message("hello one", channel_id=1, message_id=1)] + queries = [make_query("hello", chat_filter="missing_chat")] + + result = await matcher.match_and_notify(messages, queries) + + assert result == {} + + @pytest.mark.anyio async def test_match_and_notify_message_no_text(): """Messages with None text are skipped.""" diff --git a/tests/test_search_queries.py b/tests/test_search_queries.py index a83891ac..499822a0 100644 --- a/tests/test_search_queries.py +++ b/tests/test_search_queries.py @@ -242,6 +242,31 @@ async def test_max_length_filter(bundle, db): assert total == 1 # only "short" passes length filter +@pytest.mark.anyio +async def test_chat_filter_limits_fts_stats(bundle, db): + await _insert_messages(db, ["target in one"], channel_id=401) + await _insert_messages(db, ["target in two"], channel_id=402) + + sq = SearchQuery(query="target", chat_filter="401") + sq_id = await bundle.add(sq) + sq = await bundle.get_by_id(sq_id) + stats = await bundle.get_fts_daily_stats_for_query(sq, days=30) + + assert sum(s.count for s in stats) == 1 + + +@pytest.mark.anyio +async def test_unknown_chat_filter_matches_nothing(bundle, db): + await _insert_messages(db, ["target in one"], channel_id=411) + + sq = SearchQuery(query="target", chat_filter="definitely_unknown") + sq_id = await bundle.add(sq) + sq = await bundle.get_by_id(sq_id) + stats = await bundle.get_fts_daily_stats_for_query(sq, days=30) + + assert sum(s.count for s in stats) == 0 + + @pytest.mark.anyio async def test_fts_collector_matching(): """Test the _fts_query_matches function.""" diff --git a/tests/test_search_query_service.py b/tests/test_search_query_service.py index 1da15b04..df9ea04c 100644 --- a/tests/test_search_query_service.py +++ b/tests/test_search_query_service.py @@ -98,6 +98,7 @@ async def test_add_creates_search_query(service, bundle): track_stats=True, exclude_patterns="spam", max_length=500, + chat_filter="@chat_one", ) assert sq_id == 1 @@ -109,6 +110,7 @@ async def test_add_creates_search_query(service, bundle): assert stored.notify_on_collect is True assert stored.exclude_patterns == "spam" assert stored.max_length == 500 + assert stored.chat_filter == "@chat_one" @pytest.mark.anyio @@ -195,6 +197,7 @@ async def test_update_modifies_query(service, bundle): track_stats=False, exclude_patterns="", max_length=None, + chat_filter="@chat_two", ) assert result is True @@ -203,6 +206,7 @@ async def test_update_modifies_query(service, bundle): assert stored.query == "new query" assert stored.interval_minutes == 120 assert stored.is_active is False # Preserved + assert stored.chat_filter == "@chat_two" @pytest.mark.anyio From a47806668c87a8956494342dbf01b81609b2f88f Mon Sep 17 00:00:00 2001 From: axisrow Date: Mon, 11 May 2026 00:29:20 +0700 Subject: [PATCH 2/2] fix(search): handle t.me preview links in chat filters --- src/utils/search_query_chat_filter.py | 4 ++++ tests/test_notification_matcher.py | 17 ++++++++++++++++ tests/test_search_queries.py | 28 +++++++++++++++++++++++++-- 3 files changed, 47 insertions(+), 2 deletions(-) diff --git a/src/utils/search_query_chat_filter.py b/src/utils/search_query_chat_filter.py index dcb6a9a5..9fbaff73 100644 --- a/src/utils/search_query_chat_filter.py +++ b/src/utils/search_query_chat_filter.py @@ -190,6 +190,10 @@ def _normalize_token(token: str) -> str | None: parts = [part for part in parsed.path.split("/") if part] if not parts: return None + if parts[0] == "s": + parts = parts[1:] + if not parts: + return None if parts[0] == "c" and len(parts) > 1 and parts[1].isdigit(): return f"-100{parts[1]}" return parts[0].strip().lstrip("@") diff --git a/tests/test_notification_matcher.py b/tests/test_notification_matcher.py index dc64b63c..028e8894 100644 --- a/tests/test_notification_matcher.py +++ b/tests/test_notification_matcher.py @@ -228,6 +228,23 @@ async def test_match_and_notify_respects_chat_filter(): assert result == {1: 1} +@pytest.mark.anyio +async def test_match_and_notify_tme_s_link_chat_filter(): + """t.me/s/{username} chat filters match the referenced channel.""" + notifier = AsyncMock() + matcher = NotificationMatcher(notifier) + + messages = [ + make_message("hello one", channel_id=1, message_id=1, channel_username="public_chat"), + make_message("hello two", channel_id=2, message_id=2, channel_username="other_chat"), + ] + queries = [make_query("hello", chat_filter="https://t.me/s/public_chat/123")] + + result = await matcher.match_and_notify(messages, queries) + + assert result == {1: 1} + + @pytest.mark.anyio async def test_match_and_notify_unknown_chat_filter_matches_nothing(): """A non-empty unknown chat filter must not fall back to all chats.""" diff --git a/tests/test_search_queries.py b/tests/test_search_queries.py index 499822a0..86740db6 100644 --- a/tests/test_search_queries.py +++ b/tests/test_search_queries.py @@ -19,11 +19,11 @@ def svc(bundle): return SearchQueryService(bundle) -async def _insert_messages(db, texts, channel_id=100, base_date=None): +async def _insert_messages(db, texts, channel_id=100, base_date=None, username=None): """Helper: add a channel and insert messages with given texts.""" if base_date is None: base_date = datetime.now() - ch = Channel(channel_id=channel_id, title="test") + ch = Channel(channel_id=channel_id, title="test", username=username) await db.repos.channels.add_channel(ch) for i, text in enumerate(texts): msg = Message( @@ -255,6 +255,30 @@ async def test_chat_filter_limits_fts_stats(bundle, db): assert sum(s.count for s in stats) == 1 +@pytest.mark.anyio +async def test_chat_filter_tme_s_link_limits_fts_stats(bundle, db): + await _insert_messages(db, ["target in public"], channel_id=405, username="public_chat") + await _insert_messages(db, ["target in other"], channel_id=406, username="other_chat") + + sq = SearchQuery(query="target", chat_filter="https://t.me/s/public_chat/123") + sq_id = await bundle.add(sq) + sq = await bundle.get_by_id(sq_id) + stats = await bundle.get_fts_daily_stats_for_query(sq, days=30) + + assert sum(s.count for s in stats) == 1 + + +@pytest.mark.anyio +async def test_validate_chat_filter_tme_s_link_resolves_channel(svc, db): + await db.repos.channels.add_channel(Channel(channel_id=407, title="Public", username="public_chat")) + + validation = await svc.validate_chat_filter("https://t.me/s/public_chat/123") + + assert validation.invalid_tokens == () + assert validation.unknown_tokens == () + assert validation.matched_channel_ids == (407,) + + @pytest.mark.anyio async def test_unknown_chat_filter_matches_nothing(bundle, db): await _insert_messages(db, ["target in one"], channel_id=411)