Skip to content

Commit 24f41b2

Browse files
committed
feat(exasol): implement cov/corr
1 parent e20bdad commit 24f41b2

File tree

2 files changed

+16
-3
lines changed

2 files changed

+16
-3
lines changed

ibis/backends/sql/compilers/exasol.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,6 @@ class ExasolCompiler(SQLGlotCompiler):
4444
ops.ArrayUnion,
4545
ops.ArrayZip,
4646
ops.BitwiseNot,
47-
ops.Covariance,
4847
ops.CumeDist,
4948
ops.DateAdd,
5049
ops.DateSub,
@@ -120,6 +119,20 @@ def visit_NonNullLiteral(self, op, *, value, dtype):
120119
def visit_Date(self, op, *, arg):
121120
return self.cast(arg, dt.date)
122121

122+
def visit_Correlation(self, op, *, left, right, how, where):
123+
if how == "sample":
124+
raise com.UnsupportedOperationError(
125+
"Exasol only implements `pop` correlation coefficient"
126+
)
127+
128+
if (left_type := op.left.dtype).is_boolean():
129+
left = self.cast(left, dt.Int32(nullable=left_type.nullable))
130+
131+
if (right_type := op.right.dtype).is_boolean():
132+
right = self.cast(right, dt.Int32(nullable=right_type.nullable))
133+
134+
return self.agg.corr(left, right, where=where)
135+
123136
def visit_GroupConcat(self, op, *, arg, sep, where, order_by):
124137
if where is not None:
125138
arg = self.if_(where, arg, NULL)

ibis/backends/tests/test_aggregation.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1013,7 +1013,7 @@ def test_quantile(
10131013
raises=com.OperationNotDefinedError,
10141014
),
10151015
pytest.mark.notyet(
1016-
["postgres", "duckdb", "snowflake", "risingwave"],
1016+
["postgres", "duckdb", "snowflake", "risingwave", "exasol"],
10171017
raises=com.UnsupportedOperationError,
10181018
reason="backend only implements population correlation coefficient",
10191019
),
@@ -1114,7 +1114,7 @@ def test_quantile(
11141114
),
11151115
],
11161116
)
1117-
@pytest.mark.notimpl(["mssql", "exasol"], raises=com.OperationNotDefinedError)
1117+
@pytest.mark.notimpl(["mssql"], raises=com.OperationNotDefinedError)
11181118
def test_corr_cov(
11191119
con,
11201120
batting,

0 commit comments

Comments
 (0)