Skip to content

Commit a9f5630

Browse files
authored
Random methods on resonate ctx (#209)
1 parent 3a9c361 commit a9f5630

File tree

3 files changed

+39
-9
lines changed

3 files changed

+39
-9
lines changed

resonate/dependencies.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,5 +10,5 @@ def __init__(self) -> None:
1010
def add(self, key: str, obj: Any) -> None:
1111
self._deps[key] = obj
1212

13-
def get(self, key: str) -> Any:
14-
return self._deps[key]
13+
def get[T](self, key: str, default: T) -> Any | T:
14+
return self._deps.get(key, default)

resonate/resonate.py

Lines changed: 33 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -202,6 +202,32 @@ def promises(self) -> PromiseStore:
202202
return self._store.promises
203203

204204

205+
class Random:
206+
def __init__(self, ctx: Context) -> None:
207+
self.ctx = ctx
208+
209+
def random(self) -> LFC:
210+
return self.ctx.lfc(lambda _: self.ctx.get_dependency("resonate:random", random).random())
211+
212+
def betavariate(self, alpha: float, beta: float) -> LFC:
213+
return self.ctx.lfc(lambda _: self.ctx.get_dependency("resonate:random", random).betavariate(alpha, beta))
214+
215+
def randrange(self, start: int, stop: int | None = None, step: int = 1) -> LFC:
216+
return self.ctx.lfc(lambda _: self.ctx.get_dependency("resonate:random", random).randrange(start, stop, step))
217+
218+
def randint(self, a: int, b: int) -> LFC:
219+
return self.ctx.lfc(lambda _: self.ctx.get_dependency("resonate:random", random).randint(a, b))
220+
221+
def getrandbits(self, k: int) -> LFC:
222+
return self.ctx.lfc(lambda _: self.ctx.get_dependency("resonate:random", random).getrandbits(k))
223+
224+
def triangular(self, low: float = 0, high: float = 1, mode: float | None = None) -> LFC:
225+
return self.ctx.lfc(lambda _: self.ctx.get_dependency("resonate:random", random).triangular(low, high, mode))
226+
227+
def expovariate(self, lambd: float = 1) -> LFC:
228+
return self.ctx.lfc(lambda _: self.ctx.get_dependency("resonate:random", random).expovariate(lambd))
229+
230+
205231
# Context
206232
class Context:
207233
def __init__(self, id: str, info: Info, opts: Options, registry: Registry, dependencies: Dependencies) -> None:
@@ -211,6 +237,7 @@ def __init__(self, id: str, info: Info, opts: Options, registry: Registry, depen
211237
self._registry = registry
212238
self._dependencies = dependencies
213239
self._counter = 0
240+
self._random = Random(self)
214241

215242
@property
216243
def id(self) -> str:
@@ -220,6 +247,10 @@ def id(self) -> str:
220247
def info(self) -> Info:
221248
return self._info
222249

250+
@property
251+
def random(self) -> Random:
252+
return self._random
253+
223254
@overload
224255
def lfi[**P, R](self, func: Callable[Concatenate[Context, P], Generator[Any, Any, R]], *args: P.args, **kwargs: P.kwargs) -> LFI: ...
225256
@overload
@@ -285,11 +316,8 @@ def promise(self, data: Any = None, headers: dict[str, str] | None = None) -> RF
285316
self._counter += 1
286317
return RFI(f"{self.id}.{self._counter}", Base(data, headers))
287318

288-
def random(self, a: int, b: int) -> LFC:
289-
return self.lfc(lambda _, a, b: random.randint(a, b), a, b)
290-
291-
def get_dependency(self, name: str) -> Any:
292-
return self._dependencies.get(name)
319+
def get_dependency[T](self, key: str, default: T = None) -> Any | T:
320+
return self._dependencies.get(key, default)
293321

294322
def _lfi_func(self, f: str | Callable) -> tuple[Callable, int, dict[int, Callable] | None]:
295323
match f:

tests/test_bridge.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,9 @@ def add_one(ctx: Context, n: int) -> int:
105105

106106

107107
def get_dependency(ctx: Context) -> int:
108-
return ctx.get_dependency("foo") + 1
108+
dep = ctx.get_dependency("foo")
109+
assert dep is not None
110+
return dep + 1
109111

110112

111113
def rfi_add_one_by_name(ctx: Context, n: int) -> Generator[Any, Any, int]:
@@ -123,7 +125,7 @@ def hitl(ctx: Context, id: str | None) -> Generator[Yieldable, Any, int]:
123125

124126

125127
def random_generation(ctx: Context) -> Generator[Yieldable, Any, float]:
126-
return (yield ctx.random(0, 10))
128+
return (yield ctx.random.randint(0, 10))
127129

128130

129131
def get_stores_config() -> list[tuple[Store, MessageSource | None]]:

0 commit comments

Comments
 (0)