14
14
15
15
from itertools import islice , zip_longest
16
16
from copy import deepcopy
17
+ from collections import defaultdict
18
+ from argparse import Namespace
17
19
18
20
from onmt .constants import DefaultTokens
19
21
from onmt .utils .logging import init_logger
22
24
from onmt .utils .alignment import to_word_align
23
25
from onmt .utils .parse import ArgumentParser
24
26
from onmt .translate .translator import build_translator
27
+ from onmt .transforms .features import InferFeatsTransform
25
28
26
29
27
30
def critical (func ):
@@ -192,6 +195,7 @@ def start(self, config_file):
192
195
{}),
193
196
'ct2_translate_batch_args' : conf .get (
194
197
'ct2_translate_batch_args' , {}),
198
+ 'features_opt' : conf .get ('features' , None )
195
199
}
196
200
kwargs = {k : v for (k , v ) in kwargs .items () if v is not None }
197
201
model_id = conf .get ("id" , None )
@@ -299,14 +303,16 @@ class ServerModel(object):
299
303
def __init__ (self , opt , model_id , preprocess_opt = None , tokenizer_opt = None ,
300
304
postprocess_opt = None , custom_opt = None , load = False , timeout = - 1 ,
301
305
on_timeout = "to_cpu" , model_root = "./" , ct2_model = None ,
302
- ct2_translator_args = None , ct2_translate_batch_args = None ):
306
+ ct2_translator_args = None , ct2_translate_batch_args = None ,
307
+ features_opt = None ):
303
308
self .model_root = model_root
304
309
self .opt = self .parse_opt (opt )
305
310
self .custom_opt = custom_opt
306
311
307
312
self .model_id = model_id
308
313
self .preprocess_opt = preprocess_opt
309
314
self .tokenizers_opt = tokenizer_opt
315
+ self .features_opt = features_opt
310
316
self .postprocess_opt = postprocess_opt
311
317
self .timeout = timeout
312
318
self .on_timeout = on_timeout
@@ -319,6 +325,7 @@ def __init__(self, opt, model_id, preprocess_opt=None, tokenizer_opt=None,
319
325
self .unload_timer = None
320
326
self .user_opt = opt
321
327
self .tokenizers = None
328
+ self .feats_transform = None
322
329
323
330
if len (self .opt .log_file ) > 0 :
324
331
log_file = os .path .join (model_root , self .opt .log_file )
@@ -361,6 +368,10 @@ def __init__(self, opt, model_id, preprocess_opt=None, tokenizer_opt=None,
361
368
'tgt' : tokenizer
362
369
}
363
370
371
+ if self .features_opt is not None :
372
+ self .feats_transform = InferFeatsTransform (
373
+ Namespace (** self .features_opt ))
374
+
364
375
if self .postprocess_opt is not None :
365
376
self .logger .info ("Loading postprocessor" )
366
377
self .postprocessor = []
@@ -497,21 +508,27 @@ def run(self, inputs):
497
508
# every segment becomes a dict for flexibility purposes
498
509
seg_dict = self .maybe_preprocess (inp )
499
510
all_preprocessed .append (seg_dict )
500
- for seg , ref in zip_longest (seg_dict ["seg" ], seg_dict ["ref" ]):
511
+ for seg , ref , feats in zip_longest (
512
+ seg_dict ["seg" ], seg_dict ["ref" ],
513
+ seg_dict ["src_feats" ]):
501
514
tok = self .maybe_tokenize (seg )
502
515
if ref is not None :
503
516
ref = self .maybe_tokenize (ref , side = 'tgt' )
504
- texts .append ((tok , ref ))
517
+ inferred_feats = self .transform_feats (seg , tok , feats )
518
+ texts .append ((tok , ref , inferred_feats ))
505
519
tail_spaces .append (whitespaces_after )
506
520
507
521
empty_indices = []
508
522
texts_to_translate , texts_ref = [], []
509
- for i , (tok , ref_tok ) in enumerate (texts ):
523
+ texts_features = defaultdict (list )
524
+ for i , (tok , ref_tok , feats ) in enumerate (texts ):
510
525
if tok == "" :
511
526
empty_indices .append (i )
512
527
else :
513
528
texts_to_translate .append (tok )
514
529
texts_ref .append (ref_tok )
530
+ for feat_name , feat_values in feats .items ():
531
+ texts_features [feat_name ].append (feat_values )
515
532
if any ([item is None for item in texts_ref ]):
516
533
texts_ref = None
517
534
@@ -522,6 +539,7 @@ def run(self, inputs):
522
539
try :
523
540
scores , predictions = self .translator .translate (
524
541
texts_to_translate ,
542
+ src_feats = texts_features ,
525
543
tgt = texts_ref ,
526
544
batch_size = len (texts_to_translate )
527
545
if self .opt .batch_size == 0
@@ -682,6 +700,7 @@ def maybe_preprocess(self, sequence):
682
700
sequence ["seg" ] = [sequence ["src" ].strip ()]
683
701
sequence .pop ("src" )
684
702
sequence ["ref" ] = [sequence .get ('ref' , None )]
703
+ sequence ["src_feats" ] = [sequence .get ('src_feats' , {})]
685
704
sequence ["n_seg" ] = 1
686
705
if self .preprocess_opt is not None :
687
706
return self .preprocess (sequence )
@@ -702,6 +721,23 @@ def preprocess(self, sequence):
702
721
sequence = function (sequence , self )
703
722
return sequence
704
723
724
+ def transform_feats (self , raw_src , tok_src , feats ):
725
+ """Apply InferFeatsTransform to features"""
726
+ if self .feats_transform is None :
727
+ return feats
728
+ ex = {
729
+ "src" : tok_src .split (' ' ),
730
+ "src_original" : raw_src .split (' ' ),
731
+ "src_feats" : {k : v .split (' ' ) for k , v in feats .items ()}
732
+ }
733
+ transformed_ex = self .feats_transform .apply (ex )
734
+ if not transformed_ex :
735
+ raise Exception ("Error inferring feats" )
736
+ transformed_feats = dict ()
737
+ for feat_name , feat_values in transformed_ex ["src_feats" ].items ():
738
+ transformed_feats [feat_name ] = " " .join (feat_values )
739
+ return transformed_feats
740
+
705
741
def build_tokenizer (self , tokenizer_opt ):
706
742
"""Build tokenizer described by `tokenizer_opt`."""
707
743
if "type" not in tokenizer_opt :
0 commit comments