Skip to content

Commit ae818fb

Browse files
cpcloudkszucs
authored andcommitted
feat(pyspark): implement covariance and correlation
1 parent 335f6ba commit ae818fb

File tree

2 files changed

+38
-4
lines changed

2 files changed

+38
-4
lines changed

ibis/backends/pyspark/compiler.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -690,6 +690,44 @@ def compile_variance(t, expr, scope, timecontext, context=None, **kwargs):
690690
)
691691

692692

693+
@compiles(ops.Covariance)
694+
def compile_covariance(t, expr, scope, timecontext, context=None, **kwargs):
695+
op = expr.op()
696+
how = op.how
697+
698+
fn = {"sample": F.covar_samp, "pop": F.covar_pop}[how]
699+
700+
pyspark_double_type = ibis_dtype_to_spark_dtype(dtypes.double)
701+
expr = op.__class__(
702+
left=op.left.cast(pyspark_double_type),
703+
right=op.right.cast(pyspark_double_type),
704+
how=how,
705+
where=op.where,
706+
).to_expr()
707+
return compile_aggregator(
708+
t, expr, scope, timecontext, fn=fn, context=context
709+
)
710+
711+
712+
@compiles(ops.Correlation)
713+
def compile_correlation(t, expr, scope, timecontext, context=None, **kwargs):
714+
op = expr.op()
715+
716+
if (how := op.how) == "pop":
717+
raise ValueError("PySpark only implements sample correlation")
718+
719+
pyspark_double_type = ibis_dtype_to_spark_dtype(dtypes.double)
720+
expr = op.__class__(
721+
left=op.left.cast(pyspark_double_type),
722+
right=op.right.cast(pyspark_double_type),
723+
how=how,
724+
where=op.where,
725+
).to_expr()
726+
return compile_aggregator(
727+
t, expr, scope, timecontext, fn=F.corr, context=context
728+
)
729+
730+
693731
@compiles(ops.Arbitrary)
694732
def compile_arbitrary(t, expr, scope, timecontext, context=None, **kwargs):
695733
how = expr.op().how

ibis/backends/tests/test_aggregation.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -200,8 +200,6 @@ def test_aggregate_grouped(
200200
"impala",
201201
"mysql",
202202
"pandas",
203-
"postgres",
204-
"pyspark",
205203
"sqlite",
206204
]
207205
)
@@ -220,8 +218,6 @@ def test_aggregate_grouped(
220218
"impala",
221219
"mysql",
222220
"pandas",
223-
"postgres",
224-
"pyspark",
225221
"sqlite",
226222
]
227223
)

0 commit comments

Comments
 (0)