We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent b1a4615 commit e4ab06cCopy full SHA for e4ab06c
onmt/bin/release_model.py
@@ -2,14 +2,19 @@
2
import argparse
3
import torch
4
5
+from onmt.modules.position_ffn import ActivationFunction
6
+
7
8
def get_ctranslate2_model_spec(opt):
9
"""Creates a CTranslate2 model specification from the model options."""
10
with_relative_position = getattr(opt, "max_relative_positions", 0) > 0
11
+ relu = ActivationFunction.relu
12
is_ct2_compatible = (
13
opt.encoder_type == "transformer"
14
and opt.decoder_type == "transformer"
15
+ and not getattr(opt, "aan_useffn", False)
16
and getattr(opt, "self_attn_type", "scaled-dot") == "scaled-dot"
17
+ and getattr(opt, "pos_ffn_activation_fn", relu) == relu
18
and ((opt.position_encoding and not with_relative_position)
19
or (with_relative_position and not opt.position_encoding)))
20
if not is_ct2_compatible:
0 commit comments