24
24
"Matern52" ,
25
25
"Cosine" ,
26
26
"CARMA" ,
27
+ "carma" ,
27
28
]
28
29
29
30
from abc import ABCMeta , abstractmethod
30
- from typing import Any , Optional , Union
31
+ from typing import Any , Optional , Tuple , Union
31
32
32
33
import jax
33
34
import jax .numpy as jnp
39
40
from tinygp .solvers .quasisep .core import DiagQSM , StrictLowerTriQSM , SymmQSM
40
41
from tinygp .solvers .quasisep .general import GeneralQSM
41
42
43
+ eta = 1e-20 # avoid nan
44
+
42
45
43
46
class Quasisep (Kernel , metaclass = ABCMeta ):
44
47
"""The base class for all quasiseparable kernels
@@ -643,21 +646,27 @@ class CARMA(Quasisep):
643
646
644
647
.. code-block:: python
645
648
646
- kernel = CARMA.init(alpha=..., beta=..., sigma=...)
649
+ kernel = CARMA.init(alpha=..., beta=...)
650
+
651
+ .. note::
652
+ To fit a CARMA model with p > 2, the :func:`from_fpoly` method needs to
653
+ be used to construct a valid model.
647
654
"""
648
655
alpha : JAXArray
649
656
beta : JAXArray
650
657
sigma : JAXArray
651
- roots : JAXArray
652
- proj : JAXArray
653
- proj_inv : JAXArray
654
- stn : JAXArray
658
+ arroots : JAXArray
659
+ acf : JAXArray
660
+ real_mask : JAXArray
661
+ complex_mask : JAXArray
662
+ complex_select : JAXArray
663
+ obsmodel : JAXArray
655
664
656
665
@classmethod
657
666
def init (
658
667
cls , alpha : JAXArray , beta : JAXArray , sigma : Optional [JAXArray ] = None
659
668
) -> "CARMA" :
660
- r"""Construct a CARMA kernel
669
+ r"""Construct a CARMA kernel using the alpha, beta parameters
661
670
662
671
Args:
663
672
alpha: The parameter :math:`\alpha` in the definition above. This
@@ -674,78 +683,229 @@ def init(
674
683
p = alpha .shape [0 ]
675
684
assert beta .shape [0 ] <= p
676
685
677
- # We find the roots of the autoregressive polynomial as a means to find
678
- # the eigendecomposition of the design matrix.
679
- alpha_ext = jnp .append (alpha , 1.0 )
680
- roots = jnp .roots (alpha_ext [::- 1 ])
681
- proj = roots [:, None ] ** jnp .arange (p )[None , :]
682
- proj_inv = jnp .linalg .inv (proj )
683
-
684
- # Compute the stationary covariance - there is almost certainly a more
685
- # elegant way, but this works! I worked this out kind of by trial and
686
- # error using sympy. There is a lot of known structure in the P_inf
687
- # matrix that can be exploited to "simplify" this calculation.
688
- # Specifically, there are only `p` degrees of freedom, and P_inf has the
689
- # following structure:
690
- #
691
- # P_inf = [
692
- # [ p0 0 -p1 0 p2 ]
693
- # [ 0 p1 0 -p2 0 ]
694
- # [-p1 0 p2 0 -p3 ]
695
- # [ 0 -p2 0 p3 0 ]
696
- # [ p2 0 -p3 0 p4 ]
697
- # ]
698
- #
699
- # Using this structure, we get can solve the usual:
700
- #
701
- # A @ P + P @ A.T + L @ L.T = 0
702
- #
703
- # for `P`, and we get something like the following. Kelly et al. (2104)
704
- # also have an expression for this (their V_{ij}), but I prefer to use
705
- # this since it is probably roughly just as fast to compute, and it is
706
- # strictly real-valued.
707
- f = 2 * ((np .arange (2 * p ) // 2 ) % 2 ) - 1
708
- x = f * jnp .append (alpha_ext , jnp .zeros (p - 1 ))
709
- params = jnp .stack ([np .roll (x , k )[::2 ] for k in range (p )], axis = 0 )
710
- params = jnp .linalg .solve (
711
- params , 0.5 * sigma ** 2 * jnp .eye (p , 1 , k = - p + 1 )
712
- )[:, 0 ]
713
- stn_ = []
714
- for j in range (p ):
715
- stn_ .append ([jnp .zeros (()) for _ in range (p )])
716
- for n , k in enumerate (range (j - 2 , - 1 , - 2 )):
717
- stn_ [- 1 ][k ] = (2 * (n % 2 ) - 1 ) * params [j - n - 1 ]
718
- for n , k in enumerate (range (j , p , 2 )):
719
- stn_ [- 1 ][k ] = (1 - 2 * (n % 2 )) * params [n + j ]
720
- stn = jnp .array (list (map (jnp .stack , stn_ )))
686
+ # find acf
687
+ arroots = CARMA .roots (jnp .append (alpha , 1.0 ))
688
+ acf = CARMA .carma_acf (arroots , alpha , beta * sigma )
689
+ # masks for selecting entries in matrixes
690
+ real_mask = jnp .where (arroots .imag == 0.0 , jnp .ones (p ), jnp .zeros (p ))
691
+ complex_mask = - real_mask + 1
692
+ complex_idx = jnp .cumsum (- real_mask + 1 ) * complex_mask
693
+ complex_select = complex_mask * complex_idx % 2
694
+
695
+ # compute obsmodel
696
+ om_real = jnp .sqrt (jnp .abs (acf .real ))
697
+ a , b , c , d = (
698
+ 2 * acf .real * complex_mask ,
699
+ 2 * acf .imag * complex_mask ,
700
+ - arroots .real * complex_mask ,
701
+ - arroots .imag * complex_mask ,
702
+ )
703
+ c2 = jnp .square (c )
704
+ d2 = jnp .square (d )
705
+ s2 = c2 + d2
706
+ h2_2 = d2 * (a * c - b * d ) / (2 * c * s2 + eta * real_mask )
707
+ h2 = jnp .sqrt (h2_2 )
708
+ h1 = (c * h2 - jnp .sqrt (a * d2 - s2 * h2_2 )) / (d + eta * real_mask )
709
+ om_complex = jnp .array ([h1 , h2 ])
710
+ obsmodel = (om_real * real_mask ) + jnp .ravel (om_complex )[
711
+ ::2
712
+ ] * complex_mask
721
713
722
714
return cls (
723
- sigma = sigma ,
724
715
alpha = alpha ,
725
716
beta = beta ,
726
- roots = roots ,
727
- proj = proj ,
728
- proj_inv = proj_inv ,
729
- stn = stn ,
717
+ sigma = sigma ,
718
+ arroots = arroots ,
719
+ acf = acf ,
720
+ real_mask = real_mask ,
721
+ complex_mask = complex_mask ,
722
+ complex_select = complex_select ,
723
+ obsmodel = obsmodel ,
724
+ )
725
+
726
+ @classmethod
727
+ def from_fpoly (
728
+ cls , alpha_fpoly : JAXArray , beta_fpoly : JAXArray , beta_mult : JAXArray
729
+ ) -> "CARMA" :
730
+ """Construct a CARMA kernel using the roots of the characteristic polynomials
731
+
732
+ The roots can be re-parameterized as the coefficients of a product
733
+ of quadratic equations each with the second-order term set to 1. The
734
+ input for this constructor are said coefficients. See Equation 30 in
735
+ the paper linked above for a reference.
736
+
737
+ Args:
738
+ alpha_fpoly: The coefficients of the auto-regressive quadratic
739
+ equations corresponding to the alpha parameters.
740
+ beta_fpoly: The coefficients of the moving-average quadratic
741
+ equations corresponding to the beta parameters.
742
+ beta_mult: Equivalent to beta[-1] used in the init constructor.
743
+ """
744
+
745
+ alpha_fpoly = jnp .atleast_1d (alpha_fpoly )
746
+ beta_fpoly = jnp .atleast_1d (beta_fpoly )
747
+ beta_mult = jnp .atleast_1d (beta_mult )
748
+
749
+ alpha = CARMA .fpoly2poly (jnp .append (alpha_fpoly , jnp .array ([1.0 ])))[
750
+ :- 1
751
+ ]
752
+ beta = CARMA .fpoly2poly (jnp .append (beta_fpoly , beta_mult ))
753
+
754
+ return CARMA .init (alpha , beta )
755
+
756
+ @staticmethod
757
+ @jax .jit
758
+ def roots (poly_coeffs : JAXArray ) -> JAXArray :
759
+ roots = jnp .roots (poly_coeffs [::- 1 ], strip_zeros = False )
760
+ return roots [jnp .argsort (roots .real )]
761
+
762
+ @staticmethod
763
+ @jax .jit
764
+ def fpoly2poly (fpoly_coeffs : JAXArray ) -> JAXArray :
765
+ """Expand the factorized characteristic polynomial"""
766
+
767
+ size = fpoly_coeffs .shape [0 ] - 1
768
+ remain = size % 2
769
+ nPair = size // 2
770
+ mult_f = fpoly_coeffs [
771
+ - 1 :
772
+ ] # The coeff of highest order term in the output
773
+
774
+ poly = jax .lax .cond (
775
+ remain == 1 ,
776
+ lambda x : jnp .array ([1.0 , x ]),
777
+ lambda x : jnp .array ([0.0 , 1.0 ]),
778
+ fpoly_coeffs [- 2 ],
730
779
)
780
+ poly = poly [- remain + 1 :]
781
+
782
+ for p in jnp .arange (nPair ):
783
+ poly = jnp .convolve (
784
+ poly ,
785
+ jnp .append (
786
+ jnp .array ([fpoly_coeffs [p * 2 ], fpoly_coeffs [p * 2 + 1 ]]),
787
+ jnp .ones ((1 ,)),
788
+ )[::- 1 ],
789
+ )
790
+
791
+ # the returned is low->high following Kelly+14
792
+ return poly [::- 1 ] * mult_f
793
+
794
+ @staticmethod
795
+ def poly2fpoly (poly_coeffs : JAXArray ) -> Tuple [JAXArray , JAXArray ]:
796
+ """Factorize a polynomial into product of quadratic equations"""
797
+
798
+ fpoly = jnp .empty ((0 ))
799
+ mult_f = poly_coeffs [- 1 ]
800
+ roots = CARMA .roots (poly_coeffs / mult_f )
801
+ odd = bool (len (roots ) & 0x1 )
802
+
803
+ rootsComp = roots [roots .imag != 0 ]
804
+ rootsReal = roots [roots .imag == 0 ]
805
+ nCompPair = len (rootsComp ) // 2
806
+ nRealPair = len (rootsReal ) // 2
807
+
808
+ for i in range (nCompPair ):
809
+ root1 = rootsComp [i ]
810
+ root2 = rootsComp [i + 1 ]
811
+ fpoly = jnp .append (fpoly , (root1 * root2 ).real )
812
+ fpoly = jnp .append (fpoly , - (root1 .real + root2 .real ))
813
+
814
+ for i in range (nRealPair ):
815
+ root1 = rootsReal [i ]
816
+ root2 = rootsReal [i + 1 ]
817
+ fpoly = jnp .append (fpoly , (root1 * root2 ).real )
818
+ fpoly = jnp .append (fpoly , - (root1 .real + root2 .real ))
819
+
820
+ if odd :
821
+ fpoly = jnp .append (fpoly , - rootsReal [- 1 ].real )
822
+
823
+ return fpoly , jnp .array (mult_f )
824
+
825
+ @staticmethod
826
+ def carma_acf (
827
+ arroots : JAXArray , arparam : JAXArray , maparam : JAXArray
828
+ ) -> JAXArray :
829
+ """Get ACVF coefficients given CARMA parameters
830
+
831
+ Args:
832
+ arroots (array(complex)): AR roots in a numpy array
833
+ arparam (array(float)): AR parameters in a numpy array
834
+ maparam (array(float)): MA parameters in a numpy array
835
+ Returns:
836
+ array(complex): ACVF coefficients, each element correspond to a root.
837
+ """
838
+ arparam = jnp .atleast_1d (arparam )
839
+ maparam = jnp .atleast_1d (maparam )
840
+ p = arparam .shape [0 ]
841
+ q = maparam .shape [0 ] - 1
842
+ sigma = maparam [0 ]
843
+
844
+ # MA param into Kelly's notation
845
+ maparam = maparam / sigma
846
+
847
+ # init acf product terms
848
+ num_left = jnp .zeros (p , dtype = jnp .complex128 )
849
+ num_right = jnp .zeros (p , dtype = jnp .complex128 )
850
+ denom = - 2 * arroots .real + jnp .zeros_like (arroots ) * 1j
851
+
852
+ for k in range (q + 1 ):
853
+ num_left += maparam [k ] * jnp .power (arroots , k )
854
+ num_right += maparam [k ] * jnp .power (jnp .negative (arroots ), k )
855
+
856
+ root_idx = jnp .arange (p )
857
+ for j in range (1 , p ):
858
+ root_k = arroots [jnp .roll (root_idx , j )]
859
+ denom *= (root_k - arroots ) * (jnp .conj (root_k ) + arroots )
860
+
861
+ return sigma ** 2 * num_left * num_right / denom
731
862
732
863
def design_matrix (self ) -> JAXArray :
733
- p = self .alpha .shape [0 ]
734
- return jnp .concatenate ((jnp .eye (p - 1 , p , k = 1 ), - self .alpha [None ]))
864
+ dm_real = jnp .diag (self .arroots .real * self .real_mask )
865
+ dm_complex_diag = jnp .diag (self .arroots .real * self .complex_mask )
866
+ dm_complex_u = jnp .diag (
867
+ (self .arroots .imag * self .complex_select )[:- 1 ], k = 1
868
+ )
869
+
870
+ return dm_real + dm_complex_diag + - dm_complex_u .T + dm_complex_u
735
871
736
872
def stationary_covariance (self ) -> JAXArray :
737
- return self .stn
873
+ p = self .acf .shape [0 ]
874
+ diag = jnp .diag (
875
+ jnp .where (self .acf .real > 0 , jnp .ones (p ), - jnp .ones (p ))
876
+ )
877
+ diag_complex = jnp .diag (
878
+ (
879
+ 2
880
+ * jnp .square (- self .arroots .real )
881
+ / jnp .square (- self .arroots .imag + eta )
882
+ )
883
+ * jnp .roll (self .complex_select , 1 )
884
+ * self .complex_mask
885
+ )
886
+ c_over_d = self .arroots .real / (self .arroots .imag + eta )
887
+ sc_complex_u = jnp .diag ((- c_over_d * self .complex_select )[:- 1 ], k = 1 )
888
+
889
+ return diag + diag_complex + sc_complex_u + sc_complex_u .T
738
890
739
891
def observation_model (self , X : JAXArray ) -> JAXArray :
740
- return jnp .append (
741
- self .beta , jnp .zeros (self .alpha .shape [0 ] - self .beta .shape [0 ])
742
- )
892
+ return self .obsmodel
743
893
744
894
def transition_matrix (self , X1 : JAXArray , X2 : JAXArray ) -> JAXArray :
745
895
dt = X2 - X1
746
- return (
747
- self .proj_inv @ (jnp .exp (self .roots * dt )[:, None ] * self .proj )
748
- ).real
896
+ c = - self .arroots .real
897
+ d = - self .arroots .imag
898
+ decay = jnp .exp (- c * dt )
899
+ sin = jnp .sin (d * dt )
900
+
901
+ tm_real = jnp .diag (decay * self .real_mask )
902
+ tm_complex_diag = jnp .diag (decay * jnp .cos (d * dt ) * self .complex_mask )
903
+ tm_complex_u = jnp .diag (
904
+ (decay * sin * self .complex_select )[:- 1 ],
905
+ k = 1 ,
906
+ )
907
+
908
+ return tm_real + tm_complex_diag + - tm_complex_u .T + tm_complex_u
749
909
750
910
751
911
def _prod_helper (a1 : JAXArray , a2 : JAXArray ) -> JAXArray :
0 commit comments