diff --git a/qiskit/primitives/containers/data_bin.py b/qiskit/primitives/containers/data_bin.py index 21252d2ef906..cb6efcc4f42c 100644 --- a/qiskit/primitives/containers/data_bin.py +++ b/qiskit/primitives/containers/data_bin.py @@ -103,9 +103,6 @@ def __init__(self, *, shape: ShapeInput = (), **data): def __len__(self): return len(self._data) - def __setattr__(self, *_): - raise NotImplementedError - def __repr__(self): vals = [f"{name}={_value_repr(val)}" for name, val in self.items()] if self.ndim: diff --git a/qiskit/primitives/primitive_job.py b/qiskit/primitives/primitive_job.py index 64ab8d016095..bc316cfde61d 100644 --- a/qiskit/primitives/primitive_job.py +++ b/qiskit/primitives/primitive_job.py @@ -35,6 +35,8 @@ def __init__(self, function, *args, **kwargs): super().__init__(str(uuid.uuid4())) self._future = None self._function = function + self._result = None + self._status = None self._args = args self._kwargs = kwargs @@ -46,19 +48,36 @@ def _submit(self): self._future = executor.submit(self._function, *self._args, **self._kwargs) executor.shutdown(wait=False) + def __getstate__(self): + _ = self.result() + _ = self.status() + state = self.__dict__.copy() + state["_future"] = None + return state + + def __setstate__(self, state): + self.__dict__.update(state) + self._future = None + def result(self) -> ResultT: - self._check_submitted() - return self._future.result() + if self._result is None: + self._check_submitted() + self._result = self._future.result() + return self._result def status(self) -> JobStatus: - self._check_submitted() - if self._future.running(): - return JobStatus.RUNNING - elif self._future.cancelled(): - return JobStatus.CANCELLED - elif self._future.done() and self._future.exception() is None: - return JobStatus.DONE - return JobStatus.ERROR + if self._status is None: + self._check_submitted() + if self._future.running(): + # we should not store status running because it is not completed + return JobStatus.RUNNING + elif self._future.cancelled(): + self._status = JobStatus.CANCELLED + elif self._future.done() and self._future.exception() is None: + self._status = JobStatus.DONE + else: + self._status = JobStatus.ERROR + return self._status def _check_submitted(self): if self._future is None: diff --git a/releasenotes/notes/serialize-primitive-job-aa97d0bf8221ea99.yaml b/releasenotes/notes/serialize-primitive-job-aa97d0bf8221ea99.yaml new file mode 100644 index 000000000000..4d9a08b9802e --- /dev/null +++ b/releasenotes/notes/serialize-primitive-job-aa97d0bf8221ea99.yaml @@ -0,0 +1,5 @@ +--- +features_primitives: + - | + To make :class:`.PrimitiveJob` serializable, :class:`.DataBin` has been + updated to be pickleable. As a result, :class:`.PrimitiveResult` is now also pickleable. diff --git a/test/python/primitives/test_primitive_job.py b/test/python/primitives/test_primitive_job.py new file mode 100644 index 000000000000..f6cb7479f057 --- /dev/null +++ b/test/python/primitives/test_primitive_job.py @@ -0,0 +1,51 @@ +# This code is part of Qiskit. +# +# (C) Copyright IBM 2025. +# +# This code is licensed under the Apache License, Version 2.0. You may +# obtain a copy of this license in the LICENSE.txt file in the root directory +# of this source tree or at http://www.apache.org/licenses/LICENSE-2.0. +# +# Any modifications or derivative works of this code must retain this +# copyright notice, and modified files need to carry a notice indicating +# that they have been altered from the originals. + +"""Tests for PrimitiveJob.""" + +import pickle +from test import QiskitTestCase + +import numpy as np +from ddt import data, ddt + +from qiskit import QuantumCircuit +from qiskit.primitives import PrimitiveJob, StatevectorSampler + + +@ddt +class TestPrimitiveJob(QiskitTestCase): + """Tests PrimitiveJob.""" + + @data(1, 2, 3) + def test_serialize(self, size): + """Test serialize.""" + n = 2 + qc = QuantumCircuit(n) + qc.h(range(n)) + qc.measure_all() + sampler = StatevectorSampler() + job = sampler.run([qc] * size) + obj = pickle.dumps(job) + job2 = pickle.loads(obj) + self.assertIsInstance(job2, PrimitiveJob) + self.assertEqual(job.job_id(), job2.job_id()) + self.assertEqual(job.status(), job2.status()) + self.assertEqual(job.metadata, job2.metadata) + result = job.result() + result2 = job2.result() + self.assertEqual(result.metadata, result2.metadata) + self.assertEqual(len(result), len(result2)) + for sampler_pub in result: + self.assertEqual(sampler_pub.metadata, sampler_pub.metadata) + self.assertEqual(sampler_pub.data.keys(), sampler_pub.data.keys()) + np.testing.assert_allclose(sampler_pub.join_data().array, sampler_pub.join_data().array)