diff --git a/onmt/translate/translation_server.py b/onmt/translate/translation_server.py index e0e968f57b..e3cc084c0e 100644 --- a/onmt/translate/translation_server.py +++ b/onmt/translate/translation_server.py @@ -14,6 +14,8 @@ from itertools import islice, zip_longest from copy import deepcopy +from collections import defaultdict +from argparse import Namespace from onmt.constants import DefaultTokens from onmt.utils.logging import init_logger @@ -22,6 +24,7 @@ from onmt.utils.alignment import to_word_align from onmt.utils.parse import ArgumentParser from onmt.translate.translator import build_translator +from onmt.transforms.features import InferFeatsTransform def critical(func): @@ -192,6 +195,7 @@ def start(self, config_file): {}), 'ct2_translate_batch_args': conf.get( 'ct2_translate_batch_args', {}), + 'features_opt': conf.get('features', None) } kwargs = {k: v for (k, v) in kwargs.items() if v is not None} model_id = conf.get("id", None) @@ -299,7 +303,8 @@ class ServerModel(object): def __init__(self, opt, model_id, preprocess_opt=None, tokenizer_opt=None, postprocess_opt=None, custom_opt=None, load=False, timeout=-1, on_timeout="to_cpu", model_root="./", ct2_model=None, - ct2_translator_args=None, ct2_translate_batch_args=None): + ct2_translator_args=None, ct2_translate_batch_args=None, + features_opt=None): self.model_root = model_root self.opt = self.parse_opt(opt) self.custom_opt = custom_opt @@ -307,6 +312,7 @@ def __init__(self, opt, model_id, preprocess_opt=None, tokenizer_opt=None, self.model_id = model_id self.preprocess_opt = preprocess_opt self.tokenizers_opt = tokenizer_opt + self.features_opt = features_opt self.postprocess_opt = postprocess_opt self.timeout = timeout self.on_timeout = on_timeout @@ -319,6 +325,7 @@ def __init__(self, opt, model_id, preprocess_opt=None, tokenizer_opt=None, self.unload_timer = None self.user_opt = opt self.tokenizers = None + self.feats_transform = None if len(self.opt.log_file) > 0: log_file = os.path.join(model_root, self.opt.log_file) @@ -361,6 +368,10 @@ def __init__(self, opt, model_id, preprocess_opt=None, tokenizer_opt=None, 'tgt': tokenizer } + if self.features_opt is not None: + self.feats_transform = InferFeatsTransform( + Namespace(**self.features_opt)) + if self.postprocess_opt is not None: self.logger.info("Loading postprocessor") self.postprocessor = [] @@ -497,21 +508,27 @@ def run(self, inputs): # every segment becomes a dict for flexibility purposes seg_dict = self.maybe_preprocess(inp) all_preprocessed.append(seg_dict) - for seg, ref in zip_longest(seg_dict["seg"], seg_dict["ref"]): + for seg, ref, feats in zip_longest( + seg_dict["seg"], seg_dict["ref"], + seg_dict["src_feats"]): tok = self.maybe_tokenize(seg) if ref is not None: ref = self.maybe_tokenize(ref, side='tgt') - texts.append((tok, ref)) + inferred_feats = self.transform_feats(seg, tok, feats) + texts.append((tok, ref, inferred_feats)) tail_spaces.append(whitespaces_after) empty_indices = [] texts_to_translate, texts_ref = [], [] - for i, (tok, ref_tok) in enumerate(texts): + texts_features = defaultdict(list) + for i, (tok, ref_tok, feats) in enumerate(texts): if tok == "": empty_indices.append(i) else: texts_to_translate.append(tok) texts_ref.append(ref_tok) + for feat_name, feat_values in feats.items(): + texts_features[feat_name].append(feat_values) if any([item is None for item in texts_ref]): texts_ref = None @@ -522,6 +539,7 @@ def run(self, inputs): try: scores, predictions = self.translator.translate( texts_to_translate, + src_feats=texts_features, tgt=texts_ref, batch_size=len(texts_to_translate) if self.opt.batch_size == 0 @@ -682,6 +700,7 @@ def maybe_preprocess(self, sequence): sequence["seg"] = [sequence["src"].strip()] sequence.pop("src") sequence["ref"] = [sequence.get('ref', None)] + sequence["src_feats"] = [sequence.get('src_feats', {})] sequence["n_seg"] = 1 if self.preprocess_opt is not None: return self.preprocess(sequence) @@ -702,6 +721,23 @@ def preprocess(self, sequence): sequence = function(sequence, self) return sequence + def transform_feats(self, raw_src, tok_src, feats): + """Apply InferFeatsTransform to features""" + if self.feats_transform is None: + return feats + ex = { + "src": tok_src.split(' '), + "src_original": raw_src.split(' '), + "src_feats": {k: v.split(' ') for k, v in feats.items()} + } + transformed_ex = self.feats_transform.apply(ex) + if not transformed_ex: + raise Exception("Error inferring feats") + transformed_feats = dict() + for feat_name, feat_values in transformed_ex["src_feats"].items(): + transformed_feats[feat_name] = " ".join(feat_values) + return transformed_feats + def build_tokenizer(self, tokenizer_opt): """Build tokenizer described by `tokenizer_opt`.""" if "type" not in tokenizer_opt: