Skip to content

Commit 529a3a2

Browse files
cpcloudgforsyth
authored andcommitted
feat(snowflake): support udf arguments for reading from staged files
1 parent 45ee391 commit 529a3a2

File tree

4 files changed

+122
-26
lines changed

4 files changed

+122
-26
lines changed

.github/workflows/ibis-backends-cloud.yml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,6 @@ jobs:
7878
run: just download-data
7979

8080
- uses: google-github-actions/auth@v2
81-
if: matrix.backend.name == 'bigquery'
8281
with:
8382
credentials_json: ${{ secrets.GCP_CREDENTIALS }}
8483

ibis/backends/snowflake/__init__.py

Lines changed: 42 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -294,29 +294,53 @@ def _get_udf_source(self, udf_node: ops.ScalarUDF):
294294
for name, arg in zip(udf_node.argnames, udf_node.args)
295295
)
296296
return_type = self._compile_type(udf_node.dtype)
297-
source = textwrap.dedent(inspect.getsource(udf_node.__func__)).strip()
298-
source = "\n".join(
299-
line for line in source.splitlines() if not line.startswith("@udf")
297+
lines, _ = inspect.getsourcelines(udf_node.__func__)
298+
source = textwrap.dedent(
299+
"".join(
300+
itertools.dropwhile(
301+
lambda line: not line.lstrip().startswith("def "), lines
302+
)
303+
)
304+
).strip()
305+
306+
config = udf_node.__config__
307+
308+
preamble_lines = [*self._UDF_PREAMBLE_LINES]
309+
310+
if imports := config.get("imports"):
311+
preamble_lines.append(f"IMPORTS = ({', '.join(map(repr, imports))})")
312+
313+
packages = "({})".format(
314+
", ".join(map(repr, ("pandas", *config.get("packages", ()))))
300315
)
316+
preamble_lines.append(f"PACKAGES = {packages}")
317+
301318
return dict(
302319
source=source,
303320
name=name,
304-
signature=signature,
305-
return_type=return_type,
306-
comment=f"Generated by ibis {ibis.__version__} using Python {platform.python_version()}",
307-
version=".".join(
308-
map(str, min(sys.version_info[:2], self._latest_udf_python_version))
321+
preamble="\n".join(preamble_lines).format(
322+
name=name,
323+
signature=signature,
324+
return_type=return_type,
325+
comment=f"Generated by ibis {ibis.__version__} using Python {platform.python_version()}",
326+
version=".".join(
327+
map(str, min(sys.version_info[:2], self._latest_udf_python_version))
328+
),
309329
),
310330
)
311331

332+
_UDF_PREAMBLE_LINES = (
333+
"CREATE OR REPLACE TEMPORARY FUNCTION {name}({signature})",
334+
"RETURNS {return_type}",
335+
"LANGUAGE PYTHON",
336+
"IMMUTABLE",
337+
"RUNTIME_VERSION = '{version}'",
338+
"COMMENT = '{comment}'",
339+
)
340+
312341
def _compile_python_udf(self, udf_node: ops.ScalarUDF) -> str:
313342
return """\
314-
CREATE OR REPLACE TEMPORARY FUNCTION {name}({signature})
315-
RETURNS {return_type}
316-
LANGUAGE PYTHON
317-
IMMUTABLE
318-
RUNTIME_VERSION = '{version}'
319-
COMMENT = '{comment}'
343+
{preamble}
320344
HANDLER = '{name}'
321345
AS $$
322346
from __future__ import annotations
@@ -327,14 +351,8 @@ def _compile_python_udf(self, udf_node: ops.ScalarUDF) -> str:
327351
$$""".format(**self._get_udf_source(udf_node))
328352

329353
def _compile_pandas_udf(self, udf_node: ops.ScalarUDF) -> str:
330-
return """\
331-
CREATE OR REPLACE TEMPORARY FUNCTION {name}({signature})
332-
RETURNS {return_type}
333-
LANGUAGE PYTHON
334-
IMMUTABLE
335-
RUNTIME_VERSION = '{version}'
336-
COMMENT = '{comment}'
337-
PACKAGES = ('pandas')
354+
template = """\
355+
{preamble}
338356
HANDLER = 'wrapper'
339357
AS $$
340358
from __future__ import annotations
@@ -349,7 +367,8 @@ def _compile_pandas_udf(self, udf_node: ops.ScalarUDF) -> str:
349367
@_snowflake.vectorized(input=pd.DataFrame)
350368
def wrapper(df):
351369
return {name}(*(col for _, col in df.items()))
352-
$$""".format(**self._get_udf_source(udf_node))
370+
$$"""
371+
return template.format(**self._get_udf_source(udf_node))
353372

354373
def to_pyarrow(
355374
self,

ibis/backends/snowflake/tests/conftest.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,10 @@
22

33
import concurrent.futures
44
import os
5+
import tempfile
6+
from pathlib import Path
57
from typing import TYPE_CHECKING, Any
8+
from urllib.request import urlretrieve
69

710
import pyarrow.parquet as pq
811
import pyarrow_hotfix # noqa: F401
@@ -17,8 +20,6 @@
1720
from ibis.formats.pyarrow import PyArrowSchema
1821

1922
if TYPE_CHECKING:
20-
from pathlib import Path
21-
2223
from ibis.backends.base import BaseBackend
2324

2425

@@ -115,9 +116,22 @@ def _load_data(self, **_: Any) -> None:
115116
CREATE SCHEMA IF NOT EXISTS {dbschema};
116117
USE SCHEMA {dbschema};
117118
CREATE TEMP STAGE ibis_testing;
119+
CREATE STAGE IF NOT EXISTS models;
118120
{self.script_dir.joinpath("snowflake.sql").read_text()}"""
119121
)
120122

123+
with tempfile.TemporaryDirectory() as d:
124+
path, _ = urlretrieve(
125+
"https://storage.googleapis.com/ibis-testing-data/model.joblib",
126+
os.path.join(d, "model.joblib"),
127+
)
128+
129+
assert os.path.exists(path)
130+
assert os.path.getsize(path) > 0
131+
132+
with con.begin() as c:
133+
c.exec_driver_sql(f"PUT {Path(path).as_uri()} @MODELS")
134+
121135
with con.begin() as c:
122136
# not much we can do to make this faster, but running these in
123137
# multiple threads seems to save about 2x

ibis/backends/snowflake/tests/test_udf.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,3 +92,67 @@ def test_builtin_agg_udf(con):
9292
expected = c.exec_driver_sql(query).cursor.fetch_pandas_all()
9393

9494
tm.assert_frame_equal(result, expected)
95+
96+
97+
def test_xgboost_model(con):
98+
import ibis
99+
from ibis import _
100+
101+
@udf.scalar.pandas(
102+
packages=("joblib", "xgboost"), imports=("@MODELS/model.joblib",)
103+
)
104+
def predict_price(
105+
carat_scaled: float, cut_encoded: int, color_encoded: int, clarity_encoded: int
106+
) -> int:
107+
import sys
108+
109+
import joblib
110+
import pandas as pd
111+
112+
import_dir = sys._xoptions.get("snowflake_import_directory")
113+
model = joblib.load(f"{import_dir}model.joblib")
114+
df = pd.concat(
115+
[carat_scaled, cut_encoded, color_encoded, clarity_encoded], axis=1
116+
)
117+
df.columns = ["CARAT_SCALED", "CUT_ENCODED", "COLOR_ENCODED", "CLARITY_ENCODED"]
118+
return model.predict(df)
119+
120+
def cases(value, mapping):
121+
"""This should really be a top-level function or method."""
122+
expr = ibis.case()
123+
for k, v in mapping.items():
124+
expr = expr.when(value == k, v)
125+
return expr.end()
126+
127+
diamonds = con.tables.DIAMONDS
128+
expr = diamonds.mutate(
129+
predicted_price=predict_price(
130+
(_.carat - _.carat.mean()) / _.carat.std(),
131+
cases(
132+
_.cut,
133+
{
134+
c: i
135+
for i, c in enumerate(
136+
("Fair", "Good", "Very Good", "Premium", "Ideal"), start=1
137+
)
138+
},
139+
),
140+
cases(_.color, {c: i for i, c in enumerate("DEFGHIJ", start=1)}),
141+
cases(
142+
_.clarity,
143+
{
144+
c: i
145+
for i, c in enumerate(
146+
("I1", "IF", "SI1", "SI2", "VS1", "VS2", "VVS1", "VVS2"),
147+
start=1,
148+
)
149+
},
150+
),
151+
)
152+
)
153+
154+
df = expr.execute()
155+
156+
assert not df.empty
157+
assert "predicted_price" in df.columns
158+
assert len(df) == diamonds.count().execute()

0 commit comments

Comments
 (0)