From 3101a0de08e503c60fe1d99d744c598ceaaf4265 Mon Sep 17 00:00:00 2001 From: Anthony Date: Wed, 18 Jun 2025 11:58:36 -0400 Subject: [PATCH 01/17] initial commit --- mrmustard/lab/circuits.py | 2 +- mrmustard/lab/states/dm.py | 2 +- mrmustard/lab/utils.py | 2 +- mrmustard/math/backend_jax.py | 110 +-- mrmustard/math/backend_manager.py | 380 +-------- mrmustard/math/backend_numpy.py | 135 +--- mrmustard/math/backend_tensorflow.py | 78 +- mrmustard/physics/bargmann_utils.py | 23 - mrmustard/physics/fock_utils.py | 127 +-- mrmustard/physics/gaussian.py | 44 -- mrmustard/physics/gaussian_integrals.py | 32 - .../test_b_to_ps.py | 5 +- tests/test_math/test_backend_manager.py | 129 ---- .../test_lattice/test_lattice_functions.py | 16 +- tests/test_physics/test_bargmann_utils.py | 18 - tests/test_physics/test_fock_utils.py | 26 - tests/test_physics/test_gaussian_integrals.py | 23 - tests/test_training/test_callbacks.py | 68 +- tests/test_training/test_opt_lab.py | 722 +++++++++--------- 19 files changed, 444 insertions(+), 1498 deletions(-) diff --git a/mrmustard/lab/circuits.py b/mrmustard/lab/circuits.py index 3f282ba1d..d828cd0db 100644 --- a/mrmustard/lab/circuits.py +++ b/mrmustard/lab/circuits.py @@ -343,7 +343,7 @@ def component_to_str(comp: CircuitComponent) -> list[str]: param = comp.parameters.constants.get(name) or comp.parameters.variables.get( name ) - new_values = math.atleast_1d(param.value) + new_values = math.atleast_nd(param.value, 1) if len(new_values) == 1 and cc_name not in control_gates: new_values = math.tile(new_values, (len(comp.modes),)) values.append(math.asnumpy(new_values)) diff --git a/mrmustard/lab/states/dm.py b/mrmustard/lab/states/dm.py index bc890eb5f..dc3b46275 100644 --- a/mrmustard/lab/states/dm.py +++ b/mrmustard/lab/states/dm.py @@ -704,7 +704,7 @@ def physical_stellar_decomposition_mixed( # pylint: disable=too-many-statements R_c_transpose = math.einsum("...ij->...ji", R_c) Aphi_out = Am - gamma = np.linalg.pinv(R_c) @ R + gamma = math.pinv(R_c) @ R gamma_transpose = math.einsum("...ij->...ji", gamma) Aphi_in = gamma @ math.inv(Aphi_out - math.Xmat(M)) @ gamma_transpose + math.Xmat(M) diff --git a/mrmustard/lab/utils.py b/mrmustard/lab/utils.py index 605da3d04..a07dce37d 100644 --- a/mrmustard/lab/utils.py +++ b/mrmustard/lab/utils.py @@ -66,7 +66,7 @@ def reshape_params(n_modes: int, **kwargs) -> Generator: names = list(kwargs.keys()) vars = list(kwargs.values()) - vars = [math.atleast_1d(var) for var in vars] + vars = [math.atleast_nd(var, 1) for var in vars] for i, var in enumerate(vars): if len(var) == 1: diff --git a/mrmustard/math/backend_jax.py b/mrmustard/math/backend_jax.py index cb80d2d13..245603b25 100644 --- a/mrmustard/math/backend_jax.py +++ b/mrmustard/math/backend_jax.py @@ -75,9 +75,6 @@ def abs(self, array: jnp.ndarray) -> jnp.ndarray: def all(self, array: jnp.ndarray) -> jnp.ndarray: return jnp.all(array) - def angle(self, array: jnp.ndarray) -> jnp.ndarray: - return jnp.angle(array) - @jax.jit def any(self, array: jnp.ndarray) -> jnp.ndarray: return jnp.any(array) @@ -115,25 +112,6 @@ def log(self, array: jnp.ndarray) -> jnp.ndarray: def atleast_nd(self, array: jnp.ndarray, n: int, dtype=None) -> jnp.ndarray: return jnp.array(array, ndmin=n, dtype=dtype) - @jax.jit - def block_diag(self, mat1: jnp.ndarray, mat2: jnp.ndarray) -> jnp.ndarray: - Za = self.zeros((mat1.shape[-2], mat2.shape[-1]), dtype=mat1.dtype) - Zb = self.zeros((mat2.shape[-2], mat1.shape[-1]), dtype=mat1.dtype) - return self.concat( - [self.concat([mat1, Za], axis=-1), self.concat([Zb, mat2], axis=-1)], - axis=-2, - ) - - def constraint_func(self, bounds: tuple[float | None, float | None]) -> Callable: - lower = -jnp.inf if bounds[0] is None else bounds[0] - upper = jnp.inf if bounds[1] is None else bounds[1] - - @jax.jit - def constraint(x): - return jnp.clip(x, lower, upper) - - return constraint - @partial(jax.jit, static_argnames=["dtype"]) def cast(self, array: jnp.ndarray, dtype=None) -> jnp.ndarray: if dtype is None: @@ -159,18 +137,6 @@ def allclose(self, array1: jnp.ndarray, array2: jnp.ndarray, atol=1e-9, rtol=1e- def clip(self, array: jnp.ndarray, a_min: float, a_max: float) -> jnp.ndarray: return jnp.clip(array, a_min, a_max) - @jax.jit - def maximum(self, a: jnp.ndarray, b: jnp.ndarray) -> jnp.ndarray: - return jnp.maximum(a, b) - - @jax.jit - def minimum(self, a: jnp.ndarray, b: jnp.ndarray) -> jnp.ndarray: - return jnp.minimum(a, b) - - @jax.jit - def lgamma(self, array: jnp.ndarray) -> jnp.ndarray: - return jax.lax.lgamma(array) - @jax.jit def conj(self, array: jnp.ndarray) -> jnp.ndarray: return jnp.conj(array) @@ -178,26 +144,6 @@ def conj(self, array: jnp.ndarray) -> jnp.ndarray: def pow(self, x: jnp.ndarray, y: float) -> jnp.ndarray: return jnp.power(x, y) - def Categorical(self, probs: jnp.ndarray, name: str): # pylint: disable=unused-argument - class Generator: - def __init__(self, probs): - self._probs = probs - - def sample(self): - key = jax.random.PRNGKey(0) - idx = jnp.arange(len(self._probs)) - return jax.random.choice(key, idx, p=self._probs / jnp.sum(self._probs)) - - return Generator(probs) - - @partial(jax.jit, static_argnames=["k"]) - def set_diag(self, array: jnp.ndarray, diag: jnp.ndarray, k: int) -> jnp.ndarray: - i = jnp.arange(0, array.shape[-2] - abs(k)) - j = jnp.arange(abs(k), array.shape[-1]) - i = jnp.where(k < 0, i - array.shape[-2] + abs(k), i) - j = jnp.where(k < 0, j - abs(k), j) - return array.at[..., i, j].set(diag) - def new_variable( self, value: jnp.ndarray, @@ -210,7 +156,7 @@ def new_variable( @jax.jit def outer(self, array1: jnp.ndarray, array2: jnp.ndarray) -> jnp.ndarray: - return jnp.tensordot(array1, array2, [[], []]) + return self.tensordot(array1, array2, [[], []]) @partial(jax.jit, static_argnames=["name", "dtype"]) def new_constant(self, value, name: str, dtype=None): # pylint: disable=unused-argument @@ -218,17 +164,6 @@ def new_constant(self, value, name: str, dtype=None): # pylint: disable=unused- value = self.astensor(value, dtype) return value - @partial(jax.jit, static_argnames=["data_format", "padding"]) - def convolution( - self, - array: jnp.ndarray, - filters: jnp.ndarray, - padding: str | None = None, - data_format="NWC", # pylint: disable=unused-argument - ) -> jnp.ndarray: - padding = padding or "VALID" - return jax.lax.conv(array, filters, (1, 1), padding) - def tile(self, array: jnp.ndarray, repeats: Sequence[int]) -> jnp.ndarray: return jnp.tile(array, repeats) @@ -311,10 +246,6 @@ def eye_like(self, array: jnp.ndarray) -> jnp.ndarray: def from_backend(self, value) -> bool: return isinstance(value, jnp.ndarray) - @partial(jax.jit, static_argnames=["repeats", "axis"]) - def repeat(self, array: jnp.ndarray, repeats: int, axis: int = None) -> jnp.ndarray: - return jnp.repeat(array, repeats, axis=axis) - @partial(jax.jit, static_argnames=["axis"]) def gather(self, array: jnp.ndarray, indices: jnp.ndarray, axis: int = 0) -> jnp.ndarray: return jnp.take(array, indices, axis=axis) @@ -339,12 +270,6 @@ def matmul(self, *matrices: jnp.ndarray) -> jnp.ndarray: mat = jnp.linalg.multi_dot(matrices) return mat - @partial(jax.jit, static_argnames=["old", "new"]) - def moveaxis( - self, array: jnp.ndarray, old: int | Sequence[int], new: int | Sequence[int] - ) -> jnp.ndarray: - return jnp.moveaxis(array, old, new) - def ones(self, shape: Sequence[int], dtype=None) -> jnp.ndarray: dtype = dtype or self.float64 return jnp.ones(shape, dtype=dtype) @@ -385,10 +310,6 @@ def real(self, array: jnp.ndarray) -> jnp.ndarray: def reshape(self, array: jnp.ndarray, shape: Sequence[int]) -> jnp.ndarray: return jnp.reshape(array, shape) - @partial(jax.jit, static_argnames=["decimals"]) - def round(self, array: jnp.ndarray, decimals: int = 0) -> jnp.ndarray: - return jnp.round(array, decimals) - @jax.jit def sin(self, array: jnp.ndarray) -> jnp.ndarray: return jnp.sin(array) @@ -416,9 +337,6 @@ def stack(self, arrays: jnp.ndarray, axis: int = 0) -> jnp.ndarray: def kron(self, tensor1: jnp.ndarray, tensor2: jnp.ndarray): return jnp.kron(tensor1, tensor2) - def boolean_mask(self, tensor: jnp.ndarray, mask: jnp.ndarray) -> jnp.ndarray: - return tensor[mask] - @partial(jax.jit, static_argnames=["axes"]) def sum(self, array: jnp.ndarray, axes: Sequence[int] = None): return jnp.sum(array, axis=axes) @@ -430,24 +348,6 @@ def norm(self, array: jnp.ndarray) -> jnp.ndarray: def map_fn(self, func, elements): return jax.vmap(func)(elements) - def MultivariateNormalTriL(self, loc: jnp.ndarray, scale_tril: jnp.ndarray, key: int = 0): - class Generator: - def __init__(self, mean, cov, key): - self._mean = mean - self._cov = cov - self._rng = jax.random.PRNGKey(key) - - def sample(self, dtype=None): # pylint: disable=unused-argument - fn = jax.random.multivariate_normal - ret = fn(self._rng, self._mean, self._cov) - return ret - - def prob(self, x): - return jsp.stats.multivariate_normal.pdf(x, mean=self._mean, cov=self._cov) - - scale_tril = scale_tril @ jnp.transpose(scale_tril) - return Generator(loc, scale_tril, key) - def tensordot(self, a: jnp.ndarray, b: jnp.ndarray, axes: Sequence[int]) -> jnp.ndarray: return jnp.tensordot(a, b, axes) @@ -466,14 +366,6 @@ def zeros(self, shape: Sequence[int], dtype=None) -> jnp.ndarray: def zeros_like(self, array: jnp.ndarray, dtype: str = "complex128") -> jnp.ndarray: return jnp.zeros_like(array, dtype=dtype) - @partial(jax.jit, static_argnames=["axis"]) - def squeeze(self, tensor: jnp.ndarray, axis=None): - return jnp.squeeze(tensor, axis=axis) - - @jax.jit - def cholesky(self, input: jnp.ndarray): - return jnp.linalg.cholesky(input) - @staticmethod @jax.jit def eigh(tensor: jnp.ndarray) -> tuple: diff --git a/mrmustard/math/backend_manager.py b/mrmustard/math/backend_manager.py index 9f3461180..100152269 100644 --- a/mrmustard/math/backend_manager.py +++ b/mrmustard/math/backend_manager.py @@ -23,7 +23,6 @@ from jax.errors import TracerArrayConversionError import numpy as np -from scipy.special import binom from scipy.stats import ortho_group, unitary_group from ..utils.settings import settings @@ -90,7 +89,7 @@ def lazy_import(module_name: str): } -class BackendManager: # pylint: disable=too-many-public-methods, fixme +class BackendManager: r""" A class to manage the different backends supported by Mr Mustard. """ @@ -255,19 +254,6 @@ def allclose(self, array1: Tensor, array2: Tensor, atol=1e-9, rtol=1e-5) -> bool array2 = self.astensor(array2) return self._apply("allclose", (array1, array2, atol, rtol)) - def angle(self, array: Tensor) -> Tensor: - r""" - The complex phase of ``array``. - - Args: - array: The array to take the complex phase of. - - Returns: - The complex phase of ``array``. - """ - - return self._apply("angle", (array,)) - def any(self, array: Tensor) -> bool: r"""Returns ``True`` if any element of array is ``True``, ``False`` otherwise. @@ -291,7 +277,6 @@ def arange(self, start: int, limit: int = None, delta: int = 1, dtype: Any = Non Returns: The array of evenly spaced values. """ - # NOTE: is float64 by default return self._apply("arange", (start, limit, delta, dtype)) def asnumpy(self, tensor: Tensor) -> Tensor: @@ -330,45 +315,6 @@ def astensor(self, array: Tensor, dtype=None): """ return self._apply("astensor", (array, dtype)) - def atleast_1d(self, array: Tensor, dtype=None) -> Tensor: - r"""Returns an array with at least one dimension. - - Args: - array: The array to convert. - dtype: The data type of the array. If ``None``, the returned array - is of the same type as the given one. - - Returns: - The array with at least one dimension. - """ - return self._apply("atleast_nd", (array, 1, dtype)) - - def atleast_2d(self, array: Tensor, dtype=None) -> Tensor: - r"""Returns an array with at least two dimensions. - - Args: - array: The array to convert. - dtype: The data type of the array. If ``None``, the returned array - is of the same type as the given one. - - Returns: - The array with at least two dimensions. - """ - return self._apply("atleast_nd", (array, 2, dtype)) - - def atleast_3d(self, array: Tensor, dtype=None) -> Tensor: - r"""Returns an array with at least three dimensions. - - Args: - array: The array to convert. - dtype: The data type of the array. If ``None``, the returned array - is of the same type as the given one. - - Returns: - The array with at least three dimensions. - """ - return self._apply("atleast_nd", (array, 3, dtype)) - def atleast_nd(self, array: Tensor, n: int, dtype=None) -> Tensor: r"""Returns an array with at least n dimensions. Note that dimensions are prepended to meet the minimum number of dimensions. @@ -383,31 +329,6 @@ def atleast_nd(self, array: Tensor, n: int, dtype=None) -> Tensor: """ return self._apply("atleast_nd", (array, n, dtype)) - def block_diag(self, mat1: Matrix, mat2: Matrix) -> Matrix: - r"""Returns a block diagonal matrix from the given matrices. - - Args: - mat1: A matrix. - mat2: A matrix. - - Returns: - A block diagonal matrix from the given matrices. - """ - return self._apply("block_diag", (mat1, mat2)) - - def boolean_mask(self, tensor: Tensor, mask: Tensor) -> Tensor: - """ - Returns a tensor based on the truth value of the boolean mask. - - Args: - tensor: A tensor. - mask: A boolean mask. - - Returns: - A tensor based on the truth value of the boolean mask. - """ - return self._apply("boolean_mask", (tensor, mask)) - def block(self, blocks: list[list[Tensor]], axes=(-2, -1)) -> Tensor: r""" Returns a matrix made from the given blocks. @@ -496,44 +417,6 @@ def conj(self, array: Tensor) -> Tensor: """ return self._apply("conj", (array,)) - def constraint_func(self, bounds: tuple[float | None, float | None]) -> Callable | None: - r"""Returns a constraint function for the given bounds. - - A constraint function will clip the value to the interval given by the bounds. - - .. note:: - - The upper and/or lower bounds can be ``None``, in which case the constraint - function will not clip the value. - - Args: - bounds: The bounds of the constraint. - - Returns: - The constraint function. - """ - return self._apply("constraint_func", (bounds)) - - def convolution( - self, - array: Tensor, - filters: Tensor, - padding: str | None = None, - data_format="NWC", - ) -> Tensor: # TODO: remove strides and data_format? - r"""Performs a convolution on array with filters. - - Args: - array: The array to convolve. - filters: The filters to convolve with. - padding: The padding mode. - data_format: The data format of the array. - - Returns: - The convolved array. - """ - return self._apply("convolution", (array, filters, padding, data_format)) - def cos(self, array: Tensor) -> Tensor: r"""The cosine of an array. @@ -914,17 +797,6 @@ def is_trainable(self, tensor: Tensor) -> bool: """ return self._apply("is_trainable", (tensor,)) - def lgamma(self, x: Tensor) -> Tensor: - r"""The natural logarithm of the gamma function of ``x``. - - Args: - x: The array to take the natural logarithm of the gamma function of - - Returns: - The natural logarithm of the gamma function of ``x`` - """ - return self._apply("lgamma", (x,)) - def log(self, x: Tensor) -> Tensor: r"""The natural logarithm of ``x``. @@ -972,64 +844,6 @@ def matvec(self, a: Matrix, b: Vector) -> Tensor: """ return self._apply("matvec", (a, b)) - def maximum(self, a: Tensor, b: Tensor) -> Tensor: - r"""The element-wise maximum of ``a`` and ``b``. - - Args: - a: The first array to take the maximum of - b: The second array to take the maximum of - - Returns: - The element-wise maximum of ``a`` and ``b`` - """ - return self._apply( - "maximum", - ( - a, - b, - ), - ) - - def minimum(self, a: Tensor, b: Tensor) -> Tensor: - r"""The element-wise minimum of ``a`` and ``b``. - - Args: - a: The first array to take the minimum of - b: The second array to take the minimum of - - Returns: - The element-wise minimum of ``a`` and ``b`` - """ - return self._apply( - "minimum", - ( - a, - b, - ), - ) - - def moveaxis(self, array: Tensor, old: Tensor, new: Tensor) -> Tensor: - r""" - Moves the axes of an array to a new position. - - Args: - array: The array to move the axes of. - old: The old index position - new: The new index position - - - Returns: - The updated array - """ - return self._apply( - "moveaxis", - ( - array, - old, - new, - ), - ) - def new_variable( self, value: Tensor, @@ -1193,20 +1007,6 @@ def real(self, array: Tensor) -> Tensor: """ return self._apply("real", (array,)) - def repeat(self, array: Tensor, repeats: int, axis: int = None) -> Tensor: - """ - Repeats elements of a tensor along a specified axis. - - Args: - array: The input tensor. - repeats: The number of repetitions for each element. - axis: The axis along which to repeat values. If None, use the flattened input tensor. - - Returns: - The tensor with repeated elements. - """ - return self._apply("repeat", (array, repeats, axis)) - def reshape(self, array: Tensor, shape: Sequence[int]) -> Tensor: r"""The reshaped array. @@ -1220,31 +1020,6 @@ def reshape(self, array: Tensor, shape: Sequence[int]) -> Tensor: shape = (shape,) if isinstance(shape, int) else tuple(shape) return self._apply("reshape", (array, shape)) - def round(self, array: Tensor, decimals: int) -> Tensor: - r"""The array rounded to the nearest integer. - - Args: - array: The array to round - decimals: number of decimals to round to - - Returns: - The array rounded to the nearest integer - """ - return self._apply("round", (array, decimals)) - - def set_diag(self, array: Tensor, diag: Tensor, k: int) -> Tensor: - r"""The array with the diagonal set to ``diag``. - - Args: - array: The array to set the diagonal of - diag: The diagonal to set - k (int): diagonal to set - - Returns: - The array with the diagonal set to ``diag`` - """ - return self._apply("set_diag", (array, diag, k)) - def sin(self, array: Tensor) -> Tensor: r"""The sine of ``array``. @@ -1391,8 +1166,6 @@ def transpose(self, a: Tensor, perm: Sequence[int] = None): Returns: The transposed array """ - if a is None: - return None # TODO: remove and address None inputs where tranpose is used perm = tuple(perm) if perm is not None else None return self._apply("transpose", (a, perm)) @@ -1520,56 +1293,6 @@ def map_fn(self, fn: Callable, elements: Tensor) -> Tensor: """ return self._apply("map_fn", (fn, elements)) - def squeeze(self, tensor: Tensor, axis: list[int] | None = None) -> Tensor: - """Removes dimensions of size 1 from the shape of a tensor. - - Args: - tensor (Tensor): the tensor to squeeze - axis (Optional[List[int]]): if specified, only squeezes the - dimensions listed, defaults to [] - - Returns: - Tensor: tensor with one or more dimensions of size 1 removed - """ - return self._apply("squeeze", (tensor, axis)) - - def cholesky(self, input: Tensor) -> Tensor: - """Computes the Cholesky decomposition of square matrices. - - Args: - input (Tensor) - - Returns: - Tensor: tensor with the same type as input - """ - return self._apply("cholesky", (input,)) - - def Categorical(self, probs: Tensor, name: str): - """Categorical distribution over integers. - - Args: - probs: The unnormalized probabilities of a set of Categorical distributions. - name: The name prefixed to operations created by this class. - - Returns: - tfp.distributions.Categorical: instance of ``tfp.distributions.Categorical`` class - """ - return self._apply("Categorical", (probs, name)) - - def MultivariateNormalTriL(self, loc: Tensor, scale_tril: Tensor): - """Multivariate normal distribution on `R^k` and parameterized by a (batch of) length-k loc - vector (aka "mu") and a (batch of) k x k scale matrix; covariance = scale @ scale.T - where @ denotes matrix-multiplication. - - Args: - loc (Tensor): if this is set to None, loc is implicitly 0 - scale_tril: lower-triangular Tensor with non-zero diagonal elements - - Returns: - tfp.distributions.MultivariateNormalTriL: instance of ``tfp.distributions.MultivariateNormalTriL`` - """ - return self._apply("MultivariateNormalTriL", (loc, scale_tril)) - def DefaultEuclideanOptimizer(self): r"""Default optimizer for the Euclidean parameters.""" return self._apply("DefaultEuclideanOptimizer") @@ -1611,7 +1334,7 @@ def beamsplitter(self, theta: float, phi: float, shape: tuple[int, int, int, int """ return self._apply("beamsplitter", (theta, phi), {"shape": shape, "method": method}) - def squeezed(self, r: float, phi: float, shape: tuple[int, int]): # pragma: no cover + def squeezed(self, r: float, phi: float, shape: tuple[int, int]): r""" Creates a single mode squeezed state matrix using a numba-based fock lattice strategy. @@ -1658,7 +1381,8 @@ def dagger(self, array: Tensor) -> Tensor: return self.conj(self.transpose(array, perm=perm)) def unitary_to_orthogonal(self, U): - r"""Unitary to orthogonal mapping. + r""" + Unitary to orthogonal mapping. Args: U: The unitary matrix in ``U(n)`` @@ -1671,13 +1395,14 @@ def unitary_to_orthogonal(self, U): return self.block([[X, -Y], [Y, X]]) def random_symplectic(self, num_modes: int, max_r: float = 1.0) -> Tensor: - r"""A random symplectic matrix in ``Sp(2*num_modes)``. + r""" + A random symplectic matrix in ``Sp(2*num_modes)``. Squeezing is sampled uniformly from 0.0 to ``max_r`` (1.0 by default). """ if num_modes == 1: - W = np.exp(1j * 2 * np.pi * settings.rng.uniform(size=(1, 1))) - V = np.exp(1j * 2 * np.pi * settings.rng.uniform(size=(1, 1))) + W = self.exp(1j * 2 * np.pi * settings.rng.uniform(size=(1, 1))) + V = self.exp(1j * 2 * np.pi * settings.rng.uniform(size=(1, 1))) else: W = unitary_group.rvs(dim=num_modes, random_state=settings.rng) V = unitary_group.rvs(dim=num_modes, random_state=settings.rng) @@ -1689,45 +1414,26 @@ def random_symplectic(self, num_modes: int, max_r: float = 1.0) -> Tensor: @staticmethod def random_orthogonal(N: int) -> Tensor: - """A random orthogonal matrix in :math:`O(N)`.""" + r""" + A random orthogonal matrix in :math:`O(N)`. + """ if N == 1: return np.array([[1.0]]) return ortho_group.rvs(dim=N, random_state=settings.rng) def random_unitary(self, N: int) -> Tensor: - """a random unitary matrix in :math:`U(N)`""" + r""" + A random unitary matrix in :math:`U(N)`. + """ if N == 1: - return self.exp(1j * settings.rng.uniform(size=(1, 1))) + return np.exp(1j * settings.rng.uniform(size=(1, 1))) return unitary_group.rvs(dim=N, random_state=settings.rng) - def single_mode_to_multimode_vec(self, vec, num_modes: int): - r"""Apply the same 2-vector (i.e. single-mode) to a larger number of modes.""" - if vec.shape[-1] != 2: - raise ValueError("vec must be 2-dimensional (i.e. single-mode)") - x, y = vec[..., -2], vec[..., -1] - vec = self.concat( - [ - self.tile(self.astensor([x]), (num_modes,)), - self.tile(self.astensor([y]), (num_modes,)), - ], - axis=-1, - ) - return vec - - def single_mode_to_multimode_mat(self, mat: Tensor, num_modes: int): - r"""Apply the same :math:`2\times 2` matrix (i.e. single-mode) to a larger number of modes.""" - if mat.shape[-2:] != (2, 2): - raise ValueError("mat must be a single-mode (2x2) matrix") - mat = self.diag( - self.tile(self.expand_dims(mat, axis=-1), (1, 1, num_modes)), k=0 - ) # shape [2,2,N,N] - mat = self.reshape(self.transpose(mat, (0, 2, 1, 3)), [2 * num_modes, 2 * num_modes]) - return mat - @staticmethod @lru_cache() def Xmat(num_modes: int): - r"""The matrix :math:`X_n = \begin{bmatrix}0 & I_n\\ I_n & 0\end{bmatrix}.` + r""" + The matrix :math:`X_n = \begin{bmatrix}0 & I_n\\ I_n & 0\end{bmatrix}.` Args: num_modes (int): positive integer @@ -1779,58 +1485,6 @@ def all_diagonals(self, rho: Tensor, real: bool) -> Tensor: return self.reshape(diag, cutoffs) - def poisson(self, max_k: int, rate: Tensor) -> Tensor: - """Poisson distribution up to ``max_k``.""" - k = self.arange(max_k, dtype=rate.dtype) - return self.exp(k * self.log(rate + 1e-9) - rate - self.lgamma(k + 1.0)) - - def binomial_conditional_prob(self, success_prob: Tensor, dim_out: int, dim_in: int): - """:math:`P(out|in) = binom(in, out) * (1-success_prob)**(in-out) * success_prob**out`.""" - in_ = self.arange(dim_in)[None, :] - out_ = self.arange(dim_out)[:, None] - return ( - self.cast(binom(in_, out_), in_.dtype) - * self.pow(success_prob, out_) - * self.pow(1.0 - success_prob, self.maximum(in_ - out_, 0.0)) - ) - - def convolve_probs_1d(self, prob: Tensor, other_probs: list[Tensor]) -> Tensor: - """Convolution of a joint probability with a list of single-index probabilities.""" - - if prob.ndim > 3 or len(other_probs) > 3: - raise ValueError("cannot convolve arrays with more than 3 axes") - if not all((q.ndim == 1 for q in other_probs)): - raise ValueError("other_probs must contain 1d arrays") - if not all((len(q) == s for q, s in zip(other_probs, prob.shape))): - raise ValueError("The length of the 1d prob vectors must match shape of prob") - - q = other_probs[0] - for q_ in other_probs[1:]: - q = q[..., None] * q_[(None,) * q.ndim + (slice(None),)] - - return self.convolve_probs(prob, q) - - def convolve_probs(self, prob: Tensor, other: Tensor) -> Tensor: - r"""Convolve two probability distributions (up to 3D) with the same shape. - - Note that the output is not guaranteed to be a complete joint probability, - as it's computed only up to the dimension of the base probs. - """ - if prob.ndim > 3 or other.ndim > 3: - raise ValueError("cannot convolve arrays with more than 3 axes") - if not prob.shape == other.shape: - raise ValueError("prob and other must have the same shape") - - prob_padded = self.pad(prob, [(s - 1, 0) for s in other.shape]) - other_reversed = other[(slice(None, None, -1),) * other.ndim] - return self.convolution( - prob_padded[None, ..., None], - other_reversed[..., None, None], - data_format="N" - + ("HD"[: other.ndim - 1])[::-1] - + "WC", # TODO: rewrite this to be more readable (do we need it?) - )[0, ..., 0] - def euclidean_to_symplectic(self, S: Matrix, dS_euclidean: Matrix) -> Matrix: r"""Convert the Euclidean gradient to a Riemannian gradient on the tangent bundle of the symplectic manifold. diff --git a/mrmustard/math/backend_numpy.py b/mrmustard/math/backend_numpy.py index e24ed22c2..fb3a29a08 100644 --- a/mrmustard/math/backend_numpy.py +++ b/mrmustard/math/backend_numpy.py @@ -18,18 +18,14 @@ from __future__ import annotations -from math import lgamma as mlgamma from typing import Sequence, Callable from opt_einsum import contract import numpy as np -import scipy as sp -from scipy.signal import convolve2d as scipy_convolve2d from scipy.linalg import expm as scipy_expm from scipy.linalg import sqrtm as scipy_sqrtm from scipy.special import xlogy as scipy_xlogy -from scipy.stats import multivariate_normal from ..utils.settings import settings from .backend_base import BackendBase @@ -70,9 +66,6 @@ def all(self, array: np.ndarray) -> bool: def allclose(self, array1: np.array, array2: np.array, atol: float, rtol: float) -> bool: return np.allclose(array1, array2, atol=atol, rtol=rtol) - def angle(self, array: np.ndarray) -> np.ndarray: - return np.angle(array) - def any(self, array: np.ndarray) -> np.ndarray: return np.any(array) @@ -103,12 +96,6 @@ def broadcast_to(self, array: np.ndarray, shape: tuple[int]) -> np.ndarray: def broadcast_arrays(self, *arrays: list[np.ndarray]) -> list[np.ndarray]: return np.broadcast_arrays(*arrays) - def block_diag(self, *blocks: list[np.ndarray]) -> np.ndarray: - return sp.linalg.block_diag(*blocks) - - def boolean_mask(self, tensor: np.ndarray, mask: np.ndarray) -> np.ndarray: - return np.array([t for i, t in enumerate(tensor) if mask[i]]) - def cast(self, array: np.ndarray, dtype=None) -> np.ndarray: if dtype is None: return array @@ -129,59 +116,6 @@ def concat(self, values: list[np.ndarray], axis: int) -> np.ndarray: def conj(self, array: np.ndarray) -> np.ndarray: return np.conj(array) - def convolution( - self, - array: np.ndarray, - filters: np.ndarray, - padding: str = "VALID", - data_format: str | None = None, # pylint: disable=unused-argument - ) -> np.ndarray: - """Performs a 2D convolution operation similar to tf.nn.convolution. - - Args: - array: Input array of shape (batch, height, width, channels) - filters: Filter kernel of shape (kernel_height, kernel_width, in_channels, out_channels) - padding: String indicating the padding type ('VALID' or 'SAME') - data_format: Unused, kept for API compatibility - - Returns: - np.ndarray: Result of the convolution operation with shape (batch, new_height, new_width, out_channels) - """ - # Extract shapes - batch, _, _, _ = array.shape - kernel_h, kernel_w, _, out_channels = filters.shape - - # Reshape filter to 2D for convolution - filter_2d = filters[:, :, 0, 0] - - # For SAME padding, calculate padding sizes - if padding == "SAME": - pad_h = (kernel_h - 1) // 2 - pad_w = (kernel_w - 1) // 2 - array = np.pad( - array[:, :, :, 0], ((0, 0), (pad_h, pad_h), (pad_w, pad_w)), mode="constant" - ) - else: - array = array[:, :, :, 0] - - # Calculate output dimensions - out_height = array.shape[1] - kernel_h + 1 - out_width = array.shape[2] - kernel_w + 1 - - # Initialize output array - output = np.zeros((batch, out_height, out_width, out_channels)) - - # Perform convolution for each batch - for b in range(batch): - # Convolve using scipy's convolve2d which is more efficient than np.convolve for 2D - output[b, :, :, 0] = scipy_convolve2d( - array[b], - np.flip(np.flip(filter_2d, 0), 1), # Flip kernel for proper convolution - mode="valid", - ) - - return output - def cos(self, array: np.ndarray) -> np.ndarray: return np.cos(array) @@ -219,19 +153,6 @@ def diag_part(self, array: np.ndarray, k: int) -> np.ndarray: ret.flags.writeable = True return ret - def set_diag(self, array: np.ndarray, diag: np.ndarray, k: int) -> np.ndarray: - i = np.arange(0, array.shape[-2] - abs(k)) - if k < 0: - i -= array.shape[-2] - abs(k) - - j = np.arange(abs(k), array.shape[-1]) - if k < 0: - j -= abs(k) - - array[..., i, j] = diag - - return array - def einsum(self, string: str, *tensors, optimize: bool | str) -> np.ndarray: return contract(string, *tensors, optimize=optimize) @@ -265,9 +186,6 @@ def inv(self, tensor: np.ndarray) -> np.ndarray: def is_trainable(self, tensor: np.ndarray) -> bool: # pylint: disable=unused-argument return False - def lgamma(self, x: np.ndarray) -> np.ndarray: - return np.array([mlgamma(v) for v in x]) - def log(self, x: np.ndarray) -> np.ndarray: return np.log(x) @@ -283,17 +201,6 @@ def matmul(self, *matrices: np.ndarray) -> np.ndarray: def matvec(self, a: np.ndarray, b: np.ndarray) -> np.ndarray: return self.matmul(a, b[..., None])[..., 0] - def maximum(self, a: np.ndarray, b: np.ndarray) -> np.ndarray: - return np.maximum(a, b) - - def minimum(self, a: np.ndarray, b: np.ndarray) -> np.ndarray: - return np.minimum(a, b) - - def moveaxis( - self, array: np.ndarray, old: int | Sequence[int], new: int | Sequence[int] - ) -> np.ndarray: - return np.moveaxis(array, old, new) - def new_variable( self, value, @@ -333,7 +240,7 @@ def error_if( raise ValueError(msg) def outer(self, array1: np.ndarray, array2: np.ndarray) -> np.ndarray: - return np.tensordot(array1, array2, [[], []]) + return self.tensordot(array1, array2, [[], []]) def pad( self, @@ -365,12 +272,6 @@ def real(self, array: np.ndarray) -> np.ndarray: def reshape(self, array: np.ndarray, shape: Sequence[int]) -> np.ndarray: return np.reshape(array, shape) - def repeat(self, array: np.ndarray, repeats: int, axis: int = None) -> np.ndarray: - return np.repeat(array, repeats, axis=axis) - - def round(self, array: np.ndarray, decimals: int = 0) -> np.ndarray: - return np.round(array, decimals) - def sin(self, array: np.ndarray) -> np.ndarray: return np.sin(array) @@ -431,40 +332,6 @@ def map_fn(self, func, elements): # Is this done like this? return np.array([func(e) for e in elements]) - def squeeze(self, tensor, axis=None): - return np.squeeze(tensor, axis=axis) - - def cholesky(self, input: np.ndarray): - return np.linalg.cholesky(input) - - def Categorical(self, probs: np.ndarray, name: str): # pylint: disable=unused-argument - class Generator: - def __init__(self, probs): - self._probs = probs - - def sample(self): - idx = [i for i, _ in enumerate(probs)] - return np.random.choice(idx, p=probs / sum(probs)) - - return Generator(probs) - - def MultivariateNormalTriL(self, loc: np.ndarray, scale_tril: np.ndarray): - class Generator: - def __init__(self, mean, cov): - self._mean = mean - self._cov = cov - - def sample(self, dtype=None): # pylint: disable=unused-argument - fn = np.random.default_rng().multivariate_normal - ret = fn(self._mean, self._cov) - return ret - - def prob(self, x): - return multivariate_normal.pdf(x, mean=self._mean, cov=self._cov) - - scale_tril = scale_tril @ np.transpose(scale_tril) - return Generator(loc, scale_tril) - @staticmethod def eigvals(tensor: np.ndarray) -> np.ndarray: return np.linalg.eigvals(tensor) diff --git a/mrmustard/math/backend_tensorflow.py b/mrmustard/math/backend_tensorflow.py index 63440bca5..1f4eca58c 100644 --- a/mrmustard/math/backend_tensorflow.py +++ b/mrmustard/math/backend_tensorflow.py @@ -24,7 +24,6 @@ import numpy as np from semantic_version import Version -import tensorflow_probability as tfp os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" import tensorflow as tf @@ -76,9 +75,6 @@ def all(self, array: tf.Tensor) -> tf.Tensor: def allclose(self, array1: np.array, array2: np.array, atol: float, rtol: float) -> bool: return tf.experimental.numpy.allclose(array1, array2, atol=atol, rtol=rtol) - def angle(self, array: tf.Tensor) -> tf.Tensor: - return tf.experimental.numpy.angle(array) - def any(self, array: tf.Tensor) -> tf.Tensor: return tf.math.reduce_any(array) @@ -100,14 +96,6 @@ def astensor(self, array: np.ndarray | tf.Tensor, dtype=None) -> tf.Tensor: def atleast_nd(self, array: tf.Tensor, n: int, dtype=None) -> tf.Tensor: return tf.experimental.numpy.array(array, ndmin=n, dtype=dtype) - def block_diag(self, mat1: tf.Tensor, mat2: tf.Tensor) -> tf.Tensor: - Za = self.zeros((mat1.shape[-2], mat2.shape[-1]), dtype=mat1.dtype) - Zb = self.zeros((mat2.shape[-2], mat1.shape[-1]), dtype=mat1.dtype) - return self.concat( - [self.concat([mat1, Za], axis=-1), self.concat([Zb, mat2], axis=-1)], - axis=-2, - ) - def broadcast_to(self, array: tf.Tensor, shape: tuple[int]) -> tf.Tensor: return tf.broadcast_to(array, shape) @@ -126,9 +114,6 @@ def broadcast_arrays(self, *arrays: list[tf.Tensor]) -> list[tf.Tensor]: # Broadcast each array to the common shape return [tf.broadcast_to(arr, broadcasted_shape) for arr in arrays] - def boolean_mask(self, tensor: tf.Tensor, mask: tf.Tensor) -> Tensor: - return tf.boolean_mask(tensor, mask) - def cast(self, array: tf.Tensor, dtype=None) -> tf.Tensor: if dtype is None: return array @@ -145,7 +130,8 @@ def concat(self, values: Sequence[tf.Tensor], axis: int) -> tf.Tensor: def conj(self, array: tf.Tensor) -> tf.Tensor: return tf.math.conj(array) - def constraint_func(self, bounds: tuple[float | None, float | None]) -> Callable | None: + @staticmethod + def constraint_func(bounds: tuple[float | None, float | None]) -> Callable | None: bounds = ( -np.inf if bounds[0] is None else bounds[0], np.inf if bounds[1] is None else bounds[1], @@ -159,18 +145,6 @@ def constraint(x): constraint = None return constraint - # pylint: disable=arguments-differ - @Autocast() - def convolution( - self, - array: tf.Tensor, - filters: tf.Tensor, - padding: str | None = None, - data_format="NWC", - ) -> tf.Tensor: - padding = padding or "VALID" - return tf.nn.convolution(array, filters=filters, padding=padding, data_format=data_format) - def cos(self, array: tf.Tensor) -> tf.Tensor: return tf.math.cos(array) @@ -238,9 +212,6 @@ def inv(self, tensor: tf.Tensor) -> tf.Tensor: def is_trainable(self, tensor: tf.Tensor) -> bool: return isinstance(tensor, tf.Variable) - def lgamma(self, x: tf.Tensor) -> tf.Tensor: - return tf.math.lgamma(x) - def log(self, x: tf.Tensor) -> tf.Tensor: return tf.math.log(x) @@ -258,19 +229,6 @@ def matvec(self, a: tf.Tensor, b: tf.Tensor) -> tf.Tensor: def make_complex(self, real: tf.Tensor, imag: tf.Tensor) -> tf.Tensor: return tf.complex(real, imag) - @Autocast() - def maximum(self, a: tf.Tensor, b: tf.Tensor) -> tf.Tensor: - return tf.maximum(a, b) - - @Autocast() - def minimum(self, a: tf.Tensor, b: tf.Tensor) -> tf.Tensor: - return tf.minimum(a, b) - - def moveaxis( - self, array: tf.Tensor, old: int | Sequence[int], new: int | Sequence[int] - ) -> tf.Tensor: - return tf.experimental.numpy.moveaxis(array, old, new) - def new_variable( self, value, @@ -304,7 +262,7 @@ def infinity_like(self, array: np.ndarray) -> np.ndarray: @Autocast() def outer(self, array1: tf.Tensor, array2: tf.Tensor) -> tf.Tensor: - return tf.tensordot(array1, array2, [[], []]) + return self.tensordot(array1, array2, [[], []]) def pad( self, @@ -317,7 +275,14 @@ def pad( @staticmethod def pinv(matrix: tf.Tensor) -> tf.Tensor: - return tf.linalg.pinv(matrix) + # need to handle complex case on our own + # https://stackoverflow.com/questions/60025950/tensorflow-pseudo-inverse-doesnt-work-for-complex-matrices + real_matrix = tf.math.real(matrix) + imag_matrix = tf.math.imag(matrix) + r0 = tf.linalg.pinv(real_matrix) @ imag_matrix + y11 = tf.linalg.pinv(imag_matrix @ r0 + real_matrix) + y10 = -r0 @ y11 + return tf.cast(tf.complex(y11, y10), dtype=matrix.dtype) @Autocast() def pow(self, x: tf.Tensor, y: float) -> tf.Tensor: @@ -335,15 +300,6 @@ def real(self, array: tf.Tensor) -> tf.Tensor: def reshape(self, array: tf.Tensor, shape: Sequence[int]) -> tf.Tensor: return tf.reshape(array, shape) - def repeat(self, array: tf.Tensor, repeats: int, axis: int = None) -> tf.Tensor: - return tf.repeat(array, repeats, axis=axis) - - def round(self, array: tf.Tensor, decimals: int = 0) -> tf.Tensor: - return tf.round(10**decimals * array) / 10**decimals - - def set_diag(self, array: tf.Tensor, diag: tf.Tensor, k: int) -> tf.Tensor: - return tf.linalg.set_diag(array, diag, k=k) - def sin(self, array: tf.Tensor) -> tf.Tensor: return tf.math.sin(array) @@ -401,18 +357,6 @@ def zeros_like(self, array: tf.Tensor) -> tf.Tensor: def map_fn(self, func, elements): return tf.map_fn(func, elements) - def squeeze(self, tensor, axis=None): - return tf.squeeze(tensor, axis=axis or []) - - def cholesky(self, input: Tensor): - return tf.linalg.cholesky(input) - - def Categorical(self, probs: Tensor, name: str): - return tfp.distributions.Categorical(probs=probs, name=name) - - def MultivariateNormalTriL(self, loc: Tensor, scale_tril: Tensor): - return tfp.distributions.MultivariateNormalTriL(loc=loc, scale_tril=scale_tril) - @staticmethod def eigh(tensor: tf.Tensor) -> Tensor: return tf.linalg.eigh(tensor) diff --git a/mrmustard/physics/bargmann_utils.py b/mrmustard/physics/bargmann_utils.py index 77651523e..3b55178c6 100644 --- a/mrmustard/physics/bargmann_utils.py +++ b/mrmustard/physics/bargmann_utils.py @@ -99,29 +99,6 @@ def wigner_to_bargmann_psi(cov, means): # NOTE: c for th psi is to calculated from the global phase formula. -def norm_ket(A, b, c): - r"""Calculates the l2 norm of a Ket with a representation given by the Bargmann triple A,b,c.""" - M = math.block([[math.conj(A), -math.eye_like(A)], [-math.eye_like(A), A]]) - B = math.concat([math.conj(b), b], 0) - norm_squared = ( - math.abs(c) ** 2 - * math.exp(-0.5 * math.sum(B * math.matvec(math.inv(M), B))) - / math.sqrt((-1) ** A.shape[-1] * math.det(M)) - ) - return math.real(math.sqrt(norm_squared)) - - -def trace_dm(A, b, c): - r"""Calculates the total trace of the density matrix with representation given by the Bargmann triple A,b,c.""" - M = A - math.Xmat(A.shape[-1] // 2) - trace = ( - c - * math.exp(-0.5 * math.sum(b * math.matvec(math.inv(M), b))) - / math.sqrt((-1) ** (A.shape[-1] // 2) * math.det(M)) - ) - return math.real(trace) - - def au2Symplectic(A): r""" helper for finding the Au of a unitary from its symplectic rep. diff --git a/mrmustard/physics/fock_utils.py b/mrmustard/physics/fock_utils.py index a70638aa8..03470fbc5 100644 --- a/mrmustard/physics/fock_utils.py +++ b/mrmustard/physics/fock_utils.py @@ -28,9 +28,7 @@ from tensorflow.python.framework.errors_impl import InvalidArgumentError from mrmustard import math, settings -from mrmustard.math.lattice import strategies from mrmustard.math.caching import tensor_int_cache -from mrmustard.math.jax_vjps import beamsplitter_jax, displacement_jax from mrmustard.utils.typing import Scalar, Tensor, Vector, Batch # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -56,14 +54,14 @@ def fock_state(n: int | Sequence[int], cutoffs: int | Sequence[int] | None = Non ValueError: If the number of cutoffs does not match the number of photon numbers. ValueError: If the photon numbers are larger than the corresponding cutoffs. """ - n = math.atleast_1d(n) + n = math.atleast_nd(n, 1) if cutoffs is None: cutoffs = list(n) elif isinstance(cutoffs, int): cutoffs = [cutoffs] * len(n) else: - cutoffs = math.atleast_1d(cutoffs) + cutoffs = math.atleast_nd(cutoffs, 1) if len(cutoffs) != len(n): msg = f"Expected ``len(cutoffs)={len(n)}`` but found ``{len(cutoffs)}``." @@ -79,56 +77,6 @@ def fock_state(n: int | Sequence[int], cutoffs: int | Sequence[int] | None = Non return array -def ket_to_dm(ket: Tensor) -> Tensor: - r"""Maps a ket to a density matrix. - - Args: - ket: the ket - - Returns: - Tensor: the density matrix - """ - return math.outer(ket, math.conj(ket)) - - -def ket_to_probs(ket: Tensor) -> Tensor: - r"""Maps a ket to probabilities. - - Args: - ket: the ket - - Returns: - Tensor: the probabilities vector - """ - return math.abs(ket) ** 2 - - -def dm_to_probs(dm: Tensor) -> Tensor: - r"""Extracts the diagonals of a density matrix. - - Args: - dm: the density matrix - - Returns: - Tensor: the probabilities vector - """ - return math.all_diagonals(dm, real=True) - - -def U_to_choi(U: Tensor, Udual: Tensor | None = None) -> Tensor: - r"""Converts a unitary transformation to a Choi tensor. - - Args: - U: the unitary transformation - Udual: the dual unitary transformation (optional, will use conj U if not provided) - - Returns: - Tensor: the Choi tensor. The index order is going to be :math:`[\mathrm{out}_l, \mathrm{in}_l, \mathrm{out}_r, \mathrm{in}_r]` - where :math:`\mathrm{in}_l` and :math:`\mathrm{in}_r` are to be contracted with the left and right indices of the density matrix. - """ - return math.outer(U, math.conj(U) if Udual is None else Udual) - - def fidelity(dm_a, dm_b) -> Scalar: r"""Computes the fidelity between two states in Fock representation.""" # Richard Jozsa (1994) Fidelity for Mixed Quantum States, @@ -137,48 +85,6 @@ def fidelity(dm_a, dm_b) -> Scalar: return math.abs(math.trace(math.sqrtm(math.matmul(sqrt_dm_a, dm_b, sqrt_dm_a))) ** 2) -def number_means(tensor, is_dm: bool): - r"""Returns the mean of the number operator in each mode.""" - probs = math.all_diagonals(tensor, real=True) if is_dm else math.abs(tensor) ** 2 - modes = list(range(len(probs.shape))) - marginals = [math.sum(probs, axis=modes[:k] + modes[k + 1 :]) for k in range(len(modes))] - return math.astensor( - [ - math.sum(marginal * math.arange(len(marginal), dtype=math.float64)) - for marginal in marginals - ] - ) - - -def number_variances(tensor, is_dm: bool): - r"""Returns the variance of the number operator in each mode.""" - probs = math.all_diagonals(tensor, real=True) if is_dm else math.abs(tensor) ** 2 - modes = list(range(len(probs.shape))) - marginals = [math.sum(probs, axis=modes[:k] + modes[k + 1 :]) for k in range(len(modes))] - return math.astensor( - [ - ( - math.sum(marginal * math.arange(marginal.shape[0], dtype=marginal.dtype) ** 2) - - math.sum(marginal * math.arange(marginal.shape[0], dtype=marginal.dtype)) ** 2 - ) - for marginal in marginals - ] - ) - - -def validate_contraction_indices(in_idx, out_idx, M, name): - r"""Validates the indices used for the contraction of a tensor.""" - if len(set(in_idx)) != len(in_idx): - raise ValueError(f"{name}_in_idx should not contain repeated indices.") - if len(set(out_idx)) != len(out_idx): - raise ValueError(f"{name}_out_idx should not contain repeated indices.") - if not set(range(M)).intersection(out_idx).issubset(set(in_idx)): - wrong_indices = set(range(M)).intersection(out_idx) - set(in_idx) - raise ValueError( - f"Indices {wrong_indices} in {name}_out_idx are trying to replace uncontracted indices." - ) - - @tensor_int_cache def oscillator_eigenstate(q: Vector, cutoff: int) -> Tensor: r"""Harmonic oscillator eigenstate wavefunction `\psi_n(q) = `. @@ -385,35 +291,6 @@ def quadrature_distribution( return x, math.real(pdf) -def sample_homodyne(state: Tensor, quadrature_angle: float = 0.0) -> tuple[float, float]: - r"""Given a single-mode state, it generates the pdf of :math:`\tr [ \rho |x_\phi> 2: - raise ValueError( - "Input state has dimension {state.shape}. Make sure is either a single-mode ket or dm." - ) - - x, pdf = quadrature_distribution(state, quadrature_angle) - probs = pdf * (x[1] - x[0]) - - # draw a sample from the distribution - pdf = math.Categorical(probs=probs, name="homodyne_dist") - sample_idx = pdf.sample() - homodyne_sample = math.gather(x, sample_idx) - probability_sample = math.gather(probs, sample_idx) - - return homodyne_sample, probability_sample - - def c_ps_matrix(m, n, alpha): """ helper function for ``c_in_PS``. diff --git a/mrmustard/physics/gaussian.py b/mrmustard/physics/gaussian.py index c16fb2c04..9133424e9 100644 --- a/mrmustard/physics/gaussian.py +++ b/mrmustard/physics/gaussian.py @@ -42,27 +42,6 @@ def number_means(cov: Matrix, means: Vector) -> Vector: ) / (2 * settings.HBAR) -def number_cov(cov: Matrix, means: Vector) -> Matrix: - r"""Returns the photon number covariance matrix given a Wigner covariance matrix and a means vector. - - Args: - cov: the Wigner covariance matrix - means: the Wigner means vector - - Returns: - Matrix: the photon number covariance matrix - """ - N = means.shape[-1] // 2 - mCm = cov * means[:, None] * means[None, :] - dd = math.diag(math.diag_part(mCm[:N, :N] + mCm[N:, N:] + mCm[:N, N:] + mCm[N:, :N])) / ( - 2 * settings.HBAR**2 # TODO: sum(diag_part) is better than diag_part(sum) - ) - CC = (cov**2 + mCm) / (2 * settings.HBAR**2) - return ( - CC[:N, :N] + CC[N:, N:] + CC[:N, N:] + CC[N:, :N] + dd - 0.25 * math.eye(N, dtype=CC.dtype) - ) - - def purity(cov: Matrix) -> Scalar: r"""Returns the purity of the state with the given covariance matrix. @@ -165,26 +144,3 @@ def fidelity(mu1: Vector, cov1: Matrix, mu2: Vector, cov2: Matrix) -> float: _fidelity = f0 * math.exp((-1 / 2) * dot) # square of equation 95 return math.real(_fidelity) - - -def log_negativity(cov: Matrix) -> float: - r"""Returns the log_negativity of a Gaussian state. - - Reference: `https://arxiv.org/abs/quant-ph/0102117 `_ , Equation 57, 61. - - Args: - cov (Matrix): the covariance matrix - - Returns: - float: the log-negativity - """ - vals = symplectic_eigenvals(cov) - vals_filtered = math.boolean_mask( - vals, vals < 1.0 - ) # Get rid of terms that would lead to zero contribution. - if len(vals_filtered) > 0: - return -math.sum( - math.log(vals_filtered) / math.cast(math.log(2.0), dtype=vals_filtered.dtype) - ) - - return 0 diff --git a/mrmustard/physics/gaussian_integrals.py b/mrmustard/physics/gaussian_integrals.py index 446fadc61..a6a7be910 100644 --- a/mrmustard/physics/gaussian_integrals.py +++ b/mrmustard/physics/gaussian_integrals.py @@ -154,38 +154,6 @@ def join_Abc_real( return A12, b12, c12 -def reorder_abc(Abc: tuple, order: Sequence[int]): - r""" - Reorders the indices of the A matrix and b vector of an (A,b,c) triple. - - Arguments: - Abc: the ``(A,b,c)`` triple - order: the new order of the indices - - Returns: - The reordered ``(A,b,c)`` triple - """ - A, b, c = Abc - c = math.astensor(c) - order = list(order) - if len(order) == 0: - return A, b, c - batched = len(A.shape) == 3 and len(b.shape) == 2 and len(c.shape) > 0 - dim_poly = len(c.shape) - int(batched) - n = A.shape[-1] - dim_poly - - if len(order) != n: - raise ValueError(f"order must have length {n}, got {len(order)}") - - if any(i >= n or n < 0 for i in order): - raise ValueError(f"elements in `order` must be between 0 and {n-1}, got {order}") - order += list(range(len(order), len(order) + dim_poly)) - order = math.astensor(order) - A = math.gather(math.gather(A, order, axis=-1), order, axis=-2) - b = math.gather(b, order, axis=-1) - return A, b, c - - def join_Abc( Abc1: tuple[ComplexMatrix, ComplexVector, ComplexTensor], Abc2: tuple[ComplexMatrix, ComplexVector, ComplexTensor], diff --git a/tests/test_lab/test_circuit_components_utils/test_b_to_ps.py b/tests/test_lab/test_circuit_components_utils/test_b_to_ps.py index cfcd5439a..b25111e84 100644 --- a/tests/test_lab/test_circuit_components_utils/test_b_to_ps.py +++ b/tests/test_lab/test_circuit_components_utils/test_b_to_ps.py @@ -47,9 +47,8 @@ def test_application(self, hbar): vec = np.linspace(-4.5, 4.5, 100) wigner, _, _ = wigner_discretized(dm, vec, vec) - settings.HBAR = hbar - Wigner = (state >> BtoPS(0, s=0)).ansatz - settings.HBAR = 1.0 + with settings(HBAR=hbar): + Wigner = (state >> BtoPS(0, s=0)).ansatz X, Y = np.meshgrid(vec / np.sqrt(2 * settings.HBAR), vec / np.sqrt(2 * settings.HBAR)) assert math.allclose( diff --git a/tests/test_math/test_backend_manager.py b/tests/test_math/test_backend_manager.py index ee4d932b7..3fc133ee1 100644 --- a/tests/test_math/test_backend_manager.py +++ b/tests/test_math/test_backend_manager.py @@ -159,61 +159,6 @@ def test_astensor(self, t, l): exp = exp.numpy() assert math.allclose(res, exp) - @pytest.mark.parametrize("t", types) - @pytest.mark.parametrize("l", [l1, l3]) - def test_atleast_1d(self, t, l): - r""" - Tests the ``atleast_1d`` method. - """ - dtype = getattr(math, t, None) - arr = np.array(l) - - res = math.asnumpy(math.atleast_1d(arr, dtype=dtype)) - - exp = np.atleast_1d(arr) - if dtype: - np_dtype = getattr(np, t, None) - exp = exp.astype(np_dtype) - - assert math.allclose(res, exp) - - @pytest.mark.parametrize("t", types) - @pytest.mark.parametrize("l", [l1, l3]) - def test_atleast_2d(self, t, l): - r""" - Tests the ``atleast_2d`` method. - """ - dtype = getattr(math, t, None) - arr = np.array(l) - - res = math.asnumpy(math.atleast_2d(arr, dtype=dtype)) - - exp = np.atleast_2d(arr) - if dtype: - np_dtype = getattr(np, t, None) - exp = exp.astype(np_dtype) - - assert math.allclose(res, exp) - - @pytest.mark.parametrize("t", types) - @pytest.mark.parametrize("l", [l1, l3, l5]) - def test_atleast_3d(self, t, l): - r""" - Tests the ``atleast_3d`` method. - """ - dtype = getattr(math, t, None) - arr = np.array(l) - - res = math.asnumpy(math.atleast_3d(arr, dtype=dtype)) - - if arr.ndim == 1: - exp_shape = (1, 1) + arr.shape - elif arr.ndim == 2: - exp_shape = (1,) + arr.shape - else: - exp_shape = arr.shape - assert res.shape == exp_shape - @pytest.mark.parametrize("t", types) @pytest.mark.parametrize("l", [l1, l3, l5]) @pytest.mark.parametrize("n", [1, 2, 3]) @@ -232,16 +177,6 @@ def test_atleast_nd(self, t, l, n): exp_shape = arr.shape assert res.shape == exp_shape - def test_boolean_mask(self): - r""" - Tests the ``boolean_mask`` method. - """ - arr = math.astensor([1, 2, 3, 4]) - mask = math.astensor([True, False, True, True]) - res = math.boolean_mask(arr, mask) - exp = math.astensor([1, 3, 4]) - assert math.allclose(res, exp) - def test_block(self): r""" Tests the ``block`` method. @@ -258,16 +193,6 @@ def test_block(self): ) assert R.shape == (16, 16) - def test_block_diag(self): - r""" - Tests the ``block_diag`` method. - """ - I = math.ones(shape=(4, 4), dtype=math.complex128) - O = math.zeros(shape=(4, 4), dtype=math.complex128) - R = math.block_diag(I, 1j * I) - assert R.shape == (8, 8) - assert math.allclose(math.block([[I, O], [O, 1j * I]]), R) - def test_broadcast_arrays(self): r""" Tests the ``broadcast_arrays`` method. @@ -500,13 +425,6 @@ def test_is_trainable(self): if math.backend_name == "jax": assert not math.is_trainable(arr4) - def test_lgamma(self): - r""" - Tests the ``lgamma`` method. - """ - arr = np.array([1.0, 2.0, 3.0, 4.0]) - assert math.allclose(math.asnumpy(math.lgamma(arr)), math.lgamma(arr)) - def test_log(self): r""" Tests the ``log`` method. @@ -522,37 +440,6 @@ def test_make_complex(self): i = 2.0 assert math.asnumpy(math.make_complex(r, i)) == r + i * 1j - def test_maximum(self): - r""" - Tests the ``maximum`` method. - """ - arr1 = np.eye(3) - arr2 = 2 * np.eye(3) - res = math.asnumpy(math.maximum(arr1, arr2)) - assert math.allclose(res, arr2) - - def test_minimum(self): - r""" - Tests the ``minimum`` method. - """ - arr1 = np.eye(3) - arr2 = 2 * np.eye(3) - res = math.asnumpy(math.minimum(arr1, arr2)) - assert math.allclose(res, arr1) - - def test_moveaxis(self): - r""" - Tests the ``moveaxis`` method. - """ - arr1 = np.random.random(size=(1, 2, 3)) - arr2 = np.random.random(size=(2, 1, 3)) - arr2_moved = math.moveaxis(arr2, 0, 1) - assert math.allclose(arr1.shape, arr2_moved.shape) - - arr1_moved1 = math.moveaxis(arr1, 0, 1) - arr1_moved2 = math.moveaxis(arr1_moved1, 1, 0) - assert math.allclose(arr1, arr1_moved2) - @pytest.mark.parametrize("t", types) def test_new_variable(self, t): r""" @@ -649,14 +536,6 @@ def test_reshape(self): arr = math.reshape(arr, shape) assert arr.shape == shape - def test_set_diag(self): - r""" - Tests the ``set_diag`` method. - """ - arr = np.zeros(shape=(3, 3)) - diag = np.ones(shape=(3,)) - assert math.allclose(math.asnumpy(math.set_diag(arr, diag, 0)), np.eye(3)) - @pytest.mark.parametrize("l", lists) def test_sin(self, l): r""" @@ -716,14 +595,6 @@ def test_sum(self): res = math.asnumpy(math.sum(arr)) assert math.allclose(res, 12) - def test_categorical(self): - r""" - Tests the ``Categorical`` method. - """ - probs = np.array([1e-6 for _ in range(300)]) - results = [math.Categorical(probs, "") for _ in range(100)] - assert len(set(results)) > 1 - def test_displacement(self): r""" Tests the ``displacement`` method. diff --git a/tests/test_math/test_lattice/test_lattice_functions.py b/tests/test_math/test_lattice/test_lattice_functions.py index ffbb65e5b..c7a13f766 100644 --- a/tests/test_math/test_lattice/test_lattice_functions.py +++ b/tests/test_math/test_lattice/test_lattice_functions.py @@ -121,16 +121,16 @@ def test_sector_u(): def test_vanillaNumba_vs_binomial(): """Test that the vanilla method and the binomial method give the same result.""" - settings.SEED = 42 - A, b, c = Ket.random((0, 1)).bargmann_triple() - A, b, c = math.asnumpy(A), math.asnumpy(b), math.asnumpy(c) + with settings(SEED=42): + A, b, c = Ket.random((0, 1)).bargmann_triple() + A, b, c = math.asnumpy(A), math.asnumpy(b), math.asnumpy(c) - ket_vanilla = vanilla_numba(shape=(10, 10), A=A, b=b, c=c)[:5, :5] - ket_binomial = binomial(local_cutoffs=(5, 5), A=A, b=b, c=c, max_l2=0.9999, global_cutoff=12)[ - 0 - ][:5, :5] + ket_vanilla = vanilla_numba(shape=(10, 10), A=A, b=b, c=c)[:5, :5] + ket_binomial = binomial( + local_cutoffs=(5, 5), A=A, b=b, c=c, max_l2=0.9999, global_cutoff=12 + )[0][:5, :5] - assert np.allclose(ket_vanilla, ket_binomial) + assert np.allclose(ket_vanilla, ket_binomial) def test_vanilla_stable(): diff --git a/tests/test_physics/test_bargmann_utils.py b/tests/test_physics/test_bargmann_utils.py index 4722e919f..2e33091a2 100644 --- a/tests/test_physics/test_bargmann_utils.py +++ b/tests/test_physics/test_bargmann_utils.py @@ -21,9 +21,7 @@ from mrmustard.physics.bargmann_utils import ( XY_of_channel, au2Symplectic, - norm_ket, symplectic2Au, - trace_dm, wigner_to_bargmann_psi, wigner_to_bargmann_rho, ) @@ -51,22 +49,6 @@ def test_wigner_to_bargmann_rho(): assert np.allclose(c, c_exp) -def test_norm_ket(): - """Test that the norm of a ket is calculated correctly""" - - ket = Vacuum((0, 1)) >> Unitary.from_symplectic((0, 1), math.random_symplectic(2)) - A, b, c = ket.bargmann_triple() - assert np.isclose(norm_ket(A, b, c), ket.probability) - - -def test_trace_dm(): - """Test that the trace of a density matrix is calculated correctly""" - ket = Vacuum((0, 1, 2, 3)) >> Unitary.from_symplectic((0, 1, 2, 3), math.random_symplectic(4)) - dm = ket[0, 1] - A, b, c = dm.bargmann_triple() - assert np.allclose(trace_dm(A, b, c), dm.probability) - - def test_au2Symplectic(): """Tests the Au -> symplectic code; we check two simple examples""" # Beam splitter example diff --git a/tests/test_physics/test_fock_utils.py b/tests/test_physics/test_fock_utils.py index 23fbc7890..e2a68a584 100644 --- a/tests/test_physics/test_fock_utils.py +++ b/tests/test_physics/test_fock_utils.py @@ -161,29 +161,3 @@ def test_lossy_two_mode_squeezing(n_mean, phi, eta_0, eta_1): mean_1 = np.sum(n * ps1) assert np.allclose(mean_0, n_mean * eta_0, atol=1e-5) assert np.allclose(mean_1, n_mean * eta_1, atol=1e-5) - - -@given(x=st.floats(-1, 1), y=st.floats(-1, 1)) -def test_number_means(x, y): - """Tests the mean photon number.""" - ket = Coherent(0, x, y).fock_array(80) - dm = Coherent(0, x, y).dm().fock_array(80) - assert np.allclose(fock_utils.number_means(ket, False), x * x + y * y) - assert np.allclose(fock_utils.number_means(dm, True), x * x + y * y) - - -@given(x=st.floats(-1, 1), y=st.floats(-1, 1)) -def test_number_variances_coh(x, y): - """Tests the variance of the number operator.""" - assert np.allclose( - fock_utils.number_variances(Coherent(0, x, y).fock_array(80), False)[0], x * x + y * y - ) - assert np.allclose( - fock_utils.number_variances(Coherent(0, x, y).dm().fock_array(80), True)[0], x * x + y * y - ) - - -def test_number_variances_fock(): - """Tests the variance of the number operator in Fock.""" - assert np.allclose(fock_utils.number_variances(Number(0, 1).fock_array(100), False), 0) - assert np.allclose(fock_utils.number_variances(Number(0, 1).dm().fock_array(100), True), 0) diff --git a/tests/test_physics/test_gaussian_integrals.py b/tests/test_physics/test_gaussian_integrals.py index 9caf5537b..b8e442416 100644 --- a/tests/test_physics/test_gaussian_integrals.py +++ b/tests/test_physics/test_gaussian_integrals.py @@ -24,7 +24,6 @@ join_Abc, join_Abc_real, real_gaussian_integral, - reorder_abc, ) @@ -175,28 +174,6 @@ def test_join_Abc_batched_kron(): assert math.allclose(c, math.astensor([70, 700])) -def test_reorder_abc(): - """Test that the reorder_abc function works correctly""" - A = math.astensor([[1, 2], [2, 3]]) - b = math.astensor([4, 5]) - c = math.astensor(6) - same = reorder_abc((A, b, c), (0, 1)) - assert all(math.allclose(x, y) for x, y in zip(same, (A, b, c))) - flipped = reorder_abc((A, b, c), (1, 0)) - assert all(math.allclose(x, y) for x, y in zip(flipped, (A[::-1, :][:, ::-1], b[::-1], c))) - - A = math.astensor([[[1, 2, 3], [2, 4, 5], [3, 5, 6]]]) - b = math.astensor([[4, 5, 6]]) - c = math.astensor([[1, 2, 3]]) - same = reorder_abc((A, b, c), (0, 1)) - assert all(math.allclose(x, y) for x, y in zip(same, (A, b, c))) - flipped = reorder_abc((A, b, c), (1, 0)) - assert all( - math.allclose(x, y) - for x, y in zip(flipped, (A[:, (1, 0, 2), :][:, :, (1, 0, 2)], b[:, (1, 0, 2)], c)) - ) - - def test_complex_gaussian_integral_2_not_batched(): """Tests the ``complex_gaussian_integral_2`` method for non-batched inputs.""" A1, b1, c1 = triples.vacuum_state_Abc(2) diff --git a/tests/test_training/test_callbacks.py b/tests/test_training/test_callbacks.py index 37e7792ba..f34b23b1c 100644 --- a/tests/test_training/test_callbacks.py +++ b/tests/test_training/test_callbacks.py @@ -25,42 +25,42 @@ @pytest.mark.requires_backend("tensorflow") def test_tensorboard_callback(tmp_path): """Tests tensorboard callbacks on hong-ou-mandel optimization.""" - settings.SEED = 42 - i, k = 2, 3 - r = np.arcsinh(1.0) - state_in = Vacuum((0, 1, 2, 3)) - s2_0, s2_1, bs = ( - S2gate((0, 1), r=r, phi=0.0, phi_trainable=True), - S2gate((2, 3), r=r, phi=0.0, phi_trainable=True), - BSgate( - (1, 2), - theta=np.arccos(np.sqrt(k / (i + k))) + 0.1 * settings.rng.normal(), - phi=settings.rng.normal(), - theta_trainable=True, - phi_trainable=True, - ), - ) - circ = Circuit([state_in, s2_0, s2_1, bs]) - cutoff = 1 + i + k + with settings(SEED=42): + i, k = 2, 3 + r = np.arcsinh(1.0) + state_in = Vacuum((0, 1, 2, 3)) + s2_0, s2_1, bs = ( + S2gate((0, 1), r=r, phi=0.0, phi_trainable=True), + S2gate((2, 3), r=r, phi=0.0, phi_trainable=True), + BSgate( + (1, 2), + theta=np.arccos(np.sqrt(k / (i + k))) + 0.1 * settings.rng.normal(), + phi=settings.rng.normal(), + theta_trainable=True, + phi_trainable=True, + ), + ) + circ = Circuit([state_in, s2_0, s2_1, bs]) + cutoff = 1 + i + k - free_var = math.new_variable([1.1, -0.2], None, "free_var") + free_var = math.new_variable([1.1, -0.2], None, "free_var") - def cost_fn(): - return tf.abs( - circ.contract().fock_array((cutoff,) * 4)[i, 1, i + k - 1, k] - ) ** 2 + tf.reduce_sum(free_var**2) + def cost_fn(): + return tf.abs( + circ.contract().fock_array((cutoff,) * 4)[i, 1, i + k - 1, k] + ) ** 2 + tf.reduce_sum(free_var**2) - tbcb = TensorboardCallback( - steps_per_call=2, - root_logdir=tmp_path, - cost_converter=np.log10, - track_grads=True, - ) + tbcb = TensorboardCallback( + steps_per_call=2, + root_logdir=tmp_path, + cost_converter=np.log10, + track_grads=True, + ) - opt = Optimizer(euclidean_lr=0.01) - opt.minimize(cost_fn, by_optimizing=[circ, free_var], max_steps=300, callbacks={"tb": tbcb}) + opt = Optimizer(euclidean_lr=0.01) + opt.minimize(cost_fn, by_optimizing=[circ, free_var], max_steps=300, callbacks={"tb": tbcb}) - assert np.allclose(np.cos(bs.parameters.theta.value) ** 2, k / (i + k), atol=1e-2) - assert tbcb.logdir.exists() - assert len(list(tbcb.writter_logdir.glob("events*"))) > 0 - assert len(opt.callback_history["tb"]) == (len(opt.opt_history) - 1) // tbcb.steps_per_call + assert np.allclose(np.cos(bs.parameters.theta.value) ** 2, k / (i + k), atol=1e-2) + assert tbcb.logdir.exists() + assert len(list(tbcb.writter_logdir.glob("events*"))) > 0 + assert len(opt.callback_history["tb"]) == (len(opt.opt_history) - 1) // tbcb.steps_per_call diff --git a/tests/test_training/test_opt_lab.py b/tests/test_training/test_opt_lab.py index 7902dd301..dd1ddf318 100644 --- a/tests/test_training/test_opt_lab.py +++ b/tests/test_training/test_opt_lab.py @@ -53,34 +53,34 @@ class TestOptimizer: @given(n=st.integers(0, 3)) def test_S2gate_coincidence_prob(self, n): """Testing the optimal probability of obtaining |n,n> from a two mode squeezed vacuum""" - settings.SEED = 40 - rng = tf.random.get_global_generator() - rng.reset_from_seed(settings.SEED) + with settings(SEED=40): + rng = tf.random.get_global_generator() + rng.reset_from_seed(settings.SEED) - S = TwoModeSqueezedVacuum( - (0, 1), r=abs(settings.rng.normal(loc=1.0, scale=0.1)), r_trainable=True - ) + S = TwoModeSqueezedVacuum( + (0, 1), r=abs(settings.rng.normal(loc=1.0, scale=0.1)), r_trainable=True + ) - def cost_fn(): - return -math.abs(S.fock_array((n + 1, n + 1))[n, n]) ** 2 + def cost_fn(): + return -math.abs(S.fock_array((n + 1, n + 1))[n, n]) ** 2 - def cb(optimizer, cost, trainables, **kwargs): # pylint: disable=unused-argument - return { - "cost": cost, - "lr": optimizer.learning_rate[update_euclidean], - "num_trainables": len(trainables), - } + def cb(optimizer, cost, trainables, **kwargs): # pylint: disable=unused-argument + return { + "cost": cost, + "lr": optimizer.learning_rate[update_euclidean], + "num_trainables": len(trainables), + } - opt = Optimizer(euclidean_lr=0.01) - opt.minimize(cost_fn, by_optimizing=[S], max_steps=300, callbacks=cb) + opt = Optimizer(euclidean_lr=0.01) + opt.minimize(cost_fn, by_optimizing=[S], max_steps=300, callbacks=cb) - expected = 1 / (n + 1) * (n / (n + 1)) ** n - assert np.allclose(-cost_fn(), expected, atol=1e-5) + expected = 1 / (n + 1) * (n / (n + 1)) ** n + assert np.allclose(-cost_fn(), expected, atol=1e-5) - cb_result = opt.callback_history.get("cb") - assert {res["num_trainables"] for res in cb_result} == {1} - assert {res["lr"] for res in cb_result} == {0.01} - assert [res["cost"] for res in cb_result] == opt.opt_history[1:] + cb_result = opt.callback_history.get("cb") + assert {res["num_trainables"] for res in cb_result} == {1} + assert {res["lr"] for res in cb_result} == {0.01} + assert [res["cost"] for res in cb_result] == opt.opt_history[1:] @given(i=st.integers(1, 5), k=st.integers(1, 5)) def test_hong_ou_mandel_optimizer(self, i, k): @@ -89,442 +89,450 @@ def test_hong_ou_mandel_optimizer(self, i, k): see Eq. 20 of https://journals.aps.org/prresearch/pdf/10.1103/PhysRevResearch.3.043065 which lacks a square root in the right hand side. """ - settings.SEED = 42 - rng = tf.random.get_global_generator() - rng.reset_from_seed(settings.SEED) - - r = np.arcsinh(1.0) - cutoff = 1 + i + k - - state = TwoModeSqueezedVacuum((0, 1), r=r, phi_trainable=True) - bs = BSgate( - (1, 2), - theta=np.arccos(np.sqrt(k / (i + k))) + 0.1 * settings.rng.normal(), - phi=settings.rng.normal(), - theta_trainable=True, - phi_trainable=True, - ) - circ = Circuit([state, state.on((2, 3)), bs]) + with settings(SEED=42): + rng = tf.random.get_global_generator() + rng.reset_from_seed(settings.SEED) + + r = np.arcsinh(1.0) + cutoff = 1 + i + k + + state = TwoModeSqueezedVacuum((0, 1), r=r, phi_trainable=True) + bs = BSgate( + (1, 2), + theta=np.arccos(np.sqrt(k / (i + k))) + 0.1 * settings.rng.normal(), + phi=settings.rng.normal(), + theta_trainable=True, + phi_trainable=True, + ) + circ = Circuit([state, state.on((2, 3)), bs]) - def cost_fn(): - return math.abs(circ.contract().fock_array((cutoff,) * 4)[i, 1, i + k - 1, k]) ** 2 - - opt = Optimizer(euclidean_lr=0.01) - opt.minimize( - cost_fn, - by_optimizing=[circ], - max_steps=300, - callbacks=[Callback(tag="null_cb", steps_per_call=3)], - ) - assert np.allclose(np.cos(bs.parameters.theta.value) ** 2, k / (i + k), atol=1e-2) - assert "null_cb" in opt.callback_history - assert len(opt.callback_history["null_cb"]) == (len(opt.opt_history) - 1) // 3 + def cost_fn(): + return math.abs(circ.contract().fock_array((cutoff,) * 4)[i, 1, i + k - 1, k]) ** 2 + + opt = Optimizer(euclidean_lr=0.01) + opt.minimize( + cost_fn, + by_optimizing=[circ], + max_steps=300, + callbacks=[Callback(tag="null_cb", steps_per_call=3)], + ) + assert np.allclose(np.cos(bs.parameters.theta.value) ** 2, k / (i + k), atol=1e-2) + assert "null_cb" in opt.callback_history + assert len(opt.callback_history["null_cb"]) == (len(opt.opt_history) - 1) // 3 def test_learning_two_mode_squeezing(self): """Finding the optimal beamsplitter transmission to make a pair of single photons""" - settings.SEED = 42 - rng = tf.random.get_global_generator() - rng.reset_from_seed(settings.SEED) - - state_in = Vacuum((0, 1)) - s_gate = Sgate( - 0, - r=abs(settings.rng.normal()), - phi=settings.rng.normal(), - r_trainable=True, - phi_trainable=True, - ) - bs_gate = BSgate( - (0, 1), - theta=settings.rng.normal(), - phi=settings.rng.normal(), - theta_trainable=True, - phi_trainable=True, - ) - circ = Circuit([state_in, s_gate, s_gate.on(1), bs_gate]) + with settings(SEED=42): + rng = tf.random.get_global_generator() + rng.reset_from_seed(settings.SEED) + + state_in = Vacuum((0, 1)) + s_gate = Sgate( + 0, + r=abs(settings.rng.normal()), + phi=settings.rng.normal(), + r_trainable=True, + phi_trainable=True, + ) + bs_gate = BSgate( + (0, 1), + theta=settings.rng.normal(), + phi=settings.rng.normal(), + theta_trainable=True, + phi_trainable=True, + ) + circ = Circuit([state_in, s_gate, s_gate.on(1), bs_gate]) - def cost_fn(): - amps = circ.contract().fock_array((2, 2)) - return -math.abs(amps[1, 1]) ** 2 + math.abs(amps[0, 1]) ** 2 + def cost_fn(): + amps = circ.contract().fock_array((2, 2)) + return -math.abs(amps[1, 1]) ** 2 + math.abs(amps[0, 1]) ** 2 - opt = Optimizer(euclidean_lr=0.05) + opt = Optimizer(euclidean_lr=0.05) - opt.minimize(cost_fn, by_optimizing=[circ], max_steps=300) - assert np.allclose(-cost_fn(), 0.25, atol=1e-5) + opt.minimize(cost_fn, by_optimizing=[circ], max_steps=300) + assert np.allclose(-cost_fn(), 0.25, atol=1e-5) def test_learning_two_mode_Ggate(self): """Finding the optimal Ggate to make a pair of single photons""" - settings.SEED = 42 - rng = tf.random.get_global_generator() - rng.reset_from_seed(settings.SEED) + with settings(SEED=42): + rng = tf.random.get_global_generator() + rng.reset_from_seed(settings.SEED) - G = GKet((0, 1), symplectic_trainable=True) + G = GKet((0, 1), symplectic_trainable=True) - def cost_fn(): - amps = G.fock_array((2, 2)) - return -math.abs(amps[1, 1]) ** 2 + math.abs(amps[0, 1]) ** 2 + def cost_fn(): + amps = G.fock_array((2, 2)) + return -math.abs(amps[1, 1]) ** 2 + math.abs(amps[0, 1]) ** 2 - opt = Optimizer(symplectic_lr=0.5, euclidean_lr=0.01) + opt = Optimizer(symplectic_lr=0.5, euclidean_lr=0.01) - opt.minimize(cost_fn, by_optimizing=[G], max_steps=500) - assert np.allclose(-cost_fn(), 0.25, atol=1e-4) + opt.minimize(cost_fn, by_optimizing=[G], max_steps=500) + assert np.allclose(-cost_fn(), 0.25, atol=1e-4) def test_learning_two_mode_Interferometer(self): """Finding the optimal Interferometer to make a pair of single photons""" - settings.SEED = 4 - rng = tf.random.get_global_generator() - rng.reset_from_seed(settings.SEED) - - state_in = Vacuum((0, 1)) - s_gate = Sgate( - 0, - r=settings.rng.normal() ** 2, - phi=settings.rng.normal(), - r_trainable=True, - phi_trainable=True, - ) - interferometer = Interferometer((0, 1), unitary_trainable=True) - circ = Circuit([state_in, s_gate, s_gate.on(1), interferometer]) + with settings(SEED=4): + rng = tf.random.get_global_generator() + rng.reset_from_seed(settings.SEED) + + state_in = Vacuum((0, 1)) + s_gate = Sgate( + 0, + r=settings.rng.normal() ** 2, + phi=settings.rng.normal(), + r_trainable=True, + phi_trainable=True, + ) + interferometer = Interferometer((0, 1), unitary_trainable=True) + circ = Circuit([state_in, s_gate, s_gate.on(1), interferometer]) - def cost_fn(): - amps = circ.contract().fock_array((2, 2)) - return -math.abs(amps[1, 1]) ** 2 + math.abs(amps[0, 1]) ** 2 + def cost_fn(): + amps = circ.contract().fock_array((2, 2)) + return -math.abs(amps[1, 1]) ** 2 + math.abs(amps[0, 1]) ** 2 - opt = Optimizer(unitary_lr=0.5, euclidean_lr=0.01) + opt = Optimizer(unitary_lr=0.5, euclidean_lr=0.01) - opt.minimize(cost_fn, by_optimizing=[circ], max_steps=1000) - assert np.allclose(-cost_fn(), 0.25, atol=1e-5) + opt.minimize(cost_fn, by_optimizing=[circ], max_steps=1000) + assert np.allclose(-cost_fn(), 0.25, atol=1e-5) def test_learning_two_mode_RealInterferometer(self): """Finding the optimal Interferometer to make a pair of single photons""" - settings.SEED = 2 - rng = tf.random.get_global_generator() - rng.reset_from_seed(settings.SEED) - - state_in = Vacuum((0, 1)) - s_gate0 = Sgate( - 0, - r=settings.rng.normal() ** 2, - phi=settings.rng.normal(), - r_trainable=True, - phi_trainable=True, - ) - s_gate1 = Sgate( - 1, - r=settings.rng.normal() ** 2, - phi=settings.rng.normal(), - r_trainable=True, - phi_trainable=True, - ) - r_inter = RealInterferometer((0, 1), orthogonal_trainable=True) - - circ = Circuit([state_in, s_gate0, s_gate1, r_inter]) + with settings(SEED=2): + rng = tf.random.get_global_generator() + rng.reset_from_seed(settings.SEED) + + state_in = Vacuum((0, 1)) + s_gate0 = Sgate( + 0, + r=settings.rng.normal() ** 2, + phi=settings.rng.normal(), + r_trainable=True, + phi_trainable=True, + ) + s_gate1 = Sgate( + 1, + r=settings.rng.normal() ** 2, + phi=settings.rng.normal(), + r_trainable=True, + phi_trainable=True, + ) + r_inter = RealInterferometer((0, 1), orthogonal_trainable=True) - def cost_fn(): - amps = circ.contract().fock_array((2, 2)) - return -math.abs(amps[1, 1]) ** 2 + math.abs(amps[0, 1]) ** 2 + circ = Circuit([state_in, s_gate0, s_gate1, r_inter]) + + def cost_fn(): + amps = circ.contract().fock_array((2, 2)) + return -math.abs(amps[1, 1]) ** 2 + math.abs(amps[0, 1]) ** 2 - opt = Optimizer(orthogonal_lr=0.5, euclidean_lr=0.01) + opt = Optimizer(orthogonal_lr=0.5, euclidean_lr=0.01) - opt.minimize(cost_fn, by_optimizing=[circ], max_steps=1000) - assert np.allclose(-cost_fn(), 0.25, atol=1e-5) + opt.minimize(cost_fn, by_optimizing=[circ], max_steps=1000) + assert np.allclose(-cost_fn(), 0.25, atol=1e-5) def test_learning_four_mode_Interferometer(self): """Finding the optimal Interferometer to make a NOON state with N=2""" - settings.SEED = 4 - rng = tf.random.get_global_generator() - rng.reset_from_seed(settings.SEED) + with settings(SEED=4): + rng = tf.random.get_global_generator() + rng.reset_from_seed(settings.SEED) - solution_U = np.array( - [ - [ - -0.47541806 + 0.00045878j, - -0.41513474 - 0.27218387j, - -0.11065812 - 0.39556922j, - -0.29912017 + 0.51900235j, - ], + solution_U = np.array( [ - -0.05246398 + 0.5209089j, - -0.29650069 - 0.40653082j, - 0.57434638 - 0.04417284j, - 0.28230532 - 0.24738672j, - ], - [ - 0.28437557 + 0.08773767j, - 0.18377764 - 0.66496587j, - -0.5874942 - 0.19866946j, - 0.2010813 - 0.10210844j, - ], - [ - -0.63173183 - 0.11057324j, - -0.03468292 + 0.15245454j, - -0.25390362 - 0.2244298j, - 0.18706333 - 0.64375049j, - ], - ] - ) - perturbed = ( - Interferometer((0, 1, 2, 3), unitary=solution_U) - >> BSgate((0, 1), settings.rng.normal(scale=0.01)) - >> BSgate((2, 3), settings.rng.normal(scale=0.01)) - >> BSgate((1, 2), settings.rng.normal(scale=0.01)) - >> BSgate((0, 3), settings.rng.normal(scale=0.01)) - ) - X = perturbed.symplectic - perturbed_U = X[:4, :4] + 1j * X[4:, :4] - - state_in = Vacuum((0, 1, 2, 3)) - s_gate = Sgate( - 0, - r=settings.rng.normal(loc=np.arcsinh(1.0), scale=0.01), - r_trainable=True, - ) - interferometer = Interferometer((0, 1, 2, 3), unitary=perturbed_U, unitary_trainable=True) - - circ = Circuit([state_in, s_gate, s_gate.on(1), s_gate.on(2), s_gate.on(3), interferometer]) + [ + -0.47541806 + 0.00045878j, + -0.41513474 - 0.27218387j, + -0.11065812 - 0.39556922j, + -0.29912017 + 0.51900235j, + ], + [ + -0.05246398 + 0.5209089j, + -0.29650069 - 0.40653082j, + 0.57434638 - 0.04417284j, + 0.28230532 - 0.24738672j, + ], + [ + 0.28437557 + 0.08773767j, + 0.18377764 - 0.66496587j, + -0.5874942 - 0.19866946j, + 0.2010813 - 0.10210844j, + ], + [ + -0.63173183 - 0.11057324j, + -0.03468292 + 0.15245454j, + -0.25390362 - 0.2244298j, + 0.18706333 - 0.64375049j, + ], + ] + ) + perturbed = ( + Interferometer((0, 1, 2, 3), unitary=solution_U) + >> BSgate((0, 1), settings.rng.normal(scale=0.01)) + >> BSgate((2, 3), settings.rng.normal(scale=0.01)) + >> BSgate((1, 2), settings.rng.normal(scale=0.01)) + >> BSgate((0, 3), settings.rng.normal(scale=0.01)) + ) + X = perturbed.symplectic + perturbed_U = X[:4, :4] + 1j * X[4:, :4] + + state_in = Vacuum((0, 1, 2, 3)) + s_gate = Sgate( + 0, + r=settings.rng.normal(loc=np.arcsinh(1.0), scale=0.01), + r_trainable=True, + ) + interferometer = Interferometer( + (0, 1, 2, 3), unitary=perturbed_U, unitary_trainable=True + ) - def cost_fn(): - amps = circ.contract().fock_array((3, 3, 3, 3)) - return -math.abs((amps[1, 1, 2, 0] + amps[1, 1, 0, 2]) / np.sqrt(2)) ** 2 + circ = Circuit( + [state_in, s_gate, s_gate.on(1), s_gate.on(2), s_gate.on(3), interferometer] + ) - opt = Optimizer(unitary_lr=0.05) - opt.minimize(cost_fn, by_optimizing=[circ], max_steps=200) - assert np.allclose(-cost_fn(), 0.0625, atol=1e-5) + def cost_fn(): + amps = circ.contract().fock_array((3, 3, 3, 3)) + return -math.abs((amps[1, 1, 2, 0] + amps[1, 1, 0, 2]) / np.sqrt(2)) ** 2 + + opt = Optimizer(unitary_lr=0.05) + opt.minimize(cost_fn, by_optimizing=[circ], max_steps=200) + assert np.allclose(-cost_fn(), 0.0625, atol=1e-5) def test_learning_four_mode_RealInterferometer(self): """Finding the optimal Interferometer to make a NOON state with N=2""" - settings.SEED = 6 - rng = tf.random.get_global_generator() - rng.reset_from_seed(settings.SEED) - - solution_O = np.array( - [ - [0.5, -0.5, 0.5, 0.5], - [-0.5, -0.5, -0.5, 0.5], - [0.5, 0.5, -0.5, 0.5], - [0.5, -0.5, -0.5, -0.5], - ] - ) - pertubed = ( - RealInterferometer((0, 1, 2, 3), orthogonal=solution_O) - >> BSgate((0, 1), settings.rng.normal(scale=0.01)) - >> BSgate((2, 3), settings.rng.normal(scale=0.01)) - >> BSgate((1, 2), settings.rng.normal(scale=0.01)) - >> BSgate((0, 3), settings.rng.normal(scale=0.01)) - ) - perturbed_O = pertubed.symplectic[:4, :4] - - state_in = Vacuum((0, 1, 2, 3)) - s_gate0 = Sgate( - 0, - r=np.arcsinh(1.0) + settings.rng.normal(scale=0.01), - phi=settings.rng.normal(scale=0.01), - r_trainable=True, - phi_trainable=True, - ) - s_gate1 = Sgate( - 1, - r=np.arcsinh(1.0) + settings.rng.normal(scale=0.01), - phi=(np.pi / 2) + settings.rng.normal(scale=0.01), - r_trainable=True, - phi_trainable=True, - ) - s_gate2 = Sgate( - 2, - r=np.arcsinh(1.0) + settings.rng.normal(scale=0.01), - phi=-np.pi + settings.rng.normal(scale=0.01), - r_trainable=True, - phi_trainable=True, - ) - s_gate3 = Sgate( - 3, - r=np.arcsinh(1.0) + settings.rng.normal(scale=0.01), - phi=(-np.pi / 2) + settings.rng.normal(scale=0.01), - r_trainable=True, - phi_trainable=True, - ) - r_inter = RealInterferometer( - (0, 1, 2, 3), orthogonal=perturbed_O, orthogonal_trainable=True - ) - - circ = Circuit([state_in, s_gate0, s_gate1, s_gate2, s_gate3, r_inter]) + with settings(SEED=6): + rng = tf.random.get_global_generator() + rng.reset_from_seed(settings.SEED) - def cost_fn(): - amps = circ.contract().fock_array((2, 2, 3, 3)) - return -math.abs((amps[1, 1, 0, 2] + amps[1, 1, 2, 0]) / np.sqrt(2)) ** 2 + solution_O = np.array( + [ + [0.5, -0.5, 0.5, 0.5], + [-0.5, -0.5, -0.5, 0.5], + [0.5, 0.5, -0.5, 0.5], + [0.5, -0.5, -0.5, -0.5], + ] + ) + pertubed = ( + RealInterferometer((0, 1, 2, 3), orthogonal=solution_O) + >> BSgate((0, 1), settings.rng.normal(scale=0.01)) + >> BSgate((2, 3), settings.rng.normal(scale=0.01)) + >> BSgate((1, 2), settings.rng.normal(scale=0.01)) + >> BSgate((0, 3), settings.rng.normal(scale=0.01)) + ) + perturbed_O = pertubed.symplectic[:4, :4] + + state_in = Vacuum((0, 1, 2, 3)) + s_gate0 = Sgate( + 0, + r=np.arcsinh(1.0) + settings.rng.normal(scale=0.01), + phi=settings.rng.normal(scale=0.01), + r_trainable=True, + phi_trainable=True, + ) + s_gate1 = Sgate( + 1, + r=np.arcsinh(1.0) + settings.rng.normal(scale=0.01), + phi=(np.pi / 2) + settings.rng.normal(scale=0.01), + r_trainable=True, + phi_trainable=True, + ) + s_gate2 = Sgate( + 2, + r=np.arcsinh(1.0) + settings.rng.normal(scale=0.01), + phi=-np.pi + settings.rng.normal(scale=0.01), + r_trainable=True, + phi_trainable=True, + ) + s_gate3 = Sgate( + 3, + r=np.arcsinh(1.0) + settings.rng.normal(scale=0.01), + phi=(-np.pi / 2) + settings.rng.normal(scale=0.01), + r_trainable=True, + phi_trainable=True, + ) + r_inter = RealInterferometer( + (0, 1, 2, 3), orthogonal=perturbed_O, orthogonal_trainable=True + ) + + circ = Circuit([state_in, s_gate0, s_gate1, s_gate2, s_gate3, r_inter]) - opt = Optimizer() + def cost_fn(): + amps = circ.contract().fock_array((2, 2, 3, 3)) + return -math.abs((amps[1, 1, 0, 2] + amps[1, 1, 2, 0]) / np.sqrt(2)) ** 2 - opt.minimize(cost_fn, by_optimizing=[circ], max_steps=200) - assert np.allclose(-cost_fn(), 0.0625, atol=1e-5) + opt = Optimizer() + + opt.minimize(cost_fn, by_optimizing=[circ], max_steps=200) + assert np.allclose(-cost_fn(), 0.0625, atol=1e-5) def test_squeezing_hong_ou_mandel_optimizer(self): """Finding the optimal squeezing parameter to get Hong-Ou-Mandel dip in time see https://www.pnas.org/content/117/52/33107/tab-article-info """ - settings.SEED = 42 - rng = tf.random.get_global_generator() - rng.reset_from_seed(settings.SEED) + with settings(SEED=42): + rng = tf.random.get_global_generator() + rng.reset_from_seed(settings.SEED) - r = np.arcsinh(1.0) + r = np.arcsinh(1.0) - state_in = Vacuum((0, 1, 2, 3)) - S_01 = S2gate((0, 1), r=r, phi=0.0, phi_trainable=True) - S_23 = S2gate((2, 3), r=r, phi=0.0, phi_trainable=True) - S_12 = S2gate( - (1, 2), r=1.0, phi=settings.rng.normal(), r_trainable=True, phi_trainable=True - ) + state_in = Vacuum((0, 1, 2, 3)) + S_01 = S2gate((0, 1), r=r, phi=0.0, phi_trainable=True) + S_23 = S2gate((2, 3), r=r, phi=0.0, phi_trainable=True) + S_12 = S2gate( + (1, 2), r=1.0, phi=settings.rng.normal(), r_trainable=True, phi_trainable=True + ) - circ = Circuit([state_in, S_01, S_23, S_12]) + circ = Circuit([state_in, S_01, S_23, S_12]) - def cost_fn(): - return math.abs(circ.contract().fock_array((2, 2, 2, 2))[1, 1, 1, 1]) ** 2 + def cost_fn(): + return math.abs(circ.contract().fock_array((2, 2, 2, 2))[1, 1, 1, 1]) ** 2 - opt = Optimizer(euclidean_lr=0.001) - opt.minimize(cost_fn, by_optimizing=[circ], max_steps=300) - assert np.allclose(np.sinh(S_12.parameters.r.value) ** 2, 1, atol=1e-2) + opt = Optimizer(euclidean_lr=0.001) + opt.minimize(cost_fn, by_optimizing=[circ], max_steps=300) + assert np.allclose(np.sinh(S_12.parameters.r.value) ** 2, 1, atol=1e-2) def test_parameter_passthrough(self): """Same as the test above, but with param passthrough""" - settings.SEED = 42 - rng = tf.random.get_global_generator() - rng.reset_from_seed(settings.SEED) + with settings(SEED=42): + rng = tf.random.get_global_generator() + rng.reset_from_seed(settings.SEED) - r = np.arcsinh(1.0) - r_var = Variable(r, "r", (0.0, None)) - phi_var = Variable(settings.rng.normal(), "phi", (None, None)) + r = np.arcsinh(1.0) + r_var = Variable(r, "r", (0.0, None)) + phi_var = Variable(settings.rng.normal(), "phi", (None, None)) - state_in = Vacuum((0, 1, 2, 3)) - s2_gate0 = S2gate((0, 1), r=r, phi=0.0, phi_trainable=True) - s2_gate1 = S2gate((2, 3), r=r, phi=0.0, phi_trainable=True) - s2_gate2 = S2gate((1, 2), r=r_var, phi=phi_var) + state_in = Vacuum((0, 1, 2, 3)) + s2_gate0 = S2gate((0, 1), r=r, phi=0.0, phi_trainable=True) + s2_gate1 = S2gate((2, 3), r=r, phi=0.0, phi_trainable=True) + s2_gate2 = S2gate((1, 2), r=r_var, phi=phi_var) - circ = Circuit([state_in, s2_gate0, s2_gate1, s2_gate2]) + circ = Circuit([state_in, s2_gate0, s2_gate1, s2_gate2]) - def cost_fn(): - return math.abs(circ.contract().fock_array((2, 2, 2, 2))[1, 1, 1, 1]) ** 2 + def cost_fn(): + return math.abs(circ.contract().fock_array((2, 2, 2, 2))[1, 1, 1, 1]) ** 2 - opt = Optimizer(euclidean_lr=0.001) - opt.minimize(cost_fn, by_optimizing=[r_var, phi_var], max_steps=300) - assert np.allclose(np.sinh(r_var.value) ** 2, 1, atol=1e-2) + opt = Optimizer(euclidean_lr=0.001) + opt.minimize(cost_fn, by_optimizing=[r_var, phi_var], max_steps=300) + assert np.allclose(np.sinh(r_var.value) ** 2, 1, atol=1e-2) def test_making_thermal_state_as_one_half_two_mode_squeezed_vacuum(self): """Optimizes a Ggate on two modes so as to prepare a state with the same entropy and mean photon number as a thermal state""" - settings.SEED = 42 - rng = tf.random.get_global_generator() - rng.reset_from_seed(settings.SEED) + with settings(SEED=42): + rng = tf.random.get_global_generator() + rng.reset_from_seed(settings.SEED) - def thermal_entropy(nbar): - return -(nbar * np.log((nbar) / (1 + nbar)) - np.log(1 + nbar)) + def thermal_entropy(nbar): + return -(nbar * np.log((nbar) / (1 + nbar)) - np.log(1 + nbar)) - nbar = 1.4 - S_init = two_mode_squeezing(np.arcsinh(1.0), 0.0) - S = thermal_entropy(nbar) + nbar = 1.4 + S_init = two_mode_squeezing(np.arcsinh(1.0), 0.0) + S = thermal_entropy(nbar) - G = Ggate((0, 1), symplectic=S_init, symplectic_trainable=True) + G = Ggate((0, 1), symplectic=S_init, symplectic_trainable=True) - def cost_fn(): - state = Vacuum((0, 1)) >> G + def cost_fn(): + state = Vacuum((0, 1)) >> G - state0 = state[0] - state1 = state[1] + state0 = state[0] + state1 = state[1] - cov0, mean0, _ = state0.phase_space(s=0) - cov1, mean1, _ = state1.phase_space(s=0) + cov0, mean0, _ = state0.phase_space(s=0) + cov1, mean1, _ = state1.phase_space(s=0) - num_mean0 = number_means(cov0, mean0)[0] - num_mean1 = number_means(cov1, mean1)[0] + num_mean0 = number_means(cov0, mean0)[0] + num_mean1 = number_means(cov1, mean1)[0] - entropy = von_neumann_entropy(cov0) - return (num_mean0 - nbar) ** 2 + (entropy - S) ** 2 + (num_mean1 - nbar) ** 2 + entropy = von_neumann_entropy(cov0) + return (num_mean0 - nbar) ** 2 + (entropy - S) ** 2 + (num_mean1 - nbar) ** 2 - opt = Optimizer(symplectic_lr=0.1) - opt.minimize(cost_fn, by_optimizing=[G], max_steps=50) - S = math.asnumpy(G.parameters.symplectic.value) - cov = S @ S.T - assert np.allclose(cov, two_mode_squeezing(2 * np.arcsinh(np.sqrt(nbar)), 0.0)) + opt = Optimizer(symplectic_lr=0.1) + opt.minimize(cost_fn, by_optimizing=[G], max_steps=50) + S = math.asnumpy(G.parameters.symplectic.value) + cov = S @ S.T + assert np.allclose(cov, two_mode_squeezing(2 * np.arcsinh(np.sqrt(nbar)), 0.0)) def test_opt_backend_param(self): """Test the optimization of a backend parameter defined outside a gate.""" # rotated displaced squeezed state - settings.SEED = 42 - rng = tf.random.get_global_generator() - rng.reset_from_seed(settings.SEED) + with settings(SEED=42): + rng = tf.random.get_global_generator() + rng.reset_from_seed(settings.SEED) - rotation_angle = np.pi / 2 - target_state = SqueezedVacuum(0, r=1.0, phi=rotation_angle) + rotation_angle = np.pi / 2 + target_state = SqueezedVacuum(0, r=1.0, phi=rotation_angle) - # angle of rotation gate - r_angle = math.new_variable(0, bounds=(0, np.pi), name="r_angle") - # trainable squeezing - S = Sgate(0, r=0.1, r_trainable=True) + # angle of rotation gate + r_angle = math.new_variable(0, bounds=(0, np.pi), name="r_angle") + # trainable squeezing + S = Sgate(0, r=0.1, r_trainable=True) - def cost_fn_sympl(): - state_out = Vacuum(0) >> S >> Rgate(0, theta=r_angle) - return 1 - math.abs((state_out >> target_state.dual) ** 2) + def cost_fn_sympl(): + state_out = Vacuum(0) >> S >> Rgate(0, theta=r_angle) + return 1 - math.abs((state_out >> target_state.dual) ** 2) - opt = Optimizer(symplectic_lr=0.1, euclidean_lr=0.05) - opt.minimize(cost_fn_sympl, by_optimizing=[S, r_angle]) + opt = Optimizer(symplectic_lr=0.1, euclidean_lr=0.05) + opt.minimize(cost_fn_sympl, by_optimizing=[S, r_angle]) - assert np.allclose(math.asnumpy(r_angle), rotation_angle / 2, atol=1e-4) + assert np.allclose(math.asnumpy(r_angle), rotation_angle / 2, atol=1e-4) def test_dgate_optimization(self): """Test that Dgate is optimized correctly.""" - settings.SEED = 24 - rng = tf.random.get_global_generator() - rng.reset_from_seed(settings.SEED) + with settings(SEED=24): + rng = tf.random.get_global_generator() + rng.reset_from_seed(settings.SEED) - dgate = Dgate(0, x_trainable=True, y_trainable=True) - target_state = DisplacedSqueezed(0, r=0.0, x=0.1, y=0.2).fock_array((40,)) + dgate = Dgate(0, x_trainable=True, y_trainable=True) + target_state = DisplacedSqueezed(0, r=0.0, x=0.1, y=0.2).fock_array((40,)) - def cost_fn(): - state_out = Vacuum(0) >> dgate - return -math.abs(math.sum(math.conj(state_out.fock_array((40,))) * target_state)) ** 2 + def cost_fn(): + state_out = Vacuum(0) >> dgate + return ( + -math.abs(math.sum(math.conj(state_out.fock_array((40,))) * target_state)) ** 2 + ) - opt = Optimizer() - opt.minimize(cost_fn, by_optimizing=[dgate]) + opt = Optimizer() + opt.minimize(cost_fn, by_optimizing=[dgate]) - assert np.allclose(dgate.parameters.x.value, 0.1, atol=0.01) - assert np.allclose(dgate.parameters.y.value, 0.2, atol=0.01) + assert np.allclose(dgate.parameters.x.value, 0.1, atol=0.01) + assert np.allclose(dgate.parameters.y.value, 0.2, atol=0.01) def test_sgate_optimization(self): """Test that Sgate is optimized correctly.""" - settings.SEED = 25 - rng = tf.random.get_global_generator() - rng.reset_from_seed(settings.SEED) + with settings(SEED=25): + rng = tf.random.get_global_generator() + rng.reset_from_seed(settings.SEED) - sgate = Sgate(0, r=0.2, phi=0.1, r_trainable=True, phi_trainable=True) - target_state = SqueezedVacuum(0, r=0.1, phi=0.2).fock_array((40,)) + sgate = Sgate(0, r=0.2, phi=0.1, r_trainable=True, phi_trainable=True) + target_state = SqueezedVacuum(0, r=0.1, phi=0.2).fock_array((40,)) - def cost_fn(): - state_out = Vacuum(0) >> sgate + def cost_fn(): + state_out = Vacuum(0) >> sgate - return -math.abs(math.sum(math.conj(state_out.fock_array((40,))) * target_state)) ** 2 + return ( + -math.abs(math.sum(math.conj(state_out.fock_array((40,))) * target_state)) ** 2 + ) - opt = Optimizer() - opt.minimize(cost_fn, by_optimizing=[sgate]) + opt = Optimizer() + opt.minimize(cost_fn, by_optimizing=[sgate]) - assert np.allclose(sgate.parameters.r.value, 0.1, atol=0.01) - assert np.allclose(sgate.parameters.phi.value, 0.2, atol=0.01) + assert np.allclose(sgate.parameters.r.value, 0.1, atol=0.01) + assert np.allclose(sgate.parameters.phi.value, 0.2, atol=0.01) def test_bsgate_optimization(self): """Test that BSgate is optimized correctly.""" - settings.SEED = 25 - rng = tf.random.get_global_generator() - rng.reset_from_seed(settings.SEED) + with settings(SEED=25): + rng = tf.random.get_global_generator() + rng.reset_from_seed(settings.SEED) - bsgate = BSgate((0, 1), 0.05, 0.1, theta_trainable=True, phi_trainable=True) - target_gate = BSgate((0, 1), 0.1, 0.2).fock_array(40) + bsgate = BSgate((0, 1), 0.05, 0.1, theta_trainable=True, phi_trainable=True) + target_gate = BSgate((0, 1), 0.1, 0.2).fock_array(40) - def cost_fn(): - return -math.abs(math.sum(math.conj(bsgate.fock_array(40)) * target_gate)) ** 2 + def cost_fn(): + return -math.abs(math.sum(math.conj(bsgate.fock_array(40)) * target_gate)) ** 2 - opt = Optimizer() - opt.minimize(cost_fn, by_optimizing=[bsgate]) + opt = Optimizer() + opt.minimize(cost_fn, by_optimizing=[bsgate]) - assert np.allclose(bsgate.parameters.theta.value, 0.1, atol=0.01) - assert np.allclose(bsgate.parameters.phi.value, 0.2, atol=0.01) + assert np.allclose(bsgate.parameters.theta.value, 0.1, atol=0.01) + assert np.allclose(bsgate.parameters.phi.value, 0.2, atol=0.01) def test_squeezing_grad_from_fock(self): """Test that the gradient of a squeezing gate is computed from the fock representation.""" From a257d2fab4b6c9759d1bd21ad3d6e919f871492e Mon Sep 17 00:00:00 2001 From: Anthony Date: Wed, 18 Jun 2025 12:30:14 -0400 Subject: [PATCH 02/17] cf --- mrmustard/math/backend_manager.py | 2 +- tests/test_physics/test_bargmann_utils.py | 3 +-- tests/test_physics/test_fock_utils.py | 1 - 3 files changed, 2 insertions(+), 4 deletions(-) diff --git a/mrmustard/math/backend_manager.py b/mrmustard/math/backend_manager.py index 100152269..95b0960fe 100644 --- a/mrmustard/math/backend_manager.py +++ b/mrmustard/math/backend_manager.py @@ -89,7 +89,7 @@ def lazy_import(module_name: str): } -class BackendManager: +class BackendManager: # pylint: disable=too-many-public-methods r""" A class to manage the different backends supported by Mr Mustard. """ diff --git a/tests/test_physics/test_bargmann_utils.py b/tests/test_physics/test_bargmann_utils.py index 2e33091a2..234586086 100644 --- a/tests/test_physics/test_bargmann_utils.py +++ b/tests/test_physics/test_bargmann_utils.py @@ -16,8 +16,7 @@ import numpy as np -from mrmustard import math -from mrmustard.lab import DM, Channel, Dgate, Ket, Unitary, Vacuum +from mrmustard.lab import DM, Channel, Dgate, Ket from mrmustard.physics.bargmann_utils import ( XY_of_channel, au2Symplectic, diff --git a/tests/test_physics/test_fock_utils.py b/tests/test_physics/test_fock_utils.py index e2a68a584..46e6114af 100644 --- a/tests/test_physics/test_fock_utils.py +++ b/tests/test_physics/test_fock_utils.py @@ -27,7 +27,6 @@ Attenuator, BSgate, Coherent, - Number, S2gate, SqueezedVacuum, TwoModeSqueezedVacuum, From bcff422b75c18735e9ced0b17407e26f5366b2ba Mon Sep 17 00:00:00 2001 From: Anthony Date: Wed, 18 Jun 2025 12:47:12 -0400 Subject: [PATCH 03/17] rev moveaxis --- mrmustard/math/backend_jax.py | 6 ++++++ mrmustard/math/backend_manager.py | 19 +++++++++++++++++++ mrmustard/math/backend_numpy.py | 5 +++++ mrmustard/math/backend_tensorflow.py | 5 +++++ tests/test_math/test_backend_manager.py | 13 +++++++++++++ 5 files changed, 48 insertions(+) diff --git a/mrmustard/math/backend_jax.py b/mrmustard/math/backend_jax.py index 245603b25..e278c68a2 100644 --- a/mrmustard/math/backend_jax.py +++ b/mrmustard/math/backend_jax.py @@ -270,6 +270,12 @@ def matmul(self, *matrices: jnp.ndarray) -> jnp.ndarray: mat = jnp.linalg.multi_dot(matrices) return mat + @partial(jax.jit, static_argnames=["old", "new"]) + def moveaxis( + self, array: jnp.ndarray, old: int | Sequence[int], new: int | Sequence[int] + ) -> jnp.ndarray: + return jnp.moveaxis(array, old, new) + def ones(self, shape: Sequence[int], dtype=None) -> jnp.ndarray: dtype = dtype or self.float64 return jnp.ones(shape, dtype=dtype) diff --git a/mrmustard/math/backend_manager.py b/mrmustard/math/backend_manager.py index 95b0960fe..a9a80f266 100644 --- a/mrmustard/math/backend_manager.py +++ b/mrmustard/math/backend_manager.py @@ -844,6 +844,25 @@ def matvec(self, a: Matrix, b: Vector) -> Tensor: """ return self._apply("matvec", (a, b)) + def moveaxis(self, array: Tensor, old: Tensor, new: Tensor) -> Tensor: + r""" + Moves the axes of an array to a new position. + Args: + array: The array to move the axes of. + old: The old index position + new: The new index position + Returns: + The updated array + """ + return self._apply( + "moveaxis", + ( + array, + old, + new, + ), + ) + def new_variable( self, value: Tensor, diff --git a/mrmustard/math/backend_numpy.py b/mrmustard/math/backend_numpy.py index fb3a29a08..644afd6ef 100644 --- a/mrmustard/math/backend_numpy.py +++ b/mrmustard/math/backend_numpy.py @@ -201,6 +201,11 @@ def matmul(self, *matrices: np.ndarray) -> np.ndarray: def matvec(self, a: np.ndarray, b: np.ndarray) -> np.ndarray: return self.matmul(a, b[..., None])[..., 0] + def moveaxis( + self, array: np.ndarray, old: int | Sequence[int], new: int | Sequence[int] + ) -> np.ndarray: + return np.moveaxis(array, old, new) + def new_variable( self, value, diff --git a/mrmustard/math/backend_tensorflow.py b/mrmustard/math/backend_tensorflow.py index 1f4eca58c..3e340d278 100644 --- a/mrmustard/math/backend_tensorflow.py +++ b/mrmustard/math/backend_tensorflow.py @@ -229,6 +229,11 @@ def matvec(self, a: tf.Tensor, b: tf.Tensor) -> tf.Tensor: def make_complex(self, real: tf.Tensor, imag: tf.Tensor) -> tf.Tensor: return tf.complex(real, imag) + def moveaxis( + self, array: tf.Tensor, old: int | Sequence[int], new: int | Sequence[int] + ) -> tf.Tensor: + return tf.experimental.numpy.moveaxis(array, old, new) + def new_variable( self, value, diff --git a/tests/test_math/test_backend_manager.py b/tests/test_math/test_backend_manager.py index 3fc133ee1..e70db9d2b 100644 --- a/tests/test_math/test_backend_manager.py +++ b/tests/test_math/test_backend_manager.py @@ -440,6 +440,19 @@ def test_make_complex(self): i = 2.0 assert math.asnumpy(math.make_complex(r, i)) == r + i * 1j + def test_moveaxis(self): + r""" + Tests the ``moveaxis`` method. + """ + arr1 = np.random.random(size=(1, 2, 3)) + arr2 = np.random.random(size=(2, 1, 3)) + arr2_moved = math.moveaxis(arr2, 0, 1) + assert math.allclose(arr1.shape, arr2_moved.shape) + + arr1_moved1 = math.moveaxis(arr1, 0, 1) + arr1_moved2 = math.moveaxis(arr1_moved1, 1, 0) + assert math.allclose(arr1, arr1_moved2) + @pytest.mark.parametrize("t", types) def test_new_variable(self, t): r""" From 03301aa8a6e2d7b33cb917f0aea81c6033d46bf0 Mon Sep 17 00:00:00 2001 From: Anthony <125415978+apchytr@users.noreply.github.com> Date: Wed, 18 Jun 2025 14:30:15 -0400 Subject: [PATCH 04/17] Revert random_unitary change --- mrmustard/math/backend_manager.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mrmustard/math/backend_manager.py b/mrmustard/math/backend_manager.py index a9a80f266..1aae64db3 100644 --- a/mrmustard/math/backend_manager.py +++ b/mrmustard/math/backend_manager.py @@ -1445,7 +1445,7 @@ def random_unitary(self, N: int) -> Tensor: A random unitary matrix in :math:`U(N)`. """ if N == 1: - return np.exp(1j * settings.rng.uniform(size=(1, 1))) + return self.exp(1j * settings.rng.uniform(size=(1, 1))) return unitary_group.rvs(dim=N, random_state=settings.rng) @staticmethod From 40e3b5d5a5a76987f9f489890e052be8f70c7f64 Mon Sep 17 00:00:00 2001 From: Anthony Date: Fri, 20 Jun 2025 08:41:09 -0400 Subject: [PATCH 05/17] rev angle --- mrmustard/math/backend_jax.py | 4 ++++ mrmustard/math/backend_manager.py | 12 ++++++++++++ mrmustard/math/backend_numpy.py | 3 +++ mrmustard/math/backend_tensorflow.py | 3 +++ tests/test_math/test_backend_manager.py | 7 +++++++ 5 files changed, 29 insertions(+) diff --git a/mrmustard/math/backend_jax.py b/mrmustard/math/backend_jax.py index e278c68a2..5bcc4b01a 100644 --- a/mrmustard/math/backend_jax.py +++ b/mrmustard/math/backend_jax.py @@ -75,6 +75,10 @@ def abs(self, array: jnp.ndarray) -> jnp.ndarray: def all(self, array: jnp.ndarray) -> jnp.ndarray: return jnp.all(array) + @jax.jit + def angle(self, array: jnp.ndarray) -> jnp.ndarray: + return jnp.angle(array) + @jax.jit def any(self, array: jnp.ndarray) -> jnp.ndarray: return jnp.any(array) diff --git a/mrmustard/math/backend_manager.py b/mrmustard/math/backend_manager.py index 1aae64db3..f2953a21a 100644 --- a/mrmustard/math/backend_manager.py +++ b/mrmustard/math/backend_manager.py @@ -254,6 +254,18 @@ def allclose(self, array1: Tensor, array2: Tensor, atol=1e-9, rtol=1e-5) -> bool array2 = self.astensor(array2) return self._apply("allclose", (array1, array2, atol, rtol)) + def angle(self, array: Tensor) -> Tensor: + r""" + The complex phase of ``array``. + + Args: + array: The array to take the complex phase of. + + Returns: + The complex phase of ``array``. + """ + return self._apply("angle", (array,)) + def any(self, array: Tensor) -> bool: r"""Returns ``True`` if any element of array is ``True``, ``False`` otherwise. diff --git a/mrmustard/math/backend_numpy.py b/mrmustard/math/backend_numpy.py index 644afd6ef..ac87ed1b7 100644 --- a/mrmustard/math/backend_numpy.py +++ b/mrmustard/math/backend_numpy.py @@ -66,6 +66,9 @@ def all(self, array: np.ndarray) -> bool: def allclose(self, array1: np.array, array2: np.array, atol: float, rtol: float) -> bool: return np.allclose(array1, array2, atol=atol, rtol=rtol) + def angle(self, array: np.ndarray) -> np.ndarray: + return np.angle(array) + def any(self, array: np.ndarray) -> np.ndarray: return np.any(array) diff --git a/mrmustard/math/backend_tensorflow.py b/mrmustard/math/backend_tensorflow.py index 3e340d278..a5e977236 100644 --- a/mrmustard/math/backend_tensorflow.py +++ b/mrmustard/math/backend_tensorflow.py @@ -75,6 +75,9 @@ def all(self, array: tf.Tensor) -> tf.Tensor: def allclose(self, array1: np.array, array2: np.array, atol: float, rtol: float) -> bool: return tf.experimental.numpy.allclose(array1, array2, atol=atol, rtol=rtol) + def angle(self, array: tf.Tensor) -> tf.Tensor: + return tf.experimental.numpy.angle(array) + def any(self, array: tf.Tensor) -> tf.Tensor: return tf.math.reduce_any(array) diff --git a/tests/test_math/test_backend_manager.py b/tests/test_math/test_backend_manager.py index 4a93edce6..f76967475 100644 --- a/tests/test_math/test_backend_manager.py +++ b/tests/test_math/test_backend_manager.py @@ -103,6 +103,13 @@ def test_allclose_error(self): with pytest.raises(ValueError, match="Incompatible shapes"): math.allclose(arr2, arr1) + def test_angle(self): + r""" + Tests the ``angle`` method. + """ + arr = math.astensor([1.0 + 1.0j, 2.0 + 2.0j, 3.0 + 3.0j, 4.0 + 4.0j]) + assert math.allclose(math.asnumpy(math.angle(arr)), np.angle(arr)) + @pytest.mark.parametrize("l", lists) def test_any(self, l): r""" From f47de239beae3642efd2ceb8c9600c34a695e70c Mon Sep 17 00:00:00 2001 From: Anthony Date: Fri, 20 Jun 2025 08:47:49 -0400 Subject: [PATCH 06/17] rev maximum and minimum --- mrmustard/math/backend_jax.py | 8 ++++++ mrmustard/math/backend_manager.py | 38 +++++++++++++++++++++++++ mrmustard/math/backend_numpy.py | 6 ++++ mrmustard/math/backend_tensorflow.py | 8 ++++++ tests/test_math/test_backend_manager.py | 18 ++++++++++++ 5 files changed, 78 insertions(+) diff --git a/mrmustard/math/backend_jax.py b/mrmustard/math/backend_jax.py index 5bcc4b01a..1c52ebb3c 100644 --- a/mrmustard/math/backend_jax.py +++ b/mrmustard/math/backend_jax.py @@ -274,6 +274,14 @@ def matmul(self, *matrices: jnp.ndarray) -> jnp.ndarray: mat = jnp.linalg.multi_dot(matrices) return mat + @jax.jit + def maximum(self, a: jnp.ndarray, b: jnp.ndarray) -> jnp.ndarray: + return jnp.maximum(a, b) + + @jax.jit + def minimum(self, a: jnp.ndarray, b: jnp.ndarray) -> jnp.ndarray: + return jnp.minimum(a, b) + @partial(jax.jit, static_argnames=["old", "new"]) def moveaxis( self, array: jnp.ndarray, old: int | Sequence[int], new: int | Sequence[int] diff --git a/mrmustard/math/backend_manager.py b/mrmustard/math/backend_manager.py index f2953a21a..4c9582e28 100644 --- a/mrmustard/math/backend_manager.py +++ b/mrmustard/math/backend_manager.py @@ -856,6 +856,44 @@ def matvec(self, a: Matrix, b: Vector) -> Tensor: """ return self._apply("matvec", (a, b)) + def maximum(self, a: Tensor, b: Tensor) -> Tensor: + r""" + The element-wise maximum of ``a`` and ``b``. + + Args: + a: The first array to take the maximum of. + b: The second array to take the maximum of. + + Returns: + The element-wise maximum of ``a`` and ``b`` + """ + return self._apply( + "maximum", + ( + a, + b, + ), + ) + + def minimum(self, a: Tensor, b: Tensor) -> Tensor: + r""" + The element-wise minimum of ``a`` and ``b``. + + Args: + a: The first array to take the minimum of. + b: The second array to take the minimum of. + + Returns: + The element-wise minimum of ``a`` and ``b`` + """ + return self._apply( + "minimum", + ( + a, + b, + ), + ) + def moveaxis(self, array: Tensor, old: Tensor, new: Tensor) -> Tensor: r""" Moves the axes of an array to a new position. diff --git a/mrmustard/math/backend_numpy.py b/mrmustard/math/backend_numpy.py index ac87ed1b7..e6bcf2a9b 100644 --- a/mrmustard/math/backend_numpy.py +++ b/mrmustard/math/backend_numpy.py @@ -204,6 +204,12 @@ def matmul(self, *matrices: np.ndarray) -> np.ndarray: def matvec(self, a: np.ndarray, b: np.ndarray) -> np.ndarray: return self.matmul(a, b[..., None])[..., 0] + def maximum(self, a: np.ndarray, b: np.ndarray) -> np.ndarray: + return np.maximum(a, b) + + def minimum(self, a: np.ndarray, b: np.ndarray) -> np.ndarray: + return np.minimum(a, b) + def moveaxis( self, array: np.ndarray, old: int | Sequence[int], new: int | Sequence[int] ) -> np.ndarray: diff --git a/mrmustard/math/backend_tensorflow.py b/mrmustard/math/backend_tensorflow.py index a5e977236..3e6f2c5b4 100644 --- a/mrmustard/math/backend_tensorflow.py +++ b/mrmustard/math/backend_tensorflow.py @@ -232,6 +232,14 @@ def matvec(self, a: tf.Tensor, b: tf.Tensor) -> tf.Tensor: def make_complex(self, real: tf.Tensor, imag: tf.Tensor) -> tf.Tensor: return tf.complex(real, imag) + @Autocast() + def maximum(self, a: tf.Tensor, b: tf.Tensor) -> tf.Tensor: + return tf.maximum(a, b) + + @Autocast() + def minimum(self, a: tf.Tensor, b: tf.Tensor) -> tf.Tensor: + return tf.minimum(a, b) + def moveaxis( self, array: tf.Tensor, old: int | Sequence[int], new: int | Sequence[int] ) -> tf.Tensor: diff --git a/tests/test_math/test_backend_manager.py b/tests/test_math/test_backend_manager.py index f76967475..43ba79328 100644 --- a/tests/test_math/test_backend_manager.py +++ b/tests/test_math/test_backend_manager.py @@ -461,6 +461,24 @@ def test_moveaxis(self): arr1_moved2 = math.moveaxis(arr1_moved1, 1, 0) assert math.allclose(arr1, arr1_moved2) + def test_maximum(self): + r""" + Tests the ``maximum`` method. + """ + arr1 = np.eye(3) + arr2 = 2 * np.eye(3) + res = math.asnumpy(math.maximum(arr1, arr2)) + assert math.allclose(res, arr2) + + def test_minimum(self): + r""" + Tests the ``minimum`` method. + """ + arr1 = np.eye(3) + arr2 = 2 * np.eye(3) + res = math.asnumpy(math.minimum(arr1, arr2)) + assert math.allclose(res, arr1) + @pytest.mark.parametrize("t", types) def test_new_variable(self, t): r""" From 53a3314d91af591bc840f29626a780cb02e292f8 Mon Sep 17 00:00:00 2001 From: Anthony Date: Fri, 20 Jun 2025 08:57:07 -0400 Subject: [PATCH 07/17] lgamma --- mrmustard/math/backend_jax.py | 4 ++++ mrmustard/math/backend_manager.py | 12 ++++++++++++ mrmustard/math/backend_numpy.py | 4 ++++ mrmustard/math/backend_tensorflow.py | 3 +++ tests/test_math/test_backend_manager.py | 8 ++++++++ 5 files changed, 31 insertions(+) diff --git a/mrmustard/math/backend_jax.py b/mrmustard/math/backend_jax.py index 1c52ebb3c..e0749daa8 100644 --- a/mrmustard/math/backend_jax.py +++ b/mrmustard/math/backend_jax.py @@ -265,6 +265,10 @@ def inv(self, tensor: jnp.ndarray) -> jnp.ndarray: def is_trainable(self, tensor: jnp.ndarray) -> bool: # pylint: disable=unused-argument return False + @jax.jit + def lgamma(self, array: jnp.ndarray) -> jnp.ndarray: + return jax.lax.lgamma(array) + @jax.jit def make_complex(self, real: jnp.ndarray, imag: jnp.ndarray) -> jnp.ndarray: return real + 1j * imag diff --git a/mrmustard/math/backend_manager.py b/mrmustard/math/backend_manager.py index 4c9582e28..dca352110 100644 --- a/mrmustard/math/backend_manager.py +++ b/mrmustard/math/backend_manager.py @@ -809,6 +809,18 @@ def is_trainable(self, tensor: Tensor) -> bool: """ return self._apply("is_trainable", (tensor,)) + def lgamma(self, x: Tensor) -> Tensor: + r""" + The natural logarithm of the gamma function of ``x``. + + Args: + x: The array to take the natural logarithm of the gamma function of. + + Returns: + The natural logarithm of the gamma function of ``x``. + """ + return self._apply("lgamma", (x,)) + def log(self, x: Tensor) -> Tensor: r"""The natural logarithm of ``x``. diff --git a/mrmustard/math/backend_numpy.py b/mrmustard/math/backend_numpy.py index e6bcf2a9b..f3ecd22d7 100644 --- a/mrmustard/math/backend_numpy.py +++ b/mrmustard/math/backend_numpy.py @@ -25,6 +25,7 @@ from scipy.linalg import expm as scipy_expm from scipy.linalg import sqrtm as scipy_sqrtm +from scipy.special import loggamma as scipy_loggamma from scipy.special import xlogy as scipy_xlogy from ..utils.settings import settings @@ -189,6 +190,9 @@ def inv(self, tensor: np.ndarray) -> np.ndarray: def is_trainable(self, tensor: np.ndarray) -> bool: # pylint: disable=unused-argument return False + def lgamma(self, x: np.ndarray) -> np.ndarray: + return scipy_loggamma(x) + def log(self, x: np.ndarray) -> np.ndarray: return np.log(x) diff --git a/mrmustard/math/backend_tensorflow.py b/mrmustard/math/backend_tensorflow.py index 3e6f2c5b4..628be1a30 100644 --- a/mrmustard/math/backend_tensorflow.py +++ b/mrmustard/math/backend_tensorflow.py @@ -215,6 +215,9 @@ def inv(self, tensor: tf.Tensor) -> tf.Tensor: def is_trainable(self, tensor: tf.Tensor) -> bool: return isinstance(tensor, tf.Variable) + def lgamma(self, x: tf.Tensor) -> tf.Tensor: + return tf.math.lgamma(x) + def log(self, x: tf.Tensor) -> tf.Tensor: return tf.math.log(x) diff --git a/tests/test_math/test_backend_manager.py b/tests/test_math/test_backend_manager.py index 43ba79328..37d7de955 100644 --- a/tests/test_math/test_backend_manager.py +++ b/tests/test_math/test_backend_manager.py @@ -21,6 +21,7 @@ import tensorflow as tf from jax import numpy as jnp from jax.errors import TracerArrayConversionError +from scipy.special import loggamma as scipy_loggamma from mrmustard import math @@ -433,6 +434,13 @@ def test_is_trainable(self): if math.backend_name == "jax": assert not math.is_trainable(arr4) + def test_lgamma(self): + r""" + Tests the ``lgamma`` method. + """ + arr = np.array([1.0, 2.0, 3.0, 4.0]) + assert math.allclose(math.asnumpy(math.lgamma(arr)), scipy_loggamma(arr)) + def test_log(self): r""" Tests the ``log`` method. From 22ec0befcd94f5f9f2f984611e01de4989f57cec Mon Sep 17 00:00:00 2001 From: Anthony Date: Mon, 23 Jun 2025 11:25:18 -0400 Subject: [PATCH 08/17] nuke random.py --- tests/random.py | 433 ------------------------------------------------ 1 file changed, 433 deletions(-) diff --git a/tests/random.py b/tests/random.py index 4e9380eda..0f75a198d 100644 --- a/tests/random.py +++ b/tests/random.py @@ -13,57 +13,6 @@ # limitations under the License. import numpy as np -from hypothesis import strategies as st -from hypothesis.extra.numpy import arrays - -from mrmustard import settings -from mrmustard.lab import ( - Amplifier, - Attenuator, - BSgate, - CXgate, - CZgate, - Dgate, - DisplacedSqueezed, - GaussRandNoise, - Ggate, - Interferometer, - MZgate, - Pgate, - Rgate, - S2gate, - Sgate, - SqueezedVacuum, - Thermal, - Vacuum, -) - -# numbers -integer32bits = st.integers(min_value=0, max_value=2**31 - 1) -real = st.floats(allow_infinity=False, allow_nan=False) -positive = st.floats(min_value=0, exclude_min=True, allow_infinity=False, allow_nan=False) -negative = st.floats(max_value=0, exclude_max=True, allow_infinity=False, allow_nan=False) -real_not_zero = st.one_of(negative, positive) -small_float = st.floats(min_value=-0.1, max_value=0.1, allow_infinity=False, allow_nan=False) -medium_float = st.floats(min_value=-1.0, max_value=1.0, allow_infinity=False, allow_nan=False) -complex_nonzero = st.complex_numbers( - allow_infinity=False, allow_nan=False, min_magnitude=1e-9, max_magnitude=1e2 -) - -# physical parameters -nmodes = st.integers(min_value=1, max_value=10) -angle = st.floats(min_value=0, max_value=2 * np.pi) -r = st.floats(min_value=0, max_value=1.25, allow_infinity=False, allow_nan=False) -prob = st.floats(min_value=0, max_value=1, allow_infinity=False, allow_nan=False) -gain = st.floats(min_value=1, max_value=2, allow_infinity=False, allow_nan=False) - -# Complex number strategy -complex_number = st.complex_numbers( - min_magnitude=1e-9, max_magnitude=1, allow_infinity=False, allow_nan=False -) - -# Size strategy -size = st.integers(min_value=1, max_value=9) def Abc_triple(n: int, batch: tuple[int, ...] = ()): @@ -90,385 +39,3 @@ def Abc_triple(n: int, batch: tuple[int, ...] = ()): ) return A, b, c - - -@st.composite -def vector(draw, length): - r"""Return a vector of length `length`.""" - return draw(arrays(np.float64, (length,), elements=st.floats(min_value=-1.0, max_value=1.0))) - - -@st.composite -def list_of_ints(draw, N): - r"""Return a list of N unique integers between 0 and N-1.""" - return draw( - st.lists(st.integers(min_value=0, max_value=N), min_size=N, max_size=N, unique=True) - ) - - -@st.composite -def matrix(draw, rows, cols): - """Return a strategy for generating matrices of shape `rows` x `cols`.""" - elements = st.floats(allow_infinity=False, allow_nan=False, max_value=1e10, min_value=-1e10) - return draw(arrays(np.float64, (rows, cols), elements=elements)) - - -@st.composite -def complex_matrix(draw, rows, cols): - """Return a strategy for generating matrices of shape `rows` x `cols` with complex numbers.""" - max_abs_value = 1e10 - elements = st.complex_numbers( - min_magnitude=0, - max_magnitude=max_abs_value, - allow_infinity=False, - allow_nan=False, - ) - return draw(arrays(np.complex, (rows, cols), elements=elements)) - - -@st.composite -def complex_vector(draw, length=None): - """Return a strategy for generating vectors of length `length` with complex numbers.""" - elements = st.complex_numbers( - min_magnitude=0, max_magnitude=1, allow_infinity=False, allow_nan=False - ) - if length is None: - length = draw(st.integers(min_value=1, max_value=10)) - return draw(arrays(np.complex, (length,), elements=elements)) - - -def array_of_(strategy, minlen=0, maxlen=100): - r"""Return a strategy that returns an array of values from `strategy`.""" - return arrays( - shape=(st.integers(minlen, maxlen).example(),), - elements=strategy, - dtype=type(strategy.example()), - ) - - -def none_or_(strategy): - r"""Return a strategy that returns either None or a value from `strategy`.""" - return st.one_of(st.just(None), strategy) - - -# bounds -def bounds_check(t): - return t[0] < t[1] if t[0] is not None and t[1] is not None else True - - -angle_bounds = st.tuples(none_or_(angle), none_or_(angle)).filter(bounds_check) -positive_bounds = st.tuples(none_or_(positive), none_or_(positive)).filter(bounds_check) -real_bounds = st.tuples(none_or_(real), none_or_(real)).filter(bounds_check) -gain_bounds = st.tuples(none_or_(gain), none_or_(gain)).filter(bounds_check) -prob_bounds = st.tuples(none_or_(prob), none_or_(prob)).filter(bounds_check) - - -# gates -@st.composite -def random_Rgate(draw, trainable=False): - r"""Return a random Rgate.""" - return Rgate( - theta=draw(angle), - theta_bounds=draw(angle_bounds), - theta_trainable=trainable, - ) - - -@st.composite -def random_Sgate(draw, trainable=False): - r"""Return a random Sgate.""" - return Sgate( - 0, - r=draw(r), - phi=draw(angle), - r_bounds=draw(positive_bounds), - phi_bounds=draw(angle_bounds), - r_trainable=trainable, - phi_trainable=trainable, - ) - - -@st.composite -def random_Dgate(draw, trainable=False): - r"""Return a random Dgate.""" - x = draw(small_float) - y = draw(small_float) - return Dgate( - 0, - x=x, - y=y, - x_bounds=draw(real_bounds), - y_bounds=draw(real_bounds), - x_trainable=trainable, - y_trainable=trainable, - ) - - -@st.composite -def random_Pgate(draw, trainable=False): - r"""Return a random Pgate.""" - return Pgate( - 0, - shearing=draw(prob), - shearing_bounds=draw(prob_bounds), - shearing_trainable=trainable, - ) - - -@st.composite -def random_Attenuator(draw, trainable=False): - r"""Return a random Attenuator.""" - return Attenuator( - 0, - transmissivity=draw(prob), - transmissivity_bounds=draw(prob_bounds), - transmissivity_trainable=trainable, - ) - - -@st.composite -def random_Amplifier(draw, trainable=False): - r"""Return a random Amplifier.""" - return Amplifier( - 0, - gain=draw(gain), - gain_bounds=draw(gain_bounds), - gain_trainable=trainable, - ) - - -@st.composite -def random_GaussRandNoise(draw, trainable=False): - r"""Return a random GaussRandNoise.""" - settings.SEED = draw(integer32bits) - return GaussRandNoise( - 0, - Y_trainable=trainable, - ) - - -@st.composite -def random_S2gate(draw, trainable=False): - r"""Return a random S2gate.""" - return S2gate( - (0, 1), - r=draw(r), - phi=draw(angle), - r_bounds=draw(positive_bounds), - phi_bounds=draw(angle_bounds), - r_trainable=trainable, - phi_trainable=trainable, - ) - - -@st.composite -def random_CXgate(draw, trainable=False): - r"""Return a random CXgate.""" - return CXgate( - (0, 1), - s=draw(medium_float), - s_bounds=draw(real_bounds), - s_trainable=trainable, - ) - - -@st.composite -def random_CZgate(draw, trainable=False): - r"""Return a random CZgate.""" - return CZgate( - (0, 1), - s=draw(medium_float), - s_bounds=draw(real_bounds), - s_trainable=trainable, - ) - - -@st.composite -def random_BSgate(draw, trainable=False): - r"""Return a random BSgate.""" - return BSgate( - (0, 1), - theta=draw(angle), - phi=draw(angle), - theta_bounds=draw(angle_bounds), - phi_bounds=draw(angle_bounds), - theta_trainable=trainable, - phi_trainable=trainable, - ) - - -@st.composite -def random_MZgate(draw, trainable=False): - r"""Return a random MZgate.""" - return MZgate( - (0, 1), - phi_a=draw(angle), - phi_b=draw(angle), - phi_a_bounds=draw(angle_bounds), - phi_b_bounds=draw(angle_bounds), - phi_a_trainable=trainable, - phi_b_trainable=trainable, - internal=draw(st.booleans()), - ) - - -@st.composite -def random_Interferometer(draw, num_modes, trainable=False): - r"""Return a random Interferometer.""" - settings.SEED = draw(integer32bits) - return Interferometer(modes=list(range(num_modes)), unitary_trainable=trainable) - - -@st.composite -def random_Ggate(draw, num_modes, trainable=False): - r"""Return a random Ggate.""" - settings.SEED = draw(integer32bits) - return Ggate(modes=list(range(num_modes)), symplectic_trainable=trainable) - - -@st.composite -def single_mode_unitary_gate(draw): - r"""Return a random single mode unitary gate.""" - return draw( - st.one_of( - random_Rgate(), - random_Sgate(), - random_Dgate(), - random_Pgate(), - random_Interferometer(num_modes=1), # like Rgate - ) - ) - - -@st.composite -def single_mode_cv_channel(draw): - r"""Return a random single mode unitary gate.""" - return draw( - st.one_of( - random_Attenuator(), - random_Amplifier(), - random_GaussRandNoise(), - ) - ) - - -@st.composite -def two_mode_unitary_gate(draw): - r"""Return a random two mode unitary gate.""" - return draw( - st.one_of( - random_S2gate(), - random_BSgate(), - random_MZgate(), - random_CXgate(), - random_CZgate(), - random_Ggate(num_modes=2), - random_Interferometer(num_modes=2), - ) - ) - - -@st.composite -def n_mode_unitary_gate(draw, num_modes=None): - r"""Return a random n mode unitary gate.""" - return draw(st.one_of(random_Interferometer(num_modes), random_Ggate(num_modes))) - - -## states -@st.composite -def squeezed_vacuum(draw, num_modes): - r"""Return a random squeezed vacuum state.""" - r_vec = draw(array_of_(r, num_modes, num_modes)) - phi = draw(array_of_(angle, num_modes, num_modes)) - state = SqueezedVacuum(0, r=r_vec[0], phi=phi[0]) - for i in range(1, num_modes): - state = state >> SqueezedVacuum(i, r=r_vec[i], phi=phi[i]) - return state - - -@st.composite -def displacedsqueezed(draw, num_modes): - r"""Return a random displaced squeezed state.""" - r_vec = draw(array_of_(r, num_modes, num_modes)) - phi = draw(array_of_(angle, num_modes, num_modes)) - x = draw(array_of_(medium_float, num_modes, num_modes)) - y = draw(array_of_(medium_float, num_modes, num_modes)) - state = DisplacedSqueezed(0, r=r_vec[0], phi=phi[0], x=x[0], y=y[0]) - for i in range(1, num_modes): - state = state >> DisplacedSqueezed(i, r=r_vec[i], phi=phi[i], x=x[i], y=y[i]) - return state - - -@st.composite -def coherent(draw, num_modes): - r"""Return a random coherent state.""" - x = draw(array_of_(medium_float, num_modes, num_modes)) - y = draw(array_of_(medium_float, num_modes, num_modes)) - state = Coherent(0, x=x[0], y=y[0]) - for i in range(1, num_modes): - state = state >> Coherent(i, x=x[i], y=y[i]) - return state - - -@st.composite -def tmsv(draw, phi): - r"""Return a random two-mode squeezed vacuum state.""" - return TMSV((0, 1), r=draw(r), phi=draw(phi)) - - -@st.composite -def thermal(draw, num_modes): - r"""Return a random thermal state.""" - n_mean = draw(array_of_(r, num_modes, num_modes)) # using r here - state = Thermal(0, n_mean=n_mean[0]) - for i in range(1, num_modes): - state = state >> Thermal(i, n_mean=n_mean[i]) - return state - - -# generic states -@st.composite -def n_mode_separable_pure_state(draw, num_modes): - r"""Return a random n mode separable pure state.""" - return draw( - st.one_of( - squeezed_vacuum(num_modes), - displacedsqueezed(num_modes), - coherent(num_modes), - ) - ) - - -@st.composite -def n_mode_separable_mixed_state(draw, num_modes): - r"""Return a random n mode separable mixed state.""" - attenuator = Attenuator(draw(st.floats(min_value=0.2, max_value=0.9))) - return ( - draw( - st.one_of( - squeezed_vacuum(num_modes), - displacedsqueezed(num_modes), - coherent(num_modes), - thermal(num_modes), - ) - ) - >> attenuator - ) - - -@st.composite -def n_mode_pure_state(draw, num_modes=1): - r"""Return a random n mode pure state.""" - S = draw(random_Sgate(num_modes)) - I = draw(random_Interferometer(num_modes)) - D = draw(random_Dgate(num_modes)) - return Vacuum(num_modes) >> S >> I >> D - - -@st.composite -def n_mode_mixed_state(draw, num_modes=1): - r"""Return a random n mode pure state.""" - S = draw(random_Sgate(num_modes)) - I = draw(random_Interferometer(num_modes)) - D = draw(random_Dgate(num_modes)) - return Thermal([0.5] * num_modes) >> S >> I >> D From 2f2b3798e5fa9658d8ec061cc667410162142cee Mon Sep 17 00:00:00 2001 From: Anthony Date: Tue, 24 Jun 2025 16:30:33 -0400 Subject: [PATCH 09/17] rem old dev container things --- .devcontainer/devcontainer.json | 45 --------------------------------- .devcontainer/postinstall.sh | 32 ----------------------- env.Dockerfile | 30 ---------------------- 3 files changed, 107 deletions(-) delete mode 100644 .devcontainer/devcontainer.json delete mode 100644 .devcontainer/postinstall.sh delete mode 100644 env.Dockerfile diff --git a/.devcontainer/devcontainer.json b/.devcontainer/devcontainer.json deleted file mode 100644 index 4c84539cc..000000000 --- a/.devcontainer/devcontainer.json +++ /dev/null @@ -1,45 +0,0 @@ -{ - "name": "MrMustard 🌭", - "build": { - "dockerfile": "../env.Dockerfile", - "args": { - "PYTHON_VERSION": "${localEnv:MRMUSTARD_PYTHON_VERSION:3.10}" - } - }, - "postCreateCommand": "/bin/sh ./.devcontainer/postinstall.sh", - "customizations": { - "vscode": { - "extensions": [ - "ms-python.python", - "ms-python.vscode-pylance", - "ms-python.pylint", - "ms-toolsai.jupyter", - "GitHub.vscode-pull-request-github", - "mutantdino.resourcemonitor", - "njpwerner.autodocstring" - ], - "settings": { - "python.languageServer": "Pylance", - "python.testing.unittestEnabled": false, - "python.testing.pytestEnabled": true, - "python.testing.pytestArgs": [ - "tests" - ], - "python.analysis.typeCheckingMode": "off", - "python.defaultInterpreterPath": "/usr/local/bin/python", - "python.terminal.executeInFileDir": true, - "code-runner.fileDirectoryAsCwd": true, - "autoDocstring.docstringFormat": "google", - "terminal.integrated.defaultProfile.linux": "zsh", - "terminal.integrated.profiles.linux": { - "zsh": { - "path": "zsh" - } - } - } - } - }, - "remoteUser": "root", - "workspaceMount": "source=${localWorkspaceFolder},target=/mrmustard,type=bind", - "workspaceFolder": "/mrmustard" -} diff --git a/.devcontainer/postinstall.sh b/.devcontainer/postinstall.sh deleted file mode 100644 index d7d18fb8b..000000000 --- a/.devcontainer/postinstall.sh +++ /dev/null @@ -1,32 +0,0 @@ -#! /bin/sh - -apt-get update -y -apt-get -y -o Dpkg::Options::="--force-confold" --force-yes install --no-install-recommends zsh -apt-get -y install --no-install-recommends fonts-powerline locales toilet fortunes fortune-mod -apt-get clean && rm -rf /var/lib/apt/lists/* - -# install jupyter notebook widgets extension -pip install ipywidgets ipykernel - -# generate locale for zsh terminal agnoster theme -echo "en_US.UTF-8 UTF-8" > /etc/locale.gen && /usr/sbin/locale-gen -locale-gen en_US.UTF-8 - -# *** Install oh-my-zsh *** -# Make zsh the default shell -chsh -s $(which zsh) -# Disable marking untracked files -# under VCS as dirty. This makes repository status check for large repositories -# much, much faster. -git config --add oh-my-zsh.hide-dirty 1 -# Download and run installation script -curl -fsSL https://raw.githubusercontent.com/ohmyzsh/ohmyzsh/master/tools/install.sh | sh - -# Print welcome message -clear -echo "\n" -/usr/games/fortune -w wisdom tao science songs-poems literature linuxcookie linux education art -echo "\n" -toilet --metal -f emboss -W MrMustard -echo "\n" -echo "This devcontainer is ready for development!" diff --git a/env.Dockerfile b/env.Dockerfile deleted file mode 100644 index fc9d0cbf1..000000000 --- a/env.Dockerfile +++ /dev/null @@ -1,30 +0,0 @@ -# *** Base *** # -ARG PYTHON_VERSION -FROM python:${PYTHON_VERSION} AS base - -ENV DEBIAN_FRONTEND=noninteractive - -# Update package lists and install curl -RUN apt-get update \ - && apt-get install -y curl \ - && rm -rf /var/lib/apt/lists/* - -# Setup workdir -WORKDIR /mrmustard -COPY pyproject.toml . -COPY uv.lock . - -# Install uv, add to path -COPY --from=ghcr.io/astral-sh/uv:0.5.29 /uv /uvx /uv_bin/ -ENV PATH="${PATH}:/uv_bin" - -# Install all dependencies -RUN uv venv -p python${PYTHON_VERSION} -RUN uv sync --all-extras --group doc - -ENV DEBIAN_FRONTEND=dialog - -# Add source code, tests and configuration -COPY . . - -CMD ["uv", "run", "python"] From b3765639eed99a15df62c33f182f2f6c001ec95e Mon Sep 17 00:00:00 2001 From: Anthony Date: Wed, 25 Jun 2025 09:17:30 -0400 Subject: [PATCH 10/17] rem COM --- mrmustard/lab/circuits.py | 2 +- pyproject.toml | 1 - tests/test_training/test_opt_lab.py | 14 ++++---------- 3 files changed, 5 insertions(+), 12 deletions(-) diff --git a/mrmustard/lab/circuits.py b/mrmustard/lab/circuits.py index 67a1a7d67..cf0a0f276 100644 --- a/mrmustard/lab/circuits.py +++ b/mrmustard/lab/circuits.py @@ -349,7 +349,7 @@ def component_to_str(comp: CircuitComponent) -> list[str]: param = comp.parameters.constants.get(name) or comp.parameters.variables.get( name, ) - new_values = math.atleast_nd(param.value, 1) + new_values = math.atleast_1d(param.value) if len(new_values) == 1 and cc_name not in control_gates: new_values = math.tile(new_values, (len(comp.modes),)) values.append(math.asnumpy(new_values)) diff --git a/pyproject.toml b/pyproject.toml index e08549228..b6314f8b6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -146,7 +146,6 @@ select = [ "BLE", # flake8-blind-except "B", # flake8-bugbear "A", # flake8-builtins - "COM", # flake8-commas "C4", # flake8-comprehensions "ISC", # flake8-implicit-str-concat "ICN", # flake8-import-conventions diff --git a/tests/test_training/test_opt_lab.py b/tests/test_training/test_opt_lab.py index f56ecaf8f..1683ef57f 100644 --- a/tests/test_training/test_opt_lab.py +++ b/tests/test_training/test_opt_lab.py @@ -495,11 +495,8 @@ def test_dgate_optimization(self): def cost_fn(): state_out = Vacuum(0) >> dgate - return ( - -( - math.abs(math.sum(math.conj(state_out.fock_array((40,))) * target_state)) - ** 2 - ) + return -( + math.abs(math.sum(math.conj(state_out.fock_array((40,))) * target_state)) ** 2 ) opt = Optimizer() @@ -520,11 +517,8 @@ def test_sgate_optimization(self): def cost_fn(): state_out = Vacuum(0) >> sgate - return ( - -( - math.abs(math.sum(math.conj(state_out.fock_array((40,))) * target_state)) - ** 2 - ) + return -( + math.abs(math.sum(math.conj(state_out.fock_array((40,))) * target_state)) ** 2 ) opt = Optimizer() From 9f9df99e49505c875a32555471727a61cc5bfdd0 Mon Sep 17 00:00:00 2001 From: Anthony Date: Wed, 25 Jun 2025 09:18:11 -0400 Subject: [PATCH 11/17] trying --- tests/test_training/test_opt_lab.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/tests/test_training/test_opt_lab.py b/tests/test_training/test_opt_lab.py index 1683ef57f..04d5f62cb 100644 --- a/tests/test_training/test_opt_lab.py +++ b/tests/test_training/test_opt_lab.py @@ -59,9 +59,7 @@ def test_S2gate_coincidence_prob(self, n): rng.reset_from_seed(settings.SEED) S = TwoModeSqueezedVacuum( - (0, 1), - r=abs(settings.rng.normal(loc=1.0, scale=0.1)), - r_trainable=True, + (0, 1), r=abs(settings.rng.normal(loc=1.0, scale=0.1)), r_trainable=True ) def cost_fn(): From f27e8a6fc5569af148366223a0399b444fd24907 Mon Sep 17 00:00:00 2001 From: Anthony Date: Wed, 25 Jun 2025 09:18:32 -0400 Subject: [PATCH 12/17] rem pylint --- tests/test_training/test_opt_lab.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_training/test_opt_lab.py b/tests/test_training/test_opt_lab.py index 04d5f62cb..944199adf 100644 --- a/tests/test_training/test_opt_lab.py +++ b/tests/test_training/test_opt_lab.py @@ -65,7 +65,7 @@ def test_S2gate_coincidence_prob(self, n): def cost_fn(): return -(math.abs(S.fock_array((n + 1, n + 1))[n, n]) ** 2) - def cb(optimizer, cost, trainables, **kwargs): # pylint: disable=unused-argument + def cb(optimizer, cost, trainables, **kwargs): return { "cost": cost, "lr": optimizer.learning_rate[update_euclidean], From be91872c7022be9b8a54106136cc00e01133ae95 Mon Sep 17 00:00:00 2001 From: Anthony Date: Wed, 25 Jun 2025 09:39:39 -0400 Subject: [PATCH 13/17] updates --- .dockerignore | 1 - pyproject.toml | 2 ++ 2 files changed, 2 insertions(+), 1 deletion(-) delete mode 100644 .dockerignore diff --git a/.dockerignore b/.dockerignore deleted file mode 100644 index 1d17dae13..000000000 --- a/.dockerignore +++ /dev/null @@ -1 +0,0 @@ -.venv diff --git a/pyproject.toml b/pyproject.toml index b6314f8b6..6572be265 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -146,6 +146,7 @@ select = [ "BLE", # flake8-blind-except "B", # flake8-bugbear "A", # flake8-builtins + "COM", # flake8-commas "C4", # flake8-comprehensions "ISC", # flake8-implicit-str-concat "ICN", # flake8-import-conventions @@ -165,6 +166,7 @@ ignore = [ "E741", "F403", "B905", + "COM812", "PERF203", "PLR2004", "PLR0124", "PLR0913", "PLR0912", "PLR0915", "PLC0414", "PLW1641", "RUF012", "RUF001", "RUF002", From 7f05e2eb42500d3c0c523fd24137ec37123e9cd8 Mon Sep 17 00:00:00 2001 From: Anthony Date: Wed, 25 Jun 2025 09:40:13 -0400 Subject: [PATCH 14/17] bump patch cov --- .codecov.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.codecov.yml b/.codecov.yml index d21559366..d1515b335 100644 --- a/.codecov.yml +++ b/.codecov.yml @@ -5,5 +5,5 @@ coverage: target: 89% patch: default: - target: 90% + target: 100% threshold: 0% From 367dbc33abe6fc080282ad8154a524c32b4733d1 Mon Sep 17 00:00:00 2001 From: Anthony Date: Wed, 25 Jun 2025 10:16:23 -0400 Subject: [PATCH 15/17] fix --- mrmustard/lab/circuits.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mrmustard/lab/circuits.py b/mrmustard/lab/circuits.py index cf0a0f276..67a1a7d67 100644 --- a/mrmustard/lab/circuits.py +++ b/mrmustard/lab/circuits.py @@ -349,7 +349,7 @@ def component_to_str(comp: CircuitComponent) -> list[str]: param = comp.parameters.constants.get(name) or comp.parameters.variables.get( name, ) - new_values = math.atleast_1d(param.value) + new_values = math.atleast_nd(param.value, 1) if len(new_values) == 1 and cc_name not in control_gates: new_values = math.tile(new_values, (len(comp.modes),)) values.append(math.asnumpy(new_values)) From 6e5429dbdfb8fac0bfe40e47bc160556f65a38c1 Mon Sep 17 00:00:00 2001 From: Anthony Date: Wed, 25 Jun 2025 10:55:22 -0400 Subject: [PATCH 16/17] cov? --- mrmustard/lab/utils.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/mrmustard/lab/utils.py b/mrmustard/lab/utils.py index 53d51ac01..8c603a8a8 100644 --- a/mrmustard/lab/utils.py +++ b/mrmustard/lab/utils.py @@ -72,8 +72,7 @@ def reshape_params(n_modes: int, **kwargs) -> Generator: if len(var) == 1: var = math.tile(var, (n_modes,)) # noqa: PLW2901 elif len(var) != n_modes: - msg = f"Parameter {names[i]} has an incompatible shape." - raise ValueError(msg) + raise ValueError(f"Parameter {names[i]} has an incompatible shape.") yield var From 41c3b6f5d1e59d08123faff801e7e55259d024f5 Mon Sep 17 00:00:00 2001 From: Anthony Date: Wed, 25 Jun 2025 10:57:00 -0400 Subject: [PATCH 17/17] rev utils --- mrmustard/lab/utils.py | 19 ++++++++----------- 1 file changed, 8 insertions(+), 11 deletions(-) diff --git a/mrmustard/lab/utils.py b/mrmustard/lab/utils.py index 8c603a8a8..2af3e2243 100644 --- a/mrmustard/lab/utils.py +++ b/mrmustard/lab/utils.py @@ -63,17 +63,14 @@ def reshape_params(n_modes: int, **kwargs) -> Generator: ValueError: If a parameter has a length which is neither equal to ``1`` nor ``n_modes``. """ - names = list(kwargs.keys()) - variables = list(kwargs.values()) - - variables = [math.atleast_nd(var, 1) for var in variables] - - for i, var in enumerate(variables): - if len(var) == 1: - var = math.tile(var, (n_modes,)) # noqa: PLW2901 - elif len(var) != n_modes: - raise ValueError(f"Parameter {names[i]} has an incompatible shape.") - yield var + for name, val in kwargs.items(): + val = math.atleast_nd(val, 1) # noqa: PLW2901 + if len(val) == 1: + val = math.tile(val, (n_modes,)) # noqa: PLW2901 + elif len(val) != n_modes: + msg = f"Parameter {name} has an incompatible shape." + raise ValueError(msg) + yield val def shape_check(mat, vec, dim: int, name: str):