Skip to content

Commit 5ceed05

Browse files
SalemJordenSalem Boylandtswast
authored
feat: add Model.transform_columns property (#1661)
--------- Co-authored-by: Salem Boyland <[email protected]> Co-authored-by: Tim Swast <[email protected]>
1 parent faa50b9 commit 5ceed05

File tree

3 files changed

+140
-1
lines changed

3 files changed

+140
-1
lines changed

google/cloud/bigquery/model.py

+71
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616

1717
"""Define resources for the BigQuery ML Models API."""
1818

19+
from __future__ import annotations # type: ignore
20+
1921
import copy
2022
import datetime
2123
import typing
@@ -184,6 +186,21 @@ def feature_columns(self) -> Sequence[standard_sql.StandardSqlField]:
184186
standard_sql.StandardSqlField.from_api_repr(column) for column in resource
185187
]
186188

189+
@property
190+
def transform_columns(self) -> Sequence[TransformColumn]:
191+
"""The input feature columns that were used to train this model.
192+
The output transform columns used to train this model.
193+
194+
See REST API:
195+
https://cloud.google.com/bigquery/docs/reference/rest/v2/models#transformcolumn
196+
197+
Read-only.
198+
"""
199+
resources: Sequence[Dict[str, Any]] = typing.cast(
200+
Sequence[Dict[str, Any]], self._properties.get("transformColumns", [])
201+
)
202+
return [TransformColumn(resource) for resource in resources]
203+
187204
@property
188205
def label_columns(self) -> Sequence[standard_sql.StandardSqlField]:
189206
"""Label columns that were used to train this model.
@@ -434,6 +451,60 @@ def __repr__(self):
434451
)
435452

436453

454+
class TransformColumn:
455+
"""TransformColumn represents a transform column feature.
456+
457+
See
458+
https://cloud.google.com/bigquery/docs/reference/rest/v2/models#transformcolumn
459+
460+
Args:
461+
resource:
462+
A dictionary representing a transform column feature.
463+
"""
464+
465+
def __init__(self, resource: Dict[str, Any]):
466+
self._properties = resource
467+
468+
@property
469+
def name(self) -> Optional[str]:
470+
"""Name of the column."""
471+
return self._properties.get("name")
472+
473+
@property
474+
def type_(self) -> Optional[standard_sql.StandardSqlDataType]:
475+
"""Data type of the column after the transform.
476+
477+
Returns:
478+
Optional[google.cloud.bigquery.standard_sql.StandardSqlDataType]:
479+
Data type of the column.
480+
"""
481+
type_json = self._properties.get("type")
482+
if type_json is None:
483+
return None
484+
return standard_sql.StandardSqlDataType.from_api_repr(type_json)
485+
486+
@property
487+
def transform_sql(self) -> Optional[str]:
488+
"""The SQL expression used in the column transform."""
489+
return self._properties.get("transformSql")
490+
491+
@classmethod
492+
def from_api_repr(cls, resource: Dict[str, Any]) -> "TransformColumn":
493+
"""Constructs a transform column feature given its API representation
494+
495+
Args:
496+
resource:
497+
Transform column feature representation from the API
498+
499+
Returns:
500+
Transform column feature parsed from ``resource``.
501+
"""
502+
this = cls({})
503+
resource = copy.deepcopy(resource)
504+
this._properties = resource
505+
return this
506+
507+
437508
def _model_arg_to_model_ref(value, default_project=None):
438509
"""Helper to convert a string or Model to ModelReference.
439510

mypy.ini

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
[mypy]
2-
python_version = 3.6
2+
python_version = 3.8
33
namespace_packages = True

tests/unit/model/test_model.py

+68
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,9 @@
1818

1919
import pytest
2020

21+
2122
import google.cloud._helpers
23+
import google.cloud.bigquery.model
2224

2325
KMS_KEY_NAME = "projects/1/locations/us/keyRings/1/cryptoKeys/1"
2426

@@ -136,6 +138,7 @@ def test_from_api_repr(target_class):
136138
google.cloud._helpers._rfc3339_to_datetime(got.training_runs[2]["startTime"])
137139
== expiration_time
138140
)
141+
assert got.transform_columns == []
139142

140143

141144
def test_from_api_repr_w_minimal_resource(target_class):
@@ -293,6 +296,71 @@ def test_feature_columns(object_under_test):
293296
assert object_under_test.feature_columns == expected
294297

295298

299+
def test_from_api_repr_w_transform_columns(target_class):
300+
resource = {
301+
"modelReference": {
302+
"projectId": "my-project",
303+
"datasetId": "my_dataset",
304+
"modelId": "my_model",
305+
},
306+
"transformColumns": [
307+
{
308+
"name": "transform_name",
309+
"type": {"typeKind": "INT64"},
310+
"transformSql": "transform_sql",
311+
}
312+
],
313+
}
314+
got = target_class.from_api_repr(resource)
315+
assert len(got.transform_columns) == 1
316+
transform_column = got.transform_columns[0]
317+
assert isinstance(transform_column, google.cloud.bigquery.model.TransformColumn)
318+
assert transform_column.name == "transform_name"
319+
320+
321+
def test_transform_column_name():
322+
transform_columns = google.cloud.bigquery.model.TransformColumn(
323+
{"name": "is_female"}
324+
)
325+
assert transform_columns.name == "is_female"
326+
327+
328+
def test_transform_column_transform_sql():
329+
transform_columns = google.cloud.bigquery.model.TransformColumn(
330+
{"transformSql": "is_female"}
331+
)
332+
assert transform_columns.transform_sql == "is_female"
333+
334+
335+
def test_transform_column_type():
336+
transform_columns = google.cloud.bigquery.model.TransformColumn(
337+
{"type": {"typeKind": "BOOL"}}
338+
)
339+
assert transform_columns.type_.type_kind == "BOOL"
340+
341+
342+
def test_transform_column_type_none():
343+
transform_columns = google.cloud.bigquery.model.TransformColumn({})
344+
assert transform_columns.type_ is None
345+
346+
347+
def test_transform_column_from_api_repr_with_unknown_properties():
348+
transform_column = google.cloud.bigquery.model.TransformColumn.from_api_repr(
349+
{
350+
"name": "is_female",
351+
"type": {"typeKind": "BOOL"},
352+
"transformSql": "is_female",
353+
"test": "one",
354+
}
355+
)
356+
assert transform_column._properties == {
357+
"name": "is_female",
358+
"type": {"typeKind": "BOOL"},
359+
"transformSql": "is_female",
360+
"test": "one",
361+
}
362+
363+
296364
def test_label_columns(object_under_test):
297365
from google.cloud.bigquery import standard_sql
298366

0 commit comments

Comments
 (0)