Skip to content
This repository was archived by the owner on Dec 7, 2021. It is now read-only.

Commit 521d505

Browse files
Make QGAN run primarily on circuits (#1341)
* make quantum generator use circuits primarily * add test and deprecation warning * add reno * fix lint Co-authored-by: Manoel Marques <[email protected]>
1 parent 3810857 commit 521d505

File tree

4 files changed

+83
-86
lines changed

4 files changed

+83
-86
lines changed

qiskit/aqua/components/neural_networks/generative_network.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,13 +43,12 @@ def set_seed(self, seed):
4343
raise NotImplementedError()
4444

4545
@abstractmethod
46-
def get_output(self, quantum_instance, qc_state_in, params, shots):
46+
def get_output(self, quantum_instance, params, shots):
4747
"""
4848
Apply quantum/classical neural network to given input and get the respective output
4949
5050
Args:
5151
quantum_instance (QuantumInstance): Quantum Instance, used to run the generator circuit.
52-
qc_state_in (QuantumCircuit or vector): corresponding to the network input state
5352
params (numpy.ndarray): parameters which should be used to run the generator,
5453
if None use self._params
5554
shots (int): if not None use a number of shots that is different from the number

qiskit/aqua/components/neural_networks/quantum_generator.py

Lines changed: 48 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -13,20 +13,17 @@
1313
"""Quantum Generator."""
1414

1515
from typing import Optional, List, Union, Dict, Any
16+
import warnings
1617
from copy import deepcopy
1718
import numpy as np
1819

1920
from qiskit import QuantumRegister, ClassicalRegister, QuantumCircuit
2021
from qiskit.circuit.library import TwoLocal
2122
from qiskit.aqua import aqua_globals
2223
from qiskit.aqua.components.optimizers import ADAM
23-
from qiskit.aqua.components.uncertainty_models import \
24-
UniformDistribution, MultivariateUniformDistribution
2524
from qiskit.aqua.components.uncertainty_models import UnivariateVariationalDistribution, \
2625
MultivariateVariationalDistribution
27-
from qiskit.aqua import AquaError
2826
from qiskit.aqua.components.neural_networks.generative_network import GenerativeNetwork
29-
from qiskit.aqua.components.initial_states import Custom
3027

3128
# pylint: disable=invalid-name
3229

@@ -72,61 +69,32 @@ def __init__(self,
7269
self._bounds = bounds
7370
self._num_qubits = num_qubits
7471
self.generator_circuit = generator_circuit
75-
if self.generator_circuit is None:
76-
entangler_map = []
77-
if np.sum(num_qubits) > 2:
78-
for i in range(int(np.sum(num_qubits))):
79-
entangler_map.append([i, int(np.mod(i + 1, np.sum(num_qubits)))])
80-
else:
81-
if np.sum(num_qubits) > 1:
82-
entangler_map.append([0, 1])
83-
84-
if len(num_qubits) > 1:
85-
num_qubits = list(map(int, num_qubits))
86-
low = bounds[:, 0].tolist()
87-
high = bounds[:, 1].tolist()
88-
init_dist = MultivariateUniformDistribution(num_qubits, low=low, high=high)
89-
q = QuantumRegister(sum(num_qubits))
90-
qc = QuantumCircuit(q)
91-
init_dist.build(qc, q)
92-
init_distribution = Custom(num_qubits=sum(num_qubits), circuit=qc)
93-
# Set variational form
94-
var_form = TwoLocal(sum(num_qubits), 'ry', 'cz', reps=1,
95-
initial_state=init_distribution,
96-
entanglement=entangler_map)
97-
if init_params is None:
98-
init_params = aqua_globals.random.random(var_form.num_parameters) * 2 * 1e-2
99-
# Set generator circuit
100-
self.generator_circuit = MultivariateVariationalDistribution(num_qubits, var_form,
101-
init_params,
102-
low=low, high=high)
103-
else:
104-
init_dist = UniformDistribution(sum(num_qubits), low=bounds[0], high=bounds[1])
105-
q = QuantumRegister(sum(num_qubits), name='q')
106-
qc = QuantumCircuit(q)
107-
init_dist.build(qc, q)
108-
init_distribution = Custom(num_qubits=sum(num_qubits), circuit=qc)
109-
var_form = TwoLocal(sum(num_qubits), 'ry', 'cz', reps=1,
110-
initial_state=init_distribution,
111-
entanglement=entangler_map)
112-
if init_params is None:
113-
init_params = aqua_globals.random.random(var_form.num_parameters) * 2 * 1e-2
114-
# Set generator circuit
115-
self.generator_circuit = UnivariateVariationalDistribution(
116-
int(np.sum(num_qubits)), var_form, init_params, low=bounds[0], high=bounds[1])
117-
118-
if len(num_qubits) > 1:
119-
if isinstance(self.generator_circuit, MultivariateVariationalDistribution):
120-
pass
121-
else:
122-
raise AquaError('Set multivariate variational distribution '
123-
'to represent multivariate data')
72+
if generator_circuit is None:
73+
circuit = QuantumCircuit(sum(num_qubits))
74+
circuit.h(circuit.qubits)
75+
var_form = TwoLocal(sum(num_qubits), 'ry', 'cz', reps=1, entanglement='circular')
76+
circuit.compose(var_form, inplace=True)
77+
78+
# Set generator circuit
79+
self.generator_circuit = circuit
80+
81+
if isinstance(generator_circuit, (UnivariateVariationalDistribution,
82+
MultivariateVariationalDistribution)):
83+
warnings.warn('Passing a UnivariateVariationalDistribution or MultivariateVariational'
84+
'Distribution is as ``generator_circuit`` is deprecated as of Aqua 0.8.0 '
85+
'and the support will be removed no earlier than 3 months after the '
86+
'release data. You should pass as QuantumCircuit instead.',
87+
DeprecationWarning, stacklevel=2)
88+
self._free_parameters = generator_circuit._var_form_params
89+
self.generator_circuit = generator_circuit._var_form
12490
else:
125-
if isinstance(self.generator_circuit, UnivariateVariationalDistribution):
126-
pass
127-
else:
128-
raise AquaError('Set univariate variational distribution '
129-
'to represent univariate data')
91+
self._free_parameters = list(self.generator_circuit.parameters)
92+
93+
if init_params is None:
94+
init_params = aqua_globals.random.random(self.generator_circuit.num_parameters) * 2e-2
95+
96+
self._bound_parameters = init_params
97+
13098
# Set optimizer for updating the generator network
13199
self._optimizer = ADAM(maxiter=1, tol=1e-6, lr=1e-3, beta_1=0.7,
132100
beta_2=0.99, noise_factor=1e-6,
@@ -192,26 +160,28 @@ def construct_circuit(self, params=None):
192160
Construct generator circuit.
193161
194162
Args:
195-
params (numpy.ndarray): parameters which should be used to run the generator,
196-
if None use self._params
163+
params (list | dict): parameters which should be used to run the generator.
197164
198165
Returns:
199166
Instruction: construct the quantum circuit and return as gate
200167
"""
201-
202-
q = QuantumRegister(sum(self._num_qubits), name='q')
203-
qc = QuantumCircuit(q)
204168
if params is None:
205-
self.generator_circuit.build(qc=qc, q=q)
206-
else:
207-
generator_circuit_copy = deepcopy(self.generator_circuit)
208-
generator_circuit_copy.params = params
209-
generator_circuit_copy.build(qc=qc, q=q)
169+
return self.generator_circuit
210170

211-
# return qc.copy(name='qc')
212-
return qc.to_instruction()
171+
if isinstance(params, (list, np.ndarray)):
172+
params = dict(zip(self._free_parameters, params))
213173

214-
def get_output(self, quantum_instance, qc_state_in=None, params=None, shots=None):
174+
return self.generator_circuit.assign_parameters(params)
175+
# self.generator_circuit.build(qc=qc, q=q)
176+
# else:
177+
# generator_circuit_copy = deepcopy(self.generator_circuit)
178+
# generator_circuit_copy.params = params
179+
# generator_circuit_copy.build(qc=qc, q=q)
180+
181+
# # return qc.copy(name='qc')
182+
# return qc.to_instruction()
183+
184+
def get_output(self, quantum_instance, params=None, shots=None):
215185
"""
216186
Get classical data samples from the generator.
217187
Running the quantum generator circuit results in a quantum state.
@@ -222,7 +192,6 @@ def get_output(self, quantum_instance, qc_state_in=None, params=None, shots=None
222192
Args:
223193
quantum_instance (QuantumInstance): Quantum Instance, used to run the generator
224194
circuit.
225-
qc_state_in (QuantumCircuit): deprecated
226195
params (numpy.ndarray): array or None, parameters which should
227196
be used to run the generator, if None use self._params
228197
shots (int): if not None use a number of shots that is different from the
@@ -234,6 +203,8 @@ def get_output(self, quantum_instance, qc_state_in=None, params=None, shots=None
234203
instance_shots = quantum_instance.run_config.shots
235204
q = QuantumRegister(sum(self._num_qubits), name='q')
236205
qc = QuantumCircuit(q)
206+
if params is None:
207+
params = self._bound_parameters
237208
qc.append(self.construct_circuit(params), q)
238209
if quantum_instance.is_statevector:
239210
pass
@@ -277,7 +248,7 @@ def get_output(self, quantum_instance, qc_state_in=None, params=None, shots=None
277248
temp.append(self._data_grid[int(bin_rep)])
278249
generated_samples.append(temp)
279250

280-
self.generator_circuit._probabilities = generated_samples_weights
251+
# self.generator_circuit._probabilities = generated_samples_weights
281252
if shots is not None:
282253
# Restore the initial quantum_instance configuration
283254
quantum_instance.set_config(shots=instance_shots)
@@ -347,13 +318,13 @@ def train(self, quantum_instance=None, shots=None):
347318
self._optimizer._maxiter = 1
348319
self._optimizer._t = 0
349320
objective = self._get_objective_function(quantum_instance, self._discriminator)
350-
self.generator_circuit.params, loss, _ = self._optimizer.optimize(
351-
num_vars=len(self.generator_circuit.params),
321+
self._bound_parameters, loss, _ = self._optimizer.optimize(
322+
num_vars=len(self._bound_parameters),
352323
objective_function=objective,
353-
initial_point=self.generator_circuit.params
324+
initial_point=self._bound_parameters
354325
)
355326

356327
self._ret['loss'] = loss
357-
self._ret['params'] = self.generator_circuit.params
328+
self._ret['params'] = self._bound_parameters
358329

359330
return self._ret
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
---
2+
features:
3+
- |
4+
Support passing ``QuantumCircuit`` objects as generator circuits into
5+
the ``QuantumGenerator``.
6+
deprecations:
7+
- |
8+
Deprecate the ``UnivariateVariationalDistribution`` and
9+
``MultivariateVariationalDistribution`` as input
10+
to the ``QuantumGenerator``. Instead, plain ``QuantumCircuit`` objects can
11+
be used.

test/aqua/test_qgan.py

Lines changed: 23 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,20 +12,23 @@
1212

1313
"""Test the QGAN algorithm."""
1414

15+
import unittest
16+
import warnings
1517
from test.aqua import QiskitAquaTestCase
18+
from ddt import ddt, data
1619

17-
import unittest
1820
from qiskit import QuantumCircuit, QuantumRegister
1921
from qiskit.circuit.library import RealAmplitudes
2022
from qiskit.aqua.components.uncertainty_models import (UniformDistribution,
2123
UnivariateVariationalDistribution)
2224
from qiskit.aqua.algorithms import QGAN
23-
from qiskit.aqua import aqua_globals, QuantumInstance
25+
from qiskit.aqua import aqua_globals, QuantumInstance, MissingOptionalLibraryError
2426
from qiskit.aqua.components.initial_states import Custom
2527
from qiskit.aqua.components.neural_networks import NumPyDiscriminator, PyTorchDiscriminator
2628
from qiskit import BasicAer
2729

2830

31+
@ddt
2932
class TestQGAN(QiskitAquaTestCase):
3033
"""Test the QGAN algorithm."""
3134

@@ -83,14 +86,24 @@ def setUp(self):
8386
# Set variational form
8487
var_form = RealAmplitudes(sum(num_qubits), reps=1, initial_state=init_distribution,
8588
entanglement=entangler_map)
86-
self.generator_circuit = UnivariateVariationalDistribution(sum(num_qubits), var_form,
89+
self.generator_circuit = var_form
90+
warnings.filterwarnings('ignore', category=DeprecationWarning)
91+
self.generator_factory = UnivariateVariationalDistribution(sum(num_qubits), var_form,
8792
init_params,
8893
low=self._bounds[0],
8994
high=self._bounds[1])
95+
warnings.filterwarnings('always', category=DeprecationWarning)
9096

91-
def test_sample_generation(self):
97+
@data('circuit', 'factory')
98+
def test_sample_generation(self, circuit_type):
9299
"""Test sample generation."""
93-
self.qgan.set_generator(generator_circuit=self.generator_circuit)
100+
if circuit_type == 'factory':
101+
warnings.filterwarnings('ignore', category=DeprecationWarning)
102+
self.qgan.set_generator(generator_circuit=self.generator_factory)
103+
warnings.filterwarnings('always', category=DeprecationWarning)
104+
else:
105+
self.qgan.set_generator(generator_circuit=self.generator_circuit)
106+
94107
_, weights_statevector = self.qgan._generator.get_output(self.qi_statevector, shots=100)
95108
samples_qasm, weights_qasm = self.qgan._generator.get_output(self.qi_qasm, shots=100)
96109
samples_qasm, weights_qasm = zip(*sorted(zip(samples_qasm, weights_qasm)))
@@ -99,7 +112,10 @@ def test_sample_generation(self):
99112

100113
def test_qgan_training(self):
101114
"""Test QGAN training."""
115+
warnings.filterwarnings('ignore', category=DeprecationWarning)
102116
self.qgan.set_generator(generator_circuit=self.generator_circuit)
117+
warnings.filterwarnings('always', category=DeprecationWarning)
118+
103119
trained_statevector = self.qgan.run(self.qi_statevector)
104120
trained_qasm = self.qgan.run(self.qi_qasm)
105121
self.assertAlmostEqual(trained_qasm['rel_entr'], trained_statevector['rel_entr'], delta=0.1)
@@ -131,8 +147,8 @@ def test_qgan_training_run_algo_torch(self):
131147
seed_transpiler=aqua_globals.random_seed))
132148
self.assertAlmostEqual(trained_qasm['rel_entr'],
133149
trained_statevector['rel_entr'], delta=0.1)
134-
except Exception as ex: # pylint: disable=broad-except
135-
self.skipTest(str(ex))
150+
except MissingOptionalLibraryError:
151+
self.skipTest('pytorch not installed, skipping test')
136152

137153
def test_qgan_training_run_algo_numpy(self):
138154
"""Test QGAN training using a NumPy discriminator."""

0 commit comments

Comments
 (0)