|
4 | 4 | CompletionOptions,
|
5 | 5 | LLMOptions,
|
6 | 6 | MessagePart,
|
| 7 | + TextMessagePart, |
7 | 8 | ToolCallDelta,
|
8 | 9 | } from "../../index.js";
|
9 | 10 | import { findLast } from "../../util/findLast.js";
|
@@ -69,21 +70,71 @@ class Gemini extends BaseLLM {
|
69 | 70 | }
|
70 | 71 | }
|
71 | 72 |
|
72 |
| - public removeSystemMessage(messages: ChatMessage[]): ChatMessage[] { |
73 |
| - // should be public for use within VertexAI |
74 |
| - const msgs = [...messages]; |
75 |
| - |
76 |
| - if (msgs[0]?.role === "system") { |
77 |
| - const sysMsg = msgs.shift()?.content; |
78 |
| - // @ts-ignore |
79 |
| - if (msgs[0]?.role === "user") { |
80 |
| - // @ts-ignore |
81 |
| - msgs[0].content = `System message - follow these instructions in every response: ${sysMsg}\n\n---\n\n${msgs[0].content}`; |
| 73 | +/** |
| 74 | + * Removes the system message and merges it with the next user message if present. |
| 75 | + * @param messages Array of chat messages |
| 76 | + * @returns Modified array with system message merged into user message if applicable |
| 77 | + */ |
| 78 | +public removeSystemMessage(messages: ChatMessage[]): ChatMessage[] { |
| 79 | + // If no messages or first message isn't system, return copy of original messages |
| 80 | + if (messages.length === 0 || messages[0]?.role !== "system") { |
| 81 | + return [...messages]; |
| 82 | + } |
| 83 | + |
| 84 | + // Extract system message |
| 85 | + const systemMessage : ChatMessage = messages[0]; |
| 86 | + |
| 87 | + // Extract system content based on its type |
| 88 | + let systemContent = ""; |
| 89 | + if (typeof systemMessage.content === "string") { |
| 90 | + systemContent = systemMessage.content; |
| 91 | + } else if (Array.isArray(systemMessage.content)) { |
| 92 | + const contentArray : Array<MessagePart> = systemMessage.content as Array<MessagePart>; |
| 93 | + const concatenatedText = contentArray |
| 94 | + .filter(part => part.type === "text") |
| 95 | + .map(part => part.text) |
| 96 | + .join(" "); |
| 97 | + systemContent = concatenatedText ? concatenatedText : ""; |
| 98 | + } else if (systemMessage.content && typeof systemMessage.content === "object") { |
| 99 | + const typedContent = systemMessage.content as TextMessagePart; |
| 100 | + systemContent = typedContent?.text || ""; |
| 101 | + } |
| 102 | + |
| 103 | + // Create new array without the system message |
| 104 | + const remainingMessages : ChatMessage[] = messages.slice(1); |
| 105 | + |
| 106 | + // Check if there's a user message to merge with |
| 107 | + if (remainingMessages.length > 0 && remainingMessages[0].role === "user") { |
| 108 | + const userMessage : ChatMessage = remainingMessages[0]; |
| 109 | + const prefix = `System message - follow these instructions in every response: ${systemContent}\n\n---\n\n`; |
| 110 | + |
| 111 | + // Merge based on user content type |
| 112 | + if (typeof userMessage.content === "string") { |
| 113 | + userMessage.content = prefix + userMessage.content; |
| 114 | + } else if (Array.isArray(userMessage.content)) { |
| 115 | + const contentArray : Array<MessagePart> = userMessage.content as Array<MessagePart>; |
| 116 | + const textPart = contentArray.find(part => part.type === "text") as TextMessagePart | undefined; |
| 117 | + |
| 118 | + if (textPart) { |
| 119 | + textPart.text = prefix + textPart.text; |
| 120 | + } else { |
| 121 | + userMessage.content.push({ |
| 122 | + type: "text", |
| 123 | + text: prefix |
| 124 | + } as TextMessagePart); |
82 | 125 | }
|
| 126 | + } else if (userMessage.content && typeof userMessage.content === "object") { |
| 127 | + const typedContent = userMessage.content as TextMessagePart; |
| 128 | + userMessage.content = [{ |
| 129 | + type: "text", |
| 130 | + text: prefix + (typedContent.text || "") |
| 131 | + } as TextMessagePart]; |
83 | 132 | }
|
84 |
| - |
85 |
| - return msgs; |
86 | 133 | }
|
| 134 | + |
| 135 | + return remainingMessages; |
| 136 | +} |
| 137 | + |
87 | 138 |
|
88 | 139 | protected async *_streamChat(
|
89 | 140 | messages: ChatMessage[],
|
|
0 commit comments