File tree 1 file changed +5
-8
lines changed
1 file changed +5
-8
lines changed Original file line number Diff line number Diff line change @@ -42,21 +42,17 @@ def _get_model_id_from_tuning_model_id(tuning_model_id: str) -> str:
42
42
"""Gets the base model ID for the model ID labels used the tuned models.
43
43
44
44
Args:
45
- tuning_model_id: The model ID used in tuning
45
+ tuning_model_id: The model ID used in tuning. E.g. `text-bison-001`
46
46
47
47
Returns:
48
48
The publisher model ID
49
49
50
50
Raises:
51
51
ValueError: If tuning model ID is unsupported
52
52
"""
53
- if tuning_model_id .startswith ("text-bison-" ):
54
- return tuning_model_id .replace (
55
- "text-bison-" , "publishers/google/models/text-bison@"
56
- )
57
- if "/" not in tuning_model_id :
58
- return "publishers/google/models/" + tuning_model_id
59
- return tuning_model_id
53
+ model_name , _ , version = tuning_model_id .rpartition ("-" )
54
+ # "publishers/google/models/text-bison@001"
55
+ return f"publishers/google/models/{ model_name } @{ version } "
60
56
61
57
62
58
class _LanguageModel (_model_garden_models ._ModelGardenModel ):
@@ -203,6 +199,7 @@ def tune_model(
203
199
tuned_model = job .result ()
204
200
# The UXR study attendees preferred to tune model in place
205
201
self ._endpoint = tuned_model ._endpoint
202
+ self ._endpoint_name = tuned_model ._endpoint_name
206
203
207
204
208
205
@dataclasses .dataclass
You can’t perform that action at this time.
0 commit comments