-
Notifications
You must be signed in to change notification settings - Fork 18
Exclusive access task #1025
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Exclusive access task #1025
Changes from 10 commits
27d0653
b9d5db8
b2eb517
9b9e20b
590c517
b56194b
7dcb7af
aaa575d
a460d23
74c53fa
2927372
1870b89
2f98ecf
06b090c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||
---|---|---|---|---|
@@ -0,0 +1,347 @@ | ||||
import os | ||||
import abc | ||||
import uuid | ||||
import json | ||||
import re | ||||
|
||||
from beartype.typing import Dict | ||||
|
||||
from bloqade.analog.task.base import Geometry, CustomRemoteTaskABC | ||||
from bloqade.analog.builder.typing import ParamType | ||||
from bloqade.analog.submission.ir.parallel import ParallelDecoder | ||||
from bloqade.analog.submission.ir.task_results import ( | ||||
QuEraTaskResults, | ||||
QuEraTaskStatusCode, | ||||
) | ||||
from bloqade.analog.submission.ir.task_specification import QuEraTaskSpecification | ||||
from requests import Response, request, get | ||||
from bloqade.analog.serialize import Serializer | ||||
from bloqade.analog.builder.base import ParamType | ||||
|
||||
|
||||
|
||||
class HTTPHandlerABC: | ||||
@abc.abstractmethod | ||||
def submit_task_via_zapier(task_ir: QuEraTaskSpecification, task_id: str): | ||||
"""Submit a task and add task_id to the task fields for querying later. | ||||
|
||||
args: | ||||
task_ir: The task to be submitted. | ||||
task_id: The task id to be added to the task fields. | ||||
|
||||
returns | ||||
response: The response from the Zapier webhook. used for error handling | ||||
|
||||
""" | ||||
... | ||||
|
||||
@abc.abstractmethod | ||||
def query_task_status(task_id: str): | ||||
"""Query the task status from the AirTable. | ||||
|
||||
args: | ||||
task_id: The task id to be queried. | ||||
|
||||
returns | ||||
response: The response from the AirTable. used for error handling | ||||
|
||||
""" | ||||
... | ||||
|
||||
@abc.abstractmethod | ||||
def fetch_results(task_id: str): | ||||
"""Fetch the task results from the AirTable. | ||||
|
||||
args: | ||||
task_id: The task id to be queried. | ||||
|
||||
returns | ||||
response: The response from the AirTable. used for error handling | ||||
|
||||
""" | ||||
|
||||
... | ||||
|
||||
|
||||
def convert_preview_to_download(preview_url): | ||||
# help function to convert the googledrive preview URL to download URL | ||||
# Only used in http handler | ||||
match = re.search(r"/d/([^/]+)/", preview_url) | ||||
if not match: | ||||
raise ValueError("Invalid preview URL format") | ||||
file_id = match.group(1) | ||||
return f"https://drive.usercontent.google.com/download?id={file_id}&export=download" | ||||
|
||||
|
||||
class HTTPHandler(HTTPHandlerABC): | ||||
def __init__(self, zapier_webhook_url: str = None, | ||||
zapier_webhook_key: str = None, | ||||
vercel_api_url: str = None): | ||||
self.zapier_webhook_url = zapier_webhook_url or os.environ["ZAPIER_WEBHOOK_URL"] | ||||
self.zapier_webhook_key = zapier_webhook_key or os.environ["ZAPIER_WEBHOOK_KEY"] | ||||
self.verrcel_api_url = vercel_api_url or os.environ["VERCEL_API_URL"] | ||||
|
||||
def submit_task_via_zapier(self, task_ir: QuEraTaskSpecification, task_id: str, task_note: str): | ||||
# implement http request logic to submit task via Zapier | ||||
request_options = dict( | ||||
params={"key": self.zapier_webhook_key, "note": task_id}) | ||||
|
||||
# for metadata, task_ir in self._compile_single(shots, use_experimental, args): | ||||
json_request_body = task_ir.json(exclude_none=True, exclude_unset=True) | ||||
|
||||
request_options.update(data=json_request_body) | ||||
response = request("POST", self.zapier_webhook_url, **request_options) | ||||
|
||||
if response.status_code == 200: | ||||
response_data = response.json() | ||||
submit_status = response_data.get("status", None) | ||||
return submit_status | ||||
else: | ||||
print( | ||||
f"HTTP request failed with status code: {response.status_code}") | ||||
print("HTTP responce: ", response.text) | ||||
return "Failed" | ||||
|
||||
def query_task_status(self, task_id: str): | ||||
response = request( | ||||
"GET", | ||||
self.verrcel_api_url, | ||||
params={ | ||||
"searchPattern": task_id, | ||||
"magicToken": self.zapier_webhook_key, | ||||
"useRegex": False, | ||||
}, | ||||
) | ||||
if response.status_code != 200: | ||||
return "Not Found" | ||||
response_data = response.json() | ||||
# Get "matched" from the response | ||||
matches = response_data.get("matches", None) | ||||
# The return is a list of dictionaries | ||||
# Verify if the list contains only one element | ||||
if matches is None: | ||||
print("No task found with the given ID.") | ||||
return "Failed" | ||||
elif len(matches) > 1: | ||||
print("Multiple tasks found with the given ID.") | ||||
return "Failed" | ||||
|
||||
# Extract the status from the first dictionary | ||||
status = matches[0].get("status") | ||||
return status | ||||
|
||||
|
||||
def fetch_results(self, task_id: str): | ||||
response = request( | ||||
"GET", | ||||
self.verrcel_api_url, | ||||
params={ | ||||
"searchPattern": task_id, | ||||
"magicToken": self.zapier_webhook_key, | ||||
"useRegex": False, | ||||
}, | ||||
) | ||||
if response.status_code != 200: | ||||
print( | ||||
f"HTTP request failed with status code: {response.status_code}") | ||||
print("HTTP responce: ", response.text) | ||||
return None | ||||
|
||||
response_data = response.json() | ||||
# Get "matched" from the response | ||||
matches = response_data.get("matches", None) | ||||
# The return is a list of dictionaries | ||||
# Verify if the list contains only one element | ||||
if matches is None: | ||||
print("No task found with the given ID.") | ||||
return None | ||||
elif len(matches) > 1: | ||||
print("Multiple tasks found with the given ID.") | ||||
return None | ||||
record = matches[0] | ||||
if record.get("status") == "Completed": | ||||
googledoc = record.get("resultsFileUrl") | ||||
|
||||
# convert the preview URL to download URL | ||||
googledoc = convert_preview_to_download( | ||||
googledoc) | ||||
res = get(googledoc) | ||||
res.raise_for_status() | ||||
data = res.json() | ||||
|
||||
task_results = QuEraTaskResults(**data) | ||||
return task_results | ||||
|
||||
|
||||
class TestHTTPHandler(HTTPHandlerABC): | ||||
pass | ||||
|
||||
@Serializer.register | ||||
class ExclusiveRemoteTask(CustomRemoteTaskABC): | ||||
def __init__( | ||||
self, | ||||
task_ir: QuEraTaskSpecification, | ||||
metadata: Dict[str, ParamType], | ||||
parallel_decoder: ParallelDecoder | None, | ||||
http_handler: HTTPHandlerABC | None = HTTPHandler(), | ||||
task_id: str = None, | ||||
task_result_ir: QuEraTaskResults = None, | ||||
): | ||||
self._http_handler = http_handler | ||||
self._task_ir = task_ir | ||||
self._metadata = metadata | ||||
self._parallel_decoder = parallel_decoder | ||||
float_sites = list( | ||||
map(lambda x: (float(x[0]), float(x[1])), task_ir.lattice.sites) | ||||
) | ||||
self._geometry = Geometry( | ||||
float_sites, task_ir.lattice.filling, parallel_decoder | ||||
) | ||||
self._task_id = task_id | ||||
self._task_result_ir = task_result_ir | ||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. As per our discussion 1 on 1, use a dataclass to clean up this implementation, Also you should use default_factories for the default values instead of instances of objects that are mutable. |
||||
|
||||
@classmethod | ||||
def from_compile_results(cls, task_ir, metadata, parallel_decoder): | ||||
return cls( | ||||
task_ir=task_ir, | ||||
metadata=metadata, | ||||
parallel_decoder=parallel_decoder, | ||||
) | ||||
|
||||
def _submit(self, force: bool = False) -> "ExclusiveRemoteTask": | ||||
if not force: | ||||
if self._task_id is not None: | ||||
raise ValueError( | ||||
"the task is already submitted with %s" % (self._task_id) | ||||
) | ||||
self._task_id = str(uuid.uuid4()) | ||||
if self._http_handler.submit_task_via_zapier(self._task_ir, self._task_id, None) == "success": | ||||
self._task_result_ir = QuEraTaskResults( | ||||
task_status=QuEraTaskStatusCode.Accepted) | ||||
else: | ||||
self._task_result_ir = QuEraTaskResults( | ||||
task_status=QuEraTaskStatusCode.Failed) | ||||
print(self.task_result_ir) | ||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
removing debug print statements. |
||||
return self | ||||
|
||||
def fetch(self): | ||||
if self.task_result_ir.task_status is QuEraTaskStatusCode.Unsubmitted: | ||||
raise ValueError("Task ID not found.") | ||||
|
||||
if self.task_result_ir.task_status in [ | ||||
QuEraTaskStatusCode.Completed, | ||||
QuEraTaskStatusCode.Partial, | ||||
QuEraTaskStatusCode.Failed, | ||||
QuEraTaskStatusCode.Unaccepted, | ||||
QuEraTaskStatusCode.Cancelled, | ||||
]: | ||||
return self | ||||
|
||||
status = self.status() | ||||
if status in [QuEraTaskStatusCode.Completed, QuEraTaskStatusCode.Partial]: | ||||
self.task_result_ir = self._http_handler.fetch_results( | ||||
self.task_id) | ||||
else: | ||||
self.task_result_ir = QuEraTaskResults(task_status=status) | ||||
|
||||
return self | ||||
|
||||
def pull(self): | ||||
# Please avoid using this method, it's blocking and the waiting time is hours long | ||||
# Throw an error saying this is not supported | ||||
raise NotImplementedError( | ||||
"Pulling is not supported. Please use fetch() instead." | ||||
) | ||||
|
||||
def cancel(self): | ||||
# This is not supported | ||||
raise NotImplementedError( | ||||
"Cancelling is not supported." | ||||
) | ||||
|
||||
def status(self) -> QuEraTaskStatusCode: | ||||
if self._task_id is None: | ||||
return QuEraTaskStatusCode.Unsubmitted | ||||
res = self._http_handler.query_task_status(self._task_id) | ||||
#print("Query task status: ", res) | ||||
if res == "Failed": | ||||
raise ValueError("Query task status failed.") | ||||
elif res == "Submitted": | ||||
return QuEraTaskStatusCode.Enqueued | ||||
# TODO: please add all possible status | ||||
elif res == "Completed": | ||||
return QuEraTaskStatusCode.Completed | ||||
elif res == "Running": | ||||
# Not covered by test | ||||
return QuEraTaskStatusCode.Executing | ||||
else: | ||||
return self.task_result_ir.task_status | ||||
|
||||
def _result_exists(self): | ||||
if self.task_result_ir is None: | ||||
return False | ||||
else: | ||||
if self.task_result_ir.task_status == QuEraTaskStatusCode.Completed: | ||||
return True | ||||
else: | ||||
return False | ||||
|
||||
def result(self): | ||||
if self._task_result_ir is None: | ||||
raise ValueError("Task result not found.") | ||||
return self._task_result_ir | ||||
|
||||
@property | ||||
def metadata(self): | ||||
return self._metadata | ||||
|
||||
@property | ||||
def geometry(self): | ||||
return self._geometry | ||||
|
||||
@property | ||||
def task_ir(self): | ||||
return self._task_ir | ||||
|
||||
@property | ||||
def task_id(self) -> str: | ||||
assert isinstance(self._task_id, str), "Task ID is not set" | ||||
return self._task_id | ||||
|
||||
@property | ||||
def task_result_ir(self): | ||||
return self._task_result_ir | ||||
|
||||
@property | ||||
def parallel_decoder(self): | ||||
return self._parallel_decoder | ||||
|
||||
@task_result_ir.setter | ||||
def task_result_ir(self, task_result_ir: QuEraTaskResults): | ||||
self._task_result_ir = task_result_ir | ||||
|
||||
|
||||
@ExclusiveRemoteTask.set_serializer | ||||
|
||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
removing a space here. |
||||
def _serialze(obj: ExclusiveRemoteTask) -> Dict[str, ParamType]: | ||||
return { | ||||
"task_id": obj.task_id or None, | ||||
"task_ir": obj.task_ir.dict(by_alias=True, exclude_none=True), | ||||
"metadata": obj.metadata, | ||||
"parallel_decoder": ( | ||||
obj.parallel_decoder.dict() if obj.parallel_decoder else None | ||||
), | ||||
"task_result_ir": obj.task_result_ir.dict() if obj.task_result_ir else None, | ||||
} | ||||
|
||||
|
||||
@ExclusiveRemoteTask.set_deserializer | ||||
def _deserializer(d: Dict[str, any]) -> ExclusiveRemoteTask: | ||||
# TODO: Not tested, once it's done, resolve the DEBUG flag | ||||
d["task_ir"] = QuEraTaskSpecification(**d["task_ir"]) | ||||
d["parallel_decoder"] = ( | ||||
ParallelDecoder(**d["parallel_decoder"] | ||||
) if d["parallel_decoder"] else None | ||||
) | ||||
return ExclusiveRemoteTask(**d) | ||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This shouldn't be deleted, this is needed to get things working with the Report object.