Skip to content

Commit 095bea2

Browse files
authored
fix: enforce bq SchemaField field_type and mode using feature value_type (#1019)
* samples: add feature store samples * fix: force bq has a data type for temp table before ingestion * Revert "samples: add feature store samples" This reverts commit 24ece4d. * fix: double to float64 * fix: add job_config for repeated data type * fix: remove print * fix: bq_schema and tests * fix: add unit tests for get_bq_schema and ic tests for string array ingestion validation * fix compat service init misplace fs version * fix: unit tests by adding assert for bq schema field mock
1 parent 09c2e8a commit 095bea2

File tree

5 files changed

+155
-28
lines changed

5 files changed

+155
-28
lines changed

google/cloud/aiplatform/compat/services/__init__.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -87,8 +87,8 @@
8787
# v1
8888
dataset_service_client_v1,
8989
endpoint_service_client_v1,
90-
featurestore_online_serving_service_client_v1beta1,
91-
featurestore_service_client_v1beta1,
90+
featurestore_online_serving_service_client_v1,
91+
featurestore_service_client_v1,
9292
job_service_client_v1,
9393
metadata_service_client_v1,
9494
model_service_client_v1,
@@ -99,8 +99,8 @@
9999
# v1beta1
100100
dataset_service_client_v1beta1,
101101
endpoint_service_client_v1beta1,
102-
featurestore_online_serving_service_client_v1,
103-
featurestore_service_client_v1,
102+
featurestore_online_serving_service_client_v1beta1,
103+
featurestore_service_client_v1beta1,
104104
job_service_client_v1beta1,
105105
model_service_client_v1beta1,
106106
pipeline_service_client_v1beta1,

google/cloud/aiplatform/featurestore/entity_type.py

+50-1
Original file line numberDiff line numberDiff line change
@@ -1238,6 +1238,17 @@ def ingest_from_df(
12381238
)
12391239

12401240
self.wait()
1241+
1242+
feature_source_fields = feature_source_fields or {}
1243+
bq_schema = []
1244+
for feature_id in feature_ids:
1245+
feature_field_name = feature_source_fields.get(feature_id, feature_id)
1246+
feature_value_type = self.get_feature(feature_id).to_dict()["valueType"]
1247+
bq_schema_field = self._get_bq_schema_field(
1248+
feature_field_name, feature_value_type
1249+
)
1250+
bq_schema.append(bq_schema_field)
1251+
12411252
entity_type_name_components = self._parse_resource_name(self.resource_name)
12421253
featurestore_id, entity_type_id = (
12431254
entity_type_name_components["featurestore"],
@@ -1260,8 +1271,20 @@ def ingest_from_df(
12601271
temp_bq_dataset = bigquery_client.create_dataset(temp_bq_dataset)
12611272

12621273
try:
1274+
1275+
parquet_options = bigquery.format_options.ParquetOptions()
1276+
parquet_options.enable_list_inference = True
1277+
1278+
job_config = bigquery.LoadJobConfig(
1279+
schema=bq_schema,
1280+
source_format=bigquery.SourceFormat.PARQUET,
1281+
parquet_options=parquet_options,
1282+
)
1283+
12631284
job = bigquery_client.load_table_from_dataframe(
1264-
dataframe=df_source, destination=temp_bq_table_id
1285+
dataframe=df_source,
1286+
destination=temp_bq_table_id,
1287+
job_config=job_config,
12651288
)
12661289
job.result()
12671290

@@ -1281,6 +1304,32 @@ def ingest_from_df(
12811304

12821305
return entity_type_obj
12831306

1307+
@staticmethod
1308+
def _get_bq_schema_field(
1309+
name: str, feature_value_type: str
1310+
) -> bigquery.SchemaField:
1311+
"""Helper method to get BigQuery Schema Field.
1312+
1313+
Args:
1314+
name (str):
1315+
Required. The name of the schema field, which can be either the feature_id,
1316+
or the field_name in BigQuery for the feature if different than the feature_id.
1317+
feature_value_type (str):
1318+
Required. The feature value_type.
1319+
1320+
Returns:
1321+
bigquery.SchemaField: bigquery.SchemaField
1322+
"""
1323+
bq_data_type = utils.featurestore_utils.FEATURE_STORE_VALUE_TYPE_TO_BQ_DATA_TYPE_MAP[
1324+
feature_value_type
1325+
]
1326+
bq_schema_field = bigquery.SchemaField(
1327+
name=name,
1328+
field_type=bq_data_type["field_type"],
1329+
mode=bq_data_type.get("mode") or "NULLABLE",
1330+
)
1331+
return bq_schema_field
1332+
12841333
@staticmethod
12851334
def _instantiate_featurestore_online_client(
12861335
location: Optional[str] = None,

google/cloud/aiplatform/utils/featurestore_utils.py

+12
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,18 @@
3333

3434
_FEATURE_VALUE_TYPE_UNSPECIFIED = "VALUE_TYPE_UNSPECIFIED"
3535

36+
FEATURE_STORE_VALUE_TYPE_TO_BQ_DATA_TYPE_MAP = {
37+
"BOOL": {"field_type": "BOOL"},
38+
"BOOL_ARRAY": {"field_type": "BOOL", "mode": "REPEATED"},
39+
"DOUBLE": {"field_type": "FLOAT64"},
40+
"DOUBLE_ARRAY": {"field_type": "FLOAT64", "mode": "REPEATED"},
41+
"INT64": {"field_type": "INT64"},
42+
"INT64_ARRAY": {"field_type": "INT64", "mode": "REPEATED"},
43+
"STRING": {"field_type": "STRING"},
44+
"STRING_ARRAY": {"field_type": "STRING", "mode": "REPEATED"},
45+
"BYTES": {"field_type": "BYTES"},
46+
}
47+
3648

3749
def validate_id(resource_id: str) -> None:
3850
"""Validates feature store resource ID pattern.

tests/system/aiplatform/test_featurestore.py

+7-7
Original file line numberDiff line numberDiff line change
@@ -219,7 +219,7 @@ def test_batch_create_features(self, shared_state):
219219

220220
movie_feature_configs = {
221221
_TEST_MOVIE_TITLE_FEATURE_ID: {"value_type": "STRING"},
222-
_TEST_MOVIE_GENRES_FEATURE_ID: {"value_type": "STRING"},
222+
_TEST_MOVIE_GENRES_FEATURE_ID: {"value_type": "STRING_ARRAY"},
223223
_TEST_MOVIE_AVERAGE_RATING_FEATURE_ID: {"value_type": "DOUBLE"},
224224
}
225225

@@ -277,14 +277,14 @@ def test_ingest_feature_values_from_df_using_feature_time_column_and_online_read
277277
"movie_id": "movie_01",
278278
"average_rating": 4.9,
279279
"title": "The Shawshank Redemption",
280-
"genres": "Drama",
280+
"genres": ["Drama"],
281281
"update_time": "2021-08-20 20:44:11.094375+00:00",
282282
},
283283
{
284284
"movie_id": "movie_02",
285285
"average_rating": 4.2,
286286
"title": "The Shining",
287-
"genres": "Horror",
287+
"genres": ["Horror"],
288288
"update_time": "2021-08-20 20:44:11.094375+00:00",
289289
},
290290
],
@@ -312,13 +312,13 @@ def test_ingest_feature_values_from_df_using_feature_time_column_and_online_read
312312
"movie_id": "movie_01",
313313
"average_rating": 4.9,
314314
"title": "The Shawshank Redemption",
315-
"genres": "Drama",
315+
"genres": ["Drama"],
316316
},
317317
{
318318
"movie_id": "movie_02",
319319
"average_rating": 4.2,
320320
"title": "The Shining",
321-
"genres": "Horror",
321+
"genres": ["Horror"],
322322
},
323323
]
324324
expected_movie_entity_views_df_after_ingest = pd.DataFrame(
@@ -350,13 +350,13 @@ def test_ingest_feature_values_from_df_using_feature_time_datetime_and_online_re
350350
"movie_id": "movie_03",
351351
"average_rating": 4.5,
352352
"title": "Cinema Paradiso",
353-
"genres": "Romance",
353+
"genres": ["Romance"],
354354
},
355355
{
356356
"movie_id": "movie_04",
357357
"average_rating": 4.6,
358358
"title": "The Dark Knight",
359-
"genres": "Action",
359+
"genres": ["Action"],
360360
},
361361
],
362362
columns=["movie_id", "average_rating", "title", "genres"],

tests/unit/aiplatform/test_featurestores.py

+82-16
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,8 @@
114114
}
115115

116116
_TEST_FEATURE_VALUE_TYPE = _TEST_INT_TYPE
117+
_TEST_FEATURE_VALUE_TYPE_BQ_FIELD_TYPE = "INT64"
118+
_TEST_FEATURE_VALUE_TYPE_BQ_MODE = "NULLABLE"
117119

118120
_ARRAY_FEATURE_VALUE_TYPE_TO_GCA_TYPE_MAP = {
119121
_TEST_BOOL_ARR_TYPE: gca_types.BoolArray,
@@ -211,6 +213,9 @@
211213
"my_feature_id_1": {"value_type": _TEST_FEATURE_VALUE_TYPE_STR},
212214
}
213215

216+
_TEST_IMPORTING_FEATURE_ID = "my_feature_id_1"
217+
_TEST_IMPORTING_FEATURE_SOURCE_FIELD = "my_feature_id_1_source_field"
218+
214219
_TEST_IMPORTING_FEATURE_IDS = ["my_feature_id_1"]
215220

216221
_TEST_IMPORTING_FEATURE_SOURCE_FIELDS = {
@@ -363,22 +368,22 @@ def bq_init_dataset_mock(bq_dataset_mock):
363368

364369

365370
@pytest.fixture
366-
def bq_create_dataset_mock(bq_init_client_mock):
367-
with patch.object(bigquery.Client, "create_dataset") as bq_create_dataset_mock:
371+
def bq_create_dataset_mock(bq_client_mock):
372+
with patch.object(bq_client_mock, "create_dataset") as bq_create_dataset_mock:
368373
yield bq_create_dataset_mock
369374

370375

371376
@pytest.fixture
372-
def bq_load_table_from_dataframe_mock(bq_init_client_mock):
377+
def bq_load_table_from_dataframe_mock(bq_client_mock):
373378
with patch.object(
374-
bigquery.Client, "load_table_from_dataframe"
379+
bq_client_mock, "load_table_from_dataframe"
375380
) as bq_load_table_from_dataframe_mock:
376381
yield bq_load_table_from_dataframe_mock
377382

378383

379384
@pytest.fixture
380-
def bq_delete_dataset_mock(bq_init_client_mock):
381-
with patch.object(bigquery.Client, "delete_dataset") as bq_delete_dataset_mock:
385+
def bq_delete_dataset_mock(bq_client_mock):
386+
with patch.object(bq_client_mock, "delete_dataset") as bq_delete_dataset_mock:
382387
yield bq_delete_dataset_mock
383388

384389

@@ -396,16 +401,29 @@ def bqs_init_client_mock(bqs_client_mock):
396401

397402

398403
@pytest.fixture
399-
def bqs_create_read_session(bqs_init_client_mock):
404+
def bqs_create_read_session(bqs_client_mock):
400405
with patch.object(
401-
bigquery_storage.BigQueryReadClient, "create_read_session"
406+
bqs_client_mock, "create_read_session"
402407
) as bqs_create_read_session:
403408
read_session_proto = gcbqs_stream.ReadSession()
404409
read_session_proto.streams = [gcbqs_stream.ReadStream()]
405410
bqs_create_read_session.return_value = read_session_proto
406411
yield bqs_create_read_session
407412

408413

414+
@pytest.fixture
415+
def bq_schema_field_mock():
416+
mock = MagicMock(bigquery.SchemaField)
417+
yield mock
418+
419+
420+
@pytest.fixture
421+
def bq_init_schema_field_mock(bq_schema_field_mock):
422+
with patch.object(bigquery, "SchemaField") as bq_init_schema_field_mock:
423+
bq_init_schema_field_mock.return_value = bq_schema_field_mock
424+
yield bq_init_schema_field_mock
425+
426+
409427
# All Featurestore Mocks
410428
@pytest.fixture
411429
def get_featurestore_mock():
@@ -1672,14 +1690,19 @@ def test_ingest_from_gcs_with_invalid_gcs_source_type(self):
16721690

16731691
@pytest.mark.usefixtures(
16741692
"get_entity_type_mock",
1693+
"get_feature_mock",
16751694
"bq_init_client_mock",
16761695
"bq_init_dataset_mock",
16771696
"bq_create_dataset_mock",
1678-
"bq_load_table_from_dataframe_mock",
16791697
"bq_delete_dataset_mock",
16801698
)
16811699
@patch("uuid.uuid4", uuid_mock)
1682-
def test_ingest_from_df_using_column(self, import_feature_values_mock):
1700+
def test_ingest_from_df_using_column(
1701+
self,
1702+
import_feature_values_mock,
1703+
bq_load_table_from_dataframe_mock,
1704+
bq_init_schema_field_mock,
1705+
):
16831706

16841707
aiplatform.init(project=_TEST_PROJECT)
16851708

@@ -1701,7 +1724,7 @@ def test_ingest_from_df_using_column(self, import_feature_values_mock):
17011724
f"{expecte_temp_bq_dataset_id}.{_TEST_ENTITY_TYPE_ID}"
17021725
)
17031726

1704-
true_import_feature_values_request = gca_featurestore_service.ImportFeatureValuesRequest(
1727+
expected_import_feature_values_request = gca_featurestore_service.ImportFeatureValuesRequest(
17051728
entity_type=_TEST_ENTITY_TYPE_NAME,
17061729
feature_specs=[
17071730
gca_featurestore_service.ImportFeatureValuesRequest.FeatureSpec(
@@ -1714,20 +1737,32 @@ def test_ingest_from_df_using_column(self, import_feature_values_mock):
17141737
feature_time_field=_TEST_FEATURE_TIME_FIELD,
17151738
)
17161739

1740+
bq_init_schema_field_mock.assert_called_once_with(
1741+
name=_TEST_IMPORTING_FEATURE_SOURCE_FIELD,
1742+
field_type=_TEST_FEATURE_VALUE_TYPE_BQ_FIELD_TYPE,
1743+
mode=_TEST_FEATURE_VALUE_TYPE_BQ_MODE,
1744+
)
1745+
17171746
import_feature_values_mock.assert_called_once_with(
1718-
request=true_import_feature_values_request, metadata=_TEST_REQUEST_METADATA,
1747+
request=expected_import_feature_values_request,
1748+
metadata=_TEST_REQUEST_METADATA,
17191749
)
17201750

17211751
@pytest.mark.usefixtures(
17221752
"get_entity_type_mock",
1753+
"get_feature_mock",
17231754
"bq_init_client_mock",
17241755
"bq_init_dataset_mock",
17251756
"bq_create_dataset_mock",
1726-
"bq_load_table_from_dataframe_mock",
17271757
"bq_delete_dataset_mock",
17281758
)
17291759
@patch("uuid.uuid4", uuid_mock)
1730-
def test_ingest_from_df_using_datetime(self, import_feature_values_mock):
1760+
def test_ingest_from_df_using_datetime(
1761+
self,
1762+
import_feature_values_mock,
1763+
bq_load_table_from_dataframe_mock,
1764+
bq_init_schema_field_mock,
1765+
):
17311766
aiplatform.init(project=_TEST_PROJECT)
17321767

17331768
my_entity_type = aiplatform.EntityType(entity_type_name=_TEST_ENTITY_TYPE_NAME)
@@ -1752,7 +1787,7 @@ def test_ingest_from_df_using_datetime(self, import_feature_values_mock):
17521787
timestamp_proto = timestamp_pb2.Timestamp()
17531788
timestamp_proto.FromDatetime(_TEST_FEATURE_TIME_DATETIME)
17541789

1755-
true_import_feature_values_request = gca_featurestore_service.ImportFeatureValuesRequest(
1790+
expected_import_feature_values_request = gca_featurestore_service.ImportFeatureValuesRequest(
17561791
entity_type=_TEST_ENTITY_TYPE_NAME,
17571792
feature_specs=[
17581793
gca_featurestore_service.ImportFeatureValuesRequest.FeatureSpec(
@@ -1765,8 +1800,39 @@ def test_ingest_from_df_using_datetime(self, import_feature_values_mock):
17651800
feature_time=timestamp_proto,
17661801
)
17671802

1803+
bq_init_schema_field_mock.assert_called_once_with(
1804+
name=_TEST_IMPORTING_FEATURE_SOURCE_FIELD,
1805+
field_type=_TEST_FEATURE_VALUE_TYPE_BQ_FIELD_TYPE,
1806+
mode=_TEST_FEATURE_VALUE_TYPE_BQ_MODE,
1807+
)
1808+
17681809
import_feature_values_mock.assert_called_once_with(
1769-
request=true_import_feature_values_request, metadata=_TEST_REQUEST_METADATA,
1810+
request=expected_import_feature_values_request,
1811+
metadata=_TEST_REQUEST_METADATA,
1812+
)
1813+
1814+
@pytest.mark.parametrize(
1815+
"feature_value_type, expected_field_type, expected_mode",
1816+
[
1817+
("BOOL", "BOOL", "NULLABLE"),
1818+
("BOOL_ARRAY", "BOOL", "REPEATED"),
1819+
("DOUBLE", "FLOAT64", "NULLABLE"),
1820+
("DOUBLE_ARRAY", "FLOAT64", "REPEATED"),
1821+
("INT64", "INT64", "NULLABLE"),
1822+
("INT64_ARRAY", "INT64", "REPEATED"),
1823+
("STRING", "STRING", "NULLABLE"),
1824+
("STRING_ARRAY", "STRING", "REPEATED"),
1825+
("BYTES", "BYTES", "NULLABLE"),
1826+
],
1827+
)
1828+
def test_get_bq_schema_field(
1829+
self, feature_value_type, expected_field_type, expected_mode
1830+
):
1831+
expected_bq_schema_field = bigquery.SchemaField(
1832+
name=_TEST_FEATURE_ID, field_type=expected_field_type, mode=expected_mode,
1833+
)
1834+
assert expected_bq_schema_field == aiplatform.EntityType._get_bq_schema_field(
1835+
name=_TEST_FEATURE_ID, feature_value_type=feature_value_type
17701836
)
17711837

17721838
@pytest.mark.usefixtures("get_entity_type_mock", "get_feature_mock")

0 commit comments

Comments
 (0)