Skip to content

Commit cd53251

Browse files
elastufkailan-goldgomezzz
authored
Gaussian quadrature (#141)
* basic version of gauss-legendre * fstrings for my sanity * fstrings for my sanity * weights and points multidimensional * transform xi,wi correctly * basic version of gauss-legendre * fstrings for my sanity * fstrings for my sanity * weights and points multidimensional * transform xi,wi correctly * let function to integrate accept args, c.f. scipy.nquad * any edits * add numpy import * autoray * add Gaussian quadrature methods * fix import * change anp.inf to numpy.inf * fix interval transformation and clean up * make sure tensors are on same device * make sure tensors are on same devicepart 2 * make sure tensors are on same devicepart 3 * make sure tensors are on same devicepart 4 * make sure tensors are on same devicepart 5 * add special import * add tests to /tests * run autopep8, add docstring * (feat): cache for roots. * (feat): refactor out grid integration procedure * (feat): gaussian integration refactored, some tests passing * (fix): scaling constant * (chore): higher dim integrals testing * (feat): weights correct for multi-dim integrands. * (fix): correct number of argument. * (fix): remove non-legendre tests. * (fix): import GaussLegendre * (fix): ensure grid and weights are correct type * (style): docstrings. * (fix): default `grid_func` * (fix): `_cached_poi...` returns tuple, not ndarray * (fix): propagate `backend` correctly. * (chore): export base types * (feat): add jit for gausssian * (feat): backward diff * (fix): env issue * Fixed tests badge * (chore): cleanup * (fix): `intergal` -> `integral` * (chore): add tutorial * (fix): change to `argnums` to work around decorator * (fix): add fix from other PR * (feat): add (broken) tests for gauss jit * (chore): remove unused import * (fix): use `item` for `N` when `jit` with `jax` * (fix): `domain` for jit gauss `calculate_result` * (chore): `black` * (chore): erroneous diff * (chore): remove erroneous print * (fix): correct comment * (fix): clean up gaussian tests * (chore): add comments. * (chore): formatting * (fix): error of 1D integral * (fix): increase bounds. --------- Co-authored-by: ilan-gold <[email protected]> Co-authored-by: Pablo Gómez <[email protected]>
1 parent 64a0188 commit cd53251

14 files changed

+678
-248
lines changed

docs/source/tutorial.rst

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -756,4 +756,40 @@ Now let's see how to do this a bit more simply, and in a way that provides signf
756756
757757
torch.all(torch.isclose(result_vectorized, result)) # True!
758758
759+
Custom Integrators
760+
------------------
761+
762+
It is of course possible to extend our provided Integrators, perhaps for a special class of functions or for a new algorithm.
763+
764+
.. code:: ipython3
765+
766+
class GaussHermite(Gaussian):
767+
"""Gauss Hermite quadrature rule in torch, for integrals of the form :math:`\\int_{-\\infty}^{+\\infty} e^{-x^{2}} f(x) dx`. It will correctly integrate
768+
polynomials of degree :math:`2n - 1` or less over the interval
769+
:math:`[-\\infty, \\infty]` with weight function :math:`f(x) = e^{-x^2}`. See https://en.wikipedia.org/wiki/Gauss%E2%80%93Hermite_quadrature
770+
"""
771+
772+
def __init__(self):
773+
super().__init__()
774+
self.name = "Gauss-Hermite"
775+
self.root_fn = scipy.special.roots_hermite
776+
self.default_integration_domain = [[-1 * numpy.inf, numpy.inf]]
777+
self.wrapper_func = None
778+
779+
@staticmethod
780+
def _apply_composite_rule(cur_dim_areas, dim, hs, domain):
781+
"""Apply "composite" rule for gaussian integrals
782+
cur_dim_areas will contain the areas per dimension
783+
"""
784+
# We collapse dimension by dimension
785+
for _ in range(dim):
786+
cur_dim_areas = anp.sum(cur_dim_areas, axis=len(cur_dim_areas.shape) - 1)
787+
return cur_dim_areas
788+
789+
gh=torchquad.GaussHermite()
790+
integral=gh.integrate(lambda x: 1-x,dim=1,N=200) #integral from -inf to inf of np.exp(-(x**2))*(1-x)
791+
# Computed integral was 1.7724538509055168.
792+
# analytic result = sqrt(pi)
793+
794+
759795

torchquad/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,9 @@
99
from .integration.simpson import Simpson
1010
from .integration.boole import Boole
1111
from .integration.vegas import VEGAS
12+
from .integration.gaussian import GaussLegendre
13+
from .integration.grid_integrator import GridIntegrator
14+
from .integration.base_integrator import BaseIntegrator
1215

1316
from .integration.rng import RNG
1417

@@ -22,12 +25,15 @@
2225
from .utils.deployment_test import _deployment_test
2326

2427
__all__ = [
28+
"GridIntegrator",
29+
"BaseIntegrator",
2530
"IntegrationGrid",
2631
"MonteCarlo",
2732
"Trapezoid",
2833
"Simpson",
2934
"Boole",
3035
"VEGAS",
36+
"GaussLegendre",
3137
"RNG",
3238
"plot_convergence",
3339
"plot_runtime",

torchquad/integration/base_integrator.py

Lines changed: 27 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -29,30 +29,41 @@ def integrate(self):
2929
NotImplementedError("This is an abstract base class. Should not be called.")
3030
)
3131

32-
def _eval(self, points):
32+
def _eval(self, points, weights=None, args=None):
3333
"""Call evaluate_integrand to evaluate self._fn function at the passed points and update self._nr_of_evals
3434
3535
Args:
3636
points (backend tensor): Integration points
37+
weights (backend tensor, optional): Integration weights. Defaults to None.
38+
args (list or tuple, optional): Any arguments required by the function. Defaults to None.
3739
"""
38-
result, num_points = self.evaluate_integrand(self._fn, points)
40+
result, num_points = self.evaluate_integrand(
41+
self._fn, points, weights=weights, args=args
42+
)
3943
self._nr_of_fevals += num_points
4044
return result
4145

4246
@staticmethod
43-
def evaluate_integrand(fn, points):
47+
def evaluate_integrand(fn, points, weights=None, args=None):
4448
"""Evaluate the integrand function at the passed points
4549
4650
Args:
4751
fn (function): Integrand function
4852
points (backend tensor): Integration points
53+
weights (backend tensor, optional): Integration weights. Defaults to None.
54+
args (list or tuple, optional): Any arguments required by the function. Defaults to None.
4955
5056
Returns:
5157
backend tensor: Integrand function output
5258
int: Number of evaluated points
5359
"""
5460
num_points = points.shape[0]
55-
result = fn(points)
61+
62+
if args is None:
63+
args = ()
64+
65+
result = fn(points, *args)
66+
5667
if infer_backend(result) != infer_backend(points):
5768
warnings.warn(
5869
"The passed function's return value has a different numerical backend than the passed points. Will try to convert. Note that this may be slow as it results in memory transfers between CPU and GPU, if torchquad uses the GPU."
@@ -67,17 +78,27 @@ def evaluate_integrand(fn, points):
6778
f"where first dimension matches length of passed elements. "
6879
)
6980

81+
if weights is not None:
82+
if (
83+
len(result.shape) > 1
84+
): # if the the integrand is multi-dimensional, we need to reshape/repeat weights so they can be broadcast in the *=
85+
integrand_shape = anp.array(
86+
result.shape[1:], like=infer_backend(points)
87+
)
88+
weights = anp.repeat(
89+
anp.expand_dims(weights, axis=1), anp.prod(integrand_shape)
90+
).reshape((weights.shape[0], *(integrand_shape)))
91+
result *= weights
92+
7093
return result, num_points
7194

7295
@staticmethod
7396
def _check_inputs(dim=None, N=None, integration_domain=None):
7497
"""Used to check input validity
75-
7698
Args:
7799
dim (int, optional): Dimensionality of function to integrate. Defaults to None.
78100
N (int, optional): Total number of integration points. Defaults to None.
79101
integration_domain (list or backend tensor, optional): Integration domain, e.g. [[0,1],[1,2]]. Defaults to None.
80-
81102
Raises:
82103
ValueError: if inputs are not compatible with each other.
83104
"""

torchquad/integration/boole.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ def integrate(self, fn, dim, N=None, integration_domain=None, backend=None):
2828
return super().integrate(fn, dim, N, integration_domain, backend)
2929

3030
@staticmethod
31-
def _apply_composite_rule(cur_dim_areas, dim, hs):
31+
def _apply_composite_rule(cur_dim_areas, dim, hs, domain):
3232
"""Apply composite Boole quadrature.
3333
cur_dim_areas will contain the areas per dimension
3434
"""

torchquad/integration/gaussian.py

Lines changed: 152 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,152 @@
1+
import numpy
2+
from autoray import numpy as anp
3+
from .grid_integrator import GridIntegrator
4+
5+
6+
class Gaussian(GridIntegrator):
7+
"""Gaussian quadrature methods inherit from this. Default behaviour is Gauss-Legendre quadrature on [-1,1]."""
8+
9+
def __init__(self):
10+
super().__init__()
11+
self.name = "Gauss-Legendre"
12+
self.root_fn = numpy.polynomial.legendre.leggauss
13+
self.root_args = ()
14+
self.default_integration_domain = [[-1, 1]]
15+
self.transform_interval = True
16+
self._cache = {}
17+
18+
def integrate(self, fn, dim, N=8, integration_domain=None, backend=None):
19+
"""Integrates the passed function on the passed domain using Simpson's rule.
20+
21+
Args:
22+
fn (func): The function to integrate over.
23+
dim (int): Dimensionality of the integration domain.
24+
N (int, optional): Total number of sample points to use for the integration. Should be odd. Defaults to 3 points per dimension if None is given.
25+
integration_domain (list or backend tensor, optional): Integration domain, e.g. [[-1,1],[0,1]]. Defaults to [-1,1]^dim. It also determines the numerical backend if possible.
26+
backend (string, optional): Numerical backend. This argument is ignored if the backend can be inferred from integration_domain. Defaults to the backend from the latest call to set_up_backend or "torch" for backwards compatibility.
27+
28+
Returns:
29+
backend-specific number: Integral value
30+
"""
31+
return super().integrate(fn, dim, N, integration_domain, backend)
32+
33+
def _weights(self, N, dim, backend, requires_grad=False):
34+
"""return the weights, broadcast across the dimensions, generated from the polynomial of choice
35+
36+
Args:
37+
N (int): number of nodes
38+
dim (int): number of dimensions
39+
backend (string): which backend array to return
40+
41+
Returns:
42+
backend tensor: the weights
43+
"""
44+
weights = anp.array(self._cached_points_and_weights(N)[1], like=backend)
45+
if backend == "torch":
46+
weights.requires_grad = requires_grad
47+
return anp.prod(
48+
anp.array(
49+
anp.stack(
50+
list(anp.meshgrid(*([weights] * dim))), like=backend, dim=0
51+
)
52+
),
53+
axis=0,
54+
).ravel()
55+
else:
56+
return anp.prod(
57+
anp.meshgrid(*([weights] * dim), like=backend), axis=0
58+
).ravel()
59+
60+
def _roots(self, N, backend, requires_grad=False):
61+
"""return the roots generated from the polynomial of choice
62+
63+
Args:
64+
N (int): number of nodes
65+
backend (string): which backend array to return
66+
67+
Returns:
68+
backend tensor: the roots
69+
"""
70+
roots = anp.array(self._cached_points_and_weights(N)[0], like=backend)
71+
if requires_grad:
72+
roots.requires_grad = True
73+
return roots
74+
75+
@property
76+
def _grid_func(self):
77+
"""
78+
function for generating a grid to be integrated over i.e., the polynomial roots, resized to the domain.
79+
"""
80+
81+
def f(a, b, N, requires_grad, backend=None):
82+
return self._resize_roots(a, b, self._roots(N, backend, requires_grad))
83+
84+
return f
85+
86+
def _resize_roots(self, a, b, roots): # scale from [-1,1] to [a,b]
87+
"""resize the roots based on domain of [a,b]
88+
89+
Args:
90+
a (backend tensor): lower bound
91+
b (backend tensor): upper bound
92+
roots (backend tensor): polynomial nodes
93+
94+
Returns:
95+
backend tensor: rescaled roots
96+
"""
97+
return roots
98+
99+
# credit for the idea https://github.com/scipy/scipy/blob/dde50595862a4f9cede24b5d1c86935c30f1f88a/scipy/integrate/_quadrature.py#L72
100+
def _cached_points_and_weights(self, N):
101+
"""wrap the calls to get weights/roots in a cache
102+
103+
Args:
104+
N (int): number of nodes to return
105+
backend (string): which backend to use
106+
107+
Returns:
108+
tuple: nodes and weights
109+
"""
110+
root_args = (N, *self.root_args)
111+
if not isinstance(N, int):
112+
if hasattr(N, "item"):
113+
root_args = (N.item(), *self.root_args)
114+
else:
115+
raise NotImplementedError(
116+
f"N {N} is not an int and lacks an `item` method"
117+
)
118+
if root_args in self._cache:
119+
return self._cache[root_args]
120+
self._cache[root_args] = self.root_fn(*root_args)
121+
return self._cache[root_args]
122+
123+
@staticmethod
124+
def _apply_composite_rule(cur_dim_areas, dim, hs, domain):
125+
"""Apply "composite" rule for gaussian integrals
126+
127+
cur_dim_areas will contain the areas per dimension
128+
"""
129+
# We collapse dimension by dimension
130+
for cur_dim in range(dim):
131+
cur_dim_areas = (
132+
0.5
133+
* (domain[cur_dim][1] - domain[cur_dim][0])
134+
* anp.sum(cur_dim_areas, axis=len(cur_dim_areas.shape) - 1)
135+
)
136+
return cur_dim_areas
137+
138+
139+
class GaussLegendre(Gaussian):
140+
"""Gauss Legendre quadrature rule in torch. See https://en.wikipedia.org/wiki/Gaussian_quadrature#Gauss%E2%80%93Legendre_quadrature.
141+
142+
Examples
143+
--------
144+
>>> gl=torchquad.GaussLegendre()
145+
>>> integral = gl.integrate(lambda x:np.sin(x), dim=1, N=101, integration_domain=[[0,5]]) #integral from 0 to 5 of np.sin(x)
146+
|TQ-INFO| Computed integral was 0.7163378000259399 #analytic result = 1-np.cos(5)"""
147+
148+
def __init__(self):
149+
super().__init__()
150+
151+
def _resize_roots(self, a, b, roots): # scale from [-1,1] to [a,b]
152+
return ((b - a) / 2) * roots + ((a + b) / 2)

0 commit comments

Comments
 (0)