Skip to content

Commit 9f9e974

Browse files
authored
Merge pull request #2 from neandertech/fix-race-condition
Adds a function to manually open the channel
2 parents 8bc0e75 + 3d83715 commit 9f9e974

File tree

3 files changed

+34
-12
lines changed

3 files changed

+34
-12
lines changed

fs2/src/jsonrpclib/fs2/FS2Channel.scala

+18-8
Original file line numberDiff line numberDiff line change
@@ -14,37 +14,43 @@ import jsonrpclib.internals.MessageDispatcher
1414
import jsonrpclib.internals._
1515

1616
import scala.util.Try
17+
import _root_.fs2.concurrent.SignallingRef
1718

1819
trait FS2Channel[F[_]] extends Channel[F] {
1920
def withEndpoint(endpoint: Endpoint[F])(implicit F: Functor[F]): Resource[F, Unit] =
2021
Resource.make(mountEndpoint(endpoint))(_ => unmountEndpoint(endpoint.method))
2122

2223
def withEndpoints(endpoint: Endpoint[F], rest: Endpoint[F]*)(implicit F: Monad[F]): Resource[F, Unit] =
2324
(endpoint :: rest.toList).traverse_(withEndpoint)
25+
26+
def open: Resource[F, Unit]
27+
def openStream: Stream[F, Unit]
2428
}
2529

2630
object FS2Channel {
2731

2832
def lspCompliant[F[_]: Concurrent](
2933
byteStream: Stream[F, Byte],
3034
byteSink: Pipe[F, Byte, Nothing],
31-
startingEndpoints: List[Endpoint[F]] = List.empty,
3235
bufferSize: Int = 512
3336
): Stream[F, FS2Channel[F]] = internals.LSP.writeSink(byteSink, bufferSize).flatMap { sink =>
34-
apply[F](internals.LSP.readStream(byteStream), sink, startingEndpoints)
37+
apply[F](internals.LSP.readStream(byteStream), sink)
3538
}
3639

3740
def apply[F[_]: Concurrent](
3841
payloadStream: Stream[F, Payload],
39-
payloadSink: Payload => F[Unit],
40-
startingEndpoints: List[Endpoint[F]] = List.empty[Endpoint[F]]
42+
payloadSink: Payload => F[Unit]
4143
): Stream[F, FS2Channel[F]] = {
42-
val endpointsMap = startingEndpoints.map(ep => ep.method -> ep).toMap
4344
for {
4445
supervisor <- Stream.resource(Supervisor[F])
45-
ref <- Ref[F].of(State[F](Map.empty, endpointsMap, 0)).toStream
46-
impl = new Impl(payloadSink, ref, supervisor)
47-
_ <- Stream(()).concurrently(payloadStream.evalMap(impl.handleReceivedPayload))
46+
ref <- Ref[F].of(State[F](Map.empty, Map.empty, 0)).toStream
47+
isOpen <- SignallingRef[F].of(false).toStream
48+
awaitingSink = isOpen.waitUntil(identity) >> payloadSink(_: Payload)
49+
impl = new Impl(awaitingSink, ref, isOpen, supervisor)
50+
_ <- Stream(()).concurrently {
51+
// Gatekeeping the pull until the channel is actually marked as open
52+
payloadStream.pauseWhen(isOpen.map(b => !b)).evalMap(impl.handleReceivedPayload)
53+
}
4854
} yield impl
4955
}
5056

@@ -72,6 +78,7 @@ object FS2Channel {
7278
private class Impl[F[_]](
7379
private val sink: Payload => F[Unit],
7480
private val state: Ref[F, FS2Channel.State[F]],
81+
private val isOpen: SignallingRef[F, Boolean],
7582
supervisor: Supervisor[F]
7683
)(implicit F: Concurrent[F])
7784
extends MessageDispatcher[F]
@@ -88,6 +95,9 @@ object FS2Channel {
8895

8996
def unmountEndpoint(method: String): F[Unit] = state.update(_.removeEndpoint(method))
9097

98+
def open: Resource[F, Unit] = Resource.make[F, Unit](isOpen.set(true))(_ => isOpen.set(false))
99+
def openStream: Stream[F, Unit] = Stream.resource(open)
100+
91101
protected def background[A](fa: F[A]): F[Unit] = supervisor.supervise(fa).void
92102
protected def reportError(params: Option[Payload], error: ProtocolError, method: String): F[Unit] = ???
93103
protected def getEndpoint(method: String): F[Option[Endpoint[F]]] = state.get.map(_.endpoints.get(method))

fs2/src/jsonrpclib/fs2/package.scala

+6
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,19 @@ package jsonrpclib
33
import _root_.fs2.Stream
44
import cats.MonadThrow
55
import cats.Monad
6+
import cats.effect.kernel.Resource
7+
import cats.effect.kernel.MonadCancel
68

79
package object fs2 {
810

911
private[jsonrpclib] implicit class EffectOps[F[_], A](private val fa: F[A]) extends AnyVal {
1012
def toStream: Stream[F, A] = Stream.eval(fa)
1113
}
1214

15+
private[jsonrpclib] implicit class ResourceOps[F[_], A](private val fa: Resource[F, A]) extends AnyVal {
16+
def asStream(implicit F: MonadCancel[F, Throwable]): Stream[F, A] = Stream.resource(fa)
17+
}
18+
1319
implicit def catsMonadic[F[_]: MonadThrow]: Monadic[F] = new Monadic[F] {
1420
def doFlatMap[A, B](fa: F[A])(f: A => F[B]): F[B] = Monad[F].flatMap(fa)(f)
1521

fs2/test/src/jsonrpclib/fs2/FS2ChannelSpec.scala

+10-4
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ object FS2ChannelSpec extends SimpleIOSuite {
2121
}
2222

2323
def testRes(name: TestName)(run: Stream[IO, Expectations]): Unit =
24-
test(name)(run.compile.lastOrError)
24+
test(name)(run.compile.lastOrError.timeout(10.second))
2525

2626
testRes("Round trip") {
2727
val endpoint: Endpoint[IO] = Endpoint[IO]("inc").simple((int: IntWrapper) => IO(IntWrapper(int.int + 1)))
@@ -31,8 +31,10 @@ object FS2ChannelSpec extends SimpleIOSuite {
3131
stdin <- Queue.bounded[IO, Payload](10).toStream
3232
serverSideChannel <- FS2Channel[IO](Stream.fromQueueUnterminated(stdin), stdout.offer)
3333
clientSideChannel <- FS2Channel[IO](Stream.fromQueueUnterminated(stdout), stdin.offer)
34-
_ <- Stream.resource(serverSideChannel.withEndpoint(endpoint))
34+
_ <- serverSideChannel.withEndpoint(endpoint).asStream
3535
remoteFunction = clientSideChannel.simpleStub[IntWrapper, IntWrapper]("inc")
36+
_ <- serverSideChannel.open.asStream
37+
_ <- clientSideChannel.open.asStream
3638
result <- remoteFunction(IntWrapper(1)).toStream
3739
} yield {
3840
expect.same(result, IntWrapper(2))
@@ -44,9 +46,11 @@ object FS2ChannelSpec extends SimpleIOSuite {
4446
for {
4547
stdout <- Queue.bounded[IO, Payload](10).toStream
4648
stdin <- Queue.bounded[IO, Payload](10).toStream
47-
_ <- FS2Channel[IO](Stream.fromQueueUnterminated(stdin), stdout.offer)
49+
serverSideChannel <- FS2Channel[IO](Stream.fromQueueUnterminated(stdin), stdout.offer)
4850
clientSideChannel <- FS2Channel[IO](Stream.fromQueueUnterminated(stdout), stdin.offer)
4951
remoteFunction = clientSideChannel.simpleStub[IntWrapper, IntWrapper]("inc")
52+
_ <- serverSideChannel.open.asStream
53+
_ <- clientSideChannel.open.asStream
5054
result <- remoteFunction(IntWrapper(1)).attempt.toStream
5155
} yield {
5256
expect.same(result, Left(ErrorPayload(-32601, "Method inc not found", None)))
@@ -65,8 +69,10 @@ object FS2ChannelSpec extends SimpleIOSuite {
6569
stdin <- Queue.bounded[IO, Payload](10).toStream
6670
serverSideChannel <- FS2Channel[IO](Stream.fromQueueUnterminated(stdin), payload => stdout.offer(payload))
6771
clientSideChannel <- FS2Channel[IO](Stream.fromQueueUnterminated(stdout), payload => stdin.offer(payload))
68-
_ <- Stream.resource(serverSideChannel.withEndpoint(endpoint))
72+
_ <- serverSideChannel.withEndpoint(endpoint).asStream
6973
remoteFunction = clientSideChannel.simpleStub[IntWrapper, IntWrapper]("inc")
74+
_ <- serverSideChannel.open.asStream
75+
_ <- clientSideChannel.open.asStream
7076
timedResults <- (1 to 10).toList.map(IntWrapper(_)).parTraverse(remoteFunction).timed.toStream
7177
} yield {
7278
val (time, results) = timedResults

0 commit comments

Comments
 (0)