diff --git a/.changes/cc154f8b-62ba-4ab5-8da5-84d1b5fe0d9f.json b/.changes/cc154f8b-62ba-4ab5-8da5-84d1b5fe0d9f.json new file mode 100644 index 00000000000..1f6faf97e9a --- /dev/null +++ b/.changes/cc154f8b-62ba-4ab5-8da5-84d1b5fe0d9f.json @@ -0,0 +1,8 @@ +{ + "id": "cc154f8b-62ba-4ab5-8da5-84d1b5fe0d9f", + "type": "feature", + "description": "Added MD5 checksum validation for SQS message operations", + "issues": [ + "awslabs/aws-sdk-kotlin#222" + ] +} \ No newline at end of file diff --git a/aws-runtime/aws-config/common/src/aws/sdk/kotlin/runtime/config/profile/AwsProfile.kt b/aws-runtime/aws-config/common/src/aws/sdk/kotlin/runtime/config/profile/AwsProfile.kt index 1f69799c31e..439b2bab712 100644 --- a/aws-runtime/aws-config/common/src/aws/sdk/kotlin/runtime/config/profile/AwsProfile.kt +++ b/aws-runtime/aws-config/common/src/aws/sdk/kotlin/runtime/config/profile/AwsProfile.kt @@ -212,15 +212,29 @@ public inline fun > AwsProfile.getEnumOrNull(key: String, su it.name.equals(value, ignoreCase = true) } ?: throw ConfigurationException( buildString { - append(key) - append(" '") - append(value) - append("' is not supported, should be one of: ") + append("$key '$value' is not supported, should be one of: ") enumValues().joinTo(this) { it.name.lowercase() } }, ) } +/** + * Parse a config value as an enum set. + */ +@InternalSdkApi +public inline fun > AwsProfile.getEnumSetOrNull(key: String, subKey: String? = null): Set? = + getOrNull(key, subKey)?.split(",")?.map { it -> + val value = it.trim() + enumValues().firstOrNull { enumValue -> + enumValue.name.equals(value, ignoreCase = true) + } ?: throw ConfigurationException( + buildString { + append("$key '$value' is not supported, should be one of: ") + enumValues().joinTo(this) { it.name.lowercase() } + }, + ) + }?.toSet() + internal fun AwsProfile.getUrlOrNull(key: String, subKey: String? = null): Url? = getOrNull(key, subKey)?.let { try { diff --git a/codegen/aws-sdk-codegen/src/main/kotlin/aws/sdk/kotlin/codegen/customization/sqs/SqsMd5ChecksumValidationIntegration.kt b/codegen/aws-sdk-codegen/src/main/kotlin/aws/sdk/kotlin/codegen/customization/sqs/SqsMd5ChecksumValidationIntegration.kt new file mode 100644 index 00000000000..9c74f6cb81e --- /dev/null +++ b/codegen/aws-sdk-codegen/src/main/kotlin/aws/sdk/kotlin/codegen/customization/sqs/SqsMd5ChecksumValidationIntegration.kt @@ -0,0 +1,150 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ +package aws.sdk.kotlin.codegen.customization.sqs + +import aws.sdk.kotlin.codegen.ServiceClientCompanionObjectWriter +import aws.sdk.kotlin.codegen.sdkId +import software.amazon.smithy.kotlin.codegen.KotlinSettings +import software.amazon.smithy.kotlin.codegen.core.CodegenContext +import software.amazon.smithy.kotlin.codegen.core.KotlinWriter +import software.amazon.smithy.kotlin.codegen.integration.AppendingSectionWriter +import software.amazon.smithy.kotlin.codegen.integration.KotlinIntegration +import software.amazon.smithy.kotlin.codegen.integration.SectionWriterBinding +import software.amazon.smithy.kotlin.codegen.lang.KotlinTypes +import software.amazon.smithy.kotlin.codegen.model.buildSymbol +import software.amazon.smithy.kotlin.codegen.model.expectShape +import software.amazon.smithy.kotlin.codegen.rendering.protocol.ProtocolGenerator +import software.amazon.smithy.kotlin.codegen.rendering.protocol.ProtocolMiddleware +import software.amazon.smithy.kotlin.codegen.rendering.util.ConfigProperty +import software.amazon.smithy.kotlin.codegen.rendering.util.ConfigPropertyType +import software.amazon.smithy.model.Model +import software.amazon.smithy.model.shapes.OperationShape +import software.amazon.smithy.model.shapes.ServiceShape + +/** + * Register interceptor to handle SQS message MD5 checksum validation. + */ +class SqsMd5ChecksumValidationIntegration : KotlinIntegration { + override fun enabledForService(model: Model, settings: KotlinSettings): Boolean = + model.expectShape(settings.service).sdkId.lowercase() == "sqs" + + companion object { + val ValidationEnabledProp: ConfigProperty = ConfigProperty { + name = "checksumValidationEnabled" + symbol = buildSymbol { + name = "ValidationEnabled" + namespace = "aws.sdk.kotlin.services.sqs.internal" + nullable = false + } + propertyType = ConfigPropertyType.Custom( + render = { prop, writer -> + writer.write("public val #1L: #2T = builder.#1L ?: #2T.NEVER", prop.propertyName, prop.symbol) + }, + renderBuilder = { prop, writer -> + prop.documentation?.let(writer::dokka) + writer.write("public var #L: #T? = null", prop.propertyName, prop.symbol) + writer.write("") + }, + ) + documentation = """ + Specifies when MD5 checksum validation should be performed for SQS messages. This controls the automatic + calculation and validation of checksums during message operations. + + Valid values: + - `ALWAYS` - Checksums are calculated and validated for both sending and receiving operations + (SendMessage, SendMessageBatch, and ReceiveMessage) + - `WHEN_SENDING` - Checksums are only calculated and validated during send operations + (SendMessage and SendMessageBatch) + - `WHEN_RECEIVING` - Checksums are only calculated and validated during receive operations + (ReceiveMessage) + - `NEVER` (default) - No checksum calculation or validation is performed + """.trimIndent() + // TODO: MD5 checksum validation is temporarily disabled. Change default to ALWAYS in v1.5 + } + + private val validationScope = buildSymbol { + name = "ValidationScope" + namespace = "aws.sdk.kotlin.services.sqs.internal" + nullable = false + } + + val ValidationScopeProp: ConfigProperty = ConfigProperty { + name = "checksumValidationScopes" + symbol = KotlinTypes.Collections.set(validationScope) + propertyType = ConfigPropertyType.Custom( + render = { prop, writer -> + writer.write("public val #1L: #2T = builder.#1L ?: #3T.entries.toSet()", prop.propertyName, prop.symbol, validationScope) + }, + renderBuilder = { prop, writer -> + prop.documentation?.let(writer::dokka) + writer.write("public var #L: #T? = null", prop.propertyName, prop.symbol) + writer.write("") + }, + ) + documentation = """ + Specifies which parts of an SQS message should undergo MD5 checksum validation. This configuration + accepts a set of validation scopes that determine which message components to validate. + + Valid values: + - `MESSAGE_ATTRIBUTES` - Validates checksums for message attributes + - `MESSAGE_SYSTEM_ATTRIBUTES` - Validates checksums for message system attributes + (Note: Not available for ReceiveMessage operations as SQS does not calculate checksums for + system attributes during message receipt) + - `MESSAGE_BODY` - Validates checksums for the message body + + Default: All three scopes (`MESSAGE_ATTRIBUTES`, `MESSAGE_SYSTEM_ATTRIBUTES`, `MESSAGE_BODY`) + """.trimIndent() + } + } + + override fun additionalServiceConfigProps(ctx: CodegenContext): List = + listOf( + ValidationEnabledProp, + ValidationScopeProp, + ) + + override val sectionWriters: List + get() = listOf( + SectionWriterBinding( + ServiceClientCompanionObjectWriter.FinalizeEnvironmentalConfig, + finalizeSqsConfigWriter, + ), + ) + + // add SQS-specific config finalization + private val finalizeSqsConfigWriter = AppendingSectionWriter { writer -> + val finalizeSqsConfig = buildSymbol { + name = "finalizeSqsConfig" + namespace = "aws.sdk.kotlin.services.sqs.internal" + } + writer.write("#T(builder, sharedConfig)", finalizeSqsConfig) + } + + override fun customizeMiddleware( + ctx: ProtocolGenerator.GenerationContext, + resolved: List, + ): List = resolved + listOf(SqsMd5ChecksumValidationMiddleware) +} + +internal object SqsMd5ChecksumValidationMiddleware : ProtocolMiddleware { + override fun isEnabledFor(ctx: ProtocolGenerator.GenerationContext, op: OperationShape): Boolean = when (op.id.name) { + "ReceiveMessage", + "SendMessage", + "SendMessageBatch", + -> true + else -> false + } + + override val name: String = "SqsMd5ChecksumValidationInterceptor" + + override fun render(ctx: ProtocolGenerator.GenerationContext, op: OperationShape, writer: KotlinWriter) { + val symbol = buildSymbol { + name = this@SqsMd5ChecksumValidationMiddleware.name + namespace = "aws.sdk.kotlin.services.sqs" + } + + writer.write("op.interceptors.add(#T(config.checksumValidationEnabled, config.checksumValidationScopes))", symbol) + } +} diff --git a/codegen/aws-sdk-codegen/src/main/resources/META-INF/services/software.amazon.smithy.kotlin.codegen.integration.KotlinIntegration b/codegen/aws-sdk-codegen/src/main/resources/META-INF/services/software.amazon.smithy.kotlin.codegen.integration.KotlinIntegration index 7786616c0a6..369fe2ccfc1 100644 --- a/codegen/aws-sdk-codegen/src/main/resources/META-INF/services/software.amazon.smithy.kotlin.codegen.integration.KotlinIntegration +++ b/codegen/aws-sdk-codegen/src/main/resources/META-INF/services/software.amazon.smithy.kotlin.codegen.integration.KotlinIntegration @@ -49,3 +49,4 @@ aws.sdk.kotlin.codegen.smoketests.SmokeTestsDenyListIntegration aws.sdk.kotlin.codegen.smoketests.testing.SmokeTestSuccessHttpEngineIntegration aws.sdk.kotlin.codegen.smoketests.testing.SmokeTestFailHttpEngineIntegration aws.sdk.kotlin.codegen.customization.AwsQueryModeCustomization +aws.sdk.kotlin.codegen.customization.sqs.SqsMd5ChecksumValidationIntegration \ No newline at end of file diff --git a/codegen/aws-sdk-codegen/src/test/kotlin/aws/sdk/kotlin/codegen/customization/sqs/SqsMd5ChecksumValidationIntegrationTest.kt b/codegen/aws-sdk-codegen/src/test/kotlin/aws/sdk/kotlin/codegen/customization/sqs/SqsMd5ChecksumValidationIntegrationTest.kt new file mode 100644 index 00000000000..0c0e32eced5 --- /dev/null +++ b/codegen/aws-sdk-codegen/src/test/kotlin/aws/sdk/kotlin/codegen/customization/sqs/SqsMd5ChecksumValidationIntegrationTest.kt @@ -0,0 +1,52 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ +package aws.sdk.kotlin.codegen.customization.sqs + +import aws.sdk.kotlin.codegen.testutil.model +import org.junit.jupiter.api.Test +import software.amazon.smithy.kotlin.codegen.core.KotlinWriter +import software.amazon.smithy.kotlin.codegen.rendering.protocol.ProtocolGenerator +import software.amazon.smithy.kotlin.codegen.rendering.protocol.ProtocolMiddleware +import software.amazon.smithy.kotlin.codegen.test.defaultSettings +import software.amazon.smithy.kotlin.codegen.test.newTestContext +import software.amazon.smithy.model.shapes.OperationShape +import kotlin.test.assertEquals +import kotlin.test.assertFalse +import kotlin.test.assertTrue +import kotlin.test.fail + +class SqsMd5ChecksumValidationIntegrationTest { + object FooMiddleware : ProtocolMiddleware { + override val name: String = "FooMiddleware" + override fun render(ctx: ProtocolGenerator.GenerationContext, op: OperationShape, writer: KotlinWriter) = + fail("Unexpected call to `FooMiddleware.render`") + } + + @Test + fun testNotExpectedForNonSqsModel() { + val model = model("NotSqs") + val actual = SqsMd5ChecksumValidationIntegration().enabledForService(model, model.defaultSettings()) + + assertFalse(actual) + } + + @Test + fun testExpectedForSqsModel() { + val model = model("Sqs") + val actual = SqsMd5ChecksumValidationIntegration().enabledForService(model, model.defaultSettings()) + + assertTrue(actual) + } + + @Test + fun testMiddlewareAddition() { + val model = model("Sqs") + val preexistingMiddleware = listOf(FooMiddleware) + val ctx = model.newTestContext("Sqs") + val actual = SqsMd5ChecksumValidationIntegration().customizeMiddleware(ctx.generationCtx, preexistingMiddleware) + + assertEquals(listOf(FooMiddleware, SqsMd5ChecksumValidationMiddleware), actual) + } +} diff --git a/services/sqs/common/src/aws.sdk.kotlin.services.sqs/SqsMd5ChecksumValidationInterceptor.kt b/services/sqs/common/src/aws.sdk.kotlin.services.sqs/SqsMd5ChecksumValidationInterceptor.kt new file mode 100644 index 00000000000..4adbe3aa5c2 --- /dev/null +++ b/services/sqs/common/src/aws.sdk.kotlin.services.sqs/SqsMd5ChecksumValidationInterceptor.kt @@ -0,0 +1,342 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ +package aws.sdk.kotlin.services.sqs + +import aws.sdk.kotlin.runtime.ClientException +import aws.sdk.kotlin.services.sqs.internal.ValidationEnabled +import aws.sdk.kotlin.services.sqs.internal.ValidationScope +import aws.sdk.kotlin.services.sqs.model.* +import aws.smithy.kotlin.runtime.client.ResponseInterceptorContext +import aws.smithy.kotlin.runtime.collections.AttributeKey +import aws.smithy.kotlin.runtime.hashing.md5 +import aws.smithy.kotlin.runtime.http.interceptors.ChecksumMismatchException +import aws.smithy.kotlin.runtime.http.interceptors.HttpInterceptor +import aws.smithy.kotlin.runtime.http.request.HttpRequest +import aws.smithy.kotlin.runtime.http.response.HttpResponse +import aws.smithy.kotlin.runtime.io.SdkBuffer +import aws.smithy.kotlin.runtime.telemetry.logging.Logger +import aws.smithy.kotlin.runtime.telemetry.logging.error +import aws.smithy.kotlin.runtime.telemetry.logging.logger +import aws.smithy.kotlin.runtime.util.asyncLazy +import kotlinx.coroutines.runBlocking +import kotlin.coroutines.coroutineContext + +private const val STRING_TYPE_FIELD_INDEX: Byte = 1 +private const val BINARY_TYPE_FIELD_INDEX: Byte = 2 +private const val STRING_LIST_TYPE_FIELD_INDEX: Byte = 3 +private const val BINARY_LIST_TYPE_FIELD_INDEX: Byte = 4 + +/** + * Interceptor that validates MD5 checksums for SQS message operations. + * + * This interceptor performs client-side validation of MD5 checksums returned by SQS to ensure + * message integrity during transmission. It validates the following components: + * - Message body + * - Message attributes + * - Message system attributes + * + * The validation behavior can be configured using: + * - [checksumValidationEnabled] - Controls when validation occurs (`ALWAYS`, `WHEN_SENDING`, `WHEN_RECEIVING`, `NEVER`) + * - [checksumValidationScopes] - Specifies which message components to validate + * + * Supported operations: + * - SendMessage + * - SendMessageBatch + * - ReceiveMessage + */ +@OptIn(ExperimentalStdlibApi::class) +public class SqsMd5ChecksumValidationInterceptor( + private val validationEnabled: ValidationEnabled, + private val validationScopes: Set, +) : HttpInterceptor { + public companion object { + private val isMd5Available = asyncLazy { + try { + "MD5".encodeToByteArray().md5() + true + } catch (e: Exception) { + coroutineContext.error(e) { + "MD5 checksums are not available (likely due to FIPS mode). Checksum validation will be disabled." + } + false + } + } + } + + override fun readAfterExecution(context: ResponseInterceptorContext) { + if (validationEnabled == ValidationEnabled.NEVER || runBlocking { !isMd5Available.get() }) return + + val logger = context.executionContext.coroutineContext.logger() + + val request = context.request + + context.response.getOrNull()?.let { response -> + when (request) { + is SendMessageRequest -> { + if (validationEnabled == ValidationEnabled.WHEN_RECEIVING) return + + val sendMessageResponse = response as SendMessageResponse + sendMessageOperationMd5Check(request, sendMessageResponse, logger) + } + + is ReceiveMessageRequest -> { + if (validationEnabled == ValidationEnabled.WHEN_SENDING) return + + val receiveMessageResponse = response as ReceiveMessageResponse + receiveMessageResultMd5Check(receiveMessageResponse, logger) + } + + is SendMessageBatchRequest -> { + if (validationEnabled == ValidationEnabled.WHEN_RECEIVING) return + + val sendMessageBatchResponse = response as SendMessageBatchResponse + sendMessageBatchOperationMd5Check(request, sendMessageBatchResponse, logger) + } + } + } + + // Sets validation flag in execution context for e2e test assertions + val checksumValidated: AttributeKey = AttributeKey("checksumValidated") + context.executionContext[checksumValidated] = true + } + + private fun sendMessageOperationMd5Check( + sendMessageRequest: SendMessageRequest, + sendMessageResponse: SendMessageResponse, + logger: Logger, + ) { + if (validationScopes.contains(ValidationScope.MESSAGE_BODY)) { + val messageBodyMd5Returned = sendMessageResponse.md5OfMessageBody + val messageBodySent = sendMessageRequest.messageBody + + if (!messageBodyMd5Returned.isNullOrEmpty() && !messageBodySent.isNullOrEmpty()) { + logger.debug { "Validating message body MD5 checksum for SendMessage" } + + val clientSideBodyMd5 = calculateMessageBodyMd5(messageBodySent) + + validateMd5(clientSideBodyMd5, messageBodyMd5Returned) + + logger.debug { "Message body MD5 checksum for SendMessage validated" } + } + } + + if (validationScopes.contains(ValidationScope.MESSAGE_ATTRIBUTES)) { + val messageAttrMd5Returned = sendMessageResponse.md5OfMessageAttributes + val messageAttrSent = sendMessageRequest.messageAttributes + + if (!messageAttrMd5Returned.isNullOrEmpty() && !messageAttrSent.isNullOrEmpty()) { + logger.debug { "Validating message attribute MD5 checksum for SendMessage" } + + val clientSideAttrMd5 = calculateMessageAttributesMd5(messageAttrSent) + + validateMd5(clientSideAttrMd5, messageAttrMd5Returned) + + logger.debug { "Message attribute MD5 checksum for SendMessage validated" } + } + } + + if (validationScopes.contains(ValidationScope.MESSAGE_SYSTEM_ATTRIBUTES)) { + val messageSysAttrMD5Returned = sendMessageResponse.md5OfMessageSystemAttributes + val messageSysAttrSent = sendMessageRequest.messageSystemAttributes + + if (!messageSysAttrMD5Returned.isNullOrEmpty() && !messageSysAttrSent.isNullOrEmpty()) { + logger.debug { "Validating message system attribute MD5 checksum for SendMessage" } + + val clientSideSysAttrMd5 = calculateMessageSystemAttributesMd5(messageSysAttrSent) + + validateMd5(clientSideSysAttrMd5, messageSysAttrMD5Returned) + + logger.debug { "Message system attribute MD5 checksum for SendMessage validated" } + } + } + } + + private fun receiveMessageResultMd5Check(receiveMessageResponse: ReceiveMessageResponse, logger: Logger) { + receiveMessageResponse.messages?.forEach { messageReceived -> + if (validationScopes.contains(ValidationScope.MESSAGE_BODY)) { + val messageBodyMd5Returned = messageReceived.md5OfBody + val messageBodyReturned = messageReceived.body + + if (!messageBodyMd5Returned.isNullOrEmpty() && !messageBodyReturned.isNullOrEmpty()) { + logger.debug { "Validating message body MD5 checksum for ReceiveMessage" } + + val clientSideBodyMd5 = calculateMessageBodyMd5(messageBodyReturned) + + validateMd5(clientSideBodyMd5, messageBodyMd5Returned) + + logger.debug { "Message body MD5 checksum for ReceiveMessage validated" } + } + } + + if (validationScopes.contains(ValidationScope.MESSAGE_ATTRIBUTES)) { + val messageAttrMd5Returned = messageReceived.md5OfMessageAttributes + val messageAttrReturned = messageReceived.messageAttributes + + if (!messageAttrMd5Returned.isNullOrEmpty() && !messageAttrReturned.isNullOrEmpty()) { + logger.debug { "Validating message attribute MD5 checksum for ReceiveMessage" } + + val clientSideAttrMd5 = calculateMessageAttributesMd5(messageAttrReturned) + + validateMd5(clientSideAttrMd5, messageAttrMd5Returned) + + logger.debug { "Message attribute MD5 checksum for ReceiveMessage validated" } + } + } + } + } + + private fun sendMessageBatchOperationMd5Check( + sendMessageBatchRequest: SendMessageBatchRequest, + sendMessageBatchResponse: SendMessageBatchResponse, + logger: Logger, + ) { + val idToRequestEntry = sendMessageBatchRequest + .entries + .orEmpty() + .associateBy { it.id } + + for (entry in sendMessageBatchResponse.successful) { + if (validationScopes.contains(ValidationScope.MESSAGE_BODY)) { + val messageBodyMd5Returned = entry.md5OfMessageBody + val messageBodySent = idToRequestEntry[entry.id]?.messageBody + + if (!messageBodyMd5Returned.isNullOrEmpty() && !messageBodySent.isNullOrEmpty()) { + logger.debug { "Validating message body MD5 checksum for SendMessageBatch: ${entry.messageId}" } + + val clientSideBodyMd5 = calculateMessageBodyMd5(messageBodySent) + + validateMd5(clientSideBodyMd5, messageBodyMd5Returned) + + logger.debug { "Message body MD5 checksum for SendMessageBatch: ${entry.messageId} validated" } + } + } + + if (validationScopes.contains(ValidationScope.MESSAGE_ATTRIBUTES)) { + val messageAttrMD5Returned = entry.md5OfMessageAttributes + val messageAttrSent = idToRequestEntry[entry.id]?.messageAttributes + + if (!messageAttrMD5Returned.isNullOrEmpty() && !messageAttrSent.isNullOrEmpty()) { + logger.debug { "Validating message attribute MD5 checksum for SendMessageBatch: ${entry.messageId}" } + + val clientSideAttrMd5 = calculateMessageAttributesMd5(messageAttrSent) + + validateMd5(clientSideAttrMd5, messageAttrMD5Returned) + + logger.debug { "Message attribute MD5 checksum for SendMessageBatch: ${entry.messageId} validated" } + } + } + + if (validationScopes.contains(ValidationScope.MESSAGE_SYSTEM_ATTRIBUTES)) { + val messageSysAttrMD5Returned = entry.md5OfMessageSystemAttributes + val messageSysAttrSent = idToRequestEntry[entry.id]?.messageSystemAttributes + + if (!messageSysAttrMD5Returned.isNullOrEmpty() && !messageSysAttrSent.isNullOrEmpty()) { + logger.debug { "Validating message system attribute MD5 checksum for SendMessageBatch: ${entry.messageId}" } + + val clientSideSysAttrMd5 = calculateMessageSystemAttributesMd5(messageSysAttrSent) + + validateMd5(clientSideSysAttrMd5, messageSysAttrMD5Returned) + + logger.debug { "Message system attribute MD5 checksum for SendMessageBatch: ${entry.messageId} validated" } + } + } + } + } + + private fun validateMd5(clientSideMd5: String, md5Returned: String) { + if (clientSideMd5 != md5Returned) { + throw ChecksumMismatchException("Checksum mismatch. Expected $clientSideMd5 but was $md5Returned") + } + } + + private fun calculateMessageBodyMd5(messageBody: String) = + messageBody.encodeToByteArray().md5().toHexString() + + /** + * Calculates the MD5 digest for message attributes according to SQS specifications. + * https://docs.aws.amazon.com/AWSSimpleQueueService/latest/SQSDeveloperGuide/sqs-message-metadata.html#sqs-attributes-md5-message-digest-calculation + */ + private fun calculateMessageAttributesMd5(messageAttributes: Map): String { + val buffer = SdkBuffer() + + messageAttributes + .entries + .sortedBy { (attributeName, _) -> attributeName } + .forEach { (attributeName, attributeValue) -> + updateLengthAndBytes(buffer, attributeName) + updateLengthAndBytes(buffer, attributeValue.dataType) + when { + attributeValue.stringValue != null -> updateForStringType(buffer, attributeValue.stringValue) + attributeValue.binaryValue != null -> updateForBinaryType(buffer, attributeValue.binaryValue) + !attributeValue.stringListValues.isNullOrEmpty() -> updateForStringListType(buffer, attributeValue.stringListValues) + !attributeValue.binaryListValues.isNullOrEmpty() -> updateForBinaryListType(buffer, attributeValue.binaryListValues) + else -> throw ClientException("No value type found for attribute $attributeName") + } + } + + val payload = buffer.readByteArray() + return payload.md5().toHexString() + } + + private fun calculateMessageSystemAttributesMd5( + messageSystemAttributes: Map, + ): String { + val buffer = SdkBuffer() + + messageSystemAttributes + .entries + .sortedBy { (systemAttributeName, _) -> systemAttributeName.value } + .forEach { (systemAttributeName, systemAttributeValue) -> + updateLengthAndBytes(buffer, systemAttributeName.value) + updateLengthAndBytes(buffer, systemAttributeValue.dataType) + when { + systemAttributeValue.stringValue != null -> updateForStringType(buffer, systemAttributeValue.stringValue) + systemAttributeValue.binaryValue != null -> updateForBinaryType(buffer, systemAttributeValue.binaryValue) + !systemAttributeValue.stringListValues.isNullOrEmpty() -> updateForStringListType(buffer, systemAttributeValue.stringListValues) + !systemAttributeValue.binaryListValues.isNullOrEmpty() -> updateForBinaryListType(buffer, systemAttributeValue.binaryListValues) + else -> throw ClientException("No value type found for system attribute $systemAttributeName") + } + } + + val payload = buffer.readByteArray() + return payload.md5().toHexString() + } + + private fun updateForStringType(buffer: SdkBuffer, value: String) { + buffer.writeByte(STRING_TYPE_FIELD_INDEX) + updateLengthAndBytes(buffer, value) + } + + private fun updateForBinaryType(buffer: SdkBuffer, value: ByteArray) { + buffer.writeByte(BINARY_TYPE_FIELD_INDEX) + updateLengthAndBytes(buffer, value) + } + + private fun updateForStringListType(buffer: SdkBuffer, values: List) { + buffer.writeByte(STRING_LIST_TYPE_FIELD_INDEX) + values.forEach { value -> + updateLengthAndBytes(buffer, value) + } + } + + private fun updateForBinaryListType(buffer: SdkBuffer, values: List) { + buffer.writeByte(BINARY_LIST_TYPE_FIELD_INDEX) + values.forEach { value -> + updateLengthAndBytes(buffer, value) + } + } + + private fun updateLengthAndBytes(buffer: SdkBuffer, stringValue: String) = + updateLengthAndBytes(buffer, stringValue.encodeToByteArray()) + + /** + * Update the digest using a sequence of bytes that consists of the length (in 4 bytes) of the + * input binaryValue and all the bytes it contains. + */ + private fun updateLengthAndBytes(buffer: SdkBuffer, binaryValue: ByteArray) { + buffer.writeInt(binaryValue.size) + buffer.write(binaryValue) + } +} diff --git a/services/sqs/common/src/aws.sdk.kotlin.services.sqs/internal/FinalizeSqsConfig.kt b/services/sqs/common/src/aws.sdk.kotlin.services.sqs/internal/FinalizeSqsConfig.kt new file mode 100644 index 00000000000..a176cc8bb48 --- /dev/null +++ b/services/sqs/common/src/aws.sdk.kotlin.services.sqs/internal/FinalizeSqsConfig.kt @@ -0,0 +1,34 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ +package aws.sdk.kotlin.services.sqs.internal + +import aws.sdk.kotlin.runtime.config.profile.* +import aws.sdk.kotlin.services.sqs.SqsClient +import aws.smithy.kotlin.runtime.config.resolve +import aws.smithy.kotlin.runtime.util.LazyAsyncValue +import aws.smithy.kotlin.runtime.util.PlatformProvider + +internal suspend fun finalizeSqsConfig( + builder: SqsClient.Builder, + sharedConfig: LazyAsyncValue, + provider: PlatformProvider = PlatformProvider.System, +) { + val activeProfile = sharedConfig.get().activeProfile + builder.config.checksumValidationEnabled = builder.config.checksumValidationEnabled + ?: SqsSetting.checksumValidationEnabled.resolve(provider) + ?: activeProfile.checksumValidationEnabled + ?: ValidationEnabled.NEVER // TODO: MD5 checksum validation is temporarily disabled. Set default to ALWAYS in v1.5 + + builder.config.checksumValidationScopes = builder.config.checksumValidationScopes + ?: SqsSetting.checksumValidationScopes.resolve(provider) + ?: activeProfile.checksumValidationScopes + ?: ValidationScope.entries.toSet() +} + +private val AwsProfile.checksumValidationEnabled: ValidationEnabled? + get() = getEnumOrNull("sqs_checksum_validation_enabled") + +private val AwsProfile.checksumValidationScopes: Set? + get() = getEnumSetOrNull("sqs_checksum_validation_scope") diff --git a/services/sqs/common/src/aws.sdk.kotlin.services.sqs/internal/SqsSetting.kt b/services/sqs/common/src/aws.sdk.kotlin.services.sqs/internal/SqsSetting.kt new file mode 100644 index 00000000000..bb97697b11b --- /dev/null +++ b/services/sqs/common/src/aws.sdk.kotlin.services.sqs/internal/SqsSetting.kt @@ -0,0 +1,49 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ +package aws.sdk.kotlin.services.sqs.internal + +import aws.smithy.kotlin.runtime.config.* + +/** + * SQS specific system settings + */ +internal object SqsSetting { + /** + * Configure when MD5 checksum validation is performed for SQS operations. + * + * Can be configured using: + * - System property: aws.SqsChecksumValidationEnabled + * - Environment variable: AWS_SQS_CHECKSUM_VALIDATION_ENABLED + * + * Valid values: + * - ALWAYS - Validates checksums for both sending and receiving operations + * - WHEN_SENDING - Validates checksums only when sending messages + * - WHEN_RECEIVING - Validates checksums only when receiving messages + * - NEVER (default) - Disables checksum validation + * + * Note: Value matching is case-insensitive when configured via environment variables. + */ + public val checksumValidationEnabled: EnvironmentSetting = + enumEnvSetting("aws.SqsChecksumValidationEnabled", "AWS_SQS_CHECKSUM_VALIDATION_ENABLED") + + /** + * Configure the scope of checksum validation for SQS operations. + * + * Can be configured using: + * - System property: aws.SqsChecksumValidationScope + * - Environment variable: AWS_SQS_CHECKSUM_VALIDATION_SCOPE + * + * Valid values are comma-separated combinations of: + * - MESSAGE_BODY: Validate message body checksums + * - MESSAGE_ATTRIBUTES: Validate message attribute checksums + * - SYSTEM_ATTRIBUTES: Validate system attribute checksums + * + * Example: "MESSAGE_BODY,MESSAGE_ATTRIBUTES" + * + * If not specified, defaults to validating all scopes. + */ + public val checksumValidationScopes: EnvironmentSetting?> = + enumSetEnvSetting("aws.SqsChecksumValidationScope", "AWS_SQS_CHECKSUM_VALIDATION_SCOPE") +} diff --git a/services/sqs/common/src/aws.sdk.kotlin.services.sqs/internal/ValidationConfig.kt b/services/sqs/common/src/aws.sdk.kotlin.services.sqs/internal/ValidationConfig.kt new file mode 100644 index 00000000000..bc03181bab1 --- /dev/null +++ b/services/sqs/common/src/aws.sdk.kotlin.services.sqs/internal/ValidationConfig.kt @@ -0,0 +1,47 @@ +package aws.sdk.kotlin.services.sqs.internal + +/** + * Controls when MD5 checksum validation is performed for SQS operations. + * + * This configuration determines under which conditions checksums will be automatically + * calculated and validated for SQS message operations. + * + * Valid values: + * - `ALWAYS` - Validates checksums for both sending and receiving operations + * (SendMessage, SendMessageBatch, and ReceiveMessage) + * - `WHEN_SENDING` - Validates checksums only when sending messages + * (SendMessage and SendMessageBatch) + * - `WHEN_RECEIVING` - Validates checksums only when receiving messages + * (ReceiveMessage) + * - `NEVER` - Disables checksum validation completely + * + * Default: `NEVER` + */ +// TODO: MD5 checksum validation is temporarily disabled. Change default to ALWAYS in v1.5 +public enum class ValidationEnabled { + ALWAYS, + WHEN_SENDING, + WHEN_RECEIVING, + NEVER, +} + +/** + * Specifies which parts of an SQS message should undergo MD5 checksum validation. + * + * This configuration determines which components of a message will be validated + * when checksum validation is enabled. + * + * Valid values: + * - `MESSAGE_ATTRIBUTES` - Validates checksums for message attributes + * - `MESSAGE_SYSTEM_ATTRIBUTES` - Validates checksums for message system attributes + * (Note: Not available for ReceiveMessage operations as SQS does not calculate + * checksums for system attributes during message receipt) + * - `MESSAGE_BODY` - Validates checksums for the message body + * + * Default: All scopes enabled + */ +public enum class ValidationScope { + MESSAGE_ATTRIBUTES, + MESSAGE_SYSTEM_ATTRIBUTES, + MESSAGE_BODY, +} diff --git a/services/sqs/e2eTest/src/SqsMd5ChecksumValidationTest.kt b/services/sqs/e2eTest/src/SqsMd5ChecksumValidationTest.kt new file mode 100644 index 00000000000..0ca8719db13 --- /dev/null +++ b/services/sqs/e2eTest/src/SqsMd5ChecksumValidationTest.kt @@ -0,0 +1,239 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ +package aws.sdk.kotlin.e2etest + +import aws.sdk.kotlin.e2etest.SqsTestUtils.DEFAULT_REGION +import aws.sdk.kotlin.e2etest.SqsTestUtils.TEST_MESSAGE_ATTRIBUTES_NAME +import aws.sdk.kotlin.e2etest.SqsTestUtils.TEST_MESSAGE_ATTRIBUTES_VALUE +import aws.sdk.kotlin.e2etest.SqsTestUtils.TEST_MESSAGE_BODY +import aws.sdk.kotlin.e2etest.SqsTestUtils.TEST_MESSAGE_SYSTEM_ATTRIBUTES_VALUE +import aws.sdk.kotlin.e2etest.SqsTestUtils.TEST_QUEUE_PREFIX +import aws.sdk.kotlin.e2etest.SqsTestUtils.buildSendMessageBatchRequestEntry +import aws.sdk.kotlin.e2etest.SqsTestUtils.deleteQueueAndAllMessages +import aws.sdk.kotlin.e2etest.SqsTestUtils.getTestQueueUrl +import aws.sdk.kotlin.services.sqs.SqsClient +import aws.sdk.kotlin.services.sqs.internal.ValidationEnabled +import aws.sdk.kotlin.services.sqs.model.* +import aws.smithy.kotlin.runtime.client.ResponseInterceptorContext +import aws.smithy.kotlin.runtime.collections.AttributeKey +import aws.smithy.kotlin.runtime.collections.get +import aws.smithy.kotlin.runtime.hashing.md5 +import aws.smithy.kotlin.runtime.http.interceptors.ChecksumMismatchException +import aws.smithy.kotlin.runtime.http.interceptors.HttpInterceptor +import aws.smithy.kotlin.runtime.http.request.HttpRequest +import aws.smithy.kotlin.runtime.http.response.HttpResponse +import kotlinx.coroutines.runBlocking +import org.junit.jupiter.api.* +import org.junit.jupiter.api.Assertions.assertNotNull + +/** + * Tests for Sqs MD5 checksum validation + */ +@TestInstance(TestInstance.Lifecycle.PER_CLASS) +class SqsMd5ChecksumValidationTest { + // An interceptor that set wrong md5 checksums in SQS response + @OptIn(ExperimentalStdlibApi::class) + private val wrongChecksumInterceptor = object : HttpInterceptor { + override suspend fun modifyBeforeCompletion( + context: ResponseInterceptorContext, + ): Result { + val wrongMd5ofMessageBody = "wrong message md5".encodeToByteArray().md5().toHexString() + val wrongMd5ofMessageAttribute = "wrong attribute md5".encodeToByteArray().md5().toHexString() + val wrongMd5ofMessageSystemAttribute = "wrong system attribute md5".encodeToByteArray().md5().toHexString() + + when (val response = context.response.getOrNull()) { + is SendMessageResponse -> { + val modifiedResponse = response.copy { + md5OfMessageAttributes = wrongMd5ofMessageAttribute + } + return Result.success(modifiedResponse) + } + is ReceiveMessageResponse -> { + val modifiedMessages = response.messages?.map { message -> + message.copy { + md5OfBody = wrongMd5ofMessageBody + } + } + + val modifiedResponse = ReceiveMessageResponse { + messages = modifiedMessages + } + return Result.success(modifiedResponse) + } + is SendMessageBatchResponse -> { + val modifiedEntries = response.successful.map { entry -> + entry.copy { + md5OfMessageSystemAttributes = wrongMd5ofMessageSystemAttribute + } + } + + val modifiedResponse = SendMessageBatchResponse { + successful = modifiedEntries + failed = response.failed + } + return Result.success(modifiedResponse) + } + } + return context.response + } + } + + // An interceptor that checks if the SQS md5 checksum was validated + private val checksumValidationAssertionInterceptor = object : HttpInterceptor { + private val supportedOperations = setOf( + "SendMessage", + "SendMessageBatch", + "ReceiveMessage", + ) + + override fun readAfterExecution(context: ResponseInterceptorContext) { + val operationName = context.executionContext.attributes[AttributeKey("aws.smithy.kotlin#OperationName")] as String + + if (operationName !in supportedOperations) { + return + } + + assertNotNull(context.executionContext.attributes[AttributeKey("checksumValidated")]) + + val isChecksumValidated = context.executionContext.attributes[AttributeKey("checksumValidated")] as Boolean + + assert(isChecksumValidated) + } + } + + private lateinit var correctChecksumClient: SqsClient + + private lateinit var wrongChecksumClient: SqsClient + + private lateinit var testQueueUrl: String + + @BeforeAll + private fun setUp(): Unit = runBlocking { + correctChecksumClient = SqsClient.fromEnvironment { + region = DEFAULT_REGION + checksumValidationEnabled = ValidationEnabled.ALWAYS + interceptors += checksumValidationAssertionInterceptor + } + wrongChecksumClient = SqsClient.fromEnvironment { + region = DEFAULT_REGION + checksumValidationEnabled = ValidationEnabled.ALWAYS + interceptors += wrongChecksumInterceptor + } + testQueueUrl = getTestQueueUrl(correctChecksumClient, TEST_QUEUE_PREFIX) + } + + @AfterAll + private fun cleanUp(): Unit = runBlocking { + deleteQueueAndAllMessages(correctChecksumClient, testQueueUrl) + correctChecksumClient.close() + wrongChecksumClient.close() + } + + @Test + fun testSendMessage(): Unit = runBlocking { + correctChecksumClient.sendMessage( + SendMessageRequest { + queueUrl = testQueueUrl + messageBody = TEST_MESSAGE_BODY + messageAttributes = mapOf( + TEST_MESSAGE_ATTRIBUTES_NAME to MessageAttributeValue { + dataType = "String" + stringValue = TEST_MESSAGE_ATTRIBUTES_VALUE + }, + TEST_MESSAGE_ATTRIBUTES_NAME to MessageAttributeValue { + dataType = "Binary" + binaryValue = TEST_MESSAGE_ATTRIBUTES_VALUE.toByteArray() + }, + ) + messageSystemAttributes = mapOf( + MessageSystemAttributeNameForSends.AwsTraceHeader to MessageSystemAttributeValue { + dataType = "String" + stringValue = TEST_MESSAGE_SYSTEM_ATTRIBUTES_VALUE + }, + ) + }, + ) + } + + @Test + fun testReceiveMessage(): Unit = runBlocking { + correctChecksumClient.receiveMessage( + ReceiveMessageRequest { + queueUrl = testQueueUrl + maxNumberOfMessages = 1 + messageAttributeNames = listOf(TEST_MESSAGE_ATTRIBUTES_NAME) + messageSystemAttributeNames = listOf(MessageSystemAttributeName.AwsTraceHeader) + }, + ) + } + + @Test + fun testSendMessageBatch(): Unit = runBlocking { + val entries = (1..5).map { batchId -> + buildSendMessageBatchRequestEntry(batchId) + } + + correctChecksumClient.sendMessageBatch( + SendMessageBatchRequest { + queueUrl = testQueueUrl + this.entries = entries + }, + ) + } + + @Test + fun testSendMessageWithWrongChecksum(): Unit = runBlocking { + assertThrows { + wrongChecksumClient.sendMessage( + SendMessageRequest { + queueUrl = testQueueUrl + messageBody = TEST_MESSAGE_BODY + messageAttributes = mapOf( + TEST_MESSAGE_ATTRIBUTES_NAME to MessageAttributeValue { + dataType = "String" + stringValue = TEST_MESSAGE_ATTRIBUTES_VALUE + }, + ) + messageSystemAttributes = mapOf( + MessageSystemAttributeNameForSends.AwsTraceHeader to MessageSystemAttributeValue { + dataType = "String" + stringValue = TEST_MESSAGE_SYSTEM_ATTRIBUTES_VALUE + }, + ) + }, + ) + } + } + + @Test + fun testReceiveMessageWithWrongChecksum(): Unit = runBlocking { + assertThrows { + wrongChecksumClient.receiveMessage( + ReceiveMessageRequest { + queueUrl = testQueueUrl + maxNumberOfMessages = 1 + messageAttributeNames = listOf(TEST_MESSAGE_ATTRIBUTES_NAME) + messageSystemAttributeNames = listOf(MessageSystemAttributeName.AwsTraceHeader) + }, + ) + } + } + + @Test + fun testSendMessageBatchWithWrongChecksum(): Unit = runBlocking { + val entries = (1..5).map { batchId -> + buildSendMessageBatchRequestEntry(batchId) + } + + assertThrows { + wrongChecksumClient.sendMessageBatch( + SendMessageBatchRequest { + queueUrl = testQueueUrl + this.entries = entries + }, + ) + } + } +} diff --git a/services/sqs/e2eTest/src/SqsTestUtils.kt b/services/sqs/e2eTest/src/SqsTestUtils.kt new file mode 100644 index 00000000000..3f8b0a35236 --- /dev/null +++ b/services/sqs/e2eTest/src/SqsTestUtils.kt @@ -0,0 +1,92 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ +package aws.sdk.kotlin.e2etest + +import aws.sdk.kotlin.services.sqs.SqsClient +import aws.sdk.kotlin.services.sqs.createQueue +import aws.sdk.kotlin.services.sqs.model.* +import aws.sdk.kotlin.services.sqs.paginators.listQueuesPaginated +import aws.sdk.kotlin.services.sqs.paginators.queueUrls +import kotlinx.coroutines.flow.firstOrNull +import kotlinx.coroutines.withTimeout +import java.util.* +import kotlin.time.Duration.Companion.seconds + +object SqsTestUtils { + const val DEFAULT_REGION = "us-west-2" + + const val TEST_QUEUE_PREFIX = "sqs-test-queue-" + + const val TEST_MESSAGE_BODY = "Hello World" + const val TEST_MESSAGE_ATTRIBUTES_NAME = "TestAttribute" + const val TEST_MESSAGE_ATTRIBUTES_VALUE = "TestAttributeValue" + const val TEST_MESSAGE_SYSTEM_ATTRIBUTES_VALUE = "TestSystemAttributeValue" + + suspend fun getTestQueueUrl(client: SqsClient, prefix: String): String = + getQueueUrlWithPrefix(client, prefix) + + private suspend fun getQueueUrlWithPrefix(client: SqsClient, prefix: String): String = withTimeout(60.seconds) { + var matchingQueueUrl = client + .listQueuesPaginated { queueNamePrefix = prefix } + .queueUrls() + .firstOrNull() + + if (matchingQueueUrl == null) { + matchingQueueUrl = prefix + UUID.randomUUID() + println("Creating SQS queue: $matchingQueueUrl") + + client.createQueue { + queueName = matchingQueueUrl + } + } else { + println("Using existing SQS queue: $matchingQueueUrl") + } + + matchingQueueUrl + } + + suspend fun deleteQueueAndAllMessages(client: SqsClient, queueUrl: String) { + try { + println("Purging SQS queue: $queueUrl") + + client.purgeQueue( + PurgeQueueRequest { + this.queueUrl = queueUrl + }, + ) + + println("Queue purged successfully.") + + println("Deleting SQS queue: $queueUrl") + + client.deleteQueue( + DeleteQueueRequest { + this.queueUrl = queueUrl + }, + ) + + println("Queue deleted successfully.") + } catch (e: SqsException) { + println("Error during delete SQS queue: ${e.message}") + } + } + + fun buildSendMessageBatchRequestEntry(batchId: Int): SendMessageBatchRequestEntry = SendMessageBatchRequestEntry { + id = batchId.toString() + messageBody = TEST_MESSAGE_BODY + batchId + messageAttributes = mapOf( + TEST_MESSAGE_ATTRIBUTES_NAME to MessageAttributeValue { + dataType = "String" + stringValue = TEST_MESSAGE_ATTRIBUTES_VALUE + batchId + }, + ) + messageSystemAttributes = mapOf( + MessageSystemAttributeNameForSends.AwsTraceHeader to MessageSystemAttributeValue { + dataType = "String" + stringValue = TEST_MESSAGE_SYSTEM_ATTRIBUTES_VALUE + batchId + }, + ) + } +}