|
| 1 | + |
| 2 | +from datatorch.api import ApiClient, Where, Label, Annotation, Segmentations, BoundingBox |
| 3 | + |
| 4 | +from pycocotools.coco import COCO |
| 5 | + |
| 6 | + |
| 7 | +if __name__ == "__main__": |
| 8 | + |
| 9 | + # DataTorch project ID |
| 10 | + project_id = 'DATATORCH_PROJECT_ID' |
| 11 | + # Path to annotation file |
| 12 | + anno_file = 'path/to/coco' |
| 13 | + # Only add annotations above this score |
| 14 | + min_source = 0.8 |
| 15 | + |
| 16 | + # Connect to DataTorch |
| 17 | + api = ApiClient() |
| 18 | + project = api.project(project_id) |
| 19 | + labels = project.labels() |
| 20 | + |
| 21 | + print('=' * 50) |
| 22 | + print(f'Project: {project.name}') |
| 23 | + names = [label.name for label in labels] |
| 24 | + print(f'Labels: {" ".join(names)}') |
| 25 | + print('=' * 50) |
| 26 | + |
| 27 | + |
| 28 | + def category_in_project(name: str) -> Label: |
| 29 | + """ Returns category in project """ |
| 30 | + for cat in labels: |
| 31 | + if cat.name.lower() == name.lower(): |
| 32 | + return cat |
| 33 | + return None |
| 34 | + |
| 35 | + |
| 36 | + # Load coco |
| 37 | + coco = COCO(anno_file) |
| 38 | + cats = coco.loadCats(coco.getCatIds()) |
| 39 | + names = [cat['name'] for cat in cats] |
| 40 | + |
| 41 | + label_maping = {} |
| 42 | + for cat in cats: |
| 43 | + name = cat['name'] |
| 44 | + found = category_in_project(name) |
| 45 | + |
| 46 | + if not found: |
| 47 | + print(f'label "{name}" not found in project') |
| 48 | + else: |
| 49 | + label_maping[cat['id']] = found |
| 50 | + |
| 51 | + |
| 52 | + print(f'COCO Categories: {" ".join(names)}') |
| 53 | + |
| 54 | + for img_id in coco.getImgIds(): |
| 55 | + img, = coco.loadImgs(img_id) |
| 56 | + name = img['file_name'] |
| 57 | + |
| 58 | + find_by_name = Where(name=name, mimetype__starts_with='image') |
| 59 | + dt_files = project.files(find_by_name) |
| 60 | + files_count = len(dt_files) |
| 61 | + |
| 62 | + if files_count > 1: |
| 63 | + print(f'\nMultiple files found with name {name}, skipping') |
| 64 | + continue |
| 65 | + |
| 66 | + if files_count == 0: |
| 67 | + print(f'\nNo files found with name {name}, skipping') |
| 68 | + continue |
| 69 | + |
| 70 | + print(f'\n{name} found. Importing annotations') |
| 71 | + dt_file = dt_files[0] |
| 72 | + |
| 73 | + # load file annotations |
| 74 | + anno_ids = coco.getAnnIds(imgIds=img['id']) |
| 75 | + annos = coco.loadAnns(anno_ids) |
| 76 | + |
| 77 | + for anno in annos: |
| 78 | + # Create annotation |
| 79 | + if anno.get('datatorch_id') is not None: |
| 80 | + print(f'Annotation {anno["id"]} already exists in DataTorch, skipping') |
| 81 | + |
| 82 | + source = anno.get('source') |
| 83 | + if anno.get('source') is not None and source < min_source: |
| 84 | + continue |
| 85 | + |
| 86 | + label = label_maping[anno['category_id']] |
| 87 | + if label is None: |
| 88 | + continue |
| 89 | + |
| 90 | + dt_anno = Annotation(label=label) |
| 91 | + bbox = BoundingBox.xywh(*anno['bbox']) |
| 92 | + dt_anno.add(bbox) |
| 93 | + print(bbox.__dict__) |
| 94 | + dt_file.add(dt_anno) |
| 95 | + |
| 96 | + |
| 97 | + |
| 98 | + |
0 commit comments