Skip to content

Commit 462990b

Browse files
authored
fix(chat): process openai request based on model name, drop penalty if is reasoning models (#4049)
1 parent 2f626b9 commit 462990b

File tree

2 files changed

+26
-28
lines changed

2 files changed

+26
-28
lines changed

crates/http-api-bindings/src/chat/mod.rs

+1-4
Original file line numberDiff line numberDiff line change
@@ -33,13 +33,10 @@ pub async fn create(model: &HttpModelConfig) -> Arc<dyn ChatCompletionStream> {
3333
let mut builder = ExtendedOpenAIConfig::builder();
3434
builder
3535
.base(config)
36+
.kind(model.kind.clone())
3637
.supported_models(model.supported_models.clone())
3738
.model_name(model.model_name.as_deref().expect("Model name is required"));
3839

39-
if model.kind == "mistral/chat" {
40-
builder.fields_to_remove(ExtendedOpenAIConfig::mistral_fields_to_remove());
41-
}
42-
4340
Box::new(
4441
async_openai_alt::Client::with_config(
4542
builder.build().expect("Failed to build config"),

crates/tabby-inference/src/chat.rs

+25-24
Original file line numberDiff line numberDiff line change
@@ -22,38 +22,25 @@ pub trait ChatCompletionStream: Sync + Send {
2222
) -> Result<ChatCompletionResponseStream, OpenAIError>;
2323
}
2424

25-
#[derive(Clone)]
26-
pub enum OpenAIRequestFieldEnum {
27-
PresencePenalty,
28-
User,
29-
}
30-
3125
#[derive(Builder, Clone)]
3226
pub struct ExtendedOpenAIConfig {
27+
#[builder(default)]
28+
kind: String,
29+
3330
base: OpenAIConfig,
3431

3532
#[builder(setter(into))]
3633
model_name: String,
3734

3835
#[builder(setter(into))]
3936
supported_models: Option<Vec<String>>,
40-
41-
#[builder(default)]
42-
fields_to_remove: Vec<OpenAIRequestFieldEnum>,
4337
}
4438

4539
impl ExtendedOpenAIConfig {
4640
pub fn builder() -> ExtendedOpenAIConfigBuilder {
4741
ExtendedOpenAIConfigBuilder::default()
4842
}
4943

50-
pub fn mistral_fields_to_remove() -> Vec<OpenAIRequestFieldEnum> {
51-
vec![
52-
OpenAIRequestFieldEnum::PresencePenalty,
53-
OpenAIRequestFieldEnum::User,
54-
]
55-
}
56-
5744
fn process_request(
5845
&self,
5946
mut request: CreateChatCompletionRequest,
@@ -70,21 +57,33 @@ impl ExtendedOpenAIConfig {
7057
}
7158
}
7259

73-
for field in &self.fields_to_remove {
74-
match field {
75-
OpenAIRequestFieldEnum::PresencePenalty => {
76-
request.presence_penalty = None;
77-
}
78-
OpenAIRequestFieldEnum::User => {
79-
request.user = None;
80-
}
60+
match self.kind.as_str() {
61+
"mistral/chat" => {
62+
request.presence_penalty = None;
63+
request.user = None;
64+
}
65+
"openai/chat" => {
66+
request = process_request_openai(request);
8167
}
68+
_ => {}
8269
}
8370

8471
request
8572
}
8673
}
8774

75+
fn process_request_openai(request: CreateChatCompletionRequest) -> CreateChatCompletionRequest {
76+
let mut request = request;
77+
78+
// Check for specific O-series model prefixes
79+
if request.model.starts_with("o1") || request.model.starts_with("o3-mini") {
80+
request.presence_penalty = None;
81+
request.frequency_penalty = None;
82+
}
83+
84+
request
85+
}
86+
8887
impl async_openai_alt::config::Config for ExtendedOpenAIConfig {
8988
fn headers(&self) -> reqwest::header::HeaderMap {
9089
self.base.headers()
@@ -132,13 +131,15 @@ impl ChatCompletionStream for async_openai_alt::Client<async_openai_alt::config:
132131
&self,
133132
request: CreateChatCompletionRequest,
134133
) -> Result<CreateChatCompletionResponse, OpenAIError> {
134+
let request = process_request_openai(request);
135135
self.chat().create(request).await
136136
}
137137

138138
async fn chat_stream(
139139
&self,
140140
request: CreateChatCompletionRequest,
141141
) -> Result<ChatCompletionResponseStream, OpenAIError> {
142+
let request = process_request_openai(request);
142143
self.chat().create_stream(request).await
143144
}
144145
}

0 commit comments

Comments
 (0)