This repository was archived by the owner on Dec 16, 2022. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 2.2k
Dataset remix #5372
Merged
Merged
Dataset remix #5372
Changes from 32 commits
Commits
Show all changes
39 commits
Select commit
Hold shift + click to select a range
10c5479
Adds a dataset that can be read and written lazily
dirkgr 36f9b67
This approach might work better.
dirkgr e74540d
Make ShuffledSequence take indices
dirkgr 0eb53bf
Formatting
dirkgr dcedfd5
Adds failing test
dirkgr 36948ce
Merge remote-tracking branch 'origin/main' into TangoBigData
dirkgr 44eccf9
Fix sparse sequence tests
dirkgr f305de7
Fixes the Sqlite format
dirkgr 61f8810
Quality-of-life hack
dirkgr 989f15c
Makes an internal string less alarming
dirkgr 9c461b7
Save the files to the right place
dirkgr 15e0be4
Merge remote-tracking branch 'origin/main' into TangoBigData
dirkgr ca26abe
Formatting
dirkgr f2f0a34
Merge remote-tracking branch 'origin/main' into TangoBigData
dirkgr bb572b3
Fix for SqliteDatasetFormat
dirkgr 6953d7d
Performance improvement for SqliteSparseSequence
dirkgr 3f99be7
Changelog
dirkgr d69ea38
Merge branch 'main' into TangoBigData
dirkgr d58a52f
Global imports
dirkgr 104777d
More Sequence classes
dirkgr b6b5f05
Say DatasetDict when we mean DatasetDict
dirkgr 05c4dd6
Test for the sequences
dirkgr 4304a93
Use the step name correctly in the error message
dirkgr d6cb8ab
Use and consume step_name correctly in Step.from_params()
dirkgr fd305a6
Uncacheable steps don't get cached even if they have a name
dirkgr 3ae61eb
Adds a step that can remix a dataset
dirkgr 2004fd2
Improve log message
dirkgr b0c3626
Fix relative import
dirkgr fcf651f
Changelog
dirkgr aa82e3d
Merge branch 'main' into DatasetRemix
dirkgr ca5cad3
Adds documentation
dirkgr d5f11f4
Merge branch 'DatasetRemix' of https://github.com/allenai/allennlp in…
dirkgr c52b050
Give the option of changing a det_hash simply()
dirkgr a32c7f2
Tix fypo
dirkgr 6cccd64
Adds ability to shuffle datasets
dirkgr 765575d
Test for det_hash
dirkgr c69df7e
Merge branch 'main' into DatasetRemix
dirkgr 451e4ee
We don't use relative imports
dirkgr 1d71b69
Merge branch 'DatasetRemix' of https://github.com/allenai/allennlp in…
dirkgr File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,85 @@ | ||
import bisect | ||
import random | ||
from collections import abc | ||
from typing import Sequence, Optional, Union | ||
|
||
|
||
class ShuffledSequence(abc.Sequence): | ||
""" | ||
Produces a shuffled view of a sequence, such as a list. | ||
|
||
This assumes that the inner sequence never changes. If it does, the results | ||
are undefined. | ||
""" | ||
|
||
def __init__(self, inner_sequence: Sequence, indices: Optional[Sequence[int]] = None): | ||
self.inner = inner_sequence | ||
self.indices: Sequence[int] | ||
if indices is None: | ||
self.indices = list(range(len(inner_sequence))) | ||
random.shuffle(self.indices) | ||
else: | ||
self.indices = indices | ||
|
||
def __len__(self) -> int: | ||
return len(self.indices) | ||
|
||
def __getitem__(self, i: Union[int, slice]): | ||
if isinstance(i, int): | ||
return self.inner[self.indices[i]] | ||
else: | ||
return ShuffledSequence(self.inner, self.indices[i]) | ||
|
||
def __contains__(self, item) -> bool: | ||
for i in self.indices: | ||
if self.inner[i] == item: | ||
return True | ||
return False | ||
|
||
|
||
class SlicedSequence(ShuffledSequence): | ||
""" | ||
Produces a sequence that's a slice into another sequence, without copying the elements. | ||
|
||
This assumes that the inner sequence never changes. If it does, the results | ||
are undefined. | ||
""" | ||
|
||
def __init__(self, inner_sequence: Sequence, s: slice): | ||
super().__init__(inner_sequence, range(*s.indices(len(inner_sequence)))) | ||
|
||
|
||
class ConcatenatedSequence(abc.Sequence): | ||
""" | ||
Produces a sequence that's the concatenation of multiple other sequences, without | ||
copying the elements. | ||
|
||
This assumes that the inner sequence never changes. If it does, the results | ||
are undefined. | ||
""" | ||
|
||
def __init__(self, *sequences: Sequence): | ||
self.sequences = sequences | ||
self.cumulative_sequence_lengths = [0] | ||
for sequence in sequences: | ||
self.cumulative_sequence_lengths.append( | ||
self.cumulative_sequence_lengths[-1] + len(sequence) | ||
) | ||
|
||
def __len__(self): | ||
return self.cumulative_sequence_lengths[-1] | ||
|
||
def __getitem__(self, i: Union[int, slice]): | ||
if isinstance(i, int): | ||
if i < 0: | ||
i += len(self) | ||
if i < 0 or i >= len(self): | ||
raise IndexError("list index out of range") | ||
sequence_index = bisect.bisect_right(self.cumulative_sequence_lengths, i) - 1 | ||
i -= self.cumulative_sequence_lengths[sequence_index] | ||
return self.sequences[sequence_index][i] | ||
else: | ||
return SlicedSequence(self, i) | ||
|
||
def __contains__(self, item) -> bool: | ||
return any(s.__contains__(item) for s in self.sequences) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
import pytest | ||
from allennlp.common.sequences import ConcatenatedSequence | ||
|
||
|
||
def assert_equal_including_exceptions(expected_fn, actual_fn): | ||
try: | ||
expected = expected_fn() | ||
except Exception as e: | ||
with pytest.raises(e.__class__): | ||
actual_fn() | ||
else: | ||
assert expected == actual_fn() | ||
|
||
|
||
def test_concatenated_sequence(): | ||
l1 = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] | ||
l2 = ConcatenatedSequence([0, 1], [], [2, 3, 4], [5, 6, 7, 8, 9], []) | ||
|
||
# __len__() | ||
assert len(l1) == len(l2) | ||
|
||
# index() | ||
for item in l1 + [999]: | ||
# no indices | ||
assert_equal_including_exceptions(lambda: l1.index(item), lambda: l2.index(item)) | ||
|
||
# only start index | ||
for index in range(-15, 15): | ||
assert_equal_including_exceptions( | ||
lambda: l1.index(item, index), lambda: l2.index(item, index) | ||
) | ||
|
||
# start and stop index | ||
for start_index in range(-15, 15): | ||
for end_index in range(-15, 15): | ||
assert_equal_including_exceptions( | ||
lambda: l1.index(item, start_index, end_index), | ||
lambda: l2.index(item, start_index, end_index), | ||
) | ||
|
||
# __getitem__() | ||
for index in range(-15, 15): | ||
assert_equal_including_exceptions(lambda: l1[index], lambda: l2[index]) | ||
|
||
for start_index in range(-15, 15): | ||
for end_index in range(-15, 15): | ||
assert_equal_including_exceptions( | ||
lambda: l1[start_index:end_index], lambda: list(l2[start_index:end_index]) | ||
) | ||
|
||
# count() | ||
for item in l1 + [999]: | ||
assert_equal_including_exceptions(lambda: l1.count(item), lambda: l2.count(item)) | ||
|
||
# __contains__() | ||
for item in l1 + [999]: | ||
assert_equal_including_exceptions(lambda: item in l1, lambda: item in l2) |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This won't work for something like
train[:50000] + dev[:10000]
. Is it supposed to?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, it will work. This function is only called on the parts after
.split("+")
.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok I see, and train[:50000] in that case is interpreted as it should be?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes. It supports full Python slice syntax on line 98. In condensed form, it does
slice(*match.split(":"))
.