1
- from typing import Tuple
1
+ from typing import Tuple , Union
2
2
3
3
from overrides import overrides
4
4
import torch
@@ -25,13 +25,14 @@ def __init__(self, scorer: torch.nn.Module) -> None:
25
25
def forward (self , # pylint: disable=arguments-differ
26
26
embeddings : torch .FloatTensor ,
27
27
mask : torch .LongTensor ,
28
- num_items_to_keep : int ) -> Tuple [torch .FloatTensor , torch .LongTensor ,
29
- torch .LongTensor , torch .FloatTensor ]:
28
+ num_items_to_keep : Union [ int , torch . LongTensor ] ) -> Tuple [torch .FloatTensor , torch .LongTensor ,
29
+ torch .LongTensor , torch .FloatTensor ]:
30
30
"""
31
31
Extracts the top-k scoring items with respect to the scorer. We additionally return
32
32
the indices of the top-k in their original order, not ordered by score, so that downstream
33
33
components can rely on the original ordering (e.g., for knowing what spans are valid
34
- antecedents in a coreference resolution model).
34
+ antecedents in a coreference resolution model). May use the same k for all sentences in
35
+ minibatch, or different k for each.
35
36
36
37
Parameters
37
38
----------
@@ -41,26 +42,37 @@ def forward(self, # pylint: disable=arguments-differ
41
42
mask : ``torch.LongTensor``, required.
42
43
A tensor of shape (batch_size, num_items), denoting unpadded elements of
43
44
``embeddings``.
44
- num_items_to_keep : ``int``, required.
45
- The number of items to keep when pruning.
45
+ num_items_to_keep : ``Union[int, torch.LongTensor]``, required.
46
+ If a tensor of shape (batch_size), specifies the number of items to keep for each
47
+ individual sentence in minibatch.
48
+ If an int, keep the same number of items for all sentences.
46
49
47
50
Returns
48
51
-------
49
52
top_embeddings : ``torch.FloatTensor``
50
53
The representations of the top-k scoring items.
51
- Has shape (batch_size, num_items_to_keep , embedding_size).
54
+ Has shape (batch_size, max_num_items_to_keep , embedding_size).
52
55
top_mask : ``torch.LongTensor``
53
56
The corresponding mask for ``top_embeddings``.
54
- Has shape (batch_size, num_items_to_keep ).
57
+ Has shape (batch_size, max_num_items_to_keep ).
55
58
top_indices : ``torch.IntTensor``
56
59
The indices of the top-k scoring items into the original ``embeddings``
57
60
tensor. This is returned because it can be useful to retain pointers to
58
61
the original items, if each item is being scored by multiple distinct
59
- scorers, for instance. Has shape (batch_size, num_items_to_keep ).
62
+ scorers, for instance. Has shape (batch_size, max_num_items_to_keep ).
60
63
top_item_scores : ``torch.FloatTensor``
61
64
The values of the top-k scoring items.
62
- Has shape (batch_size, num_items_to_keep , 1).
65
+ Has shape (batch_size, max_num_items_to_keep , 1).
63
66
"""
67
+ # If an int was given for number of items to keep, construct tensor by repeating the value.
68
+ if isinstance (num_items_to_keep , int ):
69
+ batch_size = mask .size (0 )
70
+ # Put the tensor on same device as the mask.
71
+ num_items_to_keep = num_items_to_keep * torch .ones ([batch_size ], dtype = torch .long ,
72
+ device = mask .device )
73
+
74
+ max_items_to_keep = num_items_to_keep .max ()
75
+
64
76
mask = mask .unsqueeze (- 1 )
65
77
num_items = embeddings .size (1 )
66
78
# Shape: (batch_size, num_items, 1)
@@ -73,28 +85,47 @@ def forward(self, # pylint: disable=arguments-differ
73
85
# negative. These are logits, typically, so -1e20 should be plenty negative.
74
86
scores = util .replace_masked_values (scores , mask , - 1e20 )
75
87
76
- # Shape: (batch_size, num_items_to_keep, 1)
77
- _ , top_indices = scores .topk (num_items_to_keep , 1 )
88
+ # Shape: (batch_size, max_num_items_to_keep, 1)
89
+ _ , top_indices = scores .topk (max_items_to_keep , 1 )
90
+
91
+ # Mask based on number of items to keep for each sentence.
92
+ # Shape: (batch_size, max_num_items_to_keep)
93
+ top_indices_mask = util .get_mask_from_sequence_lengths (num_items_to_keep , max_items_to_keep )
94
+ top_indices_mask = top_indices_mask .byte ()
95
+
96
+ # Shape: (batch_size, max_num_items_to_keep)
97
+ top_indices = top_indices .squeeze (- 1 )
98
+
99
+ # Fill all masked indices with largest "top" index for that sentence, so that all masked
100
+ # indices will be sorted to the end.
101
+ # Shape: (batch_size, 1)
102
+ fill_value , _ = top_indices .max (dim = 1 )
103
+ fill_value = fill_value .unsqueeze (- 1 )
104
+ # Shape: (batch_size, max_num_items_to_keep)
105
+ top_indices = torch .where (top_indices_mask , top_indices , fill_value )
78
106
79
107
# Now we order the selected indices in increasing order with
80
108
# respect to their indices (and hence, with respect to the
81
109
# order they originally appeared in the ``embeddings`` tensor).
82
110
top_indices , _ = torch .sort (top_indices , 1 )
83
111
84
- # Shape: (batch_size, num_items_to_keep)
85
- top_indices = top_indices .squeeze (- 1 )
86
-
87
- # Shape: (batch_size * num_items_to_keep)
112
+ # Shape: (batch_size * max_num_items_to_keep)
88
113
# torch.index_select only accepts 1D indices, but here
89
114
# we need to select items for each element in the batch.
90
115
flat_top_indices = util .flatten_and_batch_shift_indices (top_indices , num_items )
91
116
92
- # Shape: (batch_size, num_items_to_keep , embedding_size)
117
+ # Shape: (batch_size, max_num_items_to_keep , embedding_size)
93
118
top_embeddings = util .batched_index_select (embeddings , top_indices , flat_top_indices )
94
- # Shape: (batch_size, num_items_to_keep)
95
- top_mask = util .batched_index_select (mask , top_indices , flat_top_indices )
96
119
97
- # Shape: (batch_size, num_items_to_keep, 1)
120
+ # Combine the masks on spans that are out-of-bounds, and the mask on spans that are outside
121
+ # the top k for each sentence.
122
+ # Shape: (batch_size, max_num_items_to_keep)
123
+ sequence_mask = util .batched_index_select (mask , top_indices , flat_top_indices )
124
+ sequence_mask = sequence_mask .squeeze (- 1 ).byte ()
125
+ top_mask = top_indices_mask & sequence_mask
126
+ top_mask = top_mask .long ()
127
+
128
+ # Shape: (batch_size, max_num_items_to_keep, 1)
98
129
top_scores = util .batched_index_select (scores , top_indices , flat_top_indices )
99
130
100
- return top_embeddings , top_mask . squeeze ( - 1 ) , top_indices , top_scores
131
+ return top_embeddings , top_mask , top_indices , top_scores
0 commit comments