Skip to content
This repository was archived by the owner on Dec 16, 2022. It is now read-only.

Commit b764bef

Browse files
authored
simplify dataset classes, fix multi-process lazy loading (#4344)
* simplify dataset classes, fix multi-process lazy loading * remove unecessary overrides decorators * warn about tokenizers deadlock * clean up * fix race conditions * fix type hint * remove outdated docstring * fixes * fix caching again * issue warning when can't write to cache safely * comments * update CHANGELOG * update docstring of _read * revert generic type * test and fix 'multi_worker_islice * clean up * no more tuples :( * revert, revert, revert * revert * revert * non-lazy locking * add another test * revert * make mypy happy * update CHANGELOG * improvements * fix comment * doc fixes * update CHANGELOG * tweak docs * improve caching system * improve caching * add test * add another logging statement * warnings * add UserWarning about manual sharding
1 parent 884a614 commit b764bef

File tree

12 files changed

+645
-231
lines changed

12 files changed

+645
-231
lines changed

CHANGELOG.md

+9-1
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,17 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
77

88
## Unreleased
99

10-
## Fixed
10+
### Fixed
11+
12+
- Lazy dataset readers now work correctly with multi-process data loading.
13+
- Fixed race conditions that could occur when using a dataset cache.
14+
15+
### Added
1116

1217
- A bug where where all datasets would be loaded for vocab creation even if not needed.
18+
- A parameter to the `DatasetReader` class: `manual_multi_process_sharding`. This is similar
19+
to the `manual_distributed_sharding` parameter, but applies when using a multi-process
20+
`DataLoader`.
1321

1422
## [v1.0.0rc6](https://github.com/allenai/allennlp/releases/tag/v1.0.0rc6) - 2020-06-11
1523

allennlp/common/file_utils.py

+49-23
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
import glob
66
import os
77
import logging
8-
import shutil
98
import tempfile
109
import json
1110
from urllib.parse import urlparse
@@ -243,6 +242,46 @@ def _find_latest_cached(url: str, cache_dir: str) -> Optional[str]:
243242
return None
244243

245244

245+
class CacheFile:
246+
"""
247+
This is a context manager that makes robust caching easier.
248+
249+
On `__enter__`, an IO handle to a temporarily file is returned, which can
250+
be treated as if it's the actual cache file.
251+
252+
On `__exit__`, the temporarily file is renamed to the cache file. If anything
253+
goes wrong while writing to the temporary file, it will be removed.
254+
"""
255+
256+
def __init__(self, cache_filename: Union[Path, str], mode="w+b") -> None:
257+
self.cache_filename = (
258+
cache_filename if isinstance(cache_filename, Path) else Path(cache_filename)
259+
)
260+
self.cache_directory = os.path.dirname(self.cache_filename)
261+
self.mode = mode
262+
self.temp_file = tempfile.NamedTemporaryFile(
263+
self.mode, dir=self.cache_directory, delete=False, suffix=".tmp"
264+
)
265+
266+
def __enter__(self):
267+
return self.temp_file
268+
269+
def __exit__(self, exc_type, exc_value, traceback):
270+
self.temp_file.close()
271+
if exc_value is None:
272+
# Success.
273+
logger.info(
274+
"Renaming temp file %s to cache at %s", self.temp_file.name, self.cache_filename
275+
)
276+
# Rename the temp file to the actual cache filename.
277+
os.replace(self.temp_file.name, self.cache_filename)
278+
return True
279+
# Something went wrong, remove the temp file.
280+
logger.info("removing temp file %s", self.temp_file.name)
281+
os.remove(self.temp_file.name)
282+
return False
283+
284+
246285
# TODO(joelgrus): do we want to do checksums or anything like that?
247286
def get_from_cache(url: str, cache_dir: str = None) -> str:
248287
"""
@@ -303,33 +342,20 @@ def get_from_cache(url: str, cache_dir: str = None) -> str:
303342
if os.path.exists(cache_path):
304343
logger.info("cache of %s is up-to-date", url)
305344
else:
306-
# Download to temporary file, then copy to cache dir once finished.
307-
# Otherwise you get corrupt cache entries if the download gets interrupted.
308-
with tempfile.NamedTemporaryFile() as temp_file:
309-
logger.info("%s not found in cache, downloading to %s", url, temp_file.name)
345+
with CacheFile(cache_path) as cache_file:
346+
logger.info("%s not found in cache, downloading to %s", url, cache_file.name)
310347

311348
# GET file object
312349
if url.startswith("s3://"):
313-
_s3_get(url, temp_file)
350+
_s3_get(url, cache_file)
314351
else:
315-
_http_get(url, temp_file)
316-
317-
# we are copying the file before closing it, so flush to avoid truncation
318-
temp_file.flush()
319-
# shutil.copyfileobj() starts at the current position, so go to the start
320-
temp_file.seek(0)
321-
322-
logger.info("copying %s to cache at %s", temp_file.name, cache_path)
323-
with open(cache_path, "wb") as cache_file:
324-
shutil.copyfileobj(temp_file, cache_file) # type: ignore
325-
326-
logger.info("creating metadata file for %s", cache_path)
327-
meta = {"url": url, "etag": etag}
328-
meta_path = cache_path + ".json"
329-
with open(meta_path, "w") as meta_file:
330-
json.dump(meta, meta_file)
352+
_http_get(url, cache_file)
331353

332-
logger.info("removing temp file %s", temp_file.name)
354+
logger.info("creating metadata file for %s", cache_path)
355+
meta = {"url": url, "etag": etag}
356+
meta_path = cache_path + ".json"
357+
with open(meta_path, "w") as meta_file:
358+
json.dump(meta, meta_file)
333359

334360
return cache_path
335361

allennlp/data/__init__.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,9 @@
11
from allennlp.data.dataloader import DataLoader, allennlp_collate
2-
from allennlp.data.dataset_readers.dataset_reader import DatasetReader
2+
from allennlp.data.dataset_readers.dataset_reader import (
3+
DatasetReader,
4+
AllennlpDataset,
5+
AllennlpLazyDataset,
6+
)
37
from allennlp.data.fields.field import DataArray, Field
48
from allennlp.data.fields.text_field import TextFieldTensors
59
from allennlp.data.instance import Instance

allennlp/data/dataloader.py

+10-2
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
11
from typing import List, Dict, Union
2+
import warnings
23

34
import torch
45
from torch.utils import data
56

67
from allennlp.common.registrable import Registrable
7-
from allennlp.data.instance import Instance
8-
98
from allennlp.common.lazy import Lazy
9+
from allennlp.data.instance import Instance
10+
from allennlp.data.dataset_readers.dataset_reader import AllennlpLazyDataset
1011
from allennlp.data.batch import Batch
1112
from allennlp.data.samplers import Sampler, BatchSampler
1213

@@ -65,6 +66,13 @@ def __init__(
6566
multiprocessing_context: str = None,
6667
batches_per_epoch: int = None,
6768
):
69+
if num_workers and isinstance(dataset, AllennlpLazyDataset):
70+
warnings.warn(
71+
"Using multi-process data loading with a lazy dataset could lead to "
72+
"deadlocks with certain tokenizers. See:\n"
73+
" https://github.com/allenai/allennlp/issues/4330\n",
74+
UserWarning,
75+
)
6876
super().__init__(
6977
dataset=dataset,
7078
batch_size=batch_size,

allennlp/data/dataset_readers/__init__.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,11 @@
88

99

1010
from allennlp.data.dataset_readers.conll2003 import Conll2003DatasetReader
11-
from allennlp.data.dataset_readers.dataset_reader import DatasetReader
11+
from allennlp.data.dataset_readers.dataset_reader import (
12+
DatasetReader,
13+
AllennlpDataset,
14+
AllennlpLazyDataset,
15+
)
1216
from allennlp.data.dataset_readers.interleaving_dataset_reader import InterleavingDatasetReader
1317
from allennlp.data.dataset_readers.sequence_tagging import SequenceTaggingDatasetReader
1418
from allennlp.data.dataset_readers.sharded_dataset_reader import ShardedDatasetReader

0 commit comments

Comments
 (0)