Skip to content

Commit d7766a8

Browse files
committed
Fix #1: avoid restoring root net namespace, instead retire the joining thread
The problem is that setns() on /proc/self/ns/net might be priviledged operation if we don't own the namespace. Therefore, avoid. This hack calls runtime.LockOSThread, and does _not_ call UnlockOSThread, forcing golang to retire the os thread joining the unpriviledged namespace. Good enough.
1 parent 5fe4673 commit d7766a8

File tree

5 files changed

+60
-43
lines changed

5 files changed

+60
-43
lines changed

go.mod

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ require (
66
github.com/cenkalti/backoff v2.2.1+incompatible // indirect
77
github.com/golang/protobuf v1.3.2 // indirect
88
github.com/opencontainers/runtime-spec v0.1.2-0.20171211145439-b2d941ef6a78
9+
github.com/wadey/gocovmerge v0.0.0-20160331181800-b5bfa59ec0ad // indirect
910
golang.org/x/sys v0.0.0-20200121082415-34d275377bf9
1011
golang.org/x/time v0.0.0-20191024005414-555d28b269f0 // indirect
1112
gvisor.dev/gvisor v0.0.0-20200118174625-b75a6be0ea24

netns_utils.go

Lines changed: 22 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -8,18 +8,26 @@ import (
88
"gvisor.dev/gvisor/runsc/specutils"
99
)
1010

11-
func joinNetNS(nsPath string) (func(), error) {
12-
runtime.LockOSThread()
13-
restoreNS, err := specutils.ApplyNS(specs.LinuxNamespace{
14-
Type: specs.NetworkNamespace,
15-
Path: nsPath,
16-
})
17-
if err != nil {
18-
runtime.UnlockOSThread()
19-
return nil, fmt.Errorf("joining net namespace %q: %v", nsPath, err)
20-
}
21-
return func() {
22-
restoreNS()
23-
runtime.UnlockOSThread()
24-
}, nil
11+
func joinNetNS(nsPath string, run func()) error {
12+
ch := make(chan error, 2)
13+
go func() {
14+
runtime.LockOSThread()
15+
_, err := specutils.ApplyNS(specs.LinuxNamespace{
16+
Type: specs.NetworkNamespace,
17+
Path: nsPath,
18+
})
19+
if err != nil {
20+
runtime.UnlockOSThread()
21+
ch <- fmt.Errorf("joining net namespace %q: %v", nsPath, err)
22+
return
23+
}
24+
run()
25+
ch <- nil
26+
}()
27+
// Here is a big hack. Avoid restoring netns. Allow golang to
28+
// reap the thread, by not calling runtime.UnlockOSThread().
29+
// This will avoid any errors from restoreNS().
30+
31+
err := <-ch
32+
return err
2533
}

stack.go

Lines changed: 37 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -26,43 +26,54 @@ import (
2626

2727
func GetTunTap(netNsPath string, ifName string) (int, bool, uint32, error) {
2828
var (
29-
restore func()
30-
err error
29+
err error
3130
)
32-
if netNsPath != "" {
33-
fmt.Fprintf(os.Stderr, "[.] Joininig netns %s\n", netNsPath)
34-
restore, err = joinNetNS(netNsPath)
35-
if err != nil {
36-
fmt.Fprintf(os.Stderr, "[!] Can't join netns %s: %s\n", netNsPath, err)
37-
return 0, false, 0, err
38-
}
39-
}
4031

41-
fmt.Fprintf(os.Stderr, "[.] Opening tun interface %s\n", ifName)
42-
mtu, err := rawfile.GetMTU(ifName)
43-
if err != nil {
44-
fmt.Fprintf(os.Stderr, "[!] GetMTU(%s) = %s\n", ifName, err)
45-
return 0, false, 0, err
32+
type tunState struct {
33+
fd int
34+
tapMode bool
35+
mtu uint32
36+
err error
4637
}
4738

48-
tapMode := false
39+
ch := make(chan tunState, 2)
40+
run := func() {
41+
fmt.Fprintf(os.Stderr, "[.] Opening tun interface %s\n", ifName)
42+
mtu, err := rawfile.GetMTU(ifName)
43+
if err != nil {
44+
fmt.Fprintf(os.Stderr, "[!] GetMTU(%s) = %s\n", ifName, err)
45+
ch <- tunState{err: err}
46+
return
47+
}
4948

50-
fd, err := tun.Open(ifName)
51-
if err != nil {
52-
tapMode = true
53-
fd, err = tun.OpenTAP(ifName)
49+
tapMode := false
50+
51+
fd, err := tun.Open(ifName)
5452
if err != nil {
55-
fmt.Fprintf(os.Stderr, "[!] open(%s) = %s\n", ifName, err)
56-
return 0, false, 0, err
53+
tapMode = true
54+
fd, err = tun.OpenTAP(ifName)
55+
if err != nil {
56+
fmt.Fprintf(os.Stderr, "[!] open(%s) = %s\n", ifName, err)
57+
ch <- tunState{err: err}
58+
return
59+
}
5760
}
61+
ch <- tunState{fd, tapMode, mtu, nil}
5862
}
5963

6064
if netNsPath != "" {
61-
fmt.Fprintf(os.Stderr, "[.] Restoring root netns\n")
62-
restore()
65+
fmt.Fprintf(os.Stderr, "[.] Joininig netns %s\n", netNsPath)
66+
err = joinNetNS(netNsPath, run)
67+
if err != nil {
68+
fmt.Fprintf(os.Stderr, "[!] Can't join netns %s: %s\n", netNsPath, err)
69+
return 0, false, 0, err
70+
}
71+
} else {
72+
run()
6373
}
6474

65-
return fd, tapMode, mtu, nil
75+
s := <-ch
76+
return s.fd, s.tapMode, s.mtu, s.err
6677
}
6778

6879
func NewStack() *stack.Stack {
@@ -206,7 +217,7 @@ func GonetDialTCP(s *stack.Stack, laddr, raddr *tcpip.FullAddress, network tcpip
206217
}
207218

208219
type GonetTCPConn struct {
209-
net.Conn
220+
*gonet.Conn
210221
ep tcpip.Endpoint
211222
}
212223

tests/base.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -226,7 +226,6 @@ def assertTcpRefusedError(self, ip="127.0.0.1", port=0):
226226
def assertStartSync(self, p):
227227
self.assertIn("[.] Join", p.stderr_line())
228228
self.assertIn("[.] Opening tun", p.stderr_line())
229-
self.assertIn("[.] Restoring roo", p.stderr_line())
230229
self.assertIn("Started", p.stderr_line())
231230

232231
def assertListenLine(self, p, in_pattern):

tests/test_basic.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,6 @@ def test_local_fwd_error_tcp(self):
3737
p = self.prun("-L %s -L %s" % (port, port))
3838
self.assertIn("[.] Join", p.stderr_line())
3939
self.assertIn("[.] Opening tun", p.stderr_line())
40-
self.assertIn("[.] Restoring roo", p.stderr_line())
4140
xport = self.assertListenLine(p, "local-fwd Local listen tcp://127.0.0.1")
4241
self.assertEqual(port, xport)
4342
# [!] Failed to listen on tcp://127.0.0.1:45295
@@ -49,7 +48,6 @@ def test_local_fwd_error_udp(self):
4948
p = self.prun("-L udp://%s -L udp://%s" % (port, port))
5049
self.assertIn("[.] Join", p.stderr_line())
5150
self.assertIn("[.] Opening tun", p.stderr_line())
52-
self.assertIn("[.] Restoring roo", p.stderr_line())
5351
xport = self.assertListenLine(p, "local-fwd Local listen udp://127.0.0.1")
5452
self.assertEqual(port, xport)
5553
# [!] Failed to listen on udp://127.0.0.1:45295

0 commit comments

Comments
 (0)