Skip to content
This repository was archived by the owner on Dec 16, 2022. It is now read-only.

Commit 8bbde0d

Browse files
authored
Fix index_with bug in basic iterator (#1715)
* Fix index_with bug in basic iterator * fix test * Move index_fields to ensure_batch_is_sufficiently_small * Add comment
1 parent 3f54fc8 commit 8bbde0d

File tree

4 files changed

+15
-2
lines changed

4 files changed

+15
-2
lines changed

allennlp/data/iterators/data_iterator.py

+5
Original file line numberDiff line numberDiff line change
@@ -243,6 +243,11 @@ def _ensure_batch_is_sufficiently_small(self, batch_instances: Iterable[Instance
243243
padding_length = -1
244244
list_batch_instances = list(batch_instances)
245245
for instance in list_batch_instances:
246+
if self.vocab is not None:
247+
# we index here to ensure that shape information is available,
248+
# as in some cases (with self._maximum_samples_per_batch)
249+
# we need access to shaping information before batches are constructed)
250+
instance.index_fields(self.vocab)
246251
field_lengths = instance.get_padding_lengths()
247252
for _, lengths in field_lengths.items():
248253
try:

allennlp/tests/data/iterators/basic_iterator_test.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,6 @@ def __iter__(self):
4141
def create_instance(self, str_tokens: List[str]):
4242
tokens = [Token(t) for t in str_tokens]
4343
instance = Instance({'text': TextField(tokens, self.token_indexers)})
44-
instance.index_fields(self.vocab)
4544
return instance
4645

4746
def assert_instances_are_correct(self, candidate_instances):
@@ -69,6 +68,7 @@ def test_get_num_batches(self):
6968
def test_yield_one_epoch_iterates_over_the_data_once(self):
7069
for test_instances in (self.instances, self.lazy_instances):
7170
iterator = BasicIterator(batch_size=2)
71+
iterator.index_with(self.vocab)
7272
batches = list(iterator(test_instances, num_epochs=1))
7373
# We just want to get the single-token array for the text field in the instance.
7474
instances = [tuple(instance.detach().cpu().numpy())
@@ -79,7 +79,9 @@ def test_yield_one_epoch_iterates_over_the_data_once(self):
7979

8080
def test_call_iterates_over_data_forever(self):
8181
for test_instances in (self.instances, self.lazy_instances):
82-
generator = BasicIterator(batch_size=2)(test_instances)
82+
iterator = BasicIterator(batch_size=2)
83+
iterator.index_with(self.vocab)
84+
generator = iterator(test_instances)
8385
batches = [next(generator) for _ in range(18)] # going over the data 6 times
8486
# We just want to get the single-token array for the text field in the instance.
8587
instances = [tuple(instance.detach().cpu().numpy())
@@ -218,6 +220,7 @@ def test_maximum_samples_per_batch(self):
218220
iterator = BasicIterator(
219221
batch_size=3, maximum_samples_per_batch=['num_tokens', 9]
220222
)
223+
iterator.index_with(self.vocab)
221224
batches = list(iterator._create_batches(test_instances, shuffle=False))
222225

223226
# ensure all instances are in a batch

allennlp/tests/data/iterators/bucket_iterator_test.py

+4
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ class TestBucketIterator(IteratorTest):
1111
# pylint: disable=protected-access
1212
def test_create_batches_groups_correctly(self):
1313
iterator = BucketIterator(batch_size=2, padding_noise=0, sorting_keys=[('text', 'num_tokens')])
14+
iterator.index_with(self.vocab)
1415
batches = list(iterator._create_batches(self.instances, shuffle=False))
1516
grouped_instances = [batch.instances for batch in batches]
1617
assert grouped_instances == [[self.instances[4], self.instances[2]],
@@ -26,6 +27,7 @@ def test_create_batches_groups_correctly_with_max_instances(self):
2627
padding_noise=0,
2728
sorting_keys=[('text', 'num_tokens')],
2829
max_instances_in_memory=3)
30+
iterator.index_with(self.vocab)
2931
for test_instances in (self.instances, self.lazy_instances):
3032
batches = list(iterator._create_batches(test_instances, shuffle=False))
3133
grouped_instances = [batch.instances for batch in batches]
@@ -38,6 +40,7 @@ def test_biggest_batch_first_works(self):
3840
padding_noise=0,
3941
sorting_keys=[('text', 'num_tokens')],
4042
biggest_batch_first=True)
43+
iterator.index_with(self.vocab)
4144
batches = list(iterator._create_batches(self.instances, shuffle=False))
4245
grouped_instances = [batch.instances for batch in batches]
4346
assert grouped_instances == [[self.instances[3]],
@@ -79,6 +82,7 @@ def test_bucket_iterator_maximum_samples_per_batch(self):
7982
sorting_keys=[('text', 'num_tokens')],
8083
maximum_samples_per_batch=['num_tokens', 9]
8184
)
85+
iterator.index_with(self.vocab)
8286
batches = list(iterator._create_batches(self.instances, shuffle=False))
8387

8488
# ensure all instances are in a batch

allennlp/tests/data/iterators/epoch_tracking_bucket_iterator_test.py

+1
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ def setUp(self):
88
# TextFields.
99
super(EpochTrackingBucketIteratorTest, self).setUp()
1010
self.iterator = EpochTrackingBucketIterator(sorting_keys=[["text", "num_tokens"]])
11+
self.iterator.index_with(self.vocab)
1112
# We'll add more to create a second dataset.
1213
self.more_instances = [
1314
self.create_instance(["this", "is", "a", "sentence"]),

0 commit comments

Comments
 (0)