1
1
import asyncio
2
2
import inspect
3
3
import logging
4
+ import os
4
5
from datetime import datetime
5
6
from typing import Any
6
- import sentry_sdk
7
- import os
8
7
9
- # This import is used to register built-in scorers so they can be deserialized from the DB
10
- import weave .scorers # noqa: F401
8
+ import sentry_sdk
11
9
from confluent_kafka import KafkaError , Message
12
10
from tenacity import (
13
11
before_log ,
16
14
stop_after_attempt ,
17
15
wait_fixed ,
18
16
)
17
+
18
+ # This import is used to register built-in scorers so they can be deserialized from the DB
19
+ import weave .scorers # noqa: F401
19
20
from weave .flow .monitor import Monitor
20
21
from weave .flow .scorer import Scorer , get_scorer_attributes
21
22
from weave .trace .box import box
@@ -56,10 +57,10 @@ def get_trace_server() -> TraceServerInterface:
56
57
57
58
58
59
# This should be cached to avoid hitting ClickHouse for each ended call.
59
- def get_active_monitors (project_id : str ) -> list [ tuple [ Monitor , str ]]:
60
- """
61
- Returns active monitors for a given project.
62
- """
60
+ def get_active_monitors (
61
+ project_id : str ,
62
+ ) -> list [ tuple [ Monitor , InternalObjectRef , str | None ]]:
63
+ """Returns active monitors for a given project."""
63
64
obj_query = tsi .ObjQueryReq (
64
65
project_id = project_id ,
65
66
filter = tsi .ObjectVersionFilter (
@@ -97,9 +98,7 @@ def get_active_monitors(project_id: str) -> list[tuple[Monitor, str]]:
97
98
98
99
99
100
def resolve_scorer_refs (scorer_ref_uris : list [str ], project_id : str ) -> list [Scorer ]:
100
- """
101
- Resolves scorer references to Scorer objects.
102
- """
101
+ """Resolves scorer references to Scorer objects."""
103
102
server = get_trace_server ()
104
103
105
104
scorer_refs = [
@@ -118,7 +117,7 @@ def resolve_scorer_refs(scorer_ref_uris: list[str], project_id: str) -> list[Sco
118
117
for scorer , scorer_ref in zip (scorers , scorer_refs ):
119
118
scorer .__dict__ ["internal_ref" ] = scorer_ref
120
119
121
- return scorers
120
+ return scorers # type: ignore
122
121
123
122
124
123
class CallNotWrittenError (Exception ):
@@ -133,11 +132,11 @@ class CallNotWrittenError(Exception):
133
132
before = before_log (logger , logging .INFO ),
134
133
)
135
134
def get_filtered_call (
136
- op_names : list [str ], query : tsi .Query | None , ended_call : tsi .EndedCallSchemaForInsert
135
+ op_names : list [str ],
136
+ query : tsi .Query | None ,
137
+ ended_call : tsi .EndedCallSchemaForInsert ,
137
138
) -> Call | None :
138
- """
139
- Looks up the call based on a monitor's call filter.
140
- """
139
+ """Looks up the call based on a monitor's call filter."""
141
140
server = get_trace_server ()
142
141
143
142
# We do this two-step querying to circumvent the absence of write->read consistency in ClickHouse.
@@ -169,29 +168,27 @@ def get_filtered_call(
169
168
170
169
if len (calls ) == 0 :
171
170
logger .warning ("No matching calls found for call id %s" , ended_call .id )
172
- return
171
+ return None
173
172
174
173
if len (calls ) > 1 :
175
174
logger .warning ("Multiple calls found for call id %s" , ended_call .id )
176
- return
175
+ return None
177
176
178
177
call = calls [0 ]
179
178
180
179
if not call .ended_at :
181
- return
180
+ return None
182
181
183
182
if call .exception :
184
- return
183
+ return None
185
184
186
185
logger .info ("Found call %s" , call .id )
187
186
188
187
return build_client_call (call )
189
188
190
189
191
190
def build_client_call (server_call : tsi .CallSchema ) -> Call :
192
- """
193
- Converts a server call to a client call.
194
- """
191
+ """Converts a server call to a client call."""
195
192
server = get_trace_server ()
196
193
197
194
return Call (
@@ -217,10 +214,8 @@ async def process_monitor(
217
214
monitor_internal_ref : InternalObjectRef ,
218
215
ended_call : tsi .EndedCallSchemaForInsert ,
219
216
wb_user_id : str ,
220
- ):
221
- """
222
- Actually apply the monitor's scorers for an ended call.
223
- """
217
+ ) -> None :
218
+ """Actually apply the monitor's scorers for an ended call."""
224
219
if (call := get_filtered_call (monitor .op_names , monitor .query , ended_call )) is None :
225
220
return
226
221
@@ -284,9 +279,7 @@ def _do_score_call(scorer: Scorer, call: Call, project_id: str) -> tuple[str, An
284
279
285
280
286
281
def _get_score_call (score_call_id : str , project_id : str ) -> Call :
287
- """
288
- Gets a score call from the DB.
289
- """
282
+ """Gets a score call from the DB."""
290
283
server = get_trace_server ()
291
284
292
285
call_req = tsi .CallsQueryReq (
@@ -304,26 +297,24 @@ async def apply_scorer(
304
297
call : Call ,
305
298
project_id : str ,
306
299
wb_user_id : str ,
307
- ):
308
- """
309
- Actually apply the scorer to the call.
310
- """
300
+ ) -> None :
301
+ """Actually apply the scorer to the call."""
311
302
score_call_id , result = _do_score_call (scorer , call , project_id )
312
303
313
304
score_call = _get_score_call (score_call_id , project_id )
314
305
315
- call_ref = InternalCallRef (project_id = project_id , id = call .id )
316
- score_call_ref = InternalCallRef (project_id = project_id , id = score_call .id )
306
+ call_ref = InternalCallRef (project_id = project_id , id = call .id ) # type: ignore
307
+ score_call_ref = InternalCallRef (project_id = project_id , id = score_call .id ) # type: ignore
317
308
318
- results_json = to_json (result , project_id , None )
309
+ results_json = to_json (result , project_id , None ) # type: ignore
319
310
payload = {"output" : results_json }
320
311
321
312
server = get_trace_server ()
322
313
323
314
feedback_req = FeedbackCreateReq (
324
315
project_id = project_id ,
325
316
weave_ref = call_ref .uri (),
326
- feedback_type = RUNNABLE_FEEDBACK_TYPE_PREFIX + "." + scorer .name ,
317
+ feedback_type = RUNNABLE_FEEDBACK_TYPE_PREFIX + "." + scorer .name , # type: ignore
327
318
payload = payload ,
328
319
runnable_ref = scorer .__dict__ ["internal_ref" ].uri (),
329
320
call_ref = score_call_ref .uri (),
@@ -334,13 +325,13 @@ async def apply_scorer(
334
325
server .feedback_create (feedback_req )
335
326
336
327
337
- async def process_ended_call (ended_call : tsi .EndedCallSchemaForInsert ):
328
+ async def process_ended_call (ended_call : tsi .EndedCallSchemaForInsert ) -> None :
338
329
project_id = ended_call .project_id
339
330
340
331
active_monitors_ref_user_ids = get_active_monitors (project_id )
341
332
342
333
for monitor , monitor_internal_ref , wb_user_id in active_monitors_ref_user_ids :
343
- await process_monitor (monitor , monitor_internal_ref , ended_call , wb_user_id )
334
+ await process_monitor (monitor , monitor_internal_ref , ended_call , wb_user_id ) # type: ignore
344
335
345
336
346
337
def _call_processor_done_callback (
@@ -355,15 +346,15 @@ def _call_processor_done_callback(
355
346
task .result ()
356
347
consumer .commit (msg )
357
348
except Exception as e :
358
- logger .error (f"Error processing message: { e .__class__ .__name__ } { e } " )
349
+ logger .exception (
350
+ f"Error processing message: { e .__class__ .__name__ } { e } " , exc_info = e
351
+ )
359
352
360
353
361
354
async def process_kafka_message (
362
355
msg : Message , consumer : KafkaConsumer , call_processors : set [asyncio .Task ]
363
356
) -> bool :
364
- """
365
- Process a single Kafka message and create a task for it.
366
- """
357
+ """Process a single Kafka message and create a task for it."""
367
358
try :
368
359
ended_call = tsi .EndedCallSchemaForInsert .model_validate_json (
369
360
msg .value ().decode ("utf-8" )
@@ -376,34 +367,30 @@ async def process_kafka_message(
376
367
lambda t : _call_processor_done_callback (t , msg , consumer , call_processors )
377
368
)
378
369
379
- return True
370
+ return True # noqa: TRY300
380
371
except Exception as e :
381
- logger .error ("Error processing message: %s" , e )
372
+ logger .exception ("Error processing message: %s" , e , exc_info = e )
382
373
return False
383
374
384
375
385
376
async def handle_kafka_errors (msg : Message ) -> bool :
386
- """
387
- Handle Kafka-specific errors.
388
- """
377
+ """Handle Kafka-specific errors."""
389
378
if msg .error ():
390
379
if msg .error ().code () == KafkaError ._PARTITION_EOF :
391
380
logger .error (
392
381
"%% %s [%d] reached end at offset %d\n "
393
382
% (msg .topic (), msg .partition (), msg .offset ())
394
383
)
395
384
else :
396
- logger .error ("Kafka error: %s" , msg .error ())
385
+ logger .exception ("Kafka error: %s" , msg .error ())
397
386
398
387
return False
399
388
400
389
return True
401
390
402
391
403
392
async def cleanup_tasks (call_processors : set [asyncio .Task ]) -> set [asyncio .Task ]:
404
- """
405
- Clean up completed tasks and wait for pending ones.
406
- """
393
+ """Clean up completed tasks and wait for pending ones."""
407
394
call_processors = {t for t in call_processors if not t .done ()}
408
395
409
396
if call_processors :
@@ -415,12 +402,10 @@ async def cleanup_tasks(call_processors: set[asyncio.Task]) -> set[asyncio.Task]
415
402
return set ()
416
403
417
404
418
- async def run_consumer ():
419
- """
420
- This is the main loop consuming the ended calls from the Kafka topic.
421
- """
405
+ async def run_consumer () -> None :
406
+ """This is the main loop consuming the ended calls from the Kafka topic."""
422
407
consumer = KafkaConsumer .from_env ()
423
- call_processors = set ()
408
+ call_processors : set [ asyncio . Task ] = set ()
424
409
425
410
consumer .subscribe ([CALL_ENDED_TOPIC ])
426
411
logger .info ("Subscribed to %s" , CALL_ENDED_TOPIC )
@@ -442,15 +427,15 @@ async def run_consumer():
442
427
consumer .close ()
443
428
444
429
445
- def init_sentry ():
430
+ def init_sentry () -> None :
446
431
sentry_sdk .init (
447
432
dsn = "https://[email protected] /4509210797277185" ,
448
433
environment = os .environ .get ("WEAVE_SENTRY_ENV" , "dev" ),
449
- release = weave .version .VERSION
434
+ release = weave .version .VERSION ,
450
435
)
451
436
452
437
453
- async def main ():
438
+ async def main () -> None :
454
439
init_sentry ()
455
440
456
441
await run_consumer ()
0 commit comments