Skip to content

Commit 7d708b4

Browse files
Merge pull request cvat-ai#66 from openvinotoolkit/ay/transform-labels
Add function to transform labels
2 parents 0e48bb8 + 4702754 commit 7d708b4

File tree

3 files changed

+43
-1
lines changed

3 files changed

+43
-1
lines changed

CHANGELOG.md

+1
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
99
## [Unreleased]
1010
### Added
1111
- `WiderFace` dataset format (<https://github.com/openvinotoolkit/datumaro/pull/65>)
12+
- Function to transform annotations to labels (<https://github.com/openvinotoolkit/datumaro/pull/66>)
1213

1314
### Changed
1415
-

datumaro/plugins/transforms.py

+16-1
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
import pycocotools.mask as mask_utils
1313

1414
from datumaro.components.extractor import (Transform, AnnotationType,
15-
RleMask, Polygon, Bbox, DEFAULT_SUBSET_NAME,
15+
RleMask, Polygon, Bbox, Label, DEFAULT_SUBSET_NAME,
1616
LabelCategories, MaskCategories, PointsCategories
1717
)
1818
from datumaro.components.cli_plugin import CliPlugin
@@ -541,4 +541,19 @@ def transform_item(self, item):
541541
annotations.append(ann.wrap(label=conv_label))
542542
else:
543543
annotations.append(ann.wrap())
544+
return item.wrap(annotations=annotations)
545+
546+
class AnnsToLabels(Transform, CliPlugin):
547+
"""
548+
Collects all labels from annotations (of all types) and
549+
transforms them into a set of annotations of type Label
550+
"""
551+
552+
def transform_item(self, item):
553+
labels = set(p.label for p in item.annotations
554+
if getattr(p, 'label') != None)
555+
annotations = []
556+
for label in labels:
557+
annotations.append(Label(label=label))
558+
544559
return item.wrap(annotations=annotations)

tests/test_transforms.py

+26
Original file line numberDiff line numberDiff line change
@@ -386,3 +386,29 @@ def test_remap_labels_delete_unspecified(self):
386386
mapping={}, default='delete')
387387

388388
compare_datasets(self, target_dataset, actual)
389+
390+
def test_transform_labels(self):
391+
src_dataset = Dataset.from_iterable([
392+
DatasetItem(id=1, annotations=[
393+
Label(1),
394+
Bbox(1, 2, 3, 4, label=2),
395+
Bbox(1, 3, 3, 3),
396+
Mask(image=np.array([1]), label=3),
397+
Polygon([1, 1, 2, 2, 3, 4], label=4),
398+
PolyLine([1, 3, 4, 2, 5, 6], label=5)
399+
])
400+
], categories=['label%s' % i for i in range(6)])
401+
402+
dst_dataset = Dataset.from_iterable([
403+
DatasetItem(id=1, annotations=[
404+
Label(1),
405+
Label(2),
406+
Label(3),
407+
Label(4),
408+
Label(5)
409+
]),
410+
], categories=['label%s' % i for i in range(6)])
411+
412+
actual = transforms.AnnsToLabels(src_dataset)
413+
414+
compare_datasets(self, dst_dataset, actual)

0 commit comments

Comments
 (0)