Skip to content

Commit 1d334fa

Browse files
authored
Preserve bit locations through pickle (#13981)
* Preserve bit locations through pickle Previously when we were pickling a DAGCircuit the bit locations fields were forgotten about. So when we loaded a DAGCircuit from a pickle the bit locations were empty. This would cause any call to find_bit() to raise an error because there was no entry for any of the bits in the dag. This commit fixes this by reconstructing by including the bit locations fields in the object state returned by __getstate__ and populating the fields from the provided state in __setstate__/ Fixes #13976 * Support pickle for BitLocations * Adjust pickle impl * Make BitLocations importable from Python so pickle can use it
1 parent f86d8b4 commit 1d334fa

File tree

3 files changed

+95
-1
lines changed

3 files changed

+95
-1
lines changed

crates/circuit/src/dag_circuit.rs

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -346,13 +346,25 @@ fn reject_new_register(reg: &Bound<PyAny>) -> PyResult<()> {
346346

347347
#[pyclass(module = "qiskit._accelerate.circuit")]
348348
#[derive(Clone, Debug)]
349-
struct BitLocations {
349+
pub(crate) struct BitLocations {
350350
#[pyo3(get)]
351351
index: usize,
352352
#[pyo3(get)]
353353
registers: Py<PyList>,
354354
}
355355

356+
#[pymethods]
357+
impl BitLocations {
358+
#[new]
359+
fn new(index: usize, registers: Py<PyList>) -> Self {
360+
Self { index, registers }
361+
}
362+
363+
fn __getnewargs__(&self, py: Python) -> (usize, Py<PyList>) {
364+
(self.index, self.registers.clone_ref(py))
365+
}
366+
}
367+
356368
#[derive(Copy, Clone, Debug)]
357369
enum DAGVarType {
358370
Input = 0,
@@ -528,6 +540,8 @@ impl DAGCircuit {
528540
out_dict.set_item("qregs", self.qregs.clone_ref(py))?;
529541
out_dict.set_item("cregs", self.cregs.clone_ref(py))?;
530542
out_dict.set_item("global_phase", self.global_phase.clone())?;
543+
out_dict.set_item("qubit_locations", self.qubit_locations.clone_ref(py))?;
544+
out_dict.set_item("clbit_locations", self.clbit_locations.clone_ref(py))?;
531545
out_dict.set_item(
532546
"qubit_io_map",
533547
self.qubit_io_map
@@ -617,6 +631,8 @@ impl DAGCircuit {
617631
self.qregs = dict_state.get_item("qregs")?.unwrap().extract()?;
618632
self.cregs = dict_state.get_item("cregs")?.unwrap().extract()?;
619633
self.global_phase = dict_state.get_item("global_phase")?.unwrap().extract()?;
634+
self.qubit_locations = dict_state.get_item("qubit_locations")?.unwrap().extract()?;
635+
self.clbit_locations = dict_state.get_item("clbit_locations")?.unwrap().extract()?;
620636
self.op_names = dict_state.get_item("op_name")?.unwrap().extract()?;
621637
self.vars_by_type = dict_state.get_item("vars_by_type")?.unwrap().extract()?;
622638
let binding = dict_state.get_item("vars_info")?.unwrap();

crates/circuit/src/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,7 @@ impl From<Clbit> for BitType {
125125
pub fn circuit(m: &Bound<PyModule>) -> PyResult<()> {
126126
m.add_class::<circuit_data::CircuitData>()?;
127127
m.add_class::<circuit_instruction::CircuitInstruction>()?;
128+
m.add_class::<dag_circuit::BitLocations>()?;
128129
m.add_class::<dag_circuit::DAGCircuit>()?;
129130
m.add_class::<dag_node::DAGNode>()?;
130131
m.add_class::<dag_node::DAGInNode>()?;

test/python/dagcircuit/test_dagcircuit.py

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,9 @@
1515
from __future__ import annotations
1616

1717
from collections import Counter
18+
import io
19+
import copy
20+
import pickle
1821
import unittest
1922

2023
from ddt import ddt, data
@@ -161,6 +164,80 @@ def test_dag_get_qubits(self):
161164
],
162165
)
163166

167+
def test_pickle_bit_locations_with_reg(self):
168+
"""Test bit locations preserved through pickle."""
169+
dag = DAGCircuit()
170+
qr = QuantumRegister(2, "qr")
171+
cr = ClassicalRegister(1, "cr")
172+
dag.add_qreg(qr)
173+
dag.add_creg(cr)
174+
self.assertEqual(dag.find_bit(dag.qubits[1]).index, 1)
175+
self.assertEqual(dag.find_bit(dag.qubits[1]).registers, [(qr, 1)])
176+
self.assertEqual(dag.find_bit(dag.clbits[0]).index, 0)
177+
self.assertEqual(dag.find_bit(dag.clbits[0]).registers, [(cr, 0)])
178+
with io.BytesIO() as buf:
179+
pickle.dump(dag, buf)
180+
buf.seek(0)
181+
output = pickle.load(buf)
182+
self.assertEqual(output.find_bit(output.qubits[1]).index, 1)
183+
self.assertEqual(output.find_bit(output.qubits[1]).registers, [(qr, 1)])
184+
self.assertEqual(output.find_bit(output.clbits[0]).index, 0)
185+
self.assertEqual(output.find_bit(output.clbits[0]).registers, [(cr, 0)])
186+
187+
def test_deepcopy_bit_locations_with_reg(self):
188+
"""Test bit locations preserved through pickle."""
189+
dag = DAGCircuit()
190+
qr = QuantumRegister(2, "qr")
191+
cr = ClassicalRegister(1, "cr")
192+
dag.add_qreg(qr)
193+
dag.add_creg(cr)
194+
self.assertEqual(dag.find_bit(dag.qubits[1]).index, 1)
195+
self.assertEqual(dag.find_bit(dag.qubits[1]).registers, [(qr, 1)])
196+
self.assertEqual(dag.find_bit(dag.clbits[0]).index, 0)
197+
self.assertEqual(dag.find_bit(dag.clbits[0]).registers, [(cr, 0)])
198+
output = copy.deepcopy(dag)
199+
self.assertEqual(output.find_bit(output.qubits[1]).index, 1)
200+
self.assertEqual(output.find_bit(output.qubits[1]).registers, [(qr, 1)])
201+
self.assertEqual(output.find_bit(output.clbits[0]).index, 0)
202+
self.assertEqual(output.find_bit(output.clbits[0]).registers, [(cr, 0)])
203+
204+
def test_pickle_bit_locations_with_no_reg(self):
205+
"""Test bit locations preserved through pickle."""
206+
dag = DAGCircuit()
207+
qubits = [Qubit(), Qubit()]
208+
clbits = [Clbit()]
209+
dag.add_qubits(qubits)
210+
dag.add_clbits(clbits)
211+
self.assertEqual(dag.find_bit(dag.qubits[1]).index, 1)
212+
self.assertEqual(dag.find_bit(dag.qubits[1]).registers, [])
213+
self.assertEqual(dag.find_bit(dag.clbits[0]).index, 0)
214+
self.assertEqual(dag.find_bit(dag.clbits[0]).registers, [])
215+
with io.BytesIO() as buf:
216+
pickle.dump(dag, buf)
217+
buf.seek(0)
218+
output = pickle.load(buf)
219+
self.assertEqual(output.find_bit(output.qubits[1]).index, 1)
220+
self.assertEqual(output.find_bit(output.qubits[1]).registers, [])
221+
self.assertEqual(output.find_bit(output.clbits[0]).index, 0)
222+
self.assertEqual(output.find_bit(output.clbits[0]).registers, [])
223+
224+
def test_deepcopy_bit_locations_with_no_reg(self):
225+
"""Test bit locations preserved through pickle."""
226+
dag = DAGCircuit()
227+
qubits = [Qubit(), Qubit()]
228+
clbits = [Clbit()]
229+
dag.add_qubits(qubits)
230+
dag.add_clbits(clbits)
231+
self.assertEqual(dag.find_bit(dag.qubits[1]).index, 1)
232+
self.assertEqual(dag.find_bit(dag.qubits[1]).registers, [])
233+
self.assertEqual(dag.find_bit(dag.clbits[0]).index, 0)
234+
self.assertEqual(dag.find_bit(dag.clbits[0]).registers, [])
235+
output = copy.deepcopy(dag)
236+
self.assertEqual(output.find_bit(output.qubits[1]).index, 1)
237+
self.assertEqual(output.find_bit(output.qubits[1]).registers, [])
238+
self.assertEqual(output.find_bit(output.clbits[0]).index, 0)
239+
self.assertEqual(output.find_bit(output.clbits[0]).registers, [])
240+
164241
def test_add_reg_duplicate(self):
165242
"""add_qreg with the same register twice is not allowed."""
166243
dag = DAGCircuit()

0 commit comments

Comments
 (0)