Skip to content

Commit beae31a

Browse files
committed
Async Timer, ExceptionTracker, and InprogressTracker implementations.
Signed-off-by: Cameron Lee <[email protected]>
1 parent 29f5307 commit beae31a

File tree

2 files changed

+40
-17
lines changed

2 files changed

+40
-17
lines changed

prometheus_client/context_managers.py

Lines changed: 30 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from timeit import default_timer
33
from types import TracebackType
44
from typing import Any, Callable, Optional, Type, TYPE_CHECKING, TypeVar
5+
from inspect import iscoroutinefunction
56

67
if sys.version_info >= (3, 8, 0):
78
from typing import Literal
@@ -26,10 +27,16 @@ def __exit__(self, typ: Optional[Type[BaseException]], value: Optional[BaseExcep
2627
self._counter.inc()
2728
return False
2829

30+
2931
def __call__(self, f: "F") -> "F":
30-
def wrapped(func, *args, **kwargs):
31-
with self:
32-
return func(*args, **kwargs)
32+
if iscoroutinefunction(f):
33+
async def wrapped(func, *args, **kwargs):
34+
with self:
35+
return await func(*args, **kwargs)
36+
else:
37+
def wrapped(func, *args, **kwargs):
38+
with self:
39+
return func(*args, **kwargs)
3340

3441
return decorate(f, wrapped)
3542

@@ -45,9 +52,14 @@ def __exit__(self, typ, value, traceback):
4552
self._gauge.dec()
4653

4754
def __call__(self, f):
48-
def wrapped(func, *args, **kwargs):
49-
with self:
50-
return func(*args, **kwargs)
55+
if iscoroutinefunction(f):
56+
async def wrapped(func, *args, **kwargs):
57+
with self:
58+
return await func(*args, **kwargs)
59+
else:
60+
def wrapped(func, *args, **kwargs):
61+
with self:
62+
return func(*args, **kwargs)
5163

5264
return decorate(f, wrapped)
5365

@@ -74,10 +86,17 @@ def labels(self, *args, **kw):
7486
self._metric = self._metric.labels(*args, **kw)
7587

7688
def __call__(self, f):
77-
def wrapped(func, *args, **kwargs):
78-
# Obtaining new instance of timer every time
79-
# ensures thread safety and reentrancy.
80-
with self._new_timer():
81-
return func(*args, **kwargs)
89+
if iscoroutinefunction(f):
90+
# If an async function was decorated, we support that by
91+
# producing an async function and awaiting the original.
92+
async def wrapped(func, *args, **kwargs):
93+
with self._new_timer():
94+
return await func(*args, **kwargs)
95+
else:
96+
def wrapped(func, *args, **kwargs):
97+
# Obtaining new instance of timer every time
98+
# ensures thread safety and reentrancy.
99+
with self._new_timer():
100+
return func(*args, **kwargs)
82101

83102
return decorate(f, wrapped)

tests/test_core.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,10 @@ def assert_not_observable(fn, *args, **kwargs):
2929
assert False, "Did not raise a 'missing label values' exception"
3030

3131

32+
def assert_between(lower, value, upper, msg=""):
33+
assert lower <= value <= upper, "%s is not between %s and %s%s" % (value, lower, upper, " : %s" % msg if msg else "")
34+
35+
3236
class TestCounter(aiounittest.AsyncTestCase):
3337
def setUp(self):
3438
self.registry = CollectorRegistry()
@@ -215,7 +219,7 @@ def f():
215219
self.assertEqual(([], None, None, None), getargspec(f))
216220

217221
f()
218-
self.assertTrue(0.05 <= self.registry.get_sample_value('g') <= 0.1)
222+
assert_between(0.05, self.registry.get_sample_value('g'), 0.1)
219223

220224
async def test_time_async_function_decorator(self):
221225
self.assertEqual(0, self.registry.get_sample_value('g'))
@@ -227,7 +231,7 @@ async def f():
227231
self.assertEqual(([], None, None, None), getargspec(f))
228232

229233
await f()
230-
self.assertTrue(0.05 <= self.registry.get_sample_value('g') <= 0.1)
234+
assert_between(0.05, self.registry.get_sample_value('g'), 0.1)
231235

232236
def test_function_decorator_multithread(self):
233237
self.assertEqual(0, self.registry.get_sample_value('g'))
@@ -307,7 +311,7 @@ def f():
307311

308312
f()
309313
self.assertEqual(1, self.registry.get_sample_value('s_count'))
310-
self.assertTrue(.05 < self.registry.get_sample_value('s_sum') < 0.1)
314+
assert_between(.05, self.registry.get_sample_value('s_sum'), 0.1)
311315

312316
async def test_async_function_decorator(self):
313317
self.assertEqual(0, self.registry.get_sample_value('s_count'))
@@ -320,7 +324,7 @@ async def f():
320324

321325
await f()
322326
self.assertEqual(1, self.registry.get_sample_value('s_count'))
323-
self.assertTrue(.05 < self.registry.get_sample_value('s_sum') < 0.1)
327+
assert_between(.05, self.registry.get_sample_value('s_sum'), 0.1)
324328

325329
def test_function_decorator_multithread(self):
326330
self.assertEqual(0, self.registry.get_sample_value('s_count'))
@@ -475,7 +479,7 @@ def f():
475479
f()
476480
self.assertEqual(1, self.registry.get_sample_value('h_count'))
477481
self.assertEqual(1, self.registry.get_sample_value('h_bucket', {'le': '+Inf'}))
478-
self.assertTrue(.05 < self.registry.get_sample_value('h_sum') < 0.1)
482+
assert_between(.05, self.registry.get_sample_value('h_sum'), 0.1)
479483

480484
async def test_async_function_decorator(self):
481485
self.assertEqual(0, self.registry.get_sample_value('h_count'))
@@ -490,7 +494,7 @@ async def f():
490494
await f()
491495
self.assertEqual(1, self.registry.get_sample_value('h_count'))
492496
self.assertEqual(1, self.registry.get_sample_value('h_bucket', {'le': '+Inf'}))
493-
self.assertTrue(.05 < self.registry.get_sample_value('h_sum') < 0.1)
497+
assert_between(.05, self.registry.get_sample_value('h_sum'), 0.1)
494498

495499
def test_function_decorator_multithread(self):
496500
self.assertEqual(0, self.registry.get_sample_value('h_count'))

0 commit comments

Comments
 (0)