Skip to content

Commit 17207fc

Browse files
IlyaGulyaRPRX
authored andcommitted
WireGuard: Improve config error handling; Prevent panic in case of errors during server initialization (#4566)
#4566 (comment)
1 parent 52a2c63 commit 17207fc

File tree

4 files changed

+83
-28
lines changed

4 files changed

+83
-28
lines changed

Diff for: infra/conf/wireguard.go

+5-1
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ func (c *WireGuardConfig) Build() (proto.Message, error) {
6767
var err error
6868
config.SecretKey, err = ParseWireGuardKey(c.SecretKey)
6969
if err != nil {
70-
return nil, err
70+
return nil, errors.New("invalid WireGuard secret key: %w", err)
7171
}
7272

7373
if c.Address == nil {
@@ -126,6 +126,10 @@ func (c *WireGuardConfig) Build() (proto.Message, error) {
126126
func ParseWireGuardKey(str string) (string, error) {
127127
var err error
128128

129+
if str == "" {
130+
return "", errors.New("key must not be empty")
131+
}
132+
129133
if len(str)%2 == 0 {
130134
_, err = hex.DecodeString(str)
131135
if err == nil {

Diff for: infra/conf/xray.go

+20-20
Original file line numberDiff line numberDiff line change
@@ -241,14 +241,14 @@ func (c *InboundDetourConfig) Build() (*core.InboundHandlerConfig, error) {
241241
}
242242
rawConfig, err := inboundConfigLoader.LoadWithID(settings, c.Protocol)
243243
if err != nil {
244-
return nil, errors.New("failed to load inbound detour config.").Base(err)
244+
return nil, errors.New("failed to load inbound detour config for protocol ", c.Protocol).Base(err)
245245
}
246246
if dokodemoConfig, ok := rawConfig.(*DokodemoConfig); ok {
247247
receiverSettings.ReceiveOriginalDestination = dokodemoConfig.Redirect
248248
}
249249
ts, err := rawConfig.(Buildable).Build()
250250
if err != nil {
251-
return nil, err
251+
return nil, errors.New("failed to build inbound handler for protocol ", c.Protocol).Base(err)
252252
}
253253

254254
return &core.InboundHandlerConfig{
@@ -303,15 +303,15 @@ func (c *OutboundDetourConfig) Build() (*core.OutboundHandlerConfig, error) {
303303
if c.StreamSetting != nil {
304304
ss, err := c.StreamSetting.Build()
305305
if err != nil {
306-
return nil, err
306+
return nil, errors.New("failed to build stream settings for outbound detour").Base(err)
307307
}
308308
senderSettings.StreamSettings = ss
309309
}
310310

311311
if c.ProxySettings != nil {
312312
ps, err := c.ProxySettings.Build()
313313
if err != nil {
314-
return nil, errors.New("invalid outbound detour proxy settings.").Base(err)
314+
return nil, errors.New("invalid outbound detour proxy settings").Base(err)
315315
}
316316
if ps.TransportLayerProxy {
317317
if senderSettings.StreamSettings != nil {
@@ -331,7 +331,7 @@ func (c *OutboundDetourConfig) Build() (*core.OutboundHandlerConfig, error) {
331331
if c.MuxSettings != nil {
332332
ms, err := c.MuxSettings.Build()
333333
if err != nil {
334-
return nil, errors.New("failed to build Mux config.").Base(err)
334+
return nil, errors.New("failed to build Mux config").Base(err)
335335
}
336336
senderSettings.MultiplexSettings = ms
337337
}
@@ -342,11 +342,11 @@ func (c *OutboundDetourConfig) Build() (*core.OutboundHandlerConfig, error) {
342342
}
343343
rawConfig, err := outboundConfigLoader.LoadWithID(settings, c.Protocol)
344344
if err != nil {
345-
return nil, errors.New("failed to parse to outbound detour config.").Base(err)
345+
return nil, errors.New("failed to load outbound detour config for protocol ", c.Protocol).Base(err)
346346
}
347347
ts, err := rawConfig.(Buildable).Build()
348348
if err != nil {
349-
return nil, err
349+
return nil, errors.New("failed to build outbound handler for protocol ", c.Protocol).Base(err)
350350
}
351351

352352
return &core.OutboundHandlerConfig{
@@ -490,7 +490,7 @@ func (c *Config) Override(o *Config, fn string) {
490490
// Build implements Buildable.
491491
func (c *Config) Build() (*core.Config, error) {
492492
if err := PostProcessConfigureFile(c); err != nil {
493-
return nil, err
493+
return nil, errors.New("failed to post-process configuration file").Base(err)
494494
}
495495

496496
config := &core.Config{
@@ -504,21 +504,21 @@ func (c *Config) Build() (*core.Config, error) {
504504
if c.API != nil {
505505
apiConf, err := c.API.Build()
506506
if err != nil {
507-
return nil, err
507+
return nil, errors.New("failed to build API configuration").Base(err)
508508
}
509509
config.App = append(config.App, serial.ToTypedMessage(apiConf))
510510
}
511511
if c.Metrics != nil {
512512
metricsConf, err := c.Metrics.Build()
513513
if err != nil {
514-
return nil, err
514+
return nil, errors.New("failed to build metrics configuration").Base(err)
515515
}
516516
config.App = append(config.App, serial.ToTypedMessage(metricsConf))
517517
}
518518
if c.Stats != nil {
519519
statsConf, err := c.Stats.Build()
520520
if err != nil {
521-
return nil, err
521+
return nil, errors.New("failed to build stats configuration").Base(err)
522522
}
523523
config.App = append(config.App, serial.ToTypedMessage(statsConf))
524524
}
@@ -536,55 +536,55 @@ func (c *Config) Build() (*core.Config, error) {
536536
if c.RouterConfig != nil {
537537
routerConfig, err := c.RouterConfig.Build()
538538
if err != nil {
539-
return nil, err
539+
return nil, errors.New("failed to build routing configuration").Base(err)
540540
}
541541
config.App = append(config.App, serial.ToTypedMessage(routerConfig))
542542
}
543543

544544
if c.DNSConfig != nil {
545545
dnsApp, err := c.DNSConfig.Build()
546546
if err != nil {
547-
return nil, errors.New("failed to parse DNS config").Base(err)
547+
return nil, errors.New("failed to build DNS configuration").Base(err)
548548
}
549549
config.App = append(config.App, serial.ToTypedMessage(dnsApp))
550550
}
551551

552552
if c.Policy != nil {
553553
pc, err := c.Policy.Build()
554554
if err != nil {
555-
return nil, err
555+
return nil, errors.New("failed to build policy configuration").Base(err)
556556
}
557557
config.App = append(config.App, serial.ToTypedMessage(pc))
558558
}
559559

560560
if c.Reverse != nil {
561561
r, err := c.Reverse.Build()
562562
if err != nil {
563-
return nil, err
563+
return nil, errors.New("failed to build reverse configuration").Base(err)
564564
}
565565
config.App = append(config.App, serial.ToTypedMessage(r))
566566
}
567567

568568
if c.FakeDNS != nil {
569569
r, err := c.FakeDNS.Build()
570570
if err != nil {
571-
return nil, err
571+
return nil, errors.New("failed to build fake DNS configuration").Base(err)
572572
}
573573
config.App = append([]*serial.TypedMessage{serial.ToTypedMessage(r)}, config.App...)
574574
}
575575

576576
if c.Observatory != nil {
577577
r, err := c.Observatory.Build()
578578
if err != nil {
579-
return nil, err
579+
return nil, errors.New("failed to build observatory configuration").Base(err)
580580
}
581581
config.App = append(config.App, serial.ToTypedMessage(r))
582582
}
583583

584584
if c.BurstObservatory != nil {
585585
r, err := c.BurstObservatory.Build()
586586
if err != nil {
587-
return nil, err
587+
return nil, errors.New("failed to build burst observatory configuration").Base(err)
588588
}
589589
config.App = append(config.App, serial.ToTypedMessage(r))
590590
}
@@ -602,7 +602,7 @@ func (c *Config) Build() (*core.Config, error) {
602602
for _, rawInboundConfig := range inbounds {
603603
ic, err := rawInboundConfig.Build()
604604
if err != nil {
605-
return nil, err
605+
return nil, errors.New("failed to build inbound config with tag ", rawInboundConfig.Tag).Base(err)
606606
}
607607
config.Inbound = append(config.Inbound, ic)
608608
}
@@ -616,7 +616,7 @@ func (c *Config) Build() (*core.Config, error) {
616616
for _, rawOutboundConfig := range outbounds {
617617
oc, err := rawOutboundConfig.Build()
618618
if err != nil {
619-
return nil, err
619+
return nil, errors.New("failed to build outbound config with tag ", rawOutboundConfig.Tag).Base(err)
620620
}
621621
config.Outbound = append(config.Outbound, oc)
622622
}

Diff for: proxy/wireguard/gvisortun/tun.go

+6-7
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ import (
1010
"fmt"
1111
"net/netip"
1212
"os"
13+
"sync"
1314
"syscall"
1415

1516
"golang.zx2c4.com/wireguard/tun"
@@ -33,6 +34,7 @@ type netTun struct {
3334
incomingPacket chan *buffer.View
3435
mtu int
3536
hasV4, hasV6 bool
37+
closeOnce sync.Once
3638
}
3739

3840
type Net netTun
@@ -174,18 +176,15 @@ func (tun *netTun) Flush() error {
174176

175177
// Close implements tun.Device
176178
func (tun *netTun) Close() error {
177-
tun.stack.RemoveNIC(1)
179+
tun.closeOnce.Do(func() {
180+
tun.stack.RemoveNIC(1)
178181

179-
if tun.events != nil {
180182
close(tun.events)
181-
}
182183

183-
tun.ep.Close()
184+
tun.ep.Close()
184185

185-
if tun.incomingPacket != nil {
186186
close(tun.incomingPacket)
187-
}
188-
187+
})
189188
return nil
190189
}
191190

Diff for: proxy/wireguard/server_test.go

+52
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
package wireguard_test
2+
3+
import (
4+
"context"
5+
"github.com/stretchr/testify/assert"
6+
"runtime/debug"
7+
"testing"
8+
9+
"github.com/xtls/xray-core/core"
10+
"github.com/xtls/xray-core/proxy/wireguard"
11+
)
12+
13+
// TestWireGuardServerInitializationError verifies that an error during TUN initialization
14+
// (triggered by an empty SecretKey) in the WireGuard server does not cause a panic and returns an error instead.
15+
func TestWireGuardServerInitializationError(t *testing.T) {
16+
// Create a minimal core instance with default features
17+
config := &core.Config{}
18+
instance, err := core.New(config)
19+
if err != nil {
20+
t.Fatalf("Failed to create core instance: %v", err)
21+
}
22+
// Set the Xray instance in the context
23+
ctx := context.WithValue(context.Background(), core.XrayKey(1), instance)
24+
25+
// Define the server configuration with an empty SecretKey to trigger error
26+
conf := &wireguard.DeviceConfig{
27+
IsClient: false,
28+
Endpoint: []string{"10.0.0.1/32"},
29+
Mtu: 1420,
30+
SecretKey: "", // Empty SecretKey to trigger error
31+
Peers: []*wireguard.PeerConfig{
32+
{
33+
PublicKey: "some_public_key",
34+
AllowedIps: []string{"10.0.0.2/32"},
35+
},
36+
},
37+
}
38+
39+
// Use defer to catch any panic and fail the test explicitly
40+
defer func() {
41+
if r := recover(); r != nil {
42+
t.Errorf("TUN initialization panicked: %v", r)
43+
debug.PrintStack()
44+
}
45+
}()
46+
47+
// Attempt to initialize the WireGuard server
48+
_, err = wireguard.NewServer(ctx, conf)
49+
50+
// Check that an error is returned
51+
assert.ErrorContains(t, err, "failed to set private_key: hex string does not fit the slice")
52+
}

0 commit comments

Comments
 (0)