Skip to content

Commit d4b0b46

Browse files
committed
fix: chat completion with more than one choice
# What does this PR do? ## Test Plan
1 parent 484abe3 commit d4b0b46

File tree

2 files changed

+53
-11
lines changed

2 files changed

+53
-11
lines changed

llama_stack/providers/utils/inference/openai_compat.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1402,9 +1402,8 @@ async def _process_stream_response(
14021402
outstanding_responses: list[Awaitable[AsyncIterator[ChatCompletionResponseStreamChunk]]],
14031403
):
14041404
id = f"chatcmpl-{uuid.uuid4()}"
1405-
for outstanding_response in outstanding_responses:
1405+
for i, outstanding_response in enumerate(outstanding_responses):
14061406
response = await outstanding_response
1407-
i = 0
14081407
async for chunk in response:
14091408
event = chunk.event
14101409
finish_reason = _convert_stop_reason_to_openai_finish_reason(event.stop_reason)
@@ -1459,7 +1458,6 @@ async def _process_stream_response(
14591458
model=model,
14601459
object="chat.completion.chunk",
14611460
)
1462-
i = i + 1
14631461

14641462
async def _process_non_stream_response(
14651463
self, model: str, outstanding_responses: list[Awaitable[ChatCompletionResponse]]

tests/integration/inference/test_openai_completion.py

Lines changed: 52 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -224,14 +224,50 @@ def test_openai_chat_completion_streaming(compat_client, client_with_models, tex
224224
assert expected.lower() in "".join(streamed_content)
225225

226226

227+
@pytest.mark.parametrize(
228+
"test_case",
229+
[
230+
"inference:chat_completion:streaming_01",
231+
"inference:chat_completion:streaming_02",
232+
],
233+
)
234+
def test_openai_chat_completion_streaming_with_n(compat_client, client_with_models, text_model_id, test_case):
235+
skip_if_model_doesnt_support_openai_chat_completion(client_with_models, text_model_id)
236+
237+
provider = provider_from_model(client_with_models, text_model_id)
238+
if provider.provider_type == "remote::ollama":
239+
pytest.skip(f"Model {text_model_id} hosted by {provider.provider_type} doesn't support n > 1.")
240+
241+
tc = TestCase(test_case)
242+
question = tc["question"]
243+
expected = tc["expected"]
244+
245+
response = compat_client.chat.completions.create(
246+
model=text_model_id,
247+
messages=[{"role": "user", "content": question}],
248+
stream=True,
249+
timeout=120, # Increase timeout to 2 minutes for large conversation history,
250+
n=2,
251+
)
252+
streamed_content = {}
253+
for chunk in response:
254+
for choice in chunk.choices:
255+
if choice.delta.content:
256+
streamed_content[choice.index] = (
257+
streamed_content.get(choice.index, "") + choice.delta.content.lower().strip()
258+
)
259+
assert len(streamed_content) == 2
260+
for i, content in streamed_content.items():
261+
assert expected.lower() in content, f"Choice {i}: Expected {expected.lower()} in {content}"
262+
263+
227264
@pytest.mark.parametrize(
228265
"stream",
229266
[
230267
True,
231268
False,
232269
],
233270
)
234-
@pytest.mark.skip(reason="Very flaky, keeps failing on CI")
235271
def test_inference_store(openai_client, client_with_models, text_model_id, stream):
236272
skip_if_model_doesnt_support_openai_chat_completion(client_with_models, text_model_id)
237273
client = openai_client
@@ -254,7 +290,8 @@ def test_inference_store(openai_client, client_with_models, text_model_id, strea
254290
for chunk in response:
255291
if response_id is None:
256292
response_id = chunk.id
257-
content += chunk.choices[0].delta.content
293+
if chunk.choices[0].delta.content:
294+
content += chunk.choices[0].delta.content
258295
else:
259296
response_id = response.id
260297
content = response.choices[0].message.content
@@ -264,8 +301,8 @@ def test_inference_store(openai_client, client_with_models, text_model_id, strea
264301

265302
retrieved_response = client.chat.completions.retrieve(response_id)
266303
assert retrieved_response.id == response_id
267-
assert retrieved_response.input_messages[0]["content"] == message
268-
assert retrieved_response.choices[0].message.content == content
304+
assert retrieved_response.input_messages[0]["content"] == message, retrieved_response
305+
assert retrieved_response.choices[0].message.content == content, retrieved_response
269306

270307

271308
@pytest.mark.parametrize(
@@ -275,7 +312,6 @@ def test_inference_store(openai_client, client_with_models, text_model_id, strea
275312
False,
276313
],
277314
)
278-
@pytest.mark.skip(reason="Very flaky, tool calling really wacky on CI")
279315
def test_inference_store_tool_calls(openai_client, client_with_models, text_model_id, stream):
280316
skip_if_model_doesnt_support_openai_chat_completion(client_with_models, text_model_id)
281317
client = openai_client
@@ -313,7 +349,9 @@ def test_inference_store_tool_calls(openai_client, client_with_models, text_mode
313349
for chunk in response:
314350
if response_id is None:
315351
response_id = chunk.id
316-
content += chunk.choices[0].delta.content
352+
if delta := chunk.choices[0].delta:
353+
if delta.content:
354+
content += delta.content
317355
else:
318356
response_id = response.id
319357
content = response.choices[0].message.content
@@ -324,5 +362,11 @@ def test_inference_store_tool_calls(openai_client, client_with_models, text_mode
324362
retrieved_response = client.chat.completions.retrieve(response_id)
325363
assert retrieved_response.id == response_id
326364
assert retrieved_response.input_messages[0]["content"] == message
327-
assert retrieved_response.choices[0].message.tool_calls[0].function.name == "get_weather"
328-
assert retrieved_response.choices[0].message.tool_calls[0].function.arguments == '{"city":"Tokyo"}'
365+
tool_calls = retrieved_response.choices[0].message.tool_calls
366+
# sometimes model doesn't ouptut tool calls, but we still want to test that the tool was called
367+
if tool_calls:
368+
assert len(tool_calls) == 1
369+
assert tool_calls[0].function.name == "get_weather"
370+
assert "tokyo" in tool_calls[0].function.arguments.lower()
371+
else:
372+
assert retrieved_response.choices[0].message.content == content

0 commit comments

Comments
 (0)