Skip to content

Commit 2f58c19

Browse files
committed
Add OQPyExpression._expr_matches to give subclasses ability to control expr_matches behavior
1 parent 6de16a9 commit 2f58c19

File tree

2 files changed

+25
-1
lines changed

2 files changed

+25
-1
lines changed

oqpy/base.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,12 @@ def __bool__(self) -> bool:
180180
"the equality of expressions using == instead of expr_matches."
181181
)
182182

183+
def _expr_matches(self, other) -> bool:
184+
"""Called by expr_matches to compare expression instances."""
185+
if not isinstance(other, type(self)):
186+
return False
187+
return expr_matches(self.__dict__, other.__dict__)
188+
183189

184190
def _get_type(val: AstConvertible) -> Optional[ast.ClassicalType]:
185191
if isinstance(val, OQPyExpression):
@@ -332,7 +338,7 @@ def expr_matches(a: Any, b: Any) -> bool:
332338
return all(expr_matches(va, b[k]) for k, va in a.items())
333339
if isinstance(a, OQPyExpression):
334340
# Bypass `__eq__` which is overloaded on OQPyExpressions
335-
return expr_matches(a.__dict__, b.__dict__)
341+
return a._expr_matches(b)
336342
else:
337343
return a == b
338344

tests/test_directives.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2656,3 +2656,21 @@ class MyFloatVar(oqpy.FloatVar):
26562656
x1 = MyFloatVar(3, name="x")
26572657
x2 = MyFloatVar(3, name="x")
26582658
assert oqpy.base.expr_matches(x1, x2)
2659+
2660+
class MyFloatVarWithIgnoredData(oqpy.FloatVar):
2661+
ignored: int
2662+
def _expr_matches(self, other):
2663+
if not isinstance(other, type(self)):
2664+
return False
2665+
d1 = self.__dict__.copy()
2666+
d2 = other.__dict__.copy()
2667+
d1.pop("ignored")
2668+
d2.pop("ignored")
2669+
return oqpy.base.expr_matches(d1, d2)
2670+
2671+
2672+
x1 = MyFloatVarWithIgnoredData(3, name="x")
2673+
x1.ignored = 1
2674+
x2 = MyFloatVarWithIgnoredData(3, name="x")
2675+
x2.ignored = 2
2676+
assert oqpy.base.expr_matches(x1, x2)

0 commit comments

Comments
 (0)