-
Notifications
You must be signed in to change notification settings - Fork 8
/
Copy pathtrain_wordpiece.py
87 lines (77 loc) · 2.35 KB
/
train_wordpiece.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
import argparse
import glob
from loguru import logger
from tokenizers import BertWordPieceTokenizer
"""
Wordpiece as used by Google to train BERT and DistilBERT
Algorithm:
- convert all the input into unicode characters (language/character agnostic)
Similar to BPE and uses frequency occurrences to identify potential merges,
but makes the final decision based on the likelihood of the merged token.
Important:
- prepends a word prefix '##' (`wordpieces_prefix`) for sub-words of less common (unknown in the Vocabulary) words,
e.g., `hypatia = h ##yp ##ati ##a`
- support for Chinese characters
- special tokens: `"[PAD]", "[UNK]", "[CLS]", "[SEP]", "[MASK]"`
"""
parser = argparse.ArgumentParser()
parser.add_argument(
"--files",
default=None,
metavar="path",
type=str,
required=True,
help="The files to use as training; accept `**/*.txt` type of patterns if enclosed in quotes",
)
parser.add_argument(
"--out",
default="./",
type=str,
help="Path to the output directory, where the files will be saved",
)
parser.add_argument(
"--name", default="bert-wordpiece",
type=str,
help="The name of the output vocab files"
)
parser.add_argument(
'--vocab_size',
default=30000,
type=int,
required=True,
help='Vocabulary size',
)
args = parser.parse_args()
files = glob.glob(args.files)
if not files:
logger.info(f"File does not exist: {args.files}")
exit(1)
# CHINESE CHARACTERS???!!!
# Initialize an empty tokenizer
tokenizer = BertWordPieceTokenizer(
clean_text=True, handle_chinese_chars=False,
strip_accents=True, lowercase=True,
)
# And then train
trainer = tokenizer.train(
files,
vocab_size=10000,
min_frequency=2,
show_progress=True,
special_tokens=["[PAD]", "[UNK]", "[CLS]", "[SEP]", "[MASK]"],
limit_alphabet=1000,
wordpieces_prefix="##",
)
# Save the files
tokenizer.save(args.out, args.name)
# Restoring model from learned vocab/merges
tokenizer = BertWordPieceTokenizer(vocab_file=
join(args.out, '{}-vocab.txt'.format(args.name)),
prefix=wordpieces_prefix
)
# Test encoding
logger.info('Testing BertWordPieceTokenizer with GFP protein sequence: \n MSKGEE LFTGVVPILVELDGDVNGHKFSVSGEGEG DAT')
encoded = tokenizer.encode('MSKGEE LFTGVVPILVELDGDVNGHKFSVSGEGEG DAT')
logger.info(encoded.tokens)
logger.info(encoded.ids)
logger.info('done!')