Skip to content

Refactor Runtime to support end-to-end testing #3

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Aug 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions gradle/libs.versions.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,8 @@ kotlin-date-time = { module = "org.jetbrains.kotlinx:kotlinx-datetime", version.
kotlin-coroutines-test = { module = "org.jetbrains.kotlinx:kotlinx-coroutines-test", version.ref = "coroutinest-test" }
junit = { group = "junit", name = "junit", version.ref = "junit" }
logback = { module = "ch.qos.logback:logback-classic", version.ref = "logback" }
ktor-server-core = { module = "io.ktor:ktor-server-core-jvm", version.ref = "ktor" }
ktor-server-netty = { module = "io.ktor:ktor-server-netty-jvm", version.ref = "ktor" }
ktor-server-tests = { module = "io.ktor:ktor-server-tests-jvm", version.ref = "ktor" }
ktor-client-core = { module = "io.ktor:ktor-client-core", version.ref = "ktor" }
ktor-client-mock = { module = "io.ktor:ktor-client-mock", version.ref = "ktor" }
ktor-client-cio = { module = "io.ktor:ktor-client-cio", version.ref = "ktor" }
ktor-client-curl = { module = "io.ktor:ktor-client-curl", version.ref = "ktor" }
ktor-client-logging = { module = "io.ktor:ktor-client-logging", version.ref = "ktor" }
Expand Down
16 changes: 9 additions & 7 deletions lambda-runtime/build.gradle.kts
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import dev.mokkery.MockMode

plugins {
alias(libs.plugins.kotlin.multiplatform)
alias(libs.plugins.kotlin.serialization)
alias(libs.plugins.mokkery)
alias(libs.plugins.allopen)
alias(libs.plugins.kotlinx.resources)
}

kotlin {
Expand All @@ -29,15 +31,15 @@ kotlin {
}

nativeTest.dependencies {
implementation(projects.lambdaEvents)
implementation(libs.kotlin.test)
implementation(libs.kotlin.coroutines.test)
implementation(libs.ktor.client.mock)
implementation(libs.kotlinx.resources)
}
}
}

fun isTestingTask(name: String) = name.endsWith("Test")
val isTesting = gradle.startParameter.taskNames.any(::isTestingTask)

if (isTesting) allOpen {
annotation("kotlin.Metadata")
mokkery {
defaultMockMode.set(MockMode.autoUnit)
}

Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package io.github.trueangle.knative.lambda.runtime.log
package io.github.trueangle.knative.lambda.runtime

internal fun Throwable.prettyPrint(includeStackTrace: Boolean = true) = buildString {
append("An exception occurred:\n")
Expand All @@ -9,4 +9,6 @@ internal fun Throwable.prettyPrint(includeStackTrace: Boolean = true) = buildStr
append("Stack Trace:\n")
append(stackTraceToString())
}
}
}

internal fun <T> unsafeLazy(initializer: () -> T): Lazy<T> = lazy(LazyThreadSafetyMode.NONE, initializer)
Original file line number Diff line number Diff line change
Expand Up @@ -11,24 +11,44 @@ import io.github.trueangle.knative.lambda.runtime.ReservedRuntimeEnvironmentVari
import kotlinx.cinterop.ExperimentalForeignApi
import kotlinx.cinterop.toKString
import platform.posix.getenv
import kotlin.system.exitProcess

@OptIn(ExperimentalForeignApi::class)
@PublishedApi
internal object LambdaEnvironment {
val FUNCTION_MEMORY_SIZE = getenv(AWS_LAMBDA_FUNCTION_MEMORY_SIZE)?.toKString()?.toIntOrNull() ?: 128
val LOG_GROUP_NAME: String = getenv(AWS_LAMBDA_LOG_GROUP_NAME)?.toKString().orEmpty()
val LOG_STREAM_NAME: String = getenv(AWS_LAMBDA_LOG_STREAM_NAME)?.toKString().orEmpty()
val LAMBDA_LOG_LEVEL: String? = getenv(AWS_LAMBDA_LOG_LEVEL)?.toKString()
val LAMBDA_LOG_FORMAT: String? = getenv(AWS_LAMBDA_LOG_FORMAT)?.toKString()
val FUNCTION_NAME: String = getenv(AWS_LAMBDA_FUNCTION_NAME)?.toKString().orEmpty()
val FUNCTION_VERSION: String = getenv(AWS_LAMBDA_FUNCTION_VERSION)?.toKString().orEmpty()

val RUNTIME_API: String = requireNotNull(getenv(AWS_LAMBDA_RUNTIME_API)?.toKString()) {
"Can't find AWS_LAMBDA_RUNTIME_API env variable"
internal open class LambdaEnvironment {
// open due to Mokkery limits
open fun terminate(): Nothing = exitProcess(1)

@OptIn(ExperimentalForeignApi::class)
@PublishedApi
internal companion object Variables {
val FUNCTION_MEMORY_SIZE by unsafeLazy {
getenv(AWS_LAMBDA_FUNCTION_MEMORY_SIZE)?.toKString()?.toIntOrNull() ?: 128
}
val LOG_GROUP_NAME by unsafeLazy {
getenv(AWS_LAMBDA_LOG_GROUP_NAME)?.toKString().orEmpty()
}
val LOG_STREAM_NAME by unsafeLazy {
getenv(AWS_LAMBDA_LOG_STREAM_NAME)?.toKString().orEmpty()
}
val LAMBDA_LOG_LEVEL by unsafeLazy {
getenv(AWS_LAMBDA_LOG_LEVEL)?.toKString() ?: "INFO"
}
val LAMBDA_LOG_FORMAT by unsafeLazy {
getenv(AWS_LAMBDA_LOG_FORMAT)?.toKString() ?: "TEXT"
}
val FUNCTION_NAME by unsafeLazy {
getenv(AWS_LAMBDA_FUNCTION_NAME)?.toKString().orEmpty()
}
val FUNCTION_VERSION by unsafeLazy {
getenv(AWS_LAMBDA_FUNCTION_VERSION)?.toKString().orEmpty()
}
val RUNTIME_API by unsafeLazy {
getenv(AWS_LAMBDA_RUNTIME_API)?.toKString()
}
}
}

private object ReservedRuntimeEnvironmentVariables {
internal object ReservedRuntimeEnvironmentVariables {
/**
* The handler location configured on the function.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,20 @@ package io.github.trueangle.knative.lambda.runtime
import io.github.trueangle.knative.lambda.runtime.LambdaEnvironmentException.NonRecoverableStateException
import io.github.trueangle.knative.lambda.runtime.api.Context
import io.github.trueangle.knative.lambda.runtime.api.LambdaClient
import io.github.trueangle.knative.lambda.runtime.api.LambdaClientImpl
import io.github.trueangle.knative.lambda.runtime.handler.LambdaBufferedHandler
import io.github.trueangle.knative.lambda.runtime.handler.LambdaHandler
import io.github.trueangle.knative.lambda.runtime.handler.LambdaStreamHandler
import io.github.trueangle.knative.lambda.runtime.log.KtorLogger
import io.github.trueangle.knative.lambda.runtime.log.LambdaLogger
import io.github.trueangle.knative.lambda.runtime.log.Log
import io.github.trueangle.knative.lambda.runtime.log.debug
import io.github.trueangle.knative.lambda.runtime.log.error
import io.github.trueangle.knative.lambda.runtime.log.fatal
import io.github.trueangle.knative.lambda.runtime.log.info
import io.github.trueangle.knative.lambda.runtime.log.warn
import io.ktor.client.HttpClient
import io.ktor.client.engine.HttpClientEngine
import io.ktor.client.engine.curl.Curl
import io.ktor.client.plugins.HttpTimeout
import io.ktor.client.plugins.contentnegotiation.ContentNegotiation
Expand All @@ -22,13 +30,20 @@ import io.ktor.utils.io.writeStringUtf8
import kotlinx.coroutines.runBlocking
import kotlinx.serialization.ExperimentalSerializationApi
import kotlinx.serialization.json.Json
import kotlin.system.exitProcess

object LambdaRuntime {
@OptIn(ExperimentalSerializationApi::class)
internal val json = Json { explicitNulls = false }

private val httpClient = HttpClient(Curl) {
inline fun <reified I, reified O> run(crossinline initHandler: () -> LambdaHandler<I, O>) = runBlocking {
val curlHttpClient = createHttpClient(Curl.create())
val lambdaClient = LambdaClientImpl(curlHttpClient)

Runner(client = lambdaClient, log = Log).run(false, initHandler)
}

@PublishedApi
internal fun createHttpClient(engine: HttpClientEngine) = HttpClient(engine) {
install(HttpTimeout)
install(ContentNegotiation) { json(json) }
install(Logging) {
Expand All @@ -39,86 +54,93 @@ object LambdaRuntime {
filter { !it.headers.contains("Lambda-Runtime-Function-Response-Mode", "streaming") }
}
}
}

@PublishedApi
internal val client = LambdaClient(httpClient)

inline fun <reified I, reified O> run(crossinline initHandler: () -> LambdaHandler<I, O>) = runBlocking {
@PublishedApi
internal class Runner(
val client: LambdaClient,
val log: LambdaLogger,
val env: LambdaEnvironment = LambdaEnvironment()
) {
suspend inline fun <reified I, reified O> run(singleEventMode: Boolean = false, crossinline initHandler: () -> LambdaHandler<I, O>) {
val handler = try {
Log.info("Initializing Kotlin Native Lambda Runtime")
log.info("Initializing Kotlin Native Lambda Runtime")

initHandler()
} catch (e: Exception) {
Log.fatal(e)
log.fatal(e)

client.reportError(e.asInitError())
exitProcess(1)

env.terminate()
}

val handlerName = handler::class.simpleName
val inputTypeInfo = typeInfo<I>()
val outputTypeInfo = typeInfo<O>()

while (true) {
var shouldExit = false
while (!shouldExit) {
try {
Log.info("Runtime is ready for a new event")
log.info("Runtime is ready for a new event")

val (event, context) = client.retrieveNextEvent<I>(inputTypeInfo)

with(Log) {
with(log) {
setContext(context)

debug(event)
debug(context)
info("$handlerName invocation started")
}

Log.info("$handlerName invocation started")

if (handler is LambdaStreamHandler<I, *>) {
val response = streamingResponse { handler.handleRequest(event, it, context) }

Log.info("$handlerName started response streaming")
log.info("$handlerName started response streaming")

client.streamResponse(context, response)
} else {
handler as LambdaBufferedHandler<I, O>
val response = bufferedResponse(context) { handler.handleRequest(event, context) }

Log.info("$handlerName invocation completed")
Log.debug(response)
log.info("$handlerName invocation completed")
log.debug(response)

client.sendResponse(context, response, outputTypeInfo)
}
} catch (e: LambdaRuntimeException) {
Log.error(e)
log.error(e)

client.reportError(e)
} catch (e: LambdaEnvironmentException) {
when (e) {
is NonRecoverableStateException -> {
Log.fatal(e)
log.fatal(e)

exitProcess(1)
env.terminate()
}

else -> Log.error(e)
else -> log.error(e)
}
} catch (e: Throwable) {
Log.fatal(e)
log.fatal(e)

env.terminate()
}

exitProcess(1)
if (singleEventMode) {
shouldExit = singleEventMode
}
}
}
}

@PublishedApi
internal inline fun streamingResponse(crossinline handler: suspend (ByteWriteChannel) -> Unit) =
object : WriteChannelContent() {
inline fun streamingResponse(crossinline handler: suspend (ByteWriteChannel) -> Unit) = object : WriteChannelContent() {
override suspend fun writeTo(channel: ByteWriteChannel) {
try {
handler(channel)
} catch (e: Exception) {
Log.warn("Exception occurred on streaming: " + e.message)
log.warn("Exception occurred on streaming: " + e.message)

channel.writeStringUtf8(e.toTrailer())
}
Expand All @@ -128,9 +150,9 @@ internal inline fun streamingResponse(crossinline handler: suspend (ByteWriteCha
"Lambda-Runtime-Function-Error-Type: Runtime.StreamError\r\nLambda-Runtime-Function-Error-Body: ${stackTraceToString().encodeBase64()}\r\n"
}

@PublishedApi
internal inline fun <T, R> T.bufferedResponse(context: Context, block: T.() -> R): R = try {
block()
} catch (e: Exception) {
throw e.asHandlerError(context)
inline fun <T, R> T.bufferedResponse(context: Context, block: T.() -> R): R = try {
block()
} catch (e: Exception) {
throw e.asHandlerError(context)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,22 @@ import kotlin.time.Duration.Companion.minutes
import io.ktor.http.ContentType.Application.Json as ContentTypeJson

@PublishedApi
internal class LambdaClient(private val httpClient: HttpClient) {
private val invokeUrl = "http://${LambdaEnvironment.RUNTIME_API}/2018-06-01/runtime"
internal interface LambdaClient {
suspend fun <T> retrieveNextEvent(bodyType: TypeInfo): Pair<T, Context>
suspend fun <T> sendResponse(event: Context, body: T, bodyType: TypeInfo): HttpResponse
suspend fun streamResponse(event: Context, outgoingContent: OutgoingContent): HttpResponse
suspend fun reportError(error: LambdaRuntimeException)
}

@PublishedApi
internal class LambdaClientImpl(private val httpClient: HttpClient): LambdaClient {
private val baseUrl = requireNotNull(LambdaEnvironment.RUNTIME_API) {
"Can't find AWS_LAMBDA_RUNTIME_API env variable"
}
private val invokeUrl = "http://$baseUrl/2018-06-01/runtime"
private val requestTimeout = 15.minutes.inWholeMilliseconds

suspend fun <T> retrieveNextEvent(bodyType: TypeInfo): Pair<T, Context> {
override suspend fun <T> retrieveNextEvent(bodyType: TypeInfo): Pair<T, Context> {
val response = httpClient.get {
url("${invokeUrl}/invocation/next")

Expand All @@ -47,7 +58,7 @@ internal class LambdaClient(private val httpClient: HttpClient) {
return body to context
}

suspend fun <T> sendResponse(event: Context, body: T, bodyType: TypeInfo): HttpResponse {
override suspend fun <T> sendResponse(event: Context, body: T, bodyType: TypeInfo): HttpResponse {
val response = httpClient.post {
url("${invokeUrl}/invocation/${event.awsRequestId}/response")
contentType(ContentTypeJson)
Expand All @@ -64,7 +75,7 @@ internal class LambdaClient(private val httpClient: HttpClient) {
return validateResponse(response)
}

suspend fun streamResponse(event: Context, outgoingContent: OutgoingContent): HttpResponse {
override suspend fun streamResponse(event: Context, outgoingContent: OutgoingContent): HttpResponse {
val response = httpClient.post {
url("${invokeUrl}/invocation/${event.awsRequestId}/response")

Expand All @@ -89,13 +100,15 @@ internal class LambdaClient(private val httpClient: HttpClient) {
return response
}

suspend fun reportError(error: LambdaRuntimeException) {
val response = when (error) {
override suspend fun reportError(error: LambdaRuntimeException) {
when (error) {
is LambdaRuntimeException.Init -> sendInitError(error)
is LambdaRuntimeException.Invocation -> sendInvocationError(error)
}
is LambdaRuntimeException.Invocation -> {
val response = sendInvocationError(error)

validateResponse(response)
validateResponse(response)
}
}
}

private suspend fun sendInvocationError(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,8 @@ package io.github.trueangle.knative.lambda.runtime.api.dto

import io.github.trueangle.knative.lambda.runtime.log.LogLevel
import kotlinx.serialization.Contextual
import kotlinx.serialization.KSerializer
import kotlinx.serialization.SerialName
import kotlinx.serialization.Serializable
import kotlinx.serialization.encoding.Decoder
import kotlinx.serialization.encoding.Encoder

@Serializable
internal data class LogMessageDto<T>(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package io.github.trueangle.knative.lambda.runtime.log

import io.github.trueangle.knative.lambda.runtime.api.Context
import io.github.trueangle.knative.lambda.runtime.api.dto.LogMessageDto
import io.github.trueangle.knative.lambda.runtime.prettyPrint
import io.ktor.util.reflect.TypeInfo
import kotlinx.datetime.Clock
import kotlinx.serialization.SerializationException
Expand Down
Loading
Loading