Skip to content

Commit 28a5328

Browse files
antimonyGuautofix-ci[bot]wsxiaoysliangfung
authored
feat(ui): allow selecting model in answer engine (#3304)
* llm select FE sketch * fetch model array & select model drop down * finish model select function & to be polished * fix selectedModel init value and style * fix lint * rename DropdownMenuItems * refine feature toggle * chore(answer): set chat model's name by selection * refine model select dropdown position * [autofix.ci] apply automated fixes * properly fill model name * rename selectedModelName to modelName * uplift modelName state * clean log * format code * fix lint * [autofix.ci] apply automated fixes * [autofix.ci] apply automated fixes (attempt 2/3) * [autofix.ci] apply automated fixes (attempt 3/3) * feat(ui): show tool bar in search page * fix: fix ui test * refactor(chart-ui): persist modelName and show modelName selection in followup * handle selectedModel not in models api * fix: throw warning when request model is not in supported_models * fix: fix modelInfo?.chat not supported case and check request.mode is supported in BE * fix: using warn! to print warning * fix: fix ui lint * update: ajust style and check if model is valid * [autofix.ci] apply automated fixes * [autofix.ci] apply automated fixes (attempt 2/3) --------- Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> Co-authored-by: Meng Zhang <[email protected]> Co-authored-by: liangfung <[email protected]>
1 parent 8e29952 commit 28a5328

File tree

17 files changed

+446
-147
lines changed

17 files changed

+446
-147
lines changed

Diff for: Cargo.lock

+1
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Diff for: crates/http-api-bindings/src/chat/mod.rs

+1
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ pub async fn create(model: &HttpModelConfig) -> Arc<dyn ChatCompletionStream> {
1818
let mut builder = ExtendedOpenAIConfig::builder();
1919
builder
2020
.base(config)
21+
.supported_models(model.supported_models.clone())
2122
.model_name(model.model_name.as_deref().expect("Model name is required"));
2223

2324
if model.kind == "openai/chat" {

Diff for: crates/tabby-inference/Cargo.toml

+1
Original file line numberDiff line numberDiff line change
@@ -19,3 +19,4 @@ trie-rs = "0.1.1"
1919
async-openai.workspace = true
2020
secrecy = "0.8"
2121
reqwest.workspace = true
22+
tracing.workspace = true

Diff for: crates/tabby-inference/src/chat.rs

+15-1
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ use async_openai::{
77
};
88
use async_trait::async_trait;
99
use derive_builder::Builder;
10+
use tracing::warn;
1011

1112
#[async_trait]
1213
pub trait ChatCompletionStream: Sync + Send {
@@ -34,6 +35,9 @@ pub struct ExtendedOpenAIConfig {
3435
#[builder(setter(into))]
3536
model_name: String,
3637

38+
#[builder(setter(into))]
39+
supported_models: Option<Vec<String>>,
40+
3741
#[builder(default)]
3842
fields_to_remove: Vec<OpenAIRequestFieldEnum>,
3943
}
@@ -54,7 +58,17 @@ impl ExtendedOpenAIConfig {
5458
&self,
5559
mut request: CreateChatCompletionRequest,
5660
) -> CreateChatCompletionRequest {
57-
request.model = self.model_name.clone();
61+
if request.model.is_empty() {
62+
request.model = self.model_name.clone();
63+
} else if let Some(supported_models) = &self.supported_models {
64+
if !supported_models.contains(&request.model) {
65+
warn!(
66+
"Warning: {} model is not supported, falling back to {}",
67+
request.model, self.model_name
68+
);
69+
request.model = self.model_name.clone();
70+
}
71+
}
5872

5973
for field in &self.fields_to_remove {
6074
match field {

Diff for: ee/tabby-schema/graphql/schema.graphql

+3-2
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,7 @@ input CreateMessageInput {
134134

135135
input CreateThreadAndRunInput {
136136
thread: CreateThreadInput!
137-
options: ThreadRunOptionsInput! = {codeQuery: null, debugOptions: null, docQuery: null, generateRelevantQuestions: false}
137+
options: ThreadRunOptionsInput! = {codeQuery: null, debugOptions: null, docQuery: null, generateRelevantQuestions: false, modelName: null}
138138
}
139139

140140
input CreateThreadInput {
@@ -144,7 +144,7 @@ input CreateThreadInput {
144144
input CreateThreadRunInput {
145145
threadId: ID!
146146
additionalUserMessage: CreateMessageInput!
147-
options: ThreadRunOptionsInput! = {codeQuery: null, debugOptions: null, docQuery: null, generateRelevantQuestions: false}
147+
options: ThreadRunOptionsInput! = {codeQuery: null, debugOptions: null, docQuery: null, generateRelevantQuestions: false, modelName: null}
148148
}
149149

150150
input CreateUserGroupInput {
@@ -216,6 +216,7 @@ input ThreadRunDebugOptionsInput {
216216
}
217217

218218
input ThreadRunOptionsInput {
219+
modelName: String = null
219220
docQuery: DocQueryInput = null
220221
codeQuery: CodeQueryInput = null
221222
generateRelevantQuestions: Boolean! = false

Diff for: ee/tabby-schema/src/schema/thread/inputs.rs

+3
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,9 @@ fn validate_code_query_input(input: &CodeQueryInput) -> Result<(), ValidationErr
7070

7171
#[derive(GraphQLInputObject, Validate, Default, Clone)]
7272
pub struct ThreadRunOptionsInput {
73+
#[graphql(default)]
74+
pub model_name: Option<String>,
75+
7376
#[validate(nested)]
7477
#[graphql(default)]
7578
pub doc_query: Option<DocQueryInput>,

Diff for: ee/tabby-ui/app/(home)/page.tsx

+12
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,10 @@ import { useQuery } from 'urql'
99
import { SESSION_STORAGE_KEY } from '@/lib/constants'
1010
import { useHealth } from '@/lib/hooks/use-health'
1111
import { useMe } from '@/lib/hooks/use-me'
12+
import { useSelectedModel } from '@/lib/hooks/use-models'
1213
import { useIsChatEnabled } from '@/lib/hooks/use-server-info'
1314
import { useStore } from '@/lib/hooks/use-store'
15+
import { updateSelectedModel } from '@/lib/stores/chat-actions'
1416
import {
1517
clearHomeScrollPosition,
1618
setHomeScrollPosition,
@@ -47,6 +49,8 @@ function MainPanel() {
4749
})
4850
const scrollY = useStore(useScrollStore, state => state.homePage)
4951

52+
const { selectedModel, isModelLoading, models } = useSelectedModel()
53+
5054
// Prefetch the search page
5155
useEffect(() => {
5256
router.prefetch('/search')
@@ -69,6 +73,10 @@ function MainPanel() {
6973
resettingScroller.current = true
7074
}, [])
7175

76+
const handleSelectModel = (model: string) => {
77+
updateSelectedModel(model)
78+
}
79+
7280
if (!healthInfo || !data?.me) return <></>
7381

7482
const onSearch = (question: string, ctx?: ThreadRunContexts) => {
@@ -138,6 +146,10 @@ function MainPanel() {
138146
cleanAfterSearch={false}
139147
contextInfo={contextInfoData?.contextInfo}
140148
fetchingContextInfo={fetchingContextInfo}
149+
modelName={selectedModel}
150+
onModelSelect={handleSelectModel}
151+
isModelLoading={isModelLoading}
152+
models={models}
141153
/>
142154
</AnimationWrapper>
143155
)}

Diff for: ee/tabby-ui/app/search/components/assistant-message-section.tsx

-2
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
11
'use client'
22

3-
import './search.css'
4-
53
import { MouseEventHandler, useContext, useMemo, useState } from 'react'
64
import { zodResolver } from '@hookform/resolvers/zod'
75
import DOMPurify from 'dompurify'

Diff for: ee/tabby-ui/app/search/components/search.css

-3
This file was deleted.

Diff for: ee/tabby-ui/app/search/components/search.tsx

+47-36
Original file line numberDiff line numberDiff line change
@@ -8,18 +8,48 @@ import {
88
useRef,
99
useState
1010
} from 'react'
11+
import Link from 'next/link'
1112
import { useRouter } from 'next/navigation'
13+
import slugify from '@sindresorhus/slugify'
14+
import { compact, pick, some, uniq, uniqBy } from 'lodash-es'
1215
import { nanoid } from 'nanoid'
16+
import { ImperativePanelHandle } from 'react-resizable-panels'
17+
import { toast } from 'sonner'
18+
import { useQuery } from 'urql'
1319

1420
import {
1521
ERROR_CODE_NOT_FOUND,
1622
SESSION_STORAGE_KEY,
1723
SLUG_TITLE_MAX_LENGTH
1824
} from '@/lib/constants'
1925
import { useEnableDeveloperMode } from '@/lib/experiment-flags'
26+
import { graphql } from '@/lib/gql/generates'
27+
import {
28+
CodeQueryInput,
29+
ContextInfo,
30+
DocQueryInput,
31+
InputMaybe,
32+
Maybe,
33+
Message,
34+
Role
35+
} from '@/lib/gql/generates/graphql'
36+
import { useCopyToClipboard } from '@/lib/hooks/use-copy-to-clipboard'
2037
import { useCurrentTheme } from '@/lib/hooks/use-current-theme'
38+
import { useDebounceValue } from '@/lib/hooks/use-debounce'
2139
import { useLatest } from '@/lib/hooks/use-latest'
40+
import { useMe } from '@/lib/hooks/use-me'
41+
import { useSelectedModel } from '@/lib/hooks/use-models'
42+
import useRouterStuff from '@/lib/hooks/use-router-stuff'
2243
import { useIsChatEnabled } from '@/lib/hooks/use-server-info'
44+
import { ExtendedCombinedError, useThreadRun } from '@/lib/hooks/use-thread-run'
45+
import { updateSelectedModel } from '@/lib/stores/chat-actions'
46+
import { clearHomeScrollPosition } from '@/lib/stores/scroll-store'
47+
import { useMutation } from '@/lib/tabby/gql'
48+
import {
49+
contextInfoQuery,
50+
listThreadMessages,
51+
listThreads
52+
} from '@/lib/tabby/query'
2353
import {
2454
AttachmentCodeItem,
2555
AttachmentDocItem,
@@ -46,6 +76,7 @@ import {
4676
ResizablePanelGroup
4777
} from '@/components/ui/resizable'
4878
import { ScrollArea } from '@/components/ui/scroll-area'
79+
import { Separator } from '@/components/ui/separator'
4980
import { ButtonScrollToBottom } from '@/components/button-scroll-to-bottom'
5081
import { ClientOnly } from '@/components/client-only'
5182
import { BANNER_HEIGHT, useShowDemoBanner } from '@/components/demo-banner'
@@ -54,39 +85,6 @@ import { ThemeToggle } from '@/components/theme-toggle'
5485
import { MyAvatar } from '@/components/user-avatar'
5586
import UserPanel from '@/components/user-panel'
5687

57-
import './search.css'
58-
59-
import Link from 'next/link'
60-
import slugify from '@sindresorhus/slugify'
61-
import { compact, pick, some, uniq, uniqBy } from 'lodash-es'
62-
import { ImperativePanelHandle } from 'react-resizable-panels'
63-
import { toast } from 'sonner'
64-
import { useQuery } from 'urql'
65-
66-
import { graphql } from '@/lib/gql/generates'
67-
import {
68-
CodeQueryInput,
69-
ContextInfo,
70-
DocQueryInput,
71-
InputMaybe,
72-
Maybe,
73-
Message,
74-
Role
75-
} from '@/lib/gql/generates/graphql'
76-
import { useCopyToClipboard } from '@/lib/hooks/use-copy-to-clipboard'
77-
import { useDebounceValue } from '@/lib/hooks/use-debounce'
78-
import { useMe } from '@/lib/hooks/use-me'
79-
import useRouterStuff from '@/lib/hooks/use-router-stuff'
80-
import { ExtendedCombinedError, useThreadRun } from '@/lib/hooks/use-thread-run'
81-
import { clearHomeScrollPosition } from '@/lib/stores/scroll-store'
82-
import { useMutation } from '@/lib/tabby/gql'
83-
import {
84-
contextInfoQuery,
85-
listThreadMessages,
86-
listThreads
87-
} from '@/lib/tabby/query'
88-
import { Separator } from '@/components/ui/separator'
89-
9088
import { AssistantMessageSection } from './assistant-message-section'
9189
import { DevPanel } from './dev-panel'
9290
import { MessagesSkeleton } from './messages-skeleton'
@@ -319,6 +317,8 @@ export function Search() {
319317

320318
const isLoadingRef = useLatest(isLoading)
321319

320+
const { selectedModel, isModelLoading, models } = useSelectedModel()
321+
322322
const currentMessageForDev = useMemo(() => {
323323
return messages.find(item => item.id === messageIdForDev)
324324
}, [messageIdForDev, messages])
@@ -376,6 +376,7 @@ export function Search() {
376376
if (initialMessage) {
377377
sessionStorage.removeItem(SESSION_STORAGE_KEY.SEARCH_INITIAL_MSG)
378378
sessionStorage.removeItem(SESSION_STORAGE_KEY.SEARCH_INITIAL_CONTEXTS)
379+
379380
setIsReady(true)
380381
onSubmitSearch(initialMessage, initialThreadRunContext)
381382
return
@@ -571,7 +572,8 @@ export function Search() {
571572
{
572573
generateRelevantQuestions: true,
573574
codeQuery,
574-
docQuery
575+
docQuery,
576+
modelName: ctx?.modelName
575577
}
576578
)
577579
}
@@ -638,7 +640,8 @@ export function Search() {
638640
threadRunOptions: {
639641
generateRelevantQuestions: true,
640642
codeQuery,
641-
docQuery
643+
docQuery,
644+
modelName: selectedModel
642645
}
643646
})
644647
}
@@ -696,6 +699,10 @@ export function Search() {
696699
)
697700
}
698701

702+
const onModelSelect = (model: string) => {
703+
updateSelectedModel(model)
704+
}
705+
699706
const hasThreadError = useMemo(() => {
700707
if (!isReady || fetchingThread || !threadIdFromURL) return undefined
701708
if (threadError || !threadData?.threads?.edges?.length) {
@@ -867,10 +874,14 @@ export function Search() {
867874
onSearch={onSubmitSearch}
868875
className="min-h-[5rem] lg:max-w-4xl"
869876
placeholder="Ask a follow up question"
870-
isLoading={isLoading}
871877
isFollowup
878+
isLoading={isLoading}
872879
contextInfo={contextInfoData?.contextInfo}
873880
fetchingContextInfo={fetchingContextInfo}
881+
modelName={selectedModel}
882+
onModelSelect={onModelSelect}
883+
isModelLoading={isModelLoading}
884+
models={models}
874885
/>
875886
</div>
876887
)}

0 commit comments

Comments
 (0)