|
18 | 18 | import abc
|
19 | 19 | from dataclasses import dataclass
|
20 | 20 | import enum
|
21 |
| -import proto |
22 |
| -from typing_extensions import override |
23 | 21 | from typing import Any, Dict, List, Optional
|
24 | 22 | from google.cloud.aiplatform.compat.types import (
|
25 | 23 | feature_online_store_service as fos_service,
|
26 | 24 | )
|
| 25 | +import proto |
| 26 | +from typing_extensions import override |
27 | 27 |
|
28 | 28 |
|
29 | 29 | def get_feature_online_store_name(online_store_name: str) -> str:
|
30 |
| - """Extract Feature Online Store's name from FeatureView's full resource name. |
| 30 | + """Extract Feature Online Store's name from FeatureView's full resource name. |
31 | 31 |
|
32 |
| - Args: |
33 |
| - online_store_name: Full resource name is projects/project_number/ |
| 32 | + Args: |
| 33 | + online_store_name: Full resource name is projects/project_number/ |
34 | 34 | locations/us-central1/featureOnlineStores/fos_name/featureViews/fv_name
|
35 | 35 |
|
36 |
| - Returns: |
37 |
| - str: feature online store name. |
38 |
| - """ |
39 |
| - arr = online_store_name.split("/") |
40 |
| - return arr[5] |
| 36 | + Returns: |
| 37 | + str: feature online store name. |
| 38 | + """ |
| 39 | + arr = online_store_name.split("/") |
| 40 | + return arr[5] |
41 | 41 |
|
42 | 42 |
|
43 | 43 | class PublicEndpointNotFoundError(RuntimeError):
|
44 |
| - """Public endpoint has not been created yet.""" |
| 44 | + """Public endpoint has not been created yet.""" |
45 | 45 |
|
46 | 46 |
|
47 | 47 | @dataclass
|
48 | 48 | class FeatureViewBigQuerySource:
|
49 |
| - uri: str |
50 |
| - entity_id_columns: List[str] |
| 49 | + uri: str |
| 50 | + entity_id_columns: List[str] |
51 | 51 |
|
52 | 52 |
|
53 | 53 | @dataclass
|
54 | 54 | class FeatureViewReadResponse:
|
55 |
| - _response: fos_service.FetchFeatureValuesResponse |
| 55 | + _response: fos_service.FetchFeatureValuesResponse |
56 | 56 |
|
57 |
| - def __init__(self, response: fos_service.FetchFeatureValuesResponse): |
58 |
| - self._response = response |
| 57 | + def __init__(self, response: fos_service.FetchFeatureValuesResponse): |
| 58 | + self._response = response |
59 | 59 |
|
60 |
| - def to_dict(self) -> Dict[str, Any]: |
61 |
| - return proto.Message.to_dict(self._response.key_values) |
| 60 | + def to_dict(self) -> Dict[str, Any]: |
| 61 | + return proto.Message.to_dict(self._response.key_values) |
62 | 62 |
|
63 |
| - def to_proto(self) -> fos_service.FetchFeatureValuesResponse: |
64 |
| - return self._response |
| 63 | + def to_proto(self) -> fos_service.FetchFeatureValuesResponse: |
| 64 | + return self._response |
65 | 65 |
|
66 | 66 |
|
67 | 67 | @dataclass
|
68 | 68 | class SearchNearestEntitiesResponse:
|
69 |
| - _response: fos_service.SearchNearestEntitiesResponse |
| 69 | + _response: fos_service.SearchNearestEntitiesResponse |
70 | 70 |
|
71 |
| - def __init__(self, response: fos_service.SearchNearestEntitiesResponse): |
72 |
| - self._response = response |
| 71 | + def __init__(self, response: fos_service.SearchNearestEntitiesResponse): |
| 72 | + self._response = response |
73 | 73 |
|
74 |
| - def to_dict(self) -> Dict[str, Any]: |
75 |
| - return proto.Message.to_dict(self._response.nearest_neighbors) |
| 74 | + def to_dict(self) -> Dict[str, Any]: |
| 75 | + return proto.Message.to_dict(self._response.nearest_neighbors) |
76 | 76 |
|
77 |
| - def to_proto(self) -> fos_service.SearchNearestEntitiesResponse: |
78 |
| - return self._response |
| 77 | + def to_proto(self) -> fos_service.SearchNearestEntitiesResponse: |
| 78 | + return self._response |
79 | 79 |
|
80 | 80 |
|
81 | 81 | class DistanceMeasureType(enum.Enum):
|
82 |
| - """The distance measure used in nearest neighbor search.""" |
| 82 | + """The distance measure used in nearest neighbor search.""" |
83 | 83 |
|
84 |
| - DISTANCE_MEASURE_TYPE_UNSPECIFIED = 0 |
85 |
| - # Euclidean (L_2) Distance. |
86 |
| - SQUARED_L2_DISTANCE = 1 |
87 |
| - # Cosine Distance. Defined as 1 - cosine similarity. |
88 |
| - COSINE_DISTANCE = 2 |
89 |
| - # Dot Product Distance. Defined as a negative of the dot product. |
90 |
| - DOT_PRODUCT_DISTANCE = 3 |
| 84 | + DISTANCE_MEASURE_TYPE_UNSPECIFIED = 0 |
| 85 | + # Euclidean (L_2) Distance. |
| 86 | + SQUARED_L2_DISTANCE = 1 |
| 87 | + # Cosine Distance. Defined as 1 - cosine similarity. |
| 88 | + COSINE_DISTANCE = 2 |
| 89 | + # Dot Product Distance. Defined as a negative of the dot product. |
| 90 | + DOT_PRODUCT_DISTANCE = 3 |
91 | 91 |
|
92 | 92 |
|
93 | 93 | class AlgorithmConfig(abc.ABC):
|
94 |
| - """Base class for configuration options for matching algorithm.""" |
| 94 | + """Base class for configuration options for matching algorithm.""" |
95 | 95 |
|
96 |
| - def as_dict(self) -> Dict: |
97 |
| - """Returns the configuration as a dictionary. |
| 96 | + def as_dict(self) -> Dict: |
| 97 | + """Returns the configuration as a dictionary. |
98 | 98 |
|
99 |
| - Returns: |
100 |
| - Dict[str, Any] |
101 |
| - """ |
102 |
| - pass |
| 99 | + Returns: |
| 100 | + Dict[str, Any] |
| 101 | + """ |
| 102 | + pass |
103 | 103 |
|
104 | 104 |
|
105 | 105 | @dataclass
|
106 | 106 | class TreeAhConfig(AlgorithmConfig):
|
107 |
| - """Configuration options for using the tree-AH algorithm (Shallow tree + Asymmetric Hashing). |
108 |
| - Please refer to this paper for more details: https://arxiv.org/abs/1908.10396 |
| 107 | + """Configuration options for using the tree-AH algorithm (Shallow tree + Asymmetric Hashing). |
109 | 108 |
|
110 |
| - Args: |
111 |
| - leaf_node_embedding_count (int): |
112 |
| - Optional. Number of embeddings on each leaf node. The default value is 1000 if not set. |
113 |
| - """ |
| 109 | + Please refer to this paper for more details: https://arxiv.org/abs/1908.10396 |
114 | 110 |
|
115 |
| - leaf_node_embedding_count: Optional[int] = None |
| 111 | + Args: |
| 112 | + leaf_node_embedding_count (int): Optional. Number of embeddings on each |
| 113 | + leaf node. The default value is 1000 if not set. |
| 114 | + """ |
116 | 115 |
|
117 |
| - @override |
118 |
| - def as_dict(self) -> Dict: |
119 |
| - return {"leaf_node_embedding_count": self.leaf_node_embedding_count} |
| 116 | + leaf_node_embedding_count: Optional[int] = None |
| 117 | + |
| 118 | + @override |
| 119 | + def as_dict(self) -> Dict: |
| 120 | + return {"leaf_node_embedding_count": self.leaf_node_embedding_count} |
120 | 121 |
|
121 | 122 |
|
122 | 123 | @dataclass
|
123 | 124 | class BruteForceConfig(AlgorithmConfig):
|
124 |
| - """Configuration options for using brute force search. |
125 |
| - It simply implements the standard linear search in the database for |
126 |
| - each query. |
127 |
| - """ |
| 125 | + """Configuration options for using brute force search. |
128 | 126 |
|
129 |
| - @override |
130 |
| - def as_dict(self) -> Dict[str, Any]: |
131 |
| - return {"bruteForceConfig": {}} |
| 127 | + It simply implements the standard linear search in the database for each |
| 128 | + query. |
| 129 | + """ |
| 130 | + |
| 131 | + @override |
| 132 | + def as_dict(self) -> Dict[str, Any]: |
| 133 | + return {"bruteForceConfig": {}} |
132 | 134 |
|
133 | 135 |
|
134 | 136 | @dataclass
|
135 | 137 | class IndexConfig:
|
136 |
| - """Configuration options for the Vertex FeatureView for embedding.""" |
137 |
| - |
138 |
| - embedding_column: str |
139 |
| - dimensions: int |
140 |
| - algorithm_config: AlgorithmConfig |
141 |
| - filter_columns: Optional[List[str]] = None |
142 |
| - crowding_column: Optional[str] = None |
143 |
| - distance_measure_type: Optional[DistanceMeasureType] = None |
144 |
| - |
145 |
| - def as_dict(self) -> Dict[str, Any]: |
146 |
| - """Returns the configuration as a dictionary. |
147 |
| -
|
148 |
| - Returns: |
149 |
| - Dict[str, Any] |
150 |
| - """ |
151 |
| - config = { |
152 |
| - "embedding_column": self.embedding_column, |
153 |
| - "embedding_dimension": self.dimensions, |
154 |
| - } |
155 |
| - if self.distance_measure_type is not None: |
156 |
| - config["distance_measure_type"] = self.distance_measure_type.value |
157 |
| - if self.filter_columns is not None: |
158 |
| - config["filter_columns"] = self.filter_columns |
159 |
| - if self.crowding_column is not None: |
160 |
| - config["crowding_column"] = self.crowding_column |
161 |
| - |
162 |
| - if isinstance(self.algorithm_config, TreeAhConfig): |
163 |
| - config["tree_ah_config"] = self.algorithm_config.as_dict() |
164 |
| - else: |
165 |
| - config["brute_force_config"] = self.algorithm_config.as_dict() |
166 |
| - return config |
| 138 | + """Configuration options for the Vertex FeatureView for embedding.""" |
| 139 | + |
| 140 | + embedding_column: str |
| 141 | + dimensions: int |
| 142 | + algorithm_config: AlgorithmConfig = TreeAhConfig() |
| 143 | + filter_columns: Optional[List[str]] = None |
| 144 | + crowding_column: Optional[str] = None |
| 145 | + distance_measure_type: Optional[DistanceMeasureType] = None |
| 146 | + |
| 147 | + def as_dict(self) -> Dict[str, Any]: |
| 148 | + """Returns the configuration as a dictionary. |
| 149 | +
|
| 150 | + Returns: |
| 151 | + Dict[str, Any] |
| 152 | + """ |
| 153 | + config = { |
| 154 | + "embedding_column": self.embedding_column, |
| 155 | + "embedding_dimension": self.dimensions, |
| 156 | + } |
| 157 | + if self.distance_measure_type is not None: |
| 158 | + config["distance_measure_type"] = self.distance_measure_type.value |
| 159 | + if self.filter_columns is not None: |
| 160 | + config["filter_columns"] = self.filter_columns |
| 161 | + if self.crowding_column is not None: |
| 162 | + config["crowding_column"] = self.crowding_column |
| 163 | + |
| 164 | + if isinstance(self.algorithm_config, TreeAhConfig): |
| 165 | + config["tree_ah_config"] = self.algorithm_config.as_dict() |
| 166 | + else: |
| 167 | + config["brute_force_config"] = self.algorithm_config.as_dict() |
| 168 | + return config |
167 | 169 |
|
168 | 170 |
|
169 | 171 | @dataclass
|
170 | 172 | class FeatureGroupBigQuerySource:
|
171 |
| - """BigQuery source for the Feature Group.""" |
| 173 | + """BigQuery source for the Feature Group.""" |
172 | 174 |
|
173 |
| - # The URI for the BigQuery table/view. |
174 |
| - uri: str |
175 |
| - # The entity ID columns. If not specified, defaults to ['entity_id']. |
176 |
| - entity_id_columns: Optional[List[str]] = None |
| 175 | + # The URI for the BigQuery table/view. |
| 176 | + uri: str |
| 177 | + # The entity ID columns. If not specified, defaults to ['entity_id']. |
| 178 | + entity_id_columns: Optional[List[str]] = None |
0 commit comments