Skip to content

Commit b732f97

Browse files
authored
Merge branch 'master' into sampling_speedup
2 parents 0e6ea9d + 6e6caf3 commit b732f97

File tree

1 file changed

+10
-8
lines changed

1 file changed

+10
-8
lines changed

thewalrus/symplectic.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -85,16 +85,18 @@ def expand(S, modes, N):
8585
array: the resulting :math:`2N\times 2N` Symplectic matrix
8686
"""
8787
M = S.shape[0] // 2
88-
S2 = np.identity(2 * N, dtype=S.dtype)
88+
S2 = (
89+
np.identity(2 * N, dtype=S.dtype)
90+
if not issparse(S)
91+
else sparse_identity(2 * N, dtype=S.dtype, format="csr")
92+
)
8993

90-
if issparse(S):
94+
if issparse(S) and isinstance(S, (coo_array, dia_array, bsr_array)):
9195
# cast to sparse matrix that supports slicing and indexing
92-
S2 = sparse_identity(2 * N, dtype=S.dtype, format="csr")
93-
if isinstance(S, (coo_array, dia_array, bsr_array)):
94-
warnings.warn(
95-
"Unsupported sparse matrix type, returning a Compressed Sparse Row (CSR) matrix."
96-
)
97-
S = csr_array(S)
96+
warnings.warn(
97+
"Unsupported sparse matrix type, returning a Compressed Sparse Row (CSR) matrix."
98+
)
99+
S = csr_array(S)
98100

99101
w = np.array([modes]) if isinstance(modes, int) else np.array(modes)
100102

0 commit comments

Comments
 (0)