diff --git a/chat.go b/chat.go index 995860c40..c50506185 100644 --- a/chat.go +++ b/chat.go @@ -98,6 +98,9 @@ type ChatCompletionMessage struct { Refusal string `json:"refusal,omitempty"` MultiContent []ChatMessagePart + // supported by deepseek-reasoner https://api-docs.deepseek.com/ + ReasoningContent string `json:"reasoning_content,omitempty"` + // This property isn't in the official documentation, but it's in // the documentation for the official library for python: // - https://github.com/openai/openai-python/blob/main/chatml.md @@ -119,41 +122,44 @@ func (m ChatCompletionMessage) MarshalJSON() ([]byte, error) { } if len(m.MultiContent) > 0 { msg := struct { - Role string `json:"role"` - Content string `json:"-"` - Refusal string `json:"refusal,omitempty"` - MultiContent []ChatMessagePart `json:"content,omitempty"` - Name string `json:"name,omitempty"` - FunctionCall *FunctionCall `json:"function_call,omitempty"` - ToolCalls []ToolCall `json:"tool_calls,omitempty"` - ToolCallID string `json:"tool_call_id,omitempty"` + Role string `json:"role"` + Content string `json:"-"` + Refusal string `json:"refusal,omitempty"` + MultiContent []ChatMessagePart `json:"content,omitempty"` + ReasoningContent string `json:"reasoning_content,omitempty"` + Name string `json:"name,omitempty"` + FunctionCall *FunctionCall `json:"function_call,omitempty"` + ToolCalls []ToolCall `json:"tool_calls,omitempty"` + ToolCallID string `json:"tool_call_id,omitempty"` }(m) return json.Marshal(msg) } msg := struct { - Role string `json:"role"` - Content string `json:"content,omitempty"` - Refusal string `json:"refusal,omitempty"` - MultiContent []ChatMessagePart `json:"-"` - Name string `json:"name,omitempty"` - FunctionCall *FunctionCall `json:"function_call,omitempty"` - ToolCalls []ToolCall `json:"tool_calls,omitempty"` - ToolCallID string `json:"tool_call_id,omitempty"` + Role string `json:"role"` + Content string `json:"content,omitempty"` + Refusal string `json:"refusal,omitempty"` + MultiContent []ChatMessagePart `json:"-"` + ReasoningContent string `json:"reasoning_content,omitempty"` + Name string `json:"name,omitempty"` + FunctionCall *FunctionCall `json:"function_call,omitempty"` + ToolCalls []ToolCall `json:"tool_calls,omitempty"` + ToolCallID string `json:"tool_call_id,omitempty"` }(m) return json.Marshal(msg) } func (m *ChatCompletionMessage) UnmarshalJSON(bs []byte) error { msg := struct { - Role string `json:"role"` - Content string `json:"content,omitempty"` - Refusal string `json:"refusal,omitempty"` - MultiContent []ChatMessagePart - Name string `json:"name,omitempty"` - FunctionCall *FunctionCall `json:"function_call,omitempty"` - ToolCalls []ToolCall `json:"tool_calls,omitempty"` - ToolCallID string `json:"tool_call_id,omitempty"` + Role string `json:"role"` + Content string `json:"content,omitempty"` + Refusal string `json:"refusal,omitempty"` + MultiContent []ChatMessagePart + ReasoningContent string `json:"reasoning_content,omitempty"` + Name string `json:"name,omitempty"` + FunctionCall *FunctionCall `json:"function_call,omitempty"` + ToolCalls []ToolCall `json:"tool_calls,omitempty"` + ToolCallID string `json:"tool_call_id,omitempty"` }{} if err := json.Unmarshal(bs, &msg); err == nil { @@ -161,14 +167,15 @@ func (m *ChatCompletionMessage) UnmarshalJSON(bs []byte) error { return nil } multiMsg := struct { - Role string `json:"role"` - Content string - Refusal string `json:"refusal,omitempty"` - MultiContent []ChatMessagePart `json:"content"` - Name string `json:"name,omitempty"` - FunctionCall *FunctionCall `json:"function_call,omitempty"` - ToolCalls []ToolCall `json:"tool_calls,omitempty"` - ToolCallID string `json:"tool_call_id,omitempty"` + Role string `json:"role"` + Content string + Refusal string `json:"refusal,omitempty"` + MultiContent []ChatMessagePart `json:"content"` + ReasoningContent string `json:"reasoning_content,omitempty"` + Name string `json:"name,omitempty"` + FunctionCall *FunctionCall `json:"function_call,omitempty"` + ToolCalls []ToolCall `json:"tool_calls,omitempty"` + ToolCallID string `json:"tool_call_id,omitempty"` }{} if err := json.Unmarshal(bs, &multiMsg); err != nil { return err @@ -263,6 +270,22 @@ type ChatCompletionRequest struct { ReasoningEffort string `json:"reasoning_effort,omitempty"` // Metadata to store with the completion. Metadata map[string]string `json:"metadata,omitempty"` + + // reasoning for deepseek + Reasoning bool `json:"reasoning,omitempty"` + + // GD + GuidedChoice []string `json:"guided_choice,omitempty"` + GuidedRegex string `json:"guided_regex,omitempty"` + GuidedJson string `json:"guided_json,omitempty"` + GuidedGrammar string `json:"guided_grammar,omitempty"` + + // LoraType + LoraType string `json:"lora_type,omitempty"` + + // For Hugging Face models + TopK int `json:"top_k,omitempty"` + RepetitionPenalty float32 `json:"repetition_penalty,omitempty"` } type StreamOptions struct { diff --git a/chat_stream.go b/chat_stream.go index 525b4457a..06af5f928 100644 --- a/chat_stream.go +++ b/chat_stream.go @@ -6,11 +6,12 @@ import ( ) type ChatCompletionStreamChoiceDelta struct { - Content string `json:"content,omitempty"` - Role string `json:"role,omitempty"` - FunctionCall *FunctionCall `json:"function_call,omitempty"` - ToolCalls []ToolCall `json:"tool_calls,omitempty"` - Refusal string `json:"refusal,omitempty"` + Content string `json:"content,omitempty"` + Role string `json:"role,omitempty"` + FunctionCall *FunctionCall `json:"function_call,omitempty"` + ToolCalls []ToolCall `json:"tool_calls,omitempty"` + Refusal string `json:"refusal,omitempty"` + ReasoningContent string `json:"reasoning_content,omitempty"` } type ChatCompletionStreamChoiceLogprobs struct { diff --git a/client.go b/client.go index cef375348..a130163e8 100644 --- a/client.go +++ b/client.go @@ -200,6 +200,10 @@ func (c *Client) setCommonHeaders(req *http.Request) { if c.config.OrgID != "" { req.Header.Set("OpenAI-Organization", c.config.OrgID) } + + for k, v := range c.config.HTTPHeaderSets { + req.Header[k] = v + } } func isFailureStatusCode(resp *http.Response) bool { diff --git a/common.go b/common.go index 8cc7289c0..087f1d58c 100644 --- a/common.go +++ b/common.go @@ -7,6 +7,8 @@ type Usage struct { PromptTokens int `json:"prompt_tokens"` CompletionTokens int `json:"completion_tokens"` TotalTokens int `json:"total_tokens"` + PromptCacheHitTokens int `json:"prompt_cache_hit_tokens"` + PromptCacheMissTokens int `json:"prompt_cache_miss_tokens"` PromptTokensDetails *PromptTokensDetails `json:"prompt_tokens_details"` CompletionTokensDetails *CompletionTokensDetails `json:"completion_tokens_details"` } diff --git a/completion.go b/completion.go index 015fa2a9f..84970613e 100644 --- a/completion.go +++ b/completion.go @@ -194,6 +194,16 @@ type CompletionRequest struct { Temperature float32 `json:"temperature,omitempty"` TopP float32 `json:"top_p,omitempty"` User string `json:"user,omitempty"` + + // GD + GuidedChoice []string `json:"guided_choice,omitempty"` + GuidedRegex string `json:"guided_regex,omitempty"` + GuidedJson string `json:"guided_json,omitempty"` + GuidedGrammar string `json:"guided_grammar,omitempty"` + + // For Hugging Face models + TopK int `json:"top_k,omitempty"` + RepetitionPenalty float32 `json:"repetition_penalty,omitempty"` } // CompletionChoice represents one of possible completions. diff --git a/config.go b/config.go index 4788ba62a..cbc4fc546 100644 --- a/config.go +++ b/config.go @@ -44,8 +44,8 @@ type ClientConfig struct { AssistantVersion string AzureModelMapperFunc func(model string) string // replace model to azure deployment name func HTTPClient HTTPDoer - - EmptyMessagesLimit uint + HTTPHeaderSets http.Header + EmptyMessagesLimit uint } func DefaultConfig(authToken string) ClientConfig { diff --git a/examples/chatbot/main.go b/examples/chatbot/main.go index ad41e957d..6154eb1f6 100644 --- a/examples/chatbot/main.go +++ b/examples/chatbot/main.go @@ -1,19 +1,33 @@ package main import ( - "bufio" "context" + "crypto/tls" "fmt" - "os" + "net/http" "github.com/sashabaranov/go-openai" ) func main() { - client := openai.NewClient(os.Getenv("OPENAI_API_KEY")) + config := openai.DefaultConfig("sk-xxxx") + config.BaseURL = "https://10.20.152.76:30002/v1" + config.HTTPHeaderSets = http.Header{ + "123": []string{"Bearer sk-xxxx"}, + "abc": []string{"application/json"}, + "efg": []string{"application/json"}, + } + config.HTTPClient = &http.Client{ + Transport: &http.Transport{ + DisableKeepAlives: true, + TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, + }, + } + + client := openai.NewClientWithConfig(config) req := openai.ChatCompletionRequest{ - Model: openai.GPT3Dot5Turbo, + Model: "HengNao-v4", Messages: []openai.ChatCompletionMessage{ { Role: openai.ChatMessageRoleSystem, @@ -24,19 +38,33 @@ func main() { fmt.Println("Conversation") fmt.Println("---------------------") fmt.Print("> ") - s := bufio.NewScanner(os.Stdin) - for s.Scan() { - req.Messages = append(req.Messages, openai.ChatCompletionMessage{ - Role: openai.ChatMessageRoleUser, - Content: s.Text(), - }) - resp, err := client.CreateChatCompletion(context.Background(), req) + // s := bufio.NewScanner(os.Stdin) + // for s.Scan() { + req.Messages = append(req.Messages, openai.ChatCompletionMessage{ + Role: openai.ChatMessageRoleUser, + Content: "你好", + }) + resp, err := client.CreateChatCompletion(context.Background(), req) + if err != nil { + fmt.Printf("ChatCompletion error: %v\n", err) + // continue + } + fmt.Printf("%s\n\n", resp.Choices[0].Message.Content) + req.Messages = append(req.Messages, resp.Choices[0].Message) + fmt.Print("> ") + // } + + stream, err := client.CreateChatCompletionStream(context.Background(), req) + if err != nil { + return + } + + for { + evt, err := stream.Recv() if err != nil { - fmt.Printf("ChatCompletion error: %v\n", err) - continue + return } - fmt.Printf("%s\n\n", resp.Choices[0].Message.Content) - req.Messages = append(req.Messages, resp.Choices[0].Message) - fmt.Print("> ") + + fmt.Printf("%s", evt.Choices[0].Delta.Content) } }