Skip to content

Add validation on surface area #284

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jun 17, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion resonate/models/encoder.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from __future__ import annotations

from typing import Protocol
from typing import Protocol, runtime_checkable


@runtime_checkable
class Encoder[I, O](Protocol):
def encode(self, obj: I, /) -> O: ...
def decode(self, obj: O, /) -> I: ...
3 changes: 2 additions & 1 deletion resonate/models/retry_policy.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from __future__ import annotations

from typing import Protocol
from typing import Protocol, runtime_checkable


@runtime_checkable
class RetryPolicy(Protocol):
def next(self, attempt: int) -> float | None: ...
45 changes: 42 additions & 3 deletions resonate/options.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,13 @@
from typing import TYPE_CHECKING, Any

from resonate.encoders import HeaderEncoder, JsonEncoder, JsonPickleEncoder, NoopEncoder, PairEncoder
from resonate.models.encoder import Encoder
from resonate.models.retry_policy import RetryPolicy
from resonate.retry_policies import Exponential, Never

if TYPE_CHECKING:
from collections.abc import Callable

from resonate.models.encoder import Encoder
from resonate.models.retry_policy import RetryPolicy


@dataclass(frozen=True)
class Options:
Expand All @@ -28,6 +27,46 @@ class Options:
version: int = 0

def __post_init__(self) -> None:
if not isinstance(self.durable, bool):
msg = f"durable must be `bool`, got {type(self.durable).__name__}"
raise TypeError(msg)

if self.encoder is not None and not isinstance(self.encoder, Encoder):
msg = f"encoder must be `Encoder | None`, got {type(self.encoder).__name__}"
raise TypeError(msg)

if self.id is not None and not isinstance(self.id, str):
msg = f"id must be `str | None`, got {type(self.id).__name__}"
raise TypeError(msg)

if self.idempotency_key is not None and not (isinstance(self.idempotency_key, str) or callable(self.idempotency_key)):
msg = f"idempotency_key must be `Callable | str | None`, got {type(self.idempotency_key).__name__}"
raise TypeError(msg)

if not isinstance(self.non_retryable_exceptions, tuple):
msg = f"non_retryable_exceptions must be `tuple`, got {type(self.non_retryable_exceptions).__name__}"
raise TypeError(msg)

if not (isinstance(self.retry_policy, RetryPolicy) or callable(self.retry_policy)):
msg = f"retry_policy must be `Callable | RetryPolicy | None`, got {type(self.retry_policy).__name__}"
raise TypeError(msg)

if not isinstance(self.target, str):
msg = f"target must be `str`, got {type(self.target).__name__}"
raise TypeError(msg)

if not isinstance(self.tags, dict):
msg = f"tags must be `dict`, got {type(self.tags).__name__}"
raise TypeError(msg)

if not isinstance(self.timeout, int | float):
msg = f"timeout must be `float`, got {type(self.timeout).__name__}"
raise TypeError(msg)

if not isinstance(self.version, int):
msg = f"version must be `int`, got {type(self.version).__name__}"
raise TypeError(msg)

if not (self.version >= 0):
msg = "version must be greater than or equal to zero"
raise ValueError(msg)
Expand Down
155 changes: 155 additions & 0 deletions resonate/resonate.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,36 @@ def local(
dependencies: Dependencies | None = None,
log_level: int = logging.INFO,
) -> Resonate:
# pid
if pid is not None and not isinstance(pid, str):
msg = f"pid must be `str | None`, got {type(pid).__name__}"
raise TypeError(msg)

# ttl
if not isinstance(ttl, int):
msg = f"ttl must be `int`, got {type(ttl).__name__}"
raise TypeError(msg)

# group
if not isinstance(group, str):
msg = f"group must be `str`, got {type(group).__name__}"
raise TypeError(msg)

# registry
if registry is not None and not isinstance(registry, Registry):
msg = f"registry must be `Registry | None`, got {type(registry).__name__}"
raise TypeError(msg)

# dependencies
if dependencies is not None and not isinstance(dependencies, Dependencies):
msg = f"dependencies must be `Dependencies | None`, got {type(dependencies).__name__}"
raise TypeError(msg)

# log_level
if not isinstance(log_level, int):
msg = f"log_level must be `int`, got {type(log_level).__name__}"
raise TypeError(msg)

pid = pid or uuid.uuid4().hex
store = LocalStore()

Expand Down Expand Up @@ -115,6 +145,51 @@ def remote(
dependencies: Dependencies | None = None,
log_level: int = logging.INFO,
) -> Resonate:
# host
if host is not None and not isinstance(host, str):
msg = f"host must be `str | None`, got {type(host).__name__}"
raise TypeError(msg)

# store_port
if store_port is not None and not isinstance(store_port, str):
msg = f"store_port must be `str | None`, got {type(store_port).__name__}"
raise TypeError(msg)

# message_source_port
if message_source_port is not None and not isinstance(message_source_port, str):
msg = f"message_source_port must be `str | None`, got {type(message_source_port).__name__}"
raise TypeError(msg)

# pid
if pid is not None and not isinstance(pid, str):
msg = f"pid must be `str | None`, got {type(pid).__name__}"
raise TypeError(msg)

# ttl
if not isinstance(ttl, int):
msg = f"ttl must be `int`, got {type(ttl).__name__}"
raise TypeError(msg)

# group
if not isinstance(group, str):
msg = f"group must be `str`, got {type(group).__name__}"
raise TypeError(msg)

# registry
if registry is not None and not isinstance(registry, Registry):
msg = f"registry must be `Registry | None`, got {type(registry).__name__}"
raise TypeError(msg)

# dependencies
if dependencies is not None and not isinstance(dependencies, Dependencies):
msg = f"dependencies must be `Dependencies | None`, got {type(dependencies).__name__}"
raise TypeError(msg)

# log_level
if not isinstance(log_level, int):
msg = f"log_level must be `int`, got {type(log_level).__name__}"
raise TypeError(msg)

pid = pid or uuid.uuid4().hex

return cls(
Expand Down Expand Up @@ -186,7 +261,19 @@ def register[**P, R](
name: str | None = None,
version: int = 1,
) -> Function[P, R] | Callable[[Callable[Concatenate[Context, P], R]], Function[P, R]]:
if name is not None and not isinstance(name, str):
msg = f"name must be `str | None`, got {type(name).__name__}"
raise TypeError(msg)

if not isinstance(version, int):
msg = f"version must be `int`, got {type(version).__name__}"
raise TypeError(msg)

def wrapper(func: Callable[..., Any]) -> Function[P, R]:
if not callable(func):
msg = "func must be Callable"
raise TypeError(msg)

if isinstance(func, Function):
func = func.func

Expand Down Expand Up @@ -221,6 +308,23 @@ def run[**P, R](
*args: P.args,
**kwargs: P.kwargs,
) -> Handle[R]:
# id
if not isinstance(id, str):
msg = f"id must be `str`, got {type(id).__name__}"
raise TypeError(msg)
# func
if not (callable(func) or isinstance(func, str)):
msg = f"func must be `Callable | str`, got {type(func).__name__}"
raise TypeError(msg)
# tuple
if not isinstance(args, tuple):
msg = f"args must be `tuple`, got {type(args).__name__}"
raise TypeError(args)
# dict
if not isinstance(kwargs, dict):
msg = f"kwargs must be `dict`, got {type(kwargs).__name__}"
raise TypeError(kwargs)

self.start()
future = Future[R]()

Expand Down Expand Up @@ -253,6 +357,23 @@ def rpc[**P, R](
*args: P.args,
**kwargs: P.kwargs,
) -> Handle[R]:
# id
if not isinstance(id, str):
msg = f"id must be `str`, got {type(id).__name__}"
raise TypeError(msg)
# func
if not (callable(func) or isinstance(func, str)):
msg = f"func must be `Callable | str`, got {type(func).__name__}"
raise TypeError(msg)
# tuple
if not isinstance(args, tuple):
msg = f"args must be `tuple`, got {type(args).__name__}"
raise TypeError(args)
# dict
if not isinstance(kwargs, dict):
msg = f"kwargs must be `dict`, got {type(kwargs).__name__}"
raise TypeError(kwargs)

self.start()
future = Future[R]()

Expand All @@ -267,13 +388,23 @@ def rpc[**P, R](
return Handle(future)

def get(self, id: str) -> Handle[Any]:
# id
if not isinstance(id, str):
msg = f"id must be `str`, got {type(id).__name__}"
raise TypeError(msg)

self.start()
future = Future()

self._bridge.get(id, self._opts, future)
return Handle(future)

def set_dependency(self, name: str, obj: Any) -> None:
# name
if not isinstance(name, str):
msg = f"name must be `str`, got {type(name).__name__}"
raise TypeError(msg)

self._dependencies.add(name, obj)


Expand Down Expand Up @@ -313,6 +444,10 @@ def time(self) -> Time:
return self._time

def get_dependency[T](self, key: str, default: T = None) -> Any | T:
if not isinstance(key, str):
msg = f"key must be `str`, got {type(key).__name__}"
raise TypeError(msg)

return self._dependencies.get(key, default)

def lfi[**P, R](
Expand Down Expand Up @@ -424,6 +559,10 @@ def typesafe(self, cmd: LFI | RFI | LFC | RFC | Promise) -> Generator[LFI | RFI
return (yield cmd)

def sleep(self, secs: float) -> RFC[None]:
if not isinstance(secs, int | float):
msg = f"secs must be `float`, got {type(secs).__name__}"
raise TypeError(msg)

return RFC(Sleep(self._next(), secs))

def promise(
Expand All @@ -435,6 +574,22 @@ def promise(
data: Any = None,
tags: dict[str, str] | None = None,
) -> RFI:
if id is not None and not isinstance(id, str):
msg = f"id must be `str | None`, got {type(id).__name__}"
raise TypeError(msg)

if timeout is not None and isinstance(timeout, int | float):
msg = f"timeout must be `float`, got {type(timeout).__name__}"
raise TypeError(msg)

if idempotency_key is not None and not isinstance(idempotency_key, str):
msg = f"idempotency_key must be `str | None`, got {type(idempotency_key).__name__}"
raise TypeError(msg)

if tags is not None and not isinstance(tags, dict):
msg = f"tags must be `dict | None`, got {type(tags).__name__}"
raise TypeError(tags)

default_id = self._next()
id = id or default_id

Expand Down