Skip to content

Commit e4ab06c

Browse files
authored
Add more checks before converting checkpoints to CTranslate2 (#2053)
1 parent b1a4615 commit e4ab06c

File tree

1 file changed

+5
-0
lines changed

1 file changed

+5
-0
lines changed

onmt/bin/release_model.py

+5
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,19 @@
22
import argparse
33
import torch
44

5+
from onmt.modules.position_ffn import ActivationFunction
6+
57

68
def get_ctranslate2_model_spec(opt):
79
"""Creates a CTranslate2 model specification from the model options."""
810
with_relative_position = getattr(opt, "max_relative_positions", 0) > 0
11+
relu = ActivationFunction.relu
912
is_ct2_compatible = (
1013
opt.encoder_type == "transformer"
1114
and opt.decoder_type == "transformer"
15+
and not getattr(opt, "aan_useffn", False)
1216
and getattr(opt, "self_attn_type", "scaled-dot") == "scaled-dot"
17+
and getattr(opt, "pos_ffn_activation_fn", relu) == relu
1318
and ((opt.position_encoding and not with_relative_position)
1419
or (with_relative_position and not opt.position_encoding)))
1520
if not is_ct2_compatible:

0 commit comments

Comments
 (0)