Skip to content

Commit 2d2445b

Browse files
authored
Modified server to allow source features (#2109)
1 parent 97d93de commit 2d2445b

File tree

1 file changed

+40
-4
lines changed

1 file changed

+40
-4
lines changed

onmt/translate/translation_server.py

Lines changed: 40 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414

1515
from itertools import islice, zip_longest
1616
from copy import deepcopy
17+
from collections import defaultdict
18+
from argparse import Namespace
1719

1820
from onmt.constants import DefaultTokens
1921
from onmt.utils.logging import init_logger
@@ -22,6 +24,7 @@
2224
from onmt.utils.alignment import to_word_align
2325
from onmt.utils.parse import ArgumentParser
2426
from onmt.translate.translator import build_translator
27+
from onmt.transforms.features import InferFeatsTransform
2528

2629

2730
def critical(func):
@@ -192,6 +195,7 @@ def start(self, config_file):
192195
{}),
193196
'ct2_translate_batch_args': conf.get(
194197
'ct2_translate_batch_args', {}),
198+
'features_opt': conf.get('features', None)
195199
}
196200
kwargs = {k: v for (k, v) in kwargs.items() if v is not None}
197201
model_id = conf.get("id", None)
@@ -299,14 +303,16 @@ class ServerModel(object):
299303
def __init__(self, opt, model_id, preprocess_opt=None, tokenizer_opt=None,
300304
postprocess_opt=None, custom_opt=None, load=False, timeout=-1,
301305
on_timeout="to_cpu", model_root="./", ct2_model=None,
302-
ct2_translator_args=None, ct2_translate_batch_args=None):
306+
ct2_translator_args=None, ct2_translate_batch_args=None,
307+
features_opt=None):
303308
self.model_root = model_root
304309
self.opt = self.parse_opt(opt)
305310
self.custom_opt = custom_opt
306311

307312
self.model_id = model_id
308313
self.preprocess_opt = preprocess_opt
309314
self.tokenizers_opt = tokenizer_opt
315+
self.features_opt = features_opt
310316
self.postprocess_opt = postprocess_opt
311317
self.timeout = timeout
312318
self.on_timeout = on_timeout
@@ -319,6 +325,7 @@ def __init__(self, opt, model_id, preprocess_opt=None, tokenizer_opt=None,
319325
self.unload_timer = None
320326
self.user_opt = opt
321327
self.tokenizers = None
328+
self.feats_transform = None
322329

323330
if len(self.opt.log_file) > 0:
324331
log_file = os.path.join(model_root, self.opt.log_file)
@@ -361,6 +368,10 @@ def __init__(self, opt, model_id, preprocess_opt=None, tokenizer_opt=None,
361368
'tgt': tokenizer
362369
}
363370

371+
if self.features_opt is not None:
372+
self.feats_transform = InferFeatsTransform(
373+
Namespace(**self.features_opt))
374+
364375
if self.postprocess_opt is not None:
365376
self.logger.info("Loading postprocessor")
366377
self.postprocessor = []
@@ -497,21 +508,27 @@ def run(self, inputs):
497508
# every segment becomes a dict for flexibility purposes
498509
seg_dict = self.maybe_preprocess(inp)
499510
all_preprocessed.append(seg_dict)
500-
for seg, ref in zip_longest(seg_dict["seg"], seg_dict["ref"]):
511+
for seg, ref, feats in zip_longest(
512+
seg_dict["seg"], seg_dict["ref"],
513+
seg_dict["src_feats"]):
501514
tok = self.maybe_tokenize(seg)
502515
if ref is not None:
503516
ref = self.maybe_tokenize(ref, side='tgt')
504-
texts.append((tok, ref))
517+
inferred_feats = self.transform_feats(seg, tok, feats)
518+
texts.append((tok, ref, inferred_feats))
505519
tail_spaces.append(whitespaces_after)
506520

507521
empty_indices = []
508522
texts_to_translate, texts_ref = [], []
509-
for i, (tok, ref_tok) in enumerate(texts):
523+
texts_features = defaultdict(list)
524+
for i, (tok, ref_tok, feats) in enumerate(texts):
510525
if tok == "":
511526
empty_indices.append(i)
512527
else:
513528
texts_to_translate.append(tok)
514529
texts_ref.append(ref_tok)
530+
for feat_name, feat_values in feats.items():
531+
texts_features[feat_name].append(feat_values)
515532
if any([item is None for item in texts_ref]):
516533
texts_ref = None
517534

@@ -522,6 +539,7 @@ def run(self, inputs):
522539
try:
523540
scores, predictions = self.translator.translate(
524541
texts_to_translate,
542+
src_feats=texts_features,
525543
tgt=texts_ref,
526544
batch_size=len(texts_to_translate)
527545
if self.opt.batch_size == 0
@@ -682,6 +700,7 @@ def maybe_preprocess(self, sequence):
682700
sequence["seg"] = [sequence["src"].strip()]
683701
sequence.pop("src")
684702
sequence["ref"] = [sequence.get('ref', None)]
703+
sequence["src_feats"] = [sequence.get('src_feats', {})]
685704
sequence["n_seg"] = 1
686705
if self.preprocess_opt is not None:
687706
return self.preprocess(sequence)
@@ -702,6 +721,23 @@ def preprocess(self, sequence):
702721
sequence = function(sequence, self)
703722
return sequence
704723

724+
def transform_feats(self, raw_src, tok_src, feats):
725+
"""Apply InferFeatsTransform to features"""
726+
if self.feats_transform is None:
727+
return feats
728+
ex = {
729+
"src": tok_src.split(' '),
730+
"src_original": raw_src.split(' '),
731+
"src_feats": {k: v.split(' ') for k, v in feats.items()}
732+
}
733+
transformed_ex = self.feats_transform.apply(ex)
734+
if not transformed_ex:
735+
raise Exception("Error inferring feats")
736+
transformed_feats = dict()
737+
for feat_name, feat_values in transformed_ex["src_feats"].items():
738+
transformed_feats[feat_name] = " ".join(feat_values)
739+
return transformed_feats
740+
705741
def build_tokenizer(self, tokenizer_opt):
706742
"""Build tokenizer described by `tokenizer_opt`."""
707743
if "type" not in tokenizer_opt:

0 commit comments

Comments
 (0)