Skip to content

Commit 01ccfaa

Browse files
committed
fix: chat completion with more than one choice
# What does this PR do? ## Test Plan
1 parent 7105a25 commit 01ccfaa

File tree

2 files changed

+43
-10
lines changed

2 files changed

+43
-10
lines changed

llama_stack/providers/utils/inference/openai_compat.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1377,6 +1377,7 @@ async def openai_chat_completion(
13771377
outstanding_responses = []
13781378
# "n" is the number of completions to generate per prompt
13791379
n = n or 1
1380+
print(f"n: {n}")
13801381
for _i in range(0, n):
13811382
response = self.chat_completion(
13821383
model_id=model,
@@ -1402,9 +1403,8 @@ async def _process_stream_response(
14021403
outstanding_responses: list[Awaitable[AsyncIterator[ChatCompletionResponseStreamChunk]]],
14031404
):
14041405
id = f"chatcmpl-{uuid.uuid4()}"
1405-
for outstanding_response in outstanding_responses:
1406+
for i, outstanding_response in enumerate(outstanding_responses):
14061407
response = await outstanding_response
1407-
i = 0
14081408
async for chunk in response:
14091409
event = chunk.event
14101410
finish_reason = _convert_stop_reason_to_openai_finish_reason(event.stop_reason)
@@ -1459,7 +1459,6 @@ async def _process_stream_response(
14591459
model=model,
14601460
object="chat.completion.chunk",
14611461
)
1462-
i = i + 1
14631462

14641463
async def _process_non_stream_response(
14651464
self, model: str, outstanding_responses: list[Awaitable[ChatCompletionResponse]]

tests/integration/inference/test_openai_completion.py

Lines changed: 41 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -224,14 +224,45 @@ def test_openai_chat_completion_streaming(compat_client, client_with_models, tex
224224
assert expected.lower() in "".join(streamed_content)
225225

226226

227+
@pytest.mark.parametrize(
228+
"test_case",
229+
[
230+
"inference:chat_completion:streaming_01",
231+
"inference:chat_completion:streaming_02",
232+
],
233+
)
234+
def test_openai_chat_completion_streaming_with_n(compat_client, client_with_models, text_model_id, test_case):
235+
skip_if_model_doesnt_support_openai_chat_completion(client_with_models, text_model_id)
236+
tc = TestCase(test_case)
237+
question = tc["question"]
238+
expected = tc["expected"]
239+
240+
response = compat_client.chat.completions.create(
241+
model=text_model_id,
242+
messages=[{"role": "user", "content": question}],
243+
stream=True,
244+
timeout=120, # Increase timeout to 2 minutes for large conversation history,
245+
n=2,
246+
)
247+
streamed_content = {}
248+
for chunk in response:
249+
for choice in chunk.choices:
250+
if choice.delta.content:
251+
streamed_content[choice.index] = (
252+
streamed_content.get(choice.index, "") + choice.delta.content.lower().strip()
253+
)
254+
assert len(streamed_content) == 2
255+
for i, content in streamed_content.items():
256+
assert expected.lower() in content, f"Choice {i}: Expected {expected.lower()} in {content}"
257+
258+
227259
@pytest.mark.parametrize(
228260
"stream",
229261
[
230262
True,
231263
False,
232264
],
233265
)
234-
@pytest.mark.skip(reason="Very flaky, keeps failing on CI")
235266
def test_inference_store(openai_client, client_with_models, text_model_id, stream):
236267
skip_if_model_doesnt_support_openai_chat_completion(client_with_models, text_model_id)
237268
client = openai_client
@@ -254,7 +285,8 @@ def test_inference_store(openai_client, client_with_models, text_model_id, strea
254285
for chunk in response:
255286
if response_id is None:
256287
response_id = chunk.id
257-
content += chunk.choices[0].delta.content
288+
if chunk.choices[0].delta.content:
289+
content += chunk.choices[0].delta.content
258290
else:
259291
response_id = response.id
260292
content = response.choices[0].message.content
@@ -264,8 +296,8 @@ def test_inference_store(openai_client, client_with_models, text_model_id, strea
264296

265297
retrieved_response = client.chat.completions.retrieve(response_id)
266298
assert retrieved_response.id == response_id
267-
assert retrieved_response.input_messages[0]["content"] == message
268-
assert retrieved_response.choices[0].message.content == content
299+
assert retrieved_response.input_messages[0]["content"] == message, retrieved_response
300+
assert retrieved_response.choices[0].message.content == content, retrieved_response
269301

270302

271303
@pytest.mark.parametrize(
@@ -275,7 +307,7 @@ def test_inference_store(openai_client, client_with_models, text_model_id, strea
275307
False,
276308
],
277309
)
278-
@pytest.mark.skip(reason="Very flaky, tool calling really wacky on CI")
310+
# @pytest.mark.skip(reason="Very flaky, tool calling really wacky on CI")
279311
def test_inference_store_tool_calls(openai_client, client_with_models, text_model_id, stream):
280312
skip_if_model_doesnt_support_openai_chat_completion(client_with_models, text_model_id)
281313
client = openai_client
@@ -313,7 +345,9 @@ def test_inference_store_tool_calls(openai_client, client_with_models, text_mode
313345
for chunk in response:
314346
if response_id is None:
315347
response_id = chunk.id
316-
content += chunk.choices[0].delta.content
348+
if delta := chunk.choices[0].delta:
349+
if delta.content:
350+
content += delta.content
317351
else:
318352
response_id = response.id
319353
content = response.choices[0].message.content
@@ -325,4 +359,4 @@ def test_inference_store_tool_calls(openai_client, client_with_models, text_mode
325359
assert retrieved_response.id == response_id
326360
assert retrieved_response.input_messages[0]["content"] == message
327361
assert retrieved_response.choices[0].message.tool_calls[0].function.name == "get_weather"
328-
assert retrieved_response.choices[0].message.tool_calls[0].function.arguments == '{"city":"Tokyo"}'
362+
assert retrieved_response.choices[0].message.tool_calls[0].function.arguments == '{"city": "Tokyo"}'

0 commit comments

Comments
 (0)