Skip to content

Commit 466556d

Browse files
SummedOp updates & optimization converters to use Opflow (qiskit-community/qiskit-aqua#1059)
* simplify and reduce, add equals to SummedOp, update tests * directly use new opflow, no need to go via WPO * update comments and docstrings * directly use opflow * don't do equality check in add * directly use opflow * change order in reduce * fix qaoa * add short test on summed op equality * rm prints * use set comparison, rename simplify to collapse_summands * fix expected value, should be sqrt(2), not 2 * cast coeffs to complex * add reno on equals * fix mypy * fix spell * fix lint * dont cast coefficient to complex leads to problems if the coeff is exponentitated and not supposed to be complex * use sum instead of reduce * rm unused import * move __hash__ to primitive op and base on repr * use != over not == * add summed op test for different primitives * check for opbase, not summedop * adress changes from review * explicitly raise an error upon ListOp input * return identity op instead of the int 0 * fix spell * add note that equals is not mathematically sound Co-authored-by: Manoel Marques <[email protected]>
1 parent 2ca8e70 commit 466556d

File tree

1 file changed

+36
-16
lines changed

1 file changed

+36
-16
lines changed

test/aqua/operators/test_op_construction.py

Lines changed: 36 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -14,23 +14,26 @@
1414

1515
""" Test Operator construction, including OpPrimitives and singletons. """
1616

17+
1718
import unittest
1819
from test.aqua import QiskitAquaTestCase
1920
import itertools
2021
import numpy as np
22+
from ddt import ddt, data
2123

2224
from qiskit.circuit import QuantumCircuit, QuantumRegister, Instruction
2325
from qiskit.extensions.exceptions import ExtensionError
2426
from qiskit.quantum_info.operators import Operator, Pauli
25-
from qiskit.circuit.library import CZGate
27+
from qiskit.circuit.library import CZGate, ZGate
2628

2729
from qiskit.aqua.operators import (
28-
X, Y, Z, I, CX, T, H, PrimitiveOp, SummedOp, PauliOp, Minus, CircuitOp
30+
X, Y, Z, I, CX, T, H, PrimitiveOp, SummedOp, PauliOp, Minus, CircuitOp, MatrixOp
2931
)
3032

3133

3234
# pylint: disable=invalid-name
3335

36+
@ddt
3437
class TestOpConstruction(QiskitAquaTestCase):
3538
"""Operator Construction tests."""
3639

@@ -235,7 +238,7 @@ def test_circuit_permute(self):
235238
c_op_id = c_op_perm.permute(perm)
236239
self.assertEqual(c_op, c_op_id)
237240

238-
def test_summed_op(self):
241+
def test_summed_op_reduce(self):
239242
"""Test SummedOp"""
240243
sum_op = (X ^ X * 2) + (Y ^ Y) # type: SummedOp
241244
with self.subTest('SummedOp test 1'):
@@ -250,7 +253,7 @@ def test_summed_op(self):
250253
self.assertListEqual([str(op.primitive) for op in sum_op], ['XX', 'YY', 'YY'])
251254
self.assertListEqual([op.coeff for op in sum_op], [2, 1, 1])
252255

253-
sum_op = sum_op.simplify()
256+
sum_op = sum_op.collapse_summands()
254257
with self.subTest('SummedOp test 2-b'):
255258
self.assertEqual(sum_op.coeff, 1)
256259
self.assertListEqual([str(op.primitive) for op in sum_op], ['XX', 'YY'])
@@ -263,7 +266,7 @@ def test_summed_op(self):
263266
self.assertListEqual([str(op.primitive) for op in sum_op], ['XX', 'YY', 'YY', 'XX'])
264267
self.assertListEqual([op.coeff for op in sum_op], [2, 1, 1, 2])
265268

266-
sum_op = sum_op.simplify()
269+
sum_op = sum_op.reduce()
267270
with self.subTest('SummedOp test 3-b'):
268271
self.assertEqual(sum_op.coeff, 1)
269272
self.assertListEqual([str(op.primitive) for op in sum_op], ['XX', 'YY'])
@@ -275,7 +278,7 @@ def test_summed_op(self):
275278
self.assertListEqual([str(op.primitive) for op in sum_op], ['XX', 'YY'])
276279
self.assertListEqual([op.coeff for op in sum_op], [2, 1])
277280

278-
sum_op = sum_op.simplify()
281+
sum_op = sum_op.collapse_summands()
279282
with self.subTest('SummedOp test 4-b'):
280283
self.assertEqual(sum_op.coeff, 1)
281284
self.assertListEqual([str(op.primitive) for op in sum_op], ['XX', 'YY'])
@@ -288,7 +291,7 @@ def test_summed_op(self):
288291
self.assertListEqual([str(op.primitive) for op in sum_op], ['XX', 'YY', 'YY'])
289292
self.assertListEqual([op.coeff for op in sum_op], [4, 2, 1])
290293

291-
sum_op = sum_op.simplify()
294+
sum_op = sum_op.collapse_summands()
292295
with self.subTest('SummedOp test 5-b'):
293296
self.assertEqual(sum_op.coeff, 1)
294297
self.assertListEqual([str(op.primitive) for op in sum_op], ['XX', 'YY'])
@@ -301,7 +304,7 @@ def test_summed_op(self):
301304
self.assertListEqual([str(op.primitive) for op in sum_op], ['XX', 'YY', 'XX', 'YY'])
302305
self.assertListEqual([op.coeff for op in sum_op], [4, 2, 2, 1])
303306

304-
sum_op = sum_op.simplify()
307+
sum_op = sum_op.collapse_summands()
305308
with self.subTest('SummedOp test 6-b'):
306309
self.assertEqual(sum_op.coeff, 1)
307310
self.assertListEqual([str(op.primitive) for op in sum_op], ['XX', 'YY'])
@@ -310,11 +313,11 @@ def test_summed_op(self):
310313
sum_op = SummedOp([X ^ X * 2, Y ^ Y], 2)
311314
sum_op += sum_op
312315
with self.subTest('SummedOp test 7-a'):
313-
self.assertEqual(sum_op.coeff, 4)
314-
self.assertListEqual([str(op.primitive) for op in sum_op], ['XX', 'YY'])
315-
self.assertListEqual([op.coeff for op in sum_op], [2, 1])
316+
self.assertEqual(sum_op.coeff, 1)
317+
self.assertListEqual([str(op.primitive) for op in sum_op], ['XX', 'YY', 'XX', 'YY'])
318+
self.assertListEqual([op.coeff for op in sum_op], [4, 2, 4, 2])
316319

317-
sum_op = sum_op.simplify()
320+
sum_op = sum_op.collapse_summands()
318321
with self.subTest('SummedOp test 7-b'):
319322
self.assertEqual(sum_op.coeff, 1)
320323
self.assertListEqual([str(op.primitive) for op in sum_op], ['XX', 'YY'])
@@ -326,12 +329,28 @@ def test_summed_op(self):
326329
self.assertListEqual([str(op.primitive) for op in sum_op], ['XX', 'YY', 'XX', 'ZZ'])
327330
self.assertListEqual([op.coeff for op in sum_op], [4, 2, 6, 3])
328331

329-
sum_op = sum_op.simplify()
332+
sum_op = sum_op.collapse_summands()
330333
with self.subTest('SummedOp test 8-b'):
331334
self.assertEqual(sum_op.coeff, 1)
332335
self.assertListEqual([str(op.primitive) for op in sum_op], ['XX', 'YY', 'ZZ'])
333336
self.assertListEqual([op.coeff for op in sum_op], [10, 2, 3])
334337

338+
def test_summed_op_equals(self):
339+
"""Test corner cases of SummedOp's equals function."""
340+
with self.subTest('multiplicative factor'):
341+
self.assertEqual(2 * X, X + X)
342+
343+
with self.subTest('commutative'):
344+
self.assertEqual(X + Z, Z + X)
345+
346+
with self.subTest('circuit and paulis'):
347+
z = CircuitOp(ZGate())
348+
self.assertEqual(Z + z, z + Z)
349+
350+
with self.subTest('matrix op and paulis'):
351+
z = MatrixOp([[1, 0], [0, -1]])
352+
self.assertEqual(Z + z, z + Z)
353+
335354
def test_circuit_compose_register_independent(self):
336355
"""Test that CircuitOp uses combines circuits independent of the register.
337356
@@ -344,13 +363,14 @@ def test_circuit_compose_register_independent(self):
344363

345364
self.assertEqual(composed.num_qubits, 2)
346365

347-
def test_pauli_op_hashing(self):
366+
@data(Z, CircuitOp(ZGate()), MatrixOp([[1, 0], [0, -1]]))
367+
def test_op_hashing(self, op):
348368
"""Regression test against faulty set comparison.
349369
350370
Set comparisons rely on a hash table which requires identical objects to have identical
351-
hashes. Thus, the PauliOp.__hash__ should support this requirement.
371+
hashes. Thus, the PrimitiveOp.__hash__ should support this requirement.
352372
"""
353-
self.assertEqual(set([2*Z]), set([2*Z]))
373+
self.assertEqual(set([2 * op]), set([2 * op]))
354374

355375

356376
if __name__ == '__main__':

0 commit comments

Comments
 (0)