-
Notifications
You must be signed in to change notification settings - Fork 1.2k
/
Copy pathdata.py
407 lines (304 loc) · 12.6 KB
/
data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
"""Placeholder docstring"""
from __future__ import absolute_import
import os
import platform
import sys
import tempfile
from abc import ABCMeta
from abc import abstractmethod
from six import with_metaclass
from six.moves.urllib.parse import urlparse
import sagemaker.amazon.common
import sagemaker.local.utils
import sagemaker.utils
def get_data_source_instance(data_source, sagemaker_session):
"""Return an Instance of :class:`sagemaker.local.data.DataSource`.
The instance can handle the provided data_source URI.
data_source can be either file:// or s3://
Args:
data_source (str): a valid URI that points to a data source.
sagemaker_session (:class:`sagemaker.session.Session`): a SageMaker Session to
interact with S3 if required.
Returns:
sagemaker.local.data.DataSource: an Instance of a Data Source
Raises:
ValueError: If parsed_uri scheme is neither `file` nor `s3` , raise an
error.
"""
parsed_uri = urlparse(data_source)
if parsed_uri.scheme == "file":
return LocalFileDataSource(parsed_uri.netloc + parsed_uri.path)
if parsed_uri.scheme == "s3":
return S3DataSource(parsed_uri.netloc, parsed_uri.path, sagemaker_session)
raise ValueError(
"data_source must be either file or s3. parsed_uri.scheme: {}".format(parsed_uri.scheme)
)
def get_splitter_instance(split_type):
"""Return an Instance of :class:`sagemaker.local.data.Splitter`.
The instance returned is according to the specified `split_type`.
Args:
split_type (str): either 'Line' or 'RecordIO'. Can be left as None to
signal no data split will happen.
Returns
:class:`sagemaker.local.data.Splitter`: an Instance of a Splitter
"""
if split_type == "None" or split_type is None:
return NoneSplitter()
if split_type == "Line":
return LineSplitter()
if split_type == "RecordIO":
return RecordIOSplitter()
raise ValueError("Invalid Split Type: %s" % split_type)
def get_batch_strategy_instance(strategy, splitter):
"""Return an Instance of :class:`sagemaker.local.data.BatchStrategy` according to `strategy`
Args:
strategy (str): Either 'SingleRecord' or 'MultiRecord'
splitter (:class:`sagemaker.local.data.Splitter): splitter to get the data from.
Returns
:class:`sagemaker.local.data.BatchStrategy`: an Instance of a BatchStrategy
"""
if strategy == "SingleRecord":
return SingleRecordStrategy(splitter)
if strategy == "MultiRecord":
return MultiRecordStrategy(splitter)
raise ValueError('Invalid Batch Strategy: %s - Valid Strategies: "SingleRecord", "MultiRecord"')
class DataSource(with_metaclass(ABCMeta, object)):
"""Placeholder docstring"""
@abstractmethod
def get_file_list(self):
"""Retrieve the list of absolute paths to all the files in this data source.
Returns:
List[str]: List of absolute paths.
"""
@abstractmethod
def get_root_dir(self):
"""Retrieve the absolute path to the root directory of this data source.
Returns:
str: absolute path to the root directory of this data source.
"""
class LocalFileDataSource(DataSource):
"""Represents a data source within the local filesystem."""
def __init__(self, root_path):
super(LocalFileDataSource, self).__init__()
self.root_path = os.path.abspath(root_path)
if not os.path.exists(self.root_path):
raise RuntimeError("Invalid data source: %s does not exist." % self.root_path)
def get_file_list(self):
"""Retrieve the list of absolute paths to all the files in this data source.
Returns:
List[str] List of absolute paths.
"""
if os.path.isdir(self.root_path):
return [
os.path.join(self.root_path, f)
for f in os.listdir(self.root_path)
if os.path.isfile(os.path.join(self.root_path, f))
]
return [self.root_path]
def get_root_dir(self):
"""Retrieve the absolute path to the root directory of this data source.
Returns:
str: absolute path to the root directory of this data source.
"""
if os.path.isdir(self.root_path):
return self.root_path
return os.path.dirname(self.root_path)
class S3DataSource(DataSource):
"""Defines a data source given by a bucket and S3 prefix.
The contents will be downloaded and then processed as local data.
"""
def __init__(self, bucket, prefix, sagemaker_session):
"""Create an S3DataSource instance.
Args:
bucket (str): S3 bucket name
prefix (str): S3 prefix path to the data
sagemaker_session (:class:`sagemaker.session.Session`): a sagemaker_session with the
desired settings
to talk to S3
"""
super(S3DataSource, self).__init__()
# Create a temporary dir to store the S3 contents
root_dir = sagemaker.utils.get_config_value(
"local.container_root", sagemaker_session.config
)
if root_dir:
root_dir = os.path.abspath(root_dir)
working_dir = tempfile.mkdtemp(dir=root_dir)
# Docker cannot mount Mac OS /var folder properly see
# https://forums.docker.com/t/var-folders-isnt-mounted-properly/9600
# Only apply this workaround if the user didn't provide an alternate storage root dir.
if root_dir is None and platform.system() == "Darwin":
working_dir = "/private{}".format(working_dir)
sagemaker.utils.download_folder(bucket, prefix, working_dir, sagemaker_session)
self.files = LocalFileDataSource(working_dir)
def get_file_list(self):
"""Retrieve the list of absolute paths to all the files in this data source.
Returns:
List[str]: List of absolute paths.
"""
return self.files.get_file_list()
def get_root_dir(self):
"""Retrieve the absolute path to the root directory of this data source.
Returns:
str: absolute path to the root directory of this data source.
"""
return self.files.get_root_dir()
class Splitter(with_metaclass(ABCMeta, object)):
"""Placeholder docstring"""
@abstractmethod
def split(self, file):
"""Split a file into records using a specific strategy
Args:
file (str): path to the file to split
Returns:
generator for the individual records that were split from the file
"""
class NoneSplitter(Splitter):
"""Does not split records, essentially reads the whole file."""
# non-utf8 characters.
_textchars = bytearray({7, 8, 9, 10, 12, 13, 27} | set(range(0x20, 0x100)) - {0x7F})
def split(self, filename):
"""Split a file into records using a specific strategy.
For this NoneSplitter there is no actual split happening and the file
is returned as a whole.
Args:
filename (str): path to the file to split
Returns: generator for the individual records that were split from
the file
"""
with open(filename, "rb") as f:
buf = f.read()
if not self._is_binary(buf):
buf = buf.decode()
yield buf
def _is_binary(self, buf):
"""Check whether `buf` contains binary data.
Returns True if `buf` contains any non-utf-8 characters.
Args:
buf (bytes): data to inspect
Returns:
True if data is binary, otherwise False
"""
return bool(buf.translate(None, self._textchars))
class LineSplitter(Splitter):
"""Split records by new line."""
def split(self, file):
"""Split a file into records using a specific strategy
This LineSplitter splits the file on each line break.
Args:
file (str): path to the file to split
Returns: generator for the individual records that were split from
the file
"""
with open(file, "r") as f:
for line in f:
yield line
class RecordIOSplitter(Splitter):
"""Split using Amazon Recordio.
Not useful for string content.
"""
def split(self, file):
"""Split a file into records using a specific strategy
This RecordIOSplitter splits the data into individual RecordIO
records.
Args:
file (str): path to the file to split
Returns: generator for the individual records that were split from
the file
"""
with open(file, "rb") as f:
for record in sagemaker.amazon.common.read_recordio(f):
yield record
class BatchStrategy(with_metaclass(ABCMeta, object)):
"""Placeholder docstring"""
def __init__(self, splitter):
"""Create a Batch Strategy Instance
Args:
splitter (sagemaker.local.data.Splitter): A Splitter to pre-process
the data before batching.
"""
self.splitter = splitter
@abstractmethod
def pad(self, file, size):
"""Group together as many records as possible to fit in the specified size.
Args:
file (str): file path to read the records from.
size (int): maximum size in MB that each group of records will be
fitted to. passing 0 means unlimited size.
Returns:
generator of records
"""
class MultiRecordStrategy(BatchStrategy):
"""Feed multiple records at a time for batch inference.
Will group up as many records as possible within the payload specified.
"""
def pad(self, file, size=6):
"""Group together as many records as possible to fit in the specified size.
Args:
file (str): file path to read the records from.
size (int): maximum size in MB that each group of records will be
fitted to. passing 0 means unlimited size.
Returns:
generator of records
"""
buffer = ""
for element in self.splitter.split(file):
if _payload_size_within_limit(buffer + element, size):
buffer += element
else:
tmp = buffer
buffer = element
yield tmp
if _validate_payload_size(buffer, size):
yield buffer
class SingleRecordStrategy(BatchStrategy):
"""Feed a single record at a time for batch inference.
If a single record does not fit within the payload specified it will
throw a RuntimeError.
"""
def pad(self, file, size=6):
"""Group together as many records as possible to fit in the specified size.
This SingleRecordStrategy will not group any record and will return
them one by one as long as they are within the maximum size.
Args:
file (str): file path to read the records from.
size (int): maximum size in MB that each group of records will be
fitted to. passing 0 means unlimited size.
Returns:
generator of records
"""
for element in self.splitter.split(file):
if _validate_payload_size(element, size):
yield element
def _payload_size_within_limit(payload, size):
"""Placeholder docstring."""
size_in_bytes = size * 1024 * 1024
if size == 0:
return True
return sys.getsizeof(payload) < size_in_bytes
def _validate_payload_size(payload, size):
"""Check if a payload is within the size in MB threshold.
Raise an exception if the payload is beyond the size in MB threshold.
Args:
payload: data that will be checked
size (int): max size in MB
Returns:
bool: True if within bounds. if size=0 it will always return True
Raises:
RuntimeError: If the payload is larger a runtime error is thrown.
"""
if _payload_size_within_limit(payload, size):
return True
raise RuntimeError("Record is larger than %sMB. Please increase your max_payload" % size)