Skip to content
This repository was archived by the owner on Dec 16, 2022. It is now read-only.

Commit f19c0ee

Browse files
David WaddenDeNeutoy
David Wadden
authored andcommitted
Enable Pruner class to keep different number of items for different entries in minibatch. (#2511)
This is a somewhat special case. There are occasionally situations where entries in a minibatch need to be, e.g., ordered sentences from the same document. Coreference resolution is an example. In AllenNLP, an entire document is considered as a single entry, but for very long documents (or on low-memory machines) it might be necessary to split up the document into minibatches during training. In this situation, smart batching can't be used, and entries in the same minibatch may have widely varying lengths. Keeping the same number of span candidates for each entry may not be desirable. This PR does the same thing as before if an integer is passed to `num_items_to_keep`. If a tensor is passed instead, it keeps the desired number of items for each minibatch entry.
1 parent 3cdb7e2 commit f19c0ee

File tree

2 files changed

+130
-22
lines changed

2 files changed

+130
-22
lines changed

allennlp/modules/pruner.py

+52-21
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Tuple
1+
from typing import Tuple, Union
22

33
from overrides import overrides
44
import torch
@@ -25,13 +25,14 @@ def __init__(self, scorer: torch.nn.Module) -> None:
2525
def forward(self, # pylint: disable=arguments-differ
2626
embeddings: torch.FloatTensor,
2727
mask: torch.LongTensor,
28-
num_items_to_keep: int) -> Tuple[torch.FloatTensor, torch.LongTensor,
29-
torch.LongTensor, torch.FloatTensor]:
28+
num_items_to_keep: Union[int, torch.LongTensor]) -> Tuple[torch.FloatTensor, torch.LongTensor,
29+
torch.LongTensor, torch.FloatTensor]:
3030
"""
3131
Extracts the top-k scoring items with respect to the scorer. We additionally return
3232
the indices of the top-k in their original order, not ordered by score, so that downstream
3333
components can rely on the original ordering (e.g., for knowing what spans are valid
34-
antecedents in a coreference resolution model).
34+
antecedents in a coreference resolution model). May use the same k for all sentences in
35+
minibatch, or different k for each.
3536
3637
Parameters
3738
----------
@@ -41,26 +42,37 @@ def forward(self, # pylint: disable=arguments-differ
4142
mask : ``torch.LongTensor``, required.
4243
A tensor of shape (batch_size, num_items), denoting unpadded elements of
4344
``embeddings``.
44-
num_items_to_keep : ``int``, required.
45-
The number of items to keep when pruning.
45+
num_items_to_keep : ``Union[int, torch.LongTensor]``, required.
46+
If a tensor of shape (batch_size), specifies the number of items to keep for each
47+
individual sentence in minibatch.
48+
If an int, keep the same number of items for all sentences.
4649
4750
Returns
4851
-------
4952
top_embeddings : ``torch.FloatTensor``
5053
The representations of the top-k scoring items.
51-
Has shape (batch_size, num_items_to_keep, embedding_size).
54+
Has shape (batch_size, max_num_items_to_keep, embedding_size).
5255
top_mask : ``torch.LongTensor``
5356
The corresponding mask for ``top_embeddings``.
54-
Has shape (batch_size, num_items_to_keep).
57+
Has shape (batch_size, max_num_items_to_keep).
5558
top_indices : ``torch.IntTensor``
5659
The indices of the top-k scoring items into the original ``embeddings``
5760
tensor. This is returned because it can be useful to retain pointers to
5861
the original items, if each item is being scored by multiple distinct
59-
scorers, for instance. Has shape (batch_size, num_items_to_keep).
62+
scorers, for instance. Has shape (batch_size, max_num_items_to_keep).
6063
top_item_scores : ``torch.FloatTensor``
6164
The values of the top-k scoring items.
62-
Has shape (batch_size, num_items_to_keep, 1).
65+
Has shape (batch_size, max_num_items_to_keep, 1).
6366
"""
67+
# If an int was given for number of items to keep, construct tensor by repeating the value.
68+
if isinstance(num_items_to_keep, int):
69+
batch_size = mask.size(0)
70+
# Put the tensor on same device as the mask.
71+
num_items_to_keep = num_items_to_keep * torch.ones([batch_size], dtype=torch.long,
72+
device=mask.device)
73+
74+
max_items_to_keep = num_items_to_keep.max()
75+
6476
mask = mask.unsqueeze(-1)
6577
num_items = embeddings.size(1)
6678
# Shape: (batch_size, num_items, 1)
@@ -73,28 +85,47 @@ def forward(self, # pylint: disable=arguments-differ
7385
# negative. These are logits, typically, so -1e20 should be plenty negative.
7486
scores = util.replace_masked_values(scores, mask, -1e20)
7587

76-
# Shape: (batch_size, num_items_to_keep, 1)
77-
_, top_indices = scores.topk(num_items_to_keep, 1)
88+
# Shape: (batch_size, max_num_items_to_keep, 1)
89+
_, top_indices = scores.topk(max_items_to_keep, 1)
90+
91+
# Mask based on number of items to keep for each sentence.
92+
# Shape: (batch_size, max_num_items_to_keep)
93+
top_indices_mask = util.get_mask_from_sequence_lengths(num_items_to_keep, max_items_to_keep)
94+
top_indices_mask = top_indices_mask.byte()
95+
96+
# Shape: (batch_size, max_num_items_to_keep)
97+
top_indices = top_indices.squeeze(-1)
98+
99+
# Fill all masked indices with largest "top" index for that sentence, so that all masked
100+
# indices will be sorted to the end.
101+
# Shape: (batch_size, 1)
102+
fill_value, _ = top_indices.max(dim=1)
103+
fill_value = fill_value.unsqueeze(-1)
104+
# Shape: (batch_size, max_num_items_to_keep)
105+
top_indices = torch.where(top_indices_mask, top_indices, fill_value)
78106

79107
# Now we order the selected indices in increasing order with
80108
# respect to their indices (and hence, with respect to the
81109
# order they originally appeared in the ``embeddings`` tensor).
82110
top_indices, _ = torch.sort(top_indices, 1)
83111

84-
# Shape: (batch_size, num_items_to_keep)
85-
top_indices = top_indices.squeeze(-1)
86-
87-
# Shape: (batch_size * num_items_to_keep)
112+
# Shape: (batch_size * max_num_items_to_keep)
88113
# torch.index_select only accepts 1D indices, but here
89114
# we need to select items for each element in the batch.
90115
flat_top_indices = util.flatten_and_batch_shift_indices(top_indices, num_items)
91116

92-
# Shape: (batch_size, num_items_to_keep, embedding_size)
117+
# Shape: (batch_size, max_num_items_to_keep, embedding_size)
93118
top_embeddings = util.batched_index_select(embeddings, top_indices, flat_top_indices)
94-
# Shape: (batch_size, num_items_to_keep)
95-
top_mask = util.batched_index_select(mask, top_indices, flat_top_indices)
96119

97-
# Shape: (batch_size, num_items_to_keep, 1)
120+
# Combine the masks on spans that are out-of-bounds, and the mask on spans that are outside
121+
# the top k for each sentence.
122+
# Shape: (batch_size, max_num_items_to_keep)
123+
sequence_mask = util.batched_index_select(mask, top_indices, flat_top_indices)
124+
sequence_mask = sequence_mask.squeeze(-1).byte()
125+
top_mask = top_indices_mask & sequence_mask
126+
top_mask = top_mask.long()
127+
128+
# Shape: (batch_size, max_num_items_to_keep, 1)
98129
top_scores = util.batched_index_select(scores, top_indices, flat_top_indices)
99130

100-
return top_embeddings, top_mask.squeeze(-1), top_indices, top_scores
131+
return top_embeddings, top_mask, top_indices, top_scores

allennlp/tests/modules/pruner_test.py

+78-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# pylint: disable=no-self-use,invalid-name
1+
# pylint: disable=no-self-use,invalid-name,not-callable
22
import numpy
33
import pytest
44
import torch
@@ -83,3 +83,80 @@ def test_scorer_works_for_completely_masked_rows(self):
8383
numpy.testing.assert_array_equal(correct_scores[:2], pruned_scores[:2].data.numpy())
8484
numpy.testing.assert_array_equal(pruned_scores[2] < -1e15, [[1], [1]])
8585
numpy.testing.assert_array_equal(pruned_scores[2] == float('-inf'), [[0], [0]])
86+
87+
def test_pruner_selects_top_scored_items_and_respects_masking_different_num_items(self):
88+
# Really simple scorer - sum up the embedding_dim.
89+
scorer = lambda tensor: tensor.sum(-1).unsqueeze(-1)
90+
pruner = Pruner(scorer=scorer)
91+
92+
items = torch.randn([3, 4, 5]).clamp(min=0.0, max=1.0)
93+
items[0, 0, :] = 1.5
94+
items[0, 1, :] = 2
95+
items[0, 3, :] = 1
96+
items[1, 1:3, :] = 1
97+
items[2, 0, :] = 1
98+
items[2, 1, :] = 2
99+
items[2, 2, :] = 1.5
100+
101+
mask = torch.ones([3, 4])
102+
mask[1, 3] = 0
103+
104+
num_items_to_keep = torch.tensor([3, 2, 1], dtype=torch.long)
105+
106+
pruned_embeddings, pruned_mask, pruned_indices, pruned_scores = pruner(
107+
items, mask, num_items_to_keep)
108+
109+
# Second element in the batch would have indices 2, 3, but
110+
# 3 and 0 are masked, so instead it has 1, 2.
111+
numpy.testing.assert_array_equal(pruned_indices.data.numpy(), numpy.array([[0, 1, 3],
112+
[1, 2, 2],
113+
[1, 2, 2]]))
114+
numpy.testing.assert_array_equal(pruned_mask.data.numpy(), numpy.array([[1, 1, 1],
115+
[1, 1, 0],
116+
[1, 0, 0]]))
117+
118+
# embeddings should be the result of index_selecting the pruned_indices.
119+
correct_embeddings = batched_index_select(items, pruned_indices)
120+
numpy.testing.assert_array_equal(correct_embeddings.data.numpy(),
121+
pruned_embeddings.data.numpy())
122+
# scores should be the sum of the correct embedding elements.
123+
numpy.testing.assert_array_equal(correct_embeddings.sum(-1).unsqueeze(-1).data.numpy(),
124+
pruned_scores.data.numpy())
125+
126+
def test_pruner_works_for_row_with_no_items_requested(self):
127+
# Case where `num_items_to_keep` is a tensor rather than an int. Make sure it does the right
128+
# thing when no items are requested for one of the rows.
129+
scorer = lambda tensor: tensor.sum(-1).unsqueeze(-1)
130+
pruner = Pruner(scorer=scorer)
131+
132+
items = torch.randn([3, 4, 5]).clamp(min=0.0, max=1.0)
133+
items[0, :3, :] = 1
134+
items[1, 2:, :] = 1
135+
items[2, 2:, :] = 1
136+
137+
mask = torch.ones([3, 4])
138+
mask[1, 0] = 0
139+
mask[1, 3] = 0
140+
141+
num_items_to_keep = torch.tensor([3, 2, 0], dtype=torch.long)
142+
143+
pruned_embeddings, pruned_mask, pruned_indices, pruned_scores = pruner(
144+
items, mask, num_items_to_keep)
145+
146+
# First element just picks top three entries. Second would pick entries 2 and 3, but 0 and 3
147+
# are masked, so it takes 1 and 2 (repeating the second index). The third element is
148+
# entirely masked and just repeats the largest index with a top-3 score.
149+
numpy.testing.assert_array_equal(pruned_indices.data.numpy(), numpy.array([[0, 1, 2],
150+
[1, 2, 2],
151+
[3, 3, 3]]))
152+
numpy.testing.assert_array_equal(pruned_mask.data.numpy(), numpy.array([[1, 1, 1],
153+
[1, 1, 0],
154+
[0, 0, 0]]))
155+
156+
# embeddings should be the result of index_selecting the pruned_indices.
157+
correct_embeddings = batched_index_select(items, pruned_indices)
158+
numpy.testing.assert_array_equal(correct_embeddings.data.numpy(),
159+
pruned_embeddings.data.numpy())
160+
# scores should be the sum of the correct embedding elements.
161+
numpy.testing.assert_array_equal(correct_embeddings.sum(-1).unsqueeze(-1).data.numpy(),
162+
pruned_scores.data.numpy())

0 commit comments

Comments
 (0)