From 41f79251f17213b4c32a00c58aba4c1a88359974 Mon Sep 17 00:00:00 2001 From: sdamirsa Date: Tue, 14 Apr 2026 12:20:56 +0300 Subject: [PATCH 1/3] fix: cross-platform setup and missing dependencies - Rewrite setup_sam3.sh for Windows (Git Bash, WSL), macOS, and Linux with automatic pip/uv detection, path conversion, and error messages - Install SAM3 with [dev,notebooks] extras to include einops, pycocotools and other transitive deps missing from SAM3 core requirements - Patch triton import (Linux-only) to be conditional so SAM3 loads on all platforms - Add verification step at end of setup script (green/red status) - Patch SAM3 fused BFloat16 MLP kernel (addmm_act) that causes dtype mismatch on consumer GPUs (RTX 3080/4090 etc.) - Add python-multipart, modelscope, einops to requirements.txt - Add .gitattributes rule to keep .sh files with LF line endings Co-Authored-By: Claude Opus 4.6 (1M context) --- .gitattributes | 3 + modules/sam3_info_extractor.py | 25 +++ requirements.txt | 3 + scripts/setup_sam3.sh | 352 +++++++++++++++++++++++++++++++-- 4 files changed, 365 insertions(+), 18 deletions(-) diff --git a/.gitattributes b/.gitattributes index f14936a..284d96b 100644 --- a/.gitattributes +++ b/.gitattributes @@ -1 +1,4 @@ large_files/ filter=lfs diff=lfs merge=lfs -text + +# Keep shell scripts with Unix line endings on all platforms +*.sh text eol=lf diff --git a/modules/sam3_info_extractor.py b/modules/sam3_info_extractor.py index 00b7b28..17a7087 100644 --- a/modules/sam3_info_extractor.py +++ b/modules/sam3_info_extractor.py @@ -27,6 +27,31 @@ from .base import BaseProcessor, ProcessingContext, ModelWrapper from .data_types import ElementInfo, BoundingBox, ProcessingResult +# --------------------------------------------------------------------------- +# Patch SAM3's fused BFloat16 MLP kernel (perflib/fused.py addmm_act). +# Meta optimized this for H100 GPUs by casting all tensors to BFloat16, +# but this causes dtype mismatches on consumer GPUs (e.g. RTX 3080/4090). +# Replace with standard float32 linear + activation. +# --------------------------------------------------------------------------- +def _addmm_act_f32(activation, linear, mat1): + import torch.nn as nn + import torch.nn.functional as F + x = F.linear(mat1, linear.weight, linear.bias) + if activation in [F.gelu, nn.GELU]: + return F.gelu(x) + if activation in [F.relu, nn.ReLU]: + return F.relu(x) + raise ValueError(f"Unexpected activation {activation}") + +try: + import sam3.perflib.fused as _fused + _fused.addmm_act = _addmm_act_f32 + # Also patch any module that already imported the original reference + import sam3.model.vitdet as _vitdet + _vitdet.addmm_act = _addmm_act_f32 +except ImportError: + pass # SAM3 not installed yet + # ======================== 提示词分组枚举 ======================== class PromptGroup(Enum): diff --git a/requirements.txt b/requirements.txt index 09d6a4b..179fb4e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -7,6 +7,9 @@ scikit-image requests fastapi uvicorn[standard] +python-multipart +modelscope +einops # Text OCR (local, default) pytesseract diff --git a/scripts/setup_sam3.sh b/scripts/setup_sam3.sh index f9915fb..c84c1bd 100755 --- a/scripts/setup_sam3.sh +++ b/scripts/setup_sam3.sh @@ -1,44 +1,360 @@ #!/usr/bin/env bash -# 安装 SAM3 库并把 BPE 词表复制到 models/ -# 用法:在项目根目录执行 bash scripts/setup_sam3.sh -# 若无法直连 GitHub,可用镜像:SAM3_CLONE_URL="https://gitclone.com/github.com/facebookresearch/sam3.git" bash scripts/setup_sam3.sh -# 模型权重需自行下载到 models/:推荐 ModelScope https://modelscope.cn/models/facebook/sam3 ,见 docs/SETUP_SAM3.md +# ============================================================================= +# setup_sam3.sh - Install SAM3 library and copy BPE vocab +# +# Works on: Linux, macOS, Windows (Git Bash, WSL, MSYS2) +# +# Usage: +# bash scripts/setup_sam3.sh +# +# Options (env vars): +# PIP_CMD Override pip command (e.g. PIP_CMD="pip3") +# SAM3_CLONE_URL Override git clone URL (e.g. for mirrors) +# SAM3_SRC Override clone target (default: ./sam3_src) +# MODELS_DIR Override models dir (default: ./models) +# +# Model weights must be downloaded separately (see docs/SETUP_SAM3.md) +# ============================================================================= set -e + +# ============================================================================= +# 1. PLATFORM DETECTION +# ============================================================================= +detect_platform() { + case "$(uname -s)" in + Linux*) + if grep -qi microsoft /proc/version 2>/dev/null; then + PLATFORM="wsl" + else + PLATFORM="linux" + fi + ;; + Darwin*) PLATFORM="macos" ;; + MINGW*|MSYS*|CYGWIN*) PLATFORM="gitbash" ;; + *) PLATFORM="unknown" ;; + esac +} + +detect_platform +echo "[setup] Platform detected: $PLATFORM" + +# ============================================================================= +# 2. PATH HELPERS +# ============================================================================= +# Convert a bash path to a Windows-native path (only needed on WSL/Git Bash +# when calling Windows .exe tools like uv.exe or pip.exe). + +to_native_path() { + local p="$1" + case "$PLATFORM" in + wsl) wslpath -w "$p" ;; + gitbash) cygpath -w "$p" 2>/dev/null || echo "$p" ;; + *) echo "$p" ;; + esac +} + +# True if PIP_CMD points to a Windows .exe (needs native path conversion) +pip_is_windows_exe() { + [[ "$PIP_CMD" == *".exe"* ]] +} + +# Run $PIP_CMD install, converting paths for Windows .exe tools if needed. +# Handles pip extras like "path/to/pkg[dev,notebooks]" by splitting path from extras +# before conversion, then rejoining. +pip_install() { + local args=() + for arg in "$@"; do + if pip_is_windows_exe; then + # Split path from pip extras: "/path/pkg[dev,extra]" -> "/path/pkg" + "[dev,extra]" + local path_part="${arg%%\[*}" + local extras_part="" + if [[ "$arg" == *"["* ]]; then + extras_part="[${arg#*\[}" + fi + if [[ -e "$path_part" || "$path_part" == /* ]]; then + args+=("$(to_native_path "$path_part")${extras_part}") + else + args+=("$arg") + fi + else + args+=("$arg") + fi + done + echo " Running: $PIP_CMD install ${args[*]}" + $PIP_CMD install "${args[@]}" +} + +# ============================================================================= +# 3. PIP DETECTION +# ============================================================================= +# Priority: PIP_CMD override > uv > .venv pip > python -m pip > system pip + +detect_pip() { + # User override + if [ -n "$PIP_CMD" ]; then + echo "[setup] Using PIP_CMD override: $PIP_CMD" + return 0 + fi + + # uv / uv.exe + for cmd in uv uv.exe; do + if command -v "$cmd" &>/dev/null; then + PIP_CMD="$cmd pip" + echo "[setup] Found $cmd" + return 0 + fi + done + + # .venv pip (Windows layout then Unix layout) + for venv_pip in \ + "$PROJECT_ROOT/.venv/Scripts/pip.exe" \ + "$PROJECT_ROOT/.venv/Scripts/pip" \ + "$PROJECT_ROOT/.venv/bin/pip" \ + ; do + if [ -f "$venv_pip" ]; then + PIP_CMD="$venv_pip" + echo "[setup] Found venv pip: $venv_pip" + return 0 + fi + done + + # python -m pip (.venv python first, then system) + for py in \ + "$PROJECT_ROOT/.venv/Scripts/python.exe" \ + "$PROJECT_ROOT/.venv/bin/python" \ + python3 python python3.exe python.exe \ + ; do + if command -v "$py" &>/dev/null || [ -f "$py" ]; then + if "$py" -m pip --version &>/dev/null 2>&1; then + PIP_CMD="$py -m pip" + echo "[setup] Found $py -m pip" + return 0 + fi + fi + done + + # Bare pip3 / pip + for cmd in pip3 pip3.exe pip pip.exe; do + if command -v "$cmd" &>/dev/null; then + PIP_CMD="$cmd" + echo "[setup] Found $cmd" + return 0 + fi + done + + return 1 +} + +# ============================================================================= +# 4. PRE-FLIGHT CHECKS +# ============================================================================= SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" PROJECT_ROOT="$(cd "$SCRIPT_DIR/.." && pwd)" cd "$PROJECT_ROOT" SAM3_SRC="${SAM3_SRC:-$PROJECT_ROOT/sam3_src}" MODELS_DIR="${MODELS_DIR:-$PROJECT_ROOT/models}" -# 直连 GitHub 失败时可指定镜像,例如:SAM3_CLONE_URL="https://gitclone.com/github.com/facebookresearch/sam3.git" SAM3_CLONE_URL="${SAM3_CLONE_URL:-https://github.com/facebookresearch/sam3.git}" -echo "[1/3] 克隆 facebookresearch/sam3 (${SAM3_CLONE_URL}) ..." -if [[ -d "$SAM3_SRC/.git" ]]; then - echo " 已存在 $SAM3_SRC,跳过克隆(若需更新请先删掉该目录再运行)" +# Check git +if ! command -v git &>/dev/null; then + echo "[error] git is not installed. Please install git first." + exit 1 +fi + +# Check pip +if ! detect_pip; then + echo "" + echo "[error] No Python package installer found." + echo "" + case "$PLATFORM" in + linux|macos) + echo " Option 1: Install uv (recommended):" + echo " curl -LsSf https://astral.sh/uv/install.sh | sh" + echo "" + echo " Option 2: Create and activate a venv:" + echo " python3 -m venv .venv && source .venv/bin/activate" + ;; + wsl) + echo " Option 1: Install uv inside WSL:" + echo " curl -LsSf https://astral.sh/uv/install.sh | sh" + echo "" + echo " Option 2: Run this script from Windows cmd/PowerShell instead," + echo " where uv/pip are already available." + echo "" + echo " Option 3: Override with a known path:" + echo " PIP_CMD=\"/path/to/pip\" bash scripts/setup_sam3.sh" + ;; + gitbash) + echo " Option 1: Run from Windows cmd/PowerShell where uv/pip are on PATH:" + echo " bash scripts/setup_sam3.sh" + echo "" + echo " Option 2: Activate your venv first:" + echo " source .venv/Scripts/activate && bash scripts/setup_sam3.sh" + echo "" + echo " Option 3: Override:" + echo " PIP_CMD=\"uv pip\" bash scripts/setup_sam3.sh" + ;; + *) + echo " Install uv: https://docs.astral.sh/uv/getting-started/installation/" + echo " Or override: PIP_CMD=\"pip3\" bash scripts/setup_sam3.sh" + ;; + esac + exit 1 +fi + +echo "[setup] Project root: $PROJECT_ROOT" +echo "" + +# ============================================================================= +# 5. CLONE SAM3 +# ============================================================================= +echo "[1/4] Cloning facebookresearch/sam3 ..." +if [ -d "$SAM3_SRC/.git" ]; then + echo " Already exists at $SAM3_SRC (delete to re-clone)" else rm -rf "$SAM3_SRC" - git clone --depth 1 "$SAM3_CLONE_URL" "$SAM3_SRC" + if ! git clone --depth 1 "$SAM3_CLONE_URL" "$SAM3_SRC"; then + echo "" + echo "[error] git clone failed. Possible fixes:" + echo " - Check your internet connection" + echo " - Use a mirror:" + echo " SAM3_CLONE_URL=\"https://gitclone.com/github.com/facebookresearch/sam3.git\" bash scripts/setup_sam3.sh" + exit 1 + fi fi -echo "[2/3] 安装 SAM3 包 (pip install -e $SAM3_SRC) ..." -pip install -e "$SAM3_SRC" +# ============================================================================= +# 6. INSTALL SAM3 PACKAGE (with extras for missing transitive dependencies) +# ============================================================================= +# SAM3's core deps omit einops, pycocotools, etc. that are imported unconditionally. +# Installing with [dev,notebooks] extras pulls them in so users don't hit ImportErrors. +echo "" +echo "[2/4] Installing SAM3 package (with dev+notebooks extras) ..." +pip_install -e "$SAM3_SRC[dev,notebooks]" + +# ============================================================================= +# 7. PATCH: Make triton import optional (Linux-only, not needed for image segmentation) +# ============================================================================= +# SAM3's edt.py imports triton at module level, but triton is Linux-only. +# Edit-Banana only uses image segmentation, not video tracking (which needs EDT). +# This patch makes the import conditional so SAM3 loads on all platforms. +echo "" +echo "[patch] Making triton import optional for cross-platform support ..." + +# Use Python for reliable cross-platform patching (sed behaves differently on macOS/Linux/Windows) +PATCH_PYTHON="" +for py in \ + "$PROJECT_ROOT/.venv/Scripts/python.exe" \ + "$PROJECT_ROOT/.venv/bin/python" \ + python3 python python3.exe python.exe \ +; do + if command -v "$py" &>/dev/null || [ -f "$py" ]; then + PATCH_PYTHON="$py" + break + fi +done + +if [ -n "$PATCH_PYTHON" ]; then + "$PATCH_PYTHON" -c " +import os, sys + +sam3_src = sys.argv[1] + +# Patch edt.py: wrap 'import triton' in try/except +edt = os.path.join(sam3_src, 'sam3', 'model', 'edt.py') +if os.path.isfile(edt): + text = open(edt).read() + if 'import triton\n' in text and 'HAS_TRITON' not in text: + text = text.replace( + 'import triton\nimport triton.language as tl', + 'try:\n import triton\n import triton.language as tl\n HAS_TRITON = True\nexcept ImportError:\n HAS_TRITON = False' + ) + open(edt, 'w').write(text) + print(' Patched edt.py') + else: + print(' edt.py already patched') -echo "[3/3] 复制 BPE 词表到 models/ ..." +# Patch sam3_tracker_utils.py: wrap 'from sam3.model.edt import edt_triton' in try/except +tracker = os.path.join(sam3_src, 'sam3', 'model', 'sam3_tracker_utils.py') +if os.path.isfile(tracker): + text = open(tracker).read() + old = 'from sam3.model.edt import edt_triton' + if old in text and 'edt_triton = None' not in text: + text = text.replace(old, + 'try:\n from sam3.model.edt import edt_triton\nexcept ImportError:\n edt_triton = None # triton unavailable (Windows/macOS); not needed for image segmentation' + ) + open(tracker, 'w').write(text) + print(' Patched sam3_tracker_utils.py') + else: + print(' sam3_tracker_utils.py already patched') +" "$SAM3_SRC" +else + echo " [warn] No Python found to apply patch. You may need to install 'triton' manually." +fi + +# ============================================================================= +# 8. COPY BPE VOCAB +# ============================================================================= +echo "" +echo "[4/4] Copying BPE vocab to models/ ..." mkdir -p "$MODELS_DIR" BPE_NAME="bpe_simple_vocab_16e6.txt.gz" -for BPE_SRC in "$SAM3_SRC/assets/$BPE_NAME" "$SAM3_SRC/sam3/assets/$BPE_NAME"; do - if [[ -f "$BPE_SRC" ]]; then +BPE_FOUND=false + +for BPE_SRC in \ + "$SAM3_SRC/assets/$BPE_NAME" \ + "$SAM3_SRC/sam3/assets/$BPE_NAME" \ +; do + if [ -f "$BPE_SRC" ]; then cp "$BPE_SRC" "$MODELS_DIR/" - echo " 已复制到 $MODELS_DIR/$BPE_NAME" + echo " Copied to $MODELS_DIR/$BPE_NAME" + BPE_FOUND=true break fi done -if [[ ! -f "$MODELS_DIR/$BPE_NAME" ]]; then - echo " 未找到 BPE 文件,在仓库中查找:" + +if [ "$BPE_FOUND" = false ]; then + echo " [warn] BPE file not found in expected locations." + echo " Searching the cloned repo:" find "$SAM3_SRC" -name "*.gz" 2>/dev/null || true + echo "" + echo " Copy the file manually: cp $MODELS_DIR/$BPE_NAME" fi +# ============================================================================= +# 9. VERIFY INSTALLATION +# ============================================================================= echo "" -echo "完成。下一步:将 SAM3 权重下载到 $MODELS_DIR/(推荐 ModelScope),并配置 config.yaml,详见 docs/SETUP_SAM3.md" +echo "[verify] Testing SAM3 import ..." + +# Use the same Python we found earlier for patching +VERIFY_PY="${PATCH_PYTHON:-python3}" + +if "$VERIFY_PY" -c "from sam3.model_builder import build_sam3_image_model; print('OK')" 2>/dev/null; then + echo "" + echo -e "\033[32m=========================================\033[0m" + echo -e "\033[32m SAM3 library installed successfully\033[0m" + echo -e "\033[32m=========================================\033[0m" + echo "" + echo "Next steps:" + echo " 1. Download SAM3 weights (sam3.pt) into models/" + echo " See docs/SETUP_SAM3.md for download links" + echo "" + echo " 2. Create config (if not done):" + echo " cp config/config.yaml.example config/config.yaml" +else + echo "" + echo -e "\033[31m=========================================\033[0m" + echo -e "\033[31m SAM3 import failed\033[0m" + echo -e "\033[31m=========================================\033[0m" + echo "" + echo "Run this to see the full error:" + echo " python -c \"from sam3.model_builder import build_sam3_image_model\"" + echo "" + echo "Common fixes:" + echo " - Missing dependency: pip install " + echo " - Wrong Python: make sure you're using the venv Python" + exit 1 +fi From d6c445aaea4423fdc10245f258bc7a28912b96e2 Mon Sep 17 00:00:00 2001 From: sdamirsa Date: Wed, 15 Apr 2026 21:19:52 +0300 Subject: [PATCH 2/3] feat: editable vector export pipeline (Stage 8) + prompt v2 + heart/MRI fix MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Adds an editable-vector export stage to the pipeline, broadens SAM3 prompt coverage for scientific/medical figures, and fixes a classification bug that was rendering medical-image detections as blank white outlines. ## What's new ### Stage 8 — Vector export (new modules) Per-image output under `output/{image}/vectors/`: elements/ individual editable SVGs for every detected element rasters/ cropped transparent-background PNGs for image elements combined/ single combined.svg (layered) and combined.pdf manifest.json element index with bbox, score, layer, paths New modules: modules/svg_generator.py hybrid renderer — geometric primitives for known shapes, Chaikin-smoothed polygons for complex contours, base64-embedded crops for raster elements, editable for OCR modules/pdf_combiner.py svglib/cairosvg PDF backend modules/section_detector.py panel detection via SAM3 backgrounds + HoughLinesP modules/vector_exporter.py Stage 8 orchestrator (BaseProcessor subclass) CLI: --vector-level=granular|section|component|all (default: granular) --no-vectors skip Stage 8 ### Prompt v2 — broader coverage for scientific/medical figures Total prompts: 19 -> 78 prompts/image.py 5 -> 29 (CT/MRI/ultrasound, 3D heart/anatomy, person/crowd icons, computer monitors, checkerboard/grid patterns, image stacks) prompts/shape.py 7 -> 17 (trapezoid, parallelogram, 3D cube, isometric box, cylinder, color swatch, small colored square, stack of rectangles) prompts/arrow.py 3 -> 17 (thick/block/curved/looping/bidirectional/ dashed/dotted/L-shaped/skip variants) prompts/background.py 4 -> 15 (sub-figure panel, dashed border rectangle, legend box/panel, title bar, header strip) Config tuning to match (config/config.yaml — gitignored): shape.min_area: 200 -> 80 (catches 14x14 legend swatches) shape.score_threshold: 0.5 -> 0.45 arrow.score_threshold: 0.45 -> 0.4 image.score_threshold: 0.5 -> 0.45 ### Bug fix — heart/MRI rendered as blank white polygon outlines Type classification was scattered across three files using case-sensitive string comparisons. IMAGE_PROMPT contains mixed-case names like "3D heart model" and "MRI image", but every comparison did `elem.type.lower() in CasedSet`, so those specific scientific-image prompts silently fell through and got rendered as white polygon outlines. Across 18 figures, this dropped 40 medical detections (36 MRI + 4 heart) to outline-only. After the fix all 40 are properly extracted as RGBA crops and embedded as base64 in their SVGs. Fix made the prompt files the single source of truth: modules/svg_generator.py RASTER_TYPES, GEOMETRIC_SHAPES, ARROW_TYPES now derived from prompt files via `_expand_forms()` helper (covers both space-form and underscore-form normalization) modules/icon_picture_processor.py lowercased IMAGE_PROMPT before comparison modules/data_types.py get_layer_level() imports prompt lists; specific prompts land in correct layer (IMAGE/BASIC_SHAPE/ARROW/BACKGROUND) instead of OTHER Adding a new prompt now auto-registers for routing, layer assignment, and raster cropping — no parallel lists to keep in sync. ## Run results on the 18-figure test set 1,071 individual element SVGs 425 raster PNGs (was 385 before fix; +40 = the heart/MRI recoveries) 18 combined SVGs (one per figure) 18 combined PDFs (one per figure, Affinity-ready) ## Known limitations & future work Even with broader prompts and the new hierarchical layer assignment, the pipeline still under-understands **multi-panel / schematic figures**. Detection happens per element; the global semantics — which arrow connects which box across panel boundaries, which legend swatch labels which plot — is not modeled. Two directions worth exploring: 1. Two-pass extraction with explicit panel splitting. First pass: detect sub-figure panels and split the source image into per-panel crops. Second pass: run the full pipeline on each crop independently. This should help the model focus on local structure and avoid cross-panel prompt confusion. SAM3 backgrounds + HoughLinesP already give us panel candidates (see section_detector.py); the missing piece is the recursive split-and-rerun loop. 2. Smart margin padding around cropped rasters. Tight bboxes sometimes clip strokes or leave faint background ghosts. A per-type margin heuristic (icon vs. photo vs. schematic illustration) would clean this up, but the logic is hard to pin down — loose enough to capture the full visual element, tight enough to avoid neighbor bleed. Co-Authored-By: Claude Opus 4.6 --- .gitignore | 4 + main.py | 53 ++- modules/__init__.py | 11 + modules/data_types.py | 76 +++- modules/icon_picture_processor.py | 11 +- modules/pdf_combiner.py | 157 +++++++ modules/section_detector.py | 369 ++++++++++++++++ modules/svg_generator.py | 677 ++++++++++++++++++++++++++++++ modules/vector_exporter.py | 526 +++++++++++++++++++++++ prompts/arrow.py | 37 +- prompts/background.py | 31 +- prompts/image.py | 47 +++ prompts/shape.py | 36 +- requirements.txt | 4 + 14 files changed, 2014 insertions(+), 25 deletions(-) create mode 100644 modules/pdf_combiner.py create mode 100644 modules/section_detector.py create mode 100644 modules/svg_generator.py create mode 100644 modules/vector_exporter.py diff --git a/.gitignore b/.gitignore index f05287c..b37e9a6 100644 --- a/.gitignore +++ b/.gitignore @@ -50,3 +50,7 @@ sam3_src/ # Local processing & debug arrow_processing/ debug_output/ + +# Local planning notes & AI tool session data +.amir-zone/ +.claude/ diff --git a/main.py b/main.py index 28d1e1b..35a6d84 100644 --- a/main.py +++ b/main.py @@ -12,6 +12,8 @@ python main.py -i input/test.png -o output/custom/ python main.py -i input/test.png --refine python main.py -i input/test.png --no-text + python main.py -i input/test.png --vector-level=all + python main.py -i input/test.png --no-vectors """ import os @@ -39,7 +41,10 @@ XMLMerger, MetricEvaluator, RefinementProcessor, - + + # Stage 8: Vector export + VectorExporter, + # Text (modules/text/) TextRestorer, @@ -89,6 +94,7 @@ def __init__(self, config: dict = None): self._xml_merger = None self._metric_evaluator = None self._refinement_processor = None + self._vector_exporter = None @property def text_restorer(self): @@ -138,13 +144,21 @@ def refinement_processor(self) -> RefinementProcessor: if self._refinement_processor is None: self._refinement_processor = RefinementProcessor() return self._refinement_processor + + @property + def vector_exporter(self) -> VectorExporter: + if self._vector_exporter is None: + self._vector_exporter = VectorExporter() + return self._vector_exporter def process_image(self, image_path: str, output_dir: str = None, with_refinement: bool = False, with_text: bool = True, - groups: List[PromptGroup] = None) -> Optional[str]: + groups: List[PromptGroup] = None, + vector_level: str = "granular", + no_vectors: bool = False) -> Optional[str]: """Run pipeline on one image. Returns output XML path or None.""" print(f"\n{'='*60}") print(f"Processing: {image_path}") @@ -264,8 +278,28 @@ def process_image(self, output_path = merge_result.metadata.get('output_path') print(f" Output: {output_path}") + + # ============ Stage 8: Vector Export ============ + if not no_vectors: + print(f"\n[8] Vector export (level={vector_level})...") + context.intermediate_results['vector_level'] = vector_level + try: + vec_result = self.vector_exporter.process(context) + if vec_result.success: + vec_count = vec_result.metadata.get('exported_count', 0) + vec_dir = vec_result.metadata.get('vector_dir', '') + print(f" Exported {vec_count} elements -> {vec_dir}") + else: + print(f" Vector export failed: {vec_result.error_message}") + except Exception as e: + print(f" Vector export failed: {e}") + import traceback + traceback.print_exc() + else: + print("\n[8] Vector export (skipped)") + print(f"\n{'='*60}\nDone.\n{'='*60}") - + return output_path except Exception as e: @@ -332,6 +366,8 @@ def main(): python main.py python main.py -i test.png --refine python main.py -i test.png --groups image arrow + python main.py -i test.png --vector-level=all + python main.py -i test.png --no-vectors """ ) @@ -348,6 +384,13 @@ def main(): help="Prompt groups to process (default: all)") parser.add_argument("--show-prompts", action="store_true", help="Show prompt config") + + # Stage 8: Vector export options + parser.add_argument("--vector-level", type=str, default="granular", + choices=['granular', 'section', 'component', 'all'], + help="Vector export granularity (default: granular)") + parser.add_argument("--no-vectors", action="store_true", + help="Skip vector export (Stage 8)") args = parser.parse_args() @@ -417,7 +460,9 @@ def main(): output_dir=output_dir, with_refinement=args.refine, with_text=not args.no_text, - groups=groups + groups=groups, + vector_level=args.vector_level, + no_vectors=args.no_vectors, ) if result: success_count += 1 diff --git a/modules/__init__.py b/modules/__init__.py index 5aaee23..5a7433f 100644 --- a/modules/__init__.py +++ b/modules/__init__.py @@ -21,6 +21,12 @@ from .metric_evaluator import MetricEvaluator from .refinement_processor import RefinementProcessor +# Stage 8: Vector export +from .vector_exporter import VectorExporter +from .svg_generator import SVGGenerator +from .pdf_combiner import PDFCombiner +from .section_detector import SectionDetector + # Text (modules/text/); optional if ocr/coord_processor missing try: from .text.restorer import TextRestorer @@ -53,4 +59,9 @@ 'BasicShapeProcessor', 'MetricEvaluator', 'RefinementProcessor', + # Stage 8: Vector export + 'VectorExporter', + 'SVGGenerator', + 'PDFCombiner', + 'SectionDetector', ] diff --git a/modules/data_types.py b/modules/data_types.py index 0b9169c..84be747 100644 --- a/modules/data_types.py +++ b/modules/data_types.py @@ -255,37 +255,81 @@ def from_yaml(cls, yaml_path: str) -> 'ProcessingConfig': # ======================== 辅助函数 ======================== + +def _expand_forms(prompts): + """Return set containing both lowercase-with-spaces and lowercase-with-underscores forms.""" + out = set() + for p in prompts: + low = p.lower() + out.add(low) + out.add(low.replace(" ", "_")) + return out + + +# Lazy-built prompt-derived type sets. Built on first call of get_layer_level +# to avoid import-time cycles (prompts has no deps, but be safe). +_TYPE_SETS_CACHE = {} + + +def _get_type_sets(): + """Build and cache prompt-derived type sets.""" + if _TYPE_SETS_CACHE: + return _TYPE_SETS_CACHE + try: + from prompts.image import IMAGE_PROMPT + from prompts.shape import SHAPE_PROMPT + from prompts.arrow import ARROW_PROMPT + from prompts.background import BACKGROUND_PROMPT + except ImportError: + # Fallback: empty sets; legacy hardcoded lists below still apply. + IMAGE_PROMPT = SHAPE_PROMPT = ARROW_PROMPT = BACKGROUND_PROMPT = [] + + _TYPE_SETS_CACHE["image"] = _expand_forms(IMAGE_PROMPT) + _TYPE_SETS_CACHE["shape"] = _expand_forms(SHAPE_PROMPT) + _TYPE_SETS_CACHE["arrow"] = _expand_forms(ARROW_PROMPT) + _TYPE_SETS_CACHE["background"] = _expand_forms(BACKGROUND_PROMPT) + return _TYPE_SETS_CACHE + + def get_layer_level(element_type: str) -> int: """ 根据元素类型获取默认层级 - - 供各子模块使用,确保层级分配一致 + + 供各子模块使用,确保层级分配一致。 + + v2 fix: derive image/shape/arrow/background sets from prompt files so + specific prompts like "3D heart model" or "MRI image" (which were + silently falling through to LayerLevel.OTHER and breaking stacking) + now get the correct IMAGE layer. """ element_type = element_type.lower() - - # 背景/容器类(最底层) - if element_type in {'section_panel', 'title_bar'}: + sets = _get_type_sets() + + # 背景/容器类(最底层)— legacy names + prompt-derived + if element_type in {'section_panel', 'title_bar'} or element_type in sets["background"]: return LayerLevel.BACKGROUND.value - - # 箭头/连接线 - if element_type in {'arrow', 'line', 'connector'}: + + # 箭头/连接线 — legacy names + prompt-derived + if element_type in {'arrow', 'line', 'connector'} or element_type in sets["arrow"]: return LayerLevel.ARROW.value - + # 文字 if element_type == 'text': return LayerLevel.TEXT.value - - # 图片类 - if element_type in {'icon', 'picture', 'image', 'logo', 'chart', 'function_graph'}: + + # 图片类 — legacy names + prompt-derived (this is the fix path for the heart bug) + if element_type in { + 'icon', 'picture', 'image', 'logo', 'chart', 'function_graph' + } or element_type in sets["image"]: return LayerLevel.IMAGE.value - - # 基本图形 + + # 基本图形 — legacy names + prompt-derived if element_type in { 'rectangle', 'rounded_rectangle', 'rounded rectangle', 'diamond', 'ellipse', 'circle', 'cylinder', 'cloud', 'hexagon', 'triangle', 'parallelogram', 'actor' - }: + } or element_type in sets["shape"]: return LayerLevel.BASIC_SHAPE.value - + # 其他 return LayerLevel.OTHER.value diff --git a/modules/icon_picture_processor.py b/modules/icon_picture_processor.py index 15a5bf3..fdd7e10 100644 --- a/modules/icon_picture_processor.py +++ b/modules/icon_picture_processor.py @@ -315,8 +315,15 @@ def process(self, context: ProcessingContext) -> ProcessingResult: ) def _get_elements_to_process(self, elements: List[ElementInfo]) -> List[ElementInfo]: - """Filter elements to process (icons, arrows, etc.; arrows treated as icon crop).""" - all_types = set(IMAGE_PROMPT) | {"arrow", "line", "connector"} + """Filter elements to process (icons, arrows, etc.; arrows treated as icon crop). + + NOTE: IMAGE_PROMPT contains mixed-case strings (e.g. "3D heart model", + "MRI image", "CT scan image"). Comparing `.lower()` against the raw + set caused those detections to be silently skipped — no base64 was + generated and downstream SVG rendering fell back to a plain polygon + outline. Always normalize both sides. + """ + all_types = {t.lower() for t in IMAGE_PROMPT} | {"arrow", "line", "connector"} return [ e for e in elements if e.element_type.lower() in all_types and e.base64 is None diff --git a/modules/pdf_combiner.py b/modules/pdf_combiner.py new file mode 100644 index 0000000..73dac7c --- /dev/null +++ b/modules/pdf_combiner.py @@ -0,0 +1,157 @@ +""" +Stage 8b: PDF Combiner Module + +Converts a combined SVG (or set of individual SVGs) into a single layered PDF. +Uses svglib + reportlab for pure-Python PDF generation (no system dependencies). + +Fallback: if svglib is not installed, attempts CairoSVG. If neither is available, +skips PDF generation with a warning. + +Usage: + from modules.pdf_combiner import PDFCombiner + + combiner = PDFCombiner() + combiner.svg_to_pdf("combined.svg", "output.pdf") +""" + +import os +import warnings +from typing import Optional + +# Try to import svglib + reportlab (primary) +_SVGLIB_AVAILABLE = False +_CAIROSVG_AVAILABLE = False + +try: + from svglib.svglib import svg2rlg + from reportlab.graphics import renderPDF + from reportlab.lib.pagesizes import letter + _SVGLIB_AVAILABLE = True +except ImportError: + pass + +if not _SVGLIB_AVAILABLE: + try: + import cairosvg + _CAIROSVG_AVAILABLE = True + except ImportError: + pass + + +class PDFCombiner: + """ + Converts SVG files to PDF. + + Priority: + 1. svglib + reportlab (pure Python, recommended) + 2. CairoSVG (requires cairo system library) + 3. Skip with warning + + Methods: + svg_to_pdf(svg_path, pdf_path) -> bool + svg_string_to_pdf(svg_string, pdf_path) -> bool + """ + + def __init__(self): + self._log_prefix = "[PDFCombiner]" + self._backend = self._detect_backend() + + def _log(self, msg: str): + print(f"{self._log_prefix} {msg}") + + def _detect_backend(self) -> str: + if _SVGLIB_AVAILABLE: + return "svglib" + elif _CAIROSVG_AVAILABLE: + return "cairosvg" + else: + return "none" + + @property + def is_available(self) -> bool: + """Check if any PDF backend is available.""" + return self._backend != "none" + + def svg_to_pdf(self, svg_path: str, pdf_path: str) -> bool: + """ + Convert an SVG file to PDF. + + Args: + svg_path: Path to input SVG file + pdf_path: Path to output PDF file + + Returns: + True if successful, False otherwise + """ + if not os.path.exists(svg_path): + self._log(f"SVG file not found: {svg_path}") + return False + + os.makedirs(os.path.dirname(pdf_path), exist_ok=True) + + if self._backend == "svglib": + return self._svglib_convert(svg_path, pdf_path) + elif self._backend == "cairosvg": + return self._cairosvg_convert(svg_path, pdf_path) + else: + self._log( + "No PDF backend available. Install svglib+reportlab: " + "pip install svglib reportlab" + ) + return False + + def svg_string_to_pdf(self, svg_string: str, pdf_path: str) -> bool: + """ + Convert an SVG string to PDF by writing to temp file first. + + Args: + svg_string: SVG content as string + pdf_path: Path to output PDF file + + Returns: + True if successful, False otherwise + """ + # Write SVG to temporary file next to output + tmp_svg = pdf_path.replace(".pdf", "_tmp.svg") + try: + os.makedirs(os.path.dirname(pdf_path), exist_ok=True) + with open(tmp_svg, "w", encoding="utf-8") as f: + f.write(svg_string) + result = self.svg_to_pdf(tmp_svg, pdf_path) + return result + finally: + # Clean up temp file + if os.path.exists(tmp_svg): + try: + os.remove(tmp_svg) + except OSError: + pass + + # ================================================================ + # Backend Implementations + # ================================================================ + + def _svglib_convert(self, svg_path: str, pdf_path: str) -> bool: + """Convert SVG to PDF using svglib + reportlab.""" + try: + drawing = svg2rlg(svg_path) + if drawing is None: + self._log(f"svglib could not parse: {svg_path}") + return False + renderPDF.drawToFile(drawing, pdf_path) + self._log(f"PDF created (svglib): {pdf_path}") + return True + except Exception as e: + self._log(f"svglib conversion failed: {e}") + return False + + def _cairosvg_convert(self, svg_path: str, pdf_path: str) -> bool: + """Convert SVG to PDF using CairoSVG.""" + try: + import cairosvg + cairosvg.svg2pdf(url=svg_path, write_to=pdf_path) + self._log(f"PDF created (cairosvg): {pdf_path}") + return True + except Exception as e: + self._log(f"CairoSVG conversion failed: {e}") + return False diff --git a/modules/section_detector.py b/modules/section_detector.py new file mode 100644 index 0000000..b3ebf8d --- /dev/null +++ b/modules/section_detector.py @@ -0,0 +1,369 @@ +""" +Stage 8c: Section Detector Module + +Identifies major panels/sections in scientific figures (e.g., panel (a), (b), (c)). + +Hybrid approach: + 1. Use SAM3 background elements (section_panel, title_bar) from existing segmentation + 2. Validate/refine with OpenCV line detection for panel dividers + 3. Assign elements to sections based on spatial overlap + +Usage: + from modules.section_detector import SectionDetector + + detector = SectionDetector() + sections = detector.detect_sections(context) + # Returns list of Section objects with bbox + child element IDs +""" + +import os +import cv2 +import numpy as np +from typing import List, Optional, Dict, Tuple +from dataclasses import dataclass, field + +from .base import ProcessingContext +from .data_types import ElementInfo, BoundingBox + + +@dataclass +class Section: + """Represents a detected panel/section of the figure.""" + id: int + label: str # e.g., "a", "b", "c" + bbox: BoundingBox # Section bounding box + child_element_ids: List[int] = field(default_factory=list) # IDs of elements inside + confidence: float = 1.0 # Detection confidence + source: str = "sam3" # Detection method: "sam3", "line", "merged" + + def contains_point(self, x: int, y: int) -> bool: + """Check if a point is inside this section.""" + return ( + self.bbox.x1 <= x <= self.bbox.x2 + and self.bbox.y1 <= y <= self.bbox.y2 + ) + + def contains_bbox(self, other: BoundingBox, overlap_threshold: float = 0.5) -> bool: + """Check if another bbox is mostly inside this section.""" + # Calculate intersection + ix1 = max(self.bbox.x1, other.x1) + iy1 = max(self.bbox.y1, other.y1) + ix2 = min(self.bbox.x2, other.x2) + iy2 = min(self.bbox.y2, other.y2) + + if ix2 <= ix1 or iy2 <= iy1: + return False + + intersection_area = (ix2 - ix1) * (iy2 - iy1) + other_area = other.area + if other_area == 0: + return False + + return (intersection_area / other_area) >= overlap_threshold + + def to_dict(self) -> Dict: + return { + "id": self.id, + "label": self.label, + "bbox": self.bbox.to_list(), + "child_element_ids": self.child_element_ids, + "confidence": self.confidence, + "source": self.source, + } + + +class SectionDetector: + """ + Detects major panels/sections in scientific figures. + + Strategy: + 1. Look for SAM3 background elements (section_panel, title_bar) + with large area — these are likely panel boundaries + 2. Use OpenCV to detect strong horizontal/vertical lines + that divide the image into panels + 3. Merge results from both methods + 4. Assign labels (a, b, c, ...) based on spatial ordering + 5. Assign child elements to sections + """ + + def __init__( + self, + min_section_area_ratio: float = 0.03, + line_threshold: int = 100, + merge_overlap_threshold: float = 0.6, + ): + """ + Args: + min_section_area_ratio: Minimum section area as fraction of image area + line_threshold: HoughLinesP threshold for line detection + merge_overlap_threshold: IoU threshold for merging SAM3 + line detections + """ + self._min_section_area_ratio = min_section_area_ratio + self._line_threshold = line_threshold + self._merge_overlap_threshold = merge_overlap_threshold + self._log_prefix = "[SectionDetector]" + + def _log(self, msg: str): + print(f"{self._log_prefix} {msg}") + + def detect_sections( + self, context: ProcessingContext + ) -> List[Section]: + """ + Detect sections from ProcessingContext. + + Args: + context: Pipeline context with elements and image info + + Returns: + List of Section objects with child_element_ids populated + """ + image = cv2.imread(context.image_path) + if image is None: + self._log(f"Could not read image: {context.image_path}") + return [] + + canvas_area = context.canvas_width * context.canvas_height + if canvas_area == 0: + canvas_area = image.shape[0] * image.shape[1] + + # Step 1: Get sections from SAM3 background elements + sam3_sections = self._sections_from_sam3( + context.elements, canvas_area + ) + self._log(f"SAM3 sections: {len(sam3_sections)}") + + # Step 2: Get sections from line detection + line_sections = self._sections_from_lines( + image, canvas_area + ) + self._log(f"Line-detected sections: {len(line_sections)}") + + # Step 3: Merge + if sam3_sections and not line_sections: + sections = sam3_sections + elif line_sections and not sam3_sections: + sections = line_sections + elif sam3_sections and line_sections: + sections = self._merge_sections(sam3_sections, line_sections) + else: + self._log("No sections detected") + return [] + + # Step 4: Assign labels (a, b, c, ...) by top-left ordering + sections = self._assign_labels(sections) + + # Step 5: Assign child elements to sections + self._assign_elements_to_sections(sections, context.elements) + + self._log(f"Final sections: {len(sections)}") + for s in sections: + self._log( + f" Section {s.label}: bbox={s.bbox.to_list()}, " + f"children={len(s.child_element_ids)}, source={s.source}" + ) + + return sections + + # ================================================================ + # Step 1: SAM3-based section detection + # ================================================================ + + def _sections_from_sam3( + self, elements: List[ElementInfo], canvas_area: int + ) -> List[Section]: + """Extract sections from SAM3 background/panel elements.""" + sections = [] + min_area = canvas_area * self._min_section_area_ratio + + for elem in elements: + elem_type = elem.element_type.lower() + if elem_type not in ('section_panel', 'title_bar', 'panel', 'container'): + # Also check source_group + if not (hasattr(elem, 'source_prompt') and elem.source_prompt + and elem.source_prompt.lower() in ('panel', 'container', 'filled region', 'background')): + continue + + if elem.bbox.area < min_area: + continue + + sections.append( + Section( + id=len(sections), + label="", + bbox=elem.bbox, + confidence=elem.score, + source="sam3", + ) + ) + + return sections + + # ================================================================ + # Step 2: Line-based section detection + # ================================================================ + + def _sections_from_lines( + self, image: np.ndarray, canvas_area: int + ) -> List[Section]: + """ + Detect dividing lines (horizontal/vertical) and infer section bounding boxes. + + Only returns sections if strong full-width or full-height lines are found. + """ + h, w = image.shape[:2] + gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) + + # Edge detection + edges = cv2.Canny(gray, 50, 150, apertureSize=3) + + # Detect lines with HoughLinesP + min_line_length = min(w, h) * 0.3 # At least 30% of image dimension + lines = cv2.HoughLinesP( + edges, + rho=1, + theta=np.pi / 180, + threshold=self._line_threshold, + minLineLength=int(min_line_length), + maxLineGap=10, + ) + + if lines is None: + return [] + + # Classify lines as horizontal or vertical + h_lines = [] # y-coordinates of horizontal dividers + v_lines = [] # x-coordinates of vertical dividers + + for line in lines: + x1, y1, x2, y2 = line[0] + length = np.sqrt((x2 - x1) ** 2 + (y2 - y1) ** 2) + angle = abs(np.arctan2(y2 - y1, x2 - x1) * 180 / np.pi) + + # Horizontal line (angle close to 0 or 180) + if (angle < 5 or angle > 175) and length > w * 0.5: + h_lines.append((y1 + y2) // 2) + + # Vertical line (angle close to 90) + if 85 < angle < 95 and length > h * 0.3: + v_lines.append((x1 + x2) // 2) + + # Cluster nearby lines (within 20px) + h_dividers = self._cluster_values(h_lines, threshold=20) + v_dividers = self._cluster_values(v_lines, threshold=20) + + if not h_dividers and not v_dividers: + return [] + + # Build grid of sections from dividers + h_bounds = [0] + sorted(h_dividers) + [h] + v_bounds = [0] + sorted(v_dividers) + [w] + + sections = [] + min_area = canvas_area * self._min_section_area_ratio + + for i in range(len(h_bounds) - 1): + for j in range(len(v_bounds) - 1): + y1 = h_bounds[i] + y2 = h_bounds[i + 1] + x1 = v_bounds[j] + x2 = v_bounds[j + 1] + bbox = BoundingBox(x1=x1, y1=y1, x2=x2, y2=y2) + if bbox.area >= min_area: + sections.append( + Section( + id=len(sections), + label="", + bbox=bbox, + confidence=0.7, + source="line", + ) + ) + + return sections + + # ================================================================ + # Step 3: Merge sections from multiple sources + # ================================================================ + + def _merge_sections( + self, sam3_sections: List[Section], line_sections: List[Section] + ) -> List[Section]: + """Merge SAM3 and line-detected sections, preferring SAM3 when they overlap.""" + merged = list(sam3_sections) # Start with SAM3 sections + + for ls in line_sections: + # Check if any SAM3 section significantly overlaps + has_overlap = False + for ss in sam3_sections: + iou = self._bbox_iou(ls.bbox, ss.bbox) + if iou > self._merge_overlap_threshold: + has_overlap = True + break + if not has_overlap: + ls.id = len(merged) + merged.append(ls) + + return merged + + # ================================================================ + # Step 4 & 5: Label assignment and element mapping + # ================================================================ + + def _assign_labels(self, sections: List[Section]) -> List[Section]: + """Assign (a), (b), (c) labels based on top-to-bottom, left-to-right order.""" + # Sort by (y_center, x_center) — row-major order + sections.sort(key=lambda s: (s.bbox.center[1], s.bbox.center[0])) + + labels = "abcdefghijklmnopqrstuvwxyz" + for i, section in enumerate(sections): + section.id = i + section.label = labels[i] if i < len(labels) else f"s{i}" + + return sections + + def _assign_elements_to_sections( + self, sections: List[Section], elements: List[ElementInfo] + ): + """Assign each element to the section that contains its center.""" + for elem in elements: + cx, cy = elem.bbox.center + for section in sections: + if section.contains_point(cx, cy): + section.child_element_ids.append(elem.id) + break # Each element belongs to at most one section + + # ================================================================ + # Helpers + # ================================================================ + + @staticmethod + def _cluster_values(values: List[int], threshold: int = 20) -> List[int]: + """Cluster nearby values and return cluster centers.""" + if not values: + return [] + sorted_vals = sorted(values) + clusters = [[sorted_vals[0]]] + for v in sorted_vals[1:]: + if v - clusters[-1][-1] <= threshold: + clusters[-1].append(v) + else: + clusters.append([v]) + return [int(np.mean(c)) for c in clusters] + + @staticmethod + def _bbox_iou(a: BoundingBox, b: BoundingBox) -> float: + """Calculate Intersection over Union of two bboxes.""" + ix1 = max(a.x1, b.x1) + iy1 = max(a.y1, b.y1) + ix2 = min(a.x2, b.x2) + iy2 = min(a.y2, b.y2) + + if ix2 <= ix1 or iy2 <= iy1: + return 0.0 + + intersection = (ix2 - ix1) * (iy2 - iy1) + union = a.area + b.area - intersection + if union == 0: + return 0.0 + return intersection / union diff --git a/modules/svg_generator.py b/modules/svg_generator.py new file mode 100644 index 0000000..c4934dd --- /dev/null +++ b/modules/svg_generator.py @@ -0,0 +1,677 @@ +""" +Stage 8a: SVG Generator Module + +Converts ElementInfo objects into SVG files. + +Hybrid approach: + - Known shapes (rectangle, ellipse, diamond, etc.) -> clean geometric SVG primitives + - Complex/unknown shapes -> smoothed polygon SVG paths + - Raster elements (icon, picture, photo) -> SVG with embedded base64 + - Text elements -> SVG with editable text + - Arrows -> SVG / with marker arrowheads + +Output: individual SVG files per element + combined SVG with all elements grouped. + +Usage: + from modules.svg_generator import SVGGenerator + + generator = SVGGenerator() + # Generate individual element SVG + svg_string = generator.element_to_svg(element, image) + # Generate combined SVG + combined_svg = generator.generate_combined_svg(elements, image, canvas_w, canvas_h) +""" + +import os +import io +import base64 +import math +from typing import List, Optional, Tuple, Dict, Any +from xml.sax.saxutils import escape as xml_escape + +import numpy as np +from PIL import Image +import cv2 + +from .base import BaseProcessor, ProcessingContext +from .data_types import ( + ElementInfo, + BoundingBox, + ProcessingResult, + LayerLevel, + get_layer_level, +) + +# Prompt lists are the single source of truth for element-type classification. +# Importing here means any new prompt auto-registers as its category — +# no silent "falls through to polygon outline" bugs (see heart-model regression). +from prompts.image import IMAGE_PROMPT +from prompts.shape import SHAPE_PROMPT +from prompts.arrow import ARROW_PROMPT + + +# ======================== Constants ======================== + +# SVG 1.1 header template (Canva requires SVG 1.1 profile) +SVG_HEADER = """ + +""" +SVG_FOOTER = "\n" + +# Arrowhead marker definition (reusable) +ARROW_MARKER_DEF = """ + + + + + + + +""" + + +def _expand_forms(prompts): + """ + For each prompt name, yield both the lowercase space form AND the + lowercase underscore form. This protects against the two different + element_type normalization conventions used downstream: + - svg_generator.py uses element.element_type.lower() ("3d heart model") + - vector_exporter.py uses ....lower().replace(" ", "_") ("3d_heart_model") + """ + out = set() + for p in prompts: + low = p.lower() + out.add(low) + out.add(low.replace(" ", "_")) + return out + + +# Known geometric shape types that get clean SVG primitives. +# Derived from SHAPE_PROMPT + legacy container types. +GEOMETRIC_SHAPES = _expand_forms(SHAPE_PROMPT) | { + 'section_panel', 'title_bar', # legacy container types (not in prompts) + 'actor', # legacy +} + +# Raster/image element types that get embedded base64. +# Derived from IMAGE_PROMPT + legacy ElementType enum entries. +# CRITICAL: without IMAGE_PROMPT expansion, "3D heart model" / "MRI image" +# fall through to polygon outline rendering (blank white shape bug). +RASTER_TYPES = _expand_forms(IMAGE_PROMPT) | { + 'function_graph', 'image', # legacy ElementType entries not in prompts +} + +# Arrow/connector types. Derived from ARROW_PROMPT + legacy names. +ARROW_TYPES = _expand_forms(ARROW_PROMPT) | { + 'arrow', 'line', 'connector', # legacy (overlap with prompts but explicit) +} + +# Web-safe font stack for text elements +DEFAULT_FONT_FAMILY = "Arial, Helvetica, sans-serif" + + +class SVGGenerator: + """ + Generates SVG files from ElementInfo objects. + + Hybrid strategy: + - Geometric shapes -> clean , , primitives + - Complex shapes -> smoothed from polygon contours + - Raster elements -> with base64 data + - Text -> with editable content + - Arrows -> / with arrowhead markers + """ + + def __init__(self): + self._log_prefix = "[SVGGenerator]" + + def _log(self, msg: str): + print(f"{self._log_prefix} {msg}") + + # ================================================================ + # PUBLIC API + # ================================================================ + + def element_to_svg( + self, + element: ElementInfo, + image: np.ndarray, + standalone: bool = True, + offset: Tuple[int, int] = (0, 0), + ) -> str: + """ + Convert a single ElementInfo to an SVG string. + + Args: + element: The element to convert + image: Original image as numpy array (BGR, for cropping rasters) + standalone: If True, wrap in full SVG document; if False, return inner element only + offset: (ox, oy) to subtract from coordinates (for section-relative positioning) + + Returns: + SVG string + """ + elem_type = element.element_type.lower() + + # Determine which renderer to use + if elem_type in RASTER_TYPES: + inner = self._raster_element_svg(element, image, offset) + elif elem_type in ARROW_TYPES: + inner = self._arrow_element_svg(element, image, offset) + elif elem_type == 'text': + inner = self._text_element_svg(element, offset) + elif elem_type in GEOMETRIC_SHAPES: + inner = self._geometric_shape_svg(element, offset) + else: + # Unknown type: try polygon fallback, or raster crop + if element.polygon and len(element.polygon) >= 3: + inner = self._polygon_shape_svg(element, offset) + else: + inner = self._raster_element_svg(element, image, offset) + + if not standalone: + return inner + + # Wrap in standalone SVG document + bbox = element.bbox + w = bbox.width + 4 # small padding + h = bbox.height + 4 + header = SVG_HEADER.format(width=w, height=h) + + # If arrow, include marker defs + defs = "" + if elem_type in ARROW_TYPES: + stroke_color = element.stroke_color or "#000000" + defs = ARROW_MARKER_DEF.format(color=stroke_color) + + # Re-offset inner content to (2,2) padding origin + # We need to wrap in a with translate + ox, oy = offset + tx = -bbox.x1 + ox + 2 + ty = -bbox.y1 + oy + 2 + + return ( + header + + defs + + f' \n' + + inner + + " \n" + + SVG_FOOTER + ) + + def generate_combined_svg( + self, + elements: List[ElementInfo], + image: np.ndarray, + canvas_width: int, + canvas_height: int, + ) -> str: + """ + Generate a single SVG containing ALL elements as named groups. + + Elements are layered by their layer_level (background first, text on top). + """ + header = SVG_HEADER.format(width=canvas_width, height=canvas_height) + + # Arrow marker defs + defs = ARROW_MARKER_DEF.format(color="#000000") + + # Sort by layer level (low = bottom = rendered first in SVG) + sorted_elems = sorted(elements, key=lambda e: e.layer_level) + + # Group by layer + layer_names = { + LayerLevel.BACKGROUND.value: "background", + LayerLevel.BASIC_SHAPE.value: "shapes", + LayerLevel.IMAGE.value: "images", + LayerLevel.ARROW.value: "arrows", + LayerLevel.TEXT.value: "text", + LayerLevel.OTHER.value: "other", + } + + body_parts = [] + current_layer = None + for elem in sorted_elems: + layer = elem.layer_level + if layer != current_layer: + if current_layer is not None: + body_parts.append(" \n") + layer_name = layer_names.get(layer, f"layer_{layer}") + body_parts.append( + f' \n' + ) + current_layer = layer + + # Element group + elem_type = elem.element_type.lower().replace(" ", "_") + group_id = f"elem-{elem.id}-{elem_type}" + body_parts.append(f' \n') + inner_svg = self.element_to_svg(elem, image, standalone=False) + # Indent inner content + for line in inner_svg.strip().split("\n"): + body_parts.append(f" {line}\n") + body_parts.append(" \n") + + # Close last layer group + if current_layer is not None: + body_parts.append(" \n") + + return header + defs + "".join(body_parts) + SVG_FOOTER + + def crop_raster_element( + self, element: ElementInfo, image: np.ndarray + ) -> Optional[Image.Image]: + """ + Crop a raster element from the image with transparent background. + + Uses the element's polygon mask if available, otherwise uses bbox crop. + Returns a PIL Image with RGBA (transparent background). + """ + bbox = element.bbox + x1 = max(0, bbox.x1) + y1 = max(0, bbox.y1) + x2 = min(image.shape[1], bbox.x2) + y2 = min(image.shape[0], bbox.y2) + + if x2 <= x1 or y2 <= y1: + return None + + crop = image[y1:y2, x1:x2].copy() + + # Convert BGR to RGB + crop_rgb = cv2.cvtColor(crop, cv2.COLOR_BGR2RGB) + + # If polygon available, create alpha mask + if element.polygon and len(element.polygon) >= 3: + mask = np.zeros((y2 - y1, x2 - x1), dtype=np.uint8) + # Offset polygon to crop coordinates + pts = np.array(element.polygon, dtype=np.int32) + pts[:, 0] -= x1 + pts[:, 1] -= y1 + cv2.fillPoly(mask, [pts], 255) + # Create RGBA + rgba = np.dstack([crop_rgb, mask]) + else: + # No polygon — full opaque bbox crop + alpha = np.full((y2 - y1, x2 - x1), 255, dtype=np.uint8) + rgba = np.dstack([crop_rgb, alpha]) + + return Image.fromarray(rgba, "RGBA") + + def save_raster_crop( + self, element: ElementInfo, image: np.ndarray, output_path: str + ) -> Optional[str]: + """Crop and save raster element as PNG. Returns path or None.""" + pil_img = self.crop_raster_element(element, image) + if pil_img is None: + return None + os.makedirs(os.path.dirname(output_path), exist_ok=True) + pil_img.save(output_path, "PNG") + return output_path + + # ================================================================ + # PRIVATE: Shape Renderers + # ================================================================ + + def _geometric_shape_svg( + self, element: ElementInfo, offset: Tuple[int, int] = (0, 0) + ) -> str: + """Render known geometric shapes as clean SVG primitives.""" + bbox = element.bbox + ox, oy = offset + x = bbox.x1 - ox + y = bbox.y1 - oy + w = bbox.width + h = bbox.height + + fill = element.fill_color or "#ffffff" + stroke = element.stroke_color or "#000000" + sw = element.stroke_width or 1 + elem_type = element.element_type.lower() + + style = f'fill="{fill}" stroke="{stroke}" stroke-width="{sw}"' + + if elem_type in ('rectangle', 'section_panel', 'title_bar'): + return f' \n' + + elif elem_type in ('rounded_rectangle', 'rounded rectangle'): + rx = min(10, w // 6, h // 6) + return f' \n' + + elif elem_type in ('ellipse', 'circle'): + cx = x + w // 2 + cy = y + h // 2 + rx = w // 2 + ry = h // 2 + return f' \n' + + elif elem_type == 'diamond': + cx = x + w // 2 + cy = y + h // 2 + points = f"{cx},{y} {x + w},{cy} {cx},{y + h} {x},{cy}" + return f' \n' + + elif elem_type == 'triangle': + cx = x + w // 2 + points = f"{cx},{y} {x + w},{y + h} {x},{y + h}" + return f' \n' + + elif elem_type == 'hexagon': + # Flat-top hexagon + cx = x + w // 2 + cy = y + h // 2 + qw = w // 4 + points = ( + f"{x + qw},{y} {x + 3 * qw},{y} " + f"{x + w},{cy} " + f"{x + 3 * qw},{y + h} {x + qw},{y + h} " + f"{x},{cy}" + ) + return f' \n' + + elif elem_type == 'parallelogram': + skew = w // 5 + points = ( + f"{x + skew},{y} {x + w},{y} " + f"{x + w - skew},{y + h} {x},{y + h}" + ) + return f' \n' + + elif elem_type == 'trapezoid': + # Wider base at bottom (like encoder in ML figures) + inset = w // 6 + points = ( + f"{x + inset},{y} {x + w - inset},{y} " + f"{x + w},{y + h} {x},{y + h}" + ) + return f' \n' + + elif elem_type == 'square': + # Force aspect ratio (take min dimension) + s = min(w, h) + return f' \n' + + elif elem_type == 'cloud': + # Approximated cloud shape using a path with multiple arcs + cx = x + w // 2 + cy = y + h // 2 + r1 = min(w, h) // 4 + path_d = ( + f"M{x + r1},{cy} " + f"a{r1},{r1} 0 0,1 {r1 * 2},0 " + f"a{r1},{r1} 0 0,1 {r1 * 2},0 " + f"a{r1},{r1} 0 0,1 0,{r1 * 2} " + f"a{r1},{r1} 0 0,1 -{r1 * 4},0 " + f"a{r1},{r1} 0 0,1 0,-{r1 * 2} Z" + ) + return f' \n' + + elif elem_type in ('3d_cube', '3d cube', 'isometric_box', 'isometric box'): + # Isometric 3D cube: front face + top face + side face + depth = min(w, h) // 4 + # Front face + svg = ( + f' \n' + ) + # Top face (parallelogram) + top_pts = ( + f"{x},{y + depth} " + f"{x + depth},{y} " + f"{x + w},{y} " + f"{x + w - depth},{y + depth}" + ) + svg += f' \n' + # Right face (parallelogram) + right_pts = ( + f"{x + w - depth},{y + depth} " + f"{x + w},{y} " + f"{x + w},{y + h - depth} " + f"{x + w - depth},{y + h}" + ) + svg += f' \n' + return svg + + elif elem_type in ( + 'color_swatch', 'color swatch', + 'small_colored_square', 'small colored square', + ): + # Small colored square (legend swatch) — solid-filled rect + return f' \n' + + elif elem_type in ( + 'stack_of_rectangles', 'stack of rectangles', + 'layered_boxes', 'layered boxes', + ): + # Render as 3 overlapping rectangles to suggest a stack + offset_step = min(w, h) // 12 + svg = "" + # Back rectangle (offset up-right) + svg += ( + f' \n' + ) + # Middle rectangle + svg += ( + f' \n' + ) + # Front rectangle + svg += ( + f' \n' + ) + return svg + + elif elem_type == 'cylinder': + # Approximate cylinder as rect + two ellipses + ry_cap = min(15, h // 6) + svg = "" + # Body + svg += f' \n' + # Side lines + svg += f' \n' + svg += f' \n' + # Top ellipse + cx_e = x + w // 2 + svg += f' \n' + # Bottom ellipse (only bottom half visible) + svg += f' \n' + return svg + + # Fallback: use polygon if available + if element.polygon and len(element.polygon) >= 3: + return self._polygon_shape_svg(element, offset) + + # Last resort: rectangle + return f' \n' + + def _polygon_shape_svg( + self, element: ElementInfo, offset: Tuple[int, int] = (0, 0) + ) -> str: + """Render element using its polygon contour as an SVG path with smoothing.""" + fill = element.fill_color or "#ffffff" + stroke = element.stroke_color or "#000000" + sw = element.stroke_width or 1 + ox, oy = offset + + polygon = element.polygon + if not polygon or len(polygon) < 3: + # Fallback to bbox rect + return self._geometric_shape_svg(element, offset) + + # Smooth the polygon using Chaikin's algorithm (1 iteration) + smoothed = self._chaikin_smooth(polygon, iterations=2) + + # Build SVG path + path_d = self._polygon_to_svg_path(smoothed, ox, oy) + + style = f'fill="{fill}" stroke="{stroke}" stroke-width="{sw}"' + return f' \n' + + def _arrow_element_svg( + self, element: ElementInfo, image: np.ndarray, offset: Tuple[int, int] = (0, 0) + ) -> str: + """Render arrow/line/connector as SVG polyline with arrowhead markers.""" + ox, oy = offset + stroke = element.stroke_color or "#000000" + sw = element.stroke_width or 2 + + # If we have explicit start/end points, use them + if element.arrow_start and element.arrow_end: + x1 = element.arrow_start[0] - ox + y1 = element.arrow_start[1] - oy + x2 = element.arrow_end[0] - ox + y2 = element.arrow_end[1] - oy + return ( + f' \n' + ) + + # If we have vector_points, use polyline + if element.vector_points and len(element.vector_points) >= 2: + points = " ".join( + f"{p[0] - ox},{p[1] - oy}" for p in element.vector_points + ) + return ( + f' \n' + ) + + # Fallback: use polygon contour as path, or crop as raster + if element.polygon and len(element.polygon) >= 3: + # Use polygon outline (no fill) as the arrow shape + path_d = self._polygon_to_svg_path(element.polygon, ox, oy) + return ( + f' \n' + ) + + # Last fallback: embed as raster crop + return self._raster_element_svg(element, image, offset) + + def _text_element_svg( + self, element: ElementInfo, offset: Tuple[int, int] = (0, 0) + ) -> str: + """Render text element as SVG (editable).""" + bbox = element.bbox + ox, oy = offset + + # Get text content from processing_notes or source_prompt + text_content = "" + for note in element.processing_notes: + if note.startswith("text:"): + text_content = note[5:].strip() + break + if not text_content and element.source_prompt: + text_content = element.source_prompt + + # Position at center of bbox + x = bbox.x1 - ox + bbox.width // 2 + y = bbox.y1 - oy + bbox.height // 2 + + # Estimate font size from bbox height + font_size = max(8, min(48, int(bbox.height * 0.7))) + + fill = element.fill_color or "#000000" + escaped_text = xml_escape(text_content) + + return ( + f' ' + f'{escaped_text}\n' + ) + + def _raster_element_svg( + self, element: ElementInfo, image: np.ndarray, offset: Tuple[int, int] = (0, 0) + ) -> str: + """Embed raster crop as base64 inside SVG.""" + bbox = element.bbox + ox, oy = offset + + # Use existing base64 if available (from IconPictureProcessor) + if element.base64: + b64_data = element.base64 + else: + # Crop and encode + pil_img = self.crop_raster_element(element, image) + if pil_img is None: + # Empty placeholder + return f' \n' + buf = io.BytesIO() + pil_img.save(buf, format="PNG") + b64_data = base64.b64encode(buf.getvalue()).decode("ascii") + + x = bbox.x1 - ox + y = bbox.y1 - oy + + return ( + f' \n' + ) + + # ================================================================ + # PRIVATE: Geometry Helpers + # ================================================================ + + @staticmethod + def _chaikin_smooth( + polygon: List[List[int]], iterations: int = 2 + ) -> List[List[float]]: + """ + Chaikin's corner-cutting algorithm for polygon smoothing. + Each iteration doubles the point count and rounds corners. + """ + pts = [list(map(float, p)) for p in polygon] + for _ in range(iterations): + if len(pts) < 3: + break + new_pts = [] + n = len(pts) + for i in range(n): + p0 = pts[i] + p1 = pts[(i + 1) % n] + # Q = 3/4 * P_i + 1/4 * P_{i+1} + q = [0.75 * p0[0] + 0.25 * p1[0], 0.75 * p0[1] + 0.25 * p1[1]] + # R = 1/4 * P_i + 3/4 * P_{i+1} + r = [0.25 * p0[0] + 0.75 * p1[0], 0.25 * p0[1] + 0.75 * p1[1]] + new_pts.append(q) + new_pts.append(r) + pts = new_pts + return pts + + @staticmethod + def _polygon_to_svg_path( + polygon: List, ox: int = 0, oy: int = 0 + ) -> str: + """Convert polygon points to SVG path d attribute string.""" + if not polygon: + return "" + + parts = [] + for i, pt in enumerate(polygon): + x = pt[0] - ox + y = pt[1] - oy + cmd = "M" if i == 0 else "L" + parts.append(f"{cmd}{x:.1f},{y:.1f}") + parts.append("Z") + return " ".join(parts) + + @staticmethod + def _image_to_base64(pil_img: Image.Image, fmt: str = "PNG") -> str: + """Convert PIL image to base64 string.""" + buf = io.BytesIO() + pil_img.save(buf, format=fmt) + return base64.b64encode(buf.getvalue()).decode("ascii") diff --git a/modules/vector_exporter.py b/modules/vector_exporter.py new file mode 100644 index 0000000..0cf6c06 --- /dev/null +++ b/modules/vector_exporter.py @@ -0,0 +1,526 @@ +""" +Stage 8: Vector Exporter — Orchestrator Module + +Coordinates SVG generation, raster cropping, section detection, PDF combination, +and manifest generation. This is the main entry point for the vector export pipeline. + +Integrates into Edit Banana as Stage 8 (runs after XML Merge in Stage 7). + +Reads from ProcessingContext: + - context.elements (all detected ElementInfo objects with bbox, polygon, type, colors) + - context.image_path (original image for raster cropping) + - context.canvas_width / canvas_height + - context.output_dir (base output directory for this image) + +Produces: + output/{image}/vectors/ + combined/ -> combined.svg + combined.pdf + elements/ -> individual element SVGs + rasters/ -> cropped PNGs for raster elements + sections/ -> section-level SVGs (on demand) + components/ -> grouped element SVGs (on demand) + manifest.json -> element index + +Usage: + from modules.vector_exporter import VectorExporter + + exporter = VectorExporter() + result = exporter.process(context) + +CLI flags (handled by main.py, passed via context.intermediate_results): + --vector-level=granular|section|component|all (default: granular) + --no-vectors (skip vector export entirely) +""" + +import os +import json +import re +import cv2 +import numpy as np +from datetime import datetime +from typing import List, Optional, Dict, Any + +from .base import BaseProcessor, ProcessingContext +from .data_types import ( + ElementInfo, + ProcessingResult, + LayerLevel, +) +from .svg_generator import SVGGenerator, RASTER_TYPES, ARROW_TYPES, GEOMETRIC_SHAPES +from .pdf_combiner import PDFCombiner +from .section_detector import SectionDetector, Section + + +# Valid vector-level options +VECTOR_LEVELS = {"granular", "section", "component", "all"} +DEFAULT_VECTOR_LEVEL = "granular" + + +class VectorExporter(BaseProcessor): + """ + Stage 8 orchestrator: exports editable vector assets from pipeline results. + + Inherits BaseProcessor for consistent interface with other pipeline stages. + """ + + def __init__(self, config=None): + super().__init__(config) + self._svg_gen = SVGGenerator() + self._pdf_combiner = PDFCombiner() + self._section_detector = SectionDetector() + + def process(self, context: ProcessingContext) -> ProcessingResult: + """ + Main entry point — called by Pipeline after Stage 7 (XML Merge). + + Reads vector_level from context.intermediate_results['vector_level']. + Default is 'granular' (individual element SVGs + combined SVG/PDF). + """ + self._log("Starting vector export...") + + # Get configuration from context + vector_level = context.intermediate_results.get( + "vector_level", DEFAULT_VECTOR_LEVEL + ) + if vector_level not in VECTOR_LEVELS: + self._log(f"Unknown vector level '{vector_level}', using 'granular'") + vector_level = DEFAULT_VECTOR_LEVEL + + # Load image + image = cv2.imread(context.image_path) + if image is None: + return ProcessingResult( + success=False, + error_message=f"Could not read image: {context.image_path}", + ) + + # Create output directories + vectors_dir = os.path.join(context.output_dir, "vectors") + dirs = self._create_output_dirs(vectors_dir, vector_level) + + # Get canvas dimensions + canvas_w = context.canvas_width or image.shape[1] + canvas_h = context.canvas_height or image.shape[0] + + # Filter to elements that have been processed + elements = context.elements + if not elements: + self._log("No elements to export") + return ProcessingResult( + success=True, + metadata={"exported_count": 0, "vector_dir": vectors_dir}, + ) + + self._log(f"Exporting {len(elements)} elements (level={vector_level})") + + # Track manifest entries + manifest_entries = [] + + # ============================================================ + # GRANULAR: always runs — individual element SVGs + raster PNGs + # ============================================================ + self._log("Exporting individual elements...") + for elem in elements: + entry = self._export_element( + elem, image, dirs["elements"], dirs["rasters"] + ) + if entry: + manifest_entries.append(entry) + + self._log(f" Exported {len(manifest_entries)} elements") + + # ============================================================ + # COMBINED SVG + PDF: always generated + # ============================================================ + self._log("Generating combined SVG...") + combined_svg = self._svg_gen.generate_combined_svg( + elements, image, canvas_w, canvas_h + ) + combined_svg_path = os.path.join(dirs["combined"], "combined.svg") + with open(combined_svg_path, "w", encoding="utf-8") as f: + f.write(combined_svg) + self._log(f" Saved: {combined_svg_path}") + + # PDF + self._log("Generating combined PDF...") + combined_pdf_path = os.path.join(dirs["combined"], "combined.pdf") + if self._pdf_combiner.is_available: + pdf_ok = self._pdf_combiner.svg_string_to_pdf( + combined_svg, combined_pdf_path + ) + if not pdf_ok: + self._log(" PDF generation failed (SVG may have unsupported features)") + else: + self._log( + " PDF generation skipped (install svglib+reportlab: " + "pip install svglib reportlab)" + ) + + # ============================================================ + # SECTION LEVEL: on demand + # ============================================================ + sections = [] + if vector_level in ("section", "all"): + self._log("Detecting sections...") + sections = self._section_detector.detect_sections(context) + if sections: + self._export_sections( + sections, elements, image, dirs["sections"], canvas_w, canvas_h + ) + else: + self._log(" No sections detected") + + # ============================================================ + # COMPONENT LEVEL: on demand (smart grouping) + # ============================================================ + if vector_level in ("component", "all"): + self._log("Generating component groups...") + self._export_components( + elements, image, dirs["components"], canvas_w, canvas_h + ) + + # ============================================================ + # MANIFEST + # ============================================================ + manifest = self._build_manifest( + context, manifest_entries, sections, vector_level, canvas_w, canvas_h + ) + manifest_path = os.path.join(vectors_dir, "manifest.json") + with open(manifest_path, "w", encoding="utf-8") as f: + json.dump(manifest, f, indent=2, ensure_ascii=False) + self._log(f"Manifest saved: {manifest_path}") + + self._log("Vector export complete.") + + return ProcessingResult( + success=True, + elements=elements, + metadata={ + "exported_count": len(manifest_entries), + "vector_dir": vectors_dir, + "combined_svg": combined_svg_path, + "combined_pdf": combined_pdf_path + if os.path.exists(combined_pdf_path) + else None, + "sections_count": len(sections), + "manifest_path": manifest_path, + }, + ) + + # ================================================================ + # Element Export + # ================================================================ + + def _export_element( + self, + element: ElementInfo, + image: np.ndarray, + elements_dir: str, + rasters_dir: str, + ) -> Optional[Dict[str, Any]]: + """Export a single element as SVG + optional raster PNG.""" + elem_type = element.element_type.lower().replace(" ", "_") + elem_id = f"{elem_type}_{element.id:03d}" + + # Generate SVG + try: + svg_string = self._svg_gen.element_to_svg(element, image, standalone=True) + except Exception as e: + self._log(f" Failed to generate SVG for {elem_id}: {e}") + return None + + # Save SVG + svg_filename = f"{elem_id}.svg" + svg_path = os.path.join(elements_dir, svg_filename) + with open(svg_path, "w", encoding="utf-8") as f: + f.write(svg_string) + + # Save raster PNG for image-type elements + raster_path = None + raster_filename = None + if elem_type in RASTER_TYPES or element.base64: + raster_filename = f"{elem_id}.png" + raster_full_path = os.path.join(rasters_dir, raster_filename) + saved = self._svg_gen.save_raster_crop(element, image, raster_full_path) + if saved: + raster_path = f"rasters/{raster_filename}" + + # Build manifest entry + return { + "id": elem_id, + "element_id": element.id, + "type": element.element_type, + "svg_path": f"elements/{svg_filename}", + "raster_path": raster_path, + "bbox": { + "x": element.bbox.x1, + "y": element.bbox.y1, + "w": element.bbox.width, + "h": element.bbox.height, + }, + "colors": { + "fill": element.fill_color, + "stroke": element.stroke_color, + }, + "score": round(element.score, 4), + "layer": LayerLevel(element.layer_level).name + if element.layer_level in [lv.value for lv in LayerLevel] + else "OTHER", + } + + # ================================================================ + # Section Export + # ================================================================ + + def _export_sections( + self, + sections: List[Section], + elements: List[ElementInfo], + image: np.ndarray, + sections_dir: str, + canvas_w: int, + canvas_h: int, + ): + """Export each section as its own SVG with child elements.""" + # Build element lookup by ID + elem_by_id = {e.id: e for e in elements} + + for section in sections: + section_label = f"section_{section.label}" + section_subdir = os.path.join(sections_dir, section_label) + os.makedirs(section_subdir, exist_ok=True) + + # Get child elements + child_elements = [ + elem_by_id[eid] + for eid in section.child_element_ids + if eid in elem_by_id + ] + + if not child_elements: + continue + + # Generate section SVG (elements positioned relative to section bbox) + section_svg = self._svg_gen.generate_combined_svg( + child_elements, + image, + section.bbox.width, + section.bbox.height, + ) + + # Adjust viewBox to section coordinates + section_svg = section_svg.replace( + f'viewBox="0 0 {section.bbox.width} {section.bbox.height}"', + f'viewBox="{section.bbox.x1} {section.bbox.y1} ' + f'{section.bbox.width} {section.bbox.height}"', + ) + + svg_path = os.path.join(section_subdir, f"{section_label}.svg") + with open(svg_path, "w", encoding="utf-8") as f: + f.write(section_svg) + + # Also export individual elements within section + elements_subdir = os.path.join(section_subdir, "elements") + os.makedirs(elements_subdir, exist_ok=True) + for elem in child_elements: + elem_type = elem.element_type.lower().replace(" ", "_") + elem_filename = f"{elem_type}_{elem.id:03d}.svg" + try: + elem_svg = self._svg_gen.element_to_svg( + elem, image, standalone=True + ) + with open( + os.path.join(elements_subdir, elem_filename), "w", encoding="utf-8" + ) as f: + f.write(elem_svg) + except Exception: + pass + + self._log( + f" Section {section.label}: {len(child_elements)} elements -> {svg_path}" + ) + + # ================================================================ + # Component Export (Smart Grouping) + # ================================================================ + + def _export_components( + self, + elements: List[ElementInfo], + image: np.ndarray, + components_dir: str, + canvas_w: int, + canvas_h: int, + ): + """ + Group related elements into components and export each as SVG. + + Grouping heuristics: + - Shape + overlapping text -> "labeled box" component + - Sequential arrows -> "arrow chain" component + - Nearby icons/pictures -> "icon group" component + """ + used_ids = set() + component_id = 0 + + # Build spatial index + shapes = [ + e for e in elements + if e.element_type.lower() in GEOMETRIC_SHAPES + ] + texts = [ + e for e in elements + if e.element_type.lower() == "text" + ] + arrows = [ + e for e in elements + if e.element_type.lower() in ARROW_TYPES + ] + others = [ + e for e in elements + if e.id not in {s.id for s in shapes} + and e.id not in {t.id for t in texts} + and e.id not in {a.id for a in arrows} + ] + + # Heuristic 1: Shape + contained text = labeled box + for shape in shapes: + if shape.id in used_ids: + continue + + group = [shape] + used_ids.add(shape.id) + + # Find text elements whose center falls within this shape's bbox + for text in texts: + if text.id in used_ids: + continue + tcx, tcy = text.bbox.center + if ( + shape.bbox.x1 <= tcx <= shape.bbox.x2 + and shape.bbox.y1 <= tcy <= shape.bbox.y2 + ): + group.append(text) + used_ids.add(text.id) + + if len(group) >= 1: + self._save_component_group( + group, image, components_dir, component_id, "labeled_box" + ) + component_id += 1 + + # Heuristic 2: Individual arrows + for arrow in arrows: + if arrow.id in used_ids: + continue + self._save_component_group( + [arrow], image, components_dir, component_id, "arrow" + ) + used_ids.add(arrow.id) + component_id += 1 + + # Heuristic 3: Remaining elements as individual components + for elem in others + [t for t in texts if t.id not in used_ids]: + if elem.id in used_ids: + continue + elem_type = elem.element_type.lower().replace(" ", "_") + self._save_component_group( + [elem], image, components_dir, component_id, elem_type + ) + used_ids.add(elem.id) + component_id += 1 + + self._log(f" Exported {component_id} components") + + def _save_component_group( + self, + elements: List[ElementInfo], + image: np.ndarray, + components_dir: str, + component_id: int, + component_type: str, + ): + """Save a group of elements as a single component SVG.""" + if not elements: + return + + # Calculate bounding box of all elements in group + x1 = min(e.bbox.x1 for e in elements) + y1 = min(e.bbox.y1 for e in elements) + x2 = max(e.bbox.x2 for e in elements) + y2 = max(e.bbox.y2 for e in elements) + w = x2 - x1 + 4 # padding + h = y2 - y1 + 4 + + # Generate combined SVG for this group + svg = self._svg_gen.generate_combined_svg(elements, image, w, h) + + # Adjust viewBox to component coordinates + svg = svg.replace( + f'viewBox="0 0 {w} {h}"', + f'viewBox="{x1 - 2} {y1 - 2} {w} {h}"', + ) + + filename = f"component_{component_id:03d}_{component_type}.svg" + filepath = os.path.join(components_dir, filename) + with open(filepath, "w", encoding="utf-8") as f: + f.write(svg) + + # ================================================================ + # Manifest + # ================================================================ + + def _build_manifest( + self, + context: ProcessingContext, + elements: List[Dict], + sections: List[Section], + vector_level: str, + canvas_w: int, + canvas_h: int, + ) -> Dict[str, Any]: + """Build manifest.json content.""" + return { + "source_image": os.path.basename(context.image_path), + "image_size": {"width": canvas_w, "height": canvas_h}, + "extraction_date": datetime.now().isoformat(), + "vector_level": vector_level, + "levels_generated": self._levels_generated(vector_level), + "total_elements": len(elements), + "elements": elements, + "sections": [s.to_dict() for s in sections] if sections else [], + } + + @staticmethod + def _levels_generated(vector_level: str) -> List[str]: + if vector_level == "all": + return ["granular", "section", "component"] + return [vector_level] + + # ================================================================ + # Directory Setup + # ================================================================ + + def _create_output_dirs( + self, vectors_dir: str, vector_level: str + ) -> Dict[str, str]: + """Create output directory structure.""" + dirs = { + "root": vectors_dir, + "combined": os.path.join(vectors_dir, "combined"), + "elements": os.path.join(vectors_dir, "elements"), + "rasters": os.path.join(vectors_dir, "rasters"), + } + + if vector_level in ("section", "all"): + dirs["sections"] = os.path.join(vectors_dir, "sections") + + if vector_level in ("component", "all"): + dirs["components"] = os.path.join(vectors_dir, "components") + + for d in dirs.values(): + os.makedirs(d, exist_ok=True) + + return dirs + + diff --git a/prompts/arrow.py b/prompts/arrow.py index 8ab4570..248d6c9 100644 --- a/prompts/arrow.py +++ b/prompts/arrow.py @@ -1,6 +1,41 @@ -# Arrow/connector prompts for SAM3 (minimal set) +# Arrow/connector prompts for SAM3 +# +# Design principles (from SAM3 best practices): +# - Short, specific noun phrases +# - Distinguish arrow types by visual feature (straight / curved / thick / dashed) +# - Include common diagram connector variants +# +# Note: SAM3 treats each prompt independently; similar prompts may produce +# overlapping detections that get merged in the dedup step. ARROW_PROMPT = [ + # Core arrow types "arrow", + "straight arrow", + "thin arrow", + + # Thick / block arrows (e.g. "weight transfer" orange arrows in ML figures) + "thick arrow", + "block arrow", + + # Curved arrows (e.g. "iterative refinement" loops — were missed in v1) + "curved arrow", + "looping arrow", + + # Multi-endpoint arrows + "bidirectional arrow", + "double-headed arrow", + + # Line-type connectors "line", "connector", + "connecting line", + + # Dashed / dotted (for skip connections, data flow — common in neural net figures) + "dashed line", + "dotted line", + "skip connection", + + # L-shaped / elbow connectors (common in flowcharts) + "L-shaped connector", + "elbow arrow", ] diff --git a/prompts/background.py b/prompts/background.py index c815338..297a4c7 100644 --- a/prompts/background.py +++ b/prompts/background.py @@ -1,7 +1,36 @@ # Background/container prompts for SAM3 +# +# Design principles (from SAM3 best practices): +# - Short, specific noun phrases +# - Target large container regions (panels, sub-figures, legend groups) +# - These get priority=1 (lowest) so they don't overshadow inner elements +# +# These are used for: +# 1. Detecting sub-figure panels (a), (b), (c), (d), (e) in scientific figures +# 2. Providing the BACKGROUND layer in SVG/DrawIO output BACKGROUND_PROMPT = [ + # Core panel / container types "panel", + "sub-figure panel", "container", + "background region", + + # Filled / shaded regions "filled region", - "background", + "shaded region", + "colored background", + + # Bordered / grouped sections + "dashed border rectangle", + "bordered section", + "grouped section", + + # Legend-specific (common in scientific figures with color-coded legends) + "color legend", + "legend box", + "legend panel", + + # Title bars / headers + "title bar", + "header strip", ] diff --git a/prompts/image.py b/prompts/image.py index 6dca28f..c0e7901 100644 --- a/prompts/image.py +++ b/prompts/image.py @@ -1,8 +1,55 @@ # Image/icon prompts for SAM3 +# +# Design principles (from SAM3 best practices): +# - Short, specific noun phrases +# - Include scientific/medical figure icons specifically (person, scan, render) +# - Cover the full spectrum: icons, photos, renders, charts, illustrations +# +# The IconPictureProcessor crops these regions and (optionally) removes +# backgrounds with RMBG-2.0. Outputs are embedded as base64 in vectors. IMAGE_PROMPT = [ + # Generic image categories "icon", "picture", + "photograph", "logo", + "illustration", + + # Data visualizations "chart", + "graph", + "plot", "diagram", + + # Scientific figure icons (were missed in v1 — CCT-FM panel had many) + "person icon", + "human silhouette", + "group of people", + "crowd icon", + + # Medical imaging (for figures like CT pipelines, Bern CCTA, etc.) + "medical scan", + "CT scan image", + "MRI image", + "ultrasound image", + "anatomy rendering", + "3D heart model", + "3D rendering", + + # Medical icons + "stethoscope icon", + "heart icon", + + # Stacked / layered images (for dataset thumbnails) + "stack of images", + "image stack", + "thumbnail strip", + + # Texture / pattern (for masking blocks, checkerboards) + "checkerboard pattern", + "grid pattern", + + # Computer/UI icons (for HITL / expert correction figures) + "computer monitor icon", + "screen icon", ] diff --git a/prompts/shape.py b/prompts/shape.py index 0df8306..3dde40e 100644 --- a/prompts/shape.py +++ b/prompts/shape.py @@ -1,10 +1,44 @@ # Basic shape prompts for SAM3 +# +# Design principles (from SAM3 best practices): +# - Short, specific noun phrases +# - Cover scientific figure conventions (trapezoids for encoders, cubes for volumes) +# - Avoid redundancy; the dedup step handles overlaps across prompts +# +# Downstream support: basic_shape_processor handles rectangle, rounded_rectangle, +# diamond, ellipse, circle, cylinder, cloud, actor, hexagon, triangle, +# parallelogram, trapezoid, square (see modules/basic_shape_processor.py). SHAPE_PROMPT = [ + # Core rectangles (most common in scientific figures) "rectangle", "rounded rectangle", - "diamond", + "square", + + # Conic / pill shapes "ellipse", "circle", + + # Rhombic + "diamond", + + # Encoder/decoder shapes (very common in ML/DL figures — were missed in v1) + "trapezoid", + "parallelogram", + + # Polygons "triangle", "hexagon", + + # 3D / volumetric (common in medical imaging & ML figures — were missed in v1) + "3D cube", + "isometric box", + "cylinder", + + # Small colored elements (for legends — were missed in v1 due to size) + "color swatch", + "small colored square", + + # Stacked / layered (for image stacks, CT volumes — were missed in v1) + "stack of rectangles", + "layered boxes", ] diff --git a/requirements.txt b/requirements.txt index 179fb4e..bf816ce 100644 --- a/requirements.txt +++ b/requirements.txt @@ -11,6 +11,10 @@ python-multipart modelscope einops +# Stage 8: Vector export (SVG -> PDF) +svglib +reportlab + # Text OCR (local, default) pytesseract From 53a7e0095c743015fa25e912e46b6c7383728450 Mon Sep 17 00:00:00 2001 From: sdamirsa Date: Wed, 15 Apr 2026 21:23:21 +0300 Subject: [PATCH 3/3] docs: README note on the multi-panel / schematic-figure challenge Adds a [!NOTE] admonition near the top of the README documenting the remaining limitation after the v2 prompt expansion, hierarchical layer assignment, and Stage 8 vector export: the pipeline still under-understands multi-panel scientific schematics because detection is per-element while global semantics (cross-panel arrows, legend-to-plot mapping) is not modeled. Lists two roadmap directions for tackling this challenge: 1. Two-pass extraction with explicit panel splitting + recursion 2. Smart per-element-type margin padding around cropped rasters Co-Authored-By: Claude Opus 4.6 --- README.md | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/README.md b/README.md index 92cfc3c..5d87209 100644 --- a/README.md +++ b/README.md @@ -36,6 +36,13 @@ Powered by SAM 3 and multimodal large models, it enables high-fidelity reconstru > [!WARNING] > **Please note**: Our GitHub repository currently trails behind our web-based service. For the most up-to-date features and performance, we recommend using our web platform. + +> [!NOTE] +> **Known limitation — multi-panel / schematic figures.** Recent updates (v2 prompts: 19 → 78; new hierarchical layer assignment; [Stage 8 vector export to SVG/PDF](https://github.com/Sdamirsa/Edit-Banana/commit/d6c445a)) significantly improved single-element extraction, but the pipeline still under-understands complex **multi-panel scientific schematics**. Detection is per-element; the global semantics — which arrow connects which box across panel boundaries, which legend swatch labels which plot — is not yet modeled. +> +> **Roadmap directions on this challenge:** +> 1. **Two-pass extraction with panel splitting.** First detect sub-figure panels (Stage 8's `section_detector` already produces panel candidates), split the source image into per-panel crops, then recursively run the full pipeline on each crop. The orchestration loop is the missing piece. +> 2. **Smart per-element-type margin padding** around cropped rasters. Tight bboxes clip strokes; loose ones bleed neighbors. A per-type heuristic (icon vs. photo vs. schematic illustration) would help, but the logic is hard to pin down cleanly. --- ## 💬 Join WeChat Group