Skip to content

Commit 7295ebc

Browse files
committed
chore: added import coco example
1 parent 92b0f18 commit 7295ebc

7 files changed

Lines changed: 155 additions & 14 deletions

File tree

datatorch/api/entity/annotation.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -62,10 +62,10 @@ def label(self, label: Label) -> None:
6262
self.label_id = label.id
6363

6464
def create(self, client=None):
65-
super().save(client=client)
65+
super().create(client=client)
6666

67-
assert self.label_id is not None
68-
assert self.file_id is not None
67+
assert self.label_id is not None, "Annotation must have a file ID"
68+
assert self.file_id is not None, "Annotation must have label ID"
6969

7070
params = {
7171
"id": self.id,
@@ -75,9 +75,9 @@ def create(self, client=None):
7575
"color": self.color,
7676
}
7777
results = self.client.execute(_CREATE_ANNOTATION, params=params)
78-
7978
r_anno = results.get("annotation")
8079
self.__dict__.update(camel_to_snake(r_anno))
8180

8281
for source in self.sources:
82+
source.annotation_id = self.id
8383
source.create(client=self.client)

datatorch/api/entity/base.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -72,12 +72,10 @@ def to_json(self, indent: int = 2) -> str:
7272
return json.dumps(self.dict(), indent=indent)
7373

7474
def create(self, client=None):
75-
if self.id is not None:
76-
ValueError("Entity already has an ID.")
75+
assert self.id is None
7776
if client:
7877
self.client = client
79-
if self.client is None:
80-
ValueError("Entity does not have a client.")
78+
assert self.client is not None
8179

8280
def save(self, client=None):
8381
assert self.id is not None

datatorch/api/entity/dataset.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,25 @@
11
from .base import BaseEntity
22

33

4+
_CREATE_DATASET = """
5+
mutation CreateDataset(
6+
$projectId: ID!
7+
$name: String!
8+
description: String
9+
) {
10+
dataset: createDataset(
11+
input: {
12+
projectId: $projectId
13+
name: $name
14+
description: $description
15+
}
16+
) {
17+
id
18+
}
19+
}
20+
"""
21+
22+
423
class Dataset(BaseEntity):
524

625
id: str
@@ -11,3 +30,18 @@ class Dataset(BaseEntity):
1130
formatted_bytes: int
1231
created_at: str
1332
updated_at: str
33+
34+
def create(self, client=None):
35+
super().create(client=client)
36+
37+
assert self.project_id is not None
38+
results = self.execute(
39+
_CREATE_DATASET,
40+
params={
41+
"projectId": self.project_id,
42+
"name": self.name,
43+
"description": self.description,
44+
},
45+
)
46+
47+
self.id = results.get("dataset").get("id")

datatorch/api/entity/file.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -124,8 +124,7 @@ def add(self, anno: Annotation) -> None:
124124
anno (:obj:`Annotation`): annotation to be added
125125
"""
126126
self.annotations.append(anno)
127-
128-
if self.id is None:
127+
if self.id is not None:
129128
anno.file_id = self.id
130129
anno.create(client=self.client)
131130

@@ -146,5 +145,5 @@ def annotator(self):
146145
def to_json(self, indent: int = 2) -> str:
147146
dic = self.__dict__.copy()
148147
dic.pop("client")
149-
dic["annotations"] = [anno.__dict__ for anno in dic["annotations"]]
148+
dic["annotations"] = [anno.to_json() for anno in dic["annotations"]]
150149
return json.dumps(dic, indent=indent)

datatorch/api/entity/sources/image/bounding_box.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ class BoundingBox(Source):
1414
height: float
1515

1616
@classmethod
17-
def create(cls, x, y, width, height):
17+
def xywh(cls, x, y, width, height):
1818
return cls(dict(x=x, y=y, width=width, height=height))
1919

2020
@property

datatorch/api/entity/sources/source.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
data: $data
1717
) {
1818
id
19-
annotationId
2019
}
2120
}
2221
"""
@@ -39,5 +38,18 @@ def data(self):
3938
del obj["annotation_id"]
4039
return dict([(snake_to_camel(k), v) for k, v in obj.items()])
4140

42-
def save(self, client=None):
41+
def create(self, client=None):
4342
super().create(client=client)
43+
44+
assert self.type is not None, "Source must have a type"
45+
results = self.client.execute(
46+
_CREATE_SOURCE,
47+
params={
48+
"id": self.id,
49+
"annotationId": self.annotation_id,
50+
"type": self.type,
51+
"data": self.data(),
52+
},
53+
)
54+
r_source = results.get("source")
55+
self.id = r_source.get("id")

examples/import-coco.py

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
2+
from datatorch.api import ApiClient, Where, Label, Annotation, Segmentations, BoundingBox
3+
4+
from pycocotools.coco import COCO
5+
6+
7+
if __name__ == "__main__":
8+
9+
# DataTorch project ID
10+
project_id = 'DATATORCH_PROJECT_ID'
11+
# Path to annotation file
12+
anno_file = 'path/to/coco'
13+
# Only add annotations above this score
14+
min_source = 0.8
15+
16+
# Connect to DataTorch
17+
api = ApiClient()
18+
project = api.project(project_id)
19+
labels = project.labels()
20+
21+
print('=' * 50)
22+
print(f'Project: {project.name}')
23+
names = [label.name for label in labels]
24+
print(f'Labels: {" ".join(names)}')
25+
print('=' * 50)
26+
27+
28+
def category_in_project(name: str) -> Label:
29+
""" Returns category in project """
30+
for cat in labels:
31+
if cat.name.lower() == name.lower():
32+
return cat
33+
return None
34+
35+
36+
# Load coco
37+
coco = COCO(anno_file)
38+
cats = coco.loadCats(coco.getCatIds())
39+
names = [cat['name'] for cat in cats]
40+
41+
label_maping = {}
42+
for cat in cats:
43+
name = cat['name']
44+
found = category_in_project(name)
45+
46+
if not found:
47+
print(f'label "{name}" not found in project')
48+
else:
49+
label_maping[cat['id']] = found
50+
51+
52+
print(f'COCO Categories: {" ".join(names)}')
53+
54+
for img_id in coco.getImgIds():
55+
img, = coco.loadImgs(img_id)
56+
name = img['file_name']
57+
58+
find_by_name = Where(name=name, mimetype__starts_with='image')
59+
dt_files = project.files(find_by_name)
60+
files_count = len(dt_files)
61+
62+
if files_count > 1:
63+
print(f'\nMultiple files found with name {name}, skipping')
64+
continue
65+
66+
if files_count == 0:
67+
print(f'\nNo files found with name {name}, skipping')
68+
continue
69+
70+
print(f'\n{name} found. Importing annotations')
71+
dt_file = dt_files[0]
72+
73+
# load file annotations
74+
anno_ids = coco.getAnnIds(imgIds=img['id'])
75+
annos = coco.loadAnns(anno_ids)
76+
77+
for anno in annos:
78+
# Create annotation
79+
if anno.get('datatorch_id') is not None:
80+
print(f'Annotation {anno["id"]} already exists in DataTorch, skipping')
81+
82+
source = anno.get('source')
83+
if anno.get('source') is not None and source < min_source:
84+
continue
85+
86+
label = label_maping[anno['category_id']]
87+
if label is None:
88+
continue
89+
90+
dt_anno = Annotation(label=label)
91+
bbox = BoundingBox.xywh(*anno['bbox'])
92+
dt_anno.add(bbox)
93+
print(bbox.__dict__)
94+
dt_file.add(dt_anno)
95+
96+
97+
98+

0 commit comments

Comments
 (0)