|
126 | 126 |
|
127 | 127 |
|
128 | 128 | def _get_model_suffix(jdata) -> str:
|
129 |
| - """return the model suffix based on the backend""" |
| 129 | + """Return the model suffix based on the backend""" |
130 | 130 | backend = jdata.get("train_backend", "tensorflow")
|
131 | 131 | if backend == "tensorflow":
|
132 | 132 | suffix = ".pb"
|
@@ -193,7 +193,10 @@ def copy_model(numb_model, prv_iter_index, cur_iter_index, suffix=".pb"):
|
193 | 193 | prv_train_task = os.path.join(prv_train_path, train_task_fmt % ii)
|
194 | 194 | os.chdir(cur_train_path)
|
195 | 195 | os.symlink(os.path.relpath(prv_train_task), train_task_fmt % ii)
|
196 |
| - os.symlink(os.path.join(train_task_fmt % ii, "frozen_model%s" % suffix), "graph.%03d%s" % (ii, suffix)) |
| 196 | + os.symlink( |
| 197 | + os.path.join(train_task_fmt % ii, "frozen_model%s" % suffix), |
| 198 | + "graph.%03d%s" % (ii, suffix), |
| 199 | + ) |
197 | 200 | os.chdir(cwd)
|
198 | 201 | with open(os.path.join(cur_train_path, "copied"), "w") as fp:
|
199 | 202 | None
|
@@ -657,7 +660,9 @@ def make_train(iter_index, jdata, mdata):
|
657 | 660 | )
|
658 | 661 | if copied_models is not None:
|
659 | 662 | for ii in range(len(copied_models)):
|
660 |
| - _link_old_models(work_path, [copied_models[ii]], ii, basename="init%s" % suffix) |
| 663 | + _link_old_models( |
| 664 | + work_path, [copied_models[ii]], ii, basename="init%s" % suffix |
| 665 | + ) |
661 | 666 | # Copy user defined forward files
|
662 | 667 | symlink_user_forward_files(mdata=mdata, task_type="train", work_path=work_path)
|
663 | 668 | # HDF5 format for training data
|
@@ -811,11 +816,18 @@ def run_train(iter_index, jdata, mdata):
|
811 | 816 | elif training_init_frozen_model is not None or training_finetune_model is not None:
|
812 | 817 | forward_files.append(os.path.join("old", "init%s" % suffix))
|
813 | 818 |
|
814 |
| - backward_files = ["frozen_model%s" % suffix, "lcurve.out", "train.log", "checkpoint"] |
| 819 | + backward_files = [ |
| 820 | + "frozen_model%s" % suffix, |
| 821 | + "lcurve.out", |
| 822 | + "train.log", |
| 823 | + "checkpoint", |
| 824 | + ] |
815 | 825 | if suffix == ".pb":
|
816 |
| - backward_files += ["model.ckpt.meta", |
817 |
| - "model.ckpt.index", |
818 |
| - "model.ckpt.data-00000-of-00001"] |
| 826 | + backward_files += [ |
| 827 | + "model.ckpt.meta", |
| 828 | + "model.ckpt.index", |
| 829 | + "model.ckpt.data-00000-of-00001", |
| 830 | + ] |
819 | 831 | if jdata.get("dp_compress", False):
|
820 | 832 | backward_files.append("frozen_model_compressed%s" % suffix)
|
821 | 833 |
|
|
0 commit comments