From 0c66d4019d6fb23e246d1637d7919b5a059d11a3 Mon Sep 17 00:00:00 2001 From: Sebastien Stormacq Date: Mon, 13 Jan 2025 12:47:54 +0100 Subject: [PATCH 01/15] update mock server for swft 6 compliance --- Package.swift | 6 +- Sources/MockServer/MockHTTPServer.swift | 285 ++++++++++++++++++++++++ Sources/MockServer/main.swift | 177 --------------- 3 files changed, 288 insertions(+), 180 deletions(-) create mode 100644 Sources/MockServer/MockHTTPServer.swift delete mode 100644 Sources/MockServer/main.swift diff --git a/Package.swift b/Package.swift index d2c92fdc..96068884 100644 --- a/Package.swift +++ b/Package.swift @@ -17,7 +17,7 @@ let package = Package( .library(name: "AWSLambdaTesting", targets: ["AWSLambdaTesting"]), ], dependencies: [ - .package(url: "https://github.com/apple/swift-nio.git", from: "2.76.0"), + .package(url: "https://github.com/apple/swift-nio.git", from: "2.77.0"), .package(url: "https://github.com/apple/swift-log.git", from: "1.5.4"), ], targets: [ @@ -89,11 +89,11 @@ let package = Package( .executableTarget( name: "MockServer", dependencies: [ + .product(name: "Logging", package: "swift-log"), .product(name: "NIOHTTP1", package: "swift-nio"), .product(name: "NIOCore", package: "swift-nio"), .product(name: "NIOPosix", package: "swift-nio"), - ], - swiftSettings: [.swiftLanguageMode(.v5)] + ] ), ] ) diff --git a/Sources/MockServer/MockHTTPServer.swift b/Sources/MockServer/MockHTTPServer.swift new file mode 100644 index 00000000..a730de11 --- /dev/null +++ b/Sources/MockServer/MockHTTPServer.swift @@ -0,0 +1,285 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the SwiftAWSLambdaRuntime open source project +// +// Copyright (c) 2017-2025 Apple Inc. and the SwiftAWSLambdaRuntime project authors +// Licensed under Apache License v2.0 +// +// See LICENSE.txt for license information +// See CONTRIBUTORS.txt for the list of SwiftAWSLambdaRuntime project authors +// +// SPDX-License-Identifier: Apache-2.0 +// +//===----------------------------------------------------------------------===// + +import Logging +import NIOCore +import NIOHTTP1 +import NIOPosix + +// for UUID and Date +#if canImport(FoundationEssentials) +import FoundationEssentials +#else +import Foundation +#endif + +@main +public class MockHttpServer { + + public static func main() throws { + let server = MockHttpServer() + try server.start() + } + + private func start() throws { + let host = env("HOST") ?? "127.0.0.1" + let port = env("PORT").flatMap(Int.init) ?? 7000 + let mode = env("MODE").flatMap(Mode.init) ?? .string + var log = Logger(label: "MockServer") + log.logLevel = env("LOG_LEVEL").flatMap(Logger.Level.init) ?? .info + let logger = log + + let socketBootstrap = ServerBootstrap(group: MultiThreadedEventLoopGroup(numberOfThreads: System.coreCount)) + // Specify backlog and enable SO_REUSEADDR for the server itself + // .serverChannelOption(.backlog, value: 256) + .serverChannelOption(.socketOption(.so_reuseaddr), value: 1) + // .childChannelOption(.maxMessagesPerRead, value: 1) + + // Set the handlers that are applied to the accepted Channels + .childChannelInitializer { channel in + channel.pipeline.configureHTTPServerPipeline(withErrorHandling: true).flatMap { + channel.pipeline.addHandler(HTTPHandler(mode: mode, logger: logger)) + } + } + + let channel = try socketBootstrap.bind(host: host, port: port).wait() + logger.debug("Server started and listening on \(host):\(port)") + + // This will never return as we don't close the ServerChannel + try channel.closeFuture.wait() + } +} + +private final class HTTPHandler: ChannelInboundHandler { + public typealias InboundIn = HTTPServerRequestPart + public typealias OutboundOut = HTTPServerResponsePart + + private enum State { + case idle + case waitingForRequestBody + case sendingResponse + + mutating func requestReceived() { + precondition(self == .idle, "Invalid state for request received: \(self)") + self = .waitingForRequestBody + } + + mutating func requestComplete() { + precondition( + self == .waitingForRequestBody, + "Invalid state for request complete: \(self)" + ) + self = .sendingResponse + } + + mutating func responseComplete() { + precondition(self == .sendingResponse, "Invalid state for response complete: \(self)") + self = .idle + } + } + + private let logger: Logger + private let mode: Mode + + private var buffer: ByteBuffer! = nil + private var state: HTTPHandler.State = .idle + private var keepAlive = false + + private var requestHead: HTTPRequestHead? + private var requestBodyBytes: Int = 0 + + init(mode: Mode, logger: Logger) { + self.mode = mode + self.logger = logger + } + + func channelRead(context: ChannelHandlerContext, data: NIOAny) { + let reqPart = Self.unwrapInboundIn(data) + handle(context: context, request: reqPart) + } + + func channelReadComplete(context: ChannelHandlerContext) { + context.flush() + self.buffer.clear() + } + + func handlerAdded(context: ChannelHandlerContext) { + self.buffer = context.channel.allocator.buffer(capacity: 0) + } + + private func handle(context: ChannelHandlerContext, request: HTTPServerRequestPart) { + switch request { + case .head(let request): + logger.trace("Received request .head") + self.requestHead = request + self.requestBodyBytes = 0 + self.keepAlive = request.isKeepAlive + self.state.requestReceived() + case .body(buffer: var buf): + logger.trace("Received request .body") + self.requestBodyBytes += buf.readableBytes + self.buffer.writeBuffer(&buf) + case .end: + logger.trace("Received request .end") + self.state.requestComplete() + + precondition(requestHead != nil, "Received .end without .head") + let (responseStatus, responseHeaders, responseBody) = self.processRequest( + requestHead: self.requestHead!, + requestBody: self.buffer + ) + + self.buffer.clear() + self.buffer.writeString(responseBody) + + var headers = HTTPHeaders(responseHeaders) + headers.add(name: "Content-Length", value: "\(responseBody.utf8.count)") + + // write the response + context.write( + Self.wrapOutboundOut( + .head( + httpResponseHead( + request: self.requestHead!, + status: responseStatus, + headers: headers + ) + ) + ), + promise: nil + ) + context.write(Self.wrapOutboundOut(.body(.byteBuffer(self.buffer))), promise: nil) + self.completeResponse(context, trailers: nil, promise: nil) + } + } + + private func processRequest( + requestHead: HTTPRequestHead, + requestBody: ByteBuffer + ) -> (HTTPResponseStatus, [(String, String)], String) { + var responseStatus: HTTPResponseStatus = .ok + var responseBody: String = "" + var responseHeaders: [(String, String)] = [] + + logger.trace("Processing request for : \(requestHead) - \(requestBody.getString(at: 0, length: self.requestBodyBytes) ?? "")") + + if requestHead.uri.hasSuffix("/next") { + logger.trace("URI /next") + + responseStatus = .accepted + + let requestId = UUID().uuidString + switch self.mode { + case .string: + responseBody = "\"\(requestId)\"" // must be a valid JSON string + case .json: + responseBody = "{ \"body\": \"\(requestId)\" }" + } + let deadline = Int64(Date(timeIntervalSinceNow: 60).timeIntervalSince1970 * 1000) + responseHeaders = [ + // ("Connection", "close"), + (AmazonHeaders.requestID, requestId), + (AmazonHeaders.invokedFunctionARN, "arn:aws:lambda:us-east-1:123456789012:function:custom-runtime"), + (AmazonHeaders.traceID, "Root=1-5bef4de7-ad49b0e87f6ef6c87fc2e700;Parent=9a9197af755a6419;Sampled=1"), + (AmazonHeaders.deadline, String(deadline)), + ] + } else if requestHead.uri.hasSuffix("/response") { + logger.trace("URI /response") + responseStatus = .accepted + } else if requestHead.uri.hasSuffix("/error") { + logger.trace("URI /error") + responseStatus = .ok + } else { + logger.trace("Unknown URI : \(requestHead)") + responseStatus = .notFound + } + logger.trace("Returning response: \(responseStatus), \(responseHeaders), \(responseBody)") + return (responseStatus, responseHeaders, responseBody) + } + + private func completeResponse( + _ context: ChannelHandlerContext, + trailers: HTTPHeaders?, + promise: EventLoopPromise? + ) { + self.state.responseComplete() + + let eventLoop = context.eventLoop + let loopBoundContext = NIOLoopBound(context, eventLoop: eventLoop) + + let promise = self.keepAlive ? promise : (promise ?? context.eventLoop.makePromise()) + if !self.keepAlive { + promise!.futureResult.whenComplete { (_: Result) in + let context = loopBoundContext.value + context.close(promise: nil) + } + } + + context.writeAndFlush(Self.wrapOutboundOut(.end(trailers)), promise: promise) + } + + private func httpResponseHead( + request: HTTPRequestHead, + status: HTTPResponseStatus, + headers: HTTPHeaders = HTTPHeaders() + ) -> HTTPResponseHead { + var head = HTTPResponseHead(version: request.version, status: status, headers: headers) + let connectionHeaders: [String] = head.headers[canonicalForm: "connection"].map { + $0.lowercased() + } + + if !connectionHeaders.contains("keep-alive") && !connectionHeaders.contains("close") { + // the user hasn't pre-set either 'keep-alive' or 'close', so we might need to add headers + + switch (request.isKeepAlive, request.version.major, request.version.minor) { + case (true, 1, 0): + // HTTP/1.0 and the request has 'Connection: keep-alive', we should mirror that + head.headers.add(name: "Connection", value: "keep-alive") + case (false, 1, let n) where n >= 1: + // HTTP/1.1 (or treated as such) and the request has 'Connection: close', we should mirror that + head.headers.add(name: "Connection", value: "close") + default: + // we should match the default or are dealing with some HTTP that we don't support, let's leave as is + () + } + } + return head + } + + private enum ServerError: Error { + case notReady + case cantBind + } + + private enum AmazonHeaders { + static let requestID = "Lambda-Runtime-Aws-Request-Id" + static let traceID = "Lambda-Runtime-Trace-Id" + static let clientContext = "X-Amz-Client-Context" + static let cognitoIdentity = "X-Amz-Cognito-Identity" + static let deadline = "Lambda-Runtime-Deadline-Ms" + static let invokedFunctionARN = "Lambda-Runtime-Invoked-Function-Arn" + } +} + +private enum Mode: String { + case string + case json +} + +private func env(_ name: String) -> String? { + guard let value = getenv(name) else { + return nil + } + return String(cString: value) +} diff --git a/Sources/MockServer/main.swift b/Sources/MockServer/main.swift deleted file mode 100644 index 1b8466f9..00000000 --- a/Sources/MockServer/main.swift +++ /dev/null @@ -1,177 +0,0 @@ -//===----------------------------------------------------------------------===// -// -// This source file is part of the SwiftAWSLambdaRuntime open source project -// -// Copyright (c) 2017-2018 Apple Inc. and the SwiftAWSLambdaRuntime project authors -// Licensed under Apache License v2.0 -// -// See LICENSE.txt for license information -// See CONTRIBUTORS.txt for the list of SwiftAWSLambdaRuntime project authors -// -// SPDX-License-Identifier: Apache-2.0 -// -//===----------------------------------------------------------------------===// - -import Dispatch -import NIOCore -import NIOHTTP1 -import NIOPosix - -#if canImport(FoundationEssentials) -import FoundationEssentials -#else -import Foundation -#endif - -struct MockServer { - private let group: EventLoopGroup - private let host: String - private let port: Int - private let mode: Mode - - public init() { - self.group = MultiThreadedEventLoopGroup(numberOfThreads: System.coreCount) - self.host = env("HOST") ?? "127.0.0.1" - self.port = env("PORT").flatMap(Int.init) ?? 7000 - self.mode = env("MODE").flatMap(Mode.init) ?? .string - } - - func start() throws { - let bootstrap = ServerBootstrap(group: group) - .serverChannelOption(ChannelOptions.socket(SocketOptionLevel(SOL_SOCKET), SO_REUSEADDR), value: 1) - .childChannelInitializer { channel in - channel.pipeline.configureHTTPServerPipeline(withErrorHandling: true).flatMap { _ in - channel.pipeline.addHandler(HTTPHandler(mode: self.mode)) - } - } - try bootstrap.bind(host: self.host, port: self.port).flatMap { channel -> EventLoopFuture in - guard let localAddress = channel.localAddress else { - return channel.eventLoop.makeFailedFuture(ServerError.cantBind) - } - print("\(self) started and listening on \(localAddress)") - return channel.eventLoop.makeSucceededFuture(()) - }.wait() - } -} - -final class HTTPHandler: ChannelInboundHandler { - public typealias InboundIn = HTTPServerRequestPart - public typealias OutboundOut = HTTPServerResponsePart - - private let mode: Mode - - private var pending = CircularBuffer<(head: HTTPRequestHead, body: ByteBuffer?)>() - - public init(mode: Mode) { - self.mode = mode - } - - func channelRead(context: ChannelHandlerContext, data: NIOAny) { - let requestPart = unwrapInboundIn(data) - - switch requestPart { - case .head(let head): - self.pending.append((head: head, body: nil)) - case .body(var buffer): - var request = self.pending.removeFirst() - if request.body == nil { - request.body = buffer - } else { - request.body!.writeBuffer(&buffer) - } - self.pending.prepend(request) - case .end: - let request = self.pending.removeFirst() - self.processRequest(context: context, request: request) - } - } - - func processRequest(context: ChannelHandlerContext, request: (head: HTTPRequestHead, body: ByteBuffer?)) { - var responseStatus: HTTPResponseStatus - var responseBody: String? - var responseHeaders: [(String, String)]? - - if request.head.uri.hasSuffix("/next") { - let requestId = UUID().uuidString - responseStatus = .ok - switch self.mode { - case .string: - responseBody = requestId - case .json: - responseBody = "{ \"body\": \"\(requestId)\" }" - } - let deadline = Int64(Date(timeIntervalSinceNow: 60).timeIntervalSince1970 * 1000) - responseHeaders = [ - (AmazonHeaders.requestID, requestId), - (AmazonHeaders.invokedFunctionARN, "arn:aws:lambda:us-east-1:123456789012:function:custom-runtime"), - (AmazonHeaders.traceID, "Root=1-5bef4de7-ad49b0e87f6ef6c87fc2e700;Parent=9a9197af755a6419;Sampled=1"), - (AmazonHeaders.deadline, String(deadline)), - ] - } else if request.head.uri.hasSuffix("/response") { - responseStatus = .accepted - } else { - responseStatus = .notFound - } - self.writeResponse(context: context, status: responseStatus, headers: responseHeaders, body: responseBody) - } - - func writeResponse( - context: ChannelHandlerContext, - status: HTTPResponseStatus, - headers: [(String, String)]? = nil, - body: String? = nil - ) { - var headers = HTTPHeaders(headers ?? []) - headers.add(name: "content-length", value: "\(body?.utf8.count ?? 0)") - let head = HTTPResponseHead(version: HTTPVersion(major: 1, minor: 1), status: status, headers: headers) - - context.write(wrapOutboundOut(.head(head))).whenFailure { error in - print("\(self) write error \(error)") - } - - if let b = body { - var buffer = context.channel.allocator.buffer(capacity: b.utf8.count) - buffer.writeString(b) - context.write(wrapOutboundOut(.body(.byteBuffer(buffer)))).whenFailure { error in - print("\(self) write error \(error)") - } - } - - context.writeAndFlush(wrapOutboundOut(.end(nil))).whenComplete { result in - if case .failure(let error) = result { - print("\(self) write error \(error)") - } - } - } -} - -enum ServerError: Error { - case notReady - case cantBind -} - -enum AmazonHeaders { - static let requestID = "Lambda-Runtime-Aws-Request-Id" - static let traceID = "Lambda-Runtime-Trace-Id" - static let clientContext = "X-Amz-Client-Context" - static let cognitoIdentity = "X-Amz-Cognito-Identity" - static let deadline = "Lambda-Runtime-Deadline-Ms" - static let invokedFunctionARN = "Lambda-Runtime-Invoked-Function-Arn" -} - -enum Mode: String { - case string - case json -} - -func env(_ name: String) -> String? { - guard let value = getenv(name) else { - return nil - } - return String(cString: value) -} - -// main -let server = MockServer() -try! server.start() -dispatchMain() From 888ed77ab5bae76acec6c2a55899f309d2210bbc Mon Sep 17 00:00:00 2001 From: Sebastien Stormacq Date: Mon, 13 Jan 2025 18:36:31 +0100 Subject: [PATCH 02/15] apply swift format --- Sources/MockServer/MockHTTPServer.swift | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/Sources/MockServer/MockHTTPServer.swift b/Sources/MockServer/MockHTTPServer.swift index a730de11..0de58d0c 100644 --- a/Sources/MockServer/MockHTTPServer.swift +++ b/Sources/MockServer/MockHTTPServer.swift @@ -27,7 +27,7 @@ import Foundation @main public class MockHttpServer { - public static func main() throws { + public static func main() throws { let server = MockHttpServer() try server.start() } @@ -172,7 +172,9 @@ private final class HTTPHandler: ChannelInboundHandler { var responseBody: String = "" var responseHeaders: [(String, String)] = [] - logger.trace("Processing request for : \(requestHead) - \(requestBody.getString(at: 0, length: self.requestBodyBytes) ?? "")") + logger.trace( + "Processing request for : \(requestHead) - \(requestBody.getString(at: 0, length: self.requestBodyBytes) ?? "")" + ) if requestHead.uri.hasSuffix("/next") { logger.trace("URI /next") @@ -182,7 +184,7 @@ private final class HTTPHandler: ChannelInboundHandler { let requestId = UUID().uuidString switch self.mode { case .string: - responseBody = "\"\(requestId)\"" // must be a valid JSON string + responseBody = "\"\(requestId)\"" // must be a valid JSON string case .json: responseBody = "{ \"body\": \"\(requestId)\" }" } @@ -194,7 +196,7 @@ private final class HTTPHandler: ChannelInboundHandler { (AmazonHeaders.traceID, "Root=1-5bef4de7-ad49b0e87f6ef6c87fc2e700;Parent=9a9197af755a6419;Sampled=1"), (AmazonHeaders.deadline, String(deadline)), ] - } else if requestHead.uri.hasSuffix("/response") { + } else if requestHead.uri.hasSuffix("/response") { logger.trace("URI /response") responseStatus = .accepted } else if requestHead.uri.hasSuffix("/error") { From e7b7e6ccfee7abdf26c237a109f8355e5e6f8a2c Mon Sep 17 00:00:00 2001 From: Sebastien Stormacq Date: Mon, 13 Jan 2025 18:50:45 +0100 Subject: [PATCH 03/15] simplify ByteBuffer to String --- Sources/MockServer/MockHTTPServer.swift | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Sources/MockServer/MockHTTPServer.swift b/Sources/MockServer/MockHTTPServer.swift index 0de58d0c..8e0e56fd 100644 --- a/Sources/MockServer/MockHTTPServer.swift +++ b/Sources/MockServer/MockHTTPServer.swift @@ -173,7 +173,7 @@ private final class HTTPHandler: ChannelInboundHandler { var responseHeaders: [(String, String)] = [] logger.trace( - "Processing request for : \(requestHead) - \(requestBody.getString(at: 0, length: self.requestBodyBytes) ?? "")" + "Processing request for : \(requestHead) - \(String(requestBody))" ) if requestHead.uri.hasSuffix("/next") { From 61ab17bafd8532dd5e3731905059e049a095a065 Mon Sep 17 00:00:00 2001 From: Sebastien Stormacq Date: Mon, 13 Jan 2025 18:56:51 +0100 Subject: [PATCH 04/15] fix --- Sources/MockServer/MockHTTPServer.swift | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/Sources/MockServer/MockHTTPServer.swift b/Sources/MockServer/MockHTTPServer.swift index 8e0e56fd..63fe3e72 100644 --- a/Sources/MockServer/MockHTTPServer.swift +++ b/Sources/MockServer/MockHTTPServer.swift @@ -42,9 +42,9 @@ public class MockHttpServer { let socketBootstrap = ServerBootstrap(group: MultiThreadedEventLoopGroup(numberOfThreads: System.coreCount)) // Specify backlog and enable SO_REUSEADDR for the server itself - // .serverChannelOption(.backlog, value: 256) + .serverChannelOption(.backlog, value: 256) .serverChannelOption(.socketOption(.so_reuseaddr), value: 1) - // .childChannelOption(.maxMessagesPerRead, value: 1) + .childChannelOption(.maxMessagesPerRead, value: 1) // Set the handlers that are applied to the accepted Channels .childChannelInitializer { channel in @@ -173,7 +173,7 @@ private final class HTTPHandler: ChannelInboundHandler { var responseHeaders: [(String, String)] = [] logger.trace( - "Processing request for : \(requestHead) - \(String(requestBody))" + "Processing request for : \(requestHead) - \(String(buffer: requestBody))" ) if requestHead.uri.hasSuffix("/next") { From bc3a34d301f40bffce4c0ff501d863d0b73b82d8 Mon Sep 17 00:00:00 2001 From: Sebastien Stormacq Date: Mon, 13 Jan 2025 19:18:37 +0100 Subject: [PATCH 05/15] remove unused code --- Sources/MockServer/MockHTTPServer.swift | 30 ------------------------- 1 file changed, 30 deletions(-) diff --git a/Sources/MockServer/MockHTTPServer.swift b/Sources/MockServer/MockHTTPServer.swift index 63fe3e72..a8a1663d 100644 --- a/Sources/MockServer/MockHTTPServer.swift +++ b/Sources/MockServer/MockHTTPServer.swift @@ -65,35 +65,10 @@ private final class HTTPHandler: ChannelInboundHandler { public typealias InboundIn = HTTPServerRequestPart public typealias OutboundOut = HTTPServerResponsePart - private enum State { - case idle - case waitingForRequestBody - case sendingResponse - - mutating func requestReceived() { - precondition(self == .idle, "Invalid state for request received: \(self)") - self = .waitingForRequestBody - } - - mutating func requestComplete() { - precondition( - self == .waitingForRequestBody, - "Invalid state for request complete: \(self)" - ) - self = .sendingResponse - } - - mutating func responseComplete() { - precondition(self == .sendingResponse, "Invalid state for response complete: \(self)") - self = .idle - } - } - private let logger: Logger private let mode: Mode private var buffer: ByteBuffer! = nil - private var state: HTTPHandler.State = .idle private var keepAlive = false private var requestHead: HTTPRequestHead? @@ -125,14 +100,12 @@ private final class HTTPHandler: ChannelInboundHandler { self.requestHead = request self.requestBodyBytes = 0 self.keepAlive = request.isKeepAlive - self.state.requestReceived() case .body(buffer: var buf): logger.trace("Received request .body") self.requestBodyBytes += buf.readableBytes self.buffer.writeBuffer(&buf) case .end: logger.trace("Received request .end") - self.state.requestComplete() precondition(requestHead != nil, "Received .end without .head") let (responseStatus, responseHeaders, responseBody) = self.processRequest( @@ -190,7 +163,6 @@ private final class HTTPHandler: ChannelInboundHandler { } let deadline = Int64(Date(timeIntervalSinceNow: 60).timeIntervalSince1970 * 1000) responseHeaders = [ - // ("Connection", "close"), (AmazonHeaders.requestID, requestId), (AmazonHeaders.invokedFunctionARN, "arn:aws:lambda:us-east-1:123456789012:function:custom-runtime"), (AmazonHeaders.traceID, "Root=1-5bef4de7-ad49b0e87f6ef6c87fc2e700;Parent=9a9197af755a6419;Sampled=1"), @@ -215,8 +187,6 @@ private final class HTTPHandler: ChannelInboundHandler { trailers: HTTPHeaders?, promise: EventLoopPromise? ) { - self.state.responseComplete() - let eventLoop = context.eventLoop let loopBoundContext = NIOLoopBound(context, eventLoop: eventLoop) From ccdb45a2ac5c17526407099dcc8bbf86618f377a Mon Sep 17 00:00:00 2001 From: Sebastien Stormacq Date: Tue, 14 Jan 2025 06:39:13 +0100 Subject: [PATCH 06/15] adjust payload to new examples --- Sources/MockServer/MockHTTPServer.swift | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/Sources/MockServer/MockHTTPServer.swift b/Sources/MockServer/MockHTTPServer.swift index a8a1663d..55468e67 100644 --- a/Sources/MockServer/MockHTTPServer.swift +++ b/Sources/MockServer/MockHTTPServer.swift @@ -157,9 +157,9 @@ private final class HTTPHandler: ChannelInboundHandler { let requestId = UUID().uuidString switch self.mode { case .string: - responseBody = "\"\(requestId)\"" // must be a valid JSON string + responseBody = "\"Seb\"" // must be a valid JSON document case .json: - responseBody = "{ \"body\": \"\(requestId)\" }" + responseBody = "{ \"name\": \"Seb\", \"age\" : 52 }" } let deadline = Int64(Date(timeIntervalSinceNow: 60).timeIntervalSince1970 * 1000) responseHeaders = [ From ae29289e75d1a32e4b75a33476d9fbff5e187b6a Mon Sep 17 00:00:00 2001 From: Sebastien Stormacq Date: Tue, 14 Jan 2025 07:28:32 +0100 Subject: [PATCH 07/15] wip conformance to Swift 6 --- .../Lambda+LocalServer.swift | 29 +++++++++++-------- 1 file changed, 17 insertions(+), 12 deletions(-) diff --git a/Sources/AWSLambdaRuntimeCore/Lambda+LocalServer.swift b/Sources/AWSLambdaRuntimeCore/Lambda+LocalServer.swift index a23ef1cf..5344543e 100644 --- a/Sources/AWSLambdaRuntimeCore/Lambda+LocalServer.swift +++ b/Sources/AWSLambdaRuntimeCore/Lambda+LocalServer.swift @@ -97,7 +97,8 @@ private enum LocalLambda { public typealias InboundIn = HTTPServerRequestPart public typealias OutboundOut = HTTPServerResponsePart - private var pending = CircularBuffer<(head: HTTPRequestHead, body: ByteBuffer?)>() + private var requestHead: HTTPRequestHead? + private var requestBody: ByteBuffer? private static var invocations = CircularBuffer() private static var invocationState = InvocationState.waitingForLambdaRequest @@ -110,23 +111,27 @@ private enum LocalLambda { self.invocationEndpoint = invocationEndpoint } + func handlerAdded(context: ChannelHandlerContext) { + self.requestBody = context.channel.allocator.buffer(capacity: 0) + } + func channelRead(context: ChannelHandlerContext, data: NIOAny) { let requestPart = unwrapInboundIn(data) switch requestPart { case .head(let head): - self.pending.append((head: head, body: nil)) - case .body(var buffer): - var request = self.pending.removeFirst() - if request.body == nil { - request.body = buffer - } else { - request.body!.writeBuffer(&buffer) - } - self.pending.prepend(request) + precondition(self.requestHead == nil, "received two HTTP heads") + precondition(self.requestBody != nil, "body buffer is not initialized") + self.requestHead = head + self.requestBody!.clear() + case .body(buffer: var buf): + precondition(self.requestHead != nil, "received HTTP body before head") + precondition(self.requestBody != nil, "body buffer is not initialized") + self.requestBody!.writeBuffer(&buf) case .end: - let request = self.pending.removeFirst() - self.processRequest(context: context, request: request) + precondition(self.requestHead != nil, "received HTTP end before head") + self.processRequest(context: context, request: (head: self.requestHead!, body: self.requestBody)) + self.requestHead = nil } } From eb1608b55b90a5f280d44e23f99e4ed0f618a3fc Mon Sep 17 00:00:00 2001 From: Sebastien Stormacq Date: Wed, 15 Jan 2025 10:21:27 +0100 Subject: [PATCH 08/15] [wip] use NIOAsyncChannel --- Sources/MockServer/MockHTTPServer.swift | 330 +++++++++++++----------- 1 file changed, 180 insertions(+), 150 deletions(-) diff --git a/Sources/MockServer/MockHTTPServer.swift b/Sources/MockServer/MockHTTPServer.swift index 55468e67..34a923d1 100644 --- a/Sources/MockServer/MockHTTPServer.swift +++ b/Sources/MockServer/MockHTTPServer.swift @@ -16,6 +16,7 @@ import Logging import NIOCore import NIOHTTP1 import NIOPosix +import Synchronization // for UUID and Date #if canImport(FoundationEssentials) @@ -25,133 +26,169 @@ import Foundation #endif @main -public class MockHttpServer { - - public static func main() throws { - let server = MockHttpServer() - try server.start() - } +struct HttpServer { + /// The server's host. (default: 127.0.0.1) + private let host: String + /// The server's port. (default: 7000) + private let port: Int + /// The server's event loop group. (default: MultiThreadedEventLoopGroup.singleton) + private let eventLoopGroup: MultiThreadedEventLoopGroup + /// the mode. Are we mocking a server for a Lambda function that expects a String or a JSON document? (default: string) + private let mode: Mode + /// the number of connections this server must accept before shutting down (default: 1) + private let maxInvocations: Int + /// the logger (control verbosity with LOG_LEVEL environment variable) + private let logger: Logger - private func start() throws { - let host = env("HOST") ?? "127.0.0.1" - let port = env("PORT").flatMap(Int.init) ?? 7000 - let mode = env("MODE").flatMap(Mode.init) ?? .string + static func main() async throws { var log = Logger(label: "MockServer") log.logLevel = env("LOG_LEVEL").flatMap(Logger.Level.init) ?? .info - let logger = log - let socketBootstrap = ServerBootstrap(group: MultiThreadedEventLoopGroup(numberOfThreads: System.coreCount)) - // Specify backlog and enable SO_REUSEADDR for the server itself + let server = HttpServer( + host: env("HOST") ?? "127.0.0.1", + port: env("PORT").flatMap(Int.init) ?? 7000, + eventLoopGroup: .singleton, + mode: env("MODE").flatMap(Mode.init) ?? .string, + maxInvocations: env("MAX_INVOCATIONS").flatMap(Int.init) ?? 1, + logger: log + ) + try await server.run() + } + + /// This method starts the server and handles incoming connections. + private func run() async throws { + let channel = try await ServerBootstrap(group: self.eventLoopGroup) .serverChannelOption(.backlog, value: 256) .serverChannelOption(.socketOption(.so_reuseaddr), value: 1) .childChannelOption(.maxMessagesPerRead, value: 1) + .bind( + host: self.host, + port: self.port + ) { channel in + channel.eventLoop.makeCompletedFuture { + + try channel.pipeline.syncOperations.configureHTTPServerPipeline( + withErrorHandling: true + ) - // Set the handlers that are applied to the accepted Channels - .childChannelInitializer { channel in - channel.pipeline.configureHTTPServerPipeline(withErrorHandling: true).flatMap { - channel.pipeline.addHandler(HTTPHandler(mode: mode, logger: logger)) + return try NIOAsyncChannel( + wrappingChannelSynchronously: channel, + configuration: NIOAsyncChannel.Configuration( + inboundType: HTTPServerRequestPart.self, + outboundType: HTTPServerResponsePart.self + ) + ) } } - let channel = try socketBootstrap.bind(host: host, port: port).wait() - logger.debug("Server started and listening on \(host):\(port)") - - // This will never return as we don't close the ServerChannel - try channel.closeFuture.wait() - } -} - -private final class HTTPHandler: ChannelInboundHandler { - public typealias InboundIn = HTTPServerRequestPart - public typealias OutboundOut = HTTPServerResponsePart - - private let logger: Logger - private let mode: Mode - - private var buffer: ByteBuffer! = nil - private var keepAlive = false - - private var requestHead: HTTPRequestHead? - private var requestBodyBytes: Int = 0 - - init(mode: Mode, logger: Logger) { - self.mode = mode - self.logger = logger - } - - func channelRead(context: ChannelHandlerContext, data: NIOAny) { - let reqPart = Self.unwrapInboundIn(data) - handle(context: context, request: reqPart) - } - - func channelReadComplete(context: ChannelHandlerContext) { - context.flush() - self.buffer.clear() - } + logger.info( + "Server started and listening", + metadata: [ + "host": "\(channel.channel.localAddress?.ipAddress?.debugDescription ?? "")", + "port": "\(channel.channel.localAddress?.port ?? 0)", + ] + ) - func handlerAdded(context: ChannelHandlerContext) { - self.buffer = context.channel.allocator.buffer(capacity: 0) + // We are handling each incoming connection in a separate child task. It is important + // to use a discarding task group here which automatically discards finished child tasks. + // A normal task group retains all child tasks and their outputs in memory until they are + // consumed by iterating the group or by exiting the group. Since, we are never consuming + // the results of the group we need the group to automatically discard them; otherwise, this + // would result in a memory leak over time. + try await withThrowingDiscardingTaskGroup { group in + try await channel.executeThenClose { inbound in + for try await connectionChannel in inbound { + logger.trace("Handling new connection") + logger.info( + "This mock server accepts only one connection, it will shutdown the server after handling the current connection." + ) + group.addTask { + await self.handleConnection(channel: connectionChannel) + logger.trace("Done handling connection") + } + break + } + } + } + logger.info("Server shutting down") } - private func handle(context: ChannelHandlerContext, request: HTTPServerRequestPart) { - switch request { - case .head(let request): - logger.trace("Received request .head") - self.requestHead = request - self.requestBodyBytes = 0 - self.keepAlive = request.isKeepAlive - case .body(buffer: var buf): - logger.trace("Received request .body") - self.requestBodyBytes += buf.readableBytes - self.buffer.writeBuffer(&buf) - case .end: - logger.trace("Received request .end") - - precondition(requestHead != nil, "Received .end without .head") - let (responseStatus, responseHeaders, responseBody) = self.processRequest( - requestHead: self.requestHead!, - requestBody: self.buffer - ) + /// This method handles a single connection by echoing back all inbound data. + private func handleConnection( + channel: NIOAsyncChannel + ) async { + + var requestHead: HTTPRequestHead! + var requestBody: ByteBuffer? + + // each Lambda invocation results in TWO HTTP requests (next and response) + let requestCount = RequestCounter(maxRequest: self.maxInvocations * 2) + + // Note that this method is non-throwing and we are catching any error. + // We do this since we don't want to tear down the whole server when a single connection + // encounters an error. + do { + try await channel.executeThenClose { inbound, outbound in + for try await inboundData in inbound { + let requestNumber = requestCount.current() + logger.trace("Handling request", metadata: ["requestNumber": "\(requestNumber)"]) + + if case .head(let head) = inboundData { + logger.trace("Received request head", metadata: ["head": "\(head)"]) + requestHead = head + } + if case .body(let body) = inboundData { + logger.trace("Received request body", metadata: ["body": "\(body)"]) + requestBody = body + } + if case .end(let end) = inboundData { + logger.trace("Received request end", metadata: ["end": "\(String(describing: end))"]) + + precondition(requestHead != nil, "Received .end without .head") + let (responseStatus, responseHeaders, responseBody) = self.processRequest( + requestHead: requestHead, + requestBody: requestBody + ) - self.buffer.clear() - self.buffer.writeString(responseBody) + try await self.sendResponse( + responseStatus: responseStatus, + responseHeaders: responseHeaders, + responseBody: responseBody, + outbound: outbound + ) - var headers = HTTPHeaders(responseHeaders) - headers.add(name: "Content-Length", value: "\(responseBody.utf8.count)") + requestHead = nil - // write the response - context.write( - Self.wrapOutboundOut( - .head( - httpResponseHead( - request: self.requestHead!, - status: responseStatus, - headers: headers - ) - ) - ), - promise: nil - ) - context.write(Self.wrapOutboundOut(.body(.byteBuffer(self.buffer))), promise: nil) - self.completeResponse(context, trailers: nil, promise: nil) + if requestCount.increment() { + logger.info( + "Maximum number of invocations reached, closing this connection", + metadata: ["maxInvocations": "\(self.maxInvocations)"] + ) + break + } + } + } + } + } catch { + logger.error("Hit error: \(error)") } } - + /// This function process the requests and return an hard-coded response (string or JSON depending on the mode). + /// We ignore the requestBody. private func processRequest( requestHead: HTTPRequestHead, - requestBody: ByteBuffer + requestBody: ByteBuffer? ) -> (HTTPResponseStatus, [(String, String)], String) { var responseStatus: HTTPResponseStatus = .ok var responseBody: String = "" var responseHeaders: [(String, String)] = [] logger.trace( - "Processing request for : \(requestHead) - \(String(buffer: requestBody))" + "Processing request", + metadata: ["VERB": "\(requestHead.method)", "URI": "\(requestHead.uri)"] ) if requestHead.uri.hasSuffix("/next") { - logger.trace("URI /next") - responseStatus = .accepted let requestId = UUID().uuidString @@ -169,64 +206,51 @@ private final class HTTPHandler: ChannelInboundHandler { (AmazonHeaders.deadline, String(deadline)), ] } else if requestHead.uri.hasSuffix("/response") { - logger.trace("URI /response") responseStatus = .accepted } else if requestHead.uri.hasSuffix("/error") { - logger.trace("URI /error") responseStatus = .ok } else { - logger.trace("Unknown URI : \(requestHead)") responseStatus = .notFound } logger.trace("Returning response: \(responseStatus), \(responseHeaders), \(responseBody)") return (responseStatus, responseHeaders, responseBody) } - private func completeResponse( - _ context: ChannelHandlerContext, - trailers: HTTPHeaders?, - promise: EventLoopPromise? - ) { - let eventLoop = context.eventLoop - let loopBoundContext = NIOLoopBound(context, eventLoop: eventLoop) - - let promise = self.keepAlive ? promise : (promise ?? context.eventLoop.makePromise()) - if !self.keepAlive { - promise!.futureResult.whenComplete { (_: Result) in - let context = loopBoundContext.value - context.close(promise: nil) - } - } - - context.writeAndFlush(Self.wrapOutboundOut(.end(trailers)), promise: promise) + private func sendResponse( + responseStatus: HTTPResponseStatus, + responseHeaders: [(String, String)], + responseBody: String, + outbound: NIOAsyncChannelOutboundWriter + ) async throws { + var headers = HTTPHeaders(responseHeaders) + headers.add(name: "Content-Length", value: "\(responseBody.utf8.count)") + + logger.trace("Writing response head") + try await outbound.write( + HTTPServerResponsePart.head( + HTTPResponseHead( + version: .init(major: 1, minor: 1), + status: responseStatus, + headers: headers + ) + ) + ) + logger.trace("Writing response body") + try await outbound.write(HTTPServerResponsePart.body(.byteBuffer(ByteBuffer(string: responseBody)))) + logger.trace("Writing response end") + try await outbound.write(HTTPServerResponsePart.end(nil)) } - private func httpResponseHead( - request: HTTPRequestHead, - status: HTTPResponseStatus, - headers: HTTPHeaders = HTTPHeaders() - ) -> HTTPResponseHead { - var head = HTTPResponseHead(version: request.version, status: status, headers: headers) - let connectionHeaders: [String] = head.headers[canonicalForm: "connection"].map { - $0.lowercased() - } - - if !connectionHeaders.contains("keep-alive") && !connectionHeaders.contains("close") { - // the user hasn't pre-set either 'keep-alive' or 'close', so we might need to add headers + private enum Mode: String { + case string + case json + } - switch (request.isKeepAlive, request.version.major, request.version.minor) { - case (true, 1, 0): - // HTTP/1.0 and the request has 'Connection: keep-alive', we should mirror that - head.headers.add(name: "Connection", value: "keep-alive") - case (false, 1, let n) where n >= 1: - // HTTP/1.1 (or treated as such) and the request has 'Connection: close', we should mirror that - head.headers.add(name: "Connection", value: "close") - default: - // we should match the default or are dealing with some HTTP that we don't support, let's leave as is - () - } + private static func env(_ name: String) -> String? { + guard let value = getenv(name) else { + return nil } - return head + return String(cString: value) } private enum ServerError: Error { @@ -242,16 +266,22 @@ private final class HTTPHandler: ChannelInboundHandler { static let deadline = "Lambda-Runtime-Deadline-Ms" static let invokedFunctionARN = "Lambda-Runtime-Invoked-Function-Arn" } -} -private enum Mode: String { - case string - case json -} + private final class RequestCounter: Sendable { + private let counterMutex = Mutex(0) + private let maxRequest: Int -private func env(_ name: String) -> String? { - guard let value = getenv(name) else { - return nil + init(maxRequest: Int) { + self.maxRequest = maxRequest + } + func current() -> Int { + counterMutex.withLock { $0 } + } + func increment() -> Bool { + counterMutex.withLock { + $0 += 1 + return $0 >= maxRequest + } + } } - return String(cString: value) } From 32dace6948727731871ca52d39f51a65e153c471 Mon Sep 17 00:00:00 2001 From: Sebastien Stormacq Date: Wed, 15 Jan 2025 11:38:59 +0100 Subject: [PATCH 09/15] manage max number of connections and max number of request per connection --- Sources/MockServer/MockHTTPServer.swift | 52 ++++++++++++++++--------- 1 file changed, 34 insertions(+), 18 deletions(-) diff --git a/Sources/MockServer/MockHTTPServer.swift b/Sources/MockServer/MockHTTPServer.swift index 34a923d1..b8f18998 100644 --- a/Sources/MockServer/MockHTTPServer.swift +++ b/Sources/MockServer/MockHTTPServer.swift @@ -55,7 +55,8 @@ struct HttpServer { try await server.run() } - /// This method starts the server and handles incoming connections. + /// This method starts the server and handles one unique incoming connections + /// The Lambda function will send two HTTP requests over this connection: one for the next invocation and one for the response. private func run() async throws { let channel = try await ServerBootstrap(group: self.eventLoopGroup) .serverChannelOption(.backlog, value: 256) @@ -86,9 +87,14 @@ struct HttpServer { metadata: [ "host": "\(channel.channel.localAddress?.ipAddress?.debugDescription ?? "")", "port": "\(channel.channel.localAddress?.port ?? 0)", + "maxInvocations": "\(self.maxInvocations)", ] ) + // This counter is used to track the number of incoming connections. + // This mock servers accepts n TCP connection then shutdowns + let connectionCounter = SharedCounter(maxValue: self.maxInvocations) + // We are handling each incoming connection in a separate child task. It is important // to use a discarding task group here which automatically discards finished child tasks. // A normal task group retains all child tasks and their outputs in memory until they are @@ -98,22 +104,31 @@ struct HttpServer { try await withThrowingDiscardingTaskGroup { group in try await channel.executeThenClose { inbound in for try await connectionChannel in inbound { - logger.trace("Handling new connection") - logger.info( - "This mock server accepts only one connection, it will shutdown the server after handling the current connection." - ) + + let counter = connectionCounter.current() + logger.trace("Handling new connection", metadata: ["connectionNumber": "\(counter)"]) + group.addTask { await self.handleConnection(channel: connectionChannel) - logger.trace("Done handling connection") + logger.trace("Done handling connection", metadata: ["connectionNumber": "\(counter)"]) + } + + if connectionCounter.increment() { + logger.info( + "Maximum number of connections reached, shutting down after current connection", + metadata: ["maxConnections": "\(self.maxInvocations)"] + ) + break // this causes the server to shutdown after handling the connection } - break } } } logger.info("Server shutting down") } - /// This method handles a single connection by echoing back all inbound data. + /// This method handles a single connection by responsing hard coded value to a Lambda function request. + /// It handles two requests: one for the next invocation and one for the response. + /// when the maximum number of requests is reached, it closes the connection. private func handleConnection( channel: NIOAsyncChannel ) async { @@ -122,7 +137,7 @@ struct HttpServer { var requestBody: ByteBuffer? // each Lambda invocation results in TWO HTTP requests (next and response) - let requestCount = RequestCounter(maxRequest: self.maxInvocations * 2) + let requestCount = SharedCounter(maxValue: 2) // Note that this method is non-throwing and we are catching any error. // We do this since we don't want to tear down the whole server when a single connection @@ -161,10 +176,10 @@ struct HttpServer { if requestCount.increment() { logger.info( - "Maximum number of invocations reached, closing this connection", - metadata: ["maxInvocations": "\(self.maxInvocations)"] + "Maximum number of requests reached, closing this connection", + metadata: ["maxRequest": "2"] ) - break + break // this finishes handiling request on this connection } } } @@ -224,12 +239,13 @@ struct HttpServer { ) async throws { var headers = HTTPHeaders(responseHeaders) headers.add(name: "Content-Length", value: "\(responseBody.utf8.count)") + headers.add(name: "KeepAlive", value: "timeout=1, max=2") logger.trace("Writing response head") try await outbound.write( HTTPServerResponsePart.head( HTTPResponseHead( - version: .init(major: 1, minor: 1), + version: .init(major: 1, minor: 1), // use HTTP 1.1 it keeps connection alive between requests status: responseStatus, headers: headers ) @@ -267,12 +283,12 @@ struct HttpServer { static let invokedFunctionARN = "Lambda-Runtime-Invoked-Function-Arn" } - private final class RequestCounter: Sendable { + private final class SharedCounter: Sendable { private let counterMutex = Mutex(0) - private let maxRequest: Int + private let maxValue: Int - init(maxRequest: Int) { - self.maxRequest = maxRequest + init(maxValue: Int) { + self.maxValue = maxValue } func current() -> Int { counterMutex.withLock { $0 } @@ -280,7 +296,7 @@ struct HttpServer { func increment() -> Bool { counterMutex.withLock { $0 += 1 - return $0 >= maxRequest + return $0 >= maxValue } } } From 5d4bfa73d4d9cf99dfd43ffe58164c2099dbc25b Mon Sep 17 00:00:00 2001 From: Sebastien Stormacq Date: Wed, 15 Jan 2025 23:45:58 +0100 Subject: [PATCH 10/15] LocalServer is compliant to Swift 6 concurrency --- .../Lambda+LocalServer.swift | 594 ++++++++++-------- 1 file changed, 344 insertions(+), 250 deletions(-) diff --git a/Sources/AWSLambdaRuntimeCore/Lambda+LocalServer.swift b/Sources/AWSLambdaRuntimeCore/Lambda+LocalServer.swift index 5344543e..bcddd199 100644 --- a/Sources/AWSLambdaRuntimeCore/Lambda+LocalServer.swift +++ b/Sources/AWSLambdaRuntimeCore/Lambda+LocalServer.swift @@ -19,13 +19,22 @@ import NIOConcurrencyHelpers import NIOCore import NIOHTTP1 import NIOPosix +import Synchronization // This functionality is designed for local testing hence being a #if DEBUG flag. + // For example: -// // try Lambda.withLocalServer { -// Lambda.run { (context: LambdaContext, event: String, callback: @escaping (Result) -> Void) in -// callback(.success("Hello, \(event)!")) +// try await LambdaRuntimeClient.withRuntimeClient( +// configuration: .init(ip: "127.0.0.1", port: 7000), +// eventLoop: self.eventLoop, +// logger: self.logger +// ) { runtimeClient in +// try await Lambda.runLoop( +// runtimeClient: runtimeClient, +// handler: handler, +// logger: self.logger +// ) // } // } extension Lambda { @@ -36,295 +45,380 @@ extension Lambda { /// - body: Code to run within the context of the mock server. Typically this would be a Lambda.run function call. /// /// - note: This API is designed strictly for local testing and is behind a DEBUG flag - static func withLocalServer( + static func withLocalServer( invocationEndpoint: String? = nil, - _ body: @escaping () async throws -> Value - ) async throws -> Value { - let server = LocalLambda.Server(invocationEndpoint: invocationEndpoint) - try await server.start().get() - defer { try! server.stop() } - return try await body() - } -} - -// MARK: - Local Mock Server - -private enum LocalLambda { - struct Server { - private let logger: Logger - private let group: EventLoopGroup - private let host: String - private let port: Int - private let invocationEndpoint: String - - init(invocationEndpoint: String?) { - var logger = Logger(label: "LocalLambdaServer") - logger.logLevel = .info - self.logger = logger - self.group = MultiThreadedEventLoopGroup(numberOfThreads: 1) - self.host = "127.0.0.1" - self.port = 7000 - self.invocationEndpoint = invocationEndpoint ?? "/invoke" - } - - func start() -> EventLoopFuture { - let bootstrap = ServerBootstrap(group: group) - .serverChannelOption(ChannelOptions.socket(SocketOptionLevel(SOL_SOCKET), SO_REUSEADDR), value: 1) - .childChannelInitializer { channel in - channel.pipeline.configureHTTPServerPipeline(withErrorHandling: true).flatMap { _ in - channel.pipeline.addHandler( - HTTPHandler(logger: self.logger, invocationEndpoint: self.invocationEndpoint) - ) - } - } - return bootstrap.bind(host: self.host, port: self.port).flatMap { channel -> EventLoopFuture in - guard channel.localAddress != nil else { - return channel.eventLoop.makeFailedFuture(ServerError.cantBind) + _ body: @escaping () async throws -> Void + ) async throws { + + // launch the local server and wait for it to be started before running the body + try await withThrowingTaskGroup(of: Void.self) { group in + // this call will return when the server calls continuation.resume() + try await withCheckedThrowingContinuation { (continuation: CheckedContinuation) in + group.addTask { + try await LambdaHttpServer(invocationEndpoint: invocationEndpoint).start(continuation: continuation) } - self.logger.info( - "LocalLambdaServer started and listening on \(self.host):\(self.port), receiving events on \(self.invocationEndpoint)" - ) - return channel.eventLoop.makeSucceededFuture(()) } - } - - func stop() throws { - try self.group.syncShutdownGracefully() + // now that server is started, run the Lambda function itself + try await body() } } +} - final class HTTPHandler: ChannelInboundHandler { - public typealias InboundIn = HTTPServerRequestPart - public typealias OutboundOut = HTTPServerResponsePart - - private var requestHead: HTTPRequestHead? - private var requestBody: ByteBuffer? - - private static var invocations = CircularBuffer() - private static var invocationState = InvocationState.waitingForLambdaRequest - - private let logger: Logger - private let invocationEndpoint: String - - init(logger: Logger, invocationEndpoint: String) { - self.logger = logger - self.invocationEndpoint = invocationEndpoint - } +// MARK: - Local HTTP Server + +/// An HTTP server that behaves like the AWS Lambda service for local testing. +/// This server is used to simulate the AWS Lambda service for local testing but also to accept invocation requests from the lambda client. +/// +/// It accepts three types of requests from the Lambda function (through the LambdaRuntimeClient): +/// 1. GET /next - the lambda function polls this endpoint to get the next invocation request +/// 2. POST /:requestID/response - the lambda function posts the response to the invocation request +/// 3. POST /:requestID/error - the lambda function posts an error response to the invocation request +/// +/// It also accepts one type of request from the client invoking the lambda function: +/// 1. POST /invoke - the client posts the event to the lambda function +/// +/// This server passes the data received from /invoke POST request to the lambda function (GET /next) and then forwards the response back to the client. +private struct LambdaHttpServer { + private let logger: Logger + private let group: EventLoopGroup + private let host: String + private let port: Int + private let invocationEndpoint: String + + private let invocationPool = Pool() + private let responsePool = Pool() + + init(invocationEndpoint: String?) { + var logger = Logger(label: "LocalServer") + logger.logLevel = Lambda.env("LOG_LEVEL").flatMap(Logger.Level.init) ?? .info + self.logger = logger + self.group = MultiThreadedEventLoopGroup.singleton + self.host = "127.0.0.1" + self.port = 7000 + self.invocationEndpoint = invocationEndpoint ?? "/invoke" + } - func handlerAdded(context: ChannelHandlerContext) { - self.requestBody = context.channel.allocator.buffer(capacity: 0) - } + func start(continuation: CheckedContinuation) async throws { + let channel = try await ServerBootstrap(group: self.group) + .serverChannelOption(.backlog, value: 256) + .serverChannelOption(.socketOption(.so_reuseaddr), value: 1) + .childChannelOption(.maxMessagesPerRead, value: 1) + .bind( + host: self.host, + port: self.port + ) { channel in + channel.eventLoop.makeCompletedFuture { + + try channel.pipeline.syncOperations.configureHTTPServerPipeline( + withErrorHandling: true + ) - func channelRead(context: ChannelHandlerContext, data: NIOAny) { - let requestPart = unwrapInboundIn(data) - - switch requestPart { - case .head(let head): - precondition(self.requestHead == nil, "received two HTTP heads") - precondition(self.requestBody != nil, "body buffer is not initialized") - self.requestHead = head - self.requestBody!.clear() - case .body(buffer: var buf): - precondition(self.requestHead != nil, "received HTTP body before head") - precondition(self.requestBody != nil, "body buffer is not initialized") - self.requestBody!.writeBuffer(&buf) - case .end: - precondition(self.requestHead != nil, "received HTTP end before head") - self.processRequest(context: context, request: (head: self.requestHead!, body: self.requestBody)) - self.requestHead = nil + return try NIOAsyncChannel( + wrappingChannelSynchronously: channel, + configuration: NIOAsyncChannel.Configuration( + inboundType: HTTPServerRequestPart.self, + outboundType: HTTPServerResponsePart.self + ) + ) + } } - } - - func processRequest(context: ChannelHandlerContext, request: (head: HTTPRequestHead, body: ByteBuffer?)) { - - let eventLoop = context.eventLoop - let loopBoundContext = NIOLoopBound(context, eventLoop: eventLoop) - switch (request.head.method, request.head.uri) { - // this endpoint is called by the client invoking the lambda - case (.POST, let url) where url.hasSuffix(self.invocationEndpoint): - guard let work = request.body else { - return self.writeResponse(context: context, response: .init(status: .badRequest)) - } - let requestID = "\(DispatchTime.now().uptimeNanoseconds)" // FIXME: - let promise = context.eventLoop.makePromise(of: Response.self) - promise.futureResult.whenComplete { result in - let context = loopBoundContext.value - switch result { - case .failure(let error): - self.logger.error("invocation error: \(error)") - self.writeResponse(context: context, response: .init(status: .internalServerError)) - case .success(let response): - self.writeResponse(context: context, response: response) + // notify the caller that the server is started + continuation.resume() + logger.info( + "Server started and listening", + metadata: [ + "host": "\(channel.channel.localAddress?.ipAddress?.debugDescription ?? "")", + "port": "\(channel.channel.localAddress?.port ?? 0)", + ] + ) + + // We are handling each incoming connection in a separate child task. It is important + // to use a discarding task group here which automatically discards finished child tasks. + // A normal task group retains all child tasks and their outputs in memory until they are + // consumed by iterating the group or by exiting the group. Since, we are never consuming + // the results of the group we need the group to automatically discard them; otherwise, this + // would result in a memory leak over time. + try await withThrowingDiscardingTaskGroup { group in + try await channel.executeThenClose { inbound in + for try await connectionChannel in inbound { + + group.addTask { + logger.trace("Handling a new connection") + await self.handleConnection(channel: connectionChannel) + logger.trace("Done handling the connection") } } - let invocation = Invocation(requestID: requestID, request: work, responsePromise: promise) - switch Self.invocationState { - case .waitingForInvocation(let promise): - promise.succeed(invocation) - case .waitingForLambdaRequest, .waitingForLambdaResponse: - Self.invocations.append(invocation) - } + } + } + logger.info("Server shutting down") + } - // lambda invocation using the wrong http method - case (_, let url) where url.hasSuffix(self.invocationEndpoint): - self.writeResponse(context: context, status: .methodNotAllowed) - - // /next endpoint is called by the lambda polling for work - case (.GET, let url) where url.hasSuffix(Consts.getNextInvocationURLSuffix): - // check if our server is in the correct state - guard case .waitingForLambdaRequest = Self.invocationState else { - self.logger.error("invalid invocation state \(Self.invocationState)") - self.writeResponse(context: context, response: .init(status: .unprocessableEntity)) - return - } + /// This method handles individual TCP connections + private func handleConnection( + channel: NIOAsyncChannel + ) async { + + var requestHead: HTTPRequestHead! + var requestBody: ByteBuffer? + + // Note that this method is non-throwing and we are catching any error. + // We do this since we don't want to tear down the whole server when a single connection + // encounters an error. + do { + try await channel.executeThenClose { inbound, outbound in + for try await inboundData in inbound { + if case .head(let head) = inboundData { + requestHead = head + } + if case .body(let body) = inboundData { + requestBody = body + } + if case .end = inboundData { + precondition(requestHead != nil, "Received .end without .head") + // process the request + let response = try await self.processRequest( + head: requestHead, + body: requestBody + ) + // send the responses + try await self.sendResponse( + response: response, + outbound: outbound + ) - // pop the first task from the queue - switch Self.invocations.popFirst() { - case .none: - // if there is nothing in the queue, - // create a promise that we can fullfill when we get a new task - let promise = context.eventLoop.makePromise(of: Invocation.self) - promise.futureResult.whenComplete { result in - let context = loopBoundContext.value - switch result { - case .failure(let error): - self.logger.error("invocation error: \(error)") - self.writeResponse(context: context, status: .internalServerError) - case .success(let invocation): - Self.invocationState = .waitingForLambdaResponse(invocation) - self.writeResponse(context: context, response: invocation.makeResponse()) - } + requestHead = nil + requestBody = nil } - Self.invocationState = .waitingForInvocation(promise) - case .some(let invocation): - // if there is a task pending, we can immediately respond with it. - Self.invocationState = .waitingForLambdaResponse(invocation) - self.writeResponse(context: context, response: invocation.makeResponse()) } + } + } catch { + logger.error("Hit error: \(error)") + } + } - // :requestID/response endpoint is called by the lambda posting the response - case (.POST, let url) where url.hasSuffix(Consts.postResponseURLSuffix): - let parts = request.head.uri.split(separator: "/") - guard let requestID = parts.count > 2 ? String(parts[parts.count - 2]) : nil else { - // the request is malformed, since we were expecting a requestId in the path - return self.writeResponse(context: context, status: .badRequest) - } - guard case .waitingForLambdaResponse(let invocation) = Self.invocationState else { - // a response was send, but we did not expect to receive one - self.logger.error("invalid invocation state \(Self.invocationState)") - return self.writeResponse(context: context, status: .unprocessableEntity) - } - guard requestID == invocation.requestID else { - // the request's requestId is not matching the one we are expecting - self.logger.error( - "invalid invocation state request ID \(requestID) does not match expected \(invocation.requestID)" - ) - return self.writeResponse(context: context, status: .badRequest) - } + /// This function process the URI request sent by the client and by the Lambda function + /// + /// It enqueues the client invocation and iterate over the invocation queue when the Lambda function sends /next request + /// It answers the /:requestID/response and /:requestID/error requests sent by the Lambda function but do not process the body + /// + /// - Parameters: + /// - head: the HTTP request head + /// - body: the HTTP request body + /// - Throws: + /// - Returns: the response to send back to the client or the Lambda function + private func processRequest(head: HTTPRequestHead, body: ByteBuffer?) async throws -> LocalServerResponse { + + if let body { + self.logger.trace( + "Processing request", + metadata: ["URI": "\(head.method) \(head.uri)", "Body": "\(String(buffer: body))"] + ) + } else { + self.logger.trace("Processing request", metadata: ["URI": "\(head.method) \(head.uri)"]) + } - invocation.responsePromise.succeed(.init(status: .ok, body: request.body)) - self.writeResponse(context: context, status: .accepted) - Self.invocationState = .waitingForLambdaRequest + switch (head.method, head.uri) { - // :requestID/error endpoint is called by the lambda posting an error response - case (.POST, let url) where url.hasSuffix(Consts.postErrorURLSuffix): - let parts = request.head.uri.split(separator: "/") - guard let requestID = parts.count > 2 ? String(parts[parts.count - 2]) : nil else { - // the request is malformed, since we were expecting a requestId in the path - return self.writeResponse(context: context, status: .badRequest) - } - guard case .waitingForLambdaResponse(let invocation) = Self.invocationState else { - // a response was send, but we did not expect to receive one - self.logger.error("invalid invocation state \(Self.invocationState)") - return self.writeResponse(context: context, status: .unprocessableEntity) - } - guard requestID == invocation.requestID else { - // the request's requestId is not matching the one we are expecting - self.logger.error( - "invalid invocation state request ID \(requestID) does not match expected \(invocation.requestID)" + // + // client invocations + // + // client POST /invoke + case (.POST, let url) where url.hasSuffix(self.invocationEndpoint): + guard let body else { + return .init(status: .badRequest, headers: [], body: nil) + } + // we always accept the /invoke request and push them to the pool + let requestId = "\(DispatchTime.now().uptimeNanoseconds)" + logger.trace("/invoke received invocation", metadata: ["requestId": "\(requestId)"]) + await self.invocationPool.push(LocalServerInvocation(requestId: requestId, request: body)) + + for try await response in self.responsePool { + logger.trace( + "Received response to return to client", + metadata: ["requestId": "\(response.requestId ?? "")"] + ) + if response.requestId == requestId { + return response + } else { + logger.error( + "Received response for a different request id", + metadata: ["response requestId": "\(response.requestId ?? "")", "requestId": "\(requestId)"] ) - return self.writeResponse(context: context, status: .badRequest) } + } + // What todo when there is no more responses to process? + // This should not happen as the async iterator blocks until there is a response to process + fatalError("No more responses to process - the async for loop should not return") + + // client uses incorrect HTTP method + case (_, let url) where url.hasSuffix(self.invocationEndpoint): + return .init(status: .methodNotAllowed) + + // + // lambda invocations + // + + // /next endpoint is called by the lambda polling for work + // this call only returns when there is a task to give to the lambda function + case (.GET, let url) where url.hasSuffix(Consts.getNextInvocationURLSuffix): + + // pop the tasks from the queue, until there is no more to process + self.logger.trace("/next waiting for /invoke") + for try await invocation in self.invocationPool { + self.logger.trace("/next retrieved invocation", metadata: ["requestId": "\(invocation.requestId)"]) + // this stores the invocation request id into the response + return invocation.makeResponse(status: .accepted) + } + // What todo when there is no more tasks to process? + // This should not happen as the async iterator blocks until there is a task to process + fatalError("No more invocations to process - the async for loop should not return") + + // :requestID/response endpoint is called by the lambda posting the response + // we accept all requestID and we do not handle the body + case (.POST, let url) where url.hasSuffix(Consts.postResponseURLSuffix): + let parts = head.uri.split(separator: "/") + guard let requestId = parts.count > 2 ? String(parts[parts.count - 2]) : nil else { + // the request is malformed, since we were expecting a requestId in the path + return .init(status: .badRequest) + } + // enqueue the lambda function response to be served as response to the client /invoke + logger.trace("/:requestID/response received response", metadata: ["requestId": "\(requestId)"]) + await self.responsePool.push( + LocalServerResponse( + id: requestId, + status: .ok, + headers: [("Content-Type", "application/json")], + body: body + ) + ) - invocation.responsePromise.succeed(.init(status: .internalServerError, body: request.body)) - self.writeResponse(context: context, status: .accepted) - Self.invocationState = .waitingForLambdaRequest + // tell the Lambda function we accepted the response + return .init(id: requestId, status: .accepted) - // unknown call - default: - self.writeResponse(context: context, status: .notFound) + // :requestID/error endpoint is called by the lambda posting an error response + // we accept all requestID and we do not handle the body + case (.POST, let url) where url.hasSuffix(Consts.postErrorURLSuffix): + let parts = head.uri.split(separator: "/") + guard let _ = parts.count > 2 ? String(parts[parts.count - 2]) : nil else { + // the request is malformed, since we were expecting a requestId in the path + return .init(status: .badRequest) } - } + return .init(status: .ok) - func writeResponse(context: ChannelHandlerContext, status: HTTPResponseStatus) { - self.writeResponse(context: context, response: .init(status: status)) + // unknown call + default: + return .init(status: .notFound) } + } - func writeResponse(context: ChannelHandlerContext, response: Response) { - var headers = HTTPHeaders(response.headers ?? []) - headers.add(name: "content-length", value: "\(response.body?.readableBytes ?? 0)") - let head = HTTPResponseHead( - version: HTTPVersion(major: 1, minor: 1), - status: response.status, - headers: headers + private func sendResponse( + response: LocalServerResponse, + outbound: NIOAsyncChannelOutboundWriter + ) async throws { + var headers = HTTPHeaders(response.headers ?? []) + headers.add(name: "Content-Length", value: "\(response.body?.readableBytes ?? 0)") + + self.logger.trace("Writing response", metadata: ["requestId": "\(response.requestId ?? "")"]) + try await outbound.write( + HTTPServerResponsePart.head( + HTTPResponseHead( + version: .init(major: 1, minor: 1), // use HTTP 1.1 it keeps connection alive between requests + status: response.status, + headers: headers + ) ) + ) + if let body = response.body { + try await outbound.write(HTTPServerResponsePart.body(.byteBuffer(body))) + } + + try await outbound.write(HTTPServerResponsePart.end(nil)) + } + + /// A shared data structure to store the current invocation or response request and the continuation. + /// This data structure is shared between instances of the HTTPHandler + /// (one instance to serve requests from the Lambda function and one instance to serve requests from the client invoking the lambda function). + private final class Pool: AsyncSequence, AsyncIteratorProtocol, Sendable where T: Sendable { + typealias Element = T - context.write(wrapOutboundOut(.head(head))).whenFailure { error in - self.logger.error("\(self) write error \(error)") + private let _buffer = Mutex>(.init()) + private let _continuation = Mutex?>(nil) + + public func popFirst() async -> T? { + self._buffer.withLock { $0.popFirst() } + } + + // if the iterator is waiting for an element, give it to it + // otherwise, enqueue the element + public func push(_ invocation: T) async { + if let continuation = self._continuation.withLock({ $0 }) { + self._continuation.withLock { $0 = nil } + continuation.resume(returning: invocation) + } else { + self._buffer.withLock { $0.append(invocation) } } + } - if let buffer = response.body { - context.write(wrapOutboundOut(.body(.byteBuffer(buffer)))).whenFailure { error in - self.logger.error("\(self) write error \(error)") - } + func next() async throws -> T? { + + // exit the async for loop if the task is cancelled + guard !Task.isCancelled else { + return nil } - context.writeAndFlush(wrapOutboundOut(.end(nil))).whenComplete { result in - if case .failure(let error) = result { - self.logger.error("\(self) write error \(error)") + if let element = await self.popFirst() { + return element + } else { + // we can't return nil if there is nothing to dequeue otherwise the async for loop will stop + // wait for an element to be enqueued + return try await withCheckedThrowingContinuation { (continuation: CheckedContinuation) in + // store the continuation for later, when an element is enqueued + self._continuation.withLock { + $0 = continuation + } } } } - struct Response { - var status: HTTPResponseStatus = .ok - var headers: [(String, String)]? - var body: ByteBuffer? - } - - struct Invocation { - let requestID: String - let request: ByteBuffer - let responsePromise: EventLoopPromise - - func makeResponse() -> Response { - var response = Response() - response.body = self.request - // required headers - response.headers = [ - (AmazonHeaders.requestID, self.requestID), - ( - AmazonHeaders.invokedFunctionARN, - "arn:aws:lambda:us-east-1:\(Int16.random(in: Int16.min ... Int16.max)):function:custom-runtime" - ), - (AmazonHeaders.traceID, "Root=\(AmazonHeaders.generateXRayTraceID());Sampled=1"), - (AmazonHeaders.deadline, "\(DispatchWallTime.distantFuture.millisSinceEpoch)"), - ] - return response - } + func makeAsyncIterator() -> Pool { + self } + } - enum InvocationState { - case waitingForInvocation(EventLoopPromise) - case waitingForLambdaRequest - case waitingForLambdaResponse(Invocation) + private struct LocalServerResponse: Sendable { + let requestId: String? + let status: HTTPResponseStatus + let headers: [(String, String)]? + let body: ByteBuffer? + init(id: String? = nil, status: HTTPResponseStatus, headers: [(String, String)]? = nil, body: ByteBuffer? = nil) + { + self.requestId = id + self.status = status + self.headers = headers + self.body = body } } - enum ServerError: Error { - case notReady - case cantBind + private struct LocalServerInvocation: Sendable { + let requestId: String + let request: ByteBuffer + + func makeResponse(status: HTTPResponseStatus) -> LocalServerResponse { + + // required headers + let headers = [ + (AmazonHeaders.requestID, self.requestId), + ( + AmazonHeaders.invokedFunctionARN, + "arn:aws:lambda:us-east-1:\(Int16.random(in: Int16.min ... Int16.max)):function:custom-runtime" + ), + (AmazonHeaders.traceID, "Root=\(AmazonHeaders.generateXRayTraceID());Sampled=1"), + (AmazonHeaders.deadline, "\(DispatchWallTime.distantFuture.millisSinceEpoch)"), + ] + + return LocalServerResponse(id: self.requestId, status: status, headers: headers, body: self.request) + } } } #endif From 5b76105440d765d95224a3b35d80b42eba19b35c Mon Sep 17 00:00:00 2001 From: Sebastien Stormacq Date: Wed, 15 Jan 2025 23:58:12 +0100 Subject: [PATCH 11/15] cleanup comments and fix typos --- .../Lambda+LocalServer.swift | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/Sources/AWSLambdaRuntimeCore/Lambda+LocalServer.swift b/Sources/AWSLambdaRuntimeCore/Lambda+LocalServer.swift index bcddd199..a3d98aa2 100644 --- a/Sources/AWSLambdaRuntimeCore/Lambda+LocalServer.swift +++ b/Sources/AWSLambdaRuntimeCore/Lambda+LocalServer.swift @@ -233,6 +233,7 @@ private struct LambdaHttpServer { logger.trace("/invoke received invocation", metadata: ["requestId": "\(requestId)"]) await self.invocationPool.push(LocalServerInvocation(requestId: requestId, request: body)) + // wait for the lambda function to process the request for try await response in self.responsePool { logger.trace( "Received response to return to client", @@ -245,6 +246,7 @@ private struct LambdaHttpServer { "Received response for a different request id", metadata: ["response requestId": "\(response.requestId ?? "")", "requestId": "\(requestId)"] ) + // should we return an error here ? Or crash as this is probably a programming error? } } // What todo when there is no more responses to process? @@ -263,11 +265,11 @@ private struct LambdaHttpServer { // this call only returns when there is a task to give to the lambda function case (.GET, let url) where url.hasSuffix(Consts.getNextInvocationURLSuffix): - // pop the tasks from the queue, until there is no more to process + // pop the tasks from the queue self.logger.trace("/next waiting for /invoke") for try await invocation in self.invocationPool { self.logger.trace("/next retrieved invocation", metadata: ["requestId": "\(invocation.requestId)"]) - // this stores the invocation request id into the response + // this call also stores the invocation requestId into the response return invocation.makeResponse(status: .accepted) } // What todo when there is no more tasks to process? @@ -275,7 +277,6 @@ private struct LambdaHttpServer { fatalError("No more invocations to process - the async for loop should not return") // :requestID/response endpoint is called by the lambda posting the response - // we accept all requestID and we do not handle the body case (.POST, let url) where url.hasSuffix(Consts.postResponseURLSuffix): let parts = head.uri.split(separator: "/") guard let requestId = parts.count > 2 ? String(parts[parts.count - 2]) : nil else { @@ -297,7 +298,7 @@ private struct LambdaHttpServer { return .init(id: requestId, status: .accepted) // :requestID/error endpoint is called by the lambda posting an error response - // we accept all requestID and we do not handle the body + // we accept all requestID and we do not handle the body, we just acknowledge the request case (.POST, let url) where url.hasSuffix(Consts.postErrorURLSuffix): let parts = head.uri.split(separator: "/") guard let _ = parts.count > 2 ? String(parts[parts.count - 2]) : nil else { @@ -323,7 +324,7 @@ private struct LambdaHttpServer { try await outbound.write( HTTPServerResponsePart.head( HTTPResponseHead( - version: .init(major: 1, minor: 1), // use HTTP 1.1 it keeps connection alive between requests + version: .init(major: 1, minor: 1), status: response.status, headers: headers ) @@ -336,7 +337,7 @@ private struct LambdaHttpServer { try await outbound.write(HTTPServerResponsePart.end(nil)) } - /// A shared data structure to store the current invocation or response request and the continuation. + /// A shared data structure to store the current invocation or response requests and the continuation objects. /// This data structure is shared between instances of the HTTPHandler /// (one instance to serve requests from the Lambda function and one instance to serve requests from the client invoking the lambda function). private final class Pool: AsyncSequence, AsyncIteratorProtocol, Sendable where T: Sendable { @@ -345,13 +346,15 @@ private struct LambdaHttpServer { private let _buffer = Mutex>(.init()) private let _continuation = Mutex?>(nil) + /// retrieve the first element from the buffer public func popFirst() async -> T? { self._buffer.withLock { $0.popFirst() } } - // if the iterator is waiting for an element, give it to it - // otherwise, enqueue the element + /// enqueue an element, or give it back immediately to the iterator if it is waiting for an element public func push(_ invocation: T) async { + // if the iterator is waiting for an element, give it to it + // otherwise, enqueue the element if let continuation = self._continuation.withLock({ $0 }) { self._continuation.withLock { $0 = nil } continuation.resume(returning: invocation) From 586fa819ef8f9d5f0a2ea4b17c05212106213157 Mon Sep 17 00:00:00 2001 From: Sebastien Stormacq Date: Sat, 18 Jan 2025 12:58:44 +0100 Subject: [PATCH 12/15] WIP to remove Swift6 concurrency errors --- Package.swift | 4 ++-- Sources/AWSLambdaRuntime/Lambda+Codable.swift | 4 ++-- .../AWSLambdaRuntimeCore/Lambda+Codable.swift | 10 +++++----- .../AWSLambdaRuntimeCore/LambdaHandlers.swift | 20 +++++++++---------- .../LambdaRuntimeClient.swift | 15 ++++++++++---- 5 files changed, 30 insertions(+), 23 deletions(-) diff --git a/Package.swift b/Package.swift index d2c92fdc..6716f329 100644 --- a/Package.swift +++ b/Package.swift @@ -36,8 +36,8 @@ let package = Package( .product(name: "NIOCore", package: "swift-nio"), .product(name: "NIOConcurrencyHelpers", package: "swift-nio"), .product(name: "NIOPosix", package: "swift-nio"), - ], - swiftSettings: [.swiftLanguageMode(.v5)] + ] + // swiftSettings: [.swiftLanguageMode(.v5)] ), .plugin( name: "AWSLambdaPackager", diff --git a/Sources/AWSLambdaRuntime/Lambda+Codable.swift b/Sources/AWSLambdaRuntime/Lambda+Codable.swift index fb6d2ca7..1132be5c 100644 --- a/Sources/AWSLambdaRuntime/Lambda+Codable.swift +++ b/Sources/AWSLambdaRuntime/Lambda+Codable.swift @@ -91,7 +91,7 @@ extension LambdaRuntime { public convenience init( decoder: JSONDecoder = JSONDecoder(), encoder: JSONEncoder = JSONEncoder(), - body: sending @escaping (Event, LambdaContext) async throws -> Output + body: @Sendable @escaping (Event, LambdaContext) async throws -> Output ) where Handler == LambdaCodableAdapter< @@ -116,7 +116,7 @@ extension LambdaRuntime { /// - Parameter decoder: The decoder object that will be used to decode the incoming `ByteBuffer` event into the generic `Event` type. `JSONDecoder()` used as default. public convenience init( decoder: JSONDecoder = JSONDecoder(), - body: sending @escaping (Event, LambdaContext) async throws -> Void + body: @Sendable @escaping (Event, LambdaContext) async throws -> Void ) where Handler == LambdaCodableAdapter< diff --git a/Sources/AWSLambdaRuntimeCore/Lambda+Codable.swift b/Sources/AWSLambdaRuntimeCore/Lambda+Codable.swift index 7a0a9a22..6e6890e1 100644 --- a/Sources/AWSLambdaRuntimeCore/Lambda+Codable.swift +++ b/Sources/AWSLambdaRuntimeCore/Lambda+Codable.swift @@ -16,7 +16,7 @@ import NIOCore /// The protocol a decoder must conform to so that it can be used with ``LambdaCodableAdapter`` to decode incoming /// `ByteBuffer` events. -public protocol LambdaEventDecoder { +public protocol LambdaEventDecoder: Sendable { /// Decode the `ByteBuffer` representing the received event into the generic `Event` type /// the handler will receive. /// - Parameters: @@ -28,7 +28,7 @@ public protocol LambdaEventDecoder { /// The protocol an encoder must conform to so that it can be used with ``LambdaCodableAdapter`` to encode the generic /// ``LambdaOutputEncoder/Output`` object into a `ByteBuffer`. -public protocol LambdaOutputEncoder { +public protocol LambdaOutputEncoder: Sendable { associatedtype Output /// Encode the generic type `Output` the handler has returned into a `ByteBuffer`. @@ -52,7 +52,7 @@ public struct LambdaHandlerAdapter< Event: Decodable, Output, Handler: LambdaHandler ->: LambdaWithBackgroundProcessingHandler where Handler.Event == Event, Handler.Output == Output { +>: Sendable, LambdaWithBackgroundProcessingHandler where Handler.Event == Event, Handler.Output == Output { @usableFromInline let handler: Handler /// Initializes an instance given a concrete handler. @@ -86,7 +86,7 @@ public struct LambdaCodableAdapter< Output, Decoder: LambdaEventDecoder, Encoder: LambdaOutputEncoder ->: StreamingLambdaHandler where Handler.Event == Event, Handler.Output == Output, Encoder.Output == Output { +>: Sendable, StreamingLambdaHandler where Handler.Event == Event, Handler.Output == Output, Encoder.Output == Output, Encoder: Sendable, Decoder: Sendable { @usableFromInline let handler: Handler @usableFromInline let encoder: Encoder @usableFromInline let decoder: Decoder @@ -139,7 +139,7 @@ public struct LambdaCodableAdapter< /// A ``LambdaResponseStreamWriter`` wrapper that conforms to ``LambdaResponseWriter``. public struct LambdaCodableResponseWriter: LambdaResponseWriter -where Output == Encoder.Output { +where Output == Encoder.Output, Encoder: Sendable { @usableFromInline let underlyingStreamWriter: Base @usableFromInline let encoder: Encoder diff --git a/Sources/AWSLambdaRuntimeCore/LambdaHandlers.swift b/Sources/AWSLambdaRuntimeCore/LambdaHandlers.swift index b76b453d..014d5f11 100644 --- a/Sources/AWSLambdaRuntimeCore/LambdaHandlers.swift +++ b/Sources/AWSLambdaRuntimeCore/LambdaHandlers.swift @@ -20,7 +20,7 @@ import NIOCore /// Background work can also be executed after returning the response. After closing the response stream by calling /// ``LambdaResponseStreamWriter/finish()`` or ``LambdaResponseStreamWriter/writeAndFinish(_:)``, /// the ``handle(_:responseWriter:context:)`` function is free to execute any background work. -public protocol StreamingLambdaHandler { +public protocol StreamingLambdaHandler: Sendable { /// The handler function -- implement the business logic of the Lambda function here. /// - Parameters: /// - event: The invocation's input data. @@ -45,7 +45,7 @@ public protocol StreamingLambdaHandler { /// A writer object to write the Lambda response stream into. The HTTP response is started lazily. /// before the first call to ``write(_:)`` or ``writeAndFinish(_:)``. -public protocol LambdaResponseStreamWriter { +public protocol LambdaResponseStreamWriter: Sendable { /// Write a response part into the stream. Bytes written are streamed continually. /// - Parameter buffer: The buffer to write. func write(_ buffer: ByteBuffer) async throws @@ -64,7 +64,7 @@ public protocol LambdaResponseStreamWriter { /// /// - note: This handler protocol does not support response streaming because the output has to be encoded prior to it being sent, e.g. it is not possible to encode a partial/incomplete JSON string. /// This protocol also does not support the execution of background work after the response has been returned -- the ``LambdaWithBackgroundProcessingHandler`` protocol caters for such use-cases. -public protocol LambdaHandler { +public protocol LambdaHandler: Sendable { /// Generic input type. /// The body of the request sent to Lambda will be decoded into this type for the handler to consume. associatedtype Event: Decodable @@ -86,7 +86,7 @@ public protocol LambdaHandler { /// ``LambdaResponseWriter``that is passed in as an argument, meaning that the /// ``LambdaWithBackgroundProcessingHandler/handle(_:outputWriter:context:)`` function is then /// free to implement any background work after the result has been sent to the AWS Lambda control plane. -public protocol LambdaWithBackgroundProcessingHandler { +public protocol LambdaWithBackgroundProcessingHandler: Sendable { /// Generic input type. /// The body of the request sent to Lambda will be decoded into this type for the handler to consume. associatedtype Event: Decodable @@ -110,7 +110,7 @@ public protocol LambdaWithBackgroundProcessingHandler { /// Used with ``LambdaWithBackgroundProcessingHandler``. /// A mechanism to "return" an output from ``LambdaWithBackgroundProcessingHandler/handle(_:outputWriter:context:)`` without the function needing to /// have a return type and exit at that point. This allows for background work to be executed _after_ a response has been sent to the AWS Lambda response endpoint. -public protocol LambdaResponseWriter { +public protocol LambdaResponseWriter: Sendable { associatedtype Output /// Sends the generic ``LambdaResponseWriter/Output`` object (representing the computed result of the handler) /// to the AWS Lambda response endpoint. @@ -150,17 +150,17 @@ public struct StreamingClosureHandler: StreamingLambdaHandler { /// A ``LambdaHandler`` conforming handler object that can be constructed with a closure. /// Allows for a handler to be defined in a clean manner, leveraging Swift's trailing closure syntax. public struct ClosureHandler: LambdaHandler { - let body: (Event, LambdaContext) async throws -> Output + let body: @Sendable (Event, LambdaContext) async throws -> Output /// Initialize with a closure handler over generic `Input` and `Output` types. /// - Parameter body: The handler function written as a closure. - public init(body: @escaping (Event, LambdaContext) async throws -> Output) where Output: Encodable { + public init(body: @Sendable @escaping (Event, LambdaContext) async throws -> Output) where Output: Encodable { self.body = body } /// Initialize with a closure handler over a generic `Input` type, and a `Void` `Output`. /// - Parameter body: The handler function written as a closure. - public init(body: @escaping (Event, LambdaContext) async throws -> Void) where Output == Void { + public init(body: @Sendable @escaping (Event, LambdaContext) async throws -> Void) where Output == Void { self.body = body } @@ -194,7 +194,7 @@ extension LambdaRuntime { >( encoder: Encoder, decoder: Decoder, - body: sending @escaping (Event, LambdaContext) async throws -> Output + body: @Sendable @escaping (Event, LambdaContext) async throws -> Output ) where Handler == LambdaCodableAdapter< @@ -220,7 +220,7 @@ extension LambdaRuntime { /// - body: The handler in the form of a closure. public convenience init( decoder: Decoder, - body: sending @escaping (Event, LambdaContext) async throws -> Void + body: @Sendable @escaping (Event, LambdaContext) async throws -> Void ) where Handler == LambdaCodableAdapter< diff --git a/Sources/AWSLambdaRuntimeCore/LambdaRuntimeClient.swift b/Sources/AWSLambdaRuntimeCore/LambdaRuntimeClient.swift index bbd16efa..2bb8b63a 100644 --- a/Sources/AWSLambdaRuntimeCore/LambdaRuntimeClient.swift +++ b/Sources/AWSLambdaRuntimeCore/LambdaRuntimeClient.swift @@ -12,6 +12,9 @@ // //===----------------------------------------------------------------------===// + +// TODO: rewrite for Swift 6 concurrency + import Logging import NIOCore import NIOHTTP1 @@ -140,6 +143,7 @@ final actor LambdaRuntimeClient: LambdaRuntimeClientProtocol { } } + // FIXME: add support for graceful shutdown func nextInvocation() async throws -> (Invocation, Writer) { switch self.lambdaState { case .idle: @@ -336,12 +340,15 @@ final actor LambdaRuntimeClient: LambdaRuntimeClientProtocol { case .disconnected, .connected: fatalError("Unexpected state: \(self.connectionState)") - case .connecting(let array): + // case .connecting(let array): + case .connecting: self.connectionState = .connected(channel, handler) defer { - for continuation in array { - continuation.resume(returning: handler) - } + // for continuation in array { + // // This causes an error in Swift 6 + // // 'self'-isolated 'handler' is passed as a 'sending' parameter; Uses in callee may race with later 'self'-isolated uses + // continuation.resume(returning: handler) + // } } return handler } From 132f9bd60ffb2bd2b81835b18c04ea5833696c22 Mon Sep 17 00:00:00 2001 From: Sebastien Stormacq Date: Sat, 18 Jan 2025 13:03:56 +0100 Subject: [PATCH 13/15] apply swift-format --- Sources/AWSLambdaRuntimeCore/Lambda+Codable.swift | 3 ++- Sources/AWSLambdaRuntimeCore/LambdaHandlers.swift | 2 +- Sources/AWSLambdaRuntimeCore/LambdaRuntimeClient.swift | 5 ++--- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/Sources/AWSLambdaRuntimeCore/Lambda+Codable.swift b/Sources/AWSLambdaRuntimeCore/Lambda+Codable.swift index 6e6890e1..1223052d 100644 --- a/Sources/AWSLambdaRuntimeCore/Lambda+Codable.swift +++ b/Sources/AWSLambdaRuntimeCore/Lambda+Codable.swift @@ -86,7 +86,8 @@ public struct LambdaCodableAdapter< Output, Decoder: LambdaEventDecoder, Encoder: LambdaOutputEncoder ->: Sendable, StreamingLambdaHandler where Handler.Event == Event, Handler.Output == Output, Encoder.Output == Output, Encoder: Sendable, Decoder: Sendable { +>: Sendable, StreamingLambdaHandler +where Handler.Event == Event, Handler.Output == Output, Encoder.Output == Output, Encoder: Sendable, Decoder: Sendable { @usableFromInline let handler: Handler @usableFromInline let encoder: Encoder @usableFromInline let decoder: Decoder diff --git a/Sources/AWSLambdaRuntimeCore/LambdaHandlers.swift b/Sources/AWSLambdaRuntimeCore/LambdaHandlers.swift index 014d5f11..94522e5a 100644 --- a/Sources/AWSLambdaRuntimeCore/LambdaHandlers.swift +++ b/Sources/AWSLambdaRuntimeCore/LambdaHandlers.swift @@ -160,7 +160,7 @@ public struct ClosureHandler: LambdaHandler { /// Initialize with a closure handler over a generic `Input` type, and a `Void` `Output`. /// - Parameter body: The handler function written as a closure. - public init(body: @Sendable @escaping (Event, LambdaContext) async throws -> Void) where Output == Void { + public init(body: @Sendable @escaping (Event, LambdaContext) async throws -> Void) where Output == Void { self.body = body } diff --git a/Sources/AWSLambdaRuntimeCore/LambdaRuntimeClient.swift b/Sources/AWSLambdaRuntimeCore/LambdaRuntimeClient.swift index 2bb8b63a..823956a9 100644 --- a/Sources/AWSLambdaRuntimeCore/LambdaRuntimeClient.swift +++ b/Sources/AWSLambdaRuntimeCore/LambdaRuntimeClient.swift @@ -12,7 +12,6 @@ // //===----------------------------------------------------------------------===// - // TODO: rewrite for Swift 6 concurrency import Logging @@ -143,7 +142,7 @@ final actor LambdaRuntimeClient: LambdaRuntimeClientProtocol { } } - // FIXME: add support for graceful shutdown + // FIXME: add support for graceful shutdown func nextInvocation() async throws -> (Invocation, Writer) { switch self.lambdaState { case .idle: @@ -345,7 +344,7 @@ final actor LambdaRuntimeClient: LambdaRuntimeClientProtocol { self.connectionState = .connected(channel, handler) defer { // for continuation in array { - // // This causes an error in Swift 6 + // // This causes an error in Swift 6 // // 'self'-isolated 'handler' is passed as a 'sending' parameter; Uses in callee may race with later 'self'-isolated uses // continuation.resume(returning: handler) // } From 42881edda5b7abd195b7113716ff44046c783dc4 Mon Sep 17 00:00:00 2001 From: Sebastien Stormacq Date: Sat, 18 Jan 2025 13:15:30 +0100 Subject: [PATCH 14/15] Fix example for the API change (Sendable) --- .../Sources/AuthorizerLambda/main.swift | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/Examples/APIGateway+LambdaAuthorizer/Sources/AuthorizerLambda/main.swift b/Examples/APIGateway+LambdaAuthorizer/Sources/AuthorizerLambda/main.swift index 60ea2b7b..3a0ff6f4 100644 --- a/Examples/APIGateway+LambdaAuthorizer/Sources/AuthorizerLambda/main.swift +++ b/Examples/APIGateway+LambdaAuthorizer/Sources/AuthorizerLambda/main.swift @@ -22,7 +22,7 @@ import AWSLambdaRuntime // This code is shown for the example only and is not used in this demo. // This code doesn't perform any type of token validation. It should be used as a reference only. let policyAuthorizerHandler: - (APIGatewayLambdaAuthorizerRequest, LambdaContext) async throws -> APIGatewayLambdaAuthorizerPolicyResponse = { + @Sendable (APIGatewayLambdaAuthorizerRequest, LambdaContext) async throws -> APIGatewayLambdaAuthorizerPolicyResponse = { (request: APIGatewayLambdaAuthorizerRequest, context: LambdaContext) in context.logger.debug("+++ Policy Authorizer called +++") @@ -57,7 +57,7 @@ let policyAuthorizerHandler: // // This code doesn't perform any type of token validation. It should be used as a reference only. let simpleAuthorizerHandler: - (APIGatewayLambdaAuthorizerRequest, LambdaContext) async throws -> APIGatewayLambdaAuthorizerSimpleResponse = { + @Sendable (APIGatewayLambdaAuthorizerRequest, LambdaContext) async throws -> APIGatewayLambdaAuthorizerSimpleResponse = { (_: APIGatewayLambdaAuthorizerRequest, context: LambdaContext) in context.logger.debug("+++ Simple Authorizer called +++") From 84b346a060305e0aa84f95a34d630857e98ac257 Mon Sep 17 00:00:00 2001 From: Sebastien Stormacq Date: Sat, 18 Jan 2025 17:03:35 +0100 Subject: [PATCH 15/15] swift format --- .../Sources/AuthorizerLambda/main.swift | 68 ++++++++++--------- 1 file changed, 35 insertions(+), 33 deletions(-) diff --git a/Examples/APIGateway+LambdaAuthorizer/Sources/AuthorizerLambda/main.swift b/Examples/APIGateway+LambdaAuthorizer/Sources/AuthorizerLambda/main.swift index 3a0ff6f4..87b7d2bd 100644 --- a/Examples/APIGateway+LambdaAuthorizer/Sources/AuthorizerLambda/main.swift +++ b/Examples/APIGateway+LambdaAuthorizer/Sources/AuthorizerLambda/main.swift @@ -22,34 +22,35 @@ import AWSLambdaRuntime // This code is shown for the example only and is not used in this demo. // This code doesn't perform any type of token validation. It should be used as a reference only. let policyAuthorizerHandler: - @Sendable (APIGatewayLambdaAuthorizerRequest, LambdaContext) async throws -> APIGatewayLambdaAuthorizerPolicyResponse = { - (request: APIGatewayLambdaAuthorizerRequest, context: LambdaContext) in + @Sendable (APIGatewayLambdaAuthorizerRequest, LambdaContext) async throws -> + APIGatewayLambdaAuthorizerPolicyResponse = { + (request: APIGatewayLambdaAuthorizerRequest, context: LambdaContext) in - context.logger.debug("+++ Policy Authorizer called +++") + context.logger.debug("+++ Policy Authorizer called +++") - // typically, this function will check the validity of the incoming token received in the request + // typically, this function will check the validity of the incoming token received in the request - // then it creates and returns a response - return APIGatewayLambdaAuthorizerPolicyResponse( - principalId: "John Appleseed", + // then it creates and returns a response + return APIGatewayLambdaAuthorizerPolicyResponse( + principalId: "John Appleseed", - // this policy allows the caller to invoke any API Gateway endpoint - policyDocument: .init(statement: [ - .init( - action: "execute-api:Invoke", - effect: .allow, - resource: "*" - ) + // this policy allows the caller to invoke any API Gateway endpoint + policyDocument: .init(statement: [ + .init( + action: "execute-api:Invoke", + effect: .allow, + resource: "*" + ) - ]), + ]), - // this is additional context we want to return to the caller - context: [ - "abc1": "xyz1", - "abc2": "xyz2", - ] - ) - } + // this is additional context we want to return to the caller + context: [ + "abc1": "xyz1", + "abc2": "xyz2", + ] + ) + } // // This is an example of a simple authorizer that always authorizes the request. @@ -57,21 +58,22 @@ let policyAuthorizerHandler: // // This code doesn't perform any type of token validation. It should be used as a reference only. let simpleAuthorizerHandler: - @Sendable (APIGatewayLambdaAuthorizerRequest, LambdaContext) async throws -> APIGatewayLambdaAuthorizerSimpleResponse = { - (_: APIGatewayLambdaAuthorizerRequest, context: LambdaContext) in + @Sendable (APIGatewayLambdaAuthorizerRequest, LambdaContext) async throws -> + APIGatewayLambdaAuthorizerSimpleResponse = { + (_: APIGatewayLambdaAuthorizerRequest, context: LambdaContext) in - context.logger.debug("+++ Simple Authorizer called +++") + context.logger.debug("+++ Simple Authorizer called +++") - // typically, this function will check the validity of the incoming token received in the request + // typically, this function will check the validity of the incoming token received in the request - return APIGatewayLambdaAuthorizerSimpleResponse( - // this is the authorization decision: yes or no - isAuthorized: true, + return APIGatewayLambdaAuthorizerSimpleResponse( + // this is the authorization decision: yes or no + isAuthorized: true, - // this is additional context we want to return to the caller - context: ["abc1": "xyz1"] - ) - } + // this is additional context we want to return to the caller + context: ["abc1": "xyz1"] + ) + } // create the runtime and start polling for new events. // in this demo we use the simple authorizer handler