From ccfca83d4a59a00a2b0d665941475c33920c89f5 Mon Sep 17 00:00:00 2001 From: "switch.li" <switch.li@dbappsecurity.com.cn> Date: Fri, 14 Feb 2025 15:27:12 +0800 Subject: [PATCH 1/4] support deepseek-r1: reasoning_content --- chat.go | 71 +++++++++++++++++++++++++++----------------------- chat_stream.go | 11 ++++---- 2 files changed, 45 insertions(+), 37 deletions(-) diff --git a/chat.go b/chat.go index 995860c40..a08766527 100644 --- a/chat.go +++ b/chat.go @@ -98,6 +98,9 @@ type ChatCompletionMessage struct { Refusal string `json:"refusal,omitempty"` MultiContent []ChatMessagePart + // supported by deepseek-reasoner https://api-docs.deepseek.com/ + ReasoningContent string `json:"reasoning_content,omitempty"` + // This property isn't in the official documentation, but it's in // the documentation for the official library for python: // - https://github.com/openai/openai-python/blob/main/chatml.md @@ -119,41 +122,44 @@ func (m ChatCompletionMessage) MarshalJSON() ([]byte, error) { } if len(m.MultiContent) > 0 { msg := struct { - Role string `json:"role"` - Content string `json:"-"` - Refusal string `json:"refusal,omitempty"` - MultiContent []ChatMessagePart `json:"content,omitempty"` - Name string `json:"name,omitempty"` - FunctionCall *FunctionCall `json:"function_call,omitempty"` - ToolCalls []ToolCall `json:"tool_calls,omitempty"` - ToolCallID string `json:"tool_call_id,omitempty"` + Role string `json:"role"` + Content string `json:"-"` + Refusal string `json:"refusal,omitempty"` + MultiContent []ChatMessagePart `json:"content,omitempty"` + ReasoningContent string `json:"reasoning_content,omitempty"` + Name string `json:"name,omitempty"` + FunctionCall *FunctionCall `json:"function_call,omitempty"` + ToolCalls []ToolCall `json:"tool_calls,omitempty"` + ToolCallID string `json:"tool_call_id,omitempty"` }(m) return json.Marshal(msg) } msg := struct { - Role string `json:"role"` - Content string `json:"content,omitempty"` - Refusal string `json:"refusal,omitempty"` - MultiContent []ChatMessagePart `json:"-"` - Name string `json:"name,omitempty"` - FunctionCall *FunctionCall `json:"function_call,omitempty"` - ToolCalls []ToolCall `json:"tool_calls,omitempty"` - ToolCallID string `json:"tool_call_id,omitempty"` + Role string `json:"role"` + Content string `json:"content,omitempty"` + Refusal string `json:"refusal,omitempty"` + MultiContent []ChatMessagePart `json:"-"` + ReasoningContent string `json:"reasoning_content,omitempty"` + Name string `json:"name,omitempty"` + FunctionCall *FunctionCall `json:"function_call,omitempty"` + ToolCalls []ToolCall `json:"tool_calls,omitempty"` + ToolCallID string `json:"tool_call_id,omitempty"` }(m) return json.Marshal(msg) } func (m *ChatCompletionMessage) UnmarshalJSON(bs []byte) error { msg := struct { - Role string `json:"role"` - Content string `json:"content,omitempty"` - Refusal string `json:"refusal,omitempty"` - MultiContent []ChatMessagePart - Name string `json:"name,omitempty"` - FunctionCall *FunctionCall `json:"function_call,omitempty"` - ToolCalls []ToolCall `json:"tool_calls,omitempty"` - ToolCallID string `json:"tool_call_id,omitempty"` + Role string `json:"role"` + Content string `json:"content,omitempty"` + Refusal string `json:"refusal,omitempty"` + MultiContent []ChatMessagePart + ReasoningContent string `json:"reasoning_content,omitempty"` + Name string `json:"name,omitempty"` + FunctionCall *FunctionCall `json:"function_call,omitempty"` + ToolCalls []ToolCall `json:"tool_calls,omitempty"` + ToolCallID string `json:"tool_call_id,omitempty"` }{} if err := json.Unmarshal(bs, &msg); err == nil { @@ -161,14 +167,15 @@ func (m *ChatCompletionMessage) UnmarshalJSON(bs []byte) error { return nil } multiMsg := struct { - Role string `json:"role"` - Content string - Refusal string `json:"refusal,omitempty"` - MultiContent []ChatMessagePart `json:"content"` - Name string `json:"name,omitempty"` - FunctionCall *FunctionCall `json:"function_call,omitempty"` - ToolCalls []ToolCall `json:"tool_calls,omitempty"` - ToolCallID string `json:"tool_call_id,omitempty"` + Role string `json:"role"` + Content string + Refusal string `json:"refusal,omitempty"` + MultiContent []ChatMessagePart `json:"content"` + ReasoningContent string `json:"reasoning_content,omitempty"` + Name string `json:"name,omitempty"` + FunctionCall *FunctionCall `json:"function_call,omitempty"` + ToolCalls []ToolCall `json:"tool_calls,omitempty"` + ToolCallID string `json:"tool_call_id,omitempty"` }{} if err := json.Unmarshal(bs, &multiMsg); err != nil { return err diff --git a/chat_stream.go b/chat_stream.go index 525b4457a..06af5f928 100644 --- a/chat_stream.go +++ b/chat_stream.go @@ -6,11 +6,12 @@ import ( ) type ChatCompletionStreamChoiceDelta struct { - Content string `json:"content,omitempty"` - Role string `json:"role,omitempty"` - FunctionCall *FunctionCall `json:"function_call,omitempty"` - ToolCalls []ToolCall `json:"tool_calls,omitempty"` - Refusal string `json:"refusal,omitempty"` + Content string `json:"content,omitempty"` + Role string `json:"role,omitempty"` + FunctionCall *FunctionCall `json:"function_call,omitempty"` + ToolCalls []ToolCall `json:"tool_calls,omitempty"` + Refusal string `json:"refusal,omitempty"` + ReasoningContent string `json:"reasoning_content,omitempty"` } type ChatCompletionStreamChoiceLogprobs struct { From b37589b6b52ac8fc810028bbda51d8c556a60045 Mon Sep 17 00:00:00 2001 From: mel2oo <l977631253@gmail.com> Date: Wed, 23 Apr 2025 14:22:09 +0800 Subject: [PATCH 2/4] =?UTF-8?q?feat:=20GD=E3=80=81Lora?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- chat.go | 11 ++++++++ common.go | 2 ++ examples/chatbot/main.go | 54 ++++++++++++++++++++++++++++------------ 3 files changed, 51 insertions(+), 16 deletions(-) diff --git a/chat.go b/chat.go index a08766527..f0f9a7708 100644 --- a/chat.go +++ b/chat.go @@ -270,6 +270,17 @@ type ChatCompletionRequest struct { ReasoningEffort string `json:"reasoning_effort,omitempty"` // Metadata to store with the completion. Metadata map[string]string `json:"metadata,omitempty"` + + // reasoning for deepseek + Reasoning bool `json:"reasoning,omitempty"` + + // GD + GuidedChoice []string `json:"guided_choice,omitempty"` + GuidedRegex string `json:"guided_regex,omitempty"` + GuidedJson string `json:"guided_json,omitempty"` + + // LoraType + LoraType string `json:"lora_type,omitempty"` } type StreamOptions struct { diff --git a/common.go b/common.go index 8cc7289c0..087f1d58c 100644 --- a/common.go +++ b/common.go @@ -7,6 +7,8 @@ type Usage struct { PromptTokens int `json:"prompt_tokens"` CompletionTokens int `json:"completion_tokens"` TotalTokens int `json:"total_tokens"` + PromptCacheHitTokens int `json:"prompt_cache_hit_tokens"` + PromptCacheMissTokens int `json:"prompt_cache_miss_tokens"` PromptTokensDetails *PromptTokensDetails `json:"prompt_tokens_details"` CompletionTokensDetails *CompletionTokensDetails `json:"completion_tokens_details"` } diff --git a/examples/chatbot/main.go b/examples/chatbot/main.go index ad41e957d..2d776eab9 100644 --- a/examples/chatbot/main.go +++ b/examples/chatbot/main.go @@ -1,19 +1,27 @@ package main import ( - "bufio" "context" + "crypto/tls" "fmt" - "os" + "net/http" "github.com/sashabaranov/go-openai" ) func main() { - client := openai.NewClient(os.Getenv("OPENAI_API_KEY")) + config := openai.DefaultConfig("sk-xxxx") + config.BaseURL = "https://10.20.152.76:30002/v1" + config.HTTPClient = &http.Client{ + Transport: &http.Transport{ + DisableKeepAlives: true, + TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, + }, + } + client := openai.NewClientWithConfig(config) req := openai.ChatCompletionRequest{ - Model: openai.GPT3Dot5Turbo, + Model: "HengNao-v4", Messages: []openai.ChatCompletionMessage{ { Role: openai.ChatMessageRoleSystem, @@ -24,19 +32,33 @@ func main() { fmt.Println("Conversation") fmt.Println("---------------------") fmt.Print("> ") - s := bufio.NewScanner(os.Stdin) - for s.Scan() { - req.Messages = append(req.Messages, openai.ChatCompletionMessage{ - Role: openai.ChatMessageRoleUser, - Content: s.Text(), - }) - resp, err := client.CreateChatCompletion(context.Background(), req) + // s := bufio.NewScanner(os.Stdin) + // for s.Scan() { + req.Messages = append(req.Messages, openai.ChatCompletionMessage{ + Role: openai.ChatMessageRoleUser, + Content: "你好", + }) + resp, err := client.CreateChatCompletion(context.Background(), req) + if err != nil { + fmt.Printf("ChatCompletion error: %v\n", err) + // continue + } + fmt.Printf("%s\n\n", resp.Choices[0].Message.Content) + req.Messages = append(req.Messages, resp.Choices[0].Message) + fmt.Print("> ") + // } + + stream, err := client.CreateChatCompletionStream(context.Background(), req) + if err != nil { + return + } + + for { + evt, err := stream.Recv() if err != nil { - fmt.Printf("ChatCompletion error: %v\n", err) - continue + return } - fmt.Printf("%s\n\n", resp.Choices[0].Message.Content) - req.Messages = append(req.Messages, resp.Choices[0].Message) - fmt.Print("> ") + + fmt.Printf("%s", evt.Choices[0].Delta.Content) } } From 4e1f139b6f0b7a1effe2e831f1b464ac323924f1 Mon Sep 17 00:00:00 2001 From: mel2oo <l977631253@gmail.com> Date: Thu, 24 Apr 2025 09:48:59 +0800 Subject: [PATCH 3/4] feat: support hugging face model --- chat.go | 7 ++++++- chat_stream.go | 3 ++- client.go | 6 ++++++ completion.go | 12 +++++++++++- stream.go | 3 ++- 5 files changed, 27 insertions(+), 4 deletions(-) diff --git a/chat.go b/chat.go index f0f9a7708..9b5820864 100644 --- a/chat.go +++ b/chat.go @@ -281,6 +281,10 @@ type ChatCompletionRequest struct { // LoraType LoraType string `json:"lora_type,omitempty"` + + // For Hugging Face models + TopK int `json:"top_k,omitempty"` + RepetitionPenalty float32 `json:"repetition_penalty,omitempty"` } type StreamOptions struct { @@ -399,6 +403,7 @@ type ChatCompletionResponse struct { func (c *Client) CreateChatCompletion( ctx context.Context, request ChatCompletionRequest, + opts ...requestOption, ) (response ChatCompletionResponse, err error) { if request.Stream { err = ErrChatCompletionStreamNotSupported @@ -420,7 +425,7 @@ func (c *Client) CreateChatCompletion( ctx, http.MethodPost, c.fullURL(urlSuffix, withModel(request.Model)), - withBody(request), + append(opts, withBody(request))..., ) if err != nil { return diff --git a/chat_stream.go b/chat_stream.go index 06af5f928..6eefd4857 100644 --- a/chat_stream.go +++ b/chat_stream.go @@ -73,6 +73,7 @@ type ChatCompletionStream struct { func (c *Client) CreateChatCompletionStream( ctx context.Context, request ChatCompletionRequest, + opts ...requestOption, ) (stream *ChatCompletionStream, err error) { urlSuffix := chatCompletionsSuffix if !checkEndpointSupportsModel(urlSuffix, request.Model) { @@ -90,7 +91,7 @@ func (c *Client) CreateChatCompletionStream( ctx, http.MethodPost, c.fullURL(urlSuffix, withModel(request.Model)), - withBody(request), + append(opts, withBody(request))..., ) if err != nil { return nil, err diff --git a/client.go b/client.go index cef375348..8132f6bd9 100644 --- a/client.go +++ b/client.go @@ -96,6 +96,12 @@ func withBetaAssistantVersion(version string) requestOption { } } +func WithHeader(key, value string) requestOption { + return func(ro *requestOptions) { + ro.header.Set(key, value) + } +} + func (c *Client) newRequest(ctx context.Context, method, url string, setters ...requestOption) (*http.Request, error) { // Default Options args := &requestOptions{ diff --git a/completion.go b/completion.go index 015fa2a9f..d54fe4ab0 100644 --- a/completion.go +++ b/completion.go @@ -194,6 +194,15 @@ type CompletionRequest struct { Temperature float32 `json:"temperature,omitempty"` TopP float32 `json:"top_p,omitempty"` User string `json:"user,omitempty"` + + // GD + GuidedChoice []string `json:"guided_choice,omitempty"` + GuidedRegex string `json:"guided_regex,omitempty"` + GuidedJson string `json:"guided_json,omitempty"` + + // For Hugging Face models + TopK int `json:"top_k,omitempty"` + RepetitionPenalty float32 `json:"repetition_penalty,omitempty"` } // CompletionChoice represents one of possible completions. @@ -232,6 +241,7 @@ type CompletionResponse struct { func (c *Client) CreateCompletion( ctx context.Context, request CompletionRequest, + opts ...requestOption, ) (response CompletionResponse, err error) { if request.Stream { err = ErrCompletionStreamNotSupported @@ -253,7 +263,7 @@ func (c *Client) CreateCompletion( ctx, http.MethodPost, c.fullURL(urlSuffix, withModel(request.Model)), - withBody(request), + append(opts, withBody(request))..., ) if err != nil { return diff --git a/stream.go b/stream.go index a61c7c970..8e59b7484 100644 --- a/stream.go +++ b/stream.go @@ -21,6 +21,7 @@ type CompletionStream struct { func (c *Client) CreateCompletionStream( ctx context.Context, request CompletionRequest, + opts ...requestOption, ) (stream *CompletionStream, err error) { urlSuffix := "/completions" if !checkEndpointSupportsModel(urlSuffix, request.Model) { @@ -38,7 +39,7 @@ func (c *Client) CreateCompletionStream( ctx, http.MethodPost, c.fullURL(urlSuffix, withModel(request.Model)), - withBody(request), + append(opts, withBody(request))..., ) if err != nil { return nil, err From c6c099152103f0866499dee14d7d2a812e95bffa Mon Sep 17 00:00:00 2001 From: mel2oo <l977631253@gmail.com> Date: Thu, 24 Apr 2025 10:14:19 +0800 Subject: [PATCH 4/4] feat: add http header --- chat.go | 10 +++++----- chat_stream.go | 3 +-- client.go | 10 ++++------ completion.go | 10 +++++----- config.go | 4 ++-- examples/chatbot/main.go | 6 ++++++ stream.go | 3 +-- 7 files changed, 24 insertions(+), 22 deletions(-) diff --git a/chat.go b/chat.go index 9b5820864..c50506185 100644 --- a/chat.go +++ b/chat.go @@ -275,9 +275,10 @@ type ChatCompletionRequest struct { Reasoning bool `json:"reasoning,omitempty"` // GD - GuidedChoice []string `json:"guided_choice,omitempty"` - GuidedRegex string `json:"guided_regex,omitempty"` - GuidedJson string `json:"guided_json,omitempty"` + GuidedChoice []string `json:"guided_choice,omitempty"` + GuidedRegex string `json:"guided_regex,omitempty"` + GuidedJson string `json:"guided_json,omitempty"` + GuidedGrammar string `json:"guided_grammar,omitempty"` // LoraType LoraType string `json:"lora_type,omitempty"` @@ -403,7 +404,6 @@ type ChatCompletionResponse struct { func (c *Client) CreateChatCompletion( ctx context.Context, request ChatCompletionRequest, - opts ...requestOption, ) (response ChatCompletionResponse, err error) { if request.Stream { err = ErrChatCompletionStreamNotSupported @@ -425,7 +425,7 @@ func (c *Client) CreateChatCompletion( ctx, http.MethodPost, c.fullURL(urlSuffix, withModel(request.Model)), - append(opts, withBody(request))..., + withBody(request), ) if err != nil { return diff --git a/chat_stream.go b/chat_stream.go index 6eefd4857..06af5f928 100644 --- a/chat_stream.go +++ b/chat_stream.go @@ -73,7 +73,6 @@ type ChatCompletionStream struct { func (c *Client) CreateChatCompletionStream( ctx context.Context, request ChatCompletionRequest, - opts ...requestOption, ) (stream *ChatCompletionStream, err error) { urlSuffix := chatCompletionsSuffix if !checkEndpointSupportsModel(urlSuffix, request.Model) { @@ -91,7 +90,7 @@ func (c *Client) CreateChatCompletionStream( ctx, http.MethodPost, c.fullURL(urlSuffix, withModel(request.Model)), - append(opts, withBody(request))..., + withBody(request), ) if err != nil { return nil, err diff --git a/client.go b/client.go index 8132f6bd9..a130163e8 100644 --- a/client.go +++ b/client.go @@ -96,12 +96,6 @@ func withBetaAssistantVersion(version string) requestOption { } } -func WithHeader(key, value string) requestOption { - return func(ro *requestOptions) { - ro.header.Set(key, value) - } -} - func (c *Client) newRequest(ctx context.Context, method, url string, setters ...requestOption) (*http.Request, error) { // Default Options args := &requestOptions{ @@ -206,6 +200,10 @@ func (c *Client) setCommonHeaders(req *http.Request) { if c.config.OrgID != "" { req.Header.Set("OpenAI-Organization", c.config.OrgID) } + + for k, v := range c.config.HTTPHeaderSets { + req.Header[k] = v + } } func isFailureStatusCode(resp *http.Response) bool { diff --git a/completion.go b/completion.go index d54fe4ab0..84970613e 100644 --- a/completion.go +++ b/completion.go @@ -196,9 +196,10 @@ type CompletionRequest struct { User string `json:"user,omitempty"` // GD - GuidedChoice []string `json:"guided_choice,omitempty"` - GuidedRegex string `json:"guided_regex,omitempty"` - GuidedJson string `json:"guided_json,omitempty"` + GuidedChoice []string `json:"guided_choice,omitempty"` + GuidedRegex string `json:"guided_regex,omitempty"` + GuidedJson string `json:"guided_json,omitempty"` + GuidedGrammar string `json:"guided_grammar,omitempty"` // For Hugging Face models TopK int `json:"top_k,omitempty"` @@ -241,7 +242,6 @@ type CompletionResponse struct { func (c *Client) CreateCompletion( ctx context.Context, request CompletionRequest, - opts ...requestOption, ) (response CompletionResponse, err error) { if request.Stream { err = ErrCompletionStreamNotSupported @@ -263,7 +263,7 @@ func (c *Client) CreateCompletion( ctx, http.MethodPost, c.fullURL(urlSuffix, withModel(request.Model)), - append(opts, withBody(request))..., + withBody(request), ) if err != nil { return diff --git a/config.go b/config.go index 4788ba62a..cbc4fc546 100644 --- a/config.go +++ b/config.go @@ -44,8 +44,8 @@ type ClientConfig struct { AssistantVersion string AzureModelMapperFunc func(model string) string // replace model to azure deployment name func HTTPClient HTTPDoer - - EmptyMessagesLimit uint + HTTPHeaderSets http.Header + EmptyMessagesLimit uint } func DefaultConfig(authToken string) ClientConfig { diff --git a/examples/chatbot/main.go b/examples/chatbot/main.go index 2d776eab9..6154eb1f6 100644 --- a/examples/chatbot/main.go +++ b/examples/chatbot/main.go @@ -12,12 +12,18 @@ import ( func main() { config := openai.DefaultConfig("sk-xxxx") config.BaseURL = "https://10.20.152.76:30002/v1" + config.HTTPHeaderSets = http.Header{ + "123": []string{"Bearer sk-xxxx"}, + "abc": []string{"application/json"}, + "efg": []string{"application/json"}, + } config.HTTPClient = &http.Client{ Transport: &http.Transport{ DisableKeepAlives: true, TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, }, } + client := openai.NewClientWithConfig(config) req := openai.ChatCompletionRequest{ diff --git a/stream.go b/stream.go index 8e59b7484..a61c7c970 100644 --- a/stream.go +++ b/stream.go @@ -21,7 +21,6 @@ type CompletionStream struct { func (c *Client) CreateCompletionStream( ctx context.Context, request CompletionRequest, - opts ...requestOption, ) (stream *CompletionStream, err error) { urlSuffix := "/completions" if !checkEndpointSupportsModel(urlSuffix, request.Model) { @@ -39,7 +38,7 @@ func (c *Client) CreateCompletionStream( ctx, http.MethodPost, c.fullURL(urlSuffix, withModel(request.Model)), - append(opts, withBody(request))..., + withBody(request), ) if err != nil { return nil, err