Skip to content

Commit 244e37b

Browse files
authored
Custom Remote Task Interface (#1022)
* implement interface * adding custom , routes for CustomSubmission using new .
1 parent 135aa3a commit 244e37b

File tree

3 files changed

+220
-7
lines changed

3 files changed

+220
-7
lines changed

src/bloqade/analog/ir/routine/quera.py

+137-2
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
11
import json
22
import time
3+
from typing import TypeVar
34
from collections import OrderedDict, namedtuple
45

56
from beartype import beartype
67
from requests import Response, request
78
from beartype.typing import Any, Dict, List, Tuple, Union, Optional, NamedTuple
89
from pydantic.v1.dataclasses import dataclass
910

11+
from bloqade.analog.task.base import CustomRemoteTaskABC
1012
from bloqade.analog.task.batch import RemoteBatch
1113
from bloqade.analog.task.quera import QuEraTask
1214
from bloqade.analog.builder.typing import LiteralType
@@ -49,7 +51,7 @@ def custom(self) -> "CustomSubmissionRoutine":
4951

5052
@dataclass(frozen=True, config=__pydantic_dataclass_config__)
5153
class CustomSubmissionRoutine(RoutineBase):
52-
def _compile(
54+
def _compile_single(
5355
self,
5456
shots: int,
5557
use_experimental: bool = False,
@@ -150,7 +152,7 @@ def submit(
150152
)
151153

152154
out = []
153-
for metadata, task_ir in self._compile(shots, use_experimental, args):
155+
for metadata, task_ir in self._compile_single(shots, use_experimental, args):
154156
json_request_body = json_body_template.format(
155157
task_ir=task_ir.json(exclude_none=True, exclude_unset=True)
156158
)
@@ -161,6 +163,139 @@ def submit(
161163

162164
return out
163165

166+
RemoteTaskType = TypeVar("RemoteTaskType", bound=CustomRemoteTaskABC)
167+
168+
def _compile_custom_batch(
169+
self,
170+
shots: int,
171+
RemoteTask: type[RemoteTaskType],
172+
use_experimental: bool = False,
173+
args: Tuple[LiteralType, ...] = (),
174+
name: Optional[str] = None,
175+
) -> RemoteBatch:
176+
from bloqade.analog.submission.capabilities import get_capabilities
177+
from bloqade.analog.compiler.passes.hardware import (
178+
assign_circuit,
179+
analyze_channels,
180+
generate_ahs_code,
181+
generate_quera_ir,
182+
validate_waveforms,
183+
canonicalize_circuit,
184+
)
185+
186+
if not issubclass(RemoteTask, CustomRemoteTaskABC):
187+
raise TypeError(f"{RemoteTask} must be a subclass of CustomRemoteTaskABC.")
188+
189+
circuit, params = self.circuit, self.params
190+
capabilities = get_capabilities(use_experimental)
191+
192+
tasks = OrderedDict()
193+
194+
for task_number, batch_params in enumerate(params.batch_assignments(*args)):
195+
assignments = {**batch_params, **params.static_params}
196+
final_circuit, metadata = assign_circuit(circuit, assignments)
197+
198+
level_couplings = analyze_channels(final_circuit)
199+
final_circuit = canonicalize_circuit(final_circuit, level_couplings)
200+
201+
validate_waveforms(level_couplings, final_circuit)
202+
ahs_components = generate_ahs_code(
203+
capabilities, level_couplings, final_circuit
204+
)
205+
206+
task_ir = generate_quera_ir(ahs_components, shots).discretize(capabilities)
207+
208+
tasks[task_number] = RemoteTask.from_compile_results(
209+
task_ir,
210+
metadata,
211+
ahs_components.lattice_data.parallel_decoder,
212+
)
213+
214+
batch = RemoteBatch(source=self.source, tasks=tasks, name=name)
215+
216+
return batch
217+
218+
@beartype
219+
def run_async(
220+
self,
221+
shots: int,
222+
RemoteTask: type[RemoteTaskType],
223+
args: Tuple[LiteralType, ...] = (),
224+
name: Optional[str] = None,
225+
use_experimental: bool = False,
226+
shuffle: bool = False,
227+
**kwargs,
228+
) -> RemoteBatch:
229+
"""
230+
Compile to a RemoteBatch, which contain
231+
QuEra backend specific tasks,
232+
and run_async through QuEra service.
233+
234+
Args:
235+
shots (int): number of shots
236+
args (Tuple): additional arguments
237+
name (str): custom name of the batch
238+
shuffle (bool): shuffle the order of jobs
239+
240+
Return:
241+
RemoteBatch
242+
243+
"""
244+
batch = self._compile_custom_batch(
245+
shots, RemoteTask, use_experimental, args, name
246+
)
247+
batch._submit(shuffle, **kwargs)
248+
return batch
249+
250+
@beartype
251+
def run(
252+
self,
253+
shots: int,
254+
RemoteTask: type[RemoteTaskType],
255+
args: Tuple[LiteralType, ...] = (),
256+
name: Optional[str] = None,
257+
use_experimental: bool = False,
258+
shuffle: bool = False,
259+
**kwargs,
260+
) -> RemoteBatch:
261+
"""Run the custom task and return the result.
262+
263+
Args:
264+
shots (int): number of shots
265+
RemoteTask (type): type of the remote task, must subclass of CustomRemoteTaskABC
266+
args (Tuple): additional arguments for remaining un
267+
name (str): name of the batch object
268+
shuffle (bool): shuffle the order of jobs
269+
"""
270+
if not callable(getattr(RemoteTask, "pull", None)):
271+
raise TypeError(
272+
f"{RemoteTask} must have a `pull` method for executing `run`."
273+
)
274+
275+
batch = self.run_async(
276+
shots, RemoteTask, args, name, use_experimental, shuffle, **kwargs
277+
)
278+
batch.pull()
279+
return batch
280+
281+
@beartype
282+
def __call__(
283+
self,
284+
*args: LiteralType,
285+
RemoteTask: type[RemoteTaskType] | None = None,
286+
shots: int = 1,
287+
name: Optional[str] = None,
288+
use_experimental: bool = False,
289+
shuffle: bool = False,
290+
**kwargs,
291+
) -> RemoteBatch:
292+
if RemoteTask is None:
293+
raise ValueError("RemoteTask must be provided for custom submission.")
294+
295+
return self.run(
296+
shots, RemoteTask, args, name, use_experimental, shuffle, **kwargs
297+
)
298+
164299

165300
@dataclass(frozen=True, config=__pydantic_dataclass_config__)
166301
class QuEraHardwareRoutine(RoutineBase):

src/bloqade/analog/task/base.py

+74
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import abc
12
import datetime
23
from typing import Any
34
from numbers import Number
@@ -12,11 +13,13 @@
1213

1314
from bloqade.analog.serialize import Serializer
1415
from bloqade.analog.visualization import display_report
16+
from bloqade.analog.builder.typing import ParamType
1517
from bloqade.analog.submission.ir.parallel import ParallelDecoder
1618
from bloqade.analog.submission.ir.task_results import (
1719
QuEraTaskResults,
1820
QuEraTaskStatusCode,
1921
)
22+
from bloqade.analog.submission.ir.task_specification import QuEraTaskSpecification
2023

2124

2225
@Serializer.register
@@ -97,6 +100,77 @@ def _result_exists(self) -> bool:
97100
raise NotImplementedError
98101

99102

103+
class CustomRemoteTaskABC(abc.ABC):
104+
105+
@classmethod
106+
@abc.abstractmethod
107+
def from_compile_results(
108+
cls,
109+
task_ir: QuEraTaskSpecification,
110+
metadata: Dict[str, ParamType],
111+
parallel_decoder: Optional[ParallelDecoder],
112+
): ...
113+
114+
@property
115+
@abc.abstractmethod
116+
def geometry(self) -> Geometry: ...
117+
118+
@property
119+
@abc.abstractmethod
120+
def parallel_decoder(self) -> ParallelDecoder: ...
121+
122+
@property
123+
@abc.abstractmethod
124+
def metadata(self) -> Dict[str, ParamType]: ...
125+
126+
@property
127+
def task_result_ir(self) -> QuEraTaskResults | None:
128+
if not hasattr(self, "_task_result_ir"):
129+
self._task_result_ir = QuEraTaskResults(
130+
task_status=QuEraTaskStatusCode.Unaccepted
131+
)
132+
133+
if self._result_exists():
134+
self._task_result_ir = self.result()
135+
136+
return self._task_result_ir
137+
138+
@task_result_ir.setter
139+
def set_task_result(self, task_result):
140+
self._task_result_ir = task_result
141+
142+
@property
143+
@abc.abstractmethod
144+
def task_id(self) -> str:
145+
pass
146+
147+
@property
148+
@abc.abstractmethod
149+
def task_ir(self) -> QuEraTaskSpecification: ...
150+
151+
@abc.abstractmethod
152+
def result(self) -> QuEraTaskResults: ...
153+
154+
@abc.abstractmethod
155+
def _result_exists(self) -> bool: ...
156+
157+
@abc.abstractmethod
158+
def fetch(self) -> None: ...
159+
160+
def status(self) -> QuEraTaskStatusCode:
161+
if self._result_exists():
162+
return self.result().task_status
163+
else:
164+
raise RuntimeError("Result does not exist yet.")
165+
166+
@abc.abstractmethod
167+
def _submit(self): ...
168+
169+
def submit(self, force: bool = False):
170+
if not self._result_exists() or force:
171+
self._submit()
172+
173+
100174
class LocalTask(Task):
101175
"""`Task` to use for local executions for simulation purposes.."""
102176

src/bloqade/analog/task/batch.py

+9-5
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from beartype.typing import Any, Dict, List, Union, Optional
1818

1919
from bloqade.analog.serialize import Serializer
20-
from bloqade.analog.task.base import Report
20+
from bloqade.analog.task.base import Report, CustomRemoteTaskABC
2121
from bloqade.analog.task.quera import QuEraTask
2222
from bloqade.analog.task.braket import BraketTask
2323
from bloqade.analog.builder.base import Builder
@@ -333,7 +333,11 @@ def _deserialize(obj: dict) -> BatchErrors:
333333
@Serializer.register
334334
class RemoteBatch(Serializable, Filter):
335335
source: Builder
336-
tasks: Union[OrderedDict[int, QuEraTask], OrderedDict[int, BraketTask]]
336+
tasks: Union[
337+
OrderedDict[int, QuEraTask],
338+
OrderedDict[int, BraketTask],
339+
OrderedDict[int, CustomRemoteTaskABC],
340+
]
337341
name: Optional[str] = None
338342

339343
class SubmissionException(Exception):
@@ -442,10 +446,10 @@ def tasks_metric(self) -> pd.DataFrame:
442446
# offline, non-blocking
443447
tid = []
444448
data = []
445-
for int, task in self.tasks.items():
446-
tid.append(int)
449+
for task_num, task in self.tasks.items():
450+
tid.append(task_num)
447451

448-
dat = [None, None, None]
452+
dat: list[int | str | None] = [None, None, None]
449453
dat[0] = task.task_id
450454
if task.task_result_ir is not None:
451455
dat[1] = task.task_result_ir.task_status.name

0 commit comments

Comments
 (0)