@@ -9,7 +9,10 @@ use axum::{
9
9
use axum_extra:: TypedHeader ;
10
10
use futures:: { Stream , StreamExt } ;
11
11
use hyper:: StatusCode ;
12
- use tabby_common:: axum:: MaybeUser ;
12
+ use tabby_common:: {
13
+ api:: event:: { Event as LoggerEvent , EventLogger } ,
14
+ axum:: MaybeUser ,
15
+ } ;
13
16
use tabby_inference:: ChatCompletionStream ;
14
17
use tracing:: { error, instrument, warn} ;
15
18
@@ -32,17 +35,23 @@ pub async fn chat_completions_utoipa(_request: Json<serde_json::Value>) -> Statu
32
35
unimplemented ! ( )
33
36
}
34
37
38
+ pub struct ChatState {
39
+ pub chat_completion : Arc < dyn ChatCompletionStream > ,
40
+ pub logger : Arc < dyn EventLogger > ,
41
+ }
42
+
35
43
#[ instrument( skip( state, request) ) ]
36
44
pub async fn chat_completions (
37
- State ( state) : State < Arc < dyn ChatCompletionStream > > ,
45
+ State ( state) : State < Arc < ChatState > > ,
38
46
TypedHeader ( MaybeUser ( user) ) : TypedHeader < MaybeUser > ,
39
47
Json ( mut request) : Json < async_openai_alt:: types:: CreateChatCompletionRequest > ,
40
48
) -> Result < Sse < impl Stream < Item = Result < Event , anyhow:: Error > > > , StatusCode > {
41
49
if let Some ( user) = user {
42
50
request. user . replace ( user) ;
43
51
}
52
+ let user = request. user . clone ( ) ;
44
53
45
- let s = match state. chat_stream ( request) . await {
54
+ let s = match state. chat_completion . chat_stream ( request) . await {
46
55
Ok ( s) => s,
47
56
Err ( err) => {
48
57
warn ! ( "Error happens during chat completion: {}" , err) ;
@@ -71,5 +80,7 @@ pub async fn chat_completions(
71
80
}
72
81
} ;
73
82
83
+ state. logger . log ( user, LoggerEvent :: ChatCompletion { } ) ;
84
+
74
85
Ok ( Sse :: new ( s) . keep_alive ( KeepAlive :: default ( ) ) )
75
86
}
0 commit comments