|
2 | 2 | import abc
|
3 | 3 | import uuid
|
4 | 4 | import re
|
| 5 | +import time |
5 | 6 |
|
6 | 7 | from beartype.typing import Dict
|
7 | 8 | from dataclasses import dataclass, field
|
@@ -245,12 +246,33 @@ def fetch(self):
|
245 | 246 |
|
246 | 247 | return self
|
247 | 248 |
|
248 |
| - def pull(self): |
249 |
| - # Please avoid using this method, it's blocking and the waiting time is hours long |
250 |
| - # Throw an error saying this is not supported |
251 |
| - raise NotImplementedError( |
252 |
| - "Pulling is not supported. Please use fetch() instead." |
253 |
| - ) |
| 249 | + def pull(self, poll_interval: float = 20): |
| 250 | + """ |
| 251 | + Blocking pull to get the task result. |
| 252 | + poll_interval is the time interval to poll the task status. |
| 253 | + Please ensure that it is relatively large, otherwise |
| 254 | + the server could get overloaded with queries. |
| 255 | + """ |
| 256 | + |
| 257 | + while True: |
| 258 | + if self._task_result_ir.task_status is QuEraTaskStatusCode.Unsubmitted: |
| 259 | + raise ValueError("Task ID not found.") |
| 260 | + |
| 261 | + if self._task_result_ir.task_status in [ |
| 262 | + QuEraTaskStatusCode.Completed, |
| 263 | + QuEraTaskStatusCode.Partial, |
| 264 | + QuEraTaskStatusCode.Failed, |
| 265 | + QuEraTaskStatusCode.Unaccepted, |
| 266 | + QuEraTaskStatusCode.Cancelled, |
| 267 | + ]: |
| 268 | + return self |
| 269 | + |
| 270 | + status = self.status() |
| 271 | + if status in [QuEraTaskStatusCode.Completed, QuEraTaskStatusCode.Partial]: |
| 272 | + self._task_result_ir = self._http_handler.fetch_results(self._task_id) |
| 273 | + return self |
| 274 | + |
| 275 | + time.sleep(poll_interval) |
254 | 276 |
|
255 | 277 | def cancel(self):
|
256 | 278 | # This is not supported
|
|
0 commit comments