114
114
}
115
115
116
116
_TEST_FEATURE_VALUE_TYPE = _TEST_INT_TYPE
117
+ _TEST_FEATURE_VALUE_TYPE_BQ_FIELD_TYPE = "INT64"
118
+ _TEST_FEATURE_VALUE_TYPE_BQ_MODE = "NULLABLE"
117
119
118
120
_ARRAY_FEATURE_VALUE_TYPE_TO_GCA_TYPE_MAP = {
119
121
_TEST_BOOL_ARR_TYPE : gca_types .BoolArray ,
211
213
"my_feature_id_1" : {"value_type" : _TEST_FEATURE_VALUE_TYPE_STR },
212
214
}
213
215
216
+ _TEST_IMPORTING_FEATURE_ID = "my_feature_id_1"
217
+ _TEST_IMPORTING_FEATURE_SOURCE_FIELD = "my_feature_id_1_source_field"
218
+
214
219
_TEST_IMPORTING_FEATURE_IDS = ["my_feature_id_1" ]
215
220
216
221
_TEST_IMPORTING_FEATURE_SOURCE_FIELDS = {
@@ -363,22 +368,22 @@ def bq_init_dataset_mock(bq_dataset_mock):
363
368
364
369
365
370
@pytest .fixture
366
- def bq_create_dataset_mock (bq_init_client_mock ):
367
- with patch .object (bigquery . Client , "create_dataset" ) as bq_create_dataset_mock :
371
+ def bq_create_dataset_mock (bq_client_mock ):
372
+ with patch .object (bq_client_mock , "create_dataset" ) as bq_create_dataset_mock :
368
373
yield bq_create_dataset_mock
369
374
370
375
371
376
@pytest .fixture
372
- def bq_load_table_from_dataframe_mock (bq_init_client_mock ):
377
+ def bq_load_table_from_dataframe_mock (bq_client_mock ):
373
378
with patch .object (
374
- bigquery . Client , "load_table_from_dataframe"
379
+ bq_client_mock , "load_table_from_dataframe"
375
380
) as bq_load_table_from_dataframe_mock :
376
381
yield bq_load_table_from_dataframe_mock
377
382
378
383
379
384
@pytest .fixture
380
- def bq_delete_dataset_mock (bq_init_client_mock ):
381
- with patch .object (bigquery . Client , "delete_dataset" ) as bq_delete_dataset_mock :
385
+ def bq_delete_dataset_mock (bq_client_mock ):
386
+ with patch .object (bq_client_mock , "delete_dataset" ) as bq_delete_dataset_mock :
382
387
yield bq_delete_dataset_mock
383
388
384
389
@@ -396,16 +401,29 @@ def bqs_init_client_mock(bqs_client_mock):
396
401
397
402
398
403
@pytest .fixture
399
- def bqs_create_read_session (bqs_init_client_mock ):
404
+ def bqs_create_read_session (bqs_client_mock ):
400
405
with patch .object (
401
- bigquery_storage . BigQueryReadClient , "create_read_session"
406
+ bqs_client_mock , "create_read_session"
402
407
) as bqs_create_read_session :
403
408
read_session_proto = gcbqs_stream .ReadSession ()
404
409
read_session_proto .streams = [gcbqs_stream .ReadStream ()]
405
410
bqs_create_read_session .return_value = read_session_proto
406
411
yield bqs_create_read_session
407
412
408
413
414
+ @pytest .fixture
415
+ def bq_schema_field_mock ():
416
+ mock = MagicMock (bigquery .SchemaField )
417
+ yield mock
418
+
419
+
420
+ @pytest .fixture
421
+ def bq_init_schema_field_mock (bq_schema_field_mock ):
422
+ with patch .object (bigquery , "SchemaField" ) as bq_init_schema_field_mock :
423
+ bq_init_schema_field_mock .return_value = bq_schema_field_mock
424
+ yield bq_init_schema_field_mock
425
+
426
+
409
427
# All Featurestore Mocks
410
428
@pytest .fixture
411
429
def get_featurestore_mock ():
@@ -1672,14 +1690,19 @@ def test_ingest_from_gcs_with_invalid_gcs_source_type(self):
1672
1690
1673
1691
@pytest .mark .usefixtures (
1674
1692
"get_entity_type_mock" ,
1693
+ "get_feature_mock" ,
1675
1694
"bq_init_client_mock" ,
1676
1695
"bq_init_dataset_mock" ,
1677
1696
"bq_create_dataset_mock" ,
1678
- "bq_load_table_from_dataframe_mock" ,
1679
1697
"bq_delete_dataset_mock" ,
1680
1698
)
1681
1699
@patch ("uuid.uuid4" , uuid_mock )
1682
- def test_ingest_from_df_using_column (self , import_feature_values_mock ):
1700
+ def test_ingest_from_df_using_column (
1701
+ self ,
1702
+ import_feature_values_mock ,
1703
+ bq_load_table_from_dataframe_mock ,
1704
+ bq_init_schema_field_mock ,
1705
+ ):
1683
1706
1684
1707
aiplatform .init (project = _TEST_PROJECT )
1685
1708
@@ -1701,7 +1724,7 @@ def test_ingest_from_df_using_column(self, import_feature_values_mock):
1701
1724
f"{ expecte_temp_bq_dataset_id } .{ _TEST_ENTITY_TYPE_ID } "
1702
1725
)
1703
1726
1704
- true_import_feature_values_request = gca_featurestore_service .ImportFeatureValuesRequest (
1727
+ expected_import_feature_values_request = gca_featurestore_service .ImportFeatureValuesRequest (
1705
1728
entity_type = _TEST_ENTITY_TYPE_NAME ,
1706
1729
feature_specs = [
1707
1730
gca_featurestore_service .ImportFeatureValuesRequest .FeatureSpec (
@@ -1714,20 +1737,32 @@ def test_ingest_from_df_using_column(self, import_feature_values_mock):
1714
1737
feature_time_field = _TEST_FEATURE_TIME_FIELD ,
1715
1738
)
1716
1739
1740
+ bq_init_schema_field_mock .assert_called_once_with (
1741
+ name = _TEST_IMPORTING_FEATURE_SOURCE_FIELD ,
1742
+ field_type = _TEST_FEATURE_VALUE_TYPE_BQ_FIELD_TYPE ,
1743
+ mode = _TEST_FEATURE_VALUE_TYPE_BQ_MODE ,
1744
+ )
1745
+
1717
1746
import_feature_values_mock .assert_called_once_with (
1718
- request = true_import_feature_values_request , metadata = _TEST_REQUEST_METADATA ,
1747
+ request = expected_import_feature_values_request ,
1748
+ metadata = _TEST_REQUEST_METADATA ,
1719
1749
)
1720
1750
1721
1751
@pytest .mark .usefixtures (
1722
1752
"get_entity_type_mock" ,
1753
+ "get_feature_mock" ,
1723
1754
"bq_init_client_mock" ,
1724
1755
"bq_init_dataset_mock" ,
1725
1756
"bq_create_dataset_mock" ,
1726
- "bq_load_table_from_dataframe_mock" ,
1727
1757
"bq_delete_dataset_mock" ,
1728
1758
)
1729
1759
@patch ("uuid.uuid4" , uuid_mock )
1730
- def test_ingest_from_df_using_datetime (self , import_feature_values_mock ):
1760
+ def test_ingest_from_df_using_datetime (
1761
+ self ,
1762
+ import_feature_values_mock ,
1763
+ bq_load_table_from_dataframe_mock ,
1764
+ bq_init_schema_field_mock ,
1765
+ ):
1731
1766
aiplatform .init (project = _TEST_PROJECT )
1732
1767
1733
1768
my_entity_type = aiplatform .EntityType (entity_type_name = _TEST_ENTITY_TYPE_NAME )
@@ -1752,7 +1787,7 @@ def test_ingest_from_df_using_datetime(self, import_feature_values_mock):
1752
1787
timestamp_proto = timestamp_pb2 .Timestamp ()
1753
1788
timestamp_proto .FromDatetime (_TEST_FEATURE_TIME_DATETIME )
1754
1789
1755
- true_import_feature_values_request = gca_featurestore_service .ImportFeatureValuesRequest (
1790
+ expected_import_feature_values_request = gca_featurestore_service .ImportFeatureValuesRequest (
1756
1791
entity_type = _TEST_ENTITY_TYPE_NAME ,
1757
1792
feature_specs = [
1758
1793
gca_featurestore_service .ImportFeatureValuesRequest .FeatureSpec (
@@ -1765,8 +1800,39 @@ def test_ingest_from_df_using_datetime(self, import_feature_values_mock):
1765
1800
feature_time = timestamp_proto ,
1766
1801
)
1767
1802
1803
+ bq_init_schema_field_mock .assert_called_once_with (
1804
+ name = _TEST_IMPORTING_FEATURE_SOURCE_FIELD ,
1805
+ field_type = _TEST_FEATURE_VALUE_TYPE_BQ_FIELD_TYPE ,
1806
+ mode = _TEST_FEATURE_VALUE_TYPE_BQ_MODE ,
1807
+ )
1808
+
1768
1809
import_feature_values_mock .assert_called_once_with (
1769
- request = true_import_feature_values_request , metadata = _TEST_REQUEST_METADATA ,
1810
+ request = expected_import_feature_values_request ,
1811
+ metadata = _TEST_REQUEST_METADATA ,
1812
+ )
1813
+
1814
+ @pytest .mark .parametrize (
1815
+ "feature_value_type, expected_field_type, expected_mode" ,
1816
+ [
1817
+ ("BOOL" , "BOOL" , "NULLABLE" ),
1818
+ ("BOOL_ARRAY" , "BOOL" , "REPEATED" ),
1819
+ ("DOUBLE" , "FLOAT64" , "NULLABLE" ),
1820
+ ("DOUBLE_ARRAY" , "FLOAT64" , "REPEATED" ),
1821
+ ("INT64" , "INT64" , "NULLABLE" ),
1822
+ ("INT64_ARRAY" , "INT64" , "REPEATED" ),
1823
+ ("STRING" , "STRING" , "NULLABLE" ),
1824
+ ("STRING_ARRAY" , "STRING" , "REPEATED" ),
1825
+ ("BYTES" , "BYTES" , "NULLABLE" ),
1826
+ ],
1827
+ )
1828
+ def test_get_bq_schema_field (
1829
+ self , feature_value_type , expected_field_type , expected_mode
1830
+ ):
1831
+ expected_bq_schema_field = bigquery .SchemaField (
1832
+ name = _TEST_FEATURE_ID , field_type = expected_field_type , mode = expected_mode ,
1833
+ )
1834
+ assert expected_bq_schema_field == aiplatform .EntityType ._get_bq_schema_field (
1835
+ name = _TEST_FEATURE_ID , feature_value_type = feature_value_type
1770
1836
)
1771
1837
1772
1838
@pytest .mark .usefixtures ("get_entity_type_mock" , "get_feature_mock" )
0 commit comments