-
Notifications
You must be signed in to change notification settings - Fork 3
Expand file tree
/
Copy pathinteractive_annotate_model.py
More file actions
86 lines (67 loc) · 2.67 KB
/
interactive_annotate_model.py
File metadata and controls
86 lines (67 loc) · 2.67 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
import cv2
import numpy as np
from PIL import Image
from typing import Any, Literal
from ultralytics import SAM
from utils import rle_encode
class InteractiveAnnotateModel:
def __init__(self):
self.model: Any = None
self.loaded: Any = None
def load_model(self, model_name: Literal['sam2.1_l'] = 'sam2.1_l'):
"""加载模型"""
if self.loaded:
return
if model_name == 'sam2.1_l':
self.model = SAM("sam2.1_l.pt")
self.loaded = model_name
def unload_model(self):
if hasattr(self, 'model') and self.model is not None:
del self.model
self.model = None
self.loaded = None
def inference_points(self, image: Image.Image, points: list[tuple[float, float]], labels: list[int]):
"""根据点击点进行推理
Args:
image: PIL 图片
points: 归一化坐标点列表 [(x, y), ...]
labels: 标签列表,1 表示包含点,0 表示排除点
"""
if self.loaded is None:
self.load_model()
width, height = image.size
if self.loaded == 'sam2.1_l':
# 将归一化坐标转换为像素坐标
points_int = [(int(round(xf * width)), int(round(yf * height))) for (xf, yf) in points]
result = self.model.predict(image, points=points_int, labels=labels)
return result
def inference_box(self, image: Image.Image, boxes: list[list[int]]):
"""根据边框进行推理
Args:
image: PIL 图片
boxes: 像素坐标边框列表 [[x1, y1, x2, y2], ...]
"""
if self.loaded is None:
self.load_model()
if self.loaded == 'sam2.1_l':
result = self.model.predict(image, bboxes=boxes)
return result
return None
def decode_result(self, image: Image.Image, result: Any, class_name: str):
ret = []
width, height = image.size
if self.loaded == 'sam2.1_l':
boxes = result[0].boxes.data.tolist()
masks = result[0].masks.data
polygons = result[0].masks.xyn
for [x1, y1, x2, y2, score, index], mask, polygon in zip(boxes, masks, polygons):
data = mask.cpu().numpy()
ret.append({
"mode": "M",
"class": class_name,
"score": float(score),
"points": [(x1 / width, y1 / height), (x2 / width, y2 / height)],
"mask": {'counts': rle_encode(data), 'size': [data.shape[0], data.shape[1]]},
"polygon": polygon.tolist()
})
return ret