Skip to content

Commit e8ffc3a

Browse files
fix mutating programs (#54)
* fix mutating programs * fix mypy * fix __iadd__
1 parent e9ab80a commit e8ffc3a

File tree

3 files changed

+72
-15
lines changed

3 files changed

+72
-15
lines changed

oqpy/program.py

Lines changed: 20 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,9 @@ def __iadd__(self, other: Program) -> Program:
123123
self._state.finalize_if_clause()
124124
self.defcals.update(other.defcals)
125125
for name, subroutine_stmt in other.subroutines.items():
126-
self._add_subroutine(name, subroutine_stmt)
126+
self._add_subroutine(
127+
name, subroutine_stmt, needs_declaration=name not in other.declared_subroutines
128+
)
127129
self.externs.update(other.externs)
128130
for var in other.declared_vars.values():
129131
self._mark_var_declared(var)
@@ -276,26 +278,31 @@ def to_ast(
276278
if the variables have openpulse types, automatically wrap the
277279
declarations in cal blocks.
278280
"""
279-
if not ignore_needs_declaration and self.undeclared_vars:
280-
self.autodeclare(encal=encal_declarations)
281+
mutating_prog = Program(self.version, self.simplify_constants)
282+
mutating_prog += self
281283

282-
assert len(self.stack) == 1
283-
self._state.finalize_if_clause()
284-
if self._state.annotations:
285-
warnings.warn(f"Annotation(s) {self._state.annotations} not applied to any statement")
284+
if not ignore_needs_declaration and mutating_prog.undeclared_vars:
285+
mutating_prog.autodeclare(encal=encal_declarations)
286+
287+
assert len(mutating_prog.stack) == 1
288+
mutating_prog._state.finalize_if_clause()
289+
if mutating_prog._state.annotations:
290+
warnings.warn(
291+
f"Annotation(s) {mutating_prog._state.annotations} not applied to any statement"
292+
)
286293
statements = []
287294
if include_externs:
288-
statements += self._make_externs_statements(encal_declarations)
295+
statements += mutating_prog._make_externs_statements(encal_declarations)
289296
statements += [
290-
self.subroutines[subroutine_name]
291-
for subroutine_name in self.subroutines
292-
if subroutine_name not in self.declared_subroutines
293-
] + self._state.body
297+
mutating_prog.subroutines[subroutine_name]
298+
for subroutine_name in mutating_prog.subroutines
299+
if subroutine_name not in mutating_prog.declared_subroutines
300+
] + mutating_prog._state.body
294301
if encal:
295302
statements = [ast.CalibrationStatement(statements)]
296303
if encal_declarations:
297304
statements = [ast.CalibrationGrammarDeclaration("openpulse")] + statements
298-
prog = ast.Program(statements=statements, version=self.version)
305+
prog = ast.Program(statements=statements, version=mutating_prog.version)
299306
if encal_declarations:
300307
MergeCalStatementsPass().visit(prog)
301308
return prog

oqpy/quantum_types.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,7 @@ def defcal(
111111
elif len(variables) == 1:
112112
yield variables[0]
113113
else:
114-
yield
114+
yield None
115115
state = program._pop()
116116

117117
stmt = ast.CalibrationDefinition(

tests/test_directives.py

Lines changed: 51 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1149,7 +1149,6 @@ def increment_variable_return(int[32] i) -> int[32] {
11491149
k = increment_variable_return(j);
11501150
"""
11511151
).strip()
1152-
print(prog.to_qasm())
11531152
assert prog.to_qasm() == expected
11541153
_check_respects_type_hints(prog)
11551154

@@ -1530,9 +1529,60 @@ def test_needs_declaration():
15301529
"""
15311530
).strip()
15321531

1532+
declared_vars = {}
1533+
undeclared_vars= ["i1", "i2", "f1", "f2", "q1", "q2"]
1534+
statement_ast = [
1535+
ast.ClassicalAssignment(
1536+
lvalue=ast.Identifier(name="i1"),
1537+
op=ast.AssignmentOperator["+="],
1538+
rvalue=ast.IntegerLiteral(value=1),
1539+
),
1540+
ast.ClassicalAssignment(
1541+
lvalue=ast.Identifier(name="i2"),
1542+
op=ast.AssignmentOperator["+="],
1543+
rvalue=ast.IntegerLiteral(value=1),
1544+
),
1545+
ast.ExpressionStatement(
1546+
expression=ast.FunctionCall(
1547+
name=ast.Identifier(name="set_phase"),
1548+
arguments=[ast.Identifier(name="f1"), ast.IntegerLiteral(value=0)],
1549+
)
1550+
),
1551+
ast.ExpressionStatement(
1552+
expression=ast.FunctionCall(
1553+
name=ast.Identifier(name="set_phase"),
1554+
arguments=[ast.Identifier(name="f2"), ast.IntegerLiteral(value=0)],
1555+
)
1556+
),
1557+
ast.QuantumGate(
1558+
modifiers=[],
1559+
name=ast.Identifier(name="X"),
1560+
arguments=[],
1561+
qubits=[ast.Identifier(name="q1")],
1562+
duration=None,
1563+
),
1564+
ast.QuantumGate(
1565+
modifiers=[],
1566+
name=ast.Identifier(name="X"),
1567+
arguments=[],
1568+
qubits=[ast.Identifier(name="q2")],
1569+
duration=None,
1570+
),
1571+
]
1572+
1573+
# testing variables before calling to_ast
1574+
assert prog.declared_vars == declared_vars
1575+
assert list(prog.undeclared_vars.keys()) == undeclared_vars
1576+
assert prog._state.body == statement_ast
1577+
15331578
assert prog.to_qasm() == expected
15341579
_check_respects_type_hints(prog)
15351580

1581+
# testing variables after calling to_ast, checking mutations
1582+
assert prog.declared_vars == declared_vars
1583+
assert list(prog.undeclared_vars.keys()) == undeclared_vars
1584+
assert prog._state.body == statement_ast
1585+
15361586

15371587
def test_discrete_waveform():
15381588
port = PortVar("port")

0 commit comments

Comments
 (0)