diff --git a/plugin/__init__.py b/plugin/__init__.py index 2fd7b33..5dd48dc 100644 --- a/plugin/__init__.py +++ b/plugin/__init__.py @@ -525,6 +525,27 @@ def _collect_from_frame(vf): return found_fns def _tick(): + # --- Drain the main-thread dispatch queue --- + try: + from .core.binary_operations import _main_thread_queue + from .core import binary_operations as _bo_mod + import threading as _threading + if _bo_mod._main_thread_id is None: + _bo_mod._main_thread_id = _threading.current_thread().ident + while not _main_thread_queue.empty(): + try: + func, done_event, result_holder = _main_thread_queue.get_nowait() + try: + result_holder[0] = func() + except Exception as e: + result_holder[1] = e + finally: + done_event.set() + except Exception: + break + except Exception: + pass + try: ops = ( plugin.server.binary_ops @@ -532,6 +553,7 @@ def _tick(): else None ) if not ops: + bn.log_debug("MCP _tick: ops is None") return # First, prune internal weakrefs and get a snapshot of tracked views @@ -542,17 +564,13 @@ def _tick(): # Discover all open BVs from UI and sync registry (returns filenames) try: - _discover_all_open_bvs(ops) or set() - except Exception: - pass + found = _discover_all_open_bvs(ops) or set() + if found: + bn.log_debug(f"MCP _tick: discovered BVs: {found}") + except Exception as e: + bn.log_debug(f"MCP _tick: discover error: {e}") - # Do not prune solely based on UI heuristics; UI enumeration may miss open tabs. - # Rely on explicit close notifications and weakref pruning in ops. - - # Keep MCP-selected active view independent of UI focus. - # Only adopt a UI-active view if there is no current selection - # (e.g., after the previously selected view was actually closed - # and pruned by weakrefs). + # Always try to adopt UI-active view if current_view is None try: if ops.current_view is None: try: @@ -562,14 +580,20 @@ def _tick(): act_bv = None if act_ctx: vf = act_ctx.getCurrentViewFrame() + bn.log_debug(f"MCP _tick: viewFrame={vf}, has getCurrentBinaryView={hasattr(vf, 'getCurrentBinaryView') if vf else 'N/A'}") if vf and hasattr(vf, "getCurrentBinaryView"): act_bv = vf.getCurrentBinaryView() - ops.current_view = act_bv - if act_bv: + bn.log_debug(f"MCP _tick: getCurrentBinaryView returned: {act_bv}, type={type(act_bv).__name__ if act_bv else 'None'}") + else: + bn.log_debug("MCP _tick: no active UIContext") + if act_bv is not None: + ops.current_view = act_bv ops.register_view(act_bv) - except Exception: - # If UI is unavailable or no active view, leave as None - pass + bn.log_info(f"MCP _tick: auto-set current_view to {act_bv.file.filename}") + else: + bn.log_debug("MCP _tick: current_view is None and no active BV found") + except Exception as e: + bn.log_debug(f"MCP _tick: error acquiring BV: {e}") except Exception: pass except Exception: @@ -577,7 +601,7 @@ def _tick(): pass _bv_monitor_timer = QTimer() - _bv_monitor_timer.setInterval(1000) # 1s; light periodic sync + _bv_monitor_timer.setInterval(200) # 200ms; fast dispatch for HTTP API calls _bv_monitor_timer.timeout.connect(_tick) def _start(): diff --git a/plugin/api/endpoints.py b/plugin/api/endpoints.py index b574c5b..2d32474 100644 --- a/plugin/api/endpoints.py +++ b/plugin/api/endpoints.py @@ -210,25 +210,29 @@ def search_functions( if not search_term: return [] - matches = [] - for func in self.binary_ops.current_view.functions: - if search_term.lower() in func.name.lower(): - matches.append( - { - "name": func.name, - "address": hex(func.start), - "raw_name": func.raw_name if hasattr(func, "raw_name") else func.name, - "symbol": { - "type": str(func.symbol.type) if func.symbol else None, - "full_name": func.symbol.full_name if func.symbol else None, + from ..core.binary_operations import _run_on_main_thread + + def _inner(): + matches = [] + for func in self.binary_ops.current_view.functions: + if search_term.lower() in func.name.lower(): + matches.append( + { + "name": func.name, + "address": hex(func.start), + "raw_name": func.raw_name if hasattr(func, "raw_name") else func.name, + "symbol": { + "type": str(func.symbol.type) if func.symbol else None, + "full_name": func.symbol.full_name if func.symbol else None, + } + if func.symbol + else None, } - if func.symbol - else None, - } - ) + ) + matches.sort(key=lambda x: x["name"]) + return matches[offset : offset + limit] - matches.sort(key=lambda x: x["name"]) - return matches[offset : offset + limit] + return _run_on_main_thread(_inner) def decompile_function(self, identifier: str) -> str | None: """Decompile a function by name or address""" diff --git a/plugin/core/binary_operations.py b/plugin/core/binary_operations.py index 1ac91e1..92ca2b0 100644 --- a/plugin/core/binary_operations.py +++ b/plugin/core/binary_operations.py @@ -1,6 +1,8 @@ import platform +import queue import re import subprocess +import threading import weakref from typing import Any @@ -10,6 +12,31 @@ from ..utils.string_utils import escape_non_ascii from .config import BinaryNinjaConfig +# Global dispatch queue: HTTP thread puts (func, result_event, result_holder) tuples, +# a QTimer on the main thread picks them up and executes. +_main_thread_queue: queue.Queue = queue.Queue() + +# Thread ID of the main thread (set by QTimer on first tick) +_main_thread_id: int | None = None + + +def _run_on_main_thread(func, timeout=120): + """Dispatch func to the main thread via _main_thread_queue and wait for result. + + If already on the main thread, execute directly (avoids nested deadlock). + """ + if _main_thread_id is not None and threading.current_thread().ident == _main_thread_id: + return func() + + result_holder = [None, None] # [result, exception] + done_event = threading.Event() + _main_thread_queue.put((func, done_event, result_holder)) + if not done_event.wait(timeout=timeout): + raise TimeoutError("Main-thread dispatch timed out") + if result_holder[1] is not None: + raise result_holder[1] + return result_holder[0] + class BinaryOperations: def __init__(self, config: BinaryNinjaConfig): @@ -102,14 +129,11 @@ def _prune_views(self) -> None: new_fn_map[fn] = vid self._views_by_id = alive self._id_by_filename = new_fn_map - # If current_view no longer exists among alive views, clear it - try: - if self._current_view is not None and all( - obj is not self._current_view for obj in alive_objs - ): - self._current_view = None - except Exception: - self._current_view = None + # NOTE: Do NOT clear _current_view here based on weakref liveness. + # _current_view is a strong reference and keeps the BV alive intentionally. + # It should only be cleared by explicit stop_server / close actions. + # The weakref-based _views_by_id may contain stale entries from UI wrapper + # objects that are different Python objects for the same underlying BV. def _register_view(self, bv: bn.BinaryView) -> str: """Add a view to the managed list if not present, return its id.""" @@ -313,42 +337,45 @@ def get_function_by_name_or_address(self, identifier: str | int) -> bn.Function if not self._current_view: raise RuntimeError("No binary loaded") - # Handle address-based lookup - try: - if isinstance(identifier, str) and identifier.startswith("0x"): - addr = int(identifier, 16) - elif isinstance(identifier, (int, str)): - addr = int(identifier) if isinstance(identifier, str) else identifier - - func = self._current_view.get_function_at(addr) - if func: - bn.log_info(f"Found function at address {hex(addr)}: {func.name}") - return func - except ValueError: - pass + def _inner(): + # Handle address-based lookup + try: + if isinstance(identifier, str) and identifier.startswith("0x"): + addr = int(identifier, 16) + elif isinstance(identifier, (int, str)): + addr = int(identifier) if isinstance(identifier, str) else identifier + + func = self._current_view.get_function_at(addr) + if func: + bn.log_info(f"Found function at address {hex(addr)}: {func.name}") + return func + except ValueError: + pass - # Handle name-based lookup with case sensitivity - for func in self._current_view.functions: - if func.name == identifier: - bn.log_info(f"Found function by name: {func.name}") - return func - - # Try case-insensitive match as fallback - for func in self._current_view.functions: - if func.name.lower() == str(identifier).lower(): - bn.log_info(f"Found function by case-insensitive name: {func.name}") - return func - - # Try symbol table lookup as last resort - symbol = self._current_view.get_symbol_by_raw_name(str(identifier)) - if symbol and symbol.address: - func = self._current_view.get_function_at(symbol.address) - if func: - bn.log_info(f"Found function through symbol lookup: {func.name}") - return func - - bn.log_error(f"Could not find function: {identifier}") - return None + # Handle name-based lookup with case sensitivity + for func in self._current_view.functions: + if func.name == identifier: + bn.log_info(f"Found function by name: {func.name}") + return func + + # Try case-insensitive match as fallback + for func in self._current_view.functions: + if func.name.lower() == str(identifier).lower(): + bn.log_info(f"Found function by case-insensitive name: {func.name}") + return func + + # Try symbol table lookup as last resort + symbol = self._current_view.get_symbol_by_raw_name(str(identifier)) + if symbol and symbol.address: + func = self._current_view.get_function_at(symbol.address) + if func: + bn.log_info(f"Found function through symbol lookup: {func.name}") + return func + + bn.log_error(f"Could not find function: {identifier}") + return None + + return _run_on_main_thread(_inner) def _normalize_identifier_list(self, identifiers: Any) -> list[Any]: """Normalize comma-delimited strings or iterables into a list of identifiers.""" @@ -553,17 +580,19 @@ def get_function_names(self, offset: int = 0, limit: int = 100) -> list[dict[str if not self._current_view: raise RuntimeError("No binary loaded") - functions = [] - for func in self._current_view.functions: - functions.append( - { - "name": func.name, - "address": hex(func.start), - "raw_name": func.raw_name if hasattr(func, "raw_name") else func.name, - } - ) + def _inner(): + functions = [] + for func in self._current_view.functions: + functions.append( + { + "name": func.name, + "address": hex(func.start), + "raw_name": func.raw_name if hasattr(func, "raw_name") else func.name, + } + ) + return functions[offset : offset + limit] - return functions[offset : offset + limit] + return _run_on_main_thread(_inner) def get_class_names(self, offset: int = 0, limit: int = 100) -> list[str]: """Get list of class names with pagination""" @@ -857,14 +886,7 @@ def get_function_info(self, identifier: str | int) -> dict[str, Any] | None: return info def decompile_function(self, identifier: str | int) -> str | None: - """Decompile a function and include addresses per statement. - - Args: - identifier: Function name or address - - Returns: - Decompiled HLIL-like code with address prefixes per line - """ + """Decompile a function and include addresses per statement.""" if not self._current_view: raise RuntimeError("No binary loaded") @@ -872,49 +894,53 @@ def decompile_function(self, identifier: str | int) -> str | None: if not func: return None - # analyze func in case it was skipped - func.analysis_skipped = False - self._current_view.update_analysis_and_wait() + def _inner(): + # analyze func in case it was skipped + func.analysis_skipped = False + # NOTE: Do NOT call update_analysis_and_wait() here as it blocks + # the main thread event loop when run from QTimer dispatch. - try: - il = getattr(func, "hlil", None) - if il and hasattr(il, "instructions"): - lines: list[str] = [] - last_addr: int | None = None - for ins in il.instructions: - try: - addr = getattr(ins, "address", None) - except Exception: - addr = None - if addr is None: - addr = last_addr if last_addr is not None else func.start - last_addr = addr - addr_str = f"{int(addr):08x}" - text = str(ins) - lines.append(f"{addr_str} {text}") - return "\n".join(lines) - # Fall back to MLIL with addresses - mil = getattr(func, "mlil", None) - if mil and hasattr(mil, "instructions"): - lines: list[str] = [] - last_addr: int | None = None - for ins in mil.instructions: - try: - addr = getattr(ins, "address", None) - except Exception: - addr = None - if addr is None: - addr = last_addr if last_addr is not None else func.start - last_addr = addr - addr_str = f"{int(addr):08x}" - text = str(ins) - lines.append(f"{addr_str} {text}") - return "\n".join(lines) - # Last resort - return str(func) - except Exception as e: - bn.log_error(f"Error decompiling function: {e!s}") - return None + try: + il = getattr(func, "hlil", None) + if il and hasattr(il, "instructions"): + lines: list[str] = [] + last_addr: int | None = None + for ins in il.instructions: + try: + addr = getattr(ins, "address", None) + except Exception: + addr = None + if addr is None: + addr = last_addr if last_addr is not None else func.start + last_addr = addr + addr_str = f"{int(addr):08x}" + text = str(ins) + lines.append(f"{addr_str} {text}") + return "\n".join(lines) + # Fall back to MLIL with addresses + mil = getattr(func, "mlil", None) + if mil and hasattr(mil, "instructions"): + lines: list[str] = [] + last_addr: int | None = None + for ins in mil.instructions: + try: + addr = getattr(ins, "address", None) + except Exception: + addr = None + if addr is None: + addr = last_addr if last_addr is not None else func.start + last_addr = addr + addr_str = f"{int(addr):08x}" + text = str(ins) + lines.append(f"{addr_str} {text}") + return "\n".join(lines) + # Last resort + return str(func) + except Exception as e: + bn.log_error(f"Error decompiling function: {e!s}") + return None + + return _run_on_main_thread(_inner) def get_function_il( self, identifier: str | int, view: str = "hlil", ssa: bool = False @@ -939,7 +965,7 @@ def get_function_il( # Ensure analysis has run for this function try: func.analysis_skipped = False - self._current_view.update_analysis_and_wait() + # NOTE: Do NOT call update_analysis_and_wait() here. except Exception: pass diff --git a/plugin/server/http_server.py b/plugin/server/http_server.py index 8a2df22..75d4f50 100644 --- a/plugin/server/http_server.py +++ b/plugin/server/http_server.py @@ -236,6 +236,15 @@ def _check_binary_loaded(self): return False return True + def _dispatch_on_main_thread(self, func): + """Run func via the main-thread dispatch queue. + + This avoids BN API deadlocks when called from the HTTP handler thread. + The func may freely call self._send_json_response etc. + """ + from ..core.binary_operations import _run_on_main_thread + return _run_on_main_thread(func) + def do_GET(self): try: # For all endpoints except /status, /convertNumber, /platforms, /binaries, /views, /selectBinary, check loaded @@ -254,13 +263,8 @@ def do_GET(self): params = self._parse_query_params() path = urllib.parse.urlparse(self.path).path - offset = parse_int_or_default(params.get("offset"), 0) - # Support both `limit` and `count` (alias) for pagination - if params.get("count") is not None: - limit = parse_int_or_default(params.get("count"), 100) - else: - limit = parse_int_or_default(params.get("limit"), 100) + # Handle /status directly (no BN API needed) if path == "/status": status = { "loaded": self.binary_ops and self.binary_ops.current_view is not None, @@ -269,8 +273,26 @@ def do_GET(self): else None, } self._send_json_response(status) + return + + # All other GET endpoints: dispatch to main thread + self._dispatch_on_main_thread(lambda: self._do_GET_inner(params, path)) + except TimeoutError: + self._send_json_response({"error": "Request timed out (main thread busy)"}, 504) + except Exception as e: + self._send_json_response({"error": str(e)}, 500) - elif path == "/functions" or path == "/methods": + def _do_GET_inner(self, params, path): + """Actual GET handler logic — runs on the main thread.""" + try: + offset = parse_int_or_default(params.get("offset"), 0) + # Support both `limit` and `count` (alias) for pagination + if params.get("count") is not None: + limit = parse_int_or_default(params.get("count"), 100) + else: + limit = parse_int_or_default(params.get("limit"), 100) + + if path == "/functions" or path == "/methods": functions = self.binary_ops.get_function_names(offset, limit) bn.log_info(f"Found {len(functions)} functions") self._send_json_response({"functions": functions}) @@ -1910,6 +1932,17 @@ def do_POST(self): bn.log_info(f"POST {path} with params: {params}") + # Dispatch all POST handling to main thread + self._dispatch_on_main_thread(lambda: self._do_POST_inner(params, path)) + except TimeoutError: + self._send_json_response({"error": "Request timed out (main thread busy)"}, 504) + except Exception as e: + bn.log_error(f"Error handling POST request: {e}") + self._send_json_response({"error": str(e)}, 500) + + def _do_POST_inner(self, params, path): + """Actual POST handler logic — runs on the main thread.""" + try: if path == "/load": filepath = params.get("filepath") if not filepath: