-
Notifications
You must be signed in to change notification settings - Fork 8
/
Copy pathtrain_char_level_bpe.py
124 lines (110 loc) · 3.82 KB
/
train_char_level_bpe.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
import argparse
import glob
from os.path import join
from loguru import logger
from tokenizers import CharBPETokenizer
"""
The original BPE tokenizer, as proposed in Sennrich, Haddow and Birch, Neural Machine Translation
of Rare Words with Subword Units. ACL 2016
https://arxiv.org/abs/1508.07909
https://github.com/rsennrich/subword-nmt
https://github.com/EdinburghNLP/nematus
BPE algorithm explanation:
BPE first splits the whole sentence intoindividual characters.
The most frequent adjacent pairs of characters are then consecutively
merged until reaching a desired vocabulary size. Subword segmentation is
performed by applying the same merge operations to the test sentence.
Frequent sub-strings will be joined early, resulting in common words
remaining as one unique symbol. Words consisting of rare character
combinations will be split into smaller units - substrings or characters
For example, given a Dictionary with following word frequencies:
```
5 low
2 lower
6 newest
3 widest
```
- a starting Vocabulary with all the characters is initialized:
`{l,o,w,e,r,n,w,s,t,i,d}`
- `es` is the most common 2-byte (two character) subsequence, it appears 9 times, so add it to vocab:
`{l,o,w,e,r,n,w,s,t,i,d, es}`
- `es t` is now the most common subseq, append it to Vocabulary too:
`{l,o,w,e,r,n,w,s,t,i,d, es, est}`
- then `lo` appears 7 times:
`{l,o,w,e,r,n,w,s,t,i,d, es, est, lo}`
- then `lo w`:
`{l,o,w,e,r,n,w,s,t,i,d, es, est, lo, low}`
- continue indefintitely until we reach a pre-defined vocabulary length
Example usage:
python train_char_level_bpe.py --files /Users/flp/Box/Molecular_SysBio/data/paccmann/paccmann_proteomics/uniprot_sprot/uniprot_sprot_100_seq.txt --out /Users/flp/Box/Molecular_SysBio/data/paccmann/paccmann_proteomics/tokenized_uniprot_sprot/tests --vocab_size
Important:
- Adds special end-of-word token (or a suffix) "</w>",
e.g., word `tokenization` becomes [‘to’, ‘ken’, ‘ization</w>’]
- If needed, can limit initial alphabet size with `limit_alphabet: int`
"""
parser = argparse.ArgumentParser()
parser.add_argument(
'--files',
default=None,
metavar='path',
type=str,
required=True,
help='The files to use as training; accept a string in format `"**/*.txt"`'
)
parser.add_argument(
'--out',
# default='./',
type=str,
required=True,
help='Path to the output directory, where the files will be saved'
)
parser.add_argument(
'--name',
default='char-bpe',
type=str,
help='The name of the output vocab files',
)
parser.add_argument(
'--vocab_size',
default=30000,
type=int,
required=True,
help='Vocabulary size',
)
parser.add_argument(
'--limit_alphabet',
default=100,
type=int,
help='The size of alphabet character set (e.g., for English, |alphabet|=26)',
)
args = parser.parse_args()
files = glob.glob(args.files)
if not files:
logger.info(f'File does not exist: {args.files}')
exit(1)
# Initialize an empty tokenizer
# ANY ARGS?
tokenizer = CharBPETokenizer()
# And then train
tokenizer.train(
files,
vocab_size=args.vocab_size,
min_frequency=2,
show_progress=True,
special_tokens=['<unk>'],
suffix='</w>',
limit_alphabet=args.limit_alphabet,
)
# Save the files
tokenizer.save(args.out, args.name)
# Restoring model from learned vocab/merges
tokenizer = CharBPETokenizer(
join(args.out, '{}-vocab.json'.format(args.name)),
join(args.out, '{}-merges.txt'.format(args.name)),
)
# Test encoding
logger.info('Tokens and their ids from CharBPETokenizer with GFP protein sequence: \n MSKGEE LFTGVVPILVELDGDVNGHKFSVSGEGEG DAT')
encoded = tokenizer.encode('MSKGEE LFTGVVPILVELDGDVNGHKFSVSGEGEG DAT')
logger.info(encoded.tokens)
logger.info(encoded.ids)
logger.info('done!')