Skip to content

Commit adf73ed

Browse files
committed
WIP: feat: Add rate limiting
Signed-off-by: Manuel Rüger <[email protected]>
1 parent fb7682a commit adf73ed

File tree

7 files changed

+73
-26
lines changed

7 files changed

+73
-26
lines changed

go.mod

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ require (
99
github.com/prometheus/common v0.64.0
1010
golang.org/x/crypto v0.39.0
1111
golang.org/x/sync v0.15.0
12+
golang.org/x/time v0.12.0
1213
gopkg.in/yaml.v2 v2.4.0
1314
)
1415

go.sum

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,8 @@ golang.org/x/sys v0.33.0 h1:q3i8TbbEz+JRD9ywIRlyRAQbM0qF7hu24q3teo2hbuw=
6161
golang.org/x/sys v0.33.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
6262
golang.org/x/text v0.26.0 h1:P42AVeLghgTYr4+xUnTRKDMqpar+PtX7KWuNQL21L8M=
6363
golang.org/x/text v0.26.0/go.mod h1:QK15LZJUUQVJxhz7wXgxSy/CJaTFjd0G+YLonydOVQA=
64+
golang.org/x/time v0.12.0 h1:ScB/8o8olJvc+CQPWrK3fPZNfh7qgwCrY0zJmoEQLSE=
65+
golang.org/x/time v0.12.0/go.mod h1:CDIdPxbZBQxdj6cxyCIdrNogrJKMJ7pr37NYpMcMDSg=
6466
google.golang.org/protobuf v1.36.6 h1:z1NpPI8ku2WgiWnf+t9wTPsn6eP1L7ksHUlkfLvd9xY=
6567
google.golang.org/protobuf v1.36.6/go.mod h1:jduwjTPXsFjZGTmRluh+L6NjiWu7pchiJ2/5YcXBHnY=
6668
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=

web/handler.go

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ import (
2424
"sync"
2525

2626
"golang.org/x/crypto/bcrypt"
27+
"golang.org/x/time/rate"
2728
)
2829

2930
// extraHTTPHeaders is a map of HTTP headers that can be added to HTTP
@@ -80,6 +81,7 @@ type webHandler struct {
8081
handler http.Handler
8182
logger *slog.Logger
8283
cache *cache
84+
limiter *rate.Limiter
8385
// bcryptMtx is there to ensure that bcrypt.CompareHashAndPassword is run
8486
// only once in parallel as this is CPU intensive.
8587
bcryptMtx sync.Mutex
@@ -93,6 +95,11 @@ func (u *webHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
9395
return
9496
}
9597

98+
if u.limiter != nil && !u.limiter.Allow() {
99+
http.Error(w, http.StatusText(http.StatusTooManyRequests), http.StatusTooManyRequests)
100+
return
101+
}
102+
96103
// Configure http headers.
97104
for k, v := range c.HTTPConfig.Header {
98105
w.Header().Set(k, v)
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
rate_limiter_config:
2+
rate: 1
3+
burst: 1
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
rate_limiter_config:
2+
rate: 0
3+
burst: 0

web/tls_config.go

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ import (
3131
"github.com/mdlayher/vsock"
3232
config_util "github.com/prometheus/common/config"
3333
"golang.org/x/sync/errgroup"
34+
"golang.org/x/time/rate"
3435
"gopkg.in/yaml.v2"
3536
)
3637

@@ -40,9 +41,10 @@ var (
4041
)
4142

4243
type Config struct {
43-
TLSConfig TLSConfig `yaml:"tls_server_config"`
44-
HTTPConfig HTTPConfig `yaml:"http_server_config"`
45-
Users map[string]config_util.Secret `yaml:"basic_auth_users"`
44+
TLSConfig TLSConfig `yaml:"tls_server_config"`
45+
HTTPConfig HTTPConfig `yaml:"http_server_config"`
46+
RateLimiterConfig RateLimiterConfig `yaml:"rate_limiter_config"`
47+
Users map[string]config_util.Secret `yaml:"basic_auth_users"`
4648
}
4749

4850
type TLSConfig struct {
@@ -109,6 +111,11 @@ type HTTPConfig struct {
109111
Header map[string]string `yaml:"headers,omitempty"`
110112
}
111113

114+
type RateLimiterConfig struct {
115+
Burst int `yaml:"burst"`
116+
Rate int `yaml:"rate"`
117+
}
118+
112119
func getConfig(configPath string) (*Config, error) {
113120
content, err := os.ReadFile(configPath)
114121
if err != nil {
@@ -366,11 +373,19 @@ func Serve(l net.Listener, server *http.Server, flags *FlagConfig, logger *slog.
366373
return err
367374
}
368375

376+
var limiter *rate.Limiter
377+
// Setup Rate Limiter
378+
if c.RateLimiterConfig.Rate != 0 && c.RateLimiterConfig.Burst != 0 {
379+
limiter = rate.NewLimiter(rate.Limit(c.RateLimiterConfig.Rate), c.RateLimiterConfig.Burst)
380+
logger.Info("Rate Limiter is enabled.", "burst", c.RateLimiterConfig.Burst, "rate", c.RateLimiterConfig.Rate)
381+
}
382+
369383
server.Handler = &webHandler{
370384
tlsConfigPath: tlsConfigPath,
371385
logger: logger,
372386
handler: handler,
373387
cache: newCache(),
388+
limiter: limiter,
374389
}
375390

376391
config, err := ConfigToTLSConfig(&c.TLSConfig)

web/tls_config_test.go

Lines changed: 39 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,7 @@ type TestInputs struct {
9898
Username string
9999
Password string
100100
ClientCertificate string
101+
Requests int
101102
}
102103

103104
func TestYAMLFiles(t *testing.T) {
@@ -364,6 +365,19 @@ func TestServerBehaviour(t *testing.T) {
364365
ClientCertificate: "client2_selfsigned",
365366
ExpectedError: ErrorMap["Invalid client cert"],
366367
},
368+
{
369+
Name: "valid rate limiter that doesn't block",
370+
YAMLConfigPath: "testdata/web_config_rate_limiter_nonblocking.yaml",
371+
UseTLSClient: false,
372+
ExpectedError: nil,
373+
},
374+
{
375+
Name: "valid rate limiter with a capacity of one",
376+
YAMLConfigPath: "testdata/web_config_rate_limiter_capacity_one.yaml",
377+
UseTLSClient: false,
378+
Requests: 100,
379+
ExpectedError: nil,
380+
},
367381
}
368382
for _, testInputs := range testTables {
369383
t.Run(testInputs.Name, testInputs.Test)
@@ -515,33 +529,35 @@ func (test *TestInputs) Test(t *testing.T) {
515529
}
516530
go func() {
517531
time.Sleep(250 * time.Millisecond)
518-
r, err := ClientConnection()
519-
if err != nil {
520-
recordConnectionError(err)
521-
return
522-
}
532+
for i := 0; i <= test.Requests; i++ {
533+
r, err := ClientConnection()
534+
if err != nil {
535+
recordConnectionError(err)
536+
return
537+
}
523538

524-
if test.ActualCipher != 0 {
525-
if r.TLS.CipherSuite != test.ActualCipher {
526-
recordConnectionError(
527-
fmt.Errorf("bad cipher suite selected. Expected: %s, got: %s",
528-
tls.CipherSuiteName(test.ActualCipher),
529-
tls.CipherSuiteName(r.TLS.CipherSuite),
530-
),
531-
)
539+
if test.ActualCipher != 0 {
540+
if r.TLS.CipherSuite != test.ActualCipher {
541+
recordConnectionError(
542+
fmt.Errorf("bad cipher suite selected. Expected: %s, got: %s",
543+
tls.CipherSuiteName(test.ActualCipher),
544+
tls.CipherSuiteName(r.TLS.CipherSuite),
545+
),
546+
)
547+
}
532548
}
533-
}
534549

535-
body, err := io.ReadAll(r.Body)
536-
if err != nil {
537-
recordConnectionError(err)
538-
return
539-
}
540-
if string(body) != "Hello World!" {
541-
recordConnectionError(errors.New(string(body)))
542-
return
550+
body, err := io.ReadAll(r.Body)
551+
if err != nil {
552+
recordConnectionError(err)
553+
return
554+
}
555+
if string(body) != "Hello World!" {
556+
recordConnectionError(errors.New(string(body)))
557+
return
558+
}
559+
recordConnectionError(nil)
543560
}
544-
recordConnectionError(nil)
545561
}()
546562
err := <-errorChannel
547563
if test.isCorrectError(err) == false {

0 commit comments

Comments
 (0)