Skip to content

Commit 364e735

Browse files
authored
Merge pull request #209 from esa/main
main -> develop (Release 0.4.1)
2 parents 954c0ba + 87da956 commit 364e735

File tree

11 files changed

+119
-54
lines changed

11 files changed

+119
-54
lines changed

.github/workflows/autoblack.yml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,11 @@ jobs:
1010
runs-on: ubuntu-latest
1111
steps:
1212
- uses: actions/checkout@v1
13-
- name: Set up Python 3.8
13+
- name: Set up Python 3.11
1414
uses: actions/setup-python@v1
1515
with:
16-
python-version: 3.8
16+
python-version: 3.11
1717
- name: Install Black
18-
run: pip install black
18+
run: pip install black==24.4.2
1919
- name: Run black --check .
2020
run: black --check .

.github/workflows/run_tests.yml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ jobs:
4545
# # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide
4646
# flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics
4747
- name: provision-with-micromamba
48-
uses: mamba-org/provision-with-micromamba@main
48+
uses: mamba-org/setup-micromamba@v1
4949
with:
5050
environment-file: environment_all_backends.yml
5151
environment-name: torchquad
@@ -62,6 +62,7 @@ jobs:
6262
- name: pytest coverage comment
6363
uses: MishaKav/pytest-coverage-comment@main
6464
if: github.event_name == 'pull_request'
65+
continue-on-error: true
6566
with:
6667
pytest-coverage-path: ./torchquad/tests/pytest-coverage.txt
6768
title: Coverage Report

.readthedocs.yml

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,19 +5,17 @@
55
# Required
66
version: 2
77

8+
build:
9+
os: "ubuntu-22.04"
10+
tools:
11+
python: "mambaforge-22.9"
12+
813
# Build documentation in the docs/ directory with Sphinx
914
sphinx:
1015
configuration: docs/source/conf.py
1116

1217
# Optionally build your docs in additional formats such as PDF
1318
formats: all
1419

15-
# Optionally set the version of Python and requirements required to build your docs
16-
python:
17-
version: 3.8
18-
install:
19-
- method: setuptools
20-
path: .
21-
2220
conda:
23-
environment: rtd_environment.yml
21+
environment: rtd_environment.yml

docs/source/tutorial.rst

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -724,11 +724,11 @@ sample points for both functions:
724724
725725
# Integrate the first integrand with the sample points
726726
function_values, _ = integrator.evaluate_integrand(integrand1, grid_points)
727-
integral1 = integrator.calculate_result(function_values, dim, n_per_dim, hs)
727+
integral1 = integrator.calculate_result(function_values, dim, n_per_dim, hs, integration_domain)
728728
729729
# Integrate the second integrand with the same sample points
730730
function_values, _ = integrator.evaluate_integrand(integrand2, grid_points)
731-
integral2 = integrator.calculate_result(function_values, dim, n_per_dim, hs)
731+
integral2 = integrator.calculate_result(function_values, dim, n_per_dim, hs, integration_domain)
732732
733733
print(f"Quadrature results: {integral1}, {integral2}")
734734
@@ -745,7 +745,7 @@ As an example, here we evaluate a similar integrand many times for different val
745745
.. code:: ipython3
746746
747747
def parametrized_integrand(x, a, b):
748-
return torch.sqrt(torch.cos(torch.sin((a + b) * x)))
748+
return torch.sqrt(torch.cos(torch.sin((a + b) * x)))
749749
750750
a_params = torch.arange(40)
751751
b_params = torch.arange(10, 20)

torchquad/integration/boole.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66

77

88
class Boole(NewtonCotes):
9-
109
"""Boole's rule. See https://en.wikipedia.org/wiki/Newton%E2%80%93Cotes_formulas#Closed_Newton%E2%80%93Cotes_formulas ."""
1110

1211
def __init__(self):

torchquad/integration/grid_integrator.py

Lines changed: 4 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
_linspace_with_grads,
88
expand_func_values_and_squeeze_integral,
99
_setup_integration_domain,
10+
_torch_trace_without_warnings,
1011
)
1112

1213

@@ -208,8 +209,6 @@ def compiled_integrate(fn, integration_domain):
208209
elif backend == "torch":
209210
# Torch requires explicit tracing with example inputs.
210211
def do_compile(example_integrand):
211-
import torch
212-
213212
# Define traceable first and third steps
214213
def step1(integration_domain):
215214
grid_points, hs, n_per_dim = self.calculate_grid(
@@ -218,7 +217,7 @@ def step1(integration_domain):
218217
return (
219218
grid_points,
220219
hs,
221-
torch.Tensor([n_per_dim]),
220+
anp.array([n_per_dim], like="torch"),
222221
) # n_per_dim is constant
223222

224223
dim = int(integration_domain.shape[0])
@@ -229,7 +228,7 @@ def step3(function_values, hs, integration_domain):
229228
)
230229

231230
# Trace the first step
232-
step1 = torch.jit.trace(step1, (integration_domain,))
231+
step1 = _torch_trace_without_warnings(step1, (integration_domain,))
233232

234233
# Get example input for the third step
235234
grid_points, hs, n_per_dim = step1(integration_domain)
@@ -241,15 +240,7 @@ def step3(function_values, hs, integration_domain):
241240
)
242241

243242
# Trace the third step
244-
# Avoid the warnings about a .grad attribute access of a
245-
# non-leaf Tensor
246-
if hs.requires_grad:
247-
hs = hs.detach()
248-
hs.requires_grad = True
249-
if function_values.requires_grad:
250-
function_values = function_values.detach()
251-
function_values.requires_grad = True
252-
step3 = torch.jit.trace(
243+
step3 = _torch_trace_without_warnings(
253244
step3, (function_values, hs, integration_domain)
254245
)
255246

torchquad/integration/monte_carlo.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,11 @@
33
from loguru import logger
44

55
from .base_integrator import BaseIntegrator
6-
from .utils import _setup_integration_domain, expand_func_values_and_squeeze_integral
6+
from .utils import (
7+
_setup_integration_domain,
8+
expand_func_values_and_squeeze_integral,
9+
_torch_trace_without_warnings,
10+
)
711
from .rng import RNG
812

913

@@ -195,8 +199,6 @@ def compiled_integrate(fn, integration_domain):
195199
elif backend == "torch":
196200
# Torch requires explicit tracing with example inputs.
197201
def do_compile(example_integrand):
198-
import torch
199-
200202
# Define traceable first and third steps
201203
def step1(integration_domain):
202204
return self.calculate_sample_points(
@@ -206,7 +208,9 @@ def step1(integration_domain):
206208
step3 = self.calculate_result
207209

208210
# Trace the first step (which is non-deterministic)
209-
step1 = torch.jit.trace(step1, (integration_domain,), check_trace=False)
211+
step1 = _torch_trace_without_warnings(
212+
step1, (integration_domain,), check_trace=False
213+
)
210214

211215
# Get example input for the third step
212216
sample_points = step1(integration_domain)
@@ -215,12 +219,9 @@ def step1(integration_domain):
215219
)
216220

217221
# Trace the third step
218-
if function_values.requires_grad:
219-
# Avoid the warning about a .grad attribute access of a
220-
# non-leaf Tensor
221-
function_values = function_values.detach()
222-
function_values.requires_grad = True
223-
step3 = torch.jit.trace(step3, (function_values, integration_domain))
222+
step3 = _torch_trace_without_warnings(
223+
step3, (function_values, integration_domain)
224+
)
224225

225226
# Define a compiled integrate function
226227
def compiled_integrate(fn, integration_domain):

torchquad/integration/simpson.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66

77

88
class Simpson(NewtonCotes):
9-
109
"""Simpson's rule. See https://en.wikipedia.org/wiki/Newton%E2%80%93Cotes_formulas#Closed_Newton%E2%80%93Cotes_formulas ."""
1110

1211
def __init__(self):

torchquad/integration/utils.py

Lines changed: 51 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
Utility functions for the integrator implementations including extensions for
33
autoray, which are registered when importing this file
44
"""
5+
56
import sys
67
from pathlib import Path
78

@@ -193,20 +194,11 @@ def _check_integration_domain(integration_domain):
193194
raise ValueError("integration_domain.shape[0] needs to be 1 or larger.")
194195
if num_bounds != 2:
195196
raise ValueError("integration_domain must have 2 values per boundary")
196-
# Skip the values check if an integrator.integrate method is JIT
197-
# compiled with JAX
198-
if any(
199-
nam in type(integration_domain).__name__ for nam in ["Jaxpr", "JVPTracer"]
200-
):
201-
return dim
202-
boundaries_are_invalid = (
203-
anp.min(integration_domain[:, 1] - integration_domain[:, 0]) < 0.0
204-
)
205-
# Skip the values check if an integrator.integrate method is
206-
# compiled with tensorflow.function
207-
if type(boundaries_are_invalid).__name__ == "Tensor":
197+
# The boundary values check does not work if the code is JIT compiled
198+
# with JAX or TensorFlow.
199+
if _is_compiling(integration_domain):
208200
return dim
209-
if boundaries_are_invalid:
201+
if anp.min(integration_domain[:, 1] - integration_domain[:, 0]) < 0.0:
210202
raise ValueError("integration_domain has invalid boundary values")
211203
return dim
212204

@@ -261,3 +253,49 @@ def wrap(*args, **kwargs):
261253
return f(*args, **kwargs)
262254

263255
return wrap
256+
257+
258+
def _is_compiling(x):
259+
"""
260+
Check if code is currently being compiled with PyTorch, JAX or TensorFlow
261+
262+
Args:
263+
x (backend tensor): A tensor currently used for computations
264+
Returns:
265+
bool: True if code is currently being compiled, False otherwise
266+
"""
267+
backend = infer_backend(x)
268+
if backend == "jax":
269+
return any(nam in type(x).__name__ for nam in ["Jaxpr", "JVPTracer"])
270+
if backend == "torch":
271+
import torch
272+
273+
if hasattr(torch.jit, "is_tracing"):
274+
# We ignore torch.jit.is_scripting() since we do not support
275+
# compilation to TorchScript
276+
return torch.jit.is_tracing()
277+
# torch.jit.is_tracing() is unavailable below PyTorch version 1.11.0
278+
return type(x.shape[0]).__name__ == "Tensor"
279+
if backend == "tensorflow":
280+
import tensorflow as tf
281+
282+
if hasattr(tf, "is_symbolic_tensor"):
283+
return tf.is_symbolic_tensor(x)
284+
# tf.is_symbolic_tensor() is unavailable below TensorFlow version 2.13.0
285+
return type(x).__name__ == "Tensor"
286+
return False
287+
288+
289+
def _torch_trace_without_warnings(*args, **kwargs):
290+
"""Execute `torch.jit.trace` on the passed arguments and hide tracer warnings
291+
292+
PyTorch can show warnings about traces being potentially incorrect because
293+
the Python3 control flow is not completely recorded.
294+
This function can be used to hide the warnings in situations where they are
295+
false positives.
296+
"""
297+
import torch
298+
299+
with warnings.catch_warnings():
300+
warnings.filterwarnings("ignore", category=torch.jit.TracerWarning)
301+
return torch.jit.trace(*args, **kwargs)

torchquad/tests/integration_test_functions.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ def __init__(
4444
"""
4545
self.integration_dim = integration_dim
4646
self.expected_result = expected_result
47-
if type(integrand_dims) == int or hasattr(integrand_dims, "__len__"):
47+
if type(integrand_dims) is int or hasattr(integrand_dims, "__len__"):
4848
self.integrand_dims = integrand_dims
4949
else:
5050
ValueError(

torchquad/tests/utils_integration_test.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
_linspace_with_grads,
1313
_add_at_indices,
1414
_setup_integration_domain,
15+
_is_compiling,
1516
)
1617
from utils.set_precision import set_precision
1718
from utils.enable_cuda import enable_cuda
@@ -196,11 +197,48 @@ def test_setup_integration_domain():
196197
_run_tests_with_all_backends(_run_setup_integration_domain_tests)
197198

198199

200+
def _run_is_compiling_tests(dtype_name, backend):
201+
"""
202+
Test _is_compiling with the given dtype and numerical backend
203+
"""
204+
dtype = to_backend_dtype(dtype_name, like=backend)
205+
x = anp.array([[0.0, 1.0], [1.0, 2.0]], dtype=dtype, like=backend)
206+
assert not _is_compiling(
207+
x
208+
), f"_is_compiling has a false positive with backend {backend}"
209+
210+
def check_compiling(x):
211+
assert _is_compiling(
212+
x
213+
), f"_is_compiling has a false negative with backend {backend}"
214+
return x
215+
216+
if backend == "jax":
217+
import jax
218+
219+
jax.jit(check_compiling)(x)
220+
elif backend == "torch":
221+
import torch
222+
223+
torch.jit.trace(check_compiling, (x,), check_trace=False)(x)
224+
elif backend == "tensorflow":
225+
import tensorflow as tf
226+
227+
tf.function(check_compiling, jit_compile=True)(x)
228+
tf.function(check_compiling, jit_compile=False)(x)
229+
230+
231+
def test_is_compiling():
232+
"""Test _is_compiling with all possible configurations"""
233+
_run_tests_with_all_backends(_run_is_compiling_tests)
234+
235+
199236
if __name__ == "__main__":
200237
try:
201238
# used to run this test individually
202239
test_linspace_with_grads()
203240
test_add_at_indices()
204241
test_setup_integration_domain()
242+
test_is_compiling()
205243
except KeyboardInterrupt:
206244
pass

0 commit comments

Comments
 (0)