Skip to content

Fixes #4774: Add error handling for system prompt #4863

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 63 additions & 12 deletions core/llm/llms/Gemini.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import {
CompletionOptions,
LLMOptions,
MessagePart,
TextMessagePart,
ToolCallDelta,
} from "../../index.js";
import { findLast } from "../../util/findLast.js";
Expand Down Expand Up @@ -69,21 +70,71 @@ class Gemini extends BaseLLM {
}
}

public removeSystemMessage(messages: ChatMessage[]): ChatMessage[] {
// should be public for use within VertexAI
const msgs = [...messages];

if (msgs[0]?.role === "system") {
const sysMsg = msgs.shift()?.content;
// @ts-ignore
if (msgs[0]?.role === "user") {
// @ts-ignore
msgs[0].content = `System message - follow these instructions in every response: ${sysMsg}\n\n---\n\n${msgs[0].content}`;
/**
* Removes the system message and merges it with the next user message if present.
* @param messages Array of chat messages
* @returns Modified array with system message merged into user message if applicable
*/
public removeSystemMessage(messages: ChatMessage[]): ChatMessage[] {
// If no messages or first message isn't system, return copy of original messages
if (messages.length === 0 || messages[0]?.role !== "system") {
return [...messages];
}

// Extract system message
const systemMessage : ChatMessage = messages[0];

// Extract system content based on its type
let systemContent = "";
if (typeof systemMessage.content === "string") {
systemContent = systemMessage.content;
} else if (Array.isArray(systemMessage.content)) {
const contentArray : Array<MessagePart> = systemMessage.content as Array<MessagePart>;
const concatenatedText = contentArray
.filter(part => part.type === "text")
.map(part => part.text)
.join(" ");
systemContent = concatenatedText ? concatenatedText : "";
} else if (systemMessage.content && typeof systemMessage.content === "object") {
const typedContent = systemMessage.content as TextMessagePart;
systemContent = typedContent?.text || "";
}

// Create new array without the system message
const remainingMessages : ChatMessage[] = messages.slice(1);

// Check if there's a user message to merge with
if (remainingMessages.length > 0 && remainingMessages[0].role === "user") {
const userMessage : ChatMessage = remainingMessages[0];
const prefix = `System message - follow these instructions in every response: ${systemContent}\n\n---\n\n`;

// Merge based on user content type
if (typeof userMessage.content === "string") {
userMessage.content = prefix + userMessage.content;
} else if (Array.isArray(userMessage.content)) {
const contentArray : Array<MessagePart> = userMessage.content as Array<MessagePart>;
const textPart = contentArray.find(part => part.type === "text") as TextMessagePart | undefined;

if (textPart) {
textPart.text = prefix + textPart.text;
} else {
userMessage.content.push({
type: "text",
text: prefix
} as TextMessagePart);
}
} else if (userMessage.content && typeof userMessage.content === "object") {
const typedContent = userMessage.content as TextMessagePart;
userMessage.content = [{
type: "text",
text: prefix + (typedContent.text || "")
} as TextMessagePart];
}

return msgs;
}

return remainingMessages;
}


protected async *_streamChat(
messages: ChatMessage[],
Expand Down
Loading