Skip to content

Commit 1e798b7

Browse files
authored
Adding equinox as dependency and switching all baseclasses to modules (#200)
* Adding equinox as dependency * adding news * Fixing tutorial syntax * Add comment about old init function * Refactoring tests to handle float32 precision * reverting CARMA init * reverting CARMA init * Xfail polynomial kernel tests * adding tests for kernel pytrees
1 parent d17ae68 commit 1e798b7

35 files changed

+588
-563
lines changed

.github/workflows/tests.yml

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,17 @@ jobs:
1515
matrix:
1616
python-version: ["3.9", "3.10", "3.11"]
1717
nox-session: ["test"]
18+
x64: ["1"]
1819
include:
20+
- python-version: "3.10"
21+
nox-session: "test"
22+
x64: "0"
23+
- python-version: "3.10"
24+
nox-session: "comparison"
25+
x64: "1"
1926
- python-version: "3.10"
2027
nox-session: "doctest"
28+
x64: "1"
2129

2230
steps:
2331
- name: Checkout
@@ -36,6 +44,8 @@ jobs:
3644
run: |
3745
python -m nox --non-interactive --error-on-missing-interpreter \
3846
--session ${{ matrix.nox-session }} --python ${{ matrix.python-version }}
47+
env:
48+
JAX_ENABLE_X64: ${{ matrix.x64 }}
3949

4050
build:
4151
runs-on: ubuntu-latest

docs/tutorials/derivative.ipynb

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -105,8 +105,7 @@
105105
"\n",
106106
"\n",
107107
"class DerivativeKernel(tinygp.kernels.Kernel):\n",
108-
" def __init__(self, kernel):\n",
109-
" self.kernel = kernel\n",
108+
" kernel: tinygp.kernels.Kernel\n",
110109
"\n",
111110
" def evaluate(self, X1, X2):\n",
112111
" t1, d1 = X1\n",
@@ -301,6 +300,10 @@
301300
" shape as ``coeff_prim``.\n",
302301
" \"\"\"\n",
303302
"\n",
303+
" kernel: tinygp.kernels.Kernel\n",
304+
" coeff_prim: jax.Array\n",
305+
" coeff_deriv: jax.Array\n",
306+
"\n",
304307
" def __init__(self, kernel, coeff_prim, coeff_deriv):\n",
305308
" self.kernel = kernel\n",
306309
" self.coeff_prim, self.coeff_deriv = jnp.broadcast_arrays(\n",
@@ -497,7 +500,7 @@
497500
"hash": "d20ea8a315da34b3e8fab0dbd7b542a0ef3c8cf12937343660e6bc10a20768e3"
498501
},
499502
"kernelspec": {
500-
"display_name": "Python 3.9.9 ('tinygp')",
503+
"display_name": "Python 3 (ipykernel)",
501504
"language": "python",
502505
"name": "python3"
503506
},
@@ -511,9 +514,9 @@
511514
"name": "python",
512515
"nbconvert_exporter": "python",
513516
"pygments_lexer": "ipython3",
514-
"version": "3.9.9"
517+
"version": "3.10.6"
515518
}
516519
},
517520
"nbformat": 4,
518-
"nbformat_minor": 2
521+
"nbformat_minor": 4
519522
}

docs/tutorials/kernels.ipynb

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -54,10 +54,9 @@
5454
"\n",
5555
"\n",
5656
"class SpectralMixture(tinygp.kernels.Kernel):\n",
57-
" def __init__(self, weight, scale, freq):\n",
58-
" self.weight = jnp.atleast_1d(weight)\n",
59-
" self.scale = jnp.atleast_1d(scale)\n",
60-
" self.freq = jnp.atleast_1d(freq)\n",
57+
" weight: jax.Array\n",
58+
" scale: jax.Array\n",
59+
" freq: jax.Array\n",
6160
"\n",
6261
" def evaluate(self, X1, X2):\n",
6362
" tau = jnp.atleast_1d(jnp.abs(X1 - X2))[..., None]\n",
@@ -210,7 +209,7 @@
210209
],
211210
"metadata": {
212211
"kernelspec": {
213-
"display_name": "tinygp",
212+
"display_name": "Python 3 (ipykernel)",
214213
"language": "python",
215214
"name": "python3"
216215
},

news/188.bugfix

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
Fixed use of `jnp.roots` and `np.roll` to make CARMA kernel jit-compliant
1+
Fixed use of `jnp.roots` and `np.roll` to make CARMA kernel jit-compliant.

news/200.feature

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Switched all base classes to `equinox.Module <https://docs.kidger.site/equinox/api/module/module/>`_ objects to simplify dataclass handling.

noxfile.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,12 @@
99
@nox.session(python=PYTHON_VERSIONS)
1010
def test(session: nox.Session) -> None:
1111
session.install(".[test]")
12+
session.run("pytest", *session.posargs)
13+
14+
15+
@nox.session(python=PYTHON_VERSIONS)
16+
def comparison(session: nox.Session) -> None:
17+
session.install(".[test,comparison]")
1218
session.run("pytest", *session.posargs, env={"JAX_ENABLE_X64": "1"})
1319

1420

pyproject.toml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,11 @@ classifiers = [
1515
"Programming Language :: Python :: 3",
1616
]
1717
dynamic = ["version"]
18-
dependencies = ["jax", "jaxlib"]
18+
dependencies = ["jax", "jaxlib", "equinox"]
1919

2020
[project.optional-dependencies]
21-
test = ["pytest", "george", "celerite"]
21+
test = ["pytest"]
22+
comparison = ["george", "celerite"]
2223
docs = [
2324
"sphinx-book-theme",
2425
"myst-nb",

src/tinygp/gp.py

Lines changed: 25 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -11,21 +11,24 @@
1111
NamedTuple,
1212
)
1313

14+
import equinox as eqx
1415
import jax
1516
import jax.numpy as jnp
17+
import numpy as np
1618

1719
from tinygp import kernels, means
1820
from tinygp.helpers import JAXArray
1921
from tinygp.kernels.quasisep import Quasisep
2022
from tinygp.noise import Diagonal, Noise
2123
from tinygp.solvers import DirectSolver, QuasisepSolver
2224
from tinygp.solvers.quasisep.core import SymmQSM
25+
from tinygp.solvers.solver import Solver
2326

2427
if TYPE_CHECKING:
2528
from tinygp.numpyro_support import TinyDistribution
2629

2730

28-
class GaussianProcess:
31+
class GaussianProcess(eqx.Module):
2932
"""An interface for designing a Gaussian Process regression model
3033
3134
Args:
@@ -50,14 +53,23 @@ class GaussianProcess:
5053
algebra.
5154
"""
5255

56+
num_data: int = eqx.field(static=True)
57+
dtype: np.dtype = eqx.field(static=True)
58+
kernel: kernels.Kernel
59+
X: JAXArray
60+
mean_function: means.MeanBase
61+
mean: JAXArray
62+
noise: Noise
63+
solver: Solver
64+
5365
def __init__(
5466
self,
5567
kernel: kernels.Kernel,
5668
X: JAXArray,
5769
*,
5870
diag: JAXArray | None = None,
5971
noise: Noise | None = None,
60-
mean: Callable[[JAXArray], JAXArray] | JAXArray | None = None,
72+
mean: means.MeanBase | Callable[[JAXArray], JAXArray] | JAXArray | None = None,
6173
solver: Any | None = None,
6274
mean_value: JAXArray | None = None,
6375
covariance_value: Any | None = None,
@@ -66,7 +78,7 @@ def __init__(
6678
self.kernel = kernel
6779
self.X = X
6880

69-
if callable(mean):
81+
if isinstance(mean, means.MeanBase):
7082
self.mean_function = mean
7183
elif mean is None:
7284
self.mean_function = means.Mean(jnp.zeros(()))
@@ -76,7 +88,7 @@ def __init__(
7688
mean_value = jax.vmap(self.mean_function)(self.X)
7789
self.num_data = mean_value.shape[0]
7890
self.dtype = mean_value.dtype
79-
self.loc = self.mean = mean_value
91+
self.mean = mean_value
8092
if self.mean.ndim != 1:
8193
raise ValueError(
8294
"Invalid mean shape: " f"expected ndim = 1, got ndim={self.mean.ndim}"
@@ -92,14 +104,18 @@ def __init__(
92104
solver = QuasisepSolver
93105
else:
94106
solver = DirectSolver
95-
self.solver = solver.init(
107+
self.solver = solver(
96108
kernel,
97109
self.X,
98110
self.noise,
99111
covariance=covariance_value,
100112
**solver_kwargs,
101113
)
102114

115+
@property
116+
def loc(self) -> JAXArray:
117+
return self.mean
118+
103119
@property
104120
def variance(self) -> JAXArray:
105121
return self.solver.variance()
@@ -209,7 +225,6 @@ def condition(
209225

210226
@partial(
211227
jax.jit,
212-
static_argnums=(0,),
213228
static_argnames=("include_mean", "return_var", "return_cov"),
214229
)
215230
def predict(
@@ -281,7 +296,7 @@ def numpyro_dist(self, **kwargs: Any) -> TinyDistribution:
281296

282297
return TinyDistribution(self, **kwargs)
283298

284-
@partial(jax.jit, static_argnums=(0, 2))
299+
@partial(jax.jit, static_argnums=(2,))
285300
def _sample(
286301
self,
287302
key: jax.random.KeyArray,
@@ -296,16 +311,16 @@ def _sample(
296311
self.solver.dot_triangular(normal_samples), 0, -1
297312
)
298313

299-
@partial(jax.jit, static_argnums=0)
314+
@jax.jit
300315
def _compute_log_prob(self, alpha: JAXArray) -> JAXArray:
301316
loglike = -0.5 * jnp.sum(jnp.square(alpha)) - self.solver.normalization()
302317
return jnp.where(jnp.isfinite(loglike), loglike, -jnp.inf)
303318

304-
@partial(jax.jit, static_argnums=0)
319+
@jax.jit
305320
def _get_alpha(self, y: JAXArray) -> JAXArray:
306321
return self.solver.solve_triangular(y - self.loc)
307322

308-
@partial(jax.jit, static_argnums=(0, 3))
323+
@partial(jax.jit, static_argnums=(3,))
309324
def _condition(
310325
self,
311326
y: JAXArray,

src/tinygp/helpers.py

Lines changed: 7 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -2,73 +2,18 @@
22

33
__all__ = ["JAXArray", "dataclass", "field"]
44

5-
import dataclasses
6-
from typing import Any, Callable, TypeVar, Union
5+
from typing import Any
76

7+
import equinox as eqx
88
import jax
9-
import jax.numpy as jnp
10-
import numpy as np
119

12-
JAXArray = Union[np.ndarray, jnp.ndarray]
10+
JAXArray = jax.Array
1311

14-
# This section is based closely on the implementation in flax:
15-
#
16-
# https://github.com/google/flax/blob/b60f7f45b90f8fc42a88b1639c9cc88a40b298d3/flax/struct.py
17-
#
18-
# This decorator is interpreted by static analysis tools as a hint
19-
# that a decorator or metaclass causes dataclass-like behavior.
20-
# See https://github.com/microsoft/pyright/blob/main/specs/dataclass_transforms.md
21-
# for more information about the __dataclass_transform__ magic.
22-
_T = TypeVar("_T")
2312

13+
# The following is just for backwards compatibility since tinygp used to provide a
14+
# custom dataclass implementation
15+
field = eqx.field
2416

25-
def __dataclass_transform__(
26-
*,
27-
eq_default: bool = True,
28-
order_default: bool = False,
29-
kw_only_default: bool = False,
30-
field_descriptors: tuple[type | Callable[..., Any], ...] = (()),
31-
) -> Callable[[_T], _T]:
32-
# If used within a stub file, the following implementation can be
33-
# replaced with "...".
34-
return lambda a: a
3517

36-
37-
@__dataclass_transform__()
3818
def dataclass(clz: type[Any]) -> type[Any]:
39-
data_clz: Any = dataclasses.dataclass(frozen=True)(clz)
40-
meta_fields = []
41-
data_fields = []
42-
for name, field_info in data_clz.__dataclass_fields__.items():
43-
is_pytree_node = field_info.metadata.get("pytree_node", True)
44-
if is_pytree_node:
45-
data_fields.append(name)
46-
else:
47-
meta_fields.append(name)
48-
49-
def replace(self: Any, **updates: _T) -> _T:
50-
return dataclasses.replace(self, **updates)
51-
52-
data_clz.replace = replace
53-
54-
def iterate_clz(x: Any) -> tuple[tuple[Any, ...], tuple[Any, ...]]:
55-
meta = tuple(getattr(x, name) for name in meta_fields)
56-
data = tuple(getattr(x, name) for name in data_fields)
57-
return data, meta
58-
59-
def clz_from_iterable(meta: tuple[Any, ...], data: tuple[Any, ...]) -> Any:
60-
meta_args = tuple(zip(meta_fields, meta))
61-
data_args = tuple(zip(data_fields, data))
62-
kwargs = dict(meta_args + data_args)
63-
return data_clz(**kwargs)
64-
65-
jax.tree_util.register_pytree_node(data_clz, iterate_clz, clz_from_iterable)
66-
67-
# Hack to make this class act as a tuple when unpacked
68-
data_clz.iter_elems = lambda self: iterate_clz(self)[0].__iter__()
69-
70-
return data_clz
71-
72-
73-
def field(pytree_node: bool = True, **kwargs: Any) -> Any:
74-
return dataclasses.field(metadata={"pytree_node": pytree_node}, **kwargs)
19+
return clz

0 commit comments

Comments
 (0)