Skip to content

Commit c0626fe

Browse files
vertex-sdk-botcopybara-github
authored andcommitted
feat: GenAI - Add model routing config to sdk
PiperOrigin-RevId: 676391359
1 parent c326aa5 commit c0626fe

File tree

2 files changed

+125
-0
lines changed

2 files changed

+125
-0
lines changed

tests/system/vertexai/test_generative_models.py

+16
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
GEMINI_VISION_MODEL_NAME = "gemini-1.0-pro-vision"
3737
GEMINI_15_MODEL_NAME = "gemini-1.5-pro-preview-0409"
3838
GEMINI_15_PRO_MODEL_NAME = "gemini-1.5-pro-001"
39+
SMART_ROUTER_NAME = "smart-router-001"
3940

4041
STAGING_API_ENDPOINT = "STAGING_ENDPOINT"
4142
PROD_API_ENDPOINT = "PROD_ENDPOINT"
@@ -494,6 +495,21 @@ def test_generate_content_function_calling(self, api_endpoint_env_name):
494495

495496
assert summary
496497

498+
def test_generate_content_model_router(self, gapi_endpoint_env_name):
499+
model = generative_models.GenerativeModel(SMART_ROUTER_NAME)
500+
response = model.generate_content(
501+
contents="Why is sky blue?",
502+
generation_config=generative_models.GenerationConfig(
503+
temperature=0,
504+
routing_config=generative_models.GenerationConfig.RoutingConfig(
505+
routing_config=generative_models.GenerationConfig.RoutingConfig.AutoRoutingMode(
506+
model_routing_preference=generative_models.GenerationConfig.RoutingConfig.AutoRoutingMode.ModelRoutingPreference.BALANCED,
507+
),
508+
),
509+
),
510+
)
511+
assert response.text
512+
497513
def test_chat_automatic_function_calling(self, api_endpoint_env_name):
498514
get_current_weather_func = generative_models.FunctionDeclaration.from_func(
499515
get_current_weather

vertexai/generative_models/_generative_models.py

+109
Original file line numberDiff line numberDiff line change
@@ -1576,6 +1576,7 @@ def __init__(
15761576
response_mime_type: Optional[str] = None,
15771577
response_schema: Optional[Dict[str, Any]] = None,
15781578
seed: Optional[int] = None,
1579+
routing_config: Optional["RoutingConfig"] = None,
15791580
):
15801581
r"""Constructs a GenerationConfig object.
15811582
@@ -1601,6 +1602,7 @@ def __init__(
16011602
response type, otherwise the behavior is undefined.
16021603
response_schema: Output response schema of the genreated candidate text. Only valid when
16031604
response_mime_type is application/json.
1605+
routing_config: Model routing preference set in the request.
16041606
16051607
Usage:
16061608
```
@@ -1636,6 +1638,10 @@ def __init__(
16361638
response_schema=raw_schema,
16371639
seed=seed,
16381640
)
1641+
if routing_config is not None:
1642+
self._raw_generation_config.routing_config = (
1643+
routing_config._gapic_routing_config
1644+
)
16391645

16401646
@classmethod
16411647
def _from_gapic(
@@ -1659,6 +1665,109 @@ def to_dict(self) -> Dict[str, Any]:
16591665
def __repr__(self) -> str:
16601666
return self._raw_generation_config.__repr__()
16611667

1668+
class RoutingConfig:
1669+
r"""The configuration for model router requests.
1670+
1671+
The routing config is either one of the two nested classes:
1672+
- AutoRoutingMode: Automated routing.
1673+
- ManualRoutingMode: Manual routing.
1674+
1675+
Usage:
1676+
- AutoRoutingMode:
1677+
1678+
```
1679+
routing_config=generative_models.RoutingConfig(
1680+
routing_config=generative_models.RoutingConfig.AutoRoutingMode(
1681+
model_routing_preference=generative_models.RoutingConfig.AutoRoutingMode.ModelRoutingPreference.BALANCED,
1682+
),
1683+
)
1684+
```
1685+
- ManualRoutingMode:
1686+
1687+
```
1688+
routing_config=generative_models.RoutingConfig(
1689+
routing_config=generative_models.RoutingConfig.ManutalRoutingMode(
1690+
model_name="gemini-1.5-pro-001",
1691+
),
1692+
)
1693+
```
1694+
"""
1695+
1696+
def __init__(
1697+
self,
1698+
*,
1699+
routing_config: Union[
1700+
"GenerationConfig.RoutingConfig.AutoRoutingMode",
1701+
"GenerationConfig.RoutingConfig.ManualRoutingMode",
1702+
],
1703+
):
1704+
if isinstance(routing_config, self.AutoRoutingMode):
1705+
self._gapic_routing_config = (
1706+
gapic_content_types.GenerationConfig.RoutingConfig(
1707+
auto_mode=routing_config._gapic_auto_mode
1708+
)
1709+
)
1710+
else:
1711+
self._gapic_routing_config = (
1712+
gapic_content_types.GenerationConfig.RoutingConfig(
1713+
manual_mode=routing_config._gapic_manual_mode
1714+
)
1715+
)
1716+
1717+
def __repr__(self):
1718+
return self._gapic_routing_config.__repr__()
1719+
1720+
class AutoRoutingMode:
1721+
r"""When automated routing is specified, the routing will be
1722+
determined by the routing model predicted quality and customer provided
1723+
model routing preference.
1724+
"""
1725+
1726+
ModelRoutingPreference = (
1727+
gapic_content_types.GenerationConfig.RoutingConfig.AutoRoutingMode.ModelRoutingPreference
1728+
)
1729+
1730+
def __init__(
1731+
self,
1732+
*,
1733+
model_routing_preference: "GenerationConfig.RoutingConfig.AutoRoutingMode.ModelRoutingPreference",
1734+
):
1735+
r"""AutoRouingMode constructor
1736+
1737+
Args:
1738+
model_routing_preference: Model routing preference for the routing request.
1739+
"""
1740+
self._gapic_auto_mode = (
1741+
gapic_content_types.GenerationConfig.RoutingConfig.AutoRoutingMode(
1742+
model_routing_preference=model_routing_preference
1743+
)
1744+
)
1745+
1746+
def __repr__(self):
1747+
return self._gapic_auto_mode.__repr__()
1748+
1749+
class ManualRoutingMode:
1750+
r"""When manual routing is set, the specified model will be used
1751+
directly.
1752+
"""
1753+
1754+
def __init__(
1755+
self,
1756+
*,
1757+
model_name: str,
1758+
):
1759+
r"""ManualRoutingMode constructor
1760+
1761+
Args:
1762+
model_name: The model to use. Only public LLM model names and those that are supported by the router are allowed.
1763+
"""
1764+
self._gapic_manual_mode = gapic_content_types.GenerationConfig.RoutingConfig.ManualRoutingMode(
1765+
model_name=model_name
1766+
)
1767+
1768+
def __repr__(self):
1769+
return self._gapic_manual_mode.__repr__()
1770+
16621771

16631772
class Tool:
16641773
r"""A collection of functions that the model may use to generate response.

0 commit comments

Comments
 (0)