Skip to content

Commit d5a7c04

Browse files
committed
Revert "fix: bugs in uts for property fit (deepmodeling#4120)"
This reverts commit 96ed5df.
1 parent 81b9d20 commit d5a7c04

File tree

4 files changed

+107
-175
lines changed

4 files changed

+107
-175
lines changed

backend/read_env.py

-2
Original file line numberDiff line numberDiff line change
@@ -60,8 +60,6 @@ def get_argument_from_env() -> Tuple[str, list, list, dict, str, str]:
6060
cmake_minimum_required_version = "3.21"
6161
cmake_args.append("-DUSE_ROCM_TOOLKIT:BOOL=TRUE")
6262
rocm_root = os.environ.get("ROCM_ROOT")
63-
if not rocm_root:
64-
rocm_root = os.environ.get("ROCM_PATH")
6563
if rocm_root:
6664
cmake_args.append(f"-DCMAKE_HIP_COMPILER_ROCM_ROOT:STRING={rocm_root}")
6765
hipcc_flags = os.environ.get("HIP_HIPCC_FLAGS")

deepmd/pt/train/training.py

+21
Original file line numberDiff line numberDiff line change
@@ -392,6 +392,21 @@ def get_lr(lr_params):
392392
# JIT
393393
if JIT:
394394
self.model = torch.jit.script(self.model)
395+
396+
# Initialize the fparam
397+
if model_params["fitting_net"]["numb_fparam"] > 0:
398+
nbatches = 10
399+
datasets = training_data.systems
400+
dataloaders = training_data.dataloaders
401+
fparams = []
402+
for i in range(len(datasets)):
403+
iterator = iter(dataloaders[i])
404+
numb_batches = min(nbatches, len(dataloaders[i]))
405+
for _ in range(numb_batches):
406+
stat_data = next(iterator)
407+
fparams.append(stat_data['fparam'])
408+
fparams = torch.tensor(fparams)
409+
init_fparam(self.model, fparams)
395410

396411
# Model Wrapper
397412
self.wrapper = ModelWrapper(self.model, self.loss, model_params=model_params)
@@ -1212,6 +1227,12 @@ def get_additional_data_requirement(_model):
12121227
return additional_data_requirement
12131228

12141229

1230+
def init_fparam(_model, fparams):
1231+
fitting = _model.get_fitting_net()
1232+
fitting['fparam_avg'] = torch.unsqueeze(torch.mean(fparams, dim=0), dim=-1).to(DEVICE)
1233+
fitting['fparam_inv_std'] = torch.unsqueeze(1. / torch.std(fparams, dim=0), dim=-1).to(DEVICE)
1234+
1235+
12151236
def get_loss(loss_params, start_lr, _ntypes, _model):
12161237
loss_type = loss_params.get("type", "ener")
12171238
if loss_type == "ener":

doc/install/install-from-source.md

+1-2
Original file line numberDiff line numberDiff line change
@@ -155,8 +155,7 @@ The path to the CUDA toolkit directory. CUDA 9.0 or later is supported. NVCC is
155155

156156
**Type**: Path; **Default**: Detected automatically
157157

158-
The path to the ROCM toolkit directory. If `ROCM_ROOT` is not set, it will look for `ROCM_PATH`; if `ROCM_PATH` is also not set, it will be detected using `hipconfig --rocmpath`.
159-
158+
The path to the ROCM toolkit directory.
160159
:::
161160

162161
:::{envvar} DP_ENABLE_TENSORFLOW

0 commit comments

Comments
 (0)