Skip to content
This repository was archived by the owner on Dec 16, 2022. It is now read-only.

Commit bbaf1fc

Browse files
authored
Benchmark iterator, avoid redundant queue, remove managers. (#3119)
- Adds a script to benchmark iterators. - Average speed - Introspects queues - Removes a bottleneck when `MultiprocessDatasetReader` and `MultiprocessIterator` are used in conjunction. - Specifically, removes a redundant queue that was populated by a single process. - Removes managers which have significant overhead. - Results on training_config/bidirectional_language_model.jsonnet: - original code, no multiprocessing: 0.047 s/b over 10000 batches - original code, workers = 1: 0.073 s/b over 10000 batches - original code, workers = 10: 0.078 s/b over 10000 batches - this PR (-queue), workers = 1: 0.073 s/b over 10000 batches - this PR (-queue), workers = 10: 0.046 s/b over 10000 batches - this PR (-queue, - manager), workers = 1: 0.063 s/b over 10000 batches - this PR (-queue, - manager), workers = 10: 0.020 s/b over 10000 batches - Notably, previously we did not see any benefit from scaling to multiple workers. Now we do, albeit worse than linearly. More work required there. - Related issues: #2962, #1890
1 parent 78ee3d8 commit bbaf1fc

File tree

7 files changed

+473
-106
lines changed

7 files changed

+473
-106
lines changed
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,16 @@
1-
from typing import List, Iterable, Iterator
21
import glob
32
import logging
3+
import os
4+
from queue import Empty
5+
from typing import List, Iterable, Iterator, Optional
46

57
import numpy as np
6-
from torch.multiprocessing import Manager, Process, Queue, log_to_stderr
8+
from torch.multiprocessing import Process, Queue, Value, log_to_stderr
79

810
from allennlp.data.dataset_readers.dataset_reader import DatasetReader
911
from allennlp.data.instance import Instance
1012

13+
1114
class logger:
1215
"""
1316
multiprocessing.log_to_stderr causes some output in the logs
@@ -30,26 +33,133 @@ def info(cls, message: str) -> None:
3033
def _worker(reader: DatasetReader,
3134
input_queue: Queue,
3235
output_queue: Queue,
33-
index: int) -> None:
36+
num_active_workers: Value,
37+
num_inflight_items: Value,
38+
worker_id: int) -> None:
3439
"""
3540
A worker that pulls filenames off the input queue, uses the dataset reader
36-
to read them, and places the generated instances on the output queue.
37-
When there are no filenames left on the input queue, it puts its ``index``
38-
on the output queue and doesn't do anything else.
41+
to read them, and places the generated instances on the output queue. When
42+
there are no filenames left on the input queue, it decrements
43+
num_active_workers to signal completion.
3944
"""
45+
logger.info(f"Reader worker: {worker_id} PID: {os.getpid()}")
4046
# Keep going until you get a file_path that's None.
4147
while True:
4248
file_path = input_queue.get()
4349
if file_path is None:
44-
# Put my index on the queue to signify that I'm finished
45-
output_queue.put(index)
50+
# It's important that we close and join the queue here before
51+
# decrementing num_active_workers. Otherwise our parent may join us
52+
# before the queue's feeder thread has passed all buffered items to
53+
# the underlying pipe resulting in a deadlock.
54+
#
55+
# See:
56+
# https://docs.python.org/3.6/library/multiprocessing.html?highlight=process#pipes-and-queues
57+
# https://docs.python.org/3.6/library/multiprocessing.html?highlight=process#programming-guidelines
58+
output_queue.close()
59+
output_queue.join_thread()
60+
# Decrementing is not atomic.
61+
# See https://docs.python.org/2/library/multiprocessing.html#multiprocessing.Value.
62+
with num_active_workers.get_lock():
63+
num_active_workers.value -= 1
64+
logger.info(f"Reader worker {worker_id} finished")
4665
break
4766

4867
logger.info(f"reading instances from {file_path}")
4968
for instance in reader.read(file_path):
69+
with num_inflight_items.get_lock():
70+
num_inflight_items.value += 1
5071
output_queue.put(instance)
5172

5273

74+
class QIterable(Iterable[Instance]):
75+
"""
76+
You can't set attributes on Iterators, so this is just a dumb wrapper
77+
that exposes the output_queue.
78+
"""
79+
def __init__(self, output_queue_size, epochs_per_read, num_workers, reader, file_path) -> None:
80+
self.output_queue = Queue(output_queue_size)
81+
self.epochs_per_read = epochs_per_read
82+
self.num_workers = num_workers
83+
self.reader = reader
84+
self.file_path = file_path
85+
86+
# Initialized in start.
87+
self.input_queue: Optional[Queue] = None
88+
self.processes: List[Process] = []
89+
# The num_active_workers and num_inflight_items counts in conjunction
90+
# determine whether there could be any outstanding instances.
91+
self.num_active_workers: Optional[Value] = None
92+
self.num_inflight_items: Optional[Value] = None
93+
94+
def __iter__(self) -> Iterator[Instance]:
95+
self.start()
96+
97+
# Keep going as long as not all the workers have finished or there are items in flight.
98+
while self.num_active_workers.value > 0 or self.num_inflight_items.value > 0:
99+
# Inner loop to minimize locking on self.num_active_workers.
100+
while True:
101+
try:
102+
# Non-blocking to handle the empty-queue case.
103+
yield self.output_queue.get(block=False, timeout=1.0)
104+
with self.num_inflight_items.get_lock():
105+
self.num_inflight_items.value -= 1
106+
except Empty:
107+
# The queue could be empty because the workers are
108+
# all finished or because they're busy processing.
109+
# The outer loop distinguishes between these two
110+
# cases.
111+
break
112+
113+
self.join()
114+
115+
def start(self) -> None:
116+
shards = glob.glob(self.file_path)
117+
# Ensure a consistent order before shuffling for testing.
118+
shards.sort()
119+
num_shards = len(shards)
120+
121+
# If we want multiple epochs per read, put shards in the queue multiple times.
122+
self.input_queue = Queue(num_shards * self.epochs_per_read + self.num_workers)
123+
for _ in range(self.epochs_per_read):
124+
np.random.shuffle(shards)
125+
for shard in shards:
126+
self.input_queue.put(shard)
127+
128+
# Then put a None per worker to signify no more files.
129+
for _ in range(self.num_workers):
130+
self.input_queue.put(None)
131+
132+
133+
assert not self.processes, "Process list non-empty! You must call QIterable.join() before restarting."
134+
self.num_active_workers = Value('i', self.num_workers)
135+
self.num_inflight_items = Value('i', 0)
136+
for worker_id in range(self.num_workers):
137+
process = Process(target=_worker,
138+
args=(self.reader, self.input_queue, self.output_queue,
139+
self.num_active_workers, self.num_inflight_items, worker_id))
140+
logger.info(f"starting worker {worker_id}")
141+
process.start()
142+
self.processes.append(process)
143+
144+
def join(self) -> None:
145+
for process in self.processes:
146+
process.join()
147+
self.processes.clear()
148+
149+
def __del__(self) -> None:
150+
"""
151+
Terminate processes if the user hasn't joined. This is necessary as
152+
leaving stray processes running can corrupt shared state. In brief,
153+
we've observed shared memory counters being reused (when the memory was
154+
free from the perspective of the parent process) while the stray
155+
workers still held a reference to them.
156+
157+
For a discussion of using destructors in Python in this manner, see
158+
https://eli.thegreenplace.net/2009/06/12/safely-using-destructors-in-python/.
159+
"""
160+
for process in self.processes:
161+
process.terminate()
162+
53163

54164
@DatasetReader.register('multiprocess')
55165
class MultiprocessDatasetReader(DatasetReader):
@@ -103,70 +213,10 @@ def _read(self, file_path: str) -> Iterable[Instance]:
103213
raise RuntimeError("Multiprocess reader implements read() directly.")
104214

105215
def read(self, file_path: str) -> Iterable[Instance]:
106-
outer_self = self
107-
108-
class QIterable(Iterable[Instance]):
109-
"""
110-
You can't set attributes on Iterators, so this is just a dumb wrapper
111-
that exposes the output_queue. Currently you probably shouldn't touch
112-
the output queue, but this is done with an eye toward implementing
113-
a data iterator that can read directly from the queue (instead of having
114-
to use the _instances iterator we define here.)
115-
"""
116-
def __init__(self) -> None:
117-
self.manager = Manager()
118-
self.output_queue = self.manager.Queue(outer_self.output_queue_size)
119-
self.num_workers = outer_self.num_workers
120-
121-
def __iter__(self) -> Iterator[Instance]:
122-
# pylint: disable=protected-access
123-
return outer_self._instances(file_path, self.manager, self.output_queue)
124-
125-
return QIterable()
126-
127-
def _instances(self, file_path: str, manager: Manager, output_queue: Queue) -> Iterator[Instance]:
128-
"""
129-
A generator that reads instances off the output queue and yields them up
130-
until none are left (signified by all ``num_workers`` workers putting their
131-
ids into the queue).
132-
"""
133-
shards = glob.glob(file_path)
134-
# Ensure a consistent order before shuffling for testing.
135-
shards.sort()
136-
num_shards = len(shards)
137-
138-
# If we want multiple epochs per read, put shards in the queue multiple times.
139-
input_queue = manager.Queue(num_shards * self.epochs_per_read + self.num_workers)
140-
for _ in range(self.epochs_per_read):
141-
np.random.shuffle(shards)
142-
for shard in shards:
143-
input_queue.put(shard)
144-
145-
# Then put a None per worker to signify no more files.
146-
for _ in range(self.num_workers):
147-
input_queue.put(None)
148-
149-
processes: List[Process] = []
150-
num_finished = 0
151-
152-
for worker_id in range(self.num_workers):
153-
process = Process(target=_worker,
154-
args=(self.reader, input_queue, output_queue, worker_id))
155-
logger.info(f"starting worker {worker_id}")
156-
process.start()
157-
processes.append(process)
158-
159-
# Keep going as long as not all the workers have finished.
160-
while num_finished < self.num_workers:
161-
item = output_queue.get()
162-
if isinstance(item, int):
163-
# Means a worker has finished, so increment the finished count.
164-
num_finished += 1
165-
logger.info(f"worker {item} finished ({num_finished}/{self.num_workers})")
166-
else:
167-
# Otherwise it's an ``Instance``, so yield it up.
168-
yield item
169-
170-
for process in processes:
171-
process.join()
172-
processes.clear()
216+
return QIterable(
217+
output_queue_size=self.output_queue_size,
218+
epochs_per_read=self.epochs_per_read,
219+
num_workers=self.num_workers,
220+
reader=self.reader,
221+
file_path=file_path
222+
)

0 commit comments

Comments
 (0)