Skip to content

Commit 6efb6f5

Browse files
authored
Allow expr_matches to better handle presence of extra data (#93)
* Allow expr_matches to better handle presence of extra data * Switch to checking isinstance instead of package name * Add OQPyExpression._expr_matches to give subclasses ability to control expr_matches behavior * mypy * fix coverage
1 parent a1c746c commit 6efb6f5

File tree

2 files changed

+61
-2
lines changed

2 files changed

+61
-2
lines changed

oqpy/base.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,12 @@ def __bool__(self) -> bool:
180180
"the equality of expressions using == instead of expr_matches."
181181
)
182182

183+
def _expr_matches(self, other: Any) -> bool:
184+
"""Called by expr_matches to compare expression instances."""
185+
if not isinstance(other, type(self)):
186+
return False
187+
return expr_matches(self.__dict__, other.__dict__)
188+
183189

184190
def _get_type(val: AstConvertible) -> Optional[ast.ClassicalType]:
185191
if isinstance(val, OQPyExpression):
@@ -318,6 +324,8 @@ def expr_matches(a: Any, b: Any) -> bool:
318324
319325
This bypasses calling ``__eq__`` on expr objects.
320326
"""
327+
if a is b:
328+
return True
321329
if type(a) is not type(b):
322330
return False
323331
if isinstance(a, (list, np.ndarray)):
@@ -328,8 +336,9 @@ def expr_matches(a: Any, b: Any) -> bool:
328336
if a.keys() != b.keys():
329337
return False
330338
return all(expr_matches(va, b[k]) for k, va in a.items())
331-
if hasattr(a, "__dict__"):
332-
return expr_matches(a.__dict__, b.__dict__)
339+
if isinstance(a, OQPyExpression):
340+
# Bypass `__eq__` which is overloaded on OQPyExpressions
341+
return a._expr_matches(b)
333342
else:
334343
return a == b
335344

tests/test_directives.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2624,3 +2624,53 @@ def test_box_with_negative_duration():
26242624
with pytest.raises(ValueError, match="Expected a non-negative duration, but got -4e-09"):
26252625
with Box(prog, -4e-9):
26262626
pass
2627+
2628+
2629+
def test_expr_matches_handles_outside_data():
2630+
x1 = oqpy.FloatVar(3, name="x")
2631+
x2 = oqpy.FloatVar(3, name="x")
2632+
class MyEntity:
2633+
def __init__(self):
2634+
self.self_ref = self
2635+
2636+
def __eq__(self, other):
2637+
return True
2638+
2639+
x1._entity = MyEntity()
2640+
x2._entity = MyEntity()
2641+
assert oqpy.base.expr_matches(x1, x2)
2642+
2643+
class MyEntityNoEq:
2644+
def __init__(self):
2645+
self.self_ref = self
2646+
def __eq__(self, other):
2647+
raise RuntimeError("Eq not allowed")
2648+
2649+
x1._entity = MyEntityNoEq()
2650+
x2._entity = x1._entity
2651+
oqpy.base.expr_matches(x1, x2)
2652+
2653+
class MyFloatVar(oqpy.FloatVar):
2654+
...
2655+
2656+
x1 = MyFloatVar(3, name="x")
2657+
x2 = MyFloatVar(3, name="x")
2658+
assert not x1._expr_matches(oqpy.FloatVar(3, name="x"))
2659+
assert oqpy.base.expr_matches(x1, x2)
2660+
2661+
class MyFloatVarWithIgnoredData(oqpy.FloatVar):
2662+
ignored: int
2663+
def _expr_matches(self, other):
2664+
if not isinstance(other, type(self)):
2665+
return False
2666+
d1 = self.__dict__.copy()
2667+
d2 = other.__dict__.copy()
2668+
d1.pop("ignored")
2669+
d2.pop("ignored")
2670+
return oqpy.base.expr_matches(d1, d2)
2671+
2672+
x1 = MyFloatVarWithIgnoredData(3, name="x")
2673+
x1.ignored = 1
2674+
x2 = MyFloatVarWithIgnoredData(3, name="x")
2675+
x2.ignored = 2
2676+
assert oqpy.base.expr_matches(x1, x2)

0 commit comments

Comments
 (0)