@@ -224,14 +224,45 @@ def test_openai_chat_completion_streaming(compat_client, client_with_models, tex
224
224
assert expected .lower () in "" .join (streamed_content )
225
225
226
226
227
+ @pytest .mark .parametrize (
228
+ "test_case" ,
229
+ [
230
+ "inference:chat_completion:streaming_01" ,
231
+ "inference:chat_completion:streaming_02" ,
232
+ ],
233
+ )
234
+ def test_openai_chat_completion_streaming_with_n (compat_client , client_with_models , text_model_id , test_case ):
235
+ skip_if_model_doesnt_support_openai_chat_completion (client_with_models , text_model_id )
236
+ tc = TestCase (test_case )
237
+ question = tc ["question" ]
238
+ expected = tc ["expected" ]
239
+
240
+ response = compat_client .chat .completions .create (
241
+ model = text_model_id ,
242
+ messages = [{"role" : "user" , "content" : question }],
243
+ stream = True ,
244
+ timeout = 120 , # Increase timeout to 2 minutes for large conversation history,
245
+ n = 2 ,
246
+ )
247
+ streamed_content = {}
248
+ for chunk in response :
249
+ for choice in chunk .choices :
250
+ if choice .delta .content :
251
+ streamed_content [choice .index ] = (
252
+ streamed_content .get (choice .index , "" ) + choice .delta .content .lower ().strip ()
253
+ )
254
+ assert len (streamed_content ) == 2
255
+ for i , content in streamed_content .items ():
256
+ assert expected .lower () in content , f"Choice { i } : Expected { expected .lower ()} in { content } "
257
+
258
+
227
259
@pytest .mark .parametrize (
228
260
"stream" ,
229
261
[
230
262
True ,
231
263
False ,
232
264
],
233
265
)
234
- @pytest .mark .skip (reason = "Very flaky, keeps failing on CI" )
235
266
def test_inference_store (openai_client , client_with_models , text_model_id , stream ):
236
267
skip_if_model_doesnt_support_openai_chat_completion (client_with_models , text_model_id )
237
268
client = openai_client
@@ -254,7 +285,8 @@ def test_inference_store(openai_client, client_with_models, text_model_id, strea
254
285
for chunk in response :
255
286
if response_id is None :
256
287
response_id = chunk .id
257
- content += chunk .choices [0 ].delta .content
288
+ if chunk .choices [0 ].delta .content :
289
+ content += chunk .choices [0 ].delta .content
258
290
else :
259
291
response_id = response .id
260
292
content = response .choices [0 ].message .content
@@ -264,8 +296,8 @@ def test_inference_store(openai_client, client_with_models, text_model_id, strea
264
296
265
297
retrieved_response = client .chat .completions .retrieve (response_id )
266
298
assert retrieved_response .id == response_id
267
- assert retrieved_response .input_messages [0 ]["content" ] == message
268
- assert retrieved_response .choices [0 ].message .content == content
299
+ assert retrieved_response .input_messages [0 ]["content" ] == message , retrieved_response
300
+ assert retrieved_response .choices [0 ].message .content == content , retrieved_response
269
301
270
302
271
303
@pytest .mark .parametrize (
@@ -275,7 +307,7 @@ def test_inference_store(openai_client, client_with_models, text_model_id, strea
275
307
False ,
276
308
],
277
309
)
278
- @pytest .mark .skip (reason = "Very flaky, tool calling really wacky on CI" )
310
+ # @pytest.mark.skip(reason="Very flaky, tool calling really wacky on CI")
279
311
def test_inference_store_tool_calls (openai_client , client_with_models , text_model_id , stream ):
280
312
skip_if_model_doesnt_support_openai_chat_completion (client_with_models , text_model_id )
281
313
client = openai_client
@@ -313,7 +345,9 @@ def test_inference_store_tool_calls(openai_client, client_with_models, text_mode
313
345
for chunk in response :
314
346
if response_id is None :
315
347
response_id = chunk .id
316
- content += chunk .choices [0 ].delta .content
348
+ if delta := chunk .choices [0 ].delta :
349
+ if delta .content :
350
+ content += delta .content
317
351
else :
318
352
response_id = response .id
319
353
content = response .choices [0 ].message .content
@@ -325,4 +359,4 @@ def test_inference_store_tool_calls(openai_client, client_with_models, text_mode
325
359
assert retrieved_response .id == response_id
326
360
assert retrieved_response .input_messages [0 ]["content" ] == message
327
361
assert retrieved_response .choices [0 ].message .tool_calls [0 ].function .name == "get_weather"
328
- assert retrieved_response .choices [0 ].message .tool_calls [0 ].function .arguments == '{"city":"Tokyo"}'
362
+ assert retrieved_response .choices [0 ].message .tool_calls [0 ].function .arguments == '{"city": "Tokyo"}'
0 commit comments