Skip to content

Commit 2df694b

Browse files
committed
Allow expr_matches to better handle presence of extra data
1 parent a1c746c commit 2df694b

File tree

2 files changed

+28
-1
lines changed

2 files changed

+28
-1
lines changed

oqpy/base.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -318,6 +318,8 @@ def expr_matches(a: Any, b: Any) -> bool:
318318
319319
This bypasses calling ``__eq__`` on expr objects.
320320
"""
321+
if a is b:
322+
return True
321323
if type(a) is not type(b):
322324
return False
323325
if isinstance(a, (list, np.ndarray)):
@@ -328,7 +330,7 @@ def expr_matches(a: Any, b: Any) -> bool:
328330
if a.keys() != b.keys():
329331
return False
330332
return all(expr_matches(va, b[k]) for k, va in a.items())
331-
if hasattr(a, "__dict__"):
333+
if hasattr(a, "__dict__") and type(a).__module__.startswith("oqpy"):
332334
return expr_matches(a.__dict__, b.__dict__)
333335
else:
334336
return a == b

tests/test_directives.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2624,3 +2624,28 @@ def test_box_with_negative_duration():
26242624
with pytest.raises(ValueError, match="Expected a non-negative duration, but got -4e-09"):
26252625
with Box(prog, -4e-9):
26262626
pass
2627+
2628+
2629+
def test_expr_matches_handles_outside_data():
2630+
x1 = oqpy.FloatVar(3, name="x")
2631+
x2 = oqpy.FloatVar(3, name="x")
2632+
class MyEntity:
2633+
def __init__(self):
2634+
self.self_ref = self
2635+
2636+
def __eq__(self, other):
2637+
return True
2638+
2639+
x1._entity = MyEntity()
2640+
x2._entity = MyEntity()
2641+
assert oqpy.base.expr_matches(x1, x2)
2642+
2643+
class MyEntityNoEq:
2644+
def __init__(self):
2645+
self.self_ref = self
2646+
def __eq__(self, other):
2647+
raise RuntimeError("Eq not allowed")
2648+
2649+
x1._entity = MyEntityNoEq()
2650+
x2._entity = x1._entity
2651+
oqpy.base.expr_matches(x1, x2)

0 commit comments

Comments
 (0)