diff --git a/.github/workflows/rust.yml b/.github/workflows/rust.yml index ad07878..91b274a 100644 --- a/.github/workflows/rust.yml +++ b/.github/workflows/rust.yml @@ -36,7 +36,7 @@ jobs: set -o pipefail curl https://wasmtime.dev/install.sh -sSf | bash - name: Test - run: CARGO_TARGET_WASM32_WASI_RUNNER="/home/runner/.wasmtime/bin/wasmtime --invoke _start --allow-unknown-exports" cargo test --target=wasm32-wasi --all-targets + run: CARGO_TARGET_WASM32_WASI_RUNNER="/home/runner/.wasmtime/bin/wasmtime --allow-unknown-exports" cargo test --target=wasm32-wasi --all-targets # Tests that our current minimum supported rust version compiles everything sucessfully min_rust: diff --git a/Cargo.toml b/Cargo.toml index b3c8cdc..ed36ec5 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,5 +1,18 @@ [workspace] members = [ - "scylla-bindgen", - "scylla-cql", + "examples", + "scylla-udf", + "scylla-udf-macros", + "tests", ] + +[workspace.package] +edition = "2021" +version = "0.0.1" +repository = "https://github.com/wmitros/scylla-rust-udf" +license = "MIT OR Apache-2.0" +rust-version = "1.66.1" + +[workspace.dependencies] +scylla-udf = { path = "scylla-udf" } +scylla-udf-macros = { path = "scylla-udf-macros" } diff --git a/README.md b/README.md index d89d6e2..a4fa8e2 100644 --- a/README.md +++ b/README.md @@ -1 +1,118 @@ -# Rust-utils-for-Scylla-UDFs \ No newline at end of file +# Rust helper library for Scylla UDFs + +## Usage + +### Prerequisites + +To use this helper library in Scylla you'll need: +* Standard library for Rust `wasm32-wasi` + * Can be added in rustup installations using `rustup target add wasm32-wasi` + * For non rustup setups, you can try following the steps at https://rustwasm.github.io/docs/wasm-pack/prerequisites/non-rustup-setups.html + * Also available as an rpm: `rust-std-static-wasm32-wasi` +* `wasm2wat` parser + * Available in many distributions in the `wabt` package + +### Compilation + +We recommend a setup with cargo. + +1. Start with a library package +``` +cargo new --lib +``` +2. Add the following lines to the Cargo.toml to set the crate-type to cdylib +``` +[lib] +crate-type = ["cdylib"] +``` +3. Implement your package, exporting Scylla UDFs using the `scylla_udf::export_udf` macro. +4. Build the package using the wasm32-wasi target: +``` +cargo build --target=wasm32-wasi +``` +5. Find the compiled `.wasm` binary. Let's assume it's `target/wasm32-wasi/debug/abc.wasm`. +6. (optional) Optimize the binary using `wasm-opt -O3 target/wasm32-wasi/debug/abc.wasm` (can be combined with using `cargo build --release` profile) +7. Translate the binary into `wat`: +``` +wasm2wat target/wasm32-wasi/debug/abc.wasm > target/wasm32/wasi/debug/abc.wat +``` + +### CQL Statement + +The resulting `target/wasm32/wasi/debug/abc.wat` code can now be used directly in a `CREATE FUNCTION` statement. The resulting code will most likely +contain `'` characters, so it may be necessary to first replace them with `''`, so that they're usable in a CQL string. + +For example, if you have an [Rust UDF](examples/commas.rs) that joins a list of words using commas, you can create a Scylla UDF using the following statement: +``` +CREATE FUNCTION commas(string list) CALLED ON NULL INPUT RETURNS text AS ' (module ...) ' +``` + + +## CQL Type Mapping + +The argument and return value types used in functions annotated with `#[export_udf]` must all map to CQL types used in the `CREATE FUNCTION` statements used in Scylla, according to the tables below. + +If the Scylla function is created with types that do not match the types used in the Rust function, calling the UDF will fail or produce arbitrary results. + +### Native types + +| CQL Type | Rust type | +| --------- | ----------------------------- | +| ASCII | String | +| BIGINT | i64 | +| BLOB | Vec\ | +| BOOLEAN | bool | +| COUNTER | scylla_udf::Counter | +| DATE | chrono::NaiveDate | +| DECIMAL | bigdecimal::Decimal | +| DOUBLE | f64 | +| DURATION | scylla_udf::CqlDuration | +| FLOAT | f32 | +| INET | std::net::IpAddr | +| INT | i32 | +| SMALLINT | i16 | +| TEXT | String | +| TIME | scylla_udf::Time | +| TIMESTAMP | scylla_udf::Timestamp | +| TIMEUUID | uuid::Uuid | +| TINYINT | i8 | +| UUID | uuid::Uuid | +| VARCHAR | String | +| VARINT | num_bigint::BigInt | + +### Collections + +If a CQL type `T` maps to Rust type `RustT`, you can use it as a collection parameter: + +| CQL Type | Rust type | +| ---------- | ------------------------------------------------------------------------------------- | +| LIST\ | Vec\ | +| MAP\ | std::collections::BTreeMap\, std::collections::HashMap\ | +| SET\ | Vec\, std::collections::BTreeSet\, std::collections::HashSet\ | + + +### Tuples + +If CQL types `T1`, `T2`, ... map to Rust types `RustT1`, `RustT2`, ..., you can use them in tuples: + +| CQL Type | Rust type | +| -------- | ---------------------------------- | +| TUPLE\ | (RustT1, RustT2, ...) | + +### Nulls + +If a CQL Value of type T, that's mapped to type RustT, may be a null (possible in non-`RETURNS NULL ON NULL INPUT` UDFs), +the type used in the Rust function should be Option\. + +## Contributing + +In general, try to follow the same rules as in https://github.com/scylladb/scylla-rust-driver/blob/main/CONTRIBUTING.md + +### Testing + +This crate is meant to be compiled to a `wasm32-wasi` target and ran in a WASM runtime. The tests that use WASM-specific code will most likely not succeed when executed in a different way (in particular, with a simple `cargo test` command). + +For example, if you have the [wasmtime](https://docs.wasmtime.dev/cli-install.html) runtime installed and in `PATH`, you can use the following command to run tests: +```text +CARGO_TARGET_WASM32_WASI_RUNNER="wasmtime --allow-unknown-exports" cargo test --target=wasm32-wasi +``` diff --git a/examples/Cargo.toml b/examples/Cargo.toml new file mode 100644 index 0000000..f79643e --- /dev/null +++ b/examples/Cargo.toml @@ -0,0 +1,65 @@ +[package] +name = "examples" +edition.workspace = true +version.workspace = true +repository.workspace = true +license.workspace = true +rust-version.workspace = true +publish = false + +[dependencies] +chrono = "0.4" +bigdecimal = "0.2.0" +num-bigint = "0.3" +scylla-udf = { workspace = true } +uuid = "1.0" + +[[example]] +name = "add" +path = "add.rs" +crate-type = ["cdylib"] + +[[example]] +name = "combine" +path = "combine.rs" +crate-type = ["cdylib"] + +[[example]] +name = "commas" +path = "commas.rs" +crate-type = ["cdylib"] + +[[example]] +name = "dbl" +path = "dbl.rs" +crate-type = ["cdylib"] + +[[example]] +name = "fib" +path = "fib.rs" +crate-type = ["cdylib"] + +[[example]] +name = "keys" +path = "keys.rs" +crate-type = ["cdylib"] + +[[example]] +name = "len" +path = "len.rs" +crate-type = ["cdylib"] + +[[example]] +name = "topn" +path = "topn.rs" +crate-type = ["cdylib"] + +[[example]] +name = "udt" +path = "udt.rs" +crate-type = ["cdylib"] + +[[example]] +name = "wordcount" +path = "wordcount.rs" +crate-type = ["cdylib"] diff --git a/examples/add.rs b/examples/add.rs new file mode 100644 index 0000000..a3fbe4a --- /dev/null +++ b/examples/add.rs @@ -0,0 +1,8 @@ +use scylla_udf::export_udf; + +type SmallInt = i16; + +#[export_udf] +fn add(i1: SmallInt, i2: SmallInt) -> SmallInt { + i1 + i2 +} diff --git a/examples/combine.rs b/examples/combine.rs new file mode 100644 index 0000000..106726d --- /dev/null +++ b/examples/combine.rs @@ -0,0 +1,50 @@ +use scylla_udf::{export_udf, CqlDuration, Time, Timestamp}; + +#[allow(clippy::too_many_arguments, clippy::type_complexity)] +#[export_udf] +fn combine( + b: bool, + blob: Vec, + date: chrono::NaiveDate, + bd: bigdecimal::BigDecimal, + dbl: f64, + cqldur: CqlDuration, + flt: f32, + int32: i32, + int64: i64, + s: String, + tstamp: Timestamp, + ip: std::net::IpAddr, + int16: i16, + int8: i8, + tim: Time, + uid: uuid::Uuid, + bi: num_bigint::BigInt, +) -> ( + ( + bool, + Vec, + chrono::NaiveDate, + bigdecimal::BigDecimal, + f64, + CqlDuration, + f32, + i32, + i64, + ), + ( + String, + Timestamp, + std::net::IpAddr, + i16, + i8, + Time, + uuid::Uuid, + num_bigint::BigInt, + ), +) { + ( + (b, blob, date, bd, dbl, cqldur, flt, int32, int64), + (s, tstamp, ip, int16, int8, tim, uid, bi), + ) +} diff --git a/examples/commas.rs b/examples/commas.rs new file mode 100644 index 0000000..5b88d7d --- /dev/null +++ b/examples/commas.rs @@ -0,0 +1,6 @@ +use scylla_udf::export_udf; + +#[export_udf] +fn commas(strings: Option>) -> Option { + strings.map(|strings| strings.join(", ")) +} diff --git a/examples/dbl.rs b/examples/dbl.rs new file mode 100644 index 0000000..c2f0a77 --- /dev/null +++ b/examples/dbl.rs @@ -0,0 +1,9 @@ +use scylla_udf::export_udf; + +#[export_udf] +fn dbl(s: String) -> String { + let mut newstr = String::new(); + newstr.push_str(&s); + newstr.push_str(&s); + newstr +} diff --git a/examples/fib.rs b/examples/fib.rs new file mode 100644 index 0000000..1413983 --- /dev/null +++ b/examples/fib.rs @@ -0,0 +1,16 @@ +use scylla_udf::*; + +#[export_newtype] +struct FibInputNumber(i32); + +#[export_newtype] +struct FibReturnNumber(i64); + +#[export_udf] +fn fib(i: FibInputNumber) -> FibReturnNumber { + FibReturnNumber(if i.0 <= 2 { + 1 + } else { + fib(FibInputNumber(i.0 - 1)).0 + fib(FibInputNumber(i.0 - 2)).0 + }) +} diff --git a/examples/keys.rs b/examples/keys.rs new file mode 100644 index 0000000..34256cb --- /dev/null +++ b/examples/keys.rs @@ -0,0 +1,6 @@ +use scylla_udf::export_udf; + +#[export_udf] +fn keys(map: std::collections::BTreeMap) -> Vec { + map.into_keys().collect() +} diff --git a/examples/len.rs b/examples/len.rs new file mode 100644 index 0000000..2a7dded --- /dev/null +++ b/examples/len.rs @@ -0,0 +1,6 @@ +use scylla_udf::export_udf; + +#[export_udf] +fn len(strings: std::collections::BTreeSet) -> i16 { + strings.len() as i16 +} diff --git a/examples/topn.rs b/examples/topn.rs new file mode 100644 index 0000000..d1927c0 --- /dev/null +++ b/examples/topn.rs @@ -0,0 +1,66 @@ +use scylla_udf::*; +use std::collections::BTreeSet; + +#[export_newtype] +struct StringLen(String); + +impl std::cmp::PartialEq for StringLen { + fn eq(&self, other: &Self) -> bool { + self.0 == other.0 + } +} + +impl std::cmp::Eq for StringLen {} + +impl std::cmp::PartialOrd for StringLen { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} + +impl std::cmp::Ord for StringLen { + fn cmp(&self, other: &Self) -> std::cmp::Ordering { + if self.0.len().cmp(&other.0.len()) == std::cmp::Ordering::Equal { + self.0.cmp(&other.0) + } else { + self.0.len().cmp(&other.0.len()) + } + } +} + +// Store the top N strings by length, without repetitions. +#[export_udf] +fn topn_row( + acc_tup: Option<(i32, BTreeSet)>, + v: Option, +) -> Option<(i32, BTreeSet)> { + if let Some((n, mut acc)) = acc_tup { + if let Some(v) = v { + acc.insert(v); + while acc.len() > n as usize { + acc.pop_first(); + } + } + Some((n, acc)) + } else { + None + } +} + +#[export_udf] +fn topn_reduce( + (n1, mut acc1): (i32, BTreeSet), + (n2, mut acc2): (i32, BTreeSet), +) -> (i32, BTreeSet) { + assert!(n1 == n2); + acc1.append(&mut acc2); + while acc1.len() > n1 as usize { + acc1.pop_first(); + } + (n1, acc1) +} + +#[export_udf] +fn topn_final((_, acc): (i32, BTreeSet)) -> BTreeSet { + acc +} diff --git a/examples/udt.rs b/examples/udt.rs new file mode 100644 index 0000000..1fa416d --- /dev/null +++ b/examples/udt.rs @@ -0,0 +1,19 @@ +use scylla_udf::*; + +#[export_udt] +struct Udt { + a: i32, + b: i32, + c: String, + d: String, +} + +#[export_udf] +fn udt(arg: Udt) -> Udt { + Udt { + a: arg.b, + b: arg.a, + c: arg.d, + d: arg.c, + } +} diff --git a/examples/wordcount.rs b/examples/wordcount.rs new file mode 100644 index 0000000..8088a44 --- /dev/null +++ b/examples/wordcount.rs @@ -0,0 +1,6 @@ +use scylla_udf::export_udf; + +#[export_udf] +fn wordcount(text: String) -> i32 { + text.split(' ').count() as i32 +} diff --git a/scylla-bindgen/Cargo.toml b/scylla-bindgen/Cargo.toml deleted file mode 100644 index 4b38ac8..0000000 --- a/scylla-bindgen/Cargo.toml +++ /dev/null @@ -1,7 +0,0 @@ -[package] -name = "scylla-bindgen" -version = "0.1.0" -edition = "2021" - -[lib] -crate-type = ["cdylib"] diff --git a/scylla-bindgen/src/lib.rs b/scylla-bindgen/src/lib.rs deleted file mode 100644 index 8b13789..0000000 --- a/scylla-bindgen/src/lib.rs +++ /dev/null @@ -1 +0,0 @@ - diff --git a/scylla-cql/Cargo.toml b/scylla-cql/Cargo.toml deleted file mode 100644 index b299631..0000000 --- a/scylla-cql/Cargo.toml +++ /dev/null @@ -1,36 +0,0 @@ -[package] -name = "scylla-cql" -version = "0.0.3" -edition = "2021" -description = "CQL data types and primitives, for interacting with Scylla." -repository = "https://github.com/scylladb/scylla-rust-driver" -readme = "../README.md" -keywords = ["database", "scylla", "cql", "cassandra"] -categories = ["database"] -license = "MIT OR Apache-2.0" - -[dependencies] -scylla-macros = { version = "0.1.2", path = "../scylla-macros"} -byteorder = "1.3.4" -bytes = "1.0.1" -num_enum = "0.5" -tokio = { version = "1.12", features = ["io-util", "time"] } -secrecy = "0.7.0" -snap = "1.0" -uuid = "1.0" -thiserror = "1.0" -bigdecimal = "0.2.0" -num-bigint = "0.3" -chrono = "0.4" -lz4_flex = { version = "0.9.2" } -async-trait = "0.1.57" - -[dev-dependencies] -criterion = "0.3" - -[[bench]] -name = "benchmark" -harness = false - -[features] -secret = [] \ No newline at end of file diff --git a/scylla-cql/benches/benchmark.rs b/scylla-cql/benches/benchmark.rs deleted file mode 100644 index b8e9cbc..0000000 --- a/scylla-cql/benches/benchmark.rs +++ /dev/null @@ -1,53 +0,0 @@ -use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion}; - -use scylla_cql::frame::request::Request; -use scylla_cql::frame::value::SerializedValues; -use scylla_cql::frame::value::ValueList; -use scylla_cql::frame::{request::query, Compression, SerializedRequest}; - -fn make_query<'a>(contents: &'a str, values: &'a SerializedValues) -> query::Query<'a> { - query::Query { - contents, - parameters: query::QueryParameters { - consistency: scylla_cql::Consistency::LocalQuorum, - serial_consistency: None, - values, - page_size: None, - paging_state: None, - timestamp: None, - }, - } -} - -fn serialized_request_make_bench(c: &mut Criterion) { - let mut group = c.benchmark_group("LZ4Compression.SerializedRequest"); - let query_args = [ - ("INSERT foo INTO ks.table_name (?)", &(1234,).serialized().unwrap()), - ("INSERT foo, bar, baz INTO ks.table_name (?, ?, ?)", &(1234, "a value", "i am storing a string").serialized().unwrap()), - ( - "INSERT foo, bar, baz, boop, blah INTO longer_keyspace.a_big_table_name (?, ?, ?, ?, 1000)", - &(1234, "a value", "i am storing a string", "dc0c8cd7-d954-47c1-8722-a857941c43fb").serialized().unwrap() - ), - ]; - let queries = query_args.map(|(q, v)| make_query(q, v)); - - for query in queries { - let query_size = query.to_bytes().unwrap().len(); - group.bench_with_input( - BenchmarkId::new("SerializedRequest::make", query_size), - &query, - |b, query| { - b.iter(|| { - let _ = criterion::black_box(SerializedRequest::make( - query, - Some(Compression::Lz4), - false, - )); - }) - }, - ); - } -} - -criterion_group!(benches, serialized_request_make_bench); -criterion_main!(benches); diff --git a/scylla-cql/src/errors.rs b/scylla-cql/src/errors.rs deleted file mode 100644 index d660a20..0000000 --- a/scylla-cql/src/errors.rs +++ /dev/null @@ -1,590 +0,0 @@ -//! This module contains various errors which can be returned by `scylla::Session` - -use crate::frame::frame_errors::{FrameError, ParseError}; -use crate::frame::protocol_features::ProtocolFeatures; -use crate::frame::types::LegacyConsistency; -use crate::frame::value::SerializeValuesError; -use bytes::Bytes; -use std::io::ErrorKind; -use std::sync::Arc; -use thiserror::Error; - -/// Error that occurred during query execution -#[derive(Error, Debug, Clone)] -pub enum QueryError { - /// Database sent a response containing some error with a message - #[error("Database returned an error: {0}, Error message: {1}")] - DbError(DbError, String), - - /// Caller passed an invalid query - #[error(transparent)] - BadQuery(#[from] BadQuery), - - /// Input/Output error has occurred, connection broken etc. - #[error("IO Error: {0}")] - IoError(Arc), - - /// Unexpected message received - #[error("Protocol Error: {0}")] - ProtocolError(&'static str), - - /// Invalid message received - #[error("Invalid message: {0}")] - InvalidMessage(String), - - /// Timeout error has occurred, function didn't complete in time. - #[error("Timeout Error")] - TimeoutError, - - #[error("Too many orphaned stream ids: {0}")] - TooManyOrphanedStreamIds(u16), - - #[error("Unable to allocate stream id")] - UnableToAllocStreamId, - - /// Client timeout occurred before any response arrived - #[error("Request timeout: {0}")] - RequestTimeout(String), -} - -/// An error sent from the database in response to a query -/// as described in the [specification](https://github.com/apache/cassandra/blob/5ed5e84613ef0e9664a774493db7d2604e3596e0/doc/native_protocol_v4.spec#L1029)\ -#[derive(Error, Debug, Clone, PartialEq, Eq)] -pub enum DbError { - /// The submitted query has a syntax error - #[error("The submitted query has a syntax error")] - SyntaxError, - - /// The query is syntactically correct but invalid - #[error("The query is syntactically correct but invalid")] - Invalid, - - /// Attempted to create a keyspace or a table that was already existing - #[error( - "Attempted to create a keyspace or a table that was already existing \ - (keyspace: {keyspace}, table: {table})" - )] - AlreadyExists { - /// Created keyspace name or name of the keyspace in which table was created - keyspace: String, - /// Name of the table created, in case of keyspace creation it's an empty string - table: String, - }, - - /// User defined function failed during execution - #[error( - "User defined function failed during execution \ - (keyspace: {keyspace}, function: {function}, arg_types: {arg_types:?})" - )] - FunctionFailure { - /// Keyspace of the failed function - keyspace: String, - /// Name of the failed function - function: String, - /// Types of arguments passed to the function - arg_types: Vec, - }, - - /// Authentication failed - bad credentials - #[error("Authentication failed - bad credentials")] - AuthenticationError, - - /// The logged user doesn't have the right to perform the query - #[error("The logged user doesn't have the right to perform the query")] - Unauthorized, - - /// The query is invalid because of some configuration issue - #[error("The query is invalid because of some configuration issue")] - ConfigError, - - /// Not enough nodes are alive to satisfy required consistency level - #[error( - "Not enough nodes are alive to satisfy required consistency level \ - (consistency: {consistency}, required: {required}, alive: {alive})" - )] - Unavailable { - /// Consistency level of the query - consistency: LegacyConsistency, - /// Number of nodes required to be alive to satisfy required consistency level - required: i32, - /// Found number of active nodes - alive: i32, - }, - - /// The request cannot be processed because the coordinator node is overloaded - #[error("The request cannot be processed because the coordinator node is overloaded")] - Overloaded, - - /// The coordinator node is still bootstrapping - #[error("The coordinator node is still bootstrapping")] - IsBootstrapping, - - /// Error during truncate operation - #[error("Error during truncate operation")] - TruncateError, - - /// Not enough nodes responded to the read request in time to satisfy required consistency level - #[error("Not enough nodes responded to the read request in time to satisfy required consistency level \ - (consistency: {consistency}, received: {received}, required: {required}, data_present: {data_present})")] - ReadTimeout { - /// Consistency level of the query - consistency: LegacyConsistency, - /// Number of nodes that responded to the read request - received: i32, - /// Number of nodes required to respond to satisfy required consistency level - required: i32, - /// Replica that was asked for data has responded - data_present: bool, - }, - - /// Not enough nodes responded to the write request in time to satisfy required consistency level - #[error("Not enough nodes responded to the write request in time to satisfy required consistency level \ - (consistency: {consistency}, received: {received}, required: {required}, write_type: {write_type})")] - WriteTimeout { - /// Consistency level of the query - consistency: LegacyConsistency, - /// Number of nodes that responded to the write request - received: i32, - /// Number of nodes required to respond to satisfy required consistency level - required: i32, - /// Type of write operation requested - write_type: WriteType, - }, - - /// A non-timeout error during a read request - #[error( - "A non-timeout error during a read request \ - (consistency: {consistency}, received: {received}, required: {required}, \ - numfailures: {numfailures}, data_present: {data_present})" - )] - ReadFailure { - /// Consistency level of the query - consistency: LegacyConsistency, - /// Number of nodes that responded to the read request - received: i32, - /// Number of nodes required to respond to satisfy required consistency level - required: i32, - /// Number of nodes that experience a failure while executing the request - numfailures: i32, - /// Replica that was asked for data has responded - data_present: bool, - }, - - /// A non-timeout error during a write request - #[error( - "A non-timeout error during a write request \ - (consistency: {consistency}, received: {received}, required: {required}, \ - numfailures: {numfailures}, write_type: {write_type}" - )] - WriteFailure { - /// Consistency level of the query - consistency: LegacyConsistency, - /// Number of nodes that responded to the read request - received: i32, - /// Number of nodes required to respond to satisfy required consistency level - required: i32, - /// Number of nodes that experience a failure while executing the request - numfailures: i32, - /// Type of write operation requested - write_type: WriteType, - }, - - /// Tried to execute a prepared statement that is not prepared. Driver should prepare it again - #[error( - "Tried to execute a prepared statement that is not prepared. Driver should prepare it again" - )] - Unprepared { - /// Statement id of the requested prepared query - statement_id: Bytes, - }, - - /// Internal server error. This indicates a server-side bug - #[error("Internal server error. This indicates a server-side bug")] - ServerError, - - /// Invalid protocol message received from the driver - #[error("Invalid protocol message received from the driver")] - ProtocolError, - - /// Rate limit was exceeded for a partition affected by the request. - /// (Scylla-specific) - /// TODO: Should this have a "Scylla" prefix? - #[error("Rate limit was exceeded for a partition affected by the request")] - RateLimitReached { - /// Type of the operation rejected by rate limiting. - op_type: OperationType, - /// Whether the operation was rate limited on the coordinator or not. - /// Writes rejected on the coordinator are guaranteed not to be applied - /// on any replica. - rejected_by_coordinator: bool, - }, - - /// Other error code not specified in the specification - #[error("Other error not specified in the specification. Error code: {0}")] - Other(i32), -} - -impl DbError { - pub fn code(&self, protocol_features: &ProtocolFeatures) -> i32 { - match self { - DbError::ServerError => 0x0000, - DbError::ProtocolError => 0x000A, - DbError::AuthenticationError => 0x0100, - DbError::Unavailable { - consistency: _, - required: _, - alive: _, - } => 0x1000, - DbError::Overloaded => 0x1001, - DbError::IsBootstrapping => 0x1002, - DbError::TruncateError => 0x1003, - DbError::WriteTimeout { - consistency: _, - received: _, - required: _, - write_type: _, - } => 0x1100, - DbError::ReadTimeout { - consistency: _, - received: _, - required: _, - data_present: _, - } => 0x1200, - DbError::ReadFailure { - consistency: _, - received: _, - required: _, - numfailures: _, - data_present: _, - } => 0x1300, - DbError::FunctionFailure { - keyspace: _, - function: _, - arg_types: _, - } => 0x1400, - DbError::WriteFailure { - consistency: _, - received: _, - required: _, - numfailures: _, - write_type: _, - } => 0x1500, - DbError::SyntaxError => 0x2000, - DbError::Unauthorized => 0x2100, - DbError::Invalid => 0x2200, - DbError::ConfigError => 0x2300, - DbError::AlreadyExists { - keyspace: _, - table: _, - } => 0x2400, - DbError::Unprepared { statement_id: _ } => 0x2500, - DbError::Other(code) => *code, - DbError::RateLimitReached { - op_type: _, - rejected_by_coordinator: _, - } => protocol_features.rate_limit_error.unwrap(), - } - } -} - -/// Type of the operation rejected by rate limiting -#[derive(Debug, Clone, PartialEq, Eq)] -pub enum OperationType { - Read, - Write, - Other(u8), -} - -/// Type of write operation requested -#[derive(Debug, Clone, PartialEq, Eq)] -pub enum WriteType { - /// Non-batched non-counter write - Simple, - /// Logged batch write. If this type is received, it means the batch log has been successfully written - /// (otherwise BatchLog type would be present) - Batch, - /// Unlogged batch. No batch log write has been attempted. - UnloggedBatch, - /// Counter write (batched or not) - Counter, - /// Timeout occurred during the write to the batch log when a logged batch was requested - BatchLog, - /// Timeout occurred during Compare And Set write/update - Cas, - /// Write involves VIEW update and failure to acquire local view(MV) lock for key within timeout - View, - /// Timeout occurred when a cdc_total_space_in_mb is exceeded when doing a write to data tracked by cdc - Cdc, - /// Other type not specified in the specification - Other(String), -} - -/// Error caused by caller creating an invalid query -#[derive(Error, Debug, Clone)] -#[error("Invalid query passed to Session")] -pub enum BadQuery { - /// Failed to serialize values passed to a query - values too big - #[error("Serializing values failed: {0} ")] - SerializeValuesError(#[from] SerializeValuesError), - - /// Serialized values are too long to compute partition key - #[error("Serialized values are too long to compute partition key! Length: {0}, Max allowed length: {1}")] - ValuesTooLongForKey(usize, usize), - - /// Passed invalid keyspace name to use - #[error("Passed invalid keyspace name to use: {0}")] - BadKeyspaceName(#[from] BadKeyspaceName), - - /// Other reasons of bad query - #[error("{0}")] - Other(String), -} - -/// Error that occurred during session creation -#[derive(Error, Debug, Clone)] -pub enum NewSessionError { - /// Failed to resolve hostname passed in Session creation - #[error("Couldn't resolve address: {0}")] - FailedToResolveAddress(String), - - /// List of known nodes passed to Session constructor is empty - /// There needs to be at least one node to connect to - #[error("Empty known nodes list")] - EmptyKnownNodesList, - - /// Database sent a response containing some error with a message - #[error("Database returned an error: {0}, Error message: {1}")] - DbError(DbError, String), - - /// Caller passed an invalid query - #[error(transparent)] - BadQuery(#[from] BadQuery), - - /// Input/Output error has occurred, connection broken etc. - #[error("IO Error: {0}")] - IoError(Arc), - - /// Unexpected message received - #[error("Protocol Error: {0}")] - ProtocolError(&'static str), - - /// Invalid message received - #[error("Invalid message: {0}")] - InvalidMessage(String), - - /// Timeout error has occurred, couldn't connect to node in time. - #[error("Timeout Error")] - TimeoutError, - - #[error("Too many orphaned stream ids: {0}")] - TooManyOrphanedStreamIds(u16), - - #[error("Unable to allocate stream id")] - UnableToAllocStreamId, - - /// Client timeout occurred before a response arrived for some query - /// during `Session` creation. - #[error("Client timeout: {0}")] - RequestTimeout(String), -} - -/// Invalid keyspace name given to `Session::use_keyspace()` -#[derive(Debug, Error, Clone)] -pub enum BadKeyspaceName { - /// Keyspace name is empty - #[error("Keyspace name is empty")] - Empty, - - /// Keyspace name too long, must be up to 48 characters - #[error("Keyspace name too long, must be up to 48 characters, found {1} characters. Bad keyspace name: '{0}'")] - TooLong(String, usize), - - /// Illegal character - only alphanumeric and underscores allowed. - #[error("Illegal character found: '{1}', only alphanumeric and underscores allowed. Bad keyspace name: '{0}'")] - IllegalCharacter(String, char), -} - -impl std::fmt::Display for WriteType { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "{:?}", self) - } -} - -impl From for QueryError { - fn from(io_error: std::io::Error) -> QueryError { - QueryError::IoError(Arc::new(io_error)) - } -} - -impl From for QueryError { - fn from(serialized_err: SerializeValuesError) -> QueryError { - QueryError::BadQuery(BadQuery::SerializeValuesError(serialized_err)) - } -} - -impl From for QueryError { - fn from(parse_error: ParseError) -> QueryError { - QueryError::InvalidMessage(format!("Error parsing message: {}", parse_error)) - } -} - -impl From for QueryError { - fn from(frame_error: FrameError) -> QueryError { - QueryError::InvalidMessage(format!("Frame error: {}", frame_error)) - } -} - -impl From for QueryError { - fn from(timer_error: tokio::time::error::Elapsed) -> QueryError { - QueryError::RequestTimeout(format!("{}", timer_error)) - } -} - -impl From for NewSessionError { - fn from(io_error: std::io::Error) -> NewSessionError { - NewSessionError::IoError(Arc::new(io_error)) - } -} - -impl From for NewSessionError { - fn from(query_error: QueryError) -> NewSessionError { - match query_error { - QueryError::DbError(e, msg) => NewSessionError::DbError(e, msg), - QueryError::BadQuery(e) => NewSessionError::BadQuery(e), - QueryError::IoError(e) => NewSessionError::IoError(e), - QueryError::ProtocolError(m) => NewSessionError::ProtocolError(m), - QueryError::InvalidMessage(m) => NewSessionError::InvalidMessage(m), - QueryError::TimeoutError => NewSessionError::TimeoutError, - QueryError::TooManyOrphanedStreamIds(ids) => { - NewSessionError::TooManyOrphanedStreamIds(ids) - } - QueryError::UnableToAllocStreamId => NewSessionError::UnableToAllocStreamId, - QueryError::RequestTimeout(msg) => NewSessionError::RequestTimeout(msg), - } - } -} - -impl From for QueryError { - fn from(keyspace_err: BadKeyspaceName) -> QueryError { - QueryError::BadQuery(BadQuery::BadKeyspaceName(keyspace_err)) - } -} - -impl QueryError { - /// Checks if this error indicates that a chosen source port/address cannot be bound. - /// This is caused by one of the following: - /// - The source address is already used by another socket, - /// - The source address is reserved and the process does not have sufficient privileges to use it. - pub fn is_address_unavailable_for_use(&self) -> bool { - if let QueryError::IoError(io_error) = self { - match io_error.kind() { - ErrorKind::AddrInUse | ErrorKind::PermissionDenied => return true, - _ => {} - } - } - - false - } -} - -impl From for OperationType { - fn from(operation_type: u8) -> OperationType { - match operation_type { - 0 => OperationType::Read, - 1 => OperationType::Write, - other => OperationType::Other(other), - } - } -} - -impl From<&str> for WriteType { - fn from(write_type_str: &str) -> WriteType { - match write_type_str { - "SIMPLE" => WriteType::Simple, - "BATCH" => WriteType::Batch, - "UNLOGGED_BATCH" => WriteType::UnloggedBatch, - "COUNTER" => WriteType::Counter, - "BATCH_LOG" => WriteType::BatchLog, - "CAS" => WriteType::Cas, - "VIEW" => WriteType::View, - "CDC" => WriteType::Cdc, - _ => WriteType::Other(write_type_str.to_string()), - } - } -} - -impl WriteType { - pub fn as_str(&self) -> &str { - match self { - WriteType::Simple => "SIMPLE", - WriteType::Batch => "BATCH", - WriteType::UnloggedBatch => "UNLOGGED_BATCH", - WriteType::Counter => "COUNTER", - WriteType::BatchLog => "BATCH_LOG", - WriteType::Cas => "CAS", - WriteType::View => "VIEW", - WriteType::Cdc => "CDC", - WriteType::Other(write_type) => write_type.as_str(), - } - } -} - -#[cfg(test)] -mod tests { - use super::{DbError, QueryError, WriteType}; - use crate::frame::types::{Consistency, LegacyConsistency}; - - #[test] - fn write_type_from_str() { - let test_cases: [(&str, WriteType); 9] = [ - ("SIMPLE", WriteType::Simple), - ("BATCH", WriteType::Batch), - ("UNLOGGED_BATCH", WriteType::UnloggedBatch), - ("COUNTER", WriteType::Counter), - ("BATCH_LOG", WriteType::BatchLog), - ("CAS", WriteType::Cas), - ("VIEW", WriteType::View), - ("CDC", WriteType::Cdc), - ("SOMEOTHER", WriteType::Other("SOMEOTHER".to_string())), - ]; - - for (write_type_str, expected_write_type) in &test_cases { - let write_type = WriteType::from(*write_type_str); - assert_eq!(write_type, *expected_write_type); - } - } - - // A test to check that displaying DbError and QueryError::DbError works as expected - // - displays error description - // - displays error parameters - // - displays error message - // - indented multiline strings don't cause whitespace gaps - #[test] - fn dberror_full_info() { - // Test that DbError::Unavailable is displayed correctly - let db_error = DbError::Unavailable { - consistency: LegacyConsistency::Regular(Consistency::Three), - required: 3, - alive: 2, - }; - - let db_error_displayed: String = format!("{}", db_error); - - let mut expected_dberr_msg = - "Not enough nodes are alive to satisfy required consistency level ".to_string(); - expected_dberr_msg += "(consistency: Three, required: 3, alive: 2)"; - - assert_eq!(db_error_displayed, expected_dberr_msg); - - // Test that QueryError::DbError::(DbError::Unavailable) is displayed correctly - let query_error = - QueryError::DbError(db_error, "a message about unavailable error".to_string()); - let query_error_displayed: String = format!("{}", query_error); - - let mut expected_querr_msg = "Database returned an error: ".to_string(); - expected_querr_msg += &expected_dberr_msg; - expected_querr_msg += ", Error message: a message about unavailable error"; - - assert_eq!(query_error_displayed, expected_querr_msg); - } -} diff --git a/scylla-cql/src/frame/frame_errors.rs b/scylla-cql/src/frame/frame_errors.rs deleted file mode 100644 index 403b6ab..0000000 --- a/scylla-cql/src/frame/frame_errors.rs +++ /dev/null @@ -1,48 +0,0 @@ -use super::response; -use crate::cql_to_rust::CqlTypeError; -use crate::frame::value::SerializeValuesError; -use thiserror::Error; - -#[derive(Error, Debug)] -pub enum FrameError { - #[error(transparent)] - Parse(#[from] ParseError), - #[error("Frame is compressed, but no compression negotiated for connection.")] - NoCompressionNegotiated, - #[error("Received frame marked as coming from a client")] - FrameFromClient, - #[error("Received frame marked as coming from the server")] - FrameFromServer, - #[error("Received a frame from version {0}, but only 4 is supported")] - VersionNotSupported(u8), - #[error("Connection was closed before body was read: missing {0} out of {1}")] - ConnectionClosed(usize, usize), - #[error("Frame decompression failed.")] - FrameDecompression, - #[error("Frame compression failed.")] - FrameCompression, - #[error(transparent)] - StdIoError(#[from] std::io::Error), - #[error("Unrecognized opcode{0}")] - TryFromPrimitiveError(#[from] num_enum::TryFromPrimitiveError), - #[error("Error compressing lz4 data {0}")] - Lz4CompressError(#[from] lz4_flex::block::CompressError), - #[error("Error decompressing lz4 data {0}")] - Lz4DecompressError(#[from] lz4_flex::block::DecompressError), -} - -#[derive(Error, Debug)] -pub enum ParseError { - #[error("Could not serialize frame: {0}")] - BadDataToSerialize(String), - #[error("Could not deserialize frame: {0}")] - BadIncomingData(String), - #[error(transparent)] - IoError(#[from] std::io::Error), - #[error("type not yet implemented, id: {0}")] - TypeNotImplemented(i16), - #[error(transparent)] - SerializeValuesError(#[from] SerializeValuesError), - #[error(transparent)] - CqlTypeError(#[from] CqlTypeError), -} diff --git a/scylla-cql/src/frame/mod.rs b/scylla-cql/src/frame/mod.rs deleted file mode 100644 index cd9c4d3..0000000 --- a/scylla-cql/src/frame/mod.rs +++ /dev/null @@ -1,289 +0,0 @@ -pub mod frame_errors; -pub mod protocol_features; -pub mod request; -pub mod response; -pub mod server_event_type; -pub mod types; -pub mod value; - -#[cfg(test)] -mod value_tests; - -use crate::frame::frame_errors::FrameError; -use bytes::{Buf, BufMut, Bytes}; -use tokio::io::{AsyncRead, AsyncReadExt}; -use uuid::Uuid; - -use std::convert::TryFrom; - -use request::Request; -use response::ResponseOpcode; - -const HEADER_SIZE: usize = 9; - -// Frame flags -pub const FLAG_COMPRESSION: u8 = 0x01; -pub const FLAG_TRACING: u8 = 0x02; -pub const FLAG_CUSTOM_PAYLOAD: u8 = 0x04; -pub const FLAG_WARNING: u8 = 0x08; - -// All of the Authenticators supported by Scylla -#[derive(Debug, PartialEq, Eq, Clone)] -pub enum Authenticator { - AllowAllAuthenticator, - PasswordAuthenticator, - CassandraPasswordAuthenticator, - CassandraAllowAllAuthenticator, - ScyllaTransitionalAuthenticator, -} - -/// The wire protocol compression algorithm. -#[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Debug)] -pub enum Compression { - /// LZ4 compression algorithm. - Lz4, - /// Snappy compression algorithm. - Snappy, -} - -impl ToString for Compression { - fn to_string(&self) -> String { - match self { - Compression::Lz4 => "lz4".to_owned(), - Compression::Snappy => "snappy".to_owned(), - } - } -} - -pub struct SerializedRequest { - data: Vec, -} - -impl SerializedRequest { - pub fn make( - req: &R, - compression: Option, - tracing: bool, - ) -> Result { - let mut flags = 0; - let mut data = vec![0; HEADER_SIZE]; - - if let Some(compression) = compression { - flags |= FLAG_COMPRESSION; - let body = req.to_bytes()?; - compress_append(&body, compression, &mut data)?; - } else { - req.serialize(&mut data)?; - } - - if tracing { - flags |= FLAG_TRACING; - } - - data[0] = 4; // We only support version 4 for now - data[1] = flags; - // Leave space for the stream number - data[4] = R::OPCODE as u8; - - let req_size = (data.len() - HEADER_SIZE) as u32; - data[5..9].copy_from_slice(&req_size.to_be_bytes()); - - Ok(Self { data }) - } - - pub fn set_stream(&mut self, stream: i16) { - self.data[2..4].copy_from_slice(&stream.to_be_bytes()); - } - - pub fn get_data(&self) -> &[u8] { - &self.data[..] - } -} - -// Parts of the frame header which are not determined by the request/response type. -#[derive(Debug, Copy, Clone, PartialEq, Eq)] -pub struct FrameParams { - pub version: u8, - pub flags: u8, - pub stream: i16, -} - -impl Default for FrameParams { - fn default() -> Self { - Self { - version: 0x04, - flags: 0x00, - stream: 0, - } - } -} - -pub async fn read_response_frame( - reader: &mut (impl AsyncRead + Unpin), -) -> Result<(FrameParams, ResponseOpcode, Bytes), FrameError> { - let mut raw_header = [0u8; HEADER_SIZE]; - reader.read_exact(&mut raw_header[..]).await?; - - let mut buf = &raw_header[..]; - - // TODO: Validate version - let version = buf.get_u8(); - if version & 0x80 != 0x80 { - return Err(FrameError::FrameFromClient); - } - if version & 0x7F != 0x04 { - return Err(FrameError::VersionNotSupported(version & 0x7f)); - } - - let flags = buf.get_u8(); - let stream = buf.get_i16(); - - let frame_params = FrameParams { - version, - flags, - stream, - }; - - let opcode = ResponseOpcode::try_from(buf.get_u8())?; - - // TODO: Guard from frames that are too large - let length = buf.get_u32() as usize; - - let mut raw_body = Vec::with_capacity(length).limit(length); - while raw_body.has_remaining_mut() { - let n = reader.read_buf(&mut raw_body).await?; - if n == 0 { - // EOF, too early - return Err(FrameError::ConnectionClosed( - raw_body.remaining_mut(), - length, - )); - } - } - - Ok((frame_params, opcode, raw_body.into_inner().into())) -} - -pub struct ResponseBodyWithExtensions { - pub trace_id: Option, - pub warnings: Vec, - pub body: Bytes, -} - -pub fn parse_response_body_extensions( - flags: u8, - compression: Option, - mut body: Bytes, -) -> Result { - if flags & FLAG_COMPRESSION != 0 { - if let Some(compression) = compression { - body = decompress(&body, compression)?.into(); - } else { - return Err(FrameError::NoCompressionNegotiated); - } - } - - let trace_id = if flags & FLAG_TRACING != 0 { - let buf = &mut &*body; - let trace_id = types::read_uuid(buf)?; - body.advance(16); - Some(trace_id) - } else { - None - }; - - let warnings = if flags & FLAG_WARNING != 0 { - let body_len = body.len(); - let buf = &mut &*body; - let warnings = types::read_string_list(buf)?; - let buf_len = buf.len(); - body.advance(body_len - buf_len); - warnings - } else { - Vec::new() - }; - - if flags & FLAG_CUSTOM_PAYLOAD != 0 { - // TODO: Do something useful with the custom payload map - // For now, just skip it - let body_len = body.len(); - let buf = &mut &*body; - types::read_bytes_map(buf)?; - let buf_len = buf.len(); - body.advance(body_len - buf_len); - } - - Ok(ResponseBodyWithExtensions { - trace_id, - warnings, - body, - }) -} - -pub fn compress_append( - uncomp_body: &[u8], - compression: Compression, - out: &mut Vec, -) -> Result<(), FrameError> { - match compression { - Compression::Lz4 => { - let uncomp_len = uncomp_body.len() as u32; - let tmp = lz4_flex::compress(uncomp_body); - out.reserve_exact(std::mem::size_of::() + tmp.len()); - out.put_u32(uncomp_len); - out.extend_from_slice(&tmp[..]); - Ok(()) - } - Compression::Snappy => { - let old_size = out.len(); - out.resize(old_size + snap::raw::max_compress_len(uncomp_body.len()), 0); - let compressed_size = snap::raw::Encoder::new() - .compress(uncomp_body, &mut out[old_size..]) - .map_err(|_| FrameError::FrameCompression)?; - out.truncate(old_size + compressed_size); - Ok(()) - } - } -} - -pub fn decompress(mut comp_body: &[u8], compression: Compression) -> Result, FrameError> { - match compression { - Compression::Lz4 => { - let uncomp_len = comp_body.get_u32() as usize; - let uncomp_body = lz4_flex::decompress(comp_body, uncomp_len)?; - Ok(uncomp_body) - } - Compression::Snappy => snap::raw::Decoder::new() - .decompress_vec(comp_body) - .map_err(|_| FrameError::FrameDecompression), - } -} - -#[cfg(test)] -mod test { - use super::*; - - #[test] - fn test_lz4_compress() { - let mut out = Vec::from(&b"Hello"[..]); - let uncomp_body = b", World!"; - let compression = Compression::Lz4; - let expect = vec![ - 72, 101, 108, 108, 111, 0, 0, 0, 8, 128, 44, 32, 87, 111, 114, 108, 100, 33, - ]; - - compress_append(uncomp_body, compression, &mut out).unwrap(); - assert_eq!(expect, out); - } - - #[test] - fn test_lz4_decompress() { - let mut comp_body = Vec::new(); - let uncomp_body = "Hello, World!".repeat(100); - let compression = Compression::Lz4; - compress_append(uncomp_body.as_bytes(), compression, &mut comp_body).unwrap(); - let result = decompress(&comp_body[..], compression).unwrap(); - assert_eq!(32, comp_body.len()); - assert_eq!(uncomp_body.as_bytes(), result); - } -} diff --git a/scylla-cql/src/frame/protocol_features.rs b/scylla-cql/src/frame/protocol_features.rs deleted file mode 100644 index a3bc01b..0000000 --- a/scylla-cql/src/frame/protocol_features.rs +++ /dev/null @@ -1,63 +0,0 @@ -use std::collections::HashMap; - -const RATE_LIMIT_ERROR_EXTENSION: &str = "SCYLLA_RATE_LIMIT_ERROR"; -pub const SCYLLA_LWT_ADD_METADATA_MARK_EXTENSION: &str = "SCYLLA_LWT_ADD_METADATA_MARK"; -pub const LWT_OPTIMIZATION_META_BIT_MASK_KEY: &str = "LWT_OPTIMIZATION_META_BIT_MASK"; -#[derive(Default, Clone, Copy, Debug, PartialEq, Eq)] -#[non_exhaustive] -pub struct ProtocolFeatures { - pub rate_limit_error: Option, - pub lwt_optimization_meta_bit_mask: Option, -} - -// TODO: Log information about options which failed to parse - -impl ProtocolFeatures { - pub fn parse_from_supported(supported: &HashMap>) -> Self { - Self { - rate_limit_error: Self::maybe_parse_rate_limit_error(supported), - lwt_optimization_meta_bit_mask: Self::maybe_parse_lwt_optimization_meta_bit_mask( - supported, - ), - } - } - - fn maybe_parse_rate_limit_error(supported: &HashMap>) -> Option { - let vals = supported.get(RATE_LIMIT_ERROR_EXTENSION)?; - let code_str = Self::get_cql_extension_field(vals.as_slice(), "ERROR_CODE")?; - code_str.parse::().ok() - } - - fn maybe_parse_lwt_optimization_meta_bit_mask( - supported: &HashMap>, - ) -> Option { - let vals = supported.get(SCYLLA_LWT_ADD_METADATA_MARK_EXTENSION)?; - let mask_str = - Self::get_cql_extension_field(vals.as_slice(), LWT_OPTIMIZATION_META_BIT_MASK_KEY)?; - mask_str.parse::().ok() - } - - // Looks up a field which starts with `key=` and returns the rest - fn get_cql_extension_field<'a>(vals: &'a [String], key: &str) -> Option<&'a str> { - vals.iter() - .find_map(|v| v.as_str().strip_prefix(key)?.strip_prefix('=')) - } - - pub fn add_startup_options(&self, options: &mut HashMap) { - if self.rate_limit_error.is_some() { - options.insert(RATE_LIMIT_ERROR_EXTENSION.to_string(), String::new()); - } - if let Some(mask) = self.lwt_optimization_meta_bit_mask { - options.insert( - SCYLLA_LWT_ADD_METADATA_MARK_EXTENSION.to_string(), - format!("{}={}", LWT_OPTIMIZATION_META_BIT_MASK_KEY, mask), - ); - } - } - - pub fn prepared_flags_contain_lwt_mark(&self, flags: u32) -> bool { - self.lwt_optimization_meta_bit_mask - .map(|mask| (flags & mask) == mask) - .unwrap_or(false) - } -} diff --git a/scylla-cql/src/frame/request/auth_response.rs b/scylla-cql/src/frame/request/auth_response.rs deleted file mode 100644 index 6193277..0000000 --- a/scylla-cql/src/frame/request/auth_response.rs +++ /dev/null @@ -1,18 +0,0 @@ -use crate::frame::frame_errors::ParseError; -use bytes::BufMut; - -use crate::frame::request::{Request, RequestOpcode}; -use crate::frame::types::write_bytes_opt; - -// Implements Authenticate Response -pub struct AuthResponse { - pub response: Option>, -} - -impl Request for AuthResponse { - const OPCODE: RequestOpcode = RequestOpcode::AuthResponse; - - fn serialize(&self, buf: &mut impl BufMut) -> Result<(), ParseError> { - write_bytes_opt(self.response.as_ref(), buf) - } -} diff --git a/scylla-cql/src/frame/request/batch.rs b/scylla-cql/src/frame/request/batch.rs deleted file mode 100644 index d7876c2..0000000 --- a/scylla-cql/src/frame/request/batch.rs +++ /dev/null @@ -1,128 +0,0 @@ -use crate::frame::{frame_errors::ParseError, value::BatchValuesIterator}; -use bytes::{BufMut, Bytes}; -use std::convert::TryInto; - -use crate::frame::{ - request::{Request, RequestOpcode}, - types, - value::BatchValues, -}; - -// Batch flags -const FLAG_WITH_SERIAL_CONSISTENCY: u8 = 0x10; -const FLAG_WITH_DEFAULT_TIMESTAMP: u8 = 0x20; - -pub struct Batch<'a, StatementsIter, Values> -where - StatementsIter: Iterator> + Clone, - Values: BatchValues, -{ - pub statements: StatementsIter, - pub statements_count: usize, - pub batch_type: BatchType, - pub consistency: types::Consistency, - pub serial_consistency: Option, - pub timestamp: Option, - pub values: Values, -} - -/// The type of a batch. -#[derive(Clone, Copy)] -pub enum BatchType { - Logged = 0, - Unlogged = 1, - Counter = 2, -} - -#[derive(Debug, Clone, Copy, Eq, PartialEq, PartialOrd, Ord)] -pub enum BatchStatement<'a> { - Query { text: &'a str }, - Prepared { id: &'a Bytes }, -} - -impl<'a, StatementsIter, Values> Request for Batch<'a, StatementsIter, Values> -where - StatementsIter: Iterator> + Clone, - Values: BatchValues, -{ - const OPCODE: RequestOpcode = RequestOpcode::Batch; - - fn serialize(&self, buf: &mut impl BufMut) -> Result<(), ParseError> { - // Serializing type of batch - buf.put_u8(self.batch_type as u8); - - // Serializing queries - types::write_short(self.statements_count.try_into()?, buf); - - let counts_mismatch_err = |n_values: usize, n_statements: usize| { - ParseError::BadDataToSerialize(format!( - "Length of provided values must be equal to number of batch statements \ - (got {n_values} values, {n_statements} statements)" - )) - }; - let mut n_serialized_statements = 0usize; - let mut value_lists = self.values.batch_values_iter(); - for (idx, statement) in self.statements.clone().enumerate() { - statement.serialize(buf)?; - value_lists - .write_next_to_request(buf) - .ok_or_else(|| counts_mismatch_err(idx, self.statements.clone().count()))??; - n_serialized_statements += 1; - } - if value_lists.skip_next().is_some() { - return Err(counts_mismatch_err( - std::iter::from_fn(|| value_lists.skip_next()).count() + 1, - n_serialized_statements, - )); - } - if n_serialized_statements != self.statements_count { - // We want to check this to avoid propagating an invalid construction of self.statements_count as a - // hard-to-debug silent fail - return Err(ParseError::BadDataToSerialize(format!( - "Invalid Batch constructed: not as many statements serialized as announced \ - (batch.statement_count: {announced_statement_count}, {n_serialized_statements}", - announced_statement_count = self.statements_count - ))); - } - - // Serializing consistency - types::write_consistency(self.consistency, buf); - - // Serializing flags - let mut flags = 0; - if self.serial_consistency.is_some() { - flags |= FLAG_WITH_SERIAL_CONSISTENCY; - } - if self.timestamp.is_some() { - flags |= FLAG_WITH_DEFAULT_TIMESTAMP; - } - - buf.put_u8(flags); - - if let Some(serial_consistency) = self.serial_consistency { - types::write_serial_consistency(serial_consistency, buf); - } - if let Some(timestamp) = self.timestamp { - types::write_long(timestamp, buf); - } - - Ok(()) - } -} - -impl BatchStatement<'_> { - fn serialize(&self, buf: &mut impl BufMut) -> Result<(), ParseError> { - match self { - BatchStatement::Query { text } => { - buf.put_u8(0); - types::write_long_string(text, buf)?; - } - BatchStatement::Prepared { id } => { - buf.put_u8(1); - types::write_short_bytes(&id[..], buf)?; - } - } - - Ok(()) - } -} diff --git a/scylla-cql/src/frame/request/execute.rs b/scylla-cql/src/frame/request/execute.rs deleted file mode 100644 index 5f8a35b..0000000 --- a/scylla-cql/src/frame/request/execute.rs +++ /dev/null @@ -1,25 +0,0 @@ -use crate::frame::frame_errors::ParseError; -use bytes::{BufMut, Bytes}; - -use crate::{ - frame::request::{query, Request, RequestOpcode}, - frame::types, -}; - -pub struct Execute<'a> { - pub id: Bytes, - pub parameters: query::QueryParameters<'a>, -} - -impl Request for Execute<'_> { - const OPCODE: RequestOpcode = RequestOpcode::Execute; - - fn serialize(&self, buf: &mut impl BufMut) -> Result<(), ParseError> { - // Serializing statement id - types::write_short_bytes(&self.id[..], buf)?; - - // Serializing params - self.parameters.serialize(buf)?; - Ok(()) - } -} diff --git a/scylla-cql/src/frame/request/mod.rs b/scylla-cql/src/frame/request/mod.rs deleted file mode 100644 index f09dd3c..0000000 --- a/scylla-cql/src/frame/request/mod.rs +++ /dev/null @@ -1,44 +0,0 @@ -pub mod auth_response; -pub mod batch; -pub mod execute; -pub mod options; -pub mod prepare; -pub mod query; -pub mod register; -pub mod startup; - -use crate::frame::frame_errors::ParseError; -use bytes::{BufMut, Bytes}; -use num_enum::TryFromPrimitive; - -pub use auth_response::AuthResponse; -pub use batch::Batch; -pub use options::Options; -pub use prepare::Prepare; -pub use query::Query; -pub use startup::Startup; - -#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, TryFromPrimitive)] -#[repr(u8)] -pub enum RequestOpcode { - Startup = 0x01, - Options = 0x05, - Query = 0x07, - Prepare = 0x09, - Execute = 0x0A, - Register = 0x0B, - Batch = 0x0D, - AuthResponse = 0x0F, -} - -pub trait Request { - const OPCODE: RequestOpcode; - - fn serialize(&self, buf: &mut impl BufMut) -> Result<(), ParseError>; - - fn to_bytes(&self) -> Result { - let mut v = Vec::new(); - self.serialize(&mut v)?; - Ok(v.into()) - } -} diff --git a/scylla-cql/src/frame/request/options.rs b/scylla-cql/src/frame/request/options.rs deleted file mode 100644 index 5a0561e..0000000 --- a/scylla-cql/src/frame/request/options.rs +++ /dev/null @@ -1,14 +0,0 @@ -use crate::frame::frame_errors::ParseError; -use bytes::BufMut; - -use crate::frame::request::{Request, RequestOpcode}; - -pub struct Options; - -impl Request for Options { - const OPCODE: RequestOpcode = RequestOpcode::Options; - - fn serialize(&self, _buf: &mut impl BufMut) -> Result<(), ParseError> { - Ok(()) - } -} diff --git a/scylla-cql/src/frame/request/prepare.rs b/scylla-cql/src/frame/request/prepare.rs deleted file mode 100644 index 7141f43..0000000 --- a/scylla-cql/src/frame/request/prepare.rs +++ /dev/null @@ -1,20 +0,0 @@ -use crate::frame::frame_errors::ParseError; -use bytes::BufMut; - -use crate::{ - frame::request::{Request, RequestOpcode}, - frame::types, -}; - -pub struct Prepare<'a> { - pub query: &'a str, -} - -impl<'a> Request for Prepare<'a> { - const OPCODE: RequestOpcode = RequestOpcode::Prepare; - - fn serialize(&self, buf: &mut impl BufMut) -> Result<(), ParseError> { - types::write_long_string(self.query, buf)?; - Ok(()) - } -} diff --git a/scylla-cql/src/frame/request/query.rs b/scylla-cql/src/frame/request/query.rs deleted file mode 100644 index 7028672..0000000 --- a/scylla-cql/src/frame/request/query.rs +++ /dev/null @@ -1,110 +0,0 @@ -use crate::frame::frame_errors::ParseError; -use bytes::{BufMut, Bytes}; - -use crate::{ - frame::request::{Request, RequestOpcode}, - frame::types, - frame::value::SerializedValues, -}; - -// Query flags -// Unused flags are commented out so that they don't trigger warnings -const FLAG_VALUES: u8 = 0x01; -// const FLAG_SKIP_METADATA: u8 = 0x02; -const FLAG_PAGE_SIZE: u8 = 0x04; -const FLAG_WITH_PAGING_STATE: u8 = 0x08; -const FLAG_WITH_SERIAL_CONSISTENCY: u8 = 0x10; -const FLAG_WITH_DEFAULT_TIMESTAMP: u8 = 0x20; -const FLAG_WITH_NAMES_FOR_VALUES: u8 = 0x40; - -pub struct Query<'a> { - pub contents: &'a str, - pub parameters: QueryParameters<'a>, -} - -impl Request for Query<'_> { - const OPCODE: RequestOpcode = RequestOpcode::Query; - - fn serialize(&self, buf: &mut impl BufMut) -> Result<(), ParseError> { - types::write_long_string(self.contents, buf)?; - self.parameters.serialize(buf)?; - Ok(()) - } -} - -pub struct QueryParameters<'a> { - pub consistency: types::Consistency, - pub serial_consistency: Option, - pub timestamp: Option, - pub page_size: Option, - pub paging_state: Option, - pub values: &'a SerializedValues, -} - -impl Default for QueryParameters<'_> { - fn default() -> Self { - Self { - consistency: Default::default(), - serial_consistency: None, - timestamp: None, - page_size: None, - paging_state: None, - values: SerializedValues::EMPTY, - } - } -} - -impl QueryParameters<'_> { - pub fn serialize(&self, buf: &mut impl BufMut) -> Result<(), ParseError> { - types::write_consistency(self.consistency, buf); - - let mut flags = 0; - if !self.values.is_empty() { - flags |= FLAG_VALUES; - } - - if self.page_size.is_some() { - flags |= FLAG_PAGE_SIZE; - } - - if self.paging_state.is_some() { - flags |= FLAG_WITH_PAGING_STATE; - } - - if self.serial_consistency.is_some() { - flags |= FLAG_WITH_SERIAL_CONSISTENCY; - } - - if self.timestamp.is_some() { - flags |= FLAG_WITH_DEFAULT_TIMESTAMP; - } - - if self.values.has_names() { - flags |= FLAG_WITH_NAMES_FOR_VALUES; - } - - buf.put_u8(flags); - - if !self.values.is_empty() { - self.values.write_to_request(buf); - } - - if let Some(page_size) = self.page_size { - types::write_int(page_size, buf); - } - - if let Some(paging_state) = &self.paging_state { - types::write_bytes(paging_state, buf)?; - } - - if let Some(serial_consistency) = self.serial_consistency { - types::write_serial_consistency(serial_consistency, buf); - } - - if let Some(timestamp) = self.timestamp { - types::write_long(timestamp, buf); - } - - Ok(()) - } -} diff --git a/scylla-cql/src/frame/request/register.rs b/scylla-cql/src/frame/request/register.rs deleted file mode 100644 index 2c008b0..0000000 --- a/scylla-cql/src/frame/request/register.rs +++ /dev/null @@ -1,27 +0,0 @@ -use bytes::BufMut; - -use crate::frame::{ - frame_errors::ParseError, - request::{Request, RequestOpcode}, - server_event_type::EventType, - types, -}; - -pub struct Register { - pub event_types_to_register_for: Vec, -} - -impl Request for Register { - const OPCODE: RequestOpcode = RequestOpcode::Register; - - fn serialize(&self, buf: &mut impl BufMut) -> Result<(), ParseError> { - let event_types_list = self - .event_types_to_register_for - .iter() - .map(|event| event.to_string()) - .collect::>(); - - types::write_string_list(&event_types_list, buf)?; - Ok(()) - } -} diff --git a/scylla-cql/src/frame/request/startup.rs b/scylla-cql/src/frame/request/startup.rs deleted file mode 100644 index 5e75d85..0000000 --- a/scylla-cql/src/frame/request/startup.rs +++ /dev/null @@ -1,22 +0,0 @@ -use crate::frame::frame_errors::ParseError; -use bytes::BufMut; - -use std::collections::HashMap; - -use crate::{ - frame::request::{Request, RequestOpcode}, - frame::types, -}; - -pub struct Startup { - pub options: HashMap, -} - -impl Request for Startup { - const OPCODE: RequestOpcode = RequestOpcode::Startup; - - fn serialize(&self, buf: &mut impl BufMut) -> Result<(), ParseError> { - types::write_string_map(&self.options, buf)?; - Ok(()) - } -} diff --git a/scylla-cql/src/frame/response/authenticate.rs b/scylla-cql/src/frame/response/authenticate.rs deleted file mode 100644 index 5489a06..0000000 --- a/scylla-cql/src/frame/response/authenticate.rs +++ /dev/null @@ -1,44 +0,0 @@ -use crate::frame::frame_errors::ParseError; -use crate::frame::types; - -// Implements Authenticate message. -#[derive(Debug)] -pub struct Authenticate { - pub authenticator_name: String, -} - -impl Authenticate { - pub fn deserialize(buf: &mut &[u8]) -> Result { - let authenticator_name = types::read_string(buf)?.to_string(); - - Ok(Authenticate { authenticator_name }) - } -} - -#[derive(Debug)] -pub struct AuthSuccess { - pub success_message: Option>, -} - -impl AuthSuccess { - pub fn deserialize(buf: &mut &[u8]) -> Result { - let success_message = types::read_bytes_opt(buf)?.map(|b| b.to_owned()); - - Ok(AuthSuccess { success_message }) - } -} - -#[derive(Debug)] -pub struct AuthChallenge { - pub authenticate_message: Option>, -} - -impl AuthChallenge { - pub fn deserialize(buf: &mut &[u8]) -> Result { - let authenticate_message = types::read_bytes_opt(buf)?.map(|b| b.to_owned()); - - Ok(AuthChallenge { - authenticate_message, - }) - } -} diff --git a/scylla-cql/src/frame/response/cql_to_rust.rs b/scylla-cql/src/frame/response/cql_to_rust.rs deleted file mode 100644 index f09064b..0000000 --- a/scylla-cql/src/frame/response/cql_to_rust.rs +++ /dev/null @@ -1,737 +0,0 @@ -use super::result::{CqlValue, Row}; -use crate::frame::value::{Counter, CqlDuration}; -use bigdecimal::BigDecimal; -use chrono::{DateTime, Duration, NaiveDate, TimeZone, Utc}; -use num_bigint::BigInt; -use std::collections::{BTreeMap, BTreeSet, HashMap, HashSet}; -use std::hash::Hash; -use std::net::IpAddr; -use thiserror::Error; -use uuid::Uuid; - -#[cfg(feature = "secret")] -use secrecy::{Secret, Zeroize}; - -#[derive(Error, Debug, Clone, PartialEq, Eq)] -pub enum FromRowError { - #[error("{err} in the column with index {column}")] - BadCqlVal { err: FromCqlValError, column: usize }, - #[error("Wrong row size: expected {expected}, actual {actual}")] - WrongRowSize { expected: usize, actual: usize }, -} - -#[derive(Error, Debug, PartialEq, Eq)] -pub enum CqlTypeError { - #[error("Invalid number of set elements: {0}")] - InvalidNumberOfElements(i32), -} - -/// This trait defines a way to convert CqlValue or `Option` into some rust type -// We can't use From trait because impl From> for String {...} -// is forbidden since neither From nor String are defined in this crate -pub trait FromCqlVal: Sized { - fn from_cql(cql_val: T) -> Result; -} - -#[derive(Error, Debug, Clone, PartialEq, Eq)] -pub enum FromCqlValError { - #[error("Bad CQL type")] - BadCqlType, - #[error("Value is null")] - ValIsNull, - #[error("Bad Value")] - BadVal, -} - -/// This trait defines a way to convert CQL Row into some rust type -pub trait FromRow: Sized { - fn from_row(row: Row) -> Result; -} - -// CqlValue can be converted to CqlValue -impl FromCqlVal for CqlValue { - fn from_cql(cql_val: CqlValue) -> Result { - Ok(cql_val) - } -} - -// Implement from_cql> for every type that has from_cql -// This tries to unwrap the option or fails with an error -impl> FromCqlVal> for T { - fn from_cql(cql_val_opt: Option) -> Result { - T::from_cql(cql_val_opt.ok_or(FromCqlValError::ValIsNull)?) - } -} - -// Implement from_cql> for Option for every type that has from_cql -// Value inside Option gets mapped from CqlValue to T -impl> FromCqlVal> for Option { - fn from_cql(cql_val_opt: Option) -> Result { - match cql_val_opt { - Some(CqlValue::Empty) => Ok(None), - Some(cql_val) => Ok(Some(T::from_cql(cql_val)?)), - None => Ok(None), - } - } -} -/// This macro implements FromCqlVal given a type and method of CqlValue that returns this type. -/// -/// It can be useful in client code in case you have an extension trait for CqlValue -/// and you would like to convert one of its methods into a FromCqlVal impl. -/// The conversion method must return an `Option`. `None` values will be -/// converted to `CqlValue::BadCqlType`. -/// -/// # Example -/// ``` -/// # use scylla_cql::frame::response::result::CqlValue; -/// # use scylla_cql::impl_from_cql_value_from_method; -/// struct MyBytes(Vec); -/// -/// trait CqlValueExt { -/// fn into_my_bytes(self) -> Option; -/// } -/// -/// impl CqlValueExt for CqlValue { -/// fn into_my_bytes(self) -> Option { -/// Some(MyBytes(self.into_blob()?)) -/// } -/// } -/// -/// impl_from_cql_value_from_method!(MyBytes, into_my_bytes); -/// ``` -#[macro_export] -macro_rules! impl_from_cql_value_from_method { - ($T:ty, $convert_func:ident) => { - impl - $crate::frame::response::cql_to_rust::FromCqlVal< - $crate::frame::response::result::CqlValue, - > for $T - { - fn from_cql( - cql_val: $crate::frame::response::result::CqlValue, - ) -> std::result::Result<$T, $crate::frame::response::cql_to_rust::FromCqlValError> - { - cql_val - .$convert_func() - .ok_or($crate::frame::response::cql_to_rust::FromCqlValError::BadCqlType) - } - } - }; -} - -impl_from_cql_value_from_method!(i32, as_int); // i32::from_cql -impl_from_cql_value_from_method!(i64, as_bigint); // i64::from_cql -impl_from_cql_value_from_method!(Counter, as_counter); // Counter::from_cql -impl_from_cql_value_from_method!(i16, as_smallint); // i16::from_cql -impl_from_cql_value_from_method!(BigInt, into_varint); // BigInt::from_cql -impl_from_cql_value_from_method!(i8, as_tinyint); // i8::from_cql -impl_from_cql_value_from_method!(NaiveDate, as_date); // NaiveDate::from_cql -impl_from_cql_value_from_method!(f32, as_float); // f32::from_cql -impl_from_cql_value_from_method!(f64, as_double); // f64::from_cql -impl_from_cql_value_from_method!(bool, as_boolean); // bool::from_cql -impl_from_cql_value_from_method!(String, into_string); // String::from_cql -impl_from_cql_value_from_method!(Vec, into_blob); // Vec::from_cql -impl_from_cql_value_from_method!(IpAddr, as_inet); // IpAddr::from_cql -impl_from_cql_value_from_method!(Uuid, as_uuid); // Uuid::from_cql -impl_from_cql_value_from_method!(BigDecimal, into_decimal); // BigDecimal::from_cql -impl_from_cql_value_from_method!(Duration, as_duration); // Duration::from_cql -impl_from_cql_value_from_method!(CqlDuration, as_cql_duration); // CqlDuration::from_cql - -impl FromCqlVal for crate::frame::value::Time { - fn from_cql(cql_val: CqlValue) -> Result { - match cql_val { - CqlValue::Time(d) => Ok(Self(d)), - _ => Err(FromCqlValError::BadCqlType), - } - } -} - -impl FromCqlVal for crate::frame::value::Timestamp { - fn from_cql(cql_val: CqlValue) -> Result { - match cql_val { - CqlValue::Timestamp(d) => Ok(Self(d)), - _ => Err(FromCqlValError::BadCqlType), - } - } -} - -impl FromCqlVal for DateTime { - fn from_cql(cql_val: CqlValue) -> Result { - let timestamp = cql_val.as_bigint().ok_or(FromCqlValError::BadCqlType)?; - match Utc.timestamp_millis_opt(timestamp) { - chrono::LocalResult::Single(datetime) => Ok(datetime), - _ => Err(FromCqlValError::BadVal), - } - } -} - -#[cfg(feature = "secret")] -impl + Zeroize> FromCqlVal for Secret { - fn from_cql(cql_val: CqlValue) -> Result { - Ok(Secret::new(FromCqlVal::from_cql(cql_val)?)) - } -} - -// Vec::from_cql -impl> FromCqlVal for Vec { - fn from_cql(cql_val: CqlValue) -> Result { - cql_val - .into_vec() - .ok_or(FromCqlValError::BadCqlType)? - .into_iter() - .map(T::from_cql) - .collect::, FromCqlValError>>() - } -} - -impl + Eq + Hash, T2: FromCqlVal> FromCqlVal - for HashMap -{ - fn from_cql(cql_val: CqlValue) -> Result { - let vec = cql_val.into_pair_vec().ok_or(FromCqlValError::BadCqlType)?; - let mut res = HashMap::with_capacity(vec.len()); - for (key, value) in vec { - res.insert(T1::from_cql(key)?, T2::from_cql(value)?); - } - Ok(res) - } -} - -impl + Eq + Hash> FromCqlVal for HashSet { - fn from_cql(cql_val: CqlValue) -> Result { - cql_val - .into_vec() - .ok_or(FromCqlValError::BadCqlType)? - .into_iter() - .map(T::from_cql) - .collect::, FromCqlValError>>() - } -} - -impl + Ord> FromCqlVal for BTreeSet { - fn from_cql(cql_val: CqlValue) -> Result { - cql_val - .into_vec() - .ok_or(FromCqlValError::BadCqlType)? - .into_iter() - .map(T::from_cql) - .collect::, FromCqlValError>>() - } -} - -impl + Ord, V: FromCqlVal> FromCqlVal - for BTreeMap -{ - fn from_cql(cql_val: CqlValue) -> Result { - let vec = cql_val.into_pair_vec().ok_or(FromCqlValError::BadCqlType)?; - let mut res = BTreeMap::new(); - for (key, value) in vec { - res.insert(K::from_cql(key)?, V::from_cql(value)?); - } - Ok(res) - } -} - -macro_rules! replace_expr { - ($_t:tt $sub:expr) => { - $sub - }; -} - -// This macro implements FromRow for tuple of types that have FromCqlVal -macro_rules! impl_tuple_from_row { - ( $($Ti:tt),+ ) => { - impl<$($Ti),+> FromRow for ($($Ti,)+) - where - $($Ti: FromCqlVal>),+ - { - fn from_row(row: Row) -> Result { - // From what I know, it is not possible yet to get the number of metavariable - // repetitions (https://github.com/rust-lang/lang-team/issues/28#issue-644523674) - // This is a workaround - let expected_len = <[()]>::len(&[$(replace_expr!(($Ti) ())),*]); - - if expected_len != row.columns.len() { - return Err(FromRowError::WrongRowSize { - expected: expected_len, - actual: row.columns.len(), - }); - } - let mut vals_iter = row.columns.into_iter().enumerate(); - - Ok(( - $( - { - let (col_ix, col_value) = vals_iter - .next() - .unwrap(); // vals_iter size is checked before this code is reached, - // so it is safe to unwrap - - - $Ti::from_cql(col_value) - .map_err(|e| FromRowError::BadCqlVal { - err: e, - column: col_ix, - })? - } - ,)+ - )) - } - } - } -} - -// Implement FromRow for tuples of size up to 16 -impl_tuple_from_row!(T1); -impl_tuple_from_row!(T1, T2); -impl_tuple_from_row!(T1, T2, T3); -impl_tuple_from_row!(T1, T2, T3, T4); -impl_tuple_from_row!(T1, T2, T3, T4, T5); -impl_tuple_from_row!(T1, T2, T3, T4, T5, T6); -impl_tuple_from_row!(T1, T2, T3, T4, T5, T6, T7); -impl_tuple_from_row!(T1, T2, T3, T4, T5, T6, T7, T8); -impl_tuple_from_row!(T1, T2, T3, T4, T5, T6, T7, T8, T9); -impl_tuple_from_row!(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10); -impl_tuple_from_row!(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11); -impl_tuple_from_row!(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12); -impl_tuple_from_row!(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13); -impl_tuple_from_row!(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14); -impl_tuple_from_row!(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15); -impl_tuple_from_row!(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15, T16); - -macro_rules! impl_tuple_from_cql { - ( $($Ti:tt),+ ) => { - impl<$($Ti),+> FromCqlVal for ($($Ti,)+) - where - $($Ti: FromCqlVal>),+ - { - fn from_cql(cql_val: CqlValue) -> Result { - let tuple_fields = match cql_val { - CqlValue::Tuple(fields) => fields, - _ => return Err(FromCqlValError::BadCqlType) - }; - - let mut tuple_fields_iter = tuple_fields.into_iter(); - - Ok(( - $( - $Ti::from_cql(tuple_fields_iter.next().ok_or(FromCqlValError::BadCqlType) ?) ? - ,)+ - )) - } - } - } -} - -impl_tuple_from_cql!(T1); -impl_tuple_from_cql!(T1, T2); -impl_tuple_from_cql!(T1, T2, T3); -impl_tuple_from_cql!(T1, T2, T3, T4); -impl_tuple_from_cql!(T1, T2, T3, T4, T5); -impl_tuple_from_cql!(T1, T2, T3, T4, T5, T6); -impl_tuple_from_cql!(T1, T2, T3, T4, T5, T6, T7); -impl_tuple_from_cql!(T1, T2, T3, T4, T5, T6, T7, T8); -impl_tuple_from_cql!(T1, T2, T3, T4, T5, T6, T7, T8, T9); -impl_tuple_from_cql!(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10); -impl_tuple_from_cql!(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11); -impl_tuple_from_cql!(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12); -impl_tuple_from_cql!(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13); -impl_tuple_from_cql!(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14); -impl_tuple_from_cql!(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15); -impl_tuple_from_cql!(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15, T16); - -#[cfg(test)] -mod tests { - use super::{CqlValue, FromCqlVal, FromCqlValError, FromRow, FromRowError, Row}; - use crate as scylla; - use crate::frame::value::Counter; - use crate::macros::FromRow; - use bigdecimal::BigDecimal; - use chrono::{Duration, NaiveDate}; - use num_bigint::{BigInt, ToBigInt}; - use std::collections::HashSet; - use std::net::{IpAddr, Ipv4Addr}; - use std::str::FromStr; - use uuid::Uuid; - - #[test] - fn i32_from_cql() { - assert_eq!(Ok(1234), i32::from_cql(CqlValue::Int(1234))); - } - - #[test] - fn bool_from_cql() { - assert_eq!(Ok(true), bool::from_cql(CqlValue::Boolean(true))); - assert_eq!(Ok(false), bool::from_cql(CqlValue::Boolean(false))); - } - - #[test] - fn floatingpoints_from_cql() { - let float: f32 = 2.13; - let double: f64 = 4.26; - assert_eq!(Ok(float), f32::from_cql(CqlValue::Float(float))); - assert_eq!(Ok(double), f64::from_cql(CqlValue::Double(double))); - } - - #[test] - fn i64_from_cql() { - assert_eq!(Ok(1234), i64::from_cql(CqlValue::BigInt(1234))); - } - - #[test] - fn i8_from_cql() { - assert_eq!(Ok(6), i8::from_cql(CqlValue::TinyInt(6))); - } - - #[test] - fn i16_from_cql() { - assert_eq!(Ok(16), i16::from_cql(CqlValue::SmallInt(16))); - } - - #[test] - fn string_from_cql() { - assert_eq!( - Ok("ascii_test".to_string()), - String::from_cql(CqlValue::Ascii("ascii_test".to_string())) - ); - assert_eq!( - Ok("text_test".to_string()), - String::from_cql(CqlValue::Text("text_test".to_string())) - ); - } - - #[test] - fn ip_addr_from_cql() { - let ip_addr = IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)); - assert_eq!(Ok(ip_addr), IpAddr::from_cql(CqlValue::Inet(ip_addr))); - } - - #[test] - fn varint_from_cql() { - let big_int = 0.to_bigint().unwrap(); - assert_eq!( - Ok(big_int), - BigInt::from_cql(CqlValue::Varint(0.to_bigint().unwrap())) - ); - } - - #[test] - fn decimal_from_cql() { - let decimal = BigDecimal::from_str("123.4").unwrap(); - assert_eq!( - Ok(decimal.clone()), - BigDecimal::from_cql(CqlValue::Decimal(decimal)) - ); - } - - #[test] - fn counter_from_cql() { - let counter = Counter(1); - assert_eq!(Ok(counter), Counter::from_cql(CqlValue::Counter(counter))); - } - - #[test] - fn naive_date_from_cql() { - let unix_epoch: CqlValue = CqlValue::Date(2_u32.pow(31)); - assert_eq!( - Ok(NaiveDate::from_ymd_opt(1970, 1, 1).unwrap()), - NaiveDate::from_cql(unix_epoch) - ); - - let before_epoch: CqlValue = CqlValue::Date(2_u32.pow(31) - 30); - assert_eq!( - Ok(NaiveDate::from_ymd_opt(1969, 12, 2).unwrap()), - NaiveDate::from_cql(before_epoch) - ); - - let after_epoch: CqlValue = CqlValue::Date(2_u32.pow(31) + 30); - assert_eq!( - Ok(NaiveDate::from_ymd_opt(1970, 1, 31).unwrap()), - NaiveDate::from_cql(after_epoch) - ); - - let min_date: CqlValue = CqlValue::Date(0); - assert!(NaiveDate::from_cql(min_date).is_err()); - - let max_date: CqlValue = CqlValue::Date(u32::MAX); - assert!(NaiveDate::from_cql(max_date).is_err()); - } - - #[test] - fn duration_from_cql() { - let time_duration = Duration::nanoseconds(86399999999999); - assert_eq!( - time_duration, - Duration::from_cql(CqlValue::Time(time_duration)).unwrap(), - ); - - let timestamp_duration = Duration::milliseconds(i64::MIN); - assert_eq!( - timestamp_duration, - Duration::from_cql(CqlValue::Timestamp(timestamp_duration)).unwrap(), - ); - - let timestamp_i64 = 997; - assert_eq!( - timestamp_i64, - i64::from_cql(CqlValue::Timestamp(Duration::milliseconds(timestamp_i64))).unwrap() - ) - } - - #[test] - fn cql_duration_from_cql() { - use crate::frame::value::CqlDuration; - let cql_duration = CqlDuration { - months: 3, - days: 2, - nanoseconds: 1, - }; - assert_eq!( - cql_duration, - CqlDuration::from_cql(CqlValue::Duration(cql_duration)).unwrap(), - ); - } - - #[test] - fn time_from_cql() { - use crate::frame::value::Time; - let time_duration = Duration::nanoseconds(86399999999999); - assert_eq!( - time_duration, - Time::from_cql(CqlValue::Time(time_duration)).unwrap().0, - ); - } - - #[test] - fn timestamp_from_cql() { - use crate::frame::value::Timestamp; - let timestamp_duration = Duration::milliseconds(86399999999999); - assert_eq!( - timestamp_duration, - Timestamp::from_cql(CqlValue::Timestamp(timestamp_duration)) - .unwrap() - .0, - ); - } - - #[test] - fn datetime_from_cql() { - use chrono::{DateTime, Duration, Utc}; - let naivedatetime_utc = NaiveDate::from_ymd_opt(2022, 12, 31) - .unwrap() - .and_hms_opt(2, 0, 0) - .unwrap(); - let datetime_utc = DateTime::::from_utc(naivedatetime_utc, Utc); - - assert_eq!( - datetime_utc, - DateTime::::from_cql(CqlValue::Timestamp(Duration::milliseconds( - datetime_utc.timestamp_millis() - ))) - .unwrap() - ); - } - - #[test] - fn uuid_from_cql() { - let test_uuid: Uuid = Uuid::parse_str("8e14e760-7fa8-11eb-bc66-000000000001").unwrap(); - - assert_eq!( - test_uuid, - Uuid::from_cql(CqlValue::Uuid(test_uuid)).unwrap() - ); - - assert_eq!( - test_uuid, - Uuid::from_cql(CqlValue::Timeuuid(test_uuid)).unwrap() - ); - } - - #[test] - fn vec_from_cql() { - let cql_val = CqlValue::Set(vec![CqlValue::Int(1), CqlValue::Int(2), CqlValue::Int(3)]); - assert_eq!(Ok(vec![1, 2, 3]), Vec::::from_cql(cql_val)); - } - - #[test] - fn set_from_cql() { - let cql_val = CqlValue::Set(vec![ - CqlValue::Int(1), - CqlValue::Int(2), - CqlValue::Int(3), - CqlValue::Int(1), - CqlValue::Int(2), - CqlValue::Int(3), - ]); - assert_eq!( - Ok(vec![1, 2, 3]), - HashSet::::from_cql(cql_val).map(|value| { - let mut values = value.into_iter().collect::>(); - values.sort_unstable(); - values - }) - ); - } - - #[test] - fn tuple_from_row() { - let row = Row { - columns: vec![ - Some(CqlValue::Int(1)), - Some(CqlValue::Text("some_text".to_string())), - None, - ], - }; - - let (a, b, c) = <(i32, Option, Option)>::from_row(row).unwrap(); - assert_eq!(a, 1); - assert_eq!(b, Some("some_text".to_string())); - assert_eq!(c, None); - } - - #[test] - fn from_cql_null() { - assert_eq!(i32::from_cql(None), Err(FromCqlValError::ValIsNull)); - } - - #[test] - fn from_cql_wrong_type() { - assert_eq!( - i32::from_cql(CqlValue::BigInt(1234)), - Err(FromCqlValError::BadCqlType) - ); - } - - #[test] - fn from_cql_empty_value() { - assert_eq!( - i32::from_cql(CqlValue::Empty), - Err(FromCqlValError::BadCqlType) - ); - - assert_eq!(>::from_cql(Some(CqlValue::Empty)), Ok(None)); - } - - #[test] - fn from_row_null() { - let row = Row { - columns: vec![None], - }; - - assert_eq!( - <(i32,)>::from_row(row), - Err(FromRowError::BadCqlVal { - err: FromCqlValError::ValIsNull, - column: 0 - }) - ); - } - - #[test] - fn from_row_wrong_type() { - let row = Row { - columns: vec![Some(CqlValue::Int(1234))], - }; - - assert_eq!( - <(String,)>::from_row(row), - Err(FromRowError::BadCqlVal { - err: FromCqlValError::BadCqlType, - column: 0 - }) - ); - } - - #[test] - fn from_row_too_large() { - let row = Row { - columns: vec![Some(CqlValue::Int(1234)), Some(CqlValue::Int(1234))], - }; - - assert_eq!( - <(i32,)>::from_row(row), - Err(FromRowError::WrongRowSize { - expected: 1, - actual: 2 - }) - ); - } - - #[test] - fn from_row_too_short() { - let row = Row { - columns: vec![Some(CqlValue::Int(1234)), Some(CqlValue::Int(1234))], - }; - - assert_eq!( - <(i32, i32, i32)>::from_row(row), - Err(FromRowError::WrongRowSize { - expected: 3, - actual: 2 - }) - ); - } - - #[test] - fn struct_from_row() { - #[derive(FromRow)] - struct MyRow { - a: i32, - b: Option, - c: Option>, - } - - let row = Row { - columns: vec![ - Some(CqlValue::Int(16)), - None, - Some(CqlValue::Set(vec![CqlValue::Int(1), CqlValue::Int(2)])), - ], - }; - - let my_row: MyRow = MyRow::from_row(row).unwrap(); - - assert_eq!(my_row.a, 16); - assert_eq!(my_row.b, None); - assert_eq!(my_row.c, Some(vec![1, 2])); - } - - #[test] - fn struct_from_row_wrong_size() { - #[derive(FromRow, PartialEq, Eq, Debug)] - struct MyRow { - a: i32, - b: Option, - c: Option>, - } - - let too_short_row = Row { - columns: vec![Some(CqlValue::Int(16)), None], - }; - - let too_large_row = Row { - columns: vec![ - Some(CqlValue::Int(16)), - None, - Some(CqlValue::Set(vec![CqlValue::Int(1), CqlValue::Int(2)])), - Some(CqlValue::Set(vec![CqlValue::Int(1), CqlValue::Int(2)])), - ], - }; - - assert_eq!( - MyRow::from_row(too_short_row), - Err(FromRowError::WrongRowSize { - expected: 3, - actual: 2 - }) - ); - - assert_eq!( - MyRow::from_row(too_large_row), - Err(FromRowError::WrongRowSize { - expected: 3, - actual: 4 - }) - ); - } -} diff --git a/scylla-cql/src/frame/response/error.rs b/scylla-cql/src/frame/response/error.rs deleted file mode 100644 index bc526d0..0000000 --- a/scylla-cql/src/frame/response/error.rs +++ /dev/null @@ -1,399 +0,0 @@ -use crate::errors::{DbError, OperationType, QueryError, WriteType}; -use crate::frame::frame_errors::ParseError; -use crate::frame::protocol_features::ProtocolFeatures; -use crate::frame::types; -use byteorder::ReadBytesExt; -use bytes::Bytes; - -#[derive(Debug)] -pub struct Error { - pub error: DbError, - pub reason: String, -} - -impl Error { - pub fn deserialize(features: &ProtocolFeatures, buf: &mut &[u8]) -> Result { - let code = types::read_int(buf)?; - let reason = types::read_string(buf)?.to_owned(); - - let error: DbError = match code { - 0x0000 => DbError::ServerError, - 0x000A => DbError::ProtocolError, - 0x0100 => DbError::AuthenticationError, - 0x1000 => DbError::Unavailable { - consistency: types::read_consistency(buf)?, - required: types::read_int(buf)?, - alive: types::read_int(buf)?, - }, - 0x1001 => DbError::Overloaded, - 0x1002 => DbError::IsBootstrapping, - 0x1003 => DbError::TruncateError, - 0x1100 => DbError::WriteTimeout { - consistency: types::read_consistency(buf)?, - received: types::read_int(buf)?, - required: types::read_int(buf)?, - write_type: WriteType::from(types::read_string(buf)?), - }, - 0x1200 => DbError::ReadTimeout { - consistency: types::read_consistency(buf)?, - received: types::read_int(buf)?, - required: types::read_int(buf)?, - data_present: buf.read_u8()? != 0, - }, - 0x1300 => DbError::ReadFailure { - consistency: types::read_consistency(buf)?, - received: types::read_int(buf)?, - required: types::read_int(buf)?, - numfailures: types::read_int(buf)?, - data_present: buf.read_u8()? != 0, - }, - 0x1400 => DbError::FunctionFailure { - keyspace: types::read_string(buf)?.to_string(), - function: types::read_string(buf)?.to_string(), - arg_types: types::read_string_list(buf)?, - }, - 0x1500 => DbError::WriteFailure { - consistency: types::read_consistency(buf)?, - received: types::read_int(buf)?, - required: types::read_int(buf)?, - numfailures: types::read_int(buf)?, - write_type: WriteType::from(types::read_string(buf)?), - }, - 0x2000 => DbError::SyntaxError, - 0x2100 => DbError::Unauthorized, - 0x2200 => DbError::Invalid, - 0x2300 => DbError::ConfigError, - 0x2400 => DbError::AlreadyExists { - keyspace: types::read_string(buf)?.to_string(), - table: types::read_string(buf)?.to_string(), - }, - 0x2500 => DbError::Unprepared { - statement_id: Bytes::from(types::read_short_bytes(buf)?.to_owned()), - }, - code if Some(code) == features.rate_limit_error => DbError::RateLimitReached { - op_type: OperationType::from(buf.read_u8()?), - rejected_by_coordinator: buf.read_u8()? != 0, - }, - _ => DbError::Other(code), - }; - - Ok(Error { error, reason }) - } -} - -impl From for QueryError { - fn from(error: Error) -> QueryError { - QueryError::DbError(error.error, error.reason) - } -} - -#[cfg(test)] -mod tests { - use super::Error; - use crate::errors::{DbError, OperationType, WriteType}; - use crate::frame::protocol_features::ProtocolFeatures; - use crate::frame::types::LegacyConsistency; - use crate::Consistency; - use bytes::Bytes; - use std::convert::TryInto; - - // Serializes the beginning of an ERROR response - error code and message - // All custom data depending on the error type is appended after these bytes - fn make_error_request_bytes(error_code: i32, message: &str) -> Vec { - let mut bytes: Vec = Vec::new(); - let message_len: u16 = message.len().try_into().unwrap(); - - bytes.extend(error_code.to_be_bytes()); - bytes.extend(message_len.to_be_bytes()); - bytes.extend(message.as_bytes()); - - bytes - } - - // Tests deserialization of all errors without and additional data - #[test] - fn deserialize_simple_errors() { - let simple_error_mappings: [(i32, DbError); 11] = [ - (0x0000, DbError::ServerError), - (0x000A, DbError::ProtocolError), - (0x0100, DbError::AuthenticationError), - (0x1001, DbError::Overloaded), - (0x1002, DbError::IsBootstrapping), - (0x1003, DbError::TruncateError), - (0x2000, DbError::SyntaxError), - (0x2100, DbError::Unauthorized), - (0x2200, DbError::Invalid), - (0x2300, DbError::ConfigError), - (0x1234, DbError::Other(0x1234)), - ]; - - let features = ProtocolFeatures::default(); - - for (error_code, expected_error) in &simple_error_mappings { - let bytes: Vec = make_error_request_bytes(*error_code, "simple message"); - let error: Error = Error::deserialize(&features, &mut bytes.as_slice()).unwrap(); - assert_eq!(error.error, *expected_error); - assert_eq!(error.reason, "simple message"); - } - } - - #[test] - fn deserialize_unavailable() { - let features = ProtocolFeatures::default(); - - let mut bytes = make_error_request_bytes(0x1000, "message 2"); - bytes.extend(1_i16.to_be_bytes()); - bytes.extend(2_i32.to_be_bytes()); - bytes.extend(3_i32.to_be_bytes()); - - let error: Error = Error::deserialize(&features, &mut bytes.as_slice()).unwrap(); - - assert_eq!( - error.error, - DbError::Unavailable { - consistency: LegacyConsistency::Regular(Consistency::One), - required: 2, - alive: 3, - } - ); - assert_eq!(error.reason, "message 2"); - } - - #[test] - fn deserialize_write_timeout() { - let features = ProtocolFeatures::default(); - - let mut bytes = make_error_request_bytes(0x1100, "message 2"); - bytes.extend(0x0004_i16.to_be_bytes()); - bytes.extend((-5_i32).to_be_bytes()); - bytes.extend(100_i32.to_be_bytes()); - - let write_type_str = "SIMPLE"; - let write_type_str_len: u16 = write_type_str.len().try_into().unwrap(); - bytes.extend(write_type_str_len.to_be_bytes()); - bytes.extend(write_type_str.as_bytes()); - - let error: Error = Error::deserialize(&features, &mut bytes.as_slice()).unwrap(); - - assert_eq!( - error.error, - DbError::WriteTimeout { - consistency: LegacyConsistency::Regular(Consistency::Quorum), - received: -5, // Allow negative values when they don't make sense, it's better than crashing with ProtocolError - required: 100, - write_type: WriteType::Simple, - } - ); - assert_eq!(error.reason, "message 2"); - } - - #[test] - fn deserialize_read_timeout() { - let features = ProtocolFeatures::default(); - - let mut bytes = make_error_request_bytes(0x1200, "message 2"); - bytes.extend(0x0002_i16.to_be_bytes()); - bytes.extend(8_i32.to_be_bytes()); - bytes.extend(32_i32.to_be_bytes()); - bytes.push(0_u8); - - let error: Error = Error::deserialize(&features, &mut bytes.as_slice()).unwrap(); - - assert_eq!( - error.error, - DbError::ReadTimeout { - consistency: LegacyConsistency::Regular(Consistency::Two), - received: 8, - required: 32, - data_present: false, - } - ); - assert_eq!(error.reason, "message 2"); - } - - #[test] - fn deserialize_read_failure() { - let features = ProtocolFeatures::default(); - - let mut bytes = make_error_request_bytes(0x1300, "message 2"); - bytes.extend(0x0003_i16.to_be_bytes()); - bytes.extend(4_i32.to_be_bytes()); - bytes.extend(5_i32.to_be_bytes()); - bytes.extend(6_i32.to_be_bytes()); - bytes.push(123_u8); // Any non-zero value means data_present is true - - let error: Error = Error::deserialize(&features, &mut bytes.as_slice()).unwrap(); - - assert_eq!( - error.error, - DbError::ReadFailure { - consistency: LegacyConsistency::Regular(Consistency::Three), - received: 4, - required: 5, - numfailures: 6, - data_present: true, - } - ); - assert_eq!(error.reason, "message 2"); - } - - #[test] - fn deserialize_function_failure() { - let features = ProtocolFeatures::default(); - - let mut bytes = make_error_request_bytes(0x1400, "message 2"); - - let keyspace_name: &str = "keyspace_name"; - let keyspace_name_len: u16 = keyspace_name.len().try_into().unwrap(); - - let function_name: &str = "function_name"; - let function_name_len: u16 = function_name.len().try_into().unwrap(); - - let type1: &str = "type1"; - let type1_len: u16 = type1.len().try_into().unwrap(); - - let type2: &str = "type2"; - let type2_len: u16 = type1.len().try_into().unwrap(); - - bytes.extend(keyspace_name_len.to_be_bytes()); - bytes.extend(keyspace_name.as_bytes()); - bytes.extend(function_name_len.to_be_bytes()); - bytes.extend(function_name.as_bytes()); - bytes.extend(2_i16.to_be_bytes()); - bytes.extend(type1_len.to_be_bytes()); - bytes.extend(type1.as_bytes()); - bytes.extend(type2_len.to_be_bytes()); - bytes.extend(type2.as_bytes()); - - let error: Error = Error::deserialize(&features, &mut bytes.as_slice()).unwrap(); - - assert_eq!( - error.error, - DbError::FunctionFailure { - keyspace: "keyspace_name".to_string(), - function: "function_name".to_string(), - arg_types: vec!["type1".to_string(), "type2".to_string()] - } - ); - assert_eq!(error.reason, "message 2"); - } - - #[test] - fn deserialize_write_failure() { - let features = ProtocolFeatures::default(); - - let mut bytes = make_error_request_bytes(0x1500, "message 2"); - - bytes.extend(0x0000_i16.to_be_bytes()); - bytes.extend(2_i32.to_be_bytes()); - bytes.extend(4_i32.to_be_bytes()); - bytes.extend(8_i32.to_be_bytes()); - - let write_type_str = "COUNTER"; - let write_type_str_len: u16 = write_type_str.len().try_into().unwrap(); - bytes.extend(write_type_str_len.to_be_bytes()); - bytes.extend(write_type_str.as_bytes()); - - let error: Error = Error::deserialize(&features, &mut bytes.as_slice()).unwrap(); - - assert_eq!( - error.error, - DbError::WriteFailure { - consistency: LegacyConsistency::Regular(Consistency::Any), - received: 2, - required: 4, - numfailures: 8, - write_type: WriteType::Counter, - } - ); - assert_eq!(error.reason, "message 2"); - } - - #[test] - fn deserialize_already_exists() { - let features = ProtocolFeatures::default(); - - let mut bytes = make_error_request_bytes(0x2400, "message 2"); - - let keyspace_name: &str = "keyspace_name"; - let keyspace_name_len: u16 = keyspace_name.len().try_into().unwrap(); - - let table_name: &str = "table_name"; - let table_name_len: u16 = table_name.len().try_into().unwrap(); - - bytes.extend(keyspace_name_len.to_be_bytes()); - bytes.extend(keyspace_name.as_bytes()); - bytes.extend(table_name_len.to_be_bytes()); - bytes.extend(table_name.as_bytes()); - - let error: Error = Error::deserialize(&features, &mut bytes.as_slice()).unwrap(); - - assert_eq!( - error.error, - DbError::AlreadyExists { - keyspace: "keyspace_name".to_string(), - table: "table_name".to_string(), - } - ); - assert_eq!(error.reason, "message 2"); - } - - #[test] - fn deserialize_unprepared() { - let features = ProtocolFeatures::default(); - - let mut bytes = make_error_request_bytes(0x2500, "message 3"); - let statement_id = b"deadbeef"; - bytes.extend((statement_id.len() as i16).to_be_bytes()); - bytes.extend(statement_id); - - let error: Error = Error::deserialize(&features, &mut bytes.as_slice()).unwrap(); - - assert_eq!( - error.error, - DbError::Unprepared { - statement_id: Bytes::from_static(b"deadbeef") - } - ); - assert_eq!(error.reason, "message 3"); - } - - #[test] - fn deserialize_rate_limit_error() { - let features = ProtocolFeatures { - rate_limit_error: Some(0x4321), - ..Default::default() - }; - let mut bytes = make_error_request_bytes(0x4321, "message 1"); - bytes.extend([0u8]); // Read type - bytes.extend([1u8]); // Rejected by coordinator - let error = Error::deserialize(&features, &mut bytes.as_slice()).unwrap(); - - assert_eq!( - error.error, - DbError::RateLimitReached { - op_type: OperationType::Read, - rejected_by_coordinator: true, - } - ); - assert_eq!(error.reason, "message 1"); - - let features = ProtocolFeatures { - rate_limit_error: Some(0x8765), - ..Default::default() - }; - let mut bytes = make_error_request_bytes(0x8765, "message 2"); - bytes.extend([1u8]); // Write type - bytes.extend([0u8]); // Not rejected by coordinator - let error = Error::deserialize(&features, &mut bytes.as_slice()).unwrap(); - - assert_eq!( - error.error, - DbError::RateLimitReached { - op_type: OperationType::Write, - rejected_by_coordinator: false, - } - ); - assert_eq!(error.reason, "message 2"); - } -} diff --git a/scylla-cql/src/frame/response/event.rs b/scylla-cql/src/frame/response/event.rs deleted file mode 100644 index d738ad7..0000000 --- a/scylla-cql/src/frame/response/event.rs +++ /dev/null @@ -1,183 +0,0 @@ -use crate::frame::frame_errors::ParseError; -use crate::frame::server_event_type::EventType; -use crate::frame::types; -use std::net::SocketAddr; - -#[derive(Debug)] -pub enum Event { - TopologyChange(TopologyChangeEvent), - StatusChange(StatusChangeEvent), - SchemaChange(SchemaChangeEvent), -} - -#[derive(Debug)] -pub enum TopologyChangeEvent { - NewNode(SocketAddr), - RemovedNode(SocketAddr), -} - -#[derive(Debug)] -pub enum StatusChangeEvent { - Up(SocketAddr), - Down(SocketAddr), -} - -#[derive(Debug)] -pub enum SchemaChangeEvent { - KeyspaceChange { - change_type: SchemaChangeType, - keyspace_name: String, - }, - TableChange { - change_type: SchemaChangeType, - keyspace_name: String, - object_name: String, - }, - TypeChange { - change_type: SchemaChangeType, - keyspace_name: String, - type_name: String, - }, - FunctionChange { - change_type: SchemaChangeType, - keyspace_name: String, - function_name: String, - arguments: Vec, - }, - AggregateChange { - change_type: SchemaChangeType, - keyspace_name: String, - aggregate_name: String, - arguments: Vec, - }, -} - -#[derive(Debug)] -pub enum SchemaChangeType { - Created, - Updated, - Dropped, - Invalid, -} - -impl Event { - pub fn deserialize(buf: &mut &[u8]) -> Result { - let event_type: EventType = types::read_string(buf)?.parse()?; - match event_type { - EventType::TopologyChange => { - Ok(Self::TopologyChange(TopologyChangeEvent::deserialize(buf)?)) - } - EventType::StatusChange => Ok(Self::StatusChange(StatusChangeEvent::deserialize(buf)?)), - EventType::SchemaChange => Ok(Self::SchemaChange(SchemaChangeEvent::deserialize(buf)?)), - } - } -} - -impl SchemaChangeEvent { - pub fn deserialize(buf: &mut &[u8]) -> Result { - let type_of_change_string = types::read_string(buf)?; - let type_of_change = match type_of_change_string { - "CREATED" => SchemaChangeType::Created, - "UPDATED" => SchemaChangeType::Updated, - "DROPPED" => SchemaChangeType::Dropped, - _ => SchemaChangeType::Invalid, - }; - - let target = types::read_string(buf)?; - let keyspace_affected = types::read_string(buf)?.to_string(); - - match target { - "KEYSPACE" => Ok(Self::KeyspaceChange { - change_type: type_of_change, - keyspace_name: keyspace_affected, - }), - "TABLE" => { - let table_name = types::read_string(buf)?.to_string(); - Ok(Self::TableChange { - change_type: type_of_change, - keyspace_name: keyspace_affected, - object_name: table_name, - }) - } - "TYPE" => { - let changed_type = types::read_string(buf)?.to_string(); - Ok(Self::TypeChange { - change_type: type_of_change, - keyspace_name: keyspace_affected, - type_name: changed_type, - }) - } - "FUNCTION" => { - let function = types::read_string(buf)?.to_string(); - let number_of_arguments = types::read_short(buf)?; - - let mut argument_vector = Vec::with_capacity(number_of_arguments as usize); - - for _ in 0..number_of_arguments { - argument_vector.push(types::read_string(buf)?.to_string()); - } - - Ok(Self::FunctionChange { - change_type: type_of_change, - keyspace_name: keyspace_affected, - function_name: function, - arguments: argument_vector, - }) - } - "AGGREGATE" => { - let name = types::read_string(buf)?.to_string(); - let number_of_arguments = types::read_short(buf)?; - - let mut argument_vector = Vec::with_capacity(number_of_arguments as usize); - - for _ in 0..number_of_arguments { - argument_vector.push(types::read_string(buf)?.to_string()); - } - - Ok(Self::AggregateChange { - change_type: type_of_change, - keyspace_name: keyspace_affected, - aggregate_name: name, - arguments: argument_vector, - }) - } - - _ => Err(ParseError::BadIncomingData(format!( - "Invalid type of schema change ({}) in SchemaChangeEvent", - target - ))), - } - } -} - -impl TopologyChangeEvent { - pub fn deserialize(buf: &mut &[u8]) -> Result { - let type_of_change = types::read_string(buf)?; - let addr = types::read_inet(buf)?; - - match type_of_change { - "NEW_NODE" => Ok(Self::NewNode(addr)), - "REMOVED_NODE" => Ok(Self::RemovedNode(addr)), - _ => Err(ParseError::BadIncomingData(format!( - "Invalid type of change ({}) in TopologyChangeEvent", - type_of_change - ))), - } - } -} - -impl StatusChangeEvent { - pub fn deserialize(buf: &mut &[u8]) -> Result { - let type_of_change = types::read_string(buf)?; - let addr = types::read_inet(buf)?; - - match type_of_change { - "UP" => Ok(Self::Up(addr)), - "DOWN" => Ok(Self::Down(addr)), - _ => Err(ParseError::BadIncomingData(format!( - "Invalid type of status change ({}) in StatusChangeEvent", - type_of_change - ))), - } - } -} diff --git a/scylla-cql/src/frame/response/mod.rs b/scylla-cql/src/frame/response/mod.rs deleted file mode 100644 index c8c4ec1..0000000 --- a/scylla-cql/src/frame/response/mod.rs +++ /dev/null @@ -1,90 +0,0 @@ -pub mod authenticate; -pub mod cql_to_rust; -pub mod error; -pub mod event; -pub mod result; -pub mod supported; - -use crate::{errors::QueryError, frame::frame_errors::ParseError}; -use num_enum::TryFromPrimitive; - -use crate::frame::protocol_features::ProtocolFeatures; -pub use error::Error; -pub use supported::Supported; - -#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, TryFromPrimitive)] -#[repr(u8)] -pub enum ResponseOpcode { - Error = 0x00, - Ready = 0x02, - Authenticate = 0x03, - Supported = 0x06, - Result = 0x08, - Event = 0x0C, - AuthChallenge = 0x0E, - AuthSuccess = 0x10, -} - -#[derive(Debug)] -pub enum Response { - Error(Error), - Ready, - Result(result::Result), - Authenticate(authenticate::Authenticate), - AuthSuccess(authenticate::AuthSuccess), - AuthChallenge(authenticate::AuthChallenge), - Supported(Supported), - Event(event::Event), -} - -impl Response { - pub fn deserialize( - features: &ProtocolFeatures, - opcode: ResponseOpcode, - buf: &mut &[u8], - ) -> Result { - let response = match opcode { - ResponseOpcode::Error => Response::Error(Error::deserialize(features, buf)?), - ResponseOpcode::Ready => Response::Ready, - ResponseOpcode::Authenticate => { - Response::Authenticate(authenticate::Authenticate::deserialize(buf)?) - } - ResponseOpcode::Supported => Response::Supported(Supported::deserialize(buf)?), - ResponseOpcode::Result => Response::Result(result::deserialize(buf)?), - ResponseOpcode::Event => Response::Event(event::Event::deserialize(buf)?), - ResponseOpcode::AuthChallenge => { - Response::AuthChallenge(authenticate::AuthChallenge::deserialize(buf)?) - } - ResponseOpcode::AuthSuccess => { - Response::AuthSuccess(authenticate::AuthSuccess::deserialize(buf)?) - } - }; - - Ok(response) - } - - pub fn into_non_error_response(self) -> Result { - Ok(match self { - Response::Error(err) => return Err(QueryError::from(err)), - Response::Ready => NonErrorResponse::Ready, - Response::Result(res) => NonErrorResponse::Result(res), - Response::Authenticate(auth) => NonErrorResponse::Authenticate(auth), - Response::AuthSuccess(auth_succ) => NonErrorResponse::AuthSuccess(auth_succ), - Response::AuthChallenge(auth_chal) => NonErrorResponse::AuthChallenge(auth_chal), - Response::Supported(sup) => NonErrorResponse::Supported(sup), - Response::Event(eve) => NonErrorResponse::Event(eve), - }) - } -} - -// A Response which can not be Response::Error -#[derive(Debug)] -pub enum NonErrorResponse { - Ready, - Result(result::Result), - Authenticate(authenticate::Authenticate), - AuthSuccess(authenticate::AuthSuccess), - AuthChallenge(authenticate::AuthChallenge), - Supported(Supported), - Event(event::Event), -} diff --git a/scylla-cql/src/frame/response/result.rs b/scylla-cql/src/frame/response/result.rs deleted file mode 100644 index 3608a1f..0000000 --- a/scylla-cql/src/frame/response/result.rs +++ /dev/null @@ -1,1441 +0,0 @@ -use crate::cql_to_rust::{FromRow, FromRowError}; -use crate::frame::response::event::SchemaChangeEvent; -use crate::frame::types::vint_decode; -use crate::frame::value::{Counter, CqlDuration}; -use crate::frame::{frame_errors::ParseError, types}; -use bigdecimal::BigDecimal; -use byteorder::{BigEndian, ReadBytesExt}; -use bytes::{Buf, Bytes}; -use chrono; -use chrono::prelude::*; -use num_bigint::BigInt; -use std::{ - convert::{TryFrom, TryInto}, - net::IpAddr, - result::Result as StdResult, - str, -}; -use uuid::Uuid; - -#[derive(Debug)] -pub struct SetKeyspace { - pub keyspace_name: String, -} - -#[derive(Debug)] -pub struct Prepared { - pub id: Bytes, - pub prepared_metadata: PreparedMetadata, - pub result_metadata: ResultMetadata, -} - -#[derive(Debug)] -pub struct SchemaChange { - pub event: SchemaChangeEvent, -} - -#[derive(Clone, Debug)] -pub struct TableSpec { - pub ks_name: String, - pub table_name: String, -} - -#[derive(Debug, Clone)] -pub enum ColumnType { - Custom(String), - Ascii, - Boolean, - Blob, - Counter, - Date, - Decimal, - Double, - Duration, - Float, - Int, - BigInt, - Text, - Timestamp, - Inet, - List(Box), - Map(Box, Box), - Set(Box), - UserDefinedType { - type_name: String, - keyspace: String, - field_types: Vec<(String, ColumnType)>, - }, - SmallInt, - TinyInt, - Time, - Timeuuid, - Tuple(Vec), - Uuid, - Varint, -} - -#[derive(Clone, Debug, PartialEq)] -pub enum CqlValue { - Ascii(String), - Boolean(bool), - Blob(Vec), - Counter(Counter), - Decimal(BigDecimal), - /// Days since -5877641-06-23 i.e. 2^31 days before unix epoch - /// Can be converted to chrono::NaiveDate (-262145-1-1 to 262143-12-31) using as_date - Date(u32), - Double(f64), - Duration(CqlDuration), - Empty, - Float(f32), - Int(i32), - BigInt(i64), - Text(String), - /// Milliseconds since unix epoch - Timestamp(chrono::Duration), - Inet(IpAddr), - List(Vec), - Map(Vec<(CqlValue, CqlValue)>), - Set(Vec), - UserDefinedType { - keyspace: String, - type_name: String, - /// Order of `fields` vector must match the order of fields as defined in the UDT. The - /// driver does not check it by itself, so incorrect data will be written if the order is - /// wrong. - fields: Vec<(String, Option)>, - }, - SmallInt(i16), - TinyInt(i8), - /// Nanoseconds since midnight - Time(chrono::Duration), - Timeuuid(Uuid), - Tuple(Vec>), - Uuid(Uuid), - Varint(BigInt), -} - -impl CqlValue { - pub fn as_ascii(&self) -> Option<&String> { - match self { - Self::Ascii(s) => Some(s), - _ => None, - } - } - - pub fn as_date(&self) -> Option { - // Days since -5877641-06-23 i.e. 2^31 days before unix epoch - let date_days: u32 = match self { - CqlValue::Date(days) => *days, - _ => return None, - }; - - // date_days is u32 then converted to i64 - // then we substract 2^31 - this can't panic - let days_since_epoch = - chrono::Duration::days(date_days.into()) - chrono::Duration::days(1 << 31); - - NaiveDate::from_ymd_opt(1970, 1, 1) - .unwrap() - .checked_add_signed(days_since_epoch) - } - - pub fn as_duration(&self) -> Option { - match self { - Self::Timestamp(i) => Some(*i), - Self::Time(i) => Some(*i), - _ => None, - } - } - - pub fn as_cql_duration(&self) -> Option { - match self { - Self::Duration(i) => Some(*i), - _ => None, - } - } - - pub fn as_counter(&self) -> Option { - match self { - Self::Counter(i) => Some(*i), - _ => None, - } - } - - pub fn as_boolean(&self) -> Option { - match self { - Self::Boolean(i) => Some(*i), - _ => None, - } - } - - pub fn as_double(&self) -> Option { - match self { - Self::Double(d) => Some(*d), - _ => None, - } - } - - pub fn as_uuid(&self) -> Option { - match self { - Self::Uuid(u) => Some(*u), - Self::Timeuuid(u) => Some(*u), - _ => None, - } - } - - pub fn as_float(&self) -> Option { - match self { - Self::Float(f) => Some(*f), - _ => None, - } - } - - pub fn as_int(&self) -> Option { - match self { - Self::Int(i) => Some(*i), - _ => None, - } - } - - pub fn as_bigint(&self) -> Option { - match self { - Self::BigInt(i) => Some(*i), - Self::Timestamp(d) => Some(d.num_milliseconds()), - _ => None, - } - } - - pub fn as_tinyint(&self) -> Option { - match self { - Self::TinyInt(i) => Some(*i), - _ => None, - } - } - - pub fn as_smallint(&self) -> Option { - match self { - Self::SmallInt(i) => Some(*i), - _ => None, - } - } - - pub fn as_blob(&self) -> Option<&Vec> { - match self { - Self::Blob(v) => Some(v), - _ => None, - } - } - - pub fn as_text(&self) -> Option<&String> { - match self { - Self::Text(s) => Some(s), - _ => None, - } - } - - pub fn as_timeuuid(&self) -> Option { - match self { - Self::Timeuuid(u) => Some(*u), - _ => None, - } - } - - pub fn into_string(self) -> Option { - match self { - Self::Ascii(s) => Some(s), - Self::Text(s) => Some(s), - _ => None, - } - } - - pub fn into_blob(self) -> Option> { - match self { - Self::Blob(b) => Some(b), - _ => None, - } - } - - pub fn as_inet(&self) -> Option { - match self { - Self::Inet(a) => Some(*a), - _ => None, - } - } - - pub fn as_list(&self) -> Option<&Vec> { - match self { - Self::List(s) => Some(s), - _ => None, - } - } - - pub fn as_set(&self) -> Option<&Vec> { - match self { - Self::Set(s) => Some(s), - _ => None, - } - } - - pub fn as_map(&self) -> Option<&Vec<(CqlValue, CqlValue)>> { - match self { - Self::Map(s) => Some(s), - _ => None, - } - } - - pub fn as_udt(&self) -> Option<&Vec<(String, Option)>> { - match self { - Self::UserDefinedType { fields, .. } => Some(fields), - _ => None, - } - } - - pub fn into_vec(self) -> Option> { - match self { - Self::List(s) => Some(s), - Self::Set(s) => Some(s), - _ => None, - } - } - - pub fn into_pair_vec(self) -> Option> { - match self { - Self::Map(s) => Some(s), - _ => None, - } - } - - pub fn into_udt_pair_vec(self) -> Option)>> { - match self { - Self::UserDefinedType { fields, .. } => Some(fields), - _ => None, - } - } - - pub fn into_varint(self) -> Option { - match self { - Self::Varint(i) => Some(i), - _ => None, - } - } - - pub fn into_decimal(self) -> Option { - match self { - Self::Decimal(i) => Some(i), - _ => None, - } - } - // TODO -} - -#[derive(Debug, Clone)] -pub struct ColumnSpec { - pub table_spec: TableSpec, - pub name: String, - pub typ: ColumnType, -} - -#[derive(Debug, Default)] -pub struct ResultMetadata { - col_count: usize, - pub paging_state: Option, - pub col_specs: Vec, -} - -#[derive(Debug, Copy, Clone)] -pub struct PartitionKeyIndex { - /// index in the serialized values - pub index: u16, - /// sequence number in partition key - pub sequence: u16, -} - -#[derive(Debug, Clone)] -pub struct PreparedMetadata { - pub flags: i32, - pub col_count: usize, - /// pk_indexes are sorted by `index` and can be reordered in partition key order - /// using `sequence` field - pub pk_indexes: Vec, - pub col_specs: Vec, -} - -#[derive(Debug, Default, PartialEq)] -pub struct Row { - pub columns: Vec>, -} - -impl Row { - /// Allows converting Row into tuple of rust types or custom struct deriving FromRow - pub fn into_typed(self) -> StdResult { - RowT::from_row(self) - } -} - -#[derive(Debug)] -pub struct Rows { - pub metadata: ResultMetadata, - pub rows_count: usize, - pub rows: Vec, -} - -#[derive(Debug)] -pub enum Result { - Void, - Rows(Rows), - SetKeyspace(SetKeyspace), - Prepared(Prepared), - SchemaChange(SchemaChange), -} - -fn deser_table_spec(buf: &mut &[u8]) -> StdResult { - let ks_name = types::read_string(buf)?.to_owned(); - let table_name = types::read_string(buf)?.to_owned(); - Ok(TableSpec { - ks_name, - table_name, - }) -} - -fn deser_type(buf: &mut &[u8]) -> StdResult { - use ColumnType::*; - let id = types::read_short(buf)?; - Ok(match id { - 0x0000 => { - let type_str: String = types::read_string(buf)?.to_string(); - match type_str.as_str() { - "org.apache.cassandra.db.marshal.DurationType" => Duration, - _ => Custom(type_str), - } - } - 0x0001 => Ascii, - 0x0002 => BigInt, - 0x0003 => Blob, - 0x0004 => Boolean, - 0x0005 => Counter, - 0x0006 => Decimal, - 0x0007 => Double, - 0x0008 => Float, - 0x0009 => Int, - 0x000B => Timestamp, - 0x000C => Uuid, - 0x000D => Text, - 0x000E => Varint, - 0x000F => Timeuuid, - 0x0010 => Inet, - 0x0011 => Date, - 0x0012 => Time, - 0x0013 => SmallInt, - 0x0014 => TinyInt, - 0x0015 => Duration, - 0x0020 => List(Box::new(deser_type(buf)?)), - 0x0021 => Map(Box::new(deser_type(buf)?), Box::new(deser_type(buf)?)), - 0x0022 => Set(Box::new(deser_type(buf)?)), - 0x0030 => { - let keyspace_name: String = types::read_string(buf)?.to_string(); - let type_name: String = types::read_string(buf)?.to_string(); - let fields_size: usize = types::read_short(buf)?.try_into()?; - - let mut field_types: Vec<(String, ColumnType)> = Vec::with_capacity(fields_size); - - for _ in 0..fields_size { - let field_name: String = types::read_string(buf)?.to_string(); - let field_type: ColumnType = deser_type(buf)?; - - field_types.push((field_name, field_type)); - } - - UserDefinedType { - type_name, - keyspace: keyspace_name, - field_types, - } - } - 0x0031 => { - let len: usize = types::read_short(buf)?.try_into()?; - let mut types = Vec::with_capacity(len); - for _ in 0..len { - types.push(deser_type(buf)?); - } - Tuple(types) - } - id => { - // TODO implement other types - return Err(ParseError::TypeNotImplemented(id)); - } - }) -} - -fn deser_col_specs( - buf: &mut &[u8], - global_table_spec: &Option, - col_count: usize, -) -> StdResult, ParseError> { - let mut col_specs = Vec::with_capacity(col_count); - for _ in 0..col_count { - let table_spec = if let Some(spec) = global_table_spec { - spec.clone() - } else { - deser_table_spec(buf)? - }; - let name = types::read_string(buf)?.to_owned(); - let typ = deser_type(buf)?; - col_specs.push(ColumnSpec { - table_spec, - name, - typ, - }); - } - Ok(col_specs) -} - -fn deser_result_metadata(buf: &mut &[u8]) -> StdResult { - let flags = types::read_int(buf)?; - let global_tables_spec = flags & 0x0001 != 0; - let has_more_pages = flags & 0x0002 != 0; - let no_metadata = flags & 0x0004 != 0; - - let col_count: usize = types::read_int(buf)?.try_into()?; - - let paging_state = if has_more_pages { - Some(types::read_bytes(buf)?.to_owned().into()) - } else { - None - }; - - if no_metadata { - return Ok(ResultMetadata { - col_count, - paging_state, - col_specs: vec![], - }); - } - - let global_table_spec = if global_tables_spec { - Some(deser_table_spec(buf)?) - } else { - None - }; - - let col_specs = deser_col_specs(buf, &global_table_spec, col_count)?; - - Ok(ResultMetadata { - col_count, - paging_state, - col_specs, - }) -} - -fn deser_prepared_metadata(buf: &mut &[u8]) -> StdResult { - let flags = types::read_int(buf)?; - let global_tables_spec = flags & 0x0001 != 0; - - let col_count = types::read_int_length(buf)?; - - let pk_count: usize = types::read_int(buf)?.try_into()?; - - let mut pk_indexes = Vec::with_capacity(pk_count); - for i in 0..pk_count { - pk_indexes.push(PartitionKeyIndex { - index: types::read_short(buf)? as u16, - sequence: i as u16, - }); - } - pk_indexes.sort_unstable_by_key(|pki| pki.index); - - let global_table_spec = if global_tables_spec { - Some(deser_table_spec(buf)?) - } else { - None - }; - - let col_specs = deser_col_specs(buf, &global_table_spec, col_count)?; - - Ok(PreparedMetadata { - flags, - col_count, - pk_indexes, - col_specs, - }) -} - -pub fn deser_cql_value(typ: &ColumnType, buf: &mut &[u8]) -> StdResult { - use ColumnType::*; - - if buf.is_empty() { - match typ { - Ascii | Blob | Text => { - // can't be empty - } - _ => return Ok(CqlValue::Empty), - } - } - - Ok(match typ { - Custom(type_str) => { - return Err(ParseError::BadIncomingData(format!( - "Support for custom types is not yet implemented: {}", - type_str - ))); - } - Ascii => { - if !buf.is_ascii() { - return Err(ParseError::BadIncomingData( - "String is not ascii!".to_string(), - )); - } - CqlValue::Ascii(str::from_utf8(buf)?.to_owned()) - } - Boolean => { - if buf.len() != 1 { - return Err(ParseError::BadIncomingData(format!( - "Buffer length should be 1 not {}", - buf.len() - ))); - } - CqlValue::Boolean(buf[0] != 0x00) - } - Blob => CqlValue::Blob(buf.to_vec()), - Date => { - if buf.len() != 4 { - return Err(ParseError::BadIncomingData(format!( - "Buffer length should be 4 not {}", - buf.len() - ))); - } - - let date_value = buf.read_u32::()?; - CqlValue::Date(date_value) - } - Counter => { - if buf.len() != 8 { - return Err(ParseError::BadIncomingData(format!( - "Buffer length should be 8 not {}", - buf.len() - ))); - } - CqlValue::Counter(crate::frame::value::Counter(buf.read_i64::()?)) - } - Decimal => { - let scale = types::read_int(buf)? as i64; - let int_value = num_bigint::BigInt::from_signed_bytes_be(buf); - let big_decimal: BigDecimal = BigDecimal::from((int_value, scale)); - - CqlValue::Decimal(big_decimal) - } - Double => { - if buf.len() != 8 { - return Err(ParseError::BadIncomingData(format!( - "Buffer length should be 8 not {}", - buf.len() - ))); - } - CqlValue::Double(buf.read_f64::()?) - } - Float => { - if buf.len() != 4 { - return Err(ParseError::BadIncomingData(format!( - "Buffer length should be 4 not {}", - buf.len() - ))); - } - CqlValue::Float(buf.read_f32::()?) - } - Int => { - if buf.len() != 4 { - return Err(ParseError::BadIncomingData(format!( - "Buffer length should be 4 not {}", - buf.len() - ))); - } - CqlValue::Int(buf.read_i32::()?) - } - SmallInt => { - if buf.len() != 2 { - return Err(ParseError::BadIncomingData(format!( - "Buffer length should be 2 not {}", - buf.len() - ))); - } - - CqlValue::SmallInt(buf.read_i16::()?) - } - TinyInt => { - if buf.len() != 1 { - return Err(ParseError::BadIncomingData(format!( - "Buffer length should be 1 not {}", - buf.len() - ))); - } - CqlValue::TinyInt(buf.read_i8()?) - } - BigInt => { - if buf.len() != 8 { - return Err(ParseError::BadIncomingData(format!( - "Buffer length should be 8 not {}", - buf.len() - ))); - } - CqlValue::BigInt(buf.read_i64::()?) - } - Text => CqlValue::Text(str::from_utf8(buf)?.to_owned()), - Timestamp => { - if buf.len() != 8 { - return Err(ParseError::BadIncomingData(format!( - "Buffer length should be 8 not {}", - buf.len() - ))); - } - let millis = buf.read_i64::()?; - - CqlValue::Timestamp(chrono::Duration::milliseconds(millis)) - } - Time => { - if buf.len() != 8 { - return Err(ParseError::BadIncomingData(format!( - "Buffer length should be 8 not {}", - buf.len() - ))); - } - let nanoseconds: i64 = buf.read_i64::()?; - - // Valid values are in the range 0 to 86399999999999 - if !(0..=86399999999999).contains(&nanoseconds) { - return Err(ParseError::BadIncomingData(format! { - "Invalid time value only 0 to 86399999999999 allowed: {}.", nanoseconds - })); - } - - CqlValue::Time(chrono::Duration::nanoseconds(nanoseconds)) - } - Timeuuid => { - if buf.len() != 16 { - return Err(ParseError::BadIncomingData(format!( - "Buffer length should be 16 not {}", - buf.len() - ))); - } - let uuid = uuid::Uuid::from_slice(buf).expect("Deserializing Uuid failed."); - CqlValue::Timeuuid(uuid) - } - Duration => { - let months = i32::try_from(vint_decode(buf)?)?; - let days = i32::try_from(vint_decode(buf)?)?; - let nanoseconds = vint_decode(buf)?; - - CqlValue::Duration(CqlDuration { - months, - days, - nanoseconds, - }) - } - Inet => CqlValue::Inet(match buf.len() { - 4 => { - let ret = IpAddr::from(<[u8; 4]>::try_from(&buf[0..4])?); - buf.advance(4); - ret - } - 16 => { - let ret = IpAddr::from(<[u8; 16]>::try_from(&buf[0..16])?); - buf.advance(16); - ret - } - v => { - return Err(ParseError::BadIncomingData(format!( - "Invalid inet bytes length: {}", - v - ))); - } - }), - Uuid => { - if buf.len() != 16 { - return Err(ParseError::BadIncomingData(format!( - "Buffer length should be 16 not {}", - buf.len() - ))); - } - let uuid = uuid::Uuid::from_slice(buf).expect("Deserializing Uuid failed."); - CqlValue::Uuid(uuid) - } - Varint => CqlValue::Varint(num_bigint::BigInt::from_signed_bytes_be(buf)), - List(type_name) => { - let len: usize = types::read_int(buf)?.try_into()?; - let mut res = Vec::with_capacity(len); - for _ in 0..len { - let mut b = types::read_bytes(buf)?; - res.push(deser_cql_value(type_name, &mut b)?); - } - CqlValue::List(res) - } - Map(key_type, value_type) => { - let len: usize = types::read_int(buf)?.try_into()?; - let mut res = Vec::with_capacity(len); - for _ in 0..len { - let mut b = types::read_bytes(buf)?; - let key = deser_cql_value(key_type, &mut b)?; - b = types::read_bytes(buf)?; - let val = deser_cql_value(value_type, &mut b)?; - res.push((key, val)); - } - CqlValue::Map(res) - } - Set(type_name) => { - let len: usize = types::read_int(buf)?.try_into()?; - let mut res = Vec::with_capacity(len); - for _ in 0..len { - // TODO: is `null` allowed as set element? Should we use read_bytes_opt? - let mut b = types::read_bytes(buf)?; - res.push(deser_cql_value(type_name, &mut b)?); - } - CqlValue::Set(res) - } - UserDefinedType { - type_name, - keyspace, - field_types, - } => { - let mut fields: Vec<(String, Option)> = Vec::new(); - - for (field_name, field_type) in field_types { - // If a field is added to a UDT and we read an old (frozen ?) version of it, - // the driver will fail to parse the whole UDT. - // This is why we break the parsing after we reach the end of the serialized UDT. - if buf.is_empty() { - break; - } - - let mut field_value: Option = None; - if let Some(mut field_val_bytes) = types::read_bytes_opt(buf)? { - field_value = Some(deser_cql_value(field_type, &mut field_val_bytes)?); - } - - fields.push((field_name.clone(), field_value)); - } - - CqlValue::UserDefinedType { - keyspace: keyspace.clone(), - type_name: type_name.clone(), - fields, - } - } - Tuple(type_names) => { - let mut res = Vec::with_capacity(type_names.len()); - for type_name in type_names { - match types::read_bytes_opt(buf)? { - Some(mut b) => res.push(Some(deser_cql_value(type_name, &mut b)?)), - None => res.push(None), - }; - } - - CqlValue::Tuple(res) - } - }) -} - -fn deser_rows(buf: &mut &[u8]) -> StdResult { - let metadata = deser_result_metadata(buf)?; - - // TODO: the protocol allows an optimization (which must be explicitly requested on query by - // the driver) where the column metadata is not sent with the result. - // Implement this optimization. We'll then need to take the column types by a parameter. - // Beware of races; our column types may be outdated. - assert!(metadata.col_count == metadata.col_specs.len()); - - let rows_count: usize = types::read_int(buf)?.try_into()?; - - let mut rows = Vec::with_capacity(rows_count); - for _ in 0..rows_count { - let mut columns = Vec::with_capacity(metadata.col_count); - for i in 0..metadata.col_count { - let v = if let Some(mut b) = types::read_bytes_opt(buf)? { - Some(deser_cql_value(&metadata.col_specs[i].typ, &mut b)?) - } else { - None - }; - columns.push(v); - } - rows.push(Row { columns }); - } - Ok(Rows { - metadata, - rows_count, - rows, - }) -} - -fn deser_set_keyspace(buf: &mut &[u8]) -> StdResult { - let keyspace_name = types::read_string(buf)?.to_string(); - - Ok(SetKeyspace { keyspace_name }) -} - -fn deser_prepared(buf: &mut &[u8]) -> StdResult { - let id_len = types::read_short(buf)? as usize; - let id: Bytes = buf[0..id_len].to_owned().into(); - buf.advance(id_len); - let prepared_metadata = deser_prepared_metadata(buf)?; - let result_metadata = deser_result_metadata(buf)?; - Ok(Prepared { - id, - prepared_metadata, - result_metadata, - }) -} - -#[allow(clippy::unnecessary_wraps)] -fn deser_schema_change(buf: &mut &[u8]) -> StdResult { - Ok(SchemaChange { - event: SchemaChangeEvent::deserialize(buf)?, - }) -} - -pub fn deserialize(buf: &mut &[u8]) -> StdResult { - use self::Result::*; - Ok(match types::read_int(buf)? { - 0x0001 => Void, - 0x0002 => Rows(deser_rows(buf)?), - 0x0003 => SetKeyspace(deser_set_keyspace(buf)?), - 0x0004 => Prepared(deser_prepared(buf)?), - 0x0005 => SchemaChange(deser_schema_change(buf)?), - k => { - return Err(ParseError::BadIncomingData(format!( - "Unknown query result id: {}", - k - ))) - } - }) -} - -#[cfg(test)] -mod tests { - use crate as scylla; - use crate::frame::value::{Counter, CqlDuration}; - use bigdecimal::BigDecimal; - use chrono::Duration; - use chrono::NaiveDate; - use num_bigint::BigInt; - use num_bigint::ToBigInt; - use scylla::frame::response::result::{ColumnType, CqlValue}; - use std::str::FromStr; - use uuid::Uuid; - - #[test] - fn test_deserialize_text_types() { - let buf: Vec = vec![0x41]; - let int_slice = &mut &buf[..]; - let ascii_serialized = super::deser_cql_value(&ColumnType::Ascii, int_slice).unwrap(); - let text_serialized = super::deser_cql_value(&ColumnType::Text, int_slice).unwrap(); - assert_eq!(ascii_serialized, CqlValue::Ascii("A".to_string())); - assert_eq!(text_serialized, CqlValue::Text("A".to_string())); - } - - #[test] - fn test_deserialize_uuid_inet_types() { - let my_uuid = Uuid::parse_str("00000000000000000000000000000001").unwrap(); - - let uuid_buf: Vec = vec![0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1]; - let uuid_slice = &mut &uuid_buf[..]; - let uuid_serialize = super::deser_cql_value(&ColumnType::Uuid, uuid_slice).unwrap(); - assert_eq!(uuid_serialize, CqlValue::Uuid(my_uuid)); - - let time_uuid_serialize = - super::deser_cql_value(&ColumnType::Timeuuid, uuid_slice).unwrap(); - assert_eq!(time_uuid_serialize, CqlValue::Timeuuid(my_uuid)); - - let my_ip = "::1".parse().unwrap(); - let ip_buf: Vec = vec![0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1]; - let ip_slice = &mut &ip_buf[..]; - let ip_serialize = super::deser_cql_value(&ColumnType::Inet, ip_slice).unwrap(); - assert_eq!(ip_serialize, CqlValue::Inet(my_ip)); - - let max_ip = "ffff:ffff:ffff:ffff:ffff:ffff:ffff:ffff".parse().unwrap(); - let max_ip_buf: Vec = vec![ - 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, - ]; - let max_ip_slice = &mut &max_ip_buf[..]; - let max_ip_serialize = super::deser_cql_value(&ColumnType::Inet, max_ip_slice).unwrap(); - assert_eq!(max_ip_serialize, CqlValue::Inet(max_ip)); - } - - #[test] - fn test_floating_points() { - let float: f32 = 0.5; - let double: f64 = 2.0; - - let float_buf: Vec = vec![63, 0, 0, 0]; - let float_slice = &mut &float_buf[..]; - let float_serialize = super::deser_cql_value(&ColumnType::Float, float_slice).unwrap(); - assert_eq!(float_serialize, CqlValue::Float(float)); - - let double_buf: Vec = vec![64, 0, 0, 0, 0, 0, 0, 0]; - let double_slice = &mut &double_buf[..]; - let double_serialize = super::deser_cql_value(&ColumnType::Double, double_slice).unwrap(); - assert_eq!(double_serialize, CqlValue::Double(double)); - } - - #[test] - fn test_varint() { - struct Test<'a> { - value: BigInt, - encoding: &'a [u8], - } - - /* - Table taken from CQL Binary Protocol v4 spec - - Value | Encoding - ------|--------- - 0 | 0x00 - 1 | 0x01 - 127 | 0x7F - 128 | 0x0080 - 129 | 0x0081 - -1 | 0xFF - -128 | 0x80 - -129 | 0xFF7F - */ - let tests = [ - Test { - value: 0.to_bigint().unwrap(), - encoding: &[0x00], - }, - Test { - value: 1.to_bigint().unwrap(), - encoding: &[0x01], - }, - Test { - value: 127.to_bigint().unwrap(), - encoding: &[0x7F], - }, - Test { - value: 128.to_bigint().unwrap(), - encoding: &[0x00, 0x80], - }, - Test { - value: 129.to_bigint().unwrap(), - encoding: &[0x00, 0x81], - }, - Test { - value: (-1).to_bigint().unwrap(), - encoding: &[0xFF], - }, - Test { - value: (-128).to_bigint().unwrap(), - encoding: &[0x80], - }, - Test { - value: (-129).to_bigint().unwrap(), - encoding: &[0xFF, 0x7F], - }, - ]; - - for t in tests.iter() { - let value = super::deser_cql_value(&ColumnType::Varint, &mut &*t.encoding).unwrap(); - assert_eq!(CqlValue::Varint(t.value.clone()), value); - } - } - - #[test] - fn test_decimal() { - struct Test<'a> { - value: BigDecimal, - encoding: &'a [u8], - } - - let tests = [ - Test { - value: BigDecimal::from_str("-1.28").unwrap(), - encoding: &[0x0, 0x0, 0x0, 0x2, 0x80], - }, - Test { - value: BigDecimal::from_str("1.29").unwrap(), - encoding: &[0x0, 0x0, 0x0, 0x2, 0x0, 0x81], - }, - Test { - value: BigDecimal::from_str("0").unwrap(), - encoding: &[0x0, 0x0, 0x0, 0x0, 0x0], - }, - Test { - value: BigDecimal::from_str("123").unwrap(), - encoding: &[0x0, 0x0, 0x0, 0x0, 0x7b], - }, - ]; - - for t in tests.iter() { - let value = super::deser_cql_value(&ColumnType::Decimal, &mut &*t.encoding).unwrap(); - assert_eq!(CqlValue::Decimal(t.value.clone()), value); - } - } - - #[test] - fn test_deserialize_counter() { - let counter: Vec = vec![0, 0, 0, 0, 0, 0, 1, 0]; - let counter_slice = &mut &counter[..]; - let counter_serialize = - super::deser_cql_value(&ColumnType::Counter, counter_slice).unwrap(); - assert_eq!(counter_serialize, CqlValue::Counter(Counter(256))); - } - - #[test] - fn test_deserialize_blob() { - let blob: Vec = vec![0, 1, 2, 3]; - let blob_slice = &mut &blob[..]; - let blob_serialize = super::deser_cql_value(&ColumnType::Blob, blob_slice).unwrap(); - assert_eq!(blob_serialize, CqlValue::Blob(blob)); - } - - #[test] - fn test_deserialize_bool() { - let bool_buf: Vec = vec![0x00]; - let bool_slice = &mut &bool_buf[..]; - let bool_serialize = super::deser_cql_value(&ColumnType::Boolean, bool_slice).unwrap(); - assert_eq!(bool_serialize, CqlValue::Boolean(false)); - - let bool_buf: Vec = vec![0x01]; - let bool_slice = &mut &bool_buf[..]; - let bool_serialize = super::deser_cql_value(&ColumnType::Boolean, bool_slice).unwrap(); - assert_eq!(bool_serialize, CqlValue::Boolean(true)); - } - - #[test] - fn test_deserialize_int_types() { - let int_buf: Vec = vec![0, 0, 0, 4]; - let int_slice = &mut &int_buf[..]; - let int_serialized = super::deser_cql_value(&ColumnType::Int, int_slice).unwrap(); - assert_eq!(int_serialized, CqlValue::Int(4)); - - let smallint_buf: Vec = vec![0, 4]; - let smallint_slice = &mut &smallint_buf[..]; - let smallint_serialized = - super::deser_cql_value(&ColumnType::SmallInt, smallint_slice).unwrap(); - assert_eq!(smallint_serialized, CqlValue::SmallInt(4)); - - let tinyint_buf: Vec = vec![4]; - let tinyint_slice = &mut &tinyint_buf[..]; - let tinyint_serialized = - super::deser_cql_value(&ColumnType::TinyInt, tinyint_slice).unwrap(); - assert_eq!(tinyint_serialized, CqlValue::TinyInt(4)); - - let bigint_buf: Vec = vec![0, 0, 0, 0, 0, 0, 0, 4]; - let bigint_slice = &mut &bigint_buf[..]; - let bigint_serialized = super::deser_cql_value(&ColumnType::BigInt, bigint_slice).unwrap(); - assert_eq!(bigint_serialized, CqlValue::BigInt(4)); - } - - #[test] - fn test_list_from_cql() { - let my_vec: Vec = vec![CqlValue::Int(20), CqlValue::Int(2), CqlValue::Int(13)]; - - let cql: CqlValue = CqlValue::List(my_vec); - let decoded = cql.into_vec().unwrap(); - - assert_eq!(decoded[0], CqlValue::Int(20)); - assert_eq!(decoded[1], CqlValue::Int(2)); - assert_eq!(decoded[2], CqlValue::Int(13)); - } - - #[test] - fn test_set_from_cql() { - let my_vec: Vec = vec![CqlValue::Int(20), CqlValue::Int(2), CqlValue::Int(13)]; - - let cql: CqlValue = CqlValue::Set(my_vec); - let decoded = cql.as_set().unwrap(); - - assert_eq!(decoded[0], CqlValue::Int(20)); - assert_eq!(decoded[1], CqlValue::Int(2)); - assert_eq!(decoded[2], CqlValue::Int(13)); - } - - #[test] - fn test_map_from_cql() { - let my_vec: Vec<(CqlValue, CqlValue)> = vec![ - (CqlValue::Int(20), CqlValue::Int(21)), - (CqlValue::Int(2), CqlValue::Int(3)), - ]; - - let cql: CqlValue = CqlValue::Map(my_vec); - - // Test borrowing. - let decoded = cql.as_map().unwrap(); - - assert_eq!(CqlValue::Int(20), decoded[0].0); - assert_eq!(CqlValue::Int(21), decoded[0].1); - - assert_eq!(CqlValue::Int(2), decoded[1].0); - assert_eq!(CqlValue::Int(3), decoded[1].1); - - // Test taking the ownership. - let decoded = cql.into_pair_vec().unwrap(); - - assert_eq!(CqlValue::Int(20), decoded[0].0); - assert_eq!(CqlValue::Int(21), decoded[0].1); - - assert_eq!(CqlValue::Int(2), decoded[1].0); - assert_eq!(CqlValue::Int(3), decoded[1].1); - } - - #[test] - fn test_udt_from_cql() { - let my_fields: Vec<(String, Option)> = vec![ - ("fst".to_string(), Some(CqlValue::Int(10))), - ("snd".to_string(), Some(CqlValue::Boolean(true))), - ]; - - let cql: CqlValue = CqlValue::UserDefinedType { - keyspace: "".to_string(), - type_name: "".to_string(), - fields: my_fields, - }; - - // Test borrowing. - let decoded = cql.as_udt().unwrap(); - - assert_eq!("fst".to_string(), decoded[0].0); - assert_eq!(Some(CqlValue::Int(10)), decoded[0].1); - - assert_eq!("snd".to_string(), decoded[1].0); - assert_eq!(Some(CqlValue::Boolean(true)), decoded[1].1); - - let decoded = cql.into_udt_pair_vec().unwrap(); - - assert_eq!("fst".to_string(), decoded[0].0); - assert_eq!(Some(CqlValue::Int(10)), decoded[0].1); - - assert_eq!("snd".to_string(), decoded[1].0); - assert_eq!(Some(CqlValue::Boolean(true)), decoded[1].1); - } - - #[test] - fn date_deserialize() { - // Date is correctly parsed from a 4 byte array - let four_bytes: [u8; 4] = [12, 23, 34, 45]; - let date: CqlValue = - super::deser_cql_value(&ColumnType::Date, &mut four_bytes.as_ref()).unwrap(); - assert_eq!(date, CqlValue::Date(u32::from_be_bytes(four_bytes))); - - // Date is parsed as u32 not i32, u32::MAX is u32::MAX - let date: CqlValue = - super::deser_cql_value(&ColumnType::Date, &mut u32::MAX.to_be_bytes().as_ref()) - .unwrap(); - assert_eq!(date, CqlValue::Date(u32::MAX)); - - // Trying to parse a 0, 3 or 5 byte array fails - super::deser_cql_value(&ColumnType::Date, &mut [].as_ref()).unwrap(); - super::deser_cql_value(&ColumnType::Date, &mut [1, 2, 3].as_ref()).unwrap_err(); - super::deser_cql_value(&ColumnType::Date, &mut [1, 2, 3, 4, 5].as_ref()).unwrap_err(); - - // 2^31 when converted to NaiveDate is 1970-01-01 - let unix_epoch: NaiveDate = NaiveDate::from_ymd_opt(1970, 1, 1).unwrap(); - let date: CqlValue = - super::deser_cql_value(&ColumnType::Date, &mut 2_u32.pow(31).to_be_bytes().as_ref()) - .unwrap(); - - assert_eq!(date.as_date().unwrap(), unix_epoch); - - // 2^31 - 30 when converted to NaiveDate is 1969-12-02 - let before_epoch: NaiveDate = NaiveDate::from_ymd_opt(1969, 12, 2).unwrap(); - let date: CqlValue = super::deser_cql_value( - &ColumnType::Date, - &mut (2_u32.pow(31) - 30).to_be_bytes().as_ref(), - ) - .unwrap(); - - assert_eq!(date.as_date().unwrap(), before_epoch); - - // 2^31 + 30 when converted to NaiveDate is 1970-01-31 - let after_epoch: NaiveDate = NaiveDate::from_ymd_opt(1970, 1, 31).unwrap(); - let date: CqlValue = super::deser_cql_value( - &ColumnType::Date, - &mut (2_u32.pow(31) + 30).to_be_bytes().as_ref(), - ) - .unwrap(); - - assert_eq!(date.as_date().unwrap(), after_epoch); - - // 0 and u32::MAX is out of NaiveDate range, fails with an error, not panics - assert!( - super::deser_cql_value(&ColumnType::Date, &mut 0_u32.to_be_bytes().as_ref()) - .unwrap() - .as_date() - .is_none() - ); - - assert!( - super::deser_cql_value(&ColumnType::Date, &mut u32::MAX.to_be_bytes().as_ref()) - .unwrap() - .as_date() - .is_none() - ); - - // It's hard to test NaiveDate more because it involves calculating days between calendar dates - // There are more tests using database queries that should cover it - } - - #[test] - fn test_time_deserialize() { - // Time is an i64 - nanoseconds since midnight - // in range 0..=86399999999999 - - let max_time: i64 = 24 * 60 * 60 * 1_000_000_000 - 1; - assert_eq!(max_time, 86399999999999); - - // Check that basic values are deserialized correctly - for test_val in [0, 1, 18463, max_time].iter() { - let bytes: [u8; 8] = test_val.to_be_bytes(); - let cql_value: CqlValue = - super::deser_cql_value(&ColumnType::Time, &mut &bytes[..]).unwrap(); - assert_eq!(cql_value, CqlValue::Time(Duration::nanoseconds(*test_val))); - } - - // Negative values cause an error - // Values bigger than 86399999999999 cause an error - for test_val in [-1, i64::MIN, max_time + 1, i64::MAX].iter() { - let bytes: [u8; 8] = test_val.to_be_bytes(); - super::deser_cql_value(&ColumnType::Time, &mut &bytes[..]).unwrap_err(); - } - - // chrono::Duration has enough precision to represent nanoseconds accurately - assert_eq!(Duration::nanoseconds(1).num_nanoseconds().unwrap(), 1); - assert_eq!( - Duration::nanoseconds(7364737473).num_nanoseconds().unwrap(), - 7364737473 - ); - assert_eq!( - Duration::nanoseconds(86399999999999) - .num_nanoseconds() - .unwrap(), - 86399999999999 - ); - } - - #[test] - fn test_timestamp_deserialize() { - // Timestamp is an i64 - milliseconds since unix epoch - - // Check that test values are deserialized correctly - for test_val in &[0, -1, 1, 74568745, -4584658, i64::MIN, i64::MAX] { - let bytes: [u8; 8] = test_val.to_be_bytes(); - let cql_value: CqlValue = - super::deser_cql_value(&ColumnType::Timestamp, &mut &bytes[..]).unwrap(); - assert_eq!( - cql_value, - CqlValue::Timestamp(Duration::milliseconds(*test_val)) - ); - - // Check that Duration converted back to i64 is correct - assert_eq!( - Duration::milliseconds(*test_val).num_milliseconds(), - *test_val - ); - } - } - - #[test] - fn test_serialize_empty() { - use crate::frame::value::Value; - - let empty = CqlValue::Empty; - let mut v = Vec::new(); - empty.serialize(&mut v).unwrap(); - - assert_eq!(v, vec![0, 0, 0, 0]); - } - - #[test] - fn test_duration_deserialize() { - let bytes = [0xc, 0x12, 0xe2, 0x8c, 0x39, 0xd2]; - let cql_value: CqlValue = - super::deser_cql_value(&ColumnType::Duration, &mut &bytes[..]).unwrap(); - assert_eq!( - cql_value, - CqlValue::Duration(CqlDuration { - months: 6, - days: 9, - nanoseconds: 21372137 - }) - ); - } - - #[test] - fn test_deserialize_empty_payload() { - for (test_type, res_cql) in [ - (ColumnType::Ascii, CqlValue::Ascii("".to_owned())), - (ColumnType::Boolean, CqlValue::Empty), - (ColumnType::Blob, CqlValue::Blob(vec![])), - (ColumnType::Counter, CqlValue::Empty), - (ColumnType::Date, CqlValue::Empty), - (ColumnType::Decimal, CqlValue::Empty), - (ColumnType::Double, CqlValue::Empty), - (ColumnType::Float, CqlValue::Empty), - (ColumnType::Int, CqlValue::Empty), - (ColumnType::BigInt, CqlValue::Empty), - (ColumnType::Text, CqlValue::Text("".to_owned())), - (ColumnType::Timestamp, CqlValue::Empty), - (ColumnType::Inet, CqlValue::Empty), - (ColumnType::List(Box::new(ColumnType::Int)), CqlValue::Empty), - ( - ColumnType::Map(Box::new(ColumnType::Int), Box::new(ColumnType::Int)), - CqlValue::Empty, - ), - (ColumnType::Set(Box::new(ColumnType::Int)), CqlValue::Empty), - ( - ColumnType::UserDefinedType { - type_name: "".to_owned(), - keyspace: "".to_owned(), - field_types: vec![], - }, - CqlValue::Empty, - ), - (ColumnType::SmallInt, CqlValue::Empty), - (ColumnType::TinyInt, CqlValue::Empty), - (ColumnType::Time, CqlValue::Empty), - (ColumnType::Timeuuid, CqlValue::Empty), - (ColumnType::Tuple(vec![]), CqlValue::Empty), - (ColumnType::Uuid, CqlValue::Empty), - (ColumnType::Varint, CqlValue::Empty), - ] { - let cql_value: CqlValue = super::deser_cql_value(&test_type, &mut &[][..]).unwrap(); - - assert_eq!(cql_value, res_cql); - } - } - - #[test] - fn test_timeuuid_deserialize() { - // A few random timeuuids generated manually - let tests = [ - ( - "8e14e760-7fa8-11eb-bc66-000000000001", - [ - 0x8e, 0x14, 0xe7, 0x60, 0x7f, 0xa8, 0x11, 0xeb, 0xbc, 0x66, 0, 0, 0, 0, 0, 0x01, - ], - ), - ( - "9b349580-7fa8-11eb-bc66-000000000001", - [ - 0x9b, 0x34, 0x95, 0x80, 0x7f, 0xa8, 0x11, 0xeb, 0xbc, 0x66, 0, 0, 0, 0, 0, 0x01, - ], - ), - ( - "5d74bae0-7fa3-11eb-bc66-000000000001", - [ - 0x5d, 0x74, 0xba, 0xe0, 0x7f, 0xa3, 0x11, 0xeb, 0xbc, 0x66, 0, 0, 0, 0, 0, 0x01, - ], - ), - ]; - - for (uuid_str, uuid_bytes) in &tests { - let cql_val: CqlValue = - super::deser_cql_value(&ColumnType::Timeuuid, &mut &uuid_bytes[..]).unwrap(); - - match cql_val { - CqlValue::Timeuuid(uuid) => { - assert_eq!(uuid.as_bytes(), uuid_bytes); - assert_eq!(Uuid::parse_str(uuid_str).unwrap(), uuid); - } - _ => panic!("Timeuuid parsed as wrong CqlValue"), - } - } - } -} diff --git a/scylla-cql/src/frame/response/supported.rs b/scylla-cql/src/frame/response/supported.rs deleted file mode 100644 index a2e5a56..0000000 --- a/scylla-cql/src/frame/response/supported.rs +++ /dev/null @@ -1,16 +0,0 @@ -use crate::frame::frame_errors::ParseError; -use crate::frame::types; -use std::collections::HashMap; - -#[derive(Debug)] -pub struct Supported { - pub options: HashMap>, -} - -impl Supported { - pub fn deserialize(buf: &mut &[u8]) -> Result { - let options = types::read_string_multimap(buf)?; - - Ok(Supported { options }) - } -} diff --git a/scylla-cql/src/frame/server_event_type.rs b/scylla-cql/src/frame/server_event_type.rs deleted file mode 100644 index 1a4d5bf..0000000 --- a/scylla-cql/src/frame/server_event_type.rs +++ /dev/null @@ -1,37 +0,0 @@ -use crate::frame::frame_errors::ParseError; -use std::fmt; -use std::str::FromStr; - -pub enum EventType { - TopologyChange, - StatusChange, - SchemaChange, -} - -impl fmt::Display for EventType { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - let s = match &self { - Self::TopologyChange => "TOPOLOGY_CHANGE", - Self::StatusChange => "STATUS_CHANGE", - Self::SchemaChange => "SCHEMA_CHANGE", - }; - - write!(f, "{}", s) - } -} - -impl FromStr for EventType { - type Err = ParseError; - - fn from_str(s: &str) -> Result { - match s { - "TOPOLOGY_CHANGE" => Ok(Self::TopologyChange), - "STATUS_CHANGE" => Ok(Self::StatusChange), - "SCHEMA_CHANGE" => Ok(Self::SchemaChange), - _ => Err(ParseError::BadIncomingData(format!( - "Invalid type event type: {}", - s - ))), - } - } -} diff --git a/scylla-cql/src/frame/types.rs b/scylla-cql/src/frame/types.rs deleted file mode 100644 index a10b6db..0000000 --- a/scylla-cql/src/frame/types.rs +++ /dev/null @@ -1,864 +0,0 @@ -//! CQL binary protocol in-wire types. - -use super::frame_errors::ParseError; -use byteorder::{BigEndian, ReadBytesExt}; -use bytes::{Buf, BufMut}; -use num_enum::TryFromPrimitive; -use std::collections::HashMap; -use std::convert::TryFrom; -use std::convert::TryInto; -use std::net::IpAddr; -use std::net::SocketAddr; -use std::str; -use uuid::Uuid; - -#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, TryFromPrimitive)] -#[repr(i16)] -pub enum Consistency { - Any = 0x0000, - One = 0x0001, - Two = 0x0002, - Three = 0x0003, - Quorum = 0x0004, - All = 0x0005, - LocalQuorum = 0x0006, - EachQuorum = 0x0007, - LocalOne = 0x000A, -} - -#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, TryFromPrimitive)] -#[repr(i16)] -pub enum SerialConsistency { - Serial = 0x0008, - LocalSerial = 0x0009, -} - -// LegacyConsistency exists, because Scylla may return a SerialConsistency value -// as Consistency when returning certain error types - the distinction between -// Consistency and SerialConsistency is not really a thing in CQL. -#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord)] -pub enum LegacyConsistency { - Regular(Consistency), - Serial(SerialConsistency), -} - -impl Default for Consistency { - fn default() -> Self { - Consistency::LocalQuorum - } -} - -impl std::fmt::Display for Consistency { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "{:?}", self) - } -} - -impl std::fmt::Display for SerialConsistency { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "{:?}", self) - } -} - -impl std::fmt::Display for LegacyConsistency { - fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { - match self { - Self::Regular(c) => c.fmt(f), - Self::Serial(c) => c.fmt(f), - } - } -} - -impl From for ParseError { - fn from(_err: std::num::TryFromIntError) -> Self { - ParseError::BadIncomingData("Integer conversion out of range".to_string()) - } -} - -impl From for ParseError { - fn from(_err: std::str::Utf8Error) -> Self { - ParseError::BadIncomingData("UTF8 serialization failed".to_string()) - } -} - -impl From for ParseError { - fn from(_err: std::array::TryFromSliceError) -> Self { - ParseError::BadIncomingData("array try from slice failed".to_string()) - } -} - -fn read_raw_bytes<'a>(count: usize, buf: &mut &'a [u8]) -> Result<&'a [u8], ParseError> { - if buf.len() < count { - return Err(ParseError::BadIncomingData(format!( - "Not enough bytes! expected: {} received: {}", - count, - buf.len(), - ))); - } - let (ret, rest) = buf.split_at(count); - *buf = rest; - Ok(ret) -} - -pub fn read_int(buf: &mut &[u8]) -> Result { - let v = buf.read_i32::()?; - Ok(v) -} - -pub fn write_int(v: i32, buf: &mut impl BufMut) { - buf.put_i32(v); -} - -pub fn read_int_length(buf: &mut &[u8]) -> Result { - let v = read_int(buf)?; - let v: usize = v.try_into()?; - - Ok(v) -} - -fn write_int_length(v: usize, buf: &mut impl BufMut) -> Result<(), ParseError> { - let v: i32 = v.try_into()?; - - write_int(v, buf); - Ok(()) -} - -#[test] -fn type_int() { - let vals = vec![i32::MIN, -1, 0, 1, i32::MAX]; - for val in vals.iter() { - let mut buf = Vec::new(); - write_int(*val, &mut buf); - assert_eq!(read_int(&mut &buf[..]).unwrap(), *val); - } -} - -pub fn read_long(buf: &mut &[u8]) -> Result { - let v = buf.read_i64::()?; - Ok(v) -} - -pub fn write_long(v: i64, buf: &mut impl BufMut) { - buf.put_i64(v); -} - -#[test] -fn type_long() { - let vals = vec![i64::MIN, -1, 0, 1, i64::MAX]; - for val in vals.iter() { - let mut buf = Vec::new(); - write_long(*val, &mut buf); - assert_eq!(read_long(&mut &buf[..]).unwrap(), *val); - } -} - -pub fn read_short(buf: &mut &[u8]) -> Result { - let v = buf.read_i16::()?; - Ok(v) -} - -pub fn write_short(v: i16, buf: &mut impl BufMut) { - buf.put_i16(v); -} - -pub(crate) fn read_short_length(buf: &mut &[u8]) -> Result { - let v = read_short(buf)?; - let v: usize = v.try_into()?; - Ok(v) -} - -fn write_short_length(v: usize, buf: &mut impl BufMut) -> Result<(), ParseError> { - let v: i16 = v.try_into()?; - write_short(v, buf); - Ok(()) -} - -#[test] -fn type_short() { - let vals = vec![i16::MIN, -1, 0, 1, i16::MAX]; - for val in vals.iter() { - let mut buf = Vec::new(); - write_short(*val, &mut buf); - assert_eq!(read_short(&mut &buf[..]).unwrap(), *val); - } -} - -// https://github.com/apache/cassandra/blob/trunk/doc/native_protocol_v4.spec#L208 -pub fn read_bytes_opt<'a>(buf: &mut &'a [u8]) -> Result, ParseError> { - let len = read_int(buf)?; - if len < 0 { - return Ok(None); - } - let len = len as usize; - let v = Some(read_raw_bytes(len, buf)?); - Ok(v) -} - -// Same as read_bytes, but we assume the value won't be `null` -pub fn read_bytes<'a>(buf: &mut &'a [u8]) -> Result<&'a [u8], ParseError> { - let len = read_int_length(buf)?; - let v = read_raw_bytes(len, buf)?; - Ok(v) -} - -pub fn read_short_bytes<'a>(buf: &mut &'a [u8]) -> Result<&'a [u8], ParseError> { - let len = read_short_length(buf)?; - let v = read_raw_bytes(len, buf)?; - Ok(v) -} - -pub fn write_bytes(v: &[u8], buf: &mut impl BufMut) -> Result<(), ParseError> { - write_int_length(v.len(), buf)?; - buf.put_slice(v); - Ok(()) -} - -pub fn write_bytes_opt(v: Option<&Vec>, buf: &mut impl BufMut) -> Result<(), ParseError> { - match v { - Some(bytes) => { - write_int_length(bytes.len(), buf)?; - buf.put_slice(bytes); - } - None => write_int(-1, buf), - } - - Ok(()) -} - -pub fn write_short_bytes(v: &[u8], buf: &mut impl BufMut) -> Result<(), ParseError> { - write_short_length(v.len(), buf)?; - buf.put_slice(v); - Ok(()) -} - -pub fn read_bytes_map(buf: &mut &[u8]) -> Result>, ParseError> { - let len = read_short_length(buf)?; - let mut v = HashMap::with_capacity(len); - for _ in 0..len { - let key = read_string(buf)?.to_owned(); - let val = read_bytes(buf)?.to_owned(); - v.insert(key, val); - } - Ok(v) -} - -pub fn write_bytes_map(v: &HashMap, buf: &mut impl BufMut) -> Result<(), ParseError> -where - B: AsRef<[u8]>, -{ - let len = v.len(); - write_short_length(len, buf)?; - for (key, val) in v.iter() { - write_string(key, buf)?; - write_bytes(val.as_ref(), buf)?; - } - Ok(()) -} - -#[test] -fn type_bytes_map() { - let mut val = HashMap::new(); - val.insert("".to_owned(), vec![]); - val.insert("EXTENSION1".to_owned(), vec![1, 2, 3]); - val.insert("EXTENSION2".to_owned(), vec![4, 5, 6]); - let mut buf = Vec::new(); - write_bytes_map(&val, &mut buf).unwrap(); - assert_eq!(read_bytes_map(&mut &*buf).unwrap(), val); -} - -pub fn read_string<'a>(buf: &mut &'a [u8]) -> Result<&'a str, ParseError> { - let len = read_short_length(buf)?; - let raw = read_raw_bytes(len, buf)?; - let v = str::from_utf8(raw)?; - Ok(v) -} - -pub fn write_string(v: &str, buf: &mut impl BufMut) -> Result<(), ParseError> { - let raw = v.as_bytes(); - write_short_length(v.len(), buf)?; - buf.put_slice(raw); - Ok(()) -} - -#[test] -fn type_string() { - let vals = vec![String::from(""), String::from("hello, world!")]; - for val in vals.iter() { - let mut buf = Vec::new(); - write_string(val, &mut buf).unwrap(); - assert_eq!(read_string(&mut &buf[..]).unwrap(), *val); - } -} - -pub fn read_long_string<'a>(buf: &mut &'a [u8]) -> Result<&'a str, ParseError> { - let len = read_int_length(buf)?; - let raw = read_raw_bytes(len, buf)?; - let v = str::from_utf8(raw)?; - Ok(v) -} - -pub fn write_long_string(v: &str, buf: &mut impl BufMut) -> Result<(), ParseError> { - let raw = v.as_bytes(); - let len = raw.len(); - write_int_length(len, buf)?; - buf.put_slice(raw); - Ok(()) -} - -#[test] -fn type_long_string() { - let vals = vec![String::from(""), String::from("hello, world!")]; - for val in vals.iter() { - let mut buf = Vec::new(); - write_long_string(val, &mut buf).unwrap(); - assert_eq!(read_long_string(&mut &buf[..]).unwrap(), *val); - } -} - -pub fn read_string_map(buf: &mut &[u8]) -> Result, ParseError> { - let len = read_short_length(buf)?; - let mut v = HashMap::with_capacity(len); - for _ in 0..len { - let key = read_string(buf)?.to_owned(); - let val = read_string(buf)?.to_owned(); - v.insert(key, val); - } - Ok(v) -} - -pub fn write_string_map( - v: &HashMap, - buf: &mut impl BufMut, -) -> Result<(), ParseError> { - let len = v.len(); - write_short_length(len, buf)?; - for (key, val) in v.iter() { - write_string(key, buf)?; - write_string(val, buf)?; - } - Ok(()) -} - -#[test] -fn type_string_map() { - let mut val = HashMap::new(); - val.insert(String::from(""), String::from("")); - val.insert(String::from("CQL_VERSION"), String::from("3.0.0")); - val.insert(String::from("THROW_ON_OVERLOAD"), String::from("")); - let mut buf = Vec::new(); - write_string_map(&val, &mut buf).unwrap(); - assert_eq!(read_string_map(&mut &buf[..]).unwrap(), val); -} - -pub fn read_string_list(buf: &mut &[u8]) -> Result, ParseError> { - let len = read_short_length(buf)?; - let mut v = Vec::with_capacity(len); - for _ in 0..len { - v.push(read_string(buf)?.to_owned()); - } - Ok(v) -} - -pub fn write_string_list(v: &[String], buf: &mut impl BufMut) -> Result<(), ParseError> { - let len = v.len(); - write_short_length(len, buf)?; - for v in v.iter() { - write_string(v, buf)?; - } - Ok(()) -} - -#[test] -fn type_string_list() { - let val = vec![ - "".to_owned(), - "CQL_VERSION".to_owned(), - "THROW_ON_OVERLOAD".to_owned(), - ]; - - let mut buf = Vec::new(); - write_string_list(&val, &mut buf).unwrap(); - assert_eq!(read_string_list(&mut &buf[..]).unwrap(), val); -} - -pub fn read_string_multimap(buf: &mut &[u8]) -> Result>, ParseError> { - let len = read_short_length(buf)?; - let mut v = HashMap::with_capacity(len); - for _ in 0..len { - let key = read_string(buf)?.to_owned(); - let val = read_string_list(buf)?; - v.insert(key, val); - } - Ok(v) -} - -pub fn write_string_multimap( - v: &HashMap>, - buf: &mut impl BufMut, -) -> Result<(), ParseError> { - let len = v.len(); - write_short_length(len, buf)?; - for (key, val) in v.iter() { - write_string(key, buf)?; - write_string_list(val, buf)?; - } - Ok(()) -} - -#[test] -fn type_string_multimap() { - let mut val = HashMap::new(); - val.insert(String::from(""), vec![String::from("")]); - val.insert( - String::from("versions"), - vec![String::from("3.0.0"), String::from("4.2.0")], - ); - val.insert(String::from("empty"), vec![]); - let mut buf = Vec::new(); - write_string_multimap(&val, &mut buf).unwrap(); - assert_eq!(read_string_multimap(&mut &buf[..]).unwrap(), val); -} - -pub fn read_uuid(buf: &mut &[u8]) -> Result { - let raw = read_raw_bytes(16, buf)?; - - // It's safe to unwrap here because Uuid::from_slice only fails - // if the argument slice's length is not 16. - Ok(Uuid::from_slice(raw).unwrap()) -} - -pub fn write_uuid(uuid: &Uuid, buf: &mut impl BufMut) { - buf.put_slice(&uuid.as_bytes()[..]); -} - -#[test] -fn type_uuid() { - let u = Uuid::parse_str("f3b4958c-52a1-11e7-802a-010203040506").unwrap(); - let mut buf = Vec::new(); - write_uuid(&u, &mut buf); - let u2 = read_uuid(&mut &*buf).unwrap(); - assert_eq!(u, u2); -} - -pub fn read_consistency(buf: &mut &[u8]) -> Result { - let raw = read_short(buf)?; - let parsed = match Consistency::try_from(raw) { - Ok(c) => LegacyConsistency::Regular(c), - Err(_) => { - let parsed_serial = SerialConsistency::try_from(raw).map_err(|_| { - ParseError::BadIncomingData(format!("unknown consistency: {}", raw)) - })?; - LegacyConsistency::Serial(parsed_serial) - } - }; - Ok(parsed) -} - -pub fn write_consistency(c: Consistency, buf: &mut impl BufMut) { - write_short(c as i16, buf); -} - -pub fn write_serial_consistency(c: SerialConsistency, buf: &mut impl BufMut) { - write_short(c as i16, buf); -} - -#[test] -fn type_consistency() { - let c = Consistency::Quorum; - let mut buf = Vec::new(); - write_consistency(c, &mut buf); - let c2 = read_consistency(&mut &*buf).unwrap(); - assert_eq!(LegacyConsistency::Regular(c), c2); - - let c: i16 = 0x1234; - buf.clear(); - buf.put_i16(c); - let c_result = read_consistency(&mut &*buf); - assert!(c_result.is_err()); - - // Check that the error message contains information about the invalid value - let err_str = format!("{}", c_result.unwrap_err()); - assert!(err_str.contains(&format!("{}", c))); -} - -pub fn read_inet(buf: &mut &[u8]) -> Result { - let len = buf.read_u8()?; - let ip_addr = match len { - 4 => { - let ret = IpAddr::from(<[u8; 4]>::try_from(&buf[0..4])?); - buf.advance(4); - ret - } - 16 => { - let ret = IpAddr::from(<[u8; 16]>::try_from(&buf[0..16])?); - buf.advance(16); - ret - } - v => { - return Err(ParseError::BadIncomingData(format!( - "Invalid inet bytes length: {}", - v - ))) - } - }; - let port = read_int(buf)?; - - Ok(SocketAddr::new(ip_addr, port as u16)) -} - -pub fn write_inet(addr: SocketAddr, buf: &mut impl BufMut) { - match addr.ip() { - IpAddr::V4(v4) => { - buf.put_u8(4); - buf.put_slice(&v4.octets()); - } - IpAddr::V6(v6) => { - buf.put_u8(16); - buf.put_slice(&v6.octets()); - } - } - - write_int(addr.port() as i32, buf) -} - -#[test] -fn type_inet() { - use std::net::{IpAddr, Ipv4Addr, Ipv6Addr}; - - let iv4 = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 1234); - let iv6 = SocketAddr::new(IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1)), 2345); - let mut buf = Vec::new(); - - write_inet(iv4, &mut buf); - let read_iv4 = read_inet(&mut &*buf).unwrap(); - assert_eq!(iv4, read_iv4); - buf.clear(); - - write_inet(iv6, &mut buf); - let read_iv6 = read_inet(&mut &*buf).unwrap(); - assert_eq!(iv6, read_iv6); -} - -fn zig_zag_encode(v: i64) -> u64 { - ((v >> 63) ^ (v << 1)) as u64 -} - -fn zig_zag_decode(v: u64) -> i64 { - ((v >> 1) as i64) ^ -((v & 1) as i64) -} - -fn unsigned_vint_encode(v: u64, buf: &mut Vec) { - let mut v = v; - let mut number_of_bytes = (639 - 9 * v.leading_zeros()) >> 6; - if number_of_bytes <= 1 { - return buf.put_u8(v as u8); - } - - if number_of_bytes != 9 { - let extra_bytes = number_of_bytes - 1; - let length_bits = !(0xff >> extra_bytes); - v |= (length_bits as u64) << (8 * extra_bytes); - } else { - buf.put_u8(0xff); - number_of_bytes -= 1; - } - buf.put_uint(v, number_of_bytes as usize) -} - -fn unsigned_vint_decode(buf: &mut &[u8]) -> Result { - let first_byte = buf.read_u8()?; - let extra_bytes = first_byte.leading_ones() as usize; - - let mut v = if extra_bytes != 8 { - let first_byte_bits = first_byte & (0xffu8 >> extra_bytes); - (first_byte_bits as u64) << (8 * extra_bytes) - } else { - 0 - }; - - if extra_bytes != 0 { - v += buf.read_uint::(extra_bytes)?; - } - - Ok(v) -} - -pub(crate) fn vint_encode(v: i64, buf: &mut Vec) { - unsigned_vint_encode(zig_zag_encode(v), buf) -} - -pub(crate) fn vint_decode(buf: &mut &[u8]) -> Result { - unsigned_vint_decode(buf).map(zig_zag_decode) -} - -#[test] -fn zig_zag_encode_test() { - assert_eq!(zig_zag_encode(0), 0); - assert_eq!(zig_zag_encode(-1), 1); - assert_eq!(zig_zag_encode(1), 2); - assert_eq!(zig_zag_encode(-2), 3); - assert_eq!(zig_zag_encode(2), 4); - assert_eq!(zig_zag_encode(-3), 5); - assert_eq!(zig_zag_encode(3), 6); -} - -#[test] -fn zig_zag_decode_test() { - assert_eq!(zig_zag_decode(0), 0); - assert_eq!(zig_zag_decode(1), -1); - assert_eq!(zig_zag_decode(2), 1); - assert_eq!(zig_zag_decode(3), -2); - assert_eq!(zig_zag_decode(4), 2); - assert_eq!(zig_zag_decode(5), -3); - assert_eq!(zig_zag_decode(6), 3); -} - -#[test] -fn unsigned_vint_encode_and_decode_test() { - let unsigned_vint_encoding = vec![ - (0, vec![0]), - (1, vec![1]), - (2, vec![2]), - ((1 << 2) - 1, vec![3]), - (1 << 2, vec![4]), - ((1 << 2) + 1, vec![5]), - ((1 << 3) - 1, vec![7]), - (1 << 3, vec![8]), - ((1 << 3) + 1, vec![9]), - ((1 << 4) - 1, vec![15]), - (1 << 4, vec![16]), - ((1 << 4) + 1, vec![17]), - ((1 << 5) - 1, vec![31]), - (1 << 5, vec![32]), - ((1 << 5) + 1, vec![33]), - ((1 << 6) - 1, vec![63]), - (1 << 6, vec![64]), - ((1 << 6) + 1, vec![65]), - ((1 << 7) - 1, vec![127]), - (1 << 7, vec![128, 128]), - ((1 << 7) + 1, vec![128, 129]), - ((1 << 8) - 1, vec![128, 255]), - (1 << 8, vec![129, 0]), - ((1 << 8) + 1, vec![129, 1]), - ((1 << 9) - 1, vec![129, 255]), - (1 << 9, vec![130, 0]), - ((1 << 9) + 1, vec![130, 1]), - ((1 << 10) - 1, vec![131, 255]), - (1 << 10, vec![132, 0]), - ((1 << 10) + 1, vec![132, 1]), - ((1 << 11) - 1, vec![135, 255]), - (1 << 11, vec![136, 0]), - ((1 << 11) + 1, vec![136, 1]), - ((1 << 12) - 1, vec![143, 255]), - (1 << 12, vec![144, 0]), - ((1 << 12) + 1, vec![144, 1]), - ((1 << 13) - 1, vec![159, 255]), - (1 << 13, vec![160, 0]), - ((1 << 13) + 1, vec![160, 1]), - ((1 << 14) - 1, vec![191, 255]), - (1 << 14, vec![192, 64, 0]), - ((1 << 14) + 1, vec![192, 64, 1]), - ((1 << 15) - 1, vec![192, 127, 255]), - (1 << 15, vec![192, 128, 0]), - ((1 << 15) + 1, vec![192, 128, 1]), - ((1 << 16) - 1, vec![192, 255, 255]), - (1 << 16, vec![193, 0, 0]), - ((1 << 16) + 1, vec![193, 0, 1]), - ((1 << 17) - 1, vec![193, 255, 255]), - (1 << 17, vec![194, 0, 0]), - ((1 << 17) + 1, vec![194, 0, 1]), - ((1 << 18) - 1, vec![195, 255, 255]), - (1 << 18, vec![196, 0, 0]), - ((1 << 18) + 1, vec![196, 0, 1]), - ((1 << 19) - 1, vec![199, 255, 255]), - (1 << 19, vec![200, 0, 0]), - ((1 << 19) + 1, vec![200, 0, 1]), - ((1 << 20) - 1, vec![207, 255, 255]), - (1 << 20, vec![208, 0, 0]), - ((1 << 20) + 1, vec![208, 0, 1]), - ((1 << 21) - 1, vec![223, 255, 255]), - (1 << 21, vec![224, 32, 0, 0]), - ((1 << 21) + 1, vec![224, 32, 0, 1]), - ((1 << 22) - 1, vec![224, 63, 255, 255]), - (1 << 22, vec![224, 64, 0, 0]), - ((1 << 22) + 1, vec![224, 64, 0, 1]), - ((1 << 23) - 1, vec![224, 127, 255, 255]), - (1 << 23, vec![224, 128, 0, 0]), - ((1 << 23) + 1, vec![224, 128, 0, 1]), - ((1 << 24) - 1, vec![224, 255, 255, 255]), - (1 << 24, vec![225, 0, 0, 0]), - ((1 << 24) + 1, vec![225, 0, 0, 1]), - ((1 << 25) - 1, vec![225, 255, 255, 255]), - (1 << 25, vec![226, 0, 0, 0]), - ((1 << 25) + 1, vec![226, 0, 0, 1]), - ((1 << 26) - 1, vec![227, 255, 255, 255]), - (1 << 26, vec![228, 0, 0, 0]), - ((1 << 26) + 1, vec![228, 0, 0, 1]), - ((1 << 27) - 1, vec![231, 255, 255, 255]), - (1 << 27, vec![232, 0, 0, 0]), - ((1 << 27) + 1, vec![232, 0, 0, 1]), - ((1 << 28) - 1, vec![239, 255, 255, 255]), - (1 << 28, vec![240, 16, 0, 0, 0]), - ((1 << 28) + 1, vec![240, 16, 0, 0, 1]), - ((1 << 29) - 1, vec![240, 31, 255, 255, 255]), - (1 << 29, vec![240, 32, 0, 0, 0]), - ((1 << 29) + 1, vec![240, 32, 0, 0, 1]), - ((1 << 30) - 1, vec![240, 63, 255, 255, 255]), - (1 << 30, vec![240, 64, 0, 0, 0]), - ((1 << 30) + 1, vec![240, 64, 0, 0, 1]), - ((1 << 31) - 1, vec![240, 127, 255, 255, 255]), - (1 << 31, vec![240, 128, 0, 0, 0]), - ((1 << 31) + 1, vec![240, 128, 0, 0, 1]), - ((1 << 32) - 1, vec![240, 255, 255, 255, 255]), - (1 << 32, vec![241, 0, 0, 0, 0]), - ((1 << 32) + 1, vec![241, 0, 0, 0, 1]), - ((1 << 33) - 1, vec![241, 255, 255, 255, 255]), - (1 << 33, vec![242, 0, 0, 0, 0]), - ((1 << 33) + 1, vec![242, 0, 0, 0, 1]), - ((1 << 34) - 1, vec![243, 255, 255, 255, 255]), - (1 << 34, vec![244, 0, 0, 0, 0]), - ((1 << 34) + 1, vec![244, 0, 0, 0, 1]), - ((1 << 35) - 1, vec![247, 255, 255, 255, 255]), - (1 << 35, vec![248, 8, 0, 0, 0, 0]), - ((1 << 35) + 1, vec![248, 8, 0, 0, 0, 1]), - ((1 << 36) - 1, vec![248, 15, 255, 255, 255, 255]), - (1 << 36, vec![248, 16, 0, 0, 0, 0]), - ((1 << 36) + 1, vec![248, 16, 0, 0, 0, 1]), - ((1 << 37) - 1, vec![248, 31, 255, 255, 255, 255]), - (1 << 37, vec![248, 32, 0, 0, 0, 0]), - ((1 << 37) + 1, vec![248, 32, 0, 0, 0, 1]), - ((1 << 38) - 1, vec![248, 63, 255, 255, 255, 255]), - (1 << 38, vec![248, 64, 0, 0, 0, 0]), - ((1 << 38) + 1, vec![248, 64, 0, 0, 0, 1]), - ((1 << 39) - 1, vec![248, 127, 255, 255, 255, 255]), - (1 << 39, vec![248, 128, 0, 0, 0, 0]), - ((1 << 39) + 1, vec![248, 128, 0, 0, 0, 1]), - ((1 << 40) - 1, vec![248, 255, 255, 255, 255, 255]), - (1 << 40, vec![249, 0, 0, 0, 0, 0]), - ((1 << 40) + 1, vec![249, 0, 0, 0, 0, 1]), - ((1 << 41) - 1, vec![249, 255, 255, 255, 255, 255]), - (1 << 41, vec![250, 0, 0, 0, 0, 0]), - ((1 << 41) + 1, vec![250, 0, 0, 0, 0, 1]), - ((1 << 42) - 1, vec![251, 255, 255, 255, 255, 255]), - (1 << 42, vec![252, 4, 0, 0, 0, 0, 0]), - ((1 << 42) + 1, vec![252, 4, 0, 0, 0, 0, 1]), - ((1 << 43) - 1, vec![252, 7, 255, 255, 255, 255, 255]), - (1 << 43, vec![252, 8, 0, 0, 0, 0, 0]), - ((1 << 43) + 1, vec![252, 8, 0, 0, 0, 0, 1]), - ((1 << 44) - 1, vec![252, 15, 255, 255, 255, 255, 255]), - (1 << 44, vec![252, 16, 0, 0, 0, 0, 0]), - ((1 << 44) + 1, vec![252, 16, 0, 0, 0, 0, 1]), - ((1 << 45) - 1, vec![252, 31, 255, 255, 255, 255, 255]), - (1 << 45, vec![252, 32, 0, 0, 0, 0, 0]), - ((1 << 45) + 1, vec![252, 32, 0, 0, 0, 0, 1]), - ((1 << 46) - 1, vec![252, 63, 255, 255, 255, 255, 255]), - (1 << 46, vec![252, 64, 0, 0, 0, 0, 0]), - ((1 << 46) + 1, vec![252, 64, 0, 0, 0, 0, 1]), - ((1 << 47) - 1, vec![252, 127, 255, 255, 255, 255, 255]), - (1 << 47, vec![252, 128, 0, 0, 0, 0, 0]), - ((1 << 47) + 1, vec![252, 128, 0, 0, 0, 0, 1]), - ((1 << 48) - 1, vec![252, 255, 255, 255, 255, 255, 255]), - (1 << 48, vec![253, 0, 0, 0, 0, 0, 0]), - ((1 << 48) + 1, vec![253, 0, 0, 0, 0, 0, 1]), - ((1 << 49) - 1, vec![253, 255, 255, 255, 255, 255, 255]), - (1 << 49, vec![254, 2, 0, 0, 0, 0, 0, 0]), - ((1 << 49) + 1, vec![254, 2, 0, 0, 0, 0, 0, 1]), - ((1 << 50) - 1, vec![254, 3, 255, 255, 255, 255, 255, 255]), - (1 << 50, vec![254, 4, 0, 0, 0, 0, 0, 0]), - ((1 << 50) + 1, vec![254, 4, 0, 0, 0, 0, 0, 1]), - ((1 << 51) - 1, vec![254, 7, 255, 255, 255, 255, 255, 255]), - (1 << 51, vec![254, 8, 0, 0, 0, 0, 0, 0]), - ((1 << 51) + 1, vec![254, 8, 0, 0, 0, 0, 0, 1]), - ((1 << 52) - 1, vec![254, 15, 255, 255, 255, 255, 255, 255]), - (1 << 52, vec![254, 16, 0, 0, 0, 0, 0, 0]), - ((1 << 52) + 1, vec![254, 16, 0, 0, 0, 0, 0, 1]), - ((1 << 53) - 1, vec![254, 31, 255, 255, 255, 255, 255, 255]), - (1 << 53, vec![254, 32, 0, 0, 0, 0, 0, 0]), - ((1 << 53) + 1, vec![254, 32, 0, 0, 0, 0, 0, 1]), - ((1 << 54) - 1, vec![254, 63, 255, 255, 255, 255, 255, 255]), - (1 << 54, vec![254, 64, 0, 0, 0, 0, 0, 0]), - ((1 << 54) + 1, vec![254, 64, 0, 0, 0, 0, 0, 1]), - ((1 << 55) - 1, vec![254, 127, 255, 255, 255, 255, 255, 255]), - (1 << 55, vec![254, 128, 0, 0, 0, 0, 0, 0]), - ((1 << 55) + 1, vec![254, 128, 0, 0, 0, 0, 0, 1]), - ((1 << 56) - 1, vec![254, 255, 255, 255, 255, 255, 255, 255]), - (1 << 56, vec![255, 1, 0, 0, 0, 0, 0, 0, 0]), - ((1 << 56) + 1, vec![255, 1, 0, 0, 0, 0, 0, 0, 1]), - ( - (1 << 57) - 1, - vec![255, 1, 255, 255, 255, 255, 255, 255, 255], - ), - (1 << 57, vec![255, 2, 0, 0, 0, 0, 0, 0, 0]), - ((1 << 57) + 1, vec![255, 2, 0, 0, 0, 0, 0, 0, 1]), - ( - (1 << 58) - 1, - vec![255, 3, 255, 255, 255, 255, 255, 255, 255], - ), - (1 << 58, vec![255, 4, 0, 0, 0, 0, 0, 0, 0]), - ((1 << 58) + 1, vec![255, 4, 0, 0, 0, 0, 0, 0, 1]), - ( - (1 << 59) - 1, - vec![255, 7, 255, 255, 255, 255, 255, 255, 255], - ), - (1 << 59, vec![255, 8, 0, 0, 0, 0, 0, 0, 0]), - ((1 << 59) + 1, vec![255, 8, 0, 0, 0, 0, 0, 0, 1]), - ( - (1 << 60) - 1, - vec![255, 15, 255, 255, 255, 255, 255, 255, 255], - ), - (1 << 60, vec![255, 16, 0, 0, 0, 0, 0, 0, 0]), - ((1 << 60) + 1, vec![255, 16, 0, 0, 0, 0, 0, 0, 1]), - ( - (1 << 61) - 1, - vec![255, 31, 255, 255, 255, 255, 255, 255, 255], - ), - (1 << 61, vec![255, 32, 0, 0, 0, 0, 0, 0, 0]), - ((1 << 61) + 1, vec![255, 32, 0, 0, 0, 0, 0, 0, 1]), - ( - (1 << 62) - 1, - vec![255, 63, 255, 255, 255, 255, 255, 255, 255], - ), - (1 << 62, vec![255, 64, 0, 0, 0, 0, 0, 0, 0]), - ((1 << 62) + 1, vec![255, 64, 0, 0, 0, 0, 0, 0, 1]), - ( - (1 << 63) - 1, - vec![255, 127, 255, 255, 255, 255, 255, 255, 255], - ), - (1 << 63, vec![255, 128, 0, 0, 0, 0, 0, 0, 0]), - ((1 << 63) + 1, vec![255, 128, 0, 0, 0, 0, 0, 0, 1]), - (u64::MAX, vec![255, 255, 255, 255, 255, 255, 255, 255, 255]), - ]; - - let mut buf = Vec::new(); - - for (v, result) in unsigned_vint_encoding.into_iter() { - unsigned_vint_encode(v, &mut buf); - assert_eq!(buf, result); - let decoded_v = unsigned_vint_decode(&mut buf.as_slice()).unwrap(); - assert_eq!(v, decoded_v); - buf.clear(); - } -} - -#[test] -fn vint_encode_and_decode_test() { - let mut buf: Vec = Vec::with_capacity(128); - - let mut check = |n: i64| { - vint_encode(n, &mut buf); - assert_eq!(vint_decode(&mut buf.as_slice()).unwrap(), n); - buf.clear(); - }; - - for i in 0..63 { - check((1 << i) - 1); - check(1 - (1 << i)); - check(1 << i); - check(-(1 << i)); - check((1 << i) + 1); - check(-1 - (1 << i)); - } - check(i64::MAX); - check(-i64::MAX); - check(i64::MIN) -} diff --git a/scylla-cql/src/frame/value.rs b/scylla-cql/src/frame/value.rs deleted file mode 100644 index 92524d2..0000000 --- a/scylla-cql/src/frame/value.rs +++ /dev/null @@ -1,1111 +0,0 @@ -use crate::frame::types; -use bigdecimal::BigDecimal; -use bytes::BufMut; -use chrono::prelude::*; -use chrono::Duration; -use num_bigint::BigInt; -use std::borrow::Cow; -use std::collections::{BTreeMap, BTreeSet, HashMap, HashSet}; -use std::convert::TryInto; -use std::net::IpAddr; -use thiserror::Error; -use uuid::Uuid; - -use super::response::result::CqlValue; -use super::types::vint_encode; - -#[cfg(feature = "secret")] -use secrecy::{ExposeSecret, Secret, Zeroize}; - -/// Every value being sent in a query must implement this trait -/// serialize() should write the Value as [bytes] to the provided buffer -pub trait Value { - fn serialize(&self, buf: &mut Vec) -> Result<(), ValueTooBig>; -} - -#[derive(Debug, Error, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)] -#[error("Value too big to be sent in a request - max 2GiB allowed")] -pub struct ValueTooBig; - -/// Represents an unset value -pub struct Unset; - -/// Represents an counter value -#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)] -pub struct Counter(pub i64); - -/// Enum providing a way to represent a value that might be unset -#[derive(Clone, Copy)] -pub enum MaybeUnset { - Unset, - Set(V), -} - -/// Wrapper that allows to send dates outside of NaiveDate range (-262145-1-1 to 262143-12-31) -/// Days since -5877641-06-23 i.e. 2^31 days before unix epoch -#[derive(Clone, Copy, PartialEq, Eq, Debug)] -pub struct Date(pub u32); - -/// Wrapper used to differentiate between Time and Timestamp as sending values -/// Milliseconds since unix epoch -#[derive(Clone, Copy, PartialEq, Eq, Debug)] -pub struct Timestamp(pub Duration); - -/// Wrapper used to differentiate between Time and Timestamp as sending values -/// Nanoseconds since midnight -#[derive(Clone, Copy, PartialEq, Eq, Debug)] -pub struct Time(pub Duration); - -/// Keeps a buffer with serialized Values -/// Allows adding new Values and iterating over serialized ones -#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)] -pub struct SerializedValues { - serialized_values: Vec, - values_num: i16, - contains_names: bool, -} - -/// Represents a CQL Duration value -#[derive(Clone, Debug, Copy, PartialEq, Eq)] -pub struct CqlDuration { - pub months: i32, - pub days: i32, - pub nanoseconds: i64, -} - -#[derive(Debug, Error, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)] -pub enum SerializeValuesError { - #[error("Too many values to add, max 32 767 values can be sent in a request")] - TooManyValues, - #[error("Mixing named and not named values is not allowed")] - MixingNamedAndNotNamedValues, - #[error(transparent)] - ValueTooBig(#[from] ValueTooBig), - #[error("Parsing serialized values failed")] - ParseError, -} - -pub type SerializedResult<'a> = Result, SerializeValuesError>; - -/// Represents list of values to be sent in a query -/// gets serialized and but into request -pub trait ValueList { - /// Provides a view of ValueList as SerializedValues - /// returns `Cow` to make impl ValueList for SerializedValues efficient - fn serialized(&self) -> SerializedResult<'_>; - - fn write_to_request(&self, buf: &mut impl BufMut) -> Result<(), SerializeValuesError> { - let serialized = self.serialized()?; - SerializedValues::write_to_request(&serialized, buf); - - Ok(()) - } -} - -impl SerializedValues { - /// Creates empty value list - pub const fn new() -> Self { - SerializedValues { - serialized_values: Vec::new(), - values_num: 0, - contains_names: false, - } - } - - pub fn with_capacity(capacity: usize) -> Self { - SerializedValues { - serialized_values: Vec::with_capacity(capacity), - values_num: 0, - contains_names: false, - } - } - - pub fn has_names(&self) -> bool { - self.contains_names - } - - /// A const empty instance, useful for taking references - pub const EMPTY: &'static SerializedValues = &SerializedValues::new(); - - /// Serializes value and appends it to the list - pub fn add_value(&mut self, val: &impl Value) -> Result<(), SerializeValuesError> { - if self.contains_names { - return Err(SerializeValuesError::MixingNamedAndNotNamedValues); - } - if self.values_num == i16::MAX { - return Err(SerializeValuesError::TooManyValues); - } - - let len_before_serialize: usize = self.serialized_values.len(); - - if let Err(e) = val.serialize(&mut self.serialized_values) { - self.serialized_values.resize(len_before_serialize, 0); - return Err(SerializeValuesError::from(e)); - } - - self.values_num += 1; - Ok(()) - } - - pub fn add_named_value( - &mut self, - name: &str, - val: &impl Value, - ) -> Result<(), SerializeValuesError> { - if self.values_num > 0 && !self.contains_names { - return Err(SerializeValuesError::MixingNamedAndNotNamedValues); - } - self.contains_names = true; - if self.values_num == i16::MAX { - return Err(SerializeValuesError::TooManyValues); - } - - let len_before_serialize: usize = self.serialized_values.len(); - - types::write_string(name, &mut self.serialized_values) - .map_err(|_| SerializeValuesError::ParseError)?; - - if let Err(e) = val.serialize(&mut self.serialized_values) { - self.serialized_values.resize(len_before_serialize, 0); - return Err(SerializeValuesError::from(e)); - } - - self.values_num += 1; - Ok(()) - } - - pub fn iter(&self) -> impl Iterator> { - SerializedValuesIterator { - serialized_values: &self.serialized_values, - contains_names: self.contains_names, - } - } - - pub fn write_to_request(&self, buf: &mut impl BufMut) { - buf.put_i16(self.values_num); - buf.put(&self.serialized_values[..]); - } - - pub fn is_empty(&self) -> bool { - self.values_num == 0 - } - - pub fn len(&self) -> i16 { - self.values_num - } -} - -#[derive(Clone, Copy)] -pub struct SerializedValuesIterator<'a> { - serialized_values: &'a [u8], - contains_names: bool, -} - -impl<'a> Iterator for SerializedValuesIterator<'a> { - type Item = Option<&'a [u8]>; - - fn next(&mut self) -> Option { - if self.serialized_values.is_empty() { - return None; - } - - // In case of named values, skip names - if self.contains_names { - types::read_short_bytes(&mut self.serialized_values).expect("badly encoded value name"); - } - - Some(types::read_bytes_opt(&mut self.serialized_values).expect("badly encoded value")) - } -} - -/// Represents List of ValueList for Batch statement -/// -/// This trait is not implemented directly, but rather implemented through `BatchValuesGatWorkaround` -/// (until GATs are made available in Rust) -pub trait BatchValues: for<'r> BatchValuesGatWorkaround<'r> {} -impl BatchValuesGatWorkaround<'r> + ?Sized> BatchValues for T {} - -pub trait BatchValuesGatWorkaround<'r, ImplicitBounds = &'r Self> { - /// For some unknown reason, this type, when not resolved to a concrete type for a given async function, - /// cannot live across await boundaries while maintaining the corresponding future `Send`, unless `'r: 'static` - /// - /// See for more details - type BatchValuesIter: BatchValuesIterator<'r>; - fn batch_values_iter(&'r self) -> Self::BatchValuesIter; -} - -/// An iterator-like for `ValueList` -/// -/// An instance of this can be easily obtained from `IT: Iterator`: that would be -/// `BatchValuesIteratorFromIterator` -/// -/// It's just essentially making methods from `ValueList` accessible instead of being an actual iterator because of -/// several compiler limitations that would otherwise be very complex to overcome.\ -/// (specifically, types being different would require yielding enums for tuple impls, and the trait -/// bound of `for<'r> ::BatchValuesIter as Iterator>::Item: ValueList` is very -/// hard to express considering several compiler limitations) -pub trait BatchValuesIterator<'a> { - fn next_serialized(&mut self) -> Option>; - fn write_next_to_request( - &mut self, - buf: &mut impl BufMut, - ) -> Option>; - fn skip_next(&mut self) -> Option<()>; -} - -/// Implements `BatchValuesIterator` from an `Iterator` over references to things that implement `ValueList` -/// -/// Essentially used internally by this lib to provide implementors of `BatchValuesIterator` for cases -/// that always serialize the same concrete `ValueList` type -pub struct BatchValuesIteratorFromIterator { - it: IT, -} - -impl<'r, 'a: 'r, IT, VL> BatchValuesIterator<'r> for BatchValuesIteratorFromIterator -where - IT: Iterator, - VL: ValueList + 'a, -{ - fn next_serialized(&mut self) -> Option> { - self.it.next().map(|vl| vl.serialized()) - } - fn write_next_to_request( - &mut self, - buf: &mut impl BufMut, - ) -> Option> { - self.it.next().map(|vl| vl.write_to_request(buf)) - } - fn skip_next(&mut self) -> Option<()> { - self.it.next().map(|_| ()) - } -} - -impl From for BatchValuesIteratorFromIterator -where - IT: Iterator, - IT::Item: ValueList, -{ - fn from(it: IT) -> Self { - BatchValuesIteratorFromIterator { it } - } -} - -// -// Value impls -// - -// Implement Value for primitive types -impl Value for i8 { - fn serialize(&self, buf: &mut Vec) -> Result<(), ValueTooBig> { - buf.put_i32(1); - buf.put_i8(*self); - Ok(()) - } -} - -impl Value for i16 { - fn serialize(&self, buf: &mut Vec) -> Result<(), ValueTooBig> { - buf.put_i32(2); - buf.put_i16(*self); - Ok(()) - } -} - -impl Value for i32 { - fn serialize(&self, buf: &mut Vec) -> Result<(), ValueTooBig> { - buf.put_i32(4); - buf.put_i32(*self); - Ok(()) - } -} - -impl Value for i64 { - fn serialize(&self, buf: &mut Vec) -> Result<(), ValueTooBig> { - buf.put_i32(8); - buf.put_i64(*self); - Ok(()) - } -} - -impl Value for BigDecimal { - fn serialize(&self, buf: &mut Vec) -> Result<(), ValueTooBig> { - let (value, scale) = self.as_bigint_and_exponent(); - - let serialized = value.to_signed_bytes_be(); - let serialized_len: i32 = serialized.len().try_into().map_err(|_| ValueTooBig)?; - - buf.put_i32(serialized_len + 4); - buf.put_i32(scale.try_into().map_err(|_| ValueTooBig)?); - buf.extend_from_slice(&serialized); - - Ok(()) - } -} - -impl Value for NaiveDate { - fn serialize(&self, buf: &mut Vec) -> Result<(), ValueTooBig> { - buf.put_i32(4); - let unix_epoch = NaiveDate::from_ymd_opt(1970, 1, 1).unwrap(); - - let days: u32 = self - .signed_duration_since(unix_epoch) - .num_days() - .checked_add(1 << 31) - .and_then(|days| days.try_into().ok()) // convert to u32 - .ok_or(ValueTooBig)?; - - buf.put_u32(days); - Ok(()) - } -} - -impl Value for Date { - fn serialize(&self, buf: &mut Vec) -> Result<(), ValueTooBig> { - buf.put_i32(4); - buf.put_u32(self.0); - Ok(()) - } -} - -impl Value for Timestamp { - fn serialize(&self, buf: &mut Vec) -> Result<(), ValueTooBig> { - buf.put_i32(8); - buf.put_i64(self.0.num_milliseconds()); - Ok(()) - } -} - -impl Value for Time { - fn serialize(&self, buf: &mut Vec) -> Result<(), ValueTooBig> { - buf.put_i32(8); - buf.put_i64(self.0.num_nanoseconds().ok_or(ValueTooBig)?); - Ok(()) - } -} - -impl Value for DateTime { - fn serialize(&self, buf: &mut Vec) -> Result<(), ValueTooBig> { - buf.put_i32(8); - buf.put_i64(self.timestamp_millis()); - Ok(()) - } -} - -#[cfg(feature = "secret")] -impl Value for Secret { - fn serialize(&self, buf: &mut Vec) -> Result<(), ValueTooBig> { - self.expose_secret().serialize(buf) - } -} - -impl Value for bool { - fn serialize(&self, buf: &mut Vec) -> Result<(), ValueTooBig> { - buf.put_i32(1); - let false_bytes: &[u8] = &[0x00]; - let true_bytes: &[u8] = &[0x01]; - if *self { - buf.put(true_bytes); - } else { - buf.put(false_bytes); - } - - Ok(()) - } -} - -impl Value for f32 { - fn serialize(&self, buf: &mut Vec) -> Result<(), ValueTooBig> { - buf.put_i32(4); - buf.put_f32(*self); - Ok(()) - } -} - -impl Value for f64 { - fn serialize(&self, buf: &mut Vec) -> Result<(), ValueTooBig> { - buf.put_i32(8); - buf.put_f64(*self); - Ok(()) - } -} - -impl Value for Uuid { - fn serialize(&self, buf: &mut Vec) -> Result<(), ValueTooBig> { - buf.put_i32(16); - buf.extend_from_slice(self.as_bytes()); - Ok(()) - } -} - -impl Value for BigInt { - fn serialize(&self, buf: &mut Vec) -> Result<(), ValueTooBig> { - let serialized = self.to_signed_bytes_be(); - let serialized_len: i32 = serialized.len().try_into().map_err(|_| ValueTooBig)?; - - buf.put_i32(serialized_len); - buf.extend_from_slice(&serialized); - - Ok(()) - } -} - -impl Value for &str { - fn serialize(&self, buf: &mut Vec) -> Result<(), ValueTooBig> { - let str_bytes: &[u8] = self.as_bytes(); - let val_len: i32 = str_bytes.len().try_into().map_err(|_| ValueTooBig)?; - - buf.put_i32(val_len); - buf.put(str_bytes); - - Ok(()) - } -} - -impl Value for Vec { - fn serialize(&self, buf: &mut Vec) -> Result<(), ValueTooBig> { - let val_len: i32 = self.len().try_into().map_err(|_| ValueTooBig)?; - buf.put_i32(val_len); - - buf.extend_from_slice(self); - - Ok(()) - } -} - -impl Value for IpAddr { - fn serialize(&self, buf: &mut Vec) -> Result<(), ValueTooBig> { - match self { - IpAddr::V4(addr) => { - buf.put_i32(4); - buf.extend_from_slice(&addr.octets()); - } - IpAddr::V6(addr) => { - buf.put_i32(16); - buf.extend_from_slice(&addr.octets()); - } - } - - Ok(()) - } -} - -impl Value for String { - fn serialize(&self, buf: &mut Vec) -> Result<(), ValueTooBig> { - <&str as Value>::serialize(&self.as_str(), buf) - } -} - -/// Every `Option` can be serialized as None -> NULL, Some(val) -> val.serialize() -impl Value for Option { - fn serialize(&self, buf: &mut Vec) -> Result<(), ValueTooBig> { - match self { - Some(val) => ::serialize(val, buf), - None => { - buf.put_i32(-1); - Ok(()) - } - } - } -} - -impl Value for Unset { - fn serialize(&self, buf: &mut Vec) -> Result<(), ValueTooBig> { - // Unset serializes itself to empty value with length = -2 - buf.put_i32(-2); - Ok(()) - } -} - -impl Value for Counter { - fn serialize(&self, buf: &mut Vec) -> Result<(), ValueTooBig> { - self.0.serialize(buf) - } -} - -impl Value for CqlDuration { - fn serialize(&self, buf: &mut Vec) -> Result<(), ValueTooBig> { - let bytes_num_pos: usize = buf.len(); - buf.put_i32(0); - - vint_encode(self.months as i64, buf); - vint_encode(self.days as i64, buf); - vint_encode(self.nanoseconds, buf); - - let written_bytes: usize = buf.len() - bytes_num_pos - 4; - let written_bytes_i32: i32 = written_bytes.try_into().map_err(|_| ValueTooBig)?; - buf[bytes_num_pos..(bytes_num_pos + 4)].copy_from_slice(&written_bytes_i32.to_be_bytes()); - - Ok(()) - } -} - -impl Value for MaybeUnset { - fn serialize(&self, buf: &mut Vec) -> Result<(), ValueTooBig> { - match self { - MaybeUnset::Set(v) => v.serialize(buf), - MaybeUnset::Unset => Unset.serialize(buf), - } - } -} - -// Every &impl Value and &dyn Value should also implement Value -impl Value for &T { - fn serialize(&self, buf: &mut Vec) -> Result<(), ValueTooBig> { - ::serialize(*self, buf) - } -} - -// Every Boxed Value should also implement Value -impl Value for Box { - fn serialize(&self, buf: &mut Vec) -> Result<(), ValueTooBig> { - ::serialize(self.as_ref(), buf) - } -} - -fn serialize_map( - kv_iter: impl Iterator, - kv_count: usize, - buf: &mut Vec, -) -> Result<(), ValueTooBig> { - let bytes_num_pos: usize = buf.len(); - buf.put_i32(0); - - buf.put_i32(kv_count.try_into().map_err(|_| ValueTooBig)?); - for (key, value) in kv_iter { - ::serialize(&key, buf)?; - ::serialize(&value, buf)?; - } - - let written_bytes: usize = buf.len() - bytes_num_pos - 4; - let written_bytes_i32: i32 = written_bytes.try_into().map_err(|_| ValueTooBig)?; - buf[bytes_num_pos..(bytes_num_pos + 4)].copy_from_slice(&written_bytes_i32.to_be_bytes()); - - Ok(()) -} - -fn serialize_list_or_set<'a, V: 'a + Value>( - elements_iter: impl Iterator, - element_count: usize, - buf: &mut Vec, -) -> Result<(), ValueTooBig> { - let bytes_num_pos: usize = buf.len(); - buf.put_i32(0); - - buf.put_i32(element_count.try_into().map_err(|_| ValueTooBig)?); - for value in elements_iter { - ::serialize(value, buf)?; - } - - let written_bytes: usize = buf.len() - bytes_num_pos - 4; - let written_bytes_i32: i32 = written_bytes.try_into().map_err(|_| ValueTooBig)?; - buf[bytes_num_pos..(bytes_num_pos + 4)].copy_from_slice(&written_bytes_i32.to_be_bytes()); - - Ok(()) -} - -impl Value for HashSet { - fn serialize(&self, buf: &mut Vec) -> Result<(), ValueTooBig> { - serialize_list_or_set(self.iter(), self.len(), buf) - } -} - -impl Value for HashMap { - fn serialize(&self, buf: &mut Vec) -> Result<(), ValueTooBig> { - serialize_map(self.iter(), self.len(), buf) - } -} - -impl Value for BTreeSet { - fn serialize(&self, buf: &mut Vec) -> Result<(), ValueTooBig> { - serialize_list_or_set(self.iter(), self.len(), buf) - } -} - -impl Value for BTreeMap { - fn serialize(&self, buf: &mut Vec) -> Result<(), ValueTooBig> { - serialize_map(self.iter(), self.len(), buf) - } -} - -impl Value for Vec { - fn serialize(&self, buf: &mut Vec) -> Result<(), ValueTooBig> { - serialize_list_or_set(self.iter(), self.len(), buf) - } -} - -impl Value for &[T] { - fn serialize(&self, buf: &mut Vec) -> Result<(), ValueTooBig> { - serialize_list_or_set(self.iter(), self.len(), buf) - } -} - -fn serialize_tuple( - elem_iter: impl Iterator, - buf: &mut Vec, -) -> Result<(), ValueTooBig> { - let bytes_num_pos: usize = buf.len(); - buf.put_i32(0); - - for elem in elem_iter { - elem.serialize(buf)?; - } - - let written_bytes: usize = buf.len() - bytes_num_pos - 4; - let written_bytes_i32: i32 = written_bytes.try_into().map_err(|_| ValueTooBig)?; - buf[bytes_num_pos..(bytes_num_pos + 4)].copy_from_slice(&written_bytes_i32.to_be_bytes()); - - Ok(()) -} - -fn serialize_empty(buf: &mut Vec) -> Result<(), ValueTooBig> { - buf.put_i32(0); - Ok(()) -} - -impl Value for CqlValue { - fn serialize(&self, buf: &mut Vec) -> Result<(), ValueTooBig> { - match self { - CqlValue::Map(m) => serialize_map(m.iter().map(|(k, v)| (k, v)), m.len(), buf), - CqlValue::Tuple(t) => serialize_tuple(t.iter(), buf), - - // A UDT value is composed of successive [bytes] values, one for each field of the UDT - // value (in the order defined by the type), so they serialize in a same way tuples do. - CqlValue::UserDefinedType { fields, .. } => { - serialize_tuple(fields.iter().map(|(_, value)| value), buf) - } - - CqlValue::Date(d) => Date(*d).serialize(buf), - CqlValue::Duration(d) => d.serialize(buf), - CqlValue::Timestamp(t) => Timestamp(*t).serialize(buf), - CqlValue::Time(t) => Time(*t).serialize(buf), - - CqlValue::Ascii(s) | CqlValue::Text(s) => s.serialize(buf), - CqlValue::List(v) | CqlValue::Set(v) => v.serialize(buf), - - CqlValue::Blob(b) => b.serialize(buf), - CqlValue::Boolean(b) => b.serialize(buf), - CqlValue::Counter(c) => c.serialize(buf), - CqlValue::Decimal(d) => d.serialize(buf), - CqlValue::Double(d) => d.serialize(buf), - CqlValue::Float(f) => f.serialize(buf), - CqlValue::Int(i) => i.serialize(buf), - CqlValue::BigInt(i) => i.serialize(buf), - CqlValue::Inet(i) => i.serialize(buf), - CqlValue::SmallInt(s) => s.serialize(buf), - CqlValue::TinyInt(t) => t.serialize(buf), - CqlValue::Timeuuid(t) => t.serialize(buf), - CqlValue::Uuid(u) => u.serialize(buf), - CqlValue::Varint(v) => v.serialize(buf), - - CqlValue::Empty => serialize_empty(buf), - } - } -} - -macro_rules! impl_value_for_tuple { - ( $($Ti:ident),* ; $($FieldI:tt),* ) => { - impl<$($Ti),+> Value for ($($Ti,)+) - where - $($Ti: Value),+ - { - fn serialize(&self, buf: &mut Vec) -> Result<(), ValueTooBig> { - let bytes_num_pos: usize = buf.len(); - buf.put_i32(0); - $( - <$Ti as Value>::serialize(&self.$FieldI, buf)?; - )* - - let written_bytes: usize = buf.len() - bytes_num_pos - 4; - let written_bytes_i32: i32 = written_bytes.try_into().map_err(|_| ValueTooBig) ?; - buf[bytes_num_pos..(bytes_num_pos+4)].copy_from_slice(&written_bytes_i32.to_be_bytes()); - - Ok(()) - } - } - } -} - -impl_value_for_tuple!(T0; 0); -impl_value_for_tuple!(T0, T1; 0, 1); -impl_value_for_tuple!(T0, T1, T2; 0, 1, 2); -impl_value_for_tuple!(T0, T1, T2, T3; 0, 1, 2, 3); -impl_value_for_tuple!(T0, T1, T2, T3, T4; 0, 1, 2, 3, 4); -impl_value_for_tuple!(T0, T1, T2, T3, T4, T5; 0, 1, 2, 3, 4, 5); -impl_value_for_tuple!(T0, T1, T2, T3, T4, T5, T6; 0, 1, 2, 3, 4, 5, 6); -impl_value_for_tuple!(T0, T1, T2, T3, T4, T5, T6, T7; 0, 1, 2, 3, 4, 5, 6, 7); -impl_value_for_tuple!(T0, T1, T2, T3, T4, T5, T6, T7, T8; 0, 1, 2, 3, 4, 5, 6, 7, 8); -impl_value_for_tuple!(T0, T1, T2, T3, T4, T5, T6, T7, T8, T9; - 0, 1, 2, 3, 4, 5, 6, 7, 8, 9); -impl_value_for_tuple!(T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10; - 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10); -impl_value_for_tuple!(T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11; - 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11); -impl_value_for_tuple!(T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12; - 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12); -impl_value_for_tuple!(T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13; - 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13); -impl_value_for_tuple!(T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14; - 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14); -impl_value_for_tuple!(T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15; - 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15); - -// -// ValueList impls -// - -// Implement ValueList for the unit type -impl ValueList for () { - fn serialized(&self) -> SerializedResult<'_> { - Ok(Cow::Owned(SerializedValues::new())) - } -} - -// Implement ValueList for &[] - u8 because otherwise rust can't infer type -impl ValueList for [u8; 0] { - fn serialized(&self) -> SerializedResult<'_> { - Ok(Cow::Owned(SerializedValues::new())) - } -} - -// Implement ValueList for slices of Value types -impl ValueList for &[T] { - fn serialized(&self) -> SerializedResult<'_> { - let mut result = SerializedValues::with_capacity(self.len()); - for val in *self { - result.add_value(val)?; - } - - Ok(Cow::Owned(result)) - } -} - -// Implement ValueList for Vec -impl ValueList for Vec { - fn serialized(&self) -> SerializedResult<'_> { - let mut result = SerializedValues::with_capacity(self.len()); - for val in self { - result.add_value(val)?; - } - - Ok(Cow::Owned(result)) - } -} - -// Implement ValueList for maps, which serializes named values -macro_rules! impl_value_list_for_map { - ($map_type:ident, $key_type:ty) => { - impl ValueList for $map_type<$key_type, T> { - fn serialized(&self) -> SerializedResult<'_> { - let mut result = SerializedValues::with_capacity(self.len()); - for (key, val) in self { - result.add_named_value(key, val)?; - } - - Ok(Cow::Owned(result)) - } - } - }; -} - -impl_value_list_for_map!(HashMap, String); -impl_value_list_for_map!(HashMap, &str); -impl_value_list_for_map!(BTreeMap, String); -impl_value_list_for_map!(BTreeMap, &str); - -// Implement ValueList for tuples of Values of size up to 16 - -// Here is an example implementation for (T0, ) -// Further variants are done using a macro -impl ValueList for (T0,) { - fn serialized(&self) -> SerializedResult<'_> { - let mut result = SerializedValues::with_capacity(1); - result.add_value(&self.0)?; - Ok(Cow::Owned(result)) - } -} - -macro_rules! impl_value_list_for_tuple { - ( $($Ti:ident),* ; $($FieldI:tt),* ; $size: expr) => { - impl<$($Ti),+> ValueList for ($($Ti,)+) - where - $($Ti: Value),+ - { - fn serialized(&self) -> SerializedResult<'_> { - let mut result = SerializedValues::with_capacity($size); - $( - result.add_value(&self.$FieldI) ?; - )* - Ok(Cow::Owned(result)) - } - } - } -} - -impl_value_list_for_tuple!(T0, T1; 0, 1; 2); -impl_value_list_for_tuple!(T0, T1, T2; 0, 1, 2; 3); -impl_value_list_for_tuple!(T0, T1, T2, T3; 0, 1, 2, 3; 4); -impl_value_list_for_tuple!(T0, T1, T2, T3, T4; 0, 1, 2, 3, 4; 5); -impl_value_list_for_tuple!(T0, T1, T2, T3, T4, T5; 0, 1, 2, 3, 4, 5; 6); -impl_value_list_for_tuple!(T0, T1, T2, T3, T4, T5, T6; 0, 1, 2, 3, 4, 5, 6; 7); -impl_value_list_for_tuple!(T0, T1, T2, T3, T4, T5, T6, T7; 0, 1, 2, 3, 4, 5, 6, 7; 8); -impl_value_list_for_tuple!(T0, T1, T2, T3, T4, T5, T6, T7, T8; 0, 1, 2, 3, 4, 5, 6, 7, 8; 9); -impl_value_list_for_tuple!(T0, T1, T2, T3, T4, T5, T6, T7, T8, T9; - 0, 1, 2, 3, 4, 5, 6, 7, 8, 9; 10); -impl_value_list_for_tuple!(T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10; - 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10; 11); -impl_value_list_for_tuple!(T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11; - 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11; 12); -impl_value_list_for_tuple!(T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12; - 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12; 13); -impl_value_list_for_tuple!(T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13; - 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13; 14); -impl_value_list_for_tuple!(T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14; - 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14; 15); -impl_value_list_for_tuple!(T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15; - 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15; 16); - -// Every &impl ValueList should also implement ValueList -impl ValueList for &T { - fn serialized(&self) -> SerializedResult<'_> { - ::serialized(*self) - } -} - -impl ValueList for SerializedValues { - fn serialized(&self) -> SerializedResult<'_> { - Ok(Cow::Borrowed(self)) - } -} - -impl<'b> ValueList for Cow<'b, SerializedValues> { - fn serialized(&self) -> SerializedResult<'_> { - Ok(Cow::Borrowed(self.as_ref())) - } -} - -// -// BatchValues impls -// - -/// Implements `BatchValues` from an `Iterator` over references to things that implement `ValueList` -/// -/// This is to avoid requiring allocating a new `Vec` containing all the `ValueList`s directly: -/// with this, one can write: -/// `session.batch(&batch, BatchValuesFromIterator::from(lines_to_insert.iter().map(|l| &l.value_list)))` -/// where `lines_to_insert` may also contain e.g. data to pick the statement... -/// -/// The underlying iterator will always be cloned at least once, once to compute the length if it can't be known -/// in advance, and be re-cloned at every retry. -/// It is consequently expected that the provided iterator is cheap to clone (e.g. `slice.iter().map(...)`). -pub struct BatchValuesFromIter { - it: IT, -} - -impl<'a, IT, VL> BatchValuesFromIter -where - IT: Iterator + Clone, - VL: ValueList + 'a, -{ - pub fn new(into_iter: impl IntoIterator) -> Self { - Self { - it: into_iter.into_iter(), - } - } -} - -impl From for BatchValuesFromIter -where - IT: Iterator + Clone, - IT::Item: ValueList, -{ - fn from(it: IT) -> Self { - Self { it } - } -} - -impl<'r, 'a: 'r, IT, VL> BatchValuesGatWorkaround<'r> for BatchValuesFromIter -where - IT: Iterator + Clone, - VL: ValueList + 'a, -{ - type BatchValuesIter = BatchValuesIteratorFromIterator; - fn batch_values_iter(&'r self) -> Self::BatchValuesIter { - self.it.clone().into() - } -} - -// Implement BatchValues for slices of ValueList types -impl<'r, T: ValueList> BatchValuesGatWorkaround<'r> for [T] { - type BatchValuesIter = BatchValuesIteratorFromIterator>; - fn batch_values_iter(&'r self) -> Self::BatchValuesIter { - self.iter().into() - } -} - -// Implement BatchValues for Vec -impl<'r, T: ValueList> BatchValuesGatWorkaround<'r> for Vec { - type BatchValuesIter = BatchValuesIteratorFromIterator>; - fn batch_values_iter(&'r self) -> Self::BatchValuesIter { - BatchValuesGatWorkaround::batch_values_iter(self.as_slice()) - } -} - -// Here is an example implementation for (T0, ) -// Further variants are done using a macro -impl<'r, T0: ValueList> BatchValuesGatWorkaround<'r> for (T0,) { - type BatchValuesIter = BatchValuesIteratorFromIterator>; - fn batch_values_iter(&'r self) -> Self::BatchValuesIter { - std::iter::once(&self.0).into() - } -} - -pub struct TupleValuesIter<'a, T> { - tuple: &'a T, - idx: usize, -} - -macro_rules! impl_batch_values_for_tuple { - ( $($Ti:ident),* ; $($FieldI:tt),* ; $TupleSize:tt) => { - impl<'r, $($Ti),+> BatchValuesGatWorkaround<'r> for ($($Ti,)+) - where - $($Ti: ValueList),+ - { - type BatchValuesIter = TupleValuesIter<'r, ($($Ti,)+)>; - fn batch_values_iter(&'r self) -> Self::BatchValuesIter { - TupleValuesIter { - tuple: self, - idx: 0, - } - } - } - impl<'r, $($Ti),+> BatchValuesIterator<'r> for TupleValuesIter<'r, ($($Ti,)+)> - where - $($Ti: ValueList),+ - { - fn next_serialized(&mut self) -> Option> { - let serialized_value_res = match self.idx { - $( - $FieldI => self.tuple.$FieldI.serialized(), - )* - _ => return None, - }; - self.idx += 1; - Some(serialized_value_res) - } - fn write_next_to_request( - &mut self, - buf: &mut impl BufMut, - ) -> Option> { - let ret = match self.idx { - $( - $FieldI => self.tuple.$FieldI.write_to_request(buf), - )* - _ => return None, - }; - self.idx += 1; - Some(ret) - } - fn skip_next(&mut self) -> Option<()> { - if self.idx < $TupleSize { - self.idx += 1; - Some(()) - } else { - None - } - } - } - } -} - -impl_batch_values_for_tuple!(T0, T1; 0, 1; 2); -impl_batch_values_for_tuple!(T0, T1, T2; 0, 1, 2; 3); -impl_batch_values_for_tuple!(T0, T1, T2, T3; 0, 1, 2, 3; 4); -impl_batch_values_for_tuple!(T0, T1, T2, T3, T4; 0, 1, 2, 3, 4; 5); -impl_batch_values_for_tuple!(T0, T1, T2, T3, T4, T5; 0, 1, 2, 3, 4, 5; 6); -impl_batch_values_for_tuple!(T0, T1, T2, T3, T4, T5, T6; 0, 1, 2, 3, 4, 5, 6; 7); -impl_batch_values_for_tuple!(T0, T1, T2, T3, T4, T5, T6, T7; 0, 1, 2, 3, 4, 5, 6, 7; 8); -impl_batch_values_for_tuple!(T0, T1, T2, T3, T4, T5, T6, T7, T8; 0, 1, 2, 3, 4, 5, 6, 7, 8; 9); -impl_batch_values_for_tuple!(T0, T1, T2, T3, T4, T5, T6, T7, T8, T9; - 0, 1, 2, 3, 4, 5, 6, 7, 8, 9; 10); -impl_batch_values_for_tuple!(T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10; - 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10; 11); -impl_batch_values_for_tuple!(T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11; - 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11; 12); -impl_batch_values_for_tuple!(T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12; - 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12; 13); -impl_batch_values_for_tuple!(T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13; - 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13; 14); -impl_batch_values_for_tuple!(T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14; - 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14; 15); -impl_batch_values_for_tuple!(T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15; - 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15; 16); - -// Every &impl BatchValues should also implement BatchValues -impl<'a, 'r, T: BatchValues + ?Sized> BatchValuesGatWorkaround<'r> for &'a T { - type BatchValuesIter = >::BatchValuesIter; - fn batch_values_iter(&'r self) -> Self::BatchValuesIter { - >::batch_values_iter(*self) - } -} - -/// Allows reusing already-serialized first value -/// -/// We'll need to build a `SerializedValues` for the first ~`ValueList` of a batch to figure out the shard (#448). -/// Once that is done, we can use that instead of re-serializing. -/// -/// This struct implements both `BatchValues` and `BatchValuesIterator` for that purpose -pub struct BatchValuesFirstSerialized<'f, T> { - first: Option<&'f SerializedValues>, - rest: T, -} - -impl<'f, T: BatchValues> BatchValuesFirstSerialized<'f, T> { - pub fn new(batch_values: T, already_serialized_first: Option<&'f SerializedValues>) -> Self { - Self { - first: already_serialized_first, - rest: batch_values, - } - } -} - -impl<'r, 'f, BV: BatchValues> BatchValuesGatWorkaround<'r> for BatchValuesFirstSerialized<'f, BV> { - type BatchValuesIter = - BatchValuesFirstSerialized<'f, >::BatchValuesIter>; - fn batch_values_iter(&'r self) -> Self::BatchValuesIter { - BatchValuesFirstSerialized { - first: self.first, - rest: self.rest.batch_values_iter(), - } - } -} - -impl<'a, 'f: 'a, IT: BatchValuesIterator<'a>> BatchValuesIterator<'a> - for BatchValuesFirstSerialized<'f, IT> -{ - fn next_serialized(&mut self) -> Option> { - match self.first.take() { - Some(first) => { - self.rest.skip_next(); - Some(Ok(Cow::Borrowed(first))) - } - None => self.rest.next_serialized(), - } - } - fn write_next_to_request( - &mut self, - buf: &mut impl BufMut, - ) -> Option> { - match self.first.take() { - Some(first) => { - self.rest.skip_next(); - first.write_to_request(buf); - Some(Ok(())) - } - None => self.rest.write_next_to_request(buf), - } - } - fn skip_next(&mut self) -> Option<()> { - self.rest.skip_next(); - self.first.take().map(|_| ()) - } -} diff --git a/scylla-cql/src/frame/value_tests.rs b/scylla-cql/src/frame/value_tests.rs deleted file mode 100644 index 7f7135e..0000000 --- a/scylla-cql/src/frame/value_tests.rs +++ /dev/null @@ -1,720 +0,0 @@ -use crate::frame::value::{BatchValuesGatWorkaround, BatchValuesIterator}; - -use super::value::{ - BatchValues, Date, MaybeUnset, SerializeValuesError, SerializedValues, Time, Timestamp, Unset, - Value, ValueList, ValueTooBig, -}; -use bytes::BufMut; -use chrono::{Duration, NaiveDate}; -use std::{borrow::Cow, convert::TryInto}; -use uuid::Uuid; - -fn serialized(val: impl Value) -> Vec { - let mut result: Vec = Vec::new(); - val.serialize(&mut result).unwrap(); - result -} - -#[test] -fn basic_serialization() { - assert_eq!(serialized(8_i8), vec![0, 0, 0, 1, 8]); - assert_eq!(serialized(16_i16), vec![0, 0, 0, 2, 0, 16]); - assert_eq!(serialized(32_i32), vec![0, 0, 0, 4, 0, 0, 0, 32]); - assert_eq!( - serialized(64_i64), - vec![0, 0, 0, 8, 0, 0, 0, 0, 0, 0, 0, 64] - ); - - assert_eq!(serialized("abc"), vec![0, 0, 0, 3, 97, 98, 99]); - assert_eq!(serialized("abc".to_string()), vec![0, 0, 0, 3, 97, 98, 99]); -} - -#[test] -fn naive_date_serialization() { - // 1970-01-31 is 2^31 - let unix_epoch: NaiveDate = NaiveDate::from_ymd_opt(1970, 1, 1).unwrap(); - assert_eq!(serialized(unix_epoch), vec![0, 0, 0, 4, 128, 0, 0, 0]); - assert_eq!(2_u32.pow(31).to_be_bytes(), [128, 0, 0, 0]); - - // 1969-12-02 is 2^31 - 30 - let before_epoch: NaiveDate = NaiveDate::from_ymd_opt(1969, 12, 2).unwrap(); - assert_eq!( - serialized(before_epoch), - vec![0, 0, 0, 4, 127, 255, 255, 226] - ); - assert_eq!((2_u32.pow(31) - 30).to_be_bytes(), [127, 255, 255, 226]); - - // 1970-01-31 is 2^31 + 30 - let after_epoch: NaiveDate = NaiveDate::from_ymd_opt(1970, 1, 31).unwrap(); - assert_eq!(serialized(after_epoch), vec![0, 0, 0, 4, 128, 0, 0, 30]); - assert_eq!((2_u32.pow(31) + 30).to_be_bytes(), [128, 0, 0, 30]); -} - -#[test] -fn date_serialization() { - assert_eq!(serialized(Date(0)), vec![0, 0, 0, 4, 0, 0, 0, 0]); - assert_eq!( - serialized(Date(u32::MAX)), - vec![0, 0, 0, 4, 255, 255, 255, 255] - ); -} - -#[test] -fn time_serialization() { - // Time is an i64 - nanoseconds since midnight - // in range 0..=86399999999999 - - let max_time: i64 = 24 * 60 * 60 * 1_000_000_000 - 1; - assert_eq!(max_time, 86399999999999); - - // Check that basic values are serialized correctly - // Invalid values are also serialized correctly - database will respond with an error - for test_val in [0, 1, 15, 18463, max_time, -1, -324234, max_time + 16].iter() { - let test_time: Time = Time(Duration::nanoseconds(*test_val)); - let bytes: Vec = serialized(test_time); - - let mut expected_bytes: Vec = vec![0, 0, 0, 8]; - expected_bytes.extend_from_slice(&test_val.to_be_bytes()); - - assert_eq!(bytes, expected_bytes); - assert_eq!(expected_bytes.len(), 12); - } - - // Durations so long that nanoseconds don't fit in i64 cause an error - let long_time = Time(Duration::milliseconds(i64::MAX)); - assert_eq!(long_time.serialize(&mut Vec::new()), Err(ValueTooBig)); -} - -#[test] -fn timestamp_serialization() { - // Timestamp is milliseconds since unix epoch represented as i64 - - for test_val in &[0, -1, 1, -45345346, 453451, i64::MIN, i64::MAX] { - let test_timestamp: Timestamp = Timestamp(Duration::milliseconds(*test_val)); - let bytes: Vec = serialized(test_timestamp); - - let mut expected_bytes: Vec = vec![0, 0, 0, 8]; - expected_bytes.extend_from_slice(&test_val.to_be_bytes()); - - assert_eq!(bytes, expected_bytes); - assert_eq!(expected_bytes.len(), 12); - } -} - -#[test] -fn datetime_serialization() { - use chrono::{DateTime, NaiveDateTime, Utc}; - // Datetime is milliseconds since unix epoch represented as i64 - let max_time: i64 = 24 * 60 * 60 * 1_000_000_000 - 1; - - for test_val in &[0, 1, 15, 18463, max_time, max_time + 16] { - let native_datetime = NaiveDateTime::from_timestamp_opt( - *test_val / 1000, - ((*test_val % 1000) as i32 * 1_000_000) as u32, - ) - .expect("invalid or out-of-range datetime"); - let test_datetime = DateTime::::from_utc(native_datetime, Utc); - let bytes: Vec = serialized(test_datetime); - - let mut expected_bytes: Vec = vec![0, 0, 0, 8]; - expected_bytes.extend_from_slice(&test_val.to_be_bytes()); - - assert_eq!(bytes, expected_bytes); - assert_eq!(expected_bytes.len(), 12); - } -} - -#[test] -fn timeuuid_serialization() { - // A few random timeuuids generated manually - let tests = [ - [ - 0x8e, 0x14, 0xe7, 0x60, 0x7f, 0xa8, 0x11, 0xeb, 0xbc, 0x66, 0, 0, 0, 0, 0, 0x01, - ], - [ - 0x9b, 0x34, 0x95, 0x80, 0x7f, 0xa8, 0x11, 0xeb, 0xbc, 0x66, 0, 0, 0, 0, 0, 0x01, - ], - [ - 0x5d, 0x74, 0xba, 0xe0, 0x7f, 0xa3, 0x11, 0xeb, 0xbc, 0x66, 0, 0, 0, 0, 0, 0x01, - ], - ]; - - for uuid_bytes in &tests { - let uuid = Uuid::from_slice(uuid_bytes.as_ref()).unwrap(); - let uuid_serialized: Vec = serialized(uuid); - - let mut expected_serialized: Vec = vec![0, 0, 0, 16]; - expected_serialized.extend_from_slice(uuid_bytes.as_ref()); - - assert_eq!(uuid_serialized, expected_serialized); - } -} - -#[test] -fn option_value() { - assert_eq!(serialized(Some(32_i32)), vec![0, 0, 0, 4, 0, 0, 0, 32]); - let null_i32: Option = None; - assert_eq!(serialized(null_i32), &(-1_i32).to_be_bytes()[..]); -} - -#[test] -fn unset_value() { - assert_eq!(serialized(Unset), &(-2_i32).to_be_bytes()[..]); - - let unset_i32: MaybeUnset = MaybeUnset::Unset; - assert_eq!(serialized(unset_i32), &(-2_i32).to_be_bytes()[..]); - - let set_i32: MaybeUnset = MaybeUnset::Set(32); - assert_eq!(serialized(set_i32), vec![0, 0, 0, 4, 0, 0, 0, 32]); -} - -#[test] -fn ref_value() { - fn serialized_generic(val: T) -> Vec { - let mut result: Vec = Vec::new(); - val.serialize(&mut result).unwrap(); - result - } - - // This trickery is needed to prevent the compiler from performing deref coercions on refs - // and effectively defeating the purpose of this test. With specialisations provided - // in such an explicit way, the compiler is not allowed to coerce. - fn check(x: &T, y: T) { - assert_eq!(serialized_generic::<&T>(x), serialized_generic::(y)); - } - - check(&1_i32, 1_i32); -} - -#[test] -fn empty_serialized_values() { - const EMPTY: SerializedValues = SerializedValues::new(); - assert_eq!(EMPTY.len(), 0); - assert!(EMPTY.is_empty()); - assert_eq!(EMPTY.iter().next(), None); - - let mut empty_request = Vec::::new(); - EMPTY.write_to_request(&mut empty_request); - assert_eq!(empty_request, vec![0, 0]); -} - -#[test] -fn serialized_values() { - let mut values = SerializedValues::new(); - assert!(values.is_empty()); - - // Add first value - values.add_value(&8_i8).unwrap(); - { - assert_eq!(values.len(), 1); - assert!(!values.is_empty()); - - let mut request = Vec::::new(); - values.write_to_request(&mut request); - assert_eq!(request, vec![0, 1, 0, 0, 0, 1, 8]); - - assert_eq!(values.iter().collect::>(), vec![Some([8].as_ref())]); - } - - // Add second value - values.add_value(&16_i16).unwrap(); - { - assert_eq!(values.len(), 2); - assert!(!values.is_empty()); - - let mut request = Vec::::new(); - values.write_to_request(&mut request); - assert_eq!(request, vec![0, 2, 0, 0, 0, 1, 8, 0, 0, 0, 2, 0, 16]); - - assert_eq!( - values.iter().collect::>(), - vec![Some([8].as_ref()), Some([0, 16].as_ref())] - ); - } - - // Add a value that's too big, recover gracefully - struct TooBigValue; - impl Value for TooBigValue { - fn serialize(&self, buf: &mut Vec) -> Result<(), ValueTooBig> { - // serialize some - buf.put_i32(1); - - // then throw an error - Err(ValueTooBig) - } - } - - assert_eq!( - values.add_value(&TooBigValue), - Err(SerializeValuesError::ValueTooBig(ValueTooBig)) - ); - - // All checks for two values should still pass - { - assert_eq!(values.len(), 2); - assert!(!values.is_empty()); - - let mut request = Vec::::new(); - values.write_to_request(&mut request); - assert_eq!(request, vec![0, 2, 0, 0, 0, 1, 8, 0, 0, 0, 2, 0, 16]); - - assert_eq!( - values.iter().collect::>(), - vec![Some([8].as_ref()), Some([0, 16].as_ref())] - ); - } -} - -#[test] -fn unit_value_list() { - let serialized_unit: SerializedValues = - <() as ValueList>::serialized(&()).unwrap().into_owned(); - assert!(serialized_unit.is_empty()); -} - -#[test] -fn empty_array_value_list() { - let serialized_arr: SerializedValues = <[u8; 0] as ValueList>::serialized(&[]) - .unwrap() - .into_owned(); - assert!(serialized_arr.is_empty()); -} - -#[test] -fn slice_value_list() { - let values: &[i32] = &[1, 2, 3]; - let serialized: SerializedValues = <&[i32] as ValueList>::serialized(&values) - .unwrap() - .into_owned(); - - assert_eq!( - serialized.iter().collect::>(), - vec![ - Some([0, 0, 0, 1].as_ref()), - Some([0, 0, 0, 2].as_ref()), - Some([0, 0, 0, 3].as_ref()) - ] - ); -} - -#[test] -fn vec_value_list() { - let values: Vec = vec![1, 2, 3]; - let serialized: SerializedValues = as ValueList>::serialized(&values) - .unwrap() - .into_owned(); - - assert_eq!( - serialized.iter().collect::>(), - vec![ - Some([0, 0, 0, 1].as_ref()), - Some([0, 0, 0, 2].as_ref()), - Some([0, 0, 0, 3].as_ref()) - ] - ); -} - -#[test] -fn tuple_value_list() { - fn check_i8_tuple(tuple: impl ValueList, expected: core::ops::Range) { - let serialized: SerializedValues = tuple.serialized().unwrap().into_owned(); - assert_eq!(serialized.len() as usize, expected.len()); - - let serialized_vals: Vec = serialized - .iter() - .map(|o: Option<&[u8]>| o.unwrap()[0]) - .collect(); - - let expected: Vec = expected.collect(); - - assert_eq!(serialized_vals, expected); - } - - check_i8_tuple((1_i8,), 1..2); - check_i8_tuple((1_i8, 2_i8), 1..3); - check_i8_tuple((1_i8, 2_i8, 3_i8), 1..4); - check_i8_tuple((1_i8, 2_i8, 3_i8, 4_i8), 1..5); - check_i8_tuple((1_i8, 2_i8, 3_i8, 4_i8, 5_i8), 1..6); - check_i8_tuple((1_i8, 2_i8, 3_i8, 4_i8, 5_i8, 6_i8), 1..7); - check_i8_tuple((1_i8, 2_i8, 3_i8, 4_i8, 5_i8, 6_i8, 7_i8), 1..8); - check_i8_tuple((1_i8, 2_i8, 3_i8, 4_i8, 5_i8, 6_i8, 7_i8, 8_i8), 1..9); - check_i8_tuple( - (1_i8, 2_i8, 3_i8, 4_i8, 5_i8, 6_i8, 7_i8, 8_i8, 9_i8), - 1..10, - ); - check_i8_tuple( - (1_i8, 2_i8, 3_i8, 4_i8, 5_i8, 6_i8, 7_i8, 8_i8, 9_i8, 10_i8), - 1..11, - ); - check_i8_tuple( - ( - 1_i8, 2_i8, 3_i8, 4_i8, 5_i8, 6_i8, 7_i8, 8_i8, 9_i8, 10_i8, 11_i8, - ), - 1..12, - ); - check_i8_tuple( - ( - 1_i8, 2_i8, 3_i8, 4_i8, 5_i8, 6_i8, 7_i8, 8_i8, 9_i8, 10_i8, 11_i8, 12_i8, - ), - 1..13, - ); - check_i8_tuple( - ( - 1_i8, 2_i8, 3_i8, 4_i8, 5_i8, 6_i8, 7_i8, 8_i8, 9_i8, 10_i8, 11_i8, 12_i8, 13_i8, - ), - 1..14, - ); - check_i8_tuple( - ( - 1_i8, 2_i8, 3_i8, 4_i8, 5_i8, 6_i8, 7_i8, 8_i8, 9_i8, 10_i8, 11_i8, 12_i8, 13_i8, 14_i8, - ), - 1..15, - ); - check_i8_tuple( - ( - 1_i8, 2_i8, 3_i8, 4_i8, 5_i8, 6_i8, 7_i8, 8_i8, 9_i8, 10_i8, 11_i8, 12_i8, 13_i8, - 14_i8, 15_i8, - ), - 1..16, - ); - check_i8_tuple( - ( - 1_i8, 2_i8, 3_i8, 4_i8, 5_i8, 6_i8, 7_i8, 8_i8, 9_i8, 10_i8, 11_i8, 12_i8, 13_i8, - 14_i8, 15_i8, 16_i8, - ), - 1..17, - ); -} - -#[test] -fn ref_value_list() { - let values: &[i32] = &[1, 2, 3]; - let serialized: SerializedValues = <&&[i32] as ValueList>::serialized(&&values) - .unwrap() - .into_owned(); - - assert_eq!( - serialized.iter().collect::>(), - vec![ - Some([0, 0, 0, 1].as_ref()), - Some([0, 0, 0, 2].as_ref()), - Some([0, 0, 0, 3].as_ref()) - ] - ); -} - -#[test] -fn serialized_values_value_list() { - let mut ser_values = SerializedValues::new(); - ser_values.add_value(&1_i32).unwrap(); - ser_values.add_value(&"qwertyuiop").unwrap(); - - let ser_ser_values: Cow = ser_values.serialized().unwrap(); - assert!(matches!(ser_ser_values, Cow::Borrowed(_))); - - assert_eq!(&ser_values, ser_ser_values.as_ref()); -} - -#[test] -fn cow_serialized_values_value_list() { - let cow_ser_values: Cow = Cow::Owned(SerializedValues::new()); - - let serialized: Cow = cow_ser_values.serialized().unwrap(); - assert!(matches!(serialized, Cow::Borrowed(_))); - - assert_eq!(cow_ser_values.as_ref(), serialized.as_ref()); -} - -#[test] -fn slice_batch_values() { - let batch_values: &[&[i8]] = &[&[1, 2], &[2, 3, 4, 5], &[6]]; - let mut it = batch_values.batch_values_iter(); - { - let mut request: Vec = Vec::new(); - it.write_next_to_request(&mut request).unwrap().unwrap(); - assert_eq!(request, vec![0, 2, 0, 0, 0, 1, 1, 0, 0, 0, 1, 2]); - } - - { - let mut request: Vec = Vec::new(); - it.write_next_to_request(&mut request).unwrap().unwrap(); - assert_eq!( - request, - vec![0, 4, 0, 0, 0, 1, 2, 0, 0, 0, 1, 3, 0, 0, 0, 1, 4, 0, 0, 0, 1, 5] - ); - } - - { - let mut request: Vec = Vec::new(); - it.write_next_to_request(&mut request).unwrap().unwrap(); - assert_eq!(request, vec![0, 1, 0, 0, 0, 1, 6]); - } - - assert_eq!(it.write_next_to_request(&mut Vec::new()), None); -} - -#[test] -fn vec_batch_values() { - let batch_values: Vec> = vec![vec![1, 2], vec![2, 3, 4, 5], vec![6]]; - - let mut it = batch_values.batch_values_iter(); - { - let mut request: Vec = Vec::new(); - it.write_next_to_request(&mut request).unwrap().unwrap(); - assert_eq!(request, vec![0, 2, 0, 0, 0, 1, 1, 0, 0, 0, 1, 2]); - } - - { - let mut request: Vec = Vec::new(); - it.write_next_to_request(&mut request).unwrap().unwrap(); - assert_eq!( - request, - vec![0, 4, 0, 0, 0, 1, 2, 0, 0, 0, 1, 3, 0, 0, 0, 1, 4, 0, 0, 0, 1, 5] - ); - } - - { - let mut request: Vec = Vec::new(); - it.write_next_to_request(&mut request).unwrap().unwrap(); - assert_eq!(request, vec![0, 1, 0, 0, 0, 1, 6]); - } -} - -#[test] -fn tuple_batch_values() { - fn check_twoi32_tuple(tuple: impl BatchValues, size: usize) { - let mut it = tuple.batch_values_iter(); - for i in 0..size { - let mut request: Vec = Vec::new(); - it.write_next_to_request(&mut request).unwrap().unwrap(); - - let mut expected: Vec = Vec::new(); - let i: i32 = i.try_into().unwrap(); - expected.put_i16(2); - expected.put_i32(4); - expected.put_i32(i + 1); - expected.put_i32(4); - expected.put_i32(2 * (i + 1)); - - assert_eq!(request, expected); - } - } - - // rustfmt wants to have each tuple inside a tuple in a separate line - // so we end up with 170 lines of tuples - // FIXME: Is there some cargo fmt flag to fix this? - - check_twoi32_tuple(((1, 2),), 1); - check_twoi32_tuple(((1, 2), (2, 4)), 2); - check_twoi32_tuple(((1, 2), (2, 4), (3, 6)), 3); - check_twoi32_tuple(((1, 2), (2, 4), (3, 6), (4, 8)), 4); - check_twoi32_tuple(((1, 2), (2, 4), (3, 6), (4, 8), (5, 10)), 5); - check_twoi32_tuple(((1, 2), (2, 4), (3, 6), (4, 8), (5, 10), (6, 12)), 6); - check_twoi32_tuple( - ((1, 2), (2, 4), (3, 6), (4, 8), (5, 10), (6, 12), (7, 14)), - 7, - ); - check_twoi32_tuple( - ( - (1, 2), - (2, 4), - (3, 6), - (4, 8), - (5, 10), - (6, 12), - (7, 14), - (8, 16), - ), - 8, - ); - check_twoi32_tuple( - ( - (1, 2), - (2, 4), - (3, 6), - (4, 8), - (5, 10), - (6, 12), - (7, 14), - (8, 16), - (9, 18), - ), - 9, - ); - check_twoi32_tuple( - ( - (1, 2), - (2, 4), - (3, 6), - (4, 8), - (5, 10), - (6, 12), - (7, 14), - (8, 16), - (9, 18), - (10, 20), - ), - 10, - ); - check_twoi32_tuple( - ( - (1, 2), - (2, 4), - (3, 6), - (4, 8), - (5, 10), - (6, 12), - (7, 14), - (8, 16), - (9, 18), - (10, 20), - (11, 22), - ), - 11, - ); - check_twoi32_tuple( - ( - (1, 2), - (2, 4), - (3, 6), - (4, 8), - (5, 10), - (6, 12), - (7, 14), - (8, 16), - (9, 18), - (10, 20), - (11, 22), - (12, 24), - ), - 12, - ); - check_twoi32_tuple( - ( - (1, 2), - (2, 4), - (3, 6), - (4, 8), - (5, 10), - (6, 12), - (7, 14), - (8, 16), - (9, 18), - (10, 20), - (11, 22), - (12, 24), - (13, 26), - ), - 13, - ); - check_twoi32_tuple( - ( - (1, 2), - (2, 4), - (3, 6), - (4, 8), - (5, 10), - (6, 12), - (7, 14), - (8, 16), - (9, 18), - (10, 20), - (11, 22), - (12, 24), - (13, 26), - (14, 28), - ), - 14, - ); - check_twoi32_tuple( - ( - (1, 2), - (2, 4), - (3, 6), - (4, 8), - (5, 10), - (6, 12), - (7, 14), - (8, 16), - (9, 18), - (10, 20), - (11, 22), - (12, 24), - (13, 26), - (14, 28), - (15, 30), - ), - 15, - ); - check_twoi32_tuple( - ( - (1, 2), - (2, 4), - (3, 6), - (4, 8), - (5, 10), - (6, 12), - (7, 14), - (8, 16), - (9, 18), - (10, 20), - (11, 22), - (12, 24), - (13, 26), - (14, 28), - (15, 30), - (16, 32), - ), - 16, - ); -} - -#[test] -#[allow(clippy::needless_borrow)] -fn ref_batch_values() { - let batch_values: &[&[i8]] = &[&[1, 2], &[2, 3, 4, 5], &[6]]; - - return check_ref_bv::<&&&&&[&[i8]]>(&&&&batch_values); - fn check_ref_bv(batch_values: B) { - let mut it = >::batch_values_iter(&batch_values); - - let mut request: Vec = Vec::new(); - it.write_next_to_request(&mut request).unwrap().unwrap(); - assert_eq!(request, vec![0, 2, 0, 0, 0, 1, 1, 0, 0, 0, 1, 2]); - } -} - -#[test] -#[allow(clippy::needless_borrow)] -fn check_ref_tuple() { - fn assert_has_batch_values(bv: BV) { - let mut it = bv.batch_values_iter(); - let mut request: Vec = Vec::new(); - while let Some(res) = it.write_next_to_request(&mut request) { - res.unwrap() - } - } - let s = String::from("hello"); - let tuple: ((&str,),) = ((&s,),); - assert_has_batch_values(&tuple); - let tuple2: ((&str, &str), (&str, &str)) = ((&s, &s), (&s, &s)); - assert_has_batch_values(&tuple2); -} - -#[test] -fn check_batch_values_iterator_is_not_lending() { - // This is an interesting property if we want to improve the batch shard selection heuristic - fn f(bv: impl BatchValues) { - let mut it = bv.batch_values_iter(); - let mut it2 = bv.batch_values_iter(); - // Make sure we can hold all these at the same time - let v = vec![ - it.next_serialized().unwrap().unwrap(), - it2.next_serialized().unwrap().unwrap(), - it.next_serialized().unwrap().unwrap(), - it2.next_serialized().unwrap().unwrap(), - ]; - let _ = v; - } - f(((10,), (11,))) -} diff --git a/scylla-cql/src/lib.rs b/scylla-cql/src/lib.rs deleted file mode 100644 index 47b58b4..0000000 --- a/scylla-cql/src/lib.rs +++ /dev/null @@ -1,21 +0,0 @@ -pub mod errors; -pub mod frame; -#[macro_use] -pub mod macros; - -pub use crate::frame::response::cql_to_rust; -pub use crate::frame::response::cql_to_rust::FromRow; - -pub use crate::frame::types::Consistency; - -#[doc(hidden)] -pub mod _macro_internal { - pub use crate::frame::response::cql_to_rust::{ - FromCqlVal, FromCqlValError, FromRow, FromRowError, - }; - pub use crate::frame::response::result::{CqlValue, Row}; - pub use crate::frame::value::{ - SerializedResult, SerializedValues, Value, ValueList, ValueTooBig, - }; - pub use crate::macros::*; -} diff --git a/scylla-cql/src/macros.rs b/scylla-cql/src/macros.rs deleted file mode 100644 index 8d60312..0000000 --- a/scylla-cql/src/macros.rs +++ /dev/null @@ -1,19 +0,0 @@ -/// #[derive(FromRow)] derives FromRow for struct -/// Works only on simple structs without generics etc -pub use scylla_macros::FromRow; - -/// #[derive(FromUserType)] allows to parse struct as a User Defined Type -/// Works only on simple structs without generics etc -pub use scylla_macros::FromUserType; - -/// #[derive(IntoUserType)] allows to pass struct a User Defined Type Value in queries -/// Works only on simple structs without generics etc -pub use scylla_macros::IntoUserType; - -/// #[derive(ValueList)] allows to pass struct as a list of values for a query -pub use scylla_macros::ValueList; - -// Reexports for derive(IntoUserType) -pub use bytes::{BufMut, Bytes, BytesMut}; - -pub use crate::impl_from_cql_value_from_method; diff --git a/scylla-macros/Cargo.toml b/scylla-macros/Cargo.toml deleted file mode 100644 index 9c34a2a..0000000 --- a/scylla-macros/Cargo.toml +++ /dev/null @@ -1,17 +0,0 @@ -[package] -name = "scylla-macros" -version = "0.1.2" -edition = "2021" -description = "proc macros for scylla async CQL driver" -repository = "https://github.com/scylladb/scylla-rust-driver" -readme = "../README.md" -categories = ["database"] -license = "MIT OR Apache-2.0" - -[lib] -proc-macro = true - -[dependencies] -syn = "1.0" -quote = "1.0" -proc-macro2 = "1.0" \ No newline at end of file diff --git a/scylla-macros/src/from_row.rs b/scylla-macros/src/from_row.rs deleted file mode 100644 index 7375636..0000000 --- a/scylla-macros/src/from_row.rs +++ /dev/null @@ -1,60 +0,0 @@ -use proc_macro::TokenStream; -use quote::{quote, quote_spanned}; -use syn::{spanned::Spanned, DeriveInput}; - -/// #[derive(FromRow)] derives FromRow for struct -pub fn from_row_derive(tokens_input: TokenStream) -> TokenStream { - let item = syn::parse::(tokens_input).expect("No DeriveInput"); - let path = crate::parser::get_path(&item).expect("No path"); - let struct_fields = crate::parser::parse_named_fields(&item, "FromRow"); - - let struct_name = &item.ident; - let (impl_generics, ty_generics, where_clause) = item.generics.split_for_impl(); - - // Generates tokens for field_name: field_type::from_cql(vals_iter.next().ok_or(...)?), ... - let set_fields_code = struct_fields.named.iter().map(|field| { - let field_name = &field.ident; - let field_type = &field.ty; - - quote_spanned! {field.span() => - #field_name: { - let (col_ix, col_value) = vals_iter - .next() - .unwrap(); // vals_iter size is checked before this code is reached, so - // it is safe to unwrap - - <#field_type as FromCqlVal<::std::option::Option>>::from_cql(col_value) - .map_err(|e| FromRowError::BadCqlVal { - err: e, - column: col_ix, - })? - }, - } - }); - - let fields_count = struct_fields.named.len(); - let generated = quote! { - impl #impl_generics #path::_macro_internal::FromRow for #struct_name #ty_generics #where_clause { - fn from_row(row: #path::_macro_internal::Row) - -> ::std::result::Result { - use #path::_macro_internal::{CqlValue, FromCqlVal, FromRow, FromRowError}; - use ::std::result::Result::{Ok, Err}; - use ::std::iter::{Iterator, IntoIterator}; - - if #fields_count != row.columns.len() { - return Err(FromRowError::WrongRowSize { - expected: #fields_count, - actual: row.columns.len(), - }); - } - let mut vals_iter = row.columns.into_iter().enumerate(); - - Ok(#struct_name { - #(#set_fields_code)* - }) - } - } - }; - - TokenStream::from(generated) -} diff --git a/scylla-macros/src/from_user_type.rs b/scylla-macros/src/from_user_type.rs deleted file mode 100644 index b66b955..0000000 --- a/scylla-macros/src/from_user_type.rs +++ /dev/null @@ -1,78 +0,0 @@ -use proc_macro::TokenStream; -use quote::{quote, quote_spanned}; -use syn::{spanned::Spanned, DeriveInput}; - -/// #[derive(FromUserType)] allows to parse a struct as User Defined Type -pub fn from_user_type_derive(tokens_input: TokenStream) -> TokenStream { - let item = syn::parse::(tokens_input).expect("No DeriveInput"); - let path = crate::parser::get_path(&item).expect("Couldn't get path to the scylla crate"); - let struct_fields = crate::parser::parse_named_fields(&item, "FromUserType"); - - let struct_name = &item.ident; - let (impl_generics, ty_generics, where_clause) = item.generics.split_for_impl(); - - // Generates tokens for field_name: field_type::from_cql(fields.remove(stringify!(#field_name)).unwrap_or(None)) ?, ... - let set_fields_code = struct_fields.named.iter().map(|field| { - let field_name = &field.ident; - let field_type = &field.ty; - - quote_spanned! {field.span() => - #field_name: <#field_type as FromCqlVal<::std::option::Option>>::from_cql( - { - let received_field_name: Option<&::std::string::String> = fields_iter - .peek() - .map(|(ref name, _)| name); - - // Order of received fields is the same as the order of processed struct's - // fields. There cannot be an extra received field present, so it is safe to - // assign subsequent received fields, to processed struct's fields (inserting - // None if there is no received field corresponding to processed struct's - // field) - if let Some(received_field_name) = received_field_name { - if received_field_name == stringify!(#field_name) { - let (_, value) = fields_iter.next().unwrap(); - value - } else { - None - } - } else { - None - } - } - ) ?, - } - }); - - let generated = quote! { - impl #impl_generics #path::_macro_internal::FromCqlVal<#path::_macro_internal::CqlValue> for #struct_name #ty_generics #where_clause { - fn from_cql(cql_val: #path::_macro_internal::CqlValue) - -> ::std::result::Result { - use ::std::collections::BTreeMap; - use ::std::option::Option::{self, Some, None}; - use ::std::result::Result::{Ok, Err}; - use #path::_macro_internal::{FromCqlVal, FromCqlValError, CqlValue}; - use ::std::iter::{Iterator, IntoIterator}; - - // Interpret CqlValue as CQlValue::UserDefinedType - let mut fields_iter = match cql_val { - CqlValue::UserDefinedType{fields, ..} => fields.into_iter().peekable(), - _ => return Err(FromCqlValError::BadCqlType), - }; - - // Parse struct using values from fields - let result = #struct_name { - #(#set_fields_code)* - }; - - // There should be no unused fields when reading user defined type - if fields_iter.next().is_some() { - return Err(FromCqlValError::BadCqlType); - } - - return Ok(result); - } - } - }; - - TokenStream::from(generated) -} diff --git a/scylla-macros/src/into_user_type.rs b/scylla-macros/src/into_user_type.rs deleted file mode 100644 index f181df6..0000000 --- a/scylla-macros/src/into_user_type.rs +++ /dev/null @@ -1,49 +0,0 @@ -use proc_macro::TokenStream; -use quote::{quote, quote_spanned}; -use syn::{spanned::Spanned, DeriveInput}; - -/// #[derive(IntoUserType)] allows to parse a struct as User Defined Type -pub fn into_user_type_derive(tokens_input: TokenStream) -> TokenStream { - let item = syn::parse::(tokens_input).expect("No DeriveInput"); - let path = crate::parser::get_path(&item).expect("No path"); - let struct_fields = crate::parser::parse_named_fields(&item, "IntoUserType"); - - let struct_name = &item.ident; - let (impl_generics, ty_generics, where_clause) = item.generics.split_for_impl(); - - let serialize_code = struct_fields.named.iter().map(|field| { - let field_name = &field.ident; - - quote_spanned! {field.span() => - <_ as Value>::serialize(&self.#field_name, buf) ?; - } - }); - - let generated = quote! { - impl #impl_generics #path::_macro_internal::Value for #struct_name #ty_generics #where_clause { - fn serialize(&self, buf: &mut ::std::vec::Vec<::core::primitive::u8>) -> ::std::result::Result<(), #path::_macro_internal::ValueTooBig> { - use #path::_macro_internal::{BufMut, ValueTooBig, Value}; - use ::std::convert::TryInto; - use ::core::primitive::{usize, i32}; - - // Reserve space to put serialized size in - let total_size_index: usize = buf.len(); - buf.put_i32(0); - - let len_before_serialize = buf.len(); - - // Serialize fields - #(#serialize_code)* - - // Put serialized size in its place - let total_size : usize = buf.len() - len_before_serialize; - let total_size_i32: i32 = total_size.try_into().map_err(|_| ValueTooBig) ?; - buf[total_size_index..(total_size_index+4)].copy_from_slice(&total_size_i32.to_be_bytes()[..]); - - ::std::result::Result::Ok(()) - } - } - }; - - TokenStream::from(generated) -} diff --git a/scylla-macros/src/lib.rs b/scylla-macros/src/lib.rs deleted file mode 100644 index f5ad28a..0000000 --- a/scylla-macros/src/lib.rs +++ /dev/null @@ -1,35 +0,0 @@ -use proc_macro::TokenStream; - -mod from_row; -mod from_user_type; -mod into_user_type; -mod parser; -mod value_list; - -/// #[derive(FromRow)] derives FromRow for struct -/// Works only on simple structs without generics etc -#[proc_macro_derive(FromRow, attributes(scylla_crate))] -pub fn from_row_derive(tokens_input: TokenStream) -> TokenStream { - from_row::from_row_derive(tokens_input) -} - -/// #[derive(FromUserType)] allows to parse a struct as User Defined Type -/// Works only on simple structs without generics etc -#[proc_macro_derive(FromUserType, attributes(scylla_crate))] -pub fn from_user_type_derive(tokens_input: TokenStream) -> TokenStream { - from_user_type::from_user_type_derive(tokens_input) -} - -/// #[derive(IntoUserType)] allows to parse a struct as User Defined Type -/// Works only on simple structs without generics etc -#[proc_macro_derive(IntoUserType, attributes(scylla_crate))] -pub fn into_user_type_derive(tokens_input: TokenStream) -> TokenStream { - into_user_type::into_user_type_derive(tokens_input) -} - -/// #[derive(ValueList)] derives ValueList for struct -/// Works only on simple structs without generics etc -#[proc_macro_derive(ValueList, attributes(scylla_crate))] -pub fn value_list_derive(tokens_input: TokenStream) -> TokenStream { - value_list::value_list_derive(tokens_input) -} diff --git a/scylla-macros/src/parser.rs b/scylla-macros/src/parser.rs deleted file mode 100644 index 46cf772..0000000 --- a/scylla-macros/src/parser.rs +++ /dev/null @@ -1,65 +0,0 @@ -use syn::{Data, DeriveInput, Fields, FieldsNamed}; -use syn::{Lit, Meta}; - -/// Parses the tokens_input to a DeriveInput and returns the struct name from which it derives and -/// the named fields -pub(crate) fn parse_named_fields<'a>( - input: &'a DeriveInput, - current_derive: &str, -) -> &'a FieldsNamed { - match &input.data { - Data::Struct(data) => match &data.fields { - Fields::Named(named_fields) => named_fields, - _ => panic!( - "derive({}) works only for structs with named fields. Tuples don't need derive.", - current_derive - ), - }, - _ => panic!("derive({}) works only on structs!", current_derive), - } -} - -pub(crate) fn get_path(input: &DeriveInput) -> Result { - let mut this_path: Option = None; - for attr in input.attrs.iter() { - match attr.parse_meta() { - Ok(Meta::NameValue(meta_name_value)) => { - if !meta_name_value.path.is_ident("scylla_crate") { - continue; - } - if let Lit::Str(lit_str) = &meta_name_value.lit { - let path_val = &lit_str.value().parse::().unwrap(); - if this_path.is_none() { - this_path = Some(quote::quote!(#path_val)); - } else { - return Err(syn::Error::new_spanned( - &meta_name_value.lit, - "the `scylla_crate` attribute was set multiple times", - )); - } - } else { - return Err(syn::Error::new_spanned( - &meta_name_value.lit, - "the `scylla_crate` attribute should be a string literal", - )); - } - } - Ok(other) => { - if !other.path().is_ident("scylla_crate") { - continue; - } - return Err(syn::Error::new_spanned( - other, - "the `scylla_crate` attribute have a single value", - )); - } - Err(err) => { - return Err(err); - } - } - } - match this_path { - Some(path) => Ok(path), - None => Ok(quote::quote!(scylla)), - } -} diff --git a/scylla-macros/src/value_list.rs b/scylla-macros/src/value_list.rs deleted file mode 100644 index 6a841d1..0000000 --- a/scylla-macros/src/value_list.rs +++ /dev/null @@ -1,31 +0,0 @@ -use proc_macro::TokenStream; -use quote::quote; -use syn::DeriveInput; - -/// #[derive(ValueList)] allows to parse a struct as a list of values, -/// which can be fed to the query directly. -pub fn value_list_derive(tokens_input: TokenStream) -> TokenStream { - let item = syn::parse::(tokens_input).expect("No DeriveInput"); - let path = crate::parser::get_path(&item).expect("No path"); - let struct_fields = crate::parser::parse_named_fields(&item, "ValueList"); - - let struct_name = &item.ident; - let (impl_generics, ty_generics, where_clause) = item.generics.split_for_impl(); - - let values_len = struct_fields.named.len(); - let field_name = struct_fields.named.iter().map(|field| &field.ident); - let generated = quote! { - impl #impl_generics #path::_macro_internal::ValueList for #struct_name #ty_generics #where_clause { - fn serialized(&self) -> #path::_macro_internal::SerializedResult { - let mut result = #path::_macro_internal::SerializedValues::with_capacity(#values_len); - #( - result.add_value(&self.#field_name)?; - )* - - ::std::result::Result::Ok(::std::borrow::Cow::Owned(result)) - } - } - }; - - TokenStream::from(generated) -} diff --git a/scylla-udf-macros/Cargo.toml b/scylla-udf-macros/Cargo.toml new file mode 100644 index 0000000..201c751 --- /dev/null +++ b/scylla-udf-macros/Cargo.toml @@ -0,0 +1,19 @@ +[package] +name = "scylla-udf-macros" +edition.workspace = true +version.workspace = true +repository.workspace = true +license.workspace = true +rust-version.workspace = true +description = "Implementation of scylla-udf macros" +readme = "../README.md" +keywords = ["scylla", "udf", "macro"] +categories = ["database", "wasm"] + +[lib] +proc-macro = true + +[dependencies] +proc-macro2 = "1.0.36" +quote = "1.0.15" +syn = { version = "1.0.86", features = ["full"] } diff --git a/scylla-udf-macros/src/export_newtype.rs b/scylla-udf-macros/src/export_newtype.rs new file mode 100644 index 0000000..0570f1e --- /dev/null +++ b/scylla-udf-macros/src/export_newtype.rs @@ -0,0 +1,127 @@ +use proc_macro::TokenStream; +use proc_macro2::TokenStream as TokenStream2; +use quote::quote; +use syn::Fields; + +struct NewtypeStruct { + struct_name: syn::Ident, + field_type: syn::Type, + generics: syn::Generics, +} + +fn get_newtype_struct(st: &syn::ItemStruct) -> Result { + let struct_name = &st.ident; + let struct_fields = match &st.fields { + Fields::Unnamed(named_fields) => named_fields, + _ => { + return Err(syn::Error::new_spanned( + st, + "#[scylla_udf::export_newtype] error: struct has named fields.", + ) + .to_compile_error()); + } + }; + if struct_fields.unnamed.len() > 1 { + return Err(syn::Error::new_spanned( + st, + "#[scylla_udf::export_newtype] error: struct has more than 1 field.", + ) + .to_compile_error()); + } + let field_type = match struct_fields.unnamed.first() { + Some(field) => &field.ty, + None => { + return Err(syn::Error::new_spanned( + st, + "#[scylla_udf::export_newtype] error: struct has no fields.", + ) + .to_compile_error()); + } + }; + + Ok(NewtypeStruct { + struct_name: struct_name.clone(), + field_type: field_type.clone(), + generics: st.generics.clone(), + }) +} + +fn impl_wasm_convertible(nst: &NewtypeStruct, path: &TokenStream2) -> TokenStream2 { + let struct_name = &nst.struct_name; + let struct_type = &nst.field_type; + let (impl_generics, ty_generics, where_clause) = nst.generics.split_for_impl(); + quote! { + impl #impl_generics ::#path::WasmConvertible for #struct_name #ty_generics #where_clause { + type WasmType = <#struct_type as ::#path::WasmConvertible>::WasmType; + fn from_wasm(arg: Self::WasmType) -> Self { + #struct_name(<#struct_type as ::#path::WasmConvertible>::from_wasm(arg)) + } + fn to_wasm(&self) -> Self::WasmType { + <#struct_type as ::#path::WasmConvertible>::to_wasm(&self.0) + } + } + } +} + +fn impl_to_col_type(nst: &NewtypeStruct, path: &TokenStream2) -> TokenStream2 { + let struct_name = &nst.struct_name; + let struct_type = &nst.field_type; + let (impl_generics, ty_generics, where_clause) = nst.generics.split_for_impl(); + quote! { + impl #impl_generics ::#path::ToColumnType for #struct_name #ty_generics #where_clause { + fn to_column_type() -> ::#path::ColumnType { + <#struct_type as ::#path::ToColumnType>::to_column_type() + } + } + } +} + +fn impl_value(nst: &NewtypeStruct, path: &TokenStream2) -> TokenStream2 { + let struct_name = &nst.struct_name; + let struct_type = &nst.field_type; + let (impl_generics, ty_generics, where_clause) = nst.generics.split_for_impl(); + + quote! { + impl #impl_generics ::#path::Value for #struct_name #ty_generics #where_clause { + fn serialize(&self, buf: &mut ::std::vec::Vec<::core::primitive::u8>) -> ::std::result::Result<(), ::#path::ValueTooBig> { + <#struct_type as ::#path::Value>::serialize(&self.0, buf) + } + } + } +} + +fn impl_from_cql_val(nst: &NewtypeStruct, path: &TokenStream2) -> TokenStream2 { + let struct_name = &nst.struct_name; + let struct_type = &nst.field_type; + let (impl_generics, ty_generics, where_clause) = nst.generics.split_for_impl(); + + quote! { + impl #impl_generics ::#path::FromCqlVal<::#path::CqlValue> for #struct_name #ty_generics #where_clause { + fn from_cql(val: ::#path::CqlValue) -> ::std::result::Result { + <#struct_type as ::#path::FromCqlVal<::#path::CqlValue>>::from_cql(val).map(|v| #struct_name(v)) + } + } + } +} + +pub(crate) fn export_newtype(attrs: TokenStream, item: TokenStream) -> TokenStream { + let st = syn::parse_macro_input!(item as syn::ItemStruct); + let atrs = syn::parse_macro_input!(attrs as syn::AttributeArgs); + let path = crate::path::get_path(&atrs).expect("Couldn't get path to the scylla_udf crate"); + let newtype_struct = match get_newtype_struct(&st) { + Ok(nst) => nst, + Err(e) => return e.into(), + }; + let wasm_convertible = impl_wasm_convertible(&newtype_struct, &path); + let to_col_type = impl_to_col_type(&newtype_struct, &path); + let value = impl_value(&newtype_struct, &path); + let from_cql_val = impl_from_cql_val(&newtype_struct, &path); + quote! { + #st + #wasm_convertible + #to_col_type + #value + #from_cql_val + } + .into() +} diff --git a/scylla-udf-macros/src/export_udf.rs b/scylla-udf-macros/src/export_udf.rs new file mode 100644 index 0000000..59adfe8 --- /dev/null +++ b/scylla-udf-macros/src/export_udf.rs @@ -0,0 +1,112 @@ +use proc_macro::TokenStream; +use proc_macro2::TokenStream as TokenStream2; +use quote::{format_ident, quote}; +use syn::spanned::Spanned; +use syn::{parse_macro_input, FnArg, ItemFn}; + +fn get_parameters_and_arguments( + item: &ItemFn, + path: &TokenStream2, +) -> Result<(Vec, Vec), TokenStream2> { + let inputs = &item.sig.inputs; + let mut parameters = Vec::with_capacity(inputs.len()); + let mut arguments = Vec::with_capacity(inputs.len()); + for (idx, i) in inputs.iter().enumerate() { + if let FnArg::Typed(pat) = i { + let ident = format_ident!("arg_{}", idx); + let typ = &pat.ty; + parameters.push(quote! { #ident: <#typ as ::#path::WasmConvertible>::WasmType }); + arguments.push(quote! { <#typ as ::#path::WasmConvertible>::from_wasm(#ident) }); + } else { + return Err(syn::Error::new( + i.span(), + "unexpected untyped self parameter in a scylla_udf function.", + ) + .to_compile_error()); + } + } + Ok((parameters, arguments)) +} + +fn get_output_type_and_block( + item: &ItemFn, + arguments: &[TokenStream2], + path: &TokenStream2, +) -> Result<(TokenStream2, TokenStream2), TokenStream2> { + let fun_name = item.sig.ident.clone(); + if let syn::ReturnType::Type(_, typ) = &item.sig.output { + Ok(( + quote! { -> <#typ as ::#path::WasmConvertible>::WasmType }, + quote! { { + <#typ as ::#path::WasmConvertible>::to_wasm(&#fun_name(#(#arguments),*)) + } }, + )) + } else { + Err(syn::Error::new( + item.sig.output.span(), + "scylla_udf function should return a value.", + ) + .to_compile_error()) + } +} + +fn get_exported_fun( + item: &ItemFn, + parameters: &[TokenStream2], + output_type_token: TokenStream2, + exported_block: TokenStream2, +) -> TokenStream2 { + let fun_name = &item.sig.ident; + let exported_fun_ident = format_ident!("{}{}", "_scylla_internal_", fun_name); + // The exported function doesn't need to be pub, because it will be included in the final + // binary anyway due to the #[export_name] attribute. No pub helps with the UDT implementation. + let sig_exported = quote! { + extern "C" fn #exported_fun_ident(#(#parameters),*) #output_type_token + }; + let fun_name_string = fun_name.to_string(); + let export_name = quote! { + #[export_name = #fun_name_string] + }; + quote! { + #export_name + #sig_exported #exported_block + } +} + +/// The macro transforms a function: +/// ```ignore +/// #[scylla_udf::export_udf] +/// fn foo(arg1: u32, arg2: String) -> u32 { +/// arg1 + arg2.len() as u32 +/// } +/// ``` +/// into something like: +/// ```ignore +/// fn foo(arg1: u32, arg2: String) -> u32 { +/// arg1 + arg2.len() as u32 +/// } +/// #[export_name = "foo"] +/// extern "C" fn _scylla_internal_foo(arg1: u32, arg2: WasmPtr) -> u32 { +/// foo(from_wasm(arg1), from_wasm(arg2)).to_wasm() +/// } +/// ``` +pub(crate) fn export_udf(attrs: TokenStream, input: TokenStream) -> TokenStream { + let item = parse_macro_input!(input as ItemFn); + let atrs = syn::parse_macro_input!(attrs as syn::AttributeArgs); + let path = crate::path::get_path(&atrs).expect("Couldn't get path to the scylla_udf crate"); + let (parameters, arguments) = match get_parameters_and_arguments(&item, &path) { + Ok(pa) => pa, + Err(e) => return e.into(), + }; + let (output_type_token, exported_block) = + match get_output_type_and_block(&item, &arguments, &path) { + Ok(oe) => oe, + Err(e) => return e.into(), + }; + let exported_fun = get_exported_fun(&item, ¶meters, output_type_token, exported_block); + quote! { + #item + #exported_fun + } + .into() +} diff --git a/scylla-udf-macros/src/export_udt.rs b/scylla-udf-macros/src/export_udt.rs new file mode 100644 index 0000000..0f3eec4 --- /dev/null +++ b/scylla-udf-macros/src/export_udt.rs @@ -0,0 +1,77 @@ +use proc_macro::TokenStream; +use proc_macro2::TokenStream as TokenStream2; +use quote::{quote, quote_spanned}; +use syn::spanned::Spanned; +use syn::Fields; + +pub fn impl_wasm_convertible(st: &syn::ItemStruct, path: &TokenStream2) -> TokenStream2 { + let struct_name = &st.ident; + let (impl_generics, ty_generics, where_clause) = st.generics.split_for_impl(); + quote! { + impl #impl_generics ::#path::WasmConvertible for #struct_name #ty_generics #where_clause { + type WasmType = ::#path::WasmPtr; + fn from_wasm(arg: Self::WasmType) -> Self { + ::from_wasmptr(arg) + } + fn to_wasm(&self) -> Self::WasmType { + ::to_wasmptr(self) + } + } + } +} + +pub fn impl_to_col_type(st: &syn::ItemStruct, path: &TokenStream2) -> TokenStream2 { + let struct_name = &st.ident; + let struct_name_string = struct_name.to_string(); + let struct_fields = match &st.fields { + Fields::Named(named_fields) => named_fields, + _ => { + return syn::Error::new_spanned( + st, + "#[scylla_udf::export_udt] works only for structs with named fields.", + ) + .to_compile_error() + } + }; + let (impl_generics, ty_generics, where_clause) = st.generics.split_for_impl(); + let fields_column_types = struct_fields.named.iter().map(|field| { + // we matched with Fields::Named above, so we can unwrap + let field_name = field.ident.as_ref().unwrap().to_string(); + let field_type = &field.ty; + quote_spanned! {field.span() => + (#field_name.to_string(), <#field_type as ::#path::ToColumnType>::to_column_type()), + } + }); + quote! { + impl #impl_generics ::#path::ToColumnType for #struct_name #ty_generics #where_clause { + fn to_column_type() -> ::#path::ColumnType { + use ::std::string::ToString; + ::#path::ColumnType::UserDefinedType { + type_name: #struct_name_string.to_string(), + keyspace: "".to_string(), + field_types: ::std::vec![#(#fields_column_types)*], + } + } + } + } +} + +pub(crate) fn export_udt(attrs: TokenStream, item: TokenStream) -> TokenStream { + let st = syn::parse_macro_input!(item as syn::ItemStruct); + let atrs = syn::parse_macro_input!(attrs as syn::AttributeArgs); + let path_no_internal = crate::path::get_path_no_internal(&atrs) + .expect("Couldn't get path to the scylla_udf crate"); + let path_string = path_no_internal.to_string(); + let path = crate::path::append_internal(&path_no_internal); + let wasm_convertible = impl_wasm_convertible(&st, &path); + let to_col_type = impl_to_col_type(&st, &path); + quote! { + #[derive(::#path::FromUserType)] + #[derive(::#path::IntoUserType)] + #[scylla_crate = #path_string] + #st + #wasm_convertible + #to_col_type + } + .into() +} diff --git a/scylla-udf-macros/src/lib.rs b/scylla-udf-macros/src/lib.rs new file mode 100644 index 0000000..4eb2963 --- /dev/null +++ b/scylla-udf-macros/src/lib.rs @@ -0,0 +1,22 @@ +use proc_macro::TokenStream; + +mod export_newtype; +mod export_udf; +mod export_udt; + +#[proc_macro_attribute] +pub fn export_udt(attrs: TokenStream, item: TokenStream) -> TokenStream { + export_udt::export_udt(attrs, item) +} + +#[proc_macro_attribute] +pub fn export_udf(attrs: TokenStream, item: TokenStream) -> TokenStream { + export_udf::export_udf(attrs, item) +} + +#[proc_macro_attribute] +pub fn export_newtype(attrs: TokenStream, item: TokenStream) -> TokenStream { + export_newtype::export_newtype(attrs, item) +} + +pub(crate) mod path; diff --git a/scylla-udf-macros/src/path.rs b/scylla-udf-macros/src/path.rs new file mode 100644 index 0000000..76df8a1 --- /dev/null +++ b/scylla-udf-macros/src/path.rs @@ -0,0 +1,62 @@ +use proc_macro2::TokenStream as TokenStream2; +use syn::{AttributeArgs, Error, Lit, Meta, NestedMeta}; + +// function that returns the value of the "crate" attribute given AttributeArgs +pub(crate) fn get_path_no_internal(atrs: &AttributeArgs) -> Result { + let mut this_path: Option = None; + for attr in atrs.iter() { + match attr { + NestedMeta::Lit(lit) => { + return Err(Error::new_spanned( + lit, + "unexpected literal attribute for `scylla_udf`", + )); + } + NestedMeta::Meta(meta) => { + if !meta.path().is_ident("crate") { + return Err(Error::new_spanned( + meta, + "unexpected meta attribute for `scylla_udf`", + )); + } + match meta { + Meta::NameValue(meta_name_value) => { + if let Lit::Str(lit_str) = &meta_name_value.lit { + let path_val = + &lit_str.value().parse::().unwrap(); + if this_path.is_none() { + this_path = Some(quote::quote!(#path_val)); + } else { + return Err(syn::Error::new_spanned( + &meta_name_value.lit, + "the `crate` attribute was set multiple times", + )); + } + } else { + return Err(syn::Error::new_spanned( + &meta_name_value.lit, + "the `crate` attribute should be a string literal", + )); + } + } + other => { + return Err(Error::new_spanned( + other, + "the `crate` attribute have a single value", + )); + } + } + } + } + } + Ok(this_path.unwrap_or_else(|| quote::quote!(scylla_udf))) +} + +pub(crate) fn append_internal(path: &TokenStream2) -> TokenStream2 { + quote::quote!(#path::_macro_internal) +} + +pub(crate) fn get_path(atrs: &AttributeArgs) -> Result { + let path = get_path_no_internal(atrs)?; + Ok(append_internal(&path)) +} diff --git a/scylla-udf/Cargo.toml b/scylla-udf/Cargo.toml new file mode 100644 index 0000000..0e8d806 --- /dev/null +++ b/scylla-udf/Cargo.toml @@ -0,0 +1,21 @@ +[package] +name = "scylla-udf" +edition.workspace = true +version.workspace = true +repository.workspace = true +license.workspace = true +rust-version.workspace = true +description = "Proc macros for scylla rust UDFs bindings" +readme = "../README.md" +keywords = ["scylla", "udf"] +categories = ["database", "wasm"] + +[dependencies] +bigdecimal = "0.2.0" +bytes = "1.2.1" +chrono = "0.4" +libc = "0.2.119" +num-bigint = "0.3" +scylla-udf-macros = { workspace = true } +scylla-cql = "0.0.4" +uuid = "1.0" diff --git a/scylla-udf/src/abi_exports.rs b/scylla-udf/src/abi_exports.rs new file mode 100644 index 0000000..ed9b0f7 --- /dev/null +++ b/scylla-udf/src/abi_exports.rs @@ -0,0 +1,25 @@ +extern "C" { + fn malloc(size: u32) -> *mut u8; + fn free(ptr: *mut u8); +} + +/// # Safety +/// - caller must ensure that the size is valid, if the allocation fails +/// - the caller must not dereference the returned pointer +#[no_mangle] +#[doc(hidden)] +pub(crate) unsafe extern "C" fn _scylla_malloc(size: u32) -> u32 { + malloc(size) as u32 +} + +/// # Safety +/// - caller must ensure that the pointer is valid +#[no_mangle] +#[doc(hidden)] +pub(crate) unsafe extern "C" fn _scylla_free(ptr: u32) { + free(ptr as *mut u8) +} + +#[no_mangle] +#[doc(hidden)] +static _scylla_abi: u32 = 2; diff --git a/scylla-udf/src/from_wasmptr.rs b/scylla-udf/src/from_wasmptr.rs new file mode 100644 index 0000000..4615f13 --- /dev/null +++ b/scylla-udf/src/from_wasmptr.rs @@ -0,0 +1,24 @@ +use crate::to_columntype::ToColumnType; +use crate::wasmptr::WasmPtr; +use scylla_cql::cql_to_rust::FromCqlVal; +use scylla_cql::frame::response::result::{deser_cql_value, CqlValue}; + +pub trait FromWasmPtr { + fn from_wasmptr(wasmptr: WasmPtr) -> Self; +} + +impl FromWasmPtr for T +where + T: FromCqlVal> + ToColumnType, +{ + fn from_wasmptr(wasmptr: WasmPtr) -> Self { + if wasmptr.is_null() { + return T::from_cql(None).unwrap(); + } + let mut slice = wasmptr.as_slice().expect("WasmPtr::as_slice returned None"); + T::from_cql(Some( + deser_cql_value(&T::to_column_type(), &mut slice).unwrap(), + )) + .unwrap() + } +} diff --git a/scylla-udf/src/lib.rs b/scylla-udf/src/lib.rs new file mode 100644 index 0000000..aa56717 --- /dev/null +++ b/scylla-udf/src/lib.rs @@ -0,0 +1,90 @@ +mod abi_exports; +mod from_wasmptr; +mod to_columntype; +mod to_wasmptr; +mod wasm_convertible; +mod wasmptr; + +/// Not a part of public API. May change in minor releases. +/// Contains all the items used by the scylla_udf macros. +#[doc(hidden)] +pub mod _macro_internal { + pub use crate::from_wasmptr::FromWasmPtr; + pub use crate::to_columntype::ToColumnType; + pub use crate::to_wasmptr::ToWasmPtr; + pub use crate::wasm_convertible::WasmConvertible; + pub use crate::wasmptr::WasmPtr; + pub use scylla_cql::_macro_internal::*; + pub use scylla_cql::frame::response::result::ColumnType; +} + +/// This macro allows using a Rust function as a Scylla UDF. +/// +/// The function must have arguments and return value of Rust types that can be mapped to CQL types, +/// the macro takes care of converting the arguments from CQL types to Rust types and back. +/// The function must not have the `#[no_mangle]` attribute, it will be added by the macro. +/// +/// For example, for a function: +/// ``` +/// #[scylla_udf::export_udf] +/// fn foo(arg: i32) -> i32 { +/// arg + 1 +/// } +/// ``` +/// you can use the compiled binary in its text format as a UDF in Scylla: +/// ```text +/// CREATE FUNCTION foo(arg int) RETURNS NULL ON NULL INPUT RETURNS int LANGUAGE rust AS '(module ...)`; +/// ``` +pub use scylla_udf_macros::export_udf; + +/// This macro allows mapping a Rust struct to a UDT from Scylla, and using in a scylla_udf function. +/// +/// To use it, you need to define a struct with the same fields as the Scylla UDT. +/// For example, for a UDT defined as: +/// ```text +/// CREATE TYPE udt ( +/// a int, +/// b double, +/// c text, +/// ); +/// ``` +/// you need to define a struct: +/// ``` +/// #[scylla_udf::export_udt] +/// struct Udt { +/// a: i32, +/// b: f64, +/// c: String, +/// } +/// ``` +pub use scylla_udf_macros::export_udt; + +/// This macro allows (de)serializing a cql type to/from a Rust "newtype" struct. +/// +/// The macro takes a "newtype" struct (tuple struct with only one field) and generates all implementations for (de)serialization +/// traits used in the scylla_udf macros by treating the struct as the inner type itself. +/// +/// This allows overriding the impls for the inner type, while still being able to use it in the types of parameters or return +/// values of scylla_udf functions. +/// +/// For example, for a function using a newtype struct: +/// ``` +/// #[scylla_udf::export_newtype] +/// struct MyInt(i32); +/// +/// #[scylla_udf::export_udf] +/// fn foo(arg: MyInt) -> MyInt { +/// ... +/// } +/// ``` +/// and a table: +/// ```text +/// CREATE TABLE table (x int PRIMARY KEY); +/// ``` +/// you can use the function in a query: +/// ```text +/// SELECT foo(x) FROM table; +/// ``` +pub use scylla_udf_macros::export_newtype; + +pub use scylla_cql::frame::value::{Counter, CqlDuration, Time, Timestamp}; diff --git a/scylla-udf/src/to_columntype.rs b/scylla-udf/src/to_columntype.rs new file mode 100644 index 0000000..b45cf70 --- /dev/null +++ b/scylla-udf/src/to_columntype.rs @@ -0,0 +1,96 @@ +pub use scylla_cql::frame::response::result::ColumnType; +use scylla_cql::frame::value::{Counter, CqlDuration, Time, Timestamp}; +use std::collections::{BTreeMap, BTreeSet, HashMap, HashSet}; + +pub trait ToColumnType { + fn to_column_type() -> ColumnType; +} + +// This macro implements ToColumnType given a Rust type and the resulting ColumnType +macro_rules! impl_to_col_type { + ($rust_type:ty, $col_type:expr) => { + impl ToColumnType for $rust_type { + fn to_column_type() -> ColumnType { + $col_type + } + } + }; +} + +impl_to_col_type!(bool, ColumnType::Boolean); +impl_to_col_type!(Vec, ColumnType::Blob); +impl_to_col_type!(Counter, ColumnType::Counter); +impl_to_col_type!(chrono::NaiveDate, ColumnType::Date); +impl_to_col_type!(bigdecimal::BigDecimal, ColumnType::Decimal); +impl_to_col_type!(f64, ColumnType::Double); +impl_to_col_type!(CqlDuration, ColumnType::Duration); +impl_to_col_type!(f32, ColumnType::Float); +impl_to_col_type!(i32, ColumnType::Int); +impl_to_col_type!(i64, ColumnType::BigInt); +impl_to_col_type!(String, ColumnType::Text); +impl_to_col_type!(Timestamp, ColumnType::Timestamp); +impl_to_col_type!(std::net::IpAddr, ColumnType::Inet); +impl_to_col_type!(i16, ColumnType::SmallInt); +impl_to_col_type!(i8, ColumnType::TinyInt); +impl_to_col_type!(Time, ColumnType::Time); +impl_to_col_type!(uuid::Uuid, ColumnType::Uuid); +impl_to_col_type!(num_bigint::BigInt, ColumnType::Varint); + +impl ToColumnType for Vec { + fn to_column_type() -> ColumnType { + ColumnType::List(Box::new(T::to_column_type())) + } +} + +impl ToColumnType for BTreeMap { + fn to_column_type() -> ColumnType { + ColumnType::Map(Box::new(K::to_column_type()), Box::new(V::to_column_type())) + } +} + +impl ToColumnType for HashMap { + fn to_column_type() -> ColumnType { + ColumnType::Map(Box::new(K::to_column_type()), Box::new(V::to_column_type())) + } +} + +impl ToColumnType for BTreeSet { + fn to_column_type() -> ColumnType { + ColumnType::Set(Box::new(T::to_column_type())) + } +} + +impl ToColumnType for HashSet { + fn to_column_type() -> ColumnType { + ColumnType::Set(Box::new(T::to_column_type())) + } +} + +macro_rules! tuple_impls { + ( $( $types:ident )* ) => { + impl<$($types: ToColumnType),*> ToColumnType for ($($types,)*) { + fn to_column_type() -> ColumnType { + ColumnType::Tuple(vec![$($types::to_column_type()),*]) + } + } + }; +} + +tuple_impls! { A } +tuple_impls! { A B } +tuple_impls! { A B C } +tuple_impls! { A B C D } +tuple_impls! { A B C D E } +tuple_impls! { A B C D E F } +tuple_impls! { A B C D E F G } +tuple_impls! { A B C D E F G H } +tuple_impls! { A B C D E F G H I } +tuple_impls! { A B C D E F G H I J } +tuple_impls! { A B C D E F G H I J K } +tuple_impls! { A B C D E F G H I J K L } + +impl ToColumnType for Option { + fn to_column_type() -> ColumnType { + T::to_column_type() + } +} diff --git a/scylla-udf/src/to_wasmptr.rs b/scylla-udf/src/to_wasmptr.rs new file mode 100644 index 0000000..6c082a8 --- /dev/null +++ b/scylla-udf/src/to_wasmptr.rs @@ -0,0 +1,24 @@ +use crate::wasmptr::WasmPtr; +use core::convert::TryInto; +use scylla_cql::frame::value::Value; + +pub trait ToWasmPtr { + fn to_wasmptr(&self) -> WasmPtr; +} + +impl ToWasmPtr for T { + fn to_wasmptr(&self) -> WasmPtr { + let mut bytes = Vec::::new(); + self.serialize(&mut bytes).expect("Error serializing value"); + let size = u32::from_be_bytes(bytes[..4].try_into().expect("slice with incorrect length")); + if size == u32::MAX { + return WasmPtr::null(); + } + let mut dest = WasmPtr::with_size(size).expect("Failed to allocate memory"); + let dest_slice = dest + .as_mut_slice() + .expect("WasmPtr::as_mut_slice returned None"); + dest_slice.copy_from_slice(&bytes[4..]); + dest + } +} diff --git a/scylla-udf/src/wasm_convertible.rs b/scylla-udf/src/wasm_convertible.rs new file mode 100644 index 0000000..087c021 --- /dev/null +++ b/scylla-udf/src/wasm_convertible.rs @@ -0,0 +1,327 @@ +use crate::from_wasmptr::FromWasmPtr; +use crate::to_wasmptr::ToWasmPtr; +use crate::wasmptr::WasmPtr; +use scylla_cql::frame::value::{Counter, CqlDuration, Time, Timestamp}; +use std::collections::{BTreeMap, BTreeSet, HashMap, HashSet}; +use std::convert::TryFrom; + +pub trait WasmConvertible { + type WasmType; + fn from_wasm(arg: Self::WasmType) -> Self; + fn to_wasm(&self) -> Self::WasmType; +} + +// This macro implements WasmConvertible given a Rust type and the resulting WasmType +macro_rules! impl_wasm_convertible_native { + ($rust_type:ty) => { + impl WasmConvertible for $rust_type { + type WasmType = $rust_type; + fn from_wasm(arg: Self::WasmType) -> Self { + arg + } + fn to_wasm(&self) -> Self::WasmType { + *self + } + } + }; +} + +impl_wasm_convertible_native!(i32); +impl_wasm_convertible_native!(i64); +impl_wasm_convertible_native!(f32); +impl_wasm_convertible_native!(f64); + +// This macro implements WasmConvertible given a Rust type that can be converted to the given WasmType using TryFrom +macro_rules! impl_wasm_convertible_scalar { + ($rust_type:ty, $scalar_type:ty) => { + impl WasmConvertible for $rust_type { + type WasmType = $scalar_type; + fn from_wasm(arg: Self::WasmType) -> Self { + <$rust_type>::try_from(arg) + .expect("Failed to convert from wasm type to a rust type") + } + fn to_wasm(&self) -> Self::WasmType { + *self as Self::WasmType + } + } + }; +} + +impl_wasm_convertible_scalar!(i8, i32); +impl_wasm_convertible_scalar!(i16, i32); + +// Can't convert bool to i32 using TryFrom, so we need a special implementation +impl WasmConvertible for bool { + type WasmType = i32; + fn from_wasm(arg: Self::WasmType) -> Self { + arg != 0 + } + fn to_wasm(&self) -> Self::WasmType { + i32::from(*self) + } +} + +// This macro implements WasmConvertible given a Rust type that can be (de)serialized using FromWasmPtr and ToWasmPtr +macro_rules! impl_wasm_convertible_serialized { + ($rust_type:ty) => { + impl WasmConvertible for $rust_type { + type WasmType = WasmPtr; + fn from_wasm(arg: Self::WasmType) -> Self { + ::from_wasmptr(arg) + } + fn to_wasm(&self) -> Self::WasmType { + ::to_wasmptr(self) + } + } + }; +} + +impl_wasm_convertible_serialized!(Counter); +impl_wasm_convertible_serialized!(chrono::NaiveDate); +impl_wasm_convertible_serialized!(bigdecimal::BigDecimal); +impl_wasm_convertible_serialized!(CqlDuration); +impl_wasm_convertible_serialized!(String); +impl_wasm_convertible_serialized!(Timestamp); +impl_wasm_convertible_serialized!(std::net::IpAddr); +impl_wasm_convertible_serialized!(Time); +impl_wasm_convertible_serialized!(uuid::Uuid); +impl_wasm_convertible_serialized!(num_bigint::BigInt); + +// This macro implements WasmConvertible given a Rust type with a generic parameter T that can be (de)serialized using FromWasmPtr and ToWasmPtr +macro_rules! impl_wasm_convertible_serialized_generic { + ($rust_type:ty) => { + impl WasmConvertible for $rust_type + where + $rust_type: FromWasmPtr + ToWasmPtr, + { + type WasmType = WasmPtr; + fn from_wasm(arg: Self::WasmType) -> Self { + ::from_wasmptr(arg) + } + fn to_wasm(&self) -> Self::WasmType { + ::to_wasmptr(self) + } + } + }; +} + +impl_wasm_convertible_serialized_generic!(Option); +// Implements both lists and blobs +impl_wasm_convertible_serialized_generic!(Vec); +impl_wasm_convertible_serialized_generic!(BTreeSet); +impl_wasm_convertible_serialized_generic!(HashSet); + +// This macro implements WasmConvertible given a Rust type with generic parameters K and V that can be (de)serialized using FromWasmPtr and ToWasmPtr +macro_rules! impl_wasm_convertible_serialized_double_generic { + ($rust_type:ty) => { + impl WasmConvertible for $rust_type + where + $rust_type: FromWasmPtr + ToWasmPtr, + { + type WasmType = WasmPtr; + fn from_wasm(arg: Self::WasmType) -> Self { + ::from_wasmptr(arg) + } + fn to_wasm(&self) -> Self::WasmType { + ::to_wasmptr(self) + } + } + }; +} + +impl_wasm_convertible_serialized_double_generic!(BTreeMap); +impl_wasm_convertible_serialized_double_generic!(HashMap); + +// This macro implements WasmConvertible for tuples of types that can be (de)serialized using FromWasmPtr and ToWasmPtr +macro_rules! impl_wasm_convertible_serialized_tuple { + ( $( $types:ident )* ) => { + impl<$($types),*> WasmConvertible for ($($types,)*) + where + ($($types,)*): FromWasmPtr + ToWasmPtr + { + type WasmType = WasmPtr; + fn from_wasm(arg: Self::WasmType) -> Self { + ::from_wasmptr(arg) + } + fn to_wasm(&self) -> Self::WasmType { + ::to_wasmptr(self) + } + } + }; +} + +impl_wasm_convertible_serialized_tuple! { A } +impl_wasm_convertible_serialized_tuple! { A B } +impl_wasm_convertible_serialized_tuple! { A B C } +impl_wasm_convertible_serialized_tuple! { A B C D } +impl_wasm_convertible_serialized_tuple! { A B C D E } +impl_wasm_convertible_serialized_tuple! { A B C D E F } +impl_wasm_convertible_serialized_tuple! { A B C D E F G } +impl_wasm_convertible_serialized_tuple! { A B C D E F G H } +impl_wasm_convertible_serialized_tuple! { A B C D E F G H I } +impl_wasm_convertible_serialized_tuple! { A B C D E F G H I J } +impl_wasm_convertible_serialized_tuple! { A B C D E F G H I J K } +impl_wasm_convertible_serialized_tuple! { A B C D E F G H I J K L } + +#[cfg(test)] +mod tests { + use super::WasmConvertible; + use crate::*; + + #[test] + fn i32_convert() { + assert_eq!(i32::from_wasm(42_i32.to_wasm()), 42_i32); + assert_eq!(i32::from_wasm(-42_i32), -42_i32); + assert_eq!((-42_i32).to_wasm(), -42_i32); + } + #[test] + fn i64_convert() { + assert_eq!(i64::from_wasm(42_i64.to_wasm()), 42_i64); + assert_eq!(i64::from_wasm(-42_i64), -42_i64); + assert_eq!((-42_i64).to_wasm(), -42_i64); + } + #[test] + fn f32_convert() { + assert_eq!(f32::from_wasm(0.42_f32.to_wasm()), 0.42_f32); + assert_eq!(f32::from_wasm(-0.42_f32), -0.42_f32); + assert_eq!((-0.42_f32).to_wasm(), -0.42_f32); + } + #[test] + fn f64_convert() { + assert_eq!(f64::from_wasm(0.42_f64.to_wasm()), 0.42_f64); + assert_eq!(f64::from_wasm(-0.42_f64), -0.42_f64); + assert_eq!((-0.42_f64).to_wasm(), -0.42_f64); + } + #[test] + fn i8_convert() { + assert_eq!(i8::from_wasm(42_i8.to_wasm()), 42_i8); + assert_eq!(i8::from_wasm(-42_i32), -42_i8); + assert_eq!((-42_i8).to_wasm(), -42_i32); + } + #[test] + fn i16_convert() { + assert_eq!(i64::from_wasm(42_i64.to_wasm()), 42_i64); + assert_eq!(i64::from_wasm(-42_i64), -42_i64); + assert_eq!((-42_i64).to_wasm(), -42_i64); + } + #[test] + fn bool_convert() { + assert!(bool::from_wasm(true.to_wasm())); + assert!(bool::from_wasm(1_i32)); + assert_eq!(bool::to_wasm(&false), 0_i32); + } + #[test] + fn blob_convert() { + let blob: Vec = vec![1, 2, 3, 4, 5]; + assert_eq!(Vec::::from_wasm(blob.to_wasm()), blob); + } + #[test] + fn counter_convert() { + assert_eq!(Counter::from_wasm(Counter(13).to_wasm()), Counter(13)); + } + #[test] + fn naive_date_convert() { + use chrono::NaiveDate; + let date = NaiveDate::from_ymd_opt(1970, 1, 1).unwrap(); + assert_eq!(NaiveDate::from_wasm(date.to_wasm()), date); + } + #[test] + fn big_decimal_convert() { + use bigdecimal::BigDecimal; + use std::str::FromStr; + let bigd = BigDecimal::from_str("547318970434573134570").unwrap(); + assert_eq!(BigDecimal::from_wasm(bigd.to_wasm()), bigd); + } + #[test] + fn cql_duration_convert() { + let dur = CqlDuration { + months: 1, + days: 2, + nanoseconds: 3, + }; + assert_eq!(CqlDuration::from_wasm(dur.to_wasm()), dur); + } + #[test] + fn timestamp_convert() { + use chrono::Duration; + let ts = Timestamp(Duration::weeks(2)); + assert_eq!(Timestamp::from_wasm(ts.to_wasm()), ts); + } + #[test] + fn string_convert() { + let s = String::from("abc"); + assert_eq!(String::from_wasm(s.to_wasm()), s); + } + #[test] + fn inet_convert() { + use std::net::IpAddr; + let ip = IpAddr::from([127, 0, 0, 1]); + assert_eq!(IpAddr::from_wasm(ip.to_wasm()), ip); + } + #[test] + fn time_convert() { + use chrono::Duration; + let t = Time(Duration::hours(3)); + assert_eq!(Time::from_wasm(t.to_wasm()), t); + } + #[test] + fn uuid_convert() { + use uuid::Uuid; + let uuid = Uuid::NAMESPACE_OID; + assert_eq!(Uuid::from_wasm(uuid.to_wasm()), uuid); + } + #[test] + fn big_int_convert() { + use num_bigint::BigInt; + use std::str::FromStr; + let bi = BigInt::from_str("420000000000000000").unwrap(); + assert_eq!(BigInt::from_wasm(bi.to_wasm()), bi); + } + + #[test] + fn vec_convert() { + // convert vec of strings + let vec = vec![String::from("a"), String::from("b")]; + assert_eq!(Vec::::from_wasm(vec.to_wasm()), vec); + } + #[test] + fn hashset_convert() { + use std::collections::HashSet; + let mut set = HashSet::new(); + set.insert(String::from("a")); + set.insert(String::from("b")); + assert_eq!(HashSet::::from_wasm(set.to_wasm()), set); + } + #[test] + fn btreeset_convert() { + use std::collections::BTreeSet; + let mut set = BTreeSet::new(); + set.insert((1, String::from("a"))); + set.insert((3, String::from("b"))); + assert_eq!(BTreeSet::<(i32, String)>::from_wasm(set.to_wasm()), set); + } + #[test] + fn hashmap_convert() { + use std::collections::HashMap; + let mut map = HashMap::new(); + map.insert(String::from("a"), 5_i16); + map.insert(String::from("b"), 55_i16); + assert_eq!(HashMap::::from_wasm(map.to_wasm()), map); + } + #[test] + fn btreemap_convert() { + use std::collections::BTreeMap; + let mut map = BTreeMap::new(); + map.insert((1, 2), String::from("a")); + map.insert((3, 4), String::from("b")); + assert_eq!( + BTreeMap::<(i32, i32), String>::from_wasm(map.to_wasm()), + map + ); + } + #[test] + fn tuple_convert() { + let tup = (String::from("a"), 5_i8); + assert_eq!(<(String, i8)>::from_wasm(tup.to_wasm()), tup); + } +} diff --git a/scylla-udf/src/wasmptr.rs b/scylla-udf/src/wasmptr.rs new file mode 100644 index 0000000..b6b3604 --- /dev/null +++ b/scylla-udf/src/wasmptr.rs @@ -0,0 +1,74 @@ +use crate::abi_exports::{_scylla_free, _scylla_malloc}; + +// A unique pointer to an object in the WASM memory. +// Contains the serialized size of the object in the high 32 bits, and the pointer +// to the object in the low 32 bits. A null pointer is represented by a size of u32::MAX. +// The pointer is allocated with _scylla_malloc and freed with _scylla_free. +#[repr(transparent)] +pub struct WasmPtr(u64); + +impl WasmPtr { + pub fn with_size(size: u32) -> Option { + if size == u32::MAX { + // u32::MAX is reserved for null + return None; + } + + // SAFETY: the size fits in a u32, so it's valid to allocate that much memory + // and we do not dereference the pointer if the allocation fails + let ptr = unsafe { _scylla_malloc(size) }; + if ptr == 0 { + return None; + } + Some(WasmPtr(((size as u64) << 32) + ptr as u64)) + } + + pub const fn size(&self) -> Option { + let size = self.0 >> 32; + if size == u32::MAX as u64 { + None + } else { + Some(size as usize) + } + } + + pub const fn null() -> WasmPtr { + WasmPtr((u32::MAX as u64) << 32) + } + + pub const fn is_null(&self) -> bool { + self.size().is_none() + } + + fn raw(&self) -> *mut u8 { + (self.0 & 0xffffffff) as *mut u8 + } + + fn raw_mut(&self) -> *mut u8 { + (self.0 & 0xffffffff) as *mut u8 + } + + pub fn as_slice<'a>(&self) -> Option<&'a [u8]> { + // SAFETY: the `dest` pointer is a succesful result of allocating `size` bytes and it's always aligned to a u8 + self.size() + .map(|size| unsafe { std::slice::from_raw_parts::<'a>(self.raw(), size) }) + } + + pub fn as_mut_slice<'a>(&mut self) -> Option<&'a mut [u8]> { + if let Some(size) = self.size() { + // SAFETY: the `dest` pointer is a succesful result of allocating `size` bytes and it's always aligned to a u8 + Some(unsafe { std::slice::from_raw_parts_mut::<'a>(self.raw_mut(), size) }) + } else { + None + } + } +} + +impl Drop for WasmPtr { + fn drop(&mut self) { + if !self.is_null() { + // SAFETY: the `dest` pointer is a succesful result of a _scylla_malloc call, so it's valid + unsafe { _scylla_free(self.raw() as u32) }; + } + } +} diff --git a/tests/Cargo.toml b/tests/Cargo.toml new file mode 100644 index 0000000..0343080 --- /dev/null +++ b/tests/Cargo.toml @@ -0,0 +1,22 @@ +[package] +name = "tests" +edition.workspace = true +version.workspace = true +repository.workspace = true +license.workspace = true +rust-version.workspace = true +publish = false + +[dependencies] +scylla-udf = { workspace = true } +bigdecimal = "0.2.0" +bytes = "1.2.1" +chrono = "0.4" +libc = "0.2.119" +num-bigint = "0.3" +uuid = "1.0" + +[[test]] +name = "hygiene" +path = "hygiene.rs" +crate-type = ["cdylib"] diff --git a/tests/hygiene.rs b/tests/hygiene.rs new file mode 100644 index 0000000..fa96413 --- /dev/null +++ b/tests/hygiene.rs @@ -0,0 +1,30 @@ +#![no_implicit_prelude] + +extern crate scylla_udf as _scylla_udf; + +#[derive(::core::fmt::Debug, ::core::cmp::PartialEq, ::std::marker::Copy, ::std::clone::Clone)] +#[::_scylla_udf::export_udt(crate = "_scylla_udf")] +struct TestStruct { + a: ::core::primitive::i32, +} +#[derive(::core::fmt::Debug, ::core::cmp::PartialEq, ::std::marker::Copy, ::std::clone::Clone)] +#[::_scylla_udf::export_newtype(crate = "_scylla_udf")] +struct TestNewtype(::core::primitive::i32); + +// Macro can only be expanded if TestStruct and TestNewtype were +// properly expanded. +#[::_scylla_udf::export_udf(crate = "_scylla_udf")] +fn test_fn(arg1: TestNewtype, arg2: TestStruct) -> (TestNewtype, TestStruct) { + (arg1, arg2) +} + +#[test] +fn test_renamed() { + use ::_scylla_udf::_macro_internal::WasmConvertible; + let arg1 = TestNewtype(16); + let arg2 = TestStruct { a: 16 }; + let rets = _scylla_internal_test_fn(arg1.to_wasm(), arg2.to_wasm()); + let (ret1, ret2) = <(TestNewtype, TestStruct)>::from_wasm(rets); + ::std::assert_eq!(arg1, ret1); + ::std::assert_eq!(arg2, ret2); +}