diff --git a/scripts/whatsapp_bridge/bridge.mjs b/scripts/whatsapp_bridge/bridge.mjs index 6252a177..5cd90b87 100644 --- a/scripts/whatsapp_bridge/bridge.mjs +++ b/scripts/whatsapp_bridge/bridge.mjs @@ -60,9 +60,9 @@ function renderQrTerminal(qr) { }); } -function normalizeDirectJid(raw) { +function normalizeChatJid(raw) { if (!raw) return ""; - if (raw.includes("@g.us") || raw.includes("@broadcast") || raw === "status@broadcast") { + if (raw.includes("@broadcast") || raw === "status@broadcast") { return ""; } if (raw.includes("@")) { @@ -255,9 +255,10 @@ async function bootstrapSocket() { for (const item of messages || []) { if (!item?.message) continue; const remoteJid = item?.key?.remoteJid || ""; - if (remoteJid.includes("@g.us") || remoteJid.includes("@broadcast") || remoteJid === "status@broadcast") { + if (remoteJid.includes("@broadcast") || remoteJid === "status@broadcast") { continue; } + const group = remoteJid.includes("@g.us"); const fromMe = Boolean(item?.key?.fromMe); const messageId = String(item?.key?.id || "").trim(); if (fromMe && messageId && outboundMessageIds.has(messageId)) { @@ -265,17 +266,17 @@ async function bootstrapSocket() { continue; } const selfNumber = senderFromJid(selfId); - const senderJid = fromMe ? (selfId || sock.user?.id || "") : remoteJid; - const sender = senderFromJid(remoteJid); + const senderJid = fromMe ? (selfId || sock.user?.id || "") : (group ? (item?.key?.participant || "") : remoteJid); const actualSender = senderFromJid(senderJid); - const conversation = sender; + const conversation = group ? remoteJid : senderFromJid(remoteJid); const text = extractText(item.message).trim(); const mediaPayload = await extractMediaPayload(item); - const selfChat = Boolean(fromMe && selfNumber && conversation && selfNumber === conversation); + const selfChat = Boolean(!group && fromMe && selfNumber && conversation && selfNumber === conversation); if (!actualSender || !conversation || (!text && !mediaPayload)) continue; await postInbound({ sender: actualSender, conversation, + group, fromMe, self: selfNumber, selfChat, @@ -339,7 +340,7 @@ const server = http.createServer(async (req, res) => { } if (req.method === "POST" && url.pathname === "/send") { const payload = await readJson(req); - const to = normalizeDirectJid(payload.to || ""); + const to = normalizeChatJid(payload.to || ""); const text = String(payload.text || "").trim(); if (!sock || !to || !text) { return await jsonResponse(res, 400, { ok: false, error: "missing_to_or_text" }); @@ -350,7 +351,7 @@ const server = http.createServer(async (req, res) => { } if (req.method === "POST" && url.pathname === "/send-file") { const payload = await readJson(req); - const to = normalizeDirectJid(payload.to || ""); + const to = normalizeChatJid(payload.to || ""); const filePath = String(payload.path || "").trim(); const caption = String(payload.caption || "").trim(); const kind = String(payload.kind || "").trim(); @@ -396,8 +397,8 @@ const server = http.createServer(async (req, res) => { } if (req.method === "POST" && url.pathname === "/react") { const payload = await readJson(req); - const to = normalizeDirectJid(payload.to || payload.remoteJid || ""); - const remoteJid = normalizeDirectJid(payload.remoteJid || payload.to || ""); + const to = normalizeChatJid(payload.to || payload.remoteJid || ""); + const remoteJid = normalizeChatJid(payload.remoteJid || payload.to || ""); const emoji = String(payload.emoji || "").trim(); const messageId = String(payload.messageId || "").trim(); const targetFromMe = Boolean(payload.targetFromMe); diff --git a/src/octopal/channels/group_addressing.py b/src/octopal/channels/group_addressing.py new file mode 100644 index 00000000..75b96f3c --- /dev/null +++ b/src/octopal/channels/group_addressing.py @@ -0,0 +1,195 @@ +from __future__ import annotations + +import json +from dataclasses import dataclass +from typing import Any, Literal + +import structlog + +from octopal.infrastructure.providers.base import InferenceProvider, Message + +logger = structlog.get_logger(__name__) + +GroupAddressingAction = Literal["ignore", "respond_self", "respond_all_agents", "continue_thread"] + + +@dataclass(frozen=True) +class GroupAddressingIdentity: + agent_name: str + agent_aliases: list[str] + collective_aliases: list[str] + + +@dataclass(frozen=True) +class GroupAddressingDecision: + action: GroupAddressingAction + reason: str = "" + confidence: float = 0.0 + + @property + def should_process(self) -> bool: + return self.action != "ignore" + + +def resolve_group_addressing_identity(settings: Any) -> GroupAddressingIdentity: + configured_name = str(getattr(settings, "group_agent_name", "") or "").strip() + a2a_config = getattr(settings, "a2a", None) + a2a_name = str(getattr(a2a_config, "agent_name", "") or "").strip() + agent_name = configured_name or a2a_name or "Octopal" + + agent_aliases = _split_aliases(getattr(settings, "group_agent_aliases", "")) + collective_aliases = _split_aliases(getattr(settings, "group_collective_aliases", "")) + + agent_aliases = _dedupe([agent_name, *agent_aliases]) + collective_aliases = _dedupe(collective_aliases) + return GroupAddressingIdentity( + agent_name=agent_name, + agent_aliases=agent_aliases, + collective_aliases=collective_aliases, + ) + + +async def decide_group_addressing( + *, + provider: InferenceProvider | None, + settings: Any, + channel: str, + chat_id: int, + text: str, + has_attachments: bool = False, + reply_to_agent: bool = False, + sender_label: str | None = None, +) -> GroupAddressingDecision: + if not bool(getattr(settings, "group_addressing_enabled", True)): + return GroupAddressingDecision("respond_self", "group addressing disabled", 1.0) + if reply_to_agent: + return GroupAddressingDecision("continue_thread", "message replies to this agent", 1.0) + + clean_text = (text or "").strip() + if not clean_text: + reason = "attachment-only group message without an explicit reply" + if has_attachments: + return GroupAddressingDecision("ignore", reason, 1.0) + return GroupAddressingDecision("ignore", "empty group message", 1.0) + + if provider is None: + return GroupAddressingDecision("ignore", "no provider available for group addressing", 0.0) + + identity = resolve_group_addressing_identity(settings) + messages = [ + Message( + role="system", + content=( + "You are a strict group-chat addressing gate for an AI agent. " + "Decide whether the incoming group-chat message is addressed to this agent, " + "to all agents, or to nobody. Use semantic understanding, not substring rules. " + "Return only compact JSON with keys action, confidence, reason. " + "action must be one of: ignore, respond_self, respond_all_agents, continue_thread. " + "Use respond_self for clear direct requests to this agent by name, alias, role, or " + "unambiguous second-person address. Use respond_all_agents when the user addresses " + "all agents collectively. Use ignore when the message is for other named agents, " + "for humans, or is ambient group conversation." + ), + ), + Message( + role="user", + content=json.dumps( + { + "channel": channel, + "chat_id": chat_id, + "sender": sender_label or "", + "agent_name": identity.agent_name, + "agent_aliases": identity.agent_aliases, + "collective_aliases": identity.collective_aliases, + "has_attachments": has_attachments, + "message": clean_text, + }, + ensure_ascii=False, + ), + ), + ] + + try: + raw = await provider.complete(messages) + except Exception: + logger.warning("Group addressing provider call failed", chat_id=chat_id, exc_info=True) + return GroupAddressingDecision("ignore", "group addressing provider failed", 0.0) + + decision = _parse_decision(raw) + logger.debug( + "Group addressing decision", + channel=channel, + chat_id=chat_id, + action=decision.action, + confidence=decision.confidence, + reason=decision.reason, + ) + return decision + + +def _parse_decision(raw: str) -> GroupAddressingDecision: + payload = _extract_json_object(raw) + if payload is None: + return GroupAddressingDecision("ignore", "invalid group addressing JSON", 0.0) + + action = str(payload.get("action", "") or "").strip().lower() + if action not in {"ignore", "respond_self", "respond_all_agents", "continue_thread"}: + return GroupAddressingDecision("ignore", "unknown group addressing action", 0.0) + + confidence = _coerce_confidence(payload.get("confidence")) + reason = str(payload.get("reason", "") or "").strip() + return GroupAddressingDecision(action=action, reason=reason, confidence=confidence) # type: ignore[arg-type] + + +def _extract_json_object(raw: str) -> dict[str, Any] | None: + text = (raw or "").strip() + if text.startswith("```"): + lines = text.splitlines() + if lines and lines[0].startswith("```"): + lines = lines[1:] + if lines and lines[-1].strip() == "```": + lines = lines[:-1] + text = "\n".join(lines).strip() + + candidates = [text] + start = text.find("{") + end = text.rfind("}") + if start >= 0 and end > start: + candidates.append(text[start : end + 1]) + + for candidate in candidates: + try: + parsed = json.loads(candidate) + except json.JSONDecodeError: + continue + if isinstance(parsed, dict): + return parsed + return None + + +def _coerce_confidence(raw: object) -> float: + try: + value = float(raw) + except (TypeError, ValueError): + return 0.0 + return max(0.0, min(1.0, value)) + + +def _split_aliases(raw: object) -> list[str]: + if isinstance(raw, (list, tuple)): + values = [str(item).strip() for item in raw] + else: + values = [chunk.strip() for chunk in str(raw or "").split(",")] + return [value for value in values if value] + + +def _dedupe(values: list[str]) -> list[str]: + out: list[str] = [] + seen: set[str] = set() + for value in values: + key = value.casefold() + if not value or key in seen: + continue + seen.add(key) + out.append(value) + return out diff --git a/src/octopal/channels/telegram/handlers.py b/src/octopal/channels/telegram/handlers.py index fa91dd2e..9f13bd29 100644 --- a/src/octopal/channels/telegram/handlers.py +++ b/src/octopal/channels/telegram/handlers.py @@ -18,6 +18,7 @@ from aiogram.filters import Command, CommandObject from aiogram.types import CallbackQuery, FSInputFile, Message, ReactionTypeEmoji +from octopal.channels.group_addressing import decide_group_addressing from octopal.channels.telegram.access import is_allowed_chat, parse_allowed_chat_ids from octopal.channels.telegram.approvals import ApprovalManager from octopal.infrastructure.config.settings import Settings @@ -291,10 +292,14 @@ async def _internal_typing_control(chat_id: int, active: bool) -> None: import importlib.metadata @dp.message(Command("help")) - async def cmd_help(message: Message) -> None: + async def cmd_help(message: Message, command: CommandObject) -> None: if not is_allowed_chat(message.chat.id, allowed_chat_ids): await _reject_unauthorized_message(message) return + if not await _telegram_group_command_should_run( + message, command=command, settings=settings, octo=octo, bot=bot + ): + return help_text = ( "Available commands:\n" "/help - Show this help message\n" @@ -306,10 +311,14 @@ async def cmd_help(message: Message) -> None: await message.answer(help_text) @dp.message(Command("version")) - async def cmd_version(message: Message) -> None: + async def cmd_version(message: Message, command: CommandObject) -> None: if not is_allowed_chat(message.chat.id, allowed_chat_ids): await _reject_unauthorized_message(message) return + if not await _telegram_group_command_should_run( + message, command=command, settings=settings, octo=octo, bot=bot + ): + return try: version = importlib.metadata.version("octopal") except importlib.metadata.PackageNotFoundError: @@ -317,10 +326,14 @@ async def cmd_version(message: Message) -> None: await message.answer(f"Octopal v{version}") @dp.message(Command("status")) - async def cmd_status(message: Message) -> None: + async def cmd_status(message: Message, command: CommandObject) -> None: if not is_allowed_chat(message.chat.id, allowed_chat_ids): await _reject_unauthorized_message(message) return + if not await _telegram_group_command_should_run( + message, command=command, settings=settings, octo=octo, bot=bot + ): + return active_workers = await asyncio.to_thread(octo.store.get_active_workers) metrics = read_metrics_snapshot(settings.state_dir) octo_status = build_octo_status((metrics or {}).get("octo", {})) @@ -343,10 +356,14 @@ async def cmd_status(message: Message) -> None: await message.answer(status_text, parse_mode="Markdown") @dp.message(Command("workers")) - async def cmd_workers(message: Message) -> None: + async def cmd_workers(message: Message, command: CommandObject) -> None: if not is_allowed_chat(message.chat.id, allowed_chat_ids): await _reject_unauthorized_message(message) return + if not await _telegram_group_command_should_run( + message, command=command, settings=settings, octo=octo, bot=bot + ): + return templates = await asyncio.to_thread(octo.store.list_worker_templates) if not templates: await message.answer("No worker templates found.") @@ -362,6 +379,10 @@ async def cmd_memory(message: Message, command: CommandObject) -> None: if not is_allowed_chat(message.chat.id, allowed_chat_ids): await _reject_unauthorized_message(message) return + if not await _telegram_group_command_should_run( + message, command=command, settings=settings, octo=octo, bot=bot + ): + return limit = 300 if command.args and command.args.isdigit(): limit = int(command.args) @@ -525,7 +546,12 @@ async def handle_message(message: Message) -> None: text=text, images=images, saved_file_paths=saved_file_paths, - metadata={"reply_to_message_id": message.message_id}, + metadata={ + "reply_to_message_id": message.message_id, + "is_group_chat": _is_telegram_group_chat(message), + "reply_to_agent": _telegram_reply_targets_this_bot(message, bot), + "sender_label": _telegram_sender_label(message), + }, ) return @@ -852,6 +878,27 @@ async def _flush_pending_turn( ) -> None: lock = _CHAT_LOCKS.setdefault(chat_id, asyncio.Lock()) reply_to_message_id = metadata.get("reply_to_message_id") + is_group_chat = bool(metadata.get("is_group_chat")) + if is_group_chat: + provider = getattr(octo, "provider", None) + decision = await decide_group_addressing( + provider=provider, + settings=settings, + channel="telegram", + chat_id=chat_id, + text=text, + has_attachments=bool(images or saved_file_paths), + reply_to_agent=bool(metadata.get("reply_to_agent")), + sender_label=str(metadata.get("sender_label", "") or "") or None, + ) + if not decision.should_process: + logger.info( + "Ignoring non-addressed Telegram group message", + chat_id=chat_id, + reason=decision.reason, + confidence=decision.confidence, + ) + return # Immediate feedback if reply_to_message_id is not None: @@ -953,6 +1000,85 @@ async def _flush_pending_turn( return _flush_pending_turn +def _is_telegram_group_chat(message: Message) -> bool: + chat_type = str(getattr(getattr(message, "chat", None), "type", "") or "").lower() + return chat_type in {"group", "supergroup"} + + +def _telegram_reply_targets_this_bot(message: Message, bot: Bot) -> bool: + reply_to = getattr(message, "reply_to_message", None) + reply_from = getattr(reply_to, "from_user", None) + if reply_from is None: + return False + bot_id = getattr(bot, "id", None) + reply_from_id = getattr(reply_from, "id", None) + if bot_id is not None and reply_from_id is not None: + return int(bot_id) == int(reply_from_id) + return False + + +async def _telegram_group_command_should_run( + message: Message, + *, + command: CommandObject | None, + settings: Settings, + octo: Octo, + bot: Bot, +) -> bool: + if not _is_telegram_group_chat(message): + return True + provider = getattr(octo, "provider", None) + decision = await decide_group_addressing( + provider=provider, + settings=settings, + channel="telegram", + chat_id=message.chat.id, + text=message.text or message.caption or "", + reply_to_agent=( + _telegram_reply_targets_this_bot(message, bot) + or _telegram_command_targets_this_bot(command, bot) + ), + sender_label=_telegram_sender_label(message) or None, + ) + if decision.should_process: + return True + logger.info( + "Ignoring non-addressed Telegram group command", + chat_id=message.chat.id, + command=getattr(command, "command", None), + reason=decision.reason, + confidence=decision.confidence, + ) + return False + + +def _telegram_command_targets_this_bot(command: CommandObject | None, bot: Bot) -> bool: + mention = str(getattr(command, "mention", "") or "").strip().lstrip("@").casefold() + if not mention: + return False + username = _telegram_bot_username(bot) + return bool(username and mention == username.casefold()) + + +def _telegram_bot_username(bot: Bot) -> str: + for attr in ("username", "bot_username"): + value = str(getattr(bot, attr, "") or "").strip() + if value: + return value.lstrip("@") + me = getattr(bot, "_me", None) + value = str(getattr(me, "username", "") or "").strip() + return value.lstrip("@") + + +def _telegram_sender_label(message: Message) -> str: + user = getattr(message, "from_user", None) + parts = [ + str(getattr(user, "full_name", "") or "").strip(), + str(getattr(user, "username", "") or "").strip(), + ] + return " / ".join(part for part in parts if part) + + def _chunk_text(text: str, limit: int) -> list[str]: if len(text) <= limit: return [text] diff --git a/src/octopal/channels/whatsapp/ids.py b/src/octopal/channels/whatsapp/ids.py index d8aa7034..aaae4207 100644 --- a/src/octopal/channels/whatsapp/ids.py +++ b/src/octopal/channels/whatsapp/ids.py @@ -27,7 +27,31 @@ def parse_allowed_whatsapp_numbers(raw: str) -> list[str]: return out +def normalize_whatsapp_chat(value: str) -> str: + raw = (value or "").strip() + if not raw: + return "" + if "@g.us" in raw: + return raw + normalized_number = normalize_whatsapp_number(raw) + if normalized_number: + return normalized_number + return raw + + +def parse_allowed_whatsapp_chats(raw: str) -> list[str]: + out: list[str] = [] + seen: set[str] = set() + for chunk in (raw or "").split(","): + normalized = normalize_whatsapp_chat(chunk) + if normalized and normalized not in seen: + seen.add(normalized) + out.append(normalized) + return out + + def whatsapp_chat_id(sender: str) -> int: - normalized = normalize_whatsapp_number(sender) or sender.strip() + raw = (sender or "").strip() + normalized = raw if "@g.us" in raw else normalize_whatsapp_number(raw) or raw digest = hashlib.sha256(normalized.encode("utf-8")).digest() return int.from_bytes(digest[:8], byteorder="big", signed=False) & 0x7FFF_FFFF_FFFF_FFFF diff --git a/src/octopal/channels/whatsapp/runtime.py b/src/octopal/channels/whatsapp/runtime.py index 46522843..f88a7274 100644 --- a/src/octopal/channels/whatsapp/runtime.py +++ b/src/octopal/channels/whatsapp/runtime.py @@ -11,9 +11,12 @@ import structlog +from octopal.channels.group_addressing import decide_group_addressing from octopal.channels.whatsapp.bridge import WhatsAppBridgeController from octopal.channels.whatsapp.ids import ( + normalize_whatsapp_chat, normalize_whatsapp_number, + parse_allowed_whatsapp_chats, parse_allowed_whatsapp_numbers, whatsapp_chat_id, ) @@ -141,9 +144,16 @@ async def start(self) -> Octo: allowed_numbers = parse_allowed_whatsapp_numbers(self.settings.allowed_whatsapp_numbers) for number in allowed_numbers: self._number_by_chat_id[whatsapp_chat_id(number)] = number + allowed_chats = parse_allowed_whatsapp_chats( + str(getattr(self.settings, "allowed_whatsapp_chats", "") or "") + ) + for chat in allowed_chats: + self._number_by_chat_id[whatsapp_chat_id(chat)] = chat await self.octo.initialize_system( bot=None, - allowed_chat_ids=[whatsapp_chat_id(number) for number in allowed_numbers], + allowed_chat_ids=[ + whatsapp_chat_id(chat) for chat in [*allowed_numbers, *allowed_chats] + ], ) self._publish_metrics(connected=True) return self.octo @@ -179,7 +189,20 @@ async def handle_inbound(self, payload: dict[str, Any]) -> dict[str, Any]: ) return {"accepted": False, "reason": "not_self_chat"} allowed = parse_allowed_whatsapp_numbers(self.settings.allowed_whatsapp_numbers) - if allowed and sender not in allowed: + allowed_chats = parse_allowed_whatsapp_chats( + str(getattr(self.settings, "allowed_whatsapp_chats", "") or "") + ) + group_chat = _is_whatsapp_group_chat(conversation, payload) + chat_number = normalize_whatsapp_chat(conversation) or conversation + if group_chat: + if chat_number not in allowed_chats: + logger.warning( + "Rejected WhatsApp group message from unauthorized chat", + sender=sender, + conversation=conversation, + ) + return {"accepted": False, "reason": "unauthorized_group"} + elif allowed and sender not in allowed: logger.warning("Rejected WhatsApp message from unauthorized sender", sender=sender) return {"accepted": False, "reason": "unauthorized"} @@ -191,7 +214,6 @@ async def handle_inbound(self, payload: dict[str, Any]) -> dict[str, Any]: ) return {"accepted": False, "reason": "invalid_self_chat_sender"} - chat_number = normalize_whatsapp_number(conversation) or conversation chat_id = whatsapp_chat_id(chat_number) self._number_by_chat_id[chat_id] = chat_number await self._pending_turns.submit( @@ -203,6 +225,9 @@ async def handle_inbound(self, payload: dict[str, Any]) -> dict[str, Any]: "message_id": str(payload.get("messageId", "") or "").strip(), "remote_jid": str(payload.get("remoteJid", "") or "").strip(), "target_from_me": from_me, + "is_group_chat": group_chat, + "reply_to_agent": bool(payload.get("replyToAgent")), + "sender_label": sender, }, ) self._publish_metrics(last_sender=sender) @@ -290,6 +315,27 @@ async def _flush_pending_turn( remote_jid = str(metadata.get("remote_jid", "") or "").strip() or None target_from_me = bool(metadata.get("target_from_me")) + if bool(metadata.get("is_group_chat")): + provider = getattr(self.octo, "provider", None) + decision = await decide_group_addressing( + provider=provider, + settings=self.settings, + channel="whatsapp", + chat_id=chat_id, + text=text, + has_attachments=bool(images or saved_file_paths), + reply_to_agent=bool(metadata.get("reply_to_agent")), + sender_label=str(metadata.get("sender_label", "") or "") or None, + ) + if not decision.should_process: + logger.info( + "Ignoring non-addressed WhatsApp group message", + chat_id=chat_id, + reason=decision.reason, + confidence=decision.confidence, + ) + return + # Immediate feedback if to and message_id: try: @@ -366,6 +412,13 @@ def _persist_whatsapp_media_payload( return str(file_path) +def _is_whatsapp_group_chat(conversation: str, payload: dict[str, Any]) -> bool: + if bool(payload.get("group")): + return True + remote_jid = str(payload.get("remoteJid", "") or "").strip() + return "@g.us" in conversation or "@g.us" in remote_jid + + def _chunk_text(text: str, limit: int) -> list[str]: if len(text) <= limit: return [_whatsappify(text)] diff --git a/src/octopal/infrastructure/config/models.py b/src/octopal/infrastructure/config/models.py index 0a4cb7cc..8a4bb615 100644 --- a/src/octopal/infrastructure/config/models.py +++ b/src/octopal/infrastructure/config/models.py @@ -13,6 +13,13 @@ class TelegramConfig(BaseModel): parse_mode: str = "MarkdownV2" +class GroupAddressingConfig(BaseModel): + enabled: bool = True + agent_name: str | None = None + agent_aliases: list[str] = Field(default_factory=list) + collective_aliases: list[str] = Field(default_factory=list) + + class LLMConfig(BaseModel): provider_id: str | None = None model: str | None = None @@ -69,6 +76,7 @@ class WorkerRuntimeConfig(BaseModel): class WhatsAppConfig(BaseModel): mode: str = "separate" allowed_numbers: list[str] = Field(default_factory=list) + allowed_chats: list[str] = Field(default_factory=list) auth_dir: Path | None = None bridge_host: str = "127.0.0.1" bridge_port: int = 8765 @@ -169,6 +177,7 @@ class ConnectorsConfig(BaseModel): class OctopalConfig(BaseModel): user_channel: str = DEFAULT_USER_CHANNEL telegram: TelegramConfig = Field(default_factory=TelegramConfig) + group_addressing: GroupAddressingConfig = Field(default_factory=GroupAddressingConfig) # Octo LLM settings llm: LLMConfig = Field(default_factory=LLMConfig) diff --git a/src/octopal/infrastructure/config/settings.py b/src/octopal/infrastructure/config/settings.py index 171156f3..832b5003 100644 --- a/src/octopal/infrastructure/config/settings.py +++ b/src/octopal/infrastructure/config/settings.py @@ -111,6 +111,10 @@ class Settings(BaseSettings): heartbeat_interval_seconds: int = Field(900, alias="OCTOPAL_HEARTBEAT_INTERVAL_SECONDS") user_message_grace_seconds: float = Field(5.0, alias="OCTOPAL_USER_MESSAGE_GRACE_SECONDS") + group_addressing_enabled: bool = Field(True, alias="OCTOPAL_GROUP_ADDRESSING_ENABLED") + group_agent_name: str = Field("", alias="OCTOPAL_GROUP_AGENT_NAME") + group_agent_aliases: str = Field("", alias="OCTOPAL_GROUP_AGENT_ALIASES") + group_collective_aliases: str = Field("", alias="OCTOPAL_GROUP_COLLECTIVE_ALIASES") # Connectors connectors: ConnectorsConfig = Field(default_factory=ConnectorsConfig) @@ -122,6 +126,7 @@ class Settings(BaseSettings): telegram_parse_mode: str = Field("MarkdownV2", alias="OCTOPAL_TELEGRAM_PARSE_MODE") whatsapp_mode: str = Field("separate", alias="OCTOPAL_WHATSAPP_MODE") allowed_whatsapp_numbers: str = Field("", alias="ALLOWED_WHATSAPP_NUMBERS") + allowed_whatsapp_chats: str = Field("", alias="ALLOWED_WHATSAPP_CHATS") whatsapp_auth_dir: Path | None = Field(default=None, alias="OCTOPAL_WHATSAPP_AUTH_DIR") whatsapp_bridge_host: str = Field("127.0.0.1", alias="OCTOPAL_WHATSAPP_BRIDGE_HOST") whatsapp_bridge_port: int = Field(8765, alias="OCTOPAL_WHATSAPP_BRIDGE_PORT") @@ -204,6 +209,12 @@ def _sync_settings_from_config(settings: Settings, config: OctopalConfig) -> Non updates["allowed_telegram_chat_ids"] = ",".join(config.telegram.allowed_chat_ids) updates["telegram_parse_mode"] = config.telegram.parse_mode + # Group addressing + updates["group_addressing_enabled"] = config.group_addressing.enabled + updates["group_agent_name"] = config.group_addressing.agent_name or "" + updates["group_agent_aliases"] = ",".join(config.group_addressing.agent_aliases) + updates["group_collective_aliases"] = ",".join(config.group_addressing.collective_aliases) + # LLM (Octo) updates["litellm_provider_id"] = config.llm.provider_id updates["litellm_model"] = config.llm.model @@ -257,6 +268,7 @@ def _sync_settings_from_config(settings: Settings, config: OctopalConfig) -> Non # WhatsApp updates["whatsapp_mode"] = config.whatsapp.mode updates["allowed_whatsapp_numbers"] = ",".join(config.whatsapp.allowed_numbers) + updates["allowed_whatsapp_chats"] = ",".join(config.whatsapp.allowed_chats) updates["whatsapp_auth_dir"] = config.whatsapp.auth_dir updates["whatsapp_bridge_host"] = config.whatsapp.bridge_host updates["whatsapp_bridge_port"] = config.whatsapp.bridge_port diff --git a/tests/test_group_addressing.py b/tests/test_group_addressing.py new file mode 100644 index 00000000..99f20623 --- /dev/null +++ b/tests/test_group_addressing.py @@ -0,0 +1,102 @@ +from __future__ import annotations + +import json +from types import SimpleNamespace + +from octopal.channels.group_addressing import ( + decide_group_addressing, + resolve_group_addressing_identity, +) + + +class _FakeProvider: + def __init__(self, payload: dict | str) -> None: + self.payload = payload + self.messages = [] + + async def complete(self, messages, **kwargs): + self.messages.append(messages) + if isinstance(self.payload, str): + return self.payload + return json.dumps(self.payload) + + +def _settings(**kwargs) -> SimpleNamespace: + defaults = { + "group_addressing_enabled": True, + "group_agent_name": "Alice", + "group_agent_aliases": "Alice,AliceBot", + "group_collective_aliases": "Octopals,agents", + "a2a": SimpleNamespace(agent_name="Fallback"), + } + defaults.update(kwargs) + return SimpleNamespace(**defaults) + + +def test_resolve_group_addressing_identity_uses_configured_values_only() -> None: + identity = resolve_group_addressing_identity(_settings()) + + assert identity.agent_name == "Alice" + assert identity.agent_aliases == ["Alice", "AliceBot"] + assert identity.collective_aliases == ["Octopals", "agents"] + + +def test_group_addressing_uses_provider_decision_for_group_message() -> None: + provider = _FakeProvider( + {"action": "respond_all_agents", "confidence": 0.91, "reason": "collective request"} + ) + + async def scenario(): + return await decide_group_addressing( + provider=provider, + settings=_settings(), + channel="telegram", + chat_id=-100, + text="Octopals, update yourselves", + sender_label="Slava", + ) + + import asyncio + + decision = asyncio.run(scenario()) + + assert decision.action == "respond_all_agents" + assert decision.should_process is True + assert provider.messages + + +def test_group_addressing_reply_to_agent_continues_without_provider() -> None: + async def scenario(): + return await decide_group_addressing( + provider=None, + settings=_settings(), + channel="telegram", + chat_id=-100, + text="yes", + reply_to_agent=True, + ) + + import asyncio + + decision = asyncio.run(scenario()) + + assert decision.action == "continue_thread" + assert decision.should_process is True + + +def test_group_addressing_is_conservative_without_provider() -> None: + async def scenario(): + return await decide_group_addressing( + provider=None, + settings=_settings(), + channel="whatsapp", + chat_id=100, + text="Alice, what is the status?", + ) + + import asyncio + + decision = asyncio.run(scenario()) + + assert decision.action == "ignore" + assert decision.should_process is False diff --git a/tests/test_settings_config_sync.py b/tests/test_settings_config_sync.py index 3cd0544d..02fce8da 100644 --- a/tests/test_settings_config_sync.py +++ b/tests/test_settings_config_sync.py @@ -172,3 +172,33 @@ def test_load_settings_syncs_a2a_config(tmp_path, monkeypatch) -> None: assert settings.a2a.public_base_url == "https://octo.example" assert settings.a2a.agent_name == "Alice" assert settings.a2a.peers["bob"].token == "peer-secret" + + +def test_load_settings_syncs_group_addressing_and_whatsapp_group_chats( + tmp_path, monkeypatch +) -> None: + (tmp_path / "config.json").write_text( + json.dumps( + { + "group_addressing": { + "enabled": True, + "agent_name": "Alice", + "agent_aliases": ["Alice", "AliceBot"], + "collective_aliases": ["Octopals", "agents"], + }, + "whatsapp": { + "allowed_chats": ["120363123456789@g.us"], + }, + } + ), + encoding="utf-8", + ) + monkeypatch.chdir(tmp_path) + + settings = load_settings() + + assert settings.group_addressing_enabled is True + assert settings.group_agent_name == "Alice" + assert settings.group_agent_aliases == "Alice,AliceBot" + assert settings.group_collective_aliases == "Octopals,agents" + assert settings.allowed_whatsapp_chats == "120363123456789@g.us" diff --git a/tests/test_telegram_react_tag_sanitization.py b/tests/test_telegram_react_tag_sanitization.py index 19b4734b..055e4e25 100644 --- a/tests/test_telegram_react_tag_sanitization.py +++ b/tests/test_telegram_react_tag_sanitization.py @@ -3,6 +3,9 @@ import asyncio import sys import types +from types import SimpleNamespace + +from aiogram.filters import CommandObject from octopal.infrastructure.config.settings import Settings from octopal.runtime.octo.core import OctoReply @@ -40,9 +43,9 @@ def test_strip_reaction_tags_removes_unknown_react_markup() -> None: def test_extract_edge_reaction_fallback_handles_short_confirmation_text() -> None: - emoji, text = extract_edge_reaction_fallback("Поставила! 👻") + emoji, text = extract_edge_reaction_fallback("Set it! 👻") assert emoji == "👻" - assert text == "Поставила!" + assert text == "Set it!" def test_telegram_uses_reply_reaction_fallback_when_immediate_loses_tag(tmp_path) -> None: @@ -51,7 +54,7 @@ async def handle_message( self, text: str, chat_id: int, images=None, saved_file_paths=None, **kwargs ): return OctoReply( - immediate="Поставила! Посмотрим, появится ли 👻", + immediate="Set it! Let us see if it appears 👻", followup=None, followup_required=False, reaction="👍", @@ -109,7 +112,7 @@ async def scenario() -> None: (211619002, 4740, "👍"), ] assert queued_messages == [ - (211619002, "Поставила! Посмотрим, появится ли 👻", 4740), + (211619002, "Set it! Let us see if it appears 👻", 4740), ] @@ -119,7 +122,7 @@ async def handle_message( self, text: str, chat_id: int, images=None, saved_file_paths=None, **kwargs ): return OctoReply( - immediate="Поставила! 👻", + immediate="Set it! 👻", followup=None, followup_required=False, reaction=None, @@ -177,5 +180,173 @@ async def scenario() -> None: (211619002, 4741, "👻"), ] assert queued_messages == [ - (211619002, "Поставила!", 4741), + (211619002, "Set it!", 4741), ] + + +def test_telegram_group_message_ignored_before_reaction_or_octo_call(tmp_path) -> None: + class DummyProvider: + async def complete(self, messages, **kwargs): + return '{"action":"ignore","confidence":0.97,"reason":"addressed to another agent"}' + + class DummyOcto: + provider = DummyProvider() + + def __init__(self) -> None: + self.calls = 0 + + async def handle_message( + self, text: str, chat_id: int, images=None, saved_file_paths=None, **kwargs + ): + self.calls += 1 + return OctoReply( + immediate="should not happen", + followup=None, + followup_required=False, + reaction=None, + ) + + class DummyBot: + def __init__(self) -> None: + self.reactions: list[tuple[int, int, str]] = [] + + async def set_message_reaction(self, chat_id: int, message_id: int, reaction): + self.reactions.append((chat_id, message_id, reaction[0].emoji)) + + queued_messages: list[tuple[int, str, int | None]] = [] + + async def fake_enqueue_send( + bot, chat_id: int, text: str, reply_to_message_id: int | None = None + ) -> None: + queued_messages.append((chat_id, text, reply_to_message_id)) + + original_enqueue = telegram_handlers._enqueue_send + telegram_handlers._enqueue_send = fake_enqueue_send + + settings = Settings( + OCTOPAL_STATE_DIR=tmp_path / "state", + OCTOPAL_WORKSPACE_DIR=tmp_path / "workspace", + OCTOPAL_TELEGRAM_PARSE_MODE="MarkdownV2", + ) + octo = DummyOcto() + bot = DummyBot() + flush = _flush_pending_turn_factory(octo, settings, bot) + + try: + + async def scenario() -> None: + await flush( + -100211619002, + "Bob, what is the status?", + [], + [], + { + "reply_to_message_id": 4742, + "is_group_chat": True, + "reply_to_agent": False, + "sender_label": "Slava", + }, + ) + + asyncio.run(scenario()) + finally: + telegram_handlers._enqueue_send = original_enqueue + + assert octo.calls == 0 + assert bot.reactions == [] + assert queued_messages == [] + + +def test_telegram_plain_group_command_is_gated(tmp_path) -> None: + class DummyProvider: + async def complete(self, messages, **kwargs): + return '{"action":"ignore","confidence":0.98,"reason":"ambient command"}' + + message = SimpleNamespace( + chat=SimpleNamespace(id=-100211619002, type="supergroup"), + text="/status", + caption=None, + from_user=SimpleNamespace(full_name="Slava", username="slava"), + reply_to_message=None, + ) + octo = SimpleNamespace(provider=DummyProvider()) + settings = Settings( + OCTOPAL_STATE_DIR=tmp_path / "state", + OCTOPAL_WORKSPACE_DIR=tmp_path / "workspace", + ) + bot = SimpleNamespace(id=123456, username="AliceBot") + + async def scenario() -> bool: + return await telegram_handlers._telegram_group_command_should_run( + message, + command=CommandObject(command="status"), + settings=settings, + octo=octo, + bot=bot, + ) + + assert asyncio.run(scenario()) is False + + +def test_telegram_targeted_group_command_skips_provider_gate(tmp_path) -> None: + class FailingProvider: + async def complete(self, messages, **kwargs): + raise AssertionError("targeted command should not need provider classification") + + message = SimpleNamespace( + chat=SimpleNamespace(id=-100211619002, type="supergroup"), + text="/status@AliceBot", + caption=None, + from_user=SimpleNamespace(full_name="Slava", username="slava"), + reply_to_message=None, + ) + octo = SimpleNamespace(provider=FailingProvider()) + settings = Settings( + OCTOPAL_STATE_DIR=tmp_path / "state", + OCTOPAL_WORKSPACE_DIR=tmp_path / "workspace", + ) + bot = SimpleNamespace(id=123456, username="AliceBot") + + async def scenario() -> bool: + return await telegram_handlers._telegram_group_command_should_run( + message, + command=CommandObject(command="status", mention="AliceBot"), + settings=settings, + octo=octo, + bot=bot, + ) + + assert asyncio.run(scenario()) is True + + +def test_telegram_reply_to_unknown_bot_does_not_bypass_group_gate(tmp_path) -> None: + class DummyProvider: + async def complete(self, messages, **kwargs): + return '{"action":"ignore","confidence":0.98,"reason":"other bot"}' + + message = SimpleNamespace( + chat=SimpleNamespace(id=-100211619002, type="supergroup"), + text="/status", + caption=None, + from_user=SimpleNamespace(full_name="Slava", username="slava"), + reply_to_message=SimpleNamespace( + from_user=SimpleNamespace(is_bot=True, id=999999, username="OtherBot") + ), + ) + octo = SimpleNamespace(provider=DummyProvider()) + settings = Settings( + OCTOPAL_STATE_DIR=tmp_path / "state", + OCTOPAL_WORKSPACE_DIR=tmp_path / "workspace", + ) + bot = SimpleNamespace(id=123456, username="AliceBot") + + async def scenario() -> bool: + return await telegram_handlers._telegram_group_command_should_run( + message, + command=CommandObject(command="status"), + settings=settings, + octo=octo, + bot=bot, + ) + + assert asyncio.run(scenario()) is False diff --git a/tests/test_whatsapp_ids.py b/tests/test_whatsapp_ids.py index e378ef93..a17511ae 100644 --- a/tests/test_whatsapp_ids.py +++ b/tests/test_whatsapp_ids.py @@ -2,6 +2,7 @@ from octopal.channels.whatsapp.ids import ( normalize_whatsapp_number, + parse_allowed_whatsapp_chats, parse_allowed_whatsapp_numbers, whatsapp_chat_id, ) @@ -20,5 +21,16 @@ def test_parse_allowed_whatsapp_numbers_is_deduplicated() -> None: assert parsed == ["+15551234567", "+447700900123"] +def test_parse_allowed_whatsapp_chats_preserves_group_jids() -> None: + parsed = parse_allowed_whatsapp_chats( + "120363123456789@g.us, +15551234567, 120363123456789@g.us" + ) + assert parsed == ["120363123456789@g.us", "+15551234567"] + + def test_whatsapp_chat_id_is_stable() -> None: assert whatsapp_chat_id("+15551234567") == whatsapp_chat_id("+1 (555) 123-4567") + + +def test_whatsapp_chat_id_keeps_group_jids_distinct_from_phone_numbers() -> None: + assert whatsapp_chat_id("120363123456789@g.us") != whatsapp_chat_id("+120363123456789") diff --git a/tests/test_whatsapp_runtime.py b/tests/test_whatsapp_runtime.py index 1926d674..c47ebe97 100644 --- a/tests/test_whatsapp_runtime.py +++ b/tests/test_whatsapp_runtime.py @@ -58,11 +58,22 @@ def __init__(self, immediate: str) -> None: self.immediate = immediate +class _FakeProvider: + def __init__(self, action: str = "respond_self") -> None: + self.action = action + self.messages = [] + + async def complete(self, messages, **kwargs): + self.messages.append(messages) + return f'{{"action":"{self.action}","confidence":0.95,"reason":"test"}}' + + class _FakeOcto: def __init__(self) -> None: self.handled: list[dict] = [] self.initialized: list[int] = [] self.internal_send = None + self.provider = _FakeProvider() async def initialize_system(self, *, bot=None, allowed_chat_ids=None) -> None: self.initialized = list(allowed_chat_ids or []) @@ -75,11 +86,18 @@ async def stop_background_tasks(self) -> None: return None -def _make_settings(*, mode: str, allowed_numbers: str) -> SimpleNamespace: +def _make_settings( + *, mode: str, allowed_numbers: str, allowed_chats: str = "" +) -> SimpleNamespace: return SimpleNamespace( whatsapp_mode=mode, allowed_whatsapp_numbers=allowed_numbers, + allowed_whatsapp_chats=allowed_chats, user_message_grace_seconds=0.0, + group_addressing_enabled=True, + group_agent_name="Alice", + group_agent_aliases="Alice,AliceBot", + group_collective_aliases="Octopals,agents", gateway_port=8000, whatsapp_callback_token="", whatsapp_bridge_host="127.0.0.1", @@ -427,3 +445,114 @@ async def scenario() -> None: assert runtime.bridge.sent_files == [ {"to": "+15551234567", "file_path": str(file_path), "caption": "Take this"} ] + + +def test_whatsapp_runtime_rejects_unconfigured_group_chat(monkeypatch) -> None: + fake_octo = _FakeOcto() + monkeypatch.setattr(whatsapp_runtime_module, "build_octo", lambda settings: fake_octo) + monkeypatch.setattr(whatsapp_runtime_module, "WhatsAppBridgeController", _FakeBridgeController) + monkeypatch.setattr( + whatsapp_runtime_module, "update_component_gauges", lambda *args, **kwargs: None + ) + + runtime = WhatsAppRuntime(_make_settings(mode="separate", allowed_numbers="+15551234567")) + runtime.attach_octo_output() + + async def scenario() -> None: + result = await runtime.handle_inbound( + { + "sender": "+15550000000", + "conversation": "120363123456789@g.us", + "remoteJid": "120363123456789@g.us", + "group": True, + "fromMe": False, + "text": "Alice, status?", + } + ) + assert result == {"accepted": False, "reason": "unauthorized_group"} + assert fake_octo.handled == [] + + asyncio.run(scenario()) + + +def test_whatsapp_runtime_ignores_non_addressed_group_message(monkeypatch) -> None: + fake_octo = _FakeOcto() + fake_octo.provider = _FakeProvider("ignore") + monkeypatch.setattr(whatsapp_runtime_module, "build_octo", lambda settings: fake_octo) + monkeypatch.setattr(whatsapp_runtime_module, "WhatsAppBridgeController", _FakeBridgeController) + monkeypatch.setattr( + whatsapp_runtime_module, "update_component_gauges", lambda *args, **kwargs: None + ) + monkeypatch.setattr( + whatsapp_runtime_module, "update_last_message", lambda *args, **kwargs: None + ) + + runtime = WhatsAppRuntime( + _make_settings( + mode="separate", + allowed_numbers="+15551234567", + allowed_chats="120363123456789@g.us", + ) + ) + runtime.attach_octo_output() + + async def scenario() -> None: + result = await runtime.handle_inbound( + { + "sender": "+15550000000", + "conversation": "120363123456789@g.us", + "remoteJid": "120363123456789@g.us", + "group": True, + "fromMe": False, + "text": "Bob, status?", + "messageId": "wamid-group-1", + } + ) + assert result["accepted"] is True + assert fake_octo.provider.messages + assert fake_octo.handled == [] + assert runtime.bridge.reactions == [] + assert runtime.bridge.sent == [] + + asyncio.run(scenario()) + + +def test_whatsapp_runtime_handles_addressed_group_message(monkeypatch) -> None: + fake_octo = _FakeOcto() + fake_octo.provider = _FakeProvider("respond_all_agents") + monkeypatch.setattr(whatsapp_runtime_module, "build_octo", lambda settings: fake_octo) + monkeypatch.setattr(whatsapp_runtime_module, "WhatsAppBridgeController", _FakeBridgeController) + monkeypatch.setattr( + whatsapp_runtime_module, "update_component_gauges", lambda *args, **kwargs: None + ) + monkeypatch.setattr( + whatsapp_runtime_module, "update_last_message", lambda *args, **kwargs: None + ) + + runtime = WhatsAppRuntime( + _make_settings( + mode="separate", + allowed_numbers="+15551234567", + allowed_chats="120363123456789@g.us", + ) + ) + runtime.attach_octo_output() + + async def scenario() -> None: + result = await runtime.handle_inbound( + { + "sender": "+15550000000", + "conversation": "120363123456789@g.us", + "remoteJid": "120363123456789@g.us", + "group": True, + "fromMe": False, + "text": "Octopals, status?", + "messageId": "wamid-group-2", + } + ) + assert result["accepted"] is True + assert fake_octo.handled[-1]["text"] == "Octopals, status?" + assert fake_octo.handled[-1]["kwargs"]["source_channel"] == "whatsapp" + assert runtime.bridge.sent == [("120363123456789@g.us", "hello back")] + + asyncio.run(scenario())