Skip to content

Commit 9954ddb

Browse files
authored
[Fix] modify sagemaker llm (#12274)
1 parent b218df6 commit 9954ddb

File tree

1 file changed

+108
-52
lines changed
  • api/core/model_runtime/model_providers/sagemaker/llm

1 file changed

+108
-52
lines changed

api/core/model_runtime/model_providers/sagemaker/llm/llm.py

+108-52
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import json
22
import logging
3-
import re
43
from collections.abc import Generator, Iterator
54
from typing import Any, Optional, Union, cast
65

@@ -132,58 +131,115 @@ def _handle_chat_stream_response(
132131
"""
133132
handle stream chat generate response
134133
"""
134+
135+
class ChunkProcessor:
136+
def __init__(self):
137+
self.buffer = bytearray()
138+
139+
def try_decode_chunk(self, chunk: bytes) -> list[dict]:
140+
"""尝试从chunk中解码出完整的JSON对象"""
141+
self.buffer.extend(chunk)
142+
results = []
143+
144+
while True:
145+
try:
146+
start = self.buffer.find(b"{")
147+
if start == -1:
148+
self.buffer.clear()
149+
break
150+
151+
bracket_count = 0
152+
end = start
153+
154+
for i in range(start, len(self.buffer)):
155+
if self.buffer[i] == ord("{"):
156+
bracket_count += 1
157+
elif self.buffer[i] == ord("}"):
158+
bracket_count -= 1
159+
if bracket_count == 0:
160+
end = i + 1
161+
break
162+
163+
if bracket_count != 0:
164+
# JSON不完整,等待更多数据
165+
if start > 0:
166+
self.buffer = self.buffer[start:]
167+
break
168+
169+
json_bytes = self.buffer[start:end]
170+
try:
171+
data = json.loads(json_bytes)
172+
results.append(data)
173+
self.buffer = self.buffer[end:]
174+
except json.JSONDecodeError:
175+
self.buffer = self.buffer[start + 1 :]
176+
177+
except Exception as e:
178+
logger.debug(f"Warning: Error processing chunk ({str(e)})")
179+
if start > 0:
180+
self.buffer = self.buffer[start:]
181+
break
182+
183+
return results
184+
135185
full_response = ""
136-
buffer = ""
137-
for chunk_bytes in resp:
138-
buffer += chunk_bytes.decode("utf-8")
139-
last_idx = 0
140-
for match in re.finditer(r"^data:\s*(.+?)(\n\n)", buffer):
141-
try:
142-
data = json.loads(match.group(1).strip())
143-
last_idx = match.span()[1]
144-
145-
if "content" in data["choices"][0]["delta"]:
146-
chunk_content = data["choices"][0]["delta"]["content"]
147-
assistant_prompt_message = AssistantPromptMessage(content=chunk_content, tool_calls=[])
148-
149-
if data["choices"][0]["finish_reason"] is not None:
150-
temp_assistant_prompt_message = AssistantPromptMessage(content=full_response, tool_calls=[])
151-
prompt_tokens = self._num_tokens_from_messages(messages=prompt_messages, tools=tools)
152-
completion_tokens = self._num_tokens_from_messages(
153-
messages=[temp_assistant_prompt_message], tools=[]
154-
)
155-
usage = self._calc_response_usage(
156-
model=model,
157-
credentials=credentials,
158-
prompt_tokens=prompt_tokens,
159-
completion_tokens=completion_tokens,
160-
)
161-
162-
yield LLMResultChunk(
163-
model=model,
164-
prompt_messages=prompt_messages,
165-
system_fingerprint=None,
166-
delta=LLMResultChunkDelta(
167-
index=0,
168-
message=assistant_prompt_message,
169-
finish_reason=data["choices"][0]["finish_reason"],
170-
usage=usage,
171-
),
172-
)
173-
else:
174-
yield LLMResultChunk(
175-
model=model,
176-
prompt_messages=prompt_messages,
177-
system_fingerprint=None,
178-
delta=LLMResultChunkDelta(index=0, message=assistant_prompt_message),
179-
)
180-
181-
full_response += chunk_content
182-
except (json.JSONDecodeError, KeyError, IndexError) as e:
183-
logger.info("json parse exception, content: {}".format(match.group(1).strip()))
184-
pass
185-
186-
buffer = buffer[last_idx:]
186+
processor = ChunkProcessor()
187+
188+
try:
189+
for chunk in resp:
190+
json_objects = processor.try_decode_chunk(chunk)
191+
192+
for data in json_objects:
193+
if data.get("choices"):
194+
choice = data["choices"][0]
195+
196+
if "delta" in choice and "content" in choice["delta"]:
197+
chunk_content = choice["delta"]["content"]
198+
assistant_prompt_message = AssistantPromptMessage(content=chunk_content, tool_calls=[])
199+
200+
if choice.get("finish_reason") is not None:
201+
temp_assistant_prompt_message = AssistantPromptMessage(
202+
content=full_response, tool_calls=[]
203+
)
204+
205+
prompt_tokens = self._num_tokens_from_messages(messages=prompt_messages, tools=tools)
206+
completion_tokens = self._num_tokens_from_messages(
207+
messages=[temp_assistant_prompt_message], tools=[]
208+
)
209+
210+
usage = self._calc_response_usage(
211+
model=model,
212+
credentials=credentials,
213+
prompt_tokens=prompt_tokens,
214+
completion_tokens=completion_tokens,
215+
)
216+
217+
yield LLMResultChunk(
218+
model=model,
219+
prompt_messages=prompt_messages,
220+
system_fingerprint=None,
221+
delta=LLMResultChunkDelta(
222+
index=0,
223+
message=assistant_prompt_message,
224+
finish_reason=choice["finish_reason"],
225+
usage=usage,
226+
),
227+
)
228+
else:
229+
yield LLMResultChunk(
230+
model=model,
231+
prompt_messages=prompt_messages,
232+
system_fingerprint=None,
233+
delta=LLMResultChunkDelta(index=0, message=assistant_prompt_message),
234+
)
235+
236+
full_response += chunk_content
237+
238+
except Exception as e:
239+
raise
240+
241+
if not full_response:
242+
logger.warning("No content received from stream response")
187243

188244
def _invoke(
189245
self,

0 commit comments

Comments
 (0)