Skip to content

Update PyQrack.run to match RFC #225

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 38 additions & 37 deletions src/bloqade/pyqrack/target.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List, TypeVar, ParamSpec
from typing import Any, TypeVar, Iterable
from dataclasses import field, dataclass

from kirin import ir
Expand All @@ -13,9 +13,6 @@
)
from bloqade.analysis.address import AnyAddress, AddressAnalysis

Params = ParamSpec("Params")
RetType = TypeVar("RetType")


@dataclass
class PyQrack:
Expand All @@ -36,7 +33,9 @@
{**_default_pyqrack_args(), **self.pyqrack_options}
)

def _get_interp(self, mt: ir.Method[Params, RetType]):
RetType = TypeVar("RetType")

def _get_interp(self, mt: ir.Method[..., RetType]):
if self.dynamic_qubits:

options = self.pyqrack_options.copy()
Expand Down Expand Up @@ -64,49 +63,51 @@

def run(
self,
mt: ir.Method[Params, RetType],
*args: Params.args,
**kwargs: Params.kwargs,
) -> RetType:
mt: ir.Method[..., RetType],
*,
shots: int = 1,
args: tuple[Any, ...] = (),
kwargs: dict[str, Any] = {},
return_iterator: bool = False,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what's the use case of return_iterator?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This goes back to the discussion of memory conservation, if you return a register that will mean the full state will persist in memory if you store the simulator object in a list whereas if you return an iterator you have the option of only using the simulator object for that iteration and the gc can clean it up. Also, we get this for free since we need to loop for each shot anyway its just a slight difference in implementation.

) -> RetType | list[RetType] | Iterable[RetType]:
"""Run the given kernel method on the PyQrack simulator.

Args
mt (Method):
The kernel method to run.
shots (int):
The number of shots to run the simulation for.
Defaults to 1.
args (tuple[Any, ...]):
Positional arguments to pass to the kernel method.
Defaults to ().
kwargs (dict[str, Any]):
Keyword arguments to pass to the kernel method.
Defaults to {}.
return_iterator (bool):
Whether to return an iterator that yields results for each shot.
Defaults to False. if False, a list of results is returned.

Returns
The result of the kernel method, if any.

"""
fold = Fold(mt.dialects)
fold(mt)
return self._get_interp(mt).run(mt, args, kwargs)

def multi_run(
self,
mt: ir.Method[Params, RetType],
_shots: int,
*args: Params.args,
**kwargs: Params.kwargs,
) -> List[RetType]:
"""Run the given kernel method on the PyQrack `_shots` times, caching analysis results.

Args
mt (Method):
The kernel method to run.
_shots (int):
The number of times to run the kernel method.

Returns
List of results of the kernel method, one for each shot.
RetType | list[RetType] | Iterable[RetType]:
The result of the simulation. If `return_iterator` is True,
an iterator that yields results for each shot is returned.
Otherwise, a list of results is returned if `shots > 1`, or
a single result is returned if `shots == 1`.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure if returning different data type here is a good idea. should we always return list or Iterator[RetType]? Maybe Iterable[RetType]?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a typo It should be Iterable

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One option is we just always return the Iterable and then just let people capture those values into a list by default:

result = list(device.run(...))


"""
fold = Fold(mt.dialects)
fold(mt)

interpreter = self._get_interp(mt)
batched_results = []
for _ in range(_shots):
batched_results.append(interpreter.run(mt, args, kwargs))

return batched_results
def run_shots():
for _ in range(shots):
yield interpreter.run(mt, args, kwargs)

if shots == 1:
return interpreter.run(mt, args, kwargs)
elif return_iterator:
return run_shots()

Check warning on line 111 in src/bloqade/pyqrack/target.py

View check run for this annotation

Codecov / codecov/patch

src/bloqade/pyqrack/target.py#L111

Added line #L111 was not covered by tests
else:
return list(run_shots())
5 changes: 2 additions & 3 deletions test/pyqrack/runtime/test_dyn_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,7 @@ def ghz(n: int):
for i in range(1, n):
qasm2.cx(q[0], q[i])

for i in range(n):
qasm2.measure(q[i], c[i])
qasm2.measure(q, c)

return c

Expand All @@ -27,6 +26,6 @@ def ghz(n: int):

N = 20

result = target.multi_run(ghz, 100, N)
result = target.run(ghz, shots=100, args=(N,))
result = Counter("".join(str(int(bit)) for bit in bits) for bits in result)
assert result.keys() == {"0" * N, "1" * N}