Skip to content
This repository was archived by the owner on Dec 16, 2022. It is now read-only.

Commit 76a65a8

Browse files
Handsome Zebramatt-gardner
Handsome Zebra
authored andcommitted
BiMPM model (#1594)
* Adding Quora data reader and BiMPM model. * Refactoring and renaming to pass pylint. * Reduce batch size. * Adding docs. * Make title underline longer. * Adding doc toctree. * Various improvements to speed and memory. * Improve comments. * 1. Remove zip file handling. 2. Use allennlp s3 for quora data download. 3. Move masked_max, masked_mean to nn.util 4. Various variable renaming, comments improvements, etc. * Remove unused url pattern match and change num_perspective to num_perspectives.
1 parent 58119c0 commit 76a65a8

20 files changed

+1157
-0
lines changed

allennlp/data/dataset_readers/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -24,3 +24,4 @@
2424
from allennlp.data.dataset_readers.stanford_sentiment_tree_bank import (
2525
StanfordSentimentTreeBankDatasetReader)
2626
from allennlp.data.dataset_readers.wikitables import WikiTablesDatasetReader
27+
from allennlp.data.dataset_readers.quora_paraphrase import QuoraParaphraseDatasetReader
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
from typing import Dict
2+
import logging
3+
import csv
4+
5+
from overrides import overrides
6+
7+
from allennlp.common.file_utils import cached_path
8+
from allennlp.data.dataset_readers.dataset_reader import DatasetReader
9+
from allennlp.data.fields import LabelField, TextField, Field
10+
from allennlp.data.instance import Instance
11+
from allennlp.data.tokenizers import Tokenizer, WordTokenizer
12+
from allennlp.data.tokenizers.word_splitter import JustSpacesWordSplitter
13+
from allennlp.data.token_indexers import TokenIndexer, SingleIdTokenIndexer
14+
15+
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
16+
17+
18+
@DatasetReader.register("quora_paraphrase")
19+
class QuoraParaphraseDatasetReader(DatasetReader):
20+
"""
21+
Reads a file from the Quora Paraphrase dataset. The train/validation/test split of the data
22+
comes from the paper `Bilateral Multi-Perspective Matching for Natural Language Sentences
23+
<https://arxiv.org/abs/1702.03814>`_ by Zhiguo Wang et al., 2017. Each file of the data
24+
is a tsv file without header. The columns are is_duplicate, question1, question2, and id.
25+
All questions are pre-tokenized and tokens are space separated. We convert these keys into
26+
fields named "label", "premise" and "hypothesis", so that it is compatible to some existing
27+
natural language inference algorithms.
28+
29+
Parameters
30+
----------
31+
lazy : ``bool`` (optional, default=False)
32+
Passed to ``DatasetReader``. If this is ``True``, training will start sooner, but will
33+
take longer per batch. This also allows training with datasets that are too large to fit
34+
in memory.
35+
tokenizer : ``Tokenizer``, optional
36+
Tokenizer to use to split the premise and hypothesis into words or other kinds of tokens.
37+
Defaults to ``WordTokenizer(JustSpacesWordSplitter())``.
38+
token_indexers : ``Dict[str, TokenIndexer]``, optional
39+
Indexers used to define input token representations. Defaults to ``{"tokens":
40+
SingleIdTokenIndexer()}``.
41+
"""
42+
def __init__(self,
43+
lazy: bool = False,
44+
tokenizer: Tokenizer = None,
45+
token_indexers: Dict[str, TokenIndexer] = None) -> None:
46+
super().__init__(lazy)
47+
self._tokenizer = tokenizer or WordTokenizer(JustSpacesWordSplitter())
48+
self._token_indexers = token_indexers or {"tokens": SingleIdTokenIndexer()}
49+
50+
@overrides
51+
def _read(self, file_path):
52+
logger.info("Reading instances from lines in file at: %s", file_path)
53+
with open(cached_path(file_path), "r") as data_file:
54+
tsv_in = csv.reader(data_file, delimiter='\t')
55+
for row in tsv_in:
56+
if len(row) == 4:
57+
yield self.text_to_instance(premise=row[1], hypothesis=row[2], label=row[0])
58+
59+
@overrides
60+
def text_to_instance(self, # type: ignore
61+
premise: str,
62+
hypothesis: str,
63+
label: str = None) -> Instance:
64+
# pylint: disable=arguments-differ
65+
fields: Dict[str, Field] = {}
66+
tokenized_premise = self._tokenizer.tokenize(premise)
67+
tokenized_hypothesis = self._tokenizer.tokenize(hypothesis)
68+
fields["premise"] = TextField(tokenized_premise, self._token_indexers)
69+
fields["hypothesis"] = TextField(tokenized_hypothesis, self._token_indexers)
70+
if label is not None:
71+
fields['label'] = LabelField(label)
72+
73+
return Instance(fields)

allennlp/models/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -20,3 +20,4 @@
2020
from allennlp.models.semantic_role_labeler import SemanticRoleLabeler
2121
from allennlp.models.simple_tagger import SimpleTagger
2222
from allennlp.models.esim import ESIM
23+
from allennlp.models.bimpm import BiMpm

allennlp/models/bimpm.py

+200
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,200 @@
1+
"""
2+
BiMPM (Bilateral Multi-Perspective Matching) model implementation.
3+
"""
4+
5+
from typing import Dict, Optional, List, Any
6+
7+
from overrides import overrides
8+
import torch
9+
10+
from allennlp.common.checks import check_dimensions_match
11+
from allennlp.data import Vocabulary
12+
from allennlp.modules import FeedForward, Seq2SeqEncoder, Seq2VecEncoder, TextFieldEmbedder
13+
from allennlp.models.model import Model
14+
from allennlp.nn import InitializerApplicator, RegularizerApplicator
15+
from allennlp.nn import util
16+
from allennlp.training.metrics import CategoricalAccuracy
17+
18+
from allennlp.modules.bimpm_matching import BiMpmMatching
19+
20+
21+
@Model.register("bimpm")
22+
class BiMpm(Model):
23+
"""
24+
This ``Model`` implements BiMPM model described in `Bilateral Multi-Perspective Matching
25+
for Natural Language Sentences <https://arxiv.org/abs/1702.03814>`_ by Zhiguo Wang et al., 2017.
26+
Also please refer to the `TensorFlow implementation <https://github.com/zhiguowang/BiMPM/>`_ and
27+
`PyTorch implementation <https://github.com/galsang/BIMPM-pytorch>`_.
28+
29+
Parameters
30+
----------
31+
vocab : ``Vocabulary``
32+
text_field_embedder : ``TextFieldEmbedder``
33+
Used to embed the ``premise`` and ``hypothesis`` ``TextFields`` we get as input to the
34+
model.
35+
matcher_word : ``BiMpmMatching``
36+
BiMPM matching on the output of word embeddings of premise and hypothesis.
37+
encoder1 : ``Seq2SeqEncoder``
38+
First encoder layer for the premise and hypothesis
39+
matcher_forward1 : ``BiMPMMatching``
40+
BiMPM matching for the forward output of first encoder layer
41+
matcher_backward1 : ``BiMPMMatching``
42+
BiMPM matching for the backward output of first encoder layer
43+
encoder2 : ``Seq2SeqEncoder``
44+
Second encoder layer for the premise and hypothesis
45+
matcher_forward2 : ``BiMPMMatching``
46+
BiMPM matching for the forward output of second encoder layer
47+
matcher_backward2 : ``BiMPMMatching``
48+
BiMPM matching for the backward output of second encoder layer
49+
aggregator : ``Seq2VecEncoder``
50+
Aggregator of all BiMPM matching vectors
51+
classifier_feedforward : ``FeedForward``
52+
Fully connected layers for classification.
53+
dropout : ``float``, optional (default=0.1)
54+
Dropout percentage to use.
55+
initializer : ``InitializerApplicator``, optional (default=``InitializerApplicator()``)
56+
If provided, will be used to initialize the model parameters.
57+
regularizer : ``RegularizerApplicator``, optional (default=``None``)
58+
If provided, will be used to calculate the regularization penalty during training.
59+
"""
60+
def __init__(self, vocab: Vocabulary,
61+
text_field_embedder: TextFieldEmbedder,
62+
matcher_word: BiMpmMatching,
63+
encoder1: Seq2SeqEncoder,
64+
matcher_forward1: BiMpmMatching,
65+
matcher_backward1: BiMpmMatching,
66+
encoder2: Seq2SeqEncoder,
67+
matcher_forward2: BiMpmMatching,
68+
matcher_backward2: BiMpmMatching,
69+
aggregator: Seq2VecEncoder,
70+
classifier_feedforward: FeedForward,
71+
dropout: float = 0.1,
72+
initializer: InitializerApplicator = InitializerApplicator(),
73+
regularizer: Optional[RegularizerApplicator] = None) -> None:
74+
super(BiMpm, self).__init__(vocab, regularizer)
75+
76+
self.text_field_embedder = text_field_embedder
77+
78+
self.matcher_word = matcher_word
79+
80+
self.encoder1 = encoder1
81+
self.matcher_forward1 = matcher_forward1
82+
self.matcher_backward1 = matcher_backward1
83+
84+
self.encoder2 = encoder2
85+
self.matcher_forward2 = matcher_forward2
86+
self.matcher_backward2 = matcher_backward2
87+
88+
self.aggregator = aggregator
89+
90+
matching_dim = self.matcher_word.get_output_dim() + \
91+
self.matcher_forward1.get_output_dim() + self.matcher_backward1.get_output_dim() + \
92+
self.matcher_forward2.get_output_dim() + self.matcher_backward2.get_output_dim()
93+
94+
check_dimensions_match(matching_dim, self.aggregator.get_input_dim(),
95+
"sum of dim of all matching layers", "aggregator input dim")
96+
97+
self.classifier_feedforward = classifier_feedforward
98+
99+
self.dropout = torch.nn.Dropout(dropout)
100+
101+
self.metrics = {"accuracy": CategoricalAccuracy()}
102+
103+
self.loss = torch.nn.CrossEntropyLoss()
104+
105+
initializer(self)
106+
107+
@overrides
108+
def forward(self, # type: ignore
109+
premise: Dict[str, torch.LongTensor],
110+
hypothesis: Dict[str, torch.LongTensor],
111+
label: torch.LongTensor = None,
112+
metadata: List[Dict[str, Any]] = None # pylint:disable=unused-argument
113+
) -> Dict[str, torch.Tensor]:
114+
# pylint: disable=arguments-differ
115+
"""
116+
117+
Parameters
118+
----------
119+
premise : Dict[str, torch.LongTensor]
120+
The premise from a ``TextField``
121+
hypothesis : Dict[str, torch.LongTensor]
122+
The hypothesis from a ``TextField``
123+
label : torch.LongTensor, optional (default = None)
124+
The label for the pair of the premise and the hypothesis
125+
metadata : ``List[Dict[str, Any]]``, optional, (default = None)
126+
Additional information about the pair
127+
Returns
128+
-------
129+
An output dictionary consisting of:
130+
131+
logits : torch.FloatTensor
132+
A tensor of shape ``(batch_size, num_labels)`` representing unnormalised log
133+
probabilities of the entailment label.
134+
loss : torch.FloatTensor, optional
135+
A scalar loss to be optimised.
136+
"""
137+
138+
mask_premise = util.get_text_field_mask(premise)
139+
mask_hypothesis = util.get_text_field_mask(hypothesis)
140+
141+
# embedding and encoding of the premise
142+
embedded_premise = self.dropout(self.text_field_embedder(premise))
143+
encoded_premise1 = self.dropout(self.encoder1(embedded_premise, mask_premise))
144+
encoded_premise2 = self.dropout(self.encoder2(encoded_premise1, mask_premise))
145+
146+
# embedding and encoding of the hypothesis
147+
embedded_hypothesis = self.dropout(self.text_field_embedder(hypothesis))
148+
encoded_hypothesis1 = self.dropout(self.encoder1(embedded_hypothesis, mask_hypothesis))
149+
encoded_hypothesis2 = self.dropout(self.encoder2(encoded_hypothesis1, mask_hypothesis))
150+
151+
matching_vector_premise: List[torch.Tensor] = []
152+
matching_vector_hypothesis: List[torch.Tensor] = []
153+
154+
def add_matching_result(matcher, encoded_premise, encoded_hypothesis):
155+
# utility function to get matching result and add to the result list
156+
matching_result = matcher(encoded_premise, mask_premise, encoded_hypothesis, mask_hypothesis)
157+
matching_vector_premise.extend(matching_result[0])
158+
matching_vector_hypothesis.extend(matching_result[1])
159+
160+
# calculate matching vectors from word embedding, first layer encoding, and second layer encoding
161+
add_matching_result(self.matcher_word, embedded_premise, embedded_hypothesis)
162+
half_hidden_size_1 = self.encoder1.get_output_dim() // 2
163+
add_matching_result(self.matcher_forward1,
164+
encoded_premise1[:, :, :half_hidden_size_1],
165+
encoded_hypothesis1[:, :, :half_hidden_size_1])
166+
add_matching_result(self.matcher_backward1,
167+
encoded_premise1[:, :, half_hidden_size_1:],
168+
encoded_hypothesis1[:, :, half_hidden_size_1:])
169+
170+
half_hidden_size_2 = self.encoder2.get_output_dim() // 2
171+
add_matching_result(self.matcher_forward2,
172+
encoded_premise2[:, :, :half_hidden_size_2],
173+
encoded_hypothesis2[:, :, :half_hidden_size_2])
174+
add_matching_result(self.matcher_backward2,
175+
encoded_premise2[:, :, half_hidden_size_2:],
176+
encoded_hypothesis2[:, :, half_hidden_size_2:])
177+
178+
# concat the matching vectors
179+
matching_vector_cat_premise = self.dropout(torch.cat(matching_vector_premise, dim=2))
180+
matching_vector_cat_hypothesis = self.dropout(torch.cat(matching_vector_hypothesis, dim=2))
181+
182+
# aggregate the matching vectors
183+
aggregated_premise = self.dropout(self.aggregator(matching_vector_cat_premise, mask_premise))
184+
aggregated_hypothesis = self.dropout(self.aggregator(matching_vector_cat_hypothesis, mask_hypothesis))
185+
186+
# the final forward layer
187+
logits = self.classifier_feedforward(torch.cat([aggregated_premise, aggregated_hypothesis], dim=-1))
188+
189+
output_dict = {'logits': logits}
190+
if label is not None:
191+
loss = self.loss(logits, label)
192+
for metric in self.metrics.values():
193+
metric(logits, label)
194+
output_dict["loss"] = loss
195+
196+
return output_dict
197+
198+
@overrides
199+
def get_metrics(self, reset: bool = False) -> Dict[str, float]:
200+
return {metric_name: metric.get_metric(reset) for metric_name, metric in self.metrics.items()}

allennlp/modules/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -22,3 +22,4 @@
2222
from allennlp.modules.matrix_attention import MatrixAttention
2323
from allennlp.modules.attention import Attention
2424
from allennlp.modules.input_variational_dropout import InputVariationalDropout
25+
from allennlp.modules.bimpm_matching import BiMpmMatching

0 commit comments

Comments
 (0)