Skip to content

Commit 422d824

Browse files
Parallelize onmt_build_vocab (#1897)
1 parent 9df8a17 commit 422d824

File tree

4 files changed

+183
-80
lines changed

4 files changed

+183
-80
lines changed

onmt/bin/build_vocab.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from onmt.utils.misc import set_random_seed, check_path
55
from onmt.utils.parse import ArgumentParser
66
from onmt.opts import dynamic_prepare_opts
7-
from onmt.inputters.corpus import save_transformed_sample
7+
from onmt.inputters.corpus import build_vocab
88
from onmt.transforms import make_transforms, get_transforms_cls
99

1010

@@ -32,8 +32,8 @@ def build_vocab_main(opts):
3232
transforms = make_transforms(opts, transforms_cls, fields)
3333

3434
logger.info(f"Counter vocab from {opts.n_sample} samples.")
35-
src_counter, tgt_counter = save_transformed_sample(
36-
opts, transforms, n_sample=opts.n_sample, build_vocab=True)
35+
src_counter, tgt_counter = build_vocab(
36+
opts, transforms, n_sample=opts.n_sample)
3737

3838
logger.info(f"Counters src:{len(src_counter)}")
3939
logger.info(f"Counters tgt:{len(tgt_counter)}")

onmt/inputters/corpus.py

+168-10
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,15 @@
33
from onmt.utils.logging import logger
44
from onmt.constants import CorpusName
55
from onmt.transforms import TransformPipe
6+
from onmt.inputters.dataset_base import _dynamic_dict
7+
from torchtext.data import Dataset as TorchtextDataset, \
8+
Example as TorchtextExample
69

710
from collections import Counter
811
from contextlib import contextmanager
912

13+
import multiprocessing as mp
14+
1015

1116
@contextmanager
1217
def exfile_open(filename, *args, **kwargs):
@@ -36,6 +41,69 @@ def exfile_open(filename, *args, **kwargs):
3641
_file.close()
3742

3843

44+
class DatasetAdapter(object):
45+
"""Adapte a buckets of tuples into examples of a torchtext Dataset."""
46+
47+
valid_field_name = (
48+
'src', 'tgt', 'indices', 'src_map', 'src_ex_vocab', 'alignment',
49+
'align')
50+
51+
def __init__(self, fields, is_train):
52+
self.fields_dict = self._valid_fields(fields)
53+
self.is_train = is_train
54+
55+
@classmethod
56+
def _valid_fields(cls, fields):
57+
"""Return valid fields in dict format."""
58+
return {
59+
f_k: f_v for f_k, f_v in fields.items()
60+
if f_k in cls.valid_field_name
61+
}
62+
63+
@staticmethod
64+
def _process(item, is_train):
65+
"""Return valid transformed example from `item`."""
66+
example, transform, cid = item
67+
# this is a hack: appears quicker to apply it here
68+
# than in the ParallelCorpusIterator
69+
maybe_example = transform.apply(
70+
example, is_train=is_train, corpus_name=cid)
71+
if maybe_example is None:
72+
return None
73+
maybe_example['src'] = ' '.join(maybe_example['src'])
74+
maybe_example['tgt'] = ' '.join(maybe_example['tgt'])
75+
if 'align' in maybe_example:
76+
maybe_example['align'] = ' '.join(maybe_example['align'])
77+
return maybe_example
78+
79+
def _maybe_add_dynamic_dict(self, example, fields):
80+
"""maybe update `example` with dynamic_dict related fields."""
81+
if 'src_map' in fields and 'alignment' in fields:
82+
example = _dynamic_dict(
83+
example,
84+
fields['src'].base_field,
85+
fields['tgt'].base_field)
86+
return example
87+
88+
def _to_examples(self, bucket, is_train=False):
89+
examples = []
90+
for item in bucket:
91+
maybe_example = self._process(item, is_train=is_train)
92+
if maybe_example is not None:
93+
example = self._maybe_add_dynamic_dict(
94+
maybe_example, self.fields_dict)
95+
ex_fields = {k: [(k, v)] for k, v in self.fields_dict.items()
96+
if k in example}
97+
ex = TorchtextExample.fromdict(example, ex_fields)
98+
examples.append(ex)
99+
return examples
100+
101+
def __call__(self, bucket):
102+
examples = self._to_examples(bucket, is_train=self.is_train)
103+
dataset = TorchtextDataset(examples, self.fields_dict)
104+
return dataset
105+
106+
39107
class ParallelCorpus(object):
40108
"""A parallel corpus file pair that can be loaded to iterate."""
41109

@@ -192,7 +260,106 @@ def build_corpora_iters(corpora, transforms, corpora_info, is_train=False,
192260
return corpora_iters
193261

194262

195-
def save_transformed_sample(opts, transforms, n_sample=3, build_vocab=False):
263+
def write_files_from_queues(sample_path, queues):
264+
"""
265+
Standalone process that reads data from
266+
queues in order and write to sample files.
267+
"""
268+
os.makedirs(sample_path, exist_ok=True)
269+
for c_name in queues.keys():
270+
dest_base = dest_base = os.path.join(
271+
sample_path, "{}.{}".format(c_name, CorpusName.SAMPLE))
272+
with open(dest_base + ".src", 'w', encoding="utf-8") as f_src,\
273+
open(dest_base + ".tgt", 'w', encoding="utf-8") as f_tgt:
274+
while True:
275+
_next = False
276+
for i, q in enumerate(queues[c_name]):
277+
item = q.get()
278+
if item == "break":
279+
_next = True
280+
break
281+
j, src_line, tgt_line = item
282+
f_src.write(src_line + '\n')
283+
f_tgt.write(tgt_line + '\n')
284+
if _next:
285+
break
286+
287+
288+
def build_sub_vocab(corpora, transforms, opts, n_sample, stride, offset):
289+
"""Build vocab on (strided) subpart of the data."""
290+
sub_counter_src = Counter()
291+
sub_counter_tgt = Counter()
292+
datasets_iterables = build_corpora_iters(
293+
corpora, transforms, opts.data, is_train=False,
294+
skip_empty_level=opts.skip_empty_level,
295+
stride=stride, offset=offset)
296+
for c_name, c_iter in datasets_iterables.items():
297+
for i, item in enumerate(c_iter):
298+
maybe_example = DatasetAdapter._process(item, is_train=True)
299+
if maybe_example is None:
300+
continue
301+
src_line, tgt_line = maybe_example['src'], maybe_example['tgt']
302+
sub_counter_src.update(src_line.split(' '))
303+
sub_counter_tgt.update(tgt_line.split(' '))
304+
if opts.dump_samples:
305+
build_sub_vocab.queues[c_name][offset].put(
306+
(i, src_line, tgt_line))
307+
if n_sample > 0 and ((i+1) * stride + offset) >= n_sample:
308+
if opts.dump_samples:
309+
build_sub_vocab.queues[c_name][offset].put("break")
310+
break
311+
if opts.dump_samples:
312+
build_sub_vocab.queues[c_name][offset].put("break")
313+
return sub_counter_src, sub_counter_tgt
314+
315+
316+
def init_pool(queues):
317+
"""Add the queues as attribute of the pooled function."""
318+
build_sub_vocab.queues = queues
319+
320+
321+
def build_vocab(opts, transforms, n_sample=3):
322+
"""Build vocabulary from data."""
323+
324+
if n_sample == -1:
325+
logger.info(f"n_sample={n_sample}: Build vocab on full datasets.")
326+
elif n_sample > 0:
327+
logger.info(f"Build vocab on {n_sample} transformed examples/corpus.")
328+
else:
329+
raise ValueError(f"n_sample should > 0 or == -1, get {n_sample}.")
330+
331+
if opts.dump_samples:
332+
logger.info("The samples on which the vocab is built will be "
333+
"dumped to disk. It may slow down the process.")
334+
corpora = get_corpora(opts, is_train=True)
335+
counter_src = Counter()
336+
counter_tgt = Counter()
337+
from functools import partial
338+
queues = {c_name: [mp.Queue(opts.vocab_sample_queue_size)
339+
for i in range(opts.num_threads)]
340+
for c_name in corpora.keys()}
341+
sample_path = os.path.join(
342+
os.path.dirname(opts.save_data), CorpusName.SAMPLE)
343+
if opts.dump_samples:
344+
write_process = mp.Process(
345+
target=write_files_from_queues,
346+
args=(sample_path, queues),
347+
daemon=True)
348+
write_process.start()
349+
with mp.Pool(opts.num_threads, init_pool, [queues]) as p:
350+
func = partial(
351+
build_sub_vocab, corpora, transforms,
352+
opts, n_sample, opts.num_threads)
353+
for sub_counter_src, sub_counter_tgt in p.imap(
354+
func, range(0, opts.num_threads)):
355+
counter_src.update(sub_counter_src)
356+
counter_tgt.update(sub_counter_tgt)
357+
if opts.dump_samples:
358+
write_process.join()
359+
return counter_src, counter_tgt
360+
361+
362+
def save_transformed_sample(opts, transforms, n_sample=3):
196363
"""Save transformed data sample as specified in opts."""
197364

198365
if n_sample == -1:
@@ -205,11 +372,7 @@ def save_transformed_sample(opts, transforms, n_sample=3, build_vocab=False):
205372
else:
206373
raise ValueError(f"n_sample should >= -1, get {n_sample}.")
207374

208-
from onmt.inputters.dynamic_iterator import DatasetAdapter
209375
corpora = get_corpora(opts, is_train=True)
210-
if build_vocab:
211-
counter_src = Counter()
212-
counter_tgt = Counter()
213376
datasets_iterables = build_corpora_iters(
214377
corpora, transforms, opts.data, is_train=False,
215378
skip_empty_level=opts.skip_empty_level)
@@ -226,12 +389,7 @@ def save_transformed_sample(opts, transforms, n_sample=3, build_vocab=False):
226389
if maybe_example is None:
227390
continue
228391
src_line, tgt_line = maybe_example['src'], maybe_example['tgt']
229-
if build_vocab:
230-
counter_src.update(src_line.split(' '))
231-
counter_tgt.update(tgt_line.split(' '))
232392
f_src.write(src_line + '\n')
233393
f_tgt.write(tgt_line + '\n')
234394
if n_sample > 0 and i >= n_sample:
235395
break
236-
if build_vocab:
237-
return counter_src, counter_tgt

onmt/inputters/dynamic_iterator.py

+3-67
Original file line numberDiff line numberDiff line change
@@ -1,77 +1,13 @@
11
"""Module that contain iterator used for dynamic data."""
22
from itertools import cycle
33

4-
from torchtext.data import Dataset as TorchtextDataset, \
5-
Example as TorchtextExample, batch as torchtext_batch
4+
from torchtext.data import batch as torchtext_batch
65
from onmt.inputters import str2sortkey, max_tok_len, OrderedIterator
7-
from onmt.inputters.dataset_base import _dynamic_dict
8-
from onmt.inputters.corpus import get_corpora, build_corpora_iters
6+
from onmt.inputters.corpus import get_corpora, build_corpora_iters,\
7+
DatasetAdapter
98
from onmt.transforms import make_transforms
109

1110

12-
class DatasetAdapter(object):
13-
"""Adapte a buckets of tuples into examples of a torchtext Dataset."""
14-
15-
valid_field_name = (
16-
'src', 'tgt', 'indices', 'src_map', 'src_ex_vocab', 'alignment',
17-
'align')
18-
19-
def __init__(self, fields, is_train):
20-
self.fields_dict = self._valid_fields(fields)
21-
self.is_train = is_train
22-
23-
@classmethod
24-
def _valid_fields(cls, fields):
25-
"""Return valid fields in dict format."""
26-
return {
27-
f_k: f_v for f_k, f_v in fields.items()
28-
if f_k in cls.valid_field_name
29-
}
30-
31-
@staticmethod
32-
def _process(item, is_train):
33-
"""Return valid transformed example from `item`."""
34-
example, transform, cid = item
35-
# this is a hack: appears quicker to apply it here
36-
# than in the ParallelCorpusIterator
37-
maybe_example = transform.apply(
38-
example, is_train=is_train, corpus_name=cid)
39-
if maybe_example is None:
40-
return None
41-
maybe_example['src'] = ' '.join(maybe_example['src'])
42-
maybe_example['tgt'] = ' '.join(maybe_example['tgt'])
43-
if 'align' in maybe_example:
44-
maybe_example['align'] = ' '.join(maybe_example['align'])
45-
return maybe_example
46-
47-
def _maybe_add_dynamic_dict(self, example, fields):
48-
"""maybe update `example` with dynamic_dict related fields."""
49-
if 'src_map' in fields and 'alignment' in fields:
50-
example = _dynamic_dict(
51-
example,
52-
fields['src'].base_field,
53-
fields['tgt'].base_field)
54-
return example
55-
56-
def _to_examples(self, bucket, is_train=False):
57-
examples = []
58-
for item in bucket:
59-
maybe_example = self._process(item, is_train=is_train)
60-
if maybe_example is not None:
61-
example = self._maybe_add_dynamic_dict(
62-
maybe_example, self.fields_dict)
63-
ex_fields = {k: [(k, v)] for k, v in self.fields_dict.items()
64-
if k in example}
65-
ex = TorchtextExample.fromdict(example, ex_fields)
66-
examples.append(ex)
67-
return examples
68-
69-
def __call__(self, bucket):
70-
examples = self._to_examples(bucket, is_train=self.is_train)
71-
dataset = TorchtextDataset(examples, self.fields_dict)
72-
return dataset
73-
74-
7511
class MixingStrategy(object):
7612
"""Mixing strategy that should be used in Data Iterator."""
7713

onmt/opts.py

+9
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,15 @@ def _add_dynamic_corpus_opts(parser, build_vocab_only=False):
9999
group.add('-dump_transforms', '--dump_transforms', action='store_true',
100100
help="Dump transforms `*.transforms.pt` to disk."
101101
" -save_data should be set as saving prefix.")
102+
else:
103+
group.add('-dump_samples', '--dump_samples', action='store_true',
104+
help="Dump samples when building vocab. "
105+
"Warning: this may slow down the process.")
106+
group.add('-num_threads', '--num_threads', type=int, default=1,
107+
help="Number of parallel threads to build the vocab.")
108+
group.add('-vocab_sample_queue_size', '--vocab_sample_queue_size',
109+
type=int, default=100,
110+
help="Size of queues used in the build_vocab dump path.")
102111

103112

104113
def _add_dynamic_fields_opts(parser, build_vocab_only=False):

0 commit comments

Comments
 (0)