Skip to content

Commit ccc9733

Browse files
authored
Merge branch 'main' into fix-ollama-empty-text-streaming
2 parents 5f75d12 + c583de7 commit ccc9733

File tree

2 files changed

+172
-130
lines changed

2 files changed

+172
-130
lines changed

pydantic_evals/pydantic_evals/otel/span_tree.py

+151-126
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
import re
44
from collections.abc import Iterator, Mapping
55
from datetime import datetime, timedelta, timezone
6-
from functools import partial
76
from textwrap import indent
87
from typing import TYPE_CHECKING, Any, Callable
98

@@ -24,6 +23,11 @@ class SpanQuery(TypedDict, total=False):
2423
All fields are optional and combined with AND logic by default.
2524
"""
2625

26+
# These fields are ordered to match the implementation of SpanNode.matches_query for easy review.
27+
# * Individual span conditions come first because these are generally the cheapest to evaluate
28+
# * Logical combinations come next because they may just be combinations of individual span conditions
29+
# * Related-span conditions come last because they may require the most work to evaluate
30+
2731
# Individual span conditions
2832
## Name conditions
2933
name_equals: str
@@ -43,26 +47,35 @@ class SpanQuery(TypedDict, total=False):
4347
and_: list[SpanQuery]
4448
or_: list[SpanQuery]
4549

46-
# Descendant conditions
50+
# Related-span conditions
51+
## Ancestor conditions
52+
min_depth: int # depth is equivalent to ancestor count; roots have depth 0
53+
max_depth: int
54+
some_ancestor_has: SpanQuery
55+
all_ancestors_have: SpanQuery
56+
no_ancestor_has: SpanQuery
57+
58+
## Child conditions
59+
min_child_count: int
60+
max_child_count: int
4761
some_child_has: SpanQuery
4862
all_children_have: SpanQuery
4963
no_child_has: SpanQuery
50-
min_child_count: int
51-
max_child_count: int
5264

65+
## Descendant conditions
66+
min_descendant_count: int
67+
max_descendant_count: int
5368
some_descendant_has: SpanQuery
5469
all_descendants_have: SpanQuery
5570
no_descendant_has: SpanQuery
5671

57-
# Ancestor conditions
58-
some_ancestor_has: SpanQuery
59-
all_ancestors_have: SpanQuery
60-
no_ancestor_has: SpanQuery
61-
6272

6373
class SpanNode:
6474
"""A node in the span tree; provides references to parents/children for easy traversal and queries."""
6575

76+
# -------------------------------------------------------------------------
77+
# Construction
78+
# -------------------------------------------------------------------------
6679
def __init__(self, span: ReadableSpan):
6780
self._span = span
6881
# If a span has no context, it's going to cause problems. We may need to add improved handling of this scenario.
@@ -71,6 +84,14 @@ def __init__(self, span: ReadableSpan):
7184
self.parent: SpanNode | None = None
7285
self.children_by_id: dict[int, SpanNode] = {} # note: we rely on insertion order to determine child order
7386

87+
def add_child(self, child: SpanNode) -> None:
88+
"""Attach a child node to this node's list of children."""
89+
self.children_by_id[child.span_id] = child
90+
child.parent = self
91+
92+
# -------------------------------------------------------------------------
93+
# Utility properties
94+
# -------------------------------------------------------------------------
7495
@property
7596
def children(self) -> list[SpanNode]:
7697
return list(self.children_by_id.values())
@@ -134,11 +155,6 @@ def attributes(self) -> Mapping[str, AttributeValue]:
134155
# nesting etc. This just exposes the JSON-serialized version, but doing more would be difficult.
135156
return self._span.attributes or {}
136157

137-
def add_child(self, child: SpanNode) -> None:
138-
"""Attach a child node to this node's list of children."""
139-
self.children_by_id[child.span_id] = child
140-
child.parent = self
141-
142158
# -------------------------------------------------------------------------
143159
# Child queries
144160
# -------------------------------------------------------------------------
@@ -155,8 +171,7 @@ def any_child(self, predicate: SpanQuery | SpanPredicate) -> bool:
155171
return self.first_child(predicate) is not None
156172

157173
def _filter_children(self, predicate: SpanQuery | SpanPredicate) -> Iterator[SpanNode]:
158-
predicate = _as_predicate(predicate)
159-
return (child for child in self.children if predicate(child))
174+
return (child for child in self.children if child.matches(predicate))
160175

161176
# -------------------------------------------------------------------------
162177
# Descendant queries (DFS)
@@ -174,11 +189,10 @@ def any_descendant(self, predicate: SpanQuery | SpanPredicate) -> bool:
174189
return self.first_descendant(predicate) is not None
175190

176191
def _filter_descendants(self, predicate: SpanQuery | SpanPredicate) -> Iterator[SpanNode]:
177-
predicate = _as_predicate(predicate)
178192
stack = list(self.children)
179193
while stack:
180194
node = stack.pop()
181-
if predicate(node):
195+
if node.matches(predicate):
182196
yield node
183197
stack.extend(node.children)
184198

@@ -198,19 +212,123 @@ def any_ancestor(self, predicate: SpanQuery | SpanPredicate) -> bool:
198212
return self.first_ancestor(predicate) is not None
199213

200214
def _filter_ancestors(self, predicate: SpanQuery | SpanPredicate) -> Iterator[SpanNode]:
201-
predicate = _as_predicate(predicate)
202215
node = self.parent
203216
while node:
204-
if predicate(node):
217+
if node.matches(predicate):
205218
yield node
206219
node = node.parent
207220

208221
# -------------------------------------------------------------------------
209222
# Query matching
210223
# -------------------------------------------------------------------------
211-
def matches(self, query: SpanQuery) -> bool:
212-
"""Check if the span node matches the query conditions."""
213-
return _matches(self, query)
224+
def matches(self, query: SpanQuery | SpanPredicate) -> bool:
225+
"""Check if the span node matches the query conditions or predicate."""
226+
if callable(query):
227+
return query(self)
228+
229+
return self._matches_query(query)
230+
231+
def _matches_query(self, query: SpanQuery) -> bool: # noqa C901
232+
"""Check if the span matches the query conditions."""
233+
# Logical combinations
234+
if or_ := query.get('or_'):
235+
if len(query) > 1:
236+
raise ValueError("Cannot combine 'or_' conditions with other conditions at the same level")
237+
return any(self._matches_query(q) for q in or_)
238+
if not_ := query.get('not_'):
239+
if self._matches_query(not_):
240+
return False
241+
if and_ := query.get('and_'):
242+
results = [self._matches_query(q) for q in and_]
243+
if not all(results):
244+
return False
245+
# At this point, all existing ANDs and no existing ORs have passed, so it comes down to this condition
246+
247+
# Name conditions
248+
if (name_equals := query.get('name_equals')) and self.name != name_equals:
249+
return False
250+
if (name_contains := query.get('name_contains')) and name_contains not in self.name:
251+
return False
252+
if (name_matches_regex := query.get('name_matches_regex')) and not re.match(name_matches_regex, self.name):
253+
return False
254+
255+
# Attribute conditions
256+
if (has_attributes := query.get('has_attributes')) and not all(
257+
self.attributes.get(key) == value for key, value in has_attributes.items()
258+
):
259+
return False
260+
if (has_attributes_keys := query.get('has_attribute_keys')) and not all(
261+
key in self.attributes for key in has_attributes_keys
262+
):
263+
return False
264+
265+
# Timing conditions
266+
if (min_duration := query.get('min_duration')) is not None:
267+
if not isinstance(min_duration, timedelta):
268+
min_duration = timedelta(seconds=min_duration)
269+
if self.duration < min_duration:
270+
return False
271+
if (max_duration := query.get('max_duration')) is not None:
272+
if not isinstance(max_duration, timedelta):
273+
max_duration = timedelta(seconds=max_duration)
274+
if self.duration > max_duration:
275+
return False
276+
277+
# Ancestor conditions
278+
if (min_depth := query.get('min_depth')) and len(self.ancestors) < min_depth:
279+
return False
280+
if (max_depth := query.get('max_depth')) and len(self.ancestors) > max_depth:
281+
return False
282+
if (some_ancestor_has := query.get('some_ancestor_has')) and not any(
283+
ancestor._matches_query(some_ancestor_has) for ancestor in self.ancestors
284+
):
285+
return False
286+
if (all_ancestors_have := query.get('all_ancestors_have')) and not all(
287+
ancestor._matches_query(all_ancestors_have) for ancestor in self.ancestors
288+
):
289+
return False
290+
if (no_ancestor_has := query.get('no_ancestor_has')) and any(
291+
ancestor._matches_query(no_ancestor_has) for ancestor in self.ancestors
292+
):
293+
return False
294+
295+
# Children conditions
296+
if (min_child_count := query.get('min_child_count')) and len(self.children) < min_child_count:
297+
return False
298+
if (max_child_count := query.get('max_child_count')) and len(self.children) > max_child_count:
299+
return False
300+
if (some_child_has := query.get('some_child_has')) and not any(
301+
child._matches_query(some_child_has) for child in self.children
302+
):
303+
return False
304+
if (all_children_have := query.get('all_children_have')) and not all(
305+
child._matches_query(all_children_have) for child in self.children
306+
):
307+
return False
308+
if (no_child_has := query.get('no_child_has')) and any(
309+
child._matches_query(no_child_has) for child in self.children
310+
):
311+
return False
312+
313+
# Descendant conditions
314+
if (min_descendant_count := query.get('min_descendant_count')) and len(self.descendants) < min_descendant_count:
315+
return False
316+
if (max_descendant_count := query.get('max_descendant_count')) and len(self.descendants) > max_descendant_count:
317+
return False
318+
if (some_descendant_has := query.get('some_descendant_has')) and not any(
319+
descendant._matches_query(some_descendant_has) for descendant in self.descendants
320+
):
321+
return False
322+
if (all_descendants_have := query.get('all_descendants_have')) and not all(
323+
descendant._matches_query(all_descendants_have) for descendant in self.descendants
324+
):
325+
return False
326+
if (no_descendant_has := query.get('no_descendant_has')) and any(
327+
descendant._matches_query(no_descendant_has) for descendant in self.descendants
328+
):
329+
return False
330+
331+
return True
214332

215333
# -------------------------------------------------------------------------
216334
# String representation
@@ -279,6 +397,9 @@ class SpanTree:
279397
You can then search or iterate the tree to make your assertions (using DFS for traversal).
280398
"""
281399

400+
# -------------------------------------------------------------------------
401+
# Construction
402+
# -------------------------------------------------------------------------
282403
def __init__(self, spans: list[ReadableSpan] | None = None):
283404
self.nodes_by_id: dict[int, SpanNode] = {}
284405
self.roots: list[SpanNode] = []
@@ -314,6 +435,9 @@ def _rebuild_tree(self):
314435
if parent_ctx is None or parent_ctx.span_id not in self.nodes_by_id:
315436
self.roots.append(node)
316437

438+
# -------------------------------------------------------------------------
439+
# Node filtering and iteration
440+
# -------------------------------------------------------------------------
317441
def find(self, predicate: SpanQuery | SpanPredicate) -> list[SpanNode]:
318442
"""Find all nodes in the entire tree that match the predicate, scanning from each root in DFS order."""
319443
return list(self._filter(predicate))
@@ -327,15 +451,17 @@ def any(self, predicate: SpanQuery | SpanPredicate) -> bool:
327451
return self.first(predicate) is not None
328452

329453
def _filter(self, predicate: SpanQuery | SpanPredicate) -> Iterator[SpanNode]:
330-
predicate = _as_predicate(predicate)
331454
for node in self:
332-
if predicate(node):
455+
if node.matches(predicate):
333456
yield node
334457

335458
def __iter__(self) -> Iterator[SpanNode]:
336459
"""Return an iterator over all nodes in the tree."""
337460
return iter(self.nodes_by_id.values())
338461

462+
# -------------------------------------------------------------------------
463+
# String representation
464+
# -------------------------------------------------------------------------
339465
def repr_xml(
340466
self,
341467
include_children: bool = True,
@@ -371,104 +497,3 @@ def __str__(self):
371497

372498
def __repr__(self):
373499
return self.repr_xml()
374-
375-
376-
def _as_predicate(query: SpanQuery | SpanPredicate) -> Callable[[SpanNode], bool]:
377-
"""Convert a SpanQuery into a callable predicate that can be used in SpanTree.find_first, etc."""
378-
if callable(query):
379-
return query
380-
381-
return partial(_matches, query=query)
382-
383-
384-
def _matches(span: SpanNode, query: SpanQuery) -> bool: # noqa C901
385-
"""Check if the span matches the query conditions."""
386-
# Logical combinations
387-
if or_ := query.get('or_'):
388-
if len(query) > 1:
389-
raise ValueError("Cannot combine 'or_' conditions with other conditions at the same level")
390-
return any(_matches(span, q) for q in or_)
391-
if not_ := query.get('not_'):
392-
if _matches(span, not_):
393-
return False
394-
if and_ := query.get('and_'):
395-
results = [_matches(span, q) for q in and_]
396-
if not all(results):
397-
return False
398-
# At this point, all existing ANDs and no existing ORs have passed, so it comes down to this condition
399-
400-
# Name conditions
401-
if (name_equals := query.get('name_equals')) and span.name != name_equals:
402-
return False
403-
if (name_contains := query.get('name_contains')) and name_contains not in span.name:
404-
return False
405-
if (name_matches_regex := query.get('name_matches_regex')) and not re.match(name_matches_regex, span.name):
406-
return False
407-
408-
# Attribute conditions
409-
if (has_attributes := query.get('has_attributes')) and not all(
410-
span.attributes.get(key) == value for key, value in has_attributes.items()
411-
):
412-
return False
413-
if (has_attributes_keys := query.get('has_attribute_keys')) and not all(
414-
key in span.attributes for key in has_attributes_keys
415-
):
416-
return False
417-
418-
# Timing conditions
419-
if (min_duration := query.get('min_duration')) is not None and span.duration is not None: # pyright: ignore[reportUnnecessaryComparison]
420-
if not isinstance(min_duration, timedelta):
421-
min_duration = timedelta(seconds=min_duration)
422-
if span.duration < min_duration:
423-
return False
424-
if (max_duration := query.get('max_duration')) is not None and span.duration is not None: # pyright: ignore[reportUnnecessaryComparison]
425-
if not isinstance(max_duration, timedelta):
426-
max_duration = timedelta(seconds=max_duration)
427-
if span.duration > max_duration:
428-
return False
429-
430-
# Children conditions
431-
if (min_child_count := query.get('min_child_count')) and len(span.children) < min_child_count:
432-
return False
433-
if (max_child_count := query.get('max_child_count')) and len(span.children) > max_child_count:
434-
return False
435-
if (some_child_has := query.get('some_child_has')) and not any(
436-
_matches(child, some_child_has) for child in span.children
437-
):
438-
return False
439-
if (all_children_have := query.get('all_children_have')) and not all(
440-
_matches(child, all_children_have) for child in span.children
441-
):
442-
return False
443-
if (no_child_has := query.get('no_child_has')) and any(_matches(child, no_child_has) for child in span.children):
444-
return False
445-
446-
# Descendant conditions
447-
if (some_descendant_has := query.get('some_descendant_has')) and not any(
448-
_matches(child, some_descendant_has) for child in span.descendants
449-
):
450-
return False
451-
if (all_descendants_have := query.get('all_descendants_have')) and not all(
452-
_matches(child, all_descendants_have) for child in span.descendants
453-
):
454-
return False
455-
if (no_descendant_has := query.get('no_descendant_has')) and any(
456-
_matches(child, no_descendant_has) for child in span.descendants
457-
):
458-
return False
459-
460-
# Ancestor conditions
461-
if (some_ancestor_has := query.get('some_ancestor_has')) and not any(
462-
_matches(ancestor, some_ancestor_has) for ancestor in span.ancestors
463-
):
464-
return False
465-
if (all_ancestors_have := query.get('all_ancestors_have')) and not all(
466-
_matches(ancestor, all_ancestors_have) for ancestor in span.ancestors
467-
):
468-
return False
469-
if (no_ancestor_has := query.get('no_ancestor_has')) and any(
470-
_matches(ancestor, no_ancestor_has) for ancestor in span.ancestors
471-
):
472-
return False
473-
474-
return True

0 commit comments

Comments
 (0)