From b4e6a50f539cd1060f1aebeb956a8309dec267a2 Mon Sep 17 00:00:00 2001 From: Tim Swast Date: Mon, 8 Jul 2024 16:26:48 +0000 Subject: [PATCH 01/59] deps: update to ibis-framework 9.x and newer sqlglot --- bigframes/functions/remote_function.py | 6 +- setup.py | 5 +- .../ibis/backends/bigquery/compiler.py | 78 +++++++------------ .../ibis/backends/bigquery/registry.py | 5 +- 4 files changed, 36 insertions(+), 58 deletions(-) diff --git a/bigframes/functions/remote_function.py b/bigframes/functions/remote_function.py index c1878b6c31..d7f99fe618 100644 --- a/bigframes/functions/remote_function.py +++ b/bigframes/functions/remote_function.py @@ -1048,7 +1048,8 @@ def try_delattr(attr): node = ibis.udf.scalar.builtin( func, name=rf_name, - schema=f"{dataset_ref.project}.{dataset_ref.dataset_id}", + catalog=dataset_ref.dataset_id, + database=dataset_ref.project, signature=(ibis_signature.input_types, ibis_signature.output_type), ) func.bigframes_cloud_function = ( @@ -1120,7 +1121,8 @@ def func(*ignored_args, **ignored_kwargs): node = ibis.udf.scalar.builtin( func, name=routine_ref.routine_id, - schema=f"{routine_ref.project}.{routine_ref.dataset_id}", + database=routine_ref.project, + catalog=routine_ref.dataset_id, signature=(ibis_signature.input_types, ibis_signature.output_type), ) func.bigframes_remote_function = str(routine_ref) # type: ignore diff --git a/setup.py b/setup.py index 79baf1fb23..b95196b128 100644 --- a/setup.py +++ b/setup.py @@ -47,9 +47,8 @@ "google-cloud-iam >=2.12.1", "google-cloud-resource-manager >=1.10.3", "google-cloud-storage >=2.0.0", - "ibis-framework[bigquery] >=8.0.0,<9.0.0dev", + "ibis-framework[bigquery] >=9.0.0,<10.0.0dev", "jellyfish >=0.8.9", - # TODO: Relax upper bound once we have fixed `system_prerelease` tests. "pandas >=1.5.0", "pyarrow >=8.0.0", "pydata-google-auth >=1.8.2", @@ -59,7 +58,7 @@ # Keep sqlglot versions in sync with ibis-framework. This avoids problems # where the incorrect version of sqlglot is installed, such as # https://github.com/googleapis/python-bigquery-dataframes/issues/315 - "sqlglot >=20.8.0,<=20.11", + "sqlglot >=23.4,<25.2", "tabulate >= 0.9", "ipywidgets >=7.7.1", "humanize >= 4.6.0", diff --git a/third_party/bigframes_vendored/ibis/backends/bigquery/compiler.py b/third_party/bigframes_vendored/ibis/backends/bigquery/compiler.py index 414f0a7c81..08c48165df 100644 --- a/third_party/bigframes_vendored/ibis/backends/bigquery/compiler.py +++ b/third_party/bigframes_vendored/ibis/backends/bigquery/compiler.py @@ -3,57 +3,33 @@ from __future__ import annotations -import re - -from ibis.backends.base.sql import compiler as sql_compiler -import ibis.backends.bigquery.compiler -from ibis.backends.bigquery.datatypes import BigQueryType -import ibis.expr.datatypes as dt -import ibis.expr.operations as ops - -_NAME_REGEX = re.compile(r'[^!"$()*,./;?@[\\\]^`{}~\n]+') -_EXACT_NAME_REGEX = re.compile(f"^{_NAME_REGEX.pattern}$") - - -class BigQueryTableSetFormatter(sql_compiler.TableSetFormatter): - def _quote_identifier(self, name): - """Restore 6.x version of identifier quoting. - - 7.x uses sqlglot which as of December 2023 doesn't know about the - extended unicode names for BigQuery yet. - """ - if _EXACT_NAME_REGEX.match(name) is not None: - return name - return f"`{name}`" - - def _format_in_memory_table(self, op): - """Restore 6.x version of InMemoryTable. - - BigQuery DataFrames explicitly uses InMemoryTable only when we know - the data is small enough to embed in SQL. - """ - schema = op.schema - names = schema.names - types = schema.types - - raw_rows = [] - for row in op.data.to_frame().itertuples(index=False): - raw_row = ", ".join( - f"{self._translate(lit)} AS {name}" - for lit, name in zip( - map(ops.Literal, row, types), map(self._quote_identifier, names) - ) - ) - raw_rows.append(f"STRUCT({raw_row})") - array_type = BigQueryType.from_ibis(dt.Array(op.schema.as_struct())) - - return f"UNNEST({array_type}[{', '.join(raw_rows)}])" +import ibis.backends.bigquery.compiler as bq_compiler +import sqlglot as sg +import sqlglot.expressions as sge + + +class BigQueryCompiler(bq_compiler.BigQueryCompiler): + def visit_InMemoryTable(self, op, *, name, schema, data): + # Avoid creating temp tables for small data, which is how memtable is + # used in BigQuery DataFrames. Implementation from: + # https://github.com/ibis-project/ibis/blob/efa6fb72bf4c790450d00a926d7bd809dade5902/ibis/backends/druid/compiler.py#L95 + tuples = data.to_frame().itertuples(index=False) + quoted = self.quoted + columns = [sg.column(col, quoted=quoted) for col in schema.names] + expr = sge.Values( + expressions=[ + sge.Tuple(expressions=tuple(map(sge.convert, row))) for row in tuples + ], + alias=sge.TableAlias( + this=sg.to_identifier(name, quoted=quoted), + columns=columns, + ), + ) + return sg.select(*columns).from_(expr) + + def visit_FirstNonNullValue(self, op): + pass # Override implementation. -ibis.backends.bigquery.compiler.BigQueryTableSetFormatter._quote_identifier = ( - BigQueryTableSetFormatter._quote_identifier -) -ibis.backends.bigquery.compiler.BigQueryTableSetFormatter._format_in_memory_table = ( - BigQueryTableSetFormatter._format_in_memory_table -) +bq_compiler.BigQueryCompiler.visit_InMemoryTable = BigQueryCompiler.visit_InMemoryTable diff --git a/third_party/bigframes_vendored/ibis/backends/bigquery/registry.py b/third_party/bigframes_vendored/ibis/backends/bigquery/registry.py index ecef2115e5..8212f96e6f 100644 --- a/third_party/bigframes_vendored/ibis/backends/bigquery/registry.py +++ b/third_party/bigframes_vendored/ibis/backends/bigquery/registry.py @@ -2,7 +2,8 @@ """Module to convert from Ibis expression to SQL string.""" import bigframes_vendored.ibis.expr.operations as vendored_ibis_ops -from ibis.backends.bigquery.registry import OPERATION_REGISTRY + +# from ibis.backends.bigquery.registry import OPERATION_REGISTRY import ibis.expr.operations.reductions as ibis_reductions @@ -69,4 +70,4 @@ def _array_aggregate(translator, op: vendored_ibis_ops.ArrayAggregate): vendored_ibis_ops.ArrayAggregate: _array_aggregate, # type:ignore } -OPERATION_REGISTRY.update(patched_ops) +# OPERATION_REGISTRY.update(patched_ops) From f1ce09de4bf25313c7e1e76025b7c6a3383c57a7 Mon Sep 17 00:00:00 2001 From: Tim Swast Date: Mon, 8 Jul 2024 21:00:49 +0000 Subject: [PATCH 02/59] update sqlglot and ibis --- setup.py | 2 +- testing/constraints-3.9.txt | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/setup.py b/setup.py index b95196b128..114b01493d 100644 --- a/setup.py +++ b/setup.py @@ -58,7 +58,7 @@ # Keep sqlglot versions in sync with ibis-framework. This avoids problems # where the incorrect version of sqlglot is installed, such as # https://github.com/googleapis/python-bigquery-dataframes/issues/315 - "sqlglot >=23.4,<25.2", + "sqlglot >=23.6.3,<25.2", "tabulate >= 0.9", "ipywidgets >=7.7.1", "humanize >= 4.6.0", diff --git a/testing/constraints-3.9.txt b/testing/constraints-3.9.txt index 5a76698576..6590191835 100644 --- a/testing/constraints-3.9.txt +++ b/testing/constraints-3.9.txt @@ -12,7 +12,7 @@ google-cloud-bigquery-connection==1.12.0 google-cloud-iam==2.12.1 google-cloud-resource-manager==1.10.3 google-cloud-storage==2.0.0 -ibis-framework==8.0.0 +ibis-framework==9.0.0 jellyfish==0.8.9 pandas==1.5.0 pyarrow==8.0.0 @@ -20,7 +20,7 @@ pydata-google-auth==1.8.2 requests==2.27.1 scikit-learn==1.2.2 sqlalchemy==1.4 -sqlglot==20.8.0 +sqlglot==23.6.3 tabulate==0.9 ipywidgets==7.7.1 humanize==4.6.0 From d224a52fdbc743942ef7e2bb35ba089f3c7cd132 Mon Sep 17 00:00:00 2001 From: Tim Swast Date: Mon, 8 Jul 2024 21:05:00 +0000 Subject: [PATCH 03/59] bump minimum pandas --- setup.py | 2 +- testing/constraints-3.9.txt | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/setup.py b/setup.py index 114b01493d..250e0e4b3b 100644 --- a/setup.py +++ b/setup.py @@ -49,7 +49,7 @@ "google-cloud-storage >=2.0.0", "ibis-framework[bigquery] >=9.0.0,<10.0.0dev", "jellyfish >=0.8.9", - "pandas >=1.5.0", + "pandas >=1.5.3", "pyarrow >=8.0.0", "pydata-google-auth >=1.8.2", "requests >=2.27.1", diff --git a/testing/constraints-3.9.txt b/testing/constraints-3.9.txt index 6590191835..2945cd1b54 100644 --- a/testing/constraints-3.9.txt +++ b/testing/constraints-3.9.txt @@ -14,7 +14,7 @@ google-cloud-resource-manager==1.10.3 google-cloud-storage==2.0.0 ibis-framework==9.0.0 jellyfish==0.8.9 -pandas==1.5.0 +pandas==1.5.3 pyarrow==8.0.0 pydata-google-auth==1.8.2 requests==2.27.1 From 28b6a310ad161cdd9b363cd8913bbb0a24d2af0f Mon Sep 17 00:00:00 2001 From: Tim Swast Date: Mon, 8 Jul 2024 21:07:04 +0000 Subject: [PATCH 04/59] bump pyarrow --- setup.py | 2 +- testing/constraints-3.9.txt | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/setup.py b/setup.py index 250e0e4b3b..c970c40b88 100644 --- a/setup.py +++ b/setup.py @@ -50,7 +50,7 @@ "ibis-framework[bigquery] >=9.0.0,<10.0.0dev", "jellyfish >=0.8.9", "pandas >=1.5.3", - "pyarrow >=8.0.0", + "pyarrow >=10.0.1", "pydata-google-auth >=1.8.2", "requests >=2.27.1", "scikit-learn >=1.2.2", diff --git a/testing/constraints-3.9.txt b/testing/constraints-3.9.txt index 2945cd1b54..c5976690f2 100644 --- a/testing/constraints-3.9.txt +++ b/testing/constraints-3.9.txt @@ -15,7 +15,7 @@ google-cloud-storage==2.0.0 ibis-framework==9.0.0 jellyfish==0.8.9 pandas==1.5.3 -pyarrow==8.0.0 +pyarrow==10.0.1 pydata-google-auth==1.8.2 requests==2.27.1 scikit-learn==1.2.2 From bb68b2b58038679a928eacfd7c5989a59ec07686 Mon Sep 17 00:00:00 2001 From: Tim Swast Date: Tue, 9 Jul 2024 15:05:57 +0000 Subject: [PATCH 05/59] fix bfill and ffill --- tests/system/conftest.py | 4 ++-- .../ibis/backends/bigquery/compiler.py | 14 ++++++++++++-- .../ibis/backends/bigquery/registry.py | 12 ------------ 3 files changed, 14 insertions(+), 16 deletions(-) diff --git a/tests/system/conftest.py b/tests/system/conftest.py index df4ff9aff0..673f6c09d5 100644 --- a/tests/system/conftest.py +++ b/tests/system/conftest.py @@ -29,7 +29,7 @@ import google.cloud.functions_v2 as functions_v2 import google.cloud.resourcemanager_v3 as resourcemanager_v3 import google.cloud.storage as storage # type: ignore -import ibis.backends.base +import ibis.backends import numpy as np import pandas as pd import pytest @@ -99,7 +99,7 @@ def bigquery_client_tokyo(session_tokyo: bigframes.Session) -> bigquery.Client: @pytest.fixture(scope="session") -def ibis_client(session: bigframes.Session) -> ibis.backends.base.BaseBackend: +def ibis_client(session: bigframes.Session) -> ibis.backends.BaseBackend: return session.ibis_client diff --git a/third_party/bigframes_vendored/ibis/backends/bigquery/compiler.py b/third_party/bigframes_vendored/ibis/backends/bigquery/compiler.py index 08c48165df..8064090a4d 100644 --- a/third_party/bigframes_vendored/ibis/backends/bigquery/compiler.py +++ b/third_party/bigframes_vendored/ibis/backends/bigquery/compiler.py @@ -27,9 +27,19 @@ def visit_InMemoryTable(self, op, *, name, schema, data): ) return sg.select(*columns).from_(expr) - def visit_FirstNonNullValue(self, op): - pass + def visit_FirstNonNullValue(self, op, *, arg): + return sge.IgnoreNulls(this=sge.FirstValue(this=arg)) + + def visit_LastNonNullValue(self, op, *, arg): + return sge.IgnoreNulls(this=sge.LastValue(this=arg)) # Override implementation. +# We monkeypatch individual methods because the class might have already been imported in other modules. bq_compiler.BigQueryCompiler.visit_InMemoryTable = BigQueryCompiler.visit_InMemoryTable +bq_compiler.BigQueryCompiler.visit_FirstNonNullValue = ( + BigQueryCompiler.visit_FirstNonNullValue +) +bq_compiler.BigQueryCompiler.visit_LastNonNullValue = ( + BigQueryCompiler.visit_LastNonNullValue +) diff --git a/third_party/bigframes_vendored/ibis/backends/bigquery/registry.py b/third_party/bigframes_vendored/ibis/backends/bigquery/registry.py index 8212f96e6f..9c879172b5 100644 --- a/third_party/bigframes_vendored/ibis/backends/bigquery/registry.py +++ b/third_party/bigframes_vendored/ibis/backends/bigquery/registry.py @@ -13,16 +13,6 @@ def _approx_quantiles(translator, op: vendored_ibis_ops.ApproximateMultiQuantile return f"APPROX_QUANTILES({arg}, {num_bins})" -def _first_non_null_value(translator, op: vendored_ibis_ops.FirstNonNullValue): - arg = translator.translate(op.arg) - return f"FIRST_VALUE({arg} IGNORE NULLS)" - - -def _last_non_null_value(translator, op: vendored_ibis_ops.LastNonNullValue): - arg = translator.translate(op.arg) - return f"LAST_VALUE({arg} IGNORE NULLS)" - - def _to_json_string(translator, op: vendored_ibis_ops.ToJsonString): arg = translator.translate(op.arg) return f"TO_JSON_STRING({arg})" @@ -61,8 +51,6 @@ def _array_aggregate(translator, op: vendored_ibis_ops.ArrayAggregate): patched_ops = { vendored_ibis_ops.ApproximateMultiQuantile: _approx_quantiles, # type:ignore - vendored_ibis_ops.FirstNonNullValue: _first_non_null_value, # type:ignore - vendored_ibis_ops.LastNonNullValue: _last_non_null_value, # type:ignore vendored_ibis_ops.ToJsonString: _to_json_string, # type:ignore vendored_ibis_ops.GenerateArray: _generate_array, # type:ignore vendored_ibis_ops.SafeCastToDatetime: _safe_cast_to_datetime, # type:ignore From d5622b607472893dc7be82f04256f961504b436c Mon Sep 17 00:00:00 2001 From: Tim Swast Date: Tue, 9 Jul 2024 15:49:49 +0000 Subject: [PATCH 06/59] nearly implement describe --- bigframes/core/compile/aggregate_compiler.py | 20 ++++++--- tests/system/small/test_ibis.py | 44 ------------------- .../ibis/expr/operations/reductions.py | 13 +----- 3 files changed, 15 insertions(+), 62 deletions(-) delete mode 100644 tests/system/small/test_ibis.py diff --git a/bigframes/core/compile/aggregate_compiler.py b/bigframes/core/compile/aggregate_compiler.py index 58973b10eb..97628d9261 100644 --- a/bigframes/core/compile/aggregate_compiler.py +++ b/bigframes/core/compile/aggregate_compiler.py @@ -13,7 +13,7 @@ # limitations under the License. import functools import typing -from typing import cast, Optional +from typing import cast, List, Optional import bigframes_vendored.ibis.expr.operations as vendored_ibis_ops import ibis @@ -31,6 +31,17 @@ scalar_compiler = scalar_compilers.scalar_op_compiler +# TODO(swast): We can remove this if ibis adds general approx_quantile +# See: https://github.com/ibis-project/ibis/issues/9541 +@ibis.udf.agg.builtin +def approx_quantiles(expression: float, number) -> List[float]: + """APPROX_QUANTILES + + https://cloud.google.com/bigquery/docs/reference/standard-sql/approximate_aggregate_functions#approx_quantiles + """ + return [] # pragma: NO COVER + + def compile_aggregate( aggregate: ex.Aggregation, bindings: typing.Dict[str, ibis_types.Value], @@ -176,15 +187,12 @@ def _( column: ibis_types.NumericColumn, window=None, ) -> ibis_types.NumericValue: - # PERCENTILE_CONT has very few allowed windows. For example, "window - # framing clause is not allowed for analytic function percentile_cont". + # APPROX_QUANTILES has very few allowed windows. if window is not None: raise NotImplementedError( f"Approx Quartiles with windowing is not supported. {constants.FEEDBACK_LINK}" ) - value = vendored_ibis_ops.ApproximateMultiQuantile( - column, num_bins=4 # type: ignore - ).to_expr()[op.quartile] + value = approx_quantiles(column, 4)[op.quartile] # type: ignore return cast(ibis_types.NumericValue, value) diff --git a/tests/system/small/test_ibis.py b/tests/system/small/test_ibis.py deleted file mode 100644 index e2648d1eba..0000000000 --- a/tests/system/small/test_ibis.py +++ /dev/null @@ -1,44 +0,0 @@ -# Copyright 2023 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Tests for monkeypatched ibis code.""" - -import bigframes_vendored.ibis.expr.operations as vendored_ibis_ops -import ibis.expr.types as ibis_types - -import bigframes - - -def test_approximate_quantiles(session: bigframes.Session, scalars_table_id: str): - num_bins = 3 - ibis_client = session.ibis_client - project, dataset, table_id = scalars_table_id.split(".") - ibis_table: ibis_types.Table = ibis_client.table( # type: ignore - table_id, - schema=dataset, - database=project, - ) - ibis_column: ibis_types.NumericColumn = ibis_table["int64_col"] - quantiles: ibis_types.ArrayScalar = vendored_ibis_ops.ApproximateMultiQuantile( - ibis_column, # type: ignore - num_bins=num_bins, # type: ignore - ).to_expr() - value = quantiles[1] - num_edges = quantiles.length() - - sql = ibis_client.compile(value) - num_edges_result = num_edges.to_pandas() - - assert "APPROX_QUANTILES" in sql - assert num_edges_result == num_bins + 1 diff --git a/third_party/bigframes_vendored/ibis/expr/operations/reductions.py b/third_party/bigframes_vendored/ibis/expr/operations/reductions.py index bd971e408a..a428c73449 100644 --- a/third_party/bigframes_vendored/ibis/expr/operations/reductions.py +++ b/third_party/bigframes_vendored/ibis/expr/operations/reductions.py @@ -9,17 +9,6 @@ from ibis.expr.operations.reductions import Filterable, Reduction -class ApproximateMultiQuantile(Filterable, Reduction): - """Calculate (approximately) evenly-spaced quantiles. - - See: https://cloud.google.com/bigquery/docs/reference/standard-sql/approximate_aggregate_functions#approx_quantiles - """ - - arg: ibis_ops_core.Value - num_bins: ibis_ops_core.Value[dt.Int64] - dtype = dt.Array(dt.float64) - - class ArrayAggregate(Filterable, Reduction): """ Collects the elements of this expression into an ordered array. Similar to @@ -34,4 +23,4 @@ def dtype(self): return dt.Array(self.arg.dtype) -__all__ = ["ApproximateMultiQuantile", "ArrayAggregate"] +__all__ = ["ArrayAggregate"] From 3596edd13af36a6c8e6068a24bcd1a803a2931eb Mon Sep 17 00:00:00 2001 From: Tim Swast Date: Tue, 9 Jul 2024 15:58:45 +0000 Subject: [PATCH 07/59] remove remaining reference to vendored_ibis_ops.ApproximateMultiQuantile --- .../bigframes_vendored/ibis/backends/bigquery/registry.py | 7 ------- 1 file changed, 7 deletions(-) diff --git a/third_party/bigframes_vendored/ibis/backends/bigquery/registry.py b/third_party/bigframes_vendored/ibis/backends/bigquery/registry.py index 9c879172b5..25cf6983a9 100644 --- a/third_party/bigframes_vendored/ibis/backends/bigquery/registry.py +++ b/third_party/bigframes_vendored/ibis/backends/bigquery/registry.py @@ -7,12 +7,6 @@ import ibis.expr.operations.reductions as ibis_reductions -def _approx_quantiles(translator, op: vendored_ibis_ops.ApproximateMultiQuantile): - arg = translator.translate(op.arg) - num_bins = translator.translate(op.num_bins) - return f"APPROX_QUANTILES({arg}, {num_bins})" - - def _to_json_string(translator, op: vendored_ibis_ops.ToJsonString): arg = translator.translate(op.arg) return f"TO_JSON_STRING({arg})" @@ -50,7 +44,6 @@ def _array_aggregate(translator, op: vendored_ibis_ops.ArrayAggregate): patched_ops = { - vendored_ibis_ops.ApproximateMultiQuantile: _approx_quantiles, # type:ignore vendored_ibis_ops.ToJsonString: _to_json_string, # type:ignore vendored_ibis_ops.GenerateArray: _generate_array, # type:ignore vendored_ibis_ops.SafeCastToDatetime: _safe_cast_to_datetime, # type:ignore From 32c2ab6006f42c6831c78f24974f96788d8c49fd Mon Sep 17 00:00:00 2001 From: Tim Swast Date: Tue, 9 Jul 2024 16:22:55 +0000 Subject: [PATCH 08/59] support ToJsonString --- .../bigframes_vendored/ibis/backends/bigquery/compiler.py | 4 ++++ .../bigframes_vendored/ibis/backends/bigquery/registry.py | 6 ------ third_party/bigframes_vendored/ibis/expr/operations/json.py | 2 ++ 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/third_party/bigframes_vendored/ibis/backends/bigquery/compiler.py b/third_party/bigframes_vendored/ibis/backends/bigquery/compiler.py index 8064090a4d..5af0e84769 100644 --- a/third_party/bigframes_vendored/ibis/backends/bigquery/compiler.py +++ b/third_party/bigframes_vendored/ibis/backends/bigquery/compiler.py @@ -33,6 +33,9 @@ def visit_FirstNonNullValue(self, op, *, arg): def visit_LastNonNullValue(self, op, *, arg): return sge.IgnoreNulls(this=sge.LastValue(this=arg)) + def visit_ToJsonString(self, op, *, arg): + return self.f.to_json_string(arg) + # Override implementation. # We monkeypatch individual methods because the class might have already been imported in other modules. @@ -43,3 +46,4 @@ def visit_LastNonNullValue(self, op, *, arg): bq_compiler.BigQueryCompiler.visit_LastNonNullValue = ( BigQueryCompiler.visit_LastNonNullValue ) +bq_compiler.BigQueryCompiler.visit_ToJsonString = BigQueryCompiler.visit_ToJsonString diff --git a/third_party/bigframes_vendored/ibis/backends/bigquery/registry.py b/third_party/bigframes_vendored/ibis/backends/bigquery/registry.py index 25cf6983a9..a981d0a9ef 100644 --- a/third_party/bigframes_vendored/ibis/backends/bigquery/registry.py +++ b/third_party/bigframes_vendored/ibis/backends/bigquery/registry.py @@ -7,11 +7,6 @@ import ibis.expr.operations.reductions as ibis_reductions -def _to_json_string(translator, op: vendored_ibis_ops.ToJsonString): - arg = translator.translate(op.arg) - return f"TO_JSON_STRING({arg})" - - def _generate_array(translator, op: vendored_ibis_ops.GenerateArray): arg = translator.translate(op.arg) return f"GENERATE_ARRAY(0, {arg})" @@ -44,7 +39,6 @@ def _array_aggregate(translator, op: vendored_ibis_ops.ArrayAggregate): patched_ops = { - vendored_ibis_ops.ToJsonString: _to_json_string, # type:ignore vendored_ibis_ops.GenerateArray: _generate_array, # type:ignore vendored_ibis_ops.SafeCastToDatetime: _safe_cast_to_datetime, # type:ignore ibis_reductions.Quantile: _quantile, # type:ignore diff --git a/third_party/bigframes_vendored/ibis/expr/operations/json.py b/third_party/bigframes_vendored/ibis/expr/operations/json.py index 1eb0554137..ea1f766a71 100644 --- a/third_party/bigframes_vendored/ibis/expr/operations/json.py +++ b/third_party/bigframes_vendored/ibis/expr/operations/json.py @@ -5,5 +5,7 @@ import ibis.expr.operations.core as ibis_ops_core +# TODO(swast): Remove once supported upstream. +# See: https://github.com/ibis-project/ibis/issues/9542 class ToJsonString(ibis_ops_core.Unary): dtype = dt.string From 5e0c1e7ff7d0600669880e64dd92f947983ddb01 Mon Sep 17 00:00:00 2001 From: Tim Swast Date: Tue, 9 Jul 2024 19:30:52 +0000 Subject: [PATCH 09/59] partial support for quantile --- bigframes/core/compile/compiled.py | 34 +++++++++++-------- .../ibis/backends/bigquery/compiler.py | 14 ++++++++ .../ibis/backends/bigquery/registry.py | 18 ---------- .../ibis/expr/operations/arrays.py | 9 ----- 4 files changed, 33 insertions(+), 42 deletions(-) diff --git a/bigframes/core/compile/compiled.py b/bigframes/core/compile/compiled.py index cc601744c1..8a74b672d3 100644 --- a/bigframes/core/compile/compiled.py +++ b/bigframes/core/compile/compiled.py @@ -19,7 +19,6 @@ import typing from typing import Collection, Literal, Optional, Sequence -import bigframes_vendored.ibis.expr.operations as vendored_ibis_ops import ibis import ibis.backends.bigquery as ibis_bigquery import ibis.common.deferred # type: ignore @@ -55,6 +54,13 @@ op_compiler = op_compilers.scalar_op_compiler +# TODO(swast): remove once ibis.range is more efficient. +# See: https://github.com/ibis-project/ibis/issues/8892 +@ibis.udf.scalar.builtin +def generate_array(start_expression, end_expression) -> list[int]: + return [] # pragma: NO COVER + + class BaseIbisIR(abc.ABC): """Implementation detail, contains common logic between ordered and unordered IR""" @@ -404,16 +410,15 @@ def explode(self, column_ids: typing.Sequence[str]) -> UnorderedIR: # The offset array ensures null represents empty arrays after unnesting. offset_array_id = bigframes.core.guid.generate_guid("offset_array_") offset_array = ( - vendored_ibis_ops.GenerateArray( + ibis.range( + 0, ibis.greatest( 0, ibis.least( - *[table[column_id].length() - 1 for column_id in column_ids] + *[table[column_id].length() for column_id in column_ids] ), - ) - ) - .to_expr() - .name(offset_array_id), + ), + ).name(offset_array_id), ) table_w_offset_array = table.select( offset_array, @@ -709,16 +714,15 @@ def explode(self, column_ids: typing.Sequence[str]) -> OrderedIR: offset_array_id = bigframes.core.guid.generate_guid("offset_array_") offset_array = ( - vendored_ibis_ops.GenerateArray( + ibis.range( + 0, ibis.greatest( 0, ibis.least( - *[table[column_id].length() - 1 for column_id in column_ids] + *[table[column_id].length() for column_id in column_ids] ), - ) - ) - .to_expr() - .name(offset_array_id), + ), + ).name(offset_array_id), ) table_w_offset_array = table.select( offset_array, @@ -826,7 +830,7 @@ def project_window_op( clauses = [] if op.skips_nulls and not never_skip_nulls: - clauses.append((column.isnull(), ibis.NA)) + clauses.append((column.isnull(), ibis.null())) if window_spec.min_periods: if op.skips_nulls: # Most operations do not count NULL values towards min_periods @@ -847,7 +851,7 @@ def project_window_op( clauses.append( ( observation_count < ibis_types.literal(window_spec.min_periods), - ibis.NA, + ibis.null(), ) ) if clauses: diff --git a/third_party/bigframes_vendored/ibis/backends/bigquery/compiler.py b/third_party/bigframes_vendored/ibis/backends/bigquery/compiler.py index 5af0e84769..263303287f 100644 --- a/third_party/bigframes_vendored/ibis/backends/bigquery/compiler.py +++ b/third_party/bigframes_vendored/ibis/backends/bigquery/compiler.py @@ -4,11 +4,19 @@ from __future__ import annotations import ibis.backends.bigquery.compiler as bq_compiler +import ibis.backends.sql.compiler as sql_compiler +import ibis.expr.operations.reductions as ibis_reductions import sqlglot as sg import sqlglot.expressions as sge class BigQueryCompiler(bq_compiler.BigQueryCompiler): + UNSUPPORTED_OPS = bq_compiler.BigQueryCompiler.UNSUPPORTED_OPS = tuple( + op + for op in bq_compiler.BigQueryCompiler.UNSUPPORTED_OPS + if op != ibis_reductions.Quantile + ) + def visit_InMemoryTable(self, op, *, name, schema, data): # Avoid creating temp tables for small data, which is how memtable is # used in BigQuery DataFrames. Implementation from: @@ -47,3 +55,9 @@ def visit_ToJsonString(self, op, *, arg): BigQueryCompiler.visit_LastNonNullValue ) bq_compiler.BigQueryCompiler.visit_ToJsonString = BigQueryCompiler.visit_ToJsonString + +# TODO(swast): sqlglot base implementation appears to work fine for the bigquery backend, at least in our windowed contexts. See: ISSUE NUMBER +bq_compiler.BigQueryCompiler.UNSUPPORTED_OPS = BigQueryCompiler.UNSUPPORTED_OPS +bq_compiler.BigQueryCompiler.visit_Quantile = ( + sql_compiler.SQLGlotCompiler.visit_Quantile +) diff --git a/third_party/bigframes_vendored/ibis/backends/bigquery/registry.py b/third_party/bigframes_vendored/ibis/backends/bigquery/registry.py index a981d0a9ef..b219bd1a3a 100644 --- a/third_party/bigframes_vendored/ibis/backends/bigquery/registry.py +++ b/third_party/bigframes_vendored/ibis/backends/bigquery/registry.py @@ -3,26 +3,12 @@ import bigframes_vendored.ibis.expr.operations as vendored_ibis_ops -# from ibis.backends.bigquery.registry import OPERATION_REGISTRY -import ibis.expr.operations.reductions as ibis_reductions - - -def _generate_array(translator, op: vendored_ibis_ops.GenerateArray): - arg = translator.translate(op.arg) - return f"GENERATE_ARRAY(0, {arg})" - def _safe_cast_to_datetime(translator, op: vendored_ibis_ops.SafeCastToDatetime): arg = translator.translate(op.arg) return f"SAFE_CAST({arg} AS DATETIME)" -def _quantile(translator, op: ibis_reductions.Quantile): - arg = translator.translate(op.arg) - quantile = translator.translate(op.quantile) - return f"PERCENTILE_CONT({arg}, {quantile})" - - def _array_aggregate(translator, op: vendored_ibis_ops.ArrayAggregate): """This method provides the same functionality as the collect() method in Ibis, with the added capability of ordering the results using order_by. @@ -39,10 +25,6 @@ def _array_aggregate(translator, op: vendored_ibis_ops.ArrayAggregate): patched_ops = { - vendored_ibis_ops.GenerateArray: _generate_array, # type:ignore vendored_ibis_ops.SafeCastToDatetime: _safe_cast_to_datetime, # type:ignore - ibis_reductions.Quantile: _quantile, # type:ignore vendored_ibis_ops.ArrayAggregate: _array_aggregate, # type:ignore } - -# OPERATION_REGISTRY.update(patched_ops) diff --git a/third_party/bigframes_vendored/ibis/expr/operations/arrays.py b/third_party/bigframes_vendored/ibis/expr/operations/arrays.py index a0ad915a9b..627cca2765 100644 --- a/third_party/bigframes_vendored/ibis/expr/operations/arrays.py +++ b/third_party/bigframes_vendored/ibis/expr/operations/arrays.py @@ -5,14 +5,5 @@ from ibis.expr.operations.core import Unary -class GenerateArray(Unary): - """ - Generates an array of values, similar to ibis.range(), but with simpler and - more efficient SQL generation. - """ - - dtype = dt.Array(dt.int64) - - class SafeCastToDatetime(Unary): dtype = dt.Timestamp(timezone=None) From d877261dac9f4910e39f9d2c75b564a6cb3d1646 Mon Sep 17 00:00:00 2001 From: Tim Swast Date: Fri, 12 Jul 2024 21:00:38 +0000 Subject: [PATCH 10/59] fix inmemorytable --- .../ibis/backends/bigquery/compiler.py | 31 ++++++++++++++++--- 1 file changed, 27 insertions(+), 4 deletions(-) diff --git a/third_party/bigframes_vendored/ibis/backends/bigquery/compiler.py b/third_party/bigframes_vendored/ibis/backends/bigquery/compiler.py index 263303287f..536930b355 100644 --- a/third_party/bigframes_vendored/ibis/backends/bigquery/compiler.py +++ b/third_party/bigframes_vendored/ibis/backends/bigquery/compiler.py @@ -5,6 +5,7 @@ import ibis.backends.bigquery.compiler as bq_compiler import ibis.backends.sql.compiler as sql_compiler +import ibis.backends.sql.datatypes as sql_datatypes import ibis.expr.operations.reductions as ibis_reductions import sqlglot as sg import sqlglot.expressions as sge @@ -19,21 +20,43 @@ class BigQueryCompiler(bq_compiler.BigQueryCompiler): def visit_InMemoryTable(self, op, *, name, schema, data): # Avoid creating temp tables for small data, which is how memtable is - # used in BigQuery DataFrames. Implementation from: + # used in BigQuery DataFrames. Inspired by: # https://github.com/ibis-project/ibis/blob/efa6fb72bf4c790450d00a926d7bd809dade5902/ibis/backends/druid/compiler.py#L95 tuples = data.to_frame().itertuples(index=False) quoted = self.quoted columns = [sg.column(col, quoted=quoted) for col in schema.names] - expr = sge.Values( + expr = sge.Unnest( expressions=[ - sge.Tuple(expressions=tuple(map(sge.convert, row))) for row in tuples + sge.DataType( + this=sge.DataType.Type.ARRAY, + expressions=[ + # TODO: Data types and names from schema. + sge.DataType( + this=sge.DataType.Type.STRUCT, + expressions=[ + sge.ColumnDef( + this=sge.to_identifier(field, quoted=self.quoted), + kind=sql_datatypes.SqlglotType.from_ibis(type_), + ) + for field, type_ in zip(schema.names, schema.types) + ], + nested=True, + ) + ], + nested=True, + values=[ + sge.Tuple(expressions=tuple(map(sge.convert, row))) + for row in tuples + ], + ), ], alias=sge.TableAlias( this=sg.to_identifier(name, quoted=quoted), columns=columns, ), ) - return sg.select(*columns).from_(expr) + # return expr + return sg.select(sge.Star()).from_(expr) def visit_FirstNonNullValue(self, op, *, arg): return sge.IgnoreNulls(this=sge.FirstValue(this=arg)) From 847f459ffad0494b5ee7f8a8e719db2d25607360 Mon Sep 17 00:00:00 2001 From: Tim Swast Date: Mon, 15 Jul 2024 16:36:55 +0000 Subject: [PATCH 11/59] fixed Series.explode --- bigframes/core/compile/compiled.py | 43 ++++++++++-------------------- 1 file changed, 14 insertions(+), 29 deletions(-) diff --git a/bigframes/core/compile/compiled.py b/bigframes/core/compile/compiled.py index d13a5a9888..16a6af1b5c 100644 --- a/bigframes/core/compile/compiled.py +++ b/bigframes/core/compile/compiled.py @@ -56,13 +56,6 @@ op_compiler = op_compilers.scalar_op_compiler -# TODO(swast): remove once ibis.range is more efficient. -# See: https://github.com/ibis-project/ibis/issues/8892 -@ibis.udf.scalar.builtin -def generate_array(start_expression, end_expression) -> list[int]: - return [] # pragma: NO COVER - - class BaseIbisIR(abc.ABC): """Implementation detail, contains common logic between ordered and unordered IR""" @@ -411,17 +404,13 @@ def explode(self, column_ids: typing.Sequence[str]) -> UnorderedIR: # The offset array ensures null represents empty arrays after unnesting. offset_array_id = bigframes.core.guid.generate_guid("offset_array_") - offset_array = ( - ibis.range( - 0, - ibis.greatest( - 0, - ibis.least( - *[table[column_id].length() for column_id in column_ids] - ), - ), - ).name(offset_array_id), - ) + offset_array = ibis.range( + 0, + ibis.greatest( + 1, # We always want at least 1 element to fill in NULLs for empty arrays. + ibis.least(*[table[column_id].length() for column_id in column_ids]), + ), + ).name(offset_array_id) table_w_offset_array = table.select( offset_array, *self._column_names, @@ -719,17 +708,13 @@ def explode(self, column_ids: typing.Sequence[str]) -> OrderedIR: table = self._to_ibis_expr(ordering_mode="unordered", expose_hidden_cols=True) offset_array_id = bigframes.core.guid.generate_guid("offset_array_") - offset_array = ( - ibis.range( - 0, - ibis.greatest( - 0, - ibis.least( - *[table[column_id].length() for column_id in column_ids] - ), - ), - ).name(offset_array_id), - ) + offset_array = ibis.range( + 0, + ibis.greatest( + 1, # We always want at least 1 element to fill in NULLs for empty arrays. + ibis.least(*[table[column_id].length() for column_id in column_ids]), + ), + ).name(offset_array_id) table_w_offset_array = table.select( offset_array, *self._column_names, From 6a0bedcef8cf42a9cca75322cc79dc30798801df Mon Sep 17 00:00:00 2001 From: Tim Swast Date: Mon, 15 Jul 2024 21:26:59 +0000 Subject: [PATCH 12/59] nearly fix to_datetime --- bigframes/core/compile/scalar_op_compiler.py | 2 +- tests/system/small/test_pandas.py | 21 +++++++++++++++++++ .../ibis/backends/bigquery/compiler.py | 8 +++++-- .../ibis/backends/bigquery/registry.py | 6 ------ .../ibis/expr/operations/__init__.py | 1 - .../ibis/expr/operations/arrays.py | 9 -------- 6 files changed, 28 insertions(+), 19 deletions(-) delete mode 100644 third_party/bigframes_vendored/ibis/expr/operations/arrays.py diff --git a/bigframes/core/compile/scalar_op_compiler.py b/bigframes/core/compile/scalar_op_compiler.py index 0bc9f2e370..6be67d58e7 100644 --- a/bigframes/core/compile/scalar_op_compiler.py +++ b/bigframes/core/compile/scalar_op_compiler.py @@ -832,7 +832,7 @@ def isin_op_impl(x: ibis_types.Value, op: ops.IsInOp): @scalar_op_compiler.register_unary_op(ops.ToDatetimeOp, pass_op=True) def to_datetime_op_impl(x: ibis_types.Value, op: ops.ToDatetimeOp): if x.type() == ibis_dtypes.str: - return vendored_ibis_ops.SafeCastToDatetime(x).to_expr() + return x.try_cast(ibis_dtypes.Timestamp(None)) else: # Numerical inputs. if op.format: diff --git a/tests/system/small/test_pandas.py b/tests/system/small/test_pandas.py index 30ffaa8a7d..8bc5fc05a8 100644 --- a/tests/system/small/test_pandas.py +++ b/tests/system/small/test_pandas.py @@ -546,6 +546,27 @@ def test_to_datetime_scalar(arg, utc, unit, format): assert bf_result == pd_result +@pytest.mark.parametrize( + ("arg", "utc", "format"), + [ + ("not-a-datetime", False, None), + ("not-a-timestamp", True, None), + ("not-matching-format", False, "%Y-%m-%d"), + ], +) +def test_to_datetime_scalar_invalid(arg, utc, format): + bf_result = bpd.to_datetime(arg, utc=utc, format=format) + pd_result = pd.to_datetime( + arg, + utc=utc, + format=format, + # Convert invalid values to a NULL marker, similar to BigQuery SAFE_CAST. + errors="coerce", + ) + + assert bf_result == pd_result + + @pytest.mark.parametrize( ("arg", "utc", "unit", "format"), [ diff --git a/third_party/bigframes_vendored/ibis/backends/bigquery/compiler.py b/third_party/bigframes_vendored/ibis/backends/bigquery/compiler.py index 536930b355..80c11a68b1 100644 --- a/third_party/bigframes_vendored/ibis/backends/bigquery/compiler.py +++ b/third_party/bigframes_vendored/ibis/backends/bigquery/compiler.py @@ -30,7 +30,6 @@ def visit_InMemoryTable(self, op, *, name, schema, data): sge.DataType( this=sge.DataType.Type.ARRAY, expressions=[ - # TODO: Data types and names from schema. sge.DataType( this=sge.DataType.Type.STRUCT, expressions=[ @@ -45,7 +44,12 @@ def visit_InMemoryTable(self, op, *, name, schema, data): ], nested=True, values=[ - sge.Tuple(expressions=tuple(map(sge.convert, row))) + sge.Tuple( + expressions=tuple( + self.visit_Literal(None, value=value, dtype=type_) + for value, type_ in zip(row, schema.types) + ) + ) for row in tuples ], ), diff --git a/third_party/bigframes_vendored/ibis/backends/bigquery/registry.py b/third_party/bigframes_vendored/ibis/backends/bigquery/registry.py index b219bd1a3a..2dcfa11aa8 100644 --- a/third_party/bigframes_vendored/ibis/backends/bigquery/registry.py +++ b/third_party/bigframes_vendored/ibis/backends/bigquery/registry.py @@ -4,11 +4,6 @@ import bigframes_vendored.ibis.expr.operations as vendored_ibis_ops -def _safe_cast_to_datetime(translator, op: vendored_ibis_ops.SafeCastToDatetime): - arg = translator.translate(op.arg) - return f"SAFE_CAST({arg} AS DATETIME)" - - def _array_aggregate(translator, op: vendored_ibis_ops.ArrayAggregate): """This method provides the same functionality as the collect() method in Ibis, with the added capability of ordering the results using order_by. @@ -25,6 +20,5 @@ def _array_aggregate(translator, op: vendored_ibis_ops.ArrayAggregate): patched_ops = { - vendored_ibis_ops.SafeCastToDatetime: _safe_cast_to_datetime, # type:ignore vendored_ibis_ops.ArrayAggregate: _array_aggregate, # type:ignore } diff --git a/third_party/bigframes_vendored/ibis/expr/operations/__init__.py b/third_party/bigframes_vendored/ibis/expr/operations/__init__.py index 3ae5fc10e4..2c2efe528d 100644 --- a/third_party/bigframes_vendored/ibis/expr/operations/__init__.py +++ b/third_party/bigframes_vendored/ibis/expr/operations/__init__.py @@ -2,6 +2,5 @@ from __future__ import annotations from bigframes_vendored.ibis.expr.operations.analytic import * # noqa: F401 F403 -from bigframes_vendored.ibis.expr.operations.arrays import * # noqa: F401 F403 from bigframes_vendored.ibis.expr.operations.json import * # noqa: F401 F403 from bigframes_vendored.ibis.expr.operations.reductions import * # noqa: F401 F403 diff --git a/third_party/bigframes_vendored/ibis/expr/operations/arrays.py b/third_party/bigframes_vendored/ibis/expr/operations/arrays.py deleted file mode 100644 index 627cca2765..0000000000 --- a/third_party/bigframes_vendored/ibis/expr/operations/arrays.py +++ /dev/null @@ -1,9 +0,0 @@ -# Contains code from https://github.com/ibis-project/ibis/blob/master/ibis/expr/operations/arrays.py -from __future__ import annotations - -import ibis.expr.datatypes as dt -from ibis.expr.operations.core import Unary - - -class SafeCastToDatetime(Unary): - dtype = dt.Timestamp(timezone=None) From 9d56ee78a98496307150e927761665182816badd Mon Sep 17 00:00:00 2001 From: Tim Swast Date: Mon, 15 Jul 2024 21:37:04 +0000 Subject: [PATCH 13/59] remove tests I added --- tests/system/small/test_pandas.py | 21 --------------------- 1 file changed, 21 deletions(-) diff --git a/tests/system/small/test_pandas.py b/tests/system/small/test_pandas.py index 8bc5fc05a8..30ffaa8a7d 100644 --- a/tests/system/small/test_pandas.py +++ b/tests/system/small/test_pandas.py @@ -546,27 +546,6 @@ def test_to_datetime_scalar(arg, utc, unit, format): assert bf_result == pd_result -@pytest.mark.parametrize( - ("arg", "utc", "format"), - [ - ("not-a-datetime", False, None), - ("not-a-timestamp", True, None), - ("not-matching-format", False, "%Y-%m-%d"), - ], -) -def test_to_datetime_scalar_invalid(arg, utc, format): - bf_result = bpd.to_datetime(arg, utc=utc, format=format) - pd_result = pd.to_datetime( - arg, - utc=utc, - format=format, - # Convert invalid values to a NULL marker, similar to BigQuery SAFE_CAST. - errors="coerce", - ) - - assert bf_result == pd_result - - @pytest.mark.parametrize( ("arg", "utc", "unit", "format"), [ From fc84cb857326838e14779675690cb5886e0e50a9 Mon Sep 17 00:00:00 2001 From: Tim Swast Date: Tue, 16 Jul 2024 16:20:24 +0000 Subject: [PATCH 14/59] patch for python 3.9 support --- .../ibis/backends/bigquery/compiler.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/third_party/bigframes_vendored/ibis/backends/bigquery/compiler.py b/third_party/bigframes_vendored/ibis/backends/bigquery/compiler.py index 80c11a68b1..a7c891ef1c 100644 --- a/third_party/bigframes_vendored/ibis/backends/bigquery/compiler.py +++ b/third_party/bigframes_vendored/ibis/backends/bigquery/compiler.py @@ -12,10 +12,14 @@ class BigQueryCompiler(bq_compiler.BigQueryCompiler): - UNSUPPORTED_OPS = bq_compiler.BigQueryCompiler.UNSUPPORTED_OPS = tuple( - op - for op in bq_compiler.BigQueryCompiler.UNSUPPORTED_OPS - if op != ibis_reductions.Quantile + UNSUPPORTED_OPS = ( + tuple( + op + for op in bq_compiler.BigQueryCompiler.UNSUPPORTED_OPS + if op != ibis_reductions.Quantile + ) + if hasattr(bq_compiler.BigQueryCompiler, "UNSUPPORTED_OPS") + else () ) def visit_InMemoryTable(self, op, *, name, schema, data): From de32335605e2b8f93b4aa47b38d48a7ccd92fc6e Mon Sep 17 00:00:00 2001 From: Tim Swast Date: Tue, 16 Jul 2024 16:32:05 +0000 Subject: [PATCH 15/59] fix unit tests --- bigframes/functions/remote_function.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/bigframes/functions/remote_function.py b/bigframes/functions/remote_function.py index 77c427f940..cb3212d69a 100644 --- a/bigframes/functions/remote_function.py +++ b/bigframes/functions/remote_function.py @@ -1179,7 +1179,8 @@ def try_delattr(attr): node = ibis.udf.scalar.builtin( func, name=rf_name, - schema=f"{dataset_ref.project}.{dataset_ref.dataset_id}", + database=dataset_ref.project, + catalog=dataset_ref.dataset_id, signature=(ibis_signature.input_types, ibis_signature.output_type), ) func.bigframes_cloud_function = ( From e7dd60fd6c9e48d56d58ee86149af770003ab824 Mon Sep 17 00:00:00 2001 From: Tim Swast Date: Wed, 17 Jul 2024 16:43:45 +0000 Subject: [PATCH 16/59] fix explode with time type --- .../ibis/backends/bigquery/compiler.py | 39 +++++++++++++++++++ 1 file changed, 39 insertions(+) diff --git a/third_party/bigframes_vendored/ibis/backends/bigquery/compiler.py b/third_party/bigframes_vendored/ibis/backends/bigquery/compiler.py index a7c891ef1c..916761d908 100644 --- a/third_party/bigframes_vendored/ibis/backends/bigquery/compiler.py +++ b/third_party/bigframes_vendored/ibis/backends/bigquery/compiler.py @@ -6,6 +6,8 @@ import ibis.backends.bigquery.compiler as bq_compiler import ibis.backends.sql.compiler as sql_compiler import ibis.backends.sql.datatypes as sql_datatypes +import ibis.common.exceptions as com +from ibis.common.temporal import IntervalUnit import ibis.expr.operations.reductions as ibis_reductions import sqlglot as sg import sqlglot.expressions as sge @@ -66,6 +68,40 @@ def visit_InMemoryTable(self, op, *, name, schema, data): # return expr return sg.select(sge.Star()).from_(expr) + def visit_NonNullLiteral(self, op, *, value, dtype): + # Patch from https://github.com/ibis-project/ibis/pull/9610 to support ibis 9.0.0 and 9.1.0 + if dtype.is_inet() or dtype.is_macaddr(): + return sge.convert(str(value)) + elif dtype.is_timestamp(): + funcname = "DATETIME" if dtype.timezone is None else "TIMESTAMP" + return self.f.anon[funcname](value.isoformat()) + elif dtype.is_date(): + return self.f.date_from_parts(value.year, value.month, value.day) + elif dtype.is_time(): + time = self.f.time_from_parts(value.hour, value.minute, value.second) + if micros := value.microsecond: + # bigquery doesn't support `time(12, 34, 56.789101)`, AKA a + # float seconds specifier, so add any non-zero micros to the + # time value + return sge.TimeAdd( + this=time, expression=sge.convert(micros), unit=self.v.MICROSECOND + ) + return time + elif dtype.is_binary(): + return sge.Cast( + this=sge.convert(value.hex()), + to=sge.DataType(this=sge.DataType.Type.BINARY), + format=sge.convert("HEX"), + ) + elif dtype.is_interval(): + if dtype.unit == IntervalUnit.NANOSECOND: + raise com.UnsupportedOperationError( + "BigQuery does not support nanosecond intervals" + ) + elif dtype.is_uuid(): + return sge.convert(str(value)) + return None + def visit_FirstNonNullValue(self, op, *, arg): return sge.IgnoreNulls(this=sge.FirstValue(this=arg)) @@ -79,6 +115,9 @@ def visit_ToJsonString(self, op, *, arg): # Override implementation. # We monkeypatch individual methods because the class might have already been imported in other modules. bq_compiler.BigQueryCompiler.visit_InMemoryTable = BigQueryCompiler.visit_InMemoryTable +bq_compiler.BigQueryCompiler.visit_NonNullLiteral = ( + BigQueryCompiler.visit_NonNullLiteral +) bq_compiler.BigQueryCompiler.visit_FirstNonNullValue = ( BigQueryCompiler.visit_FirstNonNullValue ) From 016a203ed6067d433412ad16ebfb22f0a90f9d4c Mon Sep 17 00:00:00 2001 From: Tim Swast Date: Wed, 17 Jul 2024 22:08:54 +0000 Subject: [PATCH 17/59] fix array_agg --- .../ibis/backends/bigquery/__init__.py | 1 - .../ibis/backends/bigquery/compiler.py | 29 +++++++++++++++++++ .../ibis/backends/bigquery/registry.py | 24 --------------- 3 files changed, 29 insertions(+), 25 deletions(-) delete mode 100644 third_party/bigframes_vendored/ibis/backends/bigquery/registry.py diff --git a/third_party/bigframes_vendored/ibis/backends/bigquery/__init__.py b/third_party/bigframes_vendored/ibis/backends/bigquery/__init__.py index 1d2d05a741..ee24a71446 100644 --- a/third_party/bigframes_vendored/ibis/backends/bigquery/__init__.py +++ b/third_party/bigframes_vendored/ibis/backends/bigquery/__init__.py @@ -1,3 +1,2 @@ # Import all sub-modules to monkeypatch everything. import bigframes_vendored.ibis.backends.bigquery.compiler # noqa -import bigframes_vendored.ibis.backends.bigquery.registry # noqa diff --git a/third_party/bigframes_vendored/ibis/backends/bigquery/compiler.py b/third_party/bigframes_vendored/ibis/backends/bigquery/compiler.py index 916761d908..a49e04a2b1 100644 --- a/third_party/bigframes_vendored/ibis/backends/bigquery/compiler.py +++ b/third_party/bigframes_vendored/ibis/backends/bigquery/compiler.py @@ -4,6 +4,7 @@ from __future__ import annotations import ibis.backends.bigquery.compiler as bq_compiler +from ibis.backends.sql.compiler import NULL import ibis.backends.sql.compiler as sql_compiler import ibis.backends.sql.datatypes as sql_datatypes import ibis.common.exceptions as com @@ -102,6 +103,29 @@ def visit_NonNullLiteral(self, op, *, value, dtype): return sge.convert(str(value)) return None + # Custom operators. + + def visit_ArrayAggregate(self, op, *, arg, order_by, where): + if where is not None: + arg = self.if_(where, arg, NULL) + + if len(order_by) > 0: + expr = sge.Order( + this=arg, + expressions=[ + # Avoid adding NULLS FIRST / NULLS LAST in SQL, which is + # unsupported in ARRAY_AGG by reconstructing the node. + sge.Ordered( + this=order_column.this, + desc=order_column.desc, + ) + for order_column in order_by + ], + ) + else: + expr = arg + return sge.IgnoreNulls(this=sge.ArrayAgg(this=expr)) + def visit_FirstNonNullValue(self, op, *, arg): return sge.IgnoreNulls(this=sge.FirstValue(this=arg)) @@ -118,6 +142,11 @@ def visit_ToJsonString(self, op, *, arg): bq_compiler.BigQueryCompiler.visit_NonNullLiteral = ( BigQueryCompiler.visit_NonNullLiteral ) + +# Custom operators. +bq_compiler.BigQueryCompiler.visit_ArrayAggregate = ( + BigQueryCompiler.visit_ArrayAggregate +) bq_compiler.BigQueryCompiler.visit_FirstNonNullValue = ( BigQueryCompiler.visit_FirstNonNullValue ) diff --git a/third_party/bigframes_vendored/ibis/backends/bigquery/registry.py b/third_party/bigframes_vendored/ibis/backends/bigquery/registry.py deleted file mode 100644 index 2dcfa11aa8..0000000000 --- a/third_party/bigframes_vendored/ibis/backends/bigquery/registry.py +++ /dev/null @@ -1,24 +0,0 @@ -# Contains code from https://github.com/ibis-project/ibis/blob/master/ibis/backends/bigquery/registry.py -"""Module to convert from Ibis expression to SQL string.""" - -import bigframes_vendored.ibis.expr.operations as vendored_ibis_ops - - -def _array_aggregate(translator, op: vendored_ibis_ops.ArrayAggregate): - """This method provides the same functionality as the collect() method in Ibis, with - the added capability of ordering the results using order_by. - https://github.com/ibis-project/ibis/issues/9170 - """ - arg = translator.translate(op.arg) - - order_by_sql = "" - if len(op.order_by) > 0: - order_by = ", ".join([translator.translate(column) for column in op.order_by]) - order_by_sql = f"ORDER BY {order_by}" - - return f"ARRAY_AGG({arg} IGNORE NULLS {order_by_sql})" - - -patched_ops = { - vendored_ibis_ops.ArrayAggregate: _array_aggregate, # type:ignore -} From 00163fe66030101889a83f29b95f11a80f440776 Mon Sep 17 00:00:00 2001 From: Tim Swast Date: Wed, 17 Jul 2024 22:13:41 +0000 Subject: [PATCH 18/59] fix array_agg for asc order --- .../bigframes_vendored/ibis/backends/bigquery/compiler.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/third_party/bigframes_vendored/ibis/backends/bigquery/compiler.py b/third_party/bigframes_vendored/ibis/backends/bigquery/compiler.py index a49e04a2b1..09ddb42346 100644 --- a/third_party/bigframes_vendored/ibis/backends/bigquery/compiler.py +++ b/third_party/bigframes_vendored/ibis/backends/bigquery/compiler.py @@ -117,7 +117,8 @@ def visit_ArrayAggregate(self, op, *, arg, order_by, where): # unsupported in ARRAY_AGG by reconstructing the node. sge.Ordered( this=order_column.this, - desc=order_column.desc, + desc=order_column.desc is True, + nulls_first=True, ) for order_column in order_by ], From 129cfaec23a2026d7a98fc450d457c37cc0ecd01 Mon Sep 17 00:00:00 2001 From: Tim Swast Date: Wed, 17 Jul 2024 23:17:10 +0000 Subject: [PATCH 19/59] actually fix array_agg --- .../ibis/backends/bigquery/compiler.py | 15 ++++----------- 1 file changed, 4 insertions(+), 11 deletions(-) diff --git a/third_party/bigframes_vendored/ibis/backends/bigquery/compiler.py b/third_party/bigframes_vendored/ibis/backends/bigquery/compiler.py index 09ddb42346..b7af3f1158 100644 --- a/third_party/bigframes_vendored/ibis/backends/bigquery/compiler.py +++ b/third_party/bigframes_vendored/ibis/backends/bigquery/compiler.py @@ -4,7 +4,6 @@ from __future__ import annotations import ibis.backends.bigquery.compiler as bq_compiler -from ibis.backends.sql.compiler import NULL import ibis.backends.sql.compiler as sql_compiler import ibis.backends.sql.datatypes as sql_datatypes import ibis.common.exceptions as com @@ -106,26 +105,20 @@ def visit_NonNullLiteral(self, op, *, value, dtype): # Custom operators. def visit_ArrayAggregate(self, op, *, arg, order_by, where): - if where is not None: - arg = self.if_(where, arg, NULL) - if len(order_by) > 0: expr = sge.Order( this=arg, expressions=[ # Avoid adding NULLS FIRST / NULLS LAST in SQL, which is - # unsupported in ARRAY_AGG by reconstructing the node. - sge.Ordered( - this=order_column.this, - desc=order_column.desc is True, - nulls_first=True, - ) + # unsupported in ARRAY_AGG by reconstructing the node as + # plain SQL text. + f"({order_column.args['this'].sql(dialect='bigquery')}) {'DESC' if order_column.args.get('desc') else 'ASC'}" for order_column in order_by ], ) else: expr = arg - return sge.IgnoreNulls(this=sge.ArrayAgg(this=expr)) + return sge.IgnoreNulls(this=self.agg.array_agg(expr, where=where)) def visit_FirstNonNullValue(self, op, *, arg): return sge.IgnoreNulls(this=sge.FirstValue(this=arg)) From fdfc503ba203ce5ff73cb7b9a00a317caa95a992 Mon Sep 17 00:00:00 2001 From: Tim Swast Date: Wed, 17 Jul 2024 23:33:47 +0000 Subject: [PATCH 20/59] fix remote function --- bigframes/functions/remote_function.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/bigframes/functions/remote_function.py b/bigframes/functions/remote_function.py index cb3212d69a..57db204248 100644 --- a/bigframes/functions/remote_function.py +++ b/bigframes/functions/remote_function.py @@ -1179,8 +1179,8 @@ def try_delattr(attr): node = ibis.udf.scalar.builtin( func, name=rf_name, - database=dataset_ref.project, - catalog=dataset_ref.dataset_id, + catalog=dataset_ref.project, + database=dataset_ref.dataset_id, signature=(ibis_signature.input_types, ibis_signature.output_type), ) func.bigframes_cloud_function = ( @@ -1276,8 +1276,8 @@ def func(*ignored_args, **ignored_kwargs): node = ibis.udf.scalar.builtin( func, name=routine_ref.routine_id, - database=routine_ref.project, - catalog=routine_ref.dataset_id, + catalog=routine_ref.project, + database=routine_ref.dataset_id, signature=(ibis_signature.input_types, ibis_signature.output_type), ) func.bigframes_remote_function = str(routine_ref) # type: ignore From 6ac497a311b58b023939ed18253260c924f50f85 Mon Sep 17 00:00:00 2001 From: Tim Swast Date: Thu, 18 Jul 2024 15:21:15 +0000 Subject: [PATCH 21/59] fix in-memory nullable integer compilation --- .../bigframes_vendored/ibis/backends/bigquery/compiler.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/third_party/bigframes_vendored/ibis/backends/bigquery/compiler.py b/third_party/bigframes_vendored/ibis/backends/bigquery/compiler.py index b7af3f1158..f0da1a8d95 100644 --- a/third_party/bigframes_vendored/ibis/backends/bigquery/compiler.py +++ b/third_party/bigframes_vendored/ibis/backends/bigquery/compiler.py @@ -52,7 +52,12 @@ def visit_InMemoryTable(self, op, *, name, schema, data): values=[ sge.Tuple( expressions=tuple( - self.visit_Literal(None, value=value, dtype=type_) + # In-memory nullable integers can get stored as floats. + sge.convert( + int(value) + if value is not None and type_.is_integer() + else value + ) for value, type_ in zip(row, schema.types) ) ) From bbb7615a446eb05d3cce5f2c1b0f9e1bb0c1d7ca Mon Sep 17 00:00:00 2001 From: Tim Swast Date: Thu, 18 Jul 2024 22:27:26 +0000 Subject: [PATCH 22/59] fix test_df_construct_pandas_default on Python 3.9 --- .../ibis/backends/bigquery/compiler.py | 46 +++++++++++++++---- 1 file changed, 38 insertions(+), 8 deletions(-) diff --git a/third_party/bigframes_vendored/ibis/backends/bigquery/compiler.py b/third_party/bigframes_vendored/ibis/backends/bigquery/compiler.py index f0da1a8d95..7087994ba9 100644 --- a/third_party/bigframes_vendored/ibis/backends/bigquery/compiler.py +++ b/third_party/bigframes_vendored/ibis/backends/bigquery/compiler.py @@ -3,9 +3,9 @@ from __future__ import annotations +import bigframes_vendored.ibis.backends.bigquery.datatypes as bq_datatypes import ibis.backends.bigquery.compiler as bq_compiler import ibis.backends.sql.compiler as sql_compiler -import ibis.backends.sql.datatypes as sql_datatypes import ibis.common.exceptions as com from ibis.common.temporal import IntervalUnit import ibis.expr.operations.reductions as ibis_reductions @@ -13,6 +13,41 @@ import sqlglot.expressions as sge +def _convert(value, ibis_type): + if value is None: + return sge.Null() + + # Older versions of SQLGlot can't support time literals in convert(). + if ibis_type.is_time(): + return sge.TimeAdd( + this=sge.TimeFromParts( + hour=sge.convert(value.hour), + min=sge.convert(value.minute), + sec=sge.convert(value.second), + ), + expression=sge.convert(value.microsecond), + unit=sge.Var(this="MICROSECOND"), + ) + + # Older versions of SQLGlot don't distinguish DATETIME from TIMESTAMP in convert(). + if ibis_type.is_timestamp(): + if ibis_type.timezone == "UTC": + return sge.cast( + sge.convert(value.isoformat()), sge.DataType.Type.TIMESTAMPTZ + ) + else: + return sge.cast( + sge.convert(value.strftime("%Y-%m-%d %H:%M:%S.%f")), + sge.DataType.Type.TIMESTAMP, + ) + + # In-memory nullable integers can get stored as floats. + if ibis_type.is_integer(): + value = int(value) + + return sge.convert(value) + + class BigQueryCompiler(bq_compiler.BigQueryCompiler): UNSUPPORTED_OPS = ( tuple( @@ -41,7 +76,7 @@ def visit_InMemoryTable(self, op, *, name, schema, data): expressions=[ sge.ColumnDef( this=sge.to_identifier(field, quoted=self.quoted), - kind=sql_datatypes.SqlglotType.from_ibis(type_), + kind=bq_datatypes.BigQueryType.from_ibis(type_), ) for field, type_ in zip(schema.names, schema.types) ], @@ -52,12 +87,7 @@ def visit_InMemoryTable(self, op, *, name, schema, data): values=[ sge.Tuple( expressions=tuple( - # In-memory nullable integers can get stored as floats. - sge.convert( - int(value) - if value is not None and type_.is_integer() - else value - ) + _convert(value, type_) for value, type_ in zip(row, schema.types) ) ) From a4c49cd7c4e8c813399f12c95a46c89db439742c Mon Sep 17 00:00:00 2001 From: Tim Swast Date: Fri, 19 Jul 2024 12:33:04 +0000 Subject: [PATCH 23/59] fix ShiftOp windows --- bigframes/core/block_transforms.py | 7 ++--- bigframes/core/blocks.py | 18 ++++++------ bigframes/core/groupby/__init__.py | 10 +++---- bigframes/dataframe.py | 28 +++++++++---------- bigframes/series.py | 44 ++++++++++-------------------- 5 files changed, 43 insertions(+), 64 deletions(-) diff --git a/bigframes/core/block_transforms.py b/bigframes/core/block_transforms.py index eaee2e2cc0..020725ca60 100644 --- a/bigframes/core/block_transforms.py +++ b/bigframes/core/block_transforms.py @@ -386,10 +386,9 @@ def value_counts( def pct_change(block: blocks.Block, periods: int = 1) -> blocks.Block: column_labels = block.column_labels - window_spec = windows.rows( - preceding=periods if periods > 0 else None, - following=-periods if periods < 0 else None, - ) + + # Window framing clause is not allowed for analytic function lag. + window_spec = windows.unbound() original_columns = block.value_columns block, shift_columns = block.multi_apply_window_op( diff --git a/bigframes/core/blocks.py b/bigframes/core/blocks.py index c2bf20076a..21fa457e30 100644 --- a/bigframes/core/blocks.py +++ b/bigframes/core/blocks.py @@ -50,7 +50,7 @@ import bigframes.core.sql as sql import bigframes.core.tree_properties as tree_properties import bigframes.core.utils as utils -import bigframes.core.window_spec as window_specs +import bigframes.core.window_spec as windows import bigframes.dtypes import bigframes.exceptions import bigframes.features @@ -876,7 +876,7 @@ def multi_apply_window_op( self, columns: typing.Sequence[str], op: agg_ops.WindowOp, - window_spec: window_specs.WindowSpec, + window_spec: windows.WindowSpec, *, skip_null_groups: bool = False, never_skip_nulls: bool = False, @@ -935,7 +935,7 @@ def apply_window_op( self, column: str, op: agg_ops.WindowOp, - window_spec: window_specs.WindowSpec, + window_spec: windows.WindowSpec, *, result_label: Label = None, skip_null_groups: bool = False, @@ -1456,7 +1456,7 @@ def grouped_head( value_columns: typing.Sequence[str], n: int, ): - window_spec = window_specs.cumulative_rows(grouping_keys=tuple(by_column_ids)) + window_spec = windows.cumulative_rows(grouping_keys=tuple(by_column_ids)) block, result_id = self.apply_window_op( value_columns[0], @@ -2383,11 +2383,8 @@ def _is_monotonic( if op_name in self._stats_cache[column_name]: return self._stats_cache[column_name][op_name] - period = 1 - window = window_specs.rows( - preceding=period, - following=None, - ) + # Window framing clause is not allowed for analytic function lag. + window_spec = windows.unbound() # any NaN value means not monotonic block, last_notna_id = self.apply_unary_op(column_ids[0], ops.notnull_op) @@ -2398,10 +2395,11 @@ def _is_monotonic( ) # loop over all columns to check monotonicity + period = 1 last_result_id = None for column_id in column_ids[::-1]: block, lag_result_id = block.apply_window_op( - column_id, agg_ops.ShiftOp(period), window + column_id, agg_ops.ShiftOp(period), window_spec ) block, strict_monotonic_id = block.apply_binary_op( column_id, lag_result_id, ops.gt_op if increasing else ops.lt_op diff --git a/bigframes/core/groupby/__init__.py b/bigframes/core/groupby/__init__.py index 11a5d43ba0..bf878f730a 100644 --- a/bigframes/core/groupby/__init__.py +++ b/bigframes/core/groupby/__init__.py @@ -254,10 +254,9 @@ def cumprod(self, *args, **kwargs) -> df.DataFrame: @validations.requires_strict_ordering() def shift(self, periods=1) -> series.Series: - window = window_specs.rows( + # Window framing clause is not allowed for analytic function lag. + window = window_specs.unbound( grouping_keys=tuple(self._by_col_ids), - preceding=periods if periods > 0 else None, - following=-periods if periods < 0 else None, ) return self._apply_window_op(agg_ops.ShiftOp(periods), window=window) @@ -685,10 +684,9 @@ def cumcount(self, *args, **kwargs) -> series.Series: @validations.requires_strict_ordering() def shift(self, periods=1) -> series.Series: """Shift index by desired number of periods.""" - window = window_specs.rows( + # Window framing clause is not allowed for analytic function lag. + window = window_specs.unbound( grouping_keys=tuple(self._by_col_ids), - preceding=periods if periods > 0 else None, - following=-periods if periods < 0 else None, ) return self._apply_window_op(agg_ops.ShiftOp(periods), window=window) diff --git a/bigframes/dataframe.py b/bigframes/dataframe.py index 4dcc4414ed..23c3a5a255 100644 --- a/bigframes/dataframe.py +++ b/bigframes/dataframe.py @@ -64,7 +64,7 @@ import bigframes.core.utils as utils import bigframes.core.validations as validations import bigframes.core.window -import bigframes.core.window_spec as window_spec +import bigframes.core.window_spec as windows import bigframes.dtypes import bigframes.exceptions import bigframes.formatting_helpers as formatter @@ -1964,12 +1964,12 @@ def replace( @validations.requires_strict_ordering() def ffill(self, *, limit: typing.Optional[int] = None) -> DataFrame: - window = window_spec.rows(preceding=limit, following=0) + window = windows.rows(preceding=limit, following=0) return self._apply_window_op(agg_ops.LastNonNullOp(), window) @validations.requires_strict_ordering() def bfill(self, *, limit: typing.Optional[int] = None) -> DataFrame: - window = window_spec.rows(preceding=0, following=limit) + window = windows.rows(preceding=0, following=limit) return self._apply_window_op(agg_ops.FirstNonNullOp(), window) def isin(self, values) -> DataFrame: @@ -2676,7 +2676,7 @@ def _perform_join_by_index( @validations.requires_strict_ordering() def rolling(self, window: int, min_periods=None) -> bigframes.core.window.Window: # To get n size window, need current row and n-1 preceding rows. - window_def = window_spec.rows( + window_def = windows.rows( preceding=window - 1, following=0, min_periods=min_periods or window ) return bigframes.core.window.Window( @@ -2685,7 +2685,7 @@ def rolling(self, window: int, min_periods=None) -> bigframes.core.window.Window @validations.requires_strict_ordering() def expanding(self, min_periods: int = 1) -> bigframes.core.window.Window: - window = window_spec.cumulative_rows(min_periods=min_periods) + window = windows.cumulative_rows(min_periods=min_periods) return bigframes.core.window.Window( self._block, window, self._block.value_columns ) @@ -2796,7 +2796,7 @@ def cumsum(self): raise ValueError("All values must be numeric to apply cumsum.") return self._apply_window_op( agg_ops.sum_op, - window_spec.cumulative_rows(), + windows.cumulative_rows(), ) @validations.requires_strict_ordering() @@ -2809,34 +2809,32 @@ def cumprod(self) -> DataFrame: raise ValueError("All values must be numeric to apply cumsum.") return self._apply_window_op( agg_ops.product_op, - window_spec.cumulative_rows(), + windows.cumulative_rows(), ) @validations.requires_strict_ordering() def cummin(self) -> DataFrame: return self._apply_window_op( agg_ops.min_op, - window_spec.cumulative_rows(), + windows.cumulative_rows(), ) @validations.requires_strict_ordering() def cummax(self) -> DataFrame: return self._apply_window_op( agg_ops.max_op, - window_spec.cumulative_rows(), + windows.cumulative_rows(), ) @validations.requires_strict_ordering() def shift(self, periods: int = 1) -> DataFrame: - window = window_spec.rows( - preceding=periods if periods > 0 else None, - following=-periods if periods < 0 else None, - ) + # Window framing clause is not allowed for analytic function lag. + window = windows.unbound() return self._apply_window_op(agg_ops.ShiftOp(periods), window) @validations.requires_strict_ordering() def diff(self, periods: int = 1) -> DataFrame: - window = window_spec.rows( + window = windows.rows( preceding=periods if periods > 0 else None, following=-periods if periods < 0 else None, ) @@ -2851,7 +2849,7 @@ def pct_change(self, periods: int = 1) -> DataFrame: def _apply_window_op( self, op: agg_ops.WindowOp, - window_spec: window_spec.WindowSpec, + window_spec: windows.WindowSpec, ): block, result_ids = self._block.multi_apply_window_op( self._block.value_columns, diff --git a/bigframes/series.py b/bigframes/series.py index c325783e96..029ffa6439 100644 --- a/bigframes/series.py +++ b/bigframes/series.py @@ -45,7 +45,7 @@ import bigframes.core.utils as utils import bigframes.core.validations as validations import bigframes.core.window -import bigframes.core.window_spec +import bigframes.core.window_spec as windows import bigframes.dataframe import bigframes.dtypes import bigframes.formatting_helpers as formatter @@ -462,13 +462,11 @@ def case_when(self, caselist) -> Series: @validations.requires_strict_ordering() def cumsum(self) -> Series: - return self._apply_window_op( - agg_ops.sum_op, bigframes.core.window_spec.cumulative_rows() - ) + return self._apply_window_op(agg_ops.sum_op, windows.cumulative_rows()) @validations.requires_strict_ordering() def ffill(self, *, limit: typing.Optional[int] = None) -> Series: - window = bigframes.core.window_spec.rows(preceding=limit, following=0) + window = windows.rows(preceding=limit, following=0) return self._apply_window_op(agg_ops.LastNonNullOp(), window) pad = ffill @@ -476,38 +474,30 @@ def ffill(self, *, limit: typing.Optional[int] = None) -> Series: @validations.requires_strict_ordering() def bfill(self, *, limit: typing.Optional[int] = None) -> Series: - window = bigframes.core.window_spec.rows(preceding=0, following=limit) + window = windows.rows(preceding=0, following=limit) return self._apply_window_op(agg_ops.FirstNonNullOp(), window) @validations.requires_strict_ordering() def cummax(self) -> Series: - return self._apply_window_op( - agg_ops.max_op, bigframes.core.window_spec.cumulative_rows() - ) + return self._apply_window_op(agg_ops.max_op, windows.cumulative_rows()) @validations.requires_strict_ordering() def cummin(self) -> Series: - return self._apply_window_op( - agg_ops.min_op, bigframes.core.window_spec.cumulative_rows() - ) + return self._apply_window_op(agg_ops.min_op, windows.cumulative_rows()) @validations.requires_strict_ordering() def cumprod(self) -> Series: - return self._apply_window_op( - agg_ops.product_op, bigframes.core.window_spec.cumulative_rows() - ) + return self._apply_window_op(agg_ops.product_op, windows.cumulative_rows()) @validations.requires_strict_ordering() def shift(self, periods: int = 1) -> Series: - window = bigframes.core.window_spec.rows( - preceding=periods if periods > 0 else None, - following=-periods if periods < 0 else None, - ) - return self._apply_window_op(agg_ops.ShiftOp(periods), window) + # Window framing clause is not allowed for analytic function lag. + window_spec = windows.unbound() + return self._apply_window_op(agg_ops.ShiftOp(periods), window_spec) @validations.requires_strict_ordering() def diff(self, periods: int = 1) -> Series: - window = bigframes.core.window_spec.rows( + window = windows.rows( preceding=periods if periods > 0 else None, following=-periods if periods < 0 else None, ) @@ -1044,7 +1034,7 @@ def mode(self) -> Series: block, max_value_count_col_id = block.apply_window_op( value_count_col_id, agg_ops.max_op, - window_spec=bigframes.core.window_spec.unbound(), + window_spec=windows.unbound(), ) block, is_mode_col_id = block.apply_binary_op( value_count_col_id, @@ -1277,9 +1267,7 @@ def _apply_aggregation( ) -> Any: return self._block.get_stat(self._value_column, op) - def _apply_window_op( - self, op: agg_ops.WindowOp, window_spec: bigframes.core.window_spec.WindowSpec - ): + def _apply_window_op(self, op: agg_ops.WindowOp, window_spec: windows.WindowSpec): block = self._block block, result_id = block.apply_window_op( self._value_column, op, window_spec=window_spec, result_label=self.name @@ -1336,7 +1324,7 @@ def sort_index(self, *, axis=0, ascending=True, na_position="last") -> Series: @validations.requires_strict_ordering() def rolling(self, window: int, min_periods=None) -> bigframes.core.window.Window: # To get n size window, need current row and n-1 preceding rows. - window_spec = bigframes.core.window_spec.rows( + window_spec = windows.rows( preceding=window - 1, following=0, min_periods=min_periods or window ) return bigframes.core.window.Window( @@ -1345,9 +1333,7 @@ def rolling(self, window: int, min_periods=None) -> bigframes.core.window.Window @validations.requires_strict_ordering() def expanding(self, min_periods: int = 1) -> bigframes.core.window.Window: - window_spec = bigframes.core.window_spec.cumulative_rows( - min_periods=min_periods - ) + window_spec = windows.cumulative_rows(min_periods=min_periods) return bigframes.core.window.Window( self._block, window_spec, self._block.value_columns, is_series=True ) From 616c99fcdc57f2a09b6cd6c8bca2565275e0f534 Mon Sep 17 00:00:00 2001 From: Tim Swast Date: Wed, 24 Jul 2024 17:50:51 +0000 Subject: [PATCH 24/59] fix inf to SQL by treating values as literal in in memory table --- .../ibis/backends/bigquery/compiler.py | 37 +------------------ 1 file changed, 1 insertion(+), 36 deletions(-) diff --git a/third_party/bigframes_vendored/ibis/backends/bigquery/compiler.py b/third_party/bigframes_vendored/ibis/backends/bigquery/compiler.py index 7087994ba9..8097ce8952 100644 --- a/third_party/bigframes_vendored/ibis/backends/bigquery/compiler.py +++ b/third_party/bigframes_vendored/ibis/backends/bigquery/compiler.py @@ -13,41 +13,6 @@ import sqlglot.expressions as sge -def _convert(value, ibis_type): - if value is None: - return sge.Null() - - # Older versions of SQLGlot can't support time literals in convert(). - if ibis_type.is_time(): - return sge.TimeAdd( - this=sge.TimeFromParts( - hour=sge.convert(value.hour), - min=sge.convert(value.minute), - sec=sge.convert(value.second), - ), - expression=sge.convert(value.microsecond), - unit=sge.Var(this="MICROSECOND"), - ) - - # Older versions of SQLGlot don't distinguish DATETIME from TIMESTAMP in convert(). - if ibis_type.is_timestamp(): - if ibis_type.timezone == "UTC": - return sge.cast( - sge.convert(value.isoformat()), sge.DataType.Type.TIMESTAMPTZ - ) - else: - return sge.cast( - sge.convert(value.strftime("%Y-%m-%d %H:%M:%S.%f")), - sge.DataType.Type.TIMESTAMP, - ) - - # In-memory nullable integers can get stored as floats. - if ibis_type.is_integer(): - value = int(value) - - return sge.convert(value) - - class BigQueryCompiler(bq_compiler.BigQueryCompiler): UNSUPPORTED_OPS = ( tuple( @@ -87,7 +52,7 @@ def visit_InMemoryTable(self, op, *, name, schema, data): values=[ sge.Tuple( expressions=tuple( - _convert(value, type_) + self.visit_Literal(None, value=value, dtype=type_) for value, type_ in zip(row, schema.types) ) ) From f266e340a9b9aba4654daaf305653cfb3789696f Mon Sep 17 00:00:00 2001 From: Tim Swast Date: Wed, 7 Aug 2024 15:13:40 +0000 Subject: [PATCH 25/59] fix unit tests for ibis-framework 9.2.0 --- setup.py | 2 +- testing/constraints-3.11.txt | 3 +++ .../ibis/backends/bigquery/compiler.py | 18 ++++++++++++++++-- 3 files changed, 20 insertions(+), 3 deletions(-) diff --git a/setup.py b/setup.py index c970c40b88..18bbe79a5c 100644 --- a/setup.py +++ b/setup.py @@ -47,7 +47,7 @@ "google-cloud-iam >=2.12.1", "google-cloud-resource-manager >=1.10.3", "google-cloud-storage >=2.0.0", - "ibis-framework[bigquery] >=9.0.0,<10.0.0dev", + "ibis-framework[bigquery] >=9.0.0,<=9.2.0", "jellyfish >=0.8.9", "pandas >=1.5.3", "pyarrow >=10.0.1", diff --git a/testing/constraints-3.11.txt b/testing/constraints-3.11.txt index e69de29bb2..60ac0af60f 100644 --- a/testing/constraints-3.11.txt +++ b/testing/constraints-3.11.txt @@ -0,0 +1,3 @@ +# Some internal modules have moved, +# so make sure we test on all ibis-framework 9.x versions. +ibis-framework==9.1.0 diff --git a/third_party/bigframes_vendored/ibis/backends/bigquery/compiler.py b/third_party/bigframes_vendored/ibis/backends/bigquery/compiler.py index 8097ce8952..26eb314aaa 100644 --- a/third_party/bigframes_vendored/ibis/backends/bigquery/compiler.py +++ b/third_party/bigframes_vendored/ibis/backends/bigquery/compiler.py @@ -4,14 +4,28 @@ from __future__ import annotations import bigframes_vendored.ibis.backends.bigquery.datatypes as bq_datatypes -import ibis.backends.bigquery.compiler as bq_compiler -import ibis.backends.sql.compiler as sql_compiler import ibis.common.exceptions as com from ibis.common.temporal import IntervalUnit import ibis.expr.operations.reductions as ibis_reductions import sqlglot as sg import sqlglot.expressions as sge +# The compilers moved between ibis-framework 9.1 and 9.2. +try: + # ibis-framework 9.0 & 9.1 + import ibis.backends.bigquery.compiler as bq_compiler +except ImportError: + # ibis-framework 9.2 + import ibis.backends.sql.compilers.bigquery as bq_compiler + + +try: + # ibis-framework 9.0 & 9.1 + import ibis.backends.sql.compiler as sql_compiler +except ImportError: + # ibis-framework 9.2 + import ibis.backends.sql.compilers.base as sql_compiler + class BigQueryCompiler(bq_compiler.BigQueryCompiler): UNSUPPORTED_OPS = ( From 7a035855c9be9a7902c9bb7af48dfbac1cf26b30 Mon Sep 17 00:00:00 2001 From: Tim Swast Date: Wed, 7 Aug 2024 15:22:37 +0000 Subject: [PATCH 26/59] fix Python 3.10 unit tests by syncing deps --- testing/constraints-3.10.txt | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/testing/constraints-3.10.txt b/testing/constraints-3.10.txt index 5782b03a2f..b11ab5a88d 100644 --- a/testing/constraints-3.10.txt +++ b/testing/constraints-3.10.txt @@ -3,10 +3,10 @@ google-auth==2.27.0 ipykernel==5.5.6 ipython==7.34.0 notebook==6.5.5 -pandas==2.0.3 -pandas-stubs==2.0.3.230814 +pandas==2.1.4 +pandas-stubs==2.1.4.231227 portpicker==1.5.2 -requests==2.31.0 +requests==2.32.3 tornado==6.3.3 absl-py==1.4.0 debugpy==1.6.6 From 387fbd9c7564dff46890730232ce96aefa1d15f4 Mon Sep 17 00:00:00 2001 From: Chelsea Lin Date: Tue, 13 Aug 2024 18:36:07 +0000 Subject: [PATCH 27/59] fixing remote function after merge --- bigframes/functions/_remote_function_session.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/bigframes/functions/_remote_function_session.py b/bigframes/functions/_remote_function_session.py index 0ab19ca353..fd71dea8b5 100644 --- a/bigframes/functions/_remote_function_session.py +++ b/bigframes/functions/_remote_function_session.py @@ -503,7 +503,8 @@ def try_delattr(attr): node = ibis.udf.scalar.builtin( func, name=rf_name, - schema=f"{dataset_ref.project}.{dataset_ref.dataset_id}", + catalog=dataset_ref.project, + database=dataset_ref.dataset_id, signature=(ibis_signature.input_types, ibis_signature.output_type), ) func.bigframes_cloud_function = ( From 232f2f9bcaed2addb99a055ddc64ecc11f28197c Mon Sep 17 00:00:00 2001 From: Chelsea Lin Date: Sat, 17 Aug 2024 00:04:16 +0000 Subject: [PATCH 28/59] fix visit_NonNullLiteral for int types --- .../ibis/backends/bigquery/compiler.py | 47 ++++++++++--------- 1 file changed, 25 insertions(+), 22 deletions(-) diff --git a/third_party/bigframes_vendored/ibis/backends/bigquery/compiler.py b/third_party/bigframes_vendored/ibis/backends/bigquery/compiler.py index 26eb314aaa..46e81ed440 100644 --- a/third_party/bigframes_vendored/ibis/backends/bigquery/compiler.py +++ b/third_party/bigframes_vendored/ibis/backends/bigquery/compiler.py @@ -3,6 +3,7 @@ from __future__ import annotations +import numpy as np import bigframes_vendored.ibis.backends.bigquery.datatypes as bq_datatypes import ibis.common.exceptions as com from ibis.common.temporal import IntervalUnit @@ -45,33 +46,33 @@ def visit_InMemoryTable(self, op, *, name, schema, data): tuples = data.to_frame().itertuples(index=False) quoted = self.quoted columns = [sg.column(col, quoted=quoted) for col in schema.names] + array_expr = sge.DataType( + this=sge.DataType.Type.STRUCT, + expressions=[ + sge.ColumnDef( + this=sge.to_identifier(field, quoted=self.quoted), + kind=bq_datatypes.BigQueryType.from_ibis(type_), + ) + for field, type_ in zip(schema.names, schema.types) + ], + nested=True, + ) + array_values = [ + sge.Tuple( + expressions=tuple( + self.visit_Literal(None, value=value, dtype=type_) + for value, type_ in zip(row, schema.types) + ) + ) + for row in tuples + ] expr = sge.Unnest( expressions=[ sge.DataType( this=sge.DataType.Type.ARRAY, - expressions=[ - sge.DataType( - this=sge.DataType.Type.STRUCT, - expressions=[ - sge.ColumnDef( - this=sge.to_identifier(field, quoted=self.quoted), - kind=bq_datatypes.BigQueryType.from_ibis(type_), - ) - for field, type_ in zip(schema.names, schema.types) - ], - nested=True, - ) - ], + expressions=[array_expr], nested=True, - values=[ - sge.Tuple( - expressions=tuple( - self.visit_Literal(None, value=value, dtype=type_) - for value, type_ in zip(row, schema.types) - ) - ) - for row in tuples - ], + values=array_values, ), ], alias=sge.TableAlias( @@ -114,6 +115,8 @@ def visit_NonNullLiteral(self, op, *, value, dtype): ) elif dtype.is_uuid(): return sge.convert(str(value)) + elif dtype.is_int64(): + return sge.convert(np.int64(value)) return None # Custom operators. From b9d6826f31fea697060db3c84b5d0f538909069c Mon Sep 17 00:00:00 2001 From: Chelsea Lin Date: Mon, 19 Aug 2024 22:19:22 +0000 Subject: [PATCH 29/59] visit_WindowFunction to fix s.median() method --- .../ibis/backends/bigquery/compiler.py | 49 +++++++++++++++++-- 1 file changed, 45 insertions(+), 4 deletions(-) diff --git a/third_party/bigframes_vendored/ibis/backends/bigquery/compiler.py b/third_party/bigframes_vendored/ibis/backends/bigquery/compiler.py index 46e81ed440..db48766a57 100644 --- a/third_party/bigframes_vendored/ibis/backends/bigquery/compiler.py +++ b/third_party/bigframes_vendored/ibis/backends/bigquery/compiler.py @@ -3,11 +3,11 @@ from __future__ import annotations -import numpy as np import bigframes_vendored.ibis.backends.bigquery.datatypes as bq_datatypes import ibis.common.exceptions as com from ibis.common.temporal import IntervalUnit import ibis.expr.operations.reductions as ibis_reductions +import numpy as np import sqlglot as sg import sqlglot.expressions as sge @@ -146,6 +146,46 @@ def visit_LastNonNullValue(self, op, *, arg): def visit_ToJsonString(self, op, *, arg): return self.f.to_json_string(arg) + def visit_Quantile(self, op, *, arg, quantile, where): + return sge.PercentileCont(this=arg, expression=quantile) + + def visit_WindowFunction(self, op, *, how, func, start, end, group_by, order_by): + # Patch for https://github.com/ibis-project/ibis/issues/9872 + if start is None and end is None: + spec = None + else: + if start is None: + start = {} + if end is None: + end = {} + + start_value = start.get("value", "UNBOUNDED") + start_side = start.get("side", "PRECEDING") + end_value = end.get("value", "UNBOUNDED") + end_side = end.get("side", "FOLLOWING") + + if getattr(start_value, "this", None) == "0": + start_value = "CURRENT ROW" + start_side = None + + if getattr(end_value, "this", None) == "0": + end_value = "CURRENT ROW" + end_side = None + + spec = sge.WindowSpec( + kind=how.upper(), + start=start_value, + start_side=start_side, + end=end_value, + end_side=end_side, + over="OVER", + ) + spec = self._minimize_spec(op.start, op.end, spec) + + order = sge.Order(expressions=order_by) if order_by else None + + return sge.Window(this=func, partition_by=group_by, order=order, spec=spec) + # Override implementation. # We monkeypatch individual methods because the class might have already been imported in other modules. @@ -165,9 +205,10 @@ def visit_ToJsonString(self, op, *, arg): BigQueryCompiler.visit_LastNonNullValue ) bq_compiler.BigQueryCompiler.visit_ToJsonString = BigQueryCompiler.visit_ToJsonString +bq_compiler.BigQueryCompiler.visit_Quantile = BigQueryCompiler.visit_Quantile +bq_compiler.BigQueryCompiler.visit_WindowFunction = ( + BigQueryCompiler.visit_WindowFunction +) # TODO(swast): sqlglot base implementation appears to work fine for the bigquery backend, at least in our windowed contexts. See: ISSUE NUMBER bq_compiler.BigQueryCompiler.UNSUPPORTED_OPS = BigQueryCompiler.UNSUPPORTED_OPS -bq_compiler.BigQueryCompiler.visit_Quantile = ( - sql_compiler.SQLGlotCompiler.visit_Quantile -) From 79c8f680dd790d49c92b635f889e191edb0e6cd4 Mon Sep 17 00:00:00 2001 From: Chelsea Lin Date: Mon, 19 Aug 2024 22:21:57 +0000 Subject: [PATCH 30/59] fix lint --- .../bigframes_vendored/ibis/backends/bigquery/compiler.py | 8 -------- 1 file changed, 8 deletions(-) diff --git a/third_party/bigframes_vendored/ibis/backends/bigquery/compiler.py b/third_party/bigframes_vendored/ibis/backends/bigquery/compiler.py index db48766a57..68723927e9 100644 --- a/third_party/bigframes_vendored/ibis/backends/bigquery/compiler.py +++ b/third_party/bigframes_vendored/ibis/backends/bigquery/compiler.py @@ -20,14 +20,6 @@ import ibis.backends.sql.compilers.bigquery as bq_compiler -try: - # ibis-framework 9.0 & 9.1 - import ibis.backends.sql.compiler as sql_compiler -except ImportError: - # ibis-framework 9.2 - import ibis.backends.sql.compilers.base as sql_compiler - - class BigQueryCompiler(bq_compiler.BigQueryCompiler): UNSUPPORTED_OPS = ( tuple( From a44a7452b89355fae2c6c831180a997301f283ca Mon Sep 17 00:00:00 2001 From: Chelsea Lin Date: Tue, 20 Aug 2024 04:56:28 +0000 Subject: [PATCH 31/59] fix s.diff with window --- bigframes/core/compile/compiled.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/bigframes/core/compile/compiled.py b/bigframes/core/compile/compiled.py index 103a376633..6ce2e08692 100644 --- a/bigframes/core/compile/compiled.py +++ b/bigframes/core/compile/compiled.py @@ -1293,9 +1293,10 @@ def _ibis_window_from_spec( bounds.preceding, bounds.following, how="range" ) if isinstance(bounds, RowsWindowBounds): - window = window.preceding_following( - bounds.preceding, bounds.following, how="rows" - ) + if bounds.preceding is not None and bounds.following is not None: + window = window.preceding_following( + bounds.preceding, bounds.following, how="rows" + ) else: raise ValueError(f"unrecognized window bounds {bounds}") return window From d7089ca7844cf88530c1e046b89bf6c72cca9c58 Mon Sep 17 00:00:00 2001 From: Chelsea Lin Date: Tue, 20 Aug 2024 17:39:46 +0000 Subject: [PATCH 32/59] fix mypy --- bigframes/operations/__init__.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/bigframes/operations/__init__.py b/bigframes/operations/__init__.py index fb333d7a53..138596e86f 100644 --- a/bigframes/operations/__init__.py +++ b/bigframes/operations/__init__.py @@ -852,7 +852,7 @@ def output_type(self, *input_types: dtypes.ExpressionType) -> dtypes.ExpressionT # Just parameterless unary ops for now # TODO: Parameter mappings -NUMPY_TO_OP: typing.Final = { +NUMPY_TO_OP: dict[np.ufunc, UnaryOp] = { np.sin: sin_op, np.cos: cos_op, np.tan: tan_op, @@ -877,7 +877,7 @@ def output_type(self, *input_types: dtypes.ExpressionType) -> dtypes.ExpressionT } -NUMPY_TO_BINOP: typing.Final = { +NUMPY_TO_BINOP: dict[np.ufunc, UnaryOp] = { np.add: add_op, np.subtract: sub_op, np.multiply: mul_op, From 133e05391ae6256bfd90b7a36549e4ba5fd42434 Mon Sep 17 00:00:00 2001 From: Chelsea Lin Date: Tue, 20 Aug 2024 21:57:30 +0000 Subject: [PATCH 33/59] patch visit_And to fix is_monotonic methods --- bigframes/core/blocks.py | 5 ++--- bigframes/core/compile/aggregate_compiler.py | 4 ++-- .../ibis/backends/bigquery/compiler.py | 12 ++++++++++++ 3 files changed, 16 insertions(+), 5 deletions(-) diff --git a/bigframes/core/blocks.py b/bigframes/core/blocks.py index 4ac8e572da..8816bb1beb 100644 --- a/bigframes/core/blocks.py +++ b/bigframes/core/blocks.py @@ -2414,8 +2414,8 @@ def _is_monotonic( if op_name in self._stats_cache[column_name]: return self._stats_cache[column_name][op_name] - # Window framing clause is not allowed for analytic function lag. - window_spec = windows.unbound() + period = 1 + window_spec = windows.rows(preceding=period, following=None) # any NaN value means not monotonic block, last_notna_id = self.apply_unary_op(column_ids[0], ops.notnull_op) @@ -2426,7 +2426,6 @@ def _is_monotonic( ) # loop over all columns to check monotonicity - period = 1 last_result_id = None for column_id in column_ids[::-1]: block, lag_result_id = block.apply_window_op( diff --git a/bigframes/core/compile/aggregate_compiler.py b/bigframes/core/compile/aggregate_compiler.py index 97628d9261..0516eb01a2 100644 --- a/bigframes/core/compile/aggregate_compiler.py +++ b/bigframes/core/compile/aggregate_compiler.py @@ -525,7 +525,7 @@ def _( result = _is_true(column).all() return cast( ibis_types.BooleanScalar, - _apply_window_if_present(result, window).fillna(ibis_types.literal(True)), + _apply_window_if_present(result, window).fill_null(ibis_types.literal(True)), ) @@ -539,7 +539,7 @@ def _( result = _is_true(column).any() return cast( ibis_types.BooleanScalar, - _apply_window_if_present(result, window).fillna(ibis_types.literal(False)), + _apply_window_if_present(result, window).fill_null(ibis_types.literal(False)), ) diff --git a/third_party/bigframes_vendored/ibis/backends/bigquery/compiler.py b/third_party/bigframes_vendored/ibis/backends/bigquery/compiler.py index 68723927e9..4f20779c0e 100644 --- a/third_party/bigframes_vendored/ibis/backends/bigquery/compiler.py +++ b/third_party/bigframes_vendored/ibis/backends/bigquery/compiler.py @@ -19,6 +19,13 @@ # ibis-framework 9.2 import ibis.backends.sql.compilers.bigquery as bq_compiler +try: + # ibis-framework 9.0 & 9.1 + import ibis.backends.sql.compiler as sql_compiler +except ImportError: + # ibis-framework 9.2 + import ibis.backends.sql.compilers.base as sql_compiler + class BigQueryCompiler(bq_compiler.BigQueryCompiler): UNSUPPORTED_OPS = ( @@ -178,6 +185,10 @@ def visit_WindowFunction(self, op, *, how, func, start, end, group_by, order_by) return sge.Window(this=func, partition_by=group_by, order=order, spec=spec) + @sql_compiler.parenthesize_inputs + def visit_And(self, op, *, left, right): + return sge.And(this=sge.Paren(this=left), expression=sge.Paren(this=right)) + # Override implementation. # We monkeypatch individual methods because the class might have already been imported in other modules. @@ -201,6 +212,7 @@ def visit_WindowFunction(self, op, *, how, func, start, end, group_by, order_by) bq_compiler.BigQueryCompiler.visit_WindowFunction = ( BigQueryCompiler.visit_WindowFunction ) +bq_compiler.BigQueryCompiler.visit_And = BigQueryCompiler.visit_And # TODO(swast): sqlglot base implementation appears to work fine for the bigquery backend, at least in our windowed contexts. See: ISSUE NUMBER bq_compiler.BigQueryCompiler.UNSUPPORTED_OPS = BigQueryCompiler.UNSUPPORTED_OPS From ded23346a5c747f59c0da3c1619211b8fcdfad8d Mon Sep 17 00:00:00 2001 From: Chelsea Lin Date: Wed, 21 Aug 2024 16:51:19 +0000 Subject: [PATCH 34/59] fix mypy and fillna warning --- bigframes/core/compile/single_column.py | 2 +- bigframes/operations/__init__.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/bigframes/core/compile/single_column.py b/bigframes/core/compile/single_column.py index 9b621c9c79..460f3c0f8f 100644 --- a/bigframes/core/compile/single_column.py +++ b/bigframes/core/compile/single_column.py @@ -170,4 +170,4 @@ def value_to_join_key(value: ibis_types.Value): """Converts nullable values to non-null string SQL will not match null keys together - but pandas does.""" if not value.type().is_string(): value = value.cast(ibis_dtypes.str) - return value.fillna(ibis_types.literal("$NULL_SENTINEL$")) + return value.fill_null(ibis_types.literal("$NULL_SENTINEL$")) diff --git a/bigframes/operations/__init__.py b/bigframes/operations/__init__.py index 138596e86f..6acca6e21e 100644 --- a/bigframes/operations/__init__.py +++ b/bigframes/operations/__init__.py @@ -877,7 +877,7 @@ def output_type(self, *input_types: dtypes.ExpressionType) -> dtypes.ExpressionT } -NUMPY_TO_BINOP: dict[np.ufunc, UnaryOp] = { +NUMPY_TO_BINOP: dict[np.ufunc, BinaryOp] = { np.add: add_op, np.subtract: sub_op, np.multiply: mul_op, From 2edaa9d5044bac7decfd5a81bcdb0414cb78f4f4 Mon Sep 17 00:00:00 2001 From: Chelsea Lin Date: Wed, 21 Aug 2024 18:00:10 +0000 Subject: [PATCH 35/59] undo window changes for test_series_autocorr --- bigframes/series.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/bigframes/series.py b/bigframes/series.py index bb59ba4075..8f5296e114 100644 --- a/bigframes/series.py +++ b/bigframes/series.py @@ -488,8 +488,10 @@ def cumprod(self) -> Series: @validations.requires_ordering() def shift(self, periods: int = 1) -> Series: - # Window framing clause is not allowed for analytic function lag. - window_spec = windows.unbound() + window_spec = windows.rows( + preceding=periods if periods > 0 else None, + following=-periods if periods < 0 else None, + ) return self._apply_window_op(agg_ops.ShiftOp(periods), window_spec) @validations.requires_ordering() From 8f6165fa626e5baf0365cc22714b2a59a457988d Mon Sep 17 00:00:00 2001 From: Chelsea Lin Date: Wed, 21 Aug 2024 18:22:42 +0000 Subject: [PATCH 36/59] undo fill_null because it was missed at 9.0 version --- bigframes/core/compile/aggregate_compiler.py | 4 ++-- bigframes/core/compile/single_column.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/bigframes/core/compile/aggregate_compiler.py b/bigframes/core/compile/aggregate_compiler.py index 0516eb01a2..97628d9261 100644 --- a/bigframes/core/compile/aggregate_compiler.py +++ b/bigframes/core/compile/aggregate_compiler.py @@ -525,7 +525,7 @@ def _( result = _is_true(column).all() return cast( ibis_types.BooleanScalar, - _apply_window_if_present(result, window).fill_null(ibis_types.literal(True)), + _apply_window_if_present(result, window).fillna(ibis_types.literal(True)), ) @@ -539,7 +539,7 @@ def _( result = _is_true(column).any() return cast( ibis_types.BooleanScalar, - _apply_window_if_present(result, window).fill_null(ibis_types.literal(False)), + _apply_window_if_present(result, window).fillna(ibis_types.literal(False)), ) diff --git a/bigframes/core/compile/single_column.py b/bigframes/core/compile/single_column.py index 460f3c0f8f..9b621c9c79 100644 --- a/bigframes/core/compile/single_column.py +++ b/bigframes/core/compile/single_column.py @@ -170,4 +170,4 @@ def value_to_join_key(value: ibis_types.Value): """Converts nullable values to non-null string SQL will not match null keys together - but pandas does.""" if not value.type().is_string(): value = value.cast(ibis_dtypes.str) - return value.fill_null(ibis_types.literal("$NULL_SENTINEL$")) + return value.fillna(ibis_types.literal("$NULL_SENTINEL$")) From e423f89d4e0b674a103ffe25773c317d9abf6ee3 Mon Sep 17 00:00:00 2001 From: Tim Swast Date: Fri, 23 Aug 2024 21:18:43 +0000 Subject: [PATCH 37/59] vendor more of ibis for python 3.9 compatibility --- bigframes/core/compile/compiled.py | 2 +- noxfile.py | 3 +- setup.py | 2 +- testing/constraints-3.12.txt | 3 + .../ibis/backends/bigquery/__init__.py | 2 - .../ibis/backends/bigquery/backend.py | 1265 +++++++++++++ .../ibis/backends/bigquery/compiler.py | 218 --- .../ibis/backends/sql/__init__.py | 0 .../ibis/backends/sql/compilers/__init__.py | 5 + .../ibis/backends/sql/compilers/base.py | 1648 +++++++++++++++++ .../sql/compilers/bigquery/__init__.py | 1149 ++++++++++++ .../ibis/backends/sql/rewrites.py | 514 +++++ .../bigframes_vendored/ibis/expr/rewrites.py | 382 ++++ 13 files changed, 4970 insertions(+), 223 deletions(-) create mode 100644 third_party/bigframes_vendored/ibis/backends/bigquery/backend.py delete mode 100644 third_party/bigframes_vendored/ibis/backends/bigquery/compiler.py create mode 100644 third_party/bigframes_vendored/ibis/backends/sql/__init__.py create mode 100644 third_party/bigframes_vendored/ibis/backends/sql/compilers/__init__.py create mode 100644 third_party/bigframes_vendored/ibis/backends/sql/compilers/base.py create mode 100644 third_party/bigframes_vendored/ibis/backends/sql/compilers/bigquery/__init__.py create mode 100644 third_party/bigframes_vendored/ibis/backends/sql/rewrites.py create mode 100644 third_party/bigframes_vendored/ibis/expr/rewrites.py diff --git a/bigframes/core/compile/compiled.py b/bigframes/core/compile/compiled.py index 6ce2e08692..0982f90c61 100644 --- a/bigframes/core/compile/compiled.py +++ b/bigframes/core/compile/compiled.py @@ -19,9 +19,9 @@ import typing from typing import Collection, Literal, Optional, Sequence +import bigframes_vendored.ibis.backends.bigquery.backend as ibis_bigquery import google.cloud.bigquery import ibis -import ibis.backends.bigquery as ibis_bigquery import ibis.backends.bigquery.datatypes import ibis.common.deferred # type: ignore import ibis.expr.datatypes as ibis_dtypes diff --git a/noxfile.py b/noxfile.py index 9ed85290fa..5e6298adbb 100644 --- a/noxfile.py +++ b/noxfile.py @@ -59,8 +59,9 @@ UNIT_TEST_EXTRAS: List[str] = [] UNIT_TEST_EXTRAS_BY_PYTHON: Dict[str, List[str]] = {} +# There are 4 different ibis-framework 9.x versions we want to test against. # 3.10 is needed for Windows tests. -SYSTEM_TEST_PYTHON_VERSIONS = ["3.9", "3.10", "3.12"] +SYSTEM_TEST_PYTHON_VERSIONS = ["3.9", "3.10", "3.11", "3.12"] SYSTEM_TEST_STANDARD_DEPENDENCIES = [ "jinja2", "mock", diff --git a/setup.py b/setup.py index 18bbe79a5c..ddaeec141c 100644 --- a/setup.py +++ b/setup.py @@ -47,7 +47,7 @@ "google-cloud-iam >=2.12.1", "google-cloud-resource-manager >=1.10.3", "google-cloud-storage >=2.0.0", - "ibis-framework[bigquery] >=9.0.0,<=9.2.0", + "ibis-framework[bigquery] >=9.0.0,<=9.3.0", "jellyfish >=0.8.9", "pandas >=1.5.3", "pyarrow >=10.0.1", diff --git a/testing/constraints-3.12.txt b/testing/constraints-3.12.txt index e69de29bb2..dbbb5a2d88 100644 --- a/testing/constraints-3.12.txt +++ b/testing/constraints-3.12.txt @@ -0,0 +1,3 @@ +# Some internal modules have moved, +# so make sure we test on all ibis-framework 9.x versions. +ibis-framework==9.2.0 diff --git a/third_party/bigframes_vendored/ibis/backends/bigquery/__init__.py b/third_party/bigframes_vendored/ibis/backends/bigquery/__init__.py index ee24a71446..e69de29bb2 100644 --- a/third_party/bigframes_vendored/ibis/backends/bigquery/__init__.py +++ b/third_party/bigframes_vendored/ibis/backends/bigquery/__init__.py @@ -1,2 +0,0 @@ -# Import all sub-modules to monkeypatch everything. -import bigframes_vendored.ibis.backends.bigquery.compiler # noqa diff --git a/third_party/bigframes_vendored/ibis/backends/bigquery/backend.py b/third_party/bigframes_vendored/ibis/backends/bigquery/backend.py new file mode 100644 index 0000000000..5c1ec85cec --- /dev/null +++ b/third_party/bigframes_vendored/ibis/backends/bigquery/backend.py @@ -0,0 +1,1265 @@ +# Contains code from https://github.com/ibis-project/ibis/blob/main/ibis/backends/bigquery/__init__.py + +"""BigQuery public API.""" + +from __future__ import annotations + +import concurrent.futures +import contextlib +import glob +import os +import re +from typing import Any, Optional, TYPE_CHECKING + +import bigframes_vendored.ibis.backends.sql.compilers as sc +import google.api_core.exceptions +import google.auth.credentials +import google.cloud.bigquery as bq +import google.cloud.bigquery_storage_v1 as bqstorage +import ibis +from ibis import util +from ibis.backends import CanCreateDatabase, CanCreateSchema +from ibis.backends.bigquery.client import ( + bigquery_param, + parse_project_and_dataset, + rename_partitioned_column, + schema_from_bigquery_table, +) +from ibis.backends.bigquery.datatypes import BigQuerySchema +from ibis.backends.sql import SQLBackend +from ibis.backends.sql.datatypes import BigQueryType +import ibis.common.exceptions as com +import ibis.expr.operations as ops +import ibis.expr.schema as sch +import ibis.expr.types as ir +import pydata_google_auth +from pydata_google_auth import cache +import sqlglot as sg +import sqlglot.expressions as sge + +if TYPE_CHECKING: + from collections.abc import Iterable, Mapping + from pathlib import Path + from urllib.parse import ParseResult + + import pandas as pd + import polars as pl + import pyarrow as pa + + +SCOPES = ["https://www.googleapis.com/auth/bigquery"] +EXTERNAL_DATA_SCOPES = [ + "https://www.googleapis.com/auth/bigquery", + "https://www.googleapis.com/auth/cloud-platform", + "https://www.googleapis.com/auth/drive", +] +CLIENT_ID = "546535678771-gvffde27nd83kfl6qbrnletqvkdmsese.apps.googleusercontent.com" +CLIENT_SECRET = "iU5ohAF2qcqrujegE3hQ1cPt" # noqa: S105 + + +def _create_user_agent(application_name: str) -> str: + user_agent = [] + + if application_name: + user_agent.append(application_name) + + user_agent_default_template = f"ibis/{ibis.__version__}" + user_agent.append(user_agent_default_template) + + return " ".join(user_agent) + + +def _create_client_info(application_name): + from google.api_core.client_info import ClientInfo + + return ClientInfo(user_agent=_create_user_agent(application_name)) + + +def _create_client_info_gapic(application_name): + from google.api_core.gapic_v1.client_info import ClientInfo + + return ClientInfo(user_agent=_create_user_agent(application_name)) + + +_MEMTABLE_PATTERN = re.compile( + r"^_?ibis_(?:[A-Za-z_][A-Za-z_0-9]*)_memtable_[a-z0-9]{26}$" +) + + +def _qualify_memtable( + node: sge.Expression, *, dataset: str | None, project: str | None +) -> sge.Expression: + """Add a BigQuery dataset and project to memtable references.""" + if isinstance(node, sge.Table) and _MEMTABLE_PATTERN.match(node.name) is not None: + node.args["db"] = dataset + node.args["catalog"] = project + # make sure to quote table location + node = _force_quote_table(node) + return node + + +def _remove_null_ordering_from_unsupported_window( + node: sge.Expression, +) -> sge.Expression: + """Remove null ordering in window frame clauses not supported by BigQuery. + + BigQuery has only partial support for NULL FIRST/LAST in RANGE windows so + we remove it from any window frame clause that doesn't support it. + + Here's the support matrix: + + ✅ sum(x) over (order by y desc nulls last) + 🚫 sum(x) over (order by y asc nulls last) + ✅ sum(x) over (order by y asc nulls first) + 🚫 sum(x) over (order by y desc nulls first) + """ + if isinstance(node, sge.Window): + order = node.args.get("order") + if order is not None: + for key in order.args["expressions"]: + kargs = key.args + if kargs.get("desc") is True and kargs.get("nulls_first", False): + kargs["nulls_first"] = False + elif kargs.get("desc") is False and not kargs.setdefault( + "nulls_first", True + ): + kargs["nulls_first"] = True + return node + + +def _force_quote_table(table: sge.Table) -> sge.Table: + """Force quote all the parts of a bigquery path. + + The BigQuery identifier quoting semantics are bonkers + https://cloud.google.com/bigquery/docs/reference/standard-sql/lexical#identifiers + + my-table is OK, but not mydataset.my-table + + mytable-287 is OK, but not mytable-287a + + Just quote everything. + """ + for key in ("this", "db", "catalog"): + if (val := table.args[key]) is not None: + if isinstance(val, sg.exp.Identifier) and not val.quoted: + val.args["quoted"] = True + else: + table.args[key] = sg.to_identifier(val, quoted=True) + return table + + +class Backend(SQLBackend, CanCreateDatabase, CanCreateSchema): + name = "bigquery" + compiler = sc.bigquery.compiler + supports_in_memory_tables = True + supports_python_udfs = False + + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + self.__session_dataset: bq.DatasetReference | None = None + self._query_cache.lookup = lambda name: self.table( + name, + database=(self._session_dataset.project, self._session_dataset.dataset_id), + ).op() + + @property + def _session_dataset(self): + if self.__session_dataset is None: + self.__session_dataset = self._make_session() + return self.__session_dataset + + def _register_in_memory_table(self, op: ops.InMemoryTable) -> None: + raw_name = op.name + + session_dataset = self._session_dataset + project = session_dataset.project + dataset = session_dataset.dataset_id + + table_ref = bq.TableReference(session_dataset, raw_name) + try: + self.client.get_table(table_ref) + except google.api_core.exceptions.NotFound: + table_id = sg.table( + raw_name, db=dataset, catalog=project, quoted=False + ).sql(dialect=self.name) + bq_schema = BigQuerySchema.from_ibis(op.schema) + load_job = self.client.load_table_from_dataframe( + op.data.to_frame(), + table_id, + job_config=bq.LoadJobConfig( + # fail if the table already exists and contains data + write_disposition=bq.WriteDisposition.WRITE_EMPTY, + schema=bq_schema, + ), + ) + load_job.result() + + def _read_file( + self, + path: str | Path, + *, + table_name: str | None = None, + job_config: bq.LoadJobConfig, + ) -> ir.Table: + self._make_session() + + if table_name is None: + table_name = util.gen_name(f"bq_read_{job_config.source_format}") + + table_ref = self._session_dataset.table(table_name) + + database = self._session_dataset.dataset_id + catalog = self._session_dataset.project + + # drop the table if it exists + # + # we could do this with write_disposition = WRITE_TRUNCATE but then the + # concurrent append jobs aren't possible + # + # dropping the table first means all write_dispositions can be + # WRITE_APPEND + self.drop_table(table_name, database=(catalog, database), force=True) + + if os.path.isdir(path): + raise NotImplementedError("Reading from a directory is not supported.") + elif str(path).startswith("gs://"): + load_job = self.client.load_table_from_uri( + path, table_ref, job_config=job_config + ) + load_job.result() + else: + + def load(file: str) -> None: + with open(file, mode="rb") as f: + load_job = self.client.load_table_from_file( + f, table_ref, job_config=job_config + ) + load_job.result() + + job_config.write_disposition = bq.WriteDisposition.WRITE_APPEND + + with concurrent.futures.ThreadPoolExecutor() as executor: + for fut in concurrent.futures.as_completed( + executor.submit(load, file) for file in glob.glob(str(path)) + ): + fut.result() + + return self.table(table_name, database=(catalog, database)) + + def read_parquet( + self, path: str | Path, table_name: str | None = None, **kwargs: Any + ): + """Read Parquet data into a BigQuery table. + + Parameters + ---------- + path + Path to a Parquet file on GCS or the local filesystem. Globs are supported. + table_name + Optional table name + kwargs + Additional keyword arguments passed to `google.cloud.bigquery.LoadJobConfig`. + + Returns + ------- + Table + An Ibis table expression + + """ + return self._read_file( + path, + table_name=table_name, + job_config=bq.LoadJobConfig( + source_format=bq.SourceFormat.PARQUET, **kwargs + ), + ) + + def read_csv( + self, path: str | Path, table_name: str | None = None, **kwargs: Any + ) -> ir.Table: + """Read CSV data into a BigQuery table. + + Parameters + ---------- + path + Path to a CSV file on GCS or the local filesystem. Globs are supported. + table_name + Optional table name + kwargs + Additional keyword arguments passed to + `google.cloud.bigquery.LoadJobConfig`. + + Returns + ------- + Table + An Ibis table expression + + """ + job_config = bq.LoadJobConfig( + source_format=bq.SourceFormat.CSV, + autodetect=True, + skip_leading_rows=1, + **kwargs, + ) + return self._read_file(path, table_name=table_name, job_config=job_config) + + def read_json( + self, path: str | Path, table_name: str | None = None, **kwargs: Any + ) -> ir.Table: + """Read newline-delimited JSON data into a BigQuery table. + + Parameters + ---------- + path + Path to a newline-delimited JSON file on GCS or the local + filesystem. Globs are supported. + table_name + Optional table name + kwargs + Additional keyword arguments passed to + `google.cloud.bigquery.LoadJobConfig`. + + Returns + ------- + Table + An Ibis table expression + + """ + job_config = bq.LoadJobConfig( + source_format=bq.SourceFormat.NEWLINE_DELIMITED_JSON, + autodetect=True, + **kwargs, + ) + return self._read_file(path, table_name=table_name, job_config=job_config) + + def _from_url(self, url: ParseResult, **kwargs): + return self.connect( + project_id=url.netloc or kwargs.get("project_id", [""])[0], + dataset_id=url.path[1:] or kwargs.get("dataset_id", [""])[0], + **kwargs, + ) + + def do_connect( + self, + project_id: str | None = None, + dataset_id: str = "", + credentials: google.auth.credentials.Credentials | None = None, + application_name: str | None = None, + auth_local_webserver: bool = True, + auth_external_data: bool = False, + auth_cache: str = "default", + partition_column: str | None = "PARTITIONTIME", + client: bq.Client | None = None, + storage_client: bqstorage.BigQueryReadClient | None = None, + location: str | None = None, + ) -> Backend: + """Create a `Backend` for use with Ibis. + + Parameters + ---------- + project_id + A BigQuery project id. + dataset_id + A dataset id that lives inside of the project indicated by + `project_id`. + credentials + Optional credentials. + application_name + A string identifying your application to Google API endpoints. + auth_local_webserver + Use a local webserver for the user authentication. Binds a + webserver to an open port on localhost between 8080 and 8089, + inclusive, to receive authentication token. If not set, defaults to + False, which requests a token via the console. + auth_external_data + Authenticate using additional scopes required to `query external + data sources + `_, + such as Google Sheets, files in Google Cloud Storage, or files in + Google Drive. If not set, defaults to False, which requests the + default BigQuery scopes. + auth_cache + Selects the behavior of the credentials cache. + + `'default'`` + Reads credentials from disk if available, otherwise + authenticates and caches credentials to disk. + + `'reauth'`` + Authenticates and caches credentials to disk. + + `'none'`` + Authenticates and does **not** cache credentials. + + Defaults to `'default'`. + partition_column + Identifier to use instead of default `_PARTITIONTIME` partition + column. Defaults to `'PARTITIONTIME'`. + client + A `Client` from the `google.cloud.bigquery` package. If not + set, one is created using the `project_id` and `credentials`. + storage_client + A `BigQueryReadClient` from the + `google.cloud.bigquery_storage_v1` package. If not set, one is + created using the `project_id` and `credentials`. + location + Default location for BigQuery objects. + + Returns + ------- + Backend + An instance of the BigQuery backend. + + """ + default_project_id = client.project if client is not None else project_id + + # Only need `credentials` to create a `client` and + # `storage_client`, so only one or the other needs to be set. + if (client is None or storage_client is None) and credentials is None: + scopes = SCOPES + if auth_external_data: + scopes = EXTERNAL_DATA_SCOPES + + if auth_cache == "default": + credentials_cache = cache.ReadWriteCredentialsCache( + filename="ibis.json" + ) + elif auth_cache == "reauth": + credentials_cache = cache.WriteOnlyCredentialsCache( + filename="ibis.json" + ) + elif auth_cache == "none": + credentials_cache = cache.NOOP + else: + raise ValueError( + f"Got unexpected value for auth_cache = '{auth_cache}'. " + "Expected one of 'default', 'reauth', or 'none'." + ) + + credentials, default_project_id = pydata_google_auth.default( + scopes, + client_id=CLIENT_ID, + client_secret=CLIENT_SECRET, + credentials_cache=credentials_cache, + use_local_webserver=auth_local_webserver, + ) + + project_id = project_id or default_project_id + + ( + self.data_project, + self.billing_project, + self.dataset, + ) = parse_project_and_dataset(project_id, dataset_id) + + if client is not None: + self.client = client + else: + self.client = bq.Client( + project=self.billing_project, + credentials=credentials, + client_info=_create_client_info(application_name), + location=location, + ) + + if self.client.default_query_job_config is None: + self.client.default_query_job_config = bq.QueryJobConfig() + + self.client.default_query_job_config.use_legacy_sql = False + self.client.default_query_job_config.allow_large_results = True + + if storage_client is not None: + self.storage_client = storage_client + else: + self.storage_client = bqstorage.BigQueryReadClient( + credentials=credentials, + client_info=_create_client_info_gapic(application_name), + ) + + self.partition_column = partition_column + + @util.experimental + @classmethod + def from_connection( + cls, + client: bq.Client, + partition_column: str | None = "PARTITIONTIME", + storage_client: bqstorage.BigQueryReadClient | None = None, + dataset_id: str = "", + ) -> Backend: + """Create a BigQuery `Backend` from an existing `Client`. + + Parameters + ---------- + client + A `Client` from the `google.cloud.bigquery` package. + partition_column + Identifier to use instead of default `_PARTITIONTIME` partition + column. Defaults to `'PARTITIONTIME'`. + storage_client + A `BigQueryReadClient` from the `google.cloud.bigquery_storage_v1` + package. + dataset_id + A dataset id that lives inside of the project attached to `client`. + """ + return ibis.bigquery.connect( + client=client, + partition_column=partition_column, + storage_client=storage_client, + dataset_id=dataset_id, + ) + + def disconnect(self) -> None: + self.client.close() + + def _parse_project_and_dataset(self, dataset) -> tuple[str, str]: + if isinstance(dataset, sge.Table): + dataset = dataset.sql(self.dialect) + if not dataset and not self.dataset: + raise ValueError("Unable to determine BigQuery dataset.") + project, _, dataset = parse_project_and_dataset( + self.billing_project, + dataset or f"{self.data_project}.{self.dataset}", + ) + return project, dataset + + @property + def project_id(self): + return self.data_project + + @property + def dataset_id(self): + return self.dataset + + def create_database( + self, + name: str, + catalog: str | None = None, + force: bool = False, + collate: str | None = None, + **options: Any, + ) -> None: + properties = [ + sge.Property(this=sg.to_identifier(name), value=sge.convert(value)) + for name, value in (options or {}).items() + ] + + if collate is not None: + properties.append( + sge.CollateProperty(this=sge.convert(collate), default=True) + ) + + stmt = sge.Create( + kind="SCHEMA", + this=sg.table(name, db=catalog), + exists=force, + properties=sge.Properties(expressions=properties), + ) + + self.raw_sql(stmt.sql(self.name)) + + def drop_database( + self, + name: str, + catalog: str | None = None, + force: bool = False, + cascade: bool = False, + ) -> None: + """Drop a BigQuery dataset.""" + stmt = sge.Drop( + kind="SCHEMA", + this=sg.table(name, db=catalog), + exists=force, + cascade=cascade, + ) + + self.raw_sql(stmt.sql(self.name)) + + def table( + self, name: str, database: str | None = None, schema: str | None = None + ) -> ir.Table: + table_loc = self._warn_and_create_table_loc(database, schema) + table = sg.parse_one(f"`{name}`", into=sge.Table, read=self.name) + + # Bigquery, unlike other backends, had existing support for specifying + # table hierarchy in the table name, e.g. con.table("dataset.table_name") + # so here we have an extra layer of disambiguation to handle. + + # Default `catalog` to None unless we've parsed it out of the database/schema kwargs + # Raise if there are path specifications in both the name and as a kwarg + catalog = table_loc.args["catalog"] # args access will return None, not '' + if table.catalog: + if table_loc.catalog: + raise com.IbisInputError( + "Cannot specify catalog both in the table name and as an argument" + ) + else: + catalog = table.catalog + + # Default `db` to None unless we've parsed it out of the database/schema kwargs + db = table_loc.args["db"] # args access will return None, not '' + if table.db: + if table_loc.db: + raise com.IbisInputError( + "Cannot specify database both in the table name and as an argument" + ) + else: + db = table.db + + database = ( + sg.table(None, db=db, catalog=catalog, quoted=False).sql(dialect=self.name) + or None + ) + + project, dataset = self._parse_project_and_dataset(database) + + bq_table = self.client.get_table( + bq.TableReference( + bq.DatasetReference(project=project, dataset_id=dataset), + table.name, + ) + ) + + node = ops.DatabaseTable( + table.name, + # https://cloud.google.com/bigquery/docs/querying-wildcard-tables#filtering_selected_tables_using_table_suffix + schema=schema_from_bigquery_table(bq_table, wildcard=table.name[-1] == "*"), + source=self, + namespace=ops.Namespace(database=dataset, catalog=project), + ) + table_expr = node.to_expr() + return rename_partitioned_column(table_expr, bq_table, self.partition_column) + + def _make_session(self) -> tuple[str, str]: + if (client := getattr(self, "client", None)) is not None: + job_config = bq.QueryJobConfig(use_query_cache=False) + query = client.query( + "SELECT 1", job_config=job_config, project=self.billing_project + ) + query.result() + + return bq.DatasetReference( + project=query.destination.project, + dataset_id=query.destination.dataset_id, + ) + return None + + def _get_schema_using_query(self, query: str) -> sch.Schema: + job = self.client.query( + query, + job_config=bq.QueryJobConfig(dry_run=True, use_query_cache=False), + project=self.billing_project, + ) + return BigQuerySchema.to_ibis(job.schema) + + def raw_sql(self, query: str, params=None, page_size: int | None = None): + query_parameters = [ + bigquery_param( + param.type(), + value, + ( + param.get_name() + if not isinstance(op := param.op(), ops.Alias) + else op.arg.name + ), + ) + for param, value in (params or {}).items() + ] + with contextlib.suppress(AttributeError): + query = query.sql(self.dialect) + + job_config = bq.job.QueryJobConfig(query_parameters=query_parameters or []) + return self.client.query_and_wait( + query, + job_config=job_config, + project=self.billing_project, + page_size=page_size, + ) + + @property + def current_catalog(self) -> str: + return self.data_project + + @property + def current_database(self) -> str | None: + return self.dataset + + def compile( + self, + expr: ir.Expr, + limit: str | None = None, + params=None, + pretty: bool = True, + **kwargs: Any, + ): + """Compile an Ibis expression to a SQL string.""" + session_dataset = self._session_dataset + query = self.compiler.to_sqlglot( + expr, + limit=limit, + params=params, + session_dataset_id=getattr(session_dataset, "dataset_id", None), + session_project=getattr(session_dataset, "project", None), + **kwargs, + ) + queries = query if isinstance(query, list) else [query] + sql = ";\n".join(query.sql(self.dialect, pretty=pretty) for query in queries) + self._log(sql) + return sql + + def execute(self, expr, params=None, limit="default", **kwargs): + """Compile and execute the given Ibis expression. + + Compile and execute Ibis expression using this backend client + interface, returning results in-memory in the appropriate object type + + Parameters + ---------- + expr + Ibis expression to execute + limit + Retrieve at most this number of values/rows. Overrides any limit + already set on the expression. + params + Query parameters + kwargs + Extra arguments specific to the backend + + Returns + ------- + pd.DataFrame | pd.Series | scalar + Output from execution + + """ + from ibis.backends.bigquery.converter import BigQueryPandasData + + self._run_pre_execute_hooks(expr) + + schema = expr.as_table().schema() - ibis.schema({"_TABLE_SUFFIX": "string"}) + + sql = self.compile(expr, limit=limit, params=params, **kwargs) + self._log(sql) + query = self.raw_sql(sql, params=params, **kwargs) + + arrow_t = query.to_arrow( + progress_bar_type=None, bqstorage_client=self.storage_client + ) + + result = BigQueryPandasData.convert_table( + arrow_t.to_pandas(timestamp_as_object=True), schema + ) + + return expr.__pandas_result__(result, schema=schema) + + def insert( + self, + table_name: str, + obj: pd.DataFrame | ir.Table | list | dict, + schema: str | None = None, + database: str | None = None, + overwrite: bool = False, + ): + """Insert data into a table. + + Parameters + ---------- + table_name + The name of the table to which data needs will be inserted + obj + The source data or expression to insert + schema + The name of the schema that the table is located in + database + Name of the attached database that the table is located in. + overwrite + If `True` then replace existing contents of table + + """ + table_loc = self._warn_and_create_table_loc(database, schema) + catalog, db = self._to_catalog_db_tuple(table_loc) + if catalog is None: + catalog = self.current_catalog + if db is None: + db = self.current_database + + return super().insert( + table_name, + obj, + database=(catalog, db), + overwrite=overwrite, + ) + + def to_pyarrow( + self, + expr: ir.Expr, + *, + params: Mapping[ir.Scalar, Any] | None = None, + limit: int | str | None = None, + **kwargs: Any, + ) -> pa.Table: + self._import_pyarrow() + self._register_in_memory_tables(expr) + sql = self.compile(expr, limit=limit, params=params, **kwargs) + self._log(sql) + query = self.raw_sql(sql, params=params, **kwargs) + table = query.to_arrow( + progress_bar_type=None, bqstorage_client=self.storage_client + ) + table = table.rename_columns(list(expr.as_table().schema().names)) + return expr.__pyarrow_result__(table) + + def to_pyarrow_batches( + self, + expr: ir.Expr, + *, + params: Mapping[ir.Scalar, Any] | None = None, + limit: int | str | None = None, + chunk_size: int = 1_000_000, + **kwargs: Any, + ): + pa = self._import_pyarrow() + + schema = expr.as_table().schema() + + self._register_in_memory_tables(expr) + sql = self.compile(expr, limit=limit, params=params, **kwargs) + self._log(sql) + query = self.raw_sql(sql, params=params, page_size=chunk_size, **kwargs) + batch_iter = query.to_arrow_iterable(bqstorage_client=self.storage_client) + return pa.ipc.RecordBatchReader.from_batches(schema.to_pyarrow(), batch_iter) + + def _gen_udf_name(self, name: str, schema: Optional[str]) -> str: + func = ".".join(filter(None, (schema, name))) + if "." in func: + return ".".join(f"`{part}`" for part in func.split(".")) + return func + + def get_schema( + self, + name, + *, + catalog: str | None = None, + database: str | None = None, + ): + table_ref = bq.TableReference( + bq.DatasetReference( + project=catalog or self.data_project, + dataset_id=database or self.current_database, + ), + name, + ) + return schema_from_bigquery_table( + self.client.get_table(table_ref), + # https://cloud.google.com/bigquery/docs/querying-wildcard-tables#filtering_selected_tables_using_table_suffix + wildcard=name[-1] == "*", + ) + + def list_databases( + self, like: str | None = None, catalog: str | None = None + ) -> list[str]: + results = [ + dataset.dataset_id + for dataset in self.client.list_datasets( + project=catalog if catalog is not None else self.data_project + ) + ] + return self._filter_with_like(results, like) + + def list_tables( + self, + like: str | None = None, + database: tuple[str, str] | str | None = None, + schema: str | None = None, + ) -> list[str]: + """List the tables in the database. + + Parameters + ---------- + like + A pattern to use for listing tables. + database + The database location to perform the list against. + + By default uses the current `dataset` (`self.current_database`) and + `project` (`self.current_catalog`). + + To specify a table in a separate BigQuery dataset, you can pass in the + dataset and project as a string `"dataset.project"`, or as a tuple of + strings `("dataset", "project")`. + + ::: {.callout-note} + ## Ibis does not use the word `schema` to refer to database hierarchy. + + A collection of tables is referred to as a `database`. + A collection of `database` is referred to as a `catalog`. + + These terms are mapped onto the corresponding features in each + backend (where available), regardless of whether the backend itself + uses the same terminology. + ::: + schema + [deprecated] The schema (dataset) inside `database` to perform the list against. + """ + table_loc = self._warn_and_create_table_loc(database, schema) + + project, dataset = self._parse_project_and_dataset(table_loc) + dataset_ref = bq.DatasetReference(project, dataset) + result = [table.table_id for table in self.client.list_tables(dataset_ref)] + return self._filter_with_like(result, like) + + def set_database(self, name): + self.data_project, self.dataset = self._parse_project_and_dataset(name) + + @property + def version(self): + return bq.__version__ + + def create_table( + self, + name: str, + obj: ir.Table + | pd.DataFrame + | pa.Table + | pl.DataFrame + | pl.LazyFrame + | None = None, + *, + schema: ibis.Schema | None = None, + database: str | None = None, + temp: bool = False, + overwrite: bool = False, + default_collate: str | None = None, + partition_by: str | None = None, + cluster_by: Iterable[str] | None = None, + options: Mapping[str, Any] | None = None, + ) -> ir.Table: + """Create a table in BigQuery. + + Parameters + ---------- + name + Name of the table to create + obj + The data with which to populate the table; optional, but one of `obj` + or `schema` must be specified + schema + The schema of the table to create; optional, but one of `obj` or + `schema` must be specified + database + The BigQuery *dataset* in which to create the table; optional + temp + Whether the table is temporary + overwrite + If `True`, replace the table if it already exists, otherwise fail if + the table exists + default_collate + Default collation for string columns. See BigQuery's documentation + for more details: https://cloud.google.com/bigquery/docs/reference/standard-sql/collation-concepts + partition_by + Partition the table by the given expression. See BigQuery's documentation + for more details: https://cloud.google.com/bigquery/docs/reference/standard-sql/data-definition-language#partition_expression + cluster_by + List of columns to cluster the table by. See BigQuery's documentation + for more details: https://cloud.google.com/bigquery/docs/reference/standard-sql/data-definition-language#clustering_column_list + options + BigQuery-specific table options; see the BigQuery documentation for + details: https://cloud.google.com/bigquery/docs/reference/standard-sql/data-definition-language#table_option_list + + Returns + ------- + Table + The table that was just created + + """ + if obj is None and schema is None: + raise com.IbisError("One of the `schema` or `obj` parameter is required") + + if isinstance(obj, ir.Table) and schema is not None: + if not schema.equals(obj.schema()): + raise com.IbisTypeError( + "Provided schema and Ibis table schema are incompatible. Please " + "align the two schemas, or provide only one of the two arguments." + ) + + project_id, dataset = self._parse_project_and_dataset(database) + + properties = [] + + if default_collate is not None: + properties.append( + sge.CollateProperty(this=sge.convert(default_collate), default=True) + ) + + if partition_by is not None: + properties.append( + sge.PartitionedByProperty( + this=sge.Tuple( + expressions=list(map(sg.to_identifier, partition_by)) + ) + ) + ) + + if cluster_by is not None: + properties.append( + sge.Cluster(expressions=list(map(sg.to_identifier, cluster_by))) + ) + + properties.extend( + sge.Property(this=sg.to_identifier(name), value=sge.convert(value)) + for name, value in (options or {}).items() + ) + + if obj is not None and not isinstance(obj, ir.Table): + obj = ibis.memtable(obj, schema=schema) + + if obj is not None: + self._register_in_memory_tables(obj) + + if temp: + dataset = self._session_dataset.dataset_id + if database is not None: + raise com.IbisInputError("Cannot specify database for temporary table") + database = self._session_dataset.project + else: + dataset = database or self.current_database + + try: + table = sg.parse_one(name, into=sge.Table, read="bigquery") + except sg.ParseError: + table = sg.table( + name, + db=dataset, + catalog=project_id, + quoted=self.compiler.quoted, + ) + else: + if table.args["db"] is None: + table.args["db"] = dataset + + if table.args["catalog"] is None: + table.args["catalog"] = project_id + + table = _force_quote_table(table) + + column_defs = [ + sge.ColumnDef( + this=sg.to_identifier(name, quoted=self.compiler.quoted), + kind=BigQueryType.from_ibis(typ), + constraints=( + None + if typ.nullable or typ.is_array() + else [sge.ColumnConstraint(kind=sge.NotNullColumnConstraint())] + ), + ) + for name, typ in (schema or {}).items() + ] + + stmt = sge.Create( + kind="TABLE", + this=sge.Schema(this=table, expressions=column_defs or None), + replace=overwrite, + properties=sge.Properties(expressions=properties), + expression=None if obj is None else self.compile(obj), + ) + + sql = stmt.sql(self.name) + + self.raw_sql(sql) + return self.table(table.name, database=(table.catalog, table.db)) + + def drop_table( + self, + name: str, + *, + schema: str | None = None, + database: tuple[str | str] | str | None = None, + force: bool = False, + ) -> None: + table_loc = self._warn_and_create_table_loc(database, schema) + catalog, db = self._to_catalog_db_tuple(table_loc) + stmt = sge.Drop( + kind="TABLE", + this=sg.table( + name, + db=db or self.current_database, + catalog=catalog or self.billing_project, + ), + exists=force, + ) + self.raw_sql(stmt.sql(self.name)) + + def create_view( + self, + name: str, + obj: ir.Table, + *, + schema: str | None = None, + database: str | None = None, + overwrite: bool = False, + ) -> ir.Table: + table_loc = self._warn_and_create_table_loc(database, schema) + catalog, db = self._to_catalog_db_tuple(table_loc) + + stmt = sge.Create( + kind="VIEW", + this=sg.table( + name, + db=db or self.current_database, + catalog=catalog or self.billing_project, + ), + expression=self.compile(obj), + replace=overwrite, + ) + self._register_in_memory_tables(obj) + self.raw_sql(stmt.sql(self.name)) + return self.table(name, database=(catalog, database)) + + def drop_view( + self, + name: str, + *, + schema: str | None = None, + database: str | None = None, + force: bool = False, + ) -> None: + table_loc = self._warn_and_create_table_loc(database, schema) + catalog, db = self._to_catalog_db_tuple(table_loc) + + stmt = sge.Drop( + kind="VIEW", + this=sg.table( + name, + db=db or self.current_database, + catalog=catalog or self.billing_project, + ), + exists=force, + ) + self.raw_sql(stmt.sql(self.name)) + + def _load_into_cache(self, name, expr): + self.create_table(name, expr, schema=expr.schema(), temp=True) + + def _clean_up_cached_table(self, name): + self.drop_table( + name, + database=(self._session_dataset.project, self._session_dataset.dataset_id), + force=True, + ) + + def _register_udfs(self, expr: ir.Expr) -> None: + """No op because UDFs made with CREATE TEMPORARY FUNCTION must be followed by a query.""" + + @contextlib.contextmanager + def _safe_raw_sql(self, *args, **kwargs): + yield self.raw_sql(*args, **kwargs) + + # TODO: remove when the schema kwarg is removed + def _warn_and_create_table_loc(self, database=None, schema=None): + if schema is not None: + self._warn_schema() + if database is not None and schema is not None: + if isinstance(database, str): + table_loc = f"{database}.{schema}" + elif isinstance(database, tuple): + table_loc = database + schema + elif schema is not None: + table_loc = schema + elif database is not None: + table_loc = database + else: + table_loc = None + + table_loc = self._to_sqlglot_table(table_loc) + + if table_loc is not None: + if (sg_cat := table_loc.args["catalog"]) is not None: + sg_cat.args["quoted"] = False + if (sg_db := table_loc.args["db"]) is not None: + sg_db.args["quoted"] = False + + return table_loc + + +def compile(expr, params=None, **kwargs): + """Compile an expression for BigQuery.""" + backend = Backend() + return backend.compile(expr, params=params, **kwargs) + + +def connect( + project_id: str | None = None, + dataset_id: str = "", + credentials: google.auth.credentials.Credentials | None = None, + application_name: str | None = None, + auth_local_webserver: bool = False, + auth_external_data: bool = False, + auth_cache: str = "default", + partition_column: str | None = "PARTITIONTIME", +) -> Backend: + """Create a :class:`Backend` for use with Ibis. + + Parameters + ---------- + project_id + A BigQuery project id. + dataset_id + A dataset id that lives inside of the project indicated by + `project_id`. + credentials + Optional credentials. + application_name + A string identifying your application to Google API endpoints. + auth_local_webserver + Use a local webserver for the user authentication. Binds a + webserver to an open port on localhost between 8080 and 8089, + inclusive, to receive authentication token. If not set, defaults + to False, which requests a token via the console. + auth_external_data + Authenticate using additional scopes required to `query external + data sources + `_, + such as Google Sheets, files in Google Cloud Storage, or files in + Google Drive. If not set, defaults to False, which requests the + default BigQuery scopes. + auth_cache + Selects the behavior of the credentials cache. + + `'default'`` + Reads credentials from disk if available, otherwise + authenticates and caches credentials to disk. + + `'reauth'`` + Authenticates and caches credentials to disk. + + `'none'`` + Authenticates and does **not** cache credentials. + + Defaults to `'default'`. + partition_column + Identifier to use instead of default `_PARTITIONTIME` partition + column. Defaults to `'PARTITIONTIME'`. + + Returns + ------- + Backend + An instance of the BigQuery backend + + """ + backend = Backend() + return backend.connect( + project_id=project_id, + dataset_id=dataset_id, + credentials=credentials, + application_name=application_name, + auth_local_webserver=auth_local_webserver, + auth_external_data=auth_external_data, + auth_cache=auth_cache, + partition_column=partition_column, + ) + + +__all__ = [ + "Backend", + "compile", + "connect", +] diff --git a/third_party/bigframes_vendored/ibis/backends/bigquery/compiler.py b/third_party/bigframes_vendored/ibis/backends/bigquery/compiler.py deleted file mode 100644 index 4f20779c0e..0000000000 --- a/third_party/bigframes_vendored/ibis/backends/bigquery/compiler.py +++ /dev/null @@ -1,218 +0,0 @@ -# Contains code from https://github.com/ibis-project/ibis/blob/master/ibis/backends/bigquery/compiler.py -"""Module to convert from Ibis expression to SQL string.""" - -from __future__ import annotations - -import bigframes_vendored.ibis.backends.bigquery.datatypes as bq_datatypes -import ibis.common.exceptions as com -from ibis.common.temporal import IntervalUnit -import ibis.expr.operations.reductions as ibis_reductions -import numpy as np -import sqlglot as sg -import sqlglot.expressions as sge - -# The compilers moved between ibis-framework 9.1 and 9.2. -try: - # ibis-framework 9.0 & 9.1 - import ibis.backends.bigquery.compiler as bq_compiler -except ImportError: - # ibis-framework 9.2 - import ibis.backends.sql.compilers.bigquery as bq_compiler - -try: - # ibis-framework 9.0 & 9.1 - import ibis.backends.sql.compiler as sql_compiler -except ImportError: - # ibis-framework 9.2 - import ibis.backends.sql.compilers.base as sql_compiler - - -class BigQueryCompiler(bq_compiler.BigQueryCompiler): - UNSUPPORTED_OPS = ( - tuple( - op - for op in bq_compiler.BigQueryCompiler.UNSUPPORTED_OPS - if op != ibis_reductions.Quantile - ) - if hasattr(bq_compiler.BigQueryCompiler, "UNSUPPORTED_OPS") - else () - ) - - def visit_InMemoryTable(self, op, *, name, schema, data): - # Avoid creating temp tables for small data, which is how memtable is - # used in BigQuery DataFrames. Inspired by: - # https://github.com/ibis-project/ibis/blob/efa6fb72bf4c790450d00a926d7bd809dade5902/ibis/backends/druid/compiler.py#L95 - tuples = data.to_frame().itertuples(index=False) - quoted = self.quoted - columns = [sg.column(col, quoted=quoted) for col in schema.names] - array_expr = sge.DataType( - this=sge.DataType.Type.STRUCT, - expressions=[ - sge.ColumnDef( - this=sge.to_identifier(field, quoted=self.quoted), - kind=bq_datatypes.BigQueryType.from_ibis(type_), - ) - for field, type_ in zip(schema.names, schema.types) - ], - nested=True, - ) - array_values = [ - sge.Tuple( - expressions=tuple( - self.visit_Literal(None, value=value, dtype=type_) - for value, type_ in zip(row, schema.types) - ) - ) - for row in tuples - ] - expr = sge.Unnest( - expressions=[ - sge.DataType( - this=sge.DataType.Type.ARRAY, - expressions=[array_expr], - nested=True, - values=array_values, - ), - ], - alias=sge.TableAlias( - this=sg.to_identifier(name, quoted=quoted), - columns=columns, - ), - ) - # return expr - return sg.select(sge.Star()).from_(expr) - - def visit_NonNullLiteral(self, op, *, value, dtype): - # Patch from https://github.com/ibis-project/ibis/pull/9610 to support ibis 9.0.0 and 9.1.0 - if dtype.is_inet() or dtype.is_macaddr(): - return sge.convert(str(value)) - elif dtype.is_timestamp(): - funcname = "DATETIME" if dtype.timezone is None else "TIMESTAMP" - return self.f.anon[funcname](value.isoformat()) - elif dtype.is_date(): - return self.f.date_from_parts(value.year, value.month, value.day) - elif dtype.is_time(): - time = self.f.time_from_parts(value.hour, value.minute, value.second) - if micros := value.microsecond: - # bigquery doesn't support `time(12, 34, 56.789101)`, AKA a - # float seconds specifier, so add any non-zero micros to the - # time value - return sge.TimeAdd( - this=time, expression=sge.convert(micros), unit=self.v.MICROSECOND - ) - return time - elif dtype.is_binary(): - return sge.Cast( - this=sge.convert(value.hex()), - to=sge.DataType(this=sge.DataType.Type.BINARY), - format=sge.convert("HEX"), - ) - elif dtype.is_interval(): - if dtype.unit == IntervalUnit.NANOSECOND: - raise com.UnsupportedOperationError( - "BigQuery does not support nanosecond intervals" - ) - elif dtype.is_uuid(): - return sge.convert(str(value)) - elif dtype.is_int64(): - return sge.convert(np.int64(value)) - return None - - # Custom operators. - - def visit_ArrayAggregate(self, op, *, arg, order_by, where): - if len(order_by) > 0: - expr = sge.Order( - this=arg, - expressions=[ - # Avoid adding NULLS FIRST / NULLS LAST in SQL, which is - # unsupported in ARRAY_AGG by reconstructing the node as - # plain SQL text. - f"({order_column.args['this'].sql(dialect='bigquery')}) {'DESC' if order_column.args.get('desc') else 'ASC'}" - for order_column in order_by - ], - ) - else: - expr = arg - return sge.IgnoreNulls(this=self.agg.array_agg(expr, where=where)) - - def visit_FirstNonNullValue(self, op, *, arg): - return sge.IgnoreNulls(this=sge.FirstValue(this=arg)) - - def visit_LastNonNullValue(self, op, *, arg): - return sge.IgnoreNulls(this=sge.LastValue(this=arg)) - - def visit_ToJsonString(self, op, *, arg): - return self.f.to_json_string(arg) - - def visit_Quantile(self, op, *, arg, quantile, where): - return sge.PercentileCont(this=arg, expression=quantile) - - def visit_WindowFunction(self, op, *, how, func, start, end, group_by, order_by): - # Patch for https://github.com/ibis-project/ibis/issues/9872 - if start is None and end is None: - spec = None - else: - if start is None: - start = {} - if end is None: - end = {} - - start_value = start.get("value", "UNBOUNDED") - start_side = start.get("side", "PRECEDING") - end_value = end.get("value", "UNBOUNDED") - end_side = end.get("side", "FOLLOWING") - - if getattr(start_value, "this", None) == "0": - start_value = "CURRENT ROW" - start_side = None - - if getattr(end_value, "this", None) == "0": - end_value = "CURRENT ROW" - end_side = None - - spec = sge.WindowSpec( - kind=how.upper(), - start=start_value, - start_side=start_side, - end=end_value, - end_side=end_side, - over="OVER", - ) - spec = self._minimize_spec(op.start, op.end, spec) - - order = sge.Order(expressions=order_by) if order_by else None - - return sge.Window(this=func, partition_by=group_by, order=order, spec=spec) - - @sql_compiler.parenthesize_inputs - def visit_And(self, op, *, left, right): - return sge.And(this=sge.Paren(this=left), expression=sge.Paren(this=right)) - - -# Override implementation. -# We monkeypatch individual methods because the class might have already been imported in other modules. -bq_compiler.BigQueryCompiler.visit_InMemoryTable = BigQueryCompiler.visit_InMemoryTable -bq_compiler.BigQueryCompiler.visit_NonNullLiteral = ( - BigQueryCompiler.visit_NonNullLiteral -) - -# Custom operators. -bq_compiler.BigQueryCompiler.visit_ArrayAggregate = ( - BigQueryCompiler.visit_ArrayAggregate -) -bq_compiler.BigQueryCompiler.visit_FirstNonNullValue = ( - BigQueryCompiler.visit_FirstNonNullValue -) -bq_compiler.BigQueryCompiler.visit_LastNonNullValue = ( - BigQueryCompiler.visit_LastNonNullValue -) -bq_compiler.BigQueryCompiler.visit_ToJsonString = BigQueryCompiler.visit_ToJsonString -bq_compiler.BigQueryCompiler.visit_Quantile = BigQueryCompiler.visit_Quantile -bq_compiler.BigQueryCompiler.visit_WindowFunction = ( - BigQueryCompiler.visit_WindowFunction -) -bq_compiler.BigQueryCompiler.visit_And = BigQueryCompiler.visit_And - -# TODO(swast): sqlglot base implementation appears to work fine for the bigquery backend, at least in our windowed contexts. See: ISSUE NUMBER -bq_compiler.BigQueryCompiler.UNSUPPORTED_OPS = BigQueryCompiler.UNSUPPORTED_OPS diff --git a/third_party/bigframes_vendored/ibis/backends/sql/__init__.py b/third_party/bigframes_vendored/ibis/backends/sql/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/third_party/bigframes_vendored/ibis/backends/sql/compilers/__init__.py b/third_party/bigframes_vendored/ibis/backends/sql/compilers/__init__.py new file mode 100644 index 0000000000..3a2cc8b51a --- /dev/null +++ b/third_party/bigframes_vendored/ibis/backends/sql/compilers/__init__.py @@ -0,0 +1,5 @@ +import bigframes_vendored.ibis.backends.sql.compilers.bigquery as bigquery + +__all__ = [ + "bigquery", +] diff --git a/third_party/bigframes_vendored/ibis/backends/sql/compilers/base.py b/third_party/bigframes_vendored/ibis/backends/sql/compilers/base.py new file mode 100644 index 0000000000..5003c3e4cd --- /dev/null +++ b/third_party/bigframes_vendored/ibis/backends/sql/compilers/base.py @@ -0,0 +1,1648 @@ +# Contains code from https://github.com/ibis-project/ibis/blob/main/ibis/backends/sql/compilers/base.py + +from __future__ import annotations + +import abc +import calendar +from functools import partial, reduce +import itertools +import math +import operator +import string +from typing import Any, ClassVar, TYPE_CHECKING + +from bigframes_vendored.ibis.backends.sql.rewrites import ( + add_one_to_nth_value_input, + add_order_by_to_empty_ranking_window_functions, + empty_in_values_right_side, + FirstValue, + LastValue, + lower_bucket, + lower_capitalize, + lower_sample, + one_to_zero_index, + sqlize, +) +from bigframes_vendored.ibis.expr.rewrites import lower_stringslice +import ibis.common.exceptions as com +import ibis.common.patterns as pats +from ibis.config import options +import ibis.expr.datatypes as dt +import ibis.expr.operations as ops +from ibis.expr.operations.udf import InputType +from public import public +import sqlglot as sg +import sqlglot.expressions as sge + +try: + from sqlglot.expressions import Alter +except ImportError: + from sqlglot.expressions import AlterTable +else: + + def AlterTable(*args, kind="TABLE", **kwargs): + return Alter(*args, kind=kind, **kwargs) + + +if TYPE_CHECKING: + from collections.abc import Callable, Iterable, Mapping + + from ibis.backends.sql.datatypes import SqlglotType + import ibis.expr.schema as sch + import ibis.expr.types as ir + + +def get_leaf_classes(op): + for child_class in op.__subclasses__(): + if not child_class.__subclasses__(): + yield child_class + else: + yield from get_leaf_classes(child_class) + + +ALL_OPERATIONS = frozenset(get_leaf_classes(ops.Node)) + + +class AggGen: + """A descriptor for compiling aggregate functions. + + Common cases can be handled by setting configuration flags, + special cases should override the `aggregate` method directly. + + Parameters + ---------- + supports_filter + Whether the backend supports a FILTER clause in the aggregate. + Defaults to False. + supports_order_by + Whether the backend supports an ORDER BY clause in (relevant) + aggregates. Defaults to False. + """ + + class _Accessor: + """An internal type to handle getattr/getitem access.""" + + __slots__ = ("handler", "compiler") + + def __init__(self, handler: Callable, compiler: SQLGlotCompiler): + self.handler = handler + self.compiler = compiler + + def __getattr__(self, name: str) -> Callable: + return partial(self.handler, self.compiler, name) + + __getitem__ = __getattr__ + + __slots__ = ("supports_filter", "supports_order_by") + + def __init__( + self, *, supports_filter: bool = False, supports_order_by: bool = False + ): + self.supports_filter = supports_filter + self.supports_order_by = supports_order_by + + def __get__(self, instance, owner=None): + if instance is None: + return self + + return AggGen._Accessor(self.aggregate, instance) + + def aggregate( + self, + compiler: SQLGlotCompiler, + name: str, + *args: Any, + where: Any = None, + order_by: tuple = (), + ): + """Compile the specified aggregate. + + Parameters + ---------- + compiler + The backend's compiler. + name + The aggregate name (e.g. `"sum"`). + args + Any arguments to pass to the aggregate. + where + An optional column filter to apply before performing the aggregate. + order_by + Optional ordering keys to use to order the rows before performing + the aggregate. + """ + func = compiler.f[name] + + if order_by and not self.supports_order_by: + raise com.UnsupportedOperationError( + "ordering of order-sensitive aggregations via `order_by` is " + f"not supported for the {compiler.dialect} backend" + ) + + if where is not None and not self.supports_filter: + args = tuple(compiler.if_(where, arg, NULL) for arg in args) + + if order_by and self.supports_order_by: + *rest, last = args + out = func(*rest, sge.Order(this=last, expressions=order_by)) + else: + out = func(*args) + + if where is not None and self.supports_filter: + out = sge.Filter(this=out, expression=sge.Where(this=where)) + + return out + + +class VarGen: + __slots__ = () + + def __getattr__(self, name: str) -> sge.Var: + return sge.Var(this=name) + + def __getitem__(self, key: str) -> sge.Var: + return sge.Var(this=key) + + +class AnonymousFuncGen: + __slots__ = () + + def __getattr__(self, name: str) -> Callable[..., sge.Anonymous]: + return lambda *args: sge.Anonymous( + this=name, expressions=list(map(sge.convert, args)) + ) + + def __getitem__(self, key: str) -> Callable[..., sge.Anonymous]: + return getattr(self, key) + + +class FuncGen: + __slots__ = ("namespace", "anon", "copy") + + def __init__(self, namespace: str | None = None, copy: bool = False) -> None: + self.namespace = namespace + self.anon = AnonymousFuncGen() + self.copy = copy + + def __getattr__(self, name: str) -> Callable[..., sge.Func]: + name = ".".join(filter(None, (self.namespace, name))) + return lambda *args, **kwargs: sg.func( + name, *map(sge.convert, args), **kwargs, copy=self.copy + ) + + def __getitem__(self, key: str) -> Callable[..., sge.Func]: + return getattr(self, key) + + def array(self, *args: Any) -> sge.Array: + if not args: + return sge.Array(expressions=[]) + + first, *rest = args + + if isinstance(first, sge.Select): + assert ( + not rest + ), "only one argument allowed when `first` is a select statement" + + return sge.Array(expressions=list(map(sge.convert, (first, *rest)))) + + def tuple(self, *args: Any) -> sge.Anonymous: + return self.anon.tuple(*args) + + def exists(self, query: sge.Expression) -> sge.Exists: + return sge.Exists(this=query) + + def concat(self, *args: Any) -> sge.Concat: + return sge.Concat(expressions=list(map(sge.convert, args))) + + def map(self, keys: Iterable, values: Iterable) -> sge.Map: + return sge.Map(keys=keys, values=values) + + +class ColGen: + __slots__ = ("table",) + + def __init__(self, table: str | None = None) -> None: + self.table = table + + def __getattr__(self, name: str) -> sge.Column: + return sg.column(name, table=self.table, copy=False) + + def __getitem__(self, key: str) -> sge.Column: + return sg.column(key, table=self.table, copy=False) + + +C = ColGen() +F = FuncGen() +NULL = sge.Null() +FALSE = sge.false() +TRUE = sge.true() +STAR = sge.Star() + + +def parenthesize_inputs(f): + """Decorate a translation rule to parenthesize inputs.""" + + def wrapper(self, op, *, left, right): + return f( + self, + op, + left=self._add_parens(op.left, left), + right=self._add_parens(op.right, right), + ) + + return wrapper + + +@public +class SQLGlotCompiler(abc.ABC): + __slots__ = "f", "v" + + agg = AggGen() + """A generator for handling aggregate functions""" + + rewrites: tuple[type[pats.Replace], ...] = ( + empty_in_values_right_side, + add_order_by_to_empty_ranking_window_functions, + one_to_zero_index, + add_one_to_nth_value_input, + ) + """A sequence of rewrites to apply to the expression tree before compilation.""" + + no_limit_value: sge.Null | None = None + """The value to use to indicate no limit.""" + + quoted: bool = True + """Whether to always quote identifiers.""" + + copy_func_args: bool = False + """Whether to copy function arguments when generating SQL.""" + + supports_qualify: bool = False + """Whether the backend supports the QUALIFY clause.""" + + NAN: ClassVar[sge.Expression] = sge.Cast( + this=sge.convert("NaN"), to=sge.DataType(this=sge.DataType.Type.DOUBLE) + ) + """Backend's NaN literal.""" + + POS_INF: ClassVar[sge.Expression] = sge.Cast( + this=sge.convert("Inf"), to=sge.DataType(this=sge.DataType.Type.DOUBLE) + ) + """Backend's positive infinity literal.""" + + NEG_INF: ClassVar[sge.Expression] = sge.Cast( + this=sge.convert("-Inf"), to=sge.DataType(this=sge.DataType.Type.DOUBLE) + ) + """Backend's negative infinity literal.""" + + EXTRA_SUPPORTED_OPS: tuple[type[ops.Node], ...] = ( + ops.Project, + ops.Filter, + ops.Sort, + ops.WindowFunction, + ) + """A tuple of ops classes that are supported, but don't have explicit + `visit_*` methods (usually due to being handled by rewrite rules). Used by + `has_operation`""" + + UNSUPPORTED_OPS: tuple[type[ops.Node], ...] = () + """Tuple of operations the backend doesn't support.""" + + LOWERED_OPS: dict[type[ops.Node], pats.Replace | None] = { + ops.Bucket: lower_bucket, + ops.Capitalize: lower_capitalize, + ops.Sample: lower_sample, + ops.StringSlice: lower_stringslice, + } + """A mapping from an operation class to either a rewrite rule for rewriting that + operation to one composed of lower-level operations ("lowering"), or `None` to + remove an existing rewrite rule for that operation added in a base class""" + + SIMPLE_OPS = { + ops.Abs: "abs", + ops.Acos: "acos", + ops.All: "bool_and", + ops.Any: "bool_or", + ops.ApproxCountDistinct: "approx_distinct", + ops.ArgMax: "max_by", + ops.ArgMin: "min_by", + ops.ArrayContains: "array_contains", + ops.ArrayFlatten: "flatten", + ops.ArrayLength: "array_size", + ops.ArraySort: "array_sort", + ops.ArrayStringJoin: "array_to_string", + ops.Asin: "asin", + ops.Atan2: "atan2", + ops.Atan: "atan", + ops.Cos: "cos", + ops.Cot: "cot", + ops.Count: "count", + ops.CumeDist: "cume_dist", + ops.Date: "date", + ops.DateFromYMD: "datefromparts", + ops.Degrees: "degrees", + ops.DenseRank: "dense_rank", + ops.Exp: "exp", + FirstValue: "first_value", + ops.GroupConcat: "group_concat", + ops.IfElse: "if", + ops.IsInf: "isinf", + ops.IsNan: "isnan", + ops.JSONGetItem: "json_extract", + ops.LPad: "lpad", + LastValue: "last_value", + ops.Levenshtein: "levenshtein", + ops.Ln: "ln", + ops.Log10: "log", + ops.Log2: "log2", + ops.Lowercase: "lower", + ops.Map: "map", + ops.Median: "median", + ops.MinRank: "rank", + ops.NTile: "ntile", + ops.NthValue: "nth_value", + ops.NullIf: "nullif", + ops.PercentRank: "percent_rank", + ops.Pi: "pi", + ops.Power: "pow", + ops.RPad: "rpad", + ops.Radians: "radians", + ops.RegexSearch: "regexp_like", + ops.RegexSplit: "regexp_split", + ops.Repeat: "repeat", + ops.Reverse: "reverse", + ops.RowNumber: "row_number", + ops.Sign: "sign", + ops.Sin: "sin", + ops.Sqrt: "sqrt", + ops.StartsWith: "starts_with", + ops.StrRight: "right", + ops.StringAscii: "ascii", + ops.StringContains: "contains", + ops.StringLength: "length", + ops.StringReplace: "replace", + ops.StringSplit: "split", + ops.StringToDate: "str_to_date", + ops.StringToTimestamp: "str_to_time", + ops.Tan: "tan", + ops.Translate: "translate", + ops.Unnest: "explode", + ops.Uppercase: "upper", + } + + BINARY_INFIX_OPS = ( + # Binary operations + ops.Add, + ops.Subtract, + ops.Multiply, + ops.Divide, + ops.Modulus, + ops.Power, + # Comparisons + ops.GreaterEqual, + ops.Greater, + ops.LessEqual, + ops.Less, + ops.Equals, + ops.NotEquals, + # Boolean comparisons + ops.And, + ops.Or, + ops.Xor, + # Bitwise business + ops.BitwiseLeftShift, + ops.BitwiseRightShift, + ops.BitwiseAnd, + ops.BitwiseOr, + ops.BitwiseXor, + # Time arithmetic + ops.DateAdd, + ops.DateSub, + ops.DateDiff, + ops.TimestampAdd, + ops.TimestampSub, + ops.TimestampDiff, + # Interval Marginalia + ops.IntervalAdd, + ops.IntervalMultiply, + ops.IntervalSubtract, + ) + + NEEDS_PARENS = BINARY_INFIX_OPS + (ops.IsNull,) + + # Constructed dynamically in `__init_subclass__` from their respective + # UPPERCASE values to handle inheritance, do not modify directly here. + extra_supported_ops: ClassVar[frozenset[type[ops.Node]]] = frozenset() + lowered_ops: ClassVar[dict[type[ops.Node], pats.Replace]] = {} + + def __init__(self) -> None: + self.f = FuncGen(copy=self.__class__.copy_func_args) + self.v = VarGen() + + def __init_subclass__(cls, **kwargs): + super().__init_subclass__(**kwargs) + + def methodname(op: type) -> str: + assert isinstance(type(op), type), type(op) + return f"visit_{op.__name__}" + + def make_impl(op, target_name): + assert isinstance(type(op), type), type(op) + + if issubclass(op, ops.Reduction): + + def impl( + self, _, *, _name: str = target_name, where, order_by=(), **kw + ): + return self.agg[_name](*kw.values(), where=where, order_by=order_by) + + else: + + def impl(self, _, *, _name: str = target_name, **kw): + return self.f[_name](*kw.values()) + + return impl + + for op, target_name in cls.SIMPLE_OPS.items(): + setattr(cls, methodname(op), make_impl(op, target_name)) + + # unconditionally raise an exception for unsupported operations + # + # these *must* be defined after SIMPLE_OPS to handle compilers that + # subclass other compilers + for op in cls.UNSUPPORTED_OPS: + # change to visit_Unsupported in a follow up + # TODO: handle geoespatial ops as a separate case? + setattr(cls, methodname(op), cls.visit_Undefined) + + # raise on any remaining unsupported operations + for op in ALL_OPERATIONS: + name = methodname(op) + if not hasattr(cls, name): + setattr(cls, name, cls.visit_Undefined) + + # Amend `lowered_ops` and `extra_supported_ops` using their + # respective UPPERCASE classvar values. + extra_supported_ops = set(cls.extra_supported_ops) + lowered_ops = dict(cls.lowered_ops) + extra_supported_ops.update(cls.EXTRA_SUPPORTED_OPS) + for op_cls, rewrite in cls.LOWERED_OPS.items(): + if rewrite is not None: + lowered_ops[op_cls] = rewrite + extra_supported_ops.add(op_cls) + else: + lowered_ops.pop(op_cls, None) + extra_supported_ops.discard(op_cls) + cls.lowered_ops = lowered_ops + cls.extra_supported_ops = frozenset(extra_supported_ops) + + @property + @abc.abstractmethod + def dialect(self) -> str: + """Backend dialect.""" + + @property + @abc.abstractmethod + def type_mapper(self) -> type[SqlglotType]: + """The type mapper for the backend.""" + + def _compile_builtin_udf(self, udf_node: ops.ScalarUDF) -> None: # noqa: B027 + """No-op.""" + + def _compile_python_udf(self, udf_node: ops.ScalarUDF) -> None: + raise NotImplementedError( + f"Python UDFs are not supported in the {self.dialect} backend" + ) + + def _compile_pyarrow_udf(self, udf_node: ops.ScalarUDF) -> None: + raise NotImplementedError( + f"PyArrow UDFs are not supported in the {self.dialect} backend" + ) + + def _compile_pandas_udf(self, udf_node: ops.ScalarUDF) -> str: + raise NotImplementedError( + f"pandas UDFs are not supported in the {self.dialect} backend" + ) + + # Concrete API + + def if_(self, condition, true, false: sge.Expression | None = None) -> sge.If: + return sge.If( + this=sge.convert(condition), + true=sge.convert(true), + false=None if false is None else sge.convert(false), + ) + + def cast(self, arg, to: dt.DataType) -> sge.Cast: + return sge.Cast( + this=sge.convert(arg), to=self.type_mapper.from_ibis(to), copy=False + ) + + def _prepare_params(self, params): + result = {} + for param, value in params.items(): + node = param.op() + if isinstance(node, ops.Alias): + node = node.arg + result[node] = value + return result + + def to_sqlglot( + self, + expr: ir.Expr, + *, + limit: str | None = None, + params: Mapping[ir.Expr, Any] | None = None, + ): + import ibis + + table_expr = expr.as_table() + + if limit == "default": + limit = ibis.options.sql.default_limit + if limit is not None: + table_expr = table_expr.limit(limit) + + if params is None: + params = {} + + sql = self.translate(table_expr.op(), params=params) + assert not isinstance(sql, sge.Subquery) + + if isinstance(sql, sge.Table): + sql = sg.select(STAR, copy=False).from_(sql, copy=False) + + assert not isinstance(sql, sge.Subquery) + return sql + + def translate(self, op, *, params: Mapping[ir.Value, Any]) -> sge.Expression: + """Translate an ibis operation to a sqlglot expression. + + Parameters + ---------- + op + An ibis operation + params + A mapping of expressions to concrete values + compiler + An instance of SQLGlotCompiler + translate_rel + Relation node translator + translate_val + Value node translator + + Returns + ------- + sqlglot.expressions.Expression + A sqlglot expression + + """ + # substitute parameters immediately to avoid having to define a + # ScalarParameter translation rule + params = self._prepare_params(params) + if self.lowered_ops: + op = op.replace(reduce(operator.or_, self.lowered_ops.values())) + op, ctes = sqlize( + op, + params=params, + rewrites=self.rewrites, + fuse_selects=options.sql.fuse_selects, + ) + + aliases = {} + counter = itertools.count() + + def fn(node, _, **kwargs): + result = self.visit_node(node, **kwargs) + + # if it's not a relation then we don't need to do anything special + if node is op or not isinstance(node, ops.Relation): + return result + + # alias ops.Views to their explicitly assigned name otherwise generate + alias = node.name if isinstance(node, ops.View) else f"t{next(counter)}" + aliases[node] = alias + + alias = sg.to_identifier(alias, quoted=self.quoted) + if isinstance(result, sge.Subquery): + return result.as_(alias, quoted=self.quoted) + else: + try: + return result.subquery(alias, copy=False) + except AttributeError: + return result.as_(alias, quoted=self.quoted) + + # apply translate rules in topological order + results = op.map(fn) + + # get the root node as a sqlglot select statement + out = results[op] + if isinstance(out, sge.Table): + out = sg.select(STAR, copy=False).from_(out, copy=False) + elif isinstance(out, sge.Subquery): + out = out.this + + # add cte definitions to the select statement + for cte in ctes: + alias = sg.to_identifier(aliases[cte], quoted=self.quoted) + out = out.with_( + alias, as_=results[cte].this, dialect=self.dialect, copy=False + ) + + return out + + def visit_node(self, op: ops.Node, **kwargs): + if isinstance(op, ops.ScalarUDF): + return self.visit_ScalarUDF(op, **kwargs) + elif isinstance(op, ops.AggUDF): + return self.visit_AggUDF(op, **kwargs) + else: + method = getattr(self, f"visit_{type(op).__name__}", None) + if method is not None: + return method(op, **kwargs) + else: + raise com.OperationNotDefinedError( + f"No translation rule for {type(op).__name__}" + ) + + def visit_Field(self, op, *, rel, name): + return sg.column( + self._gen_valid_name(name), table=rel.alias_or_name, quoted=self.quoted + ) + + def visit_Cast(self, op, *, arg, to): + from_ = op.arg.dtype + + if from_.is_integer() and to.is_interval(): + return self._make_interval(arg, to.unit) + + return self.cast(arg, to) + + def visit_ScalarSubquery(self, op, *, rel): + return rel.this.subquery(copy=False) + + def visit_Alias(self, op, *, arg, name): + return arg + + def visit_Literal(self, op, *, value, dtype): + """Compile a literal value. + + This is the default implementation for compiling literal values. + + Most backends should not need to override this method unless they want + to handle NULL literals as well as every other type of non-null literal + including integers, floating point numbers, decimals, strings, etc. + + The logic here is: + + 1. If the value is None and the type is nullable, return NULL + 1. If the value is None and the type is not nullable, raise an error + 1. Call `visit_NonNullLiteral` method. + 1. If the previous returns `None`, call `visit_DefaultLiteral` method + else return the result of the previous step. + """ + if value is None: + if dtype.nullable: + return NULL if dtype.is_null() else self.cast(NULL, dtype) + raise com.UnsupportedOperationError( + f"Unsupported NULL for non-nullable type: {dtype!r}" + ) + else: + result = self.visit_NonNullLiteral(op, value=value, dtype=dtype) + if result is None: + return self.visit_DefaultLiteral(op, value=value, dtype=dtype) + return result + + def visit_NonNullLiteral(self, op, *, value, dtype): + """Compile a non-null literal differently than the default implementation. + + Most backends should implement this, but only when they need to handle + some non-null literal differently than the default implementation + (`visit_DefaultLiteral`). + + Return `None` from an override of this method to fall back to + `visit_DefaultLiteral`. + """ + return self.visit_DefaultLiteral(op, value=value, dtype=dtype) + + def visit_DefaultLiteral(self, op, *, value, dtype): + """Compile a literal with a non-null value. + + This is the default implementation for compiling non-null literals. + + Most backends should not need to override this method unless they want + to handle compiling every kind of non-null literal value. + """ + if dtype.is_integer(): + return sge.convert(value) + elif dtype.is_floating(): + if math.isnan(value): + return self.NAN + elif math.isinf(value): + return self.POS_INF if value > 0 else self.NEG_INF + return sge.convert(value) + elif dtype.is_decimal(): + return self.cast(str(value), dtype) + elif dtype.is_interval(): + return sge.Interval( + this=sge.convert(str(value)), + unit=sge.Var(this=dtype.resolution.upper()), + ) + elif dtype.is_boolean(): + return sge.Boolean(this=bool(value)) + elif dtype.is_string(): + return sge.convert(value) + elif dtype.is_inet() or dtype.is_macaddr(): + return sge.convert(str(value)) + elif dtype.is_timestamp() or dtype.is_time(): + return self.cast(value.isoformat(), dtype) + elif dtype.is_date(): + return self.f.datefromparts(value.year, value.month, value.day) + elif dtype.is_array(): + value_type = dtype.value_type + return self.f.array( + *( + self.visit_Literal( + ops.Literal(v, value_type), value=v, dtype=value_type + ) + for v in value + ) + ) + elif dtype.is_map(): + key_type = dtype.key_type + keys = self.f.array( + *( + self.visit_Literal( + ops.Literal(k, key_type), value=k, dtype=key_type + ) + for k in value.keys() + ) + ) + + value_type = dtype.value_type + values = self.f.array( + *( + self.visit_Literal( + ops.Literal(v, value_type), value=v, dtype=value_type + ) + for v in value.values() + ) + ) + + return self.f.map(keys, values) + elif dtype.is_struct(): + items = [ + self.visit_Literal( + ops.Literal(v, field_dtype), value=v, dtype=field_dtype + ).as_(k, quoted=self.quoted) + for field_dtype, (k, v) in zip(dtype.types, value.items()) + ] + return sge.Struct.from_arg_list(items) + elif dtype.is_uuid(): + return self.cast(str(value), dtype) + elif dtype.is_geospatial(): + args = [value.wkt] + if (srid := dtype.srid) is not None: + args.append(srid) + return self.f.st_geomfromtext(*args) + + raise NotImplementedError(f"Unsupported type: {dtype!r}") + + def visit_BitwiseNot(self, op, *, arg): + return sge.BitwiseNot(this=arg) + + ### Mathematical Calisthenics + + def visit_E(self, op): + return self.f.exp(1) + + def visit_Log(self, op, *, arg, base): + if base is None: + return self.f.ln(arg) + elif str(base) in ("2", "10"): + return self.f[f"log{base}"](arg) + else: + return self.f.ln(arg) / self.f.ln(base) + + def visit_Clip(self, op, *, arg, lower, upper): + if upper is not None: + arg = self.if_(arg.is_(NULL), arg, self.f.least(upper, arg)) + + if lower is not None: + arg = self.if_(arg.is_(NULL), arg, self.f.greatest(lower, arg)) + + return arg + + def visit_FloorDivide(self, op, *, left, right): + return self.cast(self.f.floor(left / right), op.dtype) + + def visit_Ceil(self, op, *, arg): + return self.cast(self.f.ceil(arg), op.dtype) + + def visit_Floor(self, op, *, arg): + return self.cast(self.f.floor(arg), op.dtype) + + def visit_Round(self, op, *, arg, digits): + if digits is not None: + return sge.Round(this=arg, decimals=digits) + return sge.Round(this=arg) + + ### Random Noise + + def visit_RandomScalar(self, op, **kwargs): + return self.f.rand() + + def visit_RandomUUID(self, op, **kwargs): + return self.f.uuid() + + ### Dtype Dysmorphia + + def visit_TryCast(self, op, *, arg, to): + return sge.TryCast(this=arg, to=self.type_mapper.from_ibis(to)) + + ### Comparator Conundrums + + def visit_Between(self, op, *, arg, lower_bound, upper_bound): + return sge.Between(this=arg, low=lower_bound, high=upper_bound) + + def visit_Negate(self, op, *, arg): + return -sge.paren(arg, copy=False) + + def visit_Not(self, op, *, arg): + if isinstance(arg, sge.Filter): + return sge.Filter( + this=sg.not_(arg.this, copy=False), expression=arg.expression + ) + return sg.not_(sge.paren(arg, copy=False)) + + ### Timey McTimeFace + + def visit_Time(self, op, *, arg): + return self.cast(arg, to=dt.time) + + def visit_TimestampNow(self, op): + return sge.CurrentTimestamp() + + def visit_DateNow(self, op): + return sge.CurrentDate() + + def visit_Strftime(self, op, *, arg, format_str): + return sge.TimeToStr(this=arg, format=format_str) + + def visit_ExtractEpochSeconds(self, op, *, arg): + return self.f.epoch(self.cast(arg, dt.timestamp)) + + def visit_ExtractYear(self, op, *, arg): + return self.f.extract(self.v.year, arg) + + def visit_ExtractMonth(self, op, *, arg): + return self.f.extract(self.v.month, arg) + + def visit_ExtractDay(self, op, *, arg): + return self.f.extract(self.v.day, arg) + + def visit_ExtractDayOfYear(self, op, *, arg): + return self.f.extract(self.v.dayofyear, arg) + + def visit_ExtractQuarter(self, op, *, arg): + return self.f.extract(self.v.quarter, arg) + + def visit_ExtractWeekOfYear(self, op, *, arg): + return self.f.extract(self.v.week, arg) + + def visit_ExtractHour(self, op, *, arg): + return self.f.extract(self.v.hour, arg) + + def visit_ExtractMinute(self, op, *, arg): + return self.f.extract(self.v.minute, arg) + + def visit_ExtractSecond(self, op, *, arg): + return self.f.extract(self.v.second, arg) + + def visit_TimestampTruncate(self, op, *, arg, unit): + unit_mapping = { + "Y": "year", + "Q": "quarter", + "M": "month", + "W": "week", + "D": "day", + "h": "hour", + "m": "minute", + "s": "second", + "ms": "ms", + "us": "us", + } + + if (raw_unit := unit_mapping.get(unit.short)) is None: + raise com.UnsupportedOperationError( + f"Unsupported truncate unit {unit.short!r}" + ) + + return self.f.date_trunc(raw_unit, arg) + + def visit_DateTruncate(self, op, *, arg, unit): + return self.visit_TimestampTruncate(op, arg=arg, unit=unit) + + def visit_TimeTruncate(self, op, *, arg, unit): + return self.visit_TimestampTruncate(op, arg=arg, unit=unit) + + def visit_DayOfWeekIndex(self, op, *, arg): + return (self.f.dayofweek(arg) + 6) % 7 + + def visit_DayOfWeekName(self, op, *, arg): + # day of week number is 0-indexed + # Sunday == 0 + # Saturday == 6 + return sge.Case( + this=(self.f.dayofweek(arg) + 6) % 7, + ifs=list(itertools.starmap(self.if_, enumerate(calendar.day_name))), + ) + + def _make_interval(self, arg, unit): + return sge.Interval(this=arg, unit=self.v[unit.singular]) + + def visit_IntervalFromInteger(self, op, *, arg, unit): + return self._make_interval(arg, unit) + + ### String Instruments + def visit_Strip(self, op, *, arg): + return self.f.trim(arg, string.whitespace) + + def visit_RStrip(self, op, *, arg): + return self.f.rtrim(arg, string.whitespace) + + def visit_LStrip(self, op, *, arg): + return self.f.ltrim(arg, string.whitespace) + + def visit_Substring(self, op, *, arg, start, length): + if isinstance(op.length, ops.Literal) and (value := op.length.value) < 0: + raise com.IbisInputError( + f"Length parameter must be a non-negative value; got {value}" + ) + start += 1 + start = self.if_(start >= 1, start, start + self.f.length(arg)) + if length is None: + return self.f.substring(arg, start) + return self.f.substring(arg, start, length) + + def visit_StringFind(self, op, *, arg, substr, start, end): + if end is not None: + raise com.UnsupportedOperationError( + "String find doesn't support `end` argument" + ) + + if start is not None: + arg = self.f.substr(arg, start + 1) + pos = self.f.strpos(arg, substr) + return self.if_(pos > 0, pos + start, 0) + + return self.f.strpos(arg, substr) + + def visit_RegexReplace(self, op, *, arg, pattern, replacement): + return self.f.regexp_replace(arg, pattern, replacement, "g") + + def visit_StringConcat(self, op, *, arg): + return self.f.concat(*arg) + + def visit_StringJoin(self, op, *, sep, arg): + return self.f.concat_ws(sep, *arg) + + def visit_StringSQLLike(self, op, *, arg, pattern, escape): + return arg.like(pattern) + + def visit_StringSQLILike(self, op, *, arg, pattern, escape): + return arg.ilike(pattern) + + ### NULL PLAYER CHARACTER + def visit_IsNull(self, op, *, arg): + return arg.is_(NULL) + + def visit_NotNull(self, op, *, arg): + return arg.is_(sg.not_(NULL, copy=False)) + + def visit_InValues(self, op, *, value, options): + return value.isin(*options) + + ### Counting + + def visit_CountDistinct(self, op, *, arg, where): + return self.agg.count(sge.Distinct(expressions=[arg]), where=where) + + def visit_CountDistinctStar(self, op, *, arg, where): + return self.agg.count(sge.Distinct(expressions=[STAR]), where=where) + + def visit_CountStar(self, op, *, arg, where): + return self.agg.count(STAR, where=where) + + def visit_Sum(self, op, *, arg, where): + if op.arg.dtype.is_boolean(): + arg = self.cast(arg, dt.int32) + return self.agg.sum(arg, where=where) + + def visit_Mean(self, op, *, arg, where): + if op.arg.dtype.is_boolean(): + arg = self.cast(arg, dt.int32) + return self.agg.avg(arg, where=where) + + def visit_Min(self, op, *, arg, where): + if op.arg.dtype.is_boolean(): + return self.agg.bool_and(arg, where=where) + return self.agg.min(arg, where=where) + + def visit_Max(self, op, *, arg, where): + if op.arg.dtype.is_boolean(): + return self.agg.bool_or(arg, where=where) + return self.agg.max(arg, where=where) + + ### Stats + + def visit_VarianceStandardDevCovariance(self, op, *, how, where, **kw): + hows = {"sample": "samp", "pop": "pop"} + funcs = { + ops.Variance: "var", + ops.StandardDev: "stddev", + ops.Covariance: "covar", + } + + args = [] + + for oparg, arg in zip(op.args, kw.values()): + if (arg_dtype := oparg.dtype).is_boolean(): + arg = self.cast(arg, dt.Int32(nullable=arg_dtype.nullable)) + args.append(arg) + + funcname = f"{funcs[type(op)]}_{hows[how]}" + return self.agg[funcname](*args, where=where) + + visit_Variance = ( + visit_StandardDev + ) = visit_Covariance = visit_VarianceStandardDevCovariance + + def visit_SimpleCase(self, op, *, base=None, cases, results, default): + return sge.Case( + this=base, ifs=list(map(self.if_, cases, results)), default=default + ) + + visit_SearchedCase = visit_SimpleCase + + def visit_ExistsSubquery(self, op, *, rel): + select = rel.this.select(1, append=False) + return self.f.exists(select) + + def visit_InSubquery(self, op, *, rel, needle): + query = rel.this + if not isinstance(query, sge.Select): + query = sg.select(STAR).from_(query) + return needle.isin(query=query) + + def visit_Array(self, op, *, exprs): + return self.f.array(*exprs) + + def visit_StructColumn(self, op, *, names, values): + return sge.Struct.from_arg_list( + [value.as_(name, quoted=self.quoted) for name, value in zip(names, values)] + ) + + def visit_StructField(self, op, *, arg, field): + return sge.Dot(this=arg, expression=sg.to_identifier(field, quoted=self.quoted)) + + def visit_IdenticalTo(self, op, *, left, right): + return sge.NullSafeEQ(this=left, expression=right) + + def visit_Greatest(self, op, *, arg): + return self.f.greatest(*arg) + + def visit_Least(self, op, *, arg): + return self.f.least(*arg) + + def visit_Coalesce(self, op, *, arg): + return self.f.coalesce(*arg) + + ### Ordering and window functions + + def visit_SortKey(self, op, *, expr, ascending: bool, nulls_first: bool): + return sge.Ordered(this=expr, desc=not ascending, nulls_first=nulls_first) + + def visit_ApproxMedian(self, op, *, arg, where): + return self.agg.approx_quantile(arg, 0.5, where=where) + + def visit_WindowBoundary(self, op, *, value, preceding): + # TODO: bit of a hack to return a dict, but there's no sqlglot expression + # that corresponds to _only_ this information + return {"value": value, "side": "preceding" if preceding else "following"} + + def visit_WindowFunction(self, op, *, how, func, start, end, group_by, order_by): + if start is None: + start = {} + if end is None: + end = {} + + start_value = start.get("value", "UNBOUNDED") + start_side = start.get("side", "PRECEDING") + end_value = end.get("value", "UNBOUNDED") + end_side = end.get("side", "FOLLOWING") + + if getattr(start_value, "this", None) == "0": + start_value = "CURRENT ROW" + start_side = None + + if getattr(end_value, "this", None) == "0": + end_value = "CURRENT ROW" + end_side = None + + spec = sge.WindowSpec( + kind=how.upper(), + start=start_value, + start_side=start_side, + end=end_value, + end_side=end_side, + over="OVER", + ) + order = sge.Order(expressions=order_by) if order_by else None + + spec = self._minimize_spec(op.start, op.end, spec) + + return sge.Window(this=func, partition_by=group_by, order=order, spec=spec) + + @staticmethod + def _minimize_spec(start, end, spec): + return spec + + def visit_LagLead(self, op, *, arg, offset, default): + args = [arg] + + if default is not None: + if offset is None: + offset = 1 + + args.append(offset) + args.append(default) + elif offset is not None: + args.append(offset) + + return self.f[type(op).__name__.lower()](*args) + + visit_Lag = visit_Lead = visit_LagLead + + def visit_Argument(self, op, *, name: str, shape, dtype): + return sg.to_identifier(op.param) + + def visit_RowID(self, op, *, table): + return sg.column( + op.name, table=table.alias_or_name, quoted=self.quoted, copy=False + ) + + # TODO(kszucs): this should be renamed to something UDF related + def __sql_name__(self, op: ops.ScalarUDF | ops.AggUDF) -> str: + # for builtin functions use the exact function name, otherwise use the + # generated name to handle the case of redefinition + funcname = ( + op.__func_name__ + if op.__input_type__ == InputType.BUILTIN + else type(op).__name__ + ) + + # not actually a table, but easier to quote individual namespace + # components this way + namespace = op.__udf_namespace__ + return sg.table(funcname, db=namespace.database, catalog=namespace.catalog).sql( + self.dialect + ) + + def visit_ScalarUDF(self, op, **kw): + return self.f[self.__sql_name__(op)](*kw.values()) + + def visit_AggUDF(self, op, *, where, **kw): + return self.agg[self.__sql_name__(op)](*kw.values(), where=where) + + def visit_TimestampDelta(self, op, *, part, left, right): + # dialect is necessary due to sqlglot's default behavior + # of `part` coming last + return sge.DateDiff( + this=left, expression=right, unit=part, dialect=self.dialect + ) + + visit_TimeDelta = visit_DateDelta = visit_TimestampDelta + + def visit_TimestampBucket(self, op, *, arg, interval, offset): + origin = self.f.cast("epoch", self.type_mapper.from_ibis(dt.timestamp)) + if offset is not None: + origin += offset + return self.f.time_bucket(interval, arg, origin) + + def visit_ArrayConcat(self, op, *, arg): + return sge.ArrayConcat(this=arg[0], expressions=list(arg[1:])) + + ## relations + + @staticmethod + def _gen_valid_name(name: str) -> str: + """Generate a valid name for a value expression. + + Override this method if the dialect has restrictions on valid + identifiers even when quoted. + + See the BigQuery backend's implementation for an example. + """ + return name + + def _cleanup_names(self, exprs: Mapping[str, sge.Expression]): + """Compose `_gen_valid_name` and `_dedup_name` to clean up names in projections.""" + + for name, value in exprs.items(): + name = self._gen_valid_name(name) + if isinstance(value, sge.Column) and name == value.name: + # don't alias columns that are already named the same as their alias + yield value + else: + yield value.as_(name, quoted=self.quoted, copy=False) + + def visit_Select( + self, op, *, parent, selections, predicates, qualified=False, sort_keys=False + ): + # if we've constructed a useless projection return the parent relation + if not (selections or predicates or qualified or sort_keys): + return parent + + result = parent + + if selections: + # if there are `qualify` predicates then sqlglot adds a hidden + # column to implement the functionality if the dialect doesn't + # support it + # + # using STAR in that case would lead to an extra column, so in that + # case we have to spell out the columns + if op.is_star_selection() and (not qualified or self.supports_qualify): + fields = [STAR] + else: + fields = self._cleanup_names(selections) + result = sg.select(*fields, copy=False).from_(result, copy=False) + + if predicates: + result = result.where(*predicates, copy=False) + + if qualified: + result = result.qualify(*qualified, copy=False) + + if sort_keys: + result = result.order_by(*sort_keys, copy=False) + + return result + + def visit_DummyTable(self, op, *, values): + return sg.select(*self._cleanup_names(values), copy=False) + + def visit_UnboundTable( + self, op, *, name: str, schema: sch.Schema, namespace: ops.Namespace + ) -> sg.Table: + return sg.table( + name, db=namespace.database, catalog=namespace.catalog, quoted=self.quoted + ) + + def visit_InMemoryTable( + self, op, *, name: str, schema: sch.Schema, data + ) -> sg.Table: + return sg.table(name, quoted=self.quoted) + + def visit_DatabaseTable( + self, + op, + *, + name: str, + schema: sch.Schema, + source: Any, + namespace: ops.Namespace, + ) -> sg.Table: + return sg.table( + name, db=namespace.database, catalog=namespace.catalog, quoted=self.quoted + ) + + def visit_SelfReference(self, op, *, parent, identifier): + return parent + + visit_JoinReference = visit_SelfReference + + def visit_JoinChain(self, op, *, first, rest, values): + result = sg.select(*self._cleanup_names(values), copy=False).from_( + first, copy=False + ) + + for link in rest: + if isinstance(link, sge.Alias): + link = link.this + result = result.join(link, copy=False) + return result + + def visit_JoinLink(self, op, *, how, table, predicates): + sides = { + "inner": None, + "left": "left", + "right": "right", + "semi": "left", + "anti": "left", + "cross": None, + "outer": "full", + "asof": "asof", + "any_left": "left", + "any_inner": None, + "positional": None, + } + kinds = { + "any_left": "any", + "any_inner": "any", + "asof": "left", + "inner": "inner", + "left": "outer", + "right": "outer", + "semi": "semi", + "anti": "anti", + "cross": "cross", + "outer": "outer", + "positional": "positional", + } + assert predicates or how in { + "cross", + "positional", + }, "expected non-empty predicates when not a cross join" + on = sg.and_(*predicates) if predicates else None + return sge.Join(this=table, side=sides[how], kind=kinds[how], on=on) + + @staticmethod + def _generate_groups(groups): + return map(sge.convert, range(1, len(groups) + 1)) + + def visit_Aggregate(self, op, *, parent, groups, metrics): + sel = sg.select( + *self._cleanup_names(groups), *self._cleanup_names(metrics), copy=False + ).from_(parent, copy=False) + + if groups: + sel = sel.group_by(*self._generate_groups(groups.values()), copy=False) + + return sel + + @classmethod + def _add_parens(cls, op, sg_expr): + if isinstance(op, cls.NEEDS_PARENS): + return sge.paren(sg_expr, copy=False) + return sg_expr + + def visit_Union(self, op, *, left, right, distinct): + if isinstance(left, (sge.Table, sge.Subquery)): + left = sg.select(STAR, copy=False).from_(left, copy=False) + + if isinstance(right, (sge.Table, sge.Subquery)): + right = sg.select(STAR, copy=False).from_(right, copy=False) + + return sg.union( + left.args.get("this", left), + right.args.get("this", right), + distinct=distinct, + copy=False, + ) + + def visit_Intersection(self, op, *, left, right, distinct): + if isinstance(left, (sge.Table, sge.Subquery)): + left = sg.select(STAR, copy=False).from_(left, copy=False) + + if isinstance(right, (sge.Table, sge.Subquery)): + right = sg.select(STAR, copy=False).from_(right, copy=False) + + return sg.intersect( + left.args.get("this", left), + right.args.get("this", right), + distinct=distinct, + copy=False, + ) + + def visit_Difference(self, op, *, left, right, distinct): + if isinstance(left, (sge.Table, sge.Subquery)): + left = sg.select(STAR, copy=False).from_(left, copy=False) + + if isinstance(right, (sge.Table, sge.Subquery)): + right = sg.select(STAR, copy=False).from_(right, copy=False) + + return sg.except_( + left.args.get("this", left), + right.args.get("this", right), + distinct=distinct, + copy=False, + ) + + def visit_Limit(self, op, *, parent, n, offset): + # push limit/offset into subqueries + if isinstance(parent, sge.Subquery) and parent.this.args.get("limit") is None: + result = parent.this.copy() + alias = parent.alias + else: + result = sg.select(STAR, copy=False).from_(parent, copy=False) + alias = None + + if isinstance(n, int): + result = result.limit(n, copy=False) + elif n is not None: + result = result.limit( + sg.select(n, copy=False).from_(parent, copy=False).subquery(copy=False), + copy=False, + ) + else: + assert n is None, n + if self.no_limit_value is not None: + result = result.limit(self.no_limit_value, copy=False) + + assert offset is not None, "offset is None" + + if not isinstance(offset, int): + skip = offset + skip = ( + sg.select(skip, copy=False) + .from_(parent, copy=False) + .subquery(copy=False) + ) + elif not offset: + if alias is not None: + return result.subquery(alias, copy=False) + return result + else: + skip = offset + + result = result.offset(skip, copy=False) + if alias is not None: + return result.subquery(alias, copy=False) + return result + + def visit_Distinct(self, op, *, parent): + return ( + sg.select(STAR, copy=False).distinct(copy=False).from_(parent, copy=False) + ) + + def visit_CTE(self, op, *, parent): + return sg.table(parent.alias_or_name, quoted=self.quoted) + + def visit_View(self, op, *, child, name: str): + if isinstance(child, sge.Table): + child = sg.select(STAR, copy=False).from_(child, copy=False) + else: + child = child.copy() + + if isinstance(child, sge.Subquery): + return child.as_(name, quoted=self.quoted) + else: + try: + return child.subquery(name, copy=False) + except AttributeError: + return child.as_(name, quoted=self.quoted) + + def visit_SQLStringView(self, op, *, query: str, child, schema): + return sg.parse_one(query, read=self.dialect) + + def visit_SQLQueryResult(self, op, *, query, schema, source): + return sg.parse_one(query, dialect=self.dialect).subquery(copy=False) + + def visit_RegexExtract(self, op, *, arg, pattern, index): + return self.f.regexp_extract(arg, pattern, index, dialect=self.dialect) + + @parenthesize_inputs + def visit_Add(self, op, *, left, right): + return sge.Add(this=left, expression=right) + + visit_DateAdd = visit_TimestampAdd = visit_IntervalAdd = visit_Add + + @parenthesize_inputs + def visit_Subtract(self, op, *, left, right): + return sge.Sub(this=left, expression=right) + + visit_DateSub = ( + visit_DateDiff + ) = ( + visit_TimestampSub + ) = visit_TimestampDiff = visit_IntervalSubtract = visit_Subtract + + @parenthesize_inputs + def visit_Multiply(self, op, *, left, right): + return sge.Mul(this=left, expression=right) + + visit_IntervalMultiply = visit_Multiply + + @parenthesize_inputs + def visit_Divide(self, op, *, left, right): + return sge.Div(this=left, expression=right) + + @parenthesize_inputs + def visit_Modulus(self, op, *, left, right): + return sge.Mod(this=left, expression=right) + + @parenthesize_inputs + def visit_Power(self, op, *, left, right): + return sge.Pow(this=left, expression=right) + + @parenthesize_inputs + def visit_GreaterEqual(self, op, *, left, right): + return sge.GTE(this=left, expression=right) + + @parenthesize_inputs + def visit_Greater(self, op, *, left, right): + return sge.GT(this=left, expression=right) + + @parenthesize_inputs + def visit_LessEqual(self, op, *, left, right): + return sge.LTE(this=left, expression=right) + + @parenthesize_inputs + def visit_Less(self, op, *, left, right): + return sge.LT(this=left, expression=right) + + @parenthesize_inputs + def visit_Equals(self, op, *, left, right): + return sge.EQ(this=left, expression=right) + + @parenthesize_inputs + def visit_NotEquals(self, op, *, left, right): + return sge.NEQ(this=left, expression=right) + + @parenthesize_inputs + def visit_And(self, op, *, left, right): + return sge.And(this=left, expression=right) + + @parenthesize_inputs + def visit_Or(self, op, *, left, right): + return sge.Or(this=left, expression=right) + + @parenthesize_inputs + def visit_Xor(self, op, *, left, right): + return sge.Xor(this=left, expression=right) + + @parenthesize_inputs + def visit_BitwiseLeftShift(self, op, *, left, right): + return sge.BitwiseLeftShift(this=left, expression=right) + + @parenthesize_inputs + def visit_BitwiseRightShift(self, op, *, left, right): + return sge.BitwiseRightShift(this=left, expression=right) + + @parenthesize_inputs + def visit_BitwiseAnd(self, op, *, left, right): + return sge.BitwiseAnd(this=left, expression=right) + + @parenthesize_inputs + def visit_BitwiseOr(self, op, *, left, right): + return sge.BitwiseOr(this=left, expression=right) + + @parenthesize_inputs + def visit_BitwiseXor(self, op, *, left, right): + return sge.BitwiseXor(this=left, expression=right) + + def visit_Undefined(self, op, **_): + raise com.OperationNotDefinedError( + f"Compilation rule for {type(op).__name__!r} operation is not defined" + ) + + def visit_Unsupported(self, op, **_): + raise com.UnsupportedOperationError( + f"{type(op).__name__!r} operation is not supported in the {self.dialect} backend" + ) + + def visit_DropColumns(self, op, *, parent, columns_to_drop): + # the generated query will be huge for wide tables + # + # TODO: figure out a way to produce an IR that only contains exactly + # what is used + parent_alias = parent.alias_or_name + quoted = self.quoted + columns_to_keep = ( + sg.column(column, table=parent_alias, quoted=quoted) + for column in op.schema.names + ) + return sg.select(*columns_to_keep).from_(parent) + + def add_query_to_expr(self, *, name: str, table: ir.Table, query: str) -> str: + dialect = self.dialect + + compiled_ibis_expr = self.to_sqlglot(table) + + # pull existing CTEs from the compiled Ibis expression and combine them + # with the new query + parsed = reduce( + lambda parsed, cte: parsed.with_(cte.args["alias"], as_=cte.args["this"]), + compiled_ibis_expr.ctes, + sg.parse_one(query, read=dialect), + ) + + # remove all ctes from the compiled expression, since they're now in + # our larger expression + compiled_ibis_expr.args.pop("with", None) + + # add the new str query as a CTE + parsed = parsed.with_( + sg.to_identifier(name, quoted=self.quoted), as_=compiled_ibis_expr + ) + + # generate the SQL string + return parsed.sql(dialect) + + +# `__init_subclass__` is uncalled for subclasses - we manually call it here to +# autogenerate the base class implementations as well. +SQLGlotCompiler.__init_subclass__() diff --git a/third_party/bigframes_vendored/ibis/backends/sql/compilers/bigquery/__init__.py b/third_party/bigframes_vendored/ibis/backends/sql/compilers/bigquery/__init__.py new file mode 100644 index 0000000000..1e015c75dc --- /dev/null +++ b/third_party/bigframes_vendored/ibis/backends/sql/compilers/bigquery/__init__.py @@ -0,0 +1,1149 @@ +# Contains code from https://github.com/ibis-project/ibis/blob/main/ibis/backends/sql/compilers/bigquery/__init__.py +"""Module to convert from Ibis expression to SQL string.""" + +from __future__ import annotations + +import decimal +import math +import re +from typing import Any, TYPE_CHECKING + +import bigframes_vendored.ibis.backends.bigquery.datatypes as bq_datatypes +from bigframes_vendored.ibis.backends.sql.compilers.base import ( + AggGen, + NULL, + SQLGlotCompiler, + STAR, +) +import bigframes_vendored.ibis.backends.sql.compilers.base as sql_compiler +from ibis import util +from ibis.backends.sql.datatypes import BigQueryType, BigQueryUDFType +from ibis.backends.sql.rewrites import ( + exclude_unsupported_window_frame_from_ops, + exclude_unsupported_window_frame_from_rank, + exclude_unsupported_window_frame_from_row_number, +) +import ibis.common.exceptions as com +from ibis.common.temporal import DateUnit, IntervalUnit, TimestampUnit, TimeUnit +import ibis.expr.datatypes as dt +import ibis.expr.operations as ops +import sqlglot as sg +from sqlglot.dialects import BigQuery +import sqlglot.expressions as sge + +if TYPE_CHECKING: + from collections.abc import Mapping + + import ibis.expr.types as ir + + +_NAME_REGEX = re.compile(r'[^!"$()*,./;?@[\\\]^`{}~\n]+') + + +_MEMTABLE_PATTERN = re.compile( + r"^_?ibis_(?:[A-Za-z_][A-Za-z_0-9]*)_memtable_[a-z0-9]{26}$" +) + + +def _qualify_memtable( + node: sge.Expression, *, dataset: str | None, project: str | None +) -> sge.Expression: + """Add a BigQuery dataset and project to memtable references.""" + if isinstance(node, sge.Table) and _MEMTABLE_PATTERN.match(node.name) is not None: + node.args["db"] = dataset + node.args["catalog"] = project + # make sure to quote table location + node = _force_quote_table(node) + return node + + +def _remove_null_ordering_from_unsupported_window( + node: sge.Expression, +) -> sge.Expression: + """Remove null ordering in window frame clauses not supported by BigQuery. + + BigQuery has only partial support for NULL FIRST/LAST in RANGE windows so + we remove it from any window frame clause that doesn't support it. + + Here's the support matrix: + + ✅ sum(x) over (order by y desc nulls last) + 🚫 sum(x) over (order by y asc nulls last) + ✅ sum(x) over (order by y asc nulls first) + 🚫 sum(x) over (order by y desc nulls first) + """ + if isinstance(node, sge.Window): + order = node.args.get("order") + if order is not None: + for key in order.args["expressions"]: + kargs = key.args + if kargs.get("desc") is True and kargs.get("nulls_first", False): + kargs["nulls_first"] = False + elif kargs.get("desc") is False and not kargs.setdefault( + "nulls_first", True + ): + kargs["nulls_first"] = True + return node + + +def _force_quote_table(table: sge.Table) -> sge.Table: + """Force quote all the parts of a bigquery path. + + The BigQuery identifier quoting semantics are bonkers + https://cloud.google.com/bigquery/docs/reference/standard-sql/lexical#identifiers + + my-table is OK, but not mydataset.my-table + + mytable-287 is OK, but not mytable-287a + + Just quote everything. + """ + for key in ("this", "db", "catalog"): + if (val := table.args[key]) is not None: + if isinstance(val, sg.exp.Identifier) and not val.quoted: + val.args["quoted"] = True + else: + table.args[key] = sg.to_identifier(val, quoted=True) + return table + + +class BigQueryCompiler(SQLGlotCompiler): + dialect = BigQuery + type_mapper = BigQueryType + udf_type_mapper = BigQueryUDFType + + agg = AggGen(supports_order_by=True) + + rewrites = ( + exclude_unsupported_window_frame_from_ops, + exclude_unsupported_window_frame_from_row_number, + exclude_unsupported_window_frame_from_rank, + *SQLGlotCompiler.rewrites, + ) + + supports_qualify = True + + UNSUPPORTED_OPS = ( + ops.DateDiff, + ops.ExtractAuthority, + ops.ExtractUserInfo, + ops.FindInSet, + ops.Median, + ops.RegexSplit, + ops.RowID, + ops.TimestampDiff, + ) + + NAN = sge.Cast( + this=sge.convert("NaN"), to=sge.DataType(this=sge.DataType.Type.DOUBLE) + ) + POS_INF = sge.Cast( + this=sge.convert("Infinity"), to=sge.DataType(this=sge.DataType.Type.DOUBLE) + ) + NEG_INF = sge.Cast( + this=sge.convert("-Infinity"), to=sge.DataType(this=sge.DataType.Type.DOUBLE) + ) + + SIMPLE_OPS = { + ops.Arbitrary: "any_value", + ops.StringAscii: "ascii", + ops.BitAnd: "bit_and", + ops.BitOr: "bit_or", + ops.BitXor: "bit_xor", + ops.DateFromYMD: "date", + ops.Divide: "ieee_divide", + ops.EndsWith: "ends_with", + ops.GeoArea: "st_area", + ops.GeoAsBinary: "st_asbinary", + ops.GeoAsText: "st_astext", + ops.GeoAzimuth: "st_azimuth", + ops.GeoBuffer: "st_buffer", + ops.GeoCentroid: "st_centroid", + ops.GeoContains: "st_contains", + ops.GeoCoveredBy: "st_coveredby", + ops.GeoCovers: "st_covers", + ops.GeoDWithin: "st_dwithin", + ops.GeoDifference: "st_difference", + ops.GeoDisjoint: "st_disjoint", + ops.GeoDistance: "st_distance", + ops.GeoEndPoint: "st_endpoint", + ops.GeoEquals: "st_equals", + ops.GeoGeometryType: "st_geometrytype", + ops.GeoIntersection: "st_intersection", + ops.GeoIntersects: "st_intersects", + ops.GeoLength: "st_length", + ops.GeoMaxDistance: "st_maxdistance", + ops.GeoNPoints: "st_numpoints", + ops.GeoPerimeter: "st_perimeter", + ops.GeoPoint: "st_geogpoint", + ops.GeoPointN: "st_pointn", + ops.GeoStartPoint: "st_startpoint", + ops.GeoTouches: "st_touches", + ops.GeoUnaryUnion: "st_union_agg", + ops.GeoUnion: "st_union", + ops.GeoWithin: "st_within", + ops.GeoX: "st_x", + ops.GeoY: "st_y", + ops.Hash: "farm_fingerprint", + ops.IsInf: "is_inf", + ops.IsNan: "is_nan", + ops.Log10: "log10", + ops.LPad: "lpad", + ops.RPad: "rpad", + ops.Levenshtein: "edit_distance", + ops.Modulus: "mod", + ops.RegexReplace: "regexp_replace", + ops.RegexSearch: "regexp_contains", + ops.Time: "time", + ops.TimeFromHMS: "time_from_parts", + ops.TimestampNow: "current_timestamp", + ops.ExtractHost: "net.host", + } + + def to_sqlglot( + self, + expr: ir.Expr, + *, + limit: str | None = None, + params: Mapping[ir.Expr, Any] | None = None, + session_dataset_id: str | None = None, + session_project: str | None = None, + ) -> Any: + """Compile an Ibis expression. + + Parameters + ---------- + expr + Ibis expression + limit + For expressions yielding result sets; retrieve at most this number + of values/rows. Overrides any limit already set on the expression. + params + Named unbound parameters + session_dataset_id + Optional dataset ID to qualify memtable references. + session_project + Optional project ID to qualify memtable references. + + Returns + ------- + Any + The output of compilation. The type of this value depends on the + backend. + + """ + sql = super().to_sqlglot(expr, limit=limit, params=params) + + table_expr = expr.as_table() + geocols = getattr(table_expr.schema(), "geospatial", None) + + result = sql.transform( + _qualify_memtable, + dataset=session_dataset_id, + project=session_project, + ).transform(_remove_null_ordering_from_unsupported_window) + + if geocols: + # if there are any geospatial columns, we have to convert them to WKB, + # so interactive mode knows how to display them + # + # by default bigquery returns data to python as WKT, and there's really + # no point in supporting both if we don't need to. + quoted = self.quoted + result = sg.select( + sge.Star( + replace=[ + self.f.st_asbinary(sg.column(col, quoted=quoted)).as_( + col, quoted=quoted + ) + for col in geocols + ] + ) + ).from_(result.subquery()) + + sources = [] + + for udf_node in table_expr.op().find(ops.ScalarUDF): + compile_func = getattr( + self, f"_compile_{udf_node.__input_type__.name.lower()}_udf" + ) + if sql := compile_func(udf_node): + sources.append(sql) + + if not sources: + return result + + sources.append(result) + return sources + + @staticmethod + def _minimize_spec(start, end, spec): + if ( + start is None + and isinstance(getattr(end, "value", None), ops.Literal) + and end.value.value == 0 + and end.following + ): + return None + return spec + + def visit_BoundingBox(self, op, *, arg): + name = type(op).__name__[len("Geo") :].lower() + return sge.Dot( + this=self.f.st_boundingbox(arg), expression=sg.to_identifier(name) + ) + + visit_GeoXMax = visit_GeoXMin = visit_GeoYMax = visit_GeoYMin = visit_BoundingBox + + def visit_GeoSimplify(self, op, *, arg, tolerance, preserve_collapsed): + if ( + not isinstance(op.preserve_collapsed, ops.Literal) + or op.preserve_collapsed.value + ): + raise com.UnsupportedOperationError( + "BigQuery simplify does not support preserving collapsed geometries, " + "pass preserve_collapsed=False" + ) + return self.f.st_simplify(arg, tolerance) + + def visit_ApproxMedian(self, op, *, arg, where): + return self.agg.approx_quantiles(arg, 2, where=where)[self.f.offset(1)] + + def visit_Pi(self, op): + return self.f.acos(-1) + + def visit_E(self, op): + return self.f.exp(1) + + def visit_TimeDelta(self, op, *, left, right, part): + return self.f.time_diff(left, right, part, dialect=self.dialect) + + def visit_DateDelta(self, op, *, left, right, part): + return self.f.date_diff(left, right, part, dialect=self.dialect) + + def visit_TimestampDelta(self, op, *, left, right, part): + left_tz = op.left.dtype.timezone + right_tz = op.right.dtype.timezone + + if left_tz is None and right_tz is None: + return self.f.datetime_diff(left, right, part) + elif left_tz is not None and right_tz is not None: + return self.f.timestamp_diff(left, right, part) + + raise com.UnsupportedOperationError( + "timestamp difference with mixed timezone/timezoneless values is not implemented" + ) + + def visit_GroupConcat(self, op, *, arg, sep, where, order_by): + if where is not None: + arg = self.if_(where, arg, NULL) + + if order_by: + sep = sge.Order(this=sep, expressions=order_by) + + return sge.GroupConcat(this=arg, separator=sep) + + def visit_ApproxQuantile(self, op, *, arg, quantile, where): + if not isinstance(op.quantile, ops.Literal): + raise com.UnsupportedOperationError( + "quantile must be a literal in BigQuery" + ) + + # BigQuery syntax is `APPROX_QUANTILES(col, resolution)` to return + # `resolution + 1` quantiles array. To handle this, we compute the + # resolution ourselves then restructure the output array as needed. + # To avoid excessive resolution we arbitrarily cap it at 100,000 - + # since these are approximate quantiles anyway this seems fine. + quantiles = util.promote_list(op.quantile.value) + fracs = [decimal.Decimal(str(q)).as_integer_ratio() for q in quantiles] + resolution = min(math.lcm(*(den for _, den in fracs)), 100_000) + indices = [(num * resolution) // den for num, den in fracs] + + if where is not None: + arg = self.if_(where, arg, NULL) + + if not op.arg.dtype.is_floating(): + arg = self.cast(arg, dt.float64) + + array = self.f.approx_quantiles( + arg, sge.IgnoreNulls(this=sge.convert(resolution)) + ) + if isinstance(op, ops.ApproxQuantile): + return array[indices[0]] + + if indices == list(range(resolution + 1)): + return array + else: + return sge.Array(expressions=[array[i] for i in indices]) + + visit_ApproxMultiQuantile = visit_ApproxQuantile + + def visit_FloorDivide(self, op, *, left, right): + return self.cast(self.f.floor(self.f.ieee_divide(left, right)), op.dtype) + + def visit_Log2(self, op, *, arg): + return self.f.log(arg, 2, dialect=self.dialect) + + def visit_Log(self, op, *, arg, base): + if base is None: + return self.f.ln(arg) + return self.f.log(arg, base, dialect=self.dialect) + + def visit_ArrayRepeat(self, op, *, arg, times): + start = step = 1 + array_length = self.f.array_length(arg) + stop = self.f.greatest(times, 0) * array_length + i = sg.to_identifier("i") + idx = self.f.coalesce( + self.f.nullif(self.f.mod(i, array_length), 0), array_length + ) + series = self.f.generate_array(start, stop, step) + return self.f.array( + sg.select(arg[self.f.safe_ordinal(idx)]).from_(self._unnest(series, as_=i)) + ) + + def visit_NthValue(self, op, *, arg, nth): + if not isinstance(op.nth, ops.Literal): + raise com.UnsupportedOperationError( + f"BigQuery `nth` must be a literal; got {type(op.nth)}" + ) + return self.f.nth_value(arg, nth) + + def visit_StrRight(self, op, *, arg, nchars): + return self.f.substr(arg, -self.f.least(self.f.length(arg), nchars)) + + def visit_StringJoin(self, op, *, arg, sep): + return self.f.array_to_string(self.f.array(*arg), sep) + + def visit_DayOfWeekIndex(self, op, *, arg): + return self.f.mod(self.f.extract(self.v.dayofweek, arg) + 5, 7) + + def visit_DayOfWeekName(self, op, *, arg): + return self.f.initcap(sge.Cast(this=arg, to="STRING FORMAT 'DAY'")) + + def visit_StringToTimestamp(self, op, *, arg, format_str): + if (timezone := op.dtype.timezone) is not None: + return self.f.parse_timestamp(format_str, arg, timezone) + return self.f.parse_datetime(format_str, arg) + + def visit_ArrayCollect(self, op, *, arg, where, order_by, include_null): + if where is not None and include_null: + raise com.UnsupportedOperationError( + "Combining `include_null=True` and `where` is not supported " + "by bigquery" + ) + out = self.agg.array_agg(arg, where=where, order_by=order_by) + if not include_null: + out = sge.IgnoreNulls(this=out) + return out + + def _neg_idx_to_pos(self, arg, idx): + return self.if_(idx < 0, self.f.array_length(arg) + idx, idx) + + def visit_ArraySlice(self, op, *, arg, start, stop): + index = sg.to_identifier("bq_arr_slice") + cond = [index >= self._neg_idx_to_pos(arg, start)] + + if stop is not None: + cond.append(index < self._neg_idx_to_pos(arg, stop)) + + el = sg.to_identifier("el") + return self.f.array( + sg.select(el).from_(self._unnest(arg, as_=el, offset=index)).where(*cond) + ) + + def visit_ArrayIndex(self, op, *, arg, index): + return arg[self.f.safe_offset(index)] + + def visit_ArrayContains(self, op, *, arg, other): + name = sg.to_identifier(util.gen_name("bq_arr_contains")) + return sge.Exists( + this=sg.select(sge.convert(1)) + .from_(self._unnest(arg, as_=name)) + .where(name.eq(other)) + ) + + def visit_StringContains(self, op, *, haystack, needle): + return self.f.strpos(haystack, needle) > 0 + + def visti_StringFind(self, op, *, arg, substr, start, end): + if start is not None: + raise NotImplementedError( + "`start` not implemented for BigQuery string find" + ) + if end is not None: + raise NotImplementedError("`end` not implemented for BigQuery string find") + return self.f.strpos(arg, substr) + + def visit_TimestampFromYMDHMS( + self, op, *, year, month, day, hours, minutes, seconds + ): + return self.f.anon.DATETIME(year, month, day, hours, minutes, seconds) + + def visit_NonNullLiteral(self, op, *, value, dtype): + if dtype.is_inet() or dtype.is_macaddr(): + return sge.convert(str(value)) + elif dtype.is_timestamp(): + funcname = "DATETIME" if dtype.timezone is None else "TIMESTAMP" + return self.f.anon[funcname](value.isoformat()) + elif dtype.is_date(): + return self.f.date_from_parts(value.year, value.month, value.day) + elif dtype.is_time(): + time = self.f.time_from_parts(value.hour, value.minute, value.second) + if micros := value.microsecond: + # bigquery doesn't support `time(12, 34, 56.789101)`, AKA a + # float seconds specifier, so add any non-zero micros to the + # time value + return sge.TimeAdd( + this=time, expression=sge.convert(micros), unit=self.v.MICROSECOND + ) + return time + elif dtype.is_binary(): + return sge.Cast( + this=sge.convert(value.hex()), + to=sge.DataType(this=sge.DataType.Type.BINARY), + format=sge.convert("HEX"), + ) + elif dtype.is_interval(): + if dtype.unit == IntervalUnit.NANOSECOND: + raise com.UnsupportedOperationError( + "BigQuery does not support nanosecond intervals" + ) + elif dtype.is_uuid(): + return sge.convert(str(value)) + return None + + def visit_IntervalFromInteger(self, op, *, arg, unit): + if unit == IntervalUnit.NANOSECOND: + raise com.UnsupportedOperationError( + "BigQuery does not support nanosecond intervals" + ) + return sge.Interval(this=arg, unit=self.v[unit.singular]) + + def visit_Strftime(self, op, *, arg, format_str): + arg_dtype = op.arg.dtype + if arg_dtype.is_timestamp(): + if (timezone := arg_dtype.timezone) is None: + return self.f.format_datetime(format_str, arg) + else: + return self.f.format_timestamp(format_str, arg, timezone) + elif arg_dtype.is_date(): + return self.f.format_date(format_str, arg) + else: + assert arg_dtype.is_time(), arg_dtype + return self.f.format_time(format_str, arg) + + def visit_IntervalMultiply(self, op, *, left, right): + unit = self.v[op.left.dtype.resolution.upper()] + return sge.Interval(this=self.f.extract(unit, left) * right, unit=unit) + + def visit_TimestampFromUNIX(self, op, *, arg, unit): + unit = op.unit + if unit == TimestampUnit.SECOND: + return self.f.timestamp_seconds(arg) + elif unit == TimestampUnit.MILLISECOND: + return self.f.timestamp_millis(arg) + elif unit == TimestampUnit.MICROSECOND: + return self.f.timestamp_micros(arg) + elif unit == TimestampUnit.NANOSECOND: + return self.f.timestamp_micros( + self.cast(self.f.round(arg / 1_000), dt.int64) + ) + else: + raise com.UnsupportedOperationError(f"Unit not supported: {unit}") + + def visit_Cast(self, op, *, arg, to): + from_ = op.arg.dtype + if from_.is_timestamp() and to.is_integer(): + return self.f.unix_micros(arg) + elif from_.is_integer() and to.is_timestamp(): + return self.f.timestamp_seconds(arg) + elif from_.is_interval() and to.is_integer(): + if from_.unit in { + IntervalUnit.WEEK, + IntervalUnit.QUARTER, + IntervalUnit.NANOSECOND, + }: + raise com.UnsupportedOperationError( + f"BigQuery does not allow extracting date part `{from_.unit}` from intervals" + ) + return self.f.extract(self.v[to.resolution.upper()], arg) + elif from_.is_floating() and to.is_integer(): + return self.cast(self.f.trunc(arg), dt.int64) + return super().visit_Cast(op, arg=arg, to=to) + + def visit_JSONGetItem(self, op, *, arg, index): + return arg[index] + + def visit_UnwrapJSONString(self, op, *, arg): + return self.f.anon["safe.string"](arg) + + def visit_UnwrapJSONInt64(self, op, *, arg): + return self.f.anon["safe.int64"](arg) + + def visit_UnwrapJSONFloat64(self, op, *, arg): + return self.f.anon["safe.float64"](arg) + + def visit_UnwrapJSONBoolean(self, op, *, arg): + return self.f.anon["safe.bool"](arg) + + def visit_ExtractEpochSeconds(self, op, *, arg): + return self.f.unix_seconds(arg) + + def visit_ExtractWeekOfYear(self, op, *, arg): + return self.f.extract(self.v.isoweek, arg) + + def visit_ExtractIsoYear(self, op, *, arg): + return self.f.extract(self.v.isoyear, arg) + + def visit_ExtractMillisecond(self, op, *, arg): + return self.f.extract(self.v.millisecond, arg) + + def visit_ExtractMicrosecond(self, op, *, arg): + return self.f.extract(self.v.microsecond, arg) + + def visit_TimestampTruncate(self, op, *, arg, unit): + if unit == IntervalUnit.NANOSECOND: + raise com.UnsupportedOperationError( + f"BigQuery does not support truncating {op.arg.dtype} values to unit {unit!r}" + ) + elif unit == IntervalUnit.WEEK: + unit = "WEEK(MONDAY)" + else: + unit = unit.name + return self.f.timestamp_trunc(arg, self.v[unit], dialect=self.dialect) + + def visit_DateTruncate(self, op, *, arg, unit): + if unit == DateUnit.WEEK: + unit = "WEEK(MONDAY)" + else: + unit = unit.name + return self.f.date_trunc(arg, self.v[unit], dialect=self.dialect) + + def visit_TimeTruncate(self, op, *, arg, unit): + if unit == TimeUnit.NANOSECOND: + raise com.UnsupportedOperationError( + f"BigQuery does not support truncating {op.arg.dtype} values to unit {unit!r}" + ) + else: + unit = unit.name + return self.f.time_trunc(arg, self.v[unit], dialect=self.dialect) + + def _nullifzero(self, step, zero, step_dtype): + if step_dtype.is_interval(): + return self.if_(step.eq(zero), NULL, step) + return self.f.nullif(step, zero) + + def _zero(self, dtype): + if dtype.is_interval(): + return self.f.make_interval() + return sge.convert(0) + + def _sign(self, value, dtype): + if dtype.is_interval(): + zero = self._zero(dtype) + return sge.Case( + ifs=[ + self.if_(value < zero, -1), + self.if_(value.eq(zero), 0), + self.if_(value > zero, 1), + ], + default=NULL, + ) + return self.f.sign(value) + + def _make_range(self, func, start, stop, step, step_dtype): + step_sign = self._sign(step, step_dtype) + delta_sign = self._sign(stop - start, step_dtype) + zero = self._zero(step_dtype) + nullifzero = self._nullifzero(step, zero, step_dtype) + condition = sg.and_(sg.not_(nullifzero.is_(NULL)), step_sign.eq(delta_sign)) + gen_array = func(start, stop, step) + name = sg.to_identifier(util.gen_name("bq_arr_range")) + inner = ( + sg.select(name) + .from_(self._unnest(gen_array, as_=name)) + .where(name.neq(stop)) + ) + return self.if_(condition, self.f.array(inner), self.f.array()) + + def visit_IntegerRange(self, op, *, start, stop, step): + return self._make_range(self.f.generate_array, start, stop, step, op.step.dtype) + + def visit_TimestampRange(self, op, *, start, stop, step): + if op.start.dtype.timezone is None or op.stop.dtype.timezone is None: + raise com.IbisTypeError( + "Timestamps without timezone values are not supported when generating timestamp ranges" + ) + return self._make_range( + self.f.generate_timestamp_array, start, stop, step, op.step.dtype + ) + + def visit_First(self, op, *, arg, where, order_by, include_null): + if where is not None: + arg = self.if_(where, arg, NULL) + if include_null: + raise com.UnsupportedOperationError( + "Combining `include_null=True` and `where` is not supported " + "by bigquery" + ) + + if order_by: + arg = sge.Order(this=arg, expressions=order_by) + + if not include_null: + arg = sge.IgnoreNulls(this=arg) + + array = self.f.array_agg(sge.Limit(this=arg, expression=sge.convert(1))) + return array[self.f.safe_offset(0)] + + def visit_Last(self, op, *, arg, where, order_by, include_null): + if where is not None: + arg = self.if_(where, arg, NULL) + if include_null: + raise com.UnsupportedOperationError( + "Combining `include_null=True` and `where` is not supported " + "by bigquery" + ) + + if order_by: + arg = sge.Order(this=arg, expressions=order_by) + + if not include_null: + arg = sge.IgnoreNulls(this=arg) + + array = self.f.array_reverse(self.f.array_agg(arg)) + return array[self.f.safe_offset(0)] + + def visit_ArrayFilter(self, op, *, arg, body, param): + return self.f.array( + sg.select(param).from_(self._unnest(arg, as_=param)).where(body) + ) + + def visit_ArrayMap(self, op, *, arg, body, param): + return self.f.array(sg.select(body).from_(self._unnest(arg, as_=param))) + + def visit_ArrayZip(self, op, *, arg): + lengths = [self.f.array_length(arr) - 1 for arr in arg] + idx = sg.to_identifier(util.gen_name("bq_arr_idx")) + indices = self._unnest( + self.f.generate_array(0, self.f.greatest(*lengths)), as_=idx + ) + struct_fields = [ + arr[self.f.safe_offset(idx)].as_(name) + for name, arr in zip(op.dtype.value_type.names, arg) + ] + return self.f.array( + sge.Select(kind="STRUCT", expressions=struct_fields).from_(indices) + ) + + def visit_ArrayPosition(self, op, *, arg, other): + name = sg.to_identifier(util.gen_name("bq_arr")) + idx = sg.to_identifier(util.gen_name("bq_arr_idx")) + unnest = self._unnest(arg, as_=name, offset=idx) + return self.f.coalesce( + sg.select(idx + 1).from_(unnest).where(name.eq(other)).limit(1).subquery(), + 0, + ) + + def _unnest(self, expression, *, as_, offset=None): + alias = sge.TableAlias(columns=[sg.to_identifier(as_)]) + return sge.Unnest(expressions=[expression], alias=alias, offset=offset) + + def visit_ArrayRemove(self, op, *, arg, other): + name = sg.to_identifier(util.gen_name("bq_arr")) + unnest = self._unnest(arg, as_=name) + both_null = sg.and_(name.is_(NULL), other.is_(NULL)) + cond = sg.or_(name.neq(other), both_null) + return self.f.array(sg.select(name).from_(unnest).where(cond)) + + def visit_ArrayDistinct(self, op, *, arg): + name = util.gen_name("bq_arr") + return self.f.array( + sg.select(name).distinct().from_(self._unnest(arg, as_=name)) + ) + + def visit_ArraySort(self, op, *, arg): + name = util.gen_name("bq_arr") + return self.f.array( + sg.select(name).from_(self._unnest(arg, as_=name)).order_by(name) + ) + + def visit_ArrayUnion(self, op, *, left, right): + lname = util.gen_name("bq_arr_left") + rname = util.gen_name("bq_arr_right") + lhs = sg.select(lname).from_(self._unnest(left, as_=lname)) + rhs = sg.select(rname).from_(self._unnest(right, as_=rname)) + return self.f.array(sg.union(lhs, rhs, distinct=True)) + + def visit_ArrayIntersect(self, op, *, left, right): + lname = util.gen_name("bq_arr_left") + rname = util.gen_name("bq_arr_right") + lhs = sg.select(lname).from_(self._unnest(left, as_=lname)) + rhs = sg.select(rname).from_(self._unnest(right, as_=rname)) + return self.f.array(sg.intersect(lhs, rhs, distinct=True)) + + def visit_RegexExtract(self, op, *, arg, pattern, index): + matches = self.f.regexp_contains(arg, pattern) + nonzero_index_replace = self.f.regexp_replace( + arg, + self.f.concat(".*?", pattern, ".*"), + self.f.concat("\\", self.cast(index, dt.string)), + ) + zero_index_replace = self.f.regexp_replace( + arg, self.f.concat(".*?", self.f.concat("(", pattern, ")"), ".*"), "\\1" + ) + extract = self.if_(index.eq(0), zero_index_replace, nonzero_index_replace) + return self.if_(matches, extract, NULL) + + def visit_TimestampAddSub(self, op, *, left, right): + if not isinstance(right, sge.Interval): + raise com.OperationNotDefinedError( + "BigQuery does not support non-literals on the right side of timestamp add/subtract" + ) + if (unit := op.right.dtype.unit) == IntervalUnit.NANOSECOND: + raise com.UnsupportedOperationError( + f"BigQuery does not allow binary operation {type(op).__name__} with " + f"INTERVAL offset {unit}" + ) + + opname = type(op).__name__[len("Timestamp") :] + funcname = f"TIMESTAMP_{opname.upper()}" + return self.f.anon[funcname](left, right) + + visit_TimestampAdd = visit_TimestampSub = visit_TimestampAddSub + + def visit_DateAddSub(self, op, *, left, right): + if not isinstance(right, sge.Interval): + raise com.OperationNotDefinedError( + "BigQuery does not support non-literals on the right side of date add/subtract" + ) + if not (unit := op.right.dtype.unit).is_date(): + raise com.UnsupportedOperationError( + f"BigQuery does not allow binary operation {type(op).__name__} with " + f"INTERVAL offset {unit}" + ) + opname = type(op).__name__[len("Date") :] + funcname = f"DATE_{opname.upper()}" + return self.f.anon[funcname](left, right) + + visit_DateAdd = visit_DateSub = visit_DateAddSub + + def visit_Covariance(self, op, *, left, right, how, where): + if where is not None: + left = self.if_(where, left, NULL) + right = self.if_(where, right, NULL) + + if op.left.dtype.is_boolean(): + left = self.cast(left, dt.int64) + + if op.right.dtype.is_boolean(): + right = self.cast(right, dt.int64) + + how = op.how[:4].upper() + assert how in ("POP", "SAMP"), 'how not in ("POP", "SAMP")' + return self.agg[f"COVAR_{how}"](left, right, where=where) + + def visit_Correlation(self, op, *, left, right, how, where): + if how == "sample": + raise ValueError(f"Correlation with how={how!r} is not supported.") + + if where is not None: + left = self.if_(where, left, NULL) + right = self.if_(where, right, NULL) + + if op.left.dtype.is_boolean(): + left = self.cast(left, dt.int64) + + if op.right.dtype.is_boolean(): + right = self.cast(right, dt.int64) + + return self.agg.corr(left, right, where=where) + + def visit_TypeOf(self, op, *, arg): + return self._pudf("typeof", arg) + + def visit_Xor(self, op, *, left, right): + return sg.or_(sg.and_(left, sg.not_(right)), sg.and_(sg.not_(left), right)) + + def visit_HashBytes(self, op, *, arg, how): + if how not in ("md5", "sha1", "sha256", "sha512"): + raise NotImplementedError(how) + return self.f[how](arg) + + @staticmethod + def _gen_valid_name(name: str) -> str: + return "_".join(map(str.strip, _NAME_REGEX.findall(name))) or "tmp" + + def visit_CountStar(self, op, *, arg, where): + if where is not None: + return self.f.countif(where) + return self.f.count(STAR) + + def visit_CountDistinctStar(self, op, *, where, arg): + # Bigquery does not support count(distinct a,b,c) or count(distinct (a, b, c)) + # as expressions must be "groupable": + # https://cloud.google.com/bigquery/docs/reference/standard-sql/query-syntax#group_by_grouping_item + # + # Instead, convert the entire expression to a string + # SELECT COUNT(DISTINCT concat(to_json_string(a), to_json_string(b))) + # This works with an array of datatypes which generates a unique string + # https://cloud.google.com/bigquery/docs/reference/standard-sql/json_functions#json_encodings + row = sge.Concat( + expressions=[ + self.f.to_json_string(sg.column(x, quoted=self.quoted)) + for x in op.arg.schema.keys() + ] + ) + if where is not None: + row = self.if_(where, row, NULL) + return self.f.count(sge.Distinct(expressions=[row])) + + def visit_Degrees(self, op, *, arg): + return self._pudf("degrees", arg) + + def visit_Radians(self, op, *, arg): + return self._pudf("radians", arg) + + def visit_CountDistinct(self, op, *, arg, where): + if where is not None: + arg = self.if_(where, arg, NULL) + return self.f.count(sge.Distinct(expressions=[arg])) + + def visit_RandomUUID(self, op, **kwargs): + return self.f.generate_uuid() + + def visit_ExtractFile(self, op, *, arg): + return self._pudf("cw_url_extract_file", arg) + + def visit_ExtractFragment(self, op, *, arg): + return self._pudf("cw_url_extract_fragment", arg) + + def visit_ExtractPath(self, op, *, arg): + return self._pudf("cw_url_extract_path", arg) + + def visit_ExtractProtocol(self, op, *, arg): + return self._pudf("cw_url_extract_protocol", arg) + + def visit_ExtractQuery(self, op, *, arg, key): + if key is not None: + return self._pudf("cw_url_extract_parameter", arg, key) + else: + return self._pudf("cw_url_extract_query", arg) + + def _pudf(self, name, *args): + name = sg.table(name, db="persistent_udfs", catalog="bigquery-public-data").sql( + self.dialect + ) + return self.f[name](*args) + + def visit_DropColumns(self, op, *, parent, columns_to_drop): + quoted = self.quoted + excludes = [sg.column(column, quoted=quoted) for column in columns_to_drop] + star = sge.Star(**{"except": excludes}) + table = sg.to_identifier(parent.alias_or_name, quoted=quoted) + column = sge.Column(this=star, table=table) + return sg.select(column).from_(parent) + + def visit_TableUnnest( + self, op, *, parent, column, offset: str | None, keep_empty: bool + ): + quoted = self.quoted + + column_alias = sg.to_identifier( + util.gen_name("table_unnest_column"), quoted=quoted + ) + + selcols = [] + + table = sg.to_identifier(parent.alias_or_name, quoted=quoted) + + opname = op.column.name + overlaps_with_parent = opname in op.parent.schema + computed_column = column_alias.as_(opname, quoted=quoted) + + # replace the existing column if the unnested column hasn't been + # renamed + # + # e.g., table.unnest("x") + if overlaps_with_parent: + selcols.append( + sge.Column(this=sge.Star(replace=[computed_column]), table=table) + ) + else: + selcols.append(sge.Column(this=STAR, table=table)) + selcols.append(computed_column) + + if offset is not None: + offset = sg.to_identifier(offset, quoted=quoted) + selcols.append(offset) + + unnest = sge.Unnest( + expressions=[column], + alias=sge.TableAlias(columns=[column_alias]), + offset=offset, + ) + return ( + sg.select(*selcols) + .from_(parent) + .join(unnest, join_type="CROSS" if not keep_empty else "LEFT") + ) + + def visit_TimestampBucket(self, op, *, arg, interval, offset): + arg_dtype = op.arg.dtype + if arg_dtype.timezone is not None: + funcname = "timestamp" + else: + funcname = "datetime" + + func = self.f[f"{funcname}_bucket"] + + origin = sge.convert("1970-01-01") + if offset is not None: + origin = self.f.anon[f"{funcname}_add"](origin, offset) + + return func(arg, interval, origin) + + def _array_reduction(self, *, arg, reduction): + name = sg.to_identifier(util.gen_name(f"bq_arr_{reduction}")) + return ( + sg.select(self.f[reduction](name)) + .from_(self._unnest(arg, as_=name)) + .subquery() + ) + + def visit_ArrayMin(self, op, *, arg): + return self._array_reduction(arg=arg, reduction="min") + + def visit_ArrayMax(self, op, *, arg): + return self._array_reduction(arg=arg, reduction="max") + + def visit_ArraySum(self, op, *, arg): + return self._array_reduction(arg=arg, reduction="sum") + + def visit_ArrayMean(self, op, *, arg): + return self._array_reduction(arg=arg, reduction="avg") + + def visit_ArrayAny(self, op, *, arg): + return self._array_reduction(arg=arg, reduction="logical_or") + + def visit_ArrayAll(self, op, *, arg): + return self._array_reduction(arg=arg, reduction="logical_and") + + # Customized ops for bigframes + + def visit_InMemoryTable(self, op, *, name, schema, data): + # Avoid creating temp tables for small data, which is how memtable is + # used in BigQuery DataFrames. Inspired by: + # https://github.com/ibis-project/ibis/blob/efa6fb72bf4c790450d00a926d7bd809dade5902/ibis/backends/druid/compiler.py#L95 + tuples = data.to_frame().itertuples(index=False) + quoted = self.quoted + columns = [sg.column(col, quoted=quoted) for col in schema.names] + array_expr = sge.DataType( + this=sge.DataType.Type.STRUCT, + expressions=[ + sge.ColumnDef( + this=sge.to_identifier(field, quoted=self.quoted), + kind=bq_datatypes.BigQueryType.from_ibis(type_), + ) + for field, type_ in zip(schema.names, schema.types) + ], + nested=True, + ) + array_values = [ + sge.Tuple( + expressions=tuple( + self.visit_Literal(None, value=value, dtype=type_) + for value, type_ in zip(row, schema.types) + ) + ) + for row in tuples + ] + expr = sge.Unnest( + expressions=[ + sge.DataType( + this=sge.DataType.Type.ARRAY, + expressions=[array_expr], + nested=True, + values=array_values, + ), + ], + alias=sge.TableAlias( + this=sg.to_identifier(name, quoted=quoted), + columns=columns, + ), + ) + # return expr + return sg.select(sge.Star()).from_(expr) + + def visit_ArrayAggregate(self, op, *, arg, order_by, where): + if len(order_by) > 0: + expr = sge.Order( + this=arg, + expressions=[ + # Avoid adding NULLS FIRST / NULLS LAST in SQL, which is + # unsupported in ARRAY_AGG by reconstructing the node as + # plain SQL text. + f"({order_column.args['this'].sql(dialect='bigquery')}) {'DESC' if order_column.args.get('desc') else 'ASC'}" + for order_column in order_by + ], + ) + else: + expr = arg + return sge.IgnoreNulls(this=self.agg.array_agg(expr, where=where)) + + def visit_FirstNonNullValue(self, op, *, arg): + return sge.IgnoreNulls(this=sge.FirstValue(this=arg)) + + def visit_LastNonNullValue(self, op, *, arg): + return sge.IgnoreNulls(this=sge.LastValue(this=arg)) + + def visit_ToJsonString(self, op, *, arg): + return self.f.to_json_string(arg) + + def visit_Quantile(self, op, *, arg, quantile, where): + return sge.PercentileCont(this=arg, expression=quantile) + + def visit_WindowFunction(self, op, *, how, func, start, end, group_by, order_by): + # Patch for https://github.com/ibis-project/ibis/issues/9872 + if start is None and end is None: + spec = None + else: + if start is None: + start = {} + if end is None: + end = {} + + start_value = start.get("value", "UNBOUNDED") + start_side = start.get("side", "PRECEDING") + end_value = end.get("value", "UNBOUNDED") + end_side = end.get("side", "FOLLOWING") + + if getattr(start_value, "this", None) == "0": + start_value = "CURRENT ROW" + start_side = None + + if getattr(end_value, "this", None) == "0": + end_value = "CURRENT ROW" + end_side = None + + spec = sge.WindowSpec( + kind=how.upper(), + start=start_value, + start_side=start_side, + end=end_value, + end_side=end_side, + over="OVER", + ) + spec = self._minimize_spec(op.start, op.end, spec) + + order = sge.Order(expressions=order_by) if order_by else None + + return sge.Window(this=func, partition_by=group_by, order=order, spec=spec) + + @sql_compiler.parenthesize_inputs + def visit_And(self, op, *, left, right): + return sge.And(this=sge.Paren(this=left), expression=sge.Paren(this=right)) + + +compiler = BigQueryCompiler() diff --git a/third_party/bigframes_vendored/ibis/backends/sql/rewrites.py b/third_party/bigframes_vendored/ibis/backends/sql/rewrites.py new file mode 100644 index 0000000000..26c03c3752 --- /dev/null +++ b/third_party/bigframes_vendored/ibis/backends/sql/rewrites.py @@ -0,0 +1,514 @@ +"""Lower the ibis expression graph to a SQL-like relational algebra.""" + +from __future__ import annotations + +from collections.abc import Mapping +from functools import reduce +import operator +from typing import Any, TYPE_CHECKING + +from ibis.common.annotations import attribute +from ibis.common.collections import FrozenDict # noqa: TCH001 +from ibis.common.deferred import var +import ibis.common.exceptions as com +from ibis.common.graph import Graph +from ibis.common.patterns import InstanceOf, Object, Pattern, replace +from ibis.common.typing import VarTuple # noqa: TCH001 +import ibis.expr.datatypes as dt +import ibis.expr.operations as ops +from ibis.expr.rewrites import d, p, replace_parameter +from ibis.expr.schema import Schema +from public import public +import toolz + +if TYPE_CHECKING: + from collections.abc import Sequence + +x = var("x") +y = var("y") + + +@public +class CTE(ops.Relation): + """Common table expression.""" + + parent: ops.Relation + + @attribute + def schema(self): + return self.parent.schema + + @attribute + def values(self): + return self.parent.values + + +@public +class Select(ops.Relation): + """Relation modelled after SQL's SELECT statement.""" + + parent: ops.Relation + selections: FrozenDict[str, ops.Value] = {} + predicates: VarTuple[ops.Value[dt.Boolean]] = () + qualified: VarTuple[ops.Value[dt.Boolean]] = () + sort_keys: VarTuple[ops.SortKey] = () + + def is_star_selection(self): + return tuple(self.values.items()) == tuple(self.parent.fields.items()) + + @attribute + def values(self): + return self.selections + + @attribute + def schema(self): + return Schema({k: v.dtype for k, v in self.selections.items()}) + + +@public +class FirstValue(ops.Analytic): + """Retrieve the first element.""" + + arg: ops.Column[dt.Any] + + @attribute + def dtype(self): + return self.arg.dtype + + +@public +class LastValue(ops.Analytic): + """Retrieve the last element.""" + + arg: ops.Column[dt.Any] + + @attribute + def dtype(self): + return self.arg.dtype + + +# TODO(kszucs): there is a better strategy to rewrite the relational operations +# to Select nodes by wrapping the leaf nodes in a Select node and then merging +# Project, Filter, Sort, etc. incrementally into the Select node. This way we +# can have tighter control over simplification logic. + + +@replace(p.Project) +def project_to_select(_, **kwargs): + """Convert a Project node to a Select node.""" + return Select(_.parent, selections=_.values) + + +def partition_predicates(predicates): + qualified = [] + unqualified = [] + + for predicate in predicates: + if predicate.find(ops.WindowFunction, filter=ops.Value): + qualified.append(predicate) + else: + unqualified.append(predicate) + + return unqualified, qualified + + +@replace(p.Filter) +def filter_to_select(_, **kwargs): + """Convert a Filter node to a Select node.""" + predicates, qualified = partition_predicates(_.predicates) + return Select( + _.parent, selections=_.values, predicates=predicates, qualified=qualified + ) + + +@replace(p.Sort) +def sort_to_select(_, **kwargs): + """Convert a Sort node to a Select node.""" + return Select(_.parent, selections=_.values, sort_keys=_.keys) + + +if hasattr(p, "DropColumns"): + + @replace(p.DropColumns) + def drop_columns_to_select(_, **kwargs): + """Convert a DropColumns node to a Select node.""" + # if we're dropping fewer than 50% of the parent table's columns then the + # compiled query will likely be smaller than if we list everything *NOT* + # being dropped + if len(_.columns_to_drop) < len(_.schema) // 2: + return _ + return Select(_.parent, selections=_.values) + + +if hasattr(p, "FillNull"): + + @replace(p.FillNull) + def fill_null_to_select(_, **kwargs): + """Rewrite FillNull to a Select node.""" + if isinstance(_.replacements, Mapping): + mapping = _.replacements + else: + mapping = { + name: _.replacements + for name, type in _.parent.schema.items() + if type.nullable + } + + if not mapping: + return _.parent + + selections = {} + for name in _.parent.schema.names: + col = ops.Field(_.parent, name) + if (value := mapping.get(name)) is not None: + col = ops.Alias(ops.Coalesce((col, value)), name) + selections[name] = col + + return Select(_.parent, selections=selections) + + +if hasattr(p, "DropNull"): + + @replace(p.DropNull) + def drop_null_to_select(_, **kwargs): + """Rewrite DropNull to a Select node.""" + if _.subset is None: + columns = [ops.Field(_.parent, name) for name in _.parent.schema.names] + else: + columns = _.subset + + if columns: + preds = [ + reduce( + ops.And if _.how == "any" else ops.Or, + [ops.NotNull(c) for c in columns], + ) + ] + elif _.how == "all": + preds = [ops.Literal(False, dtype=dt.bool)] + else: + return _.parent + + return Select(_.parent, selections=_.values, predicates=tuple(preds)) + + +@replace(p.WindowFunction(p.First | p.Last)) +def first_to_firstvalue(_, **kwargs): + """Convert a First or Last node to a FirstValue or LastValue node.""" + if _.func.where is not None: + raise com.UnsupportedOperationError( + f"`{type(_.func).__name__.lower()}` with `where` is unsupported " + "in a window function" + ) + klass = FirstValue if isinstance(_.func, ops.First) else LastValue + return _.copy(func=klass(_.func.arg)) + + +def complexity(node): + """Assign a complexity score to a node. + + Subsequent projections can be merged into a single projection by replacing + the fields referenced in the outer projection with the computed expressions + from the inner projection. This inlining can result in very complex value + expressions depending on the projections. In order to prevent excessive + inlining, we assign a complexity score to each node. + + The complexity score assigns 1 to each value expression and adds up in the + tree hierarchy unless there is a Field node where we don't add up the + complexity of the referenced relation. This way we treat fields kind of like + reusable variables considering them less complex than they were inlined. + """ + + def accum(node, *args): + if isinstance(node, ops.Field): + return 1 + else: + return 1 + sum(args) + + return node.map_nodes(accum)[node] + + +@replace(Object(Select, Object(Select))) +def merge_select_select(_, **kwargs): + """Merge subsequent Select relations into one. + + This rewrites eliminates `_.parent` by merging the outer and the inner + `predicates`, `sort_keys` and keeping the outer `selections`. All selections + from the inner Select are inlined into the outer Select. + """ + # don't merge if either the outer or the inner select has window functions + blocking = ( + ops.WindowFunction, + ops.ExistsSubquery, + ops.InSubquery, + ops.Unnest, + ops.Impure, + ) + if _.find_below(blocking, filter=ops.Value): + return _ + if _.parent.find_below(blocking, filter=ops.Value): + return _ + + subs = {ops.Field(_.parent, k): v for k, v in _.parent.values.items()} + selections = {k: v.replace(subs, filter=ops.Value) for k, v in _.selections.items()} + + predicates = tuple(p.replace(subs, filter=ops.Value) for p in _.predicates) + unique_predicates = toolz.unique(_.parent.predicates + predicates) + + qualified = tuple(p.replace(subs, filter=ops.Value) for p in _.qualified) + unique_qualified = toolz.unique(_.parent.qualified + qualified) + + sort_keys = tuple(s.replace(subs, filter=ops.Value) for s in _.sort_keys) + sort_key_exprs = {s.expr for s in sort_keys} + parent_sort_keys = tuple( + k for k in _.parent.sort_keys if k.expr not in sort_key_exprs + ) + unique_sort_keys = sort_keys + parent_sort_keys + + result = Select( + _.parent.parent, + selections=selections, + predicates=unique_predicates, + qualified=unique_qualified, + sort_keys=unique_sort_keys, + ) + return result if complexity(result) <= complexity(_) else _ + + +def extract_ctes(node: ops.Relation) -> set[ops.Relation]: + cte_types = (Select, ops.Aggregate, ops.JoinChain, ops.Set, ops.Limit, ops.Sample) + dont_count = (ops.Field, ops.CountStar, ops.CountDistinctStar) + + g = Graph.from_bfs(node, filter=~InstanceOf(dont_count)) + result = set() + for op, dependents in g.invert().items(): + if isinstance(op, ops.View) or ( + len(dependents) > 1 and isinstance(op, cte_types) + ): + result.add(op) + + return result + + +def sqlize( + node: ops.Node, + params: Mapping[ops.ScalarParameter, Any], + rewrites: Sequence[Pattern] = (), + fuse_selects: bool = True, +) -> tuple[ops.Node, list[ops.Node]]: + """Lower the ibis expression graph to a SQL-like relational algebra. + + Parameters + ---------- + node + The root node of the expression graph. + params + A mapping of scalar parameters to their values. + rewrites + Supplementary rewrites to apply to the expression graph. + fuse_selects + Whether to merge subsequent Select nodes into one where possible. + + Returns + ------- + Tuple of the rewritten expression graph and a list of CTEs. + + """ + assert isinstance(node, ops.Relation) + + # apply the backend specific rewrites + if rewrites: + node = node.replace(reduce(operator.or_, rewrites)) + + # lower the expression graph to a SQL-like relational algebra + context = {"params": params} + replacements = ( + replace_parameter | project_to_select | filter_to_select | sort_to_select + ) + + if hasattr(p, "FillNull"): + replacements = replacements | fill_null_to_select + + if hasattr(p, "DropNull"): + replacements = replacements | drop_null_to_select + + if hasattr(p, "DropColumns"): + replacements = replacements | drop_columns_to_select + + replacements = replacements | first_to_firstvalue + sqlized = node.replace( + replacements, + context=context, + ) + + # squash subsequent Select nodes into one + if fuse_selects: + simplified = sqlized.replace(merge_select_select) + else: + simplified = sqlized + + # extract common table expressions while wrapping them in a CTE node + ctes = extract_ctes(simplified) + + def wrap(node, _, **kwargs): + new = node.__recreate__(kwargs) + return CTE(new) if node in ctes else new + + result = simplified.replace(wrap) + ctes = reversed([cte.parent for cte in result.find(CTE)]) + + return result, ctes + + +# supplemental rewrites selectively used on a per-backend basis + + +@replace(p.WindowFunction(func=p.NTile(y), order_by=())) +def add_order_by_to_empty_ranking_window_functions(_, **kwargs): + """Add an ORDER BY clause to rank window functions that don't have one.""" + return _.copy(order_by=(y,)) + + +"""Replace checks against an empty right side with `False`.""" +empty_in_values_right_side = p.InValues(options=()) >> d.Literal(False, dtype=dt.bool) + + +@replace( + p.WindowFunction(p.RankBase | p.NTile) + | p.StringFind + | p.FindInSet + | p.ArrayPosition +) +def one_to_zero_index(_, **kwargs): + """Subtract one from one-index functions.""" + return ops.Subtract(_, 1) + + +@replace(ops.NthValue) +def add_one_to_nth_value_input(_, **kwargs): + if isinstance(_.nth, ops.Literal): + nth = ops.Literal(_.nth.value + 1, dtype=_.nth.dtype) + else: + nth = ops.Add(_.nth, 1) + return _.copy(nth=nth) + + +@replace(p.WindowFunction(order_by=())) +def rewrite_empty_order_by_window(_, **kwargs): + return _.copy(order_by=(ops.NULL,)) + + +@replace(p.WindowFunction(p.RowNumber | p.NTile)) +def exclude_unsupported_window_frame_from_row_number(_, **kwargs): + return ops.Subtract(_.copy(start=None, end=0), 1) + + +@replace(p.WindowFunction(p.MinRank | p.DenseRank, start=None)) +def exclude_unsupported_window_frame_from_rank(_, **kwargs): + return ops.Subtract( + _.copy(start=None, end=0, order_by=_.order_by or (ops.NULL,)), 1 + ) + + +@replace( + p.WindowFunction( + p.Lag | p.Lead | p.PercentRank | p.CumeDist | p.Any | p.All, start=None + ) +) +def exclude_unsupported_window_frame_from_ops(_, **kwargs): + return _.copy(start=None, end=0, order_by=_.order_by or (ops.NULL,)) + + +# Rewrite rules for lowering a high-level operation into one composed of more +# primitive operations. + + +@replace(p.Log2) +def lower_log2(_, **kwargs): + """Rewrite `log2` as `log`.""" + return ops.Log(_.arg, base=2) + + +@replace(p.Log10) +def lower_log10(_, **kwargs): + """Rewrite `log10` as `log`.""" + return ops.Log(_.arg, base=10) + + +@replace(p.Bucket) +def lower_bucket(_, **kwargs): + """Rewrite `Bucket` as `SearchedCase`.""" + cases = [] + results = [] + + if _.closed == "left": + l_cmp = ops.LessEqual + r_cmp = ops.Less + else: + l_cmp = ops.Less + r_cmp = ops.LessEqual + + user_num_buckets = len(_.buckets) - 1 + + bucket_id = 0 + if _.include_under: + if user_num_buckets > 0: + cmp = ops.Less if _.close_extreme else r_cmp + else: + cmp = ops.LessEqual if _.closed == "right" else ops.Less + cases.append(cmp(_.arg, _.buckets[0])) + results.append(bucket_id) + bucket_id += 1 + + for j, (lower, upper) in enumerate(zip(_.buckets, _.buckets[1:])): + if _.close_extreme and ( + (_.closed == "right" and j == 0) + or (_.closed == "left" and j == (user_num_buckets - 1)) + ): + cases.append( + ops.And(ops.LessEqual(lower, _.arg), ops.LessEqual(_.arg, upper)) + ) + results.append(bucket_id) + else: + cases.append(ops.And(l_cmp(lower, _.arg), r_cmp(_.arg, upper))) + results.append(bucket_id) + bucket_id += 1 + + if _.include_over: + if user_num_buckets > 0: + cmp = ops.Less if _.close_extreme else l_cmp + else: + cmp = ops.Less if _.closed == "right" else ops.LessEqual + + cases.append(cmp(_.buckets[-1], _.arg)) + results.append(bucket_id) + bucket_id += 1 + + return ops.SearchedCase( + cases=tuple(cases), results=tuple(results), default=ops.NULL + ) + + +@replace(p.Capitalize) +def lower_capitalize(_, **kwargs): + """Rewrite Capitalize in terms of substring, concat, upper, and lower.""" + first = ops.Uppercase(ops.Substring(_.arg, start=0, length=1)) + # use length instead of length - 1 to avoid backends complaining about + # asking for negative length + # + # there are at most length - 1 characters, so asking for length is fine + rest = ops.Lowercase(ops.Substring(_.arg, start=1, length=ops.StringLength(_.arg))) + return ops.StringConcat((first, rest)) + + +@replace(p.Sample) +def lower_sample(_, **kwargs): + """Rewrite Sample as `t.filter(random() <= fraction)`. + + Errors as unsupported if a `seed` is specified. + """ + if _.seed is not None: + raise com.UnsupportedOperationError( + "`Table.sample` with a random seed is unsupported" + ) + return ops.Filter(_.parent, (ops.LessEqual(ops.RandomScalar(), _.fraction),)) diff --git a/third_party/bigframes_vendored/ibis/expr/rewrites.py b/third_party/bigframes_vendored/ibis/expr/rewrites.py new file mode 100644 index 0000000000..a5ffcae8ee --- /dev/null +++ b/third_party/bigframes_vendored/ibis/expr/rewrites.py @@ -0,0 +1,382 @@ +# Contains code from https://github.com/ibis-project/ibis/blob/main/ibis/expr/rewrites.py + +"""Some common rewrite functions to be shared between backends.""" + +from __future__ import annotations + +from collections import defaultdict + +from ibis.common.collections import FrozenDict # noqa: TCH001 +from ibis.common.deferred import _, deferred, Item, var +from ibis.common.exceptions import ExpressionError, IbisInputError +from ibis.common.graph import Node as Traversable +from ibis.common.graph import traverse +from ibis.common.grounds import Concrete +from ibis.common.patterns import Check, pattern, replace +from ibis.common.typing import VarTuple # noqa: TCH001 +import ibis.expr.operations as ops +from ibis.util import Namespace, promote_list +import toolz + +p = Namespace(pattern, module=ops) +d = Namespace(deferred, module=ops) + + +x = var("x") +y = var("y") +name = var("name") + + +class DerefMap(Concrete, Traversable): + """Trace and replace fields from earlier relations in the hierarchy. + + In order to provide a nice user experience, we need to allow expressions + from earlier relations in the hierarchy. Consider the following example: + + t = ibis.table([('a', 'int64'), ('b', 'string')], name='t') + t1 = t.select([t.a, t.b]) + t2 = t1.filter(t.a > 0) # note that not t1.a is referenced here + t3 = t2.select(t.a) # note that not t2.a is referenced here + + However the relational operations in the IR are strictly enforcing that + the expressions are referencing the immediate parent only. So we need to + track fields upwards the hierarchy to replace `t.a` with `t1.a` and `t2.a` + in the example above. This is called dereferencing. + + Whether we can treat or not a field of a relation semantically equivalent + with a field of an earlier relation in the hierarchy depends on the + `.values` mapping of the relation. Leaf relations, like `t` in the example + above, have an empty `.values` mapping, so we cannot dereference fields + from them. On the other hand a projection, like `t1` in the example above, + has a `.values` mapping like `{'a': t.a, 'b': t.b}`, so we can deduce that + `t1.a` is semantically equivalent with `t.a` and so on. + """ + + """The relations we want the values to point to.""" + rels: VarTuple[ops.Relation] + + """Substitution mapping from values of earlier relations to the fields of `rels`.""" + subs: FrozenDict[ops.Value, ops.Field] + + """Ambiguous field references.""" + ambigs: FrozenDict[ops.Value, VarTuple[ops.Value]] + + @classmethod + def from_targets(cls, rels, extra=None): + """Create a dereference map from a list of target relations. + + Usually a single relation is passed except for joins where multiple + relations are involved. + + Parameters + ---------- + rels : list of ops.Relation + The target relations to dereference to. + extra : dict, optional + Extra substitutions to be added to the dereference map. + + Returns + ------- + DerefMap + """ + rels = promote_list(rels) + mapping = defaultdict(dict) + for rel in rels: + for field in rel.fields.values(): + for value, distance in cls.backtrack(field): + mapping[value][field] = distance + + subs, ambigs = {}, {} + for from_, to in mapping.items(): + mindist = min(to.values()) + minkeys = [k for k, v in to.items() if v == mindist] + # if all the closest fields are from the same relation, then we + # can safely substitute them and we pick the first one arbitrarily + if all(minkeys[0].relations == k.relations for k in minkeys): + subs[from_] = minkeys[0] + else: + ambigs[from_] = minkeys + + if extra is not None: + subs.update(extra) + + return cls(rels, subs, ambigs) + + @classmethod + def backtrack(cls, value): + """Backtrack the field in the relation hierarchy. + + The field is traced back until no modification is made, so only follow + ops.Field nodes not arbitrary values. + + Parameters + ---------- + value : ops.Value + The value to backtrack. + + Yields + ------ + tuple[ops.Field, int] + The value node and the distance from the original value. + """ + distance = 0 + # track down the field in the hierarchy until no modification + # is made so only follow ops.Field nodes not arbitrary values; + while isinstance(value, ops.Field): + yield value, distance + value = value.rel.values.get(value.name) + distance += 1 + if ( + value is not None + and value.relations + and not value.find(ops.Impure, filter=ops.Value) + ): + yield value, distance + + def dereference(self, value): + """Dereference a value to the target relations. + + Also check for ambiguous field references. If a field reference is found + which is marked as ambiguous, then raise an error. + + Parameters + ---------- + value : ops.Value + The value to dereference. + + Returns + ------- + ops.Value + The dereferenced value. + """ + ambigs = value.find(lambda x: x in self.ambigs, filter=ops.Value) + if ambigs: + raise IbisInputError( + f"Ambiguous field reference {ambigs!r} in expression {value!r}" + ) + return value.replace(self.subs, filter=ops.Value) + + +def flatten_predicates(node): + """Yield the expressions corresponding to the `And` nodes of a predicate. + + Examples + -------- + >>> import ibis + >>> t = ibis.table([("a", "int64"), ("b", "string")], name="t") + >>> filt = (t.a == 1) & (t.b == "foo") + >>> predicates = flatten_predicates(filt.op()) + >>> len(predicates) + 2 + >>> predicates[0].to_expr().name("left") + r0 := UnboundTable: t + a int64 + b string + left: r0.a == 1 + >>> predicates[1].to_expr().name("right") + r0 := UnboundTable: t + a int64 + b string + right: r0.b == 'foo' + + """ + + def predicate(node): + if isinstance(node, ops.And): + # proceed and don't yield the node + return True, None + else: + # halt and yield the node + return False, node + + return list(traverse(predicate, node)) + + +@replace(p.Field(p.JoinChain)) +def peel_join_field(_): + return _.rel.values[_.name] + + +@replace(p.ScalarParameter) +def replace_parameter(_, params, **kwargs): + """Replace scalar parameters with their values.""" + return ops.Literal(value=params[_], dtype=_.dtype) + + +@replace(p.StringSlice) +def lower_stringslice(_, **kwargs): + """Rewrite StringSlice in terms of Substring.""" + if _.end is None: + return ops.Substring(_.arg, start=_.start) + if _.start is None: + return ops.Substring(_.arg, start=0, length=_.end) + if ( + isinstance(_.start, ops.Literal) + and isinstance(_.start.value, int) + and isinstance(_.end, ops.Literal) + and isinstance(_.end.value, int) + ): + # optimization for constant values + length = _.end.value - _.start.value + else: + length = ops.Subtract(_.end, _.start) + return ops.Substring(_.arg, start=_.start, length=length) + + +@replace(p.Analytic) +def wrap_analytic(_, **__): + # Wrap analytic functions in a window function + return ops.WindowFunction(_) + + +@replace(p.Reduction) +def project_wrap_reduction(_, rel): + # Query all the tables that the reduction depends on + if _.relations == {rel}: + # The reduction is fully originating from the `rel`, so turn + # it into a window function of `rel` + return ops.WindowFunction(_) + else: + # 1. The reduction doesn't depend on any table, constructed from + # scalar values, so turn it into a scalar subquery. + # 2. The reduction is originating from `rel` and other tables, + # so this is a correlated scalar subquery. + # 3. The reduction is originating entirely from other tables, + # so this is an uncorrelated scalar subquery. + return ops.ScalarSubquery(_.to_expr().as_table()) + + +def rewrite_project_input(value, relation): + # we need to detect reductions which are either turned into window functions + # or scalar subqueries depending on whether they are originating from the + # relation + return value.replace( + wrap_analytic | project_wrap_reduction, + filter=p.Value & ~p.WindowFunction, + context={"rel": relation}, + ) + + +ReductionLike = p.Reduction | p.Field(p.Aggregate(groups={})) + + +@replace(ReductionLike) +def filter_wrap_reduction(_): + # Wrap reductions or fields referencing an aggregation without a group by - + # which are scalar fields - in a scalar subquery. In the latter case we + # use the reduction value from the aggregation. + if isinstance(_, ops.Field): + value = _.rel.values[_.name] + else: + value = _ + return ops.ScalarSubquery(value.to_expr().as_table()) + + +def rewrite_filter_input(value): + return value.replace( + wrap_analytic | filter_wrap_reduction, filter=p.Value & ~p.WindowFunction + ) + + +@replace(p.Analytic | p.Reduction) +def window_wrap_reduction(_, window): + # Wrap analytic and reduction functions in a window function. Used in the + # value.over() API. + return ops.WindowFunction( + _, + how=window.how, + start=window.start, + end=window.end, + group_by=window.groupings, + order_by=window.orderings, + ) + + +@replace(p.WindowFunction) +def window_merge_frames(_, window): + # Merge window frames, used in the value.over() and groupby.select() APIs. + if _.how != window.how: + raise ExpressionError( + f"Unable to merge {_.how} window with {window.how} window" + ) + elif _.start and window.start and _.start != window.start: + raise ExpressionError( + "Unable to merge windows with conflicting `start` boundary" + ) + elif _.end and window.end and _.end != window.end: + raise ExpressionError("Unable to merge windows with conflicting `end` boundary") + + start = _.start or window.start + end = _.end or window.end + group_by = tuple(toolz.unique(_.group_by + window.groupings)) + + order_keys = {} + for sort_key in window.orderings + _.order_by: + order_keys[sort_key.expr] = sort_key.ascending, sort_key.nulls_first + + order_by = ( + ops.SortKey(expr, ascending=ascending, nulls_first=nulls_first) + for expr, (ascending, nulls_first) in order_keys.items() + ) + return _.copy(start=start, end=end, group_by=group_by, order_by=order_by) + + +def rewrite_window_input(value, window): + context = {"window": window} + # if self is a reduction or analytic function, wrap it in a window function + node = value.replace( + window_wrap_reduction, + filter=p.Value & ~p.WindowFunction, + context=context, + ) + # if self is already a window function, merge the existing window frame + # with the requested window frame + return node.replace(window_merge_frames, filter=p.Value, context=context) + + +# TODO(kszucs): schema comparison should be updated to not distinguish between +# different column order +@replace(p.Project(y @ p.Relation) & Check(_.schema == y.schema)) +def complete_reprojection(_, y): + # TODO(kszucs): this could be moved to the pattern itself but not sure how + # to express it, especially in a shorter way then the following check + for name in _.schema: + if _.values[name] != ops.Field(y, name): + return _ + return y + + +@replace(p.Project(y @ p.Project)) +def subsequent_projects(_, y): + rule = p.Field(y, name) >> Item(y.values, name) + values = {k: v.replace(rule, filter=ops.Value) for k, v in _.values.items()} + return ops.Project(y.parent, values) + + +@replace(p.Filter(y @ p.Filter)) +def subsequent_filters(_, y): + rule = p.Field(y, name) >> d.Field(y.parent, name) + preds = tuple(v.replace(rule, filter=ops.Value) for v in _.predicates) + return ops.Filter(y.parent, y.predicates + preds) + + +@replace(p.Filter(y @ p.Project)) +def reorder_filter_project(_, y): + rule = p.Field(y, name) >> Item(y.values, name) + preds = tuple(v.replace(rule, filter=ops.Value) for v in _.predicates) + + inner = ops.Filter(y.parent, preds) + rule = p.Field(y.parent, name) >> d.Field(inner, name) + projs = {k: v.replace(rule, filter=ops.Value) for k, v in y.values.items()} + + return ops.Project(inner, projs) + + +def simplify(node): + # TODO(kszucs): add a utility to the graph module to do rewrites in multiple + # passes after each other + node = node.replace(reorder_filter_project) + node = node.replace(reorder_filter_project) + node = node.replace(subsequent_projects | subsequent_filters) + node = node.replace(complete_reprojection) + return node From 46c6e252b04e92444ea7c054cd8779be8e6422c1 Mon Sep 17 00:00:00 2001 From: Tim Swast Date: Fri, 23 Aug 2024 21:26:13 +0000 Subject: [PATCH 38/59] add default arg for nulls_first for python 3.9 support --- .../bigframes_vendored/ibis/backends/sql/compilers/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/third_party/bigframes_vendored/ibis/backends/sql/compilers/base.py b/third_party/bigframes_vendored/ibis/backends/sql/compilers/base.py index 5003c3e4cd..48052ce2c2 100644 --- a/third_party/bigframes_vendored/ibis/backends/sql/compilers/base.py +++ b/third_party/bigframes_vendored/ibis/backends/sql/compilers/base.py @@ -1121,7 +1121,7 @@ def visit_Coalesce(self, op, *, arg): ### Ordering and window functions - def visit_SortKey(self, op, *, expr, ascending: bool, nulls_first: bool): + def visit_SortKey(self, op, *, expr, ascending: bool, nulls_first: bool = False): return sge.Ordered(this=expr, desc=not ascending, nulls_first=nulls_first) def visit_ApproxMedian(self, op, *, arg, where): From cedc6fc41445bbd87abc8819a88a8cc2df25009c Mon Sep 17 00:00:00 2001 From: Tim Swast Date: Fri, 23 Aug 2024 21:54:01 +0000 Subject: [PATCH 39/59] restore integer conversion --- .../ibis/backends/sql/compilers/bigquery/__init__.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/third_party/bigframes_vendored/ibis/backends/sql/compilers/bigquery/__init__.py b/third_party/bigframes_vendored/ibis/backends/sql/compilers/bigquery/__init__.py index 1e015c75dc..eb858a4bc9 100644 --- a/third_party/bigframes_vendored/ibis/backends/sql/compilers/bigquery/__init__.py +++ b/third_party/bigframes_vendored/ibis/backends/sql/compilers/bigquery/__init__.py @@ -27,6 +27,7 @@ from ibis.common.temporal import DateUnit, IntervalUnit, TimestampUnit, TimeUnit import ibis.expr.datatypes as dt import ibis.expr.operations as ops +import numpy as np import sqlglot as sg from sqlglot.dialects import BigQuery import sqlglot.expressions as sge @@ -511,6 +512,9 @@ def visit_NonNullLiteral(self, op, *, value, dtype): ) elif dtype.is_uuid(): return sge.convert(str(value)) + + elif dtype.is_int64(): + return sge.convert(np.int64(value)) return None def visit_IntervalFromInteger(self, op, *, arg, unit): From 16f787833940a0bae049b4a8808dbcbafd9b4d2d Mon Sep 17 00:00:00 2001 From: Chelsea Lin Date: Thu, 29 Aug 2024 22:06:44 +0000 Subject: [PATCH 40/59] fix window tests: diff, duplicated, shift --- bigframes/core/compile/compiled.py | 2 +- bigframes/core/groupby/__init__.py | 4 +--- bigframes/core/window_spec.py | 1 - bigframes/dataframe.py | 12 ++++-------- bigframes/series.py | 12 +++--------- tests/system/small/operations/test_strings.py | 2 +- 6 files changed, 10 insertions(+), 23 deletions(-) diff --git a/bigframes/core/compile/compiled.py b/bigframes/core/compile/compiled.py index 0982f90c61..9ebf748a6e 100644 --- a/bigframes/core/compile/compiled.py +++ b/bigframes/core/compile/compiled.py @@ -1293,7 +1293,7 @@ def _ibis_window_from_spec( bounds.preceding, bounds.following, how="range" ) if isinstance(bounds, RowsWindowBounds): - if bounds.preceding is not None and bounds.following is not None: + if bounds.preceding is not None or bounds.following is not None: window = window.preceding_following( bounds.preceding, bounds.following, how="rows" ) diff --git a/bigframes/core/groupby/__init__.py b/bigframes/core/groupby/__init__.py index da3101502c..e84a81573d 100644 --- a/bigframes/core/groupby/__init__.py +++ b/bigframes/core/groupby/__init__.py @@ -687,7 +687,7 @@ def cumcount(self, *args, **kwargs) -> series.Series: def shift(self, periods=1) -> series.Series: """Shift index by desired number of periods.""" # Window framing clause is not allowed for analytic function lag. - window = window_specs.unbound( + window = window_specs.rows( grouping_keys=tuple(self._by_col_ids), ) return self._apply_window_op(agg_ops.ShiftOp(periods), window=window) @@ -696,8 +696,6 @@ def shift(self, periods=1) -> series.Series: def diff(self, periods=1) -> series.Series: window = window_specs.rows( grouping_keys=tuple(self._by_col_ids), - preceding=periods if periods > 0 else None, - following=-periods if periods < 0 else None, ) return self._apply_window_op(agg_ops.DiffOp(periods), window=window) diff --git a/bigframes/core/window_spec.py b/bigframes/core/window_spec.py index f011e2848d..3d80afea5a 100644 --- a/bigframes/core/window_spec.py +++ b/bigframes/core/window_spec.py @@ -70,7 +70,6 @@ def rows( Returns: WindowSpec """ - assert (preceding is not None) or (following is not None) bounds = RowsWindowBounds(preceding=preceding, following=following) return WindowSpec( grouping_keys=grouping_keys, diff --git a/bigframes/dataframe.py b/bigframes/dataframe.py index a5960d9e53..dba45a9d46 100644 --- a/bigframes/dataframe.py +++ b/bigframes/dataframe.py @@ -2822,17 +2822,13 @@ def cummax(self) -> DataFrame: @validations.requires_ordering() def shift(self, periods: int = 1) -> DataFrame: - # Window framing clause is not allowed for analytic function lag. - window = windows.unbound() - return self._apply_window_op(agg_ops.ShiftOp(periods), window) + window_spec = windows.rows() + return self._apply_window_op(agg_ops.ShiftOp(periods), window_spec) @validations.requires_ordering() def diff(self, periods: int = 1) -> DataFrame: - window = windows.rows( - preceding=periods if periods > 0 else None, - following=-periods if periods < 0 else None, - ) - return self._apply_window_op(agg_ops.DiffOp(periods), window) + window_spec = windows.rows() + return self._apply_window_op(agg_ops.DiffOp(periods), window_spec) @validations.requires_ordering() def pct_change(self, periods: int = 1) -> DataFrame: diff --git a/bigframes/series.py b/bigframes/series.py index 07ee69da6f..0502e042e2 100644 --- a/bigframes/series.py +++ b/bigframes/series.py @@ -487,19 +487,13 @@ def cumprod(self) -> Series: @validations.requires_ordering() def shift(self, periods: int = 1) -> Series: - window_spec = windows.rows( - preceding=periods if periods > 0 else None, - following=-periods if periods < 0 else None, - ) + window_spec = windows.rows() return self._apply_window_op(agg_ops.ShiftOp(periods), window_spec) @validations.requires_ordering() def diff(self, periods: int = 1) -> Series: - window = windows.rows( - preceding=periods if periods > 0 else None, - following=-periods if periods < 0 else None, - ) - return self._apply_window_op(agg_ops.DiffOp(periods), window) + window_spec = windows.rows() + return self._apply_window_op(agg_ops.DiffOp(periods), window_spec) @validations.requires_ordering() def pct_change(self, periods: int = 1) -> Series: diff --git a/tests/system/small/operations/test_strings.py b/tests/system/small/operations/test_strings.py index 3191adf920..15e8512317 100644 --- a/tests/system/small/operations/test_strings.py +++ b/tests/system/small/operations/test_strings.py @@ -634,7 +634,7 @@ def test_getitem_w_array(index): def test_getitem_w_struct_array(): - if packaging.version.Version(pd.__version__) <= packaging.version.Version("1.5.0"): + if packaging.version.Version(pd.__version__) <= packaging.version.Version("1.5.3"): pytest.skip("https://github.com/googleapis/python-bigquery/issues/1992") pa_struct = pa.struct( From 6be3a3883e6a99ba6d1fe6e21f96d368b950fd78 Mon Sep 17 00:00:00 2001 From: Chelsea Lin Date: Fri, 30 Aug 2024 18:39:10 +0000 Subject: [PATCH 41/59] fixing ibis parenthesize_inputs bugs and related tests --- bigframes/core/blocks.py | 2 +- .../bigframes_vendored/ibis/backends/sql/compilers/base.py | 5 ++++- .../ibis/backends/sql/compilers/bigquery/__init__.py | 4 ---- 3 files changed, 5 insertions(+), 6 deletions(-) diff --git a/bigframes/core/blocks.py b/bigframes/core/blocks.py index 8816bb1beb..abe4fe62c5 100644 --- a/bigframes/core/blocks.py +++ b/bigframes/core/blocks.py @@ -2415,7 +2415,7 @@ def _is_monotonic( return self._stats_cache[column_name][op_name] period = 1 - window_spec = windows.rows(preceding=period, following=None) + window_spec = windows.rows() # any NaN value means not monotonic block, last_notna_id = self.apply_unary_op(column_ids[0], ops.notnull_op) diff --git a/third_party/bigframes_vendored/ibis/backends/sql/compilers/base.py b/third_party/bigframes_vendored/ibis/backends/sql/compilers/base.py index 48052ce2c2..3fa472344b 100644 --- a/third_party/bigframes_vendored/ibis/backends/sql/compilers/base.py +++ b/third_party/bigframes_vendored/ibis/backends/sql/compilers/base.py @@ -1385,7 +1385,10 @@ def visit_Aggregate(self, op, *, parent, groups, metrics): @classmethod def _add_parens(cls, op, sg_expr): - if isinstance(op, cls.NEEDS_PARENS): + # Patch for https://github.com/ibis-project/ibis/issues/9975 + if isinstance(op, cls.NEEDS_PARENS) or ( + isinstance(op, ops.Alias) and isinstance(op.arg, cls.NEEDS_PARENS) + ): return sge.paren(sg_expr, copy=False) return sg_expr diff --git a/third_party/bigframes_vendored/ibis/backends/sql/compilers/bigquery/__init__.py b/third_party/bigframes_vendored/ibis/backends/sql/compilers/bigquery/__init__.py index eb858a4bc9..9a5f889c86 100644 --- a/third_party/bigframes_vendored/ibis/backends/sql/compilers/bigquery/__init__.py +++ b/third_party/bigframes_vendored/ibis/backends/sql/compilers/bigquery/__init__.py @@ -1145,9 +1145,5 @@ def visit_WindowFunction(self, op, *, how, func, start, end, group_by, order_by) return sge.Window(this=func, partition_by=group_by, order=order, spec=spec) - @sql_compiler.parenthesize_inputs - def visit_And(self, op, *, left, right): - return sge.And(this=sge.Paren(this=left), expression=sge.Paren(this=right)) - compiler = BigQueryCompiler() From 90cae7c1e1f1f79dc69a999f83ccc02d02c648fd Mon Sep 17 00:00:00 2001 From: Chelsea Lin Date: Fri, 30 Aug 2024 19:01:17 +0000 Subject: [PATCH 42/59] fixing lint --- .../ibis/backends/sql/compilers/bigquery/__init__.py | 1 - 1 file changed, 1 deletion(-) diff --git a/third_party/bigframes_vendored/ibis/backends/sql/compilers/bigquery/__init__.py b/third_party/bigframes_vendored/ibis/backends/sql/compilers/bigquery/__init__.py index 9a5f889c86..9a709404fd 100644 --- a/third_party/bigframes_vendored/ibis/backends/sql/compilers/bigquery/__init__.py +++ b/third_party/bigframes_vendored/ibis/backends/sql/compilers/bigquery/__init__.py @@ -15,7 +15,6 @@ SQLGlotCompiler, STAR, ) -import bigframes_vendored.ibis.backends.sql.compilers.base as sql_compiler from ibis import util from ibis.backends.sql.datatypes import BigQueryType, BigQueryUDFType from ibis.backends.sql.rewrites import ( From 82c04898c2d100cadf7c94eba4ca722f9f8f7881 Mon Sep 17 00:00:00 2001 From: Chelsea Lin Date: Tue, 3 Sep 2024 23:30:56 +0000 Subject: [PATCH 43/59] disable test_query_complexity_error --- tests/system/small/test_dataframe.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/system/small/test_dataframe.py b/tests/system/small/test_dataframe.py index ddcf044911..e1141f9a0b 100644 --- a/tests/system/small/test_dataframe.py +++ b/tests/system/small/test_dataframe.py @@ -4567,7 +4567,7 @@ def test_recursion_limit(scalars_df_index): scalars_df_index = scalars_df_index + 4 scalars_df_index.to_pandas() - +@pytest.mark.skipif(reason="Skip until query complexity error can be reliably triggered") def test_query_complexity_error(scalars_df_index): # This test requires automatic caching/query decomposition to be turned off bf_df = scalars_df_index From b3f52c9889fee7b9aedac4452727a2fbf47535ba Mon Sep 17 00:00:00 2001 From: Chelsea Lin Date: Wed, 4 Sep 2024 00:54:24 +0000 Subject: [PATCH 44/59] fix doctest np.int64(0) upgrades --- tests/system/small/test_dataframe.py | 5 ++++- .../bigframes_vendored/pandas/core/frame.py | 4 ++-- .../bigframes_vendored/pandas/core/series.py | 18 +++++++++--------- .../sklearn/metrics/_classification.py | 4 ++-- .../sklearn/metrics/_ranking.py | 8 ++++---- .../sklearn/metrics/_regression.py | 4 ++-- 6 files changed, 23 insertions(+), 20 deletions(-) diff --git a/tests/system/small/test_dataframe.py b/tests/system/small/test_dataframe.py index e1141f9a0b..7a2de03529 100644 --- a/tests/system/small/test_dataframe.py +++ b/tests/system/small/test_dataframe.py @@ -4567,7 +4567,10 @@ def test_recursion_limit(scalars_df_index): scalars_df_index = scalars_df_index + 4 scalars_df_index.to_pandas() -@pytest.mark.skipif(reason="Skip until query complexity error can be reliably triggered") + +@pytest.mark.skipif( + reason="Skip until query complexity error can be reliably triggered" +) def test_query_complexity_error(scalars_df_index): # This test requires automatic caching/query decomposition to be turned off bf_df = scalars_df_index diff --git a/third_party/bigframes_vendored/pandas/core/frame.py b/third_party/bigframes_vendored/pandas/core/frame.py index 10565a2552..03a9e94d64 100644 --- a/third_party/bigframes_vendored/pandas/core/frame.py +++ b/third_party/bigframes_vendored/pandas/core/frame.py @@ -6515,12 +6515,12 @@ def at(self): Get value at specified row/column pair >>> df.at[4, 'B'] - 2 + np.int64(2) Get value within a series >>> df.loc[5].at['B'] - 4 + np.int64(4) Returns: bigframes.core.indexers.AtDataFrameIndexer: Indexers object. diff --git a/third_party/bigframes_vendored/pandas/core/series.py b/third_party/bigframes_vendored/pandas/core/series.py index a30ed9cd92..c83db865bc 100644 --- a/third_party/bigframes_vendored/pandas/core/series.py +++ b/third_party/bigframes_vendored/pandas/core/series.py @@ -592,7 +592,7 @@ def agg(self, func): dtype: Int64 >>> s.agg('min') - 1 + np.int64(1) >>> s.agg(['min', 'max']) min 1 @@ -3080,7 +3080,7 @@ def max( 1 3 dtype: Int64 >>> s.max() - 3 + np.int64(3) Calculating the max of a Series containing ``NA`` values: @@ -3091,7 +3091,7 @@ def max( 2 dtype: Int64 >>> s.max() - 3 + np.int64(3) Returns: scalar: Scalar. @@ -3576,10 +3576,10 @@ def argmax(self): dtype: Float64 >>> s.argmax() - 2 + np.int64(2) >>> s.argmin() - 0 + np.int64(0) The maximum cereal calories is the third element and the minimum cereal calories is the first element, since series is zero-indexed. @@ -3612,10 +3612,10 @@ def argmin(self): dtype: Float64 >>> s.argmax() - 2 + np.int64(2) >>> s.argmin() - 0 + np.int64(0) The maximum cereal calories is the third element and the minimum cereal calories is the first element, since series is zero-indexed. @@ -4066,7 +4066,7 @@ def at(self): Get value at specified row label >>> s.at['B'] - 2 + np.int64(2) Returns: @@ -4314,7 +4314,7 @@ def __getitem__(self, indexer): >>> s = bpd.Series([15, 30, 45]) >>> s[1] - 30 + np.int64(30) >>> s[0:2] 0 15 1 30 diff --git a/third_party/bigframes_vendored/sklearn/metrics/_classification.py b/third_party/bigframes_vendored/sklearn/metrics/_classification.py index 8e8b2c1952..1459d0d050 100644 --- a/third_party/bigframes_vendored/sklearn/metrics/_classification.py +++ b/third_party/bigframes_vendored/sklearn/metrics/_classification.py @@ -36,13 +36,13 @@ def accuracy_score(y_true, y_pred, normalize=True) -> float: >>> y_pred = bpd.DataFrame([0, 1, 2, 3]) >>> accuracy_score = bigframes.ml.metrics.accuracy_score(y_true, y_pred) >>> accuracy_score - 0.5 + np.float64(0.5) If False, return the number of correctly classified samples: >>> accuracy_score = bigframes.ml.metrics.accuracy_score(y_true, y_pred, normalize=False) >>> accuracy_score - 2 + np.float64(2) Args: y_true (Series or DataFrame of shape (n_samples,)): diff --git a/third_party/bigframes_vendored/sklearn/metrics/_ranking.py b/third_party/bigframes_vendored/sklearn/metrics/_ranking.py index dee8b350c0..7b97526de2 100644 --- a/third_party/bigframes_vendored/sklearn/metrics/_ranking.py +++ b/third_party/bigframes_vendored/sklearn/metrics/_ranking.py @@ -37,7 +37,7 @@ def auc(x, y) -> float: >>> y = bpd.DataFrame([2, 3, 4, 5]) >>> auc = bigframes.ml.metrics.auc(x, y) >>> auc - 3.5 + np.float64(3.5) The input can be Series: @@ -47,7 +47,7 @@ def auc(x, y) -> float: ... ) >>> auc = bigframes.ml.metrics.auc(df["x"], df["y"]) >>> auc - 3.5 + np.float64(3.5) Args: @@ -77,7 +77,7 @@ def roc_auc_score(y_true, y_score) -> float: >>> y_score = bpd.DataFrame([0.1, 0.4, 0.35, 0.8, 0.65, 0.9, 0.5, 0.3, 0.6, 0.45]) >>> roc_auc_score = bigframes.ml.metrics.roc_auc_score(y_true, y_score) >>> roc_auc_score - 0.625 + np.float64(0.625) The input can be Series: @@ -87,7 +87,7 @@ def roc_auc_score(y_true, y_score) -> float: ... ) >>> roc_auc_score = bigframes.ml.metrics.roc_auc_score(df["y_true"], df["y_score"]) >>> roc_auc_score - 0.625 + np.float64(0.625) Args: y_true (Series or DataFrame of shape (n_samples,)): diff --git a/third_party/bigframes_vendored/sklearn/metrics/_regression.py b/third_party/bigframes_vendored/sklearn/metrics/_regression.py index c3e579bd29..56f78c6d0b 100644 --- a/third_party/bigframes_vendored/sklearn/metrics/_regression.py +++ b/third_party/bigframes_vendored/sklearn/metrics/_regression.py @@ -52,7 +52,7 @@ def r2_score(y_true, y_pred, force_finite=True) -> float: >>> y_pred = bpd.DataFrame([2.5, 0.0, 2, 8]) >>> r2_score = bigframes.ml.metrics.r2_score(y_true, y_pred) >>> r2_score - 0.9486081370449679 + np.float64(0.9486081370449679) Args: y_true (Series or DataFrame of shape (n_samples,)): @@ -79,7 +79,7 @@ def mean_squared_error(y_true, y_pred) -> float: >>> y_pred = bpd.DataFrame([2.5, 0.0, 2, 8]) >>> mse = bigframes.ml.metrics.mean_squared_error(y_true, y_pred) >>> mse - 0.375 + np.float64(0.375) Args: y_true (Series or DataFrame of shape (n_samples,)): From c77120e922532a4966b8ccc109dd9c55e4a2cd65 Mon Sep 17 00:00:00 2001 From: Chelsea Lin Date: Wed, 4 Sep 2024 05:16:21 +0000 Subject: [PATCH 45/59] fix doctest np.int64(0) upgrades more --- .../bigframes_vendored/pandas/core/frame.py | 10 ++--- .../bigframes_vendored/pandas/core/series.py | 40 +++++++++---------- .../sklearn/metrics/_classification.py | 2 +- 3 files changed, 26 insertions(+), 26 deletions(-) diff --git a/third_party/bigframes_vendored/pandas/core/frame.py b/third_party/bigframes_vendored/pandas/core/frame.py index 03a9e94d64..9c72773037 100644 --- a/third_party/bigframes_vendored/pandas/core/frame.py +++ b/third_party/bigframes_vendored/pandas/core/frame.py @@ -540,7 +540,7 @@ def to_dict( >>> df = bpd.DataFrame({'col1': [1, 2], 'col2': [3, 4]}) >>> df.to_dict() - {'col1': {0: 1, 1: 2}, 'col2': {0: 3, 1: 4}} + {'col1': {np.int64(0): 1, np.int64(1): 2}, 'col2': {np.int64(0): 3, np.int64(1): 4}} You can specify the return orientation. @@ -1769,7 +1769,7 @@ def iterrows(self): ... }) >>> index, row = next(df.iterrows()) >>> index - 0 + np.int64(0) >>> row A 1 B 4 @@ -1790,7 +1790,7 @@ def itertuples(self, index: bool = True, name: str | None = "Pandas"): ... 'B': [4, 5, 6], ... }) >>> next(df.itertuples(name="Pair")) - Pair(Index=0, A=1, B=4) + Pair(Index=np.int64(0), A=np.int64(1), B=np.int64(4)) Args: index (bool, default True): @@ -6482,12 +6482,12 @@ def iat(self): Get value at specified row/column pair >>> df.iat[1, 2] - 1 + np.int64(1) Get value within a series >>> df.loc[0].iat[1] - 2 + np.int64(2) Returns: bigframes.core.indexers.IatDataFrameIndexer: Indexers object. diff --git a/third_party/bigframes_vendored/pandas/core/series.py b/third_party/bigframes_vendored/pandas/core/series.py index c83db865bc..e31656dd6e 100644 --- a/third_party/bigframes_vendored/pandas/core/series.py +++ b/third_party/bigframes_vendored/pandas/core/series.py @@ -205,7 +205,7 @@ def hasnans(self) -> bool: 3 dtype: Float64 >>> s.hasnans - True + np.True_ Returns: bool @@ -626,7 +626,7 @@ def count(self): 2 dtype: Float64 >>> s.count() - 2 + np.int64(2) Returns: int or Series (if level specified): Number of non-null values in the @@ -834,12 +834,12 @@ def corr(self, other, method="pearson", min_periods=None) -> float: >>> s1 = bpd.Series([.2, .0, .6, .2]) >>> s2 = bpd.Series([.3, .6, .0, .1]) >>> s1.corr(s2) - -0.8510644963469901 + np.float64(-0.8510644963469901) >>> s1 = bpd.Series([1, 2, 3], index=[0, 1, 2]) >>> s2 = bpd.Series([1, 2, 3], index=[2, 1, 0]) >>> s1.corr(s2) - -1.0 + np.float64(-1.0) Args: other (Series): @@ -870,9 +870,9 @@ def autocorr(self, lag: int = 1) -> float: >>> s = bpd.Series([0.25, 0.5, 0.2, -0.05]) >>> s.autocorr() # doctest: +ELLIPSIS - 0.10355... + np.float64(0.10355263309024067) >>> s.autocorr(lag=2) - -1.0 + np.float64(-1.0) If the Pearson correlation is not well defined, then 'NaN' is returned. @@ -951,12 +951,12 @@ def dot(self, other) -> Series | np.ndarray: >>> s = bpd.Series([0, 1, 2, 3]) >>> other = bpd.Series([-1, 2, -3, 4]) >>> s.dot(other) - 8 + np.int64(8) You can also use the operator ``@`` for the dot product: >>> s @ other - 8 + np.int64(8) Args: other (Series): @@ -3120,7 +3120,7 @@ def min( 1 3 dtype: Int64 >>> s.min() - 1 + np.int64(1) Calculating the min of a Series containing ``NA`` values: @@ -3131,7 +3131,7 @@ def min( 2 dtype: Int64 >>> s.min() - 1 + np.int64(1) Returns: scalar: Scalar. @@ -3207,7 +3207,7 @@ def sum(self): 1 3 dtype: Int64 >>> s.sum() - 4 + np.int64(4) Calculating the sum of a Series containing ``NA`` values: @@ -3218,7 +3218,7 @@ def sum(self): 2 dtype: Int64 >>> s.sum() - 4 + np.int64(4) Returns: scalar: Scalar. @@ -3241,7 +3241,7 @@ def mean(self): 1 3 dtype: Int64 >>> s.mean() - 2.0 + np.float64(2.0) Calculating the mean of a Series containing ``NA`` values: @@ -3252,7 +3252,7 @@ def mean(self): 2 dtype: Int64 >>> s.mean() - 2.0 + np.float64(2.0) Returns: scalar: Scalar. @@ -3285,7 +3285,7 @@ def quantile( >>> bpd.options.display.progress_bar = None >>> s = bpd.Series([1, 2, 3, 4]) >>> s.quantile(.5) - 2.5 + np.float64(2.5) >>> s.quantile([.25, .5, .75]) 0.25 1.75 0.5 2.5 @@ -3887,11 +3887,11 @@ def is_monotonic_increasing(self) -> bool: >>> s = bpd.Series([1, 2, 2]) >>> s.is_monotonic_increasing - True + np.True_ >>> s = bpd.Series([3, 2, 1]) >>> s.is_monotonic_increasing - False + np.False_ Returns: bool: Boolean. @@ -3910,11 +3910,11 @@ def is_monotonic_decreasing(self) -> bool: >>> s = bpd.Series([3, 2, 2, 1]) >>> s.is_monotonic_decreasing - True + np.True_ >>> s = bpd.Series([1, 2, 3]) >>> s.is_monotonic_decreasing - False + np.False_ Returns: bool: Boolean. @@ -4041,7 +4041,7 @@ def iat(self): Get value at specified row number >>> s.iat[1] - 2 + np.int64(2) Returns: bigframes.core.indexers.IatSeriesIndexer: Indexers object. diff --git a/third_party/bigframes_vendored/sklearn/metrics/_classification.py b/third_party/bigframes_vendored/sklearn/metrics/_classification.py index 1459d0d050..c1a909e849 100644 --- a/third_party/bigframes_vendored/sklearn/metrics/_classification.py +++ b/third_party/bigframes_vendored/sklearn/metrics/_classification.py @@ -42,7 +42,7 @@ def accuracy_score(y_true, y_pred, normalize=True) -> float: >>> accuracy_score = bigframes.ml.metrics.accuracy_score(y_true, y_pred, normalize=False) >>> accuracy_score - np.float64(2) + np.int64(2) Args: y_true (Series or DataFrame of shape (n_samples,)): From 3d105844791d3c1bedc30e35eaf022ab271955c0 Mon Sep 17 00:00:00 2001 From: Chelsea Lin Date: Wed, 4 Sep 2024 05:26:25 +0000 Subject: [PATCH 46/59] fix groupby diff --- bigframes/core/groupby/__init__.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/bigframes/core/groupby/__init__.py b/bigframes/core/groupby/__init__.py index e84a81573d..5413a21651 100644 --- a/bigframes/core/groupby/__init__.py +++ b/bigframes/core/groupby/__init__.py @@ -263,10 +263,9 @@ def shift(self, periods=1) -> series.Series: @validations.requires_ordering() def diff(self, periods=1) -> series.Series: + # Window framing clause is not allowed for analytic function lag. window = window_specs.rows( grouping_keys=tuple(self._by_col_ids), - preceding=periods if periods > 0 else None, - following=-periods if periods < 0 else None, ) return self._apply_window_op(agg_ops.DiffOp(periods), window=window) From fe7aa817076779a57c835d62f3ac0dc57d1409a9 Mon Sep 17 00:00:00 2001 From: Chelsea Lin Date: Thu, 5 Sep 2024 18:17:31 +0000 Subject: [PATCH 47/59] addressing system-3.12/doctest issues related to numpy 2.1.1 --- bigframes/ml/preprocessing.py | 15 +++++++++++++-- tests/system/small/test_numpy.py | 7 +++++++ tests/system/small/test_series.py | 9 +++++++-- .../bigframes_vendored/pandas/core/series.py | 2 +- 4 files changed, 28 insertions(+), 5 deletions(-) diff --git a/bigframes/ml/preprocessing.py b/bigframes/ml/preprocessing.py index 13d2041ef3..28ce908223 100644 --- a/bigframes/ml/preprocessing.py +++ b/bigframes/ml/preprocessing.py @@ -19,6 +19,7 @@ import typing from typing import cast, Iterable, List, Literal, Optional, Union +import numpy as np import bigframes_vendored.sklearn.preprocessing._data import bigframes_vendored.sklearn.preprocessing._discretization @@ -305,8 +306,18 @@ def _compile_to_sql( array_split_points = {} if self.strategy == "uniform": for column in columns: - min_value = X[column].min() - max_value = X[column].max() + column_min = X[column].min() + if np.issubdtype(column_min.dtype, np.floating): + min_value = column_min.item() + else: + min_value = column_min + + column_max = X[column].max() + if np.issubdtype(column_max.dtype, np.floating): + max_value = column_max.item() + else: + max_value = column_max + bin_size = (max_value - min_value) / self.n_bins array_split_points[column] = [ min_value + i * bin_size for i in range(self.n_bins - 1) diff --git a/tests/system/small/test_numpy.py b/tests/system/small/test_numpy.py index 8f62d9628c..37a707b9d0 100644 --- a/tests/system/small/test_numpy.py +++ b/tests/system/small/test_numpy.py @@ -70,6 +70,13 @@ def test_df_ufuncs(scalars_dfs, opname): ).to_pandas() pd_result = getattr(np, opname)(scalars_pandas_df[["float64_col", "int64_col"]]) + # In NumPy versions 2 and later, `np.floor` and `np.ceil` now produce integer + # outputs for the "int64_col" column. + if opname in ["floor", "ceil"] and isinstance( + pd_result["int64_col"].dtypes, pd.Int64Dtype + ): + pd_result["int64_col"] = pd_result["int64_col"].astype(pd.Float64Dtype()) + pd.testing.assert_frame_equal(bf_result, pd_result) diff --git a/tests/system/small/test_series.py b/tests/system/small/test_series.py index 7458187a82..0ceae7787f 100644 --- a/tests/system/small/test_series.py +++ b/tests/system/small/test_series.py @@ -2346,8 +2346,13 @@ def test_value_counts(scalars_dfs, kwargs): scalars_df, scalars_pandas_df = scalars_dfs col_name = "int64_too" - bf_result = scalars_df[col_name].value_counts(**kwargs).to_pandas() - pd_result = scalars_pandas_df[col_name].value_counts(**kwargs) + # Pandas `value_counts` can produce non-deterministic results with tied counts. + # Remove duplicates to enforce a consistent output. + s = scalars_df[col_name].drop(0) + pd_s = scalars_pandas_df[col_name].drop(0) + + bf_result = s.value_counts(**kwargs).to_pandas() + pd_result = pd_s.value_counts(**kwargs) pd.testing.assert_series_equal( bf_result, diff --git a/third_party/bigframes_vendored/pandas/core/series.py b/third_party/bigframes_vendored/pandas/core/series.py index e31656dd6e..a6363e3285 100644 --- a/third_party/bigframes_vendored/pandas/core/series.py +++ b/third_party/bigframes_vendored/pandas/core/series.py @@ -878,7 +878,7 @@ def autocorr(self, lag: int = 1) -> float: >>> s = bpd.Series([1, 0, 0, 0]) >>> s.autocorr() - nan + np.float64(nan) Args: lag (int, default 1): From 7ff77f35deef6fbd55f84392933e6a52f036b249 Mon Sep 17 00:00:00 2001 From: Chelsea Lin Date: Thu, 5 Sep 2024 21:55:07 +0000 Subject: [PATCH 48/59] fix test_df_apply_axis_1_complex --- bigframes/ml/preprocessing.py | 9 +++++---- tests/system/large/test_remote_function.py | 2 +- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/bigframes/ml/preprocessing.py b/bigframes/ml/preprocessing.py index 28ce908223..80e00a7497 100644 --- a/bigframes/ml/preprocessing.py +++ b/bigframes/ml/preprocessing.py @@ -19,7 +19,6 @@ import typing from typing import cast, Iterable, List, Literal, Optional, Union -import numpy as np import bigframes_vendored.sklearn.preprocessing._data import bigframes_vendored.sklearn.preprocessing._discretization @@ -307,13 +306,15 @@ def _compile_to_sql( if self.strategy == "uniform": for column in columns: column_min = X[column].min() - if np.issubdtype(column_min.dtype, np.floating): + column_max = X[column].max() + + # Use Python value rather than Numpy value to serialization. + if hasattr(column_min, "item"): min_value = column_min.item() else: min_value = column_min - column_max = X[column].max() - if np.issubdtype(column_max.dtype, np.floating): + if hasattr(column_max, "item"): max_value = column_max.item() else: max_value = column_max diff --git a/tests/system/large/test_remote_function.py b/tests/system/large/test_remote_function.py index d6eefc1e31..b6e2baa503 100644 --- a/tests/system/large/test_remote_function.py +++ b/tests/system/large/test_remote_function.py @@ -1719,7 +1719,7 @@ def test_df_apply_axis_1_complex(session, pd_df): def serialize_row(row): custom = { - "name": row.name, + "name": row.name.item() if hasattr(row.name, "item") else row.name, "index": [idx for idx in row.index], "values": [ val.item() if hasattr(val, "item") else val for val in row.values From 00834c751fd1b756ddf88d78f159997fe72a0b21 Mon Sep 17 00:00:00 2001 From: Chelsea Lin Date: Tue, 10 Sep 2024 18:15:51 +0000 Subject: [PATCH 49/59] address compiler errors after merge --- .../ibis/backends/bigquery/backend.py | 2 +- .../ibis/backends/sql/compilers/base.py | 8 +- .../sql/compilers/bigquery/__init__.py | 85 ++----------------- 3 files changed, 10 insertions(+), 85 deletions(-) diff --git a/third_party/bigframes_vendored/ibis/backends/bigquery/backend.py b/third_party/bigframes_vendored/ibis/backends/bigquery/backend.py index 106b372f03..61a227619e 100644 --- a/third_party/bigframes_vendored/ibis/backends/bigquery/backend.py +++ b/third_party/bigframes_vendored/ibis/backends/bigquery/backend.py @@ -11,8 +11,8 @@ import re from typing import Any, Optional, TYPE_CHECKING -import bigframes_vendored.ibis.backends.sql.compilers as sc from bigframes_vendored.ibis.backends.bigquery.datatypes import BigQueryType +import bigframes_vendored.ibis.backends.sql.compilers as sc import google.api_core.exceptions import google.auth.credentials import google.cloud.bigquery as bq diff --git a/third_party/bigframes_vendored/ibis/backends/sql/compilers/base.py b/third_party/bigframes_vendored/ibis/backends/sql/compilers/base.py index e24b82a73f..4ba9a51618 100644 --- a/third_party/bigframes_vendored/ibis/backends/sql/compilers/base.py +++ b/third_party/bigframes_vendored/ibis/backends/sql/compilers/base.py @@ -610,7 +610,6 @@ def translate(self, op, *, params: Mapping[ir.Value, Any]) -> sge.Expression: op, params=params, rewrites=self.rewrites, - post_rewrites=self.post_rewrites, fuse_selects=options.sql.fuse_selects, ) @@ -1263,10 +1262,10 @@ def _cleanup_names(self, exprs: Mapping[str, sge.Expression]): yield value.as_(name, quoted=self.quoted, copy=False) def visit_Select( - self, op, *, parent, selections, predicates, qualified, sort_keys, distinct + self, op, *, parent, selections, predicates, qualified, sort_keys ): # if we've constructed a useless projection return the parent relation - if not (selections or predicates or qualified or sort_keys or distinct): + if not (selections or predicates or qualified or sort_keys): return parent result = parent @@ -1293,9 +1292,6 @@ def visit_Select( if sort_keys: result = result.order_by(*sort_keys, copy=False) - if distinct: - result = result.distinct() - return result def visit_DummyTable(self, op, *, values): diff --git a/third_party/bigframes_vendored/ibis/backends/sql/compilers/bigquery/__init__.py b/third_party/bigframes_vendored/ibis/backends/sql/compilers/bigquery/__init__.py index 4ac189ae45..b664ad7333 100644 --- a/third_party/bigframes_vendored/ibis/backends/sql/compilers/bigquery/__init__.py +++ b/third_party/bigframes_vendored/ibis/backends/sql/compilers/bigquery/__init__.py @@ -8,28 +8,25 @@ import re from typing import Any, TYPE_CHECKING -from bigframes_vendored.ibis.backends.bigquery.datatypes import ( - BigQueryType, - BigQueryUDFType, -) +import bigframes_vendored.ibis.backends.bigquery.datatypes as bq_datatypes from bigframes_vendored.ibis.backends.sql.compilers.base import ( AggGen, NULL, SQLGlotCompiler, STAR, ) -from bigframes_vendored.ibis.backends.sql.rewrites import ( +from ibis import util +from ibis.backends.sql.datatypes import BigQueryType, BigQueryUDFType +from ibis.backends.sql.rewrites import ( exclude_unsupported_window_frame_from_ops, exclude_unsupported_window_frame_from_rank, exclude_unsupported_window_frame_from_row_number, - split_select_distinct_with_order_by, ) -from ibis import util -from ibis.backends.sql.compilers.bigquery.udf.core import PythonToJavaScriptTranslator import ibis.common.exceptions as com from ibis.common.temporal import DateUnit, IntervalUnit, TimestampUnit, TimeUnit import ibis.expr.datatypes as dt import ibis.expr.operations as ops +import numpy as np import sqlglot as sg from sqlglot.dialects import BigQuery import sqlglot.expressions as sge @@ -39,6 +36,7 @@ import ibis.expr.types as ir + _NAME_REGEX = re.compile(r'[^!"$()*,./;?@[\\\]^`{}~\n]+') @@ -63,12 +61,9 @@ def _remove_null_ordering_from_unsupported_window( node: sge.Expression, ) -> sge.Expression: """Remove null ordering in window frame clauses not supported by BigQuery. - BigQuery has only partial support for NULL FIRST/LAST in RANGE windows so we remove it from any window frame clause that doesn't support it. - Here's the support matrix: - ✅ sum(x) over (order by y desc nulls last) 🚫 sum(x) over (order by y asc nulls last) ✅ sum(x) over (order by y asc nulls first) @@ -90,14 +85,10 @@ def _remove_null_ordering_from_unsupported_window( def _force_quote_table(table: sge.Table) -> sge.Table: """Force quote all the parts of a bigquery path. - The BigQuery identifier quoting semantics are bonkers https://cloud.google.com/bigquery/docs/reference/standard-sql/lexical#identifiers - my-table is OK, but not mydataset.my-table - mytable-287 is OK, but not mytable-287a - Just quote everything. """ for key in ("this", "db", "catalog"): @@ -122,7 +113,6 @@ class BigQueryCompiler(SQLGlotCompiler): exclude_unsupported_window_frame_from_rank, *SQLGlotCompiler.rewrites, ) - post_rewrites = (split_select_distinct_with_order_by,) supports_qualify = True @@ -213,7 +203,6 @@ def to_sqlglot( session_project: str | None = None, ) -> Any: """Compile an Ibis expression. - Parameters ---------- expr @@ -227,18 +216,16 @@ def to_sqlglot( Optional dataset ID to qualify memtable references. session_project Optional project ID to qualify memtable references. - Returns ------- Any The output of compilation. The type of this value depends on the backend. - """ sql = super().to_sqlglot(expr, limit=limit, params=params) table_expr = expr.as_table() - geocols = table_expr.schema().geospatial + geocols = getattr(table_expr.schema(), "geospatial", None) result = sql.transform( _qualify_memtable, @@ -279,64 +266,6 @@ def to_sqlglot( sources.append(result) return sources - def _compile_python_udf(self, udf_node: ops.ScalarUDF) -> sge.Create: - name = type(udf_node).__name__ - type_mapper = self.udf_type_mapper - - body = PythonToJavaScriptTranslator(udf_node.__func__).compile() - config = udf_node.__config__ - libraries = config.get("libraries", []) - - signature = [ - sge.ColumnDef( - this=sg.to_identifier(name, quoted=self.quoted), - kind=type_mapper.from_ibis(param.annotation.pattern.dtype), - ) - for name, param in udf_node.__signature__.parameters.items() - ] - - lines = ['"""'] - - if config.get("strict", True): - lines.append('"use strict";') - - lines += [ - body, - "", - f"return {udf_node.__func_name__}({', '.join(udf_node.argnames)});", - '"""', - ] - - func = sge.Create( - kind="FUNCTION", - this=sge.UserDefinedFunction( - this=sg.to_identifier(name), expressions=signature, wrapped=True - ), - # not exactly what I had in mind, but it works - # - # quoting is too simplistic to handle multiline strings - expression=sge.Var(this="\n".join(lines)), - exists=False, - properties=sge.Properties( - expressions=[ - sge.TemporaryProperty(), - sge.ReturnsProperty(this=type_mapper.from_ibis(udf_node.dtype)), - sge.StabilityProperty( - this="IMMUTABLE" if config.get("determinism") else "VOLATILE" - ), - sge.LanguageProperty(this=sg.to_identifier("js")), - ] - + [ - sge.Property( - this=sg.to_identifier("library"), value=self.f.array(*libraries) - ) - ] - * bool(libraries) - ), - ) - - return func - @staticmethod def _minimize_spec(start, end, spec): if ( From 822180a3030694ff5516e736c439d8087f796e50 Mon Sep 17 00:00:00 2001 From: Owl Bot Date: Tue, 10 Sep 2024 18:19:34 +0000 Subject: [PATCH 50/59] =?UTF-8?q?=F0=9F=A6=89=20Updates=20from=20OwlBot=20?= =?UTF-8?q?post-processor?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit See https://github.com/googleapis/repo-automation-bots/blob/main/packages/owl-bot/README.md --- .../bigframes_vendored/ibis/backends/sql/compilers/base.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/third_party/bigframes_vendored/ibis/backends/sql/compilers/base.py b/third_party/bigframes_vendored/ibis/backends/sql/compilers/base.py index 4ba9a51618..481cec397e 100644 --- a/third_party/bigframes_vendored/ibis/backends/sql/compilers/base.py +++ b/third_party/bigframes_vendored/ibis/backends/sql/compilers/base.py @@ -1261,9 +1261,7 @@ def _cleanup_names(self, exprs: Mapping[str, sge.Expression]): else: yield value.as_(name, quoted=self.quoted, copy=False) - def visit_Select( - self, op, *, parent, selections, predicates, qualified, sort_keys - ): + def visit_Select(self, op, *, parent, selections, predicates, qualified, sort_keys): # if we've constructed a useless projection return the parent relation if not (selections or predicates or qualified or sort_keys): return parent From 6f054f4425cf42bbb0f14a65f24f00ba8857540c Mon Sep 17 00:00:00 2001 From: Chelsea Lin Date: Tue, 10 Sep 2024 20:55:56 +0000 Subject: [PATCH 51/59] fix unit-test compile errors --- .../bigframes_vendored/ibis/backends/sql/compilers/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/third_party/bigframes_vendored/ibis/backends/sql/compilers/base.py b/third_party/bigframes_vendored/ibis/backends/sql/compilers/base.py index 481cec397e..cbd8e4e2d9 100644 --- a/third_party/bigframes_vendored/ibis/backends/sql/compilers/base.py +++ b/third_party/bigframes_vendored/ibis/backends/sql/compilers/base.py @@ -1124,7 +1124,7 @@ def visit_Coalesce(self, op, *, arg): ### Ordering and window functions - def visit_SortKey(self, op, *, expr, ascending: bool, nulls_first: bool): + def visit_SortKey(self, op, *, expr, ascending: bool, nulls_first: bool = False): return sge.Ordered(this=expr, desc=not ascending, nulls_first=nulls_first) def visit_ApproxMedian(self, op, *, arg, where): From c7167dbafb32c8c4d24bff88a457af80cfa8a9b6 Mon Sep 17 00:00:00 2001 From: Chelsea Lin Date: Tue, 10 Sep 2024 22:34:21 +0000 Subject: [PATCH 52/59] remove unused ibis codes --- .../ibis/backends/bigquery/backend.py | 76 ------------------- .../sql/compilers/bigquery/__init__.py | 47 +----------- 2 files changed, 1 insertion(+), 122 deletions(-) diff --git a/third_party/bigframes_vendored/ibis/backends/bigquery/backend.py b/third_party/bigframes_vendored/ibis/backends/bigquery/backend.py index 61a227619e..08f260115d 100644 --- a/third_party/bigframes_vendored/ibis/backends/bigquery/backend.py +++ b/third_party/bigframes_vendored/ibis/backends/bigquery/backend.py @@ -81,50 +81,6 @@ def _create_client_info_gapic(application_name): return ClientInfo(user_agent=_create_user_agent(application_name)) -_MEMTABLE_PATTERN = re.compile( - r"^_?ibis_(?:[A-Za-z_][A-Za-z_0-9]*)_memtable_[a-z0-9]{26}$" -) - - -def _qualify_memtable( - node: sge.Expression, *, dataset: str | None, project: str | None -) -> sge.Expression: - """Add a BigQuery dataset and project to memtable references.""" - if isinstance(node, sge.Table) and _MEMTABLE_PATTERN.match(node.name) is not None: - node.args["db"] = dataset - node.args["catalog"] = project - # make sure to quote table location - node = _force_quote_table(node) - return node - - -def _remove_null_ordering_from_unsupported_window( - node: sge.Expression, -) -> sge.Expression: - """Remove null ordering in window frame clauses not supported by BigQuery. - - BigQuery has only partial support for NULL FIRST/LAST in RANGE windows so - we remove it from any window frame clause that doesn't support it. - - Here's the support matrix: - - ✅ sum(x) over (order by y desc nulls last) - 🚫 sum(x) over (order by y asc nulls last) - ✅ sum(x) over (order by y asc nulls first) - 🚫 sum(x) over (order by y desc nulls first) - """ - if isinstance(node, sge.Window): - order = node.args.get("order") - if order is not None: - for key in order.args["expressions"]: - kargs = key.args - if kargs.get("desc") is True and kargs.get("nulls_first", False): - kargs["nulls_first"] = False - elif kargs.get("desc") is False and not kargs.setdefault( - "nulls_first", True - ): - kargs["nulls_first"] = True - return node def _force_quote_table(table: sge.Table) -> sge.Table: @@ -167,32 +123,6 @@ def _session_dataset(self): self.__session_dataset = self._make_session() return self.__session_dataset - def _register_in_memory_table(self, op: ops.InMemoryTable) -> None: - raw_name = op.name - - session_dataset = self._session_dataset - project = session_dataset.project - dataset = session_dataset.dataset_id - - table_ref = bq.TableReference(session_dataset, raw_name) - try: - self.client.get_table(table_ref) - except google.api_core.exceptions.NotFound: - table_id = sg.table( - raw_name, db=dataset, catalog=project, quoted=False - ).sql(dialect=self.name) - bq_schema = BigQuerySchema.from_ibis(op.schema) - load_job = self.client.load_table_from_dataframe( - op.data.to_frame(), - table_id, - job_config=bq.LoadJobConfig( - # fail if the table already exists and contains data - write_disposition=bq.WriteDisposition.WRITE_EMPTY, - schema=bq_schema, - ), - ) - load_job.result() - def _read_file( self, path: str | Path, @@ -797,7 +727,6 @@ def to_pyarrow( **kwargs: Any, ) -> pa.Table: self._import_pyarrow() - self._register_in_memory_tables(expr) sql = self.compile(expr, limit=limit, params=params, **kwargs) self._log(sql) query = self.raw_sql(sql, params=params, **kwargs) @@ -820,7 +749,6 @@ def to_pyarrow_batches( schema = expr.as_table().schema() - self._register_in_memory_tables(expr) sql = self.compile(expr, limit=limit, params=params, **kwargs) self._log(sql) query = self.raw_sql(sql, params=params, page_size=chunk_size, **kwargs) @@ -1013,9 +941,6 @@ def create_table( if obj is not None and not isinstance(obj, ir.Table): obj = ibis.memtable(obj, schema=schema) - if obj is not None: - self._register_in_memory_tables(obj) - if temp: dataset = self._session_dataset.dataset_id if database is not None: @@ -1111,7 +1036,6 @@ def create_view( expression=self.compile(obj), replace=overwrite, ) - self._register_in_memory_tables(obj) self.raw_sql(stmt.sql(self.name)) return self.table(name, database=(catalog, database)) diff --git a/third_party/bigframes_vendored/ibis/backends/sql/compilers/bigquery/__init__.py b/third_party/bigframes_vendored/ibis/backends/sql/compilers/bigquery/__init__.py index b664ad7333..280e7ed1e6 100644 --- a/third_party/bigframes_vendored/ibis/backends/sql/compilers/bigquery/__init__.py +++ b/third_party/bigframes_vendored/ibis/backends/sql/compilers/bigquery/__init__.py @@ -52,54 +52,9 @@ def _qualify_memtable( if isinstance(node, sge.Table) and _MEMTABLE_PATTERN.match(node.name) is not None: node.args["db"] = dataset node.args["catalog"] = project - # make sure to quote table location - node = _force_quote_table(node) return node -def _remove_null_ordering_from_unsupported_window( - node: sge.Expression, -) -> sge.Expression: - """Remove null ordering in window frame clauses not supported by BigQuery. - BigQuery has only partial support for NULL FIRST/LAST in RANGE windows so - we remove it from any window frame clause that doesn't support it. - Here's the support matrix: - ✅ sum(x) over (order by y desc nulls last) - 🚫 sum(x) over (order by y asc nulls last) - ✅ sum(x) over (order by y asc nulls first) - 🚫 sum(x) over (order by y desc nulls first) - """ - if isinstance(node, sge.Window): - order = node.args.get("order") - if order is not None: - for key in order.args["expressions"]: - kargs = key.args - if kargs.get("desc") is True and kargs.get("nulls_first", False): - kargs["nulls_first"] = False - elif kargs.get("desc") is False and not kargs.setdefault( - "nulls_first", True - ): - kargs["nulls_first"] = True - return node - - -def _force_quote_table(table: sge.Table) -> sge.Table: - """Force quote all the parts of a bigquery path. - The BigQuery identifier quoting semantics are bonkers - https://cloud.google.com/bigquery/docs/reference/standard-sql/lexical#identifiers - my-table is OK, but not mydataset.my-table - mytable-287 is OK, but not mytable-287a - Just quote everything. - """ - for key in ("this", "db", "catalog"): - if (val := table.args[key]) is not None: - if isinstance(val, sg.exp.Identifier) and not val.quoted: - val.args["quoted"] = True - else: - table.args[key] = sg.to_identifier(val, quoted=True) - return table - - class BigQueryCompiler(SQLGlotCompiler): dialect = BigQuery type_mapper = BigQueryType @@ -231,7 +186,7 @@ def to_sqlglot( _qualify_memtable, dataset=session_dataset_id, project=session_project, - ).transform(_remove_null_ordering_from_unsupported_window) + ) if geocols: # if there are any geospatial columns, we have to convert them to WKB, From 555453b228e81cc67a744d3258081e8dde3f315d Mon Sep 17 00:00:00 2001 From: Chelsea Lin Date: Tue, 10 Sep 2024 23:19:16 +0000 Subject: [PATCH 53/59] fix fillna deprecated warning --- bigframes/core/compile/aggregate_compiler.py | 16 +++++++++++----- bigframes/core/compile/default_ordering.py | 7 ++++++- bigframes/core/compile/scalar_op_compiler.py | 15 ++++++++++++--- bigframes/core/compile/single_column.py | 6 +++++- .../ibis/backends/bigquery/backend.py | 3 --- 5 files changed, 34 insertions(+), 13 deletions(-) diff --git a/bigframes/core/compile/aggregate_compiler.py b/bigframes/core/compile/aggregate_compiler.py index 97628d9261..6e1cc5ae43 100644 --- a/bigframes/core/compile/aggregate_compiler.py +++ b/bigframes/core/compile/aggregate_compiler.py @@ -521,11 +521,14 @@ def _( column: ibis_types.Column, window=None, ) -> ibis_types.BooleanValue: - # BQ will return null for empty column, result would be true in pandas. - result = _is_true(column).all() + # BQ will return null for empty column, result would be false in pandas. + result = _apply_window_if_present(_is_true(column).all(), window) + return cast( ibis_types.BooleanScalar, - _apply_window_if_present(result, window).fillna(ibis_types.literal(True)), + result.fill_null(ibis_types.literal(False)) + if hasattr(result, "fill_null") + else result.fillna(ibis_types.literal(False)), ) @@ -536,10 +539,13 @@ def _( window=None, ) -> ibis_types.BooleanValue: # BQ will return null for empty column, result would be false in pandas. - result = _is_true(column).any() + result = _apply_window_if_present(_is_true(column).any(), window) + return cast( ibis_types.BooleanScalar, - _apply_window_if_present(result, window).fillna(ibis_types.literal(False)), + result.fill_null(ibis_types.literal(False)) + if hasattr(result, "fill_null") + else result.fillna(ibis_types.literal(False)), ) diff --git a/bigframes/core/compile/default_ordering.py b/bigframes/core/compile/default_ordering.py index 7d7a41f742..a6b625caca 100644 --- a/bigframes/core/compile/default_ordering.py +++ b/bigframes/core/compile/default_ordering.py @@ -49,7 +49,12 @@ def _convert_to_nonnull_string(column: ibis_types.Column) -> ibis_types.StringVa # Needed for JSON, STRUCT and ARRAY datatypes result = vendored_ibis_ops.ToJsonString(column).to_expr() # type: ignore # Escape backslashes and use backslash as delineator - escaped = cast(ibis_types.StringColumn, result.fillna("")).replace("\\", "\\\\") # type: ignore + escaped = cast( + ibis_types.StringColumn, + result.fill_null("") if hasattr(result, "fill_null") else result.fillna(""), + ).replace( + "\\", "\\\\" + ) # type: ignore return cast(ibis_types.StringColumn, ibis.literal("\\")).concat(escaped) diff --git a/bigframes/core/compile/scalar_op_compiler.py b/bigframes/core/compile/scalar_op_compiler.py index 76b91abb28..6fe876f7f9 100644 --- a/bigframes/core/compile/scalar_op_compiler.py +++ b/bigframes/core/compile/scalar_op_compiler.py @@ -995,8 +995,14 @@ def eq_nulls_match_op( y: ibis_types.Value, ): """Variant of eq_op where nulls match each other. Only use where dtypes are known to be same.""" - left = x.cast(ibis_dtypes.str).fillna(ibis_types.literal("$NULL_SENTINEL$")) - right = y.cast(ibis_dtypes.str).fillna(ibis_types.literal("$NULL_SENTINEL$")) + literal = ibis_types.literal("$NULL_SENTINEL$") + if hasattr(x, "fill_null"): + left = x.cast(ibis_dtypes.str).fill_null(literal) + right = y.cast(ibis_dtypes.str).fill_null(literal) + else: + left = x.cast(ibis_dtypes.str).fillna(literal) + right = y.cast(ibis_dtypes.str).fillna(literal) + return left == right @@ -1379,7 +1385,10 @@ def fillna_op( x: ibis_types.Value, y: ibis_types.Value, ): - return x.fillna(typing.cast(ibis_types.Scalar, y)) + if hasattr(x, "fill_null"): + return x.fill_null(typing.cast(ibis_types.Scalar, y)) + else: + return x.fillna(typing.cast(ibis_types.Scalar, y)) @scalar_op_compiler.register_binary_op(ops.round_op) diff --git a/bigframes/core/compile/single_column.py b/bigframes/core/compile/single_column.py index 9b621c9c79..cd04b1fea1 100644 --- a/bigframes/core/compile/single_column.py +++ b/bigframes/core/compile/single_column.py @@ -170,4 +170,8 @@ def value_to_join_key(value: ibis_types.Value): """Converts nullable values to non-null string SQL will not match null keys together - but pandas does.""" if not value.type().is_string(): value = value.cast(ibis_dtypes.str) - return value.fillna(ibis_types.literal("$NULL_SENTINEL$")) + return ( + value.fill_null(ibis_types.literal("$NULL_SENTINEL$")) + if hasattr(value, "fill_null") + else value.fillna(ibis_types.literal("$NULL_SENTINEL$")) + ) diff --git a/third_party/bigframes_vendored/ibis/backends/bigquery/backend.py b/third_party/bigframes_vendored/ibis/backends/bigquery/backend.py index 08f260115d..d4d5156572 100644 --- a/third_party/bigframes_vendored/ibis/backends/bigquery/backend.py +++ b/third_party/bigframes_vendored/ibis/backends/bigquery/backend.py @@ -8,7 +8,6 @@ import contextlib import glob import os -import re from typing import Any, Optional, TYPE_CHECKING from bigframes_vendored.ibis.backends.bigquery.datatypes import BigQueryType @@ -81,8 +80,6 @@ def _create_client_info_gapic(application_name): return ClientInfo(user_agent=_create_user_agent(application_name)) - - def _force_quote_table(table: sge.Table) -> sge.Table: """Force quote all the parts of a bigquery path. From 0762299ffb431b053393b2549fbc3e92fb2cbb82 Mon Sep 17 00:00:00 2001 From: Chelsea Lin Date: Wed, 11 Sep 2024 17:32:42 +0000 Subject: [PATCH 54/59] add _remove_null_ordering_from_unsupported_window back to fix test_precision_score etc ml tests --- .../sql/compilers/bigquery/__init__.py | 28 ++++++++++++++++++- 1 file changed, 27 insertions(+), 1 deletion(-) diff --git a/third_party/bigframes_vendored/ibis/backends/sql/compilers/bigquery/__init__.py b/third_party/bigframes_vendored/ibis/backends/sql/compilers/bigquery/__init__.py index 280e7ed1e6..7cf4947383 100644 --- a/third_party/bigframes_vendored/ibis/backends/sql/compilers/bigquery/__init__.py +++ b/third_party/bigframes_vendored/ibis/backends/sql/compilers/bigquery/__init__.py @@ -55,6 +55,32 @@ def _qualify_memtable( return node +def _remove_null_ordering_from_unsupported_window( + node: sge.Expression, +) -> sge.Expression: + """Remove null ordering in window frame clauses not supported by BigQuery. + BigQuery has only partial support for NULL FIRST/LAST in RANGE windows so + we remove it from any window frame clause that doesn't support it. + Here's the support matrix: + ✅ sum(x) over (order by y desc nulls last) + 🚫 sum(x) over (order by y asc nulls last) + ✅ sum(x) over (order by y asc nulls first) + 🚫 sum(x) over (order by y desc nulls first) + """ + if isinstance(node, sge.Window): + order = node.args.get("order") + if order is not None: + for key in order.args["expressions"]: + kargs = key.args + if kargs.get("desc") is True and kargs.get("nulls_first", False): + kargs["nulls_first"] = False + elif kargs.get("desc") is False and not kargs.setdefault( + "nulls_first", True + ): + kargs["nulls_first"] = True + return node + + class BigQueryCompiler(SQLGlotCompiler): dialect = BigQuery type_mapper = BigQueryType @@ -186,7 +212,7 @@ def to_sqlglot( _qualify_memtable, dataset=session_dataset_id, project=session_project, - ) + ).transform(_remove_null_ordering_from_unsupported_window) if geocols: # if there are any geospatial columns, we have to convert them to WKB, From 4c905162f9694b6b4cb0c450efe29ea0949ed4ef Mon Sep 17 00:00:00 2001 From: Chelsea Lin Date: Wed, 11 Sep 2024 19:30:22 +0000 Subject: [PATCH 55/59] fix is_monotonic_decreasing test --- bigframes/core/compile/aggregate_compiler.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/bigframes/core/compile/aggregate_compiler.py b/bigframes/core/compile/aggregate_compiler.py index 6e1cc5ae43..5d1b4530e1 100644 --- a/bigframes/core/compile/aggregate_compiler.py +++ b/bigframes/core/compile/aggregate_compiler.py @@ -523,12 +523,13 @@ def _( ) -> ibis_types.BooleanValue: # BQ will return null for empty column, result would be false in pandas. result = _apply_window_if_present(_is_true(column).all(), window) + literal = ibis_types.literal(True) return cast( ibis_types.BooleanScalar, - result.fill_null(ibis_types.literal(False)) + result.fill_null(literal) if hasattr(result, "fill_null") - else result.fillna(ibis_types.literal(False)), + else result.fillna(literal), ) @@ -540,12 +541,13 @@ def _( ) -> ibis_types.BooleanValue: # BQ will return null for empty column, result would be false in pandas. result = _apply_window_if_present(_is_true(column).any(), window) + literal = ibis_types.literal(False) return cast( ibis_types.BooleanScalar, - result.fill_null(ibis_types.literal(False)) + result.fill_null(literal) if hasattr(result, "fill_null") - else result.fillna(ibis_types.literal(False)), + else result.fillna(literal), ) From d425aa9dadd76dbfdeab67dc25d4ab15d92e3f3d Mon Sep 17 00:00:00 2001 From: Chelsea Lin Date: Thu, 12 Sep 2024 05:08:20 +0000 Subject: [PATCH 56/59] fix explode after merge --- bigframes/core/compile/compiled.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bigframes/core/compile/compiled.py b/bigframes/core/compile/compiled.py index fbb7a3e7be..cd00c98381 100644 --- a/bigframes/core/compile/compiled.py +++ b/bigframes/core/compile/compiled.py @@ -716,7 +716,7 @@ def explode(self, offsets: typing.Sequence[int]) -> OrderedIR: 0, ibis.greatest( 1, # We always want at least 1 element to fill in NULLs for empty arrays. - ibis.least(*[table[table.columns[offset]].length() - 1 for offset in offsets]), + ibis.least(*[table[column_id].length() for column_id in column_ids]), ), ).name(offset_array_id) table_w_offset_array = table.select( From fa46553052f7d94cc482e41ac927e87362a35d92 Mon Sep 17 00:00:00 2001 From: Chelsea Lin Date: Thu, 12 Sep 2024 17:39:05 +0000 Subject: [PATCH 57/59] fix numpy on remote function test --- tests/system/large/test_remote_function.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/system/large/test_remote_function.py b/tests/system/large/test_remote_function.py index 1408d9c631..9503faa0b7 100644 --- a/tests/system/large/test_remote_function.py +++ b/tests/system/large/test_remote_function.py @@ -1727,7 +1727,7 @@ def test_df_apply_axis_1_complex(session, pd_df): def serialize_row(row): custom = { "name": row.name.item() if hasattr(row.name, "item") else row.name, - "index": [idx for idx in row.index], + "index": [idx.item() if hasattr(idx, "item") else idx for idx in row.index], "values": [ val.item() if hasattr(val, "item") else val for val in row.values ], From f17559688493a2db1eed1d71b49bbc0a340f8d9c Mon Sep 17 00:00:00 2001 From: Owl Bot Date: Thu, 12 Sep 2024 17:41:32 +0000 Subject: [PATCH 58/59] =?UTF-8?q?=F0=9F=A6=89=20Updates=20from=20OwlBot=20?= =?UTF-8?q?post-processor?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit See https://github.com/googleapis/repo-automation-bots/blob/main/packages/owl-bot/README.md --- tests/system/large/test_remote_function.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/system/large/test_remote_function.py b/tests/system/large/test_remote_function.py index 9503faa0b7..e224f65a01 100644 --- a/tests/system/large/test_remote_function.py +++ b/tests/system/large/test_remote_function.py @@ -1727,7 +1727,9 @@ def test_df_apply_axis_1_complex(session, pd_df): def serialize_row(row): custom = { "name": row.name.item() if hasattr(row.name, "item") else row.name, - "index": [idx.item() if hasattr(idx, "item") else idx for idx in row.index], + "index": [ + idx.item() if hasattr(idx, "item") else idx for idx in row.index + ], "values": [ val.item() if hasattr(val, "item") else val for val in row.values ], From f3a43b1cd07ddb9d07bb8579eb5e4c23de6489be Mon Sep 17 00:00:00 2001 From: Chelsea Lin Date: Fri, 13 Sep 2024 20:38:16 +0000 Subject: [PATCH 59/59] ml numpy sql generations --- bigframes/ml/preprocessing.py | 15 ++------------- bigframes/ml/sql.py | 7 ++++++- tests/system/small/test_dataframe.py | 2 +- .../backends/sql/compilers/bigquery/__init__.py | 2 +- 4 files changed, 10 insertions(+), 16 deletions(-) diff --git a/bigframes/ml/preprocessing.py b/bigframes/ml/preprocessing.py index 80e00a7497..2c327f63f8 100644 --- a/bigframes/ml/preprocessing.py +++ b/bigframes/ml/preprocessing.py @@ -305,19 +305,8 @@ def _compile_to_sql( array_split_points = {} if self.strategy == "uniform": for column in columns: - column_min = X[column].min() - column_max = X[column].max() - - # Use Python value rather than Numpy value to serialization. - if hasattr(column_min, "item"): - min_value = column_min.item() - else: - min_value = column_min - - if hasattr(column_max, "item"): - max_value = column_max.item() - else: - max_value = column_max + min_value = X[column].min() + max_value = X[column].max() bin_size = (max_value - min_value) / self.n_bins array_split_points[column] = [ diff --git a/bigframes/ml/sql.py b/bigframes/ml/sql.py index d14627f590..a91ae78f16 100644 --- a/bigframes/ml/sql.py +++ b/bigframes/ml/sql.py @@ -124,7 +124,12 @@ def ml_bucketize( name: str, ) -> str: """Encode ML.BUCKETIZE for BQML""" - return f"""ML.BUCKETIZE({numeric_expr_sql}, {array_split_points}, FALSE) AS {name}""" + # Use Python value rather than Numpy value to serialization. + points = [ + point.item() if hasattr(point, "item") else point + for point in array_split_points + ] + return f"""ML.BUCKETIZE({numeric_expr_sql}, {points}, FALSE) AS {name}""" def ml_quantile_bucketize( self, diff --git a/tests/system/small/test_dataframe.py b/tests/system/small/test_dataframe.py index da29b79aef..9e046dc62e 100644 --- a/tests/system/small/test_dataframe.py +++ b/tests/system/small/test_dataframe.py @@ -4569,7 +4569,7 @@ def test_recursion_limit(scalars_df_index): @pytest.mark.skipif( - reason="Skip until query complexity error can be reliably triggered" + reason="b/366477265: Skip until query complexity error can be reliably triggered." ) def test_query_complexity_error(scalars_df_index): # This test requires automatic caching/query decomposition to be turned off diff --git a/third_party/bigframes_vendored/ibis/backends/sql/compilers/bigquery/__init__.py b/third_party/bigframes_vendored/ibis/backends/sql/compilers/bigquery/__init__.py index 7cf4947383..3015991a26 100644 --- a/third_party/bigframes_vendored/ibis/backends/sql/compilers/bigquery/__init__.py +++ b/third_party/bigframes_vendored/ibis/backends/sql/compilers/bigquery/__init__.py @@ -437,7 +437,7 @@ def visit_ArrayContains(self, op, *, arg, other): def visit_StringContains(self, op, *, haystack, needle): return self.f.strpos(haystack, needle) > 0 - def visti_StringFind(self, op, *, arg, substr, start, end): + def visit_StringFind(self, op, *, arg, substr, start, end): if start is not None: raise NotImplementedError( "`start` not implemented for BigQuery string find"