Skip to content
This repository was archived by the owner on Dec 16, 2022. It is now read-only.

Commit c78bb36

Browse files
authored
Add a correctness test for Open AI transformer (#1801)
* Add a correctness check for OpenAI BPE encoding * add test fixtures * Add correctness test for open ai * pylint
1 parent e64373c commit c78bb36

File tree

6 files changed

+113
-2
lines changed

6 files changed

+113
-2
lines changed

allennlp/data/token_indexers/openai_transformer_byte_pair_indexer.py

+19
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from typing import Dict, List, Tuple
22
import json
33
import tarfile
4+
import re
45

56
from overrides import overrides
67

@@ -11,6 +12,21 @@
1112
from allennlp.data.tokenizers.token import Token
1213
from allennlp.data.token_indexers.token_indexer import TokenIndexer
1314

15+
def text_standardize(text):
16+
"""
17+
Apply text standardization following original implementation.
18+
"""
19+
# pylint: disable=anomalous-backslash-in-string
20+
text = text.replace('—', '-')
21+
text = text.replace('–', '-')
22+
text = text.replace('―', '-')
23+
text = text.replace('…', '...')
24+
text = text.replace('´', "'")
25+
text = re.sub('''(-+|~+|!+|"+|;+|\?+|\++|,+|\)+|\(+|\\+|\/+|\*+|\[+|\]+|}+|{+|\|+|_+)''', r' \1 ', text)
26+
text = re.sub('\s*\n\s*', ' \n ', text)
27+
text = re.sub('[^\S\n]+', ' ', text)
28+
return text.strip()
29+
1430

1531
@TokenIndexer.register("openai_transformer_byte_pair")
1632
class OpenaiTransformerBytePairIndexer(TokenIndexer[int]):
@@ -21,6 +37,9 @@ class OpenaiTransformerBytePairIndexer(TokenIndexer[int]):
2137
This is unlike most of our TokenIndexers in that its
2238
indexing is not based on a `Vocabulary` but on a fixed
2339
set of mappings that are loaded by the constructor.
40+
41+
Note: the original implementation applied ``text_standardize`` before
42+
tokenizing.
2443
"""
2544
# pylint: disable=no-self-use
2645
def __init__(self,

allennlp/tests/data/token_indexers/openai_transformer_byte_pair_indexer_test.py

+23
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
11
# pylint: disable=no-self-use,invalid-name,protected-access
22
import json
33
import tarfile
4+
import spacy
45

56
import pytest
67

78
from allennlp.common.testing import AllenNlpTestCase
89
from allennlp.data import Token
910
from allennlp.data.token_indexers import OpenaiTransformerBytePairIndexer
11+
from allennlp.data.token_indexers.openai_transformer_byte_pair_indexer import text_standardize
1012
from allennlp.data.vocabulary import Vocabulary
1113

1214

@@ -99,3 +101,24 @@ def test_raises_with_too_long_sentence(self):
99101

100102
with pytest.raises(RuntimeError):
101103
self.indexer.tokens_to_indices(tokens, self.vocab, 'should-fail')
104+
105+
@pytest.mark.skip()
106+
def test_for_correctness_with_fixture(self):
107+
bpe_path = "https://s3-us-west-2.amazonaws.com/allennlp/models/openai-transformer-lm-2018.07.23.tar.gz"
108+
indexer = OpenaiTransformerBytePairIndexer(model_path=bpe_path)
109+
110+
with open(self.FIXTURES_ROOT / 'openai_transformer' / 'text.txt', 'r') as fin:
111+
sentences = fin.read().strip().split('\n')
112+
with open(self.FIXTURES_ROOT / 'openai_transformer' / 'indexed_text.json', 'r') as fin:
113+
expected_indices = json.load(fin)
114+
115+
# tokenize and check that indices are correct
116+
nlp = spacy.load('en_core_web_sm')
117+
118+
for k, sentence in enumerate(sentences):
119+
tokens = [token.text for token in nlp(text_standardize(sentence)) if not token.is_space]
120+
indices = indexer.tokens_to_indices(
121+
[Token(token) for token in tokens], Vocabulary(), 'openai_indexer'
122+
)
123+
non_padded_indices = [i for i in indices['openai_indexer'] if i != 0]
124+
assert non_padded_indices == expected_indices[k]
Binary file not shown.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
[[249, 1925, 485, 6231, 246, 4121, 669, 1662, 939, 715, 1009, 995, 239, 861, 1081, 822, 37700, 606, 1925, 504, 20267, 239], [2703, 13819, 566, 2795, 525, 487, 980, 538, 999, 524, 1114, 589, 850, 239, 246, 267, 305, 285, 267, 67, 3906, 23, 18493, 13103, 43, 38380, 49, 50, 54, 53, 48, 31446, 13103, 43, 13103, 43, 13103, 43, 13103, 15870, 239]]
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
I decided to rent a movie when friends came over last night. After looking through selections we decided on comedy.
2+
James realizes one afternoon that he hasn't left his house all day. A !@#!@OOVTOKEN 1234567890551231231231231.

allennlp/tests/modules/token_embedders/openai_transformer_embedder_test.py

+68-2
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,19 @@
11
# pylint: disable=no-self-use,invalid-name
22
import pytest
3+
import spacy
4+
import torch
5+
import numpy
6+
import h5py
37

4-
from allennlp.common.testing import ModelTestCase
8+
from allennlp.common.testing import ModelTestCase, AllenNlpTestCase
59
from allennlp.data.dataset import Batch
10+
from allennlp.data import Token
11+
from allennlp.data.token_indexers import OpenaiTransformerBytePairIndexer
12+
from allennlp.data.token_indexers.openai_transformer_byte_pair_indexer import text_standardize
13+
from allennlp.data.vocabulary import Vocabulary
14+
from allennlp.modules.openai_transformer import OpenaiTransformer
15+
from allennlp.nn.util import get_range_vector
16+
617

718
# Skip this one, it's an expensive test.
819
@pytest.mark.skip()
@@ -54,6 +65,62 @@ def test_tagger_with_openai_token_embedder_forward_pass_runs_correctly(self):
5465
assert tag in {'O', 'I-ORG', 'I-PER', 'I-LOC'}
5566

5667

68+
# Skip this one, it's an expensive test.
69+
@pytest.mark.skip()
70+
class TestOpenAiTransformerEmbedderCorrectWithFixture(AllenNlpTestCase):
71+
"""
72+
Test that the implementation produces same embeddings as tensorflow model
73+
"""
74+
def test_openai_transformer_matches_tensorflow(self):
75+
model_path = "https://s3-us-west-2.amazonaws.com/allennlp/models/openai-transformer-lm-2018.07.23.tar.gz"
76+
indexer = OpenaiTransformerBytePairIndexer(model_path=model_path)
77+
transformer = OpenaiTransformer(model_path=model_path)
78+
79+
# get the test sentences
80+
with open(self.FIXTURES_ROOT / 'openai_transformer' / 'text.txt', 'r') as fin:
81+
sentences = fin.read().strip().split('\n')
82+
83+
# tokenize and check that indices are correct
84+
nlp = spacy.load('en_core_web_sm')
85+
86+
# make a batch of two sentences
87+
batch_indices = []
88+
batch_lengths = []
89+
for k, sentence in enumerate(sentences):
90+
tokens = [token.text for token in nlp(text_standardize(sentence)) if not token.is_space]
91+
indices = indexer.tokens_to_indices(
92+
[Token(token) for token in tokens], Vocabulary(), 'openai_indexer'
93+
)
94+
batch_indices.append(indices['openai_indexer'])
95+
batch_lengths.append(len([i for i in indices['openai_indexer'] if i != 0]))
96+
batch_indices = torch.from_numpy(numpy.array(batch_indices))
97+
batch_size, num_timesteps = batch_indices.size()
98+
vocab_size = transformer.vocab_size - transformer.n_ctx
99+
positional_encodings = get_range_vector(num_timesteps, device=-1) + vocab_size
100+
101+
# Combine the inputs with positional encodings
102+
batch_tensor = torch.stack([
103+
batch_indices, # (batch_size, num_timesteps)
104+
positional_encodings.expand(batch_size, num_timesteps)
105+
], dim=-1)
106+
107+
# run the LM
108+
transformer.eval()
109+
activations = transformer(batch_tensor)
110+
111+
# load the expected activations
112+
expected_activations = []
113+
with h5py.File(self.FIXTURES_ROOT / 'openai_transformer' / 'expected_embeddings.hdf5', 'r') as fin:
114+
expected_activations.append(fin['0'][...])
115+
expected_activations.append(fin['1'][...])
116+
117+
# just check the top layer
118+
for k in range(2):
119+
actual = activations[-1][k, :batch_lengths[k], :].numpy()
120+
expected = expected_activations[k]
121+
numpy.testing.assert_almost_equal(expected, actual, decimal=5)
122+
123+
57124
def create_small_test_fixture(output_dir: str = '/tmp') -> None:
58125
"""
59126
This is how I created the transformer_model.tar.gz.
@@ -65,7 +132,6 @@ def create_small_test_fixture(output_dir: str = '/tmp') -> None:
65132
"""
66133
import json
67134
import pathlib
68-
from allennlp.modules.openai_transformer import OpenaiTransformer
69135

70136
model_dir = pathlib.Path(output_dir) / 'model'
71137
model_dir.mkdir(exist_ok=True) # pylint: disable=no-member

0 commit comments

Comments
 (0)