diff --git a/core/core.ts b/core/core.ts index f948491fc1..8f50d90e7f 100644 --- a/core/core.ts +++ b/core/core.ts @@ -474,6 +474,18 @@ export class Core { } }); + on("llm/compileChat", async (msg) => { + const { title: modelName, messages, options } = msg.data; + const model = await this.configHandler.llmFromTitle(modelName); + + const { compiledChatMessages, lastMessageTruncated } = + model.compileChatMessages(options, messages); + return { + compiledChatMessages, + lastMessageTruncated, + }; + }); + on("llm/streamChat", (msg) => llmStreamChat( this.configHandler, diff --git a/core/index.d.ts b/core/index.d.ts index dd96d16995..48cefeb6ca 100644 --- a/core/index.d.ts +++ b/core/index.d.ts @@ -132,6 +132,7 @@ export interface ILLM extends LLMOptions { messages: ChatMessage[], signal: AbortSignal, options?: LLMFullCompletionOptions, + messageOptions?: MessageOptions, ): AsyncGenerator; chat( @@ -140,6 +141,11 @@ export interface ILLM extends LLMOptions { options?: LLMFullCompletionOptions, ): Promise; + compileChatMessages( + options: LLMFullCompletionOptions, + messages: ChatMessage[], + ): { compiledChatMessages: ChatMEssage[]; lastMessageTruncated: boolean }; + embed(chunks: string[]): Promise; rerank(query: string, chunks: Chunk[]): Promise; @@ -1388,3 +1394,7 @@ export interface TerminalOptions { reuseTerminal?: boolean; terminalName?: string; } + +export interface MessageOptions { + precompiled: boolean; +} diff --git a/core/llm/countTokens.ts b/core/llm/countTokens.ts index 2fc01f9b8b..ac1978fff0 100644 --- a/core/llm/countTokens.ts +++ b/core/llm/countTokens.ts @@ -251,13 +251,15 @@ function pruneChatHistory( chatHistory: ChatMessage[], contextLength: number, tokensForCompletion: number, -): ChatMessage[] { +): { prunedChatHistory: ChatMessage[]; lastMessageTruncated: boolean } { let totalTokens = tokensForCompletion + chatHistory.reduce((acc, message) => { return acc + countChatMessageTokens(modelName, message); }, 0); + let lastMessageTruncated = false; + // 0. Prune any messages that take up more than 1/3 of the context length const zippedMessagesAndTokens: [ChatMessage, number][] = []; @@ -363,8 +365,10 @@ function pruneChatHistory( tokensForCompletion, ); totalTokens = contextLength; + lastMessageTruncated = true; } - return chatHistory; + + return { prunedChatHistory: chatHistory, lastMessageTruncated }; } function messageIsEmpty(message: ChatMessage): boolean { @@ -527,7 +531,7 @@ function compileChatMessages({ functions: any[] | undefined; systemMessage: string | undefined; rules: Rule[]; -}): ChatMessage[] { +}): { compiledChatMessages: ChatMessage[]; lastMessageTruncated: boolean } { let msgsCopy = msgs ? msgs .map((msg) => ({ ...msg })) @@ -581,7 +585,7 @@ function compileChatMessages({ } } - const history = pruneChatHistory( + const { prunedChatHistory: history, lastMessageTruncated } = pruneChatHistory( modelName, msgsCopy, contextLength, @@ -595,7 +599,7 @@ function compileChatMessages({ const flattenedHistory = flattenMessages(history); - return flattenedHistory; + return { compiledChatMessages: flattenedHistory, lastMessageTruncated }; } export { diff --git a/core/llm/index.ts b/core/llm/index.ts index cc7d7123cf..c7aaaf59b8 100644 --- a/core/llm/index.ts +++ b/core/llm/index.ts @@ -18,6 +18,7 @@ import { ILLM, LLMFullCompletionOptions, LLMOptions, + MessageOptions, ModelCapability, ModelInstaller, PromptLog, @@ -738,6 +739,15 @@ export abstract class BaseLLM implements ILLM { return { role: "assistant" as const, content: completion }; } + compileChatMessages( + options: LLMFullCompletionOptions, + messages: ChatMessage[], + ) { + let { completionOptions } = this._parseCompletionOptions(options); + completionOptions = this._modifyCompletionOptions(completionOptions); + return this._compileChatMessages(completionOptions, messages); + } + protected modifyChatBody( body: ChatCompletionCreateParams, ): ChatCompletionCreateParams { @@ -762,13 +772,23 @@ export abstract class BaseLLM implements ILLM { _messages: ChatMessage[], signal: AbortSignal, options: LLMFullCompletionOptions = {}, + messageOptions?: MessageOptions, ): AsyncGenerator { let { completionOptions, logEnabled } = this._parseCompletionOptions(options); completionOptions = this._modifyCompletionOptions(completionOptions); - const messages = this._compileChatMessages(completionOptions, _messages); + const { precompiled } = messageOptions ?? {}; + + let messages = _messages; + if (!precompiled) { + const { compiledChatMessages } = this._compileChatMessages( + completionOptions, + _messages, + ); + messages = compiledChatMessages; + } const prompt = this.templateMessages ? this.templateMessages(messages) diff --git a/core/llm/streamChat.ts b/core/llm/streamChat.ts index 9b1ea87b3a..17332ce1dc 100644 --- a/core/llm/streamChat.ts +++ b/core/llm/streamChat.ts @@ -23,8 +23,13 @@ export async function* llmStreamChat( void TTS.kill(); } - const { title, legacySlashCommandData, completionOptions, messages } = - msg.data; + const { + title, + legacySlashCommandData, + completionOptions, + messages, + messageOptions, + } = msg.data; const model = await configHandler.llmFromTitle(title); @@ -113,6 +118,7 @@ export async function* llmStreamChat( messages, new AbortController().signal, completionOptions, + messageOptions, ); let next = await gen.next(); while (!next.done) { diff --git a/core/protocol/core.ts b/core/protocol/core.ts index 0213cd0e55..f043942e15 100644 --- a/core/protocol/core.ts +++ b/core/protocol/core.ts @@ -20,6 +20,7 @@ import type { FileSymbolMap, IdeSettings, LLMFullCompletionOptions, + MessageOptions, ModelDescription, PromptLog, RangeInFile, @@ -115,12 +116,21 @@ export type ToCoreFromIdeOrWebviewProtocol = { }, string, ]; + "llm/compileChat": [ + { + title: string; + messages: ChatMessage[]; + options: LLMFullCompletionOptions; + }, + { compiledChatMessages: ChatMessage[]; lastMessageTruncated: boolean }, + ]; "llm/listModels": [{ title: string }, string[] | undefined]; "llm/streamChat": [ { messages: ChatMessage[]; completionOptions: LLMFullCompletionOptions; title: string; + messageOptions: MessageOptions; legacySlashCommandData?: { command: SlashCommandDescription; input: string; diff --git a/core/protocol/passThrough.ts b/core/protocol/passThrough.ts index 272a6d148b..cd22784156 100644 --- a/core/protocol/passThrough.ts +++ b/core/protocol/passThrough.ts @@ -35,6 +35,7 @@ export const WEBVIEW_TO_CORE_PASS_THROUGH: (keyof ToCoreFromWebviewProtocol)[] = "autocomplete/cancel", "autocomplete/accept", "tts/kill", + "llm/compileChat", "llm/complete", "llm/streamChat", "llm/listModels", diff --git a/extensions/intellij/src/main/kotlin/com/github/continuedev/continueintellijextension/constants/MessageTypes.kt b/extensions/intellij/src/main/kotlin/com/github/continuedev/continueintellijextension/constants/MessageTypes.kt index 60d2d3b338..083dbf9c8a 100644 --- a/extensions/intellij/src/main/kotlin/com/github/continuedev/continueintellijextension/constants/MessageTypes.kt +++ b/extensions/intellij/src/main/kotlin/com/github/continuedev/continueintellijextension/constants/MessageTypes.kt @@ -101,6 +101,7 @@ class MessageTypes { "autocomplete/cancel", "autocomplete/accept", "tts/kill", + "llm/compileChat", "llm/complete", "llm/streamChat", "llm/listModels", diff --git a/gui/src/pages/gui/Chat.tsx b/gui/src/pages/gui/Chat.tsx index 682671cbb1..06a3aa4b45 100644 --- a/gui/src/pages/gui/Chat.tsx +++ b/gui/src/pages/gui/Chat.tsx @@ -132,6 +132,9 @@ export function Chat() { const hasDismissedExploreDialog = useAppSelector( (state) => state.ui.hasDismissedExploreDialog, ); + const warningMessage = useAppSelector( + (state) => state.session.warningMessage, + ); useEffect(() => { // Cmd + Backspace to delete current step @@ -394,6 +397,13 @@ export function Chat() { ))} + {warningMessage.length > 0 && ( +
+

+ {`Warning: ${warningMessage}`} +

+
+ )}
{toolCallState?.status === "generated" && } diff --git a/gui/src/redux/slices/sessionSlice.ts b/gui/src/redux/slices/sessionSlice.ts index 0906fdd451..c426855131 100644 --- a/gui/src/redux/slices/sessionSlice.ts +++ b/gui/src/redux/slices/sessionSlice.ts @@ -56,6 +56,7 @@ type SessionState = { }; activeToolStreamId?: [string, string]; newestToolbarPreviewForInput: Record; + warningMessage: string; }; function isCodeToEditEqual(a: CodeToEdit, b: CodeToEdit) { @@ -94,6 +95,7 @@ const initialState: SessionState = { }, lastSessionId: undefined, newestToolbarPreviewForInput: {}, + warningMessage: "", }; export const sessionSlice = createSlice({ @@ -234,6 +236,7 @@ export const sessionSlice = createSlice({ deleteMessage: (state, action: PayloadAction) => { // Deletes the current assistant message and the previous user message state.history.splice(action.payload - 1, 2); + state.warningMessage = ""; }, updateHistoryItemAtIndex: ( state, @@ -679,6 +682,9 @@ export const sessionSlice = createSlice({ state.newestToolbarPreviewForInput[payload.inputId] = payload.contextItemId; }, + setWarningMessage: (state, action: PayloadAction) => { + state.warningMessage = action.payload; + }, }, selectors: { selectIsGatheringContext: (state) => { @@ -780,6 +786,7 @@ export const { deleteSessionMetadata, setNewestToolbarPreviewForInput, cycleMode, + setWarningMessage, } = sessionSlice.actions; export const { diff --git a/gui/src/redux/thunks/streamNormalInput.ts b/gui/src/redux/thunks/streamNormalInput.ts index afc80fbe46..05e9a9ac7b 100644 --- a/gui/src/redux/thunks/streamNormalInput.ts +++ b/gui/src/redux/thunks/streamNormalInput.ts @@ -9,7 +9,8 @@ import { addPromptCompletionPair, selectUseTools, setToolGenerated, - streamUpdate + setWarningMessage, + streamUpdate, } from "../slices/sessionSlice"; import { ThunkApiType } from "../store"; import { callTool } from "./callTool"; @@ -39,6 +40,34 @@ export const streamNormalInput = createAsyncThunk< } const includeTools = useTools && modelSupportsTools(defaultModel); + const res = await extra.ideMessenger.request("llm/compileChat", { + title: defaultModel.title, + messages, + options: includeTools + ? { + tools: state.config.config.tools.filter( + (tool) => + toolSettings[tool.function.name] !== "disabled" && + toolGroupSettings[tool.group] !== "exclude", + ), + } + : {}, + }); + + if (res.status === "error") { + throw new Error(res.error); + } + + const { compiledChatMessages, lastMessageTruncated } = res.content; + + if (lastMessageTruncated) { + dispatch( + setWarningMessage( + "The provided context items are too large. They have been truncated to fit within the model's context length.", + ), + ); + } + // Send request const gen = extra.ideMessenger.llmStreamChat( { @@ -52,8 +81,9 @@ export const streamNormalInput = createAsyncThunk< } : {}, title: defaultModel.title, - messages, + messages: compiledChatMessages, legacySlashCommandData, + messageOptions: { precompiled: true }, }, streamAborter.signal, );