@@ -2254,6 +2254,7 @@ def _create_prediction_request(
2254
2254
max_output_tokens : Optional [int ] = None ,
2255
2255
temperature : Optional [float ] = None ,
2256
2256
stop_sequences : Optional [List [str ]] = None ,
2257
+ candidate_count : Optional [int ] = None ,
2257
2258
) -> _PredictionRequest :
2258
2259
"""Creates a code generation prediction request.
2259
2260
@@ -2263,7 +2264,7 @@ def _create_prediction_request(
2263
2264
max_output_tokens: Max length of the output text in tokens. Range: [1, 1000].
2264
2265
temperature: Controls the randomness of predictions. Range: [0, 1].
2265
2266
stop_sequences: Customized stop sequences to stop the decoding process.
2266
-
2267
+ candidate_count: Number of response candidates to return.
2267
2268
2268
2269
Returns:
2269
2270
A `TextGenerationResponse` object that contains the text produced by the model.
@@ -2285,6 +2286,9 @@ def _create_prediction_request(
2285
2286
if stop_sequences :
2286
2287
prediction_parameters ["stopSequences" ] = stop_sequences
2287
2288
2289
+ if candidate_count is not None :
2290
+ prediction_parameters ["candidateCount" ] = candidate_count
2291
+
2288
2292
return _PredictionRequest (instance = instance , parameters = prediction_parameters )
2289
2293
2290
2294
def predict (
@@ -2295,6 +2299,7 @@ def predict(
2295
2299
max_output_tokens : Optional [int ] = None ,
2296
2300
temperature : Optional [float ] = None ,
2297
2301
stop_sequences : Optional [List [str ]] = None ,
2302
+ candidate_count : Optional [int ] = None ,
2298
2303
) -> "TextGenerationResponse" :
2299
2304
"""Gets model response for a single prompt.
2300
2305
@@ -2304,23 +2309,26 @@ def predict(
2304
2309
max_output_tokens: Max length of the output text in tokens. Range: [1, 1000].
2305
2310
temperature: Controls the randomness of predictions. Range: [0, 1].
2306
2311
stop_sequences: Customized stop sequences to stop the decoding process.
2312
+ candidate_count: Number of response candidates to return.
2307
2313
2308
2314
Returns:
2309
- A `TextGenerationResponse` object that contains the text produced by the model.
2315
+ A `MultiCandidateTextGenerationResponse` object that contains the
2316
+ text produced by the model.
2310
2317
"""
2311
2318
prediction_request = self ._create_prediction_request (
2312
2319
prefix = prefix ,
2313
2320
suffix = suffix ,
2314
2321
max_output_tokens = max_output_tokens ,
2315
2322
temperature = temperature ,
2316
2323
stop_sequences = stop_sequences ,
2324
+ candidate_count = candidate_count ,
2317
2325
)
2318
2326
2319
2327
prediction_response = self ._endpoint .predict (
2320
2328
instances = [prediction_request .instance ],
2321
2329
parameters = prediction_request .parameters ,
2322
2330
)
2323
- return _parse_text_generation_model_response (prediction_response )
2331
+ return _parse_text_generation_model_multi_candidate_response (prediction_response )
2324
2332
2325
2333
async def predict_async (
2326
2334
self ,
@@ -2330,6 +2338,7 @@ async def predict_async(
2330
2338
max_output_tokens : Optional [int ] = None ,
2331
2339
temperature : Optional [float ] = None ,
2332
2340
stop_sequences : Optional [List [str ]] = None ,
2341
+ candidate_count : Optional [int ] = None ,
2333
2342
) -> "TextGenerationResponse" :
2334
2343
"""Asynchronously gets model response for a single prompt.
2335
2344
@@ -2339,23 +2348,26 @@ async def predict_async(
2339
2348
max_output_tokens: Max length of the output text in tokens. Range: [1, 1000].
2340
2349
temperature: Controls the randomness of predictions. Range: [0, 1].
2341
2350
stop_sequences: Customized stop sequences to stop the decoding process.
2351
+ candidate_count: Number of response candidates to return.
2342
2352
2343
2353
Returns:
2344
- A `TextGenerationResponse` object that contains the text produced by the model.
2354
+ A `MultiCandidateTextGenerationResponse` object that contains the
2355
+ text produced by the model.
2345
2356
"""
2346
2357
prediction_request = self ._create_prediction_request (
2347
2358
prefix = prefix ,
2348
2359
suffix = suffix ,
2349
2360
max_output_tokens = max_output_tokens ,
2350
2361
temperature = temperature ,
2351
2362
stop_sequences = stop_sequences ,
2363
+ candidate_count = candidate_count ,
2352
2364
)
2353
2365
2354
2366
prediction_response = await self ._endpoint .predict_async (
2355
2367
instances = [prediction_request .instance ],
2356
2368
parameters = prediction_request .parameters ,
2357
2369
)
2358
- return _parse_text_generation_model_response (prediction_response )
2370
+ return _parse_text_generation_model_multi_candidate_response (prediction_response )
2359
2371
2360
2372
def predict_streaming (
2361
2373
self ,
0 commit comments