Skip to content

Commit 55f70b2

Browse files
authored
JDFTXInfile Comparison Methods (materialsproject#4416)
* Organizing `JDFTXInfile` contents * `is_equal_to` method for AbstractTag for equality checking, default parameters for JDFTx, unwritten `is_comparable_to` method for `JDFTXInfile` * `get_tag_object_on_val` alternative to `get_tag_object` that auto-sets format option for `MultiformatTag`s * Fixing value types in default inputs, repeatable multiformattags to be set as a list with elements of different format options, changing `JDFTXStructure.get_str` to build a dictionary representation of each ion instead of a list (not sure how it was even working before) * changing `is_equal_to` to be defined in `AbstractTag` as a repeatability handler, and now `_is_equal_to` to the abstractmethod to be implemented which compares two non-list values * Series of methods for gathering differing tags between two JDFTXInfile objects, along convenience arguments to filter which tags to pay attention to. Not sure if `is_comparable_to` is implementable in a way generalizable enough to be useful. * typing check * Adding back for now, seems bloaty but could be helpful for readability in implementations * Removing commented out code * Changing my mind about ion-species inclusion as a default ensure_include_tag * reordering * TODOs * TODOs * Partial fix for setting selective_dynamics through site_properties in `JDFTXInfile.from_structure` * mypy fix
1 parent 12c573a commit 55f70b2

File tree

6 files changed

+904
-341
lines changed

6 files changed

+904
-341
lines changed

src/pymatgen/io/jdftx/generic_tags.py

Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,56 @@ def validate_value_type(self, tag: str, value: Any, try_auto_type_fix: bool = Fa
7070
tuple[str, bool, Any]: The tag, whether the value is of the correct type, and the possibly fixed value.
7171
"""
7272

73+
def is_equal_to(self, val1: Any | list[Any], obj2: AbstractTag, val2: Any | list[Any]) -> bool:
74+
"""Check if the two values are equal.
75+
76+
Args:
77+
val1 (Any): The value of this tag object.
78+
obj2 (AbstractTag): The other tag object.
79+
val2 (Any): The value of the other tag object.
80+
81+
Returns:
82+
bool: True if the two tag object/value pairs are equal, False otherwise.
83+
"""
84+
if self.can_repeat:
85+
if not obj2.can_repeat:
86+
return False
87+
val1 = val1 if isinstance(val1, list) else [val1]
88+
val2 = val2 if isinstance(val2, list) else [val2]
89+
if len(val1) != len(val2):
90+
return False
91+
return all(True in [self._is_equal_to(v1, obj2, v2) for v2 in val2] for v1 in val1)
92+
return self._is_equal_to(val1, obj2, val2)
93+
94+
@abstractmethod
95+
def _is_equal_to(self, val1: Any, obj2: AbstractTag, val2: Any) -> bool:
96+
"""Check if the two values are equal.
97+
98+
Used to check if the two values are equal. Assumes val1 and val2 are single elements.
99+
100+
Args:
101+
val1 (Any): The value of this tag object.
102+
obj2 (AbstractTag): The other tag object.
103+
val2 (Any): The value of the other tag object.
104+
105+
Returns:
106+
bool: True if the two tag object/value pairs are equal, False otherwise.
107+
"""
108+
109+
def _is_same_tagtype(
110+
self,
111+
obj2: AbstractTag,
112+
) -> bool:
113+
"""Check if the two values are equal.
114+
115+
Args:
116+
obj2 (AbstractTag): The other tag object.
117+
118+
Returns:
119+
bool: True if the two tag object/value pairs are equal, False otherwise.
120+
"""
121+
return isinstance(self, type(obj2))
122+
73123
def _validate_value_type(
74124
self, type_check: type, tag: str, value: Any, try_auto_type_fix: bool = False
75125
) -> tuple[str, bool, Any]:
@@ -258,6 +308,19 @@ def validate_value_type(self, tag: str, value: Any, try_auto_type_fix: bool = Fa
258308
"""
259309
return self._validate_value_type(bool, tag, value, try_auto_type_fix=try_auto_type_fix)
260310

311+
def _is_equal_to(self, val1: Any, obj2: AbstractTag, val2: Any) -> bool:
312+
"""Check if the two values are equal.
313+
314+
Args:
315+
val1 (Any): The value of this tag object.
316+
obj2 (AbstractTag): The other tag object.
317+
val2 (Any): The value of the other tag object.
318+
319+
Returns:
320+
bool: True if the two tag object/value pairs are equal, False otherwise.
321+
"""
322+
return self._is_same_tagtype(obj2) and val1 == val2
323+
261324
def raise_value_error(self, tag: str, value: str) -> None:
262325
"""Raise a ValueError for the value string.
263326
@@ -335,6 +398,23 @@ def validate_value_type(self, tag: str, value: Any, try_auto_type_fix: bool = Fa
335398
"""
336399
return self._validate_value_type(str, tag, value, try_auto_type_fix=try_auto_type_fix)
337400

401+
def _is_equal_to(self, val1: Any, obj2: AbstractTag, val2: Any) -> bool:
402+
"""Check if the two values are equal.
403+
404+
Args:
405+
val1 (Any): The value of this tag object.
406+
obj2 (AbstractTag): The other tag object.
407+
val2 (Any): The value of the other tag object.
408+
409+
Returns:
410+
bool: True if the two tag object/value pairs are equal, False otherwise.
411+
"""
412+
if self._is_same_tagtype(obj2):
413+
if not all(isinstance(x, str) for x in (val1, val2)):
414+
raise ValueError("Both values must be strings for StrTag comparison")
415+
return val1.strip() == val2.strip()
416+
return False
417+
338418
def read(self, tag: str, value: str) -> str:
339419
"""Read the value string for this tag.
340420
@@ -379,6 +459,8 @@ class AbstractNumericTag(AbstractTag):
379459
ub: float | None = None # upper bound
380460
lb_incl: bool = True # lower bound inclusive
381461
ub_incl: bool = True # upper bound inclusive
462+
eq_atol: float = 1.0e-8 # absolute tolerance for equality check
463+
eq_rtol: float = 1.0e-5 # relative tolerance for equality check
382464

383465
def val_is_within_bounds(self, value: float) -> bool:
384466
"""Check if the value is within the bounds.
@@ -425,6 +507,22 @@ def validate_value_bounds(
425507
return False, self.get_invalid_value_error_str(tag, value)
426508
return True, ""
427509

510+
def _is_equal_to(self, val1, obj2, val2):
511+
"""Check if the two values are equal.
512+
513+
Used to check if the two values are equal. Doesn't need to be redefined for IntTag and FloatTag.
514+
515+
Args:
516+
val1 (Any): The value of this tag object.
517+
obj2 (AbstractTag): The other tag object.
518+
val2 (Any): The value of the other tag object.
519+
rtol (float, optional): Relative tolerance. Defaults to 1.e-5.
520+
atol (float, optional): Absolute tolerance. Defaults to 1.e-8.
521+
Returns:
522+
bool: True if the two tag object/value pairs are equal, False otherwise.
523+
"""
524+
return self._is_same_tagtype(obj2) and np.isclose(val1, val2, rtol=self.eq_rtol, atol=self.eq_atol)
525+
428526

429527
@dataclass
430528
class IntTag(AbstractNumericTag):
@@ -620,6 +718,10 @@ def get_token_len(self) -> int:
620718
"""
621719
return self._get_token_len()
622720

721+
def _is_equal_to(self, val1, obj2, val2):
722+
return True # TODO: We still need to actually implement initmagmom as a multi-format tag
723+
# raise NotImplementedError("equality not yet implemented for InitMagMomTag")
724+
623725

624726
@dataclass
625727
class TagContainer(AbstractTag):
@@ -1013,6 +1115,28 @@ def get_dict_representation(self, tag: str, value: list) -> dict | list[dict]:
10131115
list_value = self._make_str_for_dict(tag, value)
10141116
return self.read(tag, list_value)
10151117

1118+
def _is_equal_to(self, val1, obj2, val2):
1119+
"""Check if the two values are equal.
1120+
1121+
Return False if (checked in following order)
1122+
- obj2 is not a TagContainer
1123+
- all of val1's subtags are not in val2
1124+
- val1 and val2 are not the same length (different number of subtags)
1125+
- at least one subtag in val1 is not equal to the corresponding subtag in val2
1126+
"""
1127+
if self._is_same_tagtype(obj2):
1128+
if isinstance(val1, dict) and isinstance(val2, dict):
1129+
if all(subtag in val2 for subtag in val1) and (len(list(val1.keys())) == len(list(val2.keys()))):
1130+
for subtag, subtag_type in self.subtags.items():
1131+
if (subtag in val1) and (
1132+
not subtag_type.is_equal_to(val1[subtag], obj2.subtags[subtag], val2[subtag])
1133+
):
1134+
return False
1135+
return True
1136+
return False
1137+
raise ValueError("Values must be in dictionary format for TagContainer comparison")
1138+
return False
1139+
10161140

10171141
# TODO: Write StructureDefferedTagContainer back in (commented out code block removed
10181142
# on 11/4/24) and make usable for tags like initial-magnetic-moments
@@ -1162,6 +1286,9 @@ def get_token_len(self) -> int:
11621286
"""
11631287
raise NotImplementedError("This method is not supposed to be called directly on MultiformatTag objects!")
11641288

1289+
def _is_equal_to(self, val1, obj2, val2):
1290+
raise NotImplementedError("This method is not supposed to be called directly on MultiformatTag objects!")
1291+
11651292

11661293
@dataclass
11671294
class BoolTagContainer(TagContainer):

0 commit comments

Comments
 (0)