Skip to content
This repository was archived by the owner on Dec 16, 2022. It is now read-only.

Commit 3f54fc8

Browse files
joelgrusDeNeutoy
authored andcommitted
make openai transformer byte pair indexer add to the vocab (#1705)
fixes #1700
1 parent 7bf930f commit 3f54fc8

File tree

2 files changed

+29
-5
lines changed

2 files changed

+29
-5
lines changed

allennlp/data/token_indexers/openai_transformer_byte_pair_indexer.py

+14-2
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,10 @@ def __init__(self,
2727
encoder: Dict[str, int] = None,
2828
byte_pairs: List[Tuple[str, str]] = None,
2929
n_ctx: int = 512,
30-
model_path: str = None) -> None:
30+
model_path: str = None,
31+
namespace: str = 'openai_transformer') -> None:
32+
self._namespace = namespace
33+
self._added_to_vocabulary = False
3134

3235
too_much_information = model_path and (encoder or byte_pairs)
3336
too_little_information = not model_path and not (encoder and byte_pairs)
@@ -143,12 +146,21 @@ def byte_pair_encode(self, token: Token, lowercase: bool = True) -> List[str]:
143146
self.cache[text] = word
144147
return word
145148

149+
def _add_encoding_to_vocabulary(self, vocabulary: Vocabulary) -> None:
150+
# pylint: disable=protected-access
151+
for word, idx in self.encoder.items():
152+
vocabulary._token_to_index[self._namespace][word] = idx
153+
vocabulary._index_to_token[self._namespace][idx] = word
146154

147155
@overrides
148156
def tokens_to_indices(self,
149157
tokens: List[Token],
150-
_vocabulary: Vocabulary,
158+
vocabulary: Vocabulary,
151159
index_name: str) -> Dict[str, List[int]]:
160+
if not self._added_to_vocabulary:
161+
self._add_encoding_to_vocabulary(vocabulary)
162+
self._added_to_vocabulary = True
163+
152164
text_tokens = []
153165
offsets = []
154166
offset = -1

allennlp/tests/data/token_indexers/openai_transformer_byte_pair_indexer_test.py

+15-3
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# pylint: disable=no-self-use,invalid-name
1+
# pylint: disable=no-self-use,invalid-name,protected-access
22
import json
33
import tarfile
44

@@ -7,6 +7,7 @@
77
from allennlp.common.testing import AllenNlpTestCase
88
from allennlp.data import Token
99
from allennlp.data.token_indexers import OpenaiTransformerBytePairIndexer
10+
from allennlp.data.vocabulary import Vocabulary
1011

1112

1213
class TestOpenaiTransformerBytePairIndexer(AllenNlpTestCase):
@@ -39,6 +40,7 @@ def setUp(self):
3940
tf.add(bpe_path, 'model/vocab_40000.bpe')
4041

4142
self.indexer = OpenaiTransformerBytePairIndexer(encoding, byte_pairs)
43+
self.vocab = Vocabulary(non_padded_namespaces=['openai_transformer'])
4244

4345
def test_bpe(self):
4446

@@ -63,7 +65,17 @@ def test_bpe(self):
6365
def test_tokens_to_indices(self):
6466
tokens = [Token('ewoe'), Token('woe'), Token('ewe'), Token('ee')]
6567

66-
indices = self.indexer.tokens_to_indices(tokens, None, 'test')
68+
# vocab should be empty initially
69+
assert 'openai_transformer' not in self.vocab._index_to_token
70+
assert 'openai_transformer' not in self.vocab._token_to_index
71+
72+
indices = self.indexer.tokens_to_indices(tokens, self.vocab, 'test')
73+
74+
# vocab should be full now
75+
i2t = self.vocab._index_to_token.get('openai_transformer')
76+
t2i = self.vocab._token_to_index.get('openai_transformer')
77+
assert len(i2t) == 5 * 5 * 2
78+
assert len(t2i) == 5 * 5 * 2
6779

6880
assert set(indices.keys()) == {"test", "test-offsets", "mask"}
6981

@@ -86,4 +98,4 @@ def test_raises_with_too_long_sentence(self):
8698
tokens = [Token('a') for _ in range(513)]
8799

88100
with pytest.raises(RuntimeError):
89-
self.indexer.tokens_to_indices(tokens, None, 'should-fail')
101+
self.indexer.tokens_to_indices(tokens, self.vocab, 'should-fail')

0 commit comments

Comments
 (0)