Skip to content

Commit 909fa95

Browse files
Merge pull request #204 from aderouineau/handshaketimeout
handshake timeout
2 parents d137aad + 2326046 commit 909fa95

File tree

3 files changed

+64
-16
lines changed

3 files changed

+64
-16
lines changed

Diff for: conn.go

+21-13
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,16 @@ import (
99
type serverConn struct {
1010
net.Conn
1111

12-
idleTimeout time.Duration
13-
maxDeadline time.Time
14-
closeCanceler context.CancelFunc
12+
idleTimeout time.Duration
13+
handshakeDeadline time.Time
14+
maxDeadline time.Time
15+
closeCanceler context.CancelFunc
1516
}
1617

1718
func (c *serverConn) Write(p []byte) (n int, err error) {
18-
c.updateDeadline()
19+
if c.idleTimeout > 0 {
20+
c.updateDeadline()
21+
}
1922
n, err = c.Conn.Write(p)
2023
if _, isNetErr := err.(net.Error); isNetErr && c.closeCanceler != nil {
2124
c.closeCanceler()
@@ -24,7 +27,9 @@ func (c *serverConn) Write(p []byte) (n int, err error) {
2427
}
2528

2629
func (c *serverConn) Read(b []byte) (n int, err error) {
27-
c.updateDeadline()
30+
if c.idleTimeout > 0 {
31+
c.updateDeadline()
32+
}
2833
n, err = c.Conn.Read(b)
2934
if _, isNetErr := err.(net.Error); isNetErr && c.closeCanceler != nil {
3035
c.closeCanceler()
@@ -41,15 +46,18 @@ func (c *serverConn) Close() (err error) {
4146
}
4247

4348
func (c *serverConn) updateDeadline() {
44-
switch {
45-
case c.idleTimeout > 0:
49+
deadline := c.maxDeadline
50+
51+
if !c.handshakeDeadline.IsZero() && (deadline.IsZero() || c.handshakeDeadline.Before(deadline)) {
52+
deadline = c.handshakeDeadline
53+
}
54+
55+
if c.idleTimeout > 0 {
4656
idleDeadline := time.Now().Add(c.idleTimeout)
47-
if idleDeadline.Unix() < c.maxDeadline.Unix() || c.maxDeadline.IsZero() {
48-
c.Conn.SetDeadline(idleDeadline)
49-
return
57+
if deadline.IsZero() || idleDeadline.Before(deadline) {
58+
deadline = idleDeadline
5059
}
51-
fallthrough
52-
default:
53-
c.Conn.SetDeadline(c.maxDeadline)
5460
}
61+
62+
c.Conn.SetDeadline(deadline)
5563
}

Diff for: server.go

+9-3
Original file line numberDiff line numberDiff line change
@@ -52,8 +52,9 @@ type Server struct {
5252

5353
ConnectionFailedCallback ConnectionFailedCallback // callback to report connection failures
5454

55-
IdleTimeout time.Duration // connection timeout when no activity, none if empty
56-
MaxTimeout time.Duration // absolute connection timeout, none if empty
55+
HandshakeTimeout time.Duration // connection timeout until successful handshake, none if empty
56+
IdleTimeout time.Duration // connection timeout when no activity, none if empty
57+
MaxTimeout time.Duration // absolute connection timeout, none if empty
5758

5859
// ChannelHandlers allow overriding the built-in session handlers or provide
5960
// extensions to the protocol, such as tcpip forwarding. By default only the
@@ -290,6 +291,10 @@ func (srv *Server) HandleConn(newConn net.Conn) {
290291
if srv.MaxTimeout > 0 {
291292
conn.maxDeadline = time.Now().Add(srv.MaxTimeout)
292293
}
294+
if srv.HandshakeTimeout > 0 {
295+
conn.handshakeDeadline = time.Now().Add(srv.HandshakeTimeout)
296+
}
297+
conn.updateDeadline()
293298
defer conn.Close()
294299
sshConn, chans, reqs, err := gossh.NewServerConn(conn, srv.config(ctx))
295300
if err != nil {
@@ -298,7 +303,8 @@ func (srv *Server) HandleConn(newConn net.Conn) {
298303
}
299304
return
300305
}
301-
306+
conn.handshakeDeadline = time.Time{}
307+
conn.updateDeadline()
302308
srv.trackConn(sshConn, true)
303309
defer srv.trackConn(sshConn, false)
304310

Diff for: server_test.go

+34
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import (
44
"bytes"
55
"context"
66
"io"
7+
"net"
78
"testing"
89
"time"
910
)
@@ -124,3 +125,36 @@ func TestServerClose(t *testing.T) {
124125
return
125126
}
126127
}
128+
129+
func TestServerHandshakeTimeout(t *testing.T) {
130+
l := newLocalListener()
131+
132+
s := &Server{
133+
HandshakeTimeout: time.Millisecond,
134+
}
135+
go func() {
136+
if err := s.Serve(l); err != nil {
137+
t.Error(err)
138+
}
139+
}()
140+
141+
conn, err := net.Dial("tcp", l.Addr().String())
142+
if err != nil {
143+
t.Fatal(err)
144+
}
145+
defer conn.Close()
146+
147+
ch := make(chan struct{})
148+
go func() {
149+
defer close(ch)
150+
io.Copy(io.Discard, conn)
151+
}()
152+
153+
select {
154+
case <-ch:
155+
return
156+
case <-time.After(time.Second):
157+
t.Fatal("client connection was not force-closed")
158+
return
159+
}
160+
}

0 commit comments

Comments
 (0)