Skip to content

Removing remnant legacy methods #608

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 24 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion mrmustard/lab/circuits.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,7 +343,7 @@ def component_to_str(comp: CircuitComponent) -> list[str]:
param = comp.parameters.constants.get(name) or comp.parameters.variables.get(
name
)
new_values = math.atleast_1d(param.value)
new_values = math.atleast_nd(param.value, 1)
if len(new_values) == 1 and cc_name not in control_gates:
new_values = math.tile(new_values, (len(comp.modes),))
values.append(math.asnumpy(new_values))
Expand Down
2 changes: 1 addition & 1 deletion mrmustard/lab/states/dm.py
Original file line number Diff line number Diff line change
Expand Up @@ -704,7 +704,7 @@ def physical_stellar_decomposition_mixed( # pylint: disable=too-many-statements
R_c_transpose = math.einsum("...ij->...ji", R_c)

Aphi_out = Am
gamma = np.linalg.pinv(R_c) @ R
gamma = math.pinv(R_c) @ R
gamma_transpose = math.einsum("...ij->...ji", gamma)
Aphi_in = gamma @ math.inv(Aphi_out - math.Xmat(M)) @ gamma_transpose + math.Xmat(M)

Expand Down
2 changes: 1 addition & 1 deletion mrmustard/lab/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def reshape_params(n_modes: int, **kwargs) -> Generator:
names = list(kwargs.keys())
vars = list(kwargs.values())

vars = [math.atleast_1d(var) for var in vars]
vars = [math.atleast_nd(var, 1) for var in vars]

for i, var in enumerate(vars):
if len(var) == 1:
Expand Down
104 changes: 1 addition & 103 deletions mrmustard/math/backend_jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,9 +75,6 @@ def abs(self, array: jnp.ndarray) -> jnp.ndarray:
def all(self, array: jnp.ndarray) -> jnp.ndarray:
return jnp.all(array)

def angle(self, array: jnp.ndarray) -> jnp.ndarray:
return jnp.angle(array)

@jax.jit
def any(self, array: jnp.ndarray) -> jnp.ndarray:
return jnp.any(array)
Expand Down Expand Up @@ -115,25 +112,6 @@ def log(self, array: jnp.ndarray) -> jnp.ndarray:
def atleast_nd(self, array: jnp.ndarray, n: int, dtype=None) -> jnp.ndarray:
return jnp.array(array, ndmin=n, dtype=dtype)

@jax.jit
def block_diag(self, mat1: jnp.ndarray, mat2: jnp.ndarray) -> jnp.ndarray:
Za = self.zeros((mat1.shape[-2], mat2.shape[-1]), dtype=mat1.dtype)
Zb = self.zeros((mat2.shape[-2], mat1.shape[-1]), dtype=mat1.dtype)
return self.concat(
[self.concat([mat1, Za], axis=-1), self.concat([Zb, mat2], axis=-1)],
axis=-2,
)

def constraint_func(self, bounds: tuple[float | None, float | None]) -> Callable:
lower = -jnp.inf if bounds[0] is None else bounds[0]
upper = jnp.inf if bounds[1] is None else bounds[1]

@jax.jit
def constraint(x):
return jnp.clip(x, lower, upper)

return constraint

@partial(jax.jit, static_argnames=["dtype"])
def cast(self, array: jnp.ndarray, dtype=None) -> jnp.ndarray:
if dtype is None:
Expand All @@ -159,45 +137,13 @@ def allclose(self, array1: jnp.ndarray, array2: jnp.ndarray, atol=1e-9, rtol=1e-
def clip(self, array: jnp.ndarray, a_min: float, a_max: float) -> jnp.ndarray:
return jnp.clip(array, a_min, a_max)

@jax.jit
def maximum(self, a: jnp.ndarray, b: jnp.ndarray) -> jnp.ndarray:
return jnp.maximum(a, b)

@jax.jit
def minimum(self, a: jnp.ndarray, b: jnp.ndarray) -> jnp.ndarray:
return jnp.minimum(a, b)

@jax.jit
def lgamma(self, array: jnp.ndarray) -> jnp.ndarray:
return jax.lax.lgamma(array)

@jax.jit
def conj(self, array: jnp.ndarray) -> jnp.ndarray:
return jnp.conj(array)

def pow(self, x: jnp.ndarray, y: float) -> jnp.ndarray:
return jnp.power(x, y)

def Categorical(self, probs: jnp.ndarray, name: str): # pylint: disable=unused-argument
class Generator:
def __init__(self, probs):
self._probs = probs

def sample(self):
key = jax.random.PRNGKey(0)
idx = jnp.arange(len(self._probs))
return jax.random.choice(key, idx, p=self._probs / jnp.sum(self._probs))

return Generator(probs)

@partial(jax.jit, static_argnames=["k"])
def set_diag(self, array: jnp.ndarray, diag: jnp.ndarray, k: int) -> jnp.ndarray:
i = jnp.arange(0, array.shape[-2] - abs(k))
j = jnp.arange(abs(k), array.shape[-1])
i = jnp.where(k < 0, i - array.shape[-2] + abs(k), i)
j = jnp.where(k < 0, j - abs(k), j)
return array.at[..., i, j].set(diag)

def new_variable(
self,
value: jnp.ndarray,
Expand All @@ -210,25 +156,14 @@ def new_variable(

@jax.jit
def outer(self, array1: jnp.ndarray, array2: jnp.ndarray) -> jnp.ndarray:
return jnp.tensordot(array1, array2, [[], []])
return self.tensordot(array1, array2, [[], []])

@partial(jax.jit, static_argnames=["name", "dtype"])
def new_constant(self, value, name: str, dtype=None): # pylint: disable=unused-argument
dtype = dtype or self.float64
value = self.astensor(value, dtype)
return value

@partial(jax.jit, static_argnames=["data_format", "padding"])
def convolution(
self,
array: jnp.ndarray,
filters: jnp.ndarray,
padding: str | None = None,
data_format="NWC", # pylint: disable=unused-argument
) -> jnp.ndarray:
padding = padding or "VALID"
return jax.lax.conv(array, filters, (1, 1), padding)

def tile(self, array: jnp.ndarray, repeats: Sequence[int]) -> jnp.ndarray:
return jnp.tile(array, repeats)

Expand Down Expand Up @@ -311,10 +246,6 @@ def eye_like(self, array: jnp.ndarray) -> jnp.ndarray:
def from_backend(self, value) -> bool:
return isinstance(value, jnp.ndarray)

@partial(jax.jit, static_argnames=["repeats", "axis"])
def repeat(self, array: jnp.ndarray, repeats: int, axis: int = None) -> jnp.ndarray:
return jnp.repeat(array, repeats, axis=axis)

@partial(jax.jit, static_argnames=["axis"])
def gather(self, array: jnp.ndarray, indices: jnp.ndarray, axis: int = 0) -> jnp.ndarray:
return jnp.take(array, indices, axis=axis)
Expand Down Expand Up @@ -385,10 +316,6 @@ def real(self, array: jnp.ndarray) -> jnp.ndarray:
def reshape(self, array: jnp.ndarray, shape: Sequence[int]) -> jnp.ndarray:
return jnp.reshape(array, shape)

@partial(jax.jit, static_argnames=["decimals"])
def round(self, array: jnp.ndarray, decimals: int = 0) -> jnp.ndarray:
return jnp.round(array, decimals)

@jax.jit
def sin(self, array: jnp.ndarray) -> jnp.ndarray:
return jnp.sin(array)
Expand Down Expand Up @@ -416,9 +343,6 @@ def stack(self, arrays: jnp.ndarray, axis: int = 0) -> jnp.ndarray:
def kron(self, tensor1: jnp.ndarray, tensor2: jnp.ndarray):
return jnp.kron(tensor1, tensor2)

def boolean_mask(self, tensor: jnp.ndarray, mask: jnp.ndarray) -> jnp.ndarray:
return tensor[mask]

@partial(jax.jit, static_argnames=["axes"])
def sum(self, array: jnp.ndarray, axes: Sequence[int] = None):
return jnp.sum(array, axis=axes)
Expand All @@ -430,24 +354,6 @@ def norm(self, array: jnp.ndarray) -> jnp.ndarray:
def map_fn(self, func, elements):
return jax.vmap(func)(elements)

def MultivariateNormalTriL(self, loc: jnp.ndarray, scale_tril: jnp.ndarray, key: int = 0):
class Generator:
def __init__(self, mean, cov, key):
self._mean = mean
self._cov = cov
self._rng = jax.random.PRNGKey(key)

def sample(self, dtype=None): # pylint: disable=unused-argument
fn = jax.random.multivariate_normal
ret = fn(self._rng, self._mean, self._cov)
return ret

def prob(self, x):
return jsp.stats.multivariate_normal.pdf(x, mean=self._mean, cov=self._cov)

scale_tril = scale_tril @ jnp.transpose(scale_tril)
return Generator(loc, scale_tril, key)

def tensordot(self, a: jnp.ndarray, b: jnp.ndarray, axes: Sequence[int]) -> jnp.ndarray:
return jnp.tensordot(a, b, axes)

Expand All @@ -466,14 +372,6 @@ def zeros(self, shape: Sequence[int], dtype=None) -> jnp.ndarray:
def zeros_like(self, array: jnp.ndarray, dtype: str = "complex128") -> jnp.ndarray:
return jnp.zeros_like(array, dtype=dtype)

@partial(jax.jit, static_argnames=["axis"])
def squeeze(self, tensor: jnp.ndarray, axis=None):
return jnp.squeeze(tensor, axis=axis)

@jax.jit
def cholesky(self, input: jnp.ndarray):
return jnp.linalg.cholesky(input)

@staticmethod
@jax.jit
def eigh(tensor: jnp.ndarray) -> tuple:
Expand Down
Loading