Skip to content
This repository was archived by the owner on Dec 16, 2022. It is now read-only.

Dataset remix #5372

Merged
merged 39 commits into from
Aug 25, 2021
Merged
Show file tree
Hide file tree
Changes from 34 commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
10c5479
Adds a dataset that can be read and written lazily
dirkgr Aug 7, 2021
36f9b67
This approach might work better.
dirkgr Aug 7, 2021
e74540d
Make ShuffledSequence take indices
dirkgr Aug 7, 2021
0eb53bf
Formatting
dirkgr Aug 7, 2021
dcedfd5
Adds failing test
dirkgr Aug 7, 2021
36948ce
Merge remote-tracking branch 'origin/main' into TangoBigData
dirkgr Aug 11, 2021
44eccf9
Fix sparse sequence tests
dirkgr Aug 11, 2021
f305de7
Fixes the Sqlite format
dirkgr Aug 11, 2021
61f8810
Quality-of-life hack
dirkgr Aug 11, 2021
989f15c
Makes an internal string less alarming
dirkgr Aug 11, 2021
9c461b7
Save the files to the right place
dirkgr Aug 11, 2021
15e0be4
Merge remote-tracking branch 'origin/main' into TangoBigData
dirkgr Aug 18, 2021
ca26abe
Formatting
dirkgr Aug 19, 2021
f2f0a34
Merge remote-tracking branch 'origin/main' into TangoBigData
dirkgr Aug 19, 2021
bb572b3
Fix for SqliteDatasetFormat
dirkgr Aug 20, 2021
6953d7d
Performance improvement for SqliteSparseSequence
dirkgr Aug 20, 2021
3f99be7
Changelog
dirkgr Aug 20, 2021
d69ea38
Merge branch 'main' into TangoBigData
dirkgr Aug 20, 2021
d58a52f
Global imports
dirkgr Aug 20, 2021
104777d
More Sequence classes
dirkgr Aug 21, 2021
b6b5f05
Say DatasetDict when we mean DatasetDict
dirkgr Aug 21, 2021
05c4dd6
Test for the sequences
dirkgr Aug 21, 2021
4304a93
Use the step name correctly in the error message
dirkgr Aug 21, 2021
d6cb8ab
Use and consume step_name correctly in Step.from_params()
dirkgr Aug 21, 2021
fd305a6
Uncacheable steps don't get cached even if they have a name
dirkgr Aug 21, 2021
3ae61eb
Adds a step that can remix a dataset
dirkgr Aug 21, 2021
2004fd2
Improve log message
dirkgr Aug 21, 2021
b0c3626
Fix relative import
dirkgr Aug 21, 2021
fcf651f
Changelog
dirkgr Aug 21, 2021
aa82e3d
Merge branch 'main' into DatasetRemix
dirkgr Aug 23, 2021
ca5cad3
Adds documentation
dirkgr Aug 23, 2021
d5f11f4
Merge branch 'DatasetRemix' of https://github.com/allenai/allennlp in…
dirkgr Aug 23, 2021
c52b050
Give the option of changing a det_hash simply()
dirkgr Aug 23, 2021
a32c7f2
Tix fypo
dirkgr Aug 23, 2021
6cccd64
Adds ability to shuffle datasets
dirkgr Aug 24, 2021
765575d
Test for det_hash
dirkgr Aug 24, 2021
c69df7e
Merge branch 'main' into DatasetRemix
dirkgr Aug 24, 2021
451e4ee
We don't use relative imports
dirkgr Aug 25, 2021
1d71b69
Merge branch 'DatasetRemix' of https://github.com/allenai/allennlp in…
dirkgr Aug 25, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Added `ScaledDotProductMatrixAttention`, and converted the transformer toolkit to use it
- Added tests to ensure that all `Attention` and `MatrixAttention` implementations are interchangeable
- Added a way for AllenNLP Tango to read and write datasets lazily.
- Added a way to remix datasets flexibly

### Fixed

Expand All @@ -42,6 +43,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- `ConfigurationError` is now pickleable.
- Multitask models now support `TextFieldTensor` in heads, not just in the backbone.
- Fixed the signature of `ScaledDotProductAttention` to match the other `Attention` classes
- Fixed the way names are applied to Tango `Step` instances.

### Changed

Expand Down
48 changes: 45 additions & 3 deletions allennlp/common/det_hash.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import collections
import hashlib
import io
from typing import Any
from typing import Any, MutableMapping

import base58
import dill
Expand All @@ -13,6 +14,9 @@ def det_hash_object(self) -> Any:
representation. Sometimes you want to take control over what goes into
that hash. In that case, implement this method. `det_hash()` will pickle the
result of this method instead of the object itself.

If you return `None`, `det_hash()` falls back to the original behavior and pickles
the object.
"""
raise NotImplementedError()

Expand All @@ -38,10 +42,48 @@ def det_hash_object(self) -> Any:
return self._det_hash_object


class DetHashWithVersion(CustomDetHash):
"""
Add this class as a mixing base class to make sure your class's det_hash can be modified
by altering a static `VERSION` member of your class.
"""

VERSION = None

def det_hash_object(self) -> Any:
if self.VERSION is not None:
return self.VERSION, self
else:
return None


class _DetHashPickler(dill.Pickler):
def __init__(self, buffer: io.BytesIO):
super().__init__(buffer)

# We keep track of how deeply we are nesting the pickling of an object.
# If a class returns `self` as part of `det_hash_object()`, it causes an
# infinite recursion, because we try to pickle the `det_hash_object()`, which
# contains `self`, which returns a `det_hash_object()`, etc.
# So we keep track of how many times recursively we are trying to pickle the
# same object. We only call `det_hash_object()` the first time. We assume that
# if `det_hash_object()` returns `self` in any way, we want the second time
# to just pickle the object as normal. `DetHashWithVersion` takes advantage
# of this ability.
self.recursively_pickled_ids: MutableMapping[int, int] = collections.Counter()

def save(self, obj, save_persistent_id=True):
self.recursively_pickled_ids[id(obj)] += 1
super().save(obj, save_persistent_id)
self.recursively_pickled_ids[id(obj)] -= 1

def persistent_id(self, obj: Any) -> Any:
if isinstance(obj, CustomDetHash):
return obj.__class__.__qualname__, obj.det_hash_object()
if isinstance(obj, CustomDetHash) and self.recursively_pickled_ids[id(obj)] <= 1:
det_hash_object = obj.det_hash_object()
if det_hash_object is not None:
return obj.__class__.__module__, obj.__class__.__qualname__, det_hash_object
else:
return None
elif isinstance(obj, type):
return obj.__module__, obj.__qualname__
else:
Expand Down
85 changes: 85 additions & 0 deletions allennlp/common/sequences.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
import bisect
import random
from collections import abc
from typing import Sequence, Optional, Union


class ShuffledSequence(abc.Sequence):
"""
Produces a shuffled view of a sequence, such as a list.

This assumes that the inner sequence never changes. If it does, the results
are undefined.
"""

def __init__(self, inner_sequence: Sequence, indices: Optional[Sequence[int]] = None):
self.inner = inner_sequence
self.indices: Sequence[int]
if indices is None:
self.indices = list(range(len(inner_sequence)))
random.shuffle(self.indices)
else:
self.indices = indices

def __len__(self) -> int:
return len(self.indices)

def __getitem__(self, i: Union[int, slice]):
if isinstance(i, int):
return self.inner[self.indices[i]]
else:
return ShuffledSequence(self.inner, self.indices[i])

def __contains__(self, item) -> bool:
for i in self.indices:
if self.inner[i] == item:
return True
return False


class SlicedSequence(ShuffledSequence):
"""
Produces a sequence that's a slice into another sequence, without copying the elements.

This assumes that the inner sequence never changes. If it does, the results
are undefined.
"""

def __init__(self, inner_sequence: Sequence, s: slice):
super().__init__(inner_sequence, range(*s.indices(len(inner_sequence))))


class ConcatenatedSequence(abc.Sequence):
"""
Produces a sequence that's the concatenation of multiple other sequences, without
copying the elements.

This assumes that the inner sequence never changes. If it does, the results
are undefined.
"""

def __init__(self, *sequences: Sequence):
self.sequences = sequences
self.cumulative_sequence_lengths = [0]
for sequence in sequences:
self.cumulative_sequence_lengths.append(
self.cumulative_sequence_lengths[-1] + len(sequence)
)

def __len__(self):
return self.cumulative_sequence_lengths[-1]

def __getitem__(self, i: Union[int, slice]):
if isinstance(i, int):
if i < 0:
i += len(self)
if i < 0 or i >= len(self):
raise IndexError("list index out of range")
sequence_index = bisect.bisect_right(self.cumulative_sequence_lengths, i) - 1
i -= self.cumulative_sequence_lengths[sequence_index]
return self.sequences[sequence_index][i]
else:
return SlicedSequence(self, i)

def __contains__(self, item) -> bool:
return any(s.__contains__(item) for s in self.sequences)
5 changes: 2 additions & 3 deletions allennlp/common/sqlite_sparse_sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,9 @@
import shutil
from os import PathLike
from typing import MutableSequence, Any, Union, Iterable

from sqlitedict import SqliteDict

from allennlp.tango.dataloader import ShuffledSequence
from allennlp.common.sequences import SlicedSequence


class SqliteSparseSequence(MutableSequence[Any]):
Expand All @@ -28,7 +27,7 @@ def __getitem__(self, i: Union[int, slice]) -> Any:
else:
return None
elif isinstance(i, slice):
return ShuffledSequence(self, range(*i.indices(len(self))))
return SlicedSequence(self, i)
else:
raise TypeError(f"list indices must be integers or slices, not {i.__class__.__name__}")

Expand Down
35 changes: 2 additions & 33 deletions allennlp/tango/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,8 @@
"""

import logging
import random
from collections import abc
from math import floor, ceil
from typing import Optional, Iterator, Sequence, Union, Dict, Any
from typing import Optional, Iterator, Sequence, Dict, Any

import more_itertools
import torch
Expand All @@ -22,6 +20,7 @@
Vocabulary,
)
from allennlp.nn.util import move_to_device
from allennlp.common.sequences import ShuffledSequence


class TangoDataLoader(Registrable):
Expand Down Expand Up @@ -86,36 +85,6 @@ def set_target_device(self, device: torch.device) -> None:
self.target_device = device


class ShuffledSequence(abc.Sequence):
"""
Produces a shuffled view of a sequence, such as a list.

This assumes that the inner sequence never changes. If it does, the results
are undefined.
"""

def __init__(self, inner_sequence: Sequence, indices: Optional[Sequence[int]] = None):
self.inner = inner_sequence
self.indices: Sequence[int]
if indices is None:
self.indices = list(range(len(inner_sequence)))
random.shuffle(self.indices)
else:
self.indices = indices

def __len__(self) -> int:
return len(self.inner)

def __getitem__(self, i: Union[int, slice]):
if isinstance(i, int):
return self.inner[self.indices[i]]
else:
return ShuffledSequence(self.inner, self.indices[i])

def __contains__(self, item) -> bool:
return self.inner.__contains__(item)


@TangoDataLoader.register("batch_size")
class BatchSizeDataLoader(TangoDataLoader):
"""A data loader that turns instances into batches with a constant number of instances
Expand Down
49 changes: 47 additions & 2 deletions allennlp/tango/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,13 @@
"""

import itertools
import re
from dataclasses import dataclass, field
from typing import Mapping, Any, Optional, Sequence, Dict

from allennlp.data import Vocabulary, DatasetReader, Instance
from allennlp.tango.step import Step
from allennlp.common.sequences import SlicedSequence, ConcatenatedSequence
from tqdm import tqdm


Expand Down Expand Up @@ -39,9 +41,9 @@ def __len__(self) -> int:
@Step.register("dataset_reader_adapter")
class DatasetReaderAdapterStep(Step):
"""
This step creates an `AllenNlpDataset` from old-school dataset readers. If you're
This step creates an `DatasetDict` from old-school dataset readers. If you're
tempted to write a new `DatasetReader`, and then use this step with it, don't.
Just write a `Step` that creates the `AllenNlpDataset` you need directly.
Just write a `Step` that creates the `DatasetDict` you need directly.
"""

DETERMINISTIC = True # We're giving the dataset readers some credit here.
Expand Down Expand Up @@ -72,3 +74,46 @@ def run(self, reader: DatasetReader, splits: Dict[str, str]) -> DatasetDict: #
instance.index_fields(vocab)

return DatasetDict(splits=instances_map, vocab=vocab)


@Step.register("dataset_remix")
class DatasetRemixStep(Step):
"""
This step can remix splits in a dataset into new splits.
"""

DETERMINISTIC = True
CACHEABLE = False # This is so fast it's not worth caching.
VERSION = "001"

def run( # type: ignore
self, input: DatasetDict, new_splits: Dict[str, str], keep_old_splits: bool = True
) -> DatasetDict:
def get_slice(split_name: str) -> Sequence[Any]:
slice_match = re.match(r"(.*)\[([0123456789:]*)]", split_name)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This won't work for something like train[:50000] + dev[:10000]. Is it supposed to?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it will work. This function is only called on the parts after .split("+").

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok I see, and train[:50000] in that case is interpreted as it should be?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. It supports full Python slice syntax on line 98. In condensed form, it does slice(*match.split(":")).

if slice_match is None:
return input[split_name]
else:
split_name = slice_match[1]
slice_args = [int(a) if len(a) > 0 else None for a in slice_match[2].split(":")]
return SlicedSequence(input[split_name], slice(*slice_args))

def parse_split_spec(split_spec: str):
parts = [get_slice(name.strip()) for name in split_spec.split("+")]
if len(parts) == 1:
return parts[0]
else:
return ConcatenatedSequence(*parts)

if keep_old_splits:
result = dict(input.splits.items())
else:
result = {}
result.update(
{
new_split_name: parse_split_spec(new_split_spec)
for new_split_name, new_split_spec in new_splits.items()
}
)

return DatasetDict(vocab=input.vocab, metadata=input.metadata, splits=result)
2 changes: 1 addition & 1 deletion allennlp/tango/hf_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

@Step.register("hf_dataset")
class HuggingfaceDataset(Step):
"""This steps reads a huggingface dataset and returns it in `AllenNlpDataset` format."""
"""This steps reads a huggingface dataset and returns it in `DatasetDict` format."""

DETERMINISTIC = True
VERSION = "001"
Expand Down
Loading