Skip to content

Commit 4af232a

Browse files
authored
Merge pull request #9 from aoxolotl/seg/fromMask
Create segmentations from binary mask
2 parents f4d375a + a31152b commit 4af232a

2 files changed

Lines changed: 71 additions & 5 deletions

File tree

Lines changed: 67 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,15 @@
11
import numpy as np
22

3-
from typing import List
3+
from typing import List, Optional
4+
from imantics import Polygons, Mask
5+
import shapely.ops
6+
from shapely import geometry
47

58
from .typings import Segment
69
from ..source import Source
10+
from ....scripts.utils.simplify import simplify_points
11+
from ....entity.annotation import Annotation
12+
from datatorch.api import ApiClient
713

814

915
__all__ = "Segmentations"
@@ -14,8 +20,66 @@ class Segmentations(Source):
1420
type: str = "PaperSegmentations"
1521
path_data: List[Segment]
1622

17-
def from_mask(self, mask: np.array):
18-
pass
23+
def __init__(self, path_data: Optional[List[Segment]] = None):
24+
super().__init__()
25+
self.path_data = path_data or []
26+
27+
def from_mask(self, mask: np.array, simplify: int = 0):
28+
# convert mask to polygons
29+
polygons = Mask(mask).polygons().points
30+
# TODO: handle shifted polygons in cropped case
31+
polygons = [polygon.tolist() for polygon in polygons]
32+
33+
# polygons -> path_data
34+
# simplify
35+
self.path_data = polygons
36+
if simplify:
37+
self.path_data = [
38+
simplify_points(polygon, tolerance=simplify, highestQuality=False)
39+
for polygon in polygons
40+
]
41+
42+
# filter polygons
43+
self.path_data = list(filter(lambda x: len(x) > 2, self.path_data))
44+
45+
def combine_segmentations(self, annotation):
46+
if len(self.path_data) == 0:
47+
raise ValueError("No path data to combine")
48+
49+
self.annotation_id = annotation.id
50+
existing_segmentation = next(
51+
x
52+
for x in annotation.get("sources")
53+
if x.get("type") == "PaperSegmentations"
54+
)
55+
poly_1 = [geometry.Polygon(points) for points in self.path_data]
56+
poly_2 = [geometry.Polygon(points) for points in existing_segmentation]
57+
58+
multi = shapely.ops.unary_union(poly_1 + poly_2)
59+
60+
if isinstance(multi, geometry.Polygon):
61+
self.path_data.append(list(multi.exterior.coords[:-1]))
62+
63+
if isinstance(multi, geometry.MultiPolygon):
64+
for polygon in multi:
65+
self.path_data.append(list(polygon.exterior.coords[:-1]))
66+
67+
print(
68+
f"Updated segmentation for annotation {annotation.id}",
69+
flush=True,
70+
)
71+
72+
def create_new_annotation(self, label_id: str, file_id: str):
73+
print("Creating new annotation")
74+
new_annotation = Annotation()
75+
new_annotation.label_id = label_id
76+
new_annotation.file_id = file_id
77+
new_annotation.create(ApiClient())
78+
annotation_id = new_annotation.id
79+
80+
self.annotation_id = annotation_id
81+
self.create(ApiClient())
82+
print("Segmentation created")
1983

2084
def to_mask(self) -> np.array:
2185
pass

setup.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
"gql==3.4.0",
1313
"websockets==10.4",
1414
"websocket-client",
15-
"requests==2.27.1",
15+
"requests==2.31.0",
1616
"typing_extensions>=4.1.0",
1717
"psutil~=5.9.4",
1818
"aiodocker~=0.19.0",
@@ -21,6 +21,8 @@
2121
"aiostream~=0.4.0",
2222
"markupsafe==2.0.1",
2323
"requests_toolbelt==0.10.1",
24+
"imantics==0.1.12",
25+
"shapely==2.0.1",
2426
"tqdm~=4.65.0",
2527
"urllib3==1.26.15",
2628
]
@@ -29,7 +31,7 @@
2931

3032
setup(
3133
name="datatorch",
32-
version="0.4.7.1",
34+
version="0.4.7.2",
3335
description="A CLI and library for interacting with DataTorch.",
3436
author="DataTorch",
3537
author_email="support@datatorch.io",

0 commit comments

Comments
 (0)