Skip to content

Commit c59063a

Browse files
oscar-wallisOkuyanBogadeclanmillaradekusar-drl
authored
Bug/607 fidelity statevector kernel cannot be pickled (#778)
* Made fidelity_statevector_kernel picklable Added a new param to store cache size and a custom __getstate__ and __setstate__ to handle removing and re-initliasing the lru cache during pickle/unpickling respectively. * updated notes * name changes * spell corrections * Updated description * Added unittest for pickling * Spelling changes * Making error messages clearer * Spelling -_- * Update releasenotes/notes/fix-fid_statevector_kernel-pickling-b7fa2b13a15ec9c6.yaml Co-authored-by: Declan Millar <[email protected]> * Update .gitignore Co-authored-by: Declan Millar <[email protected]> * Update test/kernels/test_fidelity_statevector_kernel.py Co-authored-by: Declan Millar <[email protected]> * Update qiskit_machine_learning/kernels/fidelity_statevector_kernel.py Co-authored-by: Declan Millar <[email protected]> * Update test/kernels/test_fidelity_statevector_kernel.py Co-authored-by: Declan Millar <[email protected]> * Update qiskit_machine_learning/kernels/fidelity_statevector_kernel.py Co-authored-by: Declan Millar <[email protected]> * Added Any class --------- Co-authored-by: M. Emre Sahin <[email protected]> Co-authored-by: Declan Millar <[email protected]> Co-authored-by: Anton Dekusar <[email protected]>
1 parent 2f49e9e commit c59063a

File tree

3 files changed

+67
-4
lines changed

3 files changed

+67
-4
lines changed

qiskit_machine_learning/kernels/fidelity_statevector_kernel.py

+12-3
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# This code is part of a Qiskit project.
22
#
3-
# (C) Copyright IBM 2023.
3+
# (C) Copyright IBM 2023, 2024.
44
#
55
# This code is licensed under the Apache License, Version 2.0. You may
66
# obtain a copy of this license in the LICENSE.txt file in the root directory
@@ -14,7 +14,7 @@
1414
from __future__ import annotations
1515

1616
from functools import lru_cache
17-
from typing import Type, TypeVar
17+
from typing import Type, TypeVar, Any
1818

1919
import numpy as np
2020

@@ -97,7 +97,7 @@ def __init__(
9797
self._auto_clear_cache = auto_clear_cache
9898
self._shots = shots
9999
self._enforce_psd = enforce_psd
100-
100+
self._cache_size = cache_size
101101
# Create the statevector cache at the instance level.
102102
self._get_statevector = lru_cache(maxsize=cache_size)(self._get_statevector_)
103103

@@ -160,3 +160,12 @@ def clear_cache(self):
160160
"""Clear the statevector cache."""
161161
# pylint: disable=no-member
162162
self._get_statevector.cache_clear()
163+
164+
def __getstate__(self) -> dict[str, Any]:
165+
kernel = dict(self.__dict__)
166+
kernel["_get_statevector"] = None
167+
return kernel
168+
169+
def __setstate__(self, kernel: dict[str, Any]):
170+
self.__dict__ = kernel
171+
self._get_statevector = lru_cache(maxsize=self._cache_size)(self._get_statevector_)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
---
2+
fixes:
3+
- |
4+
Fixed a bug where :class:`.FidelityStatevectorKernel` threw an error when pickled.

test/kernels/test_fidelity_statevector_kernel.py

+51-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# This code is part of a Qiskit project.
22
#
3-
# (C) Copyright IBM 2023.
3+
# (C) Copyright IBM 2023, 2024.
44
#
55
# This code is licensed under the Apache License, Version 2.0. You may
66
# obtain a copy of this license in the LICENSE.txt file in the root directory
@@ -15,6 +15,7 @@
1515

1616
import functools
1717
import itertools
18+
import pickle
1819
import sys
1920
import unittest
2021

@@ -343,6 +344,55 @@ def test_properties(self):
343344
self.assertEqual(qc, kernel.feature_map)
344345
self.assertEqual(1, kernel.num_features)
345346

347+
def test_pickling(self):
348+
"""Test that the kernel can be pickled correctly and without error."""
349+
# Compares original kernel with copies made using pickle module and get & set state directly
350+
qc = QuantumCircuit(1)
351+
qc.ry(Parameter("w"), 0)
352+
kernel1 = FidelityStatevectorKernel(feature_map=qc)
353+
354+
pickled_obj = pickle.dumps(kernel1)
355+
kernel2 = pickle.loads(pickled_obj)
356+
357+
kernel3 = FidelityStatevectorKernel()
358+
kernel3.__setstate__(kernel1.__getstate__())
359+
360+
with self.subTest("Pickle fail, kernels are not the same type"):
361+
self.assertEqual(type(kernel1), type(kernel2))
362+
363+
with self.subTest("Pickle fail, kernels are not the same type"):
364+
self.assertEqual(type(kernel1), type(kernel3))
365+
366+
with self.subTest("Pickle fail, kernels are not unique objects"):
367+
self.assertNotEqual(kernel1, kernel2)
368+
369+
with self.subTest("Pickle fail, kernels are not unique objects"):
370+
self.assertNotEqual(kernel1, kernel3)
371+
372+
with self.subTest("Pickle fail, caches are not the same type"):
373+
self.assertEqual(type(kernel1._get_statevector), type(kernel2._get_statevector))
374+
375+
with self.subTest("Pickle fail, caches are not the same type"):
376+
self.assertEqual(type(kernel1._get_statevector), type(kernel3._get_statevector))
377+
378+
# Remove cache to check dict properties are otherwise identical.
379+
# - caches are never identical as they have different RAM locations.
380+
kernel1.__dict__["_get_statevector"] = None
381+
kernel2.__dict__["_get_statevector"] = None
382+
kernel3.__dict__["_get_statevector"] = None
383+
384+
# Confirm changes were made.
385+
with self.subTest("Pickle fail, caches have not been removed from kernels"):
386+
self.assertEqual(kernel1._get_statevector, None)
387+
self.assertEqual(kernel2._get_statevector, None)
388+
self.assertEqual(kernel3._get_statevector, None)
389+
390+
with self.subTest("Pickle fail, properties of kernels (bar cache) are not identical"):
391+
self.assertEqual(kernel1.__dict__, kernel2.__dict__)
392+
393+
with self.subTest("Pickle fail, properties of kernels (bar cache) are not identical"):
394+
self.assertEqual(kernel1.__dict__, kernel3.__dict__)
395+
346396

347397
@ddt
348398
class TestStatevectorKernelDuplicates(QiskitMachineLearningTestCase):

0 commit comments

Comments
 (0)