-
Notifications
You must be signed in to change notification settings - Fork 8
/
Copy pathtrain_sentencepiece.py
101 lines (88 loc) · 3.12 KB
/
train_sentencepiece.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
import argparse
import glob
from os.path import join
from loguru import logger
from tokenizers import SentencePieceBPETokenizer
"""
SentencePiece BPE Tokenizer
as outlined in Kudo 2018 Subword Regularization: Improving Neural Network Translation Modelswith Multiple Subword Candidates
The central idea is to `virtually augment training data with on-the-fly subword sampling`,
which helps to improve the accuracy as well as robustness of NMT models.
For better subword sampling they use the unigram language model, which unlike the greedy BPE approach
(takes two tokens, looks at the frequency of each pair and then merges the pairs that have
the highest combined frequency count) chooses the most likely likely combination.
Algorithm is performed in Expectation Maximization (EM) setting:
0) convert all the input into unicode, even spaces (as underscores, '_')
1) calculate probabilities (frequency-based) of each subword token (can seed the subword token set with BPE)
2) with EM estimate a loss which would result if each subword token was discarded
3) discard tokens with the largest loss (can adjust the fraction of the worst tokens to drop with param )
<-- insert fraction param
4) repeat steps 1-3 until reached final vocabulary size or until there is no change in token numbers after successive iterations
Pecularities:
- spaces encoded as "_", or symbol U+2581
"""
parser = argparse.ArgumentParser()
parser.add_argument(
'--files',
default=None,
metavar='path',
type=str,
required=True,
help='The files to use as training; accept a string in format `"**/*.txt"`'
)
parser.add_argument(
'--out',
# default='./',
type=str,
required=True,
help='Path to the output directory, where the files will be saved'
)
parser.add_argument(
'--name',
default='sentencepiece',
type=str,
help='The name of the output vocab files',
)
parser.add_argument(
'--vocab_size',
default=30000,
type=int,
required=True,
help='Vocabulary size',
)
parser.add_argument(
'--limit_alphabet',
default=1000,
type=int,
help='The size of alphabet character set (e.g., for English, |alphabet|=26)',
)
args = parser.parse_args()
files = glob.glob(args.files)
if not files:
logger.info(f'File does not exist: {args.files}')
exit(1)
# Initialize an empty tokenizer
tokenizer = SentencePieceBPETokenizer(add_prefix_space=True)
# And then train
tokenizer.train(
files,
vocab_size=args.vocab_size,
min_frequency=2,
show_progress=True,
special_tokens=['<unk>'],
limit_alphabet=1000
)
# Save the files
tokenizer.save(args.out, args.name)
# Restoring model from learned vocab/merges
tokenizer = SentencePieceBPETokenizer(
join(args.out, '{}-vocab.json'.format(args.name)),
join(args.out, '{}-merges.txt'.format(args.name)),
add_prefix_space=True
)
# Test encoding
logger.info('Tokens and their ids from SentencePiece with GFP protein sequence: \n MSKGEE LFTGVVPILVELDGDVNGHKFSVSGEGEG DAT')
encoded = tokenizer.encode('MSKGEE LFTGVVPILVELDGDVNGHKFSVSGEGEG DAT')
logger.info(encoded.tokens)
logger.info(encoded.ids)
logger.info('done!')