Skip to content

Commit 6bba33e

Browse files
committed
Merge branch 'v3' into v3-viam
2 parents 4885dc4 + cef1db8 commit 6bba33e

File tree

3 files changed

+102
-6
lines changed

3 files changed

+102
-6
lines changed

peerconnection.go

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2099,6 +2099,7 @@ func (pc *PeerConnection) close(shouldGracefullyClose bool) error {
20992099
}
21002100
if shouldGracefullyClose && !alreadyGracefullyClosed {
21012101
defer close(pc.isGracefulClosedDone)
2102+
pc.ops.GracefulClose()
21022103
}
21032104

21042105
// https://www.w3.org/TR/webrtc/#dom-rtcpeerconnection-close (step #3)
@@ -2155,8 +2156,6 @@ func (pc *PeerConnection) close(shouldGracefullyClose bool) error {
21552156
pc.updateConnectionState(pc.ICEConnectionState(), pc.dtlsTransport.State())
21562157

21572158
if shouldGracefullyClose {
2158-
pc.ops.GracefulClose()
2159-
21602159
// note that it isn't canon to stop gracefully
21612160
pc.sctpTransport.lock.Lock()
21622161
for _, d := range pc.sctpTransport.dataChannels {

sctptransport.go

Lines changed: 29 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ type SCTPTransport struct {
4545
// OnStateChange func()
4646

4747
onErrorHandler func(error)
48+
onCloseHandler func(error)
4849

4950
sctpAssociation *sctp.Association
5051
onDataChannelHandler func(*DataChannel)
@@ -174,6 +175,7 @@ func (r *SCTPTransport) acceptDataChannels(a *sctp.Association) {
174175
dataChannels = append(dataChannels, dc.dataChannel)
175176
}
176177
r.lock.RUnlock()
178+
177179
ACCEPT:
178180
for {
179181
dc, err := datachannel.Accept(a, &datachannel.Config{
@@ -183,6 +185,9 @@ ACCEPT:
183185
if !errors.Is(err, io.EOF) {
184186
r.log.Errorf("Failed to accept data channel: %v", err)
185187
r.onError(err)
188+
r.onClose(err)
189+
} else {
190+
r.onClose(nil)
186191
}
187192
return
188193
}
@@ -230,9 +235,14 @@ ACCEPT:
230235
MaxRetransmits: maxRetransmits,
231236
}, r, r.api.settingEngine.LoggerFactory.NewLogger("ortc"))
232237
if err != nil {
238+
// This data channel is invalid. Close it and log an error.
239+
if err1 := dc.Close(); err1 != nil {
240+
r.log.Errorf("Failed to close invalid data channel: %v", err1)
241+
}
233242
r.log.Errorf("Failed to accept data channel: %v", err)
234243
r.onError(err)
235-
return
244+
// We've received a datachannel with invalid configuration. We can still receive other datachannels.
245+
continue ACCEPT
236246
}
237247

238248
<-r.onDataChannel(rtcDC)
@@ -249,8 +259,7 @@ ACCEPT:
249259
}
250260
}
251261

252-
// OnError sets an event handler which is invoked when
253-
// the SCTP connection error occurs.
262+
// OnError sets an event handler which is invoked when the SCTP Association errors.
254263
func (r *SCTPTransport) OnError(f func(err error)) {
255264
r.lock.Lock()
256265
defer r.lock.Unlock()
@@ -267,6 +276,23 @@ func (r *SCTPTransport) onError(err error) {
267276
}
268277
}
269278

279+
// OnClose sets an event handler which is invoked when the SCTP Association closes.
280+
func (r *SCTPTransport) OnClose(f func(err error)) {
281+
r.lock.Lock()
282+
defer r.lock.Unlock()
283+
r.onCloseHandler = f
284+
}
285+
286+
func (r *SCTPTransport) onClose(err error) {
287+
r.lock.RLock()
288+
handler := r.onCloseHandler
289+
r.lock.RUnlock()
290+
291+
if handler != nil {
292+
go handler(err)
293+
}
294+
}
295+
270296
// OnDataChannel sets an event handler which is invoked when a data
271297
// channel message arrives from a remote peer.
272298
func (r *SCTPTransport) OnDataChannel(f func(*DataChannel)) {

sctptransport_test.go

Lines changed: 72 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,13 @@
66

77
package webrtc
88

9-
import "testing"
9+
import (
10+
"bytes"
11+
"testing"
12+
"time"
13+
14+
"github.com/stretchr/testify/require"
15+
)
1016

1117
func TestGenerateDataChannelID(t *testing.T) {
1218
sctpTransportWithChannels := func(ids []uint16) *SCTPTransport {
@@ -48,3 +54,68 @@ func TestGenerateDataChannelID(t *testing.T) {
4854
}
4955
}
5056
}
57+
58+
func TestSCTPTransportOnClose(t *testing.T) {
59+
offerPC, answerPC, err := newPair()
60+
require.NoError(t, err)
61+
62+
defer closePairNow(t, offerPC, answerPC)
63+
64+
answerPC.OnDataChannel(func(dc *DataChannel) {
65+
dc.OnMessage(func(_ DataChannelMessage) {
66+
if err1 := dc.Send([]byte("hello")); err1 != nil {
67+
t.Error("failed to send message")
68+
}
69+
})
70+
})
71+
72+
recvMsg := make(chan struct{}, 1)
73+
offerPC.OnConnectionStateChange(func(state PeerConnectionState) {
74+
if state == PeerConnectionStateConnected {
75+
defer func() {
76+
offerPC.OnConnectionStateChange(nil)
77+
}()
78+
79+
dc, createErr := offerPC.CreateDataChannel(expectedLabel, nil)
80+
if createErr != nil {
81+
t.Errorf("Failed to create a PC pair for testing")
82+
return
83+
}
84+
dc.OnMessage(func(msg DataChannelMessage) {
85+
if !bytes.Equal(msg.Data, []byte("hello")) {
86+
t.Error("invalid msg received")
87+
}
88+
recvMsg <- struct{}{}
89+
})
90+
dc.OnOpen(func() {
91+
if err1 := dc.Send([]byte("hello")); err1 != nil {
92+
t.Error("failed to send initial msg", err1)
93+
}
94+
})
95+
}
96+
})
97+
98+
err = signalPair(offerPC, answerPC)
99+
require.NoError(t, err)
100+
101+
select {
102+
case <-recvMsg:
103+
case <-time.After(5 * time.Second):
104+
t.Fatal("timed out")
105+
}
106+
107+
// setup SCTP OnClose callback
108+
ch := make(chan error, 1)
109+
answerPC.SCTP().OnClose(func(err error) {
110+
ch <- err
111+
})
112+
113+
err = offerPC.Close() // This will trigger sctp onclose callback on remote
114+
require.NoError(t, err)
115+
116+
select {
117+
case <-ch:
118+
case <-time.After(5 * time.Second):
119+
t.Fatal("timed out")
120+
}
121+
}

0 commit comments

Comments
 (0)