Skip to content

fix mutating programs #54

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jul 11, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 20 additions & 13 deletions oqpy/program.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,9 @@ def __iadd__(self, other: Program) -> Program:
self._state.finalize_if_clause()
self.defcals.update(other.defcals)
for name, subroutine_stmt in other.subroutines.items():
self._add_subroutine(name, subroutine_stmt)
self._add_subroutine(
name, subroutine_stmt, needs_declaration=name not in other.declared_subroutines
)
self.externs.update(other.externs)
for var in other.declared_vars.values():
self._mark_var_declared(var)
Expand Down Expand Up @@ -276,26 +278,31 @@ def to_ast(
if the variables have openpulse types, automatically wrap the
declarations in cal blocks.
"""
if not ignore_needs_declaration and self.undeclared_vars:
self.autodeclare(encal=encal_declarations)
mutating_prog = Program(self.version, self.simplify_constants)
mutating_prog += self

assert len(self.stack) == 1
self._state.finalize_if_clause()
if self._state.annotations:
warnings.warn(f"Annotation(s) {self._state.annotations} not applied to any statement")
if not ignore_needs_declaration and mutating_prog.undeclared_vars:
mutating_prog.autodeclare(encal=encal_declarations)

assert len(mutating_prog.stack) == 1
mutating_prog._state.finalize_if_clause()
if mutating_prog._state.annotations:
warnings.warn(
f"Annotation(s) {mutating_prog._state.annotations} not applied to any statement"
)
statements = []
if include_externs:
statements += self._make_externs_statements(encal_declarations)
statements += mutating_prog._make_externs_statements(encal_declarations)
statements += [
self.subroutines[subroutine_name]
for subroutine_name in self.subroutines
if subroutine_name not in self.declared_subroutines
] + self._state.body
mutating_prog.subroutines[subroutine_name]
for subroutine_name in mutating_prog.subroutines
if subroutine_name not in mutating_prog.declared_subroutines
] + mutating_prog._state.body
if encal:
statements = [ast.CalibrationStatement(statements)]
if encal_declarations:
statements = [ast.CalibrationGrammarDeclaration("openpulse")] + statements
prog = ast.Program(statements=statements, version=self.version)
prog = ast.Program(statements=statements, version=mutating_prog.version)
if encal_declarations:
MergeCalStatementsPass().visit(prog)
return prog
Expand Down
2 changes: 1 addition & 1 deletion oqpy/quantum_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def defcal(
elif len(variables) == 1:
yield variables[0]
else:
yield
yield None
state = program._pop()

stmt = ast.CalibrationDefinition(
Expand Down
52 changes: 51 additions & 1 deletion tests/test_directives.py
Original file line number Diff line number Diff line change
Expand Up @@ -1149,7 +1149,6 @@ def increment_variable_return(int[32] i) -> int[32] {
k = increment_variable_return(j);
"""
).strip()
print(prog.to_qasm())
assert prog.to_qasm() == expected
_check_respects_type_hints(prog)

Expand Down Expand Up @@ -1530,9 +1529,60 @@ def test_needs_declaration():
"""
).strip()

declared_vars = {}
undeclared_vars= ["i1", "i2", "f1", "f2", "q1", "q2"]
statement_ast = [
ast.ClassicalAssignment(
lvalue=ast.Identifier(name="i1"),
op=ast.AssignmentOperator["+="],
rvalue=ast.IntegerLiteral(value=1),
),
ast.ClassicalAssignment(
lvalue=ast.Identifier(name="i2"),
op=ast.AssignmentOperator["+="],
rvalue=ast.IntegerLiteral(value=1),
),
ast.ExpressionStatement(
expression=ast.FunctionCall(
name=ast.Identifier(name="set_phase"),
arguments=[ast.Identifier(name="f1"), ast.IntegerLiteral(value=0)],
)
),
ast.ExpressionStatement(
expression=ast.FunctionCall(
name=ast.Identifier(name="set_phase"),
arguments=[ast.Identifier(name="f2"), ast.IntegerLiteral(value=0)],
)
),
ast.QuantumGate(
modifiers=[],
name=ast.Identifier(name="X"),
arguments=[],
qubits=[ast.Identifier(name="q1")],
duration=None,
),
ast.QuantumGate(
modifiers=[],
name=ast.Identifier(name="X"),
arguments=[],
qubits=[ast.Identifier(name="q2")],
duration=None,
),
]

# testing variables before calling to_ast
assert prog.declared_vars == declared_vars
assert list(prog.undeclared_vars.keys()) == undeclared_vars
assert prog._state.body == statement_ast

assert prog.to_qasm() == expected
_check_respects_type_hints(prog)

# testing variables after calling to_ast, checking mutations
assert prog.declared_vars == declared_vars
assert list(prog.undeclared_vars.keys()) == undeclared_vars
assert prog._state.body == statement_ast


def test_discrete_waveform():
port = PortVar("port")
Expand Down