@@ -2,9 +2,7 @@ package httprateredis
2
2
3
3
import (
4
4
"context"
5
- "errors"
6
5
"fmt"
7
- "net"
8
6
"os"
9
7
"path/filepath"
10
8
"strconv"
@@ -40,8 +38,13 @@ func NewRedisLimitCounter(cfg *Config) (*redisCounter, error) {
40
38
cfg .PrefixKey = "httprate"
41
39
}
42
40
if cfg .FallbackTimeout == 0 {
43
- // Activate local in-memory fallback fairly quickly, as this would slow down all requests.
44
- cfg .FallbackTimeout = 100 * time .Millisecond
41
+ if cfg .FallbackDisabled {
42
+ cfg .FallbackTimeout = time .Second
43
+ } else {
44
+ // Activate local in-memory fallback fairly quickly,
45
+ // so we don't slow down incoming requests too much.
46
+ cfg .FallbackTimeout = 100 * time .Millisecond
47
+ }
45
48
}
46
49
47
50
rc := & redisCounter {
@@ -54,10 +57,10 @@ func NewRedisLimitCounter(cfg *Config) (*redisCounter, error) {
54
57
if cfg .Client == nil {
55
58
maxIdle , maxActive := cfg .MaxIdle , cfg .MaxActive
56
59
if maxIdle < 1 {
57
- maxIdle = 20
60
+ maxIdle = 5
58
61
}
59
62
if maxActive < 1 {
60
- maxActive = 50
63
+ maxActive = 10
61
64
}
62
65
63
66
rc .client = redis .NewClient (& redis.Options {
@@ -107,13 +110,8 @@ func (c *redisCounter) IncrementBy(key string, currentWindow time.Time, amount i
107
110
return c .fallbackCounter .IncrementBy (key , currentWindow , amount )
108
111
}
109
112
defer func () {
110
- if err != nil {
111
- // On redis network error, fallback to local in-memory counter.
112
- var netErr net.Error
113
- if errors .As (err , & netErr ) || errors .Is (err , redis .ErrClosed ) {
114
- c .fallback ()
115
- err = c .fallbackCounter .IncrementBy (key , currentWindow , amount )
116
- }
113
+ if c .shouldFallback (err ) {
114
+ err = c .fallbackCounter .IncrementBy (key , currentWindow , amount )
117
115
}
118
116
}()
119
117
}
@@ -147,13 +145,8 @@ func (c *redisCounter) Get(key string, currentWindow, previousWindow time.Time)
147
145
return c .fallbackCounter .Get (key , currentWindow , previousWindow )
148
146
}
149
147
defer func () {
150
- if err != nil {
151
- // On redis network error, fallback to local in-memory counter.
152
- var netErr net.Error
153
- if errors .As (err , & netErr ) || errors .Is (err , redis .ErrClosed ) {
154
- c .fallback ()
155
- curr , prev , err = c .fallbackCounter .Get (key , currentWindow , previousWindow )
156
- }
148
+ if c .shouldFallback (err ) {
149
+ curr , prev , err = c .fallbackCounter .Get (key , currentWindow , previousWindow )
157
150
}
158
151
}()
159
152
}
@@ -189,25 +182,34 @@ func (c *redisCounter) IsFallbackActivated() bool {
189
182
return c .fallbackActivated .Load ()
190
183
}
191
184
192
- func (c * redisCounter ) fallback () {
193
- // Activate the in-memory counter fallback, unless activated by some other goroutine.
194
- fallbackAlreadyActivated := c .fallbackActivated .Swap (true )
195
- if fallbackAlreadyActivated {
196
- return
185
+ func (c * redisCounter ) Close () error {
186
+ return c .client .Close ()
187
+ }
188
+
189
+ func (c * redisCounter ) shouldFallback (err error ) bool {
190
+ if err == nil {
191
+ return false
192
+ }
193
+
194
+ // Activate the local in-memory counter fallback, unless activated by some other goroutine.
195
+ alreadyActivated := c .fallbackActivated .Swap (true )
196
+ if ! alreadyActivated {
197
+ go c .reconnect ()
197
198
}
198
199
199
- go c . reconnect ()
200
+ return true
200
201
}
201
202
202
203
func (c * redisCounter ) reconnect () {
203
204
// Try to re-connect to redis every 200ms.
204
205
for {
206
+ time .Sleep (200 * time .Millisecond )
207
+
205
208
err := c .client .Ping (context .Background ()).Err ()
206
209
if err == nil {
207
210
c .fallbackActivated .Store (false )
208
211
return
209
212
}
210
- time .Sleep (200 * time .Millisecond )
211
213
}
212
214
}
213
215
0 commit comments