forked from deepmodeling/deepmd-kit
-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathserialization.py
132 lines (124 loc) · 3.89 KB
/
serialization.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
# SPDX-License-Identifier: LGPL-3.0-or-later
import json
import os
import tempfile
from deepmd.tf.entrypoints import (
freeze,
)
from deepmd.tf.env import (
GLOBAL_TF_FLOAT_PRECISION,
tf,
)
from deepmd.tf.model.model import (
Model,
)
from deepmd.tf.utils.errors import (
GraphWithoutTensorError,
)
from deepmd.tf.utils.graph import (
get_tensor_by_name_from_graph,
load_graph_def,
)
from deepmd.tf.utils.sess import (
run_sess,
)
def serialize_from_file(model_file: str) -> dict:
"""Serialize the model file to a dictionary.
Parameters
----------
model_file : str
The model file to be serialized.
Returns
-------
dict
The serialized model data.
"""
graph, graph_def = load_graph_def(model_file)
t_jdata = get_tensor_by_name_from_graph(graph, "train_attr/training_script")
jdata = json.loads(t_jdata)
model = Model(**jdata["model"])
# important! must be called before serialize
model.init_variables(graph=graph, graph_def=graph_def)
model_dict = model.serialize()
data = {
"backend": "TensorFlow",
"tf_version": tf.__version__,
"model": model_dict,
"model_def_script": jdata["model"],
}
# neighbor stat information
try:
t_min_nbor_dist = get_tensor_by_name_from_graph(
graph, "train_attr/min_nbor_dist"
)
except GraphWithoutTensorError as e:
pass
else:
data.setdefault("@variables", {})
data["@variables"]["min_nbor_dist"] = t_min_nbor_dist
return data
def deserialize_to_file(model_file: str, data: dict) -> None:
"""Deserialize the dictionary to a model file.
Parameters
----------
model_file : str
The model file to be saved.
data : dict
The dictionary to be deserialized.
"""
model = Model.deserialize(data["model"])
with tf.Graph().as_default() as graph, tf.Session(graph=graph) as sess:
place_holders = {}
for ii in ["coord", "box"]:
place_holders[ii] = tf.placeholder(
GLOBAL_TF_FLOAT_PRECISION, [None], name="t_" + ii
)
place_holders["type"] = tf.placeholder(tf.int32, [None], name="t_type")
place_holders["natoms_vec"] = tf.placeholder(
tf.int32, [model.get_ntypes() + 2], name="t_natoms"
)
place_holders["default_mesh"] = tf.placeholder(tf.int32, [None], name="t_mesh")
inputs = {}
# fparam, aparam
if model.get_numb_fparam() > 0:
inputs["fparam"] = tf.placeholder(
GLOBAL_TF_FLOAT_PRECISION,
[None],
name="t_fparam",
)
if model.get_numb_aparam() > 0:
inputs["aparam"] = tf.placeholder(
GLOBAL_TF_FLOAT_PRECISION,
[None],
name="t_aparam",
)
model.build(
place_holders["coord"],
place_holders["type"],
place_holders["natoms_vec"],
place_holders["box"],
place_holders["default_mesh"],
inputs,
reuse=False,
)
init = tf.global_variables_initializer()
tf.constant(
json.dumps({"model": data["model_def_script"]}, separators=(",", ":")),
name="train_attr/training_script",
dtype=tf.string,
)
if "min_nbor_dist" in data.get("@variables", {}):
tf.constant(
data["@variables"]["min_nbor_dist"],
name="train_attr/min_nbor_dist",
dtype=GLOBAL_TF_FLOAT_PRECISION,
)
run_sess(sess, init)
saver = tf.train.Saver()
with tempfile.TemporaryDirectory() as nt:
saver.save(
sess,
os.path.join(nt, "model.ckpt"),
global_step=0,
)
freeze(checkpoint_folder=nt, output=model_file, node_names=None)