Skip to content

Commit 1968022

Browse files
committed
feat: support multi model
1 parent 7b9dff6 commit 1968022

File tree

16 files changed

+162
-112
lines changed

16 files changed

+162
-112
lines changed

Dockerfile

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
FROM alpine:3
22

3+
EXPOSE 8080
34
COPY ./bin/azure-openai-proxy /usr/bin
45

56
ENTRYPOINT ["/usr/bin/azure-openai-proxy"]

apis/chat.go

Lines changed: 0 additions & 43 deletions
This file was deleted.

apis/vars.go

Lines changed: 0 additions & 9 deletions
This file was deleted.

azure/init.go

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
package azure
2+
3+
import (
4+
"log"
5+
"net/url"
6+
"os"
7+
"regexp"
8+
"strings"
9+
10+
"github.com/stulzq/azure-openai-proxy/constant"
11+
)
12+
13+
const (
14+
AuthHeaderKey = "api-key"
15+
)
16+
17+
var (
18+
AzureOpenAIEndpoint = ""
19+
AzureOpenAIEndpointParse *url.URL
20+
21+
AzureOpenAIAPIVer = ""
22+
23+
AzureOpenAIModelMapper = map[string]string{
24+
"gpt-3.5-turbo": "gpt-35-turbo",
25+
"gpt-3.5-turbo-0301": "gpt-35-turbo-0301",
26+
}
27+
fallbackModelMapper = regexp.MustCompile(`[.:]`)
28+
)
29+
30+
func init() {
31+
AzureOpenAIAPIVer = os.Getenv(constant.ENV_AZURE_OPENAI_API_VER)
32+
AzureOpenAIEndpoint = os.Getenv(constant.ENV_AZURE_OPENAI_ENDPOINT)
33+
34+
if AzureOpenAIAPIVer == "" {
35+
AzureOpenAIAPIVer = "2023-03-15-preview"
36+
}
37+
38+
var err error
39+
AzureOpenAIEndpointParse, err = url.Parse(AzureOpenAIEndpoint)
40+
if err != nil {
41+
log.Fatal("parse AzureOpenAIEndpoint error: ", err)
42+
}
43+
44+
if v := os.Getenv(constant.ENV_AZURE_OPENAI_MODEL_MAPPER); v != "" {
45+
for _, pair := range strings.Split(v, ",") {
46+
info := strings.Split(pair, "=")
47+
if len(info) != 2 {
48+
log.Fatalf("error parsing %s, invalid value %s", constant.ENV_AZURE_OPENAI_MODEL_MAPPER, pair)
49+
}
50+
51+
AzureOpenAIModelMapper[info[0]] = info[1]
52+
}
53+
}
54+
55+
log.Println("AzureOpenAIAPIVer: ", AzureOpenAIAPIVer)
56+
log.Println("AzureOpenAIEndpoint: ", AzureOpenAIEndpoint)
57+
log.Println("AzureOpenAIModelMapper: ", AzureOpenAIModelMapper)
58+
}

azure/proxy.go

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
package azure
2+
3+
import (
4+
"bytes"
5+
"fmt"
6+
"github.com/bytedance/sonic"
7+
"github.com/pkg/errors"
8+
"github.com/stulzq/azure-openai-proxy/util"
9+
"io"
10+
"log"
11+
"net/http"
12+
"net/http/httputil"
13+
"path"
14+
"strings"
15+
16+
"github.com/gin-gonic/gin"
17+
)
18+
19+
// Proxy Azure OpenAI
20+
func Proxy(c *gin.Context) {
21+
// improve performance some code from https://github.com/diemus/azure-openai-proxy/blob/main/pkg/azure/proxy.go
22+
director := func(req *http.Request) {
23+
if req.Body == nil {
24+
util.SendError(c, errors.New("request body is empty"))
25+
return
26+
}
27+
body, _ := io.ReadAll(req.Body)
28+
req.Body = io.NopCloser(bytes.NewBuffer(body))
29+
30+
// get model from body
31+
model, err := sonic.Get(body, "model")
32+
if err != nil {
33+
util.SendError(c, errors.Wrap(err, "get model error"))
34+
return
35+
}
36+
37+
deployment, err := model.String()
38+
if err != nil {
39+
util.SendError(c, errors.Wrap(err, "get deployment error"))
40+
return
41+
}
42+
deployment = GetDeploymentByModel(deployment)
43+
44+
// get auth token from header
45+
rawToken := req.Header.Get("Authorization")
46+
token := strings.TrimPrefix(rawToken, "Bearer ")
47+
req.Header.Set(AuthHeaderKey, token)
48+
req.Header.Del("Authorization")
49+
50+
originURL := req.URL.String()
51+
req.Host = AzureOpenAIEndpointParse.Host
52+
req.URL.Scheme = AzureOpenAIEndpointParse.Scheme
53+
req.URL.Host = AzureOpenAIEndpointParse.Host
54+
req.URL.Path = path.Join(fmt.Sprintf("/openai/deployments/%s", deployment), strings.Replace(req.URL.Path, "/v1/", "/", 1))
55+
req.URL.RawPath = req.URL.EscapedPath()
56+
57+
query := req.URL.Query()
58+
query.Add("api-version", AzureOpenAIAPIVer)
59+
req.URL.RawQuery = query.Encode()
60+
61+
log.Printf("proxying request [%s] %s -> %s", model, originURL, req.URL.String())
62+
}
63+
64+
proxy := &httputil.ReverseProxy{Director: director}
65+
proxy.ServeHTTP(c.Writer, c.Request)
66+
67+
// https://github.com/Chanzhaoyu/chatgpt-web/issues/831
68+
if c.Writer.Header().Get("Content-Type") == "text/event-stream" {
69+
if _, err := c.Writer.Write([]byte{'\n'}); err != nil {
70+
log.Printf("rewrite response error: %v", err)
71+
}
72+
}
73+
}
74+
75+
func GetDeploymentByModel(model string) string {
76+
if v, ok := AzureOpenAIModelMapper[model]; ok {
77+
return v
78+
}
79+
80+
return fallbackModelMapper.ReplaceAllString(model, "")
81+
}

build.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
set -e
44

5-
VERSION=v1.0.0
5+
VERSION=v1.1.0
66

77
rm -rf bin
88

cmd/main.go

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@ import (
44
"context"
55
"github.com/gin-gonic/gin"
66
"github.com/pkg/errors"
7-
"github.com/stulzq/azure-openai-proxy/openai"
87
"log"
98
"net/http"
109
"os"
@@ -13,8 +12,6 @@ import (
1312
)
1413

1514
func main() {
16-
openai.Init()
17-
1815
gin.SetMode(gin.ReleaseMode)
1916
r := gin.Default()
2017
registerRoute(r)

cmd/router.go

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,11 @@ package main
22

33
import (
44
"github.com/gin-gonic/gin"
5-
"github.com/stulzq/azure-openai-proxy/apis"
5+
"github.com/stulzq/azure-openai-proxy/azure"
66
)
77

8+
// registerRoute registers all routes
89
func registerRoute(r *gin.Engine) {
9-
r.POST("/v1/chat/completions", apis.ChatCompletions)
10+
// https://platform.openai.com/docs/api-reference
11+
r.Any("*path", azure.Proxy)
1012
}

constant/env.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
package constant
22

33
const (
4-
ENV_AZURE_OPENAI_ENDPOINT = "AZURE_OPENAI_ENDPOINT"
5-
ENV_AZURE_OPENAI_API_VER = "AZURE_OPENAI_API_VER"
6-
ENV_AZURE_OPENAI_DEPLOY = "AZURE_OPENAI_DEPLOY"
4+
ENV_AZURE_OPENAI_ENDPOINT = "AZURE_OPENAI_ENDPOINT"
5+
ENV_AZURE_OPENAI_API_VER = "AZURE_OPENAI_API_VER"
6+
ENV_AZURE_OPENAI_MODEL_MAPPER = "AZURE_OPENAI_MODEL_MAPPER"
77
)

go.mod

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,9 @@ require (
3535
github.com/quic-go/qtls-go1-19 v0.2.0 // indirect
3636
github.com/quic-go/qtls-go1-20 v0.1.0 // indirect
3737
github.com/quic-go/quic-go v0.32.0 // indirect
38+
github.com/tidwall/gjson v1.14.4 // indirect
39+
github.com/tidwall/match v1.1.1 // indirect
40+
github.com/tidwall/pretty v1.2.0 // indirect
3841
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
3942
github.com/ugorji/go/codec v1.2.11 // indirect
4043
golang.org/x/arch v0.3.0 // indirect

go.sum

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,12 @@ github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO
9595
github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
9696
github.com/stretchr/testify v1.8.2 h1:+h33VjcLVPDHtOdpUCuF+7gSuG3yGIftsP1YvFihtJ8=
9797
github.com/stretchr/testify v1.8.2/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
98+
github.com/tidwall/gjson v1.14.4 h1:uo0p8EbA09J7RQaflQ1aBRffTR7xedD2bcIVSYxLnkM=
99+
github.com/tidwall/gjson v1.14.4/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
100+
github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA=
101+
github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM=
102+
github.com/tidwall/pretty v1.2.0 h1:RWIZEg2iJ8/g6fDDYzMpobmaoGh5OLl4AXtGUGPcqCs=
103+
github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
98104
github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI=
99105
github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08=
100106
github.com/ugorji/go/codec v1.2.11 h1:BMaWp1Bb6fHwEtbplGBGJ498wD+LKlNSl25MjdZY4dU=

openai/chat.go

Lines changed: 0 additions & 11 deletions
This file was deleted.

openai/init.go

Lines changed: 0 additions & 20 deletions
This file was deleted.

openai/vars.go

Lines changed: 0 additions & 17 deletions
This file was deleted.

apis/tools.go renamed to util/response_err.go

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
1-
package apis
1+
package util
22

3-
import "github.com/gin-gonic/gin"
3+
import (
4+
"github.com/gin-gonic/gin"
5+
)
46

57
func SendError(c *gin.Context, err error) {
68
c.JSON(500, ApiResponse{

apis/types.go renamed to util/types.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
package apis
1+
package util
22

33
type ApiResponse struct {
44
Error ErrorDescription `json:"error"`

0 commit comments

Comments
 (0)