Skip to content

Commit e122ef8

Browse files
refactor: specify types for msg_forwarder and msg in ValidationRequest and related methods
Signed-off-by: varun-r-mallya <[email protected]>
1 parent 8d1e5ff commit e122ef8

File tree

1 file changed

+19
-23
lines changed

1 file changed

+19
-23
lines changed

libp2p/pubsub/validation_throttler.py

Lines changed: 19 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,23 @@
11
from collections.abc import (
2-
Awaitable,
32
Callable,
43
)
54
from dataclasses import dataclass
65
from enum import Enum
7-
import inspect
86
import logging
97
from typing import (
10-
Any,
118
NamedTuple,
9+
cast,
1210
)
1311

1412
import trio
1513

16-
from libp2p.custom_types import (
17-
ValidatorFn,
14+
from libp2p.custom_types import AsyncValidatorFn, ValidatorFn
15+
from libp2p.peer.id import (
16+
ID,
17+
)
18+
19+
from .pb import (
20+
rpc_pb2,
1821
)
1922

2023
logger = logging.getLogger("libp2p.pubsub.validation")
@@ -32,9 +35,8 @@ class ValidationRequest:
3235
"""Request for message validation"""
3336

3437
validators: list["TopicValidator"]
35-
# TODO: Use a more specific type for msg_forwarder
36-
msg_forwarder: Any # peer ID
37-
msg: Any # message object
38+
msg_forwarder: ID # peer ID
39+
msg: rpc_pb2.Message # message object
3840
result_callback: Callable[[ValidationResult, Exception | None], None]
3941

4042

@@ -109,8 +111,8 @@ def create_topic_validator(
109111
async def submit_validation(
110112
self,
111113
validators: list[TopicValidator],
112-
msg_forwarder: Any,
113-
msg: Any,
114+
msg_forwarder: ID,
115+
msg: rpc_pb2.Message,
114116
result_callback: Callable[[ValidationResult, Exception | None], None],
115117
) -> bool:
116118
"""
@@ -211,7 +213,7 @@ async def _validate_message(self, request: ValidationRequest) -> ValidationResul
211213
return ValidationResult.ACCEPT
212214

213215
async def _validate_async_validators(
214-
self, validators: list[TopicValidator], msg_forwarder: Any, msg: Any
216+
self, validators: list[TopicValidator], msg_forwarder: ID, msg: rpc_pb2.Message
215217
) -> ValidationResult:
216218
"""Handle async validators with proper throttling"""
217219
if len(validators) == 1:
@@ -268,7 +270,7 @@ async def run_validator(validator: TopicValidator, index: int) -> None:
268270
return ValidationResult.IGNORE
269271

270272
async def _validate_single_async_validator(
271-
self, validator: TopicValidator, msg_forwarder: Any, msg: Any
273+
self, validator: TopicValidator, msg_forwarder: ID, msg: rpc_pb2.Message
272274
) -> ValidationResult:
273275
"""Validate with a single async validator"""
274276
# Apply per-topic throttling
@@ -286,20 +288,14 @@ async def _validate_single_async_validator(
286288

287289
try:
288290
# Apply timeout if configured
289-
result: bool | Awaitable[bool]
291+
result: bool
290292
if validator.timeout:
291293
with trio.fail_after(validator.timeout):
292-
func = validator.validator
293-
if inspect.iscoroutinefunction(func):
294-
result = await func(msg_forwarder, msg)
295-
else:
296-
result = func(msg_forwarder, msg)
297-
else:
298-
func = validator.validator
299-
if inspect.iscoroutinefunction(func):
294+
func = cast(AsyncValidatorFn, validator.validator)
300295
result = await func(msg_forwarder, msg)
301-
else:
302-
result = func(msg_forwarder, msg)
296+
else:
297+
func = cast(AsyncValidatorFn, validator.validator)
298+
result = await func(msg_forwarder, msg)
303299

304300
return ValidationResult.ACCEPT if result else ValidationResult.REJECT
305301

0 commit comments

Comments
 (0)