-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathlocalize_objects_tool.py
More file actions
119 lines (102 loc) · 5.01 KB
/
localize_objects_tool.py
File metadata and controls
119 lines (102 loc) · 5.01 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
"""Tool for localizing objects in images with bounding boxes using SAM3."""
from typing import Union, Dict, List
import torch
from PIL import Image
from tool.base_tool import ModelBasedTool, register_tool
from tool.utils.image_utils import image_processing
from tool.visualize_regions_tool import VisualizeRegionsOnImageTool
@register_tool(name="localize_objects")
class LocalizeObjectsTool(ModelBasedTool):
name = "localize_objects"
model_id = "sam3"
description = "Localize objects and return their bounding boxes. Use this to get bbox for other region-based tools. Note: Cannot detect text labels or annotation markers (e.g., 'A', 'B', 'point 1') drawn on images."
parameters = {
"type": "object",
"properties": {
"image": {"type": "string", "description": "Image ID (e.g., 'img_0')"},
"objects": {"type": "array", "items": {"type": "string"}, "description": "A list of object names to localize. e.g. ['dog', 'cat', 'person']."}
},
"required": ["image", "objects"]
}
example = '{"image": "img_0", "objects": ["dog", "cat"]}'
def load_model(self, device: str) -> None:
from transformers import Sam3Processor, Sam3Model
from tool.model_config import SAM3_MODEL_PATH
self.processor = Sam3Processor.from_pretrained(SAM3_MODEL_PATH)
self.model = Sam3Model.from_pretrained(SAM3_MODEL_PATH).to(device)
self.device = device
self.is_loaded = True
def _call_impl(self, params: Union[str, Dict]) -> str:
params_dict = self.parse_params(params)
image_path = params_dict["image"]
objects = params_dict["objects"]
if not isinstance(objects, list) or len(objects) == 0:
return {
"error": "objects must be a non-empty list of strings"
}
try:
image = image_processing(image_path)
W, H = image.size
regions = []
obj_cnt = {}
# Process each object separately with SAM3
for obj_name in objects:
# Prepare inputs for SAM3
inputs = self.processor(images=image, text=obj_name, return_tensors="pt").to(self.device)
with torch.no_grad():
outputs = self.model(**inputs)
# Post-process to get instance segmentation results
results = self.processor.post_process_instance_segmentation(
outputs,
threshold=0.5,
mask_threshold=0.5,
target_sizes=inputs.get("original_sizes").tolist()
)[0]
boxes = results["boxes"]
scores = results["scores"]
# Filter by score > 0.50 and process results
for box, score in zip(boxes, scores):
score_val = score.item()
if score_val < 0.50:
continue
# Convert box to list and normalize to [0, 1] range
box_list = box.tolist()
bbox = [
round(float(box_list[0]) / W, 4), # x1
round(float(box_list[1]) / H, 4), # y1
round(float(box_list[2]) / W, 4), # x2
round(float(box_list[3]) / H, 4) # y2
]
obj_cnt[obj_name] = obj_cnt.get(obj_name, 0) + 1
label_out = f"{obj_name}-{obj_cnt[obj_name]}" if obj_cnt[obj_name] > 1 else obj_name
regions.append({
"label": label_out,
"bbox_2d": bbox,
"score": round(score_val, 4)
})
# Visualize results
visualize_tool = VisualizeRegionsOnImageTool()
visualize_params = {
"image": image_path,
"regions": [{"bbox_2d": r["bbox_2d"], "label": r["label"]} for r in regions]
}
output_image_result = visualize_tool.call(visualize_params)
output_image = output_image_result.get("output_image")
# Return dict with PIL Image
return {
"output_image": output_image,
"regions": regions
}
except FileNotFoundError as e:
return {"error": f"Image file not found: {str(e)}"}
except Exception as e:
return {"error": f"Error localizing objects: {str(e)}"}
def generate_description(self, properties, observation):
"""Generate description for localized objects."""
img = properties.get("image", "image")
objects = properties.get("objects", [])
if isinstance(objects, list):
objects_str = ", ".join(objects) if objects else "objects"
else:
objects_str = str(objects)
return f"Localized {objects_str} in {img}"