Skip to content
This repository was archived by the owner on Dec 16, 2022. It is now read-only.

improve signal handling and worker cleanup #5378

Merged
merged 5 commits into from
Aug 28, 2021
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
## Unreleased

### Added

- Added more documentation to the learning rate schedulers to include a sample config object for how to use it.
- Moved the pytorch learning rate schedulers wrappers to their own file called `pytorch_lr_schedulers.py` so that they will have their own documentation page.
- Added a module `allennlp.nn.parallel` with a new base class, `DdpAccelerator`, which generalizes
Expand Down Expand Up @@ -44,6 +45,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- `ConfigurationError` is now pickleable.
- Multitask models now support `TextFieldTensor` in heads, not just in the backbone.
- Fixed the signature of `ScaledDotProductAttention` to match the other `Attention` classes
- `allennlp` commands will now catch `SIGTERM` signals and handle them similar to `SIGINT` (keyboard interrupt).
- The `MultiProcessDataLoader` will properly shutdown its workers when a `SIGTERM` is received.
- Fixed the way names are applied to Tango `Step` instances.

### Changed
Expand Down
4 changes: 4 additions & 0 deletions allennlp/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,10 @@ def run():
)

from allennlp.commands import main # noqa
from allennlp.common.util import install_sigterm_handler

# We want to be able to catch SIGTERM signals in addition to SIGINT (keyboard interrupt).
install_sigterm_handler()

main(prog="allennlp")

Expand Down
2 changes: 1 addition & 1 deletion allennlp/commands/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -502,7 +502,7 @@ def _train_worker(
dist.barrier()

metrics = train_loop.run()
except KeyboardInterrupt:
except (KeyboardInterrupt, common_util.SigTermReceived):
# if we have completed an epoch, try to create a model archive.
if primary:
best_weights_path = train_loop.trainer.get_best_weights_path()
Expand Down
6 changes: 4 additions & 2 deletions allennlp/common/logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,10 +113,12 @@ def prepare_global_logging(
root_logger.addHandler(stdout_handler)
root_logger.addHandler(stderr_handler)

from allennlp.common.util import SigTermReceived

# write uncaught exceptions to the logs
def excepthook(exctype, value, traceback):
# For a KeyboardInterrupt, call the original exception handler.
if issubclass(exctype, KeyboardInterrupt):
# For interruptions, call the original exception handler.
if issubclass(exctype, (KeyboardInterrupt, SigTermReceived)):
sys.__excepthook__(exctype, value, traceback)
return
root_logger.critical("Uncaught exception", exc_info=(exctype, value, traceback))
Expand Down
13 changes: 13 additions & 0 deletions allennlp/common/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import pkgutil
import random
import sys
import signal
from contextlib import contextmanager
from itertools import islice, zip_longest
from pathlib import Path
Expand Down Expand Up @@ -729,3 +730,15 @@ def hash_object(o: Any) -> str:
dill.dump(o, buffer)
m.update(buffer.getbuffer())
return base58.b58encode(m.digest()).decode()


class SigTermReceived(Exception):
pass


def _handle_sigterm(sig, frame):
raise SigTermReceived


def install_sigterm_handler():
signal.signal(signal.SIGTERM, _handle_sigterm)
127 changes: 102 additions & 25 deletions allennlp/data/data_loaders/multiprocess_data_loader.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
from collections import deque
import logging
from multiprocessing.process import BaseProcess
from multiprocessing.connection import Connection
import random
import traceback
from typing import List, Iterator, Optional, Iterable, Union, TypeVar
import select
from queue import Full
from typing import List, Iterator, Optional, Iterable, Union, TypeVar, Tuple, Any

from overrides import overrides
import torch
Expand Down Expand Up @@ -374,7 +377,7 @@ def iter_instances(self) -> Iterator[Instance]:
if self._max_instance_queue_size is None
else ctx.JoinableQueue(maxsize=self._max_instance_queue_size)
)
workers = self._start_instance_workers(queue, ctx)
workers, txs = self._start_instance_workers(queue, ctx)

try:
for instance in self._maybe_tqdm(
Expand All @@ -386,7 +389,7 @@ def iter_instances(self) -> Iterator[Instance]:
finally:
if hasattr(queue, "close"): # for compat with different Python versions.
queue.close() # type: ignore[attr-defined]
self._join_workers(workers, queue)
self._join_workers(workers, queue, txs)

@overrides
def set_target_device(self, device: torch.device) -> None:
Expand All @@ -404,7 +407,7 @@ def _iter_batches(self) -> Iterator[TensorDict]:
if self._max_batch_queue_size is None
else ctx.JoinableQueue(maxsize=self._max_batch_queue_size)
)
workers = self._start_batch_workers(queue, ctx)
workers, txs = self._start_batch_workers(queue, ctx)

try:
# We can now start consuming from the `queue` as the batch workers
Expand All @@ -425,47 +428,107 @@ def _iter_batches(self) -> Iterator[TensorDict]:
finally:
if hasattr(queue, "close"): # for compat with different Python versions.
queue.close() # type: ignore[attr-defined]
self._join_workers(workers, queue)
self._join_workers(workers, queue, txs)

def _start_instance_workers(self, queue: mp.JoinableQueue, ctx) -> List[BaseProcess]:
def _start_instance_workers(
self, queue: mp.JoinableQueue, ctx
) -> Tuple[List[BaseProcess], List[Connection]]:
Tqdm.set_lock(mp.RLock())
workers: List[BaseProcess] = []
txs: List[Connection] = []
for worker_id in range(self.num_workers):
rx, tx = ctx.Pipe(duplex=False)
worker: BaseProcess = ctx.Process(
target=self._instance_worker, args=(worker_id, queue, Tqdm.get_lock()), daemon=True
target=self._instance_worker,
args=(worker_id, queue, Tqdm.get_lock(), rx),
daemon=True,
)
worker.start()
workers.append(worker)
return workers
txs.append(tx)
return workers, txs

def _start_batch_workers(self, queue: mp.JoinableQueue, ctx) -> List[BaseProcess]:
def _start_batch_workers(
self, queue: mp.JoinableQueue, ctx
) -> Tuple[List[BaseProcess], List[Connection]]:
Tqdm.set_lock(mp.RLock())
workers: List[BaseProcess] = []
txs: List[Connection] = []
for worker_id in range(self.num_workers):
rx, tx = ctx.Pipe(duplex=False)
worker: BaseProcess = ctx.Process(
target=self._batch_worker, args=(worker_id, queue, Tqdm.get_lock()), daemon=True
target=self._batch_worker, args=(worker_id, queue, Tqdm.get_lock(), rx), daemon=True
)
worker.start()
workers.append(worker)
return workers

def _join_workers(self, workers: List[BaseProcess], queue) -> None:
# Each worker will be blocking on a call to `queue.join()`,
# calling `queue.task_done()` times the number of workers will
# call the `queue.join()` to return, and each worker should exit on its own.
txs.append(tx)
return workers, txs

def _join_workers(self, workers: List[BaseProcess], queue, txs: List[Connection]) -> None:
# If the workers have exhausted their batch/instance generators,
# they will be blocking on a call to `queue.join()`,
# so calling `queue.task_done()` times the number of workers will
# allow the `queue.join()` to return and each worker should exit on its own.
for _ in range(len(workers)):
try:
queue.task_done()
except ValueError:
# This happens if a worker died early.
break
# If for some reason the workers don't exit properly, we go through and terminate
# them anyway.
for worker in workers:
# But if we're joining the workers due to an exception in the main process,
# they probably won't be finished, so we need to tell them to stop.
# We first do this nicely by sending them a message through their corresponding
# tx connection.
for tx in txs:
tx.send("stop")

# If for some reason the workers still haven't exited, we go through and terminate
# them.
for i, worker in enumerate(workers):
worker.join(1)
if worker.is_alive():
logger.warning("terminating worker %s", i)
worker.terminate()

def _instance_worker(self, worker_id: int, queue: mp.JoinableQueue, lock) -> None:
def _safe_queue_put(
self, worker_id: int, item: Any, queue: mp.JoinableQueue, rx: Connection
) -> bool:
while True:
# First we have to check to make sure the parent process is still alive
# and consuming from the queue because there are circumstances where the
# parent process can exit without automatically cleaning up its children (the workers).
# For example, when the parent process is killed with `kill -9`.
# So the first thing we do is check to see if the parent has notified
# us (the worker) to stop through the rx connection.
# Of course this only works if the parent was able to send out a notification,
# which may not always be the case. So we have a backup check below.
if rx.poll():
logger.warning(
"worker %d received stop message from parent, exiting now", worker_id
)
queue.cancel_join_thread()
return False
# The is the backup check, but it only works if the worker was spawned
# (as opposed to being created from a fork).
if self.start_method == "spawn":
# The file descriptor associated with the rx (receiver) connection will
# be readable if and only if the parent process has exited.
fds, _, _ = select.select([rx.fileno()], [], [], 0)
if fds:
logger.warning("worker %d parent process has died, exiting now", worker_id)
queue.cancel_join_thread()
return False
# If we're down here the parent process is still alive to the best of our
# knowledge, so we can continue putting things on the queue.
try:
queue.put(item, True, 0.1)
return True
except Full:
continue

def _instance_worker(
self, worker_id: int, queue: mp.JoinableQueue, lock, rx: Connection
) -> None:
Tqdm.set_lock(lock)
try:
self.reader._set_worker_info(WorkerInfo(self.num_workers, worker_id))
Expand All @@ -488,27 +551,41 @@ def _instance_worker(self, worker_id: int, queue: mp.JoinableQueue, lock) -> Non
"so already)."
)
checked_for_token_indexers = True
queue.put((instance, None))
if self._safe_queue_put(worker_id, (instance, None), queue, rx):
continue
else:
# Couldn't put item on queue because parent process has exited.
return
except Exception as e:
queue.put((None, (repr(e), traceback.format_exc())))
if not self._safe_queue_put(
worker_id, (None, (repr(e), traceback.format_exc())), queue, rx
):
return

# Indicate to the consumer that this worker is finished.
queue.put((None, None))

# Wait until this process can safely exit.
queue.join()

def _batch_worker(self, worker_id: int, queue: mp.JoinableQueue, lock) -> None:
def _batch_worker(self, worker_id: int, queue: mp.JoinableQueue, lock, rx: Connection) -> None:
Tqdm.set_lock(lock)
try:
self.reader._set_worker_info(WorkerInfo(self.num_workers, worker_id))
instances = self.reader.read(self.data_path)
for batch in self._instances_to_batches(
instances, move_to_device=self._worker_cuda_safe
):
queue.put((batch, None))
if self._safe_queue_put(worker_id, (batch, None), queue, rx):
continue
else:
# Couldn't put item on queue because parent process has exited.
return
except Exception as e:
queue.put((None, (repr(e), traceback.format_exc())))
if not self._safe_queue_put(
worker_id, (None, (repr(e), traceback.format_exc())), queue, rx
):
return

# Indicate to the consumer (main thread) that this worker is finished.
queue.put((None, None))
Expand Down