Skip to content

[Fix] revert sagemaker llm to support model hub #12378

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jan 6, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
160 changes: 52 additions & 108 deletions api/core/model_runtime/model_providers/sagemaker/llm/llm.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import json
import logging
import re
from collections.abc import Generator, Iterator
from typing import Any, Optional, Union, cast

Expand Down Expand Up @@ -131,115 +132,58 @@ def _handle_chat_stream_response(
"""
handle stream chat generate response
"""

class ChunkProcessor:
def __init__(self):
self.buffer = bytearray()

def try_decode_chunk(self, chunk: bytes) -> list[dict]:
"""尝试从chunk中解码出完整的JSON对象"""
self.buffer.extend(chunk)
results = []

while True:
try:
start = self.buffer.find(b"{")
if start == -1:
self.buffer.clear()
break

bracket_count = 0
end = start

for i in range(start, len(self.buffer)):
if self.buffer[i] == ord("{"):
bracket_count += 1
elif self.buffer[i] == ord("}"):
bracket_count -= 1
if bracket_count == 0:
end = i + 1
break

if bracket_count != 0:
# JSON不完整,等待更多数据
if start > 0:
self.buffer = self.buffer[start:]
break

json_bytes = self.buffer[start:end]
try:
data = json.loads(json_bytes)
results.append(data)
self.buffer = self.buffer[end:]
except json.JSONDecodeError:
self.buffer = self.buffer[start + 1 :]

except Exception as e:
logger.debug(f"Warning: Error processing chunk ({str(e)})")
if start > 0:
self.buffer = self.buffer[start:]
break

return results

full_response = ""
processor = ChunkProcessor()

try:
for chunk in resp:
json_objects = processor.try_decode_chunk(chunk)

for data in json_objects:
if data.get("choices"):
choice = data["choices"][0]

if "delta" in choice and "content" in choice["delta"]:
chunk_content = choice["delta"]["content"]
assistant_prompt_message = AssistantPromptMessage(content=chunk_content, tool_calls=[])

if choice.get("finish_reason") is not None:
temp_assistant_prompt_message = AssistantPromptMessage(
content=full_response, tool_calls=[]
)

prompt_tokens = self._num_tokens_from_messages(messages=prompt_messages, tools=tools)
completion_tokens = self._num_tokens_from_messages(
messages=[temp_assistant_prompt_message], tools=[]
)

usage = self._calc_response_usage(
model=model,
credentials=credentials,
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
)

yield LLMResultChunk(
model=model,
prompt_messages=prompt_messages,
system_fingerprint=None,
delta=LLMResultChunkDelta(
index=0,
message=assistant_prompt_message,
finish_reason=choice["finish_reason"],
usage=usage,
),
)
else:
yield LLMResultChunk(
model=model,
prompt_messages=prompt_messages,
system_fingerprint=None,
delta=LLMResultChunkDelta(index=0, message=assistant_prompt_message),
)

full_response += chunk_content

except Exception as e:
raise

if not full_response:
logger.warning("No content received from stream response")
buffer = ""
for chunk_bytes in resp:
buffer += chunk_bytes.decode("utf-8")
last_idx = 0
for match in re.finditer(r"^data:\s*(.+?)(\n\n)", buffer):
try:
data = json.loads(match.group(1).strip())
last_idx = match.span()[1]

if "content" in data["choices"][0]["delta"]:
chunk_content = data["choices"][0]["delta"]["content"]
assistant_prompt_message = AssistantPromptMessage(content=chunk_content, tool_calls=[])

if data["choices"][0]["finish_reason"] is not None:
temp_assistant_prompt_message = AssistantPromptMessage(content=full_response, tool_calls=[])
prompt_tokens = self._num_tokens_from_messages(messages=prompt_messages, tools=tools)
completion_tokens = self._num_tokens_from_messages(
messages=[temp_assistant_prompt_message], tools=[]
)
usage = self._calc_response_usage(
model=model,
credentials=credentials,
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
)

yield LLMResultChunk(
model=model,
prompt_messages=prompt_messages,
system_fingerprint=None,
delta=LLMResultChunkDelta(
index=0,
message=assistant_prompt_message,
finish_reason=data["choices"][0]["finish_reason"],
usage=usage,
),
)
else:
yield LLMResultChunk(
model=model,
prompt_messages=prompt_messages,
system_fingerprint=None,
delta=LLMResultChunkDelta(index=0, message=assistant_prompt_message),
)

full_response += chunk_content
except (json.JSONDecodeError, KeyError, IndexError) as e:
logger.info("json parse exception, content: {}".format(match.group(1).strip()))
pass

buffer = buffer[last_idx:]

def _invoke(
self,
Expand Down
Loading