Skip to content

Commit 1fefe16

Browse files
authored
Add validation on surface area (#284)
1 parent 228659d commit 1fefe16

File tree

4 files changed

+201
-5
lines changed

4 files changed

+201
-5
lines changed

resonate/models/encoder.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
from __future__ import annotations
22

3-
from typing import Protocol
3+
from typing import Protocol, runtime_checkable
44

55

6+
@runtime_checkable
67
class Encoder[I, O](Protocol):
78
def encode(self, obj: I, /) -> O: ...
89
def decode(self, obj: O, /) -> I: ...

resonate/models/retry_policy.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
from __future__ import annotations
22

3-
from typing import Protocol
3+
from typing import Protocol, runtime_checkable
44

55

6+
@runtime_checkable
67
class RetryPolicy(Protocol):
78
def next(self, attempt: int) -> float | None: ...

resonate/options.py

Lines changed: 42 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,13 @@
55
from typing import TYPE_CHECKING, Any
66

77
from resonate.encoders import HeaderEncoder, JsonEncoder, JsonPickleEncoder, NoopEncoder, PairEncoder
8+
from resonate.models.encoder import Encoder
9+
from resonate.models.retry_policy import RetryPolicy
810
from resonate.retry_policies import Exponential, Never
911

1012
if TYPE_CHECKING:
1113
from collections.abc import Callable
1214

13-
from resonate.models.encoder import Encoder
14-
from resonate.models.retry_policy import RetryPolicy
15-
1615

1716
@dataclass(frozen=True)
1817
class Options:
@@ -28,6 +27,46 @@ class Options:
2827
version: int = 0
2928

3029
def __post_init__(self) -> None:
30+
if not isinstance(self.durable, bool):
31+
msg = f"durable must be `bool`, got {type(self.durable).__name__}"
32+
raise TypeError(msg)
33+
34+
if self.encoder is not None and not isinstance(self.encoder, Encoder):
35+
msg = f"encoder must be `Encoder | None`, got {type(self.encoder).__name__}"
36+
raise TypeError(msg)
37+
38+
if self.id is not None and not isinstance(self.id, str):
39+
msg = f"id must be `str | None`, got {type(self.id).__name__}"
40+
raise TypeError(msg)
41+
42+
if self.idempotency_key is not None and not (isinstance(self.idempotency_key, str) or callable(self.idempotency_key)):
43+
msg = f"idempotency_key must be `Callable | str | None`, got {type(self.idempotency_key).__name__}"
44+
raise TypeError(msg)
45+
46+
if not isinstance(self.non_retryable_exceptions, tuple):
47+
msg = f"non_retryable_exceptions must be `tuple`, got {type(self.non_retryable_exceptions).__name__}"
48+
raise TypeError(msg)
49+
50+
if not (isinstance(self.retry_policy, RetryPolicy) or callable(self.retry_policy)):
51+
msg = f"retry_policy must be `Callable | RetryPolicy | None`, got {type(self.retry_policy).__name__}"
52+
raise TypeError(msg)
53+
54+
if not isinstance(self.target, str):
55+
msg = f"target must be `str`, got {type(self.target).__name__}"
56+
raise TypeError(msg)
57+
58+
if not isinstance(self.tags, dict):
59+
msg = f"tags must be `dict`, got {type(self.tags).__name__}"
60+
raise TypeError(msg)
61+
62+
if not isinstance(self.timeout, int | float):
63+
msg = f"timeout must be `float`, got {type(self.timeout).__name__}"
64+
raise TypeError(msg)
65+
66+
if not isinstance(self.version, int):
67+
msg = f"version must be `int`, got {type(self.version).__name__}"
68+
raise TypeError(msg)
69+
3170
if not (self.version >= 0):
3271
msg = "version must be greater than or equal to zero"
3372
raise ValueError(msg)

resonate/resonate.py

Lines changed: 155 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,36 @@ def local(
8888
dependencies: Dependencies | None = None,
8989
log_level: int = logging.INFO,
9090
) -> Resonate:
91+
# pid
92+
if pid is not None and not isinstance(pid, str):
93+
msg = f"pid must be `str | None`, got {type(pid).__name__}"
94+
raise TypeError(msg)
95+
96+
# ttl
97+
if not isinstance(ttl, int):
98+
msg = f"ttl must be `int`, got {type(ttl).__name__}"
99+
raise TypeError(msg)
100+
101+
# group
102+
if not isinstance(group, str):
103+
msg = f"group must be `str`, got {type(group).__name__}"
104+
raise TypeError(msg)
105+
106+
# registry
107+
if registry is not None and not isinstance(registry, Registry):
108+
msg = f"registry must be `Registry | None`, got {type(registry).__name__}"
109+
raise TypeError(msg)
110+
111+
# dependencies
112+
if dependencies is not None and not isinstance(dependencies, Dependencies):
113+
msg = f"dependencies must be `Dependencies | None`, got {type(dependencies).__name__}"
114+
raise TypeError(msg)
115+
116+
# log_level
117+
if not isinstance(log_level, int):
118+
msg = f"log_level must be `int`, got {type(log_level).__name__}"
119+
raise TypeError(msg)
120+
91121
pid = pid or uuid.uuid4().hex
92122
store = LocalStore()
93123

@@ -115,6 +145,51 @@ def remote(
115145
dependencies: Dependencies | None = None,
116146
log_level: int = logging.INFO,
117147
) -> Resonate:
148+
# host
149+
if host is not None and not isinstance(host, str):
150+
msg = f"host must be `str | None`, got {type(host).__name__}"
151+
raise TypeError(msg)
152+
153+
# store_port
154+
if store_port is not None and not isinstance(store_port, str):
155+
msg = f"store_port must be `str | None`, got {type(store_port).__name__}"
156+
raise TypeError(msg)
157+
158+
# message_source_port
159+
if message_source_port is not None and not isinstance(message_source_port, str):
160+
msg = f"message_source_port must be `str | None`, got {type(message_source_port).__name__}"
161+
raise TypeError(msg)
162+
163+
# pid
164+
if pid is not None and not isinstance(pid, str):
165+
msg = f"pid must be `str | None`, got {type(pid).__name__}"
166+
raise TypeError(msg)
167+
168+
# ttl
169+
if not isinstance(ttl, int):
170+
msg = f"ttl must be `int`, got {type(ttl).__name__}"
171+
raise TypeError(msg)
172+
173+
# group
174+
if not isinstance(group, str):
175+
msg = f"group must be `str`, got {type(group).__name__}"
176+
raise TypeError(msg)
177+
178+
# registry
179+
if registry is not None and not isinstance(registry, Registry):
180+
msg = f"registry must be `Registry | None`, got {type(registry).__name__}"
181+
raise TypeError(msg)
182+
183+
# dependencies
184+
if dependencies is not None and not isinstance(dependencies, Dependencies):
185+
msg = f"dependencies must be `Dependencies | None`, got {type(dependencies).__name__}"
186+
raise TypeError(msg)
187+
188+
# log_level
189+
if not isinstance(log_level, int):
190+
msg = f"log_level must be `int`, got {type(log_level).__name__}"
191+
raise TypeError(msg)
192+
118193
pid = pid or uuid.uuid4().hex
119194

120195
return cls(
@@ -186,7 +261,19 @@ def register[**P, R](
186261
name: str | None = None,
187262
version: int = 1,
188263
) -> Function[P, R] | Callable[[Callable[Concatenate[Context, P], R]], Function[P, R]]:
264+
if name is not None and not isinstance(name, str):
265+
msg = f"name must be `str | None`, got {type(name).__name__}"
266+
raise TypeError(msg)
267+
268+
if not isinstance(version, int):
269+
msg = f"version must be `int`, got {type(version).__name__}"
270+
raise TypeError(msg)
271+
189272
def wrapper(func: Callable[..., Any]) -> Function[P, R]:
273+
if not callable(func):
274+
msg = "func must be Callable"
275+
raise TypeError(msg)
276+
190277
if isinstance(func, Function):
191278
func = func.func
192279

@@ -221,6 +308,23 @@ def run[**P, R](
221308
*args: P.args,
222309
**kwargs: P.kwargs,
223310
) -> Handle[R]:
311+
# id
312+
if not isinstance(id, str):
313+
msg = f"id must be `str`, got {type(id).__name__}"
314+
raise TypeError(msg)
315+
# func
316+
if not (callable(func) or isinstance(func, str)):
317+
msg = f"func must be `Callable | str`, got {type(func).__name__}"
318+
raise TypeError(msg)
319+
# tuple
320+
if not isinstance(args, tuple):
321+
msg = f"args must be `tuple`, got {type(args).__name__}"
322+
raise TypeError(args)
323+
# dict
324+
if not isinstance(kwargs, dict):
325+
msg = f"kwargs must be `dict`, got {type(kwargs).__name__}"
326+
raise TypeError(kwargs)
327+
224328
self.start()
225329
future = Future[R]()
226330

@@ -253,6 +357,23 @@ def rpc[**P, R](
253357
*args: P.args,
254358
**kwargs: P.kwargs,
255359
) -> Handle[R]:
360+
# id
361+
if not isinstance(id, str):
362+
msg = f"id must be `str`, got {type(id).__name__}"
363+
raise TypeError(msg)
364+
# func
365+
if not (callable(func) or isinstance(func, str)):
366+
msg = f"func must be `Callable | str`, got {type(func).__name__}"
367+
raise TypeError(msg)
368+
# tuple
369+
if not isinstance(args, tuple):
370+
msg = f"args must be `tuple`, got {type(args).__name__}"
371+
raise TypeError(args)
372+
# dict
373+
if not isinstance(kwargs, dict):
374+
msg = f"kwargs must be `dict`, got {type(kwargs).__name__}"
375+
raise TypeError(kwargs)
376+
256377
self.start()
257378
future = Future[R]()
258379

@@ -267,13 +388,23 @@ def rpc[**P, R](
267388
return Handle(future)
268389

269390
def get(self, id: str) -> Handle[Any]:
391+
# id
392+
if not isinstance(id, str):
393+
msg = f"id must be `str`, got {type(id).__name__}"
394+
raise TypeError(msg)
395+
270396
self.start()
271397
future = Future()
272398

273399
self._bridge.get(id, self._opts, future)
274400
return Handle(future)
275401

276402
def set_dependency(self, name: str, obj: Any) -> None:
403+
# name
404+
if not isinstance(name, str):
405+
msg = f"name must be `str`, got {type(name).__name__}"
406+
raise TypeError(msg)
407+
277408
self._dependencies.add(name, obj)
278409

279410

@@ -313,6 +444,10 @@ def time(self) -> Time:
313444
return self._time
314445

315446
def get_dependency[T](self, key: str, default: T = None) -> Any | T:
447+
if not isinstance(key, str):
448+
msg = f"key must be `str`, got {type(key).__name__}"
449+
raise TypeError(msg)
450+
316451
return self._dependencies.get(key, default)
317452

318453
def lfi[**P, R](
@@ -424,6 +559,10 @@ def typesafe(self, cmd: LFI | RFI | LFC | RFC | Promise) -> Generator[LFI | RFI
424559
return (yield cmd)
425560

426561
def sleep(self, secs: float) -> RFC[None]:
562+
if not isinstance(secs, int | float):
563+
msg = f"secs must be `float`, got {type(secs).__name__}"
564+
raise TypeError(msg)
565+
427566
return RFC(Sleep(self._next(), secs))
428567

429568
def promise(
@@ -435,6 +574,22 @@ def promise(
435574
data: Any = None,
436575
tags: dict[str, str] | None = None,
437576
) -> RFI:
577+
if id is not None and not isinstance(id, str):
578+
msg = f"id must be `str | None`, got {type(id).__name__}"
579+
raise TypeError(msg)
580+
581+
if timeout is not None and isinstance(timeout, int | float):
582+
msg = f"timeout must be `float`, got {type(timeout).__name__}"
583+
raise TypeError(msg)
584+
585+
if idempotency_key is not None and not isinstance(idempotency_key, str):
586+
msg = f"idempotency_key must be `str | None`, got {type(idempotency_key).__name__}"
587+
raise TypeError(msg)
588+
589+
if tags is not None and not isinstance(tags, dict):
590+
msg = f"tags must be `dict | None`, got {type(tags).__name__}"
591+
raise TypeError(tags)
592+
438593
default_id = self._next()
439594
id = id or default_id
440595

0 commit comments

Comments
 (0)