diff --git a/chat.go b/chat.go index 0f91d481c..9c179c2e2 100644 --- a/chat.go +++ b/chat.go @@ -232,16 +232,21 @@ type ChatCompletionRequest struct { MaxTokens int `json:"max_tokens,omitempty"` // MaxCompletionTokens An upper bound for the number of tokens that can be generated for a completion, // including visible output tokens and reasoning tokens https://platform.openai.com/docs/guides/reasoning - MaxCompletionTokens int `json:"max_completion_tokens,omitempty"` - Temperature float32 `json:"temperature,omitempty"` - TopP float32 `json:"top_p,omitempty"` - N int `json:"n,omitempty"` - Stream bool `json:"stream,omitempty"` - Stop []string `json:"stop,omitempty"` - PresencePenalty float32 `json:"presence_penalty,omitempty"` - ResponseFormat *ChatCompletionResponseFormat `json:"response_format,omitempty"` - Seed *int `json:"seed,omitempty"` - FrequencyPenalty float32 `json:"frequency_penalty,omitempty"` + MaxCompletionTokens int `json:"max_completion_tokens,omitempty"` + + // Deprecated: Use TemperatureOpt instead. When TemperatureOpt is set, Temperature is ignored + // regardless of its value. Otherwise (if TemperatureOpt is nil), Temperature is used when + // non-zero. + Temperature float32 `json:"-"` + TemperatureOpt *float32 `json:"temperature,omitempty"` + TopP float32 `json:"top_p,omitempty"` + N int `json:"n,omitempty"` + Stream bool `json:"stream,omitempty"` + Stop []string `json:"stop,omitempty"` + PresencePenalty float32 `json:"presence_penalty,omitempty"` + ResponseFormat *ChatCompletionResponseFormat `json:"response_format,omitempty"` + Seed *int `json:"seed,omitempty"` + FrequencyPenalty float32 `json:"frequency_penalty,omitempty"` // LogitBias is must be a token id string (specified by their token ID in the tokenizer), not a word string. // incorrect: `"logit_bias":{"You": 6}`, correct: `"logit_bias":{"1639": 6}` // refs: https://platform.openai.com/docs/api-reference/chat/create#chat/create-logit_bias @@ -277,6 +282,52 @@ type ChatCompletionRequest struct { Prediction *Prediction `json:"prediction,omitempty"` } +func (r *ChatCompletionRequest) UnmarshalJSON(data []byte) error { + type plainChatCompletionRequest ChatCompletionRequest + if err := json.Unmarshal(data, (*plainChatCompletionRequest)(r)); err != nil { + return err + } + if r.TemperatureOpt != nil { + if *r.TemperatureOpt == 0 { + // Explicit zero. This can only be represented in the TemperatureOpt field, so + // we need to preserve it. + // We still link r.TemperatureOpt to r.Temperature, such that legacy code modifying + // temperature after unmarshaling will continue to work correctly. + r.Temperature = 0 + r.TemperatureOpt = &r.Temperature + } else { + // Non-zero temperature. This can be represented in the legacy field, and in order + // to minimize incompatibilities, we use the legacy field exclusively. + // New code should use `GetTemperature()` to retrieve the temperature, and explicitly + // setting TemperatureOpt will still be respected. + r.Temperature = *r.TemperatureOpt + r.TemperatureOpt = nil + } + } else { + r.Temperature = 0 + } + return nil +} + +func (r ChatCompletionRequest) MarshalJSON() ([]byte, error) { + type plainChatCompletionRequest ChatCompletionRequest + plainR := plainChatCompletionRequest(r) + if plainR.TemperatureOpt == nil && plainR.Temperature != 0 { + plainR.TemperatureOpt = &plainR.Temperature + } + return json.Marshal(&plainR) +} + +func (r *ChatCompletionRequest) GetTemperature() *float32 { + if r.TemperatureOpt != nil { + return r.TemperatureOpt + } + if r.Temperature != 0 { + return &r.Temperature + } + return nil +} + type StreamOptions struct { // If set, an additional chunk will be streamed before the data: [DONE] message. // The usage field on this chunk shows the token usage statistics for the entire request, diff --git a/chat_test.go b/chat_test.go index 514706c96..30cecab52 100644 --- a/chat_test.go +++ b/chat_test.go @@ -123,6 +123,23 @@ func TestO1ModelsChatCompletionsBetaLimitations(t *testing.T) { }, expectedError: openai.ErrReasoningModelLimitationsOther, }, + { + name: "set_temperature_unsupported_new", + in: openai.ChatCompletionRequest{ + MaxCompletionTokens: 1000, + Model: openai.O1Mini, + Messages: []openai.ChatCompletionMessage{ + { + Role: openai.ChatMessageRoleUser, + }, + { + Role: openai.ChatMessageRoleAssistant, + }, + }, + TemperatureOpt: &[]float32{2}[0], + }, + expectedError: openai.ErrReasoningModelLimitationsOther, + }, { name: "set_top_unsupported", in: openai.ChatCompletionRequest{ @@ -946,3 +963,94 @@ func TestFinishReason(t *testing.T) { } } } + +func TestTemperature(t *testing.T) { + tests := []struct { + name string + in openai.ChatCompletionRequest + expectedTemperature *float32 + }{ + { + name: "not_set", + in: openai.ChatCompletionRequest{}, + expectedTemperature: nil, + }, + { + name: "set_legacy", + in: openai.ChatCompletionRequest{ + Temperature: 0.5, + }, + expectedTemperature: &[]float32{0.5}[0], + }, + { + name: "set_new", + in: openai.ChatCompletionRequest{ + TemperatureOpt: &[]float32{0.5}[0], + }, + expectedTemperature: &[]float32{0.5}[0], + }, + { + name: "set_both", + in: openai.ChatCompletionRequest{ + Temperature: 0.4, + TemperatureOpt: &[]float32{0.5}[0], + }, + expectedTemperature: &[]float32{0.5}[0], + }, + { + name: "set_new_explicit_zero", + in: openai.ChatCompletionRequest{ + TemperatureOpt: &[]float32{0}[0], + }, + expectedTemperature: &[]float32{0}[0], + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + data, err := json.Marshal(tt.in) + checks.NoError(t, err, "failed to marshal request to JSON") + + var req openai.ChatCompletionRequest + err = json.Unmarshal(data, &req) + checks.NoError(t, err, "failed to unmarshal request from JSON") + + temp := req.GetTemperature() + if tt.expectedTemperature == nil { + if temp != nil { + t.Error("expected temperature to be nil") + } + } else { + if temp == nil { + t.Error("expected temperature to be set") + } else if *tt.expectedTemperature != *temp { + t.Errorf("expected temperature to be %v but was %v", *tt.expectedTemperature, *temp) + } + } + }) + } +} + +func TestTemperature_ModifyLegacyAfterUnmarshal(t *testing.T) { + req := openai.ChatCompletionRequest{ + TemperatureOpt: &[]float32{0.5}[0], + } + + data, err := json.Marshal(req) + checks.NoError(t, err, "failed to marshal request to JSON") + + var req2 openai.ChatCompletionRequest + err = json.Unmarshal(data, &req2) + checks.NoError(t, err, "failed to unmarshal request from JSON") + + if temp := req2.GetTemperature(); temp == nil || *temp != 0.5 { + t.Errorf("expected temperature to be 0.5 but was %v", temp) + } + + // Modify the legacy temperature field + req2.Temperature = 0.4 + + if temp := req2.GetTemperature(); temp == nil || *temp != 0.4 { + t.Errorf("expected temperature to be 0.4 but was %v", temp) + } +} diff --git a/reasoning_validator.go b/reasoning_validator.go index 2910b1395..36a06ce97 100644 --- a/reasoning_validator.go +++ b/reasoning_validator.go @@ -61,7 +61,7 @@ func (v *ReasoningValidator) validateReasoningModelParams(request ChatCompletion if request.LogProbs { return ErrReasoningModelLimitationsLogprobs } - if request.Temperature > 0 && request.Temperature != 1 { + if temp := request.GetTemperature(); temp != nil && *temp != 1 { return ErrReasoningModelLimitationsOther } if request.TopP > 0 && request.TopP != 1 {