Skip to content

[test] Update mock server for Swift 6 compliance #463

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Jan 22, 2025
6 changes: 3 additions & 3 deletions Package.swift
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ let package = Package(
.library(name: "AWSLambdaTesting", targets: ["AWSLambdaTesting"]),
],
dependencies: [
.package(url: "https://github.com/apple/swift-nio.git", from: "2.76.0"),
.package(url: "https://github.com/apple/swift-nio.git", from: "2.77.0"),
.package(url: "https://github.com/apple/swift-log.git", from: "1.5.4"),
],
targets: [
Expand Down Expand Up @@ -89,11 +89,11 @@ let package = Package(
.executableTarget(
name: "MockServer",
dependencies: [
.product(name: "Logging", package: "swift-log"),
.product(name: "NIOHTTP1", package: "swift-nio"),
.product(name: "NIOCore", package: "swift-nio"),
.product(name: "NIOPosix", package: "swift-nio"),
],
swiftSettings: [.swiftLanguageMode(.v5)]
]
),
]
)
298 changes: 298 additions & 0 deletions Sources/MockServer/MockHTTPServer.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,298 @@
//===----------------------------------------------------------------------===//
//
// This source file is part of the SwiftAWSLambdaRuntime open source project
//
// Copyright (c) 2017-2025 Apple Inc. and the SwiftAWSLambdaRuntime project authors
// Licensed under Apache License v2.0
//
// See LICENSE.txt for license information
// See CONTRIBUTORS.txt for the list of SwiftAWSLambdaRuntime project authors
//
// SPDX-License-Identifier: Apache-2.0
//
//===----------------------------------------------------------------------===//

import Logging
import NIOCore
import NIOHTTP1
import NIOPosix
import Synchronization

// for UUID and Date
#if canImport(FoundationEssentials)
import FoundationEssentials
#else
import Foundation
#endif

@main
struct HttpServer {
/// The server's host. (default: 127.0.0.1)
private let host: String
/// The server's port. (default: 7000)
private let port: Int
/// The server's event loop group. (default: MultiThreadedEventLoopGroup.singleton)
private let eventLoopGroup: MultiThreadedEventLoopGroup
/// the mode. Are we mocking a server for a Lambda function that expects a String or a JSON document? (default: string)
private let mode: Mode
/// the number of connections this server must accept before shutting down (default: 1)
private let maxInvocations: Int
/// the logger (control verbosity with LOG_LEVEL environment variable)
private let logger: Logger

static func main() async throws {
var log = Logger(label: "MockServer")
log.logLevel = env("LOG_LEVEL").flatMap(Logger.Level.init) ?? .info

let server = HttpServer(
host: env("HOST") ?? "127.0.0.1",
port: env("PORT").flatMap(Int.init) ?? 7000,
eventLoopGroup: .singleton,
mode: env("MODE").flatMap(Mode.init) ?? .string,
maxInvocations: env("MAX_INVOCATIONS").flatMap(Int.init) ?? 1,
logger: log
)
try await server.run()
}

/// This method starts the server and handles one unique incoming connections
/// The Lambda function will send two HTTP requests over this connection: one for the next invocation and one for the response.
private func run() async throws {
let channel = try await ServerBootstrap(group: self.eventLoopGroup)
.serverChannelOption(.backlog, value: 256)
.serverChannelOption(.socketOption(.so_reuseaddr), value: 1)
.childChannelOption(.maxMessagesPerRead, value: 1)
.bind(
host: self.host,
port: self.port
) { channel in
channel.eventLoop.makeCompletedFuture {

try channel.pipeline.syncOperations.configureHTTPServerPipeline(
withErrorHandling: true
)

return try NIOAsyncChannel(
wrappingChannelSynchronously: channel,
configuration: NIOAsyncChannel.Configuration(
inboundType: HTTPServerRequestPart.self,
outboundType: HTTPServerResponsePart.self
)
)
}
}

logger.info(
"Server started and listening",
metadata: [
"host": "\(channel.channel.localAddress?.ipAddress?.debugDescription ?? "")",
"port": "\(channel.channel.localAddress?.port ?? 0)",
"maxInvocations": "\(self.maxInvocations)",
]
)

// This counter is used to track the number of incoming connections.
// This mock servers accepts n TCP connection then shutdowns
let connectionCounter = SharedCounter(maxValue: self.maxInvocations)

// We are handling each incoming connection in a separate child task. It is important
// to use a discarding task group here which automatically discards finished child tasks.
// A normal task group retains all child tasks and their outputs in memory until they are
// consumed by iterating the group or by exiting the group. Since, we are never consuming
// the results of the group we need the group to automatically discard them; otherwise, this
// would result in a memory leak over time.
try await withThrowingDiscardingTaskGroup { group in
try await channel.executeThenClose { inbound in
for try await connectionChannel in inbound {

let counter = connectionCounter.current()
logger.trace("Handling new connection", metadata: ["connectionNumber": "\(counter)"])

group.addTask {
await self.handleConnection(channel: connectionChannel)
logger.trace("Done handling connection", metadata: ["connectionNumber": "\(counter)"])
}

if connectionCounter.increment() {
logger.info(
"Maximum number of connections reached, shutting down after current connection",
metadata: ["maxConnections": "\(self.maxInvocations)"]
)
break // this causes the server to shutdown after handling the connection
}
}
}
}
logger.info("Server shutting down")
}

/// This method handles a single connection by responsing hard coded value to a Lambda function request.
/// It handles two requests: one for the next invocation and one for the response.
/// when the maximum number of requests is reached, it closes the connection.
private func handleConnection(
channel: NIOAsyncChannel<HTTPServerRequestPart, HTTPServerResponsePart>
) async {

var requestHead: HTTPRequestHead!
var requestBody: ByteBuffer?

// each Lambda invocation results in TWO HTTP requests (next and response)
let requestCount = SharedCounter(maxValue: 2)

// Note that this method is non-throwing and we are catching any error.
// We do this since we don't want to tear down the whole server when a single connection
// encounters an error.
do {
try await channel.executeThenClose { inbound, outbound in
for try await inboundData in inbound {
let requestNumber = requestCount.current()
logger.trace("Handling request", metadata: ["requestNumber": "\(requestNumber)"])

if case .head(let head) = inboundData {
logger.trace("Received request head", metadata: ["head": "\(head)"])
requestHead = head
}
if case .body(let body) = inboundData {
logger.trace("Received request body", metadata: ["body": "\(body)"])
requestBody = body
}
if case .end(let end) = inboundData {
logger.trace("Received request end", metadata: ["end": "\(String(describing: end))"])

precondition(requestHead != nil, "Received .end without .head")
let (responseStatus, responseHeaders, responseBody) = self.processRequest(
requestHead: requestHead,
requestBody: requestBody
)

try await self.sendResponse(
responseStatus: responseStatus,
responseHeaders: responseHeaders,
responseBody: responseBody,
outbound: outbound
)

requestHead = nil

if requestCount.increment() {
logger.info(
"Maximum number of requests reached, closing this connection",
metadata: ["maxRequest": "2"]
)
break // this finishes handiling request on this connection
}
}
}
}
} catch {
logger.error("Hit error: \(error)")
}
}
/// This function process the requests and return an hard-coded response (string or JSON depending on the mode).
/// We ignore the requestBody.
private func processRequest(
requestHead: HTTPRequestHead,
requestBody: ByteBuffer?
) -> (HTTPResponseStatus, [(String, String)], String) {
var responseStatus: HTTPResponseStatus = .ok
var responseBody: String = ""
var responseHeaders: [(String, String)] = []

logger.trace(
"Processing request",
metadata: ["VERB": "\(requestHead.method)", "URI": "\(requestHead.uri)"]
)

if requestHead.uri.hasSuffix("/next") {
responseStatus = .accepted

let requestId = UUID().uuidString
switch self.mode {
case .string:
responseBody = "\"Seb\"" // must be a valid JSON document
case .json:
responseBody = "{ \"name\": \"Seb\", \"age\" : 52 }"
}
let deadline = Int64(Date(timeIntervalSinceNow: 60).timeIntervalSince1970 * 1000)
responseHeaders = [
(AmazonHeaders.requestID, requestId),
(AmazonHeaders.invokedFunctionARN, "arn:aws:lambda:us-east-1:123456789012:function:custom-runtime"),
(AmazonHeaders.traceID, "Root=1-5bef4de7-ad49b0e87f6ef6c87fc2e700;Parent=9a9197af755a6419;Sampled=1"),
(AmazonHeaders.deadline, String(deadline)),
]
} else if requestHead.uri.hasSuffix("/response") {
responseStatus = .accepted
} else if requestHead.uri.hasSuffix("/error") {
responseStatus = .ok
} else {
responseStatus = .notFound
}
logger.trace("Returning response: \(responseStatus), \(responseHeaders), \(responseBody)")
return (responseStatus, responseHeaders, responseBody)
}

private func sendResponse(
responseStatus: HTTPResponseStatus,
responseHeaders: [(String, String)],
responseBody: String,
outbound: NIOAsyncChannelOutboundWriter<HTTPServerResponsePart>
) async throws {
var headers = HTTPHeaders(responseHeaders)
headers.add(name: "Content-Length", value: "\(responseBody.utf8.count)")
headers.add(name: "KeepAlive", value: "timeout=1, max=2")

logger.trace("Writing response head")
try await outbound.write(
HTTPServerResponsePart.head(
HTTPResponseHead(
version: .init(major: 1, minor: 1), // use HTTP 1.1 it keeps connection alive between requests
status: responseStatus,
headers: headers
)
)
)
logger.trace("Writing response body")
try await outbound.write(HTTPServerResponsePart.body(.byteBuffer(ByteBuffer(string: responseBody))))
logger.trace("Writing response end")
try await outbound.write(HTTPServerResponsePart.end(nil))
}

private enum Mode: String {
case string
case json
}

private static func env(_ name: String) -> String? {
guard let value = getenv(name) else {
return nil
}
return String(cString: value)
}

private enum AmazonHeaders {
static let requestID = "Lambda-Runtime-Aws-Request-Id"
static let traceID = "Lambda-Runtime-Trace-Id"
static let clientContext = "X-Amz-Client-Context"
static let cognitoIdentity = "X-Amz-Cognito-Identity"
static let deadline = "Lambda-Runtime-Deadline-Ms"
static let invokedFunctionARN = "Lambda-Runtime-Invoked-Function-Arn"
}

private final class SharedCounter: Sendable {
private let counterMutex = Mutex<Int>(0)
private let maxValue: Int

init(maxValue: Int) {
self.maxValue = maxValue
}
func current() -> Int {
counterMutex.withLock { $0 }
}
func increment() -> Bool {
counterMutex.withLock {
$0 += 1
return $0 >= maxValue
}
}
}
}
Loading
Loading