Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitattributes
Original file line number Diff line number Diff line change
@@ -1 +1,4 @@
large_files/ filter=lfs diff=lfs merge=lfs -text

# Keep shell scripts with Unix line endings on all platforms
*.sh text eol=lf
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -50,3 +50,7 @@ sam3_src/
# Local processing & debug
arrow_processing/
debug_output/

# Local planning notes & AI tool session data
.amir-zone/
.claude/
7 changes: 7 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,13 @@ Powered by SAM 3 and multimodal large models, it enables high-fidelity reconstru

> [!WARNING]
> **Please note**: Our GitHub repository currently trails behind our web-based service. For the most up-to-date features and performance, we recommend using our web platform.

> [!NOTE]
> **Known limitation — multi-panel / schematic figures.** Recent updates (v2 prompts: 19 → 78; new hierarchical layer assignment; [Stage 8 vector export to SVG/PDF](https://github.com/Sdamirsa/Edit-Banana/commit/d6c445a)) significantly improved single-element extraction, but the pipeline still under-understands complex **multi-panel scientific schematics**. Detection is per-element; the global semantics — which arrow connects which box across panel boundaries, which legend swatch labels which plot — is not yet modeled.
>
> **Roadmap directions on this challenge:**
> 1. **Two-pass extraction with panel splitting.** First detect sub-figure panels (Stage 8's `section_detector` already produces panel candidates), split the source image into per-panel crops, then recursively run the full pipeline on each crop. The orchestration loop is the missing piece.
> 2. **Smart per-element-type margin padding** around cropped rasters. Tight bboxes clip strokes; loose ones bleed neighbors. A per-type heuristic (icon vs. photo vs. schematic illustration) would help, but the logic is hard to pin down cleanly.
---
## 💬 Join WeChat Group

Expand Down
53 changes: 49 additions & 4 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
python main.py -i input/test.png -o output/custom/
python main.py -i input/test.png --refine
python main.py -i input/test.png --no-text
python main.py -i input/test.png --vector-level=all
python main.py -i input/test.png --no-vectors
"""

import os
Expand Down Expand Up @@ -39,7 +41,10 @@
XMLMerger,
MetricEvaluator,
RefinementProcessor,


# Stage 8: Vector export
VectorExporter,

# Text (modules/text/)
TextRestorer,

Expand Down Expand Up @@ -89,6 +94,7 @@ def __init__(self, config: dict = None):
self._xml_merger = None
self._metric_evaluator = None
self._refinement_processor = None
self._vector_exporter = None

@property
def text_restorer(self):
Expand Down Expand Up @@ -138,13 +144,21 @@ def refinement_processor(self) -> RefinementProcessor:
if self._refinement_processor is None:
self._refinement_processor = RefinementProcessor()
return self._refinement_processor

@property
def vector_exporter(self) -> VectorExporter:
if self._vector_exporter is None:
self._vector_exporter = VectorExporter()
return self._vector_exporter

def process_image(self,
image_path: str,
output_dir: str = None,
with_refinement: bool = False,
with_text: bool = True,
groups: List[PromptGroup] = None) -> Optional[str]:
groups: List[PromptGroup] = None,
vector_level: str = "granular",
no_vectors: bool = False) -> Optional[str]:
"""Run pipeline on one image. Returns output XML path or None."""
print(f"\n{'='*60}")
print(f"Processing: {image_path}")
Expand Down Expand Up @@ -264,8 +278,28 @@ def process_image(self,

output_path = merge_result.metadata.get('output_path')
print(f" Output: {output_path}")

# ============ Stage 8: Vector Export ============
if not no_vectors:
print(f"\n[8] Vector export (level={vector_level})...")
context.intermediate_results['vector_level'] = vector_level
try:
vec_result = self.vector_exporter.process(context)
if vec_result.success:
vec_count = vec_result.metadata.get('exported_count', 0)
vec_dir = vec_result.metadata.get('vector_dir', '')
print(f" Exported {vec_count} elements -> {vec_dir}")
else:
print(f" Vector export failed: {vec_result.error_message}")
except Exception as e:
print(f" Vector export failed: {e}")
import traceback
traceback.print_exc()
else:
print("\n[8] Vector export (skipped)")

print(f"\n{'='*60}\nDone.\n{'='*60}")

return output_path

except Exception as e:
Expand Down Expand Up @@ -332,6 +366,8 @@ def main():
python main.py
python main.py -i test.png --refine
python main.py -i test.png --groups image arrow
python main.py -i test.png --vector-level=all
python main.py -i test.png --no-vectors
"""
)

Expand All @@ -348,6 +384,13 @@ def main():
help="Prompt groups to process (default: all)")
parser.add_argument("--show-prompts", action="store_true",
help="Show prompt config")

# Stage 8: Vector export options
parser.add_argument("--vector-level", type=str, default="granular",
choices=['granular', 'section', 'component', 'all'],
help="Vector export granularity (default: granular)")
parser.add_argument("--no-vectors", action="store_true",
help="Skip vector export (Stage 8)")

args = parser.parse_args()

Expand Down Expand Up @@ -417,7 +460,9 @@ def main():
output_dir=output_dir,
with_refinement=args.refine,
with_text=not args.no_text,
groups=groups
groups=groups,
vector_level=args.vector_level,
no_vectors=args.no_vectors,
)
if result:
success_count += 1
Expand Down
11 changes: 11 additions & 0 deletions modules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,12 @@
from .metric_evaluator import MetricEvaluator
from .refinement_processor import RefinementProcessor

# Stage 8: Vector export
from .vector_exporter import VectorExporter
from .svg_generator import SVGGenerator
from .pdf_combiner import PDFCombiner
from .section_detector import SectionDetector

# Text (modules/text/); optional if ocr/coord_processor missing
try:
from .text.restorer import TextRestorer
Expand Down Expand Up @@ -53,4 +59,9 @@
'BasicShapeProcessor',
'MetricEvaluator',
'RefinementProcessor',
# Stage 8: Vector export
'VectorExporter',
'SVGGenerator',
'PDFCombiner',
'SectionDetector',
]
76 changes: 60 additions & 16 deletions modules/data_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,37 +255,81 @@ def from_yaml(cls, yaml_path: str) -> 'ProcessingConfig':


# ======================== 辅助函数 ========================

def _expand_forms(prompts):
"""Return set containing both lowercase-with-spaces and lowercase-with-underscores forms."""
out = set()
for p in prompts:
low = p.lower()
out.add(low)
out.add(low.replace(" ", "_"))
return out


# Lazy-built prompt-derived type sets. Built on first call of get_layer_level
# to avoid import-time cycles (prompts has no deps, but be safe).
_TYPE_SETS_CACHE = {}


def _get_type_sets():
"""Build and cache prompt-derived type sets."""
if _TYPE_SETS_CACHE:
return _TYPE_SETS_CACHE
try:
from prompts.image import IMAGE_PROMPT
from prompts.shape import SHAPE_PROMPT
from prompts.arrow import ARROW_PROMPT
from prompts.background import BACKGROUND_PROMPT
except ImportError:
# Fallback: empty sets; legacy hardcoded lists below still apply.
IMAGE_PROMPT = SHAPE_PROMPT = ARROW_PROMPT = BACKGROUND_PROMPT = []

_TYPE_SETS_CACHE["image"] = _expand_forms(IMAGE_PROMPT)
_TYPE_SETS_CACHE["shape"] = _expand_forms(SHAPE_PROMPT)
_TYPE_SETS_CACHE["arrow"] = _expand_forms(ARROW_PROMPT)
_TYPE_SETS_CACHE["background"] = _expand_forms(BACKGROUND_PROMPT)
return _TYPE_SETS_CACHE


def get_layer_level(element_type: str) -> int:
"""
根据元素类型获取默认层级

供各子模块使用,确保层级分配一致

供各子模块使用,确保层级分配一致。

v2 fix: derive image/shape/arrow/background sets from prompt files so
specific prompts like "3D heart model" or "MRI image" (which were
silently falling through to LayerLevel.OTHER and breaking stacking)
now get the correct IMAGE layer.
"""
element_type = element_type.lower()

# 背景/容器类(最底层)
if element_type in {'section_panel', 'title_bar'}:
sets = _get_type_sets()

# 背景/容器类(最底层)— legacy names + prompt-derived
if element_type in {'section_panel', 'title_bar'} or element_type in sets["background"]:
return LayerLevel.BACKGROUND.value
# 箭头/连接线
if element_type in {'arrow', 'line', 'connector'}:

# 箭头/连接线 — legacy names + prompt-derived
if element_type in {'arrow', 'line', 'connector'} or element_type in sets["arrow"]:
return LayerLevel.ARROW.value

# 文字
if element_type == 'text':
return LayerLevel.TEXT.value

# 图片类
if element_type in {'icon', 'picture', 'image', 'logo', 'chart', 'function_graph'}:

# 图片类 — legacy names + prompt-derived (this is the fix path for the heart bug)
if element_type in {
'icon', 'picture', 'image', 'logo', 'chart', 'function_graph'
} or element_type in sets["image"]:
return LayerLevel.IMAGE.value
# 基本图形

# 基本图形 — legacy names + prompt-derived
if element_type in {
'rectangle', 'rounded_rectangle', 'rounded rectangle',
'diamond', 'ellipse', 'circle', 'cylinder', 'cloud',
'hexagon', 'triangle', 'parallelogram', 'actor'
}:
} or element_type in sets["shape"]:
return LayerLevel.BASIC_SHAPE.value

# 其他
return LayerLevel.OTHER.value
11 changes: 9 additions & 2 deletions modules/icon_picture_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,8 +315,15 @@ def process(self, context: ProcessingContext) -> ProcessingResult:
)

def _get_elements_to_process(self, elements: List[ElementInfo]) -> List[ElementInfo]:
"""Filter elements to process (icons, arrows, etc.; arrows treated as icon crop)."""
all_types = set(IMAGE_PROMPT) | {"arrow", "line", "connector"}
"""Filter elements to process (icons, arrows, etc.; arrows treated as icon crop).

NOTE: IMAGE_PROMPT contains mixed-case strings (e.g. "3D heart model",
"MRI image", "CT scan image"). Comparing `.lower()` against the raw
set caused those detections to be silently skipped — no base64 was
generated and downstream SVG rendering fell back to a plain polygon
outline. Always normalize both sides.
"""
all_types = {t.lower() for t in IMAGE_PROMPT} | {"arrow", "line", "connector"}
return [
e for e in elements
if e.element_type.lower() in all_types and e.base64 is None
Expand Down
Loading