diff --git a/Cargo.lock b/Cargo.lock index 0df1ad641..5462292bc 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -20,6 +20,27 @@ dependencies = [ "winapi", ] +[[package]] +name = "async-stream" +version = "0.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "171374e7e3b2504e0e5236e3b59260560f9fe94bfe9ac39ba5e4e929c5590625" +dependencies = [ + "async-stream-impl", + "futures-core", +] + +[[package]] +name = "async-stream-impl" +version = "0.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "648ed8c8d2ce5409ccd57453d9d1b214b342a0d69376a6feda1fd6cae3299308" +dependencies = [ + "proc-macro2 1.0.24", + "quote 1.0.9", + "syn 1.0.64", +] + [[package]] name = "atty" version = "0.2.14" @@ -37,6 +58,12 @@ version = "0.1.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b671c8fb71b457dd4ae18c4ba1e59aa81793daacc361d82fcd410cef0d491875" +[[package]] +name = "autocfg" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cdb031dd78e28731d87d56cc8ffef4a8f36ca26c38fe2de700543e627f8a464a" + [[package]] name = "base64" version = "0.9.3" @@ -97,6 +124,12 @@ version = "1.3.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a7c3dd8985a7111efc5c80b44e23ecdd8c007de8ade3b96595387e812b957cf5" +[[package]] +name = "bytes" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c4872d67bab6358e59559027aa3b9157c53d9358c51423c17554809a8858e0f8" + [[package]] name = "cbc" version = "0.1.2" @@ -219,6 +252,100 @@ version = "0.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a06f77d526c1a601b7c4cdd98f54b5eaabffc14d5f2f0296febdc7f357c6d3ba" +[[package]] +name = "futures" +version = "0.3.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a12aa0eb539080d55c3f2d45a67c3b58b6b0773c1a3ca2dfec66d58c97fd66ca" +dependencies = [ + "futures-channel", + "futures-core", + "futures-executor", + "futures-io", + "futures-sink", + "futures-task", + "futures-util", +] + +[[package]] +name = "futures-channel" +version = "0.3.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5da6ba8c3bb3c165d3c7319fc1cc8304facf1fb8db99c5de877183c08a273888" +dependencies = [ + "futures-core", + "futures-sink", +] + +[[package]] +name = "futures-core" +version = "0.3.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "88d1c26957f23603395cd326b0ffe64124b818f4449552f960d815cfba83a53d" + +[[package]] +name = "futures-executor" +version = "0.3.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "45025be030969d763025784f7f355043dc6bc74093e4ecc5000ca4dc50d8745c" +dependencies = [ + "futures-core", + "futures-task", + "futures-util", +] + +[[package]] +name = "futures-io" +version = "0.3.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "522de2a0fe3e380f1bc577ba0474108faf3f6b18321dbf60b3b9c39a75073377" + +[[package]] +name = "futures-macro" +version = "0.3.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "18e4a4b95cea4b4ccbcf1c5675ca7c4ee4e9e75eb79944d07defde18068f79bb" +dependencies = [ + "autocfg 1.0.1", + "proc-macro-hack", + "proc-macro2 1.0.24", + "quote 1.0.9", + "syn 1.0.64", +] + +[[package]] +name = "futures-sink" +version = "0.3.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "36ea153c13024fe480590b3e3d4cad89a0cfacecc24577b68f86c6ced9c2bc11" + +[[package]] +name = "futures-task" +version = "0.3.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1d3d00f4eddb73e498a54394f228cd55853bdf059259e8e7bc6e69d408892e99" + +[[package]] +name = "futures-util" +version = "0.3.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "36568465210a3a6ee45e1f165136d68671471a501e632e9a98d96872222b5481" +dependencies = [ + "autocfg 1.0.1", + "futures-channel", + "futures-core", + "futures-io", + "futures-macro", + "futures-sink", + "futures-task", + "memchr", + "pin-project-lite", + "pin-utils", + "proc-macro-hack", + "proc-macro-nested", + "slab", +] + [[package]] name = "generic-array" version = "0.14.6" @@ -322,9 +449,9 @@ checksum = "830d08ce1d1d941e6b30645f1a0eb5643013d835ce3779a5fc208261dbe10f55" [[package]] name = "libc" -version = "0.2.62" +version = "0.2.140" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "34fcd2c08d2f832f376f4173a231990fa5aef4e99fb569867318a227ef4c06ba" +checksum = "99227334921fae1a979cf0bfdfcc6b3e5ce376ef57e16fb6fb3ea2ed6095f80c" [[package]] name = "libloading" @@ -376,6 +503,7 @@ checksum = "7ffc5c5338469d4d3ea17d269fa8ea3512ad247247c30bd2df69e68309ed0a08" name = "mbedtls" version = "0.9.0" dependencies = [ + "async-stream", "bit-vec", "bitflags", "byteorder", @@ -383,6 +511,7 @@ dependencies = [ "cc", "cfg-if 1.0.0", "chrono", + "futures", "hex", "hyper", "libc", @@ -396,6 +525,8 @@ dependencies = [ "serde_cbor", "serde_derive", "spin", + "tokio", + "tracing", "yasna", ] @@ -429,6 +560,18 @@ dependencies = [ "log 0.3.9", ] +[[package]] +name = "mio" +version = "0.8.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5b9d9a46eff5b4ff64b45a9e316a6d1e0bc719ef429cbec4dc630684212bfdf9" +dependencies = [ + "libc", + "log 0.4.8", + "wasi", + "windows-sys", +] + [[package]] name = "nom" version = "5.1.2" @@ -445,7 +588,7 @@ version = "0.2.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "343b3df15c945a59e72aae31e89a7cfc9e11850e96d4fde6fed5e3c7c8d9c887" dependencies = [ - "autocfg", + "autocfg 0.1.6", "num-integer", "num-traits", ] @@ -456,7 +599,7 @@ version = "0.1.41" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b85e541ef8255f6cf42bbfe4ef361305c6c135d10919ecc26126c4e5ae94bc09" dependencies = [ - "autocfg", + "autocfg 0.1.6", "num-traits", ] @@ -466,7 +609,7 @@ version = "0.2.8" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6ba9a427cfca2be13aa6f6403b0b7e7368fe982bfa16fccc450ce74c46cd9b32" dependencies = [ - "autocfg", + "autocfg 0.1.6", ] [[package]] @@ -497,12 +640,36 @@ version = "1.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "31010dd2e1ac33d5b46a5b413495239882813e0369f8ed8a5e266f173602f831" +[[package]] +name = "pin-project-lite" +version = "0.2.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8d31d11c69a6b52a174b42bdc0c30e5e11670f90788b2c471c31c1d17d449443" + +[[package]] +name = "pin-utils" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8b870d8c151b6f2fb93e84a13146138f05d02ed11c7e7c54f8826aaaf7c9f184" + [[package]] name = "pkg-config" version = "0.3.16" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "72d5370d90f49f70bd033c3d75e87fc529fbfff9d6f7cccef07d6170079d91ea" +[[package]] +name = "proc-macro-hack" +version = "0.5.19" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dbf0c48bc1d91375ae5c3cd81e3722dff1abcf81a30960240640d223f59fe0e5" + +[[package]] +name = "proc-macro-nested" +version = "0.1.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bc881b2c22681370c6a780e47af9840ef841837bc98118431d4e1868bd0c1086" + [[package]] name = "proc-macro2" version = "0.4.30" @@ -611,9 +778,9 @@ checksum = "b5eb417147ba9860a96cfe72a0b93bf88fee1744b5636ec99ab20c1aa9376581" [[package]] name = "rs-libc" -version = "0.2.2" +version = "0.2.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b434763aff74b924c33af0ce3a3791c7c5ff8fb431773061dde30447e2fb77f0" +checksum = "914c985b921cf571d950d17ca33221ed54fed3c2001a329ee6fd5b15dd433260" dependencies = [ "cc", ] @@ -673,6 +840,22 @@ version = "1.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "42a568c8f2cd051a4d283bd6eb0343ac214c1b0f1ac19f93e1175b2dee38c73d" +[[package]] +name = "slab" +version = "0.4.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c307a32c1c5c437f38c7fd45d753050587732ba8628319fbdf12a7e289ccc590" + +[[package]] +name = "socket2" +version = "0.4.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "64a4a911eed85daf18834cfaa86a79b7d266ff93ff5ba14005426219480ed662" +dependencies = [ + "libc", + "winapi", +] + [[package]] name = "spin" version = "0.4.10" @@ -760,6 +943,67 @@ version = "0.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "cda74da7e1a664f795bb1f8a87ec406fb89a02522cf6e50620d016add6dbbf5c" +[[package]] +name = "tokio" +version = "1.19.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c51a52ed6686dd62c320f9b89299e9dfb46f730c7a48e635c19f21d116cb1439" +dependencies = [ + "bytes", + "libc", + "memchr", + "mio", + "num_cpus", + "once_cell", + "pin-project-lite", + "socket2", + "tokio-macros", + "winapi", +] + +[[package]] +name = "tokio-macros" +version = "1.8.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d266c00fde287f55d3f1c3e96c500c362a2b8c695076ec180f27918820bc6df8" +dependencies = [ + "proc-macro2 1.0.24", + "quote 1.0.9", + "syn 1.0.64", +] + +[[package]] +name = "tracing" +version = "0.1.27" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c2ba9ab62b7d6497a8638dfda5e5c4fb3b2d5a7fca4118f2b96151c8ef1a437e" +dependencies = [ + "cfg-if 1.0.0", + "pin-project-lite", + "tracing-attributes", + "tracing-core", +] + +[[package]] +name = "tracing-attributes" +version = "0.1.16" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "98863d0dd09fa59a1b79c6750ad80dbda6b75f4e71c437a6a1a8cb91a8bcbd77" +dependencies = [ + "proc-macro2 1.0.24", + "quote 1.0.9", + "syn 1.0.64", +] + +[[package]] +name = "tracing-core" +version = "0.1.20" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "46125608c26121c81b0c6d693eab5a420e416da7e43c426d2e8f7df8da8a3acf" +dependencies = [ + "lazy_static", +] + [[package]] name = "traitobject" version = "0.1.0" @@ -858,6 +1102,12 @@ version = "0.9.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b5a972e5669d67ba988ce3dc826706fb0a8b01471c088cb0b6110b805cc36aed" +[[package]] +name = "wasi" +version = "0.11.0+wasi-snapshot-preview1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423" + [[package]] name = "which" version = "3.0.0" @@ -869,9 +1119,9 @@ dependencies = [ [[package]] name = "winapi" -version = "0.3.8" +version = "0.3.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8093091eeb260906a183e6ae1abdba2ef5ef2257a21801128899c3fc699229c6" +checksum = "5c839a674fcd7a98952e593242ea400abe93992746761e38641405d28b00f419" dependencies = [ "winapi-i686-pc-windows-gnu", "winapi-x86_64-pc-windows-gnu", @@ -898,6 +1148,72 @@ version = "0.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f" +[[package]] +name = "windows-sys" +version = "0.45.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "75283be5efb2831d37ea142365f009c02ec203cd29a3ebecbc093d52315b66d0" +dependencies = [ + "windows-targets", +] + +[[package]] +name = "windows-targets" +version = "0.42.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8e5180c00cd44c9b1c88adb3693291f1cd93605ded80c250a75d472756b4d071" +dependencies = [ + "windows_aarch64_gnullvm", + "windows_aarch64_msvc", + "windows_i686_gnu", + "windows_i686_msvc", + "windows_x86_64_gnu", + "windows_x86_64_gnullvm", + "windows_x86_64_msvc", +] + +[[package]] +name = "windows_aarch64_gnullvm" +version = "0.42.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "597a5118570b68bc08d8d59125332c54f1ba9d9adeedeef5b99b02ba2b0698f8" + +[[package]] +name = "windows_aarch64_msvc" +version = "0.42.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e08e8864a60f06ef0d0ff4ba04124db8b0fb3be5776a5cd47641e942e58c4d43" + +[[package]] +name = "windows_i686_gnu" +version = "0.42.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c61d927d8da41da96a81f029489353e68739737d3beca43145c8afec9a31a84f" + +[[package]] +name = "windows_i686_msvc" +version = "0.42.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "44d840b6ec649f480a41c8d80f9c65108b92d89345dd94027bfe06ac444d1060" + +[[package]] +name = "windows_x86_64_gnu" +version = "0.42.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8de912b8b8feb55c064867cf047dda097f92d51efad5b491dfb98f6bbb70cb36" + +[[package]] +name = "windows_x86_64_gnullvm" +version = "0.42.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "26d41b46a36d453748aedef1486d5c7a85db22e56aff34643984ea85514e94a3" + +[[package]] +name = "windows_x86_64_msvc" +version = "0.42.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9aec5da331524158c6d1a4ac0ab1541149c0b9505fde06423b02f5ef0106b9f0" + [[package]] name = "yasna" version = "0.2.2" diff --git a/ct.sh b/ct.sh index c3689b25f..7ef9d91f3 100755 --- a/ct.sh +++ b/ct.sh @@ -33,6 +33,7 @@ if [ "$TRAVIS_RUST_VERSION" == "stable" ] || [ "$TRAVIS_RUST_VERSION" == "beta" cargo test --features pkcs12 --target $TARGET cargo test --features pkcs12_rc2 --target $TARGET cargo test --features dsa --target $TARGET + cargo test --test async_session --features=async-rt --target $TARGET # If zlib is installed, test the zlib feature if [ -n "$ZLIB_INSTALLED" ]; then diff --git a/mbedtls-sys/Cargo.toml b/mbedtls-sys/Cargo.toml index e49ade9cd..2e21c9497 100644 --- a/mbedtls-sys/Cargo.toml +++ b/mbedtls-sys/Cargo.toml @@ -42,9 +42,8 @@ quote = "1.0.9" # * strstr/strlen/strncpy/strncmp/strcmp/snprintf # * memmove/memcpy/memcmp/memset # * rand/printf (used only for self tests. optionally use custom_printf) -default = ["std", "debug", "threading", "zlib", "time", "aesni", "padlock", "legacy_protocols"] -std = ["debug"] # deprecated automatic enabling of debug, can be removed on major version bump -debug = [] +default = ["std", "threading", "zlib", "time", "aesni", "padlock", "legacy_protocols"] +std = [] custom_printf = [] custom_has_support = [] aes_alt = [] diff --git a/mbedtls/Cargo.toml b/mbedtls/Cargo.toml index df79b5e2a..e4fe7403a 100644 --- a/mbedtls/Cargo.toml +++ b/mbedtls/Cargo.toml @@ -29,13 +29,14 @@ bit-vec = { version = "0.5", optional = true } cbc = { version = "0.1.2", optional = true } rc2 = { version = "0.8.1", optional = true } cfg-if = "1.0.0" +tokio = { version = "1.16.1", optional = true } [target.x86_64-fortanix-unknown-sgx.dependencies] rs-libc = "0.2.0" chrono = "0.4" [dependencies.mbedtls-sys-auto] -version = "2.25.0" +version = "2.28.0" default-features = false features = ["custom_printf", "trusted_cert_callback", "threading"] path = "../mbedtls-sys" @@ -47,6 +48,9 @@ serde_cbor = "0.6" hex = "0.3" matches = "0.1.8" hyper = { version = "0.10.16", default-features = false } +async-stream = "0.3.0" +futures = "0.3" +tracing = "0.1" [build-dependencies] cc = "1.0" @@ -55,7 +59,7 @@ cc = "1.0" # Features are documented in the README default = ["std", "aesni", "time", "padlock"] std = ["byteorder/std", "mbedtls-sys-auto/std", "serde/std", "yasna"] -debug = ["mbedtls-sys-auto/debug"] +debug = [] no_std_deps = ["spin", "serde/alloc"] force_aesni_support = ["mbedtls-sys-auto/custom_has_support", "mbedtls-sys-auto/aes_alt", "aesni"] mpi_force_c_code = ["mbedtls-sys-auto/mpi_force_c_code"] @@ -68,6 +72,8 @@ dsa = ["std", "yasna", "num-bigint", "bit-vec"] pkcs12 = ["std", "yasna"] pkcs12_rc2 = ["pkcs12", "rc2", "cbc"] legacy_protocols = ["mbedtls-sys-auto/legacy_protocols"] +async = ["std", "tokio","tokio/net","tokio/io-util", "tokio/macros"] +async-rt = ["async", "tokio/rt", "tokio/sync", "tokio/rt-multi-thread"] [[example]] name = "client" @@ -100,3 +106,8 @@ required-features = ["std"] [[test]] name = "hyper" required-features = ["std"] + +[[test]] +name = "async_session" +path = "tests/async_session.rs" +required-features = ["async-rt"] diff --git a/mbedtls/src/pk/mod.rs b/mbedtls/src/pk/mod.rs index b71daffef..52135e4e9 100644 --- a/mbedtls/src/pk/mod.rs +++ b/mbedtls/src/pk/mod.rs @@ -163,7 +163,7 @@ define!( // B. Verifying thread safety. // // 1. Calls towards the specific Pk implementation are done via function pointers. -// +// // - Example call towards Pk: // ../../../mbedtls-sys/vendor/library/ssl_srv.c:3707 - mbedtls_pk_decrypt( private_key, p, len, ... // - This calls a generic function pointer via: @@ -174,7 +174,7 @@ define!( // - The function pointers are defined via function: // ../../../mbedtls-sys/vendor/crypto/library/pk.c:115 - mbedtls_pk_info_from_type // - They are as follows: mbedtls_rsa_info / mbedtls_eckey_info / mbedtls_ecdsa_info -// - These are defined in: +// - These are defined in: // ../../../mbedtls-sys/vendor/crypto/library/pk_wrap.c:196 // // C. Checking types one by one. @@ -222,7 +222,7 @@ define!( // mbedtls_ecp_mul_restartable: ../../../mbedtls-sys/vendor/crypto/library/ecp.c:2351 // MBEDTLS_ECP_INTERNAL_ALT is not defined. (otherwise it might not be safe depending on ecp_init/ecp_free) ../../../mbedtls-sys/build/config.rs:131 // Passes as const to: mbedtls_ecp_check_privkey / mbedtls_ecp_check_pubkey / mbedtls_ecp_get_type( grp -// +// // - Ignored due to not defined: ecdsa_verify_rs_wrap, ecdsa_sign_rs_wrap, ecdsa_rs_alloc, ecdsa_rs_free // (Undefined - MBEDTLS_ECP_RESTARTABLE - ../../../mbedtls-sys/build/config.rs:173) // @@ -927,7 +927,7 @@ impl Pk { if hash.len() == 0 || sig.len() == 0 { return Err(Error::PkBadInputData) } - + unsafe { pk_verify( &mut self.inner, @@ -1297,7 +1297,7 @@ iy6KC991zzvaWY/Ys+q/84Afqa+0qJKQnPuy/7F5GkVdQA/lfbhi let mut dummy_sig = []; assert_eq!(pk.sign(digest, data, &mut dummy_sig, &mut crate::test_support::rand::test_rng()).unwrap_err(), Error::PkBadInputData); assert_eq!(pk.sign(digest, &[], &mut signature, &mut crate::test_support::rand::test_rng()).unwrap_err(), Error::PkBadInputData); - + assert_eq!(pk.sign_deterministic(digest, data, &mut dummy_sig, &mut crate::test_support::rand::test_rng()).unwrap_err(), Error::PkBadInputData); assert_eq!(pk.sign_deterministic(digest, &[], &mut signature, &mut crate::test_support::rand::test_rng()).unwrap_err(), Error::PkBadInputData); diff --git a/mbedtls/src/ssl/async_utils.rs b/mbedtls/src/ssl/async_utils.rs new file mode 100644 index 000000000..63ea35eae --- /dev/null +++ b/mbedtls/src/ssl/async_utils.rs @@ -0,0 +1,143 @@ +/* Copyright (c) Fortanix, Inc. + * + * Licensed under the GNU General Public License, version 2 or the Apache License, Version + * 2.0 , at your + * option. This file may not be copied, modified, or distributed except + * according to those terms. */ + +#![cfg(all(feature = "std", feature = "async"))] + +use std::cell::Cell; +use std::ptr::null_mut; +use std::rc::Rc; +use std::task::{Context as TaskContext, Poll}; + + +#[cfg(feature = "std")] +use std::io::{Error as IoError, Result as IoResult, ErrorKind as IoErrorKind}; + + +#[derive(Clone)] +pub struct ErasedContext(Rc>); + +unsafe impl Send for ErasedContext {} + +impl ErasedContext { + pub fn new() -> Self { + Self(Rc::new(Cell::new(null_mut()))) + } + + pub unsafe fn get(&self) -> Option<&mut TaskContext<'_>> { + let ptr = self.0.get(); + if ptr.is_null() { + None + } else { + Some(&mut *(ptr as *mut _)) + } + } + + pub fn set(&self, cx: &mut TaskContext<'_>) { + self.0.set(cx as *mut _ as *mut ()); + } + + pub fn clear(&self) { + self.0.set(null_mut()); + } +} + +// mbedtls_ssl_write() has some weird semantics w.r.t non-blocking I/O: +// +// > When this function returns MBEDTLS_ERR_SSL_WANT_WRITE/READ, it must be +// > called later **with the same arguments**, until it returns a value greater +// > than or equal to 0. When the function returns MBEDTLS_ERR_SSL_WANT_WRITE +// > there may be some partial data in the output buffer, however this is not +// > yet sent. +// +// WriteTracker is used to ensure we pass the same data in that scenario. +// +// Reference: +// https://tls.mbed.org/api/ssl_8h.html#a5bbda87d484de82df730758b475f32e5 +pub struct WriteTracker { + pending: Option>, +} + +struct DigestAndLen { + #[cfg(debug_assertions)] + digest: [u8; 20], // SHA-1 + len: usize, +} + +impl WriteTracker { + fn new() -> Self { + WriteTracker { + pending: None, + } + } + + #[cfg(debug_assertions)] + fn digest(buf: &[u8]) -> [u8; 20] { + use crate::hash::{Md, Type}; + let mut out = [0u8; 20]; + let res = Md::hash(Type::Sha1, buf, &mut out[..]); + assert_eq!(res, Ok(out.len())); + out + } + + pub fn adjust_buf<'a>(&self, buf: &'a [u8]) -> IoResult<&'a [u8]> { + match self.pending.as_ref() { + None => Ok(buf), + Some(pending) => { + if pending.len <= buf.len() { + let buf = &buf[..pending.len]; + + // We only do this check in debug mode since it's an expensive check. + #[cfg(debug_assertions)] + if Self::digest(buf) == pending.digest { + return Ok(buf); + } + + #[cfg(not(debug_assertions))] + return Ok(buf); + } + Err(IoError::new( + IoErrorKind::Other, + "mbedtls expects the same data if the previous call to poll_write() returned Poll::Pending" + )) + }, + } + } + + pub fn post_write(&mut self, buf: &[u8], res: &Poll>) { + match res { + &Poll::Pending => { + if self.pending.is_none() { + self.pending = Some(Box::new(DigestAndLen { + #[cfg(debug_assertions)] + digest: Self::digest(buf), + len: buf.len(), + })); + } + }, + _ => { + self.pending = None; + } + } + } +} + +pub struct IoAdapter { + pub inner: S, + pub ecx: ErasedContext, + pub write_tracker: WriteTracker, +} + +impl IoAdapter { + pub fn new(stream: S) -> Self { + Self { + inner: stream, + ecx: ErasedContext::new(), + write_tracker: WriteTracker::new(), + } + } +} diff --git a/mbedtls/src/ssl/context.rs b/mbedtls/src/ssl/context.rs index 8ce2d6574..bf75c4581 100644 --- a/mbedtls/src/ssl/context.rs +++ b/mbedtls/src/ssl/context.rs @@ -14,6 +14,16 @@ use { std::sync::Arc, }; +#[cfg(all(feature = "std", feature = "async"))] +use { + std::io::ErrorKind as IoErrorKind, + std::marker::Unpin, + std::pin::Pin, + std::task::{Context as TaskContext, Poll}, + tokio::io::{AsyncRead, AsyncWrite, ReadBuf}, + crate::ssl::async_utils::IoAdapter, +}; + use mbedtls_sys::types::raw_types::{c_int, c_uchar, c_void}; use mbedtls_sys::types::size_t; use mbedtls_sys::*; @@ -186,7 +196,7 @@ define!( struct HandshakeContext { handshake_ca_cert: Option>>, handshake_crl: Option>, - + handshake_cert: Vec>>, handshake_pk: Vec>, }; @@ -200,10 +210,10 @@ define!( pub struct Context { // Base structure used in SNI callback where we cannot determine the io type. inner: HandshakeContext, - + // config is used read-only for multiple contexts and is immutable once configured. - config: Arc, - + config: Arc, + // Must be held in heap and pointer to it as pointer is sent to MbedSSL and can't be re-allocated. io: Option>, @@ -230,7 +240,7 @@ impl<'a, T> Into<*mut ssl_context> for &'a mut Context { impl Context { pub fn new(config: Arc) -> Self { let mut inner = ssl_context::default(); - + unsafe { ssl_init(&mut inner); ssl_setup(&mut inner, (&*config).into()); @@ -241,7 +251,7 @@ impl Context { inner, handshake_ca_cert: None, handshake_crl: None, - + handshake_cert: vec![], handshake_pk: vec![], }, @@ -368,7 +378,7 @@ impl Context { pub fn config(&self) -> &Arc { &self.config } - + pub fn close(&mut self) { unsafe { ssl_close_notify(self.into()); @@ -376,15 +386,15 @@ impl Context { self.io = None; } } - + pub fn io(&self) -> Option<&T> { self.io.as_ref().map(|v| &**v) } - + pub fn io_mut(&mut self) -> Option<&mut T> { self.io.as_mut().map(|v| &mut **v) } - + /// Return the minor number of the negotiated TLS version pub fn minor_version(&self) -> i32 { self.handle().minor_ver @@ -416,7 +426,7 @@ impl Context { // Session specific functions - + /// Return the 16-bit ciphersuite identifier. /// All assigned ciphersuites are listed by the IANA in /// @@ -424,7 +434,7 @@ impl Context { if self.handle().session.is_null() { return Err(Error::SslBadInputData); } - + Ok(unsafe { self.handle().session.as_ref().unwrap().ciphersuite as u16 }) } @@ -561,12 +571,12 @@ impl HandshakeContext { self.handshake_ca_cert = None; self.handshake_crl = None; } - + pub fn set_authmode(&mut self, am: AuthMode) -> Result<()> { if self.inner.handshake as *const _ == ::core::ptr::null() { return Err(Error::SslBadInputData); } - + unsafe { ssl_set_hs_authmode(self.into(), am as i32) } Ok(()) } @@ -620,6 +630,233 @@ impl HandshakeContext { } } +#[cfg(all(feature = "std", feature = "async"))] +pub type AsyncContext = Context>; + +#[cfg(all(feature = "std", feature = "async"))] +pub trait IoAsyncCallback { + unsafe extern "C" fn call_recv_async(user_data: *mut c_void, data: *mut c_uchar, len: size_t) -> c_int where Self: Sized; + unsafe extern "C" fn call_send_async(user_data: *mut c_void, data: *const c_uchar, len: size_t) -> c_int where Self: Sized; +} + +#[cfg(all(feature = "std", feature = "async"))] +impl IoAsyncCallback for IoAdapter { + unsafe extern "C" fn call_recv_async(user_data: *mut c_void, data: *mut c_uchar, len: size_t) -> c_int { + let len = if len > (c_int::max_value() as size_t) { + c_int::max_value() as size_t + } else { + len + }; + + let adapter = &mut *(user_data as *mut IoAdapter); + + if let Some(cx) = adapter.ecx.get() { + let mut buf = ReadBuf::new(::core::slice::from_raw_parts_mut(data, len)); + let stream = Pin::new(&mut adapter.inner); + + match stream.poll_read(cx, &mut buf) { + Poll::Ready(Ok(())) => buf.filled().len() as c_int, + Poll::Ready(Err(_)) => ::mbedtls_sys::ERR_NET_RECV_FAILED, + Poll::Pending => ::mbedtls_sys::ERR_SSL_WANT_READ, + } + } else { + ::mbedtls_sys::ERR_NET_RECV_FAILED + } + } + + unsafe extern "C" fn call_send_async(user_data: *mut c_void, data: *const c_uchar, len: size_t) -> c_int { + let len = if len > (c_int::max_value() as size_t) { + c_int::max_value() as size_t + } else { + len + }; + + let adapter = &mut *(user_data as *mut IoAdapter); + + if let Some(cx) = adapter.ecx.get() { + let stream = Pin::new(&mut adapter.inner); + + match stream.poll_write(cx, ::core::slice::from_raw_parts(data, len)) { + Poll::Ready(Ok(i)) => i as c_int, + Poll::Ready(Err(_)) => ::mbedtls_sys::ERR_NET_RECV_FAILED, + Poll::Pending => ::mbedtls_sys::ERR_SSL_WANT_WRITE, + } + } else { + ::mbedtls_sys::ERR_NET_RECV_FAILED + } + } +} + +#[cfg(all(feature = "std", feature = "async"))] +struct HandshakeFuture<'a, T>(&'a mut Context::>); + +#[cfg(all(feature = "std", feature = "async"))] +impl std::future::Future for HandshakeFuture<'_, T> { + type Output = Result<()>; + fn poll(mut self: Pin<&mut Self>, ctx: &mut TaskContext) -> std::task::Poll { + self.0.io_mut().ok_or(Error::NetInvalidContext)? + .ecx.set(ctx); + + let result = match self.0.handshake() { + Err(Error::SslWantRead) | + Err(Error::SslWantWrite) => { + Poll::Pending + }, + Err(e) => Poll::Ready(Err(e)), + Ok(()) => Poll::Ready(Ok(())) + }; + + self.0.io_mut().map(|v| v.ecx.clear()); + + result + } +} + +#[cfg(all(feature = "std", feature = "async"))] +impl AsyncContext { + pub async fn accept_async(config: Arc, io: T, hostname: Option<&str>) -> IoResult> { + let mut context = Self::new(config); + context.establish_async(io, hostname).await.map_err(|e| crate::private::error_to_io_error(e))?; + Ok(context) + } + + pub async fn establish_async(&mut self, io: T, hostname: Option<&str>) -> Result<()> { + unsafe { + let mut io = Box::new(IoAdapter::new(io)); + + ssl_session_reset(self.into()).into_result()?; + self.set_hostname(hostname)?; + + let ptr = &mut *io as *mut _ as *mut c_void; + ssl_set_bio( + self.into(), + ptr, + Some(IoAdapter::::call_send_async), + Some(IoAdapter::::call_recv_async), + None, + ); + + self.io = Some(io); + self.inner.reset_handshake(); + } + + HandshakeFuture(self).await + } +} + +#[cfg(all(feature = "std", feature = "async"))] +impl AsyncRead for Context> { + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut TaskContext<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll> { + + if self.handle().session.is_null() { + return Poll::Ready(Err(IoError::new(IoErrorKind::Other, "stream has been shutdown"))); + } + + self.io_mut().ok_or(IoError::new(IoErrorKind::Other, "stream has been shutdown"))? + .ecx.set(cx); + + let result = match unsafe { ssl_read((&mut *self).into(), buf.initialize_unfilled().as_mut_ptr(), buf.initialize_unfilled().len()).into_result() } { + Err(Error::SslPeerCloseNotify) => Poll::Ready(Ok(())), + Err(Error::SslWantRead) => Poll::Pending, + Err(e) => Poll::Ready(Err(crate::private::error_to_io_error(e))), + Ok(i) => { + buf.advance(i as usize); + Poll::Ready(Ok(())) + } + }; + + self.io_mut().ok_or(IoError::new(IoErrorKind::Other, "stream has been shutdown"))? + .ecx.clear(); + + result + } +} + +#[cfg(all(feature = "std", feature = "async"))] +impl AsyncWrite for Context> { + fn poll_write( + mut self: Pin<&mut Self>, + cx: &mut TaskContext<'_>, + buf: &[u8], + ) -> Poll> { + + if self.handle().session.is_null() { + return Poll::Ready(Err(IoError::new(IoErrorKind::Other, "stream has been shutdown"))); + } + + let buf = { + let io = self.io_mut().ok_or(IoError::new(IoErrorKind::Other, "stream has been shutdown"))?; + io.ecx.set(cx); + io.write_tracker.adjust_buf(buf) + }?; + + + self.io_mut().ok_or(IoError::new(IoErrorKind::Other, "stream has been shutdown"))? + .ecx.set(cx); + + let result = match unsafe { ssl_write((&mut *self).into(), buf.as_ptr(), buf.len()).into_result() } { + Err(Error::SslPeerCloseNotify) => Poll::Ready(Ok(0)), + Err(Error::SslWantWrite) => Poll::Pending, + Err(e) => Poll::Ready(Err(crate::private::error_to_io_error(e))), + Ok(i) => Poll::Ready(Ok(i as usize)) + }; + + let io = self.io_mut().ok_or(IoError::new(IoErrorKind::Other, "stream has been shutdown"))?; + + io.ecx.clear(); + io.write_tracker.post_write(buf, &result); + + cx.waker().clone().wake(); + + result + } + + fn poll_flush(mut self: Pin<&mut Self>, cx: &mut TaskContext<'_>) -> Poll> { + // We can only flush the actual IO here. + // To flush mbedtls we need writes with the same buffer until complete. + let io = &mut self.io_mut().ok_or(IoError::new(IoErrorKind::Other, "stream has been shutdown"))? + .inner; + let stream = Pin::new(io); + stream.poll_flush(cx) + } + + fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut TaskContext<'_>) -> Poll> { + if self.handle().session.is_null() { + return Poll::Ready(Err(IoError::new(IoErrorKind::Other, "stream has been shutdown"))); + } + + self.io_mut().ok_or(IoError::new(IoErrorKind::Other, "stream has been shutdown"))? + .ecx.set(cx); + + let result = match unsafe { ssl_close_notify((&mut *self).into()).into_result() } { + Err(Error::SslWantRead) | + Err(Error::SslWantWrite) => Poll::Pending, + Err(e) => { + unsafe { ssl_set_bio((&mut *self).into(), ::core::ptr::null_mut(), None, None, None); } + self.io = None; + Poll::Ready(Err(crate::private::error_to_io_error(e))) + } + Ok(0) => { + unsafe { ssl_set_bio((&mut *self).into(), ::core::ptr::null_mut(), None, None, None); } + self.io = None; + Poll::Ready(Ok(())) + } + Ok(v) => { + unsafe { ssl_set_bio((&mut *self).into(), ::core::ptr::null_mut(), None, None, None); } + self.io = None; + Poll::Ready(Err(IoError::new(IoErrorKind::Other, format!("unexpected result from ssl_close_notify: {}", v)))) + } + }; + + self.io_mut().map(|v| v.ecx.clear()); + result + } +} + #[cfg(test)] mod tests { #[cfg(feature = "std")] @@ -627,7 +864,7 @@ mod tests { use crate::ssl::context::{HandshakeContext, Context}; use crate::tests::TestTrait; - + #[test] fn handshakecontext_sync() { assert!(!TestTrait::::new().impls_trait(), "HandshakeContext must be !Sync"); @@ -643,7 +880,7 @@ mod tests { unimplemented!() } } - + #[cfg(feature = "std")] impl Write for NonSendStream { fn write(&mut self, _: &[u8]) -> IoResult { @@ -665,7 +902,7 @@ mod tests { unimplemented!() } } - + #[cfg(feature = "std")] impl Write for SendStream { fn write(&mut self, _: &[u8]) -> IoResult { diff --git a/mbedtls/src/ssl/mod.rs b/mbedtls/src/ssl/mod.rs index 1bfc078cf..40ebd8007 100644 --- a/mbedtls/src/ssl/mod.rs +++ b/mbedtls/src/ssl/mod.rs @@ -11,6 +11,7 @@ pub mod config; pub mod context; pub mod cookie; pub mod ticket; +pub mod async_utils; #[doc(inline)] pub use self::ciphersuites::CipherSuite; @@ -22,3 +23,6 @@ pub use self::context::Context; pub use self::cookie::CookieContext; #[doc(inline)] pub use self::ticket::TicketContext; +#[cfg(all(feature = "std", feature = "async"))] +#[doc(inline)] +pub use self::context::AsyncContext; diff --git a/mbedtls/src/wrapper_macros.rs b/mbedtls/src/wrapper_macros.rs index 8a3d916ff..379f844ec 100644 --- a/mbedtls/src/wrapper_macros.rs +++ b/mbedtls/src/wrapper_macros.rs @@ -109,6 +109,32 @@ macro_rules! define_enum { } macro_rules! define_struct { + { define_custom $(#[$m:meta])* struct $name:ident $(lifetime $l:tt)* inner $inner:ident members $($(#[$mm:meta])* $member:ident: $member_type:ty,)* } => { + as_item!( + #[allow(dead_code)] + $(#[$m])* + pub struct $name<$($l)*> { + $($(#[$mm])* $member: $member_type,)* + } + ); + + as_item!( + #[allow(dead_code)] + impl<$($l)*> $name<$($l)*> { + pub(crate) fn handle(&self) -> &::mbedtls_sys::$inner { + self.inner.handle() + } + + pub(crate) fn handle_mut(&mut self) -> &mut ::mbedtls_sys::$inner { + self.inner.handle_mut() + } + } + ); + + as_item!( + unsafe impl<$($l)*> Send for $name<$($l)*> {} + ); + }; { define $(#[$m:meta])* struct $name:ident $(lifetime $l:tt)* inner $inner:ident members $($(#[$mm:meta])* $member:ident: $member_type:ty,)* } => { as_item!( #[allow(dead_code)] diff --git a/mbedtls/tests/async_session.rs b/mbedtls/tests/async_session.rs new file mode 100644 index 000000000..98b113a78 --- /dev/null +++ b/mbedtls/tests/async_session.rs @@ -0,0 +1,294 @@ +/* Copyright (c) Fortanix, Inc. + * + * Licensed under the GNU General Public License, version 2 or the Apache License, Version + * 2.0 , at your + * option. This file may not be copied, modified, or distributed except + * according to those terms. */ + +#![cfg(not(target_env = "sgx"))] +extern crate mbedtls; + +use std::sync::Arc; +use std::pin::Pin; +use std::future::Future; + +use mbedtls::pk::Pk; +use mbedtls::rng::CtrDrbg; +use mbedtls::ssl::config::{Endpoint, Preset, Transport}; +use mbedtls::ssl::{Config, Context, Version}; +use mbedtls::x509::{Certificate, VerifyError}; +use mbedtls::Error; +use mbedtls::Result as TlsResult; + +use tokio::io::{AsyncReadExt, AsyncWriteExt}; +use tokio::net::TcpStream; + +mod support; +use support::entropy::entropy_new; +use support::keys; + +use mbedtls::ssl::async_utils::IoAdapter; + +async fn client( + conn: TcpStream, + min_version: Version, + max_version: Version, + exp_version: Option) -> TlsResult<()> { + + let entropy = Arc::new(entropy_new()); + let rng = Arc::new(CtrDrbg::new(entropy, None)?); + let cacert = Arc::new(Certificate::from_pem_multiple(keys::ROOT_CA_CERT.as_bytes())?); + let expected_flags = VerifyError::empty(); + #[cfg(feature = "time")] + let expected_flags = expected_flags | VerifyError::CERT_EXPIRED; + { + let verify_callback = move |crt: &Certificate, depth: i32, verify_flags: &mut VerifyError| { + + match (crt.subject().unwrap().as_str(), depth, &verify_flags) { + ("CN=RootCA", 1, _) => (), + (keys::EXPIRED_CERT_SUBJECT, 0, flags) => assert_eq!(**flags, expected_flags), + _ => assert!(false), + }; + + verify_flags.remove(VerifyError::CERT_EXPIRED); //we check the flags at the end, + //so removing this flag here prevents the connections from failing with VerifyError + Ok(()) + }; + let mut config = Config::new(Endpoint::Client, Transport::Stream, Preset::Default); + config.set_rng(rng); + config.set_verify_callback(verify_callback); + config.set_ca_list(cacert, None); + config.set_min_version(min_version)?; + config.set_max_version(max_version)?; + let mut ctx = Context::new(Arc::new(config)); + + match ctx.establish_async(conn, None).await { + Ok(()) => { + assert_eq!(ctx.version(), exp_version.unwrap()); + } + Err(e) => { + match e { + Error::SslBadHsProtocolVersion => {assert!(exp_version.is_none())}, + Error::SslFatalAlertMessage => {}, + e => panic!("Unexpected error {}", e), + }; + return Ok(()); + } + }; + + let ciphersuite = ctx.ciphersuite().unwrap(); + ctx + .write_all(format!("Client2Server {:4x}", ciphersuite).as_bytes()) + .await + .unwrap(); + let mut buf = [0u8; 13 + 4 + 1]; + ctx.read_exact(&mut buf).await.unwrap(); + assert_eq!(&buf, format!("Server2Client {:4x}", ciphersuite).as_bytes()); + } // drop verify_callback, releasing borrow of verify_args + Ok(()) +} + +async fn server( + conn: TcpStream, + min_version: Version, + max_version: Version, + exp_version: Option, +) -> TlsResult<()> { + let entropy = entropy_new(); + let rng = Arc::new(CtrDrbg::new(Arc::new(entropy), None)?); + let cert = Arc::new(Certificate::from_pem_multiple(keys::EXPIRED_CERT.as_bytes())?); + let key = Arc::new(Pk::from_private_key(keys::EXPIRED_KEY.as_bytes(), None)?); + let mut config = Config::new(Endpoint::Server, Transport::Stream, Preset::Default); + config.set_rng(rng); + config.set_min_version(min_version)?; + config.set_max_version(max_version)?; + config.push_cert(cert, key)?; + let mut ctx = Context::new(Arc::new(config)); + + match ctx.establish_async(conn, None).await { + Ok(()) => { + assert_eq!(ctx.version(), exp_version.unwrap()); + } + Err(e) => { + match e { + // client just closes connection instead of sending alert + Error::NetSendFailed => {assert!(exp_version.is_none())}, + Error::SslBadHsProtocolVersion => {}, + e => panic!("Unexpected error {}", e), + }; + return Ok(()); + } + }; + + //assert_eq!(ctx.get_alpn_protocol().unwrap().unwrap(), None); + let ciphersuite = ctx.ciphersuite().unwrap(); + ctx + .write_all(format!("Server2Client {:4x}", ciphersuite).as_bytes()) + .await + .unwrap(); + let mut buf = [0u8; 13 + 1 + 4]; + ctx.read_exact(&mut buf).await.unwrap(); + + assert_eq!(&buf, format!("Client2Server {:4x}", ciphersuite).as_bytes()); + Ok(()) +} + +async fn with_client(conn: TcpStream, f: F) -> R +where + F: FnOnce(Context>) -> Pin + Send>>, +{ + let entropy = Arc::new(entropy_new()); + let rng = Arc::new(CtrDrbg::new(entropy, None).unwrap()); + let cacert = Arc::new(Certificate::from_pem_multiple(keys::ROOT_CA_CERT.as_bytes()).unwrap()); + + let verify_callback = move |_crt: &Certificate, _depth: i32, verify_flags: &mut VerifyError| { + verify_flags.remove(VerifyError::CERT_EXPIRED); + Ok(()) + }; + + let mut config = Config::new(Endpoint::Client, Transport::Stream, Preset::Default); + config.set_rng(rng); + config.set_verify_callback(verify_callback); + config.set_ca_list(cacert, None); + + let mut ctx = Context::new(Arc::new(config)); + ctx.establish_async(conn, None).await.unwrap(); + + f(ctx).await +} + +async fn with_server(conn: TcpStream, f: F) -> R +where + F: FnOnce(Context>) -> Pin + Send>>, +{ + let entropy = Arc::new(entropy_new()); + let rng = Arc::new(CtrDrbg::new(entropy, None).unwrap()); + let cert = Arc::new(Certificate::from_pem_multiple(keys::EXPIRED_CERT.as_bytes()).unwrap()); + let key = Arc::new(Pk::from_private_key(keys::EXPIRED_KEY.as_bytes(), None).unwrap()); + + let mut config = Config::new(Endpoint::Server, Transport::Stream, Preset::Default); + config.set_rng(rng); + config.push_cert(cert, key).unwrap(); + let mut ctx = Context::new(Arc::new(config)); + + ctx.establish_async(conn, None).await.unwrap(); + + f(ctx).await +} + +#[cfg(unix)] +mod test { + use tokio::io::{AsyncReadExt, AsyncWriteExt}; + + #[tokio::test] + async fn asyncsession_client_server_test() { + use mbedtls::ssl::Version; + + #[derive(Copy,Clone)] + struct TestConfig { + min_c: Version, + max_c: Version, + min_s: Version, + max_s: Version, + exp_ver: Option, + } + + impl TestConfig { + pub fn new(min_c: Version, max_c: Version, min_s: Version, max_s: Version, exp_ver: Option) -> Self { + TestConfig { min_c, max_c, min_s, max_s, exp_ver } + } + } + + let test_configs = [ + TestConfig::new(Version::Ssl3, Version::Ssl3, Version::Ssl3, Version::Ssl3, Some(Version::Ssl3)), + TestConfig::new(Version::Ssl3, Version::Tls1_2, Version::Ssl3, Version::Ssl3, Some(Version::Ssl3)), + TestConfig::new(Version::Tls1_0, Version::Tls1_0, Version::Tls1_0, Version::Tls1_0, Some(Version::Tls1_0)), + TestConfig::new(Version::Tls1_1, Version::Tls1_1, Version::Tls1_1, Version::Tls1_1, Some(Version::Tls1_1)), + TestConfig::new(Version::Tls1_2, Version::Tls1_2, Version::Tls1_2, Version::Tls1_2, Some(Version::Tls1_2)), + TestConfig::new(Version::Tls1_0, Version::Tls1_2, Version::Tls1_0, Version::Tls1_2, Some(Version::Tls1_2)), + TestConfig::new(Version::Tls1_2, Version::Tls1_2, Version::Tls1_0, Version::Tls1_2, Some(Version::Tls1_2)), + TestConfig::new(Version::Tls1_0, Version::Tls1_1, Version::Tls1_2, Version::Tls1_2, None) + ]; + + for config in &test_configs { + let min_c = config.min_c; + let max_c = config.max_c; + let min_s = config.min_s; + let max_s = config.max_s; + let exp_ver = config.exp_ver; + + if (max_c < Version::Tls1_2 || max_s < Version::Tls1_2) && !cfg!(feature = "legacy_protocols") { + continue; + } + + let (c, s) = crate::support::net::create_tcp_pair_async().unwrap(); + let c = tokio::spawn(super::client(c, min_c, max_c, exp_ver.clone())); + let s = tokio::spawn(super::server(s, min_s, max_s, exp_ver)); + + c.await.unwrap().unwrap(); + s.await.unwrap().unwrap(); + } + } + + #[tokio::test] + async fn asyncsession_shutdown1() { + let (c, s) = crate::support::net::create_tcp_pair_async().unwrap(); + + let c = tokio::spawn(super::with_client(c, |mut session| Box::pin(async move { + session.shutdown().await.unwrap(); + }))); + + let s = tokio::spawn(super::with_server(s, |mut session| Box::pin(async move { + let mut buf = [0u8; 1]; + match session.read(&mut buf).await { + Ok(0) | Err(_) => {} + _ => panic!("expected no data"), + } + }))); + + c.await.unwrap(); + s.await.unwrap(); + } + + #[tokio::test] + async fn asyncsession_shutdown2() { + let (c, s) = crate::support::net::create_tcp_pair_async().unwrap(); + + let c = tokio::spawn(super::with_client(c, |mut session| Box::pin(async move { + let mut buf = [0u8; 5]; + session.read_exact(&mut buf).await.unwrap(); + assert_eq!(&buf, b"hello"); + match session.read(&mut buf).await { + Ok(0) | Err(_) => {} + _ => panic!("expected no data"), + } + }))); + + let s = tokio::spawn(super::with_server(s, |mut session| Box::pin(async move { + session.write_all(b"hello").await.unwrap(); + session.shutdown().await.unwrap(); + }))); + + c.await.unwrap(); + s.await.unwrap(); + } + + #[tokio::test] + async fn asyncsession_shutdown3() { + let (c, s) = crate::support::net::create_tcp_pair_async().unwrap(); + + let c = tokio::spawn(super::with_client(c, |mut session| Box::pin(async move { + session.shutdown().await + }))); + + let s = tokio::spawn(super::with_server(s, |mut session| Box::pin(async move { + session.shutdown().await + }))); + + match (c.await.unwrap(), s.await.unwrap()) { + (Err(_), Err(_)) => panic!("at least one should succeed"), + _ => {} + } + } +} diff --git a/mbedtls/tests/ssl_conf_ca_cb.rs b/mbedtls/tests/ssl_conf_ca_cb.rs index 880a699bc..c490f6bb2 100644 --- a/mbedtls/tests/ssl_conf_ca_cb.rs +++ b/mbedtls/tests/ssl_conf_ca_cb.rs @@ -18,15 +18,13 @@ use mbedtls::pk::Pk; use mbedtls::rng::CtrDrbg; use mbedtls::ssl::config::{Endpoint, Preset, Transport}; use mbedtls::ssl::{Config, Context}; -use mbedtls::x509::{Certificate}; +use mbedtls::x509::Certificate; use mbedtls::Result as TlsResult; use mbedtls::ssl::config::CaCallback; mod support; use support::entropy::entropy_new; -use mbedtls::alloc::{List as MbedtlsList}; - fn client(conn: TcpStream, ca_callback: F) -> TlsResult<()> where F: CaCallback + Send + 'static, @@ -60,9 +58,9 @@ mod test { use std::thread; use crate::support::net::create_tcp_pair; use crate::support::keys; - use mbedtls::x509::{Certificate}; use mbedtls::Error; - + use mbedtls::alloc::List as MbedtlsList; + // This callback should accept any valid self-signed certificate fn self_signed_ca_callback(child: &MbedtlsList) -> TlsResult> { Ok(child.clone()) diff --git a/mbedtls/tests/support/net.rs b/mbedtls/tests/support/net.rs index f061b32af..9446d7f7f 100644 --- a/mbedtls/tests/support/net.rs +++ b/mbedtls/tests/support/net.rs @@ -26,3 +26,14 @@ pub fn create_tcp_pair() -> IoResult<(TcpStream, TcpStream)> { } } } + +#[cfg(feature = "tokio")] +pub fn create_tcp_pair_async() -> IoResult<(tokio::net::TcpStream, tokio::net::TcpStream)> { + let (c, s) = create_tcp_pair()?; + c.set_nonblocking(true)?; + s.set_nonblocking(true)?; + Ok(( + tokio::net::TcpStream::from_std(c)?, + tokio::net::TcpStream::from_std(s)?, + )) +}