Skip to content

Modified server to allow source features #2109

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Nov 4, 2021
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 40 additions & 4 deletions onmt/translate/translation_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@

from itertools import islice, zip_longest
from copy import deepcopy
from collections import defaultdict
from argparse import Namespace

from onmt.constants import DefaultTokens
from onmt.utils.logging import init_logger
Expand All @@ -22,6 +24,7 @@
from onmt.utils.alignment import to_word_align
from onmt.utils.parse import ArgumentParser
from onmt.translate.translator import build_translator
from onmt.transforms.features import InferFeatsTransform


def critical(func):
Expand Down Expand Up @@ -192,6 +195,7 @@ def start(self, config_file):
{}),
'ct2_translate_batch_args': conf.get(
'ct2_translate_batch_args', {}),
'features_opt': conf.get('features', None)
}
kwargs = {k: v for (k, v) in kwargs.items() if v is not None}
model_id = conf.get("id", None)
Expand Down Expand Up @@ -299,14 +303,16 @@ class ServerModel(object):
def __init__(self, opt, model_id, preprocess_opt=None, tokenizer_opt=None,
postprocess_opt=None, custom_opt=None, load=False, timeout=-1,
on_timeout="to_cpu", model_root="./", ct2_model=None,
ct2_translator_args=None, ct2_translate_batch_args=None):
ct2_translator_args=None, ct2_translate_batch_args=None,
features_opt=None):
self.model_root = model_root
self.opt = self.parse_opt(opt)
self.custom_opt = custom_opt

self.model_id = model_id
self.preprocess_opt = preprocess_opt
self.tokenizers_opt = tokenizer_opt
self.features_opt = features_opt
self.postprocess_opt = postprocess_opt
self.timeout = timeout
self.on_timeout = on_timeout
Expand All @@ -319,6 +325,7 @@ def __init__(self, opt, model_id, preprocess_opt=None, tokenizer_opt=None,
self.unload_timer = None
self.user_opt = opt
self.tokenizers = None
self.feats_transform = None

if len(self.opt.log_file) > 0:
log_file = os.path.join(model_root, self.opt.log_file)
Expand Down Expand Up @@ -361,6 +368,10 @@ def __init__(self, opt, model_id, preprocess_opt=None, tokenizer_opt=None,
'tgt': tokenizer
}

if self.features_opt is not None:
self.feats_transform = InferFeatsTransform(
Namespace(**self.features_opt))

if self.postprocess_opt is not None:
self.logger.info("Loading postprocessor")
self.postprocessor = []
Expand Down Expand Up @@ -497,21 +508,27 @@ def run(self, inputs):
# every segment becomes a dict for flexibility purposes
seg_dict = self.maybe_preprocess(inp)
all_preprocessed.append(seg_dict)
for seg, ref in zip_longest(seg_dict["seg"], seg_dict["ref"]):
for seg, ref, feats in zip_longest(
seg_dict["seg"], seg_dict["ref"],
seg_dict["src_feats"]):
tok = self.maybe_tokenize(seg)
if ref is not None:
ref = self.maybe_tokenize(ref, side='tgt')
texts.append((tok, ref))
inferred_feats = self.transform_feats(seg, tok, feats)
texts.append((tok, ref, inferred_feats))
tail_spaces.append(whitespaces_after)

empty_indices = []
texts_to_translate, texts_ref = [], []
for i, (tok, ref_tok) in enumerate(texts):
texts_features = defaultdict(list)
for i, (tok, ref_tok, feats) in enumerate(texts):
if tok == "":
empty_indices.append(i)
else:
texts_to_translate.append(tok)
texts_ref.append(ref_tok)
for feat_name, feat_values in feats.items():
texts_features[feat_name].append(feat_values)
if any([item is None for item in texts_ref]):
texts_ref = None

Expand All @@ -522,6 +539,7 @@ def run(self, inputs):
try:
scores, predictions = self.translator.translate(
texts_to_translate,
src_feats=texts_features,
tgt=texts_ref,
batch_size=len(texts_to_translate)
if self.opt.batch_size == 0
Expand Down Expand Up @@ -682,6 +700,7 @@ def maybe_preprocess(self, sequence):
sequence["seg"] = [sequence["src"].strip()]
sequence.pop("src")
sequence["ref"] = [sequence.get('ref', None)]
sequence["src_feats"] = [sequence.get('src_feats', {})]
sequence["n_seg"] = 1
if self.preprocess_opt is not None:
return self.preprocess(sequence)
Expand All @@ -702,6 +721,23 @@ def preprocess(self, sequence):
sequence = function(sequence, self)
return sequence

def transform_feats(self, raw_src, tok_src, feats):
"""Apply InferFeatsTransform to features"""
if self.feats_transform is None:
return feats
ex = {
"src": tok_src.split(' '),
"src_original": raw_src.split(' '),
"src_feats": {k: v.split(' ') for k, v in feats.items()}
}
transformed_ex = self.feats_transform.apply(ex)
if not transformed_ex:
raise Exception("Error inferring feats")
transformed_feats = dict()
for feat_name, feat_values in transformed_ex["src_feats"].items():
transformed_feats[feat_name] = " ".join(feat_values)
return transformed_feats

def build_tokenizer(self, tokenizer_opt):
"""Build tokenizer described by `tokenizer_opt`."""
if "type" not in tokenizer_opt:
Expand Down