|
4 | 4 | CompletionOptions,
|
5 | 5 | LLMOptions,
|
6 | 6 | MessagePart,
|
| 7 | + TextMessagePart, |
7 | 8 | ToolCallDelta,
|
8 | 9 | } from "../../index.js";
|
9 | 10 | import { findLast } from "../../util/findLast.js";
|
@@ -69,20 +70,84 @@ class Gemini extends BaseLLM {
|
69 | 70 | }
|
70 | 71 | }
|
71 | 72 |
|
| 73 | + /** |
| 74 | + * Removes the system message and merges it with the next user message if present. |
| 75 | + * @param messages Array of chat messages |
| 76 | + * @returns Modified array with system message merged into user message if applicable |
| 77 | + */ |
72 | 78 | public removeSystemMessage(messages: ChatMessage[]): ChatMessage[] {
|
73 |
| - // should be public for use within VertexAI |
74 |
| - const msgs = [...messages]; |
75 |
| - |
76 |
| - if (msgs[0]?.role === "system") { |
77 |
| - const sysMsg = msgs.shift()?.content; |
78 |
| - // @ts-ignore |
79 |
| - if (msgs[0]?.role === "user") { |
80 |
| - // @ts-ignore |
81 |
| - msgs[0].content = `System message - follow these instructions in every response: ${sysMsg}\n\n---\n\n${msgs[0].content}`; |
| 79 | + // If no messages or first message isn't system, return copy of original messages |
| 80 | + if (messages.length === 0 || messages[0]?.role !== "system") { |
| 81 | + return [...messages]; |
| 82 | + } |
| 83 | + |
| 84 | + // Extract system message |
| 85 | + const systemMessage: ChatMessage = messages[0]; |
| 86 | + |
| 87 | + // Extract system content based on its type |
| 88 | + let systemContent = ""; |
| 89 | + |
| 90 | + if (typeof systemMessage.content === "string") { |
| 91 | + systemContent = systemMessage.content; |
| 92 | + } else if (Array.isArray(systemMessage.content)) { |
| 93 | + const contentArray: Array<MessagePart> = |
| 94 | + systemMessage.content as Array<MessagePart>; |
| 95 | + |
| 96 | + const concatenatedText = contentArray |
| 97 | + .filter((part): part is TextMessagePart => part.type === "text") |
| 98 | + .map((part) => part.text) |
| 99 | + .join(" "); |
| 100 | + |
| 101 | + systemContent = concatenatedText ? concatenatedText : ""; |
| 102 | + } else if ( |
| 103 | + systemMessage.content && |
| 104 | + typeof systemMessage.content === "object" |
| 105 | + ) { |
| 106 | + const typedContent = systemMessage.content as TextMessagePart; |
| 107 | + systemContent = typedContent?.text || ""; |
| 108 | + } |
| 109 | + |
| 110 | + // Create new array without the system message |
| 111 | + const remainingMessages: ChatMessage[] = messages.slice(1); |
| 112 | + |
| 113 | + // Check if there's a user message to merge with |
| 114 | + if (remainingMessages.length > 0 && remainingMessages[0].role === "user") { |
| 115 | + const userMessage: ChatMessage = remainingMessages[0]; |
| 116 | + const prefix = `System message - follow these instructions in every response: ${systemContent}\n\n---\n\n`; |
| 117 | + |
| 118 | + // Merge based on user content type |
| 119 | + if (typeof userMessage.content === "string") { |
| 120 | + userMessage.content = prefix + userMessage.content; |
| 121 | + } else if (Array.isArray(userMessage.content)) { |
| 122 | + const contentArray: Array<MessagePart> = |
| 123 | + userMessage.content as Array<MessagePart>; |
| 124 | + const textPart = contentArray.find((part) => part.type === "text") as |
| 125 | + | TextMessagePart |
| 126 | + | undefined; |
| 127 | + |
| 128 | + if (textPart) { |
| 129 | + textPart.text = prefix + textPart.text; |
| 130 | + } else { |
| 131 | + userMessage.content.push({ |
| 132 | + type: "text", |
| 133 | + text: prefix, |
| 134 | + } as TextMessagePart); |
| 135 | + } |
| 136 | + } else if ( |
| 137 | + userMessage.content && |
| 138 | + typeof userMessage.content === "object" |
| 139 | + ) { |
| 140 | + const typedContent = userMessage.content as TextMessagePart; |
| 141 | + userMessage.content = [ |
| 142 | + { |
| 143 | + type: "text", |
| 144 | + text: prefix + (typedContent.text || ""), |
| 145 | + } as TextMessagePart, |
| 146 | + ]; |
82 | 147 | }
|
83 | 148 | }
|
84 | 149 |
|
85 |
| - return msgs; |
| 150 | + return remainingMessages; |
86 | 151 | }
|
87 | 152 |
|
88 | 153 | protected async *_streamChat(
|
|
0 commit comments