Skip to content

Commit e8f884d

Browse files
vertex-sdk-botcopybara-github
authored andcommitted
fix: IndexConfig - use TreeAhConfig as default algorithm_config.
FUTURE_COPYBARA_INTEGRATE_REVIEW=#3966 from googleapis:release-please--branches--main 98e236c PiperOrigin-RevId: 647342595
1 parent ef5aeda commit e8f884d

File tree

1 file changed

+97
-95
lines changed
  • vertexai/resources/preview/feature_store

1 file changed

+97
-95
lines changed

vertexai/resources/preview/feature_store/utils.py

+97-95
Original file line numberDiff line numberDiff line change
@@ -18,159 +18,161 @@
1818
import abc
1919
from dataclasses import dataclass
2020
import enum
21-
import proto
22-
from typing_extensions import override
2321
from typing import Any, Dict, List, Optional
2422
from google.cloud.aiplatform.compat.types import (
2523
feature_online_store_service as fos_service,
2624
)
25+
import proto
26+
from typing_extensions import override
2727

2828

2929
def get_feature_online_store_name(online_store_name: str) -> str:
30-
"""Extract Feature Online Store's name from FeatureView's full resource name.
30+
"""Extract Feature Online Store's name from FeatureView's full resource name.
3131
32-
Args:
33-
online_store_name: Full resource name is projects/project_number/
32+
Args:
33+
online_store_name: Full resource name is projects/project_number/
3434
locations/us-central1/featureOnlineStores/fos_name/featureViews/fv_name
3535
36-
Returns:
37-
str: feature online store name.
38-
"""
39-
arr = online_store_name.split("/")
40-
return arr[5]
36+
Returns:
37+
str: feature online store name.
38+
"""
39+
arr = online_store_name.split("/")
40+
return arr[5]
4141

4242

4343
class PublicEndpointNotFoundError(RuntimeError):
44-
"""Public endpoint has not been created yet."""
44+
"""Public endpoint has not been created yet."""
4545

4646

4747
@dataclass
4848
class FeatureViewBigQuerySource:
49-
uri: str
50-
entity_id_columns: List[str]
49+
uri: str
50+
entity_id_columns: List[str]
5151

5252

5353
@dataclass
5454
class FeatureViewReadResponse:
55-
_response: fos_service.FetchFeatureValuesResponse
55+
_response: fos_service.FetchFeatureValuesResponse
5656

57-
def __init__(self, response: fos_service.FetchFeatureValuesResponse):
58-
self._response = response
57+
def __init__(self, response: fos_service.FetchFeatureValuesResponse):
58+
self._response = response
5959

60-
def to_dict(self) -> Dict[str, Any]:
61-
return proto.Message.to_dict(self._response.key_values)
60+
def to_dict(self) -> Dict[str, Any]:
61+
return proto.Message.to_dict(self._response.key_values)
6262

63-
def to_proto(self) -> fos_service.FetchFeatureValuesResponse:
64-
return self._response
63+
def to_proto(self) -> fos_service.FetchFeatureValuesResponse:
64+
return self._response
6565

6666

6767
@dataclass
6868
class SearchNearestEntitiesResponse:
69-
_response: fos_service.SearchNearestEntitiesResponse
69+
_response: fos_service.SearchNearestEntitiesResponse
7070

71-
def __init__(self, response: fos_service.SearchNearestEntitiesResponse):
72-
self._response = response
71+
def __init__(self, response: fos_service.SearchNearestEntitiesResponse):
72+
self._response = response
7373

74-
def to_dict(self) -> Dict[str, Any]:
75-
return proto.Message.to_dict(self._response.nearest_neighbors)
74+
def to_dict(self) -> Dict[str, Any]:
75+
return proto.Message.to_dict(self._response.nearest_neighbors)
7676

77-
def to_proto(self) -> fos_service.SearchNearestEntitiesResponse:
78-
return self._response
77+
def to_proto(self) -> fos_service.SearchNearestEntitiesResponse:
78+
return self._response
7979

8080

8181
class DistanceMeasureType(enum.Enum):
82-
"""The distance measure used in nearest neighbor search."""
82+
"""The distance measure used in nearest neighbor search."""
8383

84-
DISTANCE_MEASURE_TYPE_UNSPECIFIED = 0
85-
# Euclidean (L_2) Distance.
86-
SQUARED_L2_DISTANCE = 1
87-
# Cosine Distance. Defined as 1 - cosine similarity.
88-
COSINE_DISTANCE = 2
89-
# Dot Product Distance. Defined as a negative of the dot product.
90-
DOT_PRODUCT_DISTANCE = 3
84+
DISTANCE_MEASURE_TYPE_UNSPECIFIED = 0
85+
# Euclidean (L_2) Distance.
86+
SQUARED_L2_DISTANCE = 1
87+
# Cosine Distance. Defined as 1 - cosine similarity.
88+
COSINE_DISTANCE = 2
89+
# Dot Product Distance. Defined as a negative of the dot product.
90+
DOT_PRODUCT_DISTANCE = 3
9191

9292

9393
class AlgorithmConfig(abc.ABC):
94-
"""Base class for configuration options for matching algorithm."""
94+
"""Base class for configuration options for matching algorithm."""
9595

96-
def as_dict(self) -> Dict:
97-
"""Returns the configuration as a dictionary.
96+
def as_dict(self) -> Dict:
97+
"""Returns the configuration as a dictionary.
9898
99-
Returns:
100-
Dict[str, Any]
101-
"""
102-
pass
99+
Returns:
100+
Dict[str, Any]
101+
"""
102+
pass
103103

104104

105105
@dataclass
106106
class TreeAhConfig(AlgorithmConfig):
107-
"""Configuration options for using the tree-AH algorithm (Shallow tree + Asymmetric Hashing).
108-
Please refer to this paper for more details: https://arxiv.org/abs/1908.10396
107+
"""Configuration options for using the tree-AH algorithm (Shallow tree + Asymmetric Hashing).
109108
110-
Args:
111-
leaf_node_embedding_count (int):
112-
Optional. Number of embeddings on each leaf node. The default value is 1000 if not set.
113-
"""
109+
Please refer to this paper for more details: https://arxiv.org/abs/1908.10396
114110
115-
leaf_node_embedding_count: Optional[int] = None
111+
Args:
112+
leaf_node_embedding_count (int): Optional. Number of embeddings on each
113+
leaf node. The default value is 1000 if not set.
114+
"""
116115

117-
@override
118-
def as_dict(self) -> Dict:
119-
return {"leaf_node_embedding_count": self.leaf_node_embedding_count}
116+
leaf_node_embedding_count: Optional[int] = None
117+
118+
@override
119+
def as_dict(self) -> Dict:
120+
return {"leaf_node_embedding_count": self.leaf_node_embedding_count}
120121

121122

122123
@dataclass
123124
class BruteForceConfig(AlgorithmConfig):
124-
"""Configuration options for using brute force search.
125-
It simply implements the standard linear search in the database for
126-
each query.
127-
"""
125+
"""Configuration options for using brute force search.
128126
129-
@override
130-
def as_dict(self) -> Dict[str, Any]:
131-
return {"bruteForceConfig": {}}
127+
It simply implements the standard linear search in the database for each
128+
query.
129+
"""
130+
131+
@override
132+
def as_dict(self) -> Dict[str, Any]:
133+
return {"bruteForceConfig": {}}
132134

133135

134136
@dataclass
135137
class IndexConfig:
136-
"""Configuration options for the Vertex FeatureView for embedding."""
137-
138-
embedding_column: str
139-
dimensions: int
140-
algorithm_config: AlgorithmConfig
141-
filter_columns: Optional[List[str]] = None
142-
crowding_column: Optional[str] = None
143-
distance_measure_type: Optional[DistanceMeasureType] = None
144-
145-
def as_dict(self) -> Dict[str, Any]:
146-
"""Returns the configuration as a dictionary.
147-
148-
Returns:
149-
Dict[str, Any]
150-
"""
151-
config = {
152-
"embedding_column": self.embedding_column,
153-
"embedding_dimension": self.dimensions,
154-
}
155-
if self.distance_measure_type is not None:
156-
config["distance_measure_type"] = self.distance_measure_type.value
157-
if self.filter_columns is not None:
158-
config["filter_columns"] = self.filter_columns
159-
if self.crowding_column is not None:
160-
config["crowding_column"] = self.crowding_column
161-
162-
if isinstance(self.algorithm_config, TreeAhConfig):
163-
config["tree_ah_config"] = self.algorithm_config.as_dict()
164-
else:
165-
config["brute_force_config"] = self.algorithm_config.as_dict()
166-
return config
138+
"""Configuration options for the Vertex FeatureView for embedding."""
139+
140+
embedding_column: str
141+
dimensions: int
142+
algorithm_config: AlgorithmConfig = TreeAhConfig()
143+
filter_columns: Optional[List[str]] = None
144+
crowding_column: Optional[str] = None
145+
distance_measure_type: Optional[DistanceMeasureType] = None
146+
147+
def as_dict(self) -> Dict[str, Any]:
148+
"""Returns the configuration as a dictionary.
149+
150+
Returns:
151+
Dict[str, Any]
152+
"""
153+
config = {
154+
"embedding_column": self.embedding_column,
155+
"embedding_dimension": self.dimensions,
156+
}
157+
if self.distance_measure_type is not None:
158+
config["distance_measure_type"] = self.distance_measure_type.value
159+
if self.filter_columns is not None:
160+
config["filter_columns"] = self.filter_columns
161+
if self.crowding_column is not None:
162+
config["crowding_column"] = self.crowding_column
163+
164+
if isinstance(self.algorithm_config, TreeAhConfig):
165+
config["tree_ah_config"] = self.algorithm_config.as_dict()
166+
else:
167+
config["brute_force_config"] = self.algorithm_config.as_dict()
168+
return config
167169

168170

169171
@dataclass
170172
class FeatureGroupBigQuerySource:
171-
"""BigQuery source for the Feature Group."""
173+
"""BigQuery source for the Feature Group."""
172174

173-
# The URI for the BigQuery table/view.
174-
uri: str
175-
# The entity ID columns. If not specified, defaults to ['entity_id'].
176-
entity_id_columns: Optional[List[str]] = None
175+
# The URI for the BigQuery table/view.
176+
uri: str
177+
# The entity ID columns. If not specified, defaults to ['entity_id'].
178+
entity_id_columns: Optional[List[str]] = None

0 commit comments

Comments
 (0)