Skip to content

Commit 6d31caf

Browse files
deorth-kkuyuhan6665
authored andcommitted
Fixing tcp connestions leak
- always use HandshakeContext instead of Handshake - pickup dailer dropped ctx - patch reality.UConn with close timeout as well - rename HandshakeContextAddress to HandshakeAddressContext
1 parent 5ea1315 commit 6d31caf

File tree

7 files changed

+38
-18
lines changed

7 files changed

+38
-18
lines changed

proxy/dokodemo/dokodemo.go

+4-4
Original file line numberDiff line numberDiff line change
@@ -71,8 +71,8 @@ func (d *DokodemoDoor) policy() policy.Session {
7171
return p
7272
}
7373

74-
type hasHandshakeAddress interface {
75-
HandshakeAddress() net.Address
74+
type hasHandshakeAddressContext interface {
75+
HandshakeAddressContext(ctx context.Context) net.Address
7676
}
7777

7878
// Process implements proxy.Inbound.
@@ -89,8 +89,8 @@ func (d *DokodemoDoor) Process(ctx context.Context, network net.Network, conn st
8989
if outbound := session.OutboundFromContext(ctx); outbound != nil && outbound.Target.IsValid() {
9090
dest = outbound.Target
9191
destinationOverridden = true
92-
} else if handshake, ok := conn.(hasHandshakeAddress); ok {
93-
addr := handshake.HandshakeAddress()
92+
} else if handshake, ok := conn.(hasHandshakeAddressContext); ok {
93+
addr := handshake.HandshakeAddressContext(ctx)
9494
if addr != nil {
9595
dest.Address = addr
9696
destinationOverridden = true

proxy/http/client.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -308,7 +308,7 @@ func setUpHTTPTunnel(ctx context.Context, dest net.Destination, target string, u
308308

309309
nextProto := ""
310310
if tlsConn, ok := iConn.(*tls.Conn); ok {
311-
if err := tlsConn.Handshake(); err != nil {
311+
if err := tlsConn.HandshakeContext(ctx); err != nil {
312312
rawConn.Close()
313313
return nil, err
314314
}

transport/internet/http/dialer.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ func getHTTPClient(ctx context.Context, dest net.Destination, streamSettings *in
8787
} else {
8888
cn = tls.Client(pconn, tlsConfig).(*tls.Conn)
8989
}
90-
if err := cn.Handshake(); err != nil {
90+
if err := cn.HandshakeContext(ctx); err != nil {
9191
newError("failed to dial to " + addr).Base(err).AtError().WriteToLog()
9292
return nil, err
9393
}

transport/internet/tcp/dialer.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ func Dial(ctx context.Context, dest net.Destination, streamSettings *internet.Me
2424
tlsConfig := config.GetTLSConfig(tls.WithDestination(dest))
2525
if fingerprint := tls.GetFingerprint(config.Fingerprint); fingerprint != nil {
2626
conn = tls.UClient(conn, tlsConfig, fingerprint)
27-
if err := conn.(*tls.UConn).Handshake(); err != nil {
27+
if err := conn.(*tls.UConn).HandshakeContext(ctx); err != nil {
2828
return nil, err
2929
}
3030
} else {

transport/internet/tls/grpc.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ func (c *grpcUtls) ClientHandshake(ctx context.Context, authority string, rawCon
6565
conn := UClient(rawConn, cfg, c.fingerprint).(*UConn)
6666
errChannel := make(chan error, 1)
6767
go func() {
68-
errChannel <- conn.Handshake()
68+
errChannel <- conn.HandshakeContext(ctx)
6969
close(errChannel)
7070
}()
7171
select {

transport/internet/tls/tls.go

+28-8
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
package tls
22

33
import (
4+
"context"
45
"crypto/rand"
56
"crypto/tls"
67
"math/big"
8+
"time"
79

810
utls "github.com/refraction-networking/utls"
911
"github.com/xtls/xray-core/common/buf"
@@ -14,7 +16,7 @@ import (
1416

1517
type Interface interface {
1618
net.Conn
17-
Handshake() error
19+
HandshakeContext(ctx context.Context) error
1820
VerifyHostname(host string) error
1921
NegotiatedProtocol() (name string, mutual bool)
2022
}
@@ -25,15 +27,25 @@ type Conn struct {
2527
*tls.Conn
2628
}
2729

30+
const tlsCloseTimeout = 250 * time.Millisecond
31+
32+
func (c *Conn) Close() error {
33+
timer := time.AfterFunc(tlsCloseTimeout, func() {
34+
c.Conn.NetConn().Close()
35+
})
36+
defer timer.Stop()
37+
return c.Conn.Close()
38+
}
39+
2840
func (c *Conn) WriteMultiBuffer(mb buf.MultiBuffer) error {
2941
mb = buf.Compact(mb)
3042
mb, err := buf.WriteMultiBuffer(c, mb)
3143
buf.ReleaseMulti(mb)
3244
return err
3345
}
3446

35-
func (c *Conn) HandshakeAddress() net.Address {
36-
if err := c.Handshake(); err != nil {
47+
func (c *Conn) HandshakeAddressContext(ctx context.Context) net.Address {
48+
if err := c.HandshakeContext(ctx); err != nil {
3749
return nil
3850
}
3951
state := c.ConnectionState()
@@ -64,8 +76,16 @@ type UConn struct {
6476
*utls.UConn
6577
}
6678

67-
func (c *UConn) HandshakeAddress() net.Address {
68-
if err := c.Handshake(); err != nil {
79+
func (c *UConn) Close() error {
80+
timer := time.AfterFunc(tlsCloseTimeout, func() {
81+
c.Conn.NetConn().Close()
82+
})
83+
defer timer.Stop()
84+
return c.Conn.Close()
85+
}
86+
87+
func (c *UConn) HandshakeAddressContext(ctx context.Context) net.Address {
88+
if err := c.HandshakeContext(ctx); err != nil {
6989
return nil
7090
}
7191
state := c.ConnectionState()
@@ -77,7 +97,7 @@ func (c *UConn) HandshakeAddress() net.Address {
7797

7898
// WebsocketHandshake basically calls UConn.Handshake inside it but it will only send
7999
// http/1.1 in its ALPN.
80-
func (c *UConn) WebsocketHandshake() error {
100+
func (c *UConn) WebsocketHandshakeContext(ctx context.Context) error {
81101
// Build the handshake state. This will apply every variable of the TLS of the
82102
// fingerprint in the UConn
83103
if err := c.BuildHandshakeState(); err != nil {
@@ -99,7 +119,7 @@ func (c *UConn) WebsocketHandshake() error {
99119
if err := c.BuildHandshakeState(); err != nil {
100120
return err
101121
}
102-
return c.Handshake()
122+
return c.HandshakeContext(ctx)
103123
}
104124

105125
func (c *UConn) NegotiatedProtocol() (name string, mutual bool) {
@@ -118,7 +138,7 @@ func copyConfig(c *tls.Config) *utls.Config {
118138
ServerName: c.ServerName,
119139
InsecureSkipVerify: c.InsecureSkipVerify,
120140
VerifyPeerCertificate: c.VerifyPeerCertificate,
121-
KeyLogWriter: c.KeyLogWriter,
141+
KeyLogWriter: c.KeyLogWriter,
122142
}
123143
}
124144

transport/internet/websocket/dialer.go

+2-2
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ func dialWebSocket(ctx context.Context, dest net.Destination, streamSettings *in
9696
}
9797
// TLS and apply the handshake
9898
cn := tls.UClient(pconn, tlsConfig, fingerprint).(*tls.UConn)
99-
if err := cn.WebsocketHandshake(); err != nil {
99+
if err := cn.WebsocketHandshakeContext(ctx); err != nil {
100100
newError("failed to dial to " + addr).Base(err).AtError().WriteToLog()
101101
return nil, err
102102
}
@@ -147,7 +147,7 @@ func dialWebSocket(ctx context.Context, dest net.Destination, streamSettings *in
147147
header.Set("Sec-WebSocket-Protocol", base64.RawURLEncoding.EncodeToString(ed))
148148
}
149149

150-
conn, resp, err := dialer.Dial(uri, header)
150+
conn, resp, err := dialer.DialContext(ctx, uri, header)
151151
if err != nil {
152152
var reason string
153153
if resp != nil {

0 commit comments

Comments
 (0)