Skip to content

Commit 42d95b0

Browse files
committed
fix(snowflake): implement working TimestampNow
1 parent 57b1dd8 commit 42d95b0

File tree

3 files changed

+22
-4
lines changed

3 files changed

+22
-4
lines changed

ibis/backends/postgres/registry.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -559,8 +559,6 @@ def variance_compiler(t, op):
559559
ops.DayOfWeekName: fixed_arity(
560560
lambda arg: sa.func.trim(sa.func.to_char(arg, 'Day')), 1
561561
),
562-
# now is in the timezone of the server, but we want UTC
563-
ops.TimestampNow: lambda *_: sa.func.timezone('UTC', sa.func.now()),
564562
ops.TimeFromHMS: fixed_arity(sa.func.make_time, 3),
565563
ops.CumulativeAll: unary(sa.func.bool_and),
566564
ops.CumulativeAny: unary(sa.func.bool_or),
@@ -582,5 +580,8 @@ def variance_compiler(t, op):
582580
ops.Mode: _mode,
583581
ops.Quantile: _quantile,
584582
ops.MultiQuantile: _quantile,
583+
ops.TimestampNow: lambda t, op: sa.literal_column(
584+
"CURRENT_TIMESTAMP", type_=t.get_sqla_type(op.output_dtype)
585+
),
585586
}
586587
)

ibis/backends/snowflake/__init__.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,23 @@ def do_connect(
9595
)
9696
)
9797

98+
@contextlib.contextmanager
99+
def begin(self):
100+
with super().begin() as bind:
101+
previous_timezone = (
102+
bind.execute(sa.text("SHOW PARAMETERS LIKE 'TIMEZONE' IN SESSION"))
103+
.mappings()
104+
.fetchone()
105+
.value
106+
)
107+
bind.execute(sa.text("ALTER SESSION SET TIMEZONE = 'UTC'"))
108+
try:
109+
yield bind
110+
finally:
111+
bind.execute(
112+
sa.text(f"ALTER SESSION SET TIMEZONE = {previous_timezone!r}")
113+
)
114+
98115
def _get_sqla_table(
99116
self, name: str, schema: str | None = None, **_: Any
100117
) -> sa.Table:

ibis/backends/tests/test_temporal.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -749,7 +749,7 @@ def test_day_of_week_column_group_by(
749749
backend.assert_frame_equal(result, expected, check_dtype=False)
750750

751751

752-
@pytest.mark.notimpl(["datafusion", "snowflake", "mssql"])
752+
@pytest.mark.notimpl(["datafusion", "mssql"])
753753
def test_now(con):
754754
expr = ibis.now()
755755
result = con.execute(expr.name("tmp"))
@@ -762,7 +762,7 @@ def test_now(con):
762762

763763

764764
@pytest.mark.notimpl(["dask"], reason="Limit #2553")
765-
@pytest.mark.notimpl(["datafusion", "snowflake", "polars"])
765+
@pytest.mark.notimpl(["datafusion", "polars"])
766766
def test_now_from_projection(alltypes):
767767
n = 5
768768
expr = alltypes[[ibis.now().name('ts')]].limit(n)

0 commit comments

Comments
 (0)