Skip to content

fix: patch broker within testbroker context only #1619

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 14 commits into from
Aug 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .codespell-whitelist.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
dependant
unsecure
socio-economic
socio-economic
7 changes: 6 additions & 1 deletion .github/workflows/docs_update-references.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,12 @@ jobs:
cache-dependency-path: pyproject.toml
- name: Install Dependencies
if: steps.cache.outputs.cache-hit != 'true'
run: pip install -e ".[dev]"
shell: bash
# should install with `-e`
run: |
set -ux
python -m pip install uv
uv pip install --system -e ".[dev]"
- name: Run build docs
run: bash scripts/build-docs.sh
- name: Commit
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/pr_tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ jobs:
run: |
set -ux
python -m pip install uv
uv pip install --system -e ".[lint]"
uv pip install --system ".[lint]"

- name: Run ruff
shell: bash
Expand Down
4 changes: 2 additions & 2 deletions .secrets.baseline
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@
"filename": "docs/docs/en/release.md",
"hashed_secret": "35675e68f4b5af7b995d9205ad0fc43842f16450",
"is_verified": false,
"line_number": 1325,
"line_number": 1423,
"is_secret": false
}
],
Expand Down Expand Up @@ -163,5 +163,5 @@
}
]
},
"generated_at": "2024-06-10T09:56:52Z"
"generated_at": "2024-07-23T21:38:30Z"
}
2 changes: 1 addition & 1 deletion docs/docs/en/kafka/Subscriber/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -71,4 +71,4 @@ async def base_handler(
level: str = Path(),
):
...
```
```
27 changes: 27 additions & 0 deletions faststream/nats/subscriber/usecase.py
Original file line number Diff line number Diff line change
Expand Up @@ -426,6 +426,9 @@ async def _create_subscription( # type: ignore[override]
connection: "Client",
) -> None:
"""Create NATS subscription and start consume task."""
if self.subscription:
return

self.subscription = await connection.subscribe(
subject=self.clear_subject,
queue=self.queue,
Expand Down Expand Up @@ -495,6 +498,9 @@ async def _create_subscription( # type: ignore[override]
connection: "Client",
) -> None:
"""Create NATS subscription and start consume task."""
if self.subscription:
return

self.start_consume_task()

self.subscription = await connection.subscribe(
Expand Down Expand Up @@ -576,6 +582,9 @@ async def _create_subscription( # type: ignore[override]
connection: "JetStreamContext",
) -> None:
"""Create NATS subscription and start consume task."""
if self.subscription:
return

self.subscription = await connection.subscribe(
subject=self.clear_subject,
queue=self.queue,
Expand Down Expand Up @@ -636,6 +645,9 @@ async def _create_subscription( # type: ignore[override]
connection: "JetStreamContext",
) -> None:
"""Create NATS subscription and start consume task."""
if self.subscription:
return

self.start_consume_task()

self.subscription = await connection.subscribe(
Expand Down Expand Up @@ -698,6 +710,9 @@ async def _create_subscription( # type: ignore[override]
connection: "JetStreamContext",
) -> None:
"""Create NATS subscription and start consume task."""
if self.subscription:
return

self.subscription = await connection.pull_subscribe(
subject=self.clear_subject,
config=self.config,
Expand Down Expand Up @@ -775,6 +790,9 @@ async def _create_subscription( # type: ignore[override]
connection: "JetStreamContext",
) -> None:
"""Create NATS subscription and start consume task."""
if self.subscription:
return

self.start_consume_task()

self.subscription = await connection.pull_subscribe(
Expand Down Expand Up @@ -841,6 +859,9 @@ async def _create_subscription( # type: ignore[override]
connection: "JetStreamContext",
) -> None:
"""Create NATS subscription and start consume task."""
if self.subscription:
return

self.subscription = await connection.pull_subscribe(
subject=self.clear_subject,
config=self.config,
Expand Down Expand Up @@ -905,6 +926,9 @@ async def _create_subscription( # type: ignore[override]
*,
connection: "KVBucketDeclarer",
) -> None:
if self.subscription:
return

bucket = await connection.create_key_value(
bucket=self.kv_watch.name,
declare=self.kv_watch.declare,
Expand Down Expand Up @@ -1012,6 +1036,9 @@ async def _create_subscription( # type: ignore[override]
*,
connection: "OSBucketDeclarer",
) -> None:
if self.subscription:
return

self.bucket = await connection.create_object_store(
bucket=self.subject,
declare=self.obj_watch.declare,
Expand Down
20 changes: 15 additions & 5 deletions faststream/rabbit/testing.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from typing import TYPE_CHECKING, Any, Optional, Union
from contextlib import contextmanager
from typing import TYPE_CHECKING, Any, Generator, Optional, Union
from unittest import mock
from unittest.mock import AsyncMock

import aiormq
Expand Down Expand Up @@ -34,10 +36,18 @@ class TestRabbitBroker(TestBroker[RabbitBroker]):
"""A class to test RabbitMQ brokers."""

@classmethod
def _patch_test_broker(cls, broker: RabbitBroker) -> None:
broker._channel = AsyncMock()
broker.declarer = AsyncMock()
super()._patch_test_broker(broker)
@contextmanager
def _patch_broker(cls, broker: RabbitBroker) -> Generator[None, None, None]:
with mock.patch.object(
broker,
"_channel",
new_callable=AsyncMock,
), mock.patch.object(
broker,
"declarer",
new_callable=AsyncMock,
), super()._patch_broker(broker):
yield

@staticmethod
async def _fake_connect(broker: RabbitBroker, *args: Any, **kwargs: Any) -> None:
Expand Down
17 changes: 16 additions & 1 deletion faststream/redis/subscriber/usecase.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,10 +151,15 @@ async def start( # type: ignore[override]
self,
*args: Any,
) -> None:
if self.task:
return

await super().start()

start_signal = anyio.Event()
self.task = asyncio.create_task(self._consume(*args, start_signal=start_signal))
self.task = asyncio.create_task(
self._consume(*args, start_signal=start_signal)
)

with anyio.fail_after(3.0):
await start_signal.wait()
Expand Down Expand Up @@ -253,6 +258,9 @@ def get_log_context(

@override
async def start(self) -> None:
if self.subscription:
return

assert self._client, "You should setup subscriber at first." # nosec B101

self.subscription = psub = self._client.pubsub()
Expand Down Expand Up @@ -352,6 +360,9 @@ async def _consume( # type: ignore[override]

@override
async def start(self) -> None:
if self.task:
return

assert self._client, "You should setup subscriber at first." # nosec B101
await super().start(self._client)

Expand Down Expand Up @@ -512,7 +523,11 @@ def get_log_context(

@override
async def start(self) -> None:
if self.task:
return

assert self._client, "You should setup subscriber at first." # nosec B101

client = self._client

self.extra_watcher_options.update(
Expand Down
55 changes: 38 additions & 17 deletions faststream/testing/broker.py
Original file line number Diff line number Diff line change
@@ -1,35 +1,35 @@
import warnings
from abc import abstractmethod
from contextlib import asynccontextmanager
from contextlib import asynccontextmanager, contextmanager
from functools import partial
from types import MethodType
from typing import (
TYPE_CHECKING,
Any,
AsyncGenerator,
Generator,
Generic,
Optional,
Tuple,
Type,
TypeVar,
)
from unittest.mock import AsyncMock, MagicMock
from unittest import mock
from unittest.mock import MagicMock

from faststream.broker.core.usecase import BrokerUsecase
from faststream.broker.message import StreamMessage, decode_message, encode_message
from faststream.broker.middlewares.logging import CriticalLogMiddleware
from faststream.broker.wrapper.call import HandlerCallWrapper
from faststream.testing.app import TestApp
from faststream.utils.ast import is_contains_context_name
from faststream.utils.functions import timeout_scope
from faststream.utils.functions import sync_fake_context, timeout_scope

if TYPE_CHECKING:
from types import TracebackType

from faststream.broker.subscriber.proto import SubscriberProto
from faststream.broker.types import BrokerMiddleware


Broker = TypeVar("Broker", bound=BrokerUsecase[Any, Any])


Expand Down Expand Up @@ -113,22 +113,43 @@ async def __aexit__(self, *args: Any) -> None:
async def _create_ctx(self) -> AsyncGenerator[Broker, None]:
if self.with_real:
self._fake_start(self.broker)
context = sync_fake_context()
else:
self._patch_test_broker(self.broker)
context = self._patch_broker(self.broker)

async with self.broker:
try:
if not self.connect_only:
await self.broker.start()
yield self.broker
finally:
self._fake_close(self.broker)
with context:
async with self.broker:
try:
if not self.connect_only:
await self.broker.start()
yield self.broker
finally:
self._fake_close(self.broker)

@classmethod
def _patch_test_broker(cls, broker: Broker) -> None:
broker.start = AsyncMock(wraps=partial(cls._fake_start, broker)) # type: ignore[method-assign]
broker._connect = MethodType(cls._fake_connect, broker) # type: ignore[method-assign]
broker.close = AsyncMock() # type: ignore[method-assign]
@contextmanager
def _patch_broker(cls, broker: Broker) -> Generator[None, None, None]:
with mock.patch.object(
broker,
"start",
wraps=partial(cls._fake_start, broker),
), mock.patch.object(
broker,
"_connect",
wraps=partial(cls._fake_connect, broker),
), mock.patch.object(
broker,
"close",
), mock.patch.object(
broker,
"_connection",
new=None,
), mock.patch.object(
broker,
"_producer",
new=None,
):
yield

@classmethod
def _fake_start(cls, broker: Broker, *args: Any, **kwargs: Any) -> None:
Expand Down
2 changes: 1 addition & 1 deletion tests/brokers/base/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,5 +53,5 @@ async def test_ping_timeout(self, settings):
kwargs = self.get_broker_args(settings)
broker = self.broker("wrong_url")
await broker.connect(**kwargs)
assert not await broker.ping(timeout=0.00001)
assert not await broker.ping(timeout=1e-24)
await broker.close()
Loading
Loading