@@ -22,38 +22,25 @@ pub trait ChatCompletionStream: Sync + Send {
22
22
) -> Result < ChatCompletionResponseStream , OpenAIError > ;
23
23
}
24
24
25
- #[ derive( Clone ) ]
26
- pub enum OpenAIRequestFieldEnum {
27
- PresencePenalty ,
28
- User ,
29
- }
30
-
31
25
#[ derive( Builder , Clone ) ]
32
26
pub struct ExtendedOpenAIConfig {
27
+ #[ builder( default ) ]
28
+ kind : String ,
29
+
33
30
base : OpenAIConfig ,
34
31
35
32
#[ builder( setter( into) ) ]
36
33
model_name : String ,
37
34
38
35
#[ builder( setter( into) ) ]
39
36
supported_models : Option < Vec < String > > ,
40
-
41
- #[ builder( default ) ]
42
- fields_to_remove : Vec < OpenAIRequestFieldEnum > ,
43
37
}
44
38
45
39
impl ExtendedOpenAIConfig {
46
40
pub fn builder ( ) -> ExtendedOpenAIConfigBuilder {
47
41
ExtendedOpenAIConfigBuilder :: default ( )
48
42
}
49
43
50
- pub fn mistral_fields_to_remove ( ) -> Vec < OpenAIRequestFieldEnum > {
51
- vec ! [
52
- OpenAIRequestFieldEnum :: PresencePenalty ,
53
- OpenAIRequestFieldEnum :: User ,
54
- ]
55
- }
56
-
57
44
fn process_request (
58
45
& self ,
59
46
mut request : CreateChatCompletionRequest ,
@@ -70,21 +57,33 @@ impl ExtendedOpenAIConfig {
70
57
}
71
58
}
72
59
73
- for field in & self . fields_to_remove {
74
- match field {
75
- OpenAIRequestFieldEnum :: PresencePenalty => {
76
- request. presence_penalty = None ;
77
- }
78
- OpenAIRequestFieldEnum :: User => {
79
- request. user = None ;
80
- }
60
+ match self . kind . as_str ( ) {
61
+ "mistral/chat" => {
62
+ request. presence_penalty = None ;
63
+ request. user = None ;
64
+ }
65
+ "openai/chat" => {
66
+ request = process_request_openai ( request) ;
81
67
}
68
+ _ => { }
82
69
}
83
70
84
71
request
85
72
}
86
73
}
87
74
75
+ fn process_request_openai ( request : CreateChatCompletionRequest ) -> CreateChatCompletionRequest {
76
+ let mut request = request;
77
+
78
+ // Check for specific O-series model prefixes
79
+ if request. model . starts_with ( "o1" ) || request. model . starts_with ( "o3-mini" ) {
80
+ request. presence_penalty = None ;
81
+ request. frequency_penalty = None ;
82
+ }
83
+
84
+ request
85
+ }
86
+
88
87
impl async_openai_alt:: config:: Config for ExtendedOpenAIConfig {
89
88
fn headers ( & self ) -> reqwest:: header:: HeaderMap {
90
89
self . base . headers ( )
@@ -132,13 +131,15 @@ impl ChatCompletionStream for async_openai_alt::Client<async_openai_alt::config:
132
131
& self ,
133
132
request : CreateChatCompletionRequest ,
134
133
) -> Result < CreateChatCompletionResponse , OpenAIError > {
134
+ let request = process_request_openai ( request) ;
135
135
self . chat ( ) . create ( request) . await
136
136
}
137
137
138
138
async fn chat_stream (
139
139
& self ,
140
140
request : CreateChatCompletionRequest ,
141
141
) -> Result < ChatCompletionResponseStream , OpenAIError > {
142
+ let request = process_request_openai ( request) ;
142
143
self . chat ( ) . create_stream ( request) . await
143
144
}
144
145
}
0 commit comments