|
| 1 | +""" |
| 2 | +BiMPM (Bilateral Multi-Perspective Matching) model implementation. |
| 3 | +""" |
| 4 | + |
| 5 | +from typing import Dict, Optional, List, Any |
| 6 | + |
| 7 | +from overrides import overrides |
| 8 | +import torch |
| 9 | + |
| 10 | +from allennlp.common.checks import check_dimensions_match |
| 11 | +from allennlp.data import Vocabulary |
| 12 | +from allennlp.modules import FeedForward, Seq2SeqEncoder, Seq2VecEncoder, TextFieldEmbedder |
| 13 | +from allennlp.models.model import Model |
| 14 | +from allennlp.nn import InitializerApplicator, RegularizerApplicator |
| 15 | +from allennlp.nn import util |
| 16 | +from allennlp.training.metrics import CategoricalAccuracy |
| 17 | + |
| 18 | +from allennlp.modules.bimpm_matching import BiMpmMatching |
| 19 | + |
| 20 | + |
| 21 | +@Model.register("bimpm") |
| 22 | +class BiMpm(Model): |
| 23 | + """ |
| 24 | + This ``Model`` implements BiMPM model described in `Bilateral Multi-Perspective Matching |
| 25 | + for Natural Language Sentences <https://arxiv.org/abs/1702.03814>`_ by Zhiguo Wang et al., 2017. |
| 26 | + Also please refer to the `TensorFlow implementation <https://github.com/zhiguowang/BiMPM/>`_ and |
| 27 | + `PyTorch implementation <https://github.com/galsang/BIMPM-pytorch>`_. |
| 28 | +
|
| 29 | + Parameters |
| 30 | + ---------- |
| 31 | + vocab : ``Vocabulary`` |
| 32 | + text_field_embedder : ``TextFieldEmbedder`` |
| 33 | + Used to embed the ``premise`` and ``hypothesis`` ``TextFields`` we get as input to the |
| 34 | + model. |
| 35 | + matcher_word : ``BiMpmMatching`` |
| 36 | + BiMPM matching on the output of word embeddings of premise and hypothesis. |
| 37 | + encoder1 : ``Seq2SeqEncoder`` |
| 38 | + First encoder layer for the premise and hypothesis |
| 39 | + matcher_forward1 : ``BiMPMMatching`` |
| 40 | + BiMPM matching for the forward output of first encoder layer |
| 41 | + matcher_backward1 : ``BiMPMMatching`` |
| 42 | + BiMPM matching for the backward output of first encoder layer |
| 43 | + encoder2 : ``Seq2SeqEncoder`` |
| 44 | + Second encoder layer for the premise and hypothesis |
| 45 | + matcher_forward2 : ``BiMPMMatching`` |
| 46 | + BiMPM matching for the forward output of second encoder layer |
| 47 | + matcher_backward2 : ``BiMPMMatching`` |
| 48 | + BiMPM matching for the backward output of second encoder layer |
| 49 | + aggregator : ``Seq2VecEncoder`` |
| 50 | + Aggregator of all BiMPM matching vectors |
| 51 | + classifier_feedforward : ``FeedForward`` |
| 52 | + Fully connected layers for classification. |
| 53 | + dropout : ``float``, optional (default=0.1) |
| 54 | + Dropout percentage to use. |
| 55 | + initializer : ``InitializerApplicator``, optional (default=``InitializerApplicator()``) |
| 56 | + If provided, will be used to initialize the model parameters. |
| 57 | + regularizer : ``RegularizerApplicator``, optional (default=``None``) |
| 58 | + If provided, will be used to calculate the regularization penalty during training. |
| 59 | + """ |
| 60 | + def __init__(self, vocab: Vocabulary, |
| 61 | + text_field_embedder: TextFieldEmbedder, |
| 62 | + matcher_word: BiMpmMatching, |
| 63 | + encoder1: Seq2SeqEncoder, |
| 64 | + matcher_forward1: BiMpmMatching, |
| 65 | + matcher_backward1: BiMpmMatching, |
| 66 | + encoder2: Seq2SeqEncoder, |
| 67 | + matcher_forward2: BiMpmMatching, |
| 68 | + matcher_backward2: BiMpmMatching, |
| 69 | + aggregator: Seq2VecEncoder, |
| 70 | + classifier_feedforward: FeedForward, |
| 71 | + dropout: float = 0.1, |
| 72 | + initializer: InitializerApplicator = InitializerApplicator(), |
| 73 | + regularizer: Optional[RegularizerApplicator] = None) -> None: |
| 74 | + super(BiMpm, self).__init__(vocab, regularizer) |
| 75 | + |
| 76 | + self.text_field_embedder = text_field_embedder |
| 77 | + |
| 78 | + self.matcher_word = matcher_word |
| 79 | + |
| 80 | + self.encoder1 = encoder1 |
| 81 | + self.matcher_forward1 = matcher_forward1 |
| 82 | + self.matcher_backward1 = matcher_backward1 |
| 83 | + |
| 84 | + self.encoder2 = encoder2 |
| 85 | + self.matcher_forward2 = matcher_forward2 |
| 86 | + self.matcher_backward2 = matcher_backward2 |
| 87 | + |
| 88 | + self.aggregator = aggregator |
| 89 | + |
| 90 | + matching_dim = self.matcher_word.get_output_dim() + \ |
| 91 | + self.matcher_forward1.get_output_dim() + self.matcher_backward1.get_output_dim() + \ |
| 92 | + self.matcher_forward2.get_output_dim() + self.matcher_backward2.get_output_dim() |
| 93 | + |
| 94 | + check_dimensions_match(matching_dim, self.aggregator.get_input_dim(), |
| 95 | + "sum of dim of all matching layers", "aggregator input dim") |
| 96 | + |
| 97 | + self.classifier_feedforward = classifier_feedforward |
| 98 | + |
| 99 | + self.dropout = torch.nn.Dropout(dropout) |
| 100 | + |
| 101 | + self.metrics = {"accuracy": CategoricalAccuracy()} |
| 102 | + |
| 103 | + self.loss = torch.nn.CrossEntropyLoss() |
| 104 | + |
| 105 | + initializer(self) |
| 106 | + |
| 107 | + @overrides |
| 108 | + def forward(self, # type: ignore |
| 109 | + premise: Dict[str, torch.LongTensor], |
| 110 | + hypothesis: Dict[str, torch.LongTensor], |
| 111 | + label: torch.LongTensor = None, |
| 112 | + metadata: List[Dict[str, Any]] = None # pylint:disable=unused-argument |
| 113 | + ) -> Dict[str, torch.Tensor]: |
| 114 | + # pylint: disable=arguments-differ |
| 115 | + """ |
| 116 | +
|
| 117 | + Parameters |
| 118 | + ---------- |
| 119 | + premise : Dict[str, torch.LongTensor] |
| 120 | + The premise from a ``TextField`` |
| 121 | + hypothesis : Dict[str, torch.LongTensor] |
| 122 | + The hypothesis from a ``TextField`` |
| 123 | + label : torch.LongTensor, optional (default = None) |
| 124 | + The label for the pair of the premise and the hypothesis |
| 125 | + metadata : ``List[Dict[str, Any]]``, optional, (default = None) |
| 126 | + Additional information about the pair |
| 127 | + Returns |
| 128 | + ------- |
| 129 | + An output dictionary consisting of: |
| 130 | +
|
| 131 | + logits : torch.FloatTensor |
| 132 | + A tensor of shape ``(batch_size, num_labels)`` representing unnormalised log |
| 133 | + probabilities of the entailment label. |
| 134 | + loss : torch.FloatTensor, optional |
| 135 | + A scalar loss to be optimised. |
| 136 | + """ |
| 137 | + |
| 138 | + mask_premise = util.get_text_field_mask(premise) |
| 139 | + mask_hypothesis = util.get_text_field_mask(hypothesis) |
| 140 | + |
| 141 | + # embedding and encoding of the premise |
| 142 | + embedded_premise = self.dropout(self.text_field_embedder(premise)) |
| 143 | + encoded_premise1 = self.dropout(self.encoder1(embedded_premise, mask_premise)) |
| 144 | + encoded_premise2 = self.dropout(self.encoder2(encoded_premise1, mask_premise)) |
| 145 | + |
| 146 | + # embedding and encoding of the hypothesis |
| 147 | + embedded_hypothesis = self.dropout(self.text_field_embedder(hypothesis)) |
| 148 | + encoded_hypothesis1 = self.dropout(self.encoder1(embedded_hypothesis, mask_hypothesis)) |
| 149 | + encoded_hypothesis2 = self.dropout(self.encoder2(encoded_hypothesis1, mask_hypothesis)) |
| 150 | + |
| 151 | + matching_vector_premise: List[torch.Tensor] = [] |
| 152 | + matching_vector_hypothesis: List[torch.Tensor] = [] |
| 153 | + |
| 154 | + def add_matching_result(matcher, encoded_premise, encoded_hypothesis): |
| 155 | + # utility function to get matching result and add to the result list |
| 156 | + matching_result = matcher(encoded_premise, mask_premise, encoded_hypothesis, mask_hypothesis) |
| 157 | + matching_vector_premise.extend(matching_result[0]) |
| 158 | + matching_vector_hypothesis.extend(matching_result[1]) |
| 159 | + |
| 160 | + # calculate matching vectors from word embedding, first layer encoding, and second layer encoding |
| 161 | + add_matching_result(self.matcher_word, embedded_premise, embedded_hypothesis) |
| 162 | + half_hidden_size_1 = self.encoder1.get_output_dim() // 2 |
| 163 | + add_matching_result(self.matcher_forward1, |
| 164 | + encoded_premise1[:, :, :half_hidden_size_1], |
| 165 | + encoded_hypothesis1[:, :, :half_hidden_size_1]) |
| 166 | + add_matching_result(self.matcher_backward1, |
| 167 | + encoded_premise1[:, :, half_hidden_size_1:], |
| 168 | + encoded_hypothesis1[:, :, half_hidden_size_1:]) |
| 169 | + |
| 170 | + half_hidden_size_2 = self.encoder2.get_output_dim() // 2 |
| 171 | + add_matching_result(self.matcher_forward2, |
| 172 | + encoded_premise2[:, :, :half_hidden_size_2], |
| 173 | + encoded_hypothesis2[:, :, :half_hidden_size_2]) |
| 174 | + add_matching_result(self.matcher_backward2, |
| 175 | + encoded_premise2[:, :, half_hidden_size_2:], |
| 176 | + encoded_hypothesis2[:, :, half_hidden_size_2:]) |
| 177 | + |
| 178 | + # concat the matching vectors |
| 179 | + matching_vector_cat_premise = self.dropout(torch.cat(matching_vector_premise, dim=2)) |
| 180 | + matching_vector_cat_hypothesis = self.dropout(torch.cat(matching_vector_hypothesis, dim=2)) |
| 181 | + |
| 182 | + # aggregate the matching vectors |
| 183 | + aggregated_premise = self.dropout(self.aggregator(matching_vector_cat_premise, mask_premise)) |
| 184 | + aggregated_hypothesis = self.dropout(self.aggregator(matching_vector_cat_hypothesis, mask_hypothesis)) |
| 185 | + |
| 186 | + # the final forward layer |
| 187 | + logits = self.classifier_feedforward(torch.cat([aggregated_premise, aggregated_hypothesis], dim=-1)) |
| 188 | + |
| 189 | + output_dict = {'logits': logits} |
| 190 | + if label is not None: |
| 191 | + loss = self.loss(logits, label) |
| 192 | + for metric in self.metrics.values(): |
| 193 | + metric(logits, label) |
| 194 | + output_dict["loss"] = loss |
| 195 | + |
| 196 | + return output_dict |
| 197 | + |
| 198 | + @overrides |
| 199 | + def get_metrics(self, reset: bool = False) -> Dict[str, float]: |
| 200 | + return {metric_name: metric.get_metric(reset) for metric_name, metric in self.metrics.items()} |
0 commit comments