From 1842dd00012d0872c4a3493fa82bb5c741634fa7 Mon Sep 17 00:00:00 2001 From: gtarpenning Date: Mon, 21 Apr 2025 10:24:10 -0700 Subject: [PATCH 1/5] perf(weave): add pre group by ref filtersing --- .../calls_query_builder.py | 52 +++++++++++++++++-- 1 file changed, 48 insertions(+), 4 deletions(-) diff --git a/weave/trace_server/calls_query_builder/calls_query_builder.py b/weave/trace_server/calls_query_builder/calls_query_builder.py index a8044c3d88aa..397d91822299 100644 --- a/weave/trace_server/calls_query_builder/calls_query_builder.py +++ b/weave/trace_server/calls_query_builder/calls_query_builder.py @@ -716,12 +716,19 @@ def _as_sql_base_format( # call starts before grouping, creating orphan call ends. By conditioning # on `NOT any(started_at) is NULL`, we filter out orphaned call ends, ensuring # all rows returned at least have a call start. - op_name_sql = process_op_name_filter_to_conditions( + op_name_sql = process_op_name_filter_to_sql( self.hardcoded_filter, pb, table_alias, ) - trace_id_sql = process_trace_id_filter_to_conditions( + trace_id_sql = process_trace_id_filter_to_sql( + self.hardcoded_filter, + pb, + table_alias, + ) + # ref filters also have group by filters, because output_refs exist on the + # call end parts. + ref_filter_opt_sql = process_ref_filters_to_sql( self.hardcoded_filter, pb, table_alias, @@ -1077,7 +1084,7 @@ def process_operand(operand: "tsi_query.Operand") -> str: ) -def process_op_name_filter_to_conditions( +def process_op_name_filter_to_sql( hardcoded_filter: Optional[HardCodedFilter], param_builder: ParamBuilder, table_alias: str, @@ -1128,7 +1135,7 @@ def process_op_name_filter_to_conditions( return " AND " + combine_conditions(or_conditions, "OR") -def process_trace_id_filter_to_conditions( +def process_trace_id_filter_to_sql( hardcoded_filter: Optional[HardCodedFilter], param_builder: ParamBuilder, table_alias: str, @@ -1159,6 +1166,43 @@ def process_trace_id_filter_to_conditions( return f" AND ({trace_cond} OR {trace_id_field_sql} IS NULL)" +def process_ref_filters_to_sql( + hardcoded_filter: Optional[HardCodedFilter], + param_builder: ParamBuilder, + table_alias: str, +) -> str: + """Adds a ref filter optimization to the query. + + To be used before group by. This filter is NOT guaranteed to return + the correct results, as it can operate on call ends (output_refs) so it + should be used in addition to the existing ref filters after group by + generated in process_calls_filter_to_conditions.""" + if hardcoded_filter is None or ( + not hardcoded_filter.filter.output_refs + and not hardcoded_filter.filter.input_refs + ): + return "" + + def process_ref_filter(field_name: str) -> str: + field = get_field_by_name(field_name) + if not isinstance(field, CallsMergedAggField): + raise TypeError(f"{field_name} is not an aggregate field") + + field_sql = field.as_sql(param_builder, table_alias, use_agg_fn=False) + param = param_builder.add_param(hardcoded_filter.filter.output_refs) + ref_filter_sql = f"hasAny({field_sql}, {param_slot(param, 'Array(String)')})" + return f"({ref_filter_sql} OR length({field_sql}) = 0)" + + ref_filters = [] + if hardcoded_filter.filter.output_refs: + ref_filters.append(process_ref_filter("output_refs")) + + if hardcoded_filter.filter.input_refs: + ref_filters.append(process_ref_filter("input_refs")) + + return combine_conditions(ref_filters, "AND") + + def process_calls_filter_to_conditions( filter: tsi.CallsFilter, param_builder: ParamBuilder, From d9a8ad0d1253d1bbaddd528dba18472271117b65 Mon Sep 17 00:00:00 2001 From: gtarpenning Date: Mon, 21 Apr 2025 11:28:40 -0700 Subject: [PATCH 2/5] unittest --- .../trace_server/test_calls_query_builder.py | 36 +++++++++++++++++++ .../calls_query_builder.py | 11 ++++-- 2 files changed, 45 insertions(+), 2 deletions(-) diff --git a/tests/trace_server/test_calls_query_builder.py b/tests/trace_server/test_calls_query_builder.py index 69555b56115b..ff8879641efc 100644 --- a/tests/trace_server/test_calls_query_builder.py +++ b/tests/trace_server/test_calls_query_builder.py @@ -2034,6 +2034,42 @@ def test_trace_id_filter_eq(): ) +def test_input_output_refs_filter(): + cq = CallsQuery(project_id="project") + cq.add_field("id") + cq.hardcoded_filter = HardCodedFilter( + filter={ + "input_refs": ["weave-trace-internal:///%"], + "output_refs": ["weave-trace-internal:///%"], + } + ) + assert_sql( + cq, + """ + SELECT + calls_merged.id AS id + FROM calls_merged + WHERE calls_merged.project_id = {pb_4:String} + AND ((hasAny(calls_merged.output_refs, {pb_2:Array(String)}) + OR length(calls_merged.output_refs) = 0) + AND (hasAny(calls_merged.input_refs, {pb_3:Array(String)}) + OR length(calls_merged.input_refs) = 0)) + GROUP BY (calls_merged.project_id, calls_merged.id) + HAVING (((any(calls_merged.deleted_at) IS NULL)) + AND ((NOT ((any(calls_merged.started_at) IS NULL)))) + AND (((hasAny(array_concat_agg(calls_merged.input_refs), {pb_0:Array(String)})) + AND (hasAny(array_concat_agg(calls_merged.output_refs), {pb_1:Array(String)}))))) + """, + { + "pb_4": "project", + "pb_0": ["weave-trace-internal:///%"], + "pb_1": ["weave-trace-internal:///%"], + "pb_2": ["weave-trace-internal:///%"], + "pb_3": ["weave-trace-internal:///%"], + }, + ) + + def test_filter_length_validation(): """Test that filter length validation works""" pb = ParamBuilder() diff --git a/weave/trace_server/calls_query_builder/calls_query_builder.py b/weave/trace_server/calls_query_builder/calls_query_builder.py index 397d91822299..00fe6211aa53 100644 --- a/weave/trace_server/calls_query_builder/calls_query_builder.py +++ b/weave/trace_server/calls_query_builder/calls_query_builder.py @@ -825,6 +825,7 @@ def _as_sql_base_format( {op_name_sql} {trace_id_sql} {str_filter_opt_sql} + {ref_filter_opt_sql} GROUP BY (calls_merged.project_id, calls_merged.id) {having_filter_sql} {order_by_sql} @@ -1191,7 +1192,7 @@ def process_ref_filter(field_name: str) -> str: field_sql = field.as_sql(param_builder, table_alias, use_agg_fn=False) param = param_builder.add_param(hardcoded_filter.filter.output_refs) ref_filter_sql = f"hasAny({field_sql}, {param_slot(param, 'Array(String)')})" - return f"({ref_filter_sql} OR length({field_sql}) = 0)" + return f"{ref_filter_sql} OR length({field_sql}) = 0" ref_filters = [] if hardcoded_filter.filter.output_refs: @@ -1200,7 +1201,10 @@ def process_ref_filter(field_name: str) -> str: if hardcoded_filter.filter.input_refs: ref_filters.append(process_ref_filter("input_refs")) - return combine_conditions(ref_filters, "AND") + if not ref_filters: + return "" + + return " AND " + combine_conditions(ref_filters, "AND") def process_calls_filter_to_conditions( @@ -1214,6 +1218,9 @@ def process_calls_filter_to_conditions( """ conditions: list[str] = [] + # technically not required, as we are now doing a pre-groupby optimization + # that should filter out 100% of non-matching rows. However, we can't remove + # the output_refs, so lets keep both for clarity if filter.input_refs: assert_parameter_length_less_than_max("input_refs", len(filter.input_refs)) conditions.append( From 7ec615e253db41b8e7e09d6c3599866045515a41 Mon Sep 17 00:00:00 2001 From: gtarpenning Date: Mon, 21 Apr 2025 11:54:38 -0700 Subject: [PATCH 3/5] fix --- .../trace_server/test_calls_query_builder.py | 20 +++++++++---------- .../calls_query_builder.py | 15 ++++++++------ 2 files changed, 19 insertions(+), 16 deletions(-) diff --git a/tests/trace_server/test_calls_query_builder.py b/tests/trace_server/test_calls_query_builder.py index ff8879641efc..476d9560dc67 100644 --- a/tests/trace_server/test_calls_query_builder.py +++ b/tests/trace_server/test_calls_query_builder.py @@ -2039,8 +2039,8 @@ def test_input_output_refs_filter(): cq.add_field("id") cq.hardcoded_filter = HardCodedFilter( filter={ - "input_refs": ["weave-trace-internal:///%"], - "output_refs": ["weave-trace-internal:///%"], + "input_refs": ["weave-trace-internal:///222222222222%"], + "output_refs": ["weave-trace-internal:///111111111111%"], } ) assert_sql( @@ -2050,10 +2050,10 @@ def test_input_output_refs_filter(): calls_merged.id AS id FROM calls_merged WHERE calls_merged.project_id = {pb_4:String} - AND ((hasAny(calls_merged.output_refs, {pb_2:Array(String)}) - OR length(calls_merged.output_refs) = 0) - AND (hasAny(calls_merged.input_refs, {pb_3:Array(String)}) - OR length(calls_merged.input_refs) = 0)) + AND ((hasAny(calls_merged.input_refs, {pb_2:Array(String)}) + OR length(calls_merged.input_refs) = 0) + AND (hasAny(calls_merged.output_refs, {pb_3:Array(String)}) + OR length(calls_merged.output_refs) = 0)) GROUP BY (calls_merged.project_id, calls_merged.id) HAVING (((any(calls_merged.deleted_at) IS NULL)) AND ((NOT ((any(calls_merged.started_at) IS NULL)))) @@ -2062,10 +2062,10 @@ def test_input_output_refs_filter(): """, { "pb_4": "project", - "pb_0": ["weave-trace-internal:///%"], - "pb_1": ["weave-trace-internal:///%"], - "pb_2": ["weave-trace-internal:///%"], - "pb_3": ["weave-trace-internal:///%"], + "pb_0": ["weave-trace-internal:///222222222222%"], + "pb_1": ["weave-trace-internal:///111111111111%"], + "pb_2": ["weave-trace-internal:///222222222222%"], + "pb_3": ["weave-trace-internal:///111111111111%"], }, ) diff --git a/weave/trace_server/calls_query_builder/calls_query_builder.py b/weave/trace_server/calls_query_builder/calls_query_builder.py index 00fe6211aa53..c5a46afcecb6 100644 --- a/weave/trace_server/calls_query_builder/calls_query_builder.py +++ b/weave/trace_server/calls_query_builder/calls_query_builder.py @@ -1184,22 +1184,25 @@ def process_ref_filters_to_sql( ): return "" - def process_ref_filter(field_name: str) -> str: + def process_ref_filter(field_name: str, refs: list[str]) -> str: field = get_field_by_name(field_name) if not isinstance(field, CallsMergedAggField): raise TypeError(f"{field_name} is not an aggregate field") field_sql = field.as_sql(param_builder, table_alias, use_agg_fn=False) - param = param_builder.add_param(hardcoded_filter.filter.output_refs) + param = param_builder.add_param(refs) ref_filter_sql = f"hasAny({field_sql}, {param_slot(param, 'Array(String)')})" return f"{ref_filter_sql} OR length({field_sql}) = 0" ref_filters = [] - if hardcoded_filter.filter.output_refs: - ref_filters.append(process_ref_filter("output_refs")) - if hardcoded_filter.filter.input_refs: - ref_filters.append(process_ref_filter("input_refs")) + ref_filters.append( + process_ref_filter("input_refs", hardcoded_filter.filter.input_refs) + ) + if hardcoded_filter.filter.output_refs: + ref_filters.append( + process_ref_filter("output_refs", hardcoded_filter.filter.output_refs) + ) if not ref_filters: return "" From 561f86164f46803ad29042dae1dd52f84ccfeb8e Mon Sep 17 00:00:00 2001 From: gtarpenning Date: Wed, 23 Apr 2025 18:14:27 -0700 Subject: [PATCH 4/5] end to end tests --- tests/trace/test_weave_client.py | 60 ++++++++++++++++++++++++++++++++ 1 file changed, 60 insertions(+) diff --git a/tests/trace/test_weave_client.py b/tests/trace/test_weave_client.py index 3eb36c8228a2..cbecfda734fc 100644 --- a/tests/trace/test_weave_client.py +++ b/tests/trace/test_weave_client.py @@ -3276,3 +3276,63 @@ def test_calls_query_with_non_uuidv7_ids(client): assert call_ids[5] == uuidv7_calls[1].id assert call_ids[6] == uuidv7_calls[2].id assert call_ids[7] == uuidv7_calls[3].id + + +def test_filter_calls_by_ref(client): + obj = {"a": 1} + ref = client.save(obj, "obj").ref + ref2 = client.save(obj, "obj2").ref + ref3 = client.save(obj, "obj3").ref + + @weave.op + def log_obj(ref: str): + return { + "ref2": ref2, + "ref3": ref3, + } + + log_obj(ref) + + calls = client.get_calls() + assert len(calls) == 1 + assert calls[0].inputs["ref"] == obj + assert calls[0].output["ref2"] == obj + + # now query by filtering for input ref + calls = client.get_calls(filter={"input_refs": [ref.uri()]}) + assert len(calls) == 1 + assert calls[0].inputs["ref"] == obj + assert calls[0].output["ref2"] == obj + + # now query by filtering for output ref + calls = client.get_calls(filter={"output_refs": [ref2.uri()]}) + assert len(calls) == 1 + assert calls[0].inputs["ref"] == obj + assert calls[0].output["ref2"] == obj + + # filter by both input and output ref + calls = client.get_calls( + filter={"input_refs": [ref.uri()], "output_refs": [ref2.uri()]} + ) + assert len(calls) == 1 + assert calls[0].inputs["ref"] == obj + assert calls[0].output["ref2"] == obj + + # filter by the wrong ref + calls = client.get_calls(filter={"input_refs": [ref2.uri()]}) + assert len(calls) == 0 + + # filter by the wrong ref + calls = client.get_calls(filter={"output_refs": [ref.uri()]}) + assert len(calls) == 0 + + # filter by duplicate refs + calls = client.get_calls(filter={"input_refs": [ref.uri(), ref.uri()]}) + assert len(calls) == 1 + assert calls[0].inputs["ref"] == obj + assert calls[0].output["ref2"] == obj + + # filter by empty refs, this is ambiguously defined, currently we treat + # this as "no filter" + calls = client.get_calls(filter={"input_refs": [], "output_refs": []}) + assert len(calls) == 1 From e4e60fd4f00557f5029b8f7d1c977036952f5ff1 Mon Sep 17 00:00:00 2001 From: gtarpenning Date: Wed, 23 Apr 2025 18:21:51 -0700 Subject: [PATCH 5/5] add double --- tests/trace/test_weave_client.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/tests/trace/test_weave_client.py b/tests/trace/test_weave_client.py index cbecfda734fc..7dd3b0dffa16 100644 --- a/tests/trace/test_weave_client.py +++ b/tests/trace/test_weave_client.py @@ -3318,6 +3318,13 @@ def log_obj(ref: str): assert calls[0].inputs["ref"] == obj assert calls[0].output["ref2"] == obj + # filter by two output refs + calls = client.get_calls(filter={"output_refs": [ref2.uri(), ref3.uri()]}) + assert len(calls) == 1 + assert calls[0].inputs["ref"] == obj + assert calls[0].output["ref2"] == obj + assert calls[0].output["ref3"] == obj + # filter by the wrong ref calls = client.get_calls(filter={"input_refs": [ref2.uri()]}) assert len(calls) == 0