@@ -9,11 +9,9 @@ import (
9
9
"strings"
10
10
)
11
11
12
- var defaultHeaders = []string {
13
- "True-Client-IP" , // Cloudflare Enterprise plan
14
- "X-Real-IP" ,
15
- "X-Forwarded-For" ,
16
- }
12
+ var trueClientIP = http .CanonicalHeaderKey ("True-Client-IP" )
13
+ var xForwardedFor = http .CanonicalHeaderKey ("X-Forwarded-For" )
14
+ var xRealIP = http .CanonicalHeaderKey ("X-Real-IP" )
17
15
18
16
// RealIP is a middleware that sets a http.Request's RemoteAddr to the results
19
17
// of parsing either the True-Client-IP, X-Real-IP or the X-Forwarded-For headers
@@ -32,7 +30,7 @@ var defaultHeaders = []string{
32
30
// how you're using RemoteAddr, vulnerable to an attack of some sort).
33
31
func RealIP (h http.Handler ) http.Handler {
34
32
fn := func (w http.ResponseWriter , r * http.Request ) {
35
- if rip := getRealIP ( r , defaultHeaders ); rip != "" {
33
+ if rip := realIP ( r ); rip != "" {
36
34
r .RemoteAddr = rip
37
35
}
38
36
h .ServeHTTP (w , r )
@@ -41,33 +39,22 @@ func RealIP(h http.Handler) http.Handler {
41
39
return http .HandlerFunc (fn )
42
40
}
43
41
44
- // RealIPFromHeaders is a middleware that sets a http.Request's RemoteAddr to the results
45
- // of parsing the custom headers.
46
- //
47
- // usage:
48
- // r.Use(RealIPFromHeaders("CF-Connecting-IP"))
49
- func RealIPFromHeaders (headers ... string ) func (http.Handler ) http.Handler {
50
- f := func (h http.Handler ) http.Handler {
51
- fn := func (w http.ResponseWriter , r * http.Request ) {
52
- if rip := getRealIP (r , headers ); rip != "" {
53
- r .RemoteAddr = rip
54
- }
55
- h .ServeHTTP (w , r )
56
- }
57
- return http .HandlerFunc (fn )
58
- }
59
- return f
60
- }
42
+ func realIP (r * http.Request ) string {
43
+ var ip string
61
44
62
- func getRealIP ( r * http. Request , headers [] string ) string {
63
- for _ , header := range headers {
64
- if ip := r .Header .Get (header ); ip != "" {
65
- ips := strings . Split ( ip , "," )
66
- if ips [ 0 ] == "" || net . ParseIP ( ips [ 0 ]) == nil {
67
- continue
68
- }
69
- return ips [ 0 ]
45
+ if tcip := r . Header . Get ( trueClientIP ); tcip != "" {
46
+ ip = tcip
47
+ } else if xrip := r .Header .Get (xRealIP ); xrip != "" {
48
+ ip = xrip
49
+ } else if xff := r . Header . Get ( xForwardedFor ); xff != "" {
50
+ i := strings . Index ( xff , "," )
51
+ if i == - 1 {
52
+ i = len ( xff )
70
53
}
54
+ ip = xff [:i ]
55
+ }
56
+ if ip == "" || net .ParseIP (ip ) == nil {
57
+ return ""
71
58
}
72
- return ""
59
+ return ip
73
60
}
0 commit comments