1
1
from collections .abc import (
2
- Awaitable ,
3
2
Callable ,
4
3
)
5
4
from dataclasses import dataclass
6
5
from enum import Enum
7
- import inspect
8
6
import logging
9
7
from typing import (
10
- Any ,
11
8
NamedTuple ,
9
+ cast ,
12
10
)
13
11
14
12
import trio
15
13
16
- from libp2p .custom_types import (
17
- ValidatorFn ,
14
+ from libp2p .custom_types import AsyncValidatorFn , ValidatorFn
15
+ from libp2p .peer .id import (
16
+ ID ,
17
+ )
18
+
19
+ from .pb import (
20
+ rpc_pb2 ,
18
21
)
19
22
20
23
logger = logging .getLogger ("libp2p.pubsub.validation" )
@@ -32,9 +35,8 @@ class ValidationRequest:
32
35
"""Request for message validation"""
33
36
34
37
validators : list ["TopicValidator" ]
35
- # TODO: Use a more specific type for msg_forwarder
36
- msg_forwarder : Any # peer ID
37
- msg : Any # message object
38
+ msg_forwarder : ID # peer ID
39
+ msg : rpc_pb2 .Message # message object
38
40
result_callback : Callable [[ValidationResult , Exception | None ], None ]
39
41
40
42
@@ -109,8 +111,8 @@ def create_topic_validator(
109
111
async def submit_validation (
110
112
self ,
111
113
validators : list [TopicValidator ],
112
- msg_forwarder : Any ,
113
- msg : Any ,
114
+ msg_forwarder : ID ,
115
+ msg : rpc_pb2 . Message ,
114
116
result_callback : Callable [[ValidationResult , Exception | None ], None ],
115
117
) -> bool :
116
118
"""
@@ -211,7 +213,7 @@ async def _validate_message(self, request: ValidationRequest) -> ValidationResul
211
213
return ValidationResult .ACCEPT
212
214
213
215
async def _validate_async_validators (
214
- self , validators : list [TopicValidator ], msg_forwarder : Any , msg : Any
216
+ self , validators : list [TopicValidator ], msg_forwarder : ID , msg : rpc_pb2 . Message
215
217
) -> ValidationResult :
216
218
"""Handle async validators with proper throttling"""
217
219
if len (validators ) == 1 :
@@ -268,7 +270,7 @@ async def run_validator(validator: TopicValidator, index: int) -> None:
268
270
return ValidationResult .IGNORE
269
271
270
272
async def _validate_single_async_validator (
271
- self , validator : TopicValidator , msg_forwarder : Any , msg : Any
273
+ self , validator : TopicValidator , msg_forwarder : ID , msg : rpc_pb2 . Message
272
274
) -> ValidationResult :
273
275
"""Validate with a single async validator"""
274
276
# Apply per-topic throttling
@@ -286,20 +288,14 @@ async def _validate_single_async_validator(
286
288
287
289
try :
288
290
# Apply timeout if configured
289
- result : bool | Awaitable [ bool ]
291
+ result : bool
290
292
if validator .timeout :
291
293
with trio .fail_after (validator .timeout ):
292
- func = validator .validator
293
- if inspect .iscoroutinefunction (func ):
294
- result = await func (msg_forwarder , msg )
295
- else :
296
- result = func (msg_forwarder , msg )
297
- else :
298
- func = validator .validator
299
- if inspect .iscoroutinefunction (func ):
294
+ func = cast (AsyncValidatorFn , validator .validator )
300
295
result = await func (msg_forwarder , msg )
301
- else :
302
- result = func (msg_forwarder , msg )
296
+ else :
297
+ func = cast (AsyncValidatorFn , validator .validator )
298
+ result = await func (msg_forwarder , msg )
303
299
304
300
return ValidationResult .ACCEPT if result else ValidationResult .REJECT
305
301
0 commit comments