diff --git a/oqpy/base.py b/oqpy/base.py index 43dece9..4950b92 100644 --- a/oqpy/base.py +++ b/oqpy/base.py @@ -107,6 +107,36 @@ def __pow__(self, other: AstConvertible) -> OQPyBinaryExpression: def __rpow__(self, other: AstConvertible) -> OQPyBinaryExpression: return self._to_binary("**", other, self) + def __lshift__(self, other: AstConvertible) -> OQPyBinaryExpression: + return self._to_binary("<<", self, other) + + def __rlshift__(self, other: AstConvertible) -> OQPyBinaryExpression: + return self._to_binary("<<", other, self) + + def __rshift__(self, other: AstConvertible) -> OQPyBinaryExpression: + return self._to_binary(">>", self, other) + + def __rrshift__(self, other: AstConvertible) -> OQPyBinaryExpression: + return self._to_binary(">>", other, self) + + def __and__(self, other: AstConvertible) -> OQPyBinaryExpression: + return self._to_binary("&", self, other) + + def __rand__(self, other: AstConvertible) -> OQPyBinaryExpression: + return self._to_binary("&", other, self) + + def __or__(self, other: AstConvertible) -> OQPyBinaryExpression: + return self._to_binary("|", self, other) + + def __ror__(self, other: AstConvertible) -> OQPyBinaryExpression: + return self._to_binary("|", other, self) + + def __xor__(self, other: AstConvertible) -> OQPyBinaryExpression: + return self._to_binary("^", self, other) + + def __rxor__(self, other: AstConvertible) -> OQPyBinaryExpression: + return self._to_binary("^", other, self) + def __eq__(self, other: AstConvertible) -> OQPyBinaryExpression: # type: ignore[override] return self._to_binary("==", self, other) @@ -125,6 +155,9 @@ def __ge__(self, other: AstConvertible) -> OQPyBinaryExpression: def __le__(self, other: AstConvertible) -> OQPyBinaryExpression: return self._to_binary("<=", self, other) + def __invert__(self) -> OQPyUnaryExpression: + return self._to_unary("~", self) + def __bool__(self) -> bool: raise RuntimeError( "OQPy expressions cannot be converted to bool. This can occur if you try to check " @@ -132,6 +165,16 @@ def __bool__(self) -> bool: ) +def logical_and(first: AstConvertible, second: AstConvertible) -> OQPyBinaryExpression: + """Logical AND.""" + return OQPyBinaryExpression(ast.BinaryOperator["&&"], first, second) + + +def logical_or(first: AstConvertible, second: AstConvertible) -> OQPyBinaryExpression: + """Logical OR.""" + return OQPyBinaryExpression(ast.BinaryOperator["||"], first, second) + + def expr_matches(a: Any, b: Any) -> bool: """Check equality of the given objects. diff --git a/tests/test_directives.py b/tests/test_directives.py index 92ad181..f16a064 100644 --- a/tests/test_directives.py +++ b/tests/test_directives.py @@ -24,7 +24,7 @@ import oqpy from oqpy import * -from oqpy.base import expr_matches +from oqpy.base import expr_matches, logical_and, logical_or from oqpy.quantum_types import PhysicalQubits from oqpy.timing import OQDurationLiteral @@ -269,22 +269,68 @@ def test_binary_expressions(): prog = Program() i = IntVar(5, "i") j = IntVar(2, "j") + k = IntVar(0, "k") + b1 = BoolVar(False, "b1") + b2 = BoolVar(True, "b2") + b3 = BoolVar(False, "b3") prog.set(i, 2 * (i + j)) prog.set(j, 2 % (2 - i) % 2) prog.set(j, 1 + oqpy.pi) prog.set(j, 1 / oqpy.pi**2 / 2 + 2**oqpy.pi) prog.set(j, -oqpy.pi * oqpy.pi - i**j) + prog.set(k, i & 51966) + prog.set(k, 51966 & i) + prog.set(k, i & j) + prog.set(k, i | 51966) + prog.set(k, 51966 | i) + prog.set(k, i | j) + prog.set(k, i ^ 51966) + prog.set(k, 51966 & i) + prog.set(k, i ^ j) + prog.set(k, i >> 1) + prog.set(k, 1 >> i) + prog.set(k, i >> j) + prog.set(k, i << 1) + prog.set(k, 1 << j) + prog.set(k, i << j) + prog.set(k, ~k) + prog.set(b1, logical_or(b2, b3)) + prog.set(b1, logical_and(b2, True)) + prog.set(b1, logical_or(False, b3)) expected = textwrap.dedent( """ OPENQASM 3.0; int[32] i = 5; int[32] j = 2; + int[32] k = 0; + bool b1 = false; + bool b2 = true; + bool b3 = false; i = 2 * (i + j); j = 2 % (2 - i) % 2; j = 1 + pi; j = 1 / pi ** 2 / 2 + 2 ** pi; j = -pi * pi - i ** j; + k = i & 51966; + k = 51966 & i; + k = i & j; + k = i | 51966; + k = 51966 | i; + k = i | j; + k = i ^ 51966; + k = 51966 & i; + k = i ^ j; + k = i >> 1; + k = 1 >> i; + k = i >> j; + k = i << 1; + k = 1 << j; + k = i << j; + k = ~k; + b1 = b2 || b3; + b1 = b2 && true; + b1 = false || b3; """ ).strip()