Skip to content

Commit 3a84464

Browse files
authored
Fix Concurrency issues with StateChange hooks and add/removing channels (#70)
* Made StateChangeCallbacks read-only and handling adding elements to them * Channels are read-only * Channel and Push hooks moved to read only list with copy * only run on trusty
1 parent 2d5c52f commit 3a84464

File tree

5 files changed

+171
-41
lines changed

5 files changed

+171
-41
lines changed

.travis.yml

+1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
language: java
2+
dist: trusty
23
after_success:
34
- bash <(curl -s https://codecov.io/bash)
45
jdk:

src/main/kotlin/org/phoenixframework/Push.kt

+4-8
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ class Push(
4545
var timeoutTask: DispatchWorkItem? = null
4646

4747
/** Hooks into a Push. Where .receive("ok", callback(Payload)) are stored */
48-
var receiveHooks: MutableMap<String, MutableList<((message: Message) -> Unit)>> = HashMap()
48+
var receiveHooks: MutableMap<String, List<((message: Message) -> Unit)>> = HashMap()
4949

5050
/** True if the Push has been sent */
5151
var sent: Boolean = false
@@ -93,13 +93,9 @@ class Push(
9393
// If the message has already be received, pass it to the callback
9494
receivedMessage?.let { if (hasReceived(status)) callback(it) }
9595

96-
if (receiveHooks[status] == null) {
97-
// Create a new array of hooks if no previous hook is associated with status
98-
receiveHooks[status] = arrayListOf(callback)
99-
} else {
100-
// A previous hook for this status already exists. Just append the new hook
101-
receiveHooks[status]?.add(callback)
102-
}
96+
// If a previous hook for this status already exists. Just append the new hook. If not, then
97+
// create a new array of hooks if no previous hook is associated with status
98+
receiveHooks[status] = receiveHooks[status]?.copyAndAdd(callback) ?: arrayListOf(callback)
10399

104100
return this
105101
}

src/main/kotlin/org/phoenixframework/Socket.kt

+50-18
Original file line numberDiff line numberDiff line change
@@ -33,28 +33,61 @@ import java.util.concurrent.TimeUnit
3333
typealias Payload = Map<String, Any>
3434

3535
/** Data class that holds callbacks assigned to the socket */
36-
internal data class StateChangeCallbacks(
37-
val open: MutableList<() -> Unit> = ArrayList(),
38-
val close: MutableList<() -> Unit> = ArrayList(),
39-
val error: MutableList<(Throwable, Response?) -> Unit> = ArrayList(),
40-
val message: MutableList<(Message) -> Unit> = ArrayList()
41-
) {
36+
internal class StateChangeCallbacks {
37+
38+
var open: List<() -> Unit> = ArrayList()
39+
private set
40+
var close: List<() -> Unit> = ArrayList()
41+
private set
42+
var error: List<(Throwable, Response?) -> Unit> = ArrayList()
43+
private set
44+
var message: List<(Message) -> Unit> = ArrayList()
45+
private set
46+
47+
/** Safely adds an onOpen callback */
48+
fun onOpen(callback: () -> Unit) {
49+
this.open = this.open.copyAndAdd(callback)
50+
}
51+
52+
/** Safely adds an onClose callback */
53+
fun onClose(callback: () -> Unit) {
54+
this.close = this.close.copyAndAdd(callback)
55+
}
56+
57+
/** Safely adds an onError callback */
58+
fun onError(callback: (Throwable, Response?) -> Unit) {
59+
this.error = this.error.copyAndAdd(callback)
60+
}
61+
62+
/** Safely adds an onMessage callback */
63+
fun onMessage(callback: (Message) -> Unit) {
64+
this.message = this.message.copyAndAdd(callback)
65+
}
66+
4267
/** Clears all stored callbacks */
4368
fun release() {
44-
open.clear()
45-
close.clear()
46-
error.clear()
47-
message.clear()
69+
open = emptyList()
70+
close = emptyList()
71+
error = emptyList()
72+
message = emptyList()
4873
}
4974
}
5075

76+
/** Converts the List to a MutableList, adds the value, and then returns as a read-only List */
77+
fun <T> List<T>.copyAndAdd(value: T): List<T> {
78+
val temp = this.toMutableList()
79+
temp.add(value)
80+
81+
return temp
82+
}
83+
84+
5185
/** RFC 6455: indicates a normal closure */
5286
const val WS_CLOSE_NORMAL = 1000
5387

5488
/** RFC 6455: indicates that the connection was closed abnormally */
5589
const val WS_CLOSE_ABNORMAL = 1006
5690

57-
5891
/**
5992
* Connects to a Phoenix Server
6093
*/
@@ -125,7 +158,7 @@ class Socket(
125158
internal val stateChangeCallbacks: StateChangeCallbacks = StateChangeCallbacks()
126159

127160
/** Collection of unclosed channels created by the Socket */
128-
internal var channels: MutableList<Channel> = ArrayList()
161+
internal var channels: List<Channel> = ArrayList()
129162

130163
/** Buffers messages that need to be sent once the socket has connected */
131164
internal var sendBuffer: MutableList<() -> Unit> = ArrayList()
@@ -250,19 +283,19 @@ class Socket(
250283
}
251284

252285
fun onOpen(callback: (() -> Unit)) {
253-
this.stateChangeCallbacks.open.add(callback)
286+
this.stateChangeCallbacks.onOpen(callback)
254287
}
255288

256289
fun onClose(callback: () -> Unit) {
257-
this.stateChangeCallbacks.close.add(callback)
290+
this.stateChangeCallbacks.onClose(callback)
258291
}
259292

260293
fun onError(callback: (Throwable, Response?) -> Unit) {
261-
this.stateChangeCallbacks.error.add(callback)
294+
this.stateChangeCallbacks.onError(callback)
262295
}
263296

264297
fun onMessage(callback: (Message) -> Unit) {
265-
this.stateChangeCallbacks.message.add(callback)
298+
this.stateChangeCallbacks.onMessage(callback)
266299
}
267300

268301
fun removeAllCallbacks() {
@@ -271,7 +304,7 @@ class Socket(
271304

272305
fun channel(topic: String, params: Payload = mapOf()): Channel {
273306
val channel = Channel(topic, params, this)
274-
this.channels.add(channel)
307+
this.channels = this.channels.copyAndAdd(channel)
275308

276309
return channel
277310
}
@@ -282,7 +315,6 @@ class Socket(
282315
// that does not contain the channel that was removed.
283316
this.channels = channels
284317
.filter { it.joinRef != channel.joinRef }
285-
.toMutableList()
286318
}
287319

288320
//------------------------------------------------------------------------------

src/test/kotlin/org/phoenixframework/SocketTest.kt

+108-15
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ import org.junit.jupiter.api.Nested
1919
import org.junit.jupiter.api.Test
2020
import org.mockito.Mock
2121
import org.mockito.MockitoAnnotations
22+
import org.phoenixframework.utilities.copyAndRemove
2223
import java.net.URL
2324
import java.util.concurrent.TimeUnit
2425

@@ -397,7 +398,6 @@ class SocketTest {
397398
channel1.join().trigger("ok", emptyMap())
398399
channel2.join().trigger("ok", emptyMap())
399400

400-
401401
var chan1Called = false
402402
channel1.onError { chan1Called = true }
403403

@@ -756,8 +756,8 @@ class SocketTest {
756756
val spy = spy(channel)
757757

758758
// Use the spy instance instead of the Channel instance
759-
socket.channels.remove(channel)
760-
socket.channels.add(spy)
759+
socket.channels = socket.channels.copyAndRemove(channel)
760+
socket.channels = socket.channels.copyAndAdd(spy)
761761

762762
spy.join()
763763
assertThat(spy.state).isEqualTo(Channel.State.JOINING)
@@ -772,8 +772,8 @@ class SocketTest {
772772
val spy = spy(channel)
773773

774774
// Use the spy instance instead of the Channel instance
775-
socket.channels.remove(channel)
776-
socket.channels.add(spy)
775+
socket.channels = socket.channels.copyAndRemove(channel)
776+
socket.channels = socket.channels.copyAndAdd(spy)
777777

778778
spy.join().trigger("ok", emptyMap())
779779

@@ -789,8 +789,8 @@ class SocketTest {
789789
val spy = spy(channel)
790790

791791
// Use the spy instance instead of the Channel instance
792-
socket.channels.remove(channel)
793-
socket.channels.add(spy)
792+
socket.channels = socket.channels.copyAndRemove(channel)
793+
socket.channels = socket.channels.copyAndAdd(spy)
794794

795795
spy.join().trigger("ok", emptyMap())
796796
spy.leave()
@@ -828,8 +828,8 @@ class SocketTest {
828828
val spy = spy(channel)
829829

830830
// Use the spy instance instead of the Channel instance
831-
socket.channels.remove(channel)
832-
socket.channels.add(spy)
831+
socket.channels = socket.channels.copyAndRemove(channel)
832+
socket.channels = socket.channels.copyAndAdd(spy)
833833

834834
spy.join()
835835
assertThat(spy.state).isEqualTo(Channel.State.JOINING)
@@ -844,8 +844,8 @@ class SocketTest {
844844
val spy = spy(channel)
845845

846846
// Use the spy instance instead of the Channel instance
847-
socket.channels.remove(channel)
848-
socket.channels.add(spy)
847+
socket.channels = socket.channels.copyAndRemove(channel)
848+
socket.channels = socket.channels.copyAndAdd(spy)
849849

850850
spy.join().trigger("ok", emptyMap())
851851

@@ -861,8 +861,8 @@ class SocketTest {
861861
val spy = spy(channel)
862862

863863
// Use the spy instance instead of the Channel instance
864-
socket.channels.remove(channel)
865-
socket.channels.add(spy)
864+
socket.channels = socket.channels.copyAndRemove(channel)
865+
socket.channels = socket.channels.copyAndAdd(spy)
866866

867867
spy.join().trigger("ok", emptyMap())
868868
spy.leave()
@@ -886,8 +886,8 @@ class SocketTest {
886886
val otherChannel = mock<Channel>()
887887
whenever(otherChannel.isMember(any())).thenReturn(false)
888888

889-
socket.channels.add(targetChannel)
890-
socket.channels.add(otherChannel)
889+
socket.channels = socket.channels.copyAndAdd(targetChannel)
890+
socket.channels = socket.channels.copyAndRemove(otherChannel)
891891

892892
val rawMessage =
893893
"{\"topic\":\"topic\",\"event\":\"event\",\"payload\":{\"one\":\"two\"},\"status\":\"ok\"}"
@@ -923,4 +923,97 @@ class SocketTest {
923923
/* End OnConnectionMessage */
924924
}
925925

926+
927+
@Nested
928+
@DisplayName("ConcurrentModificationException")
929+
inner class ConcurrentModificationExceptionTests {
930+
931+
@Test
932+
internal fun `onOpen does not throw`() {
933+
var oneCalled = 0
934+
var twoCalled = 0
935+
socket.onOpen {
936+
socket.onOpen { twoCalled += 1 }
937+
oneCalled += 1
938+
}
939+
940+
socket.onConnectionOpened()
941+
assertThat(oneCalled).isEqualTo(1)
942+
assertThat(twoCalled).isEqualTo(0)
943+
944+
socket.onConnectionOpened()
945+
assertThat(oneCalled).isEqualTo(2)
946+
assertThat(twoCalled).isEqualTo(1)
947+
}
948+
949+
@Test
950+
internal fun `onClose does not throw`() {
951+
var oneCalled = 0
952+
var twoCalled = 0
953+
socket.onClose {
954+
socket.onClose { twoCalled += 1 }
955+
oneCalled += 1
956+
}
957+
958+
socket.onConnectionClosed(1000)
959+
assertThat(oneCalled).isEqualTo(1)
960+
assertThat(twoCalled).isEqualTo(0)
961+
962+
socket.onConnectionClosed(1001)
963+
assertThat(oneCalled).isEqualTo(2)
964+
assertThat(twoCalled).isEqualTo(1)
965+
}
966+
967+
@Test
968+
internal fun `onError does not throw`() {
969+
var oneCalled = 0
970+
var twoCalled = 0
971+
socket.onError { _, _->
972+
socket.onError { _, _ -> twoCalled += 1 }
973+
oneCalled += 1
974+
}
975+
976+
socket.onConnectionError(Throwable(), null)
977+
assertThat(oneCalled).isEqualTo(1)
978+
assertThat(twoCalled).isEqualTo(0)
979+
980+
socket.onConnectionError(Throwable(), null)
981+
assertThat(oneCalled).isEqualTo(2)
982+
assertThat(twoCalled).isEqualTo(1)
983+
}
984+
985+
@Test
986+
internal fun `onMessage does not throw`() {
987+
var oneCalled = 0
988+
var twoCalled = 0
989+
socket.onMessage {
990+
socket.onMessage { twoCalled += 1 }
991+
oneCalled += 1
992+
}
993+
994+
socket.onConnectionMessage("{\"status\":\"ok\"}")
995+
assertThat(oneCalled).isEqualTo(1)
996+
assertThat(twoCalled).isEqualTo(0)
997+
998+
socket.onConnectionMessage("{\"status\":\"ok\"}")
999+
assertThat(oneCalled).isEqualTo(2)
1000+
assertThat(twoCalled).isEqualTo(1)
1001+
}
1002+
1003+
@Test
1004+
internal fun `does not throw when adding channel`() {
1005+
var oneCalled = 0
1006+
socket.onOpen {
1007+
val channel = socket.channel("foo")
1008+
oneCalled += 1
1009+
}
1010+
1011+
socket.onConnectionOpened()
1012+
assertThat(oneCalled).isEqualTo(1)
1013+
}
1014+
1015+
/* End ConcurrentModificationExceptionTests */
1016+
}
1017+
1018+
9261019
}

src/test/kotlin/org/phoenixframework/utilities/TestUtilities.kt

+8
Original file line numberDiff line numberDiff line change
@@ -5,4 +5,12 @@ import org.phoenixframework.Channel
55

66
fun Channel.getBindings(event: String): List<Binding> {
77
return bindings.toList().filter { it.event == event }
8+
}
9+
10+
/** Converts the List to a MutableList, removes the value, and then returns as a read-only List */
11+
fun <T> List<T>.copyAndRemove(value: T): List<T> {
12+
val temp = this.toMutableList()
13+
temp.remove(value)
14+
15+
return temp
816
}

0 commit comments

Comments
 (0)