Skip to content

Commit 9c05056

Browse files
committed
linting
1 parent 5ce1374 commit 9c05056

File tree

7 files changed

+58
-78
lines changed

7 files changed

+58
-78
lines changed

weave/flow/scorer.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,8 @@
1414
from weave.trace.isinstance import weave_isinstance
1515
from weave.trace.op import Op, OpCallError, as_op, is_op
1616
from weave.trace.op_caller import async_call_op
17-
from weave.trace.weave_client import Call, sanitize_object_name
1817
from weave.trace.vals import WeaveObject
18+
from weave.trace.weave_client import Call, sanitize_object_name
1919

2020

2121
class Scorer(Object):

weave/scorers/json_scorer.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
import json
22
from typing import Any
3-
from weave.trace.objectify import register_object
4-
from weave.flow.scorer import BuiltInScorer
3+
54
import weave
5+
from weave.flow.scorer import BuiltInScorer
6+
from weave.trace.objectify import register_object
7+
68

79
@register_object
810
class ValidJSONScorer(BuiltInScorer):

weave/trace/weave_client.py

+2-4
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
from __future__ import annotations
22

33
import dataclasses
4-
from datetime import datetime, timezone
54
import json
65
import logging
76
import os
@@ -11,6 +10,7 @@
1110
import time
1211
from collections.abc import Iterator, Sequence
1312
from concurrent.futures import Future
13+
from datetime import datetime, timezone
1414
from functools import lru_cache
1515
from typing import (
1616
TYPE_CHECKING,
@@ -1471,9 +1471,7 @@ def add_cost(
14711471
llm_id: str,
14721472
prompt_token_cost: float,
14731473
completion_token_cost: float,
1474-
effective_date: datetime | None = datetime.now(
1475-
timezone.utc
1476-
),
1474+
effective_date: datetime | None = datetime.now(timezone.utc),
14771475
prompt_token_cost_unit: str | None = "USD",
14781476
completion_token_cost_unit: str | None = "USD",
14791477
provider_id: str | None = "default",

weave/trace_server/clickhouse_trace_server_batched.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,7 @@
9393
)
9494
from weave.trace_server.file_storage_uris import FileStorageURI
9595
from weave.trace_server.ids import generate_id
96+
from weave.trace_server.kafka import KafkaProducer
9697
from weave.trace_server.llm_completion import (
9798
get_custom_provider_info,
9899
lite_llm_completion,
@@ -137,8 +138,6 @@
137138
extract_refs_from_values,
138139
str_digest,
139140
)
140-
from weave.trace_server.kafka import KafkaProducer
141-
142141

143142
logger = logging.getLogger(__name__)
144143
logger.setLevel(logging.INFO)

weave/trace_server/environment.py

-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import os
22
from typing import Optional
33

4-
54
# Kafka Settings
65

76

weave/trace_server/kafka.py

+4-7
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,15 @@
1-
import json
21
import socket
3-
from typing import Iterator
4-
from confluent_kafka import Consumer as ConfluentKafkaConsumer, Producer as ConfluentKafkaProducer
2+
3+
from confluent_kafka import Consumer as ConfluentKafkaConsumer
4+
from confluent_kafka import Producer as ConfluentKafkaProducer
55

66
from weave.trace_server import trace_server_interface as tsi
77
from weave.trace_server.environment import wf_kafka_broker_host, wf_kafka_broker_port
88

9-
109
CALL_ENDED_TOPIC = "weave.call_ended"
1110

1211

1312
class KafkaProducer(ConfluentKafkaProducer):
14-
1513
@classmethod
1614
def from_env(cls) -> "KafkaProducer":
1715
conf = {
@@ -29,7 +27,6 @@ def produce_call_end(self, call_end: tsi.EndedCallSchemaForInsert) -> None:
2927

3028

3129
class KafkaConsumer(ConfluentKafkaConsumer):
32-
3330
@classmethod
3431
def from_env(cls) -> "KafkaConsumer":
3532
conf = {
@@ -40,7 +37,7 @@ def from_env(cls) -> "KafkaConsumer":
4037
}
4138
consumer = cls(conf)
4239
return consumer
43-
40+
4441

4542
def _make_broker_host() -> str:
4643
return f"{wf_kafka_broker_host()}:{wf_kafka_broker_port()}"

weave/workers/weave_scorer.py

+46-61
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,11 @@
11
import asyncio
22
import inspect
33
import logging
4+
import os
45
from datetime import datetime
56
from typing import Any
6-
import sentry_sdk
7-
import os
87

9-
# This import is used to register built-in scorers so they can be deserialized from the DB
10-
import weave.scorers # noqa: F401
8+
import sentry_sdk
119
from confluent_kafka import KafkaError, Message
1210
from tenacity import (
1311
before_log,
@@ -16,6 +14,9 @@
1614
stop_after_attempt,
1715
wait_fixed,
1816
)
17+
18+
# This import is used to register built-in scorers so they can be deserialized from the DB
19+
import weave.scorers # noqa: F401
1920
from weave.flow.monitor import Monitor
2021
from weave.flow.scorer import Scorer, get_scorer_attributes
2122
from weave.trace.box import box
@@ -56,10 +57,10 @@ def get_trace_server() -> TraceServerInterface:
5657

5758

5859
# This should be cached to avoid hitting ClickHouse for each ended call.
59-
def get_active_monitors(project_id: str) -> list[tuple[Monitor, str]]:
60-
"""
61-
Returns active monitors for a given project.
62-
"""
60+
def get_active_monitors(
61+
project_id: str,
62+
) -> list[tuple[Monitor, InternalObjectRef, str | None]]:
63+
"""Returns active monitors for a given project."""
6364
obj_query = tsi.ObjQueryReq(
6465
project_id=project_id,
6566
filter=tsi.ObjectVersionFilter(
@@ -97,9 +98,7 @@ def get_active_monitors(project_id: str) -> list[tuple[Monitor, str]]:
9798

9899

99100
def resolve_scorer_refs(scorer_ref_uris: list[str], project_id: str) -> list[Scorer]:
100-
"""
101-
Resolves scorer references to Scorer objects.
102-
"""
101+
"""Resolves scorer references to Scorer objects."""
103102
server = get_trace_server()
104103

105104
scorer_refs = [
@@ -118,7 +117,7 @@ def resolve_scorer_refs(scorer_ref_uris: list[str], project_id: str) -> list[Sco
118117
for scorer, scorer_ref in zip(scorers, scorer_refs):
119118
scorer.__dict__["internal_ref"] = scorer_ref
120119

121-
return scorers
120+
return scorers # type: ignore
122121

123122

124123
class CallNotWrittenError(Exception):
@@ -133,11 +132,11 @@ class CallNotWrittenError(Exception):
133132
before=before_log(logger, logging.INFO),
134133
)
135134
def get_filtered_call(
136-
op_names: list[str], query: tsi.Query | None, ended_call: tsi.EndedCallSchemaForInsert
135+
op_names: list[str],
136+
query: tsi.Query | None,
137+
ended_call: tsi.EndedCallSchemaForInsert,
137138
) -> Call | None:
138-
"""
139-
Looks up the call based on a monitor's call filter.
140-
"""
139+
"""Looks up the call based on a monitor's call filter."""
141140
server = get_trace_server()
142141

143142
# We do this two-step querying to circumvent the absence of write->read consistency in ClickHouse.
@@ -169,29 +168,27 @@ def get_filtered_call(
169168

170169
if len(calls) == 0:
171170
logger.warning("No matching calls found for call id %s", ended_call.id)
172-
return
171+
return None
173172

174173
if len(calls) > 1:
175174
logger.warning("Multiple calls found for call id %s", ended_call.id)
176-
return
175+
return None
177176

178177
call = calls[0]
179178

180179
if not call.ended_at:
181-
return
180+
return None
182181

183182
if call.exception:
184-
return
183+
return None
185184

186185
logger.info("Found call %s", call.id)
187186

188187
return build_client_call(call)
189188

190189

191190
def build_client_call(server_call: tsi.CallSchema) -> Call:
192-
"""
193-
Converts a server call to a client call.
194-
"""
191+
"""Converts a server call to a client call."""
195192
server = get_trace_server()
196193

197194
return Call(
@@ -217,10 +214,8 @@ async def process_monitor(
217214
monitor_internal_ref: InternalObjectRef,
218215
ended_call: tsi.EndedCallSchemaForInsert,
219216
wb_user_id: str,
220-
):
221-
"""
222-
Actually apply the monitor's scorers for an ended call.
223-
"""
217+
) -> None:
218+
"""Actually apply the monitor's scorers for an ended call."""
224219
if (call := get_filtered_call(monitor.op_names, monitor.query, ended_call)) is None:
225220
return
226221

@@ -284,9 +279,7 @@ def _do_score_call(scorer: Scorer, call: Call, project_id: str) -> tuple[str, An
284279

285280

286281
def _get_score_call(score_call_id: str, project_id: str) -> Call:
287-
"""
288-
Gets a score call from the DB.
289-
"""
282+
"""Gets a score call from the DB."""
290283
server = get_trace_server()
291284

292285
call_req = tsi.CallsQueryReq(
@@ -304,26 +297,24 @@ async def apply_scorer(
304297
call: Call,
305298
project_id: str,
306299
wb_user_id: str,
307-
):
308-
"""
309-
Actually apply the scorer to the call.
310-
"""
300+
) -> None:
301+
"""Actually apply the scorer to the call."""
311302
score_call_id, result = _do_score_call(scorer, call, project_id)
312303

313304
score_call = _get_score_call(score_call_id, project_id)
314305

315-
call_ref = InternalCallRef(project_id=project_id, id=call.id)
316-
score_call_ref = InternalCallRef(project_id=project_id, id=score_call.id)
306+
call_ref = InternalCallRef(project_id=project_id, id=call.id) # type: ignore
307+
score_call_ref = InternalCallRef(project_id=project_id, id=score_call.id) # type: ignore
317308

318-
results_json = to_json(result, project_id, None)
309+
results_json = to_json(result, project_id, None) # type: ignore
319310
payload = {"output": results_json}
320311

321312
server = get_trace_server()
322313

323314
feedback_req = FeedbackCreateReq(
324315
project_id=project_id,
325316
weave_ref=call_ref.uri(),
326-
feedback_type=RUNNABLE_FEEDBACK_TYPE_PREFIX + "." + scorer.name,
317+
feedback_type=RUNNABLE_FEEDBACK_TYPE_PREFIX + "." + scorer.name, # type: ignore
327318
payload=payload,
328319
runnable_ref=scorer.__dict__["internal_ref"].uri(),
329320
call_ref=score_call_ref.uri(),
@@ -334,13 +325,13 @@ async def apply_scorer(
334325
server.feedback_create(feedback_req)
335326

336327

337-
async def process_ended_call(ended_call: tsi.EndedCallSchemaForInsert):
328+
async def process_ended_call(ended_call: tsi.EndedCallSchemaForInsert) -> None:
338329
project_id = ended_call.project_id
339330

340331
active_monitors_ref_user_ids = get_active_monitors(project_id)
341332

342333
for monitor, monitor_internal_ref, wb_user_id in active_monitors_ref_user_ids:
343-
await process_monitor(monitor, monitor_internal_ref, ended_call, wb_user_id)
334+
await process_monitor(monitor, monitor_internal_ref, ended_call, wb_user_id) # type: ignore
344335

345336

346337
def _call_processor_done_callback(
@@ -355,15 +346,15 @@ def _call_processor_done_callback(
355346
task.result()
356347
consumer.commit(msg)
357348
except Exception as e:
358-
logger.error(f"Error processing message: {e.__class__.__name__} {e}")
349+
logger.exception(
350+
f"Error processing message: {e.__class__.__name__} {e}", exc_info=e
351+
)
359352

360353

361354
async def process_kafka_message(
362355
msg: Message, consumer: KafkaConsumer, call_processors: set[asyncio.Task]
363356
) -> bool:
364-
"""
365-
Process a single Kafka message and create a task for it.
366-
"""
357+
"""Process a single Kafka message and create a task for it."""
367358
try:
368359
ended_call = tsi.EndedCallSchemaForInsert.model_validate_json(
369360
msg.value().decode("utf-8")
@@ -376,34 +367,30 @@ async def process_kafka_message(
376367
lambda t: _call_processor_done_callback(t, msg, consumer, call_processors)
377368
)
378369

379-
return True
370+
return True # noqa: TRY300
380371
except Exception as e:
381-
logger.error("Error processing message: %s", e)
372+
logger.exception("Error processing message: %s", e, exc_info=e)
382373
return False
383374

384375

385376
async def handle_kafka_errors(msg: Message) -> bool:
386-
"""
387-
Handle Kafka-specific errors.
388-
"""
377+
"""Handle Kafka-specific errors."""
389378
if msg.error():
390379
if msg.error().code() == KafkaError._PARTITION_EOF:
391380
logger.error(
392381
"%% %s [%d] reached end at offset %d\n"
393382
% (msg.topic(), msg.partition(), msg.offset())
394383
)
395384
else:
396-
logger.error("Kafka error: %s", msg.error())
385+
logger.exception("Kafka error: %s", msg.error())
397386

398387
return False
399388

400389
return True
401390

402391

403392
async def cleanup_tasks(call_processors: set[asyncio.Task]) -> set[asyncio.Task]:
404-
"""
405-
Clean up completed tasks and wait for pending ones.
406-
"""
393+
"""Clean up completed tasks and wait for pending ones."""
407394
call_processors = {t for t in call_processors if not t.done()}
408395

409396
if call_processors:
@@ -415,12 +402,10 @@ async def cleanup_tasks(call_processors: set[asyncio.Task]) -> set[asyncio.Task]
415402
return set()
416403

417404

418-
async def run_consumer():
419-
"""
420-
This is the main loop consuming the ended calls from the Kafka topic.
421-
"""
405+
async def run_consumer() -> None:
406+
"""This is the main loop consuming the ended calls from the Kafka topic."""
422407
consumer = KafkaConsumer.from_env()
423-
call_processors = set()
408+
call_processors: set[asyncio.Task] = set()
424409

425410
consumer.subscribe([CALL_ENDED_TOPIC])
426411
logger.info("Subscribed to %s", CALL_ENDED_TOPIC)
@@ -442,15 +427,15 @@ async def run_consumer():
442427
consumer.close()
443428

444429

445-
def init_sentry():
430+
def init_sentry() -> None:
446431
sentry_sdk.init(
447432
dsn="https://[email protected]/4509210797277185",
448433
environment=os.environ.get("WEAVE_SENTRY_ENV", "dev"),
449-
release=weave.version.VERSION
434+
release=weave.version.VERSION,
450435
)
451436

452437

453-
async def main():
438+
async def main() -> None:
454439
init_sentry()
455440

456441
await run_consumer()

0 commit comments

Comments
 (0)