Skip to content

Commit e445372

Browse files
authored
Freeze TorchScript modules with preserved attributes
1 parent ff7767c commit e445372

File tree

2 files changed

+4
-0
lines changed

2 files changed

+4
-0
lines changed

source/api_cc/src/DeepPotPT.cc

+2
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,8 @@ void DeepPotPT::init(const std::string& model,
8989
std::unordered_map<std::string, std::string> metadata = {{"type", ""}};
9090
module = torch::jit::load(model, device, metadata);
9191
module.eval();
92+
const std::vector<std::string>& preserved_attrs = {"forward_lower", "has_message_passing", "get_rcut", "get_ntypes", "get_dim_fparam", "get_dim_aparam", "is_aparam_nall", "get_type_map"};
93+
module = torch::jit::freeze(module, preserved_attrs);
9294
do_message_passing = module.run_method("has_message_passing").toBool();
9395
torch::jit::FusionStrategy strategy;
9496
strategy = {{torch::jit::FusionBehavior::DYNAMIC, 10}};

source/api_cc/src/DeepSpinPT.cc

+2
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,8 @@ void DeepSpinPT::init(const std::string& model,
8989
std::unordered_map<std::string, std::string> metadata = {{"type", ""}};
9090
module = torch::jit::load(model, device, metadata);
9191
module.eval();
92+
const std::vector<std::string>& preserved_attrs = {"forward_lower", "has_message_passing", "get_rcut", "get_ntypes", "get_dim_fparam", "get_dim_aparam", "is_aparam_nall", "get_type_map"};
93+
module = torch::jit::freeze(module, preserved_attrs);
9294
do_message_passing = module.run_method("has_message_passing").toBool();
9395
torch::jit::FusionStrategy strategy;
9496
strategy = {{torch::jit::FusionBehavior::DYNAMIC, 10}};

0 commit comments

Comments
 (0)