|
1 | 1 | # This code is part of a Qiskit project.
|
2 | 2 | #
|
3 |
| -# (C) Copyright IBM 2023. |
| 3 | +# (C) Copyright IBM 2023, 2024. |
4 | 4 | #
|
5 | 5 | # This code is licensed under the Apache License, Version 2.0. You may
|
6 | 6 | # obtain a copy of this license in the LICENSE.txt file in the root directory
|
|
15 | 15 |
|
16 | 16 | import functools
|
17 | 17 | import itertools
|
| 18 | +import pickle |
18 | 19 | import sys
|
19 | 20 | import unittest
|
20 | 21 |
|
@@ -343,6 +344,55 @@ def test_properties(self):
|
343 | 344 | self.assertEqual(qc, kernel.feature_map)
|
344 | 345 | self.assertEqual(1, kernel.num_features)
|
345 | 346 |
|
| 347 | + def test_pickling(self): |
| 348 | + """Test that the kernel can be pickled correctly and without error.""" |
| 349 | + # Compares original kernel with copies made using pickle module and get & set state directly |
| 350 | + qc = QuantumCircuit(1) |
| 351 | + qc.ry(Parameter("w"), 0) |
| 352 | + kernel1 = FidelityStatevectorKernel(feature_map=qc) |
| 353 | + |
| 354 | + pickled_obj = pickle.dumps(kernel1) |
| 355 | + kernel2 = pickle.loads(pickled_obj) |
| 356 | + |
| 357 | + kernel3 = FidelityStatevectorKernel() |
| 358 | + kernel3.__setstate__(kernel1.__getstate__()) |
| 359 | + |
| 360 | + with self.subTest("Pickle fail, kernels are not the same type"): |
| 361 | + self.assertEqual(type(kernel1), type(kernel2)) |
| 362 | + |
| 363 | + with self.subTest("Pickle fail, kernels are not the same type"): |
| 364 | + self.assertEqual(type(kernel1), type(kernel3)) |
| 365 | + |
| 366 | + with self.subTest("Pickle fail, kernels are not unique objects"): |
| 367 | + self.assertNotEqual(kernel1, kernel2) |
| 368 | + |
| 369 | + with self.subTest("Pickle fail, kernels are not unique objects"): |
| 370 | + self.assertNotEqual(kernel1, kernel3) |
| 371 | + |
| 372 | + with self.subTest("Pickle fail, caches are not the same type"): |
| 373 | + self.assertEqual(type(kernel1._get_statevector), type(kernel2._get_statevector)) |
| 374 | + |
| 375 | + with self.subTest("Pickle fail, caches are not the same type"): |
| 376 | + self.assertEqual(type(kernel1._get_statevector), type(kernel3._get_statevector)) |
| 377 | + |
| 378 | + # Remove cache to check dict properties are otherwise identical. |
| 379 | + # - caches are never identical as they have different RAM locations. |
| 380 | + kernel1.__dict__["_get_statevector"] = None |
| 381 | + kernel2.__dict__["_get_statevector"] = None |
| 382 | + kernel3.__dict__["_get_statevector"] = None |
| 383 | + |
| 384 | + # Confirm changes were made. |
| 385 | + with self.subTest("Pickle fail, caches have not been removed from kernels"): |
| 386 | + self.assertEqual(kernel1._get_statevector, None) |
| 387 | + self.assertEqual(kernel2._get_statevector, None) |
| 388 | + self.assertEqual(kernel3._get_statevector, None) |
| 389 | + |
| 390 | + with self.subTest("Pickle fail, properties of kernels (bar cache) are not identical"): |
| 391 | + self.assertEqual(kernel1.__dict__, kernel2.__dict__) |
| 392 | + |
| 393 | + with self.subTest("Pickle fail, properties of kernels (bar cache) are not identical"): |
| 394 | + self.assertEqual(kernel1.__dict__, kernel3.__dict__) |
| 395 | + |
346 | 396 |
|
347 | 397 | @ddt
|
348 | 398 | class TestStatevectorKernelDuplicates(QiskitMachineLearningTestCase):
|
|
0 commit comments