Skip to content

Commit 23f3b85

Browse files
woodsp-ibmCryoris
andauthored
Update gradient logic for Qiskit Rust circuit data implementation (#188)
Co-authored-by: Julien Gacon <[email protected]>
1 parent bf5e903 commit 23f3b85

File tree

6 files changed

+64
-20
lines changed

6 files changed

+64
-20
lines changed

qiskit_algorithms/gradients/reverse/derive_circuit.py

+37-13
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
import itertools
1717
from collections.abc import Sequence
1818

19-
from qiskit.circuit import QuantumCircuit, Parameter, Gate
19+
from qiskit.circuit import QuantumCircuit, Parameter, Gate, ParameterExpression
2020
from qiskit.circuit.library import RXGate, RYGate, RZGate, CRXGate, CRYGate, CRZGate
2121

2222

@@ -90,7 +90,7 @@ def gradient_lookup(gate: Gate) -> list[tuple[complex, QuantumCircuit]]:
9090

9191

9292
def derive_circuit(
93-
circuit: QuantumCircuit, parameter: Parameter
93+
circuit: QuantumCircuit, parameter: Parameter, check: bool = True
9494
) -> Sequence[tuple[complex, QuantumCircuit]]:
9595
"""Return the analytic gradient expression of the input circuit wrt. a single parameter.
9696
@@ -114,6 +114,8 @@ def derive_circuit(
114114
Args:
115115
circuit: The quantum circuit to derive.
116116
parameter: The parameter with respect to which we derive.
117+
check: If ``True`` (default) check that the parameter is valid and that no product
118+
rule is required.
117119
118120
Returns:
119121
A list of ``(coeff, gradient_circuit)`` tuples.
@@ -124,16 +126,31 @@ def derive_circuit(
124126
NotImplementedError: If a non-unique parameter is added, as the product rule is not yet
125127
supported in this function.
126128
"""
127-
# this is added as useful user-warning, since sometimes ``ParameterExpression``s are
128-
# passed around instead of ``Parameter``s
129-
if not isinstance(parameter, Parameter):
130-
raise ValueError(f"parameter must be of type Parameter, not {type(parameter)}.")
131-
132-
if parameter not in circuit.parameters:
133-
raise ValueError(f"The parameter {parameter} is not in this circuit.")
134-
135-
if len(circuit._parameter_table[parameter]) > 1:
136-
raise NotImplementedError("No product rule support yet, circuit parameters must be unique.")
129+
if check:
130+
# this is added as useful user-warning, since sometimes ``ParameterExpression``s are
131+
# passed around instead of ``Parameter``s
132+
if not isinstance(parameter, Parameter):
133+
raise ValueError(f"parameter must be of type Parameter, not {type(parameter)}.")
134+
135+
if parameter not in circuit.parameters:
136+
raise ValueError(f"The parameter {parameter} is not in this circuit.")
137+
138+
# check uniqueness
139+
seen_parameters: set[Parameter] = set()
140+
for instruction in circuit.data:
141+
# get parameters in the current operation
142+
new_parameters = set()
143+
for p in instruction.operation.params:
144+
if isinstance(p, ParameterExpression):
145+
new_parameters.update(p.parameters)
146+
147+
if duplicates := seen_parameters.intersection(new_parameters):
148+
raise NotImplementedError(
149+
"Product rule is not supported, circuit parameters must be unique, but "
150+
f"{duplicates} are duplicated."
151+
)
152+
153+
seen_parameters.update(new_parameters)
137154

138155
summands, op_context = [], []
139156
for i, op in enumerate(circuit.data):
@@ -151,7 +168,14 @@ def derive_circuit(
151168
c = complex(1)
152169
for i, term in enumerate(product_rule_term):
153170
c *= term[0]
154-
summand_circuit.data.append([term[1], *op_context[i]])
171+
# Qiskit changed the format of the stored value. The newer Qiskit has this internal
172+
# method to go from the older (legacy) format to new. This logic may need updating
173+
# at some point if this internal method goes away.
174+
if hasattr(summand_circuit.data, "_resolve_legacy_value"):
175+
value = summand_circuit.data._resolve_legacy_value(term[1], *op_context[i])
176+
summand_circuit.data.append(value)
177+
else:
178+
summand_circuit.data.append([term[1], *op_context[i]])
155179
gradient += [(c, summand_circuit.copy())]
156180

157181
return gradient

qiskit_algorithms/gradients/reverse/reverse_gradient.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# This code is part of a Qiskit project.
22
#
3-
# (C) Copyright IBM 2022, 2023.
3+
# (C) Copyright IBM 2022, 2024.
44
#
55
# This code is licensed under the Apache License, Version 2.0. You may
66
# obtain a copy of this license in the LICENSE.txt file in the root directory
@@ -144,7 +144,8 @@ def _run_unique(
144144
parameter_j = paramlist[j][0]
145145

146146
# get the analytic gradient d U_j / d p_j and bind the gate
147-
deriv = derive_circuit(unitary_j, parameter_j)
147+
# we skip the check since we know the circuit has unique, valid parameters
148+
deriv = derive_circuit(unitary_j, parameter_j, check=False)
148149
for _, gate in deriv:
149150
bind(gate, parameter_binds, inplace=True)
150151

qiskit_algorithms/gradients/reverse/reverse_qgt.py

+5-4
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# This code is part of a Qiskit project.
22
#
3-
# (C) Copyright IBM 2023.
3+
# (C) Copyright IBM 2023, 2024.
44
#
55
# This code is licensed under the Apache License, Version 2.0. You may
66
# obtain a copy of this license in the LICENSE.txt file in the root directory
@@ -131,7 +131,8 @@ def _run_unique(
131131
# Note: We currently only support gates with a single parameter -- which is reflected
132132
# in self.SUPPORTED_GATES -- but generally we could also support gates with multiple
133133
# parameters per gate. This is the reason for the second 0-index.
134-
deriv = derive_circuit(unitaries[0], paramlist[0][0])
134+
# We skip the check since we know the circuit has unique, valid parameters.
135+
deriv = derive_circuit(unitaries[0], paramlist[0][0], check=False)
135136
for _, gate in deriv:
136137
bind(gate, parameter_binds, inplace=True)
137138

@@ -149,7 +150,7 @@ def _run_unique(
149150
phi = psi.copy()
150151

151152
# get the analytic gradient d U_j / d p_j and apply it
152-
deriv = derive_circuit(unitaries[j], paramlist[j][0])
153+
deriv = derive_circuit(unitaries[j], paramlist[j][0], check=False)
153154

154155
for _, gate in deriv:
155156
bind(gate, parameter_binds, inplace=True)
@@ -170,7 +171,7 @@ def _run_unique(
170171
lam = lam.evolve(bound_unitaries[i].inverse())
171172

172173
# get the gradient d U_i / d p_i and apply it
173-
deriv = derive_circuit(unitaries[i], paramlist[i][0])
174+
deriv = derive_circuit(unitaries[i], paramlist[i][0], check=False)
174175
for _, gate in deriv:
175176
bind(gate, parameter_binds, inplace=True)
176177

qiskit_algorithms/gradients/utils.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# This code is part of a Qiskit project.
22
#
3-
# (C) Copyright IBM 2022, 2023.
3+
# (C) Copyright IBM 2022, 2024.
44
#
55
# This code is licensed under the Apache License, Version 2.0. You may
66
# obtain a copy of this license in the LICENSE.txt file in the root directory
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
---
2+
other:
3+
- |
4+
Aspects of the gradients internal implementation, which manipulate circuits more
5+
directly, have been updated now that circuit data is being handled by Rust so it's
6+
compatible with the former Python way as well as the new Qiskit Rust implementation.

test/gradients/test_estimator_gradient.py

+12
Original file line numberDiff line numberDiff line change
@@ -512,6 +512,18 @@ def operations_callback(op):
512512
with self.subTest(msg="assert result is correct"):
513513
self.assertAlmostEqual(result.gradients[0].item(), expect, places=5)
514514

515+
def test_product_rule_check(self):
516+
"""Test product rule check."""
517+
p = Parameter("p")
518+
qc = QuantumCircuit(1)
519+
qc.rx(p, 0)
520+
qc.ry(p, 0)
521+
522+
from qiskit_algorithms.gradients.reverse.derive_circuit import derive_circuit
523+
524+
with self.assertRaises(NotImplementedError):
525+
_ = derive_circuit(qc, p)
526+
515527

516528
if __name__ == "__main__":
517529
unittest.main()

0 commit comments

Comments
 (0)