3
3
from onmt .utils .logging import logger
4
4
from onmt .constants import CorpusName
5
5
from onmt .transforms import TransformPipe
6
+ from onmt .inputters .dataset_base import _dynamic_dict
7
+ from torchtext .data import Dataset as TorchtextDataset , \
8
+ Example as TorchtextExample
6
9
7
10
from collections import Counter
8
11
from contextlib import contextmanager
9
12
13
+ import multiprocessing as mp
14
+
10
15
11
16
@contextmanager
12
17
def exfile_open (filename , * args , ** kwargs ):
@@ -36,6 +41,69 @@ def exfile_open(filename, *args, **kwargs):
36
41
_file .close ()
37
42
38
43
44
+ class DatasetAdapter (object ):
45
+ """Adapte a buckets of tuples into examples of a torchtext Dataset."""
46
+
47
+ valid_field_name = (
48
+ 'src' , 'tgt' , 'indices' , 'src_map' , 'src_ex_vocab' , 'alignment' ,
49
+ 'align' )
50
+
51
+ def __init__ (self , fields , is_train ):
52
+ self .fields_dict = self ._valid_fields (fields )
53
+ self .is_train = is_train
54
+
55
+ @classmethod
56
+ def _valid_fields (cls , fields ):
57
+ """Return valid fields in dict format."""
58
+ return {
59
+ f_k : f_v for f_k , f_v in fields .items ()
60
+ if f_k in cls .valid_field_name
61
+ }
62
+
63
+ @staticmethod
64
+ def _process (item , is_train ):
65
+ """Return valid transformed example from `item`."""
66
+ example , transform , cid = item
67
+ # this is a hack: appears quicker to apply it here
68
+ # than in the ParallelCorpusIterator
69
+ maybe_example = transform .apply (
70
+ example , is_train = is_train , corpus_name = cid )
71
+ if maybe_example is None :
72
+ return None
73
+ maybe_example ['src' ] = ' ' .join (maybe_example ['src' ])
74
+ maybe_example ['tgt' ] = ' ' .join (maybe_example ['tgt' ])
75
+ if 'align' in maybe_example :
76
+ maybe_example ['align' ] = ' ' .join (maybe_example ['align' ])
77
+ return maybe_example
78
+
79
+ def _maybe_add_dynamic_dict (self , example , fields ):
80
+ """maybe update `example` with dynamic_dict related fields."""
81
+ if 'src_map' in fields and 'alignment' in fields :
82
+ example = _dynamic_dict (
83
+ example ,
84
+ fields ['src' ].base_field ,
85
+ fields ['tgt' ].base_field )
86
+ return example
87
+
88
+ def _to_examples (self , bucket , is_train = False ):
89
+ examples = []
90
+ for item in bucket :
91
+ maybe_example = self ._process (item , is_train = is_train )
92
+ if maybe_example is not None :
93
+ example = self ._maybe_add_dynamic_dict (
94
+ maybe_example , self .fields_dict )
95
+ ex_fields = {k : [(k , v )] for k , v in self .fields_dict .items ()
96
+ if k in example }
97
+ ex = TorchtextExample .fromdict (example , ex_fields )
98
+ examples .append (ex )
99
+ return examples
100
+
101
+ def __call__ (self , bucket ):
102
+ examples = self ._to_examples (bucket , is_train = self .is_train )
103
+ dataset = TorchtextDataset (examples , self .fields_dict )
104
+ return dataset
105
+
106
+
39
107
class ParallelCorpus (object ):
40
108
"""A parallel corpus file pair that can be loaded to iterate."""
41
109
@@ -192,7 +260,106 @@ def build_corpora_iters(corpora, transforms, corpora_info, is_train=False,
192
260
return corpora_iters
193
261
194
262
195
- def save_transformed_sample (opts , transforms , n_sample = 3 , build_vocab = False ):
263
+ def write_files_from_queues (sample_path , queues ):
264
+ """
265
+ Standalone process that reads data from
266
+ queues in order and write to sample files.
267
+ """
268
+ os .makedirs (sample_path , exist_ok = True )
269
+ for c_name in queues .keys ():
270
+ dest_base = dest_base = os .path .join (
271
+ sample_path , "{}.{}" .format (c_name , CorpusName .SAMPLE ))
272
+ with open (dest_base + ".src" , 'w' , encoding = "utf-8" ) as f_src ,\
273
+ open (dest_base + ".tgt" , 'w' , encoding = "utf-8" ) as f_tgt :
274
+ while True :
275
+ _next = False
276
+ for i , q in enumerate (queues [c_name ]):
277
+ item = q .get ()
278
+ if item == "break" :
279
+ _next = True
280
+ break
281
+ j , src_line , tgt_line = item
282
+ f_src .write (src_line + '\n ' )
283
+ f_tgt .write (tgt_line + '\n ' )
284
+ if _next :
285
+ break
286
+
287
+
288
+ def build_sub_vocab (corpora , transforms , opts , n_sample , stride , offset ):
289
+ """Build vocab on (strided) subpart of the data."""
290
+ sub_counter_src = Counter ()
291
+ sub_counter_tgt = Counter ()
292
+ datasets_iterables = build_corpora_iters (
293
+ corpora , transforms , opts .data , is_train = False ,
294
+ skip_empty_level = opts .skip_empty_level ,
295
+ stride = stride , offset = offset )
296
+ for c_name , c_iter in datasets_iterables .items ():
297
+ for i , item in enumerate (c_iter ):
298
+ maybe_example = DatasetAdapter ._process (item , is_train = True )
299
+ if maybe_example is None :
300
+ continue
301
+ src_line , tgt_line = maybe_example ['src' ], maybe_example ['tgt' ]
302
+ sub_counter_src .update (src_line .split (' ' ))
303
+ sub_counter_tgt .update (tgt_line .split (' ' ))
304
+ if opts .dump_samples :
305
+ build_sub_vocab .queues [c_name ][offset ].put (
306
+ (i , src_line , tgt_line ))
307
+ if n_sample > 0 and ((i + 1 ) * stride + offset ) >= n_sample :
308
+ if opts .dump_samples :
309
+ build_sub_vocab .queues [c_name ][offset ].put ("break" )
310
+ break
311
+ if opts .dump_samples :
312
+ build_sub_vocab .queues [c_name ][offset ].put ("break" )
313
+ return sub_counter_src , sub_counter_tgt
314
+
315
+
316
+ def init_pool (queues ):
317
+ """Add the queues as attribute of the pooled function."""
318
+ build_sub_vocab .queues = queues
319
+
320
+
321
+ def build_vocab (opts , transforms , n_sample = 3 ):
322
+ """Build vocabulary from data."""
323
+
324
+ if n_sample == - 1 :
325
+ logger .info (f"n_sample={ n_sample } : Build vocab on full datasets." )
326
+ elif n_sample > 0 :
327
+ logger .info (f"Build vocab on { n_sample } transformed examples/corpus." )
328
+ else :
329
+ raise ValueError (f"n_sample should > 0 or == -1, get { n_sample } ." )
330
+
331
+ if opts .dump_samples :
332
+ logger .info ("The samples on which the vocab is built will be "
333
+ "dumped to disk. It may slow down the process." )
334
+ corpora = get_corpora (opts , is_train = True )
335
+ counter_src = Counter ()
336
+ counter_tgt = Counter ()
337
+ from functools import partial
338
+ queues = {c_name : [mp .Queue (opts .vocab_sample_queue_size )
339
+ for i in range (opts .num_threads )]
340
+ for c_name in corpora .keys ()}
341
+ sample_path = os .path .join (
342
+ os .path .dirname (opts .save_data ), CorpusName .SAMPLE )
343
+ if opts .dump_samples :
344
+ write_process = mp .Process (
345
+ target = write_files_from_queues ,
346
+ args = (sample_path , queues ),
347
+ daemon = True )
348
+ write_process .start ()
349
+ with mp .Pool (opts .num_threads , init_pool , [queues ]) as p :
350
+ func = partial (
351
+ build_sub_vocab , corpora , transforms ,
352
+ opts , n_sample , opts .num_threads )
353
+ for sub_counter_src , sub_counter_tgt in p .imap (
354
+ func , range (0 , opts .num_threads )):
355
+ counter_src .update (sub_counter_src )
356
+ counter_tgt .update (sub_counter_tgt )
357
+ if opts .dump_samples :
358
+ write_process .join ()
359
+ return counter_src , counter_tgt
360
+
361
+
362
+ def save_transformed_sample (opts , transforms , n_sample = 3 ):
196
363
"""Save transformed data sample as specified in opts."""
197
364
198
365
if n_sample == - 1 :
@@ -205,11 +372,7 @@ def save_transformed_sample(opts, transforms, n_sample=3, build_vocab=False):
205
372
else :
206
373
raise ValueError (f"n_sample should >= -1, get { n_sample } ." )
207
374
208
- from onmt .inputters .dynamic_iterator import DatasetAdapter
209
375
corpora = get_corpora (opts , is_train = True )
210
- if build_vocab :
211
- counter_src = Counter ()
212
- counter_tgt = Counter ()
213
376
datasets_iterables = build_corpora_iters (
214
377
corpora , transforms , opts .data , is_train = False ,
215
378
skip_empty_level = opts .skip_empty_level )
@@ -226,12 +389,7 @@ def save_transformed_sample(opts, transforms, n_sample=3, build_vocab=False):
226
389
if maybe_example is None :
227
390
continue
228
391
src_line , tgt_line = maybe_example ['src' ], maybe_example ['tgt' ]
229
- if build_vocab :
230
- counter_src .update (src_line .split (' ' ))
231
- counter_tgt .update (tgt_line .split (' ' ))
232
392
f_src .write (src_line + '\n ' )
233
393
f_tgt .write (tgt_line + '\n ' )
234
394
if n_sample > 0 and i >= n_sample :
235
395
break
236
- if build_vocab :
237
- return counter_src , counter_tgt
0 commit comments