Skip to content

Commit db808e6

Browse files
committed
cached MT server
1 parent 331fdb8 commit db808e6

7 files changed

+32586
-25
lines changed

.gitmodules

+3
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
[submodule "subword-nmt"]
2+
path = subword-nmt
3+
url = https://github.com/rsennrich/subword-nmt

apply_bpe.py

+373
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,373 @@
1+
#!/usr/bin/env python
2+
# -*- coding: utf-8 -*-
3+
# Author: Rico Sennrich
4+
5+
"""Use operations learned with learn_bpe.py to encode a new text.
6+
The text will not be smaller, but use only a fixed vocabulary, with rare words
7+
encoded as variable-length sequences of subword units.
8+
9+
Reference:
10+
Rico Sennrich, Barry Haddow and Alexandra Birch (2015). Neural Machine Translation of Rare Words with Subword Units.
11+
Proceedings of the 54th Annual Meeting of the Association for Computational Linguistics (ACL 2016). Berlin, Germany.
12+
"""
13+
14+
from __future__ import unicode_literals, division
15+
16+
import sys
17+
import os
18+
import inspect
19+
import codecs
20+
import io
21+
import argparse
22+
import re
23+
import warnings
24+
25+
# hack for python2/3 compatibility
26+
from io import open
27+
argparse.open = open
28+
29+
class BPE(object):
30+
31+
def __init__(self, codes, merges=-1, separator='@@', vocab=None, glossaries=None):
32+
33+
codes.seek(0)
34+
offset=1
35+
36+
# check version information
37+
firstline = codes.readline()
38+
if firstline.startswith('#version:'):
39+
self.version = tuple([int(x) for x in re.sub(r'(\.0+)*$','', firstline.split()[-1]).split(".")])
40+
offset += 1
41+
else:
42+
self.version = (0, 1)
43+
codes.seek(0)
44+
45+
self.bpe_codes = [tuple(item.strip('\r\n ').split(' ')) for (n, item) in enumerate(codes) if (n < merges or merges == -1)]
46+
47+
for i, item in enumerate(self.bpe_codes):
48+
if len(item) != 2:
49+
sys.stderr.write('Error: invalid line {0} in BPE codes file: {1}\n'.format(i+offset, ' '.join(item)))
50+
sys.stderr.write('The line should exist of exactly two subword units, separated by whitespace\n')
51+
sys.exit(1)
52+
53+
# some hacking to deal with duplicates (only consider first instance)
54+
self.bpe_codes = dict([(code,i) for (i,code) in reversed(list(enumerate(self.bpe_codes)))])
55+
56+
self.bpe_codes_reverse = dict([(pair[0] + pair[1], pair) for pair,i in self.bpe_codes.items()])
57+
58+
self.separator = separator
59+
60+
self.vocab = vocab
61+
62+
self.glossaries = glossaries if glossaries else []
63+
64+
self.cache = {}
65+
66+
def process_line(self, line):
67+
"""segment line, dealing with leading and trailing whitespace"""
68+
69+
out = ""
70+
71+
leading_whitespace = len(line)-len(line.lstrip('\r\n '))
72+
if leading_whitespace:
73+
out += line[:leading_whitespace]
74+
75+
out += self.segment(line)
76+
77+
trailing_whitespace = len(line)-len(line.rstrip('\r\n '))
78+
if trailing_whitespace and trailing_whitespace != len(line):
79+
out += line[-trailing_whitespace:]
80+
81+
return out
82+
83+
def segment(self, sentence):
84+
"""segment single sentence (whitespace-tokenized string) with BPE encoding"""
85+
segments = self.segment_tokens(sentence.strip('\r\n ').split(' '))
86+
return ' '.join(segments)
87+
88+
def segment_tokens(self, tokens):
89+
"""segment a sequence of tokens with BPE encoding"""
90+
output = []
91+
for word in tokens:
92+
# eliminate double spaces
93+
if not word:
94+
continue
95+
new_word = [out for segment in self._isolate_glossaries(word)
96+
for out in encode(segment,
97+
self.bpe_codes,
98+
self.bpe_codes_reverse,
99+
self.vocab,
100+
self.separator,
101+
self.version,
102+
self.cache,
103+
self.glossaries)]
104+
105+
for item in new_word[:-1]:
106+
output.append(item + self.separator)
107+
output.append(new_word[-1])
108+
109+
return output
110+
111+
def _isolate_glossaries(self, word):
112+
word_segments = [word]
113+
for gloss in self.glossaries:
114+
word_segments = [out_segments for segment in word_segments
115+
for out_segments in isolate_glossary(segment, gloss)]
116+
return word_segments
117+
118+
def create_parser(subparsers=None):
119+
120+
if subparsers:
121+
parser = subparsers.add_parser('apply-bpe',
122+
formatter_class=argparse.RawDescriptionHelpFormatter,
123+
description="learn BPE-based word segmentation")
124+
else:
125+
parser = argparse.ArgumentParser(
126+
formatter_class=argparse.RawDescriptionHelpFormatter,
127+
description="learn BPE-based word segmentation")
128+
129+
parser.add_argument(
130+
'--input', '-i', type=argparse.FileType('r'), default=sys.stdin,
131+
metavar='PATH',
132+
help="Input file (default: standard input).")
133+
parser.add_argument(
134+
'--codes', '-c', type=argparse.FileType('r'), metavar='PATH',
135+
required=True,
136+
help="File with BPE codes (created by learn_bpe.py).")
137+
parser.add_argument(
138+
'--merges', '-m', type=int, default=-1,
139+
metavar='INT',
140+
help="Use this many BPE operations (<= number of learned symbols)"+
141+
"default: Apply all the learned merge operations")
142+
parser.add_argument(
143+
'--output', '-o', type=argparse.FileType('w'), default=sys.stdout,
144+
metavar='PATH',
145+
help="Output file (default: standard output)")
146+
parser.add_argument(
147+
'--separator', '-s', type=str, default='@@', metavar='STR',
148+
help="Separator between non-final subword units (default: '%(default)s'))")
149+
parser.add_argument(
150+
'--vocabulary', type=argparse.FileType('r'), default=None,
151+
metavar="PATH",
152+
help="Vocabulary file (built with get_vocab.py). If provided, this script reverts any merge operations that produce an OOV.")
153+
parser.add_argument(
154+
'--vocabulary-threshold', type=int, default=None,
155+
metavar="INT",
156+
help="Vocabulary threshold. If vocabulary is provided, any word with frequency < threshold will be treated as OOV")
157+
parser.add_argument(
158+
'--glossaries', type=str, nargs='+', default=None,
159+
metavar="STR",
160+
help="Glossaries. Words matching any of the words/regex provided in glossaries will not be affected "+
161+
"by the BPE (i.e. they will neither be broken into subwords, nor concatenated with other subwords. "+
162+
"Can be provided as a list of words/regex after the --glossaries argument. Enclose each regex in quotes.")
163+
164+
return parser
165+
166+
def get_pairs(word):
167+
"""Return set of symbol pairs in a word.
168+
169+
word is represented as tuple of symbols (symbols being variable-length strings)
170+
"""
171+
pairs = set()
172+
prev_char = word[0]
173+
for char in word[1:]:
174+
pairs.add((prev_char, char))
175+
prev_char = char
176+
return pairs
177+
178+
def encode(orig, bpe_codes, bpe_codes_reverse, vocab, separator, version, cache, glossaries=None):
179+
"""Encode word based on list of BPE merge operations, which are applied consecutively
180+
"""
181+
182+
if orig in cache:
183+
return cache[orig]
184+
185+
if re.match('^({})$'.format('|'.join(glossaries)), orig):
186+
cache[orig] = (orig,)
187+
return (orig,)
188+
189+
if version == (0, 1):
190+
word = tuple(orig) + ('</w>',)
191+
elif version == (0, 2): # more consistent handling of word-final segments
192+
word = tuple(orig[:-1]) + ( orig[-1] + '</w>',)
193+
else:
194+
raise NotImplementedError
195+
196+
pairs = get_pairs(word)
197+
198+
if not pairs:
199+
return orig
200+
201+
while True:
202+
bigram = min(pairs, key = lambda pair: bpe_codes.get(pair, float('inf')))
203+
if bigram not in bpe_codes:
204+
break
205+
first, second = bigram
206+
new_word = []
207+
i = 0
208+
while i < len(word):
209+
try:
210+
j = word.index(first, i)
211+
new_word.extend(word[i:j])
212+
i = j
213+
except:
214+
new_word.extend(word[i:])
215+
break
216+
217+
if word[i] == first and i < len(word)-1 and word[i+1] == second:
218+
new_word.append(first+second)
219+
i += 2
220+
else:
221+
new_word.append(word[i])
222+
i += 1
223+
new_word = tuple(new_word)
224+
word = new_word
225+
if len(word) == 1:
226+
break
227+
else:
228+
pairs = get_pairs(word)
229+
230+
# don't print end-of-word symbols
231+
if word[-1] == '</w>':
232+
word = word[:-1]
233+
elif word[-1].endswith('</w>'):
234+
word = word[:-1] + (word[-1].replace('</w>',''),)
235+
236+
if vocab:
237+
word = check_vocab_and_split(word, bpe_codes_reverse, vocab, separator)
238+
239+
cache[orig] = word
240+
return word
241+
242+
def recursive_split(segment, bpe_codes, vocab, separator, final=False):
243+
"""Recursively split segment into smaller units (by reversing BPE merges)
244+
until all units are either in-vocabulary, or cannot be split futher."""
245+
246+
try:
247+
if final:
248+
left, right = bpe_codes[segment + '</w>']
249+
right = right[:-4]
250+
else:
251+
left, right = bpe_codes[segment]
252+
except:
253+
#sys.stderr.write('cannot split {0} further.\n'.format(segment))
254+
yield segment
255+
return
256+
257+
if left + separator in vocab:
258+
yield left
259+
else:
260+
for item in recursive_split(left, bpe_codes, vocab, separator, False):
261+
yield item
262+
263+
if (final and right in vocab) or (not final and right + separator in vocab):
264+
yield right
265+
else:
266+
for item in recursive_split(right, bpe_codes, vocab, separator, final):
267+
yield item
268+
269+
def check_vocab_and_split(orig, bpe_codes, vocab, separator):
270+
"""Check for each segment in word if it is in-vocabulary,
271+
and segment OOV segments into smaller units by reversing the BPE merge operations"""
272+
273+
out = []
274+
275+
for segment in orig[:-1]:
276+
if segment + separator in vocab:
277+
out.append(segment)
278+
else:
279+
#sys.stderr.write('OOV: {0}\n'.format(segment))
280+
for item in recursive_split(segment, bpe_codes, vocab, separator, False):
281+
out.append(item)
282+
283+
segment = orig[-1]
284+
if segment in vocab:
285+
out.append(segment)
286+
else:
287+
#sys.stderr.write('OOV: {0}\n'.format(segment))
288+
for item in recursive_split(segment, bpe_codes, vocab, separator, True):
289+
out.append(item)
290+
291+
return out
292+
293+
294+
def read_vocabulary(vocab_file, threshold):
295+
"""read vocabulary file produced by get_vocab.py, and filter according to frequency threshold.
296+
"""
297+
298+
vocabulary = set()
299+
300+
for line in vocab_file:
301+
word, freq = line.strip('\r\n ').split(' ')
302+
freq = int(freq)
303+
if threshold == None or freq >= threshold:
304+
vocabulary.add(word)
305+
306+
return vocabulary
307+
308+
def isolate_glossary(word, glossary):
309+
"""
310+
Isolate a glossary present inside a word.
311+
312+
Returns a list of subwords. In which all 'glossary' glossaries are isolated
313+
314+
For example, if 'USA' is the glossary and '1934USABUSA' the word, the return value is:
315+
['1934', 'USA', 'B', 'USA']
316+
"""
317+
# regex equivalent of (if word == glossary or glossary not in word)
318+
if re.match('^'+glossary+'$', word) or not re.search(glossary, word):
319+
return [word]
320+
else:
321+
segments = re.split(r'({})'.format(glossary), word)
322+
segments, ending = segments[:-1], segments[-1]
323+
segments = list(filter(None, segments)) # Remove empty strings in regex group.
324+
return segments + [ending.strip('\r\n ')] if ending != '' else segments
325+
326+
if __name__ == '__main__':
327+
328+
currentdir = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe())))
329+
newdir = os.path.join(currentdir, 'subword_nmt')
330+
if os.path.isdir(newdir):
331+
warnings.simplefilter('default')
332+
warnings.warn(
333+
"this script's location has moved to {0}. This symbolic link will be removed in a future version. Please point to the new location, or install the package and use the command 'subword-nmt'".format(newdir),
334+
DeprecationWarning
335+
)
336+
337+
# python 2/3 compatibility
338+
if sys.version_info < (3, 0):
339+
sys.stderr = codecs.getwriter('UTF-8')(sys.stderr)
340+
sys.stdout = codecs.getwriter('UTF-8')(sys.stdout)
341+
sys.stdin = codecs.getreader('UTF-8')(sys.stdin)
342+
else:
343+
sys.stdin = io.TextIOWrapper(sys.stdin.buffer, encoding='utf-8')
344+
sys.stderr = io.TextIOWrapper(sys.stderr.buffer, encoding='utf-8')
345+
sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding='utf-8', write_through=True, line_buffering=True)
346+
347+
parser = create_parser()
348+
args = parser.parse_args()
349+
350+
# read/write files as UTF-8
351+
args.codes = codecs.open(args.codes.name, encoding='utf-8')
352+
if args.input.name != '<stdin>':
353+
args.input = codecs.open(args.input.name, encoding='utf-8')
354+
if args.output.name != '<stdout>':
355+
args.output = codecs.open(args.output.name, 'w', encoding='utf-8')
356+
if args.vocabulary:
357+
args.vocabulary = codecs.open(args.vocabulary.name, encoding='utf-8')
358+
359+
if args.vocabulary:
360+
vocabulary = read_vocabulary(args.vocabulary, args.vocabulary_threshold)
361+
else:
362+
vocabulary = None
363+
364+
if sys.version_info < (3, 0):
365+
args.separator = args.separator.decode('UTF-8')
366+
if args.glossaries:
367+
args.glossaries = [g.decode('UTF-8') for g in args.glossaries]
368+
369+
370+
bpe = BPE(args.codes, args.merges, args.separator, vocabulary, args.glossaries)
371+
372+
for line in args.input:
373+
args.output.write(bpe.process_line(line))

opentrans-client.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,12 @@
1313
# handle command-line options
1414
parser = argparse.ArgumentParser()
1515
parser.add_argument("-b", "--batch-size", type=int, default=1)
16+
parser.add_argument("-h", "--host", type=str, default='localhost')
1617
parser.add_argument("-p", "--port", type=int, default=8080)
1718
args = parser.parse_args()
1819

1920
# open connection
20-
ws = create_connection("ws://86.50.168.81:{}/translate".format(args.port))
21+
ws = create_connection("ws://{}:{}/translate".format(args.host,args.port))
2122

2223
count = 0
2324
batch = ""

0 commit comments

Comments
 (0)