Skip to content

feat: add has_message_passing API #3851

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Jun 6, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions deepmd/dpmodel/atomic_model/dp_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,10 @@ def mixed_types(self) -> bool:
"""
return self.descriptor.mixed_types()

def has_message_passing(self) -> bool:
"""Returns whether the atomic model has message passing."""
return self.descriptor.has_message_passing()

def forward_atomic(
self,
extended_coord: np.ndarray,
Expand Down
4 changes: 4 additions & 0 deletions deepmd/dpmodel/atomic_model/linear_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,10 @@
"""
return True

def has_message_passing(self) -> bool:
"""Returns whether the atomic model has message passing."""
return any(model.has_message_passing() for model in self.models)

Check warning on line 97 in deepmd/dpmodel/atomic_model/linear_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/atomic_model/linear_atomic_model.py#L97

Added line #L97 was not covered by tests

def get_rcut(self) -> float:
"""Get the cut-off radius."""
return max(self.get_model_rcuts())
Expand Down
4 changes: 4 additions & 0 deletions deepmd/dpmodel/atomic_model/make_base_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,10 @@ def mixed_types(self) -> bool:
"""
pass

@abstractmethod
def has_message_passing(self) -> bool:
"""Returns whether the descriptor has message passing."""

@abstractmethod
def fwd(
self,
Expand Down
4 changes: 4 additions & 0 deletions deepmd/dpmodel/atomic_model/pairtab_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,10 @@
# to match DPA1 and DPA2.
return True

def has_message_passing(self) -> bool:
"""Returns whether the atomic model has message passing."""
return False

Check warning on line 136 in deepmd/dpmodel/atomic_model/pairtab_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/atomic_model/pairtab_atomic_model.py#L136

Added line #L136 was not covered by tests

def serialize(self) -> dict:
dd = BaseAtomicModel.serialize(self)
dd.update(
Expand Down
4 changes: 4 additions & 0 deletions deepmd/dpmodel/descriptor/dpa1.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,6 +349,10 @@
"""
return self.se_atten.mixed_types()

def has_message_passing(self) -> bool:
"""Returns whether the descriptor has message passing."""
return False

Check warning on line 354 in deepmd/dpmodel/descriptor/dpa1.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/descriptor/dpa1.py#L354

Added line #L354 was not covered by tests

def get_env_protection(self) -> float:
"""Returns the protection of building environment matrix."""
return self.se_atten.get_env_protection()
Expand Down
4 changes: 4 additions & 0 deletions deepmd/dpmodel/descriptor/dpa2.py
Original file line number Diff line number Diff line change
Expand Up @@ -524,6 +524,10 @@
"""
return True

def has_message_passing(self) -> bool:
"""Returns whether the descriptor has message passing."""
return True

Check warning on line 529 in deepmd/dpmodel/descriptor/dpa2.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/descriptor/dpa2.py#L529

Added line #L529 was not covered by tests

def get_env_protection(self) -> float:
"""Returns the protection of building environment matrix."""
return self.env_protection
Expand Down
4 changes: 4 additions & 0 deletions deepmd/dpmodel/descriptor/hybrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,10 @@
"""
return any(descrpt.mixed_types() for descrpt in self.descrpt_list)

def has_message_passing(self) -> bool:
"""Returns whether the descriptor has message passing."""
return any(descrpt.has_message_passing() for descrpt in self.descrpt_list)

Check warning on line 143 in deepmd/dpmodel/descriptor/hybrid.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/descriptor/hybrid.py#L143

Added line #L143 was not covered by tests

def get_env_protection(self) -> float:
"""Returns the protection of building environment matrix. All descriptors should be the same."""
all_protection = [descrpt.get_env_protection() for descrpt in self.descrpt_list]
Expand Down
4 changes: 4 additions & 0 deletions deepmd/dpmodel/descriptor/make_base_descriptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,10 @@ def mixed_types(self) -> bool:
"""
pass

@abstractmethod
def has_message_passing(self) -> bool:
"""Returns whether the descriptor has message passing."""

@abstractmethod
def get_env_protection(self) -> float:
"""Returns the protection of building environment matrix."""
Expand Down
4 changes: 4 additions & 0 deletions deepmd/dpmodel/descriptor/se_e2_a.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,10 @@ def mixed_types(self):
"""
return False

def has_message_passing(self) -> bool:
"""Returns whether the descriptor has message passing."""
return False

def get_env_protection(self) -> float:
"""Returns the protection of building environment matrix."""
return self.env_protection
Expand Down
4 changes: 4 additions & 0 deletions deepmd/dpmodel/descriptor/se_r.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,10 @@
"""
return False

def has_message_passing(self) -> bool:
"""Returns whether the descriptor has message passing."""
return False

Check warning on line 210 in deepmd/dpmodel/descriptor/se_r.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/descriptor/se_r.py#L210

Added line #L210 was not covered by tests

def get_env_protection(self) -> float:
"""Returns the protection of building environment matrix."""
return self.env_protection
Expand Down
4 changes: 4 additions & 0 deletions deepmd/dpmodel/descriptor/se_t.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,10 @@
"""
return False

def has_message_passing(self) -> bool:
"""Returns whether the descriptor has message passing."""
return False

Check warning on line 190 in deepmd/dpmodel/descriptor/se_t.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/descriptor/se_t.py#L190

Added line #L190 was not covered by tests

def get_env_protection(self) -> float:
"""Returns the protection of building environment matrix."""
return self.env_protection
Expand Down
4 changes: 4 additions & 0 deletions deepmd/dpmodel/model/make_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -471,6 +471,10 @@ def mixed_types(self) -> bool:
"""
return self.atomic_model.mixed_types()

def has_message_passing(self) -> bool:
"""Returns whether the model has message passing."""
return self.atomic_model.has_message_passing()

def atomic_output_def(self) -> FittingOutputDef:
"""Get the output def of the atomic model."""
return self.atomic_model.atomic_output_def()
Expand Down
5 changes: 1 addition & 4 deletions deepmd/pt/entrypoints/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,10 +303,7 @@ def train(FLAGS):

def freeze(FLAGS):
model = torch.jit.script(inference.Tester(FLAGS.model, head=FLAGS.head).model)
if '"type": "dpa2"' in model.get_model_def_script():
extra_files = {"type": "dpa2"}
else:
extra_files = {"type": "else"}
extra_files = {}
torch.jit.save(
model,
FLAGS.output,
Expand Down
4 changes: 4 additions & 0 deletions deepmd/pt/model/atomic_model/dp_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,10 @@ def mixed_types(self) -> bool:
"""
return self.descriptor.mixed_types()

def has_message_passing(self) -> bool:
"""Returns whether the atomic model has message passing."""
return self.descriptor.has_message_passing()

def serialize(self) -> dict:
dd = BaseAtomicModel.serialize(self)
dd.update(
Expand Down
4 changes: 4 additions & 0 deletions deepmd/pt/model/atomic_model/linear_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,10 @@ def mixed_types(self) -> bool:
"""
return True

def has_message_passing(self) -> bool:
"""Returns whether the atomic model has message passing."""
return any(model.has_message_passing() for model in self.models)

def get_out_bias(self) -> torch.Tensor:
return self.out_bias

Expand Down
4 changes: 4 additions & 0 deletions deepmd/pt/model/atomic_model/pairtab_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,10 @@ def mixed_types(self) -> bool:
# to match DPA1 and DPA2.
return True

def has_message_passing(self) -> bool:
"""Returns whether the atomic model has message passing."""
return False

def serialize(self) -> dict:
dd = BaseAtomicModel.serialize(self)
dd.update(
Expand Down
4 changes: 4 additions & 0 deletions deepmd/pt/model/descriptor/dpa1.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,6 +342,10 @@ def mixed_types(self) -> bool:
"""
return self.se_atten.mixed_types()

def has_message_passing(self) -> bool:
"""Returns whether the descriptor has message passing."""
return False

def get_env_protection(self) -> float:
"""Returns the protection of building environment matrix."""
return self.se_atten.get_env_protection()
Expand Down
4 changes: 4 additions & 0 deletions deepmd/pt/model/descriptor/dpa2.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,6 +294,10 @@ def mixed_types(self) -> bool:
"""
return True

def has_message_passing(self) -> bool:
"""Returns whether the descriptor has message passing."""
return True

def get_env_protection(self) -> float:
"""Returns the protection of building environment matrix."""
# the env_protection of repinit is the same as that of the repformer
Expand Down
4 changes: 4 additions & 0 deletions deepmd/pt/model/descriptor/hybrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,10 @@ def mixed_types(self):
"""
return any(descrpt.mixed_types() for descrpt in self.descrpt_list)

def has_message_passing(self) -> bool:
"""Returns whether the descriptor has message passing."""
return any(descrpt.has_message_passing() for descrpt in self.descrpt_list)

def get_env_protection(self) -> float:
"""Returns the protection of building environment matrix. All descriptors should be the same."""
all_protection = [descrpt.get_env_protection() for descrpt in self.descrpt_list]
Expand Down
4 changes: 4 additions & 0 deletions deepmd/pt/model/descriptor/se_a.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,10 @@ def mixed_types(self):
"""
return self.sea.mixed_types()

def has_message_passing(self) -> bool:
"""Returns whether the descriptor has message passing."""
return False

def get_env_protection(self) -> float:
"""Returns the protection of building environment matrix."""
return self.sea.get_env_protection()
Expand Down
4 changes: 4 additions & 0 deletions deepmd/pt/model/descriptor/se_r.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,10 @@
"""
return False

def has_message_passing(self) -> bool:
"""Returns whether the descriptor has message passing."""
return False

Check warning on line 175 in deepmd/pt/model/descriptor/se_r.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/descriptor/se_r.py#L175

Added line #L175 was not covered by tests

def get_env_protection(self) -> float:
"""Returns the protection of building environment matrix."""
return self.env_protection
Expand Down
4 changes: 4 additions & 0 deletions deepmd/pt/model/descriptor/se_t.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,10 @@
"""
return self.seat.mixed_types()

def has_message_passing(self) -> bool:
"""Returns whether the descriptor has message passing."""
return False

Check warning on line 179 in deepmd/pt/model/descriptor/se_t.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/descriptor/se_t.py#L179

Added line #L179 was not covered by tests

def get_env_protection(self) -> float:
"""Returns the protection of building environment matrix."""
return self.seat.get_env_protection()
Expand Down
5 changes: 5 additions & 0 deletions deepmd/pt/model/model/frozen.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,11 @@
"""
return self.model.mixed_types()

@torch.jit.export
def has_message_passing(self) -> bool:
"""Returns whether the descriptor has message passing."""
return self.model.has_message_passing()

Check warning on line 112 in deepmd/pt/model/model/frozen.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/frozen.py#L112

Added line #L112 was not covered by tests

@torch.jit.export
def forward(
self,
Expand Down
5 changes: 5 additions & 0 deletions deepmd/pt/model/model/make_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -531,6 +531,11 @@ def mixed_types(self) -> bool:
"""
return self.atomic_model.mixed_types()

@torch.jit.export
def has_message_passing(self) -> bool:
"""Returns whether the model has message passing."""
return self.atomic_model.has_message_passing()

def forward(
self,
coord,
Expand Down
8 changes: 1 addition & 7 deletions source/api_cc/src/DeepPotPT.cc
Original file line number Diff line number Diff line change
Expand Up @@ -59,13 +59,7 @@ void DeepPotPT::init(const std::string& model,
}
std::unordered_map<std::string, std::string> metadata = {{"type", ""}};
module = torch::jit::load(model, device, metadata);
// TODO: This should be fixed after implement api to decide whether need to
// message passing and rename this metadata
if (metadata["type"] == "dpa2") {
do_message_passing = 1;
} else {
do_message_passing = 0;
}
do_message_passing = module.run_method("has_message_passing").toBool();
torch::jit::FusionStrategy strategy;
strategy = {{torch::jit::FusionBehavior::DYNAMIC, 10}};
torch::jit::setFusionStrategy(strategy);
Expand Down
65 changes: 43 additions & 22 deletions source/api_cc/tests/test_deeppot_dpa_pt.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,30 +25,51 @@ class TestInferDeepPotDpaPt : public ::testing::Test {
3.51, 2.51, 2.60, 4.27, 3.22, 1.56};
std::vector<int> atype = {0, 1, 1, 0, 1, 1};
std::vector<VALUETYPE> box = {13., 0., 0., 0., 13., 0., 0., 0., 13.};
std::vector<VALUETYPE> expected_e = {-93.295296030283, -186.548183879333,
-186.988827037855, -93.295307298571,
-186.799369383945, -186.507754447584};
// Generated by the following Python code:
// import numpy as np
// from deepmd.infer import DeepPot
// coord = np.array([
// 12.83, 2.56, 2.18, 12.09, 2.87, 2.74,
// 00.25, 3.32, 1.68, 3.36, 3.00, 1.81,
// 3.51, 2.51, 2.60, 4.27, 3.22, 1.56
// ]).reshape(1, -1)
// atype = np.array([0, 1, 1, 0, 1, 1])
// box = np.array([13., 0., 0., 0., 13., 0., 0., 0., 13.]).reshape(1, -1)
// dp = DeepPot("deeppot_dpa.pth")
// e, v, f, ae, av = dp.eval(coord, box, atype, atomic=True)
// np.set_printoptions(precision=16)
// print(f"{e.ravel()=} {v.ravel()=} {f.ravel()=} {ae.ravel()=}
// {av.ravel()=}")

std::vector<VALUETYPE> expected_e = {
-94.37720733019096, -187.43155959873033, -187.37830241580824,
-94.34880710985752, -187.38869830422271, -187.33919952642458};
std::vector<VALUETYPE> expected_f = {
4.964133039248, -0.542378158452, -0.381267990914, -0.563388054735,
0.340320322541, 0.473406268590, 0.159774831398, 0.684651816874,
-0.377008867620, -4.718603033927, -0.012604322920, -0.425121993870,
-0.500302936762, -0.637586419292, 0.930351899011, 0.658386154778,
0.167596761250, -0.220359315197};
5.402355596838843, -1.263284191331685, -0.697693239979719,
-1.025144852453706, 0.6554396369933394, 0.8817286288078215,
0.4364579972147229, 1.2150079148857598, -0.6778076371985796,
-6.939243547937094, 0.1571084862688049, -0.9017435514431825,
0.3597967524845581, -1.328808718007412, 2.0974306454214653,
1.7657780538526762, 0.5645368711911929, -0.7019148456078053};
std::vector<VALUETYPE> expected_v = {
-5.055176133632, -0.743392222876, 0.330846378467, -0.031111229868,
0.018004461517, 0.170047655301, -0.063087726831, -0.004361215202,
-0.042920299661, 3.624188578021, -0.252818122305, -0.026516806138,
-0.014510755893, 0.103726553937, 0.181001311123, -0.508673535094,
0.142101134395, 0.135339636607, -0.460067993361, 0.120541583338,
-0.206396390140, -0.630991740522, 0.397670086144, -0.427022150075,
0.656463775044, -0.209989614377, 0.288974239790, -7.603428707029,
-0.912313971544, 0.882084544041, -0.807760666057, -0.070519570327,
0.022164414763, 0.569448616709, 0.028522950109, 0.051641619288,
-1.452133900157, 0.037653156584, -0.144421326931, -0.308825789350,
0.302020522568, -0.446073217801, 0.313539058423, -0.461052923736,
0.678235442273, 1.429780276456, 0.080472825760, -0.103424652500,
0.123343430648, 0.011879908277, -0.018897229721, -0.235518441452,
-0.013999547600, 0.027007016662};
9.5175137906314511e-01, -2.0801835688892991e+00, 4.6860789988973117e-01,
-6.0178723966859824e+00, 1.2556002911926123e-01, 4.7887097832213565e-02,
5.6216590124464116e-01, 1.7071246159044051e-01, 8.4990129293690209e-02,
-1.2558035496847255e+00, -3.1123763096053136e-02, -4.4100135935181761e-01,
6.4707184007995455e-01, 1.5574441384822924e-01, 3.2409058144551339e-01,
2.8631311270672963e+00, -3.0375434485037031e-04, 3.9533024424985619e-01,
3.2722174727830535e+00, 1.1867224518409690e-01, -2.2250901443705223e-01,
5.0337980348311300e+00, 6.0517723355290898e-01, -5.5204995585567707e-01,
-3.8335680797875722e+00, -2.3083403461022087e-01, 3.1281970616476651e-01,
-1.0733902445454071e+01, -2.7634498084191517e-01, 1.5720135955951031e+00,
-2.9262906180354680e+00, 1.0845127764896278e-01, -1.1142053272645919e-01,
3.6066832583682209e+00, -1.9002351752094526e-01, 3.1875602887687587e-01,
3.6971839777382898e-01, -2.7352380159430506e-02, 1.0670299036230046e-01,
1.8155828042674422e+00, 4.9170982983933986e-01, -6.7166291183351579e-01,
-2.9003369690467395e+00, -7.6647630459927585e-01, 1.0566933380800889e+00,
-4.8620953903555858e-01, 4.0440213825136057e-01, -6.5227187264812003e-01,
-4.4421997400831864e-01, 1.4811202361724179e-01, -2.4354470120979710e-01,
5.3346700156430571e-01, -1.8977527286286849e-01, 3.1383559345422440e-01};
int natoms;
double expected_tot_e;
std::vector<VALUETYPE> expected_tot_v;
Expand Down
Loading