Skip to content

Commit 1ae552b

Browse files
committed
Merge branch 'carma' of https://github.com/ywx649999311/tinygp into ywx649999311-carma
2 parents fada5dd + f7a843a commit 1ae552b

File tree

2 files changed

+298
-68
lines changed

2 files changed

+298
-68
lines changed

src/tinygp/kernels/quasisep.py

Lines changed: 225 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -24,10 +24,11 @@
2424
"Matern52",
2525
"Cosine",
2626
"CARMA",
27+
"carma",
2728
]
2829

2930
from abc import ABCMeta, abstractmethod
30-
from typing import Any, Optional, Union
31+
from typing import Any, Optional, Tuple, Union
3132

3233
import jax
3334
import jax.numpy as jnp
@@ -39,6 +40,8 @@
3940
from tinygp.solvers.quasisep.core import DiagQSM, StrictLowerTriQSM, SymmQSM
4041
from tinygp.solvers.quasisep.general import GeneralQSM
4142

43+
eta = 1e-20 # avoid nan
44+
4245

4346
class Quasisep(Kernel, metaclass=ABCMeta):
4447
"""The base class for all quasiseparable kernels
@@ -643,21 +646,27 @@ class CARMA(Quasisep):
643646
644647
.. code-block:: python
645648
646-
kernel = CARMA.init(alpha=..., beta=..., sigma=...)
649+
kernel = CARMA.init(alpha=..., beta=...)
650+
651+
.. note::
652+
To fit a CARMA model with p > 2, the :func:`from_fpoly` method needs to
653+
be used to construct a valid model.
647654
"""
648655
alpha: JAXArray
649656
beta: JAXArray
650657
sigma: JAXArray
651-
roots: JAXArray
652-
proj: JAXArray
653-
proj_inv: JAXArray
654-
stn: JAXArray
658+
arroots: JAXArray
659+
acf: JAXArray
660+
real_mask: JAXArray
661+
complex_mask: JAXArray
662+
complex_select: JAXArray
663+
obsmodel: JAXArray
655664

656665
@classmethod
657666
def init(
658667
cls, alpha: JAXArray, beta: JAXArray, sigma: Optional[JAXArray] = None
659668
) -> "CARMA":
660-
r"""Construct a CARMA kernel
669+
r"""Construct a CARMA kernel using the alpha, beta parameters
661670
662671
Args:
663672
alpha: The parameter :math:`\alpha` in the definition above. This
@@ -674,78 +683,229 @@ def init(
674683
p = alpha.shape[0]
675684
assert beta.shape[0] <= p
676685

677-
# We find the roots of the autoregressive polynomial as a means to find
678-
# the eigendecomposition of the design matrix.
679-
alpha_ext = jnp.append(alpha, 1.0)
680-
roots = jnp.roots(alpha_ext[::-1])
681-
proj = roots[:, None] ** jnp.arange(p)[None, :]
682-
proj_inv = jnp.linalg.inv(proj)
683-
684-
# Compute the stationary covariance - there is almost certainly a more
685-
# elegant way, but this works! I worked this out kind of by trial and
686-
# error using sympy. There is a lot of known structure in the P_inf
687-
# matrix that can be exploited to "simplify" this calculation.
688-
# Specifically, there are only `p` degrees of freedom, and P_inf has the
689-
# following structure:
690-
#
691-
# P_inf = [
692-
# [ p0 0 -p1 0 p2 ]
693-
# [ 0 p1 0 -p2 0 ]
694-
# [-p1 0 p2 0 -p3 ]
695-
# [ 0 -p2 0 p3 0 ]
696-
# [ p2 0 -p3 0 p4 ]
697-
# ]
698-
#
699-
# Using this structure, we get can solve the usual:
700-
#
701-
# A @ P + P @ A.T + L @ L.T = 0
702-
#
703-
# for `P`, and we get something like the following. Kelly et al. (2104)
704-
# also have an expression for this (their V_{ij}), but I prefer to use
705-
# this since it is probably roughly just as fast to compute, and it is
706-
# strictly real-valued.
707-
f = 2 * ((np.arange(2 * p) // 2) % 2) - 1
708-
x = f * jnp.append(alpha_ext, jnp.zeros(p - 1))
709-
params = jnp.stack([np.roll(x, k)[::2] for k in range(p)], axis=0)
710-
params = jnp.linalg.solve(
711-
params, 0.5 * sigma**2 * jnp.eye(p, 1, k=-p + 1)
712-
)[:, 0]
713-
stn_ = []
714-
for j in range(p):
715-
stn_.append([jnp.zeros(()) for _ in range(p)])
716-
for n, k in enumerate(range(j - 2, -1, -2)):
717-
stn_[-1][k] = (2 * (n % 2) - 1) * params[j - n - 1]
718-
for n, k in enumerate(range(j, p, 2)):
719-
stn_[-1][k] = (1 - 2 * (n % 2)) * params[n + j]
720-
stn = jnp.array(list(map(jnp.stack, stn_)))
686+
# find acf
687+
arroots = CARMA.roots(jnp.append(alpha, 1.0))
688+
acf = CARMA.carma_acf(arroots, alpha, beta * sigma)
689+
# masks for selecting entries in matrixes
690+
real_mask = jnp.where(arroots.imag == 0.0, jnp.ones(p), jnp.zeros(p))
691+
complex_mask = -real_mask + 1
692+
complex_idx = jnp.cumsum(-real_mask + 1) * complex_mask
693+
complex_select = complex_mask * complex_idx % 2
694+
695+
# compute obsmodel
696+
om_real = jnp.sqrt(jnp.abs(acf.real))
697+
a, b, c, d = (
698+
2 * acf.real * complex_mask,
699+
2 * acf.imag * complex_mask,
700+
-arroots.real * complex_mask,
701+
-arroots.imag * complex_mask,
702+
)
703+
c2 = jnp.square(c)
704+
d2 = jnp.square(d)
705+
s2 = c2 + d2
706+
h2_2 = d2 * (a * c - b * d) / (2 * c * s2 + eta * real_mask)
707+
h2 = jnp.sqrt(h2_2)
708+
h1 = (c * h2 - jnp.sqrt(a * d2 - s2 * h2_2)) / (d + eta * real_mask)
709+
om_complex = jnp.array([h1, h2])
710+
obsmodel = (om_real * real_mask) + jnp.ravel(om_complex)[
711+
::2
712+
] * complex_mask
721713

722714
return cls(
723-
sigma=sigma,
724715
alpha=alpha,
725716
beta=beta,
726-
roots=roots,
727-
proj=proj,
728-
proj_inv=proj_inv,
729-
stn=stn,
717+
sigma=sigma,
718+
arroots=arroots,
719+
acf=acf,
720+
real_mask=real_mask,
721+
complex_mask=complex_mask,
722+
complex_select=complex_select,
723+
obsmodel=obsmodel,
724+
)
725+
726+
@classmethod
727+
def from_fpoly(
728+
cls, alpha_fpoly: JAXArray, beta_fpoly: JAXArray, beta_mult: JAXArray
729+
) -> "CARMA":
730+
"""Construct a CARMA kernel using the roots of the characteristic polynomials
731+
732+
The roots can be re-parameterized as the coefficients of a product
733+
of quadratic equations each with the second-order term set to 1. The
734+
input for this constructor are said coefficients. See Equation 30 in
735+
the paper linked above for a reference.
736+
737+
Args:
738+
alpha_fpoly: The coefficients of the auto-regressive quadratic
739+
equations corresponding to the alpha parameters.
740+
beta_fpoly: The coefficients of the moving-average quadratic
741+
equations corresponding to the beta parameters.
742+
beta_mult: Equivalent to beta[-1] used in the init constructor.
743+
"""
744+
745+
alpha_fpoly = jnp.atleast_1d(alpha_fpoly)
746+
beta_fpoly = jnp.atleast_1d(beta_fpoly)
747+
beta_mult = jnp.atleast_1d(beta_mult)
748+
749+
alpha = CARMA.fpoly2poly(jnp.append(alpha_fpoly, jnp.array([1.0])))[
750+
:-1
751+
]
752+
beta = CARMA.fpoly2poly(jnp.append(beta_fpoly, beta_mult))
753+
754+
return CARMA.init(alpha, beta)
755+
756+
@staticmethod
757+
@jax.jit
758+
def roots(poly_coeffs: JAXArray) -> JAXArray:
759+
roots = jnp.roots(poly_coeffs[::-1], strip_zeros=False)
760+
return roots[jnp.argsort(roots.real)]
761+
762+
@staticmethod
763+
@jax.jit
764+
def fpoly2poly(fpoly_coeffs: JAXArray) -> JAXArray:
765+
"""Expand the factorized characteristic polynomial"""
766+
767+
size = fpoly_coeffs.shape[0] - 1
768+
remain = size % 2
769+
nPair = size // 2
770+
mult_f = fpoly_coeffs[
771+
-1:
772+
] # The coeff of highest order term in the output
773+
774+
poly = jax.lax.cond(
775+
remain == 1,
776+
lambda x: jnp.array([1.0, x]),
777+
lambda x: jnp.array([0.0, 1.0]),
778+
fpoly_coeffs[-2],
730779
)
780+
poly = poly[-remain + 1 :]
781+
782+
for p in jnp.arange(nPair):
783+
poly = jnp.convolve(
784+
poly,
785+
jnp.append(
786+
jnp.array([fpoly_coeffs[p * 2], fpoly_coeffs[p * 2 + 1]]),
787+
jnp.ones((1,)),
788+
)[::-1],
789+
)
790+
791+
# the returned is low->high following Kelly+14
792+
return poly[::-1] * mult_f
793+
794+
@staticmethod
795+
def poly2fpoly(poly_coeffs: JAXArray) -> Tuple[JAXArray, JAXArray]:
796+
"""Factorize a polynomial into product of quadratic equations"""
797+
798+
fpoly = jnp.empty((0))
799+
mult_f = poly_coeffs[-1]
800+
roots = CARMA.roots(poly_coeffs / mult_f)
801+
odd = bool(len(roots) & 0x1)
802+
803+
rootsComp = roots[roots.imag != 0]
804+
rootsReal = roots[roots.imag == 0]
805+
nCompPair = len(rootsComp) // 2
806+
nRealPair = len(rootsReal) // 2
807+
808+
for i in range(nCompPair):
809+
root1 = rootsComp[i]
810+
root2 = rootsComp[i + 1]
811+
fpoly = jnp.append(fpoly, (root1 * root2).real)
812+
fpoly = jnp.append(fpoly, -(root1.real + root2.real))
813+
814+
for i in range(nRealPair):
815+
root1 = rootsReal[i]
816+
root2 = rootsReal[i + 1]
817+
fpoly = jnp.append(fpoly, (root1 * root2).real)
818+
fpoly = jnp.append(fpoly, -(root1.real + root2.real))
819+
820+
if odd:
821+
fpoly = jnp.append(fpoly, -rootsReal[-1].real)
822+
823+
return fpoly, jnp.array(mult_f)
824+
825+
@staticmethod
826+
def carma_acf(
827+
arroots: JAXArray, arparam: JAXArray, maparam: JAXArray
828+
) -> JAXArray:
829+
"""Get ACVF coefficients given CARMA parameters
830+
831+
Args:
832+
arroots (array(complex)): AR roots in a numpy array
833+
arparam (array(float)): AR parameters in a numpy array
834+
maparam (array(float)): MA parameters in a numpy array
835+
Returns:
836+
array(complex): ACVF coefficients, each element correspond to a root.
837+
"""
838+
arparam = jnp.atleast_1d(arparam)
839+
maparam = jnp.atleast_1d(maparam)
840+
p = arparam.shape[0]
841+
q = maparam.shape[0] - 1
842+
sigma = maparam[0]
843+
844+
# MA param into Kelly's notation
845+
maparam = maparam / sigma
846+
847+
# init acf product terms
848+
num_left = jnp.zeros(p, dtype=jnp.complex128)
849+
num_right = jnp.zeros(p, dtype=jnp.complex128)
850+
denom = -2 * arroots.real + jnp.zeros_like(arroots) * 1j
851+
852+
for k in range(q + 1):
853+
num_left += maparam[k] * jnp.power(arroots, k)
854+
num_right += maparam[k] * jnp.power(jnp.negative(arroots), k)
855+
856+
root_idx = jnp.arange(p)
857+
for j in range(1, p):
858+
root_k = arroots[jnp.roll(root_idx, j)]
859+
denom *= (root_k - arroots) * (jnp.conj(root_k) + arroots)
860+
861+
return sigma**2 * num_left * num_right / denom
731862

732863
def design_matrix(self) -> JAXArray:
733-
p = self.alpha.shape[0]
734-
return jnp.concatenate((jnp.eye(p - 1, p, k=1), -self.alpha[None]))
864+
dm_real = jnp.diag(self.arroots.real * self.real_mask)
865+
dm_complex_diag = jnp.diag(self.arroots.real * self.complex_mask)
866+
dm_complex_u = jnp.diag(
867+
(self.arroots.imag * self.complex_select)[:-1], k=1
868+
)
869+
870+
return dm_real + dm_complex_diag + -dm_complex_u.T + dm_complex_u
735871

736872
def stationary_covariance(self) -> JAXArray:
737-
return self.stn
873+
p = self.acf.shape[0]
874+
diag = jnp.diag(
875+
jnp.where(self.acf.real > 0, jnp.ones(p), -jnp.ones(p))
876+
)
877+
diag_complex = jnp.diag(
878+
(
879+
2
880+
* jnp.square(-self.arroots.real)
881+
/ jnp.square(-self.arroots.imag + eta)
882+
)
883+
* jnp.roll(self.complex_select, 1)
884+
* self.complex_mask
885+
)
886+
c_over_d = self.arroots.real / (self.arroots.imag + eta)
887+
sc_complex_u = jnp.diag((-c_over_d * self.complex_select)[:-1], k=1)
888+
889+
return diag + diag_complex + sc_complex_u + sc_complex_u.T
738890

739891
def observation_model(self, X: JAXArray) -> JAXArray:
740-
return jnp.append(
741-
self.beta, jnp.zeros(self.alpha.shape[0] - self.beta.shape[0])
742-
)
892+
return self.obsmodel
743893

744894
def transition_matrix(self, X1: JAXArray, X2: JAXArray) -> JAXArray:
745895
dt = X2 - X1
746-
return (
747-
self.proj_inv @ (jnp.exp(self.roots * dt)[:, None] * self.proj)
748-
).real
896+
c = -self.arroots.real
897+
d = -self.arroots.imag
898+
decay = jnp.exp(-c * dt)
899+
sin = jnp.sin(d * dt)
900+
901+
tm_real = jnp.diag(decay * self.real_mask)
902+
tm_complex_diag = jnp.diag(decay * jnp.cos(d * dt) * self.complex_mask)
903+
tm_complex_u = jnp.diag(
904+
(decay * sin * self.complex_select)[:-1],
905+
k=1,
906+
)
907+
908+
return tm_real + tm_complex_diag + -tm_complex_u.T + tm_complex_u
749909

750910

751911
def _prod_helper(a1: JAXArray, a2: JAXArray) -> JAXArray:

0 commit comments

Comments
 (0)