@@ -55,13 +55,15 @@ class ProgramState:
55
55
"""
56
56
57
57
def __init__ (self ) -> None :
58
- self .body : list [ast .Statement ] = []
58
+ self .body : list [ast .Statement | ast . Pragma ] = []
59
59
self .if_clause : Optional [ast .BranchingStatement ] = None
60
60
self .annotations : list [ast .Annotation ] = []
61
61
62
62
def add_if_clause (self , condition : ast .Expression , if_clause : list [ast .Statement ]) -> None :
63
+ if_clause_annotations , self .annotations = self .annotations , []
63
64
self .finalize_if_clause ()
64
65
self .if_clause = ast .BranchingStatement (condition , if_clause , [])
66
+ self .if_clause .annotations = if_clause_annotations
65
67
66
68
def add_else_clause (self , else_clause : list [ast .Statement ]) -> None :
67
69
if self .if_clause is None :
@@ -74,12 +76,15 @@ def finalize_if_clause(self) -> None:
74
76
if_clause , self .if_clause = self .if_clause , None
75
77
self .add_statement (if_clause )
76
78
77
- def add_statement (self , stmt : ast .Statement ) -> None :
78
- assert isinstance (stmt , ast .Statement )
79
- self .finalize_if_clause ()
80
- if self .annotations :
79
+ def add_statement (self , stmt : ast .Statement | ast .Pragma ) -> None :
80
+ # This function accepts Statement and Pragma even though
81
+ # it seems to conflict with the definition of ast.Program.
82
+ # Issue raised in https://github.com/openqasm/openqasm/issues/468
83
+ assert isinstance (stmt , (ast .Statement , ast .Pragma ))
84
+ if isinstance (stmt , ast .Statement ) and self .annotations :
81
85
stmt .annotations = self .annotations + list (stmt .annotations )
82
86
self .annotations = []
87
+ self .finalize_if_clause ()
83
88
self .body .append (stmt )
84
89
85
90
@@ -457,6 +462,13 @@ def measure(
457
462
)
458
463
return self
459
464
465
+ def pragma (self , command : str ) -> Program :
466
+ """Add a pragma instruction."""
467
+ if len (self .stack ) != 1 :
468
+ raise RuntimeError ("Pragmas must be global" )
469
+ self ._add_statement (ast .Pragma (command ))
470
+ return self
471
+
460
472
def _do_assignment (self , var : AstConvertible , op : str , value : AstConvertible ) -> None :
461
473
"""Helper function for variable assignment operations."""
462
474
if isinstance (var , classical_types .DurationVar ):
@@ -537,7 +549,9 @@ def visit_SubroutineDefinition(
537
549
node .body = self .process_statement_list (node .body )
538
550
self .generic_visit (node , context )
539
551
540
- def process_statement_list (self , statements : list [ast .Statement ]) -> list [ast .Statement ]:
552
+ def process_statement_list (
553
+ self , statements : list [ast .Statement | ast .Pragma ]
554
+ ) -> list [ast .Statement | ast .Pragma ]:
541
555
new_list = []
542
556
cal_stmts = []
543
557
for stmt in statements :
0 commit comments