Skip to content

Commit 8d9341a

Browse files
Merge pull request #4863 from mdelder/feature/4714-vertexai-system-prompt-2
Fixes #4774: Add error handling for system prompt
2 parents 699d9af + bb5e8b6 commit 8d9341a

File tree

1 file changed

+75
-10
lines changed

1 file changed

+75
-10
lines changed

core/llm/llms/Gemini.ts

+75-10
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import {
44
CompletionOptions,
55
LLMOptions,
66
MessagePart,
7+
TextMessagePart,
78
ToolCallDelta,
89
} from "../../index.js";
910
import { findLast } from "../../util/findLast.js";
@@ -69,20 +70,84 @@ class Gemini extends BaseLLM {
6970
}
7071
}
7172

73+
/**
74+
* Removes the system message and merges it with the next user message if present.
75+
* @param messages Array of chat messages
76+
* @returns Modified array with system message merged into user message if applicable
77+
*/
7278
public removeSystemMessage(messages: ChatMessage[]): ChatMessage[] {
73-
// should be public for use within VertexAI
74-
const msgs = [...messages];
75-
76-
if (msgs[0]?.role === "system") {
77-
const sysMsg = msgs.shift()?.content;
78-
// @ts-ignore
79-
if (msgs[0]?.role === "user") {
80-
// @ts-ignore
81-
msgs[0].content = `System message - follow these instructions in every response: ${sysMsg}\n\n---\n\n${msgs[0].content}`;
79+
// If no messages or first message isn't system, return copy of original messages
80+
if (messages.length === 0 || messages[0]?.role !== "system") {
81+
return [...messages];
82+
}
83+
84+
// Extract system message
85+
const systemMessage: ChatMessage = messages[0];
86+
87+
// Extract system content based on its type
88+
let systemContent = "";
89+
90+
if (typeof systemMessage.content === "string") {
91+
systemContent = systemMessage.content;
92+
} else if (Array.isArray(systemMessage.content)) {
93+
const contentArray: Array<MessagePart> =
94+
systemMessage.content as Array<MessagePart>;
95+
96+
const concatenatedText = contentArray
97+
.filter((part): part is TextMessagePart => part.type === "text")
98+
.map((part) => part.text)
99+
.join(" ");
100+
101+
systemContent = concatenatedText ? concatenatedText : "";
102+
} else if (
103+
systemMessage.content &&
104+
typeof systemMessage.content === "object"
105+
) {
106+
const typedContent = systemMessage.content as TextMessagePart;
107+
systemContent = typedContent?.text || "";
108+
}
109+
110+
// Create new array without the system message
111+
const remainingMessages: ChatMessage[] = messages.slice(1);
112+
113+
// Check if there's a user message to merge with
114+
if (remainingMessages.length > 0 && remainingMessages[0].role === "user") {
115+
const userMessage: ChatMessage = remainingMessages[0];
116+
const prefix = `System message - follow these instructions in every response: ${systemContent}\n\n---\n\n`;
117+
118+
// Merge based on user content type
119+
if (typeof userMessage.content === "string") {
120+
userMessage.content = prefix + userMessage.content;
121+
} else if (Array.isArray(userMessage.content)) {
122+
const contentArray: Array<MessagePart> =
123+
userMessage.content as Array<MessagePart>;
124+
const textPart = contentArray.find((part) => part.type === "text") as
125+
| TextMessagePart
126+
| undefined;
127+
128+
if (textPart) {
129+
textPart.text = prefix + textPart.text;
130+
} else {
131+
userMessage.content.push({
132+
type: "text",
133+
text: prefix,
134+
} as TextMessagePart);
135+
}
136+
} else if (
137+
userMessage.content &&
138+
typeof userMessage.content === "object"
139+
) {
140+
const typedContent = userMessage.content as TextMessagePart;
141+
userMessage.content = [
142+
{
143+
type: "text",
144+
text: prefix + (typedContent.text || ""),
145+
} as TextMessagePart,
146+
];
82147
}
83148
}
84149

85-
return msgs;
150+
return remainingMessages;
86151
}
87152

88153
protected async *_streamChat(

0 commit comments

Comments
 (0)