@@ -70,6 +70,56 @@ def validate_value_type(self, tag: str, value: Any, try_auto_type_fix: bool = Fa
70
70
tuple[str, bool, Any]: The tag, whether the value is of the correct type, and the possibly fixed value.
71
71
"""
72
72
73
+ def is_equal_to (self , val1 : Any | list [Any ], obj2 : AbstractTag , val2 : Any | list [Any ]) -> bool :
74
+ """Check if the two values are equal.
75
+
76
+ Args:
77
+ val1 (Any): The value of this tag object.
78
+ obj2 (AbstractTag): The other tag object.
79
+ val2 (Any): The value of the other tag object.
80
+
81
+ Returns:
82
+ bool: True if the two tag object/value pairs are equal, False otherwise.
83
+ """
84
+ if self .can_repeat :
85
+ if not obj2 .can_repeat :
86
+ return False
87
+ val1 = val1 if isinstance (val1 , list ) else [val1 ]
88
+ val2 = val2 if isinstance (val2 , list ) else [val2 ]
89
+ if len (val1 ) != len (val2 ):
90
+ return False
91
+ return all (True in [self ._is_equal_to (v1 , obj2 , v2 ) for v2 in val2 ] for v1 in val1 )
92
+ return self ._is_equal_to (val1 , obj2 , val2 )
93
+
94
+ @abstractmethod
95
+ def _is_equal_to (self , val1 : Any , obj2 : AbstractTag , val2 : Any ) -> bool :
96
+ """Check if the two values are equal.
97
+
98
+ Used to check if the two values are equal. Assumes val1 and val2 are single elements.
99
+
100
+ Args:
101
+ val1 (Any): The value of this tag object.
102
+ obj2 (AbstractTag): The other tag object.
103
+ val2 (Any): The value of the other tag object.
104
+
105
+ Returns:
106
+ bool: True if the two tag object/value pairs are equal, False otherwise.
107
+ """
108
+
109
+ def _is_same_tagtype (
110
+ self ,
111
+ obj2 : AbstractTag ,
112
+ ) -> bool :
113
+ """Check if the two values are equal.
114
+
115
+ Args:
116
+ obj2 (AbstractTag): The other tag object.
117
+
118
+ Returns:
119
+ bool: True if the two tag object/value pairs are equal, False otherwise.
120
+ """
121
+ return isinstance (self , type (obj2 ))
122
+
73
123
def _validate_value_type (
74
124
self , type_check : type , tag : str , value : Any , try_auto_type_fix : bool = False
75
125
) -> tuple [str , bool , Any ]:
@@ -258,6 +308,19 @@ def validate_value_type(self, tag: str, value: Any, try_auto_type_fix: bool = Fa
258
308
"""
259
309
return self ._validate_value_type (bool , tag , value , try_auto_type_fix = try_auto_type_fix )
260
310
311
+ def _is_equal_to (self , val1 : Any , obj2 : AbstractTag , val2 : Any ) -> bool :
312
+ """Check if the two values are equal.
313
+
314
+ Args:
315
+ val1 (Any): The value of this tag object.
316
+ obj2 (AbstractTag): The other tag object.
317
+ val2 (Any): The value of the other tag object.
318
+
319
+ Returns:
320
+ bool: True if the two tag object/value pairs are equal, False otherwise.
321
+ """
322
+ return self ._is_same_tagtype (obj2 ) and val1 == val2
323
+
261
324
def raise_value_error (self , tag : str , value : str ) -> None :
262
325
"""Raise a ValueError for the value string.
263
326
@@ -335,6 +398,23 @@ def validate_value_type(self, tag: str, value: Any, try_auto_type_fix: bool = Fa
335
398
"""
336
399
return self ._validate_value_type (str , tag , value , try_auto_type_fix = try_auto_type_fix )
337
400
401
+ def _is_equal_to (self , val1 : Any , obj2 : AbstractTag , val2 : Any ) -> bool :
402
+ """Check if the two values are equal.
403
+
404
+ Args:
405
+ val1 (Any): The value of this tag object.
406
+ obj2 (AbstractTag): The other tag object.
407
+ val2 (Any): The value of the other tag object.
408
+
409
+ Returns:
410
+ bool: True if the two tag object/value pairs are equal, False otherwise.
411
+ """
412
+ if self ._is_same_tagtype (obj2 ):
413
+ if not all (isinstance (x , str ) for x in (val1 , val2 )):
414
+ raise ValueError ("Both values must be strings for StrTag comparison" )
415
+ return val1 .strip () == val2 .strip ()
416
+ return False
417
+
338
418
def read (self , tag : str , value : str ) -> str :
339
419
"""Read the value string for this tag.
340
420
@@ -379,6 +459,8 @@ class AbstractNumericTag(AbstractTag):
379
459
ub : float | None = None # upper bound
380
460
lb_incl : bool = True # lower bound inclusive
381
461
ub_incl : bool = True # upper bound inclusive
462
+ eq_atol : float = 1.0e-8 # absolute tolerance for equality check
463
+ eq_rtol : float = 1.0e-5 # relative tolerance for equality check
382
464
383
465
def val_is_within_bounds (self , value : float ) -> bool :
384
466
"""Check if the value is within the bounds.
@@ -425,6 +507,22 @@ def validate_value_bounds(
425
507
return False , self .get_invalid_value_error_str (tag , value )
426
508
return True , ""
427
509
510
+ def _is_equal_to (self , val1 , obj2 , val2 ):
511
+ """Check if the two values are equal.
512
+
513
+ Used to check if the two values are equal. Doesn't need to be redefined for IntTag and FloatTag.
514
+
515
+ Args:
516
+ val1 (Any): The value of this tag object.
517
+ obj2 (AbstractTag): The other tag object.
518
+ val2 (Any): The value of the other tag object.
519
+ rtol (float, optional): Relative tolerance. Defaults to 1.e-5.
520
+ atol (float, optional): Absolute tolerance. Defaults to 1.e-8.
521
+ Returns:
522
+ bool: True if the two tag object/value pairs are equal, False otherwise.
523
+ """
524
+ return self ._is_same_tagtype (obj2 ) and np .isclose (val1 , val2 , rtol = self .eq_rtol , atol = self .eq_atol )
525
+
428
526
429
527
@dataclass
430
528
class IntTag (AbstractNumericTag ):
@@ -620,6 +718,10 @@ def get_token_len(self) -> int:
620
718
"""
621
719
return self ._get_token_len ()
622
720
721
+ def _is_equal_to (self , val1 , obj2 , val2 ):
722
+ return True # TODO: We still need to actually implement initmagmom as a multi-format tag
723
+ # raise NotImplementedError("equality not yet implemented for InitMagMomTag")
724
+
623
725
624
726
@dataclass
625
727
class TagContainer (AbstractTag ):
@@ -1013,6 +1115,28 @@ def get_dict_representation(self, tag: str, value: list) -> dict | list[dict]:
1013
1115
list_value = self ._make_str_for_dict (tag , value )
1014
1116
return self .read (tag , list_value )
1015
1117
1118
+ def _is_equal_to (self , val1 , obj2 , val2 ):
1119
+ """Check if the two values are equal.
1120
+
1121
+ Return False if (checked in following order)
1122
+ - obj2 is not a TagContainer
1123
+ - all of val1's subtags are not in val2
1124
+ - val1 and val2 are not the same length (different number of subtags)
1125
+ - at least one subtag in val1 is not equal to the corresponding subtag in val2
1126
+ """
1127
+ if self ._is_same_tagtype (obj2 ):
1128
+ if isinstance (val1 , dict ) and isinstance (val2 , dict ):
1129
+ if all (subtag in val2 for subtag in val1 ) and (len (list (val1 .keys ())) == len (list (val2 .keys ()))):
1130
+ for subtag , subtag_type in self .subtags .items ():
1131
+ if (subtag in val1 ) and (
1132
+ not subtag_type .is_equal_to (val1 [subtag ], obj2 .subtags [subtag ], val2 [subtag ])
1133
+ ):
1134
+ return False
1135
+ return True
1136
+ return False
1137
+ raise ValueError ("Values must be in dictionary format for TagContainer comparison" )
1138
+ return False
1139
+
1016
1140
1017
1141
# TODO: Write StructureDefferedTagContainer back in (commented out code block removed
1018
1142
# on 11/4/24) and make usable for tags like initial-magnetic-moments
@@ -1162,6 +1286,9 @@ def get_token_len(self) -> int:
1162
1286
"""
1163
1287
raise NotImplementedError ("This method is not supposed to be called directly on MultiformatTag objects!" )
1164
1288
1289
+ def _is_equal_to (self , val1 , obj2 , val2 ):
1290
+ raise NotImplementedError ("This method is not supposed to be called directly on MultiformatTag objects!" )
1291
+
1165
1292
1166
1293
@dataclass
1167
1294
class BoolTagContainer (TagContainer ):
0 commit comments