Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Show warning message when last user input get pruned #4816

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions core/core.ts
Original file line number Diff line number Diff line change
Expand Up @@ -474,6 +474,18 @@ export class Core {
}
});

on("llm/compileChat", async (msg) => {
const { title: modelName, messages, options } = msg.data;
const model = await this.configHandler.llmFromTitle(modelName);

const { compiledChatMessages, lastMessageTruncated } =
model.compileChatMessages(options, messages);
return {
compiledChatMessages,
lastMessageTruncated,
};
});

on("llm/streamChat", (msg) =>
llmStreamChat(
this.configHandler,
Expand Down
10 changes: 10 additions & 0 deletions core/index.d.ts
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@ export interface ILLM extends LLMOptions {
messages: ChatMessage[],
signal: AbortSignal,
options?: LLMFullCompletionOptions,
messageOptions?: MessageOptions,
): AsyncGenerator<ChatMessage, PromptLog>;

chat(
Expand All @@ -140,6 +141,11 @@ export interface ILLM extends LLMOptions {
options?: LLMFullCompletionOptions,
): Promise<ChatMessage>;

compileChatMessages(
options: LLMFullCompletionOptions,
messages: ChatMessage[],
): { compiledChatMessages: ChatMEssage[]; lastMessageTruncated: boolean };

embed(chunks: string[]): Promise<number[][]>;

rerank(query: string, chunks: Chunk[]): Promise<number[]>;
Expand Down Expand Up @@ -1388,3 +1394,7 @@ export interface TerminalOptions {
reuseTerminal?: boolean;
terminalName?: string;
}

export interface MessageOptions {
precompiled: boolean;
}
14 changes: 9 additions & 5 deletions core/llm/countTokens.ts
Original file line number Diff line number Diff line change
Expand Up @@ -251,13 +251,15 @@ function pruneChatHistory(
chatHistory: ChatMessage[],
contextLength: number,
tokensForCompletion: number,
): ChatMessage[] {
): { prunedChatHistory: ChatMessage[]; lastMessageTruncated: boolean } {
let totalTokens =
tokensForCompletion +
chatHistory.reduce((acc, message) => {
return acc + countChatMessageTokens(modelName, message);
}, 0);

let lastMessageTruncated = false;

// 0. Prune any messages that take up more than 1/3 of the context length
const zippedMessagesAndTokens: [ChatMessage, number][] = [];

Expand Down Expand Up @@ -363,8 +365,10 @@ function pruneChatHistory(
tokensForCompletion,
);
totalTokens = contextLength;
lastMessageTruncated = true;
}
return chatHistory;

return { prunedChatHistory: chatHistory, lastMessageTruncated };
}

function messageIsEmpty(message: ChatMessage): boolean {
Expand Down Expand Up @@ -527,7 +531,7 @@ function compileChatMessages({
functions: any[] | undefined;
systemMessage: string | undefined;
rules: Rule[];
}): ChatMessage[] {
}): { compiledChatMessages: ChatMessage[]; lastMessageTruncated: boolean } {
let msgsCopy = msgs
? msgs
.map((msg) => ({ ...msg }))
Expand Down Expand Up @@ -581,7 +585,7 @@ function compileChatMessages({
}
}

const history = pruneChatHistory(
const { prunedChatHistory: history, lastMessageTruncated } = pruneChatHistory(
modelName,
msgsCopy,
contextLength,
Expand All @@ -595,7 +599,7 @@ function compileChatMessages({

const flattenedHistory = flattenMessages(history);

return flattenedHistory;
return { compiledChatMessages: flattenedHistory, lastMessageTruncated };
}

export {
Expand Down
22 changes: 21 additions & 1 deletion core/llm/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import {
ILLM,
LLMFullCompletionOptions,
LLMOptions,
MessageOptions,
ModelCapability,
ModelInstaller,
PromptLog,
Expand Down Expand Up @@ -738,6 +739,15 @@ export abstract class BaseLLM implements ILLM {
return { role: "assistant" as const, content: completion };
}

compileChatMessages(
options: LLMFullCompletionOptions,
messages: ChatMessage[],
) {
let { completionOptions } = this._parseCompletionOptions(options);
completionOptions = this._modifyCompletionOptions(completionOptions);
return this._compileChatMessages(completionOptions, messages);
}

protected modifyChatBody(
body: ChatCompletionCreateParams,
): ChatCompletionCreateParams {
Expand All @@ -762,13 +772,23 @@ export abstract class BaseLLM implements ILLM {
_messages: ChatMessage[],
signal: AbortSignal,
options: LLMFullCompletionOptions = {},
messageOptions?: MessageOptions,
): AsyncGenerator<ChatMessage, PromptLog> {
let { completionOptions, logEnabled } =
this._parseCompletionOptions(options);

completionOptions = this._modifyCompletionOptions(completionOptions);

const messages = this._compileChatMessages(completionOptions, _messages);
const { precompiled } = messageOptions ?? {};

let messages = _messages;
if (!precompiled) {
const { compiledChatMessages } = this._compileChatMessages(
completionOptions,
_messages,
);
messages = compiledChatMessages;
}

const prompt = this.templateMessages
? this.templateMessages(messages)
Expand Down
10 changes: 8 additions & 2 deletions core/llm/streamChat.ts
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,13 @@ export async function* llmStreamChat(
void TTS.kill();
}

const { title, legacySlashCommandData, completionOptions, messages } =
msg.data;
const {
title,
legacySlashCommandData,
completionOptions,
messages,
messageOptions,
} = msg.data;

const model = await configHandler.llmFromTitle(title);

Expand Down Expand Up @@ -113,6 +118,7 @@ export async function* llmStreamChat(
messages,
new AbortController().signal,
completionOptions,
messageOptions,
);
let next = await gen.next();
while (!next.done) {
Expand Down
10 changes: 10 additions & 0 deletions core/protocol/core.ts
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import type {
FileSymbolMap,
IdeSettings,
LLMFullCompletionOptions,
MessageOptions,
ModelDescription,
PromptLog,
RangeInFile,
Expand Down Expand Up @@ -115,12 +116,21 @@ export type ToCoreFromIdeOrWebviewProtocol = {
},
string,
];
"llm/compileChat": [
{
title: string;
messages: ChatMessage[];
options: LLMFullCompletionOptions;
},
{ compiledChatMessages: ChatMessage[]; lastMessageTruncated: boolean },
];
"llm/listModels": [{ title: string }, string[] | undefined];
"llm/streamChat": [
{
messages: ChatMessage[];
completionOptions: LLMFullCompletionOptions;
title: string;
messageOptions: MessageOptions;
legacySlashCommandData?: {
command: SlashCommandDescription;
input: string;
Expand Down
1 change: 1 addition & 0 deletions core/protocol/passThrough.ts
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ export const WEBVIEW_TO_CORE_PASS_THROUGH: (keyof ToCoreFromWebviewProtocol)[] =
"autocomplete/cancel",
"autocomplete/accept",
"tts/kill",
"llm/compileChat",
"llm/complete",
"llm/streamChat",
"llm/listModels",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ class MessageTypes {
"autocomplete/cancel",
"autocomplete/accept",
"tts/kill",
"llm/compileChat",
"llm/complete",
"llm/streamChat",
"llm/listModels",
Expand Down
10 changes: 10 additions & 0 deletions gui/src/pages/gui/Chat.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,9 @@ export function Chat() {
const hasDismissedExploreDialog = useAppSelector(
(state) => state.ui.hasDismissedExploreDialog,
);
const warningMessage = useAppSelector(
(state) => state.session.warningMessage,
);

useEffect(() => {
// Cmd + Backspace to delete current step
Expand Down Expand Up @@ -394,6 +397,13 @@ export function Chat() {
</ErrorBoundary>
</div>
))}
{warningMessage.length > 0 && (
<div className="relative m-2 flex justify-center rounded-md border border-solid border-red-600 bg-transparent p-4">
<p className="thread-message text-red-500">
{`Warning: ${warningMessage}`}
</p>
</div>
)}
</StepsDiv>
<div className={"relative"}>
{toolCallState?.status === "generated" && <ToolCallButtons />}
Expand Down
7 changes: 7 additions & 0 deletions gui/src/redux/slices/sessionSlice.ts
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ type SessionState = {
};
activeToolStreamId?: [string, string];
newestToolbarPreviewForInput: Record<string, string>;
warningMessage: string;
};

function isCodeToEditEqual(a: CodeToEdit, b: CodeToEdit) {
Expand Down Expand Up @@ -94,6 +95,7 @@ const initialState: SessionState = {
},
lastSessionId: undefined,
newestToolbarPreviewForInput: {},
warningMessage: "",
};

export const sessionSlice = createSlice({
Expand Down Expand Up @@ -234,6 +236,7 @@ export const sessionSlice = createSlice({
deleteMessage: (state, action: PayloadAction<number>) => {
// Deletes the current assistant message and the previous user message
state.history.splice(action.payload - 1, 2);
state.warningMessage = "";
},
updateHistoryItemAtIndex: (
state,
Expand Down Expand Up @@ -679,6 +682,9 @@ export const sessionSlice = createSlice({
state.newestToolbarPreviewForInput[payload.inputId] =
payload.contextItemId;
},
setWarningMessage: (state, action: PayloadAction<string>) => {
state.warningMessage = action.payload;
},
},
selectors: {
selectIsGatheringContext: (state) => {
Expand Down Expand Up @@ -780,6 +786,7 @@ export const {
deleteSessionMetadata,
setNewestToolbarPreviewForInput,
cycleMode,
setWarningMessage,
} = sessionSlice.actions;

export const {
Expand Down
34 changes: 32 additions & 2 deletions gui/src/redux/thunks/streamNormalInput.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@ import {
addPromptCompletionPair,
selectUseTools,
setToolGenerated,
streamUpdate
setWarningMessage,
streamUpdate,
} from "../slices/sessionSlice";
import { ThunkApiType } from "../store";
import { callTool } from "./callTool";
Expand Down Expand Up @@ -39,6 +40,34 @@ export const streamNormalInput = createAsyncThunk<
}
const includeTools = useTools && modelSupportsTools(defaultModel);

const res = await extra.ideMessenger.request("llm/compileChat", {
title: defaultModel.title,
messages,
options: includeTools
? {
tools: state.config.config.tools.filter(
(tool) =>
toolSettings[tool.function.name] !== "disabled" &&
toolGroupSettings[tool.group] !== "exclude",
),
}
: {},
});

if (res.status === "error") {
throw new Error(res.error);
}

const { compiledChatMessages, lastMessageTruncated } = res.content;

if (lastMessageTruncated) {
dispatch(
setWarningMessage(
"The provided context items are too large. They have been truncated to fit within the model's context length.",
),
);
}

// Send request
const gen = extra.ideMessenger.llmStreamChat(
{
Expand All @@ -52,8 +81,9 @@ export const streamNormalInput = createAsyncThunk<
}
: {},
title: defaultModel.title,
messages,
messages: compiledChatMessages,
legacySlashCommandData,
messageOptions: { precompiled: true },
},
streamAborter.signal,
);
Expand Down
Loading