diff --git a/deepmd/main.py b/deepmd/main.py index 60b8da2850..d2a52568b4 100644 --- a/deepmd/main.py +++ b/deepmd/main.py @@ -424,29 +424,30 @@ def main_parser() -> argparse.ArgumentParser: parser_compress = subparsers.add_parser( "compress", parents=[parser_log, parser_mpi_log], - help="(Supported backend: TensorFlow) compress a model", + help="Compress a model", formatter_class=RawTextArgumentDefaultsHelpFormatter, epilog=textwrap.dedent( """\ examples: dp compress - dp compress -i graph.pb -o compressed.pb + dp --tf compress -i frozen_model.pb -o compressed_model.pb + dp --pt compress -i frozen_model.pth -o compressed_model.pth """ ), ) parser_compress.add_argument( "-i", "--input", - default="frozen_model.pb", + default="frozen_model", type=str, - help="The original frozen model, which will be compressed by the code", + help="The original frozen model, which will be compressed by the code. Filename (prefix) of the input model file. TensorFlow backend: suffix is .pb; PyTorch backend: suffix is .pth", ) parser_compress.add_argument( "-o", "--output", - default="frozen_model_compressed.pb", + default="frozen_model_compressed", type=str, - help="The compressed model", + help="The compressed model. Filename (prefix) of the output model file. TensorFlow backend: suffix is .pb; PyTorch backend: suffix is .pth", ) parser_compress.add_argument( "-s", diff --git a/deepmd/pt/entrypoints/compress.py b/deepmd/pt/entrypoints/compress.py new file mode 100644 index 0000000000..1042af3335 --- /dev/null +++ b/deepmd/pt/entrypoints/compress.py @@ -0,0 +1,31 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import json + +import torch + +from deepmd.pt.model.model import ( + get_model, +) + + +def enable_compression( + input_file: str, + output: str, + stride: float = 0.01, + extrapolate: int = 5, + check_frequency: int = -1, +): + saved_model = torch.jit.load(input_file, map_location="cpu") + model_def_script = json.loads(saved_model.model_def_script) + model = get_model(model_def_script) + model.load_state_dict(saved_model.state_dict()) + + model.enable_compression( + extrapolate, + stride, + stride * 10, + check_frequency, + ) + + model = torch.jit.script(model) + torch.jit.save(model, output) diff --git a/deepmd/pt/entrypoints/main.py b/deepmd/pt/entrypoints/main.py index c56e7f0731..7daa29d0f9 100644 --- a/deepmd/pt/entrypoints/main.py +++ b/deepmd/pt/entrypoints/main.py @@ -38,6 +38,9 @@ from deepmd.pt.cxx_op import ( ENABLE_CUSTOMIZED_OP, ) +from deepmd.pt.entrypoints.compress import ( + enable_compression, +) from deepmd.pt.infer import ( inference, ) @@ -346,10 +349,14 @@ def train( # save min_nbor_dist if min_nbor_dist is not None: if not multi_task: - trainer.model.min_nbor_dist = min_nbor_dist + trainer.model.min_nbor_dist = torch.tensor( + min_nbor_dist, dtype=torch.float64, device=DEVICE + ) else: for model_item in min_nbor_dist: - trainer.model[model_item].min_nbor_dist = min_nbor_dist[model_item] + trainer.model[model_item].min_nbor_dist = torch.tensor( + min_nbor_dist[model_item], dtype=torch.float64, device=DEVICE + ) trainer.run() @@ -549,6 +556,16 @@ def main(args: Optional[Union[list[str], argparse.Namespace]] = None): model_branch=FLAGS.model_branch, output=FLAGS.output, ) + elif FLAGS.command == "compress": + FLAGS.input = str(Path(FLAGS.input).with_suffix(".pth")) + FLAGS.output = str(Path(FLAGS.output).with_suffix(".pth")) + enable_compression( + input_file=FLAGS.input, + output=FLAGS.output, + stride=FLAGS.step, + extrapolate=FLAGS.extrapolate, + check_frequency=FLAGS.frequency, + ) else: raise RuntimeError(f"Invalid command {FLAGS.command}!") diff --git a/deepmd/pt/model/descriptor/se_a.py b/deepmd/pt/model/descriptor/se_a.py index 630b96ce9b..eadce86963 100644 --- a/deepmd/pt/model/descriptor/se_a.py +++ b/deepmd/pt/model/descriptor/se_a.py @@ -9,6 +9,7 @@ import numpy as np import torch +import torch.nn as nn from deepmd.dpmodel.utils.seed import ( child_seed, @@ -437,10 +438,6 @@ def update_sel( class DescrptBlockSeA(DescriptorBlock): ndescrpt: Final[int] __constants__: ClassVar[list] = ["ndescrpt"] - lower: dict[str, int] - upper: dict[str, int] - table_data: dict[str, torch.Tensor] - table_config: list[Union[int, float]] def __init__( self, @@ -500,13 +497,6 @@ def __init__( self.register_buffer("mean", mean) self.register_buffer("stddev", stddev) - # add for compression - self.compress = False - self.lower = {} - self.upper = {} - self.table_data = {} - self.table_config = [] - ndim = 1 if self.type_one_side else 2 filter_layers = NetworkCollection( ndim=ndim, ntypes=len(sel), network_type="embedding_network" @@ -529,6 +519,21 @@ def __init__( for param in self.parameters(): param.requires_grad = trainable + # add for compression + self.compress = False + self.compress_info = nn.ParameterList( + [ + nn.Parameter(torch.zeros(0, dtype=self.prec, device="cpu")) + for _ in range(len(self.filter_layers.networks)) + ] + ) + self.compress_data = nn.ParameterList( + [ + nn.Parameter(torch.zeros(0, dtype=self.prec, device=env.DEVICE)) + for _ in range(len(self.filter_layers.networks)) + ] + ) + def get_rcut(self) -> float: """Returns the cut-off radius.""" return self.rcut @@ -667,16 +672,39 @@ def reinit_exclude( def enable_compression( self, - table_data, - table_config, - lower, - upper, + table_data: dict[str, torch.Tensor], + table_config: list[Union[int, float]], + lower: dict[str, int], + upper: dict[str, int], ) -> None: + for embedding_idx, ll in enumerate(self.filter_layers.networks): + if self.type_one_side: + ii = embedding_idx + ti = -1 + else: + # ti: center atom type, ii: neighbor type... + ii = embedding_idx // self.ntypes + ti = embedding_idx % self.ntypes + if self.type_one_side: + net = "filter_-1_net_" + str(ii) + else: + net = "filter_" + str(ti) + "_net_" + str(ii) + info_ii = torch.as_tensor( + [ + lower[net], + upper[net], + upper[net] * table_config[0], + table_config[1], + table_config[2], + table_config[3], + ], + dtype=self.prec, + device="cpu", + ) + tensor_data_ii = table_data[net].to(device=env.DEVICE, dtype=self.prec) + self.compress_data[embedding_idx] = tensor_data_ii + self.compress_info[embedding_idx] = info_ii self.compress = True - self.table_data = table_data - self.table_config = table_config - self.lower = lower - self.upper = upper def forward( self, @@ -724,7 +752,9 @@ def forward( ) # nfnl x nnei exclude_mask = self.emask(nlist, extended_atype).view(nfnl, self.nnei) - for embedding_idx, ll in enumerate(self.filter_layers.networks): + for embedding_idx, (ll, compress_data_ii, compress_info_ii) in enumerate( + zip(self.filter_layers.networks, self.compress_data, self.compress_info) + ): if self.type_one_side: ii = embedding_idx ti = -1 @@ -751,23 +781,11 @@ def forward( ss = rr[:, :, :1] if self.compress: - if self.type_one_side: - net = "filter_-1_net_" + str(ii) - else: - net = "filter_" + str(ti) + "_net_" + str(ii) - info = [ - self.lower[net], - self.upper[net], - self.upper[net] * self.table_config[0], - self.table_config[1], - self.table_config[2], - self.table_config[3], - ] ss = ss.reshape(-1, 1) # xyz_scatter_tensor in tf - tensor_data = self.table_data[net].to(ss.device).to(dtype=self.prec) + gr = torch.ops.deepmd.tabulate_fusion_se_a( - tensor_data.contiguous(), - torch.tensor(info, dtype=self.prec, device="cpu").contiguous(), + compress_data_ii.contiguous(), + compress_info_ii.cpu().contiguous(), ss.contiguous(), rr.contiguous(), self.filter_neuron[-1], diff --git a/deepmd/pt/model/descriptor/se_atten.py b/deepmd/pt/model/descriptor/se_atten.py index 8c56ccf827..6ec02de514 100644 --- a/deepmd/pt/model/descriptor/se_atten.py +++ b/deepmd/pt/model/descriptor/se_atten.py @@ -71,11 +71,6 @@ def tabulate_fusion_se_atten( @DescriptorBlock.register("se_atten") class DescrptBlockSeAtten(DescriptorBlock): - lower: dict[str, int] - upper: dict[str, int] - table_data: dict[str, torch.Tensor] - table_config: list[Union[int, float]] - def __init__( self, rcut: float, @@ -202,14 +197,6 @@ def __init__( ln_eps = 1e-5 self.ln_eps = ln_eps - # add for compression - self.compress = False - self.is_sorted = False - self.lower = {} - self.upper = {} - self.table_data = {} - self.table_config = [] - if isinstance(sel, int): sel = [sel] @@ -282,6 +269,16 @@ def __init__( self.filter_layers_strip = filter_layers_strip self.stats = None + # add for compression + self.compress = False + self.is_sorted = False + self.compress_info = nn.ParameterList( + [nn.Parameter(torch.zeros(0, dtype=self.prec, device="cpu"))] + ) + self.compress_data = nn.ParameterList( + [nn.Parameter(torch.zeros(0, dtype=self.prec, device=env.DEVICE))] + ) + def get_rcut(self) -> float: """Returns the cut-off radius.""" return self.rcut @@ -431,11 +428,21 @@ def enable_compression( lower, upper, ) -> None: + net = "filter_net" + self.compress_info[0] = torch.as_tensor( + [ + lower[net], + upper[net], + upper[net] * table_config[0], + table_config[1], + table_config[2], + table_config[3], + ], + dtype=self.prec, + device="cpu", + ) + self.compress_data[0] = table_data[net].to(device=env.DEVICE, dtype=self.prec) self.compress = True - self.table_data = table_data - self.table_config = table_config - self.lower = lower - self.upper = upper def forward( self, @@ -544,15 +551,6 @@ def forward( xyz_scatter = torch.matmul(rr.permute(0, 2, 1), gg) elif self.tebd_input_mode in ["strip"]: if self.compress: - net = "filter_net" - info = [ - self.lower[net], - self.upper[net], - self.upper[net] * self.table_config[0], - self.table_config[1], - self.table_config[2], - self.table_config[3], - ] ss = ss.reshape(-1, 1) # nfnl x nnei x ng # gg_s = self.filter_layers.networks[0](ss) @@ -569,14 +567,12 @@ def forward( gg_t = gg_t * sw.reshape(-1, self.nnei, 1) # nfnl x nnei x ng # gg = gg_s * gg_t + gg_s - tensor_data = self.table_data[net].to(gg_t.device).to(dtype=self.prec) - info_tensor = torch.tensor(info, dtype=self.prec, device="cpu") gg_t = gg_t.reshape(-1, gg_t.size(-1)) # Convert all tensors to the required precision at once ss, rr, gg_t = (t.to(self.prec) for t in (ss, rr, gg_t)) xyz_scatter = torch.ops.deepmd.tabulate_fusion_se_atten( - tensor_data.contiguous(), - info_tensor.contiguous(), + self.compress_data[0].contiguous(), + self.compress_info[0].cpu().contiguous(), ss.contiguous(), rr.contiguous(), gg_t.contiguous(), diff --git a/deepmd/pt/model/descriptor/se_r.py b/deepmd/pt/model/descriptor/se_r.py index 4a74b7671f..f70fdfa9f1 100644 --- a/deepmd/pt/model/descriptor/se_r.py +++ b/deepmd/pt/model/descriptor/se_r.py @@ -7,6 +7,7 @@ import numpy as np import torch +import torch.nn as nn from deepmd.dpmodel.utils import EnvMat as DPEnvMat from deepmd.dpmodel.utils.seed import ( @@ -78,11 +79,6 @@ def tabulate_fusion_se_r( @BaseDescriptor.register("se_e2_r") @BaseDescriptor.register("se_r") class DescrptSeR(BaseDescriptor, torch.nn.Module): - lower: dict[str, int] - upper: dict[str, int] - table_data: dict[str, torch.Tensor] - table_config: list[Union[int, float]] - def __init__( self, rcut, @@ -117,12 +113,6 @@ def __init__( # order matters, placed after the assignment of self.ntypes self.reinit_exclude(exclude_types) self.env_protection = env_protection - # add for compression - self.compress = False - self.lower = {} - self.upper = {} - self.table_data = {} - self.table_config = [] self.sel = sel self.sec = torch.tensor( @@ -160,6 +150,21 @@ def __init__( for param in self.parameters(): param.requires_grad = trainable + # add for compression + self.compress = False + self.compress_info = nn.ParameterList( + [ + nn.Parameter(torch.zeros(0, dtype=self.prec, device="cpu")) + for _ in range(len(self.filter_layers.networks)) + ] + ) + self.compress_data = nn.ParameterList( + [ + nn.Parameter(torch.zeros(0, dtype=self.prec, device=env.DEVICE)) + for _ in range(len(self.filter_layers.networks)) + ] + ) + def get_rcut(self) -> float: """Returns the cut-off radius.""" return self.rcut @@ -373,23 +378,42 @@ def enable_compression( if self.compress: raise ValueError("Compression is already enabled.") data = self.serialize() - self.table = DPTabulate( + table = DPTabulate( self, data["neuron"], data["type_one_side"], data["exclude_types"], ActivationFn(data["activation_function"]), ) - self.table_config = [ + table_config = [ table_extrapolate, table_stride_1, table_stride_2, check_frequency, ] - self.lower, self.upper = self.table.build( + lower, upper = table.build( min_nbor_dist, table_extrapolate, table_stride_1, table_stride_2 ) - self.table_data = self.table.data + table_data = table.data + + for ii, ll in enumerate(self.filter_layers.networks): + net = "filter_-1_net_" + str(ii) + info_ii = torch.as_tensor( + [ + lower[net], + upper[net], + upper[net] * table_config[0], + table_config[1], + table_config[2], + table_config[3], + ], + dtype=self.prec, + device="cpu", + ) + tensor_data_ii = table_data[net].to(device=env.DEVICE, dtype=self.prec) + self.compress_data[ii] = tensor_data_ii + self.compress_info[ii] = info_ii + self.compress = True def forward( @@ -460,7 +484,9 @@ def forward( # nfnl x nnei exclude_mask = self.emask(nlist, atype_ext).view(nfnl, self.nnei) xyz_scatter_total = [] - for ii, ll in enumerate(self.filter_layers.networks): + for ii, (ll, compress_data_ii, compress_info_ii) in enumerate( + zip(self.filter_layers.networks, self.compress_data, self.compress_info) + ): # nfnl x nt mm = exclude_mask[:, self.sec[ii] : self.sec[ii + 1]] # nfnl x nt x 1 @@ -468,19 +494,9 @@ def forward( ss = ss * mm[:, :, None] if self.compress: ss = ss.squeeze(-1) - net = "filter_-1_net_" + str(ii) - info = [ - self.lower[net], - self.upper[net], - self.upper[net] * self.table_config[0], - self.table_config[1], - self.table_config[2], - self.table_config[3], - ] - tensor_data = self.table_data[net].to(ss.device).to(dtype=self.prec) xyz_scatter = torch.ops.deepmd.tabulate_fusion_se_r( - tensor_data.contiguous(), - torch.tensor(info, dtype=self.prec, device="cpu").contiguous(), + compress_data_ii.contiguous(), + compress_info_ii.cpu().contiguous(), ss, self.filter_neuron[-1], )[0] diff --git a/deepmd/pt/model/descriptor/se_t.py b/deepmd/pt/model/descriptor/se_t.py index 5a634d7549..0eec78fd2f 100644 --- a/deepmd/pt/model/descriptor/se_t.py +++ b/deepmd/pt/model/descriptor/se_t.py @@ -9,6 +9,7 @@ import numpy as np import torch +import torch.nn as nn from deepmd.dpmodel.utils.seed import ( child_seed, @@ -468,10 +469,6 @@ def update_sel( class DescrptBlockSeT(DescriptorBlock): ndescrpt: Final[int] __constants__: ClassVar[list] = ["ndescrpt"] - lower: dict[str, int] - upper: dict[str, int] - table_data: dict[str, torch.Tensor] - table_config: list[Union[int, float]] def __init__( self, @@ -543,12 +540,6 @@ def __init__( self.split_sel = self.sel self.nnei = sum(sel) self.ndescrpt = self.nnei * 4 - # add for compression - self.compress = False - self.lower = {} - self.upper = {} - self.table_data = {} - self.table_config = [] wanted_shape = (self.ntypes, self.nnei, 4) mean = torch.zeros(wanted_shape, dtype=self.prec, device=env.DEVICE) @@ -578,6 +569,21 @@ def __init__( for param in self.parameters(): param.requires_grad = trainable + # add for compression + self.compress = False + self.compress_info = nn.ParameterList( + [ + nn.Parameter(torch.zeros(0, dtype=self.prec, device="cpu")) + for _ in range(len(self.filter_layers.networks)) + ] + ) + self.compress_data = nn.ParameterList( + [ + nn.Parameter(torch.zeros(0, dtype=self.prec, device=env.DEVICE)) + for _ in range(len(self.filter_layers.networks)) + ] + ) + def get_rcut(self) -> float: """Returns the cut-off radius.""" return self.rcut @@ -717,11 +723,27 @@ def enable_compression( lower, upper, ) -> None: + for embedding_idx, ll in enumerate(self.filter_layers.networks): + ti = embedding_idx % self.ntypes + tj = embedding_idx // self.ntypes + if ti <= tj: + net = "filter_" + str(ti) + "_net_" + str(tj) + info_ii = torch.as_tensor( + [ + lower[net], + upper[net], + upper[net] * table_config[0], + table_config[1], + table_config[2], + table_config[3], + ], + dtype=self.prec, + device="cpu", + ) + tensor_data_ii = table_data[net].to(device=env.DEVICE, dtype=self.prec) + self.compress_data[embedding_idx] = tensor_data_ii + self.compress_info[embedding_idx] = info_ii self.compress = True - self.table_data = table_data - self.table_config = table_config - self.lower = lower - self.upper = upper def forward( self, @@ -789,7 +811,9 @@ def forward( ) # nfnl x nnei exclude_mask = self.emask(nlist, extended_atype).view(nfnl, self.nnei) - for embedding_idx, ll in enumerate(self.filter_layers.networks): + for embedding_idx, (ll, compress_data_ii, compress_info_ii) in enumerate( + zip(self.filter_layers.networks, self.compress_data, self.compress_info) + ): ti = embedding_idx % self.ntypes nei_type_j = self.sel[ti] tj = embedding_idx // self.ntypes @@ -808,23 +832,11 @@ def forward( env_ij = torch.einsum("ijm,ikm->ijk", rr_i, rr_j) if self.compress: ebd_env_ij = env_ij.view(-1, 1) - net = "filter_" + str(ti) + "_net_" + str(tj) - info = [ - self.lower[net], - self.upper[net], - self.upper[net] * self.table_config[0], - self.table_config[1], - self.table_config[2], - self.table_config[3], - ] - tensor_data = ( - self.table_data[net].to(env_ij.device).to(dtype=self.prec) - ) ebd_env_ij = ebd_env_ij.to(dtype=self.prec) env_ij = env_ij.to(dtype=self.prec) res_ij = torch.ops.deepmd.tabulate_fusion_se_t( - tensor_data.contiguous(), - torch.tensor(info, dtype=self.prec, device="cpu").contiguous(), + compress_data_ii.contiguous(), + compress_info_ii.cpu().contiguous(), ebd_env_ij.contiguous(), env_ij.contiguous(), self.filter_neuron[-1], diff --git a/deepmd/pt/model/model/model.py b/deepmd/pt/model/model/model.py index d3670737ba..c7b5986f0c 100644 --- a/deepmd/pt/model/model/model.py +++ b/deepmd/pt/model/model/model.py @@ -8,6 +8,9 @@ from deepmd.dpmodel.model.base_model import ( make_base_model, ) +from deepmd.pt.utils import ( + env, +) from deepmd.utils.path import ( DPPath, ) @@ -18,7 +21,9 @@ def __init__(self, *args, **kwargs): """Construct a basic model for different tasks.""" torch.nn.Module.__init__(self) self.model_def_script = "" - self.min_nbor_dist = None + self.register_buffer( + "min_nbor_dist", torch.tensor(-1.0, dtype=torch.float64, device=env.DEVICE) + ) def compute_or_load_stat( self, @@ -50,7 +55,9 @@ def get_model_def_script(self) -> str: @torch.jit.export def get_min_nbor_dist(self) -> Optional[float]: """Get the minimum distance between two atoms.""" - return self.min_nbor_dist + if self.min_nbor_dist.item() == -1.0: + return None + return self.min_nbor_dist.item() @torch.jit.export def get_ntypes(self): diff --git a/deepmd/pt/utils/serialization.py b/deepmd/pt/utils/serialization.py index 1c6ea096aa..5d3b02482a 100644 --- a/deepmd/pt/utils/serialization.py +++ b/deepmd/pt/utils/serialization.py @@ -12,6 +12,9 @@ from deepmd.pt.train.wrapper import ( ModelWrapper, ) +from deepmd.pt.utils import ( + env, +) def serialize_from_file(model_file: str) -> dict: @@ -73,6 +76,10 @@ def deserialize_to_file(model_file: str, data: dict) -> None: # JIT will happy in this way... model.model_def_script = json.dumps(data["model_def_script"]) if "min_nbor_dist" in data.get("@variables", {}): - model.min_nbor_dist = float(data["@variables"]["min_nbor_dist"]) + model.min_nbor_dist = torch.tensor( + float(data["@variables"]["min_nbor_dist"]), + dtype=env.GLOBAL_PT_FLOAT_PRECISION, + device=env.DEVICE, + ) model = torch.jit.script(model) torch.jit.save(model, model_file) diff --git a/deepmd/pt/utils/tabulate.py b/deepmd/pt/utils/tabulate.py index 7394ac082d..796f7dcd52 100644 --- a/deepmd/pt/utils/tabulate.py +++ b/deepmd/pt/utils/tabulate.py @@ -426,53 +426,69 @@ def _n_all_excluded(self) -> int: # customized op -def grad(xbar, y, functype): # functype=tanh, gelu, .. +def grad(xbar: torch.Tensor, y: torch.Tensor, functype: int): if functype == 1: return 1 - y * y + elif functype == 2: - var = np.tanh(SQRT_2_PI * (xbar + GGELU * xbar**3)) + var = torch.tanh(SQRT_2_PI * (xbar + GGELU * xbar**3)) return ( 0.5 * SQRT_2_PI * xbar * (1 - var**2) * (3 * GGELU * xbar**2 + 1) + 0.5 * var + 0.5 ) + elif functype == 3: - return 0.0 if xbar <= 0 else 1.0 + return torch.where(xbar > 0, torch.ones_like(xbar), torch.zeros_like(xbar)) + elif functype == 4: - return 0.0 if xbar <= 0 or xbar >= 6 else 1.0 + return torch.where( + (xbar > 0) & (xbar < 6), torch.ones_like(xbar), torch.zeros_like(xbar) + ) + elif functype == 5: - return 1.0 - 1.0 / (1.0 + np.exp(xbar)) + return 1.0 - 1.0 / (1.0 + torch.exp(xbar)) + elif functype == 6: return y * (1 - y) - raise ValueError(f"Unsupported function type: {functype}") + else: + raise ValueError(f"Unsupported function type: {functype}") -def grad_grad(xbar, y, functype): +def grad_grad(xbar: torch.Tensor, y: torch.Tensor, functype: int): if functype == 1: return -2 * y * (1 - y * y) + elif functype == 2: - var1 = np.tanh(SQRT_2_PI * (xbar + GGELU * xbar**3)) + var1 = torch.tanh(SQRT_2_PI * (xbar + GGELU * xbar**3)) var2 = SQRT_2_PI * (1 - var1**2) * (3 * GGELU * xbar**2 + 1) return ( 3 * GGELU * SQRT_2_PI * xbar**2 * (1 - var1**2) - SQRT_2_PI * xbar * var2 * (3 * GGELU * xbar**2 + 1) * var1 + var2 ) + elif functype in [3, 4]: - return 0 + return torch.zeros_like(xbar) + elif functype == 5: - return np.exp(xbar) / ((1 + np.exp(xbar)) * (1 + np.exp(xbar))) + exp_xbar = torch.exp(xbar) + return exp_xbar / ((1 + exp_xbar) * (1 + exp_xbar)) + elif functype == 6: return y * (1 - y) * (1 - 2 * y) + else: - return -1 + return -torch.ones_like(xbar) def unaggregated_dy_dx_s( y: torch.Tensor, w_np: np.ndarray, xbar: torch.Tensor, functype: int ): w = torch.from_numpy(w_np).to(env.DEVICE) + y = y.to(env.DEVICE) + xbar = xbar.to(env.DEVICE) if y.dim() != 2: raise ValueError("Dim of input y should be 2") if w.dim() != 2: @@ -480,13 +496,11 @@ def unaggregated_dy_dx_s( if xbar.dim() != 2: raise ValueError("Dim of input xbar should be 2") - length, width = y.shape - dy_dx = torch.zeros_like(y) - w = torch.flatten(w) + grad_xbar_y = grad(xbar, y, functype) - for ii in range(length): - for jj in range(width): - dy_dx[ii, jj] = grad(xbar[ii, jj], y[ii, jj], functype) * w[jj] + w = torch.flatten(w)[: y.shape[1]].repeat(y.shape[0], 1) + + dy_dx = grad_xbar_y * w return dy_dx @@ -499,6 +513,9 @@ def unaggregated_dy2_dx_s( functype: int, ): w = torch.from_numpy(w_np).to(env.DEVICE) + y = y.to(env.DEVICE) + dy = dy.to(env.DEVICE) + xbar = xbar.to(env.DEVICE) if y.dim() != 2: raise ValueError("Dim of input y should be 2") if dy.dim() != 2: @@ -508,15 +525,11 @@ def unaggregated_dy2_dx_s( if xbar.dim() != 2: raise ValueError("Dim of input xbar should be 2") - length, width = y.shape - dy2_dx = torch.zeros_like(y) - w = torch.flatten(w) + grad_grad_result = grad_grad(xbar, y, functype) - for ii in range(length): - for jj in range(width): - dy2_dx[ii, jj] = ( - grad_grad(xbar[ii, jj], y[ii, jj], functype) * w[jj] * w[jj] - ) + w_flattened = torch.flatten(w)[: y.shape[1]].repeat(y.shape[0], 1) + + dy2_dx = grad_grad_result * w_flattened * w_flattened return dy2_dx @@ -540,22 +553,22 @@ def unaggregated_dy_dx( length, width = z.shape size = w.shape[0] - dy_dx = torch.flatten(dy_dx) - dz_dx = torch.zeros_like(z) + grad_ybar_z = grad(ybar, z, functype) + + dy_dx = dy_dx.view(-1)[: (length * size)].view(length, size) + + accumulator = dy_dx @ w + + dz_drou = grad_ybar_z * accumulator - for kk in range(length): - for ii in range(width): - dz_drou = grad(ybar[kk, ii], z[kk, ii], functype) - accumulator = 0.0 - for jj in range(size): - accumulator += w[jj, ii] * dy_dx[kk * size + jj] - dz_drou *= accumulator - if width == 2 * size or width == size: - dz_drou += dy_dx[kk * size + ii % size] - dz_dx[kk, ii] = dz_drou + if width == size: + dz_drou += dy_dx + if width == 2 * size: + dy_dx = torch.cat((dy_dx, dy_dx), dim=1) + dz_drou += dy_dx - return dz_dx + return dz_drou def unaggregated_dy2_dx( @@ -580,28 +593,24 @@ def unaggregated_dy2_dx( length, width = z.shape size = w.shape[0] - dy_dx = torch.flatten(dy_dx) - dy2_dx = torch.flatten(dy2_dx) - - dz2_dx = torch.zeros_like(z) - - for kk in range(length): - for ii in range(width): - dz_drou = grad(ybar[kk, ii], z[kk, ii], functype) - accumulator1 = 0.0 - for jj in range(size): - accumulator1 += w[jj, ii] * dy2_dx[kk * size + jj] - dz_drou *= accumulator1 - accumulator2 = 0.0 - for jj in range(size): - accumulator2 += w[jj, ii] * dy_dx[kk * size + jj] - dz_drou += ( - grad_grad(ybar[kk, ii], z[kk, ii], functype) - * accumulator2 - * accumulator2 - ) - if width == 2 * size or width == size: - dz_drou += dy2_dx[kk * size + ii % size] - dz2_dx[kk, ii] = dz_drou - return dz2_dx + grad_ybar_z = grad(ybar, z, functype) + grad_grad_ybar_z = grad_grad(ybar, z, functype) + + dy2_dx = dy2_dx.view(-1)[: (length * size)].view(length, size) + dy_dx = dy_dx.view(-1)[: (length * size)].view(length, size) + + accumulator1 = dy2_dx @ w + accumulator2 = dy_dx @ w + + dz_drou = ( + grad_ybar_z * accumulator1 + grad_grad_ybar_z * accumulator2 * accumulator2 + ) + + if width == size: + dz_drou += dy2_dx + if width == 2 * size: + dy2_dx = torch.cat((dy2_dx, dy2_dx), dim=1) + dz_drou += dy2_dx + + return dz_drou diff --git a/deepmd/tf/entrypoints/main.py b/deepmd/tf/entrypoints/main.py index b8bfdef6d8..9c3759288a 100644 --- a/deepmd/tf/entrypoints/main.py +++ b/deepmd/tf/entrypoints/main.py @@ -77,6 +77,12 @@ def main(args: Optional[Union[list[str], argparse.Namespace]] = None): elif args.command == "transfer": transfer(**dict_args) elif args.command == "compress": + dict_args["input"] = format_model_suffix( + dict_args["input"], preferred_backend=args.backend, strict_prefer=True + ) + dict_args["output"] = format_model_suffix( + dict_args["output"], preferred_backend=args.backend, strict_prefer=True + ) compress(**dict_args) elif args.command == "convert-from": convert(**dict_args) diff --git a/source/tests/pt/common.py b/source/tests/pt/common.py index 173e9d52dc..8709c8b4f9 100644 --- a/source/tests/pt/common.py +++ b/source/tests/pt/common.py @@ -1,4 +1,5 @@ # SPDX-License-Identifier: LGPL-3.0-or-later +import pathlib from typing import ( Optional, Union, @@ -7,6 +8,7 @@ import numpy as np import torch +from deepmd.common import j_loader as dp_j_loader from deepmd.main import ( main, ) @@ -15,6 +17,12 @@ GLOBAL_PT_FLOAT_PRECISION, ) +tests_path = pathlib.Path(__file__).parent.absolute() + + +def j_loader(filename): + return dp_j_loader(tests_path / filename) + def run_dp(cmd: str) -> int: """Run DP directly from the entry point instead of the subprocess. diff --git a/source/tests/pt/model/test_descriptor_dpa1.py b/source/tests/pt/model/test_descriptor_dpa1.py index ddd5dc6c3c..9652a63944 100644 --- a/source/tests/pt/model/test_descriptor_dpa1.py +++ b/source/tests/pt/model/test_descriptor_dpa1.py @@ -245,7 +245,11 @@ def test_descriptor_block(self): des = DescrptBlockSeAtten( **dparams, ).to(env.DEVICE) - des.load_state_dict(torch.load(self.file_model_param, weights_only=True)) + state_dict = torch.load(self.file_model_param, weights_only=True) + # this is an old state dict, modify manually + state_dict["compress_info.0"] = des.compress_info[0] + state_dict["compress_data.0"] = des.compress_data[0] + des.load_state_dict(state_dict) coord = self.coord atype = self.atype box = self.cell @@ -371,5 +375,7 @@ def translate_se_atten_and_type_embd_dicts_to_dpa1( tk = "type_embedding." + kk record[all_keys.index(tk)] = True target_dict[tk] = type_embd_dict[kk] + record[all_keys.index("se_atten.compress_data.0")] = True + record[all_keys.index("se_atten.compress_info.0")] = True assert all(record) return target_dict diff --git a/source/tests/pt/model/test_descriptor_dpa2.py b/source/tests/pt/model/test_descriptor_dpa2.py index 17d609a2f9..7efbe0a921 100644 --- a/source/tests/pt/model/test_descriptor_dpa2.py +++ b/source/tests/pt/model/test_descriptor_dpa2.py @@ -194,5 +194,7 @@ def translate_type_embd_dicts_to_dpa2( tk = "type_embedding." + kk record[all_keys.index(tk)] = True target_dict[tk] = type_embd_dict[kk] + record[all_keys.index("repinit.compress_data.0")] = True + record[all_keys.index("repinit.compress_info.0")] = True assert all(record) return target_dict diff --git a/source/tests/pt/model/test_model.py b/source/tests/pt/model/test_model.py index 84f5a113a3..0fa6baec68 100644 --- a/source/tests/pt/model/test_model.py +++ b/source/tests/pt/model/test_model.py @@ -62,6 +62,8 @@ def torch2tf(torch_name, last_layer_id=None): offset = int(fields[3] == "networks") + 1 element_id = int(fields[2 + offset]) if fields[1] == "descriptor": + if fields[2].startswith("compress_"): + return None layer_id = int(fields[4 + offset]) + 1 weight_type = fields[5 + offset] ret = "filter_type_all/%s_%d_%d:0" % (weight_type, layer_id, element_id) @@ -318,6 +320,8 @@ def test_consistency(self): for name, param in my_model.named_parameters(): name = name.replace("sea.", "") var_name = torch2tf(name, last_layer_id=len(self.n_neuron)) + if var_name is None: + continue var = vs_dict[var_name].value with torch.no_grad(): src = torch.from_numpy(var) @@ -412,6 +416,8 @@ def step(step_id): for name, param in my_model.named_parameters(): name = name.replace("sea.", "") var_name = torch2tf(name, last_layer_id=len(self.n_neuron)) + if var_name is None: + continue var_grad = vs_dict[var_name].gradient param_grad = param.grad.cpu() var_grad = torch.tensor(var_grad, device="cpu") diff --git a/source/tests/pt/model/test_unused_params.py b/source/tests/pt/model/test_unused_params.py index 98bbe7040e..8c223d7590 100644 --- a/source/tests/pt/model/test_unused_params.py +++ b/source/tests/pt/model/test_unused_params.py @@ -86,7 +86,8 @@ def get_contributing_params(y, top_level=True): contributing_parameters = set(get_contributing_params(ret0["energy"])) all_parameters = set(self.model.parameters()) non_contributing = all_parameters - contributing_parameters - self.assertEqual(len(non_contributing), 0) + # 2 for compression + self.assertEqual(len(non_contributing), 2) if __name__ == "__main__": diff --git a/source/tests/pt/model_compression/data/set.000/box.npy b/source/tests/pt/model_compression/data/set.000/box.npy new file mode 100644 index 0000000000..aa092a9aad Binary files /dev/null and b/source/tests/pt/model_compression/data/set.000/box.npy differ diff --git a/source/tests/pt/model_compression/data/set.000/coord.npy b/source/tests/pt/model_compression/data/set.000/coord.npy new file mode 100644 index 0000000000..2205bfd452 Binary files /dev/null and b/source/tests/pt/model_compression/data/set.000/coord.npy differ diff --git a/source/tests/pt/model_compression/data/set.000/energy.npy b/source/tests/pt/model_compression/data/set.000/energy.npy new file mode 100644 index 0000000000..6c1e5ce145 Binary files /dev/null and b/source/tests/pt/model_compression/data/set.000/energy.npy differ diff --git a/source/tests/pt/model_compression/data/set.000/force.npy b/source/tests/pt/model_compression/data/set.000/force.npy new file mode 100644 index 0000000000..10698ebb59 Binary files /dev/null and b/source/tests/pt/model_compression/data/set.000/force.npy differ diff --git a/source/tests/pt/model_compression/data/type.raw b/source/tests/pt/model_compression/data/type.raw new file mode 100644 index 0000000000..4eeae61de1 --- /dev/null +++ b/source/tests/pt/model_compression/data/type.raw @@ -0,0 +1,6 @@ +0 +1 +1 +0 +1 +1 diff --git a/source/tests/pt/model_compression/data/type_map.raw b/source/tests/pt/model_compression/data/type_map.raw new file mode 100644 index 0000000000..e900768b1d --- /dev/null +++ b/source/tests/pt/model_compression/data/type_map.raw @@ -0,0 +1,2 @@ +O +H diff --git a/source/tests/pt/model_compression/input.json b/source/tests/pt/model_compression/input.json new file mode 100644 index 0000000000..69adad0f35 --- /dev/null +++ b/source/tests/pt/model_compression/input.json @@ -0,0 +1,77 @@ +{ + "_comment1": " model parameters", + "model": { + "type_map": [ + "O", + "H" + ], + "descriptor": { + "type": "se_e2_a", + "sel": [ + 46, + 92 + ], + "rcut_smth": 0.50, + "rcut": 6.00, + "_comment": "N2=2N1, N2=N1, and otherwise can be tested", + "neuron": [ + 4, + 8, + 17, + 17 + ], + "resnet_dt": false, + "axis_neuron": 16, + "seed": 1, + "_comment2": " that's all" + }, + "fitting_net": { + "neuron": [ + 20, + 20, + 20 + ], + "resnet_dt": true, + "seed": 1, + "_comment3": " that's all" + }, + "_comment4": " that's all" + }, + + "learning_rate": { + "type": "exp", + "decay_steps": 5000, + "start_lr": 0.001, + "stop_lr": 3.51e-8, + "_comment5": "that's all" + }, + + "loss": { + "type": "ener", + "start_pref_e": 0.02, + "limit_pref_e": 1, + "start_pref_f": 1000, + "limit_pref_f": 1, + "start_pref_v": 0, + "limit_pref_v": 0, + "_comment6": " that's all" + }, + + "training": { + "training_data": { + "systems": [ + "model_compression/data" + ], + "batch_size": "auto", + "_comment7": "that's all" + }, + "numb_steps": 1, + "seed": 10, + "disp_file": "lcurve.out", + "disp_freq": 1, + "save_freq": 1, + "_comment9": "that's all" + }, + + "_comment10": "that's all" +} diff --git a/source/tests/pt/test_model_compression_se_a.py b/source/tests/pt/test_model_compression_se_a.py new file mode 100644 index 0000000000..0e7bf0b69a --- /dev/null +++ b/source/tests/pt/test_model_compression_se_a.py @@ -0,0 +1,576 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import json +import os +import unittest + +import numpy as np + +from deepmd.env import ( + GLOBAL_NP_FLOAT_PRECISION, +) +from deepmd.infer.deep_eval import ( + DeepEval, +) + +from .common import ( + j_loader, + run_dp, + tests_path, +) + +if GLOBAL_NP_FLOAT_PRECISION == np.float32: + default_places = 4 +else: + default_places = 10 + + +def _file_delete(file): + if os.path.isdir(file): + os.rmdir(file) + elif os.path.isfile(file): + os.remove(file) + + +def _init_models(): + data_file = str(tests_path / os.path.join("model_compression", "data")) + frozen_model = str(tests_path / "dp-original.pth") + compressed_model = str(tests_path / "dp-compressed.pth") + INPUT = str(tests_path / "input.json") + jdata = j_loader(str(tests_path / os.path.join("model_compression", "input.json"))) + jdata["training"]["training_data"]["systems"] = data_file + with open(INPUT, "w") as fp: + json.dump(jdata, fp, indent=4) + + ret = run_dp("dp --pt train " + INPUT) + np.testing.assert_equal(ret, 0, "DP train failed!") + ret = run_dp("dp --pt freeze -o " + frozen_model) + np.testing.assert_equal(ret, 0, "DP freeze failed!") + ret = run_dp( + "dp --pt compress " + " -i " + frozen_model + " -o " + compressed_model + ) + np.testing.assert_equal(ret, 0, "DP model compression failed!") + return INPUT, frozen_model, compressed_model + + +def _init_models_exclude_types(): + data_file = str(tests_path / os.path.join("model_compression", "data")) + frozen_model = str(tests_path / "dp-original-exclude-types.pth") + compressed_model = str(tests_path / "dp-compressed-exclude-types.pth") + INPUT = str(tests_path / "input.json") + jdata = j_loader(str(tests_path / os.path.join("model_compression", "input.json"))) + jdata["model"]["descriptor"]["exclude_types"] = [[0, 1]] + jdata["training"]["training_data"]["systems"] = data_file + with open(INPUT, "w") as fp: + json.dump(jdata, fp, indent=4) + + ret = run_dp("dp --pt train " + INPUT) + np.testing.assert_equal(ret, 0, "DP train failed!") + ret = run_dp("dp --pt freeze -o " + frozen_model) + np.testing.assert_equal(ret, 0, "DP freeze failed!") + ret = run_dp( + "dp --pt compress " + " -i " + frozen_model + " -o " + compressed_model + ) + np.testing.assert_equal(ret, 0, "DP model compression failed!") + return INPUT, frozen_model, compressed_model + + +def setUpModule(): + global \ + INPUT, \ + FROZEN_MODEL, \ + COMPRESSED_MODEL, \ + INPUT_ET, \ + FROZEN_MODEL_ET, \ + COMPRESSED_MODEL_ET + INPUT, FROZEN_MODEL, COMPRESSED_MODEL = _init_models() + INPUT_ET, FROZEN_MODEL_ET, COMPRESSED_MODEL_ET = _init_models_exclude_types() + + +class TestDeepPotAPBC(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.dp_original = DeepEval(FROZEN_MODEL) + cls.dp_compressed = DeepEval(COMPRESSED_MODEL) + cls.coords = np.array( + [ + 12.83, + 2.56, + 2.18, + 12.09, + 2.87, + 2.74, + 00.25, + 3.32, + 1.68, + 3.36, + 3.00, + 1.81, + 3.51, + 2.51, + 2.60, + 4.27, + 3.22, + 1.56, + ] + ) + cls.atype = [0, 1, 1, 0, 1, 1] + cls.box = np.array([13.0, 0.0, 0.0, 0.0, 13.0, 0.0, 0.0, 0.0, 13.0]) + + def test_attrs(self): + self.assertEqual(self.dp_original.get_ntypes(), 2) + self.assertAlmostEqual(self.dp_original.get_rcut(), 6.0, places=default_places) + self.assertEqual(self.dp_original.get_type_map(), ["O", "H"]) + self.assertEqual(self.dp_original.get_dim_fparam(), 0) + self.assertEqual(self.dp_original.get_dim_aparam(), 0) + + self.assertEqual(self.dp_compressed.get_ntypes(), 2) + self.assertAlmostEqual( + self.dp_compressed.get_rcut(), 6.0, places=default_places + ) + self.assertEqual(self.dp_compressed.get_type_map(), ["O", "H"]) + self.assertEqual(self.dp_compressed.get_dim_fparam(), 0) + self.assertEqual(self.dp_compressed.get_dim_aparam(), 0) + + def test_1frame(self): + ee0, ff0, vv0 = self.dp_original.eval( + self.coords, self.box, self.atype, atomic=False + ) + ee1, ff1, vv1 = self.dp_compressed.eval( + self.coords, self.box, self.atype, atomic=False + ) + # check shape of the returns + nframes = 1 + natoms = len(self.atype) + self.assertEqual(ee0.shape, (nframes, 1)) + self.assertEqual(ff0.shape, (nframes, natoms, 3)) + self.assertEqual(vv0.shape, (nframes, 9)) + self.assertEqual(ee1.shape, (nframes, 1)) + self.assertEqual(ff1.shape, (nframes, natoms, 3)) + self.assertEqual(vv1.shape, (nframes, 9)) + # check values + np.testing.assert_almost_equal(ff0, ff1, default_places) + np.testing.assert_almost_equal(ee0, ee1, default_places) + np.testing.assert_almost_equal(vv0, vv1, default_places) + + def test_1frame_atm(self): + ee0, ff0, vv0, ae0, av0 = self.dp_original.eval( + self.coords, self.box, self.atype, atomic=True + ) + ee1, ff1, vv1, ae1, av1 = self.dp_compressed.eval( + self.coords, self.box, self.atype, atomic=True + ) + # check shape of the returns + nframes = 1 + natoms = len(self.atype) + self.assertEqual(ee0.shape, (nframes, 1)) + self.assertEqual(ff0.shape, (nframes, natoms, 3)) + self.assertEqual(vv0.shape, (nframes, 9)) + self.assertEqual(ae0.shape, (nframes, natoms, 1)) + self.assertEqual(av0.shape, (nframes, natoms, 9)) + self.assertEqual(ee1.shape, (nframes, 1)) + self.assertEqual(ff1.shape, (nframes, natoms, 3)) + self.assertEqual(vv1.shape, (nframes, 9)) + self.assertEqual(ae1.shape, (nframes, natoms, 1)) + self.assertEqual(av1.shape, (nframes, natoms, 9)) + # check values + np.testing.assert_almost_equal(ff0, ff1, default_places) + np.testing.assert_almost_equal(ae0, ae1, default_places) + np.testing.assert_almost_equal(av0, av1, default_places) + np.testing.assert_almost_equal(ee0, ee1, default_places) + np.testing.assert_almost_equal(vv0, vv1, default_places) + + def test_2frame_atm(self): + coords2 = np.concatenate((self.coords, self.coords)) + box2 = np.concatenate((self.box, self.box)) + ee0, ff0, vv0, ae0, av0 = self.dp_original.eval( + coords2, box2, self.atype, atomic=True + ) + ee1, ff1, vv1, ae1, av1 = self.dp_compressed.eval( + coords2, box2, self.atype, atomic=True + ) + # check shape of the returns + nframes = 2 + natoms = len(self.atype) + self.assertEqual(ee0.shape, (nframes, 1)) + self.assertEqual(ff0.shape, (nframes, natoms, 3)) + self.assertEqual(vv0.shape, (nframes, 9)) + self.assertEqual(ae0.shape, (nframes, natoms, 1)) + self.assertEqual(av0.shape, (nframes, natoms, 9)) + self.assertEqual(ee1.shape, (nframes, 1)) + self.assertEqual(ff1.shape, (nframes, natoms, 3)) + self.assertEqual(vv1.shape, (nframes, 9)) + self.assertEqual(ae1.shape, (nframes, natoms, 1)) + self.assertEqual(av1.shape, (nframes, natoms, 9)) + + # check values + np.testing.assert_almost_equal(ff0, ff1, default_places) + np.testing.assert_almost_equal(ae0, ae1, default_places) + np.testing.assert_almost_equal(av0, av1, default_places) + np.testing.assert_almost_equal(ee0, ee1, default_places) + np.testing.assert_almost_equal(vv0, vv1, default_places) + + +class TestDeepPotANoPBC(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.dp_original = DeepEval(FROZEN_MODEL) + cls.dp_compressed = DeepEval(COMPRESSED_MODEL) + cls.coords = np.array( + [ + 12.83, + 2.56, + 2.18, + 12.09, + 2.87, + 2.74, + 00.25, + 3.32, + 1.68, + 3.36, + 3.00, + 1.81, + 3.51, + 2.51, + 2.60, + 4.27, + 3.22, + 1.56, + ] + ) + cls.atype = [0, 1, 1, 0, 1, 1] + cls.box = None + + def test_1frame(self): + ee0, ff0, vv0 = self.dp_original.eval( + self.coords, self.box, self.atype, atomic=False + ) + ee1, ff1, vv1 = self.dp_compressed.eval( + self.coords, self.box, self.atype, atomic=False + ) + # check shape of the returns + nframes = 1 + natoms = len(self.atype) + self.assertEqual(ee0.shape, (nframes, 1)) + self.assertEqual(ff0.shape, (nframes, natoms, 3)) + self.assertEqual(vv0.shape, (nframes, 9)) + self.assertEqual(ee1.shape, (nframes, 1)) + self.assertEqual(ff1.shape, (nframes, natoms, 3)) + self.assertEqual(vv1.shape, (nframes, 9)) + # check values + np.testing.assert_almost_equal(ff0, ff1, default_places) + np.testing.assert_almost_equal(ee0, ee1, default_places) + np.testing.assert_almost_equal(vv0, vv1, default_places) + + def test_1frame_atm(self): + ee0, ff0, vv0, ae0, av0 = self.dp_original.eval( + self.coords, self.box, self.atype, atomic=True + ) + ee1, ff1, vv1, ae1, av1 = self.dp_compressed.eval( + self.coords, self.box, self.atype, atomic=True + ) + # check shape of the returns + nframes = 1 + natoms = len(self.atype) + self.assertEqual(ee0.shape, (nframes, 1)) + self.assertEqual(ff0.shape, (nframes, natoms, 3)) + self.assertEqual(vv0.shape, (nframes, 9)) + self.assertEqual(ae0.shape, (nframes, natoms, 1)) + self.assertEqual(av0.shape, (nframes, natoms, 9)) + self.assertEqual(ee1.shape, (nframes, 1)) + self.assertEqual(ff1.shape, (nframes, natoms, 3)) + self.assertEqual(vv1.shape, (nframes, 9)) + self.assertEqual(ae1.shape, (nframes, natoms, 1)) + self.assertEqual(av1.shape, (nframes, natoms, 9)) + # check values + np.testing.assert_almost_equal(ff0, ff1, default_places) + np.testing.assert_almost_equal(ae0, ae1, default_places) + np.testing.assert_almost_equal(av0, av1, default_places) + np.testing.assert_almost_equal(ee0, ee1, default_places) + np.testing.assert_almost_equal(vv0, vv1, default_places) + + def test_2frame_atm(self): + coords2 = np.concatenate((self.coords, self.coords)) + ee0, ff0, vv0, ae0, av0 = self.dp_original.eval( + coords2, self.box, self.atype, atomic=True + ) + ee1, ff1, vv1, ae1, av1 = self.dp_compressed.eval( + coords2, self.box, self.atype, atomic=True + ) + # check shape of the returns + nframes = 2 + natoms = len(self.atype) + self.assertEqual(ee0.shape, (nframes, 1)) + self.assertEqual(ff0.shape, (nframes, natoms, 3)) + self.assertEqual(vv0.shape, (nframes, 9)) + self.assertEqual(ae0.shape, (nframes, natoms, 1)) + self.assertEqual(av0.shape, (nframes, natoms, 9)) + self.assertEqual(ee1.shape, (nframes, 1)) + self.assertEqual(ff1.shape, (nframes, natoms, 3)) + self.assertEqual(vv1.shape, (nframes, 9)) + self.assertEqual(ae1.shape, (nframes, natoms, 1)) + self.assertEqual(av1.shape, (nframes, natoms, 9)) + + # check values + np.testing.assert_almost_equal(ff0, ff1, default_places) + np.testing.assert_almost_equal(ae0, ae1, default_places) + np.testing.assert_almost_equal(av0, av1, default_places) + np.testing.assert_almost_equal(ee0, ee1, default_places) + np.testing.assert_almost_equal(vv0, vv1, default_places) + + +class TestDeepPotALargeBoxNoPBC(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.dp_original = DeepEval(FROZEN_MODEL) + cls.dp_compressed = DeepEval(COMPRESSED_MODEL) + cls.coords = np.array( + [ + 12.83, + 2.56, + 2.18, + 12.09, + 2.87, + 2.74, + 00.25, + 3.32, + 1.68, + 3.36, + 3.00, + 1.81, + 3.51, + 2.51, + 2.60, + 4.27, + 3.22, + 1.56, + ] + ) + cls.atype = [0, 1, 1, 0, 1, 1] + cls.box = np.array([19.0, 0.0, 0.0, 0.0, 13.0, 0.0, 0.0, 0.0, 13.0]) + + def test_1frame(self): + ee0, ff0, vv0 = self.dp_original.eval( + self.coords, self.box, self.atype, atomic=False + ) + ee1, ff1, vv1 = self.dp_compressed.eval( + self.coords, self.box, self.atype, atomic=False + ) + # check shape of the returns + nframes = 1 + natoms = len(self.atype) + self.assertEqual(ee0.shape, (nframes, 1)) + self.assertEqual(ff0.shape, (nframes, natoms, 3)) + self.assertEqual(vv0.shape, (nframes, 9)) + self.assertEqual(ee1.shape, (nframes, 1)) + self.assertEqual(ff1.shape, (nframes, natoms, 3)) + self.assertEqual(vv1.shape, (nframes, 9)) + # check values + np.testing.assert_almost_equal(ff0, ff1, default_places) + np.testing.assert_almost_equal(ee0, ee1, default_places) + np.testing.assert_almost_equal(vv0, vv1, default_places) + + def test_1frame_atm(self): + ee0, ff0, vv0, ae0, av0 = self.dp_original.eval( + self.coords, self.box, self.atype, atomic=True + ) + ee1, ff1, vv1, ae1, av1 = self.dp_compressed.eval( + self.coords, self.box, self.atype, atomic=True + ) + # check shape of the returns + nframes = 1 + natoms = len(self.atype) + self.assertEqual(ee0.shape, (nframes, 1)) + self.assertEqual(ff0.shape, (nframes, natoms, 3)) + self.assertEqual(vv0.shape, (nframes, 9)) + self.assertEqual(ae0.shape, (nframes, natoms, 1)) + self.assertEqual(av0.shape, (nframes, natoms, 9)) + self.assertEqual(ee1.shape, (nframes, 1)) + self.assertEqual(ff1.shape, (nframes, natoms, 3)) + self.assertEqual(vv1.shape, (nframes, 9)) + self.assertEqual(ae1.shape, (nframes, natoms, 1)) + self.assertEqual(av1.shape, (nframes, natoms, 9)) + # check values + np.testing.assert_almost_equal(ff0, ff1, default_places) + np.testing.assert_almost_equal(ae0, ae1, default_places) + np.testing.assert_almost_equal(av0, av1, default_places) + np.testing.assert_almost_equal(ee0, ee1, default_places) + np.testing.assert_almost_equal(vv0, vv1, default_places) + + def test_ase(self): + from ase import ( + Atoms, + ) + + from deepmd.tf.calculator import ( + DP, + ) + + water0 = Atoms( + "OHHOHH", + positions=self.coords.reshape((-1, 3)), + cell=self.box.reshape((3, 3)), + calculator=DP(FROZEN_MODEL), + ) + water1 = Atoms( + "OHHOHH", + positions=self.coords.reshape((-1, 3)), + cell=self.box.reshape((3, 3)), + calculator=DP(COMPRESSED_MODEL), + ) + ee0 = water0.get_potential_energy() + ff0 = water0.get_forces() + ee1 = water1.get_potential_energy() + ff1 = water1.get_forces() + # nframes = 1 + np.testing.assert_almost_equal(ff0, ff1, default_places) + np.testing.assert_almost_equal(ee0, ee1, default_places) + + +class TestDeepPotAPBCExcludeTypes(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.dp_original = DeepEval(FROZEN_MODEL_ET) + cls.dp_compressed = DeepEval(COMPRESSED_MODEL_ET) + cls.coords = np.array( + [ + 12.83, + 2.56, + 2.18, + 12.09, + 2.87, + 2.74, + 00.25, + 3.32, + 1.68, + 3.36, + 3.00, + 1.81, + 3.51, + 2.51, + 2.60, + 4.27, + 3.22, + 1.56, + ] + ) + cls.atype = [0, 1, 1, 0, 1, 1] + cls.box = np.array([13.0, 0.0, 0.0, 0.0, 13.0, 0.0, 0.0, 0.0, 13.0]) + + @classmethod + def tearDownClass(cls): + _file_delete(INPUT_ET) + _file_delete(FROZEN_MODEL_ET) + _file_delete(COMPRESSED_MODEL_ET) + _file_delete("out.json") + _file_delete("compress.json") + _file_delete("checkpoint") + _file_delete("model.ckpt.meta") + _file_delete("model.ckpt.index") + _file_delete("model.ckpt.data-00000-of-00001") + _file_delete("model.ckpt-100.meta") + _file_delete("model.ckpt-100.index") + _file_delete("model.ckpt-100.data-00000-of-00001") + _file_delete("model-compression/checkpoint") + _file_delete("model-compression/model.ckpt.meta") + _file_delete("model-compression/model.ckpt.index") + _file_delete("model-compression/model.ckpt.data-00000-of-00001") + _file_delete("model-compression") + _file_delete("input_v2_compat.json") + _file_delete("lcurve.out") + + def test_attrs(self): + self.assertEqual(self.dp_original.get_ntypes(), 2) + self.assertAlmostEqual(self.dp_original.get_rcut(), 6.0, places=default_places) + self.assertEqual(self.dp_original.get_type_map(), ["O", "H"]) + self.assertEqual(self.dp_original.get_dim_fparam(), 0) + self.assertEqual(self.dp_original.get_dim_aparam(), 0) + + self.assertEqual(self.dp_compressed.get_ntypes(), 2) + self.assertAlmostEqual( + self.dp_compressed.get_rcut(), 6.0, places=default_places + ) + self.assertEqual(self.dp_compressed.get_type_map(), ["O", "H"]) + self.assertEqual(self.dp_compressed.get_dim_fparam(), 0) + self.assertEqual(self.dp_compressed.get_dim_aparam(), 0) + + def test_1frame(self): + ee0, ff0, vv0 = self.dp_original.eval( + self.coords, self.box, self.atype, atomic=False + ) + ee1, ff1, vv1 = self.dp_compressed.eval( + self.coords, self.box, self.atype, atomic=False + ) + # check shape of the returns + nframes = 1 + natoms = len(self.atype) + self.assertEqual(ee0.shape, (nframes, 1)) + self.assertEqual(ff0.shape, (nframes, natoms, 3)) + self.assertEqual(vv0.shape, (nframes, 9)) + self.assertEqual(ee1.shape, (nframes, 1)) + self.assertEqual(ff1.shape, (nframes, natoms, 3)) + self.assertEqual(vv1.shape, (nframes, 9)) + # check values + np.testing.assert_almost_equal(ff0, ff1, default_places) + np.testing.assert_almost_equal(ee0, ee1, default_places) + np.testing.assert_almost_equal(vv0, vv1, default_places) + + def test_1frame_atm(self): + ee0, ff0, vv0, ae0, av0 = self.dp_original.eval( + self.coords, self.box, self.atype, atomic=True + ) + ee1, ff1, vv1, ae1, av1 = self.dp_compressed.eval( + self.coords, self.box, self.atype, atomic=True + ) + # check shape of the returns + nframes = 1 + natoms = len(self.atype) + self.assertEqual(ee0.shape, (nframes, 1)) + self.assertEqual(ff0.shape, (nframes, natoms, 3)) + self.assertEqual(vv0.shape, (nframes, 9)) + self.assertEqual(ae0.shape, (nframes, natoms, 1)) + self.assertEqual(av0.shape, (nframes, natoms, 9)) + self.assertEqual(ee1.shape, (nframes, 1)) + self.assertEqual(ff1.shape, (nframes, natoms, 3)) + self.assertEqual(vv1.shape, (nframes, 9)) + self.assertEqual(ae1.shape, (nframes, natoms, 1)) + self.assertEqual(av1.shape, (nframes, natoms, 9)) + # check values + np.testing.assert_almost_equal(ff0, ff1, default_places) + np.testing.assert_almost_equal(ae0, ae1, default_places) + np.testing.assert_almost_equal(av0, av1, default_places) + np.testing.assert_almost_equal(ee0, ee1, default_places) + np.testing.assert_almost_equal(vv0, vv1, default_places) + + def test_2frame_atm(self): + coords2 = np.concatenate((self.coords, self.coords)) + box2 = np.concatenate((self.box, self.box)) + ee0, ff0, vv0, ae0, av0 = self.dp_original.eval( + coords2, box2, self.atype, atomic=True + ) + ee1, ff1, vv1, ae1, av1 = self.dp_compressed.eval( + coords2, box2, self.atype, atomic=True + ) + # check shape of the returns + nframes = 2 + natoms = len(self.atype) + self.assertEqual(ee0.shape, (nframes, 1)) + self.assertEqual(ff0.shape, (nframes, natoms, 3)) + self.assertEqual(vv0.shape, (nframes, 9)) + self.assertEqual(ae0.shape, (nframes, natoms, 1)) + self.assertEqual(av0.shape, (nframes, natoms, 9)) + self.assertEqual(ee1.shape, (nframes, 1)) + self.assertEqual(ff1.shape, (nframes, natoms, 3)) + self.assertEqual(vv1.shape, (nframes, 9)) + self.assertEqual(ae1.shape, (nframes, natoms, 1)) + self.assertEqual(av1.shape, (nframes, natoms, 9)) + + # check values + np.testing.assert_almost_equal(ff0, ff1, default_places) + np.testing.assert_almost_equal(ae0, ae1, default_places) + np.testing.assert_almost_equal(av0, av1, default_places) + np.testing.assert_almost_equal(ee0, ee1, default_places) + np.testing.assert_almost_equal(vv0, vv1, default_places) + + +if __name__ == "__main__": + unittest.main() diff --git a/source/tests/pt/test_tabulate.py b/source/tests/pt/test_tabulate.py index c03773827d..5a812f420c 100644 --- a/source/tests/pt/test_tabulate.py +++ b/source/tests/pt/test_tabulate.py @@ -61,7 +61,7 @@ def test_ops(self): ) dy_tf_numpy = dy_tf.numpy() - dy_pt_numpy = dy_pt.detach().numpy() + dy_pt_numpy = dy_pt.detach().cpu().numpy() np.testing.assert_almost_equal(dy_tf_numpy, dy_pt_numpy, decimal=10) @@ -82,7 +82,7 @@ def test_ops(self): ) dy2_tf_numpy = dy2_tf.numpy() - dy2_pt_numpy = dy2_pt.detach().numpy() + dy2_pt_numpy = dy2_pt.detach().cpu().numpy() np.testing.assert_almost_equal(dy2_tf_numpy, dy2_pt_numpy, decimal=10)