1
- # pylint: disable=no-self-use,invalid-name
1
+ # pylint: disable=no-self-use,invalid-name,protected-access
2
2
import json
3
3
import tarfile
4
4
7
7
from allennlp .common .testing import AllenNlpTestCase
8
8
from allennlp .data import Token
9
9
from allennlp .data .token_indexers import OpenaiTransformerBytePairIndexer
10
+ from allennlp .data .vocabulary import Vocabulary
10
11
11
12
12
13
class TestOpenaiTransformerBytePairIndexer (AllenNlpTestCase ):
@@ -39,6 +40,7 @@ def setUp(self):
39
40
tf .add (bpe_path , 'model/vocab_40000.bpe' )
40
41
41
42
self .indexer = OpenaiTransformerBytePairIndexer (encoding , byte_pairs )
43
+ self .vocab = Vocabulary (non_padded_namespaces = ['openai_transformer' ])
42
44
43
45
def test_bpe (self ):
44
46
@@ -63,7 +65,17 @@ def test_bpe(self):
63
65
def test_tokens_to_indices (self ):
64
66
tokens = [Token ('ewoe' ), Token ('woe' ), Token ('ewe' ), Token ('ee' )]
65
67
66
- indices = self .indexer .tokens_to_indices (tokens , None , 'test' )
68
+ # vocab should be empty initially
69
+ assert 'openai_transformer' not in self .vocab ._index_to_token
70
+ assert 'openai_transformer' not in self .vocab ._token_to_index
71
+
72
+ indices = self .indexer .tokens_to_indices (tokens , self .vocab , 'test' )
73
+
74
+ # vocab should be full now
75
+ i2t = self .vocab ._index_to_token .get ('openai_transformer' )
76
+ t2i = self .vocab ._token_to_index .get ('openai_transformer' )
77
+ assert len (i2t ) == 5 * 5 * 2
78
+ assert len (t2i ) == 5 * 5 * 2
67
79
68
80
assert set (indices .keys ()) == {"test" , "test-offsets" , "mask" }
69
81
@@ -86,4 +98,4 @@ def test_raises_with_too_long_sentence(self):
86
98
tokens = [Token ('a' ) for _ in range (513 )]
87
99
88
100
with pytest .raises (RuntimeError ):
89
- self .indexer .tokens_to_indices (tokens , None , 'should-fail' )
101
+ self .indexer .tokens_to_indices (tokens , self . vocab , 'should-fail' )
0 commit comments