From 5a52fd15b49f2d740527c37e47adc7e41e66571d Mon Sep 17 00:00:00 2001 From: Akshita Bhagia Date: Wed, 25 Nov 2020 17:25:05 -0800 Subject: [PATCH 1/4] VE model code --- allennlp/models/vilbert_vqa.py | 209 +++----------- allennlp/models/vision_text_model.py | 256 ++++++++++++++++++ allennlp/models/visual_entailment.py | 107 ++++++++ .../experiment_from_huggingface.jsonnet | 8 +- 4 files changed, 399 insertions(+), 181 deletions(-) create mode 100644 allennlp/models/vision_text_model.py create mode 100644 allennlp/models/visual_entailment.py diff --git a/allennlp/models/vilbert_vqa.py b/allennlp/models/vilbert_vqa.py index 8e21718ded3..414082cd655 100644 --- a/allennlp/models/vilbert_vqa.py +++ b/allennlp/models/vilbert_vqa.py @@ -1,7 +1,6 @@ import collections import logging -from copy import deepcopy -from typing import Dict, List, Optional +from typing import Dict, Optional from overrides import overrides import torch @@ -12,24 +11,31 @@ TextEmbeddings, ImageFeatureEmbeddings, BiModalEncoder, - TransformerPooler, ) from allennlp.nn import util -from transformers.modeling_auto import AutoModel +from allennlp.models.vision_text_model import VisionTextModel + logger = logging.getLogger(__name__) @Model.register("vqa_vilbert") @Model.register("vqa_vilbert_from_huggingface", constructor="from_huggingface_model_name") -class VqaVilbert(Model): +class VqaVilbert(VisionTextModel): """ Model for VQA task based on the VilBERT paper. # Parameters vocab : `Vocabulary` + text_embeddings : `TextEmbeddings` + image_embeddings : `ImageFeatureEmbeddings` + encoder : `BiModalEncoder` + pooled_output_dim : `int` + fusion_method : `str`, optional (default = `"sum"`) + dropout : `float`, optional (default = `0.1`) + label_namespace : `str`, optional (default = `answers`) """ def __init__( @@ -43,7 +49,17 @@ def __init__( dropout: float = 0.1, label_namespace: str = "answers", ) -> None: - super().__init__(vocab) + super().__init__( + vocab, + text_embeddings, + image_embeddings, + encoder, + pooled_output_dim, + fusion_method, + dropout, + label_namespace, + ) + self.loss = torch.nn.BCELoss() self.consistency_wrong_map: Dict[str, int] = collections.Counter() from allennlp.training.metrics import F1MultiLabelMeasure @@ -52,114 +68,6 @@ def __init__( from allennlp.training.metrics.vqa import VqaMeasure self.vqa_metric = VqaMeasure() - self.fusion_method = fusion_method - - self.embeddings = text_embeddings - self.image_embeddings = image_embeddings - self.encoder = encoder - - self.t_pooler = TransformerPooler(encoder.hidden_size1, pooled_output_dim) - self.v_pooler = TransformerPooler(encoder.hidden_size2, pooled_output_dim) - - num_labels = vocab.get_vocab_size(label_namespace) - self.label_namespace = label_namespace - - self.classifier = torch.nn.Linear(pooled_output_dim, num_labels) - self.dropout = torch.nn.Dropout(dropout) - - @classmethod - def from_huggingface_model_name( - cls, - vocab: Vocabulary, - model_name: str, - image_feature_dim: int, - image_num_hidden_layers: int, - image_hidden_size: int, - image_num_attention_heads: int, - image_intermediate_size: int, - image_attention_dropout: float, - image_hidden_dropout: float, - image_biattention_id: List[int], - image_fixed_layer: int, - text_biattention_id: List[int], - text_fixed_layer: int, - combined_hidden_size: int, - combined_num_attention_heads: int, - pooled_output_dim: int, - pooled_dropout: float = 0.1, - fusion_method: str = "sum", - ): - transformer = AutoModel.from_pretrained(model_name) - - # TODO(mattg): This call to `transformer.embeddings` works with some transformers, but I'm - # not sure it works for all of them, or what to do if it fails. - # We should probably pull everything up until the instantiation of the image feature - # embedding out into a central "transformers_util" module, or something, and just have a - # method that pulls an initialized embedding layer out of a huggingface model. One place - # for this somewhat hacky code to live, instead of having to duplicate it in various models. - text_embeddings = deepcopy(transformer.embeddings) - - # Albert (and maybe others?) has this "embedding_size", that's different from "hidden_size". - # To get them to the same dimensionality, it uses a linear transform after the embedding - # layer, which we need to pull out and copy here. - if hasattr(transformer.config, "embedding_size"): - config = transformer.config - - from transformers.modeling_albert import AlbertModel - - if isinstance(transformer, AlbertModel): - linear_transform = deepcopy(transformer.encoder.embedding_hidden_mapping_in) - else: - logger.warning( - "Unknown model that uses separate embedding size; weights of the linear " - f"transform will not be initialized. Model type is: {transformer.__class__}" - ) - linear_transform = torch.nn.Linear(config.embedding_dim, config.hidden_dim) - - # We can't just use torch.nn.Sequential here, even though that's basically all this is, - # because Sequential doesn't accept *inputs, only a single argument. - - class EmbeddingsShim(torch.nn.Module): - def __init__(self, embeddings: torch.nn.Module, linear_transform: torch.nn.Module): - super().__init__() - self.linear_transform = linear_transform - self.embeddings = embeddings - - def forward(self, *inputs, **kwargs): - return self.linear_transform(self.embeddings(*inputs, **kwargs)) - - text_embeddings = EmbeddingsShim(text_embeddings, linear_transform) - - image_embeddings = ImageFeatureEmbeddings( - feature_dim=image_feature_dim, - hidden_dim=image_hidden_size, - dropout=image_hidden_dropout, - ) - - encoder = BiModalEncoder.from_pretrained_module( - pretrained_module=transformer, - num_hidden_layers2=image_num_hidden_layers, - hidden_size2=image_hidden_size, - num_attention_heads2=image_num_attention_heads, - combined_hidden_size=combined_hidden_size, - combined_num_attention_heads=combined_num_attention_heads, - intermediate_size2=image_intermediate_size, - attention_dropout2=image_attention_dropout, - hidden_dropout2=image_hidden_dropout, - biattention_id1=text_biattention_id, - biattention_id2=image_biattention_id, - fixed_layer1=text_fixed_layer, - fixed_layer2=image_fixed_layer, - ) - return cls( - vocab=vocab, - text_embeddings=text_embeddings, - image_embeddings=image_embeddings, - encoder=encoder, - pooled_output_dim=pooled_output_dim, - fusion_method=fusion_method, - dropout=pooled_dropout, - ) @overrides def forward( @@ -171,73 +79,19 @@ def forward( label_weights: Optional[torch.Tensor] = None, ) -> Dict[str, torch.Tensor]: - batch_size, _, feature_size = box_features.size() - - # TODO(mattg): have this make fewer assumptions. - input_ids = question["tokens"]["token_ids"] - token_type_ids = question["tokens"]["type_ids"] - attention_mask = question["tokens"]["mask"] - - # All batch instances will always have the same number of images and boxes, so no masking - # is necessary, and this is just a tensor of ones. - image_attention_mask = torch.ones_like(box_coordinates[:, :, 0]) - - # (batch_size, num_tokens, embedding_dim) - embedding_output = self.embeddings(input_ids, token_type_ids) - num_tokens = embedding_output.size(1) - - # We create a 3D attention mask from a 2D tensor mask. - # Sizes are [batch_size, 1, 1, to_seq_length] - # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] - # this attention mask is more simple than the triangular masking of - # causal attention used in OpenAI GPT, we just need to prepare the - # broadcast dimension here. - extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2).float().log() - extended_image_attention_mask = image_attention_mask.unsqueeze(1).unsqueeze(2).float().log() - - # TODO(matt): it looks like the co-attention logic is all currently commented out; not sure - # that this is necessary. - extended_co_attention_mask = torch.zeros( - batch_size, - feature_size, - num_tokens, - dtype=extended_image_attention_mask.dtype, + return super().forward( + box_features, box_coordinates, text=question, label=labels, label_weights=label_weights ) - # (batch_size, num_boxes, image_embedding_dim) - v_embedding_output = self.image_embeddings(box_features, box_coordinates) - - encoded_layers_t, encoded_layers_v = self.encoder( - embedding_output, - v_embedding_output, - extended_attention_mask, - extended_image_attention_mask, - extended_co_attention_mask, - ) - - sequence_output_t = encoded_layers_t[:, :, :, -1] - sequence_output_v = encoded_layers_v[:, :, :, -1] - - pooled_output_t = self.t_pooler(sequence_output_t) - pooled_output_v = self.v_pooler(sequence_output_v) - - if self.fusion_method == "sum": - pooled_output = self.dropout(pooled_output_t + pooled_output_v) - elif self.fusion_method == "mul": - pooled_output = self.dropout(pooled_output_t * pooled_output_v) - else: - raise ValueError(f"Fusion method '{self.fusion_method}' not supported") - - logits = self.classifier(pooled_output) - probs = torch.sigmoid(logits) - - outputs = {"logits": logits, "probs": probs} - if labels is not None and label_weights is not None: - label_mask = labels > 1 # 0 is padding, 1 is OOV, which we want to ignore + @overrides + def _compute_loss_and_metrics(self, batch_size, outputs, label, label_weights): + if label is not None and label_weights is not None: + logits = outputs["logits"] + label_mask = label > 1 # 0 is padding, 1 is OOV, which we want to ignore weighted_labels = util.masked_index_replace( logits.new_zeros(logits.size() + (1,)), - labels.clamp(min=0), + label.clamp(min=0), label_mask, label_weights.unsqueeze(-1), ).squeeze(-1) @@ -258,7 +112,8 @@ def forward( ) self.f1_metric(logits, weighted_labels, binary_label_mask.bool()) - self.vqa_metric(logits, labels, label_weights) + self.vqa_metric(logits, label, label_weights) + return outputs @overrides diff --git a/allennlp/models/vision_text_model.py b/allennlp/models/vision_text_model.py new file mode 100644 index 00000000000..d441e1cd354 --- /dev/null +++ b/allennlp/models/vision_text_model.py @@ -0,0 +1,256 @@ +import logging +from copy import deepcopy +from typing import Dict, List, Optional + +from overrides import overrides +import numpy as np +import torch + +from allennlp.data import TextFieldTensors, Vocabulary +from allennlp.models.model import Model +from allennlp.modules.transformer import ( + TextEmbeddings, + ImageFeatureEmbeddings, + BiModalEncoder, + TransformerPooler, +) + +from transformers.modeling_auto import AutoModel + +logger = logging.getLogger(__name__) + + +@Model.register("vision_model") +class VisionTextModel(Model): + """ + `VisionTextModel` takes as input a single text input and a single image input + to produce some output. Example tasks include visual question-answering, visual + entailment, etc. + + # Parameters + + vocab : `Vocabulary` + text_embeddings : `TextEmbeddings` + image_embeddings : `ImageFeatureEmbeddings` + encoder : `BiModalEncoder` + pooled_output_dim : `int` + fusion_method : `str`, optional (default = `"sum"`) + dropout : `float`, optional (default = `0.1`) + label_namespace : `str`, optional (default = `"labels"`) + """ + + def __init__( + self, + vocab: Vocabulary, + text_embeddings: TextEmbeddings, + image_embeddings: ImageFeatureEmbeddings, + encoder: BiModalEncoder, + pooled_output_dim: int, + fusion_method: str = "sum", + dropout: float = 0.1, + label_namespace: str = "labels", + ) -> None: + + super().__init__(vocab) + + self.fusion_method = fusion_method + + self.embeddings = text_embeddings + self.image_embeddings = image_embeddings + self.encoder = encoder + + self.t_pooler = TransformerPooler(encoder.hidden_size1, pooled_output_dim) + self.v_pooler = TransformerPooler(encoder.hidden_size2, pooled_output_dim) + + num_labels = vocab.get_vocab_size(label_namespace) + self.label_namespace = label_namespace + + self.classifier = torch.nn.Linear(pooled_output_dim, num_labels) + self.dropout = torch.nn.Dropout(dropout) + + @classmethod + def from_huggingface_model_name( + cls, + vocab: Vocabulary, + model_name: str, + image_feature_dim: int, + image_num_hidden_layers: int, + image_hidden_size: int, + image_num_attention_heads: int, + combined_hidden_size: int, + combined_num_attention_heads: int, + pooled_output_dim: int, + image_intermediate_size: int, + image_attention_dropout: float, + image_hidden_dropout: float, + image_biattention_id: List[int], + text_biattention_id: List[int], + text_fixed_layer: int, + image_fixed_layer: int, + pooled_dropout: float = 0.1, + fusion_method: str = "sum", + ): + transformer = AutoModel.from_pretrained(model_name) + + text_embeddings = deepcopy(transformer.embeddings) + + # Albert (and maybe others?) has this "embedding_size", that's different from "hidden_size". + # To get them to the same dimensionality, it uses a linear transform after the embedding + # layer, which we need to pull out and copy here. + if hasattr(transformer.config, "embedding_size"): + config = transformer.config + + from transformers.modeling_albert import AlbertModel + + if isinstance(transformer, AlbertModel): + linear_transform = deepcopy(transformer.encoder.embedding_hidden_mapping_in) + else: + logger.warning( + "Unknown model that uses separate embedding size; weights of the linear " + f"transform will not be initialized. Model type is: {transformer.__class__}" + ) + linear_transform = torch.nn.Linear(config.embedding_dim, config.hidden_dim) + + # We can't just use torch.nn.Sequential here, even though that's basically all this is, + # because Sequential doesn't accept *inputs, only a single argument. + + class EmbeddingsShim(torch.nn.Module): + def __init__(self, embeddings: torch.nn.Module, linear_transform: torch.nn.Module): + super().__init__() + self.linear_transform = linear_transform + self.embeddings = embeddings + + def forward(self, *inputs, **kwargs): + return self.linear_transform(self.embeddings(*inputs, **kwargs)) + + text_embeddings = EmbeddingsShim(text_embeddings, linear_transform) + + image_embeddings = ImageFeatureEmbeddings( + feature_dim=image_feature_dim, + hidden_dim=image_hidden_size, + dropout=image_hidden_dropout, + ) + + encoder = BiModalEncoder.from_pretrained_module( + pretrained_module=transformer, + num_hidden_layers2=image_num_hidden_layers, + hidden_size2=image_hidden_size, + num_attention_heads2=image_num_attention_heads, + combined_hidden_size=combined_hidden_size, + combined_num_attention_heads=combined_num_attention_heads, + intermediate_size2=image_intermediate_size, + attention_dropout2=image_attention_dropout, + hidden_dropout2=image_hidden_dropout, + biattention_id1=text_biattention_id, + biattention_id2=image_biattention_id, + fixed_layer1=text_fixed_layer, + fixed_layer2=image_fixed_layer, + ) + return cls( + vocab=vocab, + text_embeddings=text_embeddings, + image_embeddings=image_embeddings, + encoder=encoder, + pooled_output_dim=pooled_output_dim, + fusion_method=fusion_method, + dropout=pooled_dropout, + ) + + @overrides + def forward( + self, # type: ignore + box_features: torch.Tensor, + box_coordinates: torch.Tensor, + text: TextFieldTensors, + label: Optional[torch.Tensor] = None, + label_weights: Optional[torch.Tensor] = None, + ) -> Dict[str, torch.Tensor]: + + batch_size, _, feature_size = box_features.size() + + if "token_ids" in text["tokens"]: + token_ids = text["tokens"]["token_ids"] + else: + token_ids = text["tokens"]["tokens"] + + token_type_ids = text["tokens"].get("type_ids") + attention_mask = text["tokens"].get("mask") + + # All batch instances will always have the same number of images and boxes, so no masking + # is necessary, and this is just a tensor of ones. + image_attention_mask = torch.ones_like(box_coordinates[:, :, 0]) + + # (batch_size, num_tokens, embedding_dim) + embedding_output = self.embeddings(token_ids, token_type_ids) + num_tokens = embedding_output.size(1) + + # We create a 3D attention mask from a 2D tensor mask. + # Sizes are [batch_size, 1, 1, to_seq_length] + # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] + # this attention mask is more simple than the triangular masking of + # causal attention used in OpenAI GPT, we just need to prepare the + # broadcast dimension here. + if attention_mask is not None: + extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2).float().log() + else: + extended_attention_mask = None + + extended_image_attention_mask = image_attention_mask.unsqueeze(1).unsqueeze(2).float().log() + + extended_co_attention_mask = torch.zeros( + batch_size, + feature_size, + num_tokens, + dtype=extended_image_attention_mask.dtype, + ) + + # (batch_size, num_boxes, image_embedding_dim) + v_embedding_output = self.image_embeddings(box_features, box_coordinates) + + encoded_layers_t, encoded_layers_v = self.encoder( + embedding_output, + v_embedding_output, + extended_attention_mask, + extended_image_attention_mask, + extended_co_attention_mask, + ) + + sequence_output_t = encoded_layers_t[:, :, :, -1] + sequence_output_v = encoded_layers_v[:, :, :, -1] + + pooled_output_t = self.t_pooler(sequence_output_t) + pooled_output_v = self.v_pooler(sequence_output_v) + + if self.fusion_method == "sum": + pooled_output = self.dropout(pooled_output_t + pooled_output_v) + elif self.fusion_method == "mul": + pooled_output = self.dropout(pooled_output_t * pooled_output_v) + else: + raise ValueError(f"Fusion method '{self.fusion_method}' not supported") + + logits = self.classifier(pooled_output) + probs = torch.sigmoid(logits) + + outputs = {"logits": logits, "probs": probs} + outputs = self._compute_loss_and_metrics(batch_size, outputs, label, label_weights) + + return outputs + + def _compute_loss_and_metrics(self, batch_size, outputs, label, label_weights): + return outputs + + @overrides + def get_metrics(self, reset: bool = False) -> Dict[str, float]: + result = self.accuracy.get_metric(reset) + return {"accuracy": result} + + @overrides + def make_output_human_readable( + self, output_dict: Dict[str, torch.Tensor] + ) -> Dict[str, torch.Tensor]: + batch_labels = [] + for batch_index, batch in enumerate(output_dict["probs"]): + labels = np.argmax(batch, axis=-1) + batch_labels.append(labels) + output_dict["labels"] = batch_labels + return output_dict diff --git a/allennlp/models/visual_entailment.py b/allennlp/models/visual_entailment.py new file mode 100644 index 00000000000..73077a9b5d7 --- /dev/null +++ b/allennlp/models/visual_entailment.py @@ -0,0 +1,107 @@ +import logging +from typing import Dict, Optional + +from overrides import overrides +import numpy as np +import torch + +from allennlp.data import TextFieldTensors, Vocabulary +from allennlp.models.model import Model +from allennlp.modules.transformer import ( + TextEmbeddings, + ImageFeatureEmbeddings, + BiModalEncoder, +) +from allennlp.training.metrics import CategoricalAccuracy + + +from allennlp.models.vision_text_model import VisionTextModel + +logger = logging.getLogger(__name__) + + +@Model.register("ve_vilbert") +@Model.register("ve_vilbert_from_huggingface", constructor="from_huggingface_model_name") +class VisualEntailmentModel(VisionTextModel): + """ + Model for visual entailment task based on the paper + [Visual Entailment: A Novel Task for Fine-Grained Image Understanding] + (https://api.semanticscholar.org/CorpusID:58981654). + + # Parameters + + vocab : `Vocabulary` + text_embeddings : `TextEmbeddings` + image_embeddings : `ImageFeatureEmbeddings` + encoder : `BiModalEncoder` + pooled_output_dim : `int` + fusion_method : `str`, optional (default = `"sum"`) + dropout : `float`, optional (default = `0.1`) + label_namespace : `str`, optional (default = `labels`) + """ + + def __init__( + self, + vocab: Vocabulary, + text_embeddings: TextEmbeddings, + image_embeddings: ImageFeatureEmbeddings, + encoder: BiModalEncoder, + pooled_output_dim: int, + fusion_method: str = "sum", + dropout: float = 0.1, + label_namespace: str = "labels", + ) -> None: + + super().__init__( + vocab, + text_embeddings, + image_embeddings, + encoder, + pooled_output_dim, + fusion_method, + dropout, + label_namespace, + ) + + self.accuracy = CategoricalAccuracy() + + @overrides + def forward( + self, # type: ignore + box_features: torch.Tensor, + box_coordinates: torch.Tensor, + hypothesis: TextFieldTensors, + label: Optional[torch.Tensor] = None, + ) -> Dict[str, torch.Tensor]: + + return super().forward( + box_features, box_coordinates, text=hypothesis, label=label, label_weights=None + ) + + @overrides + def _compute_loss_and_metrics(self, batch_size, outputs, label, label_weights=None): # type: ignore + assert label_weights is None + if label is not None: + outputs["loss"] = ( + torch.nn.functional.cross_entropy(outputs["logits"], label) / batch_size + ) + self.accuracy(outputs["logits"], label) + return outputs + + @overrides + def get_metrics(self, reset: bool = False) -> Dict[str, float]: + result = self.accuracy.get_metric(reset) + return {"accuracy": result} + + @overrides + def make_output_human_readable( + self, output_dict: Dict[str, torch.Tensor] + ) -> Dict[str, torch.Tensor]: + batch_labels = [] + for batch_index, batch in enumerate(output_dict["probs"]): + labels = np.argmax(batch, axis=-1) + batch_labels.append(labels) + output_dict["labels"] = batch_labels + return output_dict + + default_predictor = "vilbert_ve" diff --git a/test_fixtures/vilbert_ve/experiment_from_huggingface.jsonnet b/test_fixtures/vilbert_ve/experiment_from_huggingface.jsonnet index 23dc87bbd27..ff0df50329f 100644 --- a/test_fixtures/vilbert_ve/experiment_from_huggingface.jsonnet +++ b/test_fixtures/vilbert_ve/experiment_from_huggingface.jsonnet @@ -31,11 +31,11 @@ local model_name = "epwalsh/bert-xsmall-dummy"; "image_intermediate_size": 50, "image_attention_dropout": 0.0, "image_hidden_dropout": 0.0, - "v_biattention_id": [0, 1], - "fixed_v_layer": 0, + "image_biattention_id": [0, 1], + "image_fixed_layer": 0, - "t_biattention_id": [0, 1], - "fixed_t_layer": 0, + "text_biattention_id": [0, 1], + "text_fixed_layer": 0, "combined_hidden_size": 200, "combined_num_attention_heads": 2, From 790d582e61e32c991e6f15649b7864e1c4fcdd03 Mon Sep 17 00:00:00 2001 From: Akshita Bhagia Date: Mon, 30 Nov 2020 12:00:19 -0800 Subject: [PATCH 2/4] adding VE model --- .../data/dataset_readers/vision_reader.py | 1 - allennlp/models/__init__.py | 1 + allennlp/predictors/__init__.py | 3 +- allennlp/predictors/visual_entailment.py | 39 ++++++++++ tests/models/visual_entailment_test.py | 75 +++++++++++++++++++ .../vilbert_ve_from_huggingface.jsonnet | 74 ++++++++++++++++++ 6 files changed, 191 insertions(+), 2 deletions(-) create mode 100644 allennlp/predictors/visual_entailment.py create mode 100644 tests/models/visual_entailment_test.py create mode 100644 training_configs/vilbert_ve_from_huggingface.jsonnet diff --git a/allennlp/data/dataset_readers/vision_reader.py b/allennlp/data/dataset_readers/vision_reader.py index 2644335aa29..75ef773acf2 100644 --- a/allennlp/data/dataset_readers/vision_reader.py +++ b/allennlp/data/dataset_readers/vision_reader.py @@ -157,7 +157,6 @@ def _process_image_paths( def yield_batch(): # process the images paths = list(unprocessed_paths) - print(len(paths)) images, sizes = self.image_loader(paths) with torch.no_grad(): images = images.to(self.cuda_device) diff --git a/allennlp/models/__init__.py b/allennlp/models/__init__.py index 0d380c58835..87424122803 100644 --- a/allennlp/models/__init__.py +++ b/allennlp/models/__init__.py @@ -9,3 +9,4 @@ from allennlp.models.multitask import MultiTaskModel from allennlp.models.simple_tagger import SimpleTagger from allennlp.models.vilbert_vqa import VqaVilbert +from allennlp.models.visual_entailment import VisualEntailmentModel diff --git a/allennlp/predictors/__init__.py b/allennlp/predictors/__init__.py index b8ef5943b9a..aa42fb4614b 100644 --- a/allennlp/predictors/__init__.py +++ b/allennlp/predictors/__init__.py @@ -12,6 +12,7 @@ try: from allennlp.predictors.vilbert_vqa import VilbertVqaPredictor + from allennlp.predictors.visual_entailment import VisualEntailmentPredictor except ImportError: - # VilbertVqaPredictor is not available if we don't have detectron. + # vision-based predictors are not available if we don't have detectron. pass diff --git a/allennlp/predictors/visual_entailment.py b/allennlp/predictors/visual_entailment.py new file mode 100644 index 00000000000..329ab34688e --- /dev/null +++ b/allennlp/predictors/visual_entailment.py @@ -0,0 +1,39 @@ +from typing import List, Dict + +from overrides import overrides +import numpy + +from allennlp.common.file_utils import cached_path +from allennlp.common.util import JsonDict +from allennlp.data import Instance +from allennlp.data.dataset_readers.visual_entailment import VisualEntailmentReader +from allennlp.data.fields import LabelField +from allennlp.predictors.predictor import Predictor + + +@Predictor.register("vilbert_ve") +class VisualEntailmentPredictor(Predictor): + def predict(self, image: str, hypothesis: str) -> JsonDict: + image = cached_path(image) + return self.predict_json({"image": image, "hypothesis": hypothesis}) + + @overrides + def _json_to_instance(self, json_dict: JsonDict) -> Instance: + image = cached_path(json_dict["image"]) + hypothesis = json_dict["hypothesis"] + if isinstance(self._dataset_reader, VisualEntailmentReader): + return self._dataset_reader.text_to_instance(image, hypothesis, use_cache=False) + else: + raise ValueError( + f"Dataset reader is of type f{self._dataset_reader.__class__.__name__}. " + f"Expected {VisualEntailmentReader.__name__}." + ) + + @overrides + def predictions_to_labeled_instances( + self, instance: Instance, outputs: Dict[str, numpy.ndarray] + ) -> List[Instance]: + new_instance = instance.duplicate() + label = numpy.argmax(outputs["probs"]) + new_instance.add_field("label", LabelField(int(label), skip_indexing=True)) + return [new_instance] diff --git a/tests/models/visual_entailment_test.py b/tests/models/visual_entailment_test.py new file mode 100644 index 00000000000..1fe6cafcdb1 --- /dev/null +++ b/tests/models/visual_entailment_test.py @@ -0,0 +1,75 @@ +from transformers.modeling_auto import AutoModel + +from allennlp.common.testing import ModelTestCase +from allennlp.data import Vocabulary +from allennlp.models.visual_entailment import VisualEntailmentModel + + +class TestVEVilbert(ModelTestCase): + def test_model_can_train_save_and_load_small_model(self): + param_file = self.FIXTURES_ROOT / "vilbert_ve" / "experiment.jsonnet" + self.ensure_model_can_train_save_and_load(param_file) + + def test_model_can_train_save_and_load_with_cache(self): + import tempfile + + with tempfile.TemporaryDirectory(prefix=self.__class__.__name__) as d: + overrides = {"dataset_reader": {"feature_cache_dir": str(d)}} + import json + + overrides = json.dumps(overrides) + param_file = self.FIXTURES_ROOT / "vilbert_ve" / "experiment.jsonnet" + self.ensure_model_can_train_save_and_load(param_file, overrides=overrides) + + def test_model_can_train_save_and_load_from_huggingface(self): + param_file = self.FIXTURES_ROOT / "vilbert_ve" / "experiment_from_huggingface.jsonnet" + self.ensure_model_can_train_save_and_load(param_file) + + def test_model_loads_weights_correctly(self): + vocab = Vocabulary() + # vocab.add_tokens_to_namespace(["orange", "net", "netting", "pitcher", "catcher"], "answers") + + model_name = "epwalsh/bert-xsmall-dummy" + model = VisualEntailmentModel.from_huggingface_model_name( + vocab=vocab, + model_name=model_name, + image_feature_dim=2048, + image_num_hidden_layers=1, + image_hidden_size=3, + image_num_attention_heads=1, + combined_num_attention_heads=1, + combined_hidden_size=5, + pooled_output_dim=7, + image_intermediate_size=11, + image_attention_dropout=0.0, + image_hidden_dropout=0.0, + image_biattention_id=[0, 1], + text_biattention_id=[0, 1], + text_fixed_layer=0, + image_fixed_layer=0, + ) + + def convert_transformer_param_name(name: str): + # We wrap the encoder in a `TimeDistributed`, which gives us this extra _module. + name = name.replace("layer", "layers1") + name = name.replace("LayerNorm", "layer_norm") + return name + + transformer = AutoModel.from_pretrained(model_name) + model_parameters = dict(model.named_parameters()) + print(list(model_parameters.keys())) + transformer_parameters = dict(transformer.named_parameters()) + print(list(transformer_parameters.keys())) + + # We loop over the transformer parameters here, because the encoder check is easier from + # this side (all of these encoder parameters should match, but that's not true the other way + # around). + for name, parameter in transformer_parameters.items(): + if name.startswith("embeddings"): + # Embedding layer should be identical + assert parameter.allclose(model_parameters[name]) + if name.startswith("encoder"): + # Encoder parameters should also be identical, after we match up the names + # correctly. + our_name = convert_transformer_param_name(name) + assert parameter.allclose(model_parameters[our_name]) diff --git a/training_configs/vilbert_ve_from_huggingface.jsonnet b/training_configs/vilbert_ve_from_huggingface.jsonnet new file mode 100644 index 00000000000..02dc53f689a --- /dev/null +++ b/training_configs/vilbert_ve_from_huggingface.jsonnet @@ -0,0 +1,74 @@ +local model_name = "bert-base-uncased"; +local effective_batch_size = 128; +local gpu_batch_size = 32; +local num_gpus = 4; + +local datadir = "/net/s3/allennlp/akshitab/data/SNLI-VE/data/"; + +{ + "dataset_reader": { + "type": "visual-entailment", + "image_dir": datadir + "Flickr30K/flickr30k_images", + "feature_cache_dir": datadir + "/feature_cache", + "image_loader": "detectron", + "image_featurizer": "resnet_backbone", + "region_detector": "faster_rcnn", + "tokenizer": { + "type": "pretrained_transformer", + "model_name": model_name + }, + "token_indexers": { + "tokens": { + "type": "pretrained_transformer", + "model_name": model_name + } + }, + "max_instances": 30000, + "image_processing_batch_size": 16, + }, + "validation_dataset_reader": self.dataset_reader, + "train_data_path": datadir + "snli_ve_train.jsonl", + "validation_data_path": datadir + "snli_ve_dev.jsonl", + "model": { + "type": "ve_vilbert_from_huggingface", + "model_name": model_name, + "image_feature_dim": 2048, + "image_hidden_size": 1024, + "image_num_attention_heads": 8, + "image_num_hidden_layers": 6, + "combined_hidden_size": 1024, + "combined_num_attention_heads": 8, + "pooled_output_dim": 1024, + "image_intermediate_size": 1024, + "image_attention_dropout": 0.1, + "image_hidden_dropout": 0.1, + "image_biattention_id": [0, 1, 2, 3, 4, 5], + "text_biattention_id": [6, 7, 8, 9, 10, 11], + "text_fixed_layer": 0, + "image_fixed_layer": 0, + "fusion_method": "mul" + }, + "data_loader": { + "batch_size": gpu_batch_size, + "shuffle": true, + "max_instances_in_memory": 1024 + }, + [if num_gpus > 1 then "distributed"]: { + "cuda_devices": std.range(0, num_gpus - 1) + #"cuda_devices": std.repeat([-1], num_gpus) # Use this for debugging on CPU + }, + "trainer": { + "optimizer": { + "type": "huggingface_adamw", + "lr": 4e-5 + }, + "learning_rate_scheduler": { + "type": "linear_with_warmup", + "warmup_steps": 2000, + "num_steps_per_epoch": std.ceil(30000 / $["data_loader"]["batch_size"] / $["trainer"]["num_gradient_accumulation_steps"]) + }, + "validation_metric": "+f1", + "num_epochs": 20, + "num_gradient_accumulation_steps": effective_batch_size / gpu_batch_size / std.max(1, num_gpus) + }, +} From 1af4d06fbc5b2389682e3801db6be1f84126102f Mon Sep 17 00:00:00 2001 From: Akshita Bhagia Date: Tue, 1 Dec 2020 11:00:17 -0800 Subject: [PATCH 3/4] misc minor updates --- allennlp/models/vilbert_vqa.py | 8 +++++++- allennlp/models/vision_text_model.py | 11 +++++++++-- allennlp/models/visual_entailment.py | 8 +++++++- tests/models/visual_entailment_test.py | 1 - 4 files changed, 23 insertions(+), 5 deletions(-) diff --git a/allennlp/models/vilbert_vqa.py b/allennlp/models/vilbert_vqa.py index 414082cd655..e8d66a54ae4 100644 --- a/allennlp/models/vilbert_vqa.py +++ b/allennlp/models/vilbert_vqa.py @@ -84,7 +84,13 @@ def forward( ) @overrides - def _compute_loss_and_metrics(self, batch_size, outputs, label, label_weights): + def _compute_loss_and_metrics( + self, + batch_size: int, + outputs: torch.Tensor, + label: torch.Tensor, + label_weights: Optional[torch.Tensor] = None, + ): if label is not None and label_weights is not None: logits = outputs["logits"] label_mask = label > 1 # 0 is padding, 1 is OOV, which we want to ignore diff --git a/allennlp/models/vision_text_model.py b/allennlp/models/vision_text_model.py index d441e1cd354..51cb3dfb4dc 100644 --- a/allennlp/models/vision_text_model.py +++ b/allennlp/models/vision_text_model.py @@ -6,7 +6,8 @@ import numpy as np import torch -from allennlp.data import TextFieldTensors, Vocabulary +from allennlp.data.fields.text_field import TextFieldTensors +from allennlp.data.vocabulary import Vocabulary from allennlp.models.model import Model from allennlp.modules.transformer import ( TextEmbeddings, @@ -236,7 +237,13 @@ def forward( return outputs - def _compute_loss_and_metrics(self, batch_size, outputs, label, label_weights): + def _compute_loss_and_metrics( + self, + batch_size: int, + outputs: torch.Tensor, + label: torch.Tensor, + label_weights: Optional[torch.Tensor] = None, + ): return outputs @overrides diff --git a/allennlp/models/visual_entailment.py b/allennlp/models/visual_entailment.py index 73077a9b5d7..9c07f2d8e12 100644 --- a/allennlp/models/visual_entailment.py +++ b/allennlp/models/visual_entailment.py @@ -79,7 +79,13 @@ def forward( ) @overrides - def _compute_loss_and_metrics(self, batch_size, outputs, label, label_weights=None): # type: ignore + def _compute_loss_and_metrics( + self, + batch_size: int, + outputs: torch.Tensor, + label: torch.Tensor, + label_weights: Optional[torch.Tensor] = None, + ): assert label_weights is None if label is not None: outputs["loss"] = ( diff --git a/tests/models/visual_entailment_test.py b/tests/models/visual_entailment_test.py index 1fe6cafcdb1..2748f05d7e2 100644 --- a/tests/models/visual_entailment_test.py +++ b/tests/models/visual_entailment_test.py @@ -27,7 +27,6 @@ def test_model_can_train_save_and_load_from_huggingface(self): def test_model_loads_weights_correctly(self): vocab = Vocabulary() - # vocab.add_tokens_to_namespace(["orange", "net", "netting", "pitcher", "catcher"], "answers") model_name = "epwalsh/bert-xsmall-dummy" model = VisualEntailmentModel.from_huggingface_model_name( From acc9f59100155d4d079281ac08b5b62e5a31cf6f Mon Sep 17 00:00:00 2001 From: Akshita Bhagia Date: Tue, 1 Dec 2020 11:56:44 -0800 Subject: [PATCH 4/4] update changelog --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 10db439d70e..cfefc73cb41 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -32,6 +32,7 @@ dataset at every epoch) and a `MultiTaskScheduler` (for ordering the instances w - Added abstraction and demo implementation for an image augmentation module. - Added abstraction and concrete implementation for region detectors. - Transformer toolkit to plug and play with modular components of transformer architectures. +- `VisionReader` and `VisionTextModel` base classes added. `VisualEntailment` and `VQA` inherit from these. ### Changed