|
4 | 4 | import functools
|
5 | 5 | from functools import partial
|
6 | 6 | from operator import add, mul, sub
|
7 |
| -from typing import Any, Literal, Mapping |
| 7 | +from typing import Any, Callable, Literal, Mapping |
8 | 8 |
|
9 | 9 | import sqlglot as sg
|
10 | 10 |
|
@@ -180,6 +180,39 @@ def _not_all(op, **kw):
|
180 | 180 | return translate_val(ops.Not(ops.All(op.arg)), **kw)
|
181 | 181 |
|
182 | 182 |
|
| 183 | +def _quantiles(quantiles_translator_func: Callable, func_name: str): |
| 184 | + def translate(op, **kw): |
| 185 | + quantile = quantiles_translator_func(op.quantile) |
| 186 | + args = [_sql(translate_val(op.arg, **kw))] |
| 187 | + func = func_name |
| 188 | + |
| 189 | + if (where := op.where) is not None: |
| 190 | + args.append(_sql(translate_val(where, **kw))) |
| 191 | + func += "If" |
| 192 | + |
| 193 | + return f"{func}({quantile})({', '.join(args)})" |
| 194 | + |
| 195 | + return translate |
| 196 | + |
| 197 | + |
| 198 | +@translate_val.register(ops.Quantile) |
| 199 | +def _quantile(op, **kw): |
| 200 | + def quantiles_translator_func(quantiles): |
| 201 | + return _sql(translate_val(quantiles, **kw)) |
| 202 | + |
| 203 | + return _quantiles(quantiles_translator_func, func_name="quantile")(op, **kw) |
| 204 | + |
| 205 | + |
| 206 | +@translate_val.register(ops.MultiQuantile) |
| 207 | +def _multi_quantile(op, **kw): |
| 208 | + def quantiles_translator_func(quantiles): |
| 209 | + if not isinstance(quantiles, ops.Literal): |
| 210 | + raise TypeError("ClickHouse quantile only accepts a list of Python floats") |
| 211 | + return ", ".join(map(str, quantiles.value)) |
| 212 | + |
| 213 | + return _quantiles(quantiles_translator_func, func_name="quantiles")(op, **kw) |
| 214 | + |
| 215 | + |
183 | 216 | def _agg_variance_like(func):
|
184 | 217 | variants = {"sample": f"{func}Samp", "pop": f"{func}Pop"}
|
185 | 218 |
|
|
0 commit comments