forked from googleapis/python-bigquery-dataframes
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathcompose_custom_transformers.py
91 lines (70 loc) · 2.94 KB
/
compose_custom_transformers.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
import bigframes.pandas as bpd
from bigframes.ml.compose import CustomTransformer
from typing import List, Optional, Union, Dict
import re
class IdentityTransformer(CustomTransformer):
_CTID = "IDENT"
IDENT_BQSQL_RX = re.compile("^(?P<colname>[a-z][a-z0-9_]+)$", flags=re.IGNORECASE)
def custom_compile_to_sql(self, X: bpd.DataFrame, column: str) -> str:
return f"{column}"
@classmethod
def custom_parse_from_sql(
cls, config: Optional[Union[Dict, List]], sql: str
) -> tuple[CustomTransformer, str]:
col_label = cls.IDENT_BQSQL_RX.match(sql).group("colname")
return cls(), col_label
CustomTransformer.register(IdentityTransformer)
class Length1Transformer(CustomTransformer):
_CTID = "LEN1"
_DEFAULT_VALUE_DEFAULT = -1
LEN1_BQSQL_RX = re.compile(
"^CASE WHEN (?P<colname>[a-z][a-z0-9_]*) IS NULL THEN (?P<defaultvalue>[-]?[0-9]+) ELSE LENGTH[(](?P=colname)[)] END$",
flags=re.IGNORECASE,
)
def __init__(self, default_value: Optional[int] = None):
self.default_value = default_value
def custom_compile_to_sql(self, X: bpd.DataFrame, column: str) -> str:
default_value = (
self.default_value
if self.default_value is not None
else Length1Transformer._DEFAULT_VALUE_DEFAULT
)
return (
f"CASE WHEN {column} IS NULL THEN {default_value} ELSE LENGTH({column}) END"
)
@classmethod
def custom_parse_from_sql(
cls, config: Optional[Union[Dict, List]], sql: str
) -> tuple[CustomTransformer, str]:
m = cls.LEN1_BQSQL_RX.match(sql)
col_label = m.group("colname")
default_value = int(m.group("defaultvalue"))
return cls(default_value), col_label
CustomTransformer.register(Length1Transformer)
class Length2Transformer(CustomTransformer):
_CTID = "LEN2"
_DEFAULT_VALUE_DEFAULT = -1
LEN2_BQSQL_RX = re.compile(
"^CASE WHEN (?P<colname>[a-z][a-z0-9_]*) .*$", flags=re.IGNORECASE
)
def __init__(self, default_value: Optional[int] = None):
self.default_value = default_value
def get_persistent_config(self, column: str) -> Optional[Union[Dict, List]]:
return [self.default_value]
def custom_compile_to_sql(self, X: bpd.DataFrame, column: str) -> str:
default_value = (
self.default_value
if self.default_value is not None
else Length2Transformer._DEFAULT_VALUE_DEFAULT
)
return (
f"CASE WHEN {column} IS NULL THEN {default_value} ELSE LENGTH({column}) END"
)
@classmethod
def custom_parse_from_sql(
cls, config: Optional[Union[Dict, List]], sql: str
) -> tuple[CustomTransformer, str]:
col_label = cls.LEN2_BQSQL_RX.match(sql).group("colname")
default_value = config[0] # get default value from persistent_config
return cls(default_value), col_label
CustomTransformer.register(Length2Transformer)