From aa18217f0846319d009764a710ab37dc178a0662 Mon Sep 17 00:00:00 2001 From: Michael Pollmeier Date: Thu, 4 May 2023 08:58:21 +0200 Subject: [PATCH 1/5] server-mode-rewrite: squashed slf4j-simple for stage WIP refactor WebServiceWithWebSocket: style reimplement server mode * previously it was a very flakey since it just piped input/output streams and would fail for a random 'println' in the code or inherited code * It pretended to handle concurrent requests which was a lie. Now we're programmatically interacting with the ReplDriver * No more manual thread handling, killing of the process etc. refactor for reuse and DRY: server.ReplDriver and ReplDriver refactor: shuffle classes around more refactorings, including optional server auth handling fix and refactor ReplServerTests add test for predefined code, refactor some of the test code test synchronous api readme refactor EmbeddedReplTests: fixup --- README.md | 36 ++ build.sbt | 5 +- core/src/main/scala/replpp/Config.scala | 8 +- core/src/main/scala/replpp/ReplDriver.scala | 41 +- .../main/scala/replpp/ReplDriverBase.scala | 67 +++ core/src/main/scala/replpp/package.scala | 2 +- .../replpp/server/EmbeddedReplTests.scala | 39 +- .../scala/replpp/server/EmbeddedRepl.scala | 145 +++--- .../main/scala/replpp/server/ReplDriver.scala | 44 -- .../main/scala/replpp/server/ReplServer.scala | 184 ++----- .../scala/replpp/server/UserRunnable.scala | 72 --- .../server/WebServiceWithWebSocket.scala | 96 ++++ .../scala/replpp/server/ReplServerTests.scala | 482 ++++++++++-------- 13 files changed, 594 insertions(+), 627 deletions(-) create mode 100644 core/src/main/scala/replpp/ReplDriverBase.scala delete mode 100644 server/src/main/scala/replpp/server/ReplDriver.scala delete mode 100644 server/src/main/scala/replpp/server/UserRunnable.scala create mode 100644 server/src/main/scala/replpp/server/WebServiceWithWebSocket.scala diff --git a/README.md b/README.md index ab5fa9f..95fd09c 100644 --- a/README.md +++ b/README.md @@ -264,7 +264,43 @@ The prefix is arbitrary and is only used to specify several credentials in a sin ./scala-repl-pp --server curl http://localhost:8080/query-sync -X POST -d '{"query": "val foo = 42"}' +# {"success":true,"stdout":"val foo: Int = 42\n",...} + curl http://localhost:8080/query-sync -X POST -d '{"query": "val bar = foo + 1"}' +# {"success":true,"stdout":"val bar: Int = 43\n",...} + +curl http://localhost:8080/query-sync -X POST -d '{"query":"println(\"OMG remote code execution!!1!\")"}' +# {"success":true,"stdout":"",...}% +``` + +Predef code works with server as well: +``` +echo val foo = 99 > foo.sc +./scala-repl-pp --server --predef foo.sc + +curl -XPOST http://localhost:8080/query-sync -d '{"query":"val baz = foo + 1"}' +# {"success":true,"stdout":"val baz: Int = 100\n",...} +``` + +There's also has an asynchronous mode: +``` +./scala-repl-pp --server + +curl http://localhost:8080/query -X POST -d '{"query": "val baz = 93"}' +# {"success":true,"uuid":"e2640fcb-3193-4386-8e05-914b639c3184"}% + +curl http://localhost:8080/result/e2640fcb-3193-4386-8e05-914b639c3184 +{"success":true,"uuid":"e2640fcb-3193-4386-8e05-914b639c3184","stdout":"val baz: Int = 93\n"}% +``` + +And there's even a websocket channel that allows you to get notified when the query has finished. For more details and other use cases check out [ReplServerTests.scala](server/src/test/scala/replpp/server/ReplServerTests.scala) + +Server-specific configuration options as per `scala-repl-pp --help`: +``` +--server-host Hostname on which to expose the REPL server +--server-port Port on which to expose the REPL server +--server-auth-username Basic auth username for the REPL server +--server-auth-password Basic auth password for the REPL server ``` ## Embed into your own project diff --git a/build.sbt b/build.sbt index b660007..346794c 100644 --- a/build.sbt +++ b/build.sbt @@ -38,7 +38,10 @@ lazy val server = project.in(file("server")) lazy val all = project.in(file("all")) .dependsOn(core, server) .enablePlugins(JavaAppPackaging) - .settings(name := "scala-repl-pp-all") + .settings( + name := "scala-repl-pp-all", + libraryDependencies += "org.slf4j" % "slf4j-simple" % "2.0.7" % Optional, + ) ThisBuild / libraryDependencies ++= Seq( "org.scalatest" %% "scalatest" % ScalaTestVersion % Test, diff --git a/core/src/main/scala/replpp/Config.scala b/core/src/main/scala/replpp/Config.scala index 74a79c7..186914f 100644 --- a/core/src/main/scala/replpp/Config.scala +++ b/core/src/main/scala/replpp/Config.scala @@ -24,8 +24,8 @@ case class Config( server: Boolean = false, serverHost: String = "localhost", serverPort: Int = 8080, - serverAuthUsername: String = "", - serverAuthPassword: String = "", + serverAuthUsername: Option[String] = None, + serverAuthPassword: Option[String] = None, ) { /** inverse of `Config.parse` */ lazy val asJavaArgs: Seq[String] = { @@ -154,11 +154,11 @@ object Config { .text("Port on which to expose the REPL server") opt[String]("server-auth-username") - .action((x, c) => c.copy(serverAuthUsername = x)) + .action((x, c) => c.copy(serverAuthUsername = Option(x))) .text("Basic auth username for the REPL server") opt[String]("server-auth-password") - .action((x, c) => c.copy(serverAuthPassword = x)) + .action((x, c) => c.copy(serverAuthPassword = Option(x))) .text("Basic auth password for the REPL server") help("help") diff --git a/core/src/main/scala/replpp/ReplDriver.scala b/core/src/main/scala/replpp/ReplDriver.scala index 9d151f8..b764772 100644 --- a/core/src/main/scala/replpp/ReplDriver.scala +++ b/core/src/main/scala/replpp/ReplDriver.scala @@ -28,7 +28,7 @@ class ReplDriver(args: Array[String], greeting: Option[String], prompt: String, maxHeight: Option[Int] = None, - classLoader: Option[ClassLoader] = None) extends dotty.tools.repl.ReplDriver(args, out, classLoader) { + classLoader: Option[ClassLoader] = None) extends ReplDriverBase(args, out, classLoader) { /** Run REPL with `state` until `:quit` command found * Main difference to the 'original': different greeting, trap Ctrl-c @@ -82,43 +82,4 @@ class ReplDriver(args: Array[String], terminal.readLine(completer).linesIterator } - private def interpretInput(lines: IterableOnce[String], state: State, currentFile: Path): State = { - val parsedLines = Seq.newBuilder[String] - var resultingState = state - - def handleImportFileDirective(line: String) = { - val linesBeforeUsingFileDirective = parsedLines.result() - parsedLines.clear() - if (linesBeforeUsingFileDirective.nonEmpty) { - // interpret everything until here - val parseResult = parseInput(linesBeforeUsingFileDirective, resultingState) - resultingState = interpret(parseResult)(using resultingState) - } - - // now read and interpret the given file - val pathStr = line.trim.drop(UsingDirectives.FileDirective.length) - val path = resolveFile(currentFile, pathStr) - val linesFromFile = util.linesFromFile(path) - println(s"> importing $path (${linesFromFile.size} lines)") - resultingState = interpretInput(linesFromFile, resultingState, path) - } - - for (line <- lines.iterator) { - if (line.trim.startsWith(UsingDirectives.FileDirective)) - handleImportFileDirective(line) - else - parsedLines.addOne(line) - } - - val parseResult = parseInput(parsedLines.result(), resultingState) - resultingState = interpret(parseResult)(using resultingState) - resultingState - } - - private def parseInput(lines: IterableOnce[String], state: State): ParseResult = - parseInput(lines.iterator.mkString(lineSeparator), state) - - private def parseInput(input: String, state: State): ParseResult = - ParseResult(input)(using state) - } diff --git a/core/src/main/scala/replpp/ReplDriverBase.scala b/core/src/main/scala/replpp/ReplDriverBase.scala new file mode 100644 index 0000000..a4a1261 --- /dev/null +++ b/core/src/main/scala/replpp/ReplDriverBase.scala @@ -0,0 +1,67 @@ +package replpp + +import dotty.tools.MainGenericCompiler.classpathSeparator +import dotty.tools.dotc.Run +import dotty.tools.dotc.ast.{Positioned, tpd, untpd} +import dotty.tools.dotc.classpath.{AggregateClassPath, ClassPathFactory} +import dotty.tools.dotc.config.{Feature, JavaPlatform, Platform} +import dotty.tools.dotc.core.Comments.{ContextDoc, ContextDocstrings} +import dotty.tools.dotc.core.Contexts.{Context, ContextBase, ContextState, FreshContext, ctx, explore} +import dotty.tools.dotc.core.{Contexts, MacroClassLoader, Mode, TyperState} +import dotty.tools.io.{AbstractFile, ClassPath, ClassRepresentation} +import dotty.tools.repl.* +import org.jline.reader.* +import org.slf4j.{Logger, LoggerFactory} + +import java.io.PrintStream +import java.lang.System.lineSeparator +import java.net.URL +import java.nio.file.Path +import javax.naming.InitialContext +import scala.annotation.tailrec +import scala.collection.mutable +import scala.jdk.CollectionConverters.* +import scala.util.{Failure, Success, Try} + +abstract class ReplDriverBase(args: Array[String], out: PrintStream, classLoader: Option[ClassLoader]) + extends dotty.tools.repl.ReplDriver(args, out, classLoader) { + + protected def interpretInput(lines: IterableOnce[String], state: State, currentFile: Path): State = { + val parsedLines = Seq.newBuilder[String] + var currentState = state + + def handleImportFileDirective(line: String) = { + val linesBeforeUsingFileDirective = parsedLines.result() + parsedLines.clear() + if (linesBeforeUsingFileDirective.nonEmpty) { + // interpret everything until here + val parseResult = parseInput(linesBeforeUsingFileDirective, currentState) + currentState = interpret(parseResult)(using currentState) + } + + // now read and interpret the given file + val pathStr = line.trim.drop(UsingDirectives.FileDirective.length) + val path = resolveFile(currentFile, pathStr) + val linesFromFile = util.linesFromFile(path) + println(s"> importing $path (${linesFromFile.size} lines)") + currentState = interpretInput(linesFromFile, currentState, path) + } + + for (line <- lines.iterator) { + if (line.trim.startsWith(UsingDirectives.FileDirective)) + handleImportFileDirective(line) + else + parsedLines.addOne(line) + } + + val parseResult = parseInput(parsedLines.result(), currentState) + interpret(parseResult)(using currentState) + } + + private def parseInput(lines: IterableOnce[String], state: State): ParseResult = + parseInput(lines.iterator.mkString(lineSeparator), state) + + private def parseInput(input: String, state: State): ParseResult = + ParseResult(input)(using state) + +} diff --git a/core/src/main/scala/replpp/package.scala b/core/src/main/scala/replpp/package.scala index 03bbea6..77fc053 100644 --- a/core/src/main/scala/replpp/package.scala +++ b/core/src/main/scala/replpp/package.scala @@ -80,7 +80,7 @@ package object replpp { def allPredefCode(config: Config): String = allPredefLines(config).mkString(lineSeparator) - private def allPredefLines(config: Config): Seq[String] = { + def allPredefLines(config: Config): Seq[String] = { val resultLines = Seq.newBuilder[String] val visited = mutable.Set.empty[Path] diff --git a/server/src/it/scala/replpp/server/EmbeddedReplTests.scala b/server/src/it/scala/replpp/server/EmbeddedReplTests.scala index 26859b4..7efde6a 100644 --- a/server/src/it/scala/replpp/server/EmbeddedReplTests.scala +++ b/server/src/it/scala/replpp/server/EmbeddedReplTests.scala @@ -2,8 +2,8 @@ package replpp.server import org.scalatest.matchers.should.Matchers import org.scalatest.wordspec.AnyWordSpec - -import java.util.concurrent.Semaphore +import scala.concurrent.Await +import scala.concurrent.duration.Duration /** Moved to IntegrationTests, because of some strange interaction with ReplServerTests: * if EmbeddedReplTests would run *before* ReplServerTests, the latter would stall (forever) @@ -12,34 +12,21 @@ import java.util.concurrent.Semaphore */ class EmbeddedReplTests extends AnyWordSpec with Matchers { - "start and shutdown without hanging" in { - val shell = new EmbeddedRepl() - shell.start() - shell.shutdown() - } - "execute commands synchronously" in { - val shell = new EmbeddedRepl() - shell.start() + val repl = new EmbeddedRepl() - shell.query("val x = 0").out shouldBe "val x: Int = 0\n" - shell.query("x + 1").out shouldBe "val res1: Int = 1\n" + repl.query("val x = 0").output shouldBe "val x: Int = 0\n" + repl.query("x + 1").output shouldBe "val res0: Int = 1\n" - shell.shutdown() + repl.shutdown() } - "execute a command asynchronously" in { - val shell = new EmbeddedRepl() - val mutex = new Semaphore(0) - shell.start() - var resultOut = "uninitialized" - shell.queryAsync("val x = 0") { result => - resultOut = result.out - mutex.release() - } - mutex.acquire() - resultOut shouldBe "val x: Int = 0\n" - shell.shutdown() - } + "execute a command asynchronously" in { + val repl = new EmbeddedRepl() + val (uuid, futureResult) = repl.queryAsync("val x = 0") + val result = Await.result(futureResult, Duration.Inf) + result shouldBe "val x: Int = 0\n" + repl.shutdown() + } } diff --git a/server/src/main/scala/replpp/server/EmbeddedRepl.scala b/server/src/main/scala/replpp/server/EmbeddedRepl.scala index 12d9eda..6567d9e 100644 --- a/server/src/main/scala/replpp/server/EmbeddedRepl.scala +++ b/server/src/main/scala/replpp/server/EmbeddedRepl.scala @@ -1,110 +1,89 @@ package replpp.server +import dotty.tools.dotc.config.Printers.config import dotty.tools.repl.State import org.slf4j.{Logger, LoggerFactory} +import replpp.{Config, ReplDriverBase, pwd} -import java.io.{BufferedReader, InputStream, InputStreamReader, PipedInputStream, PipedOutputStream, PrintStream, PrintWriter} +import java.io.* +import java.nio.charset.StandardCharsets import java.util.UUID -import java.util.concurrent.{BlockingQueue, LinkedBlockingQueue, Semaphore} +import java.util.concurrent.{BlockingQueue, Executors, LinkedBlockingQueue, Semaphore} +import scala.concurrent.duration.Duration +import scala.concurrent.impl.Promise +import scala.concurrent.{Await, ExecutionContext, ExecutionContextExecutorService, Future} +import scala.util.{Failure, Success} -/** Result of executing a query, containing in particular output received on standard out. */ -case class QueryResult(out: String, uuid: UUID) extends HasUUID - -trait HasUUID { - def uuid: UUID -} - -private[server] case class Job(uuid: UUID, query: String, observer: QueryResult => Unit) - -class EmbeddedRepl(predefCode: String = "", verbose: Boolean = false) { +class EmbeddedRepl(predefLines: IterableOnce[String] = Seq.empty) { private val logger: Logger = LoggerFactory.getLogger(getClass) - val jobQueue: BlockingQueue[Job] = new LinkedBlockingQueue[Job]() - - val (inStream, toStdin) = pipePair() - val (fromStdout, outStream) = pipePair() - - val writer = new PrintWriter(toStdin) - val reader = new BufferedReader(new InputStreamReader(fromStdout)) + /** repl and compiler output ends up in this replOutputStream */ + private val replOutputStream = new ByteArrayOutputStream() + + private val replDriver: ReplDriver = { + val inheritedClasspath = System.getProperty("java.class.path") + val compilerArgs = Array( + "-classpath", inheritedClasspath, + "-explain", // verbose scalac error messages + "-deprecation", + "-color", "never" + ) + new ReplDriver(compilerArgs, new PrintStream(replOutputStream), classLoader = None) + } - val userThread = new Thread(new UserRunnable(jobQueue, writer, reader, verbose)) + private var state: State = { + val state = replDriver.execute(predefLines)(using replDriver.initialState) + val output = readAndResetReplOutputStream() + if (output.nonEmpty) + logger.info(output) + state + } - val shellThread = new Thread( - new Runnable { - override def run(): Unit = { - val inheritedClasspath = System.getProperty("java.class.path") - val compilerArgs = Array( - "-classpath", inheritedClasspath, - "-explain", // verbose scalac error messages - "-deprecation", - "-color", "never" - ) + private val singleThreadedJobExecutor: ExecutionContextExecutorService = + ExecutionContext.fromExecutorService(Executors.newSingleThreadExecutor()) - val replDriver = new ReplDriver(compilerArgs, inStream, new PrintStream(outStream)) - val initialState: State = replDriver.initialState - val state: State = - if (verbose) { - println(predefCode) - replDriver.run(predefCode)(using initialState) - } else { - replDriver.runQuietly(predefCode)(using initialState) - } + /** Execute `inputLines` in REPL (in single threaded ExecutorService) and provide Future for result callback */ + def queryAsync(code: String): (UUID, Future[String]) = + queryAsync(code.linesIterator) - replDriver.runUntilQuit() - } - }) + /** Execute `inputLines` in REPL (in single threaded ExecutorService) and provide Future for result callback */ + def queryAsync(inputLines: IterableOnce[String]): (UUID, Future[String]) = { + val uuid = UUID.randomUUID() + val future = Future { + state = replDriver.execute(inputLines)(using state) + readAndResetReplOutputStream() + } (using singleThreadedJobExecutor) - private def pipePair(): (PipedInputStream, PipedOutputStream) = { - val out = new PipedOutputStream() - val in = new PipedInputStream() - in.connect(out) - (in, out) + (uuid, future) } - def start(): Unit = { - shellThread.start() - userThread.start() + private def readAndResetReplOutputStream(): String = { + val result = replOutputStream.toString(StandardCharsets.UTF_8) + replOutputStream.reset() + result } - /** Submit query `q` to shell and call `observer` when the result is ready. - */ - def queryAsync(q: String)(observer: QueryResult => Unit): UUID = { - val uuid = UUID.randomUUID() - jobQueue.add(Job(uuid, q, observer)) - uuid - } + /** Submit query to the repl, await and return results. */ + def query(code: String): QueryResult = + query(code.linesIterator) - /** Submit query `q` to the shell and return result. - */ - def query(q: String): QueryResult = { - val mutex = new Semaphore(0) - var result: QueryResult = null - queryAsync(q) { r => - result = r - mutex.release() - } - mutex.acquire() - result + /** Submit query to the repl, await and return results. */ + def query(inputLines: IterableOnce[String]): QueryResult = { + val (uuid, futureResult) = queryAsync(inputLines) + val result = Await.result(futureResult, Duration.Inf) + QueryResult(result, uuid, success = true) } /** Shutdown the embedded shell and associated threads. */ def shutdown(): Unit = { - logger.info("Trying to shutdown shell and writer thread") - shutdownShellThread() - logger.info("Shell terminated gracefully") - shutdownWriterThread() - logger.info("Writer thread terminated gracefully") - - def shutdownWriterThread(): Unit = { - jobQueue.add(Job(null, null, null)) - userThread.join() - } - def shutdownShellThread(): Unit = { - writer.println(":exit") - writer.close() - shellThread.join() - } + logger.info("shutting down") + singleThreadedJobExecutor.shutdown() } +} +class ReplDriver(args: Array[String], out: PrintStream, classLoader: Option[ClassLoader]) + extends ReplDriverBase(args, out, classLoader) { + def execute(inputLines: IterableOnce[String])(using state: State = initialState): State = + interpretInput(inputLines, state, pwd) } diff --git a/server/src/main/scala/replpp/server/ReplDriver.scala b/server/src/main/scala/replpp/server/ReplDriver.scala deleted file mode 100644 index 18d2f1e..0000000 --- a/server/src/main/scala/replpp/server/ReplDriver.scala +++ /dev/null @@ -1,44 +0,0 @@ -package replpp.server - -import dotty.tools.dotc.core.Contexts.{Context, ContextBase, ContextState, FreshContext, ctx} -import dotty.tools.repl.{AbstractFileClassLoader, CollectTopLevelImports, Newline, ParseResult, Parsed, Quit, State} - -import java.io.{BufferedReader, InputStream, InputStreamReader, PrintStream} -import scala.annotation.tailrec - -class ReplDriver(compilerArgs: Array[String], - in: InputStream, - out: PrintStream = scala.Console.out, - classLoader: Option[ClassLoader] = None) extends dotty.tools.repl.ReplDriver(compilerArgs, out, classLoader) { - val reader = new BufferedReader(new InputStreamReader(in)) - - /** Run REPL with `state` until `:quit` command found - * Main difference to the 'original': different greeting, trap Ctrl-c - */ - override def runUntilQuit(using initialState: State = initialState)(): State = { - /** Blockingly read a line, getting back a parse result */ - def readLine(state: State): ParseResult = { - given Context = state.context - - try { - val line = reader.readLine() - ParseResult(line)(using state) - } catch { - case e => - e.printStackTrace() - println(s"caught exception `$e` with msg=`${e.getMessage}` -> exiting") - Quit - } - } - - @tailrec def loop(using state: State)(): State = { - val res = readLine(state) - if (res == Quit) state - else loop(using interpret(res))() - } - - runBody { - loop(using initialState)() - } - } -} diff --git a/server/src/main/scala/replpp/server/ReplServer.scala b/server/src/main/scala/replpp/server/ReplServer.scala index 28a22cb..342f63c 100644 --- a/server/src/main/scala/replpp/server/ReplServer.scala +++ b/server/src/main/scala/replpp/server/ReplServer.scala @@ -1,73 +1,81 @@ package replpp.server import cask.model.{Request, Response} -import cask.model.Response.Raw -import cask.router.Result -import java.util.concurrent.ConcurrentHashMap -import java.util.{Base64, UUID} -import replpp.{Config, allPredefCode} +import org.slf4j.{Logger, LoggerFactory} +import replpp.{Config, allPredefLines} import ujson.Obj +import java.io.{PrintWriter, StringWriter} +import java.util.{Base64, UUID} +import scala.util.{Failure, Success, Try} + +/** Result of executing a query, containing in particular output received on standard out. */ +case class QueryResult(output: String, uuid: UUID, success: Boolean) extends HasUUID + object ReplServer { + protected val logger: Logger = LoggerFactory.getLogger(getClass) def startHttpServer(config: Config): Unit = { - val predef = allPredefCode(config) - val embeddedRepl = new EmbeddedRepl(predef, replpp.verboseEnabled(config)) - embeddedRepl.start() + val authenticationMaybe = for { + username <- config.serverAuthUsername + password <-config.serverAuthPassword + } yield UsernamePasswordAuth(username, password) + + val embeddedRepl = new EmbeddedRepl(allPredefLines(config)) Runtime.getRuntime.addShutdownHook(new Thread(() => { - println("Shutting down server...") + logger.info("Shutting down server...") embeddedRepl.shutdown() })) - val server = new ReplServer( - embeddedRepl, - config.serverHost, - config.serverPort, - config.serverAuthUsername, - config.serverAuthPassword - ) - println("Starting REPL server ...") + val server = new ReplServer(embeddedRepl, config.serverHost, config.serverPort, authenticationMaybe) + logger.info("Starting REPL server ...") try { server.main(Array.empty) } catch { case _: java.net.BindException => - println(s"Could not bind socket on port ${config.serverPort} - exiting.") + logger.error(s"Could not bind socket on port ${config.serverPort} - exiting.") embeddedRepl.shutdown() System.exit(1) case e: Throwable => - println("Unhandled exception thrown while attempting to start server: ") - println(e.getMessage) - println("Exiting.") + logger.error("Unhandled exception thrown while attempting to start server - exiting", e) embeddedRepl.shutdown() System.exit(1) } } - } class ReplServer(repl: EmbeddedRepl, - serverHost: String, - serverPort: Int, - serverAuthUsername: String = "", - serverAuthPassword: String = "" -) extends WebServiceWithWebSocket[QueryResult](serverHost, serverPort, serverAuthUsername, serverAuthPassword) { + host: String, + port: Int, + authenticationMaybe: Option[UsernamePasswordAuth] = None) + extends WebServiceWithWebSocket[QueryResult](host, port, authenticationMaybe) { @cask.websocket("/connect") override def handler(): cask.WebsocketResult = super.handler() @basicAuth() @cask.get("/result/:uuidParam") - override def getResult(uuidParam: String)(isAuthorized: Boolean): Response[Obj] = - super.getResult(uuidParam)(isAuthorized) + override def getResult(uuidParam: String)(isAuthorized: Boolean): Response[Obj] = { + val response = super.getResult(uuidParam)(isAuthorized) + logger.debug(s"GET /result/$uuidParam: statusCode=${response.statusCode}") + response + } @basicAuth() @cask.postJson("/query") def postQuery(query: String)(isAuthorized: Boolean): Response[Obj] = { if (!isAuthorized) unauthorizedResponse else { - val uuid = repl.queryAsync(query) { result => - returnResult(result) + val (uuid, resultFuture) = repl.queryAsync(query.linesIterator) + logger.debug(s"query[uuid=$uuid, length=${query.length}]: submitted to queue") + resultFuture.onComplete { + case Success(output) => + logger.debug(s"query[uuid=$uuid]: got result (length=${output.length})") + returnResult(QueryResult(output, uuid, success = true)) + case Failure(exception) => + logger.info(s"query[uuid=$uuid] failed with $exception") + returnResult(QueryResult(render(exception), uuid, success = false)) } Response(ujson.Obj("success" -> true, "uuid" -> uuid.toString), 200) } @@ -78,116 +86,22 @@ class ReplServer(repl: EmbeddedRepl, def postQuerySimple(query: String)(isAuthorized: Boolean): Response[Obj] = { if (!isAuthorized) unauthorizedResponse else { - val result = repl.query(query) - Response(ujson.Obj("success" -> true, "out" -> result.out, "uuid" -> result.uuid.toString), 200) + logger.debug(s"POST /query-sync query.length=${query.length}") + val result = repl.query(query.linesIterator) + logger.debug(s"query-sync: got result: length=${result.output.length}") + Response(ujson.Obj("success" -> true, "stdout" -> result.output, "uuid" -> result.uuid.toString), 200) } } override def resultToJson(result: QueryResult, success: Boolean): Obj = { - ujson.Obj("success" -> success, "uuid" -> result.uuid.toString, "stdout" -> result.out) - } - - initialize() -} - -abstract class WebServiceWithWebSocket[T <: HasUUID]( - serverHost: String, - serverPort: Int, - serverAuthUsername: String = "", - serverAuthPassword: String = "" -) extends cask.MainRoutes { - - class basicAuth extends cask.RawDecorator { - - def wrapFunction(ctx: Request, delegate: Delegate): Result[Raw] = { - val authString = requestToAuthString(ctx) - val Array(user, password): Array[String] = authStringToUserAndPwd(authString) - val isAuthorized = - if (serverAuthUsername == "" && serverAuthPassword == "") - true - else - user == serverAuthUsername && password == serverAuthPassword - delegate(Map("isAuthorized" -> isAuthorized)) - } - - private def requestToAuthString(ctx: Request): String = { - try { - val authHeader = ctx.exchange.getRequestHeaders.get("authorization").getFirst - val strippedHeader = authHeader.replaceFirst("Basic ", "") - new String(Base64.getDecoder.decode(strippedHeader)) - } catch { - case _: Exception => "" - } - } - - private def authStringToUserAndPwd(authString: String): Array[String] = { - val split = authString.split(":") - if (split.length == 2) { - Array(split(0), split(1)) - } else { - Array("", "") - } - } + ujson.Obj("success" -> success, "uuid" -> result.uuid.toString, "stdout" -> result.output) } - override def port: Int = serverPort - - override def host: String = serverHost - - var openConnections = Set.empty[cask.WsChannelActor] - val resultMap = new ConcurrentHashMap[UUID, (T, Boolean)]() - val unauthorizedResponse: Response[Obj] = Response(ujson.Obj(), 401, headers = Seq("WWW-Authenticate" -> "Basic")) - - def handler(): cask.WebsocketResult = { - cask.WsHandler { connection => - // TODO this can't be called from scala3 because it's using scala2 macros... - connection.send(cask.Ws.Text("connected")) - openConnections += connection - cask.WsActor { - case cask.Ws.Error(e) => - println("Connection error: " + e.getMessage) - openConnections -= connection - case cask.Ws.Close(_, _) | cask.Ws.ChannelClosed() => - println("Connection closed.") - openConnections -= connection - } - } + private def render(throwable: Throwable): String = { + val sw = new StringWriter + throwable.printStackTrace(new PrintWriter(sw)) + throwable.getMessage() + System.lineSeparator() + sw.toString() } - def getResult(uuidParam: String)(isAuthorized: Boolean): Response[Obj] = { - val res = if (!isAuthorized) { - unauthorizedResponse - } else { - val uuid = - try { - UUID.fromString(uuidParam) - } catch { - case _: IllegalArgumentException => null - } - val finalRes = if (uuid == null) { - ujson.Obj("success" -> false, "err" -> "UUID parameter is incorrectly formatted") - } else { - val resFromMap = resultMap.remove(uuid) - if (resFromMap == null) { - ujson.Obj("success" -> false, "err" -> "No result found for specified UUID") - } else { - resultToJson(resFromMap._1, resFromMap._2) - } - } - Response(finalRes, 200) - } - res - } - - def returnResult(result: T): Unit = { - resultMap.put(result.uuid, (result, true)) - openConnections.foreach { connection => - connection.send(cask.Ws.Text(result.uuid.toString)) - } - Response(ujson.Obj("success" -> true, "uuid" -> result.uuid.toString), 200) - } - - def resultToJson(result: T, b: Boolean): Obj - initialize() -} +} \ No newline at end of file diff --git a/server/src/main/scala/replpp/server/UserRunnable.scala b/server/src/main/scala/replpp/server/UserRunnable.scala deleted file mode 100644 index 2c675b3..0000000 --- a/server/src/main/scala/replpp/server/UserRunnable.scala +++ /dev/null @@ -1,72 +0,0 @@ -package replpp.server - -import java.io.{BufferedReader, PrintWriter} -import java.lang.System.lineSeparator -import java.util.UUID -import java.util.concurrent.BlockingQueue -import org.slf4j.{Logger, LoggerFactory} -import scala.util.Try - -class UserRunnable(queue: BlockingQueue[Job], writer: PrintWriter, reader: BufferedReader, verbose: Boolean = false) - extends Runnable { - private val logger = LoggerFactory.getLogger(classOf[UserRunnable]) - private val endMarker = """.*END: ([0-9a-f\-]+)""".r - - override def run(): Unit = { - try { - var terminate = false - while (!(terminate && queue.isEmpty)) { - val job = queue.take() - if (isTerminationMarker(job)) { - terminate = true - } else { - if (verbose) println(s"executing: $job") - sendQueryToEmbeddedRepl(job) - val stdoutPair = stdOutUpToMarker() - val stdOutput = stdoutPair.get - val result = QueryResult(stdOutput, job.uuid) - if (verbose) println(s"result: $result") - job.observer(result) - } - } - } catch { - case _: InterruptedException => - logger.info("Interrupted WriterThread") - } - logger.debug("WriterThread terminated gracefully") - } - - private def isTerminationMarker(job: Job): Boolean = { - job.uuid == null && job.query == null - } - - private def sendQueryToEmbeddedRepl(job: Job): Unit = { - writer.println(job.query.trim) - writer.println(s""""END: ${job.uuid}"""") - writer.flush() - } - - private def stdOutUpToMarker(): Option[String] = { - var currentOutput: String = "" - var line = reader.readLine() - while (line != null) { - if (line.nonEmpty) { - val uuid = uuidFromLine(line) - if (uuid.isEmpty) { - currentOutput += line + lineSeparator - } else { - return Some(currentOutput) - } - } - line = reader.readLine() - } - None - } - - private def uuidFromLine(line: String): Iterator[UUID] = { - endMarker.findAllIn(line).matchData.flatMap { m => - Try { UUID.fromString(m.group(1)) }.toOption - } - } - -} diff --git a/server/src/main/scala/replpp/server/WebServiceWithWebSocket.scala b/server/src/main/scala/replpp/server/WebServiceWithWebSocket.scala new file mode 100644 index 0000000..21b9d6c --- /dev/null +++ b/server/src/main/scala/replpp/server/WebServiceWithWebSocket.scala @@ -0,0 +1,96 @@ +package replpp.server + +import cask.model.Response.Raw +import cask.model.{Request, Response} +import cask.router.Result +import org.slf4j.{Logger, LoggerFactory} +import ujson.Obj + +import java.util.concurrent.ConcurrentHashMap +import java.util.{Base64, UUID} +import scala.util.{Failure, Success, Try} + +trait HasUUID { def uuid: UUID } + +case class UsernamePasswordAuth(username: String, password: String) + +abstract class WebServiceWithWebSocket[T <: HasUUID]( + override val host: String, + override val port: Int, + authenticationMaybe: Option[UsernamePasswordAuth] = None) extends cask.MainRoutes { + protected val logger: Logger = LoggerFactory.getLogger(getClass) + + class basicAuth extends cask.RawDecorator { + def wrapFunction(request: Request, delegate: Delegate): Result[Raw] = { + val isAuthorized = authenticationMaybe match { + case None => true // no authorization required + case Some(requiredAuth) => + parseAuthentication(request) match { + case None => false // no authentication provided + case Some(providedAuth) => providedAuth == requiredAuth + } + } + delegate(Map("isAuthorized" -> isAuthorized)) + } + + private def parseAuthentication(request: Request): Option[UsernamePasswordAuth] = + Try { + val authHeader = request.exchange.getRequestHeaders.get("authorization").getFirst + val strippedHeader = authHeader.replaceFirst("Basic ", "") + val authString = new String(Base64.getDecoder.decode(strippedHeader)) + authString.split(":", 2) match { + case Array(username, password) => Some(UsernamePasswordAuth(username, password)) + case _ => None + } + }.toOption.flatten + } + + private var openConnections = Set.empty[cask.WsChannelActor] + private val resultMap = new ConcurrentHashMap[UUID, (T, Boolean)]() + protected val unauthorizedResponse = Response(ujson.Obj(), 401, headers = Seq("WWW-Authenticate" -> "Basic")) + + def handler(): cask.WebsocketResult = { + cask.WsHandler { connection => + connection.send(cask.Ws.Text("connected")) + openConnections += connection + cask.WsActor { + case cask.Ws.Error(e) => + logger.error("Connection error: " + e.getMessage) + openConnections -= connection + case cask.Ws.Close(_, _) | cask.Ws.ChannelClosed() => + logger.debug("Connection closed.") + openConnections -= connection + } + } + } + + def getResult(uuidParam: String)(isAuthorized: Boolean): Response[Obj] = { + if (!isAuthorized) { + unauthorizedResponse + } else { + Try(UUID.fromString(uuidParam)) match { + case Success(uuid) if !resultMap.containsKey(uuid) => + Response(ujson.Obj("success" -> false, "err" -> "No result (yet?) found for specified UUID"), 200) + case Success(uuid) => + val (result, success) = resultMap.remove(uuid) + Response(resultToJson(result, success), 200) + case Failure(_) => + Response(ujson.Obj("success" -> false, "err" -> "UUID parameter is incorrectly formatted"), 200) + } + } + } + + def returnResult(result: T): Unit = { + resultMap.put(result.uuid, (result, true)) + openConnections.foreach { connection => + connection.send(cask.Ws.Text(result.uuid.toString)) + } + Response(ujson.Obj("success" -> true, "uuid" -> result.uuid.toString), 200) + } + + def resultToJson(result: T, success: Boolean): Obj + + initialize() +} + + diff --git a/server/src/test/scala/replpp/server/ReplServerTests.scala b/server/src/test/scala/replpp/server/ReplServerTests.scala index 7efe56a..8aa5f76 100644 --- a/server/src/test/scala/replpp/server/ReplServerTests.scala +++ b/server/src/test/scala/replpp/server/ReplServerTests.scala @@ -1,9 +1,10 @@ package replpp.server -import cask.util.Logger.Console._ +import cask.util.Logger.Console.* import castor.Context.Simple.global import org.scalatest.matchers.should.Matchers import org.scalatest.wordspec.AnyWordSpec +import replpp.Config import requests.RequestFailedException import ujson.Value.Value @@ -11,33 +12,13 @@ import java.net.URLEncoder import java.util.UUID import java.util.concurrent.locks.{Lock, ReentrantLock} import scala.collection.mutable.ListBuffer -import scala.concurrent._ -import scala.concurrent.duration._ +import scala.concurrent.* +import scala.concurrent.duration.* class ReplServerTests extends AnyWordSpec with Matchers { private val ValidBasicAuthHeaderVal: String = "Basic dXNlcm5hbWU6cGFzc3dvcmQ=" private val DefaultPromiseAwaitTimeout: FiniteDuration = Duration(10, SECONDS) - private def postQuery(host: String, query: String, authHeaderVal: String = ValidBasicAuthHeaderVal): Value = { - val postResponse = requests.post( - s"$host/query", - data = ujson.Obj("query" -> query).toString, - headers = Seq("authorization" -> authHeaderVal) - ) - val res = - if (postResponse.bytes.length > 0) - ujson.read(postResponse.bytes) - else - ujson.Obj() - res - } - - private def getResponse(host: String, uuidParam: String, authHeaderVal: String = ValidBasicAuthHeaderVal): Value = { - val uri = s"$host/result/${URLEncoder.encode(uuidParam, "utf-8")}" - val getResponse = requests.get(uri, headers = Seq("authorization" -> authHeaderVal)) - ujson.read(getResponse.bytes) - } - /** These tests happen to fail on github actions for the windows runner with the following output: WARNING: Unable to * create a system terminal, creating a dumb terminal (enable debug logging for more information) Apr 21, 2022 * 3:08:54 PM org.jboss.threads.Version INFO: JBoss Threads version 3.1.0.Final Apr 21, 2022 3:08:55 PM @@ -56,146 +37,141 @@ class ReplServerTests extends AnyWordSpec with Matchers { info("tests were cancelled because github actions windows doesn't support them for some unknown reason...") } else { - "allow websocket connections to the `/connect` endpoint" in Fixture() { host => - val wsMsgPromise = scala.concurrent.Promise[String]() - cask.util.WsClient.connect(s"$host/connect") { case cask.Ws.Text(msg) => - wsMsgPromise.success(msg) - } - val wsMsg = Await.result(wsMsgPromise.future, DefaultPromiseAwaitTimeout) - wsMsg shouldBe "connected" - } - - "allow posting a simple query without any websocket connections established" in Fixture() { host => - val postQueryResponse = postQuery(host, "1") - postQueryResponse.obj.keySet should contain("success") - val UUIDResponse = postQueryResponse("uuid").str - UUIDResponse should not be empty - postQueryResponse("success").bool shouldBe true - } - - "disallow posting a query when request headers do not include a valid authentication value" in Fixture() { host => - assertThrows[RequestFailedException] { - postQuery(host, "1", authHeaderVal = "Basic b4df00d") + "asynchronous api" should { + "allow websocket connections to the `/connect` endpoint" in Fixture() { url => + val wsMsgPromise = scala.concurrent.Promise[String]() + cask.util.WsClient.connect(s"$url/connect") { case cask.Ws.Text(msg) => + wsMsgPromise.success(msg) + } + val wsMsg = Await.result(wsMsgPromise.future, DefaultPromiseAwaitTimeout) + wsMsg shouldBe "connected" } - } - "return a valid JSON response when trying to retrieve the result of a query without a connection" in Fixture() { - host => - val postQueryResponse = postQuery(host, "1") - postQueryResponse.obj.keySet should contain("uuid") + "allow posting a simple query without any websocket connections established" in Fixture() { url => + val postQueryResponse = postQueryAsync(url, "1") + postQueryResponse.obj.keySet should contain("success") val UUIDResponse = postQueryResponse("uuid").str - val getResultResponse = getResponse(host, UUIDResponse) - getResultResponse.obj.keySet should contain("success") - getResultResponse.obj.keySet should contain("err") - getResultResponse("success").bool shouldBe false - getResultResponse("err").str.length should not be 0 - } + UUIDResponse should not be empty + postQueryResponse("success").bool shouldBe true + } - "allow fetching the result of a completed query using its UUID" in Fixture() { host => - val wsMsgPromise = scala.concurrent.Promise[String]() - val connectedPromise = scala.concurrent.Promise[String]() - cask.util.WsClient.connect(s"$host/connect") { - case cask.Ws.Text(msg) if msg == "connected" => - connectedPromise.success(msg) - case cask.Ws.Text(msg) => - wsMsgPromise.success(msg) + "disallow posting a query when request headers do not include a valid authentication value" in Fixture() { url => + assertThrows[RequestFailedException] { + postQueryAsync(url, "1", authHeaderVal = "Basic b4df00d") + } } - Await.result(connectedPromise.future, DefaultPromiseAwaitTimeout) - val postQueryResponse = postQuery(host, "1") - val queryUUID = postQueryResponse("uuid").str - queryUUID.length should not be 0 - - val queryResultWSMessage = Await.result(wsMsgPromise.future, DefaultPromiseAwaitTimeout) - queryResultWSMessage.length should not be 0 - - val getResultResponse = getResponse(host, queryUUID) - getResultResponse.obj.keySet should contain("success") - getResultResponse("uuid").str shouldBe queryResultWSMessage - getResultResponse("stdout").str shouldBe "val res0: Int = 1\n" - } - "disallow fetching the result of a completed query with an invalid auth header" in Fixture() { host => - val wsMsgPromise = scala.concurrent.Promise[String]() - val connectedPromise = scala.concurrent.Promise[String]() - cask.util.WsClient.connect(s"$host/connect") { - case cask.Ws.Text(msg) if msg == "connected" => - connectedPromise.success(msg) - case cask.Ws.Text(msg) => - wsMsgPromise.success(msg) + "return a valid JSON response when trying to retrieve the result of a query without a connection" in Fixture() { + url => + val postQueryResponse = postQueryAsync(url, "val x = 10") + postQueryResponse.obj.keySet should contain("uuid") + val UUIDResponse = postQueryResponse("uuid").str + val response = getResponse(url, UUIDResponse) + response.obj.keySet should contain("success") + response.obj.keySet should contain("err") + response("success").bool shouldBe false + response("err").str.length should not be 0 } - Await.result(connectedPromise.future, DefaultPromiseAwaitTimeout) - val postQueryResponse = postQuery(host, "1") - val queryUUID = postQueryResponse("uuid").str - queryUUID.length should not be 0 - val queryResultWSMessage = Await.result(wsMsgPromise.future, DefaultPromiseAwaitTimeout) - queryResultWSMessage.length should not be 0 + "allow fetching the result of a completed query using its UUID" in Fixture() { url => + val wsMsgPromise = scala.concurrent.Promise[String]() + val connectedPromise = scala.concurrent.Promise[String]() + cask.util.WsClient.connect(s"$url/connect") { + case cask.Ws.Text(msg) if msg == "connected" => + connectedPromise.success(msg) + case cask.Ws.Text(msg) => + wsMsgPromise.success(msg) + } + Await.result(connectedPromise.future, DefaultPromiseAwaitTimeout) + val postQueryResponse = postQueryAsync(url, "1") + val queryUUID = postQueryResponse("uuid").str + queryUUID.length should not be 0 - assertThrows[RequestFailedException] { - getResponse(host, queryUUID, "Basic b4df00d") + val queryResultWSMessage = Await.result(wsMsgPromise.future, DefaultPromiseAwaitTimeout) + queryResultWSMessage.length should not be 0 + + val getResultResponse = getResponse(url, queryUUID) + getResultResponse.obj.keySet should contain("success") + getResultResponse("uuid").str shouldBe queryResultWSMessage + getResultResponse("stdout").str shouldBe "val res0: Int = 1\n" } - } - "write a well-formatted message to a websocket connection when a query has finished evaluation" in Fixture() { - host => + "use predefined code" in Fixture("val foo = 40") { url => val wsMsgPromise = scala.concurrent.Promise[String]() val connectedPromise = scala.concurrent.Promise[String]() - cask.util.WsClient.connect(s"$host/connect") { + cask.util.WsClient.connect(s"$url/connect") { case cask.Ws.Text(msg) if msg == "connected" => connectedPromise.success(msg) case cask.Ws.Text(msg) => wsMsgPromise.success(msg) } Await.result(connectedPromise.future, DefaultPromiseAwaitTimeout) - - val postQueryResponse = postQuery(host, "1") + val postQueryResponse = postQueryAsync(url, "val bar = foo + 2") val queryUUID = postQueryResponse("uuid").str queryUUID.length should not be 0 val queryResultWSMessage = Await.result(wsMsgPromise.future, DefaultPromiseAwaitTimeout) queryResultWSMessage.length should not be 0 - val getResultResponse = getResponse(host, queryUUID) + val getResultResponse = getResponse(url, queryUUID) getResultResponse.obj.keySet should contain("success") - getResultResponse.obj.keySet should contain("stdout") - getResultResponse.obj.keySet should not contain "err" getResultResponse("uuid").str shouldBe queryResultWSMessage - getResultResponse("stdout").str shouldBe "val res0: Int = 1\n" - } + getResultResponse("stdout").str shouldBe "val bar: Int = 42\n" + } - "write a well-formatted message to a websocket connection when a query failed evaluation" in Fixture() { host => - val wsMsgPromise = scala.concurrent.Promise[String]() - val connectedPromise = scala.concurrent.Promise[String]() - cask.util.WsClient.connect(s"$host/connect") { - case cask.Ws.Text(msg) if msg == "connected" => - connectedPromise.success(msg) - case cask.Ws.Text(msg) => - wsMsgPromise.success(msg) + "disallow fetching the result of a completed query with an invalid auth header" in Fixture() { url => + val wsMsgPromise = scala.concurrent.Promise[String]() + val connectedPromise = scala.concurrent.Promise[String]() + cask.util.WsClient.connect(s"$url/connect") { + case cask.Ws.Text(msg) if msg == "connected" => + connectedPromise.success(msg) + case cask.Ws.Text(msg) => + wsMsgPromise.success(msg) + } + Await.result(connectedPromise.future, DefaultPromiseAwaitTimeout) + val postQueryResponse = postQueryAsync(url, "1") + val queryUUID = postQueryResponse("uuid").str + queryUUID.length should not be 0 + + val queryResultWSMessage = Await.result(wsMsgPromise.future, DefaultPromiseAwaitTimeout) + queryResultWSMessage.length should not be 0 + + assertThrows[RequestFailedException] { + getResponse(url, queryUUID, "Basic b4df00d") + } } - Await.result(connectedPromise.future, DefaultPromiseAwaitTimeout) - val postQueryResponse = postQuery(host, "if else for loop soup // i.e., an invalid Ammonite query") - val queryUUID = postQueryResponse("uuid").str - queryUUID.length should not be 0 + "write a well-formatted message to a websocket connection when a query has finished evaluation" in Fixture() { + url => + val wsMsgPromise = scala.concurrent.Promise[String]() + val connectedPromise = scala.concurrent.Promise[String]() + cask.util.WsClient.connect(s"$url/connect") { + case cask.Ws.Text(msg) if msg == "connected" => + connectedPromise.success(msg) + case cask.Ws.Text(msg) => + wsMsgPromise.success(msg) + } + Await.result(connectedPromise.future, DefaultPromiseAwaitTimeout) - val wsMsg = Await.result(wsMsgPromise.future, DefaultPromiseAwaitTimeout) - wsMsg.length should not be 0 + val postQueryResponse = postQueryAsync(url, "1") + val queryUUID = postQueryResponse("uuid").str + queryUUID.length should not be 0 - val resp = getResponse(host, queryUUID) - resp.obj.keySet should contain("success") - resp.obj.keySet should contain("stdout") - resp.obj.keySet should not contain "err" + val queryResultWSMessage = Await.result(wsMsgPromise.future, DefaultPromiseAwaitTimeout) + queryResultWSMessage.length should not be 0 - resp("success").bool shouldBe true - resp("uuid").str shouldBe wsMsg - resp("stdout").str.length should not be 0 - } + val getResultResponse = getResponse(url, queryUUID) + getResultResponse.obj.keySet should contain("success") + getResultResponse.obj.keySet should contain("stdout") + getResultResponse.obj.keySet should not contain "err" + getResultResponse("uuid").str shouldBe queryResultWSMessage + getResultResponse("stdout").str shouldBe "val res0: Int = 1\n" + } - "write a well-formatted message to a websocket connection when a query containing an invalid char is submitted" in Fixture() { - host => + "write a well-formatted message to a websocket connection when a query failed evaluation" in Fixture() { url => val wsMsgPromise = scala.concurrent.Promise[String]() val connectedPromise = scala.concurrent.Promise[String]() - cask.util.WsClient.connect(s"$host/connect") { + cask.util.WsClient.connect(s"$url/connect") { case cask.Ws.Text(msg) if msg == "connected" => connectedPromise.success(msg) case cask.Ws.Text(msg) => @@ -203,143 +179,207 @@ class ReplServerTests extends AnyWordSpec with Matchers { } Await.result(connectedPromise.future, DefaultPromiseAwaitTimeout) - val postQueryResponse = postQuery(host, "@1") + val postQueryResponse = postQueryAsync(url, "if else for loop soup // i.e., an invalid query") val queryUUID = postQueryResponse("uuid").str queryUUID.length should not be 0 val wsMsg = Await.result(wsMsgPromise.future, DefaultPromiseAwaitTimeout) wsMsg.length should not be 0 - val resp = getResponse(host, queryUUID) + val resp = getResponse(url, queryUUID) resp.obj.keySet should contain("success") resp.obj.keySet should contain("stdout") + resp.obj.keySet should not contain "err" resp("success").bool shouldBe true resp("uuid").str shouldBe wsMsg resp("stdout").str.length should not be 0 - } - - "receive error when attempting to retrieve result with invalid uuid" in Fixture() { host => - val connectedPromise = scala.concurrent.Promise[String]() - cask.util.WsClient.connect(s"$host/connect") { case cask.Ws.Text(msg) => - connectedPromise.success(msg) } - Await.result(connectedPromise.future, Duration(1, SECONDS)) - val getResultResponse = getResponse(host, UUID.randomUUID().toString) - getResultResponse.obj.keySet should contain("success") - getResultResponse.obj.keySet should contain("err") - getResultResponse("success").bool shouldBe false - } - "return a valid JSON response when calling /result with incorrectly-formatted UUID parameter" in Fixture() { host => - val connectedPromise = scala.concurrent.Promise[String]() - cask.util.WsClient.connect(s"$host/connect") { case cask.Ws.Text(msg) => - connectedPromise.success(msg) + "write a well-formatted message to a websocket connection when a query containing an invalid char is submitted" in Fixture() { + url => + val wsMsgPromise = scala.concurrent.Promise[String]() + val connectedPromise = scala.concurrent.Promise[String]() + cask.util.WsClient.connect(s"$url/connect") { + case cask.Ws.Text(msg) if msg == "connected" => + connectedPromise.success(msg) + case cask.Ws.Text(msg) => + wsMsgPromise.success(msg) + } + Await.result(connectedPromise.future, DefaultPromiseAwaitTimeout) + + val postQueryResponse = postQueryAsync(url, "@1") + val queryUUID = postQueryResponse("uuid").str + queryUUID.length should not be 0 + + val wsMsg = Await.result(wsMsgPromise.future, DefaultPromiseAwaitTimeout) + wsMsg.length should not be 0 + + val resp = getResponse(url, queryUUID) + resp.obj.keySet should contain("success") + resp.obj.keySet should contain("stdout") + + resp("success").bool shouldBe true + resp("uuid").str shouldBe wsMsg + resp("stdout").str.length should not be 0 } - Await.result(connectedPromise.future, Duration(1, SECONDS)) - val getResultResponse = getResponse(host, "INCORRECTLY_FORMATTED_UUID_PARAM") - getResultResponse.obj.keySet should contain("success") - getResultResponse.obj.keySet should contain("err") - getResultResponse("success").bool shouldBe false - getResultResponse("err").str.length should not equal 0 - } - "return websocket responses for all queries when posted quickly in a large number" in Fixture() { host => - val numQueries = 10 - val correctNumberOfUUIDsReceived = scala.concurrent.Promise[String]() - val wsUUIDs = ListBuffer[String]() + "receive error when attempting to retrieve result with invalid uuid" in Fixture() { url => + val connectedPromise = scala.concurrent.Promise[String]() + cask.util.WsClient.connect(s"$url/connect") { case cask.Ws.Text(msg) => + connectedPromise.success(msg) + } + Await.result(connectedPromise.future, Duration(1, SECONDS)) + val getResultResponse = getResponse(url, UUID.randomUUID().toString) + getResultResponse.obj.keySet should contain("success") + getResultResponse.obj.keySet should contain("err") + getResultResponse("success").bool shouldBe false + } - val rtl: Lock = new ReentrantLock() - val connectedPromise = scala.concurrent.Promise[String]() - cask.util.WsClient.connect(s"$host/connect") { case cask.Ws.Text(msg) => - if (msg == "connected") { + "return a valid JSON response when calling /result with incorrectly-formatted UUID parameter" in Fixture() { url => + val connectedPromise = scala.concurrent.Promise[String]() + cask.util.WsClient.connect(s"$url/connect") { case cask.Ws.Text(msg) => connectedPromise.success(msg) - } else { - rtl.lock() - try { - wsUUIDs += msg - } finally { - rtl.unlock() - if (wsUUIDs.size == numQueries) { - correctNumberOfUUIDsReceived.success("") - } - } } + Await.result(connectedPromise.future, Duration(1, SECONDS)) + val getResultResponse = getResponse(url, "INCORRECTLY_FORMATTED_UUID_PARAM") + getResultResponse.obj.keySet should contain("success") + getResultResponse.obj.keySet should contain("err") + getResultResponse("success").bool shouldBe false + getResultResponse("err").str.length should not equal 0 } - Await.result(connectedPromise.future, DefaultPromiseAwaitTimeout) - val postQueriesResponseUUIDs = - for (_ <- 1 to numQueries) yield { - val postQueryResponse = postQuery(host, "1") - postQueryResponse("uuid").str + "return websocket responses for all queries when posted quickly in a large number" in Fixture() { url => + val numQueries = 10 + val correctNumberOfUUIDsReceived = scala.concurrent.Promise[String]() + val wsUUIDs = ListBuffer[String]() + + val rtl: Lock = new ReentrantLock() + val connectedPromise = scala.concurrent.Promise[String]() + cask.util.WsClient.connect(s"$url/connect") { case cask.Ws.Text(msg) => + if (msg == "connected") { + connectedPromise.success(msg) + } else { + rtl.lock() + try { + wsUUIDs += msg + } finally { + rtl.unlock() + if (wsUUIDs.size == numQueries) { + correctNumberOfUUIDsReceived.success("") + } + } + } } + Await.result(connectedPromise.future, DefaultPromiseAwaitTimeout) - Await.result(correctNumberOfUUIDsReceived.future, DefaultPromiseAwaitTimeout * numQueries.toLong) - wsUUIDs.toSet should be(postQueriesResponseUUIDs.toSet) - } + val postQueriesResponseUUIDs = + for (_ <- 1 to numQueries) yield { + val postQueryResponse = postQueryAsync(url, "1") + postQueryResponse("uuid").str + } - "return websocket responses for all queries when some are invalid" in Fixture() { host => - val queries = List("1", "1 + 1", "open(", "open)", "open{", "open}") - val correctNumberOfUUIDsReceived = scala.concurrent.Promise[String]() - val wsUUIDs = ListBuffer[String]() - val connectedPromise = scala.concurrent.Promise[String]() + Await.result(correctNumberOfUUIDsReceived.future, DefaultPromiseAwaitTimeout * numQueries.toLong) + wsUUIDs.toSet should be(postQueriesResponseUUIDs.toSet) + } - val rtl: Lock = new ReentrantLock() - cask.util.WsClient.connect(s"$host/connect") { case cask.Ws.Text(msg) => - if (msg == "connected") { - connectedPromise.success(msg) - } else { - rtl.lock() - try { - wsUUIDs += msg - } finally { - rtl.unlock() - if (wsUUIDs.size == queries.size) { - correctNumberOfUUIDsReceived.success("") + "return websocket responses for all queries when some are invalid" in Fixture() { url => + val queries = List("1", "1 + 1", "open(", "open)", "open{", "open}") + val correctNumberOfUUIDsReceived = scala.concurrent.Promise[String]() + val wsUUIDs = ListBuffer[String]() + val connectedPromise = scala.concurrent.Promise[String]() + + val rtl: Lock = new ReentrantLock() + cask.util.WsClient.connect(s"$url/connect") { case cask.Ws.Text(msg) => + if (msg == "connected") { + connectedPromise.success(msg) + } else { + rtl.lock() + try { + wsUUIDs += msg + } finally { + rtl.unlock() + if (wsUUIDs.size == queries.size) { + correctNumberOfUUIDsReceived.success("") + } } } } + Await.result(connectedPromise.future, DefaultPromiseAwaitTimeout) + + val postQueriesResponseUUIDs = queries.map { query => + postQueryAsync(url, query)("uuid").str + } + Await.result(correctNumberOfUUIDsReceived.future, DefaultPromiseAwaitTimeout * queries.size.toLong) + wsUUIDs.toSet should be(postQueriesResponseUUIDs.toSet) + } + } + + "synchronous api" should { + "work for simple case" in Fixture() { url => + val response = postQuerySync(url, "1") + response.obj.keySet should contain("success") + response("stdout").str shouldBe "val res0: Int = 1\n" } - Await.result(connectedPromise.future, DefaultPromiseAwaitTimeout) - - val postQueriesResponseUUIDs = { - queries - .map(q => { - val res = postQuery(host, q) - res("uuid").str - }) + + "using predef code" in Fixture("val predefCode = 2") { url => + val response = postQuerySync(url, "val foo = predefCode + 40") + response.obj.keySet should contain("success") + response("stdout").str shouldBe "val foo: Int = 42\n" + } + + "fail for invalid auth" in Fixture() { url => + assertThrows[RequestFailedException] { + postQuerySync(url, "1", authHeaderVal = "Basic b4df00d") + } } - Await.result(correctNumberOfUUIDsReceived.future, DefaultPromiseAwaitTimeout * queries.size.toLong) - wsUUIDs.toSet should be(postQueriesResponseUUIDs.toSet) } } + + private def postQueryAsync(baseUrl: String, query: String, authHeaderVal: String = ValidBasicAuthHeaderVal): Value = + postQuery(s"$baseUrl/query", query, authHeaderVal) + + private def postQuerySync(baseUrl: String, query: String, authHeaderVal: String = ValidBasicAuthHeaderVal): Value = + postQuery(s"$baseUrl/query-sync", query, authHeaderVal) + + private def postQuery(endpoint: String, query: String, authHeaderVal: String): Value = { + val postResponse = requests.post( + endpoint, + data = ujson.Obj("query" -> query).toString, + headers = Seq("authorization" -> authHeaderVal) + ) + if (postResponse.bytes.length > 0) + ujson.read(postResponse.bytes) + else + ujson.Obj() + } + + private def getResponse(url: String, uuidParam: String, authHeaderVal: String = ValidBasicAuthHeaderVal): Value = { + val uri = s"$url/result/${URLEncoder.encode(uuidParam, "utf-8")}" + val getResponse = requests.get(uri, headers = Seq("authorization" -> authHeaderVal)) + ujson.read(getResponse.bytes) + } + } object Fixture { - def apply[T]()(f: String => T): T = { - val embeddedRepl = new EmbeddedRepl() - embeddedRepl.start() + def apply[T](predefCode: String = "")(urlToResult: String => T): T = { + val embeddedRepl = new EmbeddedRepl(predefLines = predefCode.linesIterator) val host = "localhost" val port = 8081 - val authUsername = "username" - val authPassword = "password" - val httpEndpoint = "http://" + host + ":" + port.toString - val replServer = new ReplServer(embeddedRepl, host, port, authUsername, authPassword) + val replServer = new ReplServer(embeddedRepl, host, port, Some(UsernamePasswordAuth("username", "password"))) val server = io.undertow.Undertow.builder .addHttpListener(replServer.port, replServer.host) .setHandler(replServer.defaultHandler) .build server.start() - val res = - try { - f(httpEndpoint) - } - finally { - server.stop() - embeddedRepl.shutdown() - } - res + try { + urlToResult(s"http://$host:$port") + } finally { + server.stop() + embeddedRepl.shutdown() + } } } From e45b09bdef6216cfcea32d2fceb6288d195d652a Mon Sep 17 00:00:00 2001 From: Michael Pollmeier Date: Mon, 8 May 2023 09:34:25 +0200 Subject: [PATCH 2/5] move EmbeddedReplTests to regular (non-integration) tests --- .github/workflows/pr.yml | 2 +- .github/workflows/release.yml | 2 +- build.sbt | 4 +--- .../{it => test}/scala/replpp/server/EmbeddedReplTests.scala | 0 4 files changed, 3 insertions(+), 5 deletions(-) rename server/src/{it => test}/scala/replpp/server/EmbeddedReplTests.scala (100%) diff --git a/.github/workflows/pr.yml b/.github/workflows/pr.yml index 54849af..dc8554d 100644 --- a/.github/workflows/pr.yml +++ b/.github/workflows/pr.yml @@ -18,4 +18,4 @@ jobs: ~/.sbt ~/.coursier key: ${{ runner.os }}-sbt-${{ hashfiles('**/build.sbt') }} - - run: sbt test IntegrationTest/test + - run: sbt test diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 8610a1b..9541b87 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -26,7 +26,7 @@ jobs: ~/.sbt ~/.coursier key: ${{ runner.os }}-sbt-${{ hashfiles('**/build.sbt') }} - - run: sbt test IntegrationTest/test ciReleaseTagNextVersion ciReleaseSonatype + - run: sbt test ciReleaseTagNextVersion ciReleaseSonatype env: SONATYPE_PASSWORD: ${{ secrets.SONATYPE_PASSWORD }} SONATYPE_USERNAME: ${{ secrets.SONATYPE_USERNAME }} diff --git a/build.sbt b/build.sbt index 346794c..d02f66c 100644 --- a/build.sbt +++ b/build.sbt @@ -22,16 +22,14 @@ lazy val core = project.in(file("core")).settings( lazy val server = project.in(file("server")) .dependsOn(core) - .configs(IntegrationTest) .settings( name := "scala-repl-pp-server", - Defaults.itSettings, fork := true, // important: otherwise we run into classloader issues libraryDependencies ++= Seq( "com.lihaoyi" %% "cask" % "0.8.3", "org.slf4j" % "slf4j-simple" % "2.0.7" % Optional, "com.lihaoyi" %% "requests" % "0.8.0" % Test, - "org.scalatest" %% "scalatest" % ScalaTestVersion % "it", + "org.scalatest" %% "scalatest" % ScalaTestVersion, ) ) diff --git a/server/src/it/scala/replpp/server/EmbeddedReplTests.scala b/server/src/test/scala/replpp/server/EmbeddedReplTests.scala similarity index 100% rename from server/src/it/scala/replpp/server/EmbeddedReplTests.scala rename to server/src/test/scala/replpp/server/EmbeddedReplTests.scala From c54cad7b6d853939bff59b6acc9e2bd5725504b7 Mon Sep 17 00:00:00 2001 From: Michael Pollmeier Date: Mon, 8 May 2023 09:36:11 +0200 Subject: [PATCH 3/5] check if the strange windows build error still exists --- .../scala/replpp/server/ReplServerTests.scala | 445 +++++++++--------- 1 file changed, 213 insertions(+), 232 deletions(-) diff --git a/server/src/test/scala/replpp/server/ReplServerTests.scala b/server/src/test/scala/replpp/server/ReplServerTests.scala index 8aa5f76..7a33daa 100644 --- a/server/src/test/scala/replpp/server/ReplServerTests.scala +++ b/server/src/test/scala/replpp/server/ReplServerTests.scala @@ -19,84 +19,112 @@ class ReplServerTests extends AnyWordSpec with Matchers { private val ValidBasicAuthHeaderVal: String = "Basic dXNlcm5hbWU6cGFzc3dvcmQ=" private val DefaultPromiseAwaitTimeout: FiniteDuration = Duration(10, SECONDS) - /** These tests happen to fail on github actions for the windows runner with the following output: WARNING: Unable to - * create a system terminal, creating a dumb terminal (enable debug logging for more information) Apr 21, 2022 - * 3:08:54 PM org.jboss.threads.Version INFO: JBoss Threads version 3.1.0.Final Apr 21, 2022 3:08:55 PM - * io.undertow.server.HttpServerExchange endExchange ERROR: UT005090: Unexpected failure - * java.lang.NoClassDefFoundError: Could not initialize class org.xnio.channels.Channels at - * io.undertow.io.UndertowOutputStream.close(UndertowOutputStream.java:348) - * - * This happens for both windows 2019 and 2022, and isn't reproducable elsewhere. Explicitly adding a dependency on - * `org.jboss.xnio/xnio-api` didn't help, as well as other debug attempts. So we gave up and disabled this - * specifically for github actions' windows runner. - */ - val isGithubActions = scala.util.Properties.envOrElse("GITHUB_ACTIONS", "false").toLowerCase == "true" - val isWindows = scala.util.Properties.isWin - - if (isGithubActions && isWindows) { - info("tests were cancelled because github actions windows doesn't support them for some unknown reason...") - } else { - - "asynchronous api" should { - "allow websocket connections to the `/connect` endpoint" in Fixture() { url => - val wsMsgPromise = scala.concurrent.Promise[String]() - cask.util.WsClient.connect(s"$url/connect") { case cask.Ws.Text(msg) => - wsMsgPromise.success(msg) - } - val wsMsg = Await.result(wsMsgPromise.future, DefaultPromiseAwaitTimeout) - wsMsg shouldBe "connected" + "asynchronous api" should { + "allow websocket connections to the `/connect` endpoint" in Fixture() { url => + val wsMsgPromise = scala.concurrent.Promise[String]() + cask.util.WsClient.connect(s"$url/connect") { case cask.Ws.Text(msg) => + wsMsgPromise.success(msg) } + val wsMsg = Await.result(wsMsgPromise.future, DefaultPromiseAwaitTimeout) + wsMsg shouldBe "connected" + } - "allow posting a simple query without any websocket connections established" in Fixture() { url => - val postQueryResponse = postQueryAsync(url, "1") - postQueryResponse.obj.keySet should contain("success") - val UUIDResponse = postQueryResponse("uuid").str - UUIDResponse should not be empty - postQueryResponse("success").bool shouldBe true + "allow posting a simple query without any websocket connections established" in Fixture() { url => + val postQueryResponse = postQueryAsync(url, "1") + postQueryResponse.obj.keySet should contain("success") + val UUIDResponse = postQueryResponse("uuid").str + UUIDResponse should not be empty + postQueryResponse("success").bool shouldBe true + } + + "disallow posting a query when request headers do not include a valid authentication value" in Fixture() { url => + assertThrows[RequestFailedException] { + postQueryAsync(url, "1", authHeaderVal = "Basic b4df00d") } + } - "disallow posting a query when request headers do not include a valid authentication value" in Fixture() { url => - assertThrows[RequestFailedException] { - postQueryAsync(url, "1", authHeaderVal = "Basic b4df00d") - } + "return a valid JSON response when trying to retrieve the result of a query without a connection" in Fixture() { + url => + val postQueryResponse = postQueryAsync(url, "val x = 10") + postQueryResponse.obj.keySet should contain("uuid") + val UUIDResponse = postQueryResponse("uuid").str + val response = getResponse(url, UUIDResponse) + response.obj.keySet should contain("success") + response.obj.keySet should contain("err") + response("success").bool shouldBe false + response("err").str.length should not be 0 + } + + "allow fetching the result of a completed query using its UUID" in Fixture() { url => + val wsMsgPromise = scala.concurrent.Promise[String]() + val connectedPromise = scala.concurrent.Promise[String]() + cask.util.WsClient.connect(s"$url/connect") { + case cask.Ws.Text(msg) if msg == "connected" => + connectedPromise.success(msg) + case cask.Ws.Text(msg) => + wsMsgPromise.success(msg) } + Await.result(connectedPromise.future, DefaultPromiseAwaitTimeout) + val postQueryResponse = postQueryAsync(url, "1") + val queryUUID = postQueryResponse("uuid").str + queryUUID.length should not be 0 + + val queryResultWSMessage = Await.result(wsMsgPromise.future, DefaultPromiseAwaitTimeout) + queryResultWSMessage.length should not be 0 + + val getResultResponse = getResponse(url, queryUUID) + getResultResponse.obj.keySet should contain("success") + getResultResponse("uuid").str shouldBe queryResultWSMessage + getResultResponse("stdout").str shouldBe "val res0: Int = 1\n" + } - "return a valid JSON response when trying to retrieve the result of a query without a connection" in Fixture() { - url => - val postQueryResponse = postQueryAsync(url, "val x = 10") - postQueryResponse.obj.keySet should contain("uuid") - val UUIDResponse = postQueryResponse("uuid").str - val response = getResponse(url, UUIDResponse) - response.obj.keySet should contain("success") - response.obj.keySet should contain("err") - response("success").bool shouldBe false - response("err").str.length should not be 0 + "use predefined code" in Fixture("val foo = 40") { url => + val wsMsgPromise = scala.concurrent.Promise[String]() + val connectedPromise = scala.concurrent.Promise[String]() + cask.util.WsClient.connect(s"$url/connect") { + case cask.Ws.Text(msg) if msg == "connected" => + connectedPromise.success(msg) + case cask.Ws.Text(msg) => + wsMsgPromise.success(msg) } + Await.result(connectedPromise.future, DefaultPromiseAwaitTimeout) + val postQueryResponse = postQueryAsync(url, "val bar = foo + 2") + val queryUUID = postQueryResponse("uuid").str + queryUUID.length should not be 0 + + val queryResultWSMessage = Await.result(wsMsgPromise.future, DefaultPromiseAwaitTimeout) + queryResultWSMessage.length should not be 0 + + val getResultResponse = getResponse(url, queryUUID) + getResultResponse.obj.keySet should contain("success") + getResultResponse("uuid").str shouldBe queryResultWSMessage + getResultResponse("stdout").str shouldBe "val bar: Int = 42\n" + } - "allow fetching the result of a completed query using its UUID" in Fixture() { url => - val wsMsgPromise = scala.concurrent.Promise[String]() - val connectedPromise = scala.concurrent.Promise[String]() - cask.util.WsClient.connect(s"$url/connect") { - case cask.Ws.Text(msg) if msg == "connected" => - connectedPromise.success(msg) - case cask.Ws.Text(msg) => - wsMsgPromise.success(msg) - } - Await.result(connectedPromise.future, DefaultPromiseAwaitTimeout) - val postQueryResponse = postQueryAsync(url, "1") - val queryUUID = postQueryResponse("uuid").str - queryUUID.length should not be 0 + "disallow fetching the result of a completed query with an invalid auth header" in Fixture() { url => + val wsMsgPromise = scala.concurrent.Promise[String]() + val connectedPromise = scala.concurrent.Promise[String]() + cask.util.WsClient.connect(s"$url/connect") { + case cask.Ws.Text(msg) if msg == "connected" => + connectedPromise.success(msg) + case cask.Ws.Text(msg) => + wsMsgPromise.success(msg) + } + Await.result(connectedPromise.future, DefaultPromiseAwaitTimeout) + val postQueryResponse = postQueryAsync(url, "1") + val queryUUID = postQueryResponse("uuid").str + queryUUID.length should not be 0 - val queryResultWSMessage = Await.result(wsMsgPromise.future, DefaultPromiseAwaitTimeout) - queryResultWSMessage.length should not be 0 + val queryResultWSMessage = Await.result(wsMsgPromise.future, DefaultPromiseAwaitTimeout) + queryResultWSMessage.length should not be 0 - val getResultResponse = getResponse(url, queryUUID) - getResultResponse.obj.keySet should contain("success") - getResultResponse("uuid").str shouldBe queryResultWSMessage - getResultResponse("stdout").str shouldBe "val res0: Int = 1\n" + assertThrows[RequestFailedException] { + getResponse(url, queryUUID, "Basic b4df00d") } + } - "use predefined code" in Fixture("val foo = 40") { url => + "write a well-formatted message to a websocket connection when a query has finished evaluation" in Fixture() { + url => val wsMsgPromise = scala.concurrent.Promise[String]() val connectedPromise = scala.concurrent.Promise[String]() cask.util.WsClient.connect(s"$url/connect") { @@ -106,7 +134,8 @@ class ReplServerTests extends AnyWordSpec with Matchers { wsMsgPromise.success(msg) } Await.result(connectedPromise.future, DefaultPromiseAwaitTimeout) - val postQueryResponse = postQueryAsync(url, "val bar = foo + 2") + + val postQueryResponse = postQueryAsync(url, "1") val queryUUID = postQueryResponse("uuid").str queryUUID.length should not be 0 @@ -115,60 +144,42 @@ class ReplServerTests extends AnyWordSpec with Matchers { val getResultResponse = getResponse(url, queryUUID) getResultResponse.obj.keySet should contain("success") + getResultResponse.obj.keySet should contain("stdout") + getResultResponse.obj.keySet should not contain "err" getResultResponse("uuid").str shouldBe queryResultWSMessage - getResultResponse("stdout").str shouldBe "val bar: Int = 42\n" - } + getResultResponse("stdout").str shouldBe "val res0: Int = 1\n" + } - "disallow fetching the result of a completed query with an invalid auth header" in Fixture() { url => - val wsMsgPromise = scala.concurrent.Promise[String]() - val connectedPromise = scala.concurrent.Promise[String]() - cask.util.WsClient.connect(s"$url/connect") { - case cask.Ws.Text(msg) if msg == "connected" => - connectedPromise.success(msg) - case cask.Ws.Text(msg) => - wsMsgPromise.success(msg) - } - Await.result(connectedPromise.future, DefaultPromiseAwaitTimeout) - val postQueryResponse = postQueryAsync(url, "1") - val queryUUID = postQueryResponse("uuid").str - queryUUID.length should not be 0 + "write a well-formatted message to a websocket connection when a query failed evaluation" in Fixture() { url => + val wsMsgPromise = scala.concurrent.Promise[String]() + val connectedPromise = scala.concurrent.Promise[String]() + cask.util.WsClient.connect(s"$url/connect") { + case cask.Ws.Text(msg) if msg == "connected" => + connectedPromise.success(msg) + case cask.Ws.Text(msg) => + wsMsgPromise.success(msg) + } + Await.result(connectedPromise.future, DefaultPromiseAwaitTimeout) - val queryResultWSMessage = Await.result(wsMsgPromise.future, DefaultPromiseAwaitTimeout) - queryResultWSMessage.length should not be 0 + val postQueryResponse = postQueryAsync(url, "if else for loop soup // i.e., an invalid query") + val queryUUID = postQueryResponse("uuid").str + queryUUID.length should not be 0 - assertThrows[RequestFailedException] { - getResponse(url, queryUUID, "Basic b4df00d") - } - } + val wsMsg = Await.result(wsMsgPromise.future, DefaultPromiseAwaitTimeout) + wsMsg.length should not be 0 - "write a well-formatted message to a websocket connection when a query has finished evaluation" in Fixture() { - url => - val wsMsgPromise = scala.concurrent.Promise[String]() - val connectedPromise = scala.concurrent.Promise[String]() - cask.util.WsClient.connect(s"$url/connect") { - case cask.Ws.Text(msg) if msg == "connected" => - connectedPromise.success(msg) - case cask.Ws.Text(msg) => - wsMsgPromise.success(msg) - } - Await.result(connectedPromise.future, DefaultPromiseAwaitTimeout) + val resp = getResponse(url, queryUUID) + resp.obj.keySet should contain("success") + resp.obj.keySet should contain("stdout") + resp.obj.keySet should not contain "err" - val postQueryResponse = postQueryAsync(url, "1") - val queryUUID = postQueryResponse("uuid").str - queryUUID.length should not be 0 - - val queryResultWSMessage = Await.result(wsMsgPromise.future, DefaultPromiseAwaitTimeout) - queryResultWSMessage.length should not be 0 - - val getResultResponse = getResponse(url, queryUUID) - getResultResponse.obj.keySet should contain("success") - getResultResponse.obj.keySet should contain("stdout") - getResultResponse.obj.keySet should not contain "err" - getResultResponse("uuid").str shouldBe queryResultWSMessage - getResultResponse("stdout").str shouldBe "val res0: Int = 1\n" - } + resp("success").bool shouldBe true + resp("uuid").str shouldBe wsMsg + resp("stdout").str.length should not be 0 + } - "write a well-formatted message to a websocket connection when a query failed evaluation" in Fixture() { url => + "write a well-formatted message to a websocket connection when a query containing an invalid char is submitted" in Fixture() { + url => val wsMsgPromise = scala.concurrent.Promise[String]() val connectedPromise = scala.concurrent.Promise[String]() cask.util.WsClient.connect(s"$url/connect") { @@ -179,7 +190,7 @@ class ReplServerTests extends AnyWordSpec with Matchers { } Await.result(connectedPromise.future, DefaultPromiseAwaitTimeout) - val postQueryResponse = postQueryAsync(url, "if else for loop soup // i.e., an invalid query") + val postQueryResponse = postQueryAsync(url, "@1") val queryUUID = postQueryResponse("uuid").str queryUUID.length should not be 0 @@ -189,149 +200,119 @@ class ReplServerTests extends AnyWordSpec with Matchers { val resp = getResponse(url, queryUUID) resp.obj.keySet should contain("success") resp.obj.keySet should contain("stdout") - resp.obj.keySet should not contain "err" resp("success").bool shouldBe true resp("uuid").str shouldBe wsMsg resp("stdout").str.length should not be 0 - } - - "write a well-formatted message to a websocket connection when a query containing an invalid char is submitted" in Fixture() { - url => - val wsMsgPromise = scala.concurrent.Promise[String]() - val connectedPromise = scala.concurrent.Promise[String]() - cask.util.WsClient.connect(s"$url/connect") { - case cask.Ws.Text(msg) if msg == "connected" => - connectedPromise.success(msg) - case cask.Ws.Text(msg) => - wsMsgPromise.success(msg) - } - Await.result(connectedPromise.future, DefaultPromiseAwaitTimeout) - - val postQueryResponse = postQueryAsync(url, "@1") - val queryUUID = postQueryResponse("uuid").str - queryUUID.length should not be 0 - - val wsMsg = Await.result(wsMsgPromise.future, DefaultPromiseAwaitTimeout) - wsMsg.length should not be 0 - - val resp = getResponse(url, queryUUID) - resp.obj.keySet should contain("success") - resp.obj.keySet should contain("stdout") + } - resp("success").bool shouldBe true - resp("uuid").str shouldBe wsMsg - resp("stdout").str.length should not be 0 + "receive error when attempting to retrieve result with invalid uuid" in Fixture() { url => + val connectedPromise = scala.concurrent.Promise[String]() + cask.util.WsClient.connect(s"$url/connect") { case cask.Ws.Text(msg) => + connectedPromise.success(msg) } + Await.result(connectedPromise.future, Duration(1, SECONDS)) + val getResultResponse = getResponse(url, UUID.randomUUID().toString) + getResultResponse.obj.keySet should contain("success") + getResultResponse.obj.keySet should contain("err") + getResultResponse("success").bool shouldBe false + } - "receive error when attempting to retrieve result with invalid uuid" in Fixture() { url => - val connectedPromise = scala.concurrent.Promise[String]() - cask.util.WsClient.connect(s"$url/connect") { case cask.Ws.Text(msg) => - connectedPromise.success(msg) - } - Await.result(connectedPromise.future, Duration(1, SECONDS)) - val getResultResponse = getResponse(url, UUID.randomUUID().toString) - getResultResponse.obj.keySet should contain("success") - getResultResponse.obj.keySet should contain("err") - getResultResponse("success").bool shouldBe false - } - - "return a valid JSON response when calling /result with incorrectly-formatted UUID parameter" in Fixture() { url => - val connectedPromise = scala.concurrent.Promise[String]() - cask.util.WsClient.connect(s"$url/connect") { case cask.Ws.Text(msg) => - connectedPromise.success(msg) - } - Await.result(connectedPromise.future, Duration(1, SECONDS)) - val getResultResponse = getResponse(url, "INCORRECTLY_FORMATTED_UUID_PARAM") - getResultResponse.obj.keySet should contain("success") - getResultResponse.obj.keySet should contain("err") - getResultResponse("success").bool shouldBe false - getResultResponse("err").str.length should not equal 0 + "return a valid JSON response when calling /result with incorrectly-formatted UUID parameter" in Fixture() { url => + val connectedPromise = scala.concurrent.Promise[String]() + cask.util.WsClient.connect(s"$url/connect") { case cask.Ws.Text(msg) => + connectedPromise.success(msg) } + Await.result(connectedPromise.future, Duration(1, SECONDS)) + val getResultResponse = getResponse(url, "INCORRECTLY_FORMATTED_UUID_PARAM") + getResultResponse.obj.keySet should contain("success") + getResultResponse.obj.keySet should contain("err") + getResultResponse("success").bool shouldBe false + getResultResponse("err").str.length should not equal 0 + } - "return websocket responses for all queries when posted quickly in a large number" in Fixture() { url => - val numQueries = 10 - val correctNumberOfUUIDsReceived = scala.concurrent.Promise[String]() - val wsUUIDs = ListBuffer[String]() + "return websocket responses for all queries when posted quickly in a large number" in Fixture() { url => + val numQueries = 10 + val correctNumberOfUUIDsReceived = scala.concurrent.Promise[String]() + val wsUUIDs = ListBuffer[String]() - val rtl: Lock = new ReentrantLock() - val connectedPromise = scala.concurrent.Promise[String]() - cask.util.WsClient.connect(s"$url/connect") { case cask.Ws.Text(msg) => - if (msg == "connected") { - connectedPromise.success(msg) - } else { - rtl.lock() - try { - wsUUIDs += msg - } finally { - rtl.unlock() - if (wsUUIDs.size == numQueries) { - correctNumberOfUUIDsReceived.success("") - } + val rtl: Lock = new ReentrantLock() + val connectedPromise = scala.concurrent.Promise[String]() + cask.util.WsClient.connect(s"$url/connect") { case cask.Ws.Text(msg) => + if (msg == "connected") { + connectedPromise.success(msg) + } else { + rtl.lock() + try { + wsUUIDs += msg + } finally { + rtl.unlock() + if (wsUUIDs.size == numQueries) { + correctNumberOfUUIDsReceived.success("") } } } - Await.result(connectedPromise.future, DefaultPromiseAwaitTimeout) + } + Await.result(connectedPromise.future, DefaultPromiseAwaitTimeout) - val postQueriesResponseUUIDs = - for (_ <- 1 to numQueries) yield { - val postQueryResponse = postQueryAsync(url, "1") - postQueryResponse("uuid").str - } + val postQueriesResponseUUIDs = + for (_ <- 1 to numQueries) yield { + val postQueryResponse = postQueryAsync(url, "1") + postQueryResponse("uuid").str + } - Await.result(correctNumberOfUUIDsReceived.future, DefaultPromiseAwaitTimeout * numQueries.toLong) - wsUUIDs.toSet should be(postQueriesResponseUUIDs.toSet) - } + Await.result(correctNumberOfUUIDsReceived.future, DefaultPromiseAwaitTimeout * numQueries.toLong) + wsUUIDs.toSet should be(postQueriesResponseUUIDs.toSet) + } - "return websocket responses for all queries when some are invalid" in Fixture() { url => - val queries = List("1", "1 + 1", "open(", "open)", "open{", "open}") - val correctNumberOfUUIDsReceived = scala.concurrent.Promise[String]() - val wsUUIDs = ListBuffer[String]() - val connectedPromise = scala.concurrent.Promise[String]() + "return websocket responses for all queries when some are invalid" in Fixture() { url => + val queries = List("1", "1 + 1", "open(", "open)", "open{", "open}") + val correctNumberOfUUIDsReceived = scala.concurrent.Promise[String]() + val wsUUIDs = ListBuffer[String]() + val connectedPromise = scala.concurrent.Promise[String]() - val rtl: Lock = new ReentrantLock() - cask.util.WsClient.connect(s"$url/connect") { case cask.Ws.Text(msg) => - if (msg == "connected") { - connectedPromise.success(msg) - } else { - rtl.lock() - try { - wsUUIDs += msg - } finally { - rtl.unlock() - if (wsUUIDs.size == queries.size) { - correctNumberOfUUIDsReceived.success("") - } + val rtl: Lock = new ReentrantLock() + cask.util.WsClient.connect(s"$url/connect") { case cask.Ws.Text(msg) => + if (msg == "connected") { + connectedPromise.success(msg) + } else { + rtl.lock() + try { + wsUUIDs += msg + } finally { + rtl.unlock() + if (wsUUIDs.size == queries.size) { + correctNumberOfUUIDsReceived.success("") } } } - Await.result(connectedPromise.future, DefaultPromiseAwaitTimeout) + } + Await.result(connectedPromise.future, DefaultPromiseAwaitTimeout) - val postQueriesResponseUUIDs = queries.map { query => - postQueryAsync(url, query)("uuid").str - } - Await.result(correctNumberOfUUIDsReceived.future, DefaultPromiseAwaitTimeout * queries.size.toLong) - wsUUIDs.toSet should be(postQueriesResponseUUIDs.toSet) + val postQueriesResponseUUIDs = queries.map { query => + postQueryAsync(url, query)("uuid").str } + Await.result(correctNumberOfUUIDsReceived.future, DefaultPromiseAwaitTimeout * queries.size.toLong) + wsUUIDs.toSet should be(postQueriesResponseUUIDs.toSet) } + } - "synchronous api" should { - "work for simple case" in Fixture() { url => - val response = postQuerySync(url, "1") - response.obj.keySet should contain("success") - response("stdout").str shouldBe "val res0: Int = 1\n" - } + "synchronous api" should { + "work for simple case" in Fixture() { url => + val response = postQuerySync(url, "1") + response.obj.keySet should contain("success") + response("stdout").str shouldBe "val res0: Int = 1\n" + } - "using predef code" in Fixture("val predefCode = 2") { url => - val response = postQuerySync(url, "val foo = predefCode + 40") - response.obj.keySet should contain("success") - response("stdout").str shouldBe "val foo: Int = 42\n" - } + "using predef code" in Fixture("val predefCode = 2") { url => + val response = postQuerySync(url, "val foo = predefCode + 40") + response.obj.keySet should contain("success") + response("stdout").str shouldBe "val foo: Int = 42\n" + } - "fail for invalid auth" in Fixture() { url => - assertThrows[RequestFailedException] { - postQuerySync(url, "1", authHeaderVal = "Basic b4df00d") - } + "fail for invalid auth" in Fixture() { url => + assertThrows[RequestFailedException] { + postQuerySync(url, "1", authHeaderVal = "Basic b4df00d") } } } From 6ea470a7160b059c5ff403a5a6b944293d479ef1 Mon Sep 17 00:00:00 2001 From: Michael Pollmeier Date: Mon, 8 May 2023 09:39:07 +0200 Subject: [PATCH 4/5] ci: run on windows/mac/linux --- .github/workflows/pr.yml | 5 ++++- .github/workflows/release.yml | 5 ++++- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/.github/workflows/pr.yml b/.github/workflows/pr.yml index dc8554d..7d966b8 100644 --- a/.github/workflows/pr.yml +++ b/.github/workflows/pr.yml @@ -2,7 +2,10 @@ name: pr on: pull_request jobs: pr: - runs-on: ubuntu-22.04 + runs-on: ${{ matrix.os }} + strategy: + matrix: + os: [ubuntu-latest, windows-latest, macos-latest] steps: - uses: actions/checkout@v3 with: diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 9541b87..21cab28 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -6,7 +6,10 @@ on: tags: ["*"] jobs: release: - runs-on: ubuntu-22.04 + runs-on: ${{ matrix.os }} + strategy: + matrix: + os: [ubuntu-latest, windows-latest, macos-latest] steps: - uses: actions/checkout@v3 with: From 0952060384d50218e51aed4fea0c9496b553dc3d Mon Sep 17 00:00:00 2001 From: Michael Pollmeier Date: Mon, 8 May 2023 09:43:00 +0200 Subject: [PATCH 5/5] rename job --- .github/workflows/pr.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/pr.yml b/.github/workflows/pr.yml index 7d966b8..733df45 100644 --- a/.github/workflows/pr.yml +++ b/.github/workflows/pr.yml @@ -1,7 +1,7 @@ name: pr on: pull_request jobs: - pr: + test: runs-on: ${{ matrix.os }} strategy: matrix: