1
- from typing import List , Iterable , Iterator
2
1
import glob
3
2
import logging
3
+ import os
4
+ from queue import Empty
5
+ from typing import List , Iterable , Iterator , Optional
4
6
5
7
import numpy as np
6
- from torch .multiprocessing import Manager , Process , Queue , log_to_stderr
8
+ from torch .multiprocessing import Process , Queue , Value , log_to_stderr
7
9
8
10
from allennlp .data .dataset_readers .dataset_reader import DatasetReader
9
11
from allennlp .data .instance import Instance
10
12
13
+
11
14
class logger :
12
15
"""
13
16
multiprocessing.log_to_stderr causes some output in the logs
@@ -30,26 +33,133 @@ def info(cls, message: str) -> None:
30
33
def _worker (reader : DatasetReader ,
31
34
input_queue : Queue ,
32
35
output_queue : Queue ,
33
- index : int ) -> None :
36
+ num_active_workers : Value ,
37
+ num_inflight_items : Value ,
38
+ worker_id : int ) -> None :
34
39
"""
35
40
A worker that pulls filenames off the input queue, uses the dataset reader
36
- to read them, and places the generated instances on the output queue.
37
- When there are no filenames left on the input queue, it puts its ``index``
38
- on the output queue and doesn't do anything else .
41
+ to read them, and places the generated instances on the output queue. When
42
+ there are no filenames left on the input queue, it decrements
43
+ num_active_workers to signal completion .
39
44
"""
45
+ logger .info (f"Reader worker: { worker_id } PID: { os .getpid ()} " )
40
46
# Keep going until you get a file_path that's None.
41
47
while True :
42
48
file_path = input_queue .get ()
43
49
if file_path is None :
44
- # Put my index on the queue to signify that I'm finished
45
- output_queue .put (index )
50
+ # It's important that we close and join the queue here before
51
+ # decrementing num_active_workers. Otherwise our parent may join us
52
+ # before the queue's feeder thread has passed all buffered items to
53
+ # the underlying pipe resulting in a deadlock.
54
+ #
55
+ # See:
56
+ # https://docs.python.org/3.6/library/multiprocessing.html?highlight=process#pipes-and-queues
57
+ # https://docs.python.org/3.6/library/multiprocessing.html?highlight=process#programming-guidelines
58
+ output_queue .close ()
59
+ output_queue .join_thread ()
60
+ # Decrementing is not atomic.
61
+ # See https://docs.python.org/2/library/multiprocessing.html#multiprocessing.Value.
62
+ with num_active_workers .get_lock ():
63
+ num_active_workers .value -= 1
64
+ logger .info (f"Reader worker { worker_id } finished" )
46
65
break
47
66
48
67
logger .info (f"reading instances from { file_path } " )
49
68
for instance in reader .read (file_path ):
69
+ with num_inflight_items .get_lock ():
70
+ num_inflight_items .value += 1
50
71
output_queue .put (instance )
51
72
52
73
74
+ class QIterable (Iterable [Instance ]):
75
+ """
76
+ You can't set attributes on Iterators, so this is just a dumb wrapper
77
+ that exposes the output_queue.
78
+ """
79
+ def __init__ (self , output_queue_size , epochs_per_read , num_workers , reader , file_path ) -> None :
80
+ self .output_queue = Queue (output_queue_size )
81
+ self .epochs_per_read = epochs_per_read
82
+ self .num_workers = num_workers
83
+ self .reader = reader
84
+ self .file_path = file_path
85
+
86
+ # Initialized in start.
87
+ self .input_queue : Optional [Queue ] = None
88
+ self .processes : List [Process ] = []
89
+ # The num_active_workers and num_inflight_items counts in conjunction
90
+ # determine whether there could be any outstanding instances.
91
+ self .num_active_workers : Optional [Value ] = None
92
+ self .num_inflight_items : Optional [Value ] = None
93
+
94
+ def __iter__ (self ) -> Iterator [Instance ]:
95
+ self .start ()
96
+
97
+ # Keep going as long as not all the workers have finished or there are items in flight.
98
+ while self .num_active_workers .value > 0 or self .num_inflight_items .value > 0 :
99
+ # Inner loop to minimize locking on self.num_active_workers.
100
+ while True :
101
+ try :
102
+ # Non-blocking to handle the empty-queue case.
103
+ yield self .output_queue .get (block = False , timeout = 1.0 )
104
+ with self .num_inflight_items .get_lock ():
105
+ self .num_inflight_items .value -= 1
106
+ except Empty :
107
+ # The queue could be empty because the workers are
108
+ # all finished or because they're busy processing.
109
+ # The outer loop distinguishes between these two
110
+ # cases.
111
+ break
112
+
113
+ self .join ()
114
+
115
+ def start (self ) -> None :
116
+ shards = glob .glob (self .file_path )
117
+ # Ensure a consistent order before shuffling for testing.
118
+ shards .sort ()
119
+ num_shards = len (shards )
120
+
121
+ # If we want multiple epochs per read, put shards in the queue multiple times.
122
+ self .input_queue = Queue (num_shards * self .epochs_per_read + self .num_workers )
123
+ for _ in range (self .epochs_per_read ):
124
+ np .random .shuffle (shards )
125
+ for shard in shards :
126
+ self .input_queue .put (shard )
127
+
128
+ # Then put a None per worker to signify no more files.
129
+ for _ in range (self .num_workers ):
130
+ self .input_queue .put (None )
131
+
132
+
133
+ assert not self .processes , "Process list non-empty! You must call QIterable.join() before restarting."
134
+ self .num_active_workers = Value ('i' , self .num_workers )
135
+ self .num_inflight_items = Value ('i' , 0 )
136
+ for worker_id in range (self .num_workers ):
137
+ process = Process (target = _worker ,
138
+ args = (self .reader , self .input_queue , self .output_queue ,
139
+ self .num_active_workers , self .num_inflight_items , worker_id ))
140
+ logger .info (f"starting worker { worker_id } " )
141
+ process .start ()
142
+ self .processes .append (process )
143
+
144
+ def join (self ) -> None :
145
+ for process in self .processes :
146
+ process .join ()
147
+ self .processes .clear ()
148
+
149
+ def __del__ (self ) -> None :
150
+ """
151
+ Terminate processes if the user hasn't joined. This is necessary as
152
+ leaving stray processes running can corrupt shared state. In brief,
153
+ we've observed shared memory counters being reused (when the memory was
154
+ free from the perspective of the parent process) while the stray
155
+ workers still held a reference to them.
156
+
157
+ For a discussion of using destructors in Python in this manner, see
158
+ https://eli.thegreenplace.net/2009/06/12/safely-using-destructors-in-python/.
159
+ """
160
+ for process in self .processes :
161
+ process .terminate ()
162
+
53
163
54
164
@DatasetReader .register ('multiprocess' )
55
165
class MultiprocessDatasetReader (DatasetReader ):
@@ -103,70 +213,10 @@ def _read(self, file_path: str) -> Iterable[Instance]:
103
213
raise RuntimeError ("Multiprocess reader implements read() directly." )
104
214
105
215
def read (self , file_path : str ) -> Iterable [Instance ]:
106
- outer_self = self
107
-
108
- class QIterable (Iterable [Instance ]):
109
- """
110
- You can't set attributes on Iterators, so this is just a dumb wrapper
111
- that exposes the output_queue. Currently you probably shouldn't touch
112
- the output queue, but this is done with an eye toward implementing
113
- a data iterator that can read directly from the queue (instead of having
114
- to use the _instances iterator we define here.)
115
- """
116
- def __init__ (self ) -> None :
117
- self .manager = Manager ()
118
- self .output_queue = self .manager .Queue (outer_self .output_queue_size )
119
- self .num_workers = outer_self .num_workers
120
-
121
- def __iter__ (self ) -> Iterator [Instance ]:
122
- # pylint: disable=protected-access
123
- return outer_self ._instances (file_path , self .manager , self .output_queue )
124
-
125
- return QIterable ()
126
-
127
- def _instances (self , file_path : str , manager : Manager , output_queue : Queue ) -> Iterator [Instance ]:
128
- """
129
- A generator that reads instances off the output queue and yields them up
130
- until none are left (signified by all ``num_workers`` workers putting their
131
- ids into the queue).
132
- """
133
- shards = glob .glob (file_path )
134
- # Ensure a consistent order before shuffling for testing.
135
- shards .sort ()
136
- num_shards = len (shards )
137
-
138
- # If we want multiple epochs per read, put shards in the queue multiple times.
139
- input_queue = manager .Queue (num_shards * self .epochs_per_read + self .num_workers )
140
- for _ in range (self .epochs_per_read ):
141
- np .random .shuffle (shards )
142
- for shard in shards :
143
- input_queue .put (shard )
144
-
145
- # Then put a None per worker to signify no more files.
146
- for _ in range (self .num_workers ):
147
- input_queue .put (None )
148
-
149
- processes : List [Process ] = []
150
- num_finished = 0
151
-
152
- for worker_id in range (self .num_workers ):
153
- process = Process (target = _worker ,
154
- args = (self .reader , input_queue , output_queue , worker_id ))
155
- logger .info (f"starting worker { worker_id } " )
156
- process .start ()
157
- processes .append (process )
158
-
159
- # Keep going as long as not all the workers have finished.
160
- while num_finished < self .num_workers :
161
- item = output_queue .get ()
162
- if isinstance (item , int ):
163
- # Means a worker has finished, so increment the finished count.
164
- num_finished += 1
165
- logger .info (f"worker { item } finished ({ num_finished } /{ self .num_workers } )" )
166
- else :
167
- # Otherwise it's an ``Instance``, so yield it up.
168
- yield item
169
-
170
- for process in processes :
171
- process .join ()
172
- processes .clear ()
216
+ return QIterable (
217
+ output_queue_size = self .output_queue_size ,
218
+ epochs_per_read = self .epochs_per_read ,
219
+ num_workers = self .num_workers ,
220
+ reader = self .reader ,
221
+ file_path = file_path
222
+ )
0 commit comments