Skip to content

Add group and remove unicast/anycast from resonate constructor #198

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Apr 24, 2025
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion resonate/bridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ def listen(self, cmd: Listen, futures: tuple[Future, Future]) -> None:

def start(self) -> None:
if not self._messages_thread.is_alive():
self._message_src.start(MesgQueueAdapter(self._mq), self._pid)
self._message_src.start(MesgQueueAdapter(self._mq))
self._messages_thread.start()

if not self._bridge_thread.is_alive():
Expand Down
69 changes: 43 additions & 26 deletions resonate/message_sources/poller.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
from resonate.models.message import InvokeMesg, NotifyMesg, ResumeMesg

if TYPE_CHECKING:
from requests.models import Response

from resonate.models.encoder import Encoder
from resonate.models.message import Mesg
from resonate.models.message_source import MessageQ
Expand All @@ -20,27 +22,41 @@
class Poller:
def __init__(
self,
group: str,
id: str,
url: str | None = None,
group: str = "default",
timeout: int | None = None,
encoder: Encoder[Any, str] | None = None,
) -> None:
self._group = group
self._id = id
self._url = url or os.getenv("RESONATE_MSG_SRC_URL", "http://localhost:8002")
self.group = group
self._timeout = timeout
self._encoder = encoder or JsonEncoder()
self._thread: Thread | None = None
self._timeout = timeout

def start(self, cq: MessageQ, pid: str) -> None:
@property
def url(self) -> str:
return f"{self._url}/{self._group}/{self._id}"

@property
def unicast(self) -> str:
return f"poll://{self._group}/{self._id}"

@property
def anycast(self) -> str:
return f"poll://{self._group}/{self._id}"

def start(self, mq: MessageQ) -> None:
if self._thread is not None:
return

self._thread = Thread(name="poller-thread", target=self.loop, args=(cq, pid), daemon=True)
self._thread = Thread(name="poller-thread", target=self.loop, args=(mq,), daemon=True)
self._thread.start()

def stop(self) -> None:
# TODO(avillega): Couldn't come up with a nice way of stoping this thread
# iter_lines is blocking and request.get is also blockig, this makes it so
# iter_lines is blocking and request.get is also blocking, this makes it so
# the only way to stop it is waiting for a timeout on the request itself
# which could never happen.
#
Expand All @@ -49,40 +65,41 @@ def stop(self) -> None:
# could be that it never gets called when running tests for example.
pass

def url(self, pid: str) -> str:
return f"{self._url}/{self.group}/{pid}"

def loop(self, cq: MessageQ, pid: str) -> None:
def loop(self, mq: MessageQ) -> None:
while True:
try:
url = self.url(pid)
with requests.get(url, stream=True, timeout=self._timeout) as res:
with requests.get(self.url, stream=True, timeout=self._timeout) as res:
res.raise_for_status()

for line in res.iter_lines(chunk_size=None, decode_unicode=True):
assert isinstance(line, str), "line must be a string"
if msg := self._process_line(line):
cq.enqueue(msg)
if msg := self._step(res):
mq.enqueue(msg)

except requests.exceptions.Timeout:
logger.warning("Polling request timed out for group %s. Retrying...", self.group)
logger.warning("Polling request timed out for group %s. Retrying...", self._group)
time.sleep(1)
continue
except requests.exceptions.RequestException as e:
logger.warning("Polling network error for group %s: %s. Retrying after delay...", self.group, str(e))
logger.warning("Polling network error for group %s: %s. Retrying after delay...", self._group, str(e))
time.sleep(1)
continue
except Exception as e:
logger.warning("Unexpected error in poller loop for group %s: %s", self.group, e)
logger.warning("Unexpected error in poller loop for group %s: %s", self._group, e)
break

def step(self, cq: MessageQ, pid: str) -> None:
with requests.get(self.url(pid), stream=True, timeout=self._timeout) as res:
for line in res.iter_lines(chunk_size=None, decode_unicode=True):
assert isinstance(line, str), "line must be a string"
if msg := self._process_line(line):
cq.enqueue(msg)
break
def step(self) -> list[Mesg]:
with requests.get(self.url, stream=True, timeout=self._timeout) as res:
msg = self._step(res)
assert msg

return [msg]

def _step(self, res: Response) -> Mesg | None:
for line in res.iter_lines(chunk_size=None, decode_unicode=True):
assert isinstance(line, str), "line must be a string"
if msg := self._process_line(line):
return msg

return None

def _process_line(self, line: str) -> Mesg | None:
if not line:
Expand Down
12 changes: 10 additions & 2 deletions resonate/models/message_source.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,22 @@
from __future__ import annotations

from typing import TYPE_CHECKING, Protocol
from typing import TYPE_CHECKING, Any, Protocol

if TYPE_CHECKING:
from resonate.models.message import Mesg


class MessageSource(Protocol):
def start(self, cq: MessageQ, pid: str) -> None: ...
def __init__(self, group: str, id: str, *args: Any, **kwargs: Any) -> None: ...

@property
def unicast(self) -> str: ...
@property
def anycast(self) -> str: ...

def start(self, mq: MessageQ) -> None: ...
def stop(self) -> None: ...
def step(self) -> list[Mesg]: ...


class MessageQ(Protocol):
Expand Down
35 changes: 20 additions & 15 deletions resonate/resonate.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from collections.abc import Callable
from concurrent.futures import Future
from inspect import isgeneratorfunction
from types import NoneType
from typing import TYPE_CHECKING, Any, Concatenate, overload

from resonate.bridge import Bridge
Expand All @@ -30,53 +31,57 @@
from resonate.models.encoder import Encoder
from resonate.models.message_source import MessageSource
from resonate.models.retry_policy import RetryPolicy
from resonate.models.store import Store
from resonate.models.store import PromiseStore, Store, TaskStore


class Resonate:
def __init__(
self,
*,
url: str | None = None,
group: str = "default",
pid: str | None = None,
ttl: int = 10,
opts: Options | None = None,
anycast: str | None = None,
unicast: str | None = None,
store: Store | None = None,
message_source: MessageSource | None = None,
encoder: Encoder[Any, str | None] | None = None,
registry: Registry | None = None,
dependencies: Dependencies | None = None,
) -> None:
assert not isinstance(store, (NoneType, LocalStore)) or message_source is None

self._started = False

self._group = group
self._pid = pid or uuid.uuid4().hex
self._opts = opts or Options()

self._registry = registry or Registry()
self._dependencies = dependencies or Dependencies()

self.store = store or LocalStore() if url is None else RemoteStore(url)
assert not isinstance(self.store, LocalStore) or message_source is None

message_source = message_source or self.store.as_msg_source() if isinstance(self.store, LocalStore) else Poller()

# TODO(dfarr): grab default addresses from message source
self._unicast = unicast or f"poll://default/{self._pid}"
self._anycast = anycast or f"poll://default/{self._pid}"
self._store = store or LocalStore() if url is None else RemoteStore(url)
self._message_source = message_source or self._store.message_source(self._group, self._pid) if isinstance(self._store, LocalStore) else Poller(self._group, self._pid)

self._bridge = Bridge(
ctx=lambda id, info: Context(id, info, self._opts, self._registry, self._dependencies),
pid=self._pid,
ttl=ttl,
anycast=self._anycast,
unicast=self._unicast,
store=self.store,
message_source=message_source,
store=self._store,
message_source=self._message_source,
anycast=self._message_source.anycast,
unicast=self._message_source.unicast,
registry=self._registry,
)

@property
def promises(self) -> PromiseStore:
return self._store.promises

@property
def tasks(self) -> TaskStore:
return self._store.tasks

def start(self) -> None:
if not self._started:
self._bridge.start()
Expand Down
44 changes: 30 additions & 14 deletions resonate/stores/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,18 +24,22 @@

class LocalStore:
def __init__(self, encoder: Encoder[Any, str | None] | None = None, clock: Clock | None = None) -> None:
self.encoder = encoder or ChainEncoder(JsonEncoder(), Base64Encoder())

self._promises: dict[str, DurablePromiseRecord] = {}
self._tasks: dict[str, TaskRecord] = {}
self._routers: list[Router] = [TagRouter()]

self._encoder = encoder or ChainEncoder(JsonEncoder(), Base64Encoder())
self._clock = clock or WallClock()

@property
def encoder(self) -> Encoder[Any, str | None]:
return self._encoder

@property
def promises(self) -> LocalPromiseStore:
return LocalPromiseStore(
self,
self.encoder,
self._encoder,
self._promises,
self._tasks,
self._routers,
Expand All @@ -46,7 +50,7 @@ def promises(self) -> LocalPromiseStore:
def tasks(self) -> LocalTaskStore:
return LocalTaskStore(
self,
self.encoder,
self._encoder,
self._promises,
self._tasks,
self._clock,
Expand All @@ -55,8 +59,8 @@ def tasks(self) -> LocalTaskStore:
def add_router(self, router: Router) -> None:
self._routers.append(router)

def as_msg_source(self) -> MessageSource:
return _LocalMessageSource(self)
def message_source(self, group: str, id: str) -> MessageSource:
return _LocalMessageSource(group, id, self)

def step(self) -> list[tuple[str, Mesg]]:
messages: list[tuple[str, Mesg]] = []
Expand Down Expand Up @@ -766,19 +770,29 @@ def ikey_match(left: str | None, right: str | None) -> bool:


class _LocalMessageSource:
def __init__(self, local_store: LocalStore) -> None:
self._store = local_store
def __init__(self, group: str, id: str, store: LocalStore) -> None:
self._group = group
self._id = id
self._store = store
self._thread: threading.Thread | None = None
self._stop_event = threading.Event()

def start(self, cq: MessageQ, pid: str) -> None:
@property
def unicast(self) -> str:
return f"poll://{self._group}/{self._id}"

@property
def anycast(self) -> str:
return f"poll://{self._group}/{self._id}"

def start(self, mq: MessageQ) -> None:
if self._thread is not None:
return

self._stop_event.clear()
self._thread = threading.Thread(
target=self._loop,
args=(cq,),
args=(mq,),
name="local_msg_source",
daemon=True,
)
Expand All @@ -791,9 +805,11 @@ def stop(self) -> None:
self._thread = None
self._stop_event.clear()

def _loop(self, cq: MessageQ) -> None:
def _loop(self, mq: MessageQ) -> None:
while not self._stop_event.is_set():
msgs = self._store.step()
for _, msg in msgs:
cq.enqueue(msg)
for msg in self.step():
mq.enqueue(msg)
self._stop_event.wait(0.1)

def step(self) -> list[Mesg]:
return [m for _, m in self._store.step()]
38 changes: 38 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,21 @@
from __future__ import annotations

import logging
import os
import random
import sys
from typing import TYPE_CHECKING

import pytest

from resonate.message_sources.poller import Poller
from resonate.stores.local import LocalStore
from resonate.stores.remote import RemoteStore

if TYPE_CHECKING:
from resonate.models.message_source import MessageSource
from resonate.models.store import Store


def pytest_configure() -> None:
logging.basicConfig(level=logging.INFO, format="[%(levelname)s] %(message)s")
Expand All @@ -16,6 +26,9 @@ def pytest_addoption(parser: pytest.Parser) -> None:
parser.addoption("--steps", action="store")


# DST fixtures


@pytest.fixture
def seed(request: pytest.FixtureRequest) -> str:
seed = request.config.getoption("--seed")
Expand All @@ -37,3 +50,28 @@ def steps(request: pytest.FixtureRequest) -> int:
pass

return 10000


# Store fixtures

stores: list[Store] = [LocalStore()]

if "RESONATE_STORE_URL" in os.environ:
stores.append(RemoteStore(os.environ["RESONATE_STORE_URL"]))


@pytest.fixture(scope="module", params=stores)
def store(request: pytest.FixtureRequest) -> Store:
return request.param


@pytest.fixture(scope="module")
def message_source(store: Store) -> MessageSource:
match store:
case LocalStore():
return store.message_source(group="default", id="test")
case RemoteStore():
return Poller(group="default", id="test", timeout=2)
case _:
msg = "Unknown store type"
raise ValueError(msg)
Loading
Loading