1
+
2
+ # Copyright (C) 2019 Intel Corporation
3
+ #
4
+ # SPDX-License-Identifier: MIT
5
+
6
+ from collections import OrderedDict
7
+ import os .path as osp
8
+ import re
9
+
10
+ from datumaro .components .extractor import (Extractor , DatasetItem ,
11
+ AnnotationType , BboxObject , LabelCategories
12
+ )
13
+ from datumaro .components .formats .yolo import YoloPath
14
+ from datumaro .util .image import lazy_image
15
+
16
+
17
+ class YoloExtractor (Extractor ):
18
+ class Subset (Extractor ):
19
+ def __init__ (self , name , parent ):
20
+ super ().__init__ ()
21
+ self ._name = name
22
+ self ._parent = parent
23
+ self .items = OrderedDict ()
24
+
25
+ def __iter__ (self ):
26
+ for item_id in self .items :
27
+ yield self ._parent ._get (item_id , self ._name )
28
+
29
+ def __len__ (self ):
30
+ return len (self .items )
31
+
32
+ def categories (self ):
33
+ return self ._parent .categories ()
34
+
35
+ def __init__ (self , config_path ):
36
+ super ().__init__ ()
37
+
38
+ if not osp .isfile (config_path ):
39
+ raise Exception ("Can't read dataset descriptor file '%s'" % \
40
+ config_path )
41
+
42
+ rootpath = osp .dirname (config_path )
43
+ self ._path = rootpath
44
+
45
+ with open (config_path , 'r' ) as f :
46
+ config_lines = f .readlines ()
47
+
48
+ subsets = OrderedDict ()
49
+ names_path = None
50
+
51
+ for line in config_lines :
52
+ match = re .match (r'(\w+)\s*=\s*(.+)$' , line )
53
+ if not match :
54
+ continue
55
+
56
+ key = match .group (1 )
57
+ value = match .group (2 )
58
+ if key == 'names' :
59
+ names_path = value
60
+ elif key in YoloPath .SUBSET_NAMES :
61
+ subsets [key ] = value
62
+ else :
63
+ continue
64
+
65
+ if not names_path :
66
+ raise Exception ("Failed to parse labels path from '%s'" % \
67
+ config_path )
68
+
69
+ for subset_name , list_path in subsets .items ():
70
+ list_path = self ._make_local_path (list_path )
71
+ if not osp .isfile (list_path ):
72
+ raise Exception ("Not found '%s' subset list file" % subset_name )
73
+
74
+ subset = YoloExtractor .Subset (subset_name , self )
75
+ with open (list_path , 'r' ) as f :
76
+ subset .items = OrderedDict (
77
+ (osp .splitext (osp .basename (p ))[0 ], p .strip ()) for p in f )
78
+
79
+ for image_path in subset .items .values ():
80
+ image_path = self ._make_local_path (image_path )
81
+ if not osp .isfile (image_path ):
82
+ raise Exception ("Can't find image '%s'" % image_path )
83
+
84
+ subsets [subset_name ] = subset
85
+
86
+ self ._subsets = subsets
87
+
88
+ self ._categories = {
89
+ AnnotationType .label :
90
+ self ._load_categories (self ._make_local_path (names_path ))
91
+ }
92
+
93
+ def _make_local_path (self , path ):
94
+ default_base = osp .join ('data' , '' )
95
+ if path .startswith (default_base ): # default path
96
+ path = path [len (default_base ) : ]
97
+ return osp .join (self ._path , path ) # relative or absolute path
98
+
99
+ def _get (self , item_id , subset_name ):
100
+ subset = self ._subsets [subset_name ]
101
+ item = subset .items [item_id ]
102
+
103
+ if isinstance (item , str ):
104
+ image_path = self ._make_local_path (item )
105
+ image = lazy_image (image_path )
106
+ h , w , _ = image ().shape
107
+ anno_path = osp .splitext (image_path )[0 ] + '.txt'
108
+ annotations = self ._parse_annotations (anno_path , w , h )
109
+
110
+ item = DatasetItem (id = item_id , subset = subset_name ,
111
+ image = image , annotations = annotations )
112
+ subset .items [item_id ] = item
113
+
114
+ return item
115
+
116
+ @staticmethod
117
+ def _parse_annotations (anno_path , image_width , image_height ):
118
+ with open (anno_path , 'r' ) as f :
119
+ annotations = []
120
+ for line in f :
121
+ label_id , xc , yc , w , h = line .strip ().split ()
122
+ label_id = int (label_id )
123
+ w = float (w )
124
+ h = float (h )
125
+ x = float (xc ) - w * 0.5
126
+ y = float (yc ) - h * 0.5
127
+ annotations .append (BboxObject (
128
+ x * image_width , y * image_height ,
129
+ w * image_width , h * image_height ,
130
+ label = label_id
131
+ ))
132
+ return annotations
133
+
134
+ @staticmethod
135
+ def _load_categories (names_path ):
136
+ label_categories = LabelCategories ()
137
+
138
+ with open (names_path , 'r' ) as f :
139
+ for label in f :
140
+ label_categories .add (label )
141
+
142
+ return label_categories
143
+
144
+ def categories (self ):
145
+ return self ._categories
146
+
147
+ def __iter__ (self ):
148
+ for subset in self ._subsets .values ():
149
+ for item in subset :
150
+ yield item
151
+
152
+ def __len__ (self ):
153
+ length = 0
154
+ for subset in self ._subsets .values ():
155
+ length += len (subset )
156
+ return length
157
+
158
+ def subsets (self ):
159
+ return list (self ._subsets )
160
+
161
+ def get_subset (self , name ):
162
+ return self ._subsets [name ]
0 commit comments