Skip to content

Commit dd23b83

Browse files
committed
feat (#1447): add StreamMessage.batch_headers attr to provide access to whole batch messages headers
1 parent 9545f9c commit dd23b83

File tree

17 files changed

+297
-74
lines changed

17 files changed

+297
-74
lines changed

faststream/broker/message.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
TYPE_CHECKING,
77
Any,
88
Generic,
9+
List,
910
Optional,
1011
Sequence,
1112
Tuple,
@@ -38,6 +39,7 @@ class StreamMessage(Generic[MsgType]):
3839

3940
body: Union[bytes, Any]
4041
headers: "AnyDict" = field(default_factory=dict)
42+
batch_headers: List["AnyDict"] = field(default_factory=list)
4143
path: "AnyDict" = field(default_factory=dict)
4244

4345
content_type: Optional[str] = None

faststream/broker/publisher/proto.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,8 +57,7 @@ class PublisherProto(
5757
_producer: Optional["ProducerProto"]
5858

5959
@abstractmethod
60-
def add_middleware(self, middleware: "BrokerMiddleware[MsgType]") -> None:
61-
...
60+
def add_middleware(self, middleware: "BrokerMiddleware[MsgType]") -> None: ...
6261

6362
@staticmethod
6463
@abstractmethod

faststream/broker/subscriber/proto.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,7 @@ class SubscriberProto(
3636
_producer: Optional["ProducerProto"]
3737

3838
@abstractmethod
39-
def add_middleware(self, middleware: "BrokerMiddleware[MsgType]") -> None:
40-
...
39+
def add_middleware(self, middleware: "BrokerMiddleware[MsgType]") -> None: ...
4140

4241
@staticmethod
4342
@abstractmethod

faststream/cli/docs/app.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -45,8 +45,7 @@ def serve(
4545
),
4646
),
4747
is_factory: bool = typer.Option(
48-
False,
49-
"--factory", help="Treat APP as an application factory"
48+
False, "--factory", help="Treat APP as an application factory"
5049
),
5150
) -> None:
5251
"""Serve project AsyncAPI schema."""
@@ -109,8 +108,7 @@ def gen(
109108
),
110109
),
111110
is_factory: bool = typer.Option(
112-
False,
113-
"--factory", help="Treat APP as an application factory"
111+
False, "--factory", help="Treat APP as an application factory"
114112
),
115113
) -> None:
116114
"""Generate project AsyncAPI schema."""

faststream/cli/main.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -210,8 +210,7 @@ def publish(
210210
message: str = typer.Argument(..., help="Message to be published"),
211211
rpc: bool = typer.Option(False, help="Enable RPC mode and system output"),
212212
is_factory: bool = typer.Option(
213-
False,
214-
"--factory", help="Treat APP as an application factory"
213+
False, "--factory", help="Treat APP as an application factory"
215214
),
216215
) -> None:
217216
"""Publish a message using the specified broker in a FastStream application.

faststream/kafka/parser.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import TYPE_CHECKING, Any, Optional, Tuple
1+
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
22

33
from faststream.broker.message import decode_message, gen_cor_id
44
from faststream.kafka.message import FAKE_CONSUMER, KafkaMessage
@@ -39,13 +39,24 @@ async def parse_message_batch(
3939
message: Tuple["ConsumerRecord", ...],
4040
) -> "StreamMessage[Tuple[ConsumerRecord, ...]]":
4141
"""Parses a batch of messages from a Kafka consumer."""
42+
body: List[Any] = []
43+
batch_headers: List[Dict[str, str]] = []
44+
4245
first = message[0]
4346
last = message[-1]
44-
headers = {i: j.decode() for i, j in first.headers}
47+
48+
for m in message:
49+
body.append(m.value)
50+
batch_headers.append({i: j.decode() for i, j in m.headers})
51+
52+
headers = next(iter(batch_headers), {})
53+
4554
handler: Optional["LogicSubscriber[Any]"] = context.get_local("handler_")
55+
4656
return KafkaMessage(
47-
body=[m.value for m in message],
57+
body=body,
4858
headers=headers,
59+
batch_headers=batch_headers,
4960
reply_to=headers.get("reply_to", ""),
5061
content_type=headers.get("content-type"),
5162
message_id=f"{first.offset}-{last.offset}-{first.timestamp}",

faststream/nats/parser.py

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import TYPE_CHECKING, List, Optional
1+
from typing import TYPE_CHECKING, Dict, List, Optional
22

33
from faststream.broker.message import StreamMessage, decode_message, gen_cor_id
44
from faststream.nats.message import NatsBatchMessage, NatsMessage
@@ -102,19 +102,27 @@ async def parse_batch(
102102
self,
103103
message: List["Msg"],
104104
) -> "StreamMessage[List[Msg]]":
105-
if first_msg := next(iter(message), None):
106-
path = self.get_path(first_msg.subject)
107-
headers = first_msg.headers
105+
body: List[bytes] = []
106+
batch_headers: List[Dict[str, str]] = []
107+
108+
if message:
109+
path = self.get_path(message[0].subject)
110+
111+
for m in message:
112+
batch_headers.append(m.headers or {})
113+
body.append(m.data)
108114

109115
else:
110116
path = None
111-
headers = None
117+
118+
headers = next(iter(batch_headers), {})
112119

113120
return NatsBatchMessage(
114121
raw_message=message,
115-
body=[m.data for m in message],
122+
body=body,
116123
path=path or {},
117-
headers=headers or {},
124+
headers=headers,
125+
batch_headers=batch_headers,
118126
)
119127

120128
async def decode_batch(

faststream/rabbit/broker/broker.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@ def __init__(
118118
Doc(
119119
"raise an :class:`aio_pika.exceptions.DeliveryError`"
120120
"when mandatory message will be returned"
121-
)
121+
),
122122
] = False,
123123
# broker args
124124
max_consumers: Annotated[
@@ -345,7 +345,7 @@ async def connect( # type: ignore[override]
345345
Doc(
346346
"raise an :class:`aio_pika.exceptions.DeliveryError`"
347347
"when mandatory message will be returned"
348-
)
348+
),
349349
] = Parameter.empty,
350350
) -> "RobustConnection":
351351
"""Connect broker object to RabbitMQ.

faststream/rabbit/fastapi/router.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ def __init__(
114114
Doc(
115115
"raise an :class:`aio_pika.exceptions.DeliveryError`"
116116
"when mandatory message will be returned"
117-
)
117+
),
118118
] = False,
119119
# broker args
120120
max_consumers: Annotated[

faststream/redis/parser.py

Lines changed: 48 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from typing import (
22
TYPE_CHECKING,
33
Any,
4-
Dict,
4+
List,
55
Mapping,
66
Optional,
77
Sequence,
@@ -136,22 +136,27 @@ async def parse_message(
136136
self,
137137
message: Mapping[str, Any],
138138
) -> "StreamMessage[Mapping[str, Any]]":
139-
data, headers = self._parse_data(message)
139+
data, headers, batch_headers = self._parse_data(message)
140+
140141
id_ = gen_cor_id()
142+
141143
return self.msg_class(
142144
raw_message=message,
143145
body=data,
144146
path=self.get_path(message),
145147
headers=headers,
148+
batch_headers=batch_headers,
146149
reply_to=headers.get("reply_to", ""),
147150
content_type=headers.get("content-type"),
148151
message_id=headers.get("message_id", id_),
149152
correlation_id=headers.get("correlation_id", id_),
150153
)
151154

152155
@staticmethod
153-
def _parse_data(message: Mapping[str, Any]) -> Tuple[bytes, "AnyDict"]:
154-
return RawMessage.parse(message["data"])
156+
def _parse_data(
157+
message: Mapping[str, Any],
158+
) -> Tuple[bytes, "AnyDict", List["AnyDict"]]:
159+
return (*RawMessage.parse(message["data"]), [])
155160

156161
def get_path(self, message: Mapping[str, Any]) -> "AnyDict":
157162
if (
@@ -183,42 +188,68 @@ class RedisBatchListParser(SimpleParser):
183188
msg_class = RedisBatchListMessage
184189

185190
@staticmethod
186-
def _parse_data(message: Mapping[str, Any]) -> Tuple[bytes, "AnyDict"]:
187-
data = [_decode_batch_body_item(x) for x in message["data"]]
191+
def _parse_data(
192+
message: Mapping[str, Any],
193+
) -> Tuple[bytes, "AnyDict", List["AnyDict"]]:
194+
body: List[Any] = []
195+
batch_headers: List["AnyDict"] = []
196+
197+
for x in message["data"]:
198+
msg_data, msg_headers = _decode_batch_body_item(x)
199+
body.append(msg_data)
200+
batch_headers.append(msg_headers)
201+
202+
first_msg_headers = next(iter(batch_headers), {})
203+
188204
return (
189-
dump_json(i[0] for i in data),
205+
dump_json(body),
190206
{
191-
**data[0][1],
192-
"content-type": ContentTypes.json,
207+
**first_msg_headers,
208+
"content-type": ContentTypes.json.value,
193209
},
210+
batch_headers,
194211
)
195212

196213

197214
class RedisStreamParser(SimpleParser):
198215
msg_class = RedisStreamMessage
199216

200217
@classmethod
201-
def _parse_data(cls, message: Mapping[str, Any]) -> Tuple[bytes, "AnyDict"]:
218+
def _parse_data(
219+
cls, message: Mapping[str, Any]
220+
) -> Tuple[bytes, "AnyDict", List["AnyDict"]]:
202221
data = message["data"]
203-
return RawMessage.parse(data.get(bDATA_KEY) or dump_json(data))
222+
return (*RawMessage.parse(data.get(bDATA_KEY) or dump_json(data)), [])
204223

205224

206225
class RedisBatchStreamParser(SimpleParser):
207226
msg_class = RedisBatchStreamMessage
208227

209228
@staticmethod
210-
def _parse_data(message: Mapping[str, Any]) -> Tuple[bytes, "AnyDict"]:
211-
data = [_decode_batch_body_item(x.get(bDATA_KEY, x)) for x in message["data"]]
229+
def _parse_data(
230+
message: Mapping[str, Any],
231+
) -> Tuple[bytes, "AnyDict", List["AnyDict"]]:
232+
body: List[Any] = []
233+
batch_headers: List["AnyDict"] = []
234+
235+
for x in message["data"]:
236+
msg_data, msg_headers = _decode_batch_body_item(x.get(bDATA_KEY, x))
237+
body.append(msg_data)
238+
batch_headers.append(msg_headers)
239+
240+
first_msg_headers = next(iter(batch_headers), {})
241+
212242
return (
213-
dump_json(i[0] for i in data),
243+
dump_json(body),
214244
{
215-
**data[0][1],
216-
"content-type": ContentTypes.json,
245+
**first_msg_headers,
246+
"content-type": ContentTypes.json.value,
217247
},
248+
batch_headers,
218249
)
219250

220251

221-
def _decode_batch_body_item(msg_content: bytes) -> Tuple[Any, Dict[str, str]]:
252+
def _decode_batch_body_item(msg_content: bytes) -> Tuple[Any, "AnyDict"]:
222253
msg_body, headers = RawMessage.parse(msg_content)
223254
try:
224255
return json_loads(msg_body), headers

faststream/types.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -63,22 +63,16 @@ class StandardDataclass(Protocol):
6363
"""Protocol to check type is dataclass."""
6464

6565
__dataclass_fields__: ClassVar[Dict[str, Any]]
66-
__dataclass_params__: ClassVar[Any]
67-
__post_init__: ClassVar[Callable[..., None]]
68-
69-
def __init__(self, *args: object, **kwargs: object) -> None:
70-
"""Interface method."""
71-
...
7266

7367

7468
BaseSendableMessage: TypeAlias = Union[
7569
JsonDecodable,
7670
Decimal,
7771
datetime,
78-
None,
7972
StandardDataclass,
8073
SendableTable,
8174
SendableArray,
75+
None,
8276
]
8377

8478
try:

tests/brokers/base/middlewares.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -271,7 +271,11 @@ async def handler(m):
271271
mock.end.assert_called_once()
272272

273273
async def test_add_global_middleware(
274-
self, event: asyncio.Event, queue: str, mock: Mock, raw_broker,
274+
self,
275+
event: asyncio.Event,
276+
queue: str,
277+
mock: Mock,
278+
raw_broker,
275279
):
276280
class mid(BaseMiddleware): # noqa: N801
277281
async def on_receive(self):

tests/brokers/base/publish.py

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import asyncio
2+
from dataclasses import asdict, dataclass
23
from datetime import datetime
34
from typing import Any, ClassVar, Dict, List, Tuple
45
from unittest.mock import Mock
@@ -7,7 +8,7 @@
78
import pytest
89
from pydantic import BaseModel
910

10-
from faststream._compat import model_to_json
11+
from faststream._compat import dump_json, model_to_json
1112
from faststream.annotations import Logger
1213
from faststream.broker.core.usecase import BrokerUsecase
1314

@@ -16,6 +17,11 @@ class SimpleModel(BaseModel):
1617
r: str
1718

1819

20+
@dataclass
21+
class SimpleDataclass:
22+
r: str
23+
24+
1925
now = datetime.now()
2026

2127

@@ -55,6 +61,12 @@ def pub_broker(self, full_broker):
5561
1.0,
5662
id="float->float",
5763
),
64+
pytest.param(
65+
1,
66+
float,
67+
1.0,
68+
id="int->float",
69+
),
5870
pytest.param(
5971
False,
6072
bool,
@@ -103,6 +115,30 @@ def pub_broker(self, full_broker):
103115
SimpleModel(r="hello!"),
104116
id="dict->model",
105117
),
118+
pytest.param(
119+
dump_json(asdict(SimpleDataclass(r="hello!"))),
120+
SimpleDataclass,
121+
SimpleDataclass(r="hello!"),
122+
id="bytes->dataclass",
123+
),
124+
pytest.param(
125+
SimpleDataclass(r="hello!"),
126+
SimpleDataclass,
127+
SimpleDataclass(r="hello!"),
128+
id="dataclass->dataclass",
129+
),
130+
pytest.param(
131+
SimpleDataclass(r="hello!"),
132+
dict,
133+
{"r": "hello!"},
134+
id="dataclass->dict",
135+
),
136+
pytest.param(
137+
{"r": "hello!"},
138+
SimpleDataclass,
139+
SimpleDataclass(r="hello!"),
140+
id="dict->dataclass",
141+
),
106142
),
107143
)
108144
async def test_serialize(

0 commit comments

Comments
 (0)