Skip to content
This repository was archived by the owner on Dec 16, 2022. It is now read-only.

Removes confusing zero mask from VilBERT #5264

Merged
merged 5 commits into from
Jun 17, 2021
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

- Fixed Broken link in `allennlp.fairness.fairness_metrics.Separation` docs
- Ensured all `allennlp` submodules are imported with `allennlp.common.plugins.import_plugins()`.
- Removed confusing zero mask from VilBERT


## [v2.5.0](https://github.com/allenai/allennlp/releases/tag/v2.5.0) - 2021-06-03
Expand Down
11 changes: 0 additions & 11 deletions allennlp/modules/backbones/vilbert_backbone.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,6 @@ def forward(

# Shape: (rolled_dimensions_product, num_tokens, embedding_dim)
embedding_output = self.text_embeddings(token_ids, token_type_ids)
num_tokens = embedding_output.size(1)

# this attention mask is more simple than the triangular masking of
# causal attention used in OpenAI GPT, we just need to prepare the
Expand All @@ -168,15 +167,6 @@ def forward(

extended_image_attention_mask = box_mask

# Shape: (rolled_dimensions_product, feature_size, num_tokens)
# TODO (epwalsh): Why all zeros?? This doesn't seem right.
extended_co_attention_mask = torch.zeros(
extended_image_attention_mask.shape[0],
feature_size,
num_tokens,
dtype=extended_image_attention_mask.dtype,
)

# Shape: (rolled_dimensions_product, num_boxes, image_embedding_dim)
v_embedding_output = self.image_embeddings(box_features, box_coordinates)

Expand All @@ -185,7 +175,6 @@ def forward(
v_embedding_output,
extended_attention_mask,
extended_image_attention_mask,
extended_co_attention_mask,
)

# Shape: (rolled_dimensions_product, num_tokens, embedding_dim)
Expand Down
9 changes: 3 additions & 6 deletions allennlp/modules/transformer/bimodal_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,8 +118,7 @@ def forward(
input_tensor2,
attention_mask1=None,
attention_mask2=None,
co_attention_mask=None, # TODO: is this flag necessary?
use_co_attention_mask=False,
co_attention_mask=None,
):
"""
# Parameters
Expand All @@ -144,8 +143,6 @@ def forward(
about the interaction between the two modalities. For example,
if you know which words correspond to which regions in the image,
this mask can be applied to limit the attention given the bias.
use_co_attention_mask : `bool`
Whether to use co_attention_mask or not, default = `False`.
"""

# for the first modality:
Expand All @@ -170,7 +167,7 @@ def forward(
attention_scores1 = self.attn1(query_layer2, key_layer1.transpose(-1, -2))
if attention_mask1 is not None:
attention_scores1 = apply_mask(attention_scores1, attention_mask1)
if use_co_attention_mask:
if co_attention_mask is not None:
attention_scores1 = apply_mask(attention_scores1, co_attention_mask.permute(0, 1, 3, 2))

attention_probs1 = torch.nn.Softmax(dim=-1)(attention_scores1)
Expand All @@ -189,7 +186,7 @@ def forward(
# we can comment this line for single flow.
if attention_mask2 is not None:
attention_scores2 = apply_mask(attention_scores2, attention_mask2)
if use_co_attention_mask:
if co_attention_mask is not None:
attention_scores2 = apply_mask(attention_scores2, co_attention_mask)

attention_probs2 = torch.nn.Softmax(dim=-1)(attention_scores2)
Expand Down
2 changes: 0 additions & 2 deletions allennlp/modules/transformer/bimodal_connection_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,6 @@ def forward(
input_tensor2,
attention_mask2,
co_attention_mask=None,
use_co_attention_mask=False,
):

bi_output1, bi_output2 = self.bimodal_attention(
Expand All @@ -101,7 +100,6 @@ def forward(
attention_mask1,
attention_mask2,
co_attention_mask,
use_co_attention_mask,
)

attention_output1, attention_output2 = self.bimodal_output(
Expand Down
15 changes: 7 additions & 8 deletions allennlp/modules/transformer/bimodal_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,6 @@ def forward(
batch_size, num_words, hidden_size1 = embedding1.size()
_, num_regions, hidden_size2 = embedding2.size()

use_co_attention_mask = False
for layer_id2, layer_id1 in zip(self.biattention_id2, self.biattention_id1):
end1 = layer_id1
end2 = layer_id2
Expand Down Expand Up @@ -191,12 +190,13 @@ def forward(
.contiguous()
.view(batch_size * batch_size, 1, 1, num_words)
)
co_attention_mask = (
co_attention_mask.unsqueeze(1)
.expand(batch_size, batch_size, 1, num_regions, num_words)
.contiguous()
.view(batch_size * batch_size, 1, num_regions, num_words)
)
if co_attention_mask is not None:
co_attention_mask = (
co_attention_mask.unsqueeze(1)
.expand(batch_size, batch_size, 1, num_regions, num_words)
.contiguous()
.view(batch_size * batch_size, 1, num_regions, num_words)
)

if count == 0 and self.FAST_MODE:
embedding1 = embedding1.expand(
Expand All @@ -218,7 +218,6 @@ def forward(
embedding2,
attention_mask2,
co_attention_mask,
use_co_attention_mask,
)

start2 = end2
Expand Down