@@ -81,3 +81,66 @@ def test_text_embedding(self):
81
81
for embedding in embeddings :
82
82
vector = embedding .values
83
83
assert len (vector ) == 768
84
+
85
+ def test_tuning (self , shared_state ):
86
+ """Test tuning, listing and loading models."""
87
+ aiplatform .init (project = e2e_base ._PROJECT , location = e2e_base ._LOCATION )
88
+
89
+ model = TextGenerationModel .from_pretrained ("google/text-bison@001" )
90
+
91
+ import pandas
92
+
93
+ training_data = pandas .DataFrame (
94
+ data = [
95
+ {"input_text" : "Input 0" , "output_text" : "Output 0" },
96
+ {"input_text" : "Input 1" , "output_text" : "Output 1" },
97
+ {"input_text" : "Input 2" , "output_text" : "Output 2" },
98
+ {"input_text" : "Input 3" , "output_text" : "Output 3" },
99
+ {"input_text" : "Input 4" , "output_text" : "Output 4" },
100
+ {"input_text" : "Input 5" , "output_text" : "Output 5" },
101
+ {"input_text" : "Input 6" , "output_text" : "Output 6" },
102
+ {"input_text" : "Input 7" , "output_text" : "Output 7" },
103
+ {"input_text" : "Input 8" , "output_text" : "Output 8" },
104
+ {"input_text" : "Input 9" , "output_text" : "Output 9" },
105
+ ]
106
+ )
107
+
108
+ model .tune_model (
109
+ training_data = training_data ,
110
+ train_steps = 1 ,
111
+ tuning_job_location = "europe-west4" ,
112
+ tuned_model_location = "us-central1" ,
113
+ )
114
+ # According to the Pipelines design, external resources created by a pipeline
115
+ # must not be modified or deleted. Otherwise caching will break next pipeline runs.
116
+ shared_state .setdefault ("resources" , [])
117
+ shared_state ["resources" ].append (model ._endpoint )
118
+ shared_state ["resources" ].extend (
119
+ aiplatform .Model (model_name = deployed_model .model )
120
+ for deployed_model in model ._endpoint .list_models ()
121
+ )
122
+ # Deleting the Endpoint is a little less bad since the LLM SDK will recreate it, but it's not advised for the same reason.
123
+
124
+ response = model .predict (
125
+ "What is the best recipe for banana bread? Recipe:" ,
126
+ max_output_tokens = 128 ,
127
+ temperature = 0 ,
128
+ top_p = 1 ,
129
+ top_k = 5 ,
130
+ )
131
+ assert response .text
132
+
133
+ tuned_model_names = model .list_tuned_model_names ()
134
+ assert tuned_model_names
135
+ tuned_model_name = tuned_model_names [0 ]
136
+
137
+ tuned_model = TextGenerationModel .get_tuned_model (tuned_model_name )
138
+
139
+ tuned_model_response = tuned_model .predict (
140
+ "What is the best recipe for banana bread? Recipe:" ,
141
+ max_output_tokens = 128 ,
142
+ temperature = 0 ,
143
+ top_p = 1 ,
144
+ top_k = 5 ,
145
+ )
146
+ assert tuned_model_response .text
0 commit comments