diff --git a/.github/workflows/dev.yml b/.github/workflows/dev.yml index cc23e99e8cba..19af21ec910b 100644 --- a/.github/workflows/dev.yml +++ b/.github/workflows/dev.yml @@ -30,7 +30,7 @@ jobs: - name: Checkout uses: actions/checkout@v4 - name: Setup Python - uses: actions/setup-python@v4 + uses: actions/setup-python@v5 with: python-version: "3.10" - name: Audit licenses diff --git a/.github/workflows/dev_pr.yml b/.github/workflows/dev_pr.yml index 85aabc188934..77b257743331 100644 --- a/.github/workflows/dev_pr.yml +++ b/.github/workflows/dev_pr.yml @@ -46,7 +46,7 @@ jobs: github.event_name == 'pull_request_target' && (github.event.action == 'opened' || github.event.action == 'synchronize') - uses: actions/labeler@v4.3.0 + uses: actions/labeler@v5.0.0 with: repo-token: ${{ secrets.GITHUB_TOKEN }} configuration-path: .github/workflows/dev_pr/labeler.yml diff --git a/.github/workflows/dev_pr/labeler.yml b/.github/workflows/dev_pr/labeler.yml index e84cf5efb1d8..34a37948785b 100644 --- a/.github/workflows/dev_pr/labeler.yml +++ b/.github/workflows/dev_pr/labeler.yml @@ -16,35 +16,37 @@ # under the License. development-process: - - dev/**.* - - .github/**.* - - ci/**.* - - .asf.yaml +- changed-files: + - any-glob-to-any-file: ['dev/**.*', '.github/**.*', 'ci/**.*', '.asf.yaml'] documentation: - - docs/**.* - - README.md - - ./**/README.md - - DEVELOPERS.md - - datafusion/docs/**.* +- changed-files: + - any-glob-to-any-file: ['docs/**.*', 'README.md', './**/README.md', 'DEVELOPERS.md', 'datafusion/docs/**.*'] sql: - - datafusion/sql/**/* +- changed-files: + - any-glob-to-any-file: ['datafusion/sql/**/*'] logical-expr: - - datafusion/expr/**/* +- changed-files: + - any-glob-to-any-file: ['datafusion/expr/**/*'] physical-expr: - - datafusion/physical-expr/**/* +- changed-files: + - any-glob-to-any-file: ['datafusion/physical-expr/**/*'] optimizer: - - datafusion/optimizer/**/* +- changed-files: + - any-glob-to-any-file: ['datafusion/optimizer/**/*'] core: - - datafusion/core/**/* +- changed-files: + - any-glob-to-any-file: ['datafusion/core/**/*'] substrait: - - datafusion/substrait/**/* +- changed-files: + - any-glob-to-any-file: ['datafusion/substrait/**/*'] sqllogictest: - - datafusion/sqllogictest/**/* +- changed-files: + - any-glob-to-any-file: ['datafusion/sqllogictest/**/*'] diff --git a/.github/workflows/docs.yaml b/.github/workflows/docs.yaml index 14b2038e8794..ab6a615ab60b 100644 --- a/.github/workflows/docs.yaml +++ b/.github/workflows/docs.yaml @@ -24,7 +24,7 @@ jobs: path: asf-site - name: Setup Python - uses: actions/setup-python@v4 + uses: actions/setup-python@v5 with: python-version: "3.10" diff --git a/.github/workflows/rust.yml b/.github/workflows/rust.yml index 485d179571e3..099aab061435 100644 --- a/.github/workflows/rust.yml +++ b/.github/workflows/rust.yml @@ -348,7 +348,7 @@ jobs: - uses: actions/checkout@v4 with: submodules: true - - uses: actions/setup-python@v4 + - uses: actions/setup-python@v5 with: python-version: "3.8" - name: Install PyArrow diff --git a/Cargo.toml b/Cargo.toml index 60befdf1cfb7..2bcbe059ab25 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -85,7 +85,7 @@ parquet = { version = "49.0.0", default-features = false, features = ["arrow", " rand = "0.8" rstest = "0.18.0" serde_json = "1" -sqlparser = { version = "0.39.0", features = ["visitor"] } +sqlparser = { version = "0.40.0", features = ["visitor"] } tempfile = "3" thiserror = "1.0.44" chrono = { version = "0.4.31", default-features = false } diff --git a/README.md b/README.md index f5ee1d6d806f..883700a39355 100644 --- a/README.md +++ b/README.md @@ -28,6 +28,7 @@ in-memory format. [Python Bindings](https://github.com/apache/arrow-datafusion-p Here are links to some important information - [Project Site](https://arrow.apache.org/datafusion) +- [Installation](https://arrow.apache.org/datafusion/user-guide/cli.html#installation) - [Rust Getting Started](https://arrow.apache.org/datafusion/user-guide/example-usage.html) - [Rust DataFrame API](https://arrow.apache.org/datafusion/user-guide/dataframe.html) - [Rust API docs](https://docs.rs/datafusion/latest/datafusion) diff --git a/benchmarks/compare.py b/benchmarks/compare.py index 80aa3c76b754..ec2b28fa0556 100755 --- a/benchmarks/compare.py +++ b/benchmarks/compare.py @@ -109,7 +109,6 @@ def compare( noise_threshold: float, ) -> None: baseline = BenchmarkRun.load_from_file(baseline_path) - comparison = BenchmarkRun.load_from_file(comparison_path) console = Console() @@ -124,27 +123,57 @@ def compare( table.add_column(comparison_header, justify="right", style="dim") table.add_column("Change", justify="right", style="dim") + faster_count = 0 + slower_count = 0 + no_change_count = 0 + total_baseline_time = 0 + total_comparison_time = 0 + for baseline_result, comparison_result in zip(baseline.queries, comparison.queries): assert baseline_result.query == comparison_result.query + total_baseline_time += baseline_result.execution_time + total_comparison_time += comparison_result.execution_time + change = comparison_result.execution_time / baseline_result.execution_time if (1.0 - noise_threshold) <= change <= (1.0 + noise_threshold): - change = "no change" + change_text = "no change" + no_change_count += 1 elif change < 1.0: - change = f"+{(1 / change):.2f}x faster" + change_text = f"+{(1 / change):.2f}x faster" + faster_count += 1 else: - change = f"{change:.2f}x slower" + change_text = f"{change:.2f}x slower" + slower_count += 1 table.add_row( f"Q{baseline_result.query}", f"{baseline_result.execution_time:.2f}ms", f"{comparison_result.execution_time:.2f}ms", - change, + change_text, ) console.print(table) + # Calculate averages + avg_baseline_time = total_baseline_time / len(baseline.queries) + avg_comparison_time = total_comparison_time / len(comparison.queries) + + # Summary table + summary_table = Table(show_header=True, header_style="bold magenta") + summary_table.add_column("Benchmark Summary", justify="left", style="dim") + summary_table.add_column("", justify="right", style="dim") + + summary_table.add_row(f"Total Time ({baseline_header})", f"{total_baseline_time:.2f}ms") + summary_table.add_row(f"Total Time ({comparison_header})", f"{total_comparison_time:.2f}ms") + summary_table.add_row(f"Average Time ({baseline_header})", f"{avg_baseline_time:.2f}ms") + summary_table.add_row(f"Average Time ({comparison_header})", f"{avg_comparison_time:.2f}ms") + summary_table.add_row("Queries Faster", str(faster_count)) + summary_table.add_row("Queries Slower", str(slower_count)) + summary_table.add_row("Queries with No Change", str(no_change_count)) + + console.print(summary_table) def main() -> None: parser = ArgumentParser() diff --git a/datafusion-cli/Cargo.lock b/datafusion-cli/Cargo.lock index fa2832ab3fc6..76be04d5ef67 100644 --- a/datafusion-cli/Cargo.lock +++ b/datafusion-cli/Cargo.lock @@ -178,7 +178,7 @@ dependencies = [ "chrono", "chrono-tz", "half", - "hashbrown 0.14.2", + "hashbrown 0.14.3", "num", ] @@ -304,7 +304,7 @@ dependencies = [ "arrow-data", "arrow-schema", "half", - "hashbrown 0.14.2", + "hashbrown 0.14.3", ] [[package]] @@ -360,9 +360,9 @@ dependencies = [ [[package]] name = "async-compression" -version = "0.4.4" +version = "0.4.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f658e2baef915ba0f26f1f7c42bfb8e12f532a01f449a090ded75ae7a07e9ba2" +checksum = "bc2d0cfb2a7388d34f590e76686704c494ed7aaceed62ee1ba35cbf363abc2a5" dependencies = [ "bzip2", "flate2", @@ -820,9 +820,9 @@ checksum = "a2bd12c1caf447e69cd4528f47f94d203fd2582878ecb9e9465484c4148a8223" [[package]] name = "bytes-utils" -version = "0.1.3" +version = "0.1.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e47d3a8076e283f3acd27400535992edb3ba4b5bb72f8891ad8fbe7932a7d4b9" +checksum = "7dafe3a8757b027e2be6e4e5601ed563c55989fcf1546e933c66c8eb3a058d35" dependencies = [ "bytes", "either", @@ -851,10 +851,11 @@ dependencies = [ [[package]] name = "cc" -version = "1.0.84" +version = "1.0.83" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0f8e7c90afad890484a21653d08b6e209ae34770fb5ee298f9c699fcc1e5c856" +checksum = "f1174fb0b6ec23863f8b971027804a42614e347eafb0a95bf0b12cdae21fc4d0" dependencies = [ + "jobserver", "libc", ] @@ -874,7 +875,7 @@ dependencies = [ "iana-time-zone", "num-traits", "serde", - "windows-targets", + "windows-targets 0.48.5", ] [[package]] @@ -988,9 +989,9 @@ checksum = "f7144d30dcf0fafbce74250a3963025d8d52177934239851c917d29f1df280c2" [[package]] name = "core-foundation" -version = "0.9.3" +version = "0.9.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "194a7a9e6de53fa55116934067c844d9d749312f75c6f6d0980e8c252f8c2146" +checksum = "91e195e091a93c46f7102ec7818a2aa394e1e1771c3ab4825963fa03e45afb8f" dependencies = [ "core-foundation-sys", "libc", @@ -998,9 +999,9 @@ dependencies = [ [[package]] name = "core-foundation-sys" -version = "0.8.4" +version = "0.8.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e496a50fda8aacccc86d7529e2c1e0892dbd0f898a6b5645b5561b89c3210efa" +checksum = "06ea2b9bc92be3c2baa9334a323ebca2d6f074ff852cd1d7b11064035cd3868f" [[package]] name = "core2" @@ -1089,7 +1090,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "978747c1d849a7d2ee5e8adc0159961c48fb7e5db2f06af6723b80123bb53856" dependencies = [ "cfg-if", - "hashbrown 0.14.2", + "hashbrown 0.14.3", "lock_api", "once_cell", "parking_lot_core", @@ -1121,7 +1122,7 @@ dependencies = [ "futures", "glob", "half", - "hashbrown 0.14.2", + "hashbrown 0.14.3", "indexmap 2.1.0", "itertools 0.12.0", "log", @@ -1154,11 +1155,13 @@ dependencies = [ "clap", "ctor", "datafusion", + "datafusion-common", "dirs", "env_logger", "mimalloc", "object_store", "parking_lot", + "parquet", "predicates", "regex", "rstest", @@ -1196,7 +1199,7 @@ dependencies = [ "datafusion-common", "datafusion-expr", "futures", - "hashbrown 0.14.2", + "hashbrown 0.14.3", "log", "object_store", "parking_lot", @@ -1229,7 +1232,7 @@ dependencies = [ "datafusion-common", "datafusion-expr", "datafusion-physical-expr", - "hashbrown 0.14.2", + "hashbrown 0.14.3", "itertools 0.12.0", "log", "regex-syntax", @@ -1252,7 +1255,7 @@ dependencies = [ "datafusion-common", "datafusion-expr", "half", - "hashbrown 0.14.2", + "hashbrown 0.14.3", "hex", "indexmap 2.1.0", "itertools 0.12.0", @@ -1284,7 +1287,7 @@ dependencies = [ "datafusion-physical-expr", "futures", "half", - "hashbrown 0.14.2", + "hashbrown 0.14.3", "indexmap 2.1.0", "itertools 0.12.0", "log", @@ -1310,9 +1313,9 @@ dependencies = [ [[package]] name = "deranged" -version = "0.3.9" +version = "0.3.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0f32d04922c60427da6f9fef14d042d9edddef64cb9d4ce0d64d0685fbeb1fd3" +checksum = "8eb30d70a07a3b04884d2677f06bec33509dc67ca60d92949e5535352d3191dc" dependencies = [ "powerfmt", ] @@ -1423,12 +1426,12 @@ checksum = "5443807d6dff69373d433ab9ef5378ad8df50ca6298caf15de6e52e24aaf54d5" [[package]] name = "errno" -version = "0.3.6" +version = "0.3.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7c18ee0ed65a5f1f81cac6b1d213b69c35fa47d4252ad41f1486dbd8226fe36e" +checksum = "a258e46cdc063eb8519c00b9fc845fc47bcfca4130e2f08e88665ceda8474245" dependencies = [ "libc", - "windows-sys", + "windows-sys 0.52.0", ] [[package]] @@ -1464,7 +1467,7 @@ checksum = "ef033ed5e9bad94e55838ca0ca906db0e043f517adda0c8b79c7a8c66c93c1b5" dependencies = [ "cfg-if", "rustix", - "windows-sys", + "windows-sys 0.48.0", ] [[package]] @@ -1510,9 +1513,9 @@ checksum = "3f9eec918d3f24069decb9af1554cad7c880e2da24a9afd88aca000531ab82c1" [[package]] name = "form_urlencoded" -version = "1.2.0" +version = "1.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a62bc1cf6f830c2ec14a513a9fb124d0a213a629668a4186f329db21fe045652" +checksum = "e13624c2627564efccf4934284bdd98cbaa14e79b0b5a141218e507b3a823456" dependencies = [ "percent-encoding", ] @@ -1635,9 +1638,9 @@ dependencies = [ [[package]] name = "gimli" -version = "0.28.0" +version = "0.28.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6fb8d784f27acf97159b40fc4db5ecd8aa23b9ad5ef69cdd136d3bc80665f0c0" +checksum = "4271d37baee1b8c7e4b708028c57d816cf9d2434acb33a549475f78c181f6253" [[package]] name = "glob" @@ -1647,9 +1650,9 @@ checksum = "d2fabcfbdc87f4758337ca535fb41a6d701b65693ce38287d856d1674551ec9b" [[package]] name = "h2" -version = "0.3.21" +version = "0.3.22" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "91fc23aa11be92976ef4729127f1a74adf36d8436f7816b185d18df956790833" +checksum = "4d6250322ef6e60f93f9a2162799302cd6f68f79f6e5d85c8c16f14d1d958178" dependencies = [ "bytes", "fnv", @@ -1657,7 +1660,7 @@ dependencies = [ "futures-sink", "futures-util", "http", - "indexmap 1.9.3", + "indexmap 2.1.0", "slab", "tokio", "tokio-util", @@ -1692,9 +1695,9 @@ dependencies = [ [[package]] name = "hashbrown" -version = "0.14.2" +version = "0.14.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f93e7192158dbcda357bdec5fb5788eebf8bbac027f3f33e719d29135ae84156" +checksum = "290f1a1d9242c78d09ce40a5e87e7554ee637af1351968159f4952f028f75604" dependencies = [ "ahash", "allocator-api2", @@ -1738,9 +1741,9 @@ dependencies = [ [[package]] name = "http" -version = "0.2.10" +version = "0.2.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f95b9abcae896730d42b78e09c155ed4ddf82c07b4de772c64aee5b2d8b7c150" +checksum = "8947b1a6fad4393052c7ba1f4cd97bed3e953a95c79c92ad9b051a04611d9fbb" dependencies = [ "bytes", "fnv", @@ -1824,7 +1827,7 @@ dependencies = [ "futures-util", "http", "hyper", - "rustls 0.21.8", + "rustls 0.21.9", "tokio", "tokio-rustls 0.24.1", ] @@ -1854,9 +1857,9 @@ dependencies = [ [[package]] name = "idna" -version = "0.4.0" +version = "0.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7d20d6b07bfbc108882d88ed8e37d39636dcc260e15e30c45e6ba089610b917c" +checksum = "634d9b1461af396cad843f47fdba5597a4f9e6ddd4bfb6ff5d85028c25cb12f6" dependencies = [ "unicode-bidi", "unicode-normalization", @@ -1879,7 +1882,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d530e1a18b1cb4c484e6e34556a0d948706958449fca0cab753d649f2bce3d1f" dependencies = [ "equivalent", - "hashbrown 0.14.2", + "hashbrown 0.14.3", ] [[package]] @@ -1927,11 +1930,20 @@ version = "1.0.9" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "af150ab688ff2122fcef229be89cb50dd66af9e01a4ff320cc137eecc9bacc38" +[[package]] +name = "jobserver" +version = "0.1.27" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8c37f63953c4c63420ed5fd3d6d398c719489b9f872b9fa683262f8edd363c7d" +dependencies = [ + "libc", +] + [[package]] name = "js-sys" -version = "0.3.65" +version = "0.3.66" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "54c0c35952f67de54bb584e9fd912b3023117cbafc0a77d8f3dee1fb5f572fe8" +checksum = "cee9c64da59eae3b50095c18d3e74f8b73c0b86d2792824ff01bbce68ba229ca" dependencies = [ "wasm-bindgen", ] @@ -2065,9 +2077,9 @@ dependencies = [ [[package]] name = "linux-raw-sys" -version = "0.4.11" +version = "0.4.12" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "969488b55f8ac402214f3f5fd243ebb7206cf82de60d3172994707a4bcc2b829" +checksum = "c4cd1a83af159aa67994778be9070f0ae1bd732942279cabb14f86f986a21456" [[package]] name = "lock_api" @@ -2147,13 +2159,13 @@ dependencies = [ [[package]] name = "mio" -version = "0.8.9" +version = "0.8.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3dce281c5e46beae905d4de1870d8b1509a9142b62eedf18b443b011ca8343d0" +checksum = "8f3d0b296e374a4e6f3c7b0a1f5a51d748a0d34c85e7dc48fc3fa9a87657fe09" dependencies = [ "libc", "wasi", - "windows-sys", + "windows-sys 0.48.0", ] [[package]] @@ -2297,7 +2309,7 @@ dependencies = [ "quick-xml", "rand", "reqwest", - "ring 0.17.5", + "ring 0.17.7", "rustls-pemfile", "serde", "serde_json", @@ -2361,7 +2373,7 @@ dependencies = [ "libc", "redox_syscall", "smallvec", - "windows-targets", + "windows-targets 0.48.5", ] [[package]] @@ -2384,7 +2396,7 @@ dependencies = [ "chrono", "flate2", "futures", - "hashbrown 0.14.2", + "hashbrown 0.14.3", "lz4_flex", "num", "num-bigint", @@ -2415,9 +2427,9 @@ checksum = "de3145af08024dea9fa9914f381a17b8fc6034dfb00f3a84013f7ff43f29ed4c" [[package]] name = "percent-encoding" -version = "2.3.0" +version = "2.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9b2a4787296e9989611394c33f193f676704af1686e70b8f8033ab5ba9a35a94" +checksum = "e3148f5046208a5d56bcfc03053e3ca6334e51da8dfb19b6cdc8b306fae3283e" [[package]] name = "petgraph" @@ -2574,9 +2586,9 @@ dependencies = [ [[package]] name = "proc-macro2" -version = "1.0.69" +version = "1.0.70" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "134c189feb4956b20f6f547d2cf727d4c0fe06722b20a0eec87ed445a97f92da" +checksum = "39278fbbf5fb4f646ce651690877f89d1c5811a3d4acb27700c1cb3cdb78fd3b" dependencies = [ "unicode-ident", ] @@ -2724,7 +2736,7 @@ dependencies = [ "once_cell", "percent-encoding", "pin-project-lite", - "rustls 0.21.8", + "rustls 0.21.9", "rustls-pemfile", "serde", "serde_json", @@ -2760,16 +2772,16 @@ dependencies = [ [[package]] name = "ring" -version = "0.17.5" +version = "0.17.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fb0205304757e5d899b9c2e448b867ffd03ae7f988002e47cd24954391394d0b" +checksum = "688c63d65483050968b2a8937f7995f443e27041a0f7700aa59b0822aedebb74" dependencies = [ "cc", "getrandom", "libc", "spin 0.9.8", "untrusted 0.9.0", - "windows-sys", + "windows-sys 0.48.0", ] [[package]] @@ -2821,15 +2833,15 @@ dependencies = [ [[package]] name = "rustix" -version = "0.38.21" +version = "0.38.26" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2b426b0506e5d50a7d8dafcf2e81471400deb602392c7dd110815afb4eaf02a3" +checksum = "9470c4bf8246c8daf25f9598dca807fb6510347b1e1cfa55749113850c79d88a" dependencies = [ "bitflags 2.4.1", "errno", "libc", "linux-raw-sys", - "windows-sys", + "windows-sys 0.52.0", ] [[package]] @@ -2846,12 +2858,12 @@ dependencies = [ [[package]] name = "rustls" -version = "0.21.8" +version = "0.21.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "446e14c5cda4f3f30fe71863c34ec70f5ac79d6087097ad0bb433e1be5edf04c" +checksum = "629648aced5775d558af50b2b4c7b02983a04b312126d45eeead26e7caa498b9" dependencies = [ "log", - "ring 0.17.5", + "ring 0.17.7", "rustls-webpki", "sct", ] @@ -2883,7 +2895,7 @@ version = "0.101.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8b6275d1ee7a1cd780b64aca7726599a1dbc893b1e64144529e55c3c2f745765" dependencies = [ - "ring 0.17.5", + "ring 0.17.7", "untrusted 0.9.0", ] @@ -2937,7 +2949,7 @@ version = "0.1.22" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0c3733bf4cf7ea0880754e19cb5a462007c4a8c1914bff372ccc95b464f1df88" dependencies = [ - "windows-sys", + "windows-sys 0.48.0", ] [[package]] @@ -2952,7 +2964,7 @@ version = "0.7.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "da046153aa2352493d6cb7da4b6e5c0c057d8a1d0a9aa8560baffdd945acd414" dependencies = [ - "ring 0.17.5", + "ring 0.17.7", "untrusted 0.9.0", ] @@ -2993,18 +3005,18 @@ checksum = "a3f0bf26fd526d2a95683cd0f87bf103b8539e2ca1ef48ce002d67aad59aa0b4" [[package]] name = "serde" -version = "1.0.192" +version = "1.0.193" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bca2a08484b285dcb282d0f67b26cadc0df8b19f8c12502c13d966bf9482f001" +checksum = "25dd9975e68d0cb5aa1120c288333fc98731bd1dd12f561e468ea4728c042b89" dependencies = [ "serde_derive", ] [[package]] name = "serde_derive" -version = "1.0.192" +version = "1.0.193" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d6c7207fbec9faa48073f3e3074cbe553af6ea512d7c21ba46e434e70ea9fbc1" +checksum = "43576ca501357b9b071ac53cdc7da8ef0cbd9493d8df094cd821777ea6e894d3" dependencies = [ "proc-macro2", "quote", @@ -3090,9 +3102,9 @@ dependencies = [ [[package]] name = "snap" -version = "1.1.0" +version = "1.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5e9f0ab6ef7eb7353d9119c170a436d1bf248eea575ac42d19d12f4e34130831" +checksum = "1b6b67fb9a61334225b5b790716f609cd58395f895b3fe8b328786812a40bc3b" [[package]] name = "socket2" @@ -3111,7 +3123,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7b5fac59a5cb5dd637972e5fca70daf0523c9067fcdc4842f053dae04a18f8e9" dependencies = [ "libc", - "windows-sys", + "windows-sys 0.48.0", ] [[package]] @@ -3128,9 +3140,9 @@ checksum = "6980e8d7511241f8acf4aebddbb1ff938df5eebe98691418c4468d0b72a96a67" [[package]] name = "sqlparser" -version = "0.39.0" +version = "0.40.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "743b4dc2cbde11890ccb254a8fc9d537fa41b36da00de2a1c5e9848c9bc42bd7" +checksum = "7c80afe31cdb649e56c0d9bb5503be9166600d68a852c38dd445636d126858e5" dependencies = [ "log", "sqlparser_derive", @@ -3138,9 +3150,9 @@ dependencies = [ [[package]] name = "sqlparser_derive" -version = "0.1.1" +version = "0.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "55fe75cb4a364c7f7ae06c7dbbc8d84bddd85d6cdf9975963c3935bc1991761e" +checksum = "3e9c2e1dde0efa87003e7923d94a90f46e3274ad1649f51de96812be561f041f" dependencies = [ "proc-macro2", "quote", @@ -3246,14 +3258,14 @@ dependencies = [ "fastrand 2.0.1", "redox_syscall", "rustix", - "windows-sys", + "windows-sys 0.48.0", ] [[package]] name = "termcolor" -version = "1.3.0" +version = "1.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6093bad37da69aab9d123a8091e4be0aa4a03e4d601ec641c327398315f62b64" +checksum = "ff1bc3d3f05aff0403e8ac0d92ced918ec05b666a43f83297ccef5bea8a3d449" dependencies = [ "winapi-util", ] @@ -3368,7 +3380,7 @@ dependencies = [ "pin-project-lite", "socket2 0.5.5", "tokio-macros", - "windows-sys", + "windows-sys 0.48.0", ] [[package]] @@ -3399,7 +3411,7 @@ version = "0.24.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c28327cf380ac148141087fbfb9de9d7bd4e84ab5d2c28fbc911d753de8a7081" dependencies = [ - "rustls 0.21.8", + "rustls 0.21.9", "tokio", ] @@ -3577,9 +3589,9 @@ checksum = "8ecb6da28b8a351d773b68d5825ac39017e680750f980f3a1a85cd8dd28a47c1" [[package]] name = "url" -version = "2.4.1" +version = "2.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "143b538f18257fac9cad154828a57c6bf5157e1aa604d4816b5995bf6de87ae5" +checksum = "31e6302e3bb753d46e83516cae55ae196fc0c309407cf11ab35cc51a4c2a4633" dependencies = [ "form_urlencoded", "idna", @@ -3600,9 +3612,9 @@ checksum = "711b9620af191e0cdc7468a8d14e709c3dcdb115b36f838e601583af800a370a" [[package]] name = "uuid" -version = "1.5.0" +version = "1.6.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "88ad59a7560b41a70d191093a945f0b87bc1deeda46fb237479708a1d6b6cdfc" +checksum = "5e395fcf16a7a3d8127ec99782007af141946b4795001f876d54fb0d55978560" dependencies = [ "getrandom", "serde", @@ -3656,9 +3668,9 @@ checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423" [[package]] name = "wasm-bindgen" -version = "0.2.88" +version = "0.2.89" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7daec296f25a1bae309c0cd5c29c4b260e510e6d813c286b19eaadf409d40fce" +checksum = "0ed0d4f68a3015cc185aff4db9506a015f4b96f95303897bfa23f846db54064e" dependencies = [ "cfg-if", "wasm-bindgen-macro", @@ -3666,9 +3678,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-backend" -version = "0.2.88" +version = "0.2.89" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e397f4664c0e4e428e8313a469aaa58310d302159845980fd23b0f22a847f217" +checksum = "1b56f625e64f3a1084ded111c4d5f477df9f8c92df113852fa5a374dbda78826" dependencies = [ "bumpalo", "log", @@ -3681,9 +3693,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-futures" -version = "0.4.38" +version = "0.4.39" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9afec9963e3d0994cac82455b2b3502b81a7f40f9a0d32181f7528d9f4b43e02" +checksum = "ac36a15a220124ac510204aec1c3e5db8a22ab06fd6706d881dc6149f8ed9a12" dependencies = [ "cfg-if", "js-sys", @@ -3693,9 +3705,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-macro" -version = "0.2.88" +version = "0.2.89" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5961017b3b08ad5f3fe39f1e79877f8ee7c23c5e5fd5eb80de95abc41f1f16b2" +checksum = "0162dbf37223cd2afce98f3d0785506dcb8d266223983e4b5b525859e6e182b2" dependencies = [ "quote", "wasm-bindgen-macro-support", @@ -3703,9 +3715,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-macro-support" -version = "0.2.88" +version = "0.2.89" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c5353b8dab669f5e10f5bd76df26a9360c748f054f862ff5f3f8aae0c7fb3907" +checksum = "f0eb82fcb7930ae6219a7ecfd55b217f5f0893484b7a13022ebb2b2bf20b5283" dependencies = [ "proc-macro2", "quote", @@ -3716,9 +3728,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-shared" -version = "0.2.88" +version = "0.2.89" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0d046c5d029ba91a1ed14da14dca44b68bf2f124cfbaf741c54151fdb3e0750b" +checksum = "7ab9b36309365056cd639da3134bf87fa8f3d86008abf99e612384a6eecd459f" [[package]] name = "wasm-streams" @@ -3735,9 +3747,9 @@ dependencies = [ [[package]] name = "web-sys" -version = "0.3.65" +version = "0.3.66" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5db499c5f66323272151db0e666cd34f78617522fb0c1604d31a27c50c206a85" +checksum = "50c24a44ec86bb68fbecd1b3efed7e85ea5621b39b35ef2766b66cd984f8010f" dependencies = [ "js-sys", "wasm-bindgen", @@ -3749,15 +3761,15 @@ version = "0.22.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ed63aea5ce73d0ff405984102c42de94fc55a6b75765d621c65262469b3c9b53" dependencies = [ - "ring 0.17.5", + "ring 0.17.7", "untrusted 0.9.0", ] [[package]] name = "webpki-roots" -version = "0.25.2" +version = "0.25.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "14247bb57be4f377dfb94c72830b8ce8fc6beac03cf4bf7b9732eadd414123fc" +checksum = "1778a42e8b3b90bff8d0f5032bf22250792889a5cdc752aa0020c84abe3aaf10" [[package]] name = "winapi" @@ -3796,7 +3808,7 @@ version = "0.51.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f1f8cf84f35d2db49a46868f947758c7a1138116f7fac3bc844f43ade1292e64" dependencies = [ - "windows-targets", + "windows-targets 0.48.5", ] [[package]] @@ -3805,7 +3817,16 @@ version = "0.48.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "677d2418bec65e3338edb076e806bc1ec15693c5d0104683f2efe857f61056a9" dependencies = [ - "windows-targets", + "windows-targets 0.48.5", +] + +[[package]] +name = "windows-sys" +version = "0.52.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "282be5f36a8ce781fad8c8ae18fa3f9beff57ec1b52cb3de0789201425d9a33d" +dependencies = [ + "windows-targets 0.52.0", ] [[package]] @@ -3814,13 +3835,28 @@ version = "0.48.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9a2fa6e2155d7247be68c096456083145c183cbbbc2764150dda45a87197940c" dependencies = [ - "windows_aarch64_gnullvm", - "windows_aarch64_msvc", - "windows_i686_gnu", - "windows_i686_msvc", - "windows_x86_64_gnu", - "windows_x86_64_gnullvm", - "windows_x86_64_msvc", + "windows_aarch64_gnullvm 0.48.5", + "windows_aarch64_msvc 0.48.5", + "windows_i686_gnu 0.48.5", + "windows_i686_msvc 0.48.5", + "windows_x86_64_gnu 0.48.5", + "windows_x86_64_gnullvm 0.48.5", + "windows_x86_64_msvc 0.48.5", +] + +[[package]] +name = "windows-targets" +version = "0.52.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8a18201040b24831fbb9e4eb208f8892e1f50a37feb53cc7ff887feb8f50e7cd" +dependencies = [ + "windows_aarch64_gnullvm 0.52.0", + "windows_aarch64_msvc 0.52.0", + "windows_i686_gnu 0.52.0", + "windows_i686_msvc 0.52.0", + "windows_x86_64_gnu 0.52.0", + "windows_x86_64_gnullvm 0.52.0", + "windows_x86_64_msvc 0.52.0", ] [[package]] @@ -3829,42 +3865,84 @@ version = "0.48.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2b38e32f0abccf9987a4e3079dfb67dcd799fb61361e53e2882c3cbaf0d905d8" +[[package]] +name = "windows_aarch64_gnullvm" +version = "0.52.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cb7764e35d4db8a7921e09562a0304bf2f93e0a51bfccee0bd0bb0b666b015ea" + [[package]] name = "windows_aarch64_msvc" version = "0.48.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "dc35310971f3b2dbbf3f0690a219f40e2d9afcf64f9ab7cc1be722937c26b4bc" +[[package]] +name = "windows_aarch64_msvc" +version = "0.52.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bbaa0368d4f1d2aaefc55b6fcfee13f41544ddf36801e793edbbfd7d7df075ef" + [[package]] name = "windows_i686_gnu" version = "0.48.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a75915e7def60c94dcef72200b9a8e58e5091744960da64ec734a6c6e9b3743e" +[[package]] +name = "windows_i686_gnu" +version = "0.52.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a28637cb1fa3560a16915793afb20081aba2c92ee8af57b4d5f28e4b3e7df313" + [[package]] name = "windows_i686_msvc" version = "0.48.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8f55c233f70c4b27f66c523580f78f1004e8b5a8b659e05a4eb49d4166cca406" +[[package]] +name = "windows_i686_msvc" +version = "0.52.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ffe5e8e31046ce6230cc7215707b816e339ff4d4d67c65dffa206fd0f7aa7b9a" + [[package]] name = "windows_x86_64_gnu" version = "0.48.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "53d40abd2583d23e4718fddf1ebec84dbff8381c07cae67ff7768bbf19c6718e" +[[package]] +name = "windows_x86_64_gnu" +version = "0.52.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3d6fa32db2bc4a2f5abeacf2b69f7992cd09dca97498da74a151a3132c26befd" + [[package]] name = "windows_x86_64_gnullvm" version = "0.48.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0b7b52767868a23d5bab768e390dc5f5c55825b6d30b86c844ff2dc7414044cc" +[[package]] +name = "windows_x86_64_gnullvm" +version = "0.52.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1a657e1e9d3f514745a572a6846d3c7aa7dbe1658c056ed9c3344c4109a6949e" + [[package]] name = "windows_x86_64_msvc" version = "0.48.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ed94fce61571a4006852b7389a063ab983c02eb1bb37b47f8272ce92d06d9538" +[[package]] +name = "windows_x86_64_msvc" +version = "0.52.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dff9641d1cd4be8d1a070daf9e3773c5f67e78b4d9d42263020c057706765c04" + [[package]] name = "winreg" version = "0.50.0" @@ -3872,7 +3950,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "524e57b2c537c0f9b1e69f1965311ec12182b4122e45035b1508cd24d2adadb1" dependencies = [ "cfg-if", - "windows-sys", + "windows-sys 0.48.0", ] [[package]] @@ -3892,18 +3970,18 @@ dependencies = [ [[package]] name = "zerocopy" -version = "0.7.25" +version = "0.7.29" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8cd369a67c0edfef15010f980c3cbe45d7f651deac2cd67ce097cd801de16557" +checksum = "5d075cf85bbb114e933343e087b92f2146bac0d55b534cbb8188becf0039948e" dependencies = [ "zerocopy-derive", ] [[package]] name = "zerocopy-derive" -version = "0.7.25" +version = "0.7.29" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c2f140bda219a26ccc0cdb03dba58af72590c53b22642577d88a927bc5c87d6b" +checksum = "86cd5ca076997b97ef09d3ad65efe811fa68c9e874cb636ccb211223a813b0c2" dependencies = [ "proc-macro2", "quote", @@ -3912,9 +3990,9 @@ dependencies = [ [[package]] name = "zeroize" -version = "1.6.0" +version = "1.7.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2a0956f1ba7c7909bfb66c2e9e4124ab6f6482560f6628b5aaeba39207c9aad9" +checksum = "525b4ec142c6b68a2d10f01f7bbf6755599ca3f81ea53b8431b7dd348f5fdb2d" [[package]] name = "zstd" diff --git a/datafusion-cli/Cargo.toml b/datafusion-cli/Cargo.toml index dd7a077988cb..5ce318aea3ac 100644 --- a/datafusion-cli/Cargo.toml +++ b/datafusion-cli/Cargo.toml @@ -40,6 +40,7 @@ env_logger = "0.9" mimalloc = { version = "0.1", default-features = false } object_store = { version = "0.8.0", features = ["aws", "gcp"] } parking_lot = { version = "0.12" } +parquet = { version = "49.0.0", default-features = false } regex = "1.8" rustyline = "11.0" tokio = { version = "1.24", features = ["macros", "rt", "rt-multi-thread", "sync", "parking_lot"] } @@ -48,5 +49,6 @@ url = "2.2" [dev-dependencies] assert_cmd = "2.0" ctor = "0.2.0" +datafusion-common = { path = "../datafusion/common" } predicates = "3.0" rstest = "0.17" diff --git a/datafusion-cli/src/exec.rs b/datafusion-cli/src/exec.rs index 1869e15ef584..8af534cd1375 100644 --- a/datafusion-cli/src/exec.rs +++ b/datafusion-cli/src/exec.rs @@ -211,7 +211,7 @@ async fn exec_and_print( })?; let statements = DFParser::parse_sql_with_dialect(&sql, dialect.as_ref())?; for statement in statements { - let plan = ctx.state().statement_to_plan(statement).await?; + let mut plan = ctx.state().statement_to_plan(statement).await?; // For plans like `Explain` ignore `MaxRows` option and always display all rows let should_ignore_maxrows = matches!( @@ -221,14 +221,13 @@ async fn exec_and_print( | LogicalPlan::Analyze(_) ); - let df = match &plan { - LogicalPlan::Ddl(DdlStatement::CreateExternalTable(cmd)) => { - create_external_table(ctx, cmd).await?; - ctx.execute_logical_plan(plan).await? - } - _ => ctx.execute_logical_plan(plan).await?, - }; - + // Note that cmd is a mutable reference so that create_external_table function can remove all + // datafusion-cli specific options before passing through to datafusion. Otherwise, datafusion + // will raise Configuration errors. + if let LogicalPlan::Ddl(DdlStatement::CreateExternalTable(cmd)) = &mut plan { + create_external_table(ctx, cmd).await?; + } + let df = ctx.execute_logical_plan(plan).await?; let results = df.collect().await?; let print_options = if should_ignore_maxrows { @@ -247,7 +246,7 @@ async fn exec_and_print( async fn create_external_table( ctx: &SessionContext, - cmd: &CreateExternalTable, + cmd: &mut CreateExternalTable, ) -> Result<()> { let table_path = ListingTableUrl::parse(&cmd.location)?; let scheme = table_path.scheme(); @@ -288,15 +287,32 @@ async fn create_external_table( #[cfg(test)] mod tests { + use std::str::FromStr; + use super::*; use datafusion::common::plan_err; + use datafusion_common::{file_options::StatementOptions, FileTypeWriterOptions}; async fn create_external_table_test(location: &str, sql: &str) -> Result<()> { let ctx = SessionContext::new(); - let plan = ctx.state().create_logical_plan(sql).await?; + let mut plan = ctx.state().create_logical_plan(sql).await?; - if let LogicalPlan::Ddl(DdlStatement::CreateExternalTable(cmd)) = &plan { + if let LogicalPlan::Ddl(DdlStatement::CreateExternalTable(cmd)) = &mut plan { create_external_table(&ctx, cmd).await?; + let options: Vec<_> = cmd + .options + .iter() + .map(|(k, v)| (k.clone(), v.clone())) + .collect(); + let statement_options = StatementOptions::new(options); + let file_type = + datafusion_common::FileType::from_str(cmd.file_type.as_str())?; + + let _file_type_writer_options = FileTypeWriterOptions::build( + &file_type, + ctx.state().config_options(), + &statement_options, + )?; } else { return plan_err!("LogicalPlan is not a CreateExternalTable"); } diff --git a/datafusion-cli/src/functions.rs b/datafusion-cli/src/functions.rs index eeebe713d716..24f3399ee2be 100644 --- a/datafusion-cli/src/functions.rs +++ b/datafusion-cli/src/functions.rs @@ -16,12 +16,26 @@ // under the License. //! Functions that are query-able and searchable via the `\h` command -use arrow::array::StringArray; -use arrow::datatypes::{DataType, Field, Schema}; +use arrow::array::{Int64Array, StringArray}; +use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; use arrow::record_batch::RecordBatch; use arrow::util::pretty::pretty_format_batches; +use async_trait::async_trait; +use datafusion::common::DataFusionError; +use datafusion::common::{plan_err, Column}; +use datafusion::datasource::function::TableFunctionImpl; +use datafusion::datasource::TableProvider; use datafusion::error::Result; +use datafusion::execution::context::SessionState; +use datafusion::logical_expr::Expr; +use datafusion::physical_plan::memory::MemoryExec; +use datafusion::physical_plan::ExecutionPlan; +use datafusion::scalar::ScalarValue; +use parquet::file::reader::FileReader; +use parquet::file::serialized_reader::SerializedFileReader; +use parquet::file::statistics::Statistics; use std::fmt; +use std::fs::File; use std::str::FromStr; use std::sync::Arc; @@ -196,3 +210,208 @@ pub fn display_all_functions() -> Result<()> { println!("{}", pretty_format_batches(&[batch]).unwrap()); Ok(()) } + +/// PARQUET_META table function +struct ParquetMetadataTable { + schema: SchemaRef, + batch: RecordBatch, +} + +#[async_trait] +impl TableProvider for ParquetMetadataTable { + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn schema(&self) -> arrow::datatypes::SchemaRef { + self.schema.clone() + } + + fn table_type(&self) -> datafusion::logical_expr::TableType { + datafusion::logical_expr::TableType::Base + } + + async fn scan( + &self, + _state: &SessionState, + projection: Option<&Vec>, + _filters: &[Expr], + _limit: Option, + ) -> Result> { + Ok(Arc::new(MemoryExec::try_new( + &[vec![self.batch.clone()]], + TableProvider::schema(self), + projection.cloned(), + )?)) + } +} + +pub struct ParquetMetadataFunc {} + +impl TableFunctionImpl for ParquetMetadataFunc { + fn call(&self, exprs: &[Expr]) -> Result> { + let filename = match exprs.get(0) { + Some(Expr::Literal(ScalarValue::Utf8(Some(s)))) => s, // single quote: parquet_metadata('x.parquet') + Some(Expr::Column(Column { name, .. })) => name, // double quote: parquet_metadata("x.parquet") + _ => { + return plan_err!( + "parquet_metadata requires string argument as its input" + ); + } + }; + + let file = File::open(filename.clone())?; + let reader = SerializedFileReader::new(file)?; + let metadata = reader.metadata(); + + let schema = Arc::new(Schema::new(vec![ + Field::new("filename", DataType::Utf8, true), + Field::new("row_group_id", DataType::Int64, true), + Field::new("row_group_num_rows", DataType::Int64, true), + Field::new("row_group_num_columns", DataType::Int64, true), + Field::new("row_group_bytes", DataType::Int64, true), + Field::new("column_id", DataType::Int64, true), + Field::new("file_offset", DataType::Int64, true), + Field::new("num_values", DataType::Int64, true), + Field::new("path_in_schema", DataType::Utf8, true), + Field::new("type", DataType::Utf8, true), + Field::new("stats_min", DataType::Utf8, true), + Field::new("stats_max", DataType::Utf8, true), + Field::new("stats_null_count", DataType::Int64, true), + Field::new("stats_distinct_count", DataType::Int64, true), + Field::new("stats_min_value", DataType::Utf8, true), + Field::new("stats_max_value", DataType::Utf8, true), + Field::new("compression", DataType::Utf8, true), + Field::new("encodings", DataType::Utf8, true), + Field::new("index_page_offset", DataType::Int64, true), + Field::new("dictionary_page_offset", DataType::Int64, true), + Field::new("data_page_offset", DataType::Int64, true), + Field::new("total_compressed_size", DataType::Int64, true), + Field::new("total_uncompressed_size", DataType::Int64, true), + ])); + + // construct recordbatch from metadata + let mut filename_arr = vec![]; + let mut row_group_id_arr = vec![]; + let mut row_group_num_rows_arr = vec![]; + let mut row_group_num_columns_arr = vec![]; + let mut row_group_bytes_arr = vec![]; + let mut column_id_arr = vec![]; + let mut file_offset_arr = vec![]; + let mut num_values_arr = vec![]; + let mut path_in_schema_arr = vec![]; + let mut type_arr = vec![]; + let mut stats_min_arr = vec![]; + let mut stats_max_arr = vec![]; + let mut stats_null_count_arr = vec![]; + let mut stats_distinct_count_arr = vec![]; + let mut stats_min_value_arr = vec![]; + let mut stats_max_value_arr = vec![]; + let mut compression_arr = vec![]; + let mut encodings_arr = vec![]; + let mut index_page_offset_arr = vec![]; + let mut dictionary_page_offset_arr = vec![]; + let mut data_page_offset_arr = vec![]; + let mut total_compressed_size_arr = vec![]; + let mut total_uncompressed_size_arr = vec![]; + for (rg_idx, row_group) in metadata.row_groups().iter().enumerate() { + for (col_idx, column) in row_group.columns().iter().enumerate() { + filename_arr.push(filename.clone()); + row_group_id_arr.push(rg_idx as i64); + row_group_num_rows_arr.push(row_group.num_rows()); + row_group_num_columns_arr.push(row_group.num_columns() as i64); + row_group_bytes_arr.push(row_group.total_byte_size()); + column_id_arr.push(col_idx as i64); + file_offset_arr.push(column.file_offset()); + num_values_arr.push(column.num_values()); + path_in_schema_arr.push(column.column_path().to_string()); + type_arr.push(column.column_type().to_string()); + if let Some(s) = column.statistics() { + let (min_val, max_val) = if s.has_min_max_set() { + let (min_val, max_val) = match s { + Statistics::Boolean(val) => { + (val.min().to_string(), val.max().to_string()) + } + Statistics::Int32(val) => { + (val.min().to_string(), val.max().to_string()) + } + Statistics::Int64(val) => { + (val.min().to_string(), val.max().to_string()) + } + Statistics::Int96(val) => { + (val.min().to_string(), val.max().to_string()) + } + Statistics::Float(val) => { + (val.min().to_string(), val.max().to_string()) + } + Statistics::Double(val) => { + (val.min().to_string(), val.max().to_string()) + } + Statistics::ByteArray(val) => { + (val.min().to_string(), val.max().to_string()) + } + Statistics::FixedLenByteArray(val) => { + (val.min().to_string(), val.max().to_string()) + } + }; + (Some(min_val), Some(max_val)) + } else { + (None, None) + }; + stats_min_arr.push(min_val.clone()); + stats_max_arr.push(max_val.clone()); + stats_null_count_arr.push(Some(s.null_count() as i64)); + stats_distinct_count_arr.push(s.distinct_count().map(|c| c as i64)); + stats_min_value_arr.push(min_val); + stats_max_value_arr.push(max_val); + } else { + stats_min_arr.push(None); + stats_max_arr.push(None); + stats_null_count_arr.push(None); + stats_distinct_count_arr.push(None); + stats_min_value_arr.push(None); + stats_max_value_arr.push(None); + }; + compression_arr.push(format!("{:?}", column.compression())); + encodings_arr.push(format!("{:?}", column.encodings())); + index_page_offset_arr.push(column.index_page_offset()); + dictionary_page_offset_arr.push(column.dictionary_page_offset()); + data_page_offset_arr.push(column.data_page_offset()); + total_compressed_size_arr.push(column.compressed_size()); + total_uncompressed_size_arr.push(column.uncompressed_size()); + } + } + + let rb = RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(StringArray::from(filename_arr)), + Arc::new(Int64Array::from(row_group_id_arr)), + Arc::new(Int64Array::from(row_group_num_rows_arr)), + Arc::new(Int64Array::from(row_group_num_columns_arr)), + Arc::new(Int64Array::from(row_group_bytes_arr)), + Arc::new(Int64Array::from(column_id_arr)), + Arc::new(Int64Array::from(file_offset_arr)), + Arc::new(Int64Array::from(num_values_arr)), + Arc::new(StringArray::from(path_in_schema_arr)), + Arc::new(StringArray::from(type_arr)), + Arc::new(StringArray::from(stats_min_arr)), + Arc::new(StringArray::from(stats_max_arr)), + Arc::new(Int64Array::from(stats_null_count_arr)), + Arc::new(Int64Array::from(stats_distinct_count_arr)), + Arc::new(StringArray::from(stats_min_value_arr)), + Arc::new(StringArray::from(stats_max_value_arr)), + Arc::new(StringArray::from(compression_arr)), + Arc::new(StringArray::from(encodings_arr)), + Arc::new(Int64Array::from(index_page_offset_arr)), + Arc::new(Int64Array::from(dictionary_page_offset_arr)), + Arc::new(Int64Array::from(data_page_offset_arr)), + Arc::new(Int64Array::from(total_compressed_size_arr)), + Arc::new(Int64Array::from(total_uncompressed_size_arr)), + ], + )?; + + let parquet_metadata = ParquetMetadataTable { schema, batch: rb }; + Ok(Arc::new(parquet_metadata)) + } +} diff --git a/datafusion-cli/src/main.rs b/datafusion-cli/src/main.rs index c069f458f196..8b1a9816afc0 100644 --- a/datafusion-cli/src/main.rs +++ b/datafusion-cli/src/main.rs @@ -22,6 +22,7 @@ use datafusion::execution::memory_pool::{FairSpillPool, GreedyMemoryPool}; use datafusion::execution::runtime_env::{RuntimeConfig, RuntimeEnv}; use datafusion::prelude::SessionContext; use datafusion_cli::catalog::DynamicFileCatalog; +use datafusion_cli::functions::ParquetMetadataFunc; use datafusion_cli::{ exec, print_format::PrintFormat, @@ -185,6 +186,8 @@ pub async fn main() -> Result<()> { ctx.state().catalog_list(), ctx.state_weak_ref(), ))); + // register `parquet_metadata` table function to get metadata from parquet files + ctx.register_udtf("parquet_metadata", Arc::new(ParquetMetadataFunc {})); let mut print_options = PrintOptions { format: args.format, @@ -328,6 +331,8 @@ fn extract_memory_pool_size(size: &str) -> Result { #[cfg(test)] mod tests { + use datafusion::assert_batches_eq; + use super::*; fn assert_conversion(input: &str, expected: Result) { @@ -385,4 +390,34 @@ mod tests { Ok(()) } + + #[tokio::test] + async fn test_parquet_metadata_works() -> Result<(), DataFusionError> { + let ctx = SessionContext::new(); + ctx.register_udtf("parquet_metadata", Arc::new(ParquetMetadataFunc {})); + + // input with single quote + let sql = + "SELECT * FROM parquet_metadata('../datafusion/core/tests/data/fixed_size_list_array.parquet')"; + let df = ctx.sql(sql).await?; + let rbs = df.collect().await?; + + let excepted = [ + "+-------------------------------------------------------------+--------------+--------------------+-----------------------+-----------------+-----------+-------------+------------+----------------+-------+-----------+-----------+------------------+----------------------+-----------------+-----------------+-------------+------------------------------+-------------------+------------------------+------------------+-----------------------+-------------------------+", + "| filename | row_group_id | row_group_num_rows | row_group_num_columns | row_group_bytes | column_id | file_offset | num_values | path_in_schema | type | stats_min | stats_max | stats_null_count | stats_distinct_count | stats_min_value | stats_max_value | compression | encodings | index_page_offset | dictionary_page_offset | data_page_offset | total_compressed_size | total_uncompressed_size |", + "+-------------------------------------------------------------+--------------+--------------------+-----------------------+-----------------+-----------+-------------+------------+----------------+-------+-----------+-----------+------------------+----------------------+-----------------+-----------------+-------------+------------------------------+-------------------+------------------------+------------------+-----------------------+-------------------------+", + "| ../datafusion/core/tests/data/fixed_size_list_array.parquet | 0 | 2 | 1 | 123 | 0 | 125 | 4 | \"f0.list.item\" | INT64 | 1 | 4 | 0 | | 1 | 4 | SNAPPY | [RLE_DICTIONARY, PLAIN, RLE] | | 4 | 46 | 121 | 123 |", + "+-------------------------------------------------------------+--------------+--------------------+-----------------------+-----------------+-----------+-------------+------------+----------------+-------+-----------+-----------+------------------+----------------------+-----------------+-----------------+-------------+------------------------------+-------------------+------------------------+------------------+-----------------------+-------------------------+", + ]; + assert_batches_eq!(excepted, &rbs); + + // input with double quote + let sql = + "SELECT * FROM parquet_metadata(\"../datafusion/core/tests/data/fixed_size_list_array.parquet\")"; + let df = ctx.sql(sql).await?; + let rbs = df.collect().await?; + assert_batches_eq!(excepted, &rbs); + + Ok(()) + } } diff --git a/datafusion-cli/src/object_storage.rs b/datafusion-cli/src/object_storage.rs index c39d1915eb43..9d79c7e0ec78 100644 --- a/datafusion-cli/src/object_storage.rs +++ b/datafusion-cli/src/object_storage.rs @@ -30,20 +30,23 @@ use url::Url; pub async fn get_s3_object_store_builder( url: &Url, - cmd: &CreateExternalTable, + cmd: &mut CreateExternalTable, ) -> Result { let bucket_name = get_bucket_name(url)?; let mut builder = AmazonS3Builder::from_env().with_bucket_name(bucket_name); if let (Some(access_key_id), Some(secret_access_key)) = ( - cmd.options.get("access_key_id"), - cmd.options.get("secret_access_key"), + // These options are datafusion-cli specific and must be removed before passing through to datafusion. + // Otherwise, a Configuration error will be raised. + cmd.options.remove("access_key_id"), + cmd.options.remove("secret_access_key"), ) { + println!("removing secret access key!"); builder = builder .with_access_key_id(access_key_id) .with_secret_access_key(secret_access_key); - if let Some(session_token) = cmd.options.get("session_token") { + if let Some(session_token) = cmd.options.remove("session_token") { builder = builder.with_token(session_token); } } else { @@ -66,7 +69,7 @@ pub async fn get_s3_object_store_builder( builder = builder.with_credentials(credentials); } - if let Some(region) = cmd.options.get("region") { + if let Some(region) = cmd.options.remove("region") { builder = builder.with_region(region); } @@ -99,7 +102,7 @@ impl CredentialProvider for S3CredentialProvider { pub fn get_oss_object_store_builder( url: &Url, - cmd: &CreateExternalTable, + cmd: &mut CreateExternalTable, ) -> Result { let bucket_name = get_bucket_name(url)?; let mut builder = AmazonS3Builder::from_env() @@ -109,15 +112,15 @@ pub fn get_oss_object_store_builder( .with_region("do_not_care"); if let (Some(access_key_id), Some(secret_access_key)) = ( - cmd.options.get("access_key_id"), - cmd.options.get("secret_access_key"), + cmd.options.remove("access_key_id"), + cmd.options.remove("secret_access_key"), ) { builder = builder .with_access_key_id(access_key_id) .with_secret_access_key(secret_access_key); } - if let Some(endpoint) = cmd.options.get("endpoint") { + if let Some(endpoint) = cmd.options.remove("endpoint") { builder = builder.with_endpoint(endpoint); } @@ -126,21 +129,21 @@ pub fn get_oss_object_store_builder( pub fn get_gcs_object_store_builder( url: &Url, - cmd: &CreateExternalTable, + cmd: &mut CreateExternalTable, ) -> Result { let bucket_name = get_bucket_name(url)?; let mut builder = GoogleCloudStorageBuilder::from_env().with_bucket_name(bucket_name); - if let Some(service_account_path) = cmd.options.get("service_account_path") { + if let Some(service_account_path) = cmd.options.remove("service_account_path") { builder = builder.with_service_account_path(service_account_path); } - if let Some(service_account_key) = cmd.options.get("service_account_key") { + if let Some(service_account_key) = cmd.options.remove("service_account_key") { builder = builder.with_service_account_key(service_account_key); } if let Some(application_credentials_path) = - cmd.options.get("application_credentials_path") + cmd.options.remove("application_credentials_path") { builder = builder.with_application_credentials(application_credentials_path); } @@ -180,9 +183,9 @@ mod tests { let sql = format!("CREATE EXTERNAL TABLE test STORED AS PARQUET OPTIONS('access_key_id' '{access_key_id}', 'secret_access_key' '{secret_access_key}', 'region' '{region}', 'session_token' {session_token}) LOCATION '{location}'"); let ctx = SessionContext::new(); - let plan = ctx.state().create_logical_plan(&sql).await?; + let mut plan = ctx.state().create_logical_plan(&sql).await?; - if let LogicalPlan::Ddl(DdlStatement::CreateExternalTable(cmd)) = &plan { + if let LogicalPlan::Ddl(DdlStatement::CreateExternalTable(cmd)) = &mut plan { let builder = get_s3_object_store_builder(table_url.as_ref(), cmd).await?; // get the actual configuration information, then assert_eq! let config = [ @@ -212,9 +215,9 @@ mod tests { let sql = format!("CREATE EXTERNAL TABLE test STORED AS PARQUET OPTIONS('access_key_id' '{access_key_id}', 'secret_access_key' '{secret_access_key}', 'endpoint' '{endpoint}') LOCATION '{location}'"); let ctx = SessionContext::new(); - let plan = ctx.state().create_logical_plan(&sql).await?; + let mut plan = ctx.state().create_logical_plan(&sql).await?; - if let LogicalPlan::Ddl(DdlStatement::CreateExternalTable(cmd)) = &plan { + if let LogicalPlan::Ddl(DdlStatement::CreateExternalTable(cmd)) = &mut plan { let builder = get_oss_object_store_builder(table_url.as_ref(), cmd)?; // get the actual configuration information, then assert_eq! let config = [ @@ -244,9 +247,9 @@ mod tests { let sql = format!("CREATE EXTERNAL TABLE test STORED AS PARQUET OPTIONS('service_account_path' '{service_account_path}', 'service_account_key' '{service_account_key}', 'application_credentials_path' '{application_credentials_path}') LOCATION '{location}'"); let ctx = SessionContext::new(); - let plan = ctx.state().create_logical_plan(&sql).await?; + let mut plan = ctx.state().create_logical_plan(&sql).await?; - if let LogicalPlan::Ddl(DdlStatement::CreateExternalTable(cmd)) = &plan { + if let LogicalPlan::Ddl(DdlStatement::CreateExternalTable(cmd)) = &mut plan { let builder = get_gcs_object_store_builder(table_url.as_ref(), cmd)?; // get the actual configuration information, then assert_eq! let config = [ diff --git a/datafusion-examples/README.md b/datafusion-examples/README.md index 9f7c9f99d14e..305422ccd0be 100644 --- a/datafusion-examples/README.md +++ b/datafusion-examples/README.md @@ -50,7 +50,7 @@ cargo run --example csv_sql - [`dataframe-to-s3.rs`](examples/external_dependency/dataframe-to-s3.rs): Run a query using a DataFrame against a parquet file from s3 - [`dataframe_in_memory.rs`](examples/dataframe_in_memory.rs): Run a query using a DataFrame against data in memory - [`deserialize_to_struct.rs`](examples/deserialize_to_struct.rs): Convert query results into rust structs using serde -- [`expr_api.rs`](examples/expr_api.rs): Use the `Expr` construction and simplification API +- [`expr_api.rs`](examples/expr_api.rs): Create, execute, simplify and anaylze `Expr`s - [`flight_sql_server.rs`](examples/flight/flight_sql_server.rs): Run DataFusion as a standalone process and execute SQL queries from JDBC clients - [`memtable.rs`](examples/memtable.rs): Create an query data in memory using SQL and `RecordBatch`es - [`parquet_sql.rs`](examples/parquet_sql.rs): Build and run a query plan from a SQL statement against a local Parquet file diff --git a/datafusion-examples/examples/custom_datasource.rs b/datafusion-examples/examples/custom_datasource.rs index 9f25a0b2fa47..69f9c9530e87 100644 --- a/datafusion-examples/examples/custom_datasource.rs +++ b/datafusion-examples/examples/custom_datasource.rs @@ -80,7 +80,7 @@ async fn search_accounts( timeout(Duration::from_secs(10), async move { let result = dataframe.collect().await.unwrap(); - let record_batch = result.get(0).unwrap(); + let record_batch = result.first().unwrap(); assert_eq!(expected_result_length, record_batch.column(1).len()); dbg!(record_batch.columns()); diff --git a/datafusion-examples/examples/expr_api.rs b/datafusion-examples/examples/expr_api.rs index 97abf4d552a9..715e1ff2dce6 100644 --- a/datafusion-examples/examples/expr_api.rs +++ b/datafusion-examples/examples/expr_api.rs @@ -15,28 +15,43 @@ // specific language governing permissions and limitations // under the License. +use arrow::array::{BooleanArray, Int32Array}; +use arrow::record_batch::RecordBatch; use datafusion::arrow::datatypes::{DataType, Field, Schema, TimeUnit}; use datafusion::error::Result; use datafusion::optimizer::simplify_expressions::{ExprSimplifier, SimplifyContext}; use datafusion::physical_expr::execution_props::ExecutionProps; +use datafusion::physical_expr::{ + analyze, create_physical_expr, AnalysisContext, ExprBoundaries, PhysicalExpr, +}; use datafusion::prelude::*; use datafusion_common::{ScalarValue, ToDFSchema}; use datafusion_expr::expr::BinaryExpr; -use datafusion_expr::Operator; +use datafusion_expr::interval_arithmetic::Interval; +use datafusion_expr::{ColumnarValue, ExprSchemable, Operator}; +use std::sync::Arc; /// This example demonstrates the DataFusion [`Expr`] API. /// /// DataFusion comes with a powerful and extensive system for /// representing and manipulating expressions such as `A + 5` and `X -/// IN ('foo', 'bar', 'baz')` and many other constructs. +/// IN ('foo', 'bar', 'baz')`. +/// +/// In addition to building and manipulating [`Expr`]s, DataFusion +/// also comes with APIs for evaluation, simplification, and analysis. +/// +/// The code in this example shows how to: +/// 1. Create [`Exprs`] using different APIs: [`main`]` +/// 2. Evaluate [`Exprs`] against data: [`evaluate_demo`] +/// 3. Simplify expressions: [`simplify_demo`] +/// 4. Analyze predicates for boundary ranges: [`range_analysis_demo`] #[tokio::main] async fn main() -> Result<()> { // The easiest way to do create expressions is to use the - // "fluent"-style API, like this: + // "fluent"-style API: let expr = col("a") + lit(5); - // this creates the same expression as the following though with - // much less code, + // The same same expression can be created directly, with much more code: let expr2 = Expr::BinaryExpr(BinaryExpr::new( Box::new(col("a")), Operator::Plus, @@ -44,15 +59,51 @@ async fn main() -> Result<()> { )); assert_eq!(expr, expr2); + // See how to evaluate expressions + evaluate_demo()?; + + // See how to simplify expressions simplify_demo()?; + // See how to analyze ranges in expressions + range_analysis_demo()?; + + Ok(()) +} + +/// DataFusion can also evaluate arbitrary expressions on Arrow arrays. +fn evaluate_demo() -> Result<()> { + // For example, let's say you have some integers in an array + let batch = RecordBatch::try_from_iter([( + "a", + Arc::new(Int32Array::from(vec![4, 5, 6, 7, 8, 7, 4])) as _, + )])?; + + // If you want to find all rows where the expression `a < 5 OR a = 8` is true + let expr = col("a").lt(lit(5)).or(col("a").eq(lit(8))); + + // First, you make a "physical expression" from the logical `Expr` + let physical_expr = physical_expr(&batch.schema(), expr)?; + + // Now, you can evaluate the expression against the RecordBatch + let result = physical_expr.evaluate(&batch)?; + + // The result contain an array that is true only for where `a < 5 OR a = 8` + let expected_result = Arc::new(BooleanArray::from(vec![ + true, false, false, false, true, false, true, + ])) as _; + assert!( + matches!(&result, ColumnarValue::Array(r) if r == &expected_result), + "result: {:?}", + result + ); + Ok(()) } -/// In addition to easy construction, DataFusion exposes APIs for -/// working with and simplifying such expressions that call into the -/// same powerful and extensive implementation used for the query -/// engine. +/// In addition to easy construction, DataFusion exposes APIs for simplifying +/// such expression so they are more efficient to evaluate. This code is also +/// used by the query engine to optimize queries. fn simplify_demo() -> Result<()> { // For example, lets say you have has created an expression such // ts = to_timestamp("2020-09-08T12:00:00+00:00") @@ -94,7 +145,7 @@ fn simplify_demo() -> Result<()> { make_field("b", DataType::Boolean), ]) .to_dfschema_ref()?; - let context = SimplifyContext::new(&props).with_schema(schema); + let context = SimplifyContext::new(&props).with_schema(schema.clone()); let simplifier = ExprSimplifier::new(context); // basic arithmetic simplification @@ -120,6 +171,64 @@ fn simplify_demo() -> Result<()> { col("i").lt(lit(10)) ); + // String --> Date simplification + // `cast('2020-09-01' as date)` --> 18500 + assert_eq!( + simplifier.simplify(lit("2020-09-01").cast_to(&DataType::Date32, &schema)?)?, + lit(ScalarValue::Date32(Some(18506))) + ); + + Ok(()) +} + +/// DataFusion also has APIs for analyzing predicates (boolean expressions) to +/// determine any ranges restrictions on the inputs required for the predicate +/// evaluate to true. +fn range_analysis_demo() -> Result<()> { + // For example, let's say you are interested in finding data for all days + // in the month of September, 2020 + let september_1 = ScalarValue::Date32(Some(18506)); // 2020-09-01 + let october_1 = ScalarValue::Date32(Some(18536)); // 2020-10-01 + + // The predicate to find all such days could be + // `date > '2020-09-01' AND date < '2020-10-01'` + let expr = col("date") + .gt(lit(september_1.clone())) + .and(col("date").lt(lit(october_1.clone()))); + + // Using the analysis API, DataFusion can determine that the value of `date` + // must be in the range `['2020-09-01', '2020-10-01']`. If your data is + // organized in files according to day, this information permits skipping + // entire files without reading them. + // + // While this simple example could be handled with a special case, the + // DataFusion API handles arbitrary expressions (so for example, you don't + // have to handle the case where the predicate clauses are reversed such as + // `date < '2020-10-01' AND date > '2020-09-01'` + + // As always, we need to tell DataFusion the type of column "date" + let schema = Schema::new(vec![make_field("date", DataType::Date32)]); + + // You can provide DataFusion any known boundaries on the values of `date` + // (for example, maybe you know you only have data up to `2020-09-15`), but + // in this case, let's say we don't know any boundaries beforehand so we use + // `try_new_unknown` + let boundaries = ExprBoundaries::try_new_unbounded(&schema)?; + + // Now, we invoke the analysis code to perform the range analysis + let physical_expr = physical_expr(&schema, expr)?; + let analysis_result = + analyze(&physical_expr, AnalysisContext::new(boundaries), &schema)?; + + // The results of the analysis is an range, encoded as an `Interval`, for + // each column in the schema, that must be true in order for the predicate + // to be true. + // + // In this case, we can see that, as expected, `analyze` has figured out + // that in this case, `date` must be in the range `['2020-09-01', '2020-10-01']` + let expected_range = Interval::try_new(september_1, october_1)?; + assert_eq!(analysis_result.boundaries[0].interval, expected_range); + Ok(()) } @@ -132,3 +241,18 @@ fn make_ts_field(name: &str) -> Field { let tz = None; make_field(name, DataType::Timestamp(TimeUnit::Nanosecond, tz)) } + +/// Build a physical expression from a logical one, after applying simplification and type coercion +pub fn physical_expr(schema: &Schema, expr: Expr) -> Result> { + let df_schema = schema.clone().to_dfschema_ref()?; + + // Simplify + let props = ExecutionProps::new(); + let simplifier = + ExprSimplifier::new(SimplifyContext::new(&props).with_schema(df_schema.clone())); + + // apply type coercion here to ensure types match + let expr = simplifier.coerce(expr, df_schema.clone())?; + + create_physical_expr(&expr, df_schema.as_ref(), schema, &props) +} diff --git a/datafusion-examples/examples/memtable.rs b/datafusion-examples/examples/memtable.rs index bef8f3e5bb8f..5cce578039e7 100644 --- a/datafusion-examples/examples/memtable.rs +++ b/datafusion-examples/examples/memtable.rs @@ -40,7 +40,7 @@ async fn main() -> Result<()> { timeout(Duration::from_secs(10), async move { let result = dataframe.collect().await.unwrap(); - let record_batch = result.get(0).unwrap(); + let record_batch = result.first().unwrap(); assert_eq!(1, record_batch.column(0).len()); dbg!(record_batch.columns()); diff --git a/datafusion-examples/examples/simple_udtf.rs b/datafusion-examples/examples/simple_udtf.rs new file mode 100644 index 000000000000..e120c5e7bf8e --- /dev/null +++ b/datafusion-examples/examples/simple_udtf.rs @@ -0,0 +1,177 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::csv::reader::Format; +use arrow::csv::ReaderBuilder; +use async_trait::async_trait; +use datafusion::arrow::datatypes::SchemaRef; +use datafusion::arrow::record_batch::RecordBatch; +use datafusion::datasource::function::TableFunctionImpl; +use datafusion::datasource::TableProvider; +use datafusion::error::Result; +use datafusion::execution::context::{ExecutionProps, SessionState}; +use datafusion::physical_plan::memory::MemoryExec; +use datafusion::physical_plan::ExecutionPlan; +use datafusion::prelude::SessionContext; +use datafusion_common::{plan_err, DataFusionError, ScalarValue}; +use datafusion_expr::{Expr, TableType}; +use datafusion_optimizer::simplify_expressions::{ExprSimplifier, SimplifyContext}; +use std::fs::File; +use std::io::Seek; +use std::path::Path; +use std::sync::Arc; + +// To define your own table function, you only need to do the following 3 things: +// 1. Implement your own [`TableProvider`] +// 2. Implement your own [`TableFunctionImpl`] and return your [`TableProvider`] +// 3. Register the function using [`SessionContext::register_udtf`] + +/// This example demonstrates how to register a TableFunction +#[tokio::main] +async fn main() -> Result<()> { + // create local execution context + let ctx = SessionContext::new(); + + // register the table function that will be called in SQL statements by `read_csv` + ctx.register_udtf("read_csv", Arc::new(LocalCsvTableFunc {})); + + let testdata = datafusion::test_util::arrow_test_data(); + let csv_file = format!("{testdata}/csv/aggregate_test_100.csv"); + + // Pass 2 arguments, read csv with at most 2 rows (simplify logic makes 1+1 --> 2) + let df = ctx + .sql(format!("SELECT * FROM read_csv('{csv_file}', 1 + 1);").as_str()) + .await?; + df.show().await?; + + // just run, return all rows + let df = ctx + .sql(format!("SELECT * FROM read_csv('{csv_file}');").as_str()) + .await?; + df.show().await?; + + Ok(()) +} + +/// Table Function that mimics the [`read_csv`] function in DuckDB. +/// +/// Usage: `read_csv(filename, [limit])` +/// +/// [`read_csv`]: https://duckdb.org/docs/data/csv/overview.html +struct LocalCsvTable { + schema: SchemaRef, + limit: Option, + batches: Vec, +} + +#[async_trait] +impl TableProvider for LocalCsvTable { + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn schema(&self) -> SchemaRef { + self.schema.clone() + } + + fn table_type(&self) -> TableType { + TableType::Base + } + + async fn scan( + &self, + _state: &SessionState, + projection: Option<&Vec>, + _filters: &[Expr], + _limit: Option, + ) -> Result> { + let batches = if let Some(max_return_lines) = self.limit { + // get max return rows from self.batches + let mut batches = vec![]; + let mut lines = 0; + for batch in &self.batches { + let batch_lines = batch.num_rows(); + if lines + batch_lines > max_return_lines { + let batch_lines = max_return_lines - lines; + batches.push(batch.slice(0, batch_lines)); + break; + } else { + batches.push(batch.clone()); + lines += batch_lines; + } + } + batches + } else { + self.batches.clone() + }; + Ok(Arc::new(MemoryExec::try_new( + &[batches], + TableProvider::schema(self), + projection.cloned(), + )?)) + } +} +struct LocalCsvTableFunc {} + +impl TableFunctionImpl for LocalCsvTableFunc { + fn call(&self, exprs: &[Expr]) -> Result> { + let Some(Expr::Literal(ScalarValue::Utf8(Some(ref path)))) = exprs.first() else { + return plan_err!("read_csv requires at least one string argument"); + }; + + let limit = exprs + .get(1) + .map(|expr| { + // try to simpify the expression, so 1+2 becomes 3, for example + let execution_props = ExecutionProps::new(); + let info = SimplifyContext::new(&execution_props); + let expr = ExprSimplifier::new(info).simplify(expr.clone())?; + + if let Expr::Literal(ScalarValue::Int64(Some(limit))) = expr { + Ok(limit as usize) + } else { + plan_err!("Limit must be an integer") + } + }) + .transpose()?; + + let (schema, batches) = read_csv_batches(path)?; + + let table = LocalCsvTable { + schema, + limit, + batches, + }; + Ok(Arc::new(table)) + } +} + +fn read_csv_batches(csv_path: impl AsRef) -> Result<(SchemaRef, Vec)> { + let mut file = File::open(csv_path)?; + let (schema, _) = Format::default().infer_schema(&mut file, None)?; + file.rewind()?; + + let reader = ReaderBuilder::new(Arc::new(schema.clone())) + .with_header(true) + .build(file)?; + let mut batches = vec![]; + for bacth in reader { + batches.push(bacth?); + } + let schema = Arc::new(schema); + Ok((schema, batches)) +} diff --git a/datafusion/common/src/config.rs b/datafusion/common/src/config.rs index ba2072ecc151..03fb5ea320a0 100644 --- a/datafusion/common/src/config.rs +++ b/datafusion/common/src/config.rs @@ -524,6 +524,11 @@ config_namespace! { /// The maximum estimated size in bytes for one input side of a HashJoin /// will be collected into a single partition pub hash_join_single_partition_threshold: usize, default = 1024 * 1024 + + /// The default filter selectivity used by Filter Statistics + /// when an exact selectivity cannot be determined. Valid values are + /// between 0 (no selectivity) and 100 (all rows are selected). + pub default_filter_selectivity: u8, default = 20 } } @@ -877,6 +882,7 @@ config_field!(String); config_field!(bool); config_field!(usize); config_field!(f64); +config_field!(u8); config_field!(u64); /// An implementation trait used to recursively walk configuration diff --git a/datafusion/common/src/dfschema.rs b/datafusion/common/src/dfschema.rs index 52cd85675824..e06f947ad5e7 100644 --- a/datafusion/common/src/dfschema.rs +++ b/datafusion/common/src/dfschema.rs @@ -199,9 +199,16 @@ impl DFSchema { pub fn with_functional_dependencies( mut self, functional_dependencies: FunctionalDependencies, - ) -> Self { - self.functional_dependencies = functional_dependencies; - self + ) -> Result { + if functional_dependencies.is_valid(self.fields.len()) { + self.functional_dependencies = functional_dependencies; + Ok(self) + } else { + _plan_err!( + "Invalid functional dependency: {:?}", + functional_dependencies + ) + } } /// Create a new schema that contains the fields from this schema followed by the fields @@ -1476,8 +1483,8 @@ mod tests { DFSchema::new_with_metadata([a, b].to_vec(), HashMap::new()).unwrap(), ); let schema: Schema = df_schema.as_ref().clone().into(); - let a_df = df_schema.fields.get(0).unwrap().field(); - let a_arrow = schema.fields.get(0).unwrap(); + let a_df = df_schema.fields.first().unwrap().field(); + let a_arrow = schema.fields.first().unwrap(); assert_eq!(a_df.metadata(), a_arrow.metadata()) } diff --git a/datafusion/common/src/error.rs b/datafusion/common/src/error.rs index 9114c669ab8b..4ae30ae86cdd 100644 --- a/datafusion/common/src/error.rs +++ b/datafusion/common/src/error.rs @@ -564,18 +564,16 @@ mod test { assert_eq!( err.split(DataFusionError::BACK_TRACE_SEP) .collect::>() - .get(0) + .first() .unwrap(), &"Error during planning: Err" ); - assert!( - err.split(DataFusionError::BACK_TRACE_SEP) - .collect::>() - .get(1) - .unwrap() - .len() - > 0 - ); + assert!(!err + .split(DataFusionError::BACK_TRACE_SEP) + .collect::>() + .get(1) + .unwrap() + .is_empty()); } #[cfg(not(feature = "backtrace"))] diff --git a/datafusion/common/src/functional_dependencies.rs b/datafusion/common/src/functional_dependencies.rs index fbddcddab4bc..1cb1751d713e 100644 --- a/datafusion/common/src/functional_dependencies.rs +++ b/datafusion/common/src/functional_dependencies.rs @@ -24,6 +24,7 @@ use std::ops::Deref; use std::vec::IntoIter; use crate::error::_plan_err; +use crate::utils::{merge_and_order_indices, set_difference}; use crate::{DFSchema, DFSchemaRef, DataFusionError, JoinType, Result}; use sqlparser::ast::TableConstraint; @@ -271,6 +272,29 @@ impl FunctionalDependencies { self.deps.extend(other.deps); } + /// Sanity checks if functional dependencies are valid. For example, if + /// there are 10 fields, we cannot receive any index further than 9. + pub fn is_valid(&self, n_field: usize) -> bool { + self.deps.iter().all( + |FunctionalDependence { + source_indices, + target_indices, + .. + }| { + source_indices + .iter() + .max() + .map(|&max_index| max_index < n_field) + .unwrap_or(true) + && target_indices + .iter() + .max() + .map(|&max_index| max_index < n_field) + .unwrap_or(true) + }, + ) + } + /// Adds the `offset` value to `source_indices` and `target_indices` for /// each functional dependency. pub fn add_offset(&mut self, offset: usize) { @@ -413,6 +437,14 @@ impl FunctionalDependencies { } } +impl Deref for FunctionalDependencies { + type Target = [FunctionalDependence]; + + fn deref(&self) -> &Self::Target { + self.deps.as_slice() + } +} + /// Calculates functional dependencies for aggregate output, when there is a GROUP BY expression. pub fn aggregate_functional_dependencies( aggr_input_schema: &DFSchema, @@ -434,44 +466,56 @@ pub fn aggregate_functional_dependencies( } in &func_dependencies.deps { // Keep source indices in a `HashSet` to prevent duplicate entries: - let mut new_source_indices = HashSet::new(); + let mut new_source_indices = vec![]; + let mut new_source_field_names = vec![]; let source_field_names = source_indices .iter() .map(|&idx| aggr_input_fields[idx].qualified_name()) .collect::>(); + for (idx, group_by_expr_name) in group_by_expr_names.iter().enumerate() { // When one of the input determinant expressions matches with // the GROUP BY expression, add the index of the GROUP BY // expression as a new determinant key: if source_field_names.contains(group_by_expr_name) { - new_source_indices.insert(idx); + new_source_indices.push(idx); + new_source_field_names.push(group_by_expr_name.clone()); } } + let existing_target_indices = + get_target_functional_dependencies(aggr_input_schema, group_by_expr_names); + let new_target_indices = get_target_functional_dependencies( + aggr_input_schema, + &new_source_field_names, + ); + let mode = if existing_target_indices == new_target_indices + && new_target_indices.is_some() + { + // If dependency covers all GROUP BY expressions, mode will be `Single`: + Dependency::Single + } else { + // Otherwise, existing mode is preserved: + *mode + }; // All of the composite indices occur in the GROUP BY expression: if new_source_indices.len() == source_indices.len() { aggregate_func_dependencies.push( FunctionalDependence::new( - new_source_indices.into_iter().collect(), + new_source_indices, target_indices.clone(), *nullable, ) - // input uniqueness stays the same when GROUP BY matches with input functional dependence determinants - .with_mode(*mode), + .with_mode(mode), ); } } + // If we have a single GROUP BY key, we can guarantee uniqueness after // aggregation: if group_by_expr_names.len() == 1 { // If `source_indices` contain 0, delete this functional dependency // as it will be added anyway with mode `Dependency::Single`: - if let Some(idx) = aggregate_func_dependencies - .iter() - .position(|item| item.source_indices.contains(&0)) - { - // Delete the functional dependency that contains zeroth idx: - aggregate_func_dependencies.remove(idx); - } + aggregate_func_dependencies.retain(|item| !item.source_indices.contains(&0)); // Add a new functional dependency associated with the whole table: aggregate_func_dependencies.push( // Use nullable property of the group by expression @@ -519,8 +563,61 @@ pub fn get_target_functional_dependencies( combined_target_indices.extend(target_indices.iter()); } } - (!combined_target_indices.is_empty()) - .then_some(combined_target_indices.iter().cloned().collect::>()) + (!combined_target_indices.is_empty()).then_some({ + let mut result = combined_target_indices.into_iter().collect::>(); + result.sort(); + result + }) +} + +/// Returns indices for the minimal subset of GROUP BY expressions that are +/// functionally equivalent to the original set of GROUP BY expressions. +pub fn get_required_group_by_exprs_indices( + schema: &DFSchema, + group_by_expr_names: &[String], +) -> Option> { + let dependencies = schema.functional_dependencies(); + let field_names = schema + .fields() + .iter() + .map(|item| item.qualified_name()) + .collect::>(); + let mut groupby_expr_indices = group_by_expr_names + .iter() + .map(|group_by_expr_name| { + field_names + .iter() + .position(|field_name| field_name == group_by_expr_name) + }) + .collect::>>()?; + + groupby_expr_indices.sort(); + for FunctionalDependence { + source_indices, + target_indices, + .. + } in &dependencies.deps + { + if source_indices + .iter() + .all(|source_idx| groupby_expr_indices.contains(source_idx)) + { + // If all source indices are among GROUP BY expression indices, we + // can remove target indices from GROUP BY expression indices and + // use source indices instead. + groupby_expr_indices = set_difference(&groupby_expr_indices, target_indices); + groupby_expr_indices = + merge_and_order_indices(groupby_expr_indices, source_indices); + } + } + groupby_expr_indices + .iter() + .map(|idx| { + group_by_expr_names + .iter() + .position(|name| &field_names[*idx] == name) + }) + .collect() } /// Updates entries inside the `entries` vector with their corresponding diff --git a/datafusion/common/src/lib.rs b/datafusion/common/src/lib.rs index 90fb4a88149c..ed547782e4a5 100644 --- a/datafusion/common/src/lib.rs +++ b/datafusion/common/src/lib.rs @@ -20,6 +20,7 @@ mod dfschema; mod error; mod functional_dependencies; mod join_type; +mod param_value; #[cfg(feature = "pyarrow")] mod pyarrow; mod schema_reference; @@ -55,10 +56,12 @@ pub use file_options::file_type::{ }; pub use file_options::FileTypeWriterOptions; pub use functional_dependencies::{ - aggregate_functional_dependencies, get_target_functional_dependencies, Constraint, - Constraints, Dependency, FunctionalDependence, FunctionalDependencies, + aggregate_functional_dependencies, get_required_group_by_exprs_indices, + get_target_functional_dependencies, Constraint, Constraints, Dependency, + FunctionalDependence, FunctionalDependencies, }; pub use join_type::{JoinConstraint, JoinSide, JoinType}; +pub use param_value::ParamValues; pub use scalar::{ScalarType, ScalarValue}; pub use schema_reference::{OwnedSchemaReference, SchemaReference}; pub use stats::{ColumnStatistics, Statistics}; diff --git a/datafusion/common/src/param_value.rs b/datafusion/common/src/param_value.rs new file mode 100644 index 000000000000..253c312b66d5 --- /dev/null +++ b/datafusion/common/src/param_value.rs @@ -0,0 +1,149 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::error::{_internal_err, _plan_err}; +use crate::{DataFusionError, Result, ScalarValue}; +use arrow_schema::DataType; +use std::collections::HashMap; + +/// The parameter value corresponding to the placeholder +#[derive(Debug, Clone)] +pub enum ParamValues { + /// for positional query parameters, like select * from test where a > $1 and b = $2 + LIST(Vec), + /// for named query parameters, like select * from test where a > $foo and b = $goo + MAP(HashMap), +} + +impl ParamValues { + /// Verify parameter list length and type + pub fn verify(&self, expect: &Vec) -> Result<()> { + match self { + ParamValues::LIST(list) => { + // Verify if the number of params matches the number of values + if expect.len() != list.len() { + return _plan_err!( + "Expected {} parameters, got {}", + expect.len(), + list.len() + ); + } + + // Verify if the types of the params matches the types of the values + let iter = expect.iter().zip(list.iter()); + for (i, (param_type, value)) in iter.enumerate() { + if *param_type != value.data_type() { + return _plan_err!( + "Expected parameter of type {:?}, got {:?} at index {}", + param_type, + value.data_type(), + i + ); + } + } + Ok(()) + } + ParamValues::MAP(_) => { + // If it is a named query, variables can be reused, + // but the lengths are not necessarily equal + Ok(()) + } + } + } + + pub fn get_placeholders_with_values( + &self, + id: &String, + data_type: &Option, + ) -> Result { + match self { + ParamValues::LIST(list) => { + if id.is_empty() || id == "$0" { + return _plan_err!("Empty placeholder id"); + } + // convert id (in format $1, $2, ..) to idx (0, 1, ..) + let idx = id[1..].parse::().map_err(|e| { + DataFusionError::Internal(format!( + "Failed to parse placeholder id: {e}" + )) + })? - 1; + // value at the idx-th position in param_values should be the value for the placeholder + let value = list.get(idx).ok_or_else(|| { + DataFusionError::Internal(format!( + "No value found for placeholder with id {id}" + )) + })?; + // check if the data type of the value matches the data type of the placeholder + if Some(value.data_type()) != *data_type { + return _internal_err!( + "Placeholder value type mismatch: expected {:?}, got {:?}", + data_type, + value.data_type() + ); + } + Ok(value.clone()) + } + ParamValues::MAP(map) => { + // convert name (in format $a, $b, ..) to mapped values (a, b, ..) + let name = &id[1..]; + // value at the name position in param_values should be the value for the placeholder + let value = map.get(name).ok_or_else(|| { + DataFusionError::Internal(format!( + "No value found for placeholder with name {id}" + )) + })?; + // check if the data type of the value matches the data type of the placeholder + if Some(value.data_type()) != *data_type { + return _internal_err!( + "Placeholder value type mismatch: expected {:?}, got {:?}", + data_type, + value.data_type() + ); + } + Ok(value.clone()) + } + } + } +} + +impl From> for ParamValues { + fn from(value: Vec) -> Self { + Self::LIST(value) + } +} + +impl From> for ParamValues +where + K: Into, +{ + fn from(value: Vec<(K, ScalarValue)>) -> Self { + let value: HashMap = + value.into_iter().map(|(k, v)| (k.into(), v)).collect(); + Self::MAP(value) + } +} + +impl From> for ParamValues +where + K: Into, +{ + fn from(value: HashMap) -> Self { + let value: HashMap = + value.into_iter().map(|(k, v)| (k.into(), v)).collect(); + Self::MAP(value) + } +} diff --git a/datafusion/common/src/pyarrow.rs b/datafusion/common/src/pyarrow.rs index aa0153919360..f4356477532f 100644 --- a/datafusion/common/src/pyarrow.rs +++ b/datafusion/common/src/pyarrow.rs @@ -119,7 +119,7 @@ mod tests { ScalarValue::Boolean(Some(true)), ScalarValue::Int32(Some(23)), ScalarValue::Float64(Some(12.34)), - ScalarValue::Utf8(Some("Hello!".to_string())), + ScalarValue::from("Hello!"), ScalarValue::Date32(Some(1234)), ]; diff --git a/datafusion/common/src/scalar.rs b/datafusion/common/src/scalar.rs index 3431d71468ea..d730fbf89b72 100644 --- a/datafusion/common/src/scalar.rs +++ b/datafusion/common/src/scalar.rs @@ -31,7 +31,6 @@ use crate::cast::{ use crate::error::{DataFusionError, Result, _internal_err, _not_impl_err}; use crate::hash_utils::create_hashes; use crate::utils::{array_into_large_list_array, array_into_list_array}; -use arrow::buffer::{NullBuffer, OffsetBuffer}; use arrow::compute::kernels::numeric::*; use arrow::datatypes::{i256, Fields, SchemaBuilder}; use arrow::util::display::{ArrayFormatter, FormatOptions}; @@ -39,20 +38,63 @@ use arrow::{ array::*, compute::kernels::cast::{cast_with_options, CastOptions}, datatypes::{ - ArrowDictionaryKeyType, ArrowNativeType, DataType, Field, Float32Type, - Float64Type, Int16Type, Int32Type, Int64Type, Int8Type, IntervalDayTimeType, - IntervalMonthDayNanoType, IntervalUnit, IntervalYearMonthType, TimeUnit, - TimestampMicrosecondType, TimestampMillisecondType, TimestampNanosecondType, - TimestampSecondType, UInt16Type, UInt32Type, UInt64Type, UInt8Type, - DECIMAL128_MAX_PRECISION, + ArrowDictionaryKeyType, ArrowNativeType, DataType, Field, Float32Type, Int16Type, + Int32Type, Int64Type, Int8Type, IntervalDayTimeType, IntervalMonthDayNanoType, + IntervalUnit, IntervalYearMonthType, TimeUnit, TimestampMicrosecondType, + TimestampMillisecondType, TimestampNanosecondType, TimestampSecondType, + UInt16Type, UInt32Type, UInt64Type, UInt8Type, DECIMAL128_MAX_PRECISION, }, }; use arrow_array::cast::as_list_array; +use arrow_array::types::ArrowTimestampType; use arrow_array::{ArrowNativeTypeOp, Scalar}; -/// Represents a dynamically typed, nullable single value. -/// This is the single-valued counter-part to arrow's [`Array`]. +/// A dynamically typed, nullable single value, (the single-valued counter-part +/// to arrow's [`Array`]) /// +/// # Performance +/// +/// In general, please use arrow [`Array`]s rather than [`ScalarValue`] whenever +/// possible, as it is far more efficient for multiple values. +/// +/// # Example +/// ``` +/// # use datafusion_common::ScalarValue; +/// // Create single scalar value for an Int32 value +/// let s1 = ScalarValue::Int32(Some(10)); +/// +/// // You can also create values using the From impl: +/// let s2 = ScalarValue::from(10i32); +/// assert_eq!(s1, s2); +/// ``` +/// +/// # Null Handling +/// +/// `ScalarValue` represents null values in the same way as Arrow. Nulls are +/// "typed" in the sense that a null value in an [`Int32Array`] is different +/// than a null value in a [`Float64Array`], and is different than the values in +/// a [`NullArray`]. +/// +/// ``` +/// # fn main() -> datafusion_common::Result<()> { +/// # use std::collections::hash_set::Difference; +/// # use datafusion_common::ScalarValue; +/// # use arrow::datatypes::DataType; +/// // You can create a 'null' Int32 value directly: +/// let s1 = ScalarValue::Int32(None); +/// +/// // You can also create a null value for a given datatype: +/// let s2 = ScalarValue::try_from(&DataType::Int32)?; +/// assert_eq!(s1, s2); +/// +/// // Note that this is DIFFERENT than a `ScalarValue::Null` +/// let s3 = ScalarValue::Null; +/// assert_ne!(s1, s3); +/// # Ok(()) +/// # } +/// ``` +/// +/// # Further Reading /// See [datatypes](https://arrow.apache.org/docs/python/api/datatypes.html) for /// details on datatypes and the [format](https://github.com/apache/arrow/blob/master/format/Schema.fbs#L354-L375) /// for the definitive reference. @@ -317,69 +359,47 @@ impl PartialOrd for ScalarValue { (FixedSizeBinary(_, _), _) => None, (LargeBinary(v1), LargeBinary(v2)) => v1.partial_cmp(v2), (LargeBinary(_), _) => None, - (List(arr1), List(arr2)) | (FixedSizeList(arr1), FixedSizeList(arr2)) => { - if arr1.data_type() == arr2.data_type() { - let list_arr1 = as_list_array(arr1); - let list_arr2 = as_list_array(arr2); - if list_arr1.len() != list_arr2.len() { - return None; - } - for i in 0..list_arr1.len() { - let arr1 = list_arr1.value(i); - let arr2 = list_arr2.value(i); - - let lt_res = - arrow::compute::kernels::cmp::lt(&arr1, &arr2).ok()?; - let eq_res = - arrow::compute::kernels::cmp::eq(&arr1, &arr2).ok()?; - - for j in 0..lt_res.len() { - if lt_res.is_valid(j) && lt_res.value(j) { - return Some(Ordering::Less); - } - if eq_res.is_valid(j) && !eq_res.value(j) { - return Some(Ordering::Greater); - } - } + (List(arr1), List(arr2)) + | (FixedSizeList(arr1), FixedSizeList(arr2)) + | (LargeList(arr1), LargeList(arr2)) => { + // ScalarValue::List / ScalarValue::FixedSizeList / ScalarValue::LargeList are ensure to have length 1 + assert_eq!(arr1.len(), 1); + assert_eq!(arr2.len(), 1); + + if arr1.data_type() != arr2.data_type() { + return None; + } + + fn first_array_for_list(arr: &ArrayRef) -> ArrayRef { + if let Some(arr) = arr.as_list_opt::() { + arr.value(0) + } else if let Some(arr) = arr.as_list_opt::() { + arr.value(0) + } else if let Some(arr) = arr.as_fixed_size_list_opt() { + arr.value(0) + } else { + unreachable!("Since only List / LargeList / FixedSizeList are supported, this should never happen") } - Some(Ordering::Equal) - } else { - None } - } - (LargeList(arr1), LargeList(arr2)) => { - if arr1.data_type() == arr2.data_type() { - let list_arr1 = as_large_list_array(arr1); - let list_arr2 = as_large_list_array(arr2); - if list_arr1.len() != list_arr2.len() { - return None; + + let arr1 = first_array_for_list(arr1); + let arr2 = first_array_for_list(arr2); + + let lt_res = arrow::compute::kernels::cmp::lt(&arr1, &arr2).ok()?; + let eq_res = arrow::compute::kernels::cmp::eq(&arr1, &arr2).ok()?; + + for j in 0..lt_res.len() { + if lt_res.is_valid(j) && lt_res.value(j) { + return Some(Ordering::Less); } - for i in 0..list_arr1.len() { - let arr1 = list_arr1.value(i); - let arr2 = list_arr2.value(i); - - let lt_res = - arrow::compute::kernels::cmp::lt(&arr1, &arr2).ok()?; - let eq_res = - arrow::compute::kernels::cmp::eq(&arr1, &arr2).ok()?; - - for j in 0..lt_res.len() { - if lt_res.is_valid(j) && lt_res.value(j) { - return Some(Ordering::Less); - } - if eq_res.is_valid(j) && !eq_res.value(j) { - return Some(Ordering::Greater); - } - } + if eq_res.is_valid(j) && !eq_res.value(j) { + return Some(Ordering::Greater); } - Some(Ordering::Equal) - } else { - None } + + Some(Ordering::Equal) } - (List(_), _) => None, - (LargeList(_), _) => None, - (FixedSizeList(_), _) => None, + (List(_), _) | (LargeList(_), _) | (FixedSizeList(_), _) => None, (Date32(v1), Date32(v2)) => v1.partial_cmp(v2), (Date32(_), _) => None, (Date64(v1), Date64(v2)) => v1.partial_cmp(v2), @@ -731,7 +751,7 @@ impl ScalarValue { /// Returns a [`ScalarValue::Utf8`] representing `val` pub fn new_utf8(val: impl Into) -> Self { - ScalarValue::Utf8(Some(val.into())) + ScalarValue::from(val.into()) } /// Returns a [`ScalarValue::IntervalYearMonth`] representing @@ -755,6 +775,20 @@ impl ScalarValue { ScalarValue::IntervalMonthDayNano(Some(val)) } + /// Returns a [`ScalarValue`] representing + /// `value` and `tz_opt` timezone + pub fn new_timestamp( + value: Option, + tz_opt: Option>, + ) -> Self { + match T::UNIT { + TimeUnit::Second => ScalarValue::TimestampSecond(value, tz_opt), + TimeUnit::Millisecond => ScalarValue::TimestampMillisecond(value, tz_opt), + TimeUnit::Microsecond => ScalarValue::TimestampMicrosecond(value, tz_opt), + TimeUnit::Nanosecond => ScalarValue::TimestampNanosecond(value, tz_opt), + } + } + /// Create a zero value in the given type. pub fn new_zero(datatype: &DataType) -> Result { assert!(datatype.is_primitive()); @@ -1325,103 +1359,36 @@ impl ScalarValue { }}; } - macro_rules! build_array_list_primitive { - ($ARRAY_TY:ident, $SCALAR_TY:ident, $NATIVE_TYPE:ident, $LIST_TY:ident, $SCALAR_LIST:pat) => {{ - Ok::(Arc::new($LIST_TY::from_iter_primitive::<$ARRAY_TY, _, _>( - scalars.into_iter().map(|x| match x{ - ScalarValue::List(arr) if matches!(x, $SCALAR_LIST) => { - // `ScalarValue::List` contains a single element `ListArray`. - let list_arr = as_list_array(&arr); - if list_arr.is_null(0) { - Ok(None) - } else { - let primitive_arr = - list_arr.values().as_primitive::<$ARRAY_TY>(); - Ok(Some( - primitive_arr.into_iter().collect::>>(), - )) - } - } - ScalarValue::LargeList(arr) if matches!(x, $SCALAR_LIST) =>{ - // `ScalarValue::List` contains a single element `ListArray`. - let list_arr = as_large_list_array(&arr); - if list_arr.is_null(0) { - Ok(None) - } else { - let primitive_arr = - list_arr.values().as_primitive::<$ARRAY_TY>(); - Ok(Some( - primitive_arr.into_iter().collect::>>(), - )) - } - } - sv => _internal_err!( - "Inconsistent types in ScalarValue::iter_to_array. \ - Expected {:?}, got {:?}", - data_type, sv - ), - }) - .collect::>>()?, - ))) - }}; - } - - macro_rules! build_array_list_string { - ($BUILDER:ident, $STRING_ARRAY:ident,$LIST_BUILDER:ident,$SCALAR_LIST:pat) => {{ - let mut builder = $LIST_BUILDER::new($BUILDER::new()); - for scalar in scalars.into_iter() { - match scalar { - ScalarValue::List(arr) if matches!(scalar, $SCALAR_LIST) => { - // `ScalarValue::List` contains a single element `ListArray`. - let list_arr = as_list_array(&arr); - - if list_arr.is_null(0) { - builder.append(false); - continue; - } - - let string_arr = $STRING_ARRAY(list_arr.values()); - - for v in string_arr.iter() { - if let Some(v) = v { - builder.values().append_value(v); - } else { - builder.values().append_null(); - } - } - builder.append(true); - } - ScalarValue::LargeList(arr) if matches!(scalar, $SCALAR_LIST) => { - // `ScalarValue::List` contains a single element `ListArray`. - let list_arr = as_large_list_array(&arr); - - if list_arr.is_null(0) { - builder.append(false); - continue; - } - - let string_arr = $STRING_ARRAY(list_arr.values()); - - for v in string_arr.iter() { - if let Some(v) = v { - builder.values().append_value(v); - } else { - builder.values().append_null(); - } - } - builder.append(true); - } - sv => { - return _internal_err!( - "Inconsistent types in ScalarValue::iter_to_array. \ - Expected List, got {:?}", - sv - ) - } - } + fn build_list_array( + scalars: impl IntoIterator, + ) -> Result { + let arrays = scalars + .into_iter() + .map(|s| s.to_array()) + .collect::>>()?; + + let capacity = Capacities::Array(arrays.iter().map(|arr| arr.len()).sum()); + // ScalarValue::List contains a single element ListArray. + let nulls = arrays + .iter() + .map(|arr| arr.is_null(0)) + .collect::>(); + let arrays_data = arrays.iter().map(|arr| arr.to_data()).collect::>(); + + let arrays_ref = arrays_data.iter().collect::>(); + let mut mutable = + MutableArrayData::with_capacities(arrays_ref, true, capacity); + + // ScalarValue::List contains a single element ListArray. + for (index, is_null) in (0..arrays.len()).zip(nulls.into_iter()) { + if is_null { + mutable.extend_nulls(1) + } else { + mutable.extend(index, 0, 1); } - Arc::new(builder.finish()) - }}; + } + let data = mutable.freeze(); + Ok(arrow_array::make_array(data)) } let array: ArrayRef = match &data_type { @@ -1498,228 +1465,7 @@ impl ScalarValue { DataType::Interval(IntervalUnit::MonthDayNano) => { build_array_primitive!(IntervalMonthDayNanoArray, IntervalMonthDayNano) } - DataType::List(fields) if fields.data_type() == &DataType::Int8 => { - build_array_list_primitive!( - Int8Type, - Int8, - i8, - ListArray, - ScalarValue::List(_) - )? - } - DataType::List(fields) if fields.data_type() == &DataType::Int16 => { - build_array_list_primitive!( - Int16Type, - Int16, - i16, - ListArray, - ScalarValue::List(_) - )? - } - DataType::List(fields) if fields.data_type() == &DataType::Int32 => { - build_array_list_primitive!( - Int32Type, - Int32, - i32, - ListArray, - ScalarValue::List(_) - )? - } - DataType::List(fields) if fields.data_type() == &DataType::Int64 => { - build_array_list_primitive!( - Int64Type, - Int64, - i64, - ListArray, - ScalarValue::List(_) - )? - } - DataType::List(fields) if fields.data_type() == &DataType::UInt8 => { - build_array_list_primitive!( - UInt8Type, - UInt8, - u8, - ListArray, - ScalarValue::List(_) - )? - } - DataType::List(fields) if fields.data_type() == &DataType::UInt16 => { - build_array_list_primitive!( - UInt16Type, - UInt16, - u16, - ListArray, - ScalarValue::List(_) - )? - } - DataType::List(fields) if fields.data_type() == &DataType::UInt32 => { - build_array_list_primitive!( - UInt32Type, - UInt32, - u32, - ListArray, - ScalarValue::List(_) - )? - } - DataType::List(fields) if fields.data_type() == &DataType::UInt64 => { - build_array_list_primitive!( - UInt64Type, - UInt64, - u64, - ListArray, - ScalarValue::List(_) - )? - } - DataType::List(fields) if fields.data_type() == &DataType::Float32 => { - build_array_list_primitive!( - Float32Type, - Float32, - f32, - ListArray, - ScalarValue::List(_) - )? - } - DataType::List(fields) if fields.data_type() == &DataType::Float64 => { - build_array_list_primitive!( - Float64Type, - Float64, - f64, - ListArray, - ScalarValue::List(_) - )? - } - DataType::List(fields) if fields.data_type() == &DataType::Utf8 => { - build_array_list_string!( - StringBuilder, - as_string_array, - ListBuilder, - ScalarValue::List(_) - ) - } - DataType::List(fields) if fields.data_type() == &DataType::LargeUtf8 => { - build_array_list_string!( - LargeStringBuilder, - as_largestring_array, - ListBuilder, - ScalarValue::List(_) - ) - } - DataType::List(_) => { - // Fallback case handling homogeneous lists with any ScalarValue element type - let list_array = ScalarValue::iter_to_array_list(scalars)?; - Arc::new(list_array) - } - DataType::LargeList(fields) if fields.data_type() == &DataType::Int8 => { - build_array_list_primitive!( - Int8Type, - Int8, - i8, - LargeListArray, - ScalarValue::LargeList(_) - )? - } - DataType::LargeList(fields) if fields.data_type() == &DataType::Int16 => { - build_array_list_primitive!( - Int16Type, - Int16, - i16, - LargeListArray, - ScalarValue::LargeList(_) - )? - } - DataType::LargeList(fields) if fields.data_type() == &DataType::Int32 => { - build_array_list_primitive!( - Int32Type, - Int32, - i32, - LargeListArray, - ScalarValue::LargeList(_) - )? - } - DataType::LargeList(fields) if fields.data_type() == &DataType::Int64 => { - build_array_list_primitive!( - Int64Type, - Int64, - i64, - LargeListArray, - ScalarValue::LargeList(_) - )? - } - DataType::LargeList(fields) if fields.data_type() == &DataType::UInt8 => { - build_array_list_primitive!( - UInt8Type, - UInt8, - u8, - LargeListArray, - ScalarValue::LargeList(_) - )? - } - DataType::LargeList(fields) if fields.data_type() == &DataType::UInt16 => { - build_array_list_primitive!( - UInt16Type, - UInt16, - u16, - LargeListArray, - ScalarValue::LargeList(_) - )? - } - DataType::LargeList(fields) if fields.data_type() == &DataType::UInt32 => { - build_array_list_primitive!( - UInt32Type, - UInt32, - u32, - LargeListArray, - ScalarValue::LargeList(_) - )? - } - DataType::LargeList(fields) if fields.data_type() == &DataType::UInt64 => { - build_array_list_primitive!( - UInt64Type, - UInt64, - u64, - LargeListArray, - ScalarValue::LargeList(_) - )? - } - DataType::LargeList(fields) if fields.data_type() == &DataType::Float32 => { - build_array_list_primitive!( - Float32Type, - Float32, - f32, - LargeListArray, - ScalarValue::LargeList(_) - )? - } - DataType::LargeList(fields) if fields.data_type() == &DataType::Float64 => { - build_array_list_primitive!( - Float64Type, - Float64, - f64, - LargeListArray, - ScalarValue::LargeList(_) - )? - } - DataType::LargeList(fields) if fields.data_type() == &DataType::Utf8 => { - build_array_list_string!( - StringBuilder, - as_string_array, - LargeListBuilder, - ScalarValue::LargeList(_) - ) - } - DataType::LargeList(fields) if fields.data_type() == &DataType::LargeUtf8 => { - build_array_list_string!( - LargeStringBuilder, - as_largestring_array, - LargeListBuilder, - ScalarValue::LargeList(_) - ) - } - DataType::LargeList(_) => { - // Fallback case handling homogeneous lists with any ScalarValue element type - let list_array = ScalarValue::iter_to_large_array_list(scalars)?; - Arc::new(list_array) - } + DataType::List(_) | DataType::LargeList(_) => build_list_array(scalars)?, DataType::Struct(fields) => { // Initialize a Vector to store the ScalarValues for each column let mut columns: Vec> = @@ -1899,116 +1645,6 @@ impl ScalarValue { Ok(array) } - /// This function build ListArray with nulls with nulls buffer. - fn iter_to_array_list( - scalars: impl IntoIterator, - ) -> Result { - let mut elements: Vec = vec![]; - let mut valid = BooleanBufferBuilder::new(0); - let mut offsets = vec![]; - - for scalar in scalars { - if let ScalarValue::List(arr) = scalar { - // `ScalarValue::List` contains a single element `ListArray`. - let list_arr = as_list_array(&arr); - - if list_arr.is_null(0) { - // Repeat previous offset index - offsets.push(0); - - // Element is null - valid.append(false); - } else { - let arr = list_arr.values().to_owned(); - offsets.push(arr.len()); - elements.push(arr); - - // Element is valid - valid.append(true); - } - } else { - return _internal_err!( - "Expected ScalarValue::List element. Received {scalar:?}" - ); - } - } - - // Concatenate element arrays to create single flat array - let element_arrays: Vec<&dyn Array> = - elements.iter().map(|a| a.as_ref()).collect(); - - let flat_array = match arrow::compute::concat(&element_arrays) { - Ok(flat_array) => flat_array, - Err(err) => return Err(DataFusionError::ArrowError(err)), - }; - - let buffer = valid.finish(); - - let list_array = ListArray::new( - Arc::new(Field::new("item", flat_array.data_type().clone(), true)), - OffsetBuffer::from_lengths(offsets), - flat_array, - Some(NullBuffer::new(buffer)), - ); - - Ok(list_array) - } - - /// This function build LargeListArray with nulls with nulls buffer. - fn iter_to_large_array_list( - scalars: impl IntoIterator, - ) -> Result { - let mut elements: Vec = vec![]; - let mut valid = BooleanBufferBuilder::new(0); - let mut offsets = vec![]; - - for scalar in scalars { - if let ScalarValue::List(arr) = scalar { - // `ScalarValue::List` contains a single element `ListArray`. - let list_arr = as_list_array(&arr); - - if list_arr.is_null(0) { - // Repeat previous offset index - offsets.push(0); - - // Element is null - valid.append(false); - } else { - let arr = list_arr.values().to_owned(); - offsets.push(arr.len()); - elements.push(arr); - - // Element is valid - valid.append(true); - } - } else { - return _internal_err!( - "Expected ScalarValue::List element. Received {scalar:?}" - ); - } - } - - // Concatenate element arrays to create single flat array - let element_arrays: Vec<&dyn Array> = - elements.iter().map(|a| a.as_ref()).collect(); - - let flat_array = match arrow::compute::concat(&element_arrays) { - Ok(flat_array) => flat_array, - Err(err) => return Err(DataFusionError::ArrowError(err)), - }; - - let buffer = valid.finish(); - - let list_array = LargeListArray::new( - Arc::new(Field::new("item", flat_array.data_type().clone(), true)), - OffsetBuffer::from_lengths(offsets), - flat_array, - Some(NullBuffer::new(buffer)), - ); - - Ok(list_array) - } - fn build_decimal_array( value: Option, precision: u8, @@ -2656,7 +2292,7 @@ impl ScalarValue { /// Try to parse `value` into a ScalarValue of type `target_type` pub fn try_from_string(value: String, target_type: &DataType) -> Result { - let value = ScalarValue::Utf8(Some(value)); + let value = ScalarValue::from(value); let cast_options = CastOptions { safe: false, format_options: Default::default(), @@ -3022,6 +2658,12 @@ impl FromStr for ScalarValue { } } +impl From for ScalarValue { + fn from(value: String) -> Self { + ScalarValue::Utf8(Some(value)) + } +} + impl From> for ScalarValue { fn from(value: Vec<(&str, ScalarValue)>) -> Self { let (fields, scalars): (SchemaBuilder, Vec<_>) = value @@ -3471,21 +3113,23 @@ impl ScalarType for TimestampNanosecondType { #[cfg(test)] mod tests { + use super::*; + use std::cmp::Ordering; use std::sync::Arc; + use chrono::NaiveDate; + use rand::Rng; + + use arrow::buffer::OffsetBuffer; use arrow::compute::kernels; use arrow::compute::{concat, is_null}; use arrow::datatypes::ArrowPrimitiveType; use arrow::util::pretty::pretty_format_columns; use arrow_array::ArrowNumericType; - use chrono::NaiveDate; - use rand::Rng; use crate::cast::{as_string_array, as_uint32_array, as_uint64_array}; - use super::*; - #[test] fn test_to_array_of_size_for_list() { let arr = ListArray::from_iter_primitive::(vec![Some(vec![ @@ -3532,9 +3176,9 @@ mod tests { #[test] fn test_list_to_array_string() { let scalars = vec![ - ScalarValue::Utf8(Some(String::from("rust"))), - ScalarValue::Utf8(Some(String::from("arrow"))), - ScalarValue::Utf8(Some(String::from("data-fusion"))), + ScalarValue::from("rust"), + ScalarValue::from("arrow"), + ScalarValue::from("data-fusion"), ]; let array = ScalarValue::new_list(scalars.as_slice(), &DataType::Utf8); @@ -3548,28 +3192,77 @@ mod tests { assert_eq!(result, &expected); } + fn build_list( + values: Vec>>>, + ) -> Vec { + values + .into_iter() + .map(|v| { + let arr = if v.is_some() { + Arc::new( + GenericListArray::::from_iter_primitive::( + vec![v], + ), + ) + } else if O::IS_LARGE { + new_null_array( + &DataType::LargeList(Arc::new(Field::new( + "item", + DataType::Int64, + true, + ))), + 1, + ) + } else { + new_null_array( + &DataType::List(Arc::new(Field::new( + "item", + DataType::Int64, + true, + ))), + 1, + ) + }; + + if O::IS_LARGE { + ScalarValue::LargeList(arr) + } else { + ScalarValue::List(arr) + } + }) + .collect() + } + #[test] fn iter_to_array_primitive_test() { - let scalars = vec![ - ScalarValue::List(Arc::new( - ListArray::from_iter_primitive::(vec![Some(vec![ - Some(1), - Some(2), - Some(3), - ])]), - )), - ScalarValue::List(Arc::new( - ListArray::from_iter_primitive::(vec![Some(vec![ - Some(4), - Some(5), - ])]), - )), - ]; + // List[[1,2,3]], List[null], List[[4,5]] + let scalars = build_list::(vec![ + Some(vec![Some(1), Some(2), Some(3)]), + None, + Some(vec![Some(4), Some(5)]), + ]); let array = ScalarValue::iter_to_array(scalars).unwrap(); let list_array = as_list_array(&array); + // List[[1,2,3], null, [4,5]] let expected = ListArray::from_iter_primitive::(vec![ Some(vec![Some(1), Some(2), Some(3)]), + None, + Some(vec![Some(4), Some(5)]), + ]); + assert_eq!(list_array, &expected); + + let scalars = build_list::(vec![ + Some(vec![Some(1), Some(2), Some(3)]), + None, + Some(vec![Some(4), Some(5)]), + ]); + + let array = ScalarValue::iter_to_array(scalars).unwrap(); + let list_array = as_large_list_array(&array); + let expected = LargeListArray::from_iter_primitive::(vec![ + Some(vec![Some(1), Some(2), Some(3)]), + None, Some(vec![Some(4), Some(5)]), ]); assert_eq!(list_array, &expected); @@ -3944,24 +3637,6 @@ mod tests { ])]), )); assert_eq!(a.partial_cmp(&b), Some(Ordering::Less)); - - let a = - ScalarValue::List(Arc::new( - ListArray::from_iter_primitive::(vec![ - Some(vec![Some(10), Some(2), Some(3)]), - None, - Some(vec![Some(10), Some(2), Some(3)]), - ]), - )); - let b = - ScalarValue::List(Arc::new( - ListArray::from_iter_primitive::(vec![ - Some(vec![Some(10), Some(2), Some(3)]), - None, - Some(vec![Some(10), Some(2), Some(3)]), - ]), - )); - assert_eq!(a.partial_cmp(&b), Some(Ordering::Equal)); } #[test] @@ -4645,6 +4320,16 @@ mod tests { ); } + #[test] + fn test_scalar_value_from_string() { + let scalar = ScalarValue::from("foo"); + assert_eq!(scalar, ScalarValue::Utf8(Some("foo".to_string()))); + let scalar = ScalarValue::from("foo".to_string()); + assert_eq!(scalar, ScalarValue::Utf8(Some("foo".to_string()))); + let scalar = ScalarValue::from_str("foo").unwrap(); + assert_eq!(scalar, ScalarValue::Utf8(Some("foo".to_string()))); + } + #[test] fn test_scalar_struct() { let field_a = Arc::new(Field::new("A", DataType::Int32, false)); @@ -4663,7 +4348,7 @@ mod tests { Some(vec![ ScalarValue::Int32(Some(23)), ScalarValue::Boolean(Some(false)), - ScalarValue::Utf8(Some("Hello".to_string())), + ScalarValue::from("Hello"), ScalarValue::from(vec![ ("e", ScalarValue::from(2i16)), ("f", ScalarValue::from(3i64)), @@ -4856,17 +4541,17 @@ mod tests { // Define struct scalars let s0 = ScalarValue::from(vec![ - ("A", ScalarValue::Utf8(Some(String::from("First")))), + ("A", ScalarValue::from("First")), ("primitive_list", l0), ]); let s1 = ScalarValue::from(vec![ - ("A", ScalarValue::Utf8(Some(String::from("Second")))), + ("A", ScalarValue::from("Second")), ("primitive_list", l1), ]); let s2 = ScalarValue::from(vec![ - ("A", ScalarValue::Utf8(Some(String::from("Third")))), + ("A", ScalarValue::from("Third")), ("primitive_list", l2), ]); @@ -5024,69 +4709,37 @@ mod tests { assert_eq!(array, &expected); } - #[test] - fn test_nested_lists() { - // Define inner list scalars - let a1 = ListArray::from_iter_primitive::(vec![Some(vec![ - Some(1), - Some(2), - Some(3), - ])]); - let a2 = ListArray::from_iter_primitive::(vec![Some(vec![ - Some(4), - Some(5), - ])]); - let l1 = ListArray::new( - Arc::new(Field::new( - "item", - DataType::List(Arc::new(Field::new("item", DataType::Int32, true))), - true, - )), - OffsetBuffer::::from_lengths([1, 1]), - arrow::compute::concat(&[&a1, &a2]).unwrap(), - None, - ); - - let a1 = - ListArray::from_iter_primitive::(vec![Some(vec![Some(6)])]); - let a2 = ListArray::from_iter_primitive::(vec![Some(vec![ - Some(7), - Some(8), - ])]); - let l2 = ListArray::new( - Arc::new(Field::new( - "item", - DataType::List(Arc::new(Field::new("item", DataType::Int32, true))), - true, - )), - OffsetBuffer::::from_lengths([1, 1]), - arrow::compute::concat(&[&a1, &a2]).unwrap(), - None, - ); - - let a1 = - ListArray::from_iter_primitive::(vec![Some(vec![Some(9)])]); - let l3 = ListArray::new( + fn build_2d_list(data: Vec>) -> ListArray { + let a1 = ListArray::from_iter_primitive::(vec![Some(data)]); + ListArray::new( Arc::new(Field::new( "item", DataType::List(Arc::new(Field::new("item", DataType::Int32, true))), true, )), OffsetBuffer::::from_lengths([1]), - arrow::compute::concat(&[&a1]).unwrap(), + Arc::new(a1), None, - ); + ) + } + + #[test] + fn test_nested_lists() { + // Define inner list scalars + let arr1 = build_2d_list(vec![Some(1), Some(2), Some(3)]); + let arr2 = build_2d_list(vec![Some(4), Some(5)]); + let arr3 = build_2d_list(vec![Some(6)]); let array = ScalarValue::iter_to_array(vec![ - ScalarValue::List(Arc::new(l1)), - ScalarValue::List(Arc::new(l2)), - ScalarValue::List(Arc::new(l3)), + ScalarValue::List(Arc::new(arr1)), + ScalarValue::List(Arc::new(arr2)), + ScalarValue::List(Arc::new(arr3)), ]) .unwrap(); let array = as_list_array(&array); // Construct expected array with array builders - let inner_builder = Int32Array::builder(8); + let inner_builder = Int32Array::builder(6); let middle_builder = ListBuilder::new(inner_builder); let mut outer_builder = ListBuilder::new(middle_builder); @@ -5094,6 +4747,7 @@ mod tests { outer_builder.values().values().append_value(2); outer_builder.values().values().append_value(3); outer_builder.values().append(true); + outer_builder.append(true); outer_builder.values().values().append_value(4); outer_builder.values().values().append_value(5); @@ -5102,14 +4756,6 @@ mod tests { outer_builder.values().values().append_value(6); outer_builder.values().append(true); - - outer_builder.values().values().append_value(7); - outer_builder.values().values().append_value(8); - outer_builder.values().append(true); - outer_builder.append(true); - - outer_builder.values().values().append_value(9); - outer_builder.values().append(true); outer_builder.append(true); let expected = outer_builder.finish(); @@ -5153,7 +4799,7 @@ mod tests { check_scalar_cast(ScalarValue::Float64(None), DataType::Int16); check_scalar_cast( - ScalarValue::Utf8(Some("foo".to_string())), + ScalarValue::from("foo"), DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)), ); @@ -5434,10 +5080,7 @@ mod tests { (ScalarValue::Int8(None), ScalarValue::Int16(Some(1))), (ScalarValue::Int8(Some(1)), ScalarValue::Int16(None)), // Unsupported types - ( - ScalarValue::Utf8(Some("foo".to_string())), - ScalarValue::Utf8(Some("bar".to_string())), - ), + (ScalarValue::from("foo"), ScalarValue::from("bar")), ( ScalarValue::Boolean(Some(true)), ScalarValue::Boolean(Some(false)), diff --git a/datafusion/common/src/utils.rs b/datafusion/common/src/utils.rs index 12d4f516b4d0..fecab8835e50 100644 --- a/datafusion/common/src/utils.rs +++ b/datafusion/common/src/utils.rs @@ -25,7 +25,8 @@ use arrow::compute; use arrow::compute::{partition, SortColumn, SortOptions}; use arrow::datatypes::{Field, SchemaRef, UInt32Type}; use arrow::record_batch::RecordBatch; -use arrow_array::{Array, LargeListArray, ListArray}; +use arrow_array::{Array, LargeListArray, ListArray, RecordBatchOptions}; +use arrow_schema::DataType; use sqlparser::ast::Ident; use sqlparser::dialect::GenericDialect; use sqlparser::parser::Parser; @@ -89,8 +90,12 @@ pub fn get_record_batch_at_indices( indices: &PrimitiveArray, ) -> Result { let new_columns = get_arrayref_at_indices(record_batch.columns(), indices)?; - RecordBatch::try_new(record_batch.schema(), new_columns) - .map_err(DataFusionError::ArrowError) + RecordBatch::try_new_with_options( + record_batch.schema(), + new_columns, + &RecordBatchOptions::new().with_row_count(Some(indices.len())), + ) + .map_err(DataFusionError::ArrowError) } /// This function compares two tuples depending on the given sort options. @@ -134,7 +139,7 @@ pub fn bisect( ) -> Result { let low: usize = 0; let high: usize = item_columns - .get(0) + .first() .ok_or_else(|| { DataFusionError::Internal("Column array shouldn't be empty".to_string()) })? @@ -185,7 +190,7 @@ pub fn linear_search( ) -> Result { let low: usize = 0; let high: usize = item_columns - .get(0) + .first() .ok_or_else(|| { DataFusionError::Internal("Column array shouldn't be empty".to_string()) })? @@ -402,6 +407,37 @@ pub fn arrays_into_list_array( )) } +/// Get the base type of a data type. +/// +/// Example +/// ``` +/// use arrow::datatypes::{DataType, Field}; +/// use datafusion_common::utils::base_type; +/// use std::sync::Arc; +/// +/// let data_type = DataType::List(Arc::new(Field::new("item", DataType::Int32, true))); +/// assert_eq!(base_type(&data_type), DataType::Int32); +/// +/// let data_type = DataType::Int32; +/// assert_eq!(base_type(&data_type), DataType::Int32); +/// ``` +pub fn base_type(data_type: &DataType) -> DataType { + if let DataType::List(field) = data_type { + base_type(field.data_type()) + } else { + data_type.to_owned() + } +} + +/// Compute the number of dimensions in a list data type. +pub fn list_ndims(data_type: &DataType) -> u64 { + if let DataType::List(field) = data_type { + 1 + list_ndims(field.data_type()) + } else { + 0 + } +} + /// An extension trait for smart pointers. Provides an interface to get a /// raw pointer to the data (with metadata stripped away). /// diff --git a/datafusion/core/Cargo.toml b/datafusion/core/Cargo.toml index 0b7aa1509820..7caf91e24f2f 100644 --- a/datafusion/core/Cargo.toml +++ b/datafusion/core/Cargo.toml @@ -167,3 +167,7 @@ name = "sort" [[bench]] harness = false name = "topk_aggregate" + +[[bench]] +harness = false +name = "array_expression" diff --git a/datafusion/core/benches/array_expression.rs b/datafusion/core/benches/array_expression.rs new file mode 100644 index 000000000000..95bc93e0e353 --- /dev/null +++ b/datafusion/core/benches/array_expression.rs @@ -0,0 +1,73 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#[macro_use] +extern crate criterion; +extern crate arrow; +extern crate datafusion; + +mod data_utils; +use crate::criterion::Criterion; +use arrow_array::cast::AsArray; +use arrow_array::types::Int64Type; +use arrow_array::{ArrayRef, Int64Array, ListArray}; +use datafusion_physical_expr::array_expressions; +use std::sync::Arc; + +fn criterion_benchmark(c: &mut Criterion) { + // Construct large arrays for benchmarking + + let array_len = 100000000; + + let array = (0..array_len).map(|_| Some(2_i64)).collect::>(); + let list_array = ListArray::from_iter_primitive::(vec![ + Some(array.clone()), + Some(array.clone()), + Some(array), + ]); + let from_array = Int64Array::from_value(2, 3); + let to_array = Int64Array::from_value(-2, 3); + + let args = vec![ + Arc::new(list_array) as ArrayRef, + Arc::new(from_array) as ArrayRef, + Arc::new(to_array) as ArrayRef, + ]; + + let array = (0..array_len).map(|_| Some(-2_i64)).collect::>(); + let expected_array = ListArray::from_iter_primitive::(vec![ + Some(array.clone()), + Some(array.clone()), + Some(array), + ]); + + // Benchmark array functions + + c.bench_function("array_replace", |b| { + b.iter(|| { + assert_eq!( + array_expressions::array_replace_all(args.as_slice()) + .unwrap() + .as_list::(), + criterion::black_box(&expected_array) + ) + }) + }); +} + +criterion_group!(benches, criterion_benchmark); +criterion_main!(benches); diff --git a/datafusion/core/benches/sort_limit_query_sql.rs b/datafusion/core/benches/sort_limit_query_sql.rs index efed5a04e7a5..cfd4b8bc4bba 100644 --- a/datafusion/core/benches/sort_limit_query_sql.rs +++ b/datafusion/core/benches/sort_limit_query_sql.rs @@ -99,7 +99,7 @@ fn create_context() -> Arc> { ctx_holder.lock().push(Arc::new(Mutex::new(ctx))) }); - let ctx = ctx_holder.lock().get(0).unwrap().clone(); + let ctx = ctx_holder.lock().first().unwrap().clone(); ctx } diff --git a/datafusion/core/benches/sql_query_with_io.rs b/datafusion/core/benches/sql_query_with_io.rs index 1f9b4dc6ccf7..c7a838385bd6 100644 --- a/datafusion/core/benches/sql_query_with_io.rs +++ b/datafusion/core/benches/sql_query_with_io.rs @@ -93,10 +93,9 @@ async fn setup_files(store: Arc) { for partition in 0..TABLE_PARTITIONS { for file in 0..PARTITION_FILES { let data = create_parquet_file(&mut rng, file * FILE_ROWS); - let location = Path::try_from(format!( + let location = Path::from(format!( "{table_name}/partition={partition}/{file}.parquet" - )) - .unwrap(); + )); store.put(&location, data).await.unwrap(); } } diff --git a/datafusion/core/src/catalog/listing_schema.rs b/datafusion/core/src/catalog/listing_schema.rs index 0d5c49f377d0..c3c682689542 100644 --- a/datafusion/core/src/catalog/listing_schema.rs +++ b/datafusion/core/src/catalog/listing_schema.rs @@ -149,6 +149,7 @@ impl ListingSchemaProvider { unbounded: false, options: Default::default(), constraints: Constraints::empty(), + column_defaults: Default::default(), }, ) .await?; diff --git a/datafusion/core/src/dataframe/mod.rs b/datafusion/core/src/dataframe/mod.rs index 89e82fa952bb..af335cd790c3 100644 --- a/datafusion/core/src/dataframe/mod.rs +++ b/datafusion/core/src/dataframe/mod.rs @@ -23,43 +23,43 @@ mod parquet; use std::any::Any; use std::sync::Arc; +use crate::arrow::datatypes::{Schema, SchemaRef}; +use crate::arrow::record_batch::RecordBatch; +use crate::arrow::util::pretty; +use crate::datasource::{provider_as_source, MemTable, TableProvider}; +use crate::error::Result; +use crate::execution::{ + context::{SessionState, TaskContext}, + FunctionRegistry, +}; +use crate::logical_expr::utils::find_window_exprs; +use crate::logical_expr::{ + col, Expr, JoinType, LogicalPlan, LogicalPlanBuilder, Partitioning, TableType, +}; +use crate::physical_plan::{ + collect, collect_partitioned, execute_stream, execute_stream_partitioned, + ExecutionPlan, SendableRecordBatchStream, +}; +use crate::prelude::SessionContext; + use arrow::array::{Array, ArrayRef, Int64Array, StringArray}; use arrow::compute::{cast, concat}; use arrow::csv::WriterBuilder; use arrow::datatypes::{DataType, Field}; -use async_trait::async_trait; use datafusion_common::file_options::csv_writer::CsvWriterOptions; use datafusion_common::file_options::json_writer::JsonWriterOptions; use datafusion_common::parsers::CompressionTypeVariant; use datafusion_common::{ - DataFusionError, FileType, FileTypeWriterOptions, SchemaError, UnnestOptions, + Column, DFSchema, DataFusionError, FileType, FileTypeWriterOptions, ParamValues, + SchemaError, UnnestOptions, }; use datafusion_expr::dml::CopyOptions; - -use datafusion_common::{Column, DFSchema, ScalarValue}; use datafusion_expr::{ avg, count, is_null, max, median, min, stddev, utils::COUNT_STAR_EXPANSION, TableProviderFilterPushDown, UNNAMED_TABLE, }; -use crate::arrow::datatypes::Schema; -use crate::arrow::datatypes::SchemaRef; -use crate::arrow::record_batch::RecordBatch; -use crate::arrow::util::pretty; -use crate::datasource::{provider_as_source, MemTable, TableProvider}; -use crate::error::Result; -use crate::execution::{ - context::{SessionState, TaskContext}, - FunctionRegistry, -}; -use crate::logical_expr::{ - col, utils::find_window_exprs, Expr, JoinType, LogicalPlan, LogicalPlanBuilder, - Partitioning, TableType, -}; -use crate::physical_plan::SendableRecordBatchStream; -use crate::physical_plan::{collect, collect_partitioned}; -use crate::physical_plan::{execute_stream, execute_stream_partitioned, ExecutionPlan}; -use crate::prelude::SessionContext; +use async_trait::async_trait; /// Contains options that control how data is /// written out from a DataFrame @@ -1227,11 +1227,32 @@ impl DataFrame { /// ], /// &results /// ); + /// // Note you can also provide named parameters + /// let results = ctx + /// .sql("SELECT a FROM example WHERE b = $my_param") + /// .await? + /// // replace $my_param with value 2 + /// // Note you can also use a HashMap as well + /// .with_param_values(vec![ + /// ("my_param", ScalarValue::from(2i64)) + /// ])? + /// .collect() + /// .await?; + /// assert_batches_eq!( + /// &[ + /// "+---+", + /// "| a |", + /// "+---+", + /// "| 1 |", + /// "+---+", + /// ], + /// &results + /// ); /// # Ok(()) /// # } /// ``` - pub fn with_param_values(self, param_values: Vec) -> Result { - let plan = self.plan.with_param_values(param_values)?; + pub fn with_param_values(self, query_values: impl Into) -> Result { + let plan = self.plan.with_param_values(query_values)?; Ok(Self::new(self.session_state, plan)) } @@ -1321,24 +1342,43 @@ impl TableProvider for DataFrameTableProvider { mod tests { use std::vec; - use arrow::array::Int32Array; - use arrow::datatypes::DataType; + use super::*; + use crate::execution::context::SessionConfig; + use crate::physical_plan::{ColumnarValue, Partitioning, PhysicalExpr}; + use crate::test_util::{register_aggregate_csv, test_table, test_table_with_name}; + use crate::{assert_batches_sorted_eq, execution::context::SessionContext}; + use arrow::array::{self, Int32Array}; + use arrow::datatypes::DataType; + use datafusion_common::{Constraint, Constraints, ScalarValue}; use datafusion_expr::{ avg, cast, count, count_distinct, create_udf, expr, lit, max, min, sum, - BuiltInWindowFunction, ScalarFunctionImplementation, Volatility, WindowFrame, - WindowFunction, + BinaryExpr, BuiltInWindowFunction, Operator, ScalarFunctionImplementation, + Volatility, WindowFrame, WindowFunction, }; use datafusion_physical_expr::expressions::Column; - - use crate::execution::context::SessionConfig; - use crate::physical_plan::ColumnarValue; - use crate::physical_plan::Partitioning; - use crate::physical_plan::PhysicalExpr; - use crate::test_util::{register_aggregate_csv, test_table, test_table_with_name}; - use crate::{assert_batches_sorted_eq, execution::context::SessionContext}; - - use super::*; + use datafusion_physical_plan::get_plan_string; + + pub fn table_with_constraints() -> Arc { + let dual_schema = Arc::new(Schema::new(vec![ + Field::new("id", DataType::Int32, false), + Field::new("name", DataType::Utf8, false), + ])); + let batch = RecordBatch::try_new( + dual_schema.clone(), + vec![ + Arc::new(array::Int32Array::from(vec![1])), + Arc::new(array::StringArray::from(vec!["a"])), + ], + ) + .unwrap(); + let provider = MemTable::try_new(dual_schema, vec![vec![batch]]) + .unwrap() + .with_constraints(Constraints::new_unverified(vec![Constraint::PrimaryKey( + vec![0], + )])); + Arc::new(provider) + } async fn assert_logical_expr_schema_eq_physical_expr_schema( df: DataFrame, @@ -1535,6 +1575,294 @@ mod tests { Ok(()) } + #[tokio::test] + async fn test_aggregate_with_pk() -> Result<()> { + // create the dataframe + let config = SessionConfig::new().with_target_partitions(1); + let ctx = SessionContext::new_with_config(config); + + let table1 = table_with_constraints(); + let df = ctx.read_table(table1)?; + let col_id = Expr::Column(datafusion_common::Column { + relation: None, + name: "id".to_string(), + }); + let col_name = Expr::Column(datafusion_common::Column { + relation: None, + name: "name".to_string(), + }); + + // group by contains id column + let group_expr = vec![col_id.clone()]; + let aggr_expr = vec![]; + let df = df.aggregate(group_expr, aggr_expr)?; + + // expr list contains id, name + let expr_list = vec![col_id, col_name]; + let df = df.select(expr_list)?; + let physical_plan = df.clone().create_physical_plan().await?; + let expected = vec![ + "AggregateExec: mode=Single, gby=[id@0 as id, name@1 as name], aggr=[]", + " MemoryExec: partitions=1, partition_sizes=[1]", + ]; + // Get string representation of the plan + let actual = get_plan_string(&physical_plan); + assert_eq!( + expected, actual, + "\n**Optimized Plan Mismatch\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n" + ); + // Since id and name are functionally dependant, we can use name among expression + // even if it is not part of the group by expression. + let df_results = collect(physical_plan, ctx.task_ctx()).await?; + + #[rustfmt::skip] + assert_batches_sorted_eq!( + ["+----+------+", + "| id | name |", + "+----+------+", + "| 1 | a |", + "+----+------+",], + &df_results + ); + + Ok(()) + } + + #[tokio::test] + async fn test_aggregate_with_pk2() -> Result<()> { + // create the dataframe + let config = SessionConfig::new().with_target_partitions(1); + let ctx = SessionContext::new_with_config(config); + + let table1 = table_with_constraints(); + let df = ctx.read_table(table1)?; + let col_id = Expr::Column(datafusion_common::Column { + relation: None, + name: "id".to_string(), + }); + let col_name = Expr::Column(datafusion_common::Column { + relation: None, + name: "name".to_string(), + }); + + // group by contains id column + let group_expr = vec![col_id.clone()]; + let aggr_expr = vec![]; + let df = df.aggregate(group_expr, aggr_expr)?; + + let condition1 = Expr::BinaryExpr(BinaryExpr::new( + Box::new(col_id.clone()), + Operator::Eq, + Box::new(Expr::Literal(ScalarValue::Int32(Some(1)))), + )); + let condition2 = Expr::BinaryExpr(BinaryExpr::new( + Box::new(col_name), + Operator::Eq, + Box::new(Expr::Literal(ScalarValue::Utf8(Some("a".to_string())))), + )); + // Predicate refers to id, and name fields + let predicate = Expr::BinaryExpr(BinaryExpr::new( + Box::new(condition1), + Operator::And, + Box::new(condition2), + )); + let df = df.filter(predicate)?; + let physical_plan = df.clone().create_physical_plan().await?; + + let expected = vec![ + "CoalesceBatchesExec: target_batch_size=8192", + " FilterExec: id@0 = 1 AND name@1 = a", + " AggregateExec: mode=Single, gby=[id@0 as id, name@1 as name], aggr=[]", + " MemoryExec: partitions=1, partition_sizes=[1]", + ]; + // Get string representation of the plan + let actual = get_plan_string(&physical_plan); + assert_eq!( + expected, actual, + "\n**Optimized Plan Mismatch\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n" + ); + + // Since id and name are functionally dependant, we can use name among expression + // even if it is not part of the group by expression. + let df_results = collect(physical_plan, ctx.task_ctx()).await?; + + #[rustfmt::skip] + assert_batches_sorted_eq!( + ["+----+------+", + "| id | name |", + "+----+------+", + "| 1 | a |", + "+----+------+",], + &df_results + ); + + Ok(()) + } + + #[tokio::test] + async fn test_aggregate_with_pk3() -> Result<()> { + // create the dataframe + let config = SessionConfig::new().with_target_partitions(1); + let ctx = SessionContext::new_with_config(config); + + let table1 = table_with_constraints(); + let df = ctx.read_table(table1)?; + let col_id = Expr::Column(datafusion_common::Column { + relation: None, + name: "id".to_string(), + }); + let col_name = Expr::Column(datafusion_common::Column { + relation: None, + name: "name".to_string(), + }); + + // group by contains id column + let group_expr = vec![col_id.clone()]; + let aggr_expr = vec![]; + // group by id, + let df = df.aggregate(group_expr, aggr_expr)?; + + let condition1 = Expr::BinaryExpr(BinaryExpr::new( + Box::new(col_id.clone()), + Operator::Eq, + Box::new(Expr::Literal(ScalarValue::Int32(Some(1)))), + )); + // Predicate refers to id field + let predicate = condition1; + // id=0 + let df = df.filter(predicate)?; + // Select expression refers to id, and name columns. + // id, name + let df = df.select(vec![col_id.clone(), col_name.clone()])?; + let physical_plan = df.clone().create_physical_plan().await?; + + let expected = vec![ + "CoalesceBatchesExec: target_batch_size=8192", + " FilterExec: id@0 = 1", + " AggregateExec: mode=Single, gby=[id@0 as id, name@1 as name], aggr=[]", + " MemoryExec: partitions=1, partition_sizes=[1]", + ]; + // Get string representation of the plan + let actual = get_plan_string(&physical_plan); + assert_eq!( + expected, actual, + "\n**Optimized Plan Mismatch\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n" + ); + + // Since id and name are functionally dependant, we can use name among expression + // even if it is not part of the group by expression. + let df_results = collect(physical_plan, ctx.task_ctx()).await?; + + #[rustfmt::skip] + assert_batches_sorted_eq!( + ["+----+------+", + "| id | name |", + "+----+------+", + "| 1 | a |", + "+----+------+",], + &df_results + ); + + Ok(()) + } + + #[tokio::test] + async fn test_aggregate_with_pk4() -> Result<()> { + // create the dataframe + let config = SessionConfig::new().with_target_partitions(1); + let ctx = SessionContext::new_with_config(config); + + let table1 = table_with_constraints(); + let df = ctx.read_table(table1)?; + let col_id = Expr::Column(datafusion_common::Column { + relation: None, + name: "id".to_string(), + }); + + // group by contains id column + let group_expr = vec![col_id.clone()]; + let aggr_expr = vec![]; + // group by id, + let df = df.aggregate(group_expr, aggr_expr)?; + + let condition1 = Expr::BinaryExpr(BinaryExpr::new( + Box::new(col_id.clone()), + Operator::Eq, + Box::new(Expr::Literal(ScalarValue::Int32(Some(1)))), + )); + // Predicate refers to id field + let predicate = condition1; + // id=1 + let df = df.filter(predicate)?; + // Select expression refers to id column. + // id + let df = df.select(vec![col_id.clone()])?; + let physical_plan = df.clone().create_physical_plan().await?; + + // In this case aggregate shouldn't be expanded, since these + // columns are not used. + let expected = vec![ + "CoalesceBatchesExec: target_batch_size=8192", + " FilterExec: id@0 = 1", + " AggregateExec: mode=Single, gby=[id@0 as id], aggr=[]", + " MemoryExec: partitions=1, partition_sizes=[1]", + ]; + // Get string representation of the plan + let actual = get_plan_string(&physical_plan); + assert_eq!( + expected, actual, + "\n**Optimized Plan Mismatch\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n" + ); + + // Since id and name are functionally dependant, we can use name among expression + // even if it is not part of the group by expression. + let df_results = collect(physical_plan, ctx.task_ctx()).await?; + + #[rustfmt::skip] + assert_batches_sorted_eq!([ + "+----+", + "| id |", + "+----+", + "| 1 |", + "+----+",], + &df_results + ); + + Ok(()) + } + + #[tokio::test] + async fn test_aggregate_alias() -> Result<()> { + let df = test_table().await?; + + let df = df + // GROUP BY `c2 + 1` + .aggregate(vec![col("c2") + lit(1)], vec![])? + // SELECT `c2 + 1` as c2 + .select(vec![(col("c2") + lit(1)).alias("c2")])? + // GROUP BY c2 as "c2" (alias in expr is not supported by SQL) + .aggregate(vec![col("c2").alias("c2")], vec![])?; + + let df_results = df.collect().await?; + + #[rustfmt::skip] + assert_batches_sorted_eq!([ + "+----+", + "| c2 |", + "+----+", + "| 2 |", + "| 3 |", + "| 4 |", + "| 5 |", + "| 6 |", + "+----+", + ], + &df_results + ); + + Ok(()) + } + #[tokio::test] async fn test_distinct() -> Result<()> { let t = test_table().await?; diff --git a/datafusion/core/src/datasource/empty.rs b/datafusion/core/src/datasource/empty.rs index 77160aa5d1c0..5100987520ee 100644 --- a/datafusion/core/src/datasource/empty.rs +++ b/datafusion/core/src/datasource/empty.rs @@ -77,7 +77,7 @@ impl TableProvider for EmptyTable { // even though there is no data, projections apply let projected_schema = project_schema(&self.schema, projection)?; Ok(Arc::new( - EmptyExec::new(false, projected_schema).with_partitions(self.partitions), + EmptyExec::new(projected_schema).with_partitions(self.partitions), )) } } diff --git a/datafusion/core/src/datasource/file_format/parquet.rs b/datafusion/core/src/datasource/file_format/parquet.rs index cf6b87408107..09e54558f12e 100644 --- a/datafusion/core/src/datasource/file_format/parquet.rs +++ b/datafusion/core/src/datasource/file_format/parquet.rs @@ -1803,8 +1803,8 @@ mod tests { // there is only one row group in one file. assert_eq!(page_index.len(), 1); assert_eq!(offset_index.len(), 1); - let page_index = page_index.get(0).unwrap(); - let offset_index = offset_index.get(0).unwrap(); + let page_index = page_index.first().unwrap(); + let offset_index = offset_index.first().unwrap(); // 13 col in one row group assert_eq!(page_index.len(), 13); diff --git a/datafusion/core/src/datasource/file_format/write/demux.rs b/datafusion/core/src/datasource/file_format/write/demux.rs index 27c65dd459ec..fa4ed8437015 100644 --- a/datafusion/core/src/datasource/file_format/write/demux.rs +++ b/datafusion/core/src/datasource/file_format/write/demux.rs @@ -264,12 +264,9 @@ async fn hive_style_partitions_demuxer( // TODO: upstream RecordBatch::take to arrow-rs let take_indices = builder.finish(); let struct_array: StructArray = rb.clone().into(); - let parted_batch = RecordBatch::try_from( + let parted_batch = RecordBatch::from( arrow::compute::take(&struct_array, &take_indices, None)?.as_struct(), - ) - .map_err(|_| { - DataFusionError::Internal("Unexpected error partitioning batch!".into()) - })?; + ); // Get or create channel for this batch let part_tx = match value_map.get_mut(&part_key) { diff --git a/datafusion/core/src/datasource/function.rs b/datafusion/core/src/datasource/function.rs new file mode 100644 index 000000000000..2fd352ee4eb3 --- /dev/null +++ b/datafusion/core/src/datasource/function.rs @@ -0,0 +1,56 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! A table that uses a function to generate data + +use super::TableProvider; + +use datafusion_common::Result; +use datafusion_expr::Expr; + +use std::sync::Arc; + +/// A trait for table function implementations +pub trait TableFunctionImpl: Sync + Send { + /// Create a table provider + fn call(&self, args: &[Expr]) -> Result>; +} + +/// A table that uses a function to generate data +pub struct TableFunction { + /// Name of the table function + name: String, + /// Function implementation + fun: Arc, +} + +impl TableFunction { + /// Create a new table function + pub fn new(name: String, fun: Arc) -> Self { + Self { name, fun } + } + + /// Get the name of the table function + pub fn name(&self) -> &str { + &self.name + } + + /// Get the function implementation and generate a table + pub fn create_table_provider(&self, args: &[Expr]) -> Result> { + self.fun.call(args) + } +} diff --git a/datafusion/core/src/datasource/listing/helpers.rs b/datafusion/core/src/datasource/listing/helpers.rs index a4505cf62d6a..3536c098bd76 100644 --- a/datafusion/core/src/datasource/listing/helpers.rs +++ b/datafusion/core/src/datasource/listing/helpers.rs @@ -526,19 +526,13 @@ mod tests { f1.object_meta.location.as_ref(), "tablepath/mypartition=val1/file.parquet" ); - assert_eq!( - &f1.partition_values, - &[ScalarValue::Utf8(Some(String::from("val1"))),] - ); + assert_eq!(&f1.partition_values, &[ScalarValue::from("val1")]); let f2 = &pruned[1]; assert_eq!( f2.object_meta.location.as_ref(), "tablepath/mypartition=val1/other=val3/file.parquet" ); - assert_eq!( - f2.partition_values, - &[ScalarValue::Utf8(Some(String::from("val1"))),] - ); + assert_eq!(f2.partition_values, &[ScalarValue::from("val1"),]); } #[tokio::test] @@ -579,10 +573,7 @@ mod tests { ); assert_eq!( &f1.partition_values, - &[ - ScalarValue::Utf8(Some(String::from("p1v2"))), - ScalarValue::Utf8(Some(String::from("p2v1"))) - ] + &[ScalarValue::from("p1v2"), ScalarValue::from("p2v1"),] ); let f2 = &pruned[1]; assert_eq!( @@ -591,10 +582,7 @@ mod tests { ); assert_eq!( &f2.partition_values, - &[ - ScalarValue::Utf8(Some(String::from("p1v2"))), - ScalarValue::Utf8(Some(String::from("p2v1"))) - ] + &[ScalarValue::from("p1v2"), ScalarValue::from("p2v1")] ); } diff --git a/datafusion/core/src/datasource/listing/table.rs b/datafusion/core/src/datasource/listing/table.rs index a3be57db3a83..0ce1b43fe456 100644 --- a/datafusion/core/src/datasource/listing/table.rs +++ b/datafusion/core/src/datasource/listing/table.rs @@ -17,6 +17,7 @@ //! The table implementation. +use std::collections::HashMap; use std::str::FromStr; use std::{any::Any, sync::Arc}; @@ -156,7 +157,7 @@ impl ListingTableConfig { /// Infer `ListingOptions` based on `table_path` suffix. pub async fn infer_options(self, state: &SessionState) -> Result { - let store = if let Some(url) = self.table_paths.get(0) { + let store = if let Some(url) = self.table_paths.first() { state.runtime_env().object_store(url)? } else { return Ok(self); @@ -164,7 +165,7 @@ impl ListingTableConfig { let file = self .table_paths - .get(0) + .first() .unwrap() .list_all_files(state, store.as_ref(), "") .await? @@ -190,7 +191,7 @@ impl ListingTableConfig { pub async fn infer_schema(self, state: &SessionState) -> Result { match self.options { Some(options) => { - let schema = if let Some(url) = self.table_paths.get(0) { + let schema = if let Some(url) = self.table_paths.first() { options.infer_schema(state, url).await? } else { Arc::new(Schema::empty()) @@ -489,7 +490,7 @@ impl ListingOptions { /// /// # Features /// -/// 1. Merges schemas if the files have compatible but not indentical schemas +/// 1. Merges schemas if the files have compatible but not identical schemas /// /// 2. Hive-style partitioning support, where a path such as /// `/files/date=1/1/2022/data.parquet` is injected as a `date` column. @@ -558,6 +559,7 @@ pub struct ListingTable { collected_statistics: FileStatisticsCache, infinite_source: bool, constraints: Constraints, + column_defaults: HashMap, } impl ListingTable { @@ -596,6 +598,7 @@ impl ListingTable { collected_statistics: Arc::new(DefaultFileStatisticsCache::default()), infinite_source, constraints: Constraints::empty(), + column_defaults: HashMap::new(), }; Ok(table) @@ -607,6 +610,15 @@ impl ListingTable { self } + /// Assign column defaults + pub fn with_column_defaults( + mut self, + column_defaults: HashMap, + ) -> Self { + self.column_defaults = column_defaults; + self + } + /// Set the [`FileStatisticsCache`] used to cache parquet file statistics. /// /// Setting a statistics cache on the `SessionContext` can avoid refetching statistics @@ -673,7 +685,7 @@ impl TableProvider for ListingTable { if partitioned_file_lists.is_empty() { let schema = self.schema(); let projected_schema = project_schema(&schema, projection)?; - return Ok(Arc::new(EmptyExec::new(false, projected_schema))); + return Ok(Arc::new(EmptyExec::new(projected_schema))); } // extract types of partition columns @@ -698,10 +710,10 @@ impl TableProvider for ListingTable { None }; - let object_store_url = if let Some(url) = self.table_paths.get(0) { + let object_store_url = if let Some(url) = self.table_paths.first() { url.object_store() } else { - return Ok(Arc::new(EmptyExec::new(false, Arc::new(Schema::empty())))); + return Ok(Arc::new(EmptyExec::new(Arc::new(Schema::empty())))); }; // create the execution plan self.options @@ -823,7 +835,7 @@ impl TableProvider for ListingTable { // Multiple sort orders in outer vec are equivalent, so we pass only the first one let ordering = self .try_create_output_ordering()? - .get(0) + .first() .ok_or(DataFusionError::Internal( "Expected ListingTable to have a sort order, but none found!".into(), ))? @@ -844,6 +856,10 @@ impl TableProvider for ListingTable { .create_writer_physical_plan(input, state, config, order_requirements) .await } + + fn get_column_default(&self, column: &str) -> Option<&Expr> { + self.column_defaults.get(column) + } } impl ListingTable { @@ -856,7 +872,7 @@ impl ListingTable { filters: &'a [Expr], limit: Option, ) -> Result<(Vec>, Statistics)> { - let store = if let Some(url) = self.table_paths.get(0) { + let store = if let Some(url) = self.table_paths.first() { ctx.runtime_env().object_store(url)? } else { return Ok((vec![], Statistics::new_unknown(&self.file_schema))); diff --git a/datafusion/core/src/datasource/listing_table_factory.rs b/datafusion/core/src/datasource/listing_table_factory.rs index f70a82035108..96436306c641 100644 --- a/datafusion/core/src/datasource/listing_table_factory.rs +++ b/datafusion/core/src/datasource/listing_table_factory.rs @@ -228,7 +228,8 @@ impl TableProviderFactory for ListingTableFactory { .with_cache(state.runtime_env().cache_manager.get_file_statistic_cache()); let table = provider .with_definition(cmd.definition.clone()) - .with_constraints(cmd.constraints.clone()); + .with_constraints(cmd.constraints.clone()) + .with_column_defaults(cmd.column_defaults.clone()); Ok(Arc::new(table)) } } @@ -279,6 +280,7 @@ mod tests { unbounded: false, options: HashMap::new(), constraints: Constraints::empty(), + column_defaults: HashMap::new(), }; let table_provider = factory.create(&state, &cmd).await.unwrap(); let listing_table = table_provider diff --git a/datafusion/core/src/datasource/memory.rs b/datafusion/core/src/datasource/memory.rs index a841518d9c8f..7c044b29366d 100644 --- a/datafusion/core/src/datasource/memory.rs +++ b/datafusion/core/src/datasource/memory.rs @@ -19,9 +19,9 @@ use datafusion_physical_plan::metrics::MetricsSet; use futures::StreamExt; -use hashbrown::HashMap; use log::debug; use std::any::Any; +use std::collections::HashMap; use std::fmt::{self, Debug}; use std::sync::Arc; diff --git a/datafusion/core/src/datasource/mod.rs b/datafusion/core/src/datasource/mod.rs index 45f9bee6a58b..2e516cc36a01 100644 --- a/datafusion/core/src/datasource/mod.rs +++ b/datafusion/core/src/datasource/mod.rs @@ -23,6 +23,7 @@ pub mod avro_to_arrow; pub mod default_table_source; pub mod empty; pub mod file_format; +pub mod function; pub mod listing; pub mod listing_table_factory; pub mod memory; diff --git a/datafusion/core/src/datasource/physical_plan/avro.rs b/datafusion/core/src/datasource/physical_plan/avro.rs index b97f162fd2f5..885b4c5d3911 100644 --- a/datafusion/core/src/datasource/physical_plan/avro.rs +++ b/datafusion/core/src/datasource/physical_plan/avro.rs @@ -406,8 +406,7 @@ mod tests { .await?; let mut partitioned_file = PartitionedFile::from(meta); - partitioned_file.partition_values = - vec![ScalarValue::Utf8(Some("2021-10-26".to_owned()))]; + partitioned_file.partition_values = vec![ScalarValue::from("2021-10-26")]; let avro_exec = AvroExec::new(FileScanConfig { // select specific columns of the files as well as the partitioning diff --git a/datafusion/core/src/datasource/physical_plan/csv.rs b/datafusion/core/src/datasource/physical_plan/csv.rs index 75aa343ffbfc..816a82543bab 100644 --- a/datafusion/core/src/datasource/physical_plan/csv.rs +++ b/datafusion/core/src/datasource/physical_plan/csv.rs @@ -872,8 +872,7 @@ mod tests { // Add partition columns config.table_partition_cols = vec![Field::new("date", DataType::Utf8, false)]; - config.file_groups[0][0].partition_values = - vec![ScalarValue::Utf8(Some("2021-10-26".to_owned()))]; + config.file_groups[0][0].partition_values = vec![ScalarValue::from("2021-10-26")]; // We should be able to project on the partition column // Which is supposed to be after the file fields diff --git a/datafusion/core/src/datasource/physical_plan/file_scan_config.rs b/datafusion/core/src/datasource/physical_plan/file_scan_config.rs index 68e996391cc3..d308397ab6e2 100644 --- a/datafusion/core/src/datasource/physical_plan/file_scan_config.rs +++ b/datafusion/core/src/datasource/physical_plan/file_scan_config.rs @@ -654,15 +654,9 @@ mod tests { // file_batch is ok here because we kept all the file cols in the projection file_batch, &[ - wrap_partition_value_in_dict(ScalarValue::Utf8(Some( - "2021".to_owned(), - ))), - wrap_partition_value_in_dict(ScalarValue::Utf8(Some( - "10".to_owned(), - ))), - wrap_partition_value_in_dict(ScalarValue::Utf8(Some( - "26".to_owned(), - ))), + wrap_partition_value_in_dict(ScalarValue::from("2021")), + wrap_partition_value_in_dict(ScalarValue::from("10")), + wrap_partition_value_in_dict(ScalarValue::from("26")), ], ) .expect("Projection of partition columns into record batch failed"); @@ -688,15 +682,9 @@ mod tests { // file_batch is ok here because we kept all the file cols in the projection file_batch, &[ - wrap_partition_value_in_dict(ScalarValue::Utf8(Some( - "2021".to_owned(), - ))), - wrap_partition_value_in_dict(ScalarValue::Utf8(Some( - "10".to_owned(), - ))), - wrap_partition_value_in_dict(ScalarValue::Utf8(Some( - "27".to_owned(), - ))), + wrap_partition_value_in_dict(ScalarValue::from("2021")), + wrap_partition_value_in_dict(ScalarValue::from("10")), + wrap_partition_value_in_dict(ScalarValue::from("27")), ], ) .expect("Projection of partition columns into record batch failed"); @@ -724,15 +712,9 @@ mod tests { // file_batch is ok here because we kept all the file cols in the projection file_batch, &[ - wrap_partition_value_in_dict(ScalarValue::Utf8(Some( - "2021".to_owned(), - ))), - wrap_partition_value_in_dict(ScalarValue::Utf8(Some( - "10".to_owned(), - ))), - wrap_partition_value_in_dict(ScalarValue::Utf8(Some( - "28".to_owned(), - ))), + wrap_partition_value_in_dict(ScalarValue::from("2021")), + wrap_partition_value_in_dict(ScalarValue::from("10")), + wrap_partition_value_in_dict(ScalarValue::from("28")), ], ) .expect("Projection of partition columns into record batch failed"); @@ -758,9 +740,9 @@ mod tests { // file_batch is ok here because we kept all the file cols in the projection file_batch, &[ - ScalarValue::Utf8(Some("2021".to_owned())), - ScalarValue::Utf8(Some("10".to_owned())), - ScalarValue::Utf8(Some("26".to_owned())), + ScalarValue::from("2021"), + ScalarValue::from("10"), + ScalarValue::from("26"), ], ) .expect("Projection of partition columns into record batch failed"); diff --git a/datafusion/core/src/datasource/physical_plan/json.rs b/datafusion/core/src/datasource/physical_plan/json.rs index 73dcb32ac81f..9c3b523a652c 100644 --- a/datafusion/core/src/datasource/physical_plan/json.rs +++ b/datafusion/core/src/datasource/physical_plan/json.rs @@ -357,9 +357,9 @@ mod tests { ) .unwrap(); let meta = file_groups - .get(0) + .first() .unwrap() - .get(0) + .first() .unwrap() .clone() .object_meta; @@ -391,9 +391,9 @@ mod tests { ) .unwrap(); let path = file_groups - .get(0) + .first() .unwrap() - .get(0) + .first() .unwrap() .object_meta .location diff --git a/datafusion/core/src/datasource/physical_plan/mod.rs b/datafusion/core/src/datasource/physical_plan/mod.rs index 4cf115d03a9b..14e550eab1d5 100644 --- a/datafusion/core/src/datasource/physical_plan/mod.rs +++ b/datafusion/core/src/datasource/physical_plan/mod.rs @@ -135,7 +135,7 @@ impl DisplayAs for FileScanConfig { write!(f, ", infinite_source=true")?; } - if let Some(ordering) = orderings.get(0) { + if let Some(ordering) = orderings.first() { if !ordering.is_empty() { let start = if orderings.len() == 1 { ", output_ordering=" diff --git a/datafusion/core/src/datasource/physical_plan/parquet/mod.rs b/datafusion/core/src/datasource/physical_plan/parquet/mod.rs index 95aae71c779e..641b7bbb1596 100644 --- a/datafusion/core/src/datasource/physical_plan/parquet/mod.rs +++ b/datafusion/core/src/datasource/physical_plan/parquet/mod.rs @@ -468,8 +468,10 @@ impl FileOpener for ParquetOpener { ParquetRecordBatchStreamBuilder::new_with_options(reader, options) .await?; + let file_schema = builder.schema().clone(); + let (schema_mapping, adapted_projections) = - schema_adapter.map_schema(builder.schema())?; + schema_adapter.map_schema(&file_schema)?; // let predicate = predicate.map(|p| reassign_predicate_columns(p, builder.schema(), true)).transpose()?; let mask = ProjectionMask::roots( @@ -481,8 +483,8 @@ impl FileOpener for ParquetOpener { if let Some(predicate) = pushdown_filters.then_some(predicate).flatten() { let row_filter = row_filter::build_row_filter( &predicate, - builder.schema().as_ref(), - table_schema.as_ref(), + &file_schema, + &table_schema, builder.metadata(), reorder_predicates, &file_metrics, @@ -507,6 +509,7 @@ impl FileOpener for ParquetOpener { let file_metadata = builder.metadata().clone(); let predicate = pruning_predicate.as_ref().map(|p| p.as_ref()); let mut row_groups = row_groups::prune_row_groups_by_statistics( + &file_schema, builder.parquet_schema(), file_metadata.row_groups(), file_range, @@ -1603,11 +1606,11 @@ mod tests { let partitioned_file = PartitionedFile { object_meta: meta, partition_values: vec![ - ScalarValue::Utf8(Some("2021".to_owned())), + ScalarValue::from("2021"), ScalarValue::UInt8(Some(10)), ScalarValue::Dictionary( Box::new(DataType::UInt16), - Box::new(ScalarValue::Utf8(Some("26".to_owned()))), + Box::new(ScalarValue::from("26")), ), ], range: None, diff --git a/datafusion/core/src/datasource/physical_plan/parquet/row_groups.rs b/datafusion/core/src/datasource/physical_plan/parquet/row_groups.rs index 0ab2046097c4..7c3f7d9384ab 100644 --- a/datafusion/core/src/datasource/physical_plan/parquet/row_groups.rs +++ b/datafusion/core/src/datasource/physical_plan/parquet/row_groups.rs @@ -55,6 +55,7 @@ use super::ParquetFileMetrics; /// Note: This method currently ignores ColumnOrder /// pub(crate) fn prune_row_groups_by_statistics( + arrow_schema: &Schema, parquet_schema: &SchemaDescriptor, groups: &[RowGroupMetaData], range: Option, @@ -80,7 +81,7 @@ pub(crate) fn prune_row_groups_by_statistics( let pruning_stats = RowGroupPruningStatistics { parquet_schema, row_group_metadata: metadata, - arrow_schema: predicate.schema().as_ref(), + arrow_schema, }; match predicate.prune(&pruning_stats) { Ok(values) => { @@ -350,6 +351,7 @@ mod tests { use arrow::datatypes::Schema; use arrow::datatypes::{DataType, Field}; use datafusion_common::{config::ConfigOptions, TableReference, ToDFSchema}; + use datafusion_common::{DataFusionError, Result}; use datafusion_expr::{ builder::LogicalTableSource, cast, col, lit, AggregateUDF, Expr, ScalarUDF, TableSource, WindowUDF, @@ -415,11 +417,11 @@ mod tests { fn row_group_pruning_predicate_simple_expr() { use datafusion_expr::{col, lit}; // int > 1 => c1_max > 1 - let schema = Schema::new(vec![Field::new("c1", DataType::Int32, false)]); + let schema = + Arc::new(Schema::new(vec![Field::new("c1", DataType::Int32, false)])); let expr = col("c1").gt(lit(15)); let expr = logical2physical(&expr, &schema); - let pruning_predicate = - PruningPredicate::try_new(expr, Arc::new(schema)).unwrap(); + let pruning_predicate = PruningPredicate::try_new(expr, schema.clone()).unwrap(); let field = PrimitiveTypeField::new("c1", PhysicalType::INT32); let schema_descr = get_test_schema_descr(vec![field]); @@ -435,6 +437,7 @@ mod tests { let metrics = parquet_file_metrics(); assert_eq!( prune_row_groups_by_statistics( + &schema, &schema_descr, &[rgm1, rgm2], None, @@ -449,11 +452,11 @@ mod tests { fn row_group_pruning_predicate_missing_stats() { use datafusion_expr::{col, lit}; // int > 1 => c1_max > 1 - let schema = Schema::new(vec![Field::new("c1", DataType::Int32, false)]); + let schema = + Arc::new(Schema::new(vec![Field::new("c1", DataType::Int32, false)])); let expr = col("c1").gt(lit(15)); let expr = logical2physical(&expr, &schema); - let pruning_predicate = - PruningPredicate::try_new(expr, Arc::new(schema)).unwrap(); + let pruning_predicate = PruningPredicate::try_new(expr, schema.clone()).unwrap(); let field = PrimitiveTypeField::new("c1", PhysicalType::INT32); let schema_descr = get_test_schema_descr(vec![field]); @@ -470,6 +473,7 @@ mod tests { // is null / undefined so the first row group can't be filtered out assert_eq!( prune_row_groups_by_statistics( + &schema, &schema_descr, &[rgm1, rgm2], None, @@ -518,6 +522,7 @@ mod tests { // when conditions are joined using AND assert_eq!( prune_row_groups_by_statistics( + &schema, &schema_descr, groups, None, @@ -531,12 +536,13 @@ mod tests { // this bypasses the entire predicate expression and no row groups are filtered out let expr = col("c1").gt(lit(15)).or(col("c2").rem(lit(2)).eq(lit(0))); let expr = logical2physical(&expr, &schema); - let pruning_predicate = PruningPredicate::try_new(expr, schema).unwrap(); + let pruning_predicate = PruningPredicate::try_new(expr, schema.clone()).unwrap(); // if conditions in predicate are joined with OR and an unsupported expression is used // this bypasses the entire predicate expression and no row groups are filtered out assert_eq!( prune_row_groups_by_statistics( + &schema, &schema_descr, groups, None, @@ -547,6 +553,64 @@ mod tests { ); } + #[test] + fn row_group_pruning_predicate_file_schema() { + use datafusion_expr::{col, lit}; + // test row group predicate when file schema is different than table schema + // c1 > 0 + let table_schema = Arc::new(Schema::new(vec![ + Field::new("c1", DataType::Int32, false), + Field::new("c2", DataType::Int32, false), + ])); + let expr = col("c1").gt(lit(0)); + let expr = logical2physical(&expr, &table_schema); + let pruning_predicate = + PruningPredicate::try_new(expr, table_schema.clone()).unwrap(); + + // Model a file schema's column order c2 then c1, which is the opposite + // of the table schema + let file_schema = Arc::new(Schema::new(vec![ + Field::new("c2", DataType::Int32, false), + Field::new("c1", DataType::Int32, false), + ])); + let schema_descr = get_test_schema_descr(vec![ + PrimitiveTypeField::new("c2", PhysicalType::INT32), + PrimitiveTypeField::new("c1", PhysicalType::INT32), + ]); + // rg1 has c2 less than zero, c1 greater than zero + let rgm1 = get_row_group_meta_data( + &schema_descr, + vec![ + ParquetStatistics::int32(Some(-10), Some(-1), None, 0, false), // c2 + ParquetStatistics::int32(Some(1), Some(10), None, 0, false), + ], + ); + // rg1 has c2 greater than zero, c1 less than zero + let rgm2 = get_row_group_meta_data( + &schema_descr, + vec![ + ParquetStatistics::int32(Some(1), Some(10), None, 0, false), + ParquetStatistics::int32(Some(-10), Some(-1), None, 0, false), + ], + ); + + let metrics = parquet_file_metrics(); + let groups = &[rgm1, rgm2]; + // the first row group should be left because c1 is greater than zero + // the second should be filtered out because c1 is less than zero + assert_eq!( + prune_row_groups_by_statistics( + &file_schema, // NB must be file schema, not table_schema + &schema_descr, + groups, + None, + Some(&pruning_predicate), + &metrics + ), + vec![0] + ); + } + fn gen_row_group_meta_data_for_pruning_predicate() -> Vec { let schema_descr = get_test_schema_descr(vec![ PrimitiveTypeField::new("c1", PhysicalType::INT32), @@ -580,13 +644,14 @@ mod tests { let schema_descr = arrow_to_parquet_schema(&schema).unwrap(); let expr = col("c1").gt(lit(15)).and(col("c2").is_null()); let expr = logical2physical(&expr, &schema); - let pruning_predicate = PruningPredicate::try_new(expr, schema).unwrap(); + let pruning_predicate = PruningPredicate::try_new(expr, schema.clone()).unwrap(); let groups = gen_row_group_meta_data_for_pruning_predicate(); let metrics = parquet_file_metrics(); // First row group was filtered out because it contains no null value on "c2". assert_eq!( prune_row_groups_by_statistics( + &schema, &schema_descr, &groups, None, @@ -612,7 +677,7 @@ mod tests { .gt(lit(15)) .and(col("c2").eq(lit(ScalarValue::Boolean(None)))); let expr = logical2physical(&expr, &schema); - let pruning_predicate = PruningPredicate::try_new(expr, schema).unwrap(); + let pruning_predicate = PruningPredicate::try_new(expr, schema.clone()).unwrap(); let groups = gen_row_group_meta_data_for_pruning_predicate(); let metrics = parquet_file_metrics(); @@ -620,6 +685,7 @@ mod tests { // pass predicates. Ideally these should both be false assert_eq!( prune_row_groups_by_statistics( + &schema, &schema_descr, &groups, None, @@ -638,8 +704,11 @@ mod tests { // INT32: c1 > 5, the c1 is decimal(9,2) // The type of scalar value if decimal(9,2), don't need to do cast - let schema = - Schema::new(vec![Field::new("c1", DataType::Decimal128(9, 2), false)]); + let schema = Arc::new(Schema::new(vec![Field::new( + "c1", + DataType::Decimal128(9, 2), + false, + )])); let field = PrimitiveTypeField::new("c1", PhysicalType::INT32) .with_logical_type(LogicalType::Decimal { scale: 2, @@ -650,8 +719,7 @@ mod tests { let schema_descr = get_test_schema_descr(vec![field]); let expr = col("c1").gt(lit(ScalarValue::Decimal128(Some(500), 9, 2))); let expr = logical2physical(&expr, &schema); - let pruning_predicate = - PruningPredicate::try_new(expr, Arc::new(schema)).unwrap(); + let pruning_predicate = PruningPredicate::try_new(expr, schema.clone()).unwrap(); let rgm1 = get_row_group_meta_data( &schema_descr, // [1.00, 6.00] @@ -679,6 +747,7 @@ mod tests { let metrics = parquet_file_metrics(); assert_eq!( prune_row_groups_by_statistics( + &schema, &schema_descr, &[rgm1, rgm2, rgm3], None, @@ -692,8 +761,11 @@ mod tests { // The c1 type is decimal(9,0) in the parquet file, and the type of scalar is decimal(5,2). // We should convert all type to the coercion type, which is decimal(11,2) // The decimal of arrow is decimal(5,2), the decimal of parquet is decimal(9,0) - let schema = - Schema::new(vec![Field::new("c1", DataType::Decimal128(9, 0), false)]); + let schema = Arc::new(Schema::new(vec![Field::new( + "c1", + DataType::Decimal128(9, 0), + false, + )])); let field = PrimitiveTypeField::new("c1", PhysicalType::INT32) .with_logical_type(LogicalType::Decimal { @@ -708,8 +780,7 @@ mod tests { Decimal128(11, 2), )); let expr = logical2physical(&expr, &schema); - let pruning_predicate = - PruningPredicate::try_new(expr, Arc::new(schema)).unwrap(); + let pruning_predicate = PruningPredicate::try_new(expr, schema.clone()).unwrap(); let rgm1 = get_row_group_meta_data( &schema_descr, // [100, 600] @@ -743,6 +814,7 @@ mod tests { let metrics = parquet_file_metrics(); assert_eq!( prune_row_groups_by_statistics( + &schema, &schema_descr, &[rgm1, rgm2, rgm3, rgm4], None, @@ -753,8 +825,11 @@ mod tests { ); // INT64: c1 < 5, the c1 is decimal(18,2) - let schema = - Schema::new(vec![Field::new("c1", DataType::Decimal128(18, 2), false)]); + let schema = Arc::new(Schema::new(vec![Field::new( + "c1", + DataType::Decimal128(18, 2), + false, + )])); let field = PrimitiveTypeField::new("c1", PhysicalType::INT64) .with_logical_type(LogicalType::Decimal { scale: 2, @@ -765,8 +840,7 @@ mod tests { let schema_descr = get_test_schema_descr(vec![field]); let expr = col("c1").lt(lit(ScalarValue::Decimal128(Some(500), 18, 2))); let expr = logical2physical(&expr, &schema); - let pruning_predicate = - PruningPredicate::try_new(expr, Arc::new(schema)).unwrap(); + let pruning_predicate = PruningPredicate::try_new(expr, schema.clone()).unwrap(); let rgm1 = get_row_group_meta_data( &schema_descr, // [6.00, 8.00] @@ -791,6 +865,7 @@ mod tests { let metrics = parquet_file_metrics(); assert_eq!( prune_row_groups_by_statistics( + &schema, &schema_descr, &[rgm1, rgm2, rgm3], None, @@ -802,8 +877,11 @@ mod tests { // FIXED_LENGTH_BYTE_ARRAY: c1 = decimal128(100000, 28, 3), the c1 is decimal(18,2) // the type of parquet is decimal(18,2) - let schema = - Schema::new(vec![Field::new("c1", DataType::Decimal128(18, 2), false)]); + let schema = Arc::new(Schema::new(vec![Field::new( + "c1", + DataType::Decimal128(18, 2), + false, + )])); let field = PrimitiveTypeField::new("c1", PhysicalType::FIXED_LEN_BYTE_ARRAY) .with_logical_type(LogicalType::Decimal { scale: 2, @@ -817,8 +895,7 @@ mod tests { let left = cast(col("c1"), DataType::Decimal128(28, 3)); let expr = left.eq(lit(ScalarValue::Decimal128(Some(100000), 28, 3))); let expr = logical2physical(&expr, &schema); - let pruning_predicate = - PruningPredicate::try_new(expr, Arc::new(schema)).unwrap(); + let pruning_predicate = PruningPredicate::try_new(expr, schema.clone()).unwrap(); // we must use the big-endian when encode the i128 to bytes or vec[u8]. let rgm1 = get_row_group_meta_data( &schema_descr, @@ -862,6 +939,7 @@ mod tests { let metrics = parquet_file_metrics(); assert_eq!( prune_row_groups_by_statistics( + &schema, &schema_descr, &[rgm1, rgm2, rgm3], None, @@ -873,8 +951,11 @@ mod tests { // BYTE_ARRAY: c1 = decimal128(100000, 28, 3), the c1 is decimal(18,2) // the type of parquet is decimal(18,2) - let schema = - Schema::new(vec![Field::new("c1", DataType::Decimal128(18, 2), false)]); + let schema = Arc::new(Schema::new(vec![Field::new( + "c1", + DataType::Decimal128(18, 2), + false, + )])); let field = PrimitiveTypeField::new("c1", PhysicalType::BYTE_ARRAY) .with_logical_type(LogicalType::Decimal { scale: 2, @@ -888,8 +969,7 @@ mod tests { let left = cast(col("c1"), DataType::Decimal128(28, 3)); let expr = left.eq(lit(ScalarValue::Decimal128(Some(100000), 28, 3))); let expr = logical2physical(&expr, &schema); - let pruning_predicate = - PruningPredicate::try_new(expr, Arc::new(schema)).unwrap(); + let pruning_predicate = PruningPredicate::try_new(expr, schema.clone()).unwrap(); // we must use the big-endian when encode the i128 to bytes or vec[u8]. let rgm1 = get_row_group_meta_data( &schema_descr, @@ -922,6 +1002,7 @@ mod tests { let metrics = parquet_file_metrics(); assert_eq!( prune_row_groups_by_statistics( + &schema, &schema_descr, &[rgm1, rgm2, rgm3], None, @@ -994,6 +1075,26 @@ mod tests { create_physical_expr(expr, &df_schema, schema, &execution_props).unwrap() } + // Note the values in the `String` column are: + // ❯ select * from './parquet-testing/data/data_index_bloom_encoding_stats.parquet'; + // +-----------+ + // | String | + // +-----------+ + // | Hello | + // | This is | + // | a | + // | test | + // | How | + // | are you | + // | doing | + // | today | + // | the quick | + // | brown fox | + // | jumps | + // | over | + // | the lazy | + // | dog | + // +-----------+ #[tokio::test] async fn test_row_group_bloom_filter_pruning_predicate_simple_expr() { // load parquet file @@ -1002,7 +1103,7 @@ mod tests { let path = format!("{testdata}/{file_name}"); let data = bytes::Bytes::from(std::fs::read(path).unwrap()); - // generate pruning predicate + // generate pruning predicate `(String = "Hello_Not_exists")` let schema = Schema::new(vec![Field::new("String", DataType::Utf8, false)]); let expr = col(r#""String""#).eq(lit("Hello_Not_Exists")); let expr = logical2physical(&expr, &schema); @@ -1029,7 +1130,7 @@ mod tests { let path = format!("{testdata}/{file_name}"); let data = bytes::Bytes::from(std::fs::read(path).unwrap()); - // generate pruning predicate + // generate pruning predicate `(String = "Hello_Not_exists" OR String = "Hello_Not_exists2")` let schema = Schema::new(vec![Field::new("String", DataType::Utf8, false)]); let expr = lit("1").eq(lit("1")).and( col(r#""String""#) @@ -1091,7 +1192,7 @@ mod tests { let path = format!("{testdata}/{file_name}"); let data = bytes::Bytes::from(std::fs::read(path).unwrap()); - // generate pruning predicate + // generate pruning predicate `(String = "Hello")` let schema = Schema::new(vec![Field::new("String", DataType::Utf8, false)]); let expr = col(r#""String""#).eq(lit("Hello")); let expr = logical2physical(&expr, &schema); @@ -1110,6 +1211,94 @@ mod tests { assert_eq!(pruned_row_groups, row_groups); } + #[tokio::test] + async fn test_row_group_bloom_filter_pruning_predicate_with_exists_2_values() { + // load parquet file + let testdata = datafusion_common::test_util::parquet_test_data(); + let file_name = "data_index_bloom_encoding_stats.parquet"; + let path = format!("{testdata}/{file_name}"); + let data = bytes::Bytes::from(std::fs::read(path).unwrap()); + + // generate pruning predicate `(String = "Hello") OR (String = "the quick")` + let schema = Schema::new(vec![Field::new("String", DataType::Utf8, false)]); + let expr = col(r#""String""#) + .eq(lit("Hello")) + .or(col(r#""String""#).eq(lit("the quick"))); + let expr = logical2physical(&expr, &schema); + let pruning_predicate = + PruningPredicate::try_new(expr, Arc::new(schema)).unwrap(); + + let row_groups = vec![0]; + let pruned_row_groups = test_row_group_bloom_filter_pruning_predicate( + file_name, + data, + &pruning_predicate, + &row_groups, + ) + .await + .unwrap(); + assert_eq!(pruned_row_groups, row_groups); + } + + #[tokio::test] + async fn test_row_group_bloom_filter_pruning_predicate_with_exists_3_values() { + // load parquet file + let testdata = datafusion_common::test_util::parquet_test_data(); + let file_name = "data_index_bloom_encoding_stats.parquet"; + let path = format!("{testdata}/{file_name}"); + let data = bytes::Bytes::from(std::fs::read(path).unwrap()); + + // generate pruning predicate `(String = "Hello") OR (String = "the quick") OR (String = "are you")` + let schema = Schema::new(vec![Field::new("String", DataType::Utf8, false)]); + let expr = col(r#""String""#) + .eq(lit("Hello")) + .or(col(r#""String""#).eq(lit("the quick"))) + .or(col(r#""String""#).eq(lit("are you"))); + let expr = logical2physical(&expr, &schema); + let pruning_predicate = + PruningPredicate::try_new(expr, Arc::new(schema)).unwrap(); + + let row_groups = vec![0]; + let pruned_row_groups = test_row_group_bloom_filter_pruning_predicate( + file_name, + data, + &pruning_predicate, + &row_groups, + ) + .await + .unwrap(); + assert_eq!(pruned_row_groups, row_groups); + } + + #[tokio::test] + async fn test_row_group_bloom_filter_pruning_predicate_with_or_not_eq() { + // load parquet file + let testdata = datafusion_common::test_util::parquet_test_data(); + let file_name = "data_index_bloom_encoding_stats.parquet"; + let path = format!("{testdata}/{file_name}"); + let data = bytes::Bytes::from(std::fs::read(path).unwrap()); + + // generate pruning predicate `(String = "foo") OR (String != "bar")` + let schema = Schema::new(vec![Field::new("String", DataType::Utf8, false)]); + let expr = col(r#""String""#) + .not_eq(lit("foo")) + .or(col(r#""String""#).not_eq(lit("bar"))); + let expr = logical2physical(&expr, &schema); + let pruning_predicate = + PruningPredicate::try_new(expr, Arc::new(schema)).unwrap(); + + let row_groups = vec![0]; + let pruned_row_groups = test_row_group_bloom_filter_pruning_predicate( + file_name, + data, + &pruning_predicate, + &row_groups, + ) + .await + .unwrap(); + assert_eq!(pruned_row_groups, row_groups); + } + #[tokio::test] async fn test_row_group_bloom_filter_pruning_predicate_without_bloom_filter() { // load parquet file @@ -1118,7 +1307,7 @@ mod tests { let path = format!("{testdata}/{file_name}"); let data = bytes::Bytes::from(std::fs::read(path).unwrap()); - // generate pruning predicate + // generate pruning predicate on a column without a bloom filter let schema = Schema::new(vec![Field::new("string_col", DataType::Utf8, false)]); let expr = col(r#""string_col""#).eq(lit("0")); let expr = logical2physical(&expr, &schema); diff --git a/datafusion/core/src/datasource/stream.rs b/datafusion/core/src/datasource/stream.rs index 6965968b6f25..e7512499eb9d 100644 --- a/datafusion/core/src/datasource/stream.rs +++ b/datafusion/core/src/datasource/stream.rs @@ -31,7 +31,7 @@ use async_trait::async_trait; use futures::StreamExt; use tokio::task::spawn_blocking; -use datafusion_common::{plan_err, DataFusionError, Result}; +use datafusion_common::{plan_err, Constraints, DataFusionError, Result}; use datafusion_execution::{SendableRecordBatchStream, TaskContext}; use datafusion_expr::{CreateExternalTable, Expr, TableType}; use datafusion_physical_plan::common::AbortOnDropSingle; @@ -100,6 +100,7 @@ pub struct StreamConfig { encoding: StreamEncoding, header: bool, order: Vec>, + constraints: Constraints, } impl StreamConfig { @@ -118,6 +119,7 @@ impl StreamConfig { encoding: StreamEncoding::Csv, order: vec![], header: false, + constraints: Constraints::empty(), } } @@ -145,6 +147,12 @@ impl StreamConfig { self } + /// Assign constraints + pub fn with_constraints(mut self, constraints: Constraints) -> Self { + self.constraints = constraints; + self + } + fn reader(&self) -> Result> { let file = File::open(&self.location)?; let schema = self.schema.clone(); @@ -215,6 +223,10 @@ impl TableProvider for StreamTable { self.0.schema.clone() } + fn constraints(&self) -> Option<&Constraints> { + Some(&self.0.constraints) + } + fn table_type(&self) -> TableType { TableType::Base } diff --git a/datafusion/core/src/execution/context/mod.rs b/datafusion/core/src/execution/context/mod.rs index dbebedce3c97..58a4f08341d6 100644 --- a/datafusion/core/src/execution/context/mod.rs +++ b/datafusion/core/src/execution/context/mod.rs @@ -26,6 +26,7 @@ mod parquet; use crate::{ catalog::{CatalogList, MemoryCatalogList}, datasource::{ + function::{TableFunction, TableFunctionImpl}, listing::{ListingOptions, ListingTable}, provider::TableProviderFactory, }, @@ -42,7 +43,7 @@ use datafusion_common::{ use datafusion_execution::registry::SerializerRegistry; use datafusion_expr::{ logical_plan::{DdlStatement, Statement}, - StringifiedPlan, UserDefinedLogicalNode, WindowUDF, + Expr, StringifiedPlan, UserDefinedLogicalNode, WindowUDF, }; pub use datafusion_physical_expr::execution_props::ExecutionProps; use datafusion_physical_expr::var_provider::is_system_variables; @@ -803,6 +804,14 @@ impl SessionContext { .add_var_provider(variable_type, provider); } + /// Register a table UDF with this context + pub fn register_udtf(&self, name: &str, fun: Arc) { + self.state.write().table_functions.insert( + name.to_owned(), + Arc::new(TableFunction::new(name.to_owned(), fun)), + ); + } + /// Registers a scalar UDF within this context. /// /// Note in SQL queries, function names are looked up using @@ -1241,6 +1250,8 @@ pub struct SessionState { query_planner: Arc, /// Collection of catalogs containing schemas and ultimately TableProviders catalog_list: Arc, + /// Table Functions + table_functions: HashMap>, /// Scalar functions that are registered with the context scalar_functions: HashMap>, /// Aggregate functions registered in the context @@ -1339,6 +1350,7 @@ impl SessionState { physical_optimizers: PhysicalOptimizer::new(), query_planner: Arc::new(DefaultQueryPlanner {}), catalog_list, + table_functions: HashMap::new(), scalar_functions: HashMap::new(), aggregate_functions: HashMap::new(), window_functions: HashMap::new(), @@ -1877,6 +1889,22 @@ impl<'a> ContextProvider for SessionContextProvider<'a> { .ok_or_else(|| plan_datafusion_err!("table '{name}' not found")) } + fn get_table_function_source( + &self, + name: &str, + args: Vec, + ) -> Result> { + let tbl_func = self + .state + .table_functions + .get(name) + .cloned() + .ok_or_else(|| plan_datafusion_err!("table function '{name}' not found"))?; + let provider = tbl_func.create_table_provider(&args)?; + + Ok(provider_as_source(provider)) + } + fn get_function_meta(&self, name: &str) -> Option> { self.state.scalar_functions().get(name).cloned() } diff --git a/datafusion/core/src/lib.rs b/datafusion/core/src/lib.rs index bf9a4abf4f2d..b3ebbc6e3637 100644 --- a/datafusion/core/src/lib.rs +++ b/datafusion/core/src/lib.rs @@ -283,12 +283,20 @@ //! //! ## Plan Representations //! -//! Logical planning yields [`LogicalPlan`]s nodes and [`Expr`] +//! ### Logical Plans +//! Logical planning yields [`LogicalPlan`] nodes and [`Expr`] //! expressions which are [`Schema`] aware and represent statements //! independent of how they are physically executed. //! A [`LogicalPlan`] is a Directed Acyclic Graph (DAG) of other //! [`LogicalPlan`]s, each potentially containing embedded [`Expr`]s. //! +//! Examples of working with and executing `Expr`s can be found in the +//! [`expr_api`.rs] example +//! +//! [`expr_api`.rs]: https://github.com/apache/arrow-datafusion/blob/main/datafusion-examples/examples/expr_api.rs +//! +//! ### Physical Plans +//! //! An [`ExecutionPlan`] (sometimes referred to as a "physical plan") //! is a plan that can be executed against data. It a DAG of other //! [`ExecutionPlan`]s each potentially containing expressions of the diff --git a/datafusion/core/src/physical_optimizer/aggregate_statistics.rs b/datafusion/core/src/physical_optimizer/aggregate_statistics.rs index 4265e3ff80d0..795857b10ef5 100644 --- a/datafusion/core/src/physical_optimizer/aggregate_statistics.rs +++ b/datafusion/core/src/physical_optimizer/aggregate_statistics.rs @@ -22,7 +22,6 @@ use super::optimizer::PhysicalOptimizerRule; use crate::config::ConfigOptions; use crate::error::Result; use crate::physical_plan::aggregates::AggregateExec; -use crate::physical_plan::empty::EmptyExec; use crate::physical_plan::projection::ProjectionExec; use crate::physical_plan::{expressions, AggregateExpr, ExecutionPlan, Statistics}; use crate::scalar::ScalarValue; @@ -30,6 +29,7 @@ use crate::scalar::ScalarValue; use datafusion_common::stats::Precision; use datafusion_common::tree_node::TreeNode; use datafusion_expr::utils::COUNT_STAR_EXPANSION; +use datafusion_physical_plan::placeholder_row::PlaceholderRowExec; /// Optimizer that uses available statistics for aggregate functions #[derive(Default)] @@ -82,7 +82,7 @@ impl PhysicalOptimizerRule for AggregateStatistics { // input can be entirely removed Ok(Arc::new(ProjectionExec::try_new( projections, - Arc::new(EmptyExec::new(true, plan.schema())), + Arc::new(PlaceholderRowExec::new(plan.schema())), )?)) } else { plan.map_children(|child| self.optimize(child, _config)) diff --git a/datafusion/core/src/physical_optimizer/enforce_sorting.rs b/datafusion/core/src/physical_optimizer/enforce_sorting.rs index ff052b5f040c..14715ede500a 100644 --- a/datafusion/core/src/physical_optimizer/enforce_sorting.rs +++ b/datafusion/core/src/physical_optimizer/enforce_sorting.rs @@ -53,14 +53,15 @@ use crate::physical_plan::sorts::sort_preserving_merge::SortPreservingMergeExec; use crate::physical_plan::windows::{ get_best_fitting_window, BoundedWindowAggExec, WindowAggExec, }; -use crate::physical_plan::{with_new_children_if_necessary, Distribution, ExecutionPlan}; +use crate::physical_plan::{ + with_new_children_if_necessary, Distribution, ExecutionPlan, InputOrderMode, +}; use datafusion_common::tree_node::{Transformed, TreeNode, VisitRecursion}; use datafusion_common::{plan_err, DataFusionError}; use datafusion_physical_expr::{PhysicalSortExpr, PhysicalSortRequirement}; use datafusion_physical_plan::repartition::RepartitionExec; -use datafusion_physical_plan::windows::PartitionSearchMode; use itertools::izip; /// This rule inspects [`SortExec`]'s in the given physical plan and removes the @@ -611,7 +612,7 @@ fn analyze_window_sort_removal( window_expr.to_vec(), window_child, partitionby_exprs.to_vec(), - PartitionSearchMode::Sorted, + InputOrderMode::Sorted, )?) as _ } else { Arc::new(WindowAggExec::try_new( diff --git a/datafusion/core/src/physical_optimizer/join_selection.rs b/datafusion/core/src/physical_optimizer/join_selection.rs index a7ecd1ca655c..6b2fe24acf00 100644 --- a/datafusion/core/src/physical_optimizer/join_selection.rs +++ b/datafusion/core/src/physical_optimizer/join_selection.rs @@ -434,7 +434,7 @@ fn hash_join_convert_symmetric_subrule( config_options: &ConfigOptions, ) -> Option> { if let Some(hash_join) = input.plan.as_any().downcast_ref::() { - let ub_flags = &input.children_unbounded; + let ub_flags = input.children_unbounded(); let (left_unbounded, right_unbounded) = (ub_flags[0], ub_flags[1]); input.unbounded = left_unbounded || right_unbounded; let result = if left_unbounded && right_unbounded { @@ -511,7 +511,7 @@ fn hash_join_swap_subrule( _config_options: &ConfigOptions, ) -> Option> { if let Some(hash_join) = input.plan.as_any().downcast_ref::() { - let ub_flags = &input.children_unbounded; + let ub_flags = input.children_unbounded(); let (left_unbounded, right_unbounded) = (ub_flags[0], ub_flags[1]); input.unbounded = left_unbounded || right_unbounded; let result = if left_unbounded @@ -577,7 +577,7 @@ fn apply_subrules( } let is_unbounded = input .plan - .unbounded_output(&input.children_unbounded) + .unbounded_output(&input.children_unbounded()) // Treat the case where an operator can not run on unbounded data as // if it can and it outputs unbounded data. Do not raise an error yet. // Such operators may be fixed, adjusted or replaced later on during @@ -1253,6 +1253,7 @@ mod hash_join_tests { use arrow::record_batch::RecordBatch; use datafusion_common::utils::DataPtr; use datafusion_common::JoinType; + use datafusion_physical_plan::empty::EmptyExec; use std::sync::Arc; struct TestCase { @@ -1620,10 +1621,22 @@ mod hash_join_tests { false, )?; + let children = vec![ + PipelineStatePropagator { + plan: Arc::new(EmptyExec::new(Arc::new(Schema::empty()))), + unbounded: left_unbounded, + children: vec![], + }, + PipelineStatePropagator { + plan: Arc::new(EmptyExec::new(Arc::new(Schema::empty()))), + unbounded: right_unbounded, + children: vec![], + }, + ]; let initial_hash_join_state = PipelineStatePropagator { plan: Arc::new(join), unbounded: false, - children_unbounded: vec![left_unbounded, right_unbounded], + children, }; let optimized_hash_join = diff --git a/datafusion/core/src/physical_optimizer/pipeline_checker.rs b/datafusion/core/src/physical_optimizer/pipeline_checker.rs index 43ae7dbfe7b6..d59248aadf05 100644 --- a/datafusion/core/src/physical_optimizer/pipeline_checker.rs +++ b/datafusion/core/src/physical_optimizer/pipeline_checker.rs @@ -70,19 +70,27 @@ impl PhysicalOptimizerRule for PipelineChecker { pub struct PipelineStatePropagator { pub(crate) plan: Arc, pub(crate) unbounded: bool, - pub(crate) children_unbounded: Vec, + pub(crate) children: Vec, } impl PipelineStatePropagator { /// Constructs a new, default pipelining state. pub fn new(plan: Arc) -> Self { - let length = plan.children().len(); + let children = plan.children(); PipelineStatePropagator { plan, unbounded: false, - children_unbounded: vec![false; length], + children: children.into_iter().map(Self::new).collect(), } } + + /// Returns the children unboundedness information. + pub fn children_unbounded(&self) -> Vec { + self.children + .iter() + .map(|c| c.unbounded) + .collect::>() + } } impl TreeNode for PipelineStatePropagator { @@ -90,9 +98,8 @@ impl TreeNode for PipelineStatePropagator { where F: FnMut(&Self) -> Result, { - let children = self.plan.children(); - for child in children { - match op(&PipelineStatePropagator::new(child))? { + for child in &self.children { + match op(child)? { VisitRecursion::Continue => {} VisitRecursion::Skip => return Ok(VisitRecursion::Continue), VisitRecursion::Stop => return Ok(VisitRecursion::Stop), @@ -106,25 +113,18 @@ impl TreeNode for PipelineStatePropagator { where F: FnMut(Self) -> Result, { - let children = self.plan.children(); - if !children.is_empty() { - let new_children = children + if !self.children.is_empty() { + let new_children = self + .children .into_iter() - .map(PipelineStatePropagator::new) .map(transform) .collect::>>()?; - let children_unbounded = new_children - .iter() - .map(|c| c.unbounded) - .collect::>(); - let children_plans = new_children - .into_iter() - .map(|child| child.plan) - .collect::>(); + let children_plans = new_children.iter().map(|c| c.plan.clone()).collect(); + Ok(PipelineStatePropagator { plan: with_new_children_if_necessary(self.plan, children_plans)?.into(), unbounded: self.unbounded, - children_unbounded, + children: new_children, }) } else { Ok(self) @@ -149,7 +149,7 @@ pub fn check_finiteness_requirements( } input .plan - .unbounded_output(&input.children_unbounded) + .unbounded_output(&input.children_unbounded()) .map(|value| { input.unbounded = value; Transformed::Yes(input) diff --git a/datafusion/core/src/physical_optimizer/projection_pushdown.rs b/datafusion/core/src/physical_optimizer/projection_pushdown.rs index 7ebb64ab858a..67a2eaf0d9b3 100644 --- a/datafusion/core/src/physical_optimizer/projection_pushdown.rs +++ b/datafusion/core/src/physical_optimizer/projection_pushdown.rs @@ -20,6 +20,7 @@ //! projections one by one if the operator below is amenable to this. If a //! projection reaches a source, it can even dissappear from the plan entirely. +use std::collections::HashMap; use std::sync::Arc; use super::output_requirements::OutputRequirementExec; @@ -42,9 +43,9 @@ use crate::physical_plan::{Distribution, ExecutionPlan}; use arrow_schema::SchemaRef; use datafusion_common::config::ConfigOptions; -use datafusion_common::tree_node::{Transformed, TreeNode}; +use datafusion_common::tree_node::{Transformed, TreeNode, VisitRecursion}; use datafusion_common::JoinSide; -use datafusion_physical_expr::expressions::Column; +use datafusion_physical_expr::expressions::{Column, Literal}; use datafusion_physical_expr::{ Partitioning, PhysicalExpr, PhysicalSortExpr, PhysicalSortRequirement, }; @@ -245,12 +246,36 @@ fn try_swapping_with_streaming_table( } /// Unifies `projection` with its input (which is also a [`ProjectionExec`]). -/// Two consecutive projections can always merge into a single projection. fn try_unifying_projections( projection: &ProjectionExec, child: &ProjectionExec, ) -> Result>> { let mut projected_exprs = vec![]; + let mut column_ref_map: HashMap = HashMap::new(); + + // Collect the column references usage in the outer projection. + projection.expr().iter().for_each(|(expr, _)| { + expr.apply(&mut |expr| { + Ok({ + if let Some(column) = expr.as_any().downcast_ref::() { + *column_ref_map.entry(column.clone()).or_default() += 1; + } + VisitRecursion::Continue + }) + }) + .unwrap(); + }); + + // Merging these projections is not beneficial, e.g + // If an expression is not trivial and it is referred more than 1, unifies projections will be + // beneficial as caching mechanism for non-trivial computations. + // See discussion in: https://github.com/apache/arrow-datafusion/issues/8296 + if column_ref_map.iter().any(|(column, count)| { + *count > 1 && !is_expr_trivial(&child.expr()[column.index()].0.clone()) + }) { + return Ok(None); + } + for (expr, alias) in projection.expr() { // If there is no match in the input projection, we cannot unify these // projections. This case will arise if the projection expression contains @@ -265,6 +290,13 @@ fn try_unifying_projections( .map(|e| Some(Arc::new(e) as _)) } +/// Checks if the given expression is trivial. +/// An expression is considered trivial if it is either a `Column` or a `Literal`. +fn is_expr_trivial(expr: &Arc) -> bool { + expr.as_any().downcast_ref::().is_some() + || expr.as_any().downcast_ref::().is_some() +} + /// Tries to swap `projection` with its input (`output_req`). If possible, /// performs the swap and returns [`OutputRequirementExec`] as the top plan. /// Otherwise, returns `None`. @@ -348,6 +380,10 @@ fn try_swapping_with_filter( }; FilterExec::try_new(new_predicate, make_with_child(projection, filter.input())?) + .and_then(|e| { + let selectivity = filter.default_selectivity(); + e.with_default_selectivity(selectivity) + }) .map(|e| Some(Arc::new(e) as _)) } diff --git a/datafusion/core/src/physical_optimizer/pruning.rs b/datafusion/core/src/physical_optimizer/pruning.rs index de508327fade..b2ba7596db8d 100644 --- a/datafusion/core/src/physical_optimizer/pruning.rs +++ b/datafusion/core/src/physical_optimizer/pruning.rs @@ -66,43 +66,57 @@ use log::trace; /// min_values("X") -> None /// ``` pub trait PruningStatistics { - /// return the minimum values for the named column, if known. - /// Note: the returned array must contain `num_containers()` rows + /// Return the minimum values for the named column, if known. + /// + /// If the minimum value for a particular container is not known, the + /// returned array should have `null` in that row. If the minimum value is + /// not known for any row, return `None`. + /// + /// Note: the returned array must contain [`Self::num_containers`] rows fn min_values(&self, column: &Column) -> Option; - /// return the maximum values for the named column, if known. - /// Note: the returned array must contain `num_containers()` rows. + /// Return the maximum values for the named column, if known. + /// + /// See [`Self::min_values`] for when to return `None` and null values. + /// + /// Note: the returned array must contain [`Self::num_containers`] rows fn max_values(&self, column: &Column) -> Option; - /// return the number of containers (e.g. row groups) being - /// pruned with these statistics + /// Return the number of containers (e.g. row groups) being + /// pruned with these statistics (the number of rows in each returned array) fn num_containers(&self) -> usize; - /// return the number of null values for the named column as an + /// Return the number of null values for the named column as an /// `Option`. /// - /// Note: the returned array must contain `num_containers()` rows. + /// See [`Self::min_values`] for when to return `None` and null values. + /// + /// Note: the returned array must contain [`Self::num_containers`] rows fn null_counts(&self, column: &Column) -> Option; } -/// Evaluates filter expressions on statistics, rather than the actual data. If -/// no rows could possibly pass the filter entire containers can be "pruned" -/// (skipped), without reading any actual data, leading to significant +/// Evaluates filter expressions on statistics such as min/max values and null +/// counts, attempting to prove a "container" (e.g. Parquet Row Group) can be +/// skipped without reading the actual data, potentially leading to significant /// performance improvements. /// -/// [`PruningPredicate`]s are used to prune (avoid scanning) Parquet Row Groups +/// For example, [`PruningPredicate`]s are used to prune Parquet Row Groups /// based on the min/max values found in the Parquet metadata. If the /// `PruningPredicate` can guarantee that no rows in the Row Group match the /// filter, the entire Row Group is skipped during query execution. /// -/// Note that this API is designed to be general, as it works: +/// The `PruningPredicate` API is general, allowing it to be used for pruning +/// other types of containers (e.g. files) based on statistics that may be +/// known from external catalogs (e.g. Delta Lake) or other sources. Thus it +/// supports: /// /// 1. Arbitrary expressions expressions (including user defined functions) /// -/// 2. Anything that implements the [`PruningStatistics`] trait, not just -/// Parquet metadata, allowing it to be used by other systems to prune entities -/// (e.g. entire files) if the statistics are known via some other source, such -/// as a catalog. +/// 2. Vectorized evaluation (provide more than one set of statistics at a time) +/// so it is suitable for pruning 1000s of containers. +/// +/// 3. Anything that implements the [`PruningStatistics`] trait, not just +/// Parquet metadata. /// /// # Example /// @@ -122,6 +136,7 @@ pub trait PruningStatistics { /// B: true (rows might match x = 5) /// C: true (rows might match x = 5) /// ``` +/// /// See [`PruningPredicate::try_new`] and [`PruningPredicate::prune`] for more information. #[derive(Debug, Clone)] pub struct PruningPredicate { @@ -251,8 +266,12 @@ fn is_always_true(expr: &Arc) -> bool { .unwrap_or_default() } -/// Records for which columns statistics are necessary to evaluate a -/// pruning predicate. +/// Describes which columns statistics are necessary to evaluate a +/// [`PruningPredicate`]. +/// +/// This structure permits reading and creating the minimum number statistics, +/// which is important since statistics may be non trivial to read (e.g. large +/// strings or when there are 1000s of columns). /// /// Handles creating references to the min/max statistics /// for columns as well as recording which statistics are needed diff --git a/datafusion/core/src/physical_optimizer/test_utils.rs b/datafusion/core/src/physical_optimizer/test_utils.rs index cc62cda41266..37a76eff1ee2 100644 --- a/datafusion/core/src/physical_optimizer/test_utils.rs +++ b/datafusion/core/src/physical_optimizer/test_utils.rs @@ -35,7 +35,7 @@ use crate::physical_plan::sorts::sort::SortExec; use crate::physical_plan::sorts::sort_preserving_merge::SortPreservingMergeExec; use crate::physical_plan::union::UnionExec; use crate::physical_plan::windows::create_window_expr; -use crate::physical_plan::{ExecutionPlan, Partitioning}; +use crate::physical_plan::{ExecutionPlan, InputOrderMode, Partitioning}; use crate::prelude::{CsvReadOptions, SessionContext}; use arrow_schema::{Schema, SchemaRef, SortOptions}; @@ -44,7 +44,6 @@ use datafusion_execution::object_store::ObjectStoreUrl; use datafusion_expr::{AggregateFunction, WindowFrame, WindowFunction}; use datafusion_physical_expr::expressions::col; use datafusion_physical_expr::{PhysicalExpr, PhysicalSortExpr}; -use datafusion_physical_plan::windows::PartitionSearchMode; use async_trait::async_trait; @@ -240,7 +239,7 @@ pub fn bounded_window_exec( .unwrap()], input.clone(), vec![], - PartitionSearchMode::Sorted, + InputOrderMode::Sorted, ) .unwrap(), ) diff --git a/datafusion/core/src/physical_planner.rs b/datafusion/core/src/physical_planner.rs index 9e64eb9c5108..ab38b3ec6d2f 100644 --- a/datafusion/core/src/physical_planner.rs +++ b/datafusion/core/src/physical_planner.rs @@ -63,12 +63,10 @@ use crate::physical_plan::sorts::sort::SortExec; use crate::physical_plan::union::UnionExec; use crate::physical_plan::unnest::UnnestExec; use crate::physical_plan::values::ValuesExec; -use crate::physical_plan::windows::{ - BoundedWindowAggExec, PartitionSearchMode, WindowAggExec, -}; +use crate::physical_plan::windows::{BoundedWindowAggExec, WindowAggExec}; use crate::physical_plan::{ - aggregates, displayable, udaf, windows, AggregateExpr, ExecutionPlan, Partitioning, - PhysicalExpr, WindowExpr, + aggregates, displayable, udaf, windows, AggregateExpr, ExecutionPlan, InputOrderMode, + Partitioning, PhysicalExpr, WindowExpr, }; use arrow::compute::SortOptions; @@ -86,13 +84,14 @@ use datafusion_expr::expr::{ Cast, GetFieldAccess, GetIndexedField, GroupingSet, InList, Like, TryCast, WindowFunction, }; -use datafusion_expr::expr_rewriter::{unalias, unnormalize_cols}; +use datafusion_expr::expr_rewriter::unnormalize_cols; use datafusion_expr::logical_plan::builder::wrap_projection_for_join_if_necessary; use datafusion_expr::{ DescribeTable, DmlStatement, ScalarFunctionDefinition, StringifiedPlan, WindowFrame, WindowFrameBound, WriteOp, }; use datafusion_physical_expr::expressions::Literal; +use datafusion_physical_plan::placeholder_row::PlaceholderRowExec; use datafusion_sql::utils::window_expr_common_partition_keys; use async_trait::async_trait; @@ -562,8 +561,7 @@ impl DefaultPhysicalPlanner { // doesn't know (nor should care) how the relation was // referred to in the query let filters = unnormalize_cols(filters.iter().cloned()); - let unaliased: Vec = filters.into_iter().map(unalias).collect(); - source.scan(session_state, projection.as_ref(), &unaliased, *fetch).await + source.scan(session_state, projection.as_ref(), &filters, *fetch).await } LogicalPlan::Copy(CopyTo{ input, @@ -762,7 +760,7 @@ impl DefaultPhysicalPlanner { window_expr, input_exec, physical_partition_keys, - PartitionSearchMode::Sorted, + InputOrderMode::Sorted, )?) } else { Arc::new(WindowAggExec::try_new( @@ -918,19 +916,14 @@ impl DefaultPhysicalPlanner { &input_schema, session_state, )?; - Ok(Arc::new(FilterExec::try_new(runtime_expr, physical_input)?)) + let selectivity = session_state.config().options().optimizer.default_filter_selectivity; + let filter = FilterExec::try_new(runtime_expr, physical_input)?; + Ok(Arc::new(filter.with_default_selectivity(selectivity)?)) } - LogicalPlan::Union(Union { inputs, schema }) => { + LogicalPlan::Union(Union { inputs, .. }) => { let physical_plans = self.create_initial_plan_multi(inputs.iter().map(|lp| lp.as_ref()), session_state).await?; - if schema.fields().len() < physical_plans[0].schema().fields().len() { - // `schema` could be a subset of the child schema. For example - // for query "select count(*) from (select a from t union all select a from t)" - // `schema` is empty but child schema contains one field `a`. - Ok(Arc::new(UnionExec::try_new_with_schema(physical_plans, schema.clone())?)) - } else { - Ok(Arc::new(UnionExec::new(physical_plans))) - } + Ok(Arc::new(UnionExec::new(physical_plans))) } LogicalPlan::Repartition(Repartition { input, @@ -1204,10 +1197,15 @@ impl DefaultPhysicalPlanner { } LogicalPlan::Subquery(_) => todo!(), LogicalPlan::EmptyRelation(EmptyRelation { - produce_one_row, + produce_one_row: false, schema, }) => Ok(Arc::new(EmptyExec::new( - *produce_one_row, + SchemaRef::new(schema.as_ref().to_owned().into()), + ))), + LogicalPlan::EmptyRelation(EmptyRelation { + produce_one_row: true, + schema, + }) => Ok(Arc::new(PlaceholderRowExec::new( SchemaRef::new(schema.as_ref().to_owned().into()), ))), LogicalPlan::SubqueryAlias(SubqueryAlias { input, .. }) => { @@ -2020,7 +2018,7 @@ impl DefaultPhysicalPlanner { let mut column_names = StringBuilder::new(); let mut data_types = StringBuilder::new(); let mut is_nullables = StringBuilder::new(); - for (_, field) in table_schema.fields().iter().enumerate() { + for field in table_schema.fields() { column_names.append_value(field.name()); // "System supplied type" --> Use debug format of the datatype @@ -2775,7 +2773,7 @@ mod tests { digraph { 1[shape=box label="ProjectionExec: expr=[id@0 + 2 as employee.id + Int32(2)]", tooltip=""] - 2[shape=box label="EmptyExec: produce_one_row=false", tooltip=""] + 2[shape=box label="EmptyExec", tooltip=""] 1 -> 2 [arrowhead=none, arrowtail=normal, dir=back] } // End DataFusion GraphViz Plan diff --git a/datafusion/core/src/test/variable.rs b/datafusion/core/src/test/variable.rs index a55513841561..38207b42cb7b 100644 --- a/datafusion/core/src/test/variable.rs +++ b/datafusion/core/src/test/variable.rs @@ -37,7 +37,7 @@ impl VarProvider for SystemVar { /// get system variable value fn get_value(&self, var_names: Vec) -> Result { let s = format!("{}-{}", "system-var", var_names.concat()); - Ok(ScalarValue::Utf8(Some(s))) + Ok(ScalarValue::from(s)) } fn get_type(&self, _: &[String]) -> Option { @@ -61,7 +61,7 @@ impl VarProvider for UserDefinedVar { fn get_value(&self, var_names: Vec) -> Result { if var_names[0] != "@integer" { let s = format!("{}-{}", "user-defined-var", var_names.concat()); - Ok(ScalarValue::Utf8(Some(s))) + Ok(ScalarValue::from(s)) } else { Ok(ScalarValue::Int32(Some(41))) } diff --git a/datafusion/core/tests/custom_sources.rs b/datafusion/core/tests/custom_sources.rs index daf1ef41a297..a9ea5cc2a35c 100644 --- a/datafusion/core/tests/custom_sources.rs +++ b/datafusion/core/tests/custom_sources.rs @@ -30,7 +30,6 @@ use datafusion::execution::context::{SessionContext, SessionState, TaskContext}; use datafusion::logical_expr::{ col, Expr, LogicalPlan, LogicalPlanBuilder, TableScan, UNNAMED_TABLE, }; -use datafusion::physical_plan::empty::EmptyExec; use datafusion::physical_plan::expressions::PhysicalSortExpr; use datafusion::physical_plan::{ collect, ColumnStatistics, DisplayAs, DisplayFormatType, ExecutionPlan, Partitioning, @@ -42,6 +41,7 @@ use datafusion_common::project_schema; use datafusion_common::stats::Precision; use async_trait::async_trait; +use datafusion_physical_plan::placeholder_row::PlaceholderRowExec; use futures::stream::Stream; /// Also run all tests that are found in the `custom_sources_cases` directory @@ -256,9 +256,9 @@ async fn optimizers_catch_all_statistics() { let physical_plan = df.create_physical_plan().await.unwrap(); - // when the optimization kicks in, the source is replaced by an EmptyExec + // when the optimization kicks in, the source is replaced by an PlaceholderRowExec assert!( - contains_empty_exec(Arc::clone(&physical_plan)), + contains_place_holder_exec(Arc::clone(&physical_plan)), "Expected aggregate_statistics optimizations missing: {physical_plan:?}" ); @@ -283,12 +283,12 @@ async fn optimizers_catch_all_statistics() { assert_eq!(format!("{:?}", actual[0]), format!("{expected:?}")); } -fn contains_empty_exec(plan: Arc) -> bool { - if plan.as_any().is::() { +fn contains_place_holder_exec(plan: Arc) -> bool { + if plan.as_any().is::() { true } else if plan.children().len() != 1 { false } else { - contains_empty_exec(Arc::clone(&plan.children()[0])) + contains_place_holder_exec(Arc::clone(&plan.children()[0])) } } diff --git a/datafusion/core/tests/dataframe/mod.rs b/datafusion/core/tests/dataframe/mod.rs index 10f4574020bf..c6b8e0e01b4f 100644 --- a/datafusion/core/tests/dataframe/mod.rs +++ b/datafusion/core/tests/dataframe/mod.rs @@ -1323,6 +1323,91 @@ async fn unnest_array_agg() -> Result<()> { Ok(()) } +#[tokio::test] +async fn unnest_with_redundant_columns() -> Result<()> { + let mut shape_id_builder = UInt32Builder::new(); + let mut tag_id_builder = UInt32Builder::new(); + + for shape_id in 1..=3 { + for tag_id in 1..=3 { + shape_id_builder.append_value(shape_id as u32); + tag_id_builder.append_value((shape_id * 10 + tag_id) as u32); + } + } + + let batch = RecordBatch::try_from_iter(vec![ + ("shape_id", Arc::new(shape_id_builder.finish()) as ArrayRef), + ("tag_id", Arc::new(tag_id_builder.finish()) as ArrayRef), + ])?; + + let ctx = SessionContext::new(); + ctx.register_batch("shapes", batch)?; + let df = ctx.table("shapes").await?; + + let results = df.clone().collect().await?; + let expected = vec![ + "+----------+--------+", + "| shape_id | tag_id |", + "+----------+--------+", + "| 1 | 11 |", + "| 1 | 12 |", + "| 1 | 13 |", + "| 2 | 21 |", + "| 2 | 22 |", + "| 2 | 23 |", + "| 3 | 31 |", + "| 3 | 32 |", + "| 3 | 33 |", + "+----------+--------+", + ]; + assert_batches_sorted_eq!(expected, &results); + + // Doing an `array_agg` by `shape_id` produces: + let df = df + .clone() + .aggregate( + vec![col("shape_id")], + vec![array_agg(col("shape_id")).alias("shape_id2")], + )? + .unnest_column("shape_id2")? + .select(vec![col("shape_id")])?; + + let optimized_plan = df.clone().into_optimized_plan()?; + let expected = vec![ + "Projection: shapes.shape_id [shape_id:UInt32]", + " Unnest: shape_id2 [shape_id:UInt32, shape_id2:UInt32;N]", + " Aggregate: groupBy=[[shapes.shape_id]], aggr=[[ARRAY_AGG(shapes.shape_id) AS shape_id2]] [shape_id:UInt32, shape_id2:List(Field { name: \"item\", data_type: UInt32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} });N]", + " TableScan: shapes projection=[shape_id] [shape_id:UInt32]", + ]; + + let formatted = optimized_plan.display_indent_schema().to_string(); + let actual: Vec<&str> = formatted.trim().lines().collect(); + assert_eq!( + expected, actual, + "\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n" + ); + + let results = df.collect().await?; + let expected = [ + "+----------+", + "| shape_id |", + "+----------+", + "| 1 |", + "| 1 |", + "| 1 |", + "| 2 |", + "| 2 |", + "| 2 |", + "| 3 |", + "| 3 |", + "| 3 |", + "+----------+", + ]; + assert_batches_sorted_eq!(expected, &results); + + Ok(()) +} + async fn create_test_table(name: &str) -> Result { let schema = Arc::new(Schema::new(vec![ Field::new("a", DataType::Utf8, false), diff --git a/datafusion/core/tests/fuzz_cases/window_fuzz.rs b/datafusion/core/tests/fuzz_cases/window_fuzz.rs index af96063ffb5f..44ff71d02392 100644 --- a/datafusion/core/tests/fuzz_cases/window_fuzz.rs +++ b/datafusion/core/tests/fuzz_cases/window_fuzz.rs @@ -25,9 +25,9 @@ use arrow::util::pretty::pretty_format_batches; use datafusion::physical_plan::memory::MemoryExec; use datafusion::physical_plan::sorts::sort::SortExec; use datafusion::physical_plan::windows::{ - create_window_expr, BoundedWindowAggExec, PartitionSearchMode, WindowAggExec, + create_window_expr, BoundedWindowAggExec, WindowAggExec, }; -use datafusion::physical_plan::{collect, ExecutionPlan}; +use datafusion::physical_plan::{collect, ExecutionPlan, InputOrderMode}; use datafusion::prelude::{SessionConfig, SessionContext}; use datafusion_common::{Result, ScalarValue}; use datafusion_expr::type_coercion::aggregates::coerce_types; @@ -43,9 +43,7 @@ use hashbrown::HashMap; use rand::rngs::StdRng; use rand::{Rng, SeedableRng}; -use datafusion_physical_plan::windows::PartitionSearchMode::{ - Linear, PartiallySorted, Sorted, -}; +use datafusion_physical_plan::InputOrderMode::{Linear, PartiallySorted, Sorted}; #[tokio::test(flavor = "multi_thread", worker_threads = 16)] async fn window_bounded_window_random_comparison() -> Result<()> { @@ -385,9 +383,9 @@ async fn run_window_test( random_seed: u64, partition_by_columns: Vec<&str>, orderby_columns: Vec<&str>, - search_mode: PartitionSearchMode, + search_mode: InputOrderMode, ) -> Result<()> { - let is_linear = !matches!(search_mode, PartitionSearchMode::Sorted); + let is_linear = !matches!(search_mode, InputOrderMode::Sorted); let mut rng = StdRng::seed_from_u64(random_seed); let schema = input1[0].schema(); let session_config = SessionConfig::new().with_batch_size(50); diff --git a/datafusion/core/tests/parquet/file_statistics.rs b/datafusion/core/tests/parquet/file_statistics.rs index 1ea154303d69..9f94a59a3e59 100644 --- a/datafusion/core/tests/parquet/file_statistics.rs +++ b/datafusion/core/tests/parquet/file_statistics.rs @@ -133,7 +133,7 @@ async fn list_files_with_session_level_cache() { assert_eq!(get_list_file_cache_size(&state1), 1); let fg = &parquet1.base_config().file_groups; assert_eq!(fg.len(), 1); - assert_eq!(fg.get(0).unwrap().len(), 1); + assert_eq!(fg.first().unwrap().len(), 1); //Session 2 first time list files //check session 1 cache result not show in session 2 @@ -144,7 +144,7 @@ async fn list_files_with_session_level_cache() { assert_eq!(get_list_file_cache_size(&state2), 1); let fg2 = &parquet2.base_config().file_groups; assert_eq!(fg2.len(), 1); - assert_eq!(fg2.get(0).unwrap().len(), 1); + assert_eq!(fg2.first().unwrap().len(), 1); //Session 1 second time list files //check session 1 cache result not show in session 2 @@ -155,7 +155,7 @@ async fn list_files_with_session_level_cache() { assert_eq!(get_list_file_cache_size(&state1), 1); let fg = &parquet3.base_config().file_groups; assert_eq!(fg.len(), 1); - assert_eq!(fg.get(0).unwrap().len(), 1); + assert_eq!(fg.first().unwrap().len(), 1); // List same file no increase assert_eq!(get_list_file_cache_size(&state1), 1); } diff --git a/datafusion/core/tests/path_partition.rs b/datafusion/core/tests/path_partition.rs index dd8eb52f67c7..abe6ab283aff 100644 --- a/datafusion/core/tests/path_partition.rs +++ b/datafusion/core/tests/path_partition.rs @@ -168,9 +168,9 @@ async fn parquet_distinct_partition_col() -> Result<()> { assert_eq!(min_limit, resulting_limit); let s = ScalarValue::try_from_array(results[0].column(1), 0)?; - let month = match extract_as_utf(&s) { - Some(month) => month, - s => panic!("Expected month as Dict(_, Utf8) found {s:?}"), + let month = match s { + ScalarValue::Utf8(Some(month)) => month, + s => panic!("Expected month as Utf8 found {s:?}"), }; let sql_on_partition_boundary = format!( @@ -191,15 +191,6 @@ async fn parquet_distinct_partition_col() -> Result<()> { Ok(()) } -fn extract_as_utf(v: &ScalarValue) -> Option { - if let ScalarValue::Dictionary(_, v) = v { - if let ScalarValue::Utf8(v) = v.as_ref() { - return v.clone(); - } - } - None -} - #[tokio::test] async fn csv_filter_with_file_col() -> Result<()> { let ctx = SessionContext::new(); diff --git a/datafusion/core/tests/sql/explain_analyze.rs b/datafusion/core/tests/sql/explain_analyze.rs index ecb5766a3bb5..37f8cefc9080 100644 --- a/datafusion/core/tests/sql/explain_analyze.rs +++ b/datafusion/core/tests/sql/explain_analyze.rs @@ -575,7 +575,7 @@ async fn explain_analyze_runs_optimizers() { // This happens as an optimization pass where count(*) can be // answered using statistics only. - let expected = "EmptyExec: produce_one_row=true"; + let expected = "PlaceholderRowExec"; let sql = "EXPLAIN SELECT count(*) from alltypes_plain"; let actual = execute_to_batches(&ctx, sql).await; @@ -806,7 +806,7 @@ async fn explain_physical_plan_only() { let expected = vec![vec![ "physical_plan", "ProjectionExec: expr=[2 as COUNT(*)]\ - \n EmptyExec: produce_one_row=true\ + \n PlaceholderRowExec\ \n", ]]; assert_eq!(expected, actual); diff --git a/datafusion/core/tests/sql/mod.rs b/datafusion/core/tests/sql/mod.rs index 47de6ec857da..94fc8015a78a 100644 --- a/datafusion/core/tests/sql/mod.rs +++ b/datafusion/core/tests/sql/mod.rs @@ -15,7 +15,6 @@ // specific language governing permissions and limitations // under the License. -use std::convert::TryFrom; use std::sync::Arc; use arrow::{ diff --git a/datafusion/core/tests/sql/parquet.rs b/datafusion/core/tests/sql/parquet.rs index c2844a2b762a..8f810a929df3 100644 --- a/datafusion/core/tests/sql/parquet.rs +++ b/datafusion/core/tests/sql/parquet.rs @@ -263,7 +263,7 @@ async fn parquet_list_columns() { assert_eq!( as_string_array(&utf8_list_array.value(0)).unwrap(), - &StringArray::try_from(vec![Some("abc"), Some("efg"), Some("hij"),]).unwrap() + &StringArray::from(vec![Some("abc"), Some("efg"), Some("hij"),]) ); assert_eq!( diff --git a/datafusion/core/tests/sql/select.rs b/datafusion/core/tests/sql/select.rs index 63f3e979305a..cbdea9d72948 100644 --- a/datafusion/core/tests/sql/select.rs +++ b/datafusion/core/tests/sql/select.rs @@ -525,6 +525,53 @@ async fn test_prepare_statement() -> Result<()> { Ok(()) } +#[tokio::test] +async fn test_named_query_parameters() -> Result<()> { + let tmp_dir = TempDir::new()?; + let partition_count = 4; + let ctx = partitioned_csv::create_ctx(&tmp_dir, partition_count).await?; + + // sql to statement then to logical plan with parameters + // c1 defined as UINT32, c2 defined as UInt64 + let results = ctx + .sql("SELECT c1, c2 FROM test WHERE c1 > $coo AND c1 < $foo") + .await? + .with_param_values(vec![ + ("foo", ScalarValue::UInt32(Some(3))), + ("coo", ScalarValue::UInt32(Some(0))), + ])? + .collect() + .await?; + let expected = vec![ + "+----+----+", + "| c1 | c2 |", + "+----+----+", + "| 1 | 1 |", + "| 1 | 2 |", + "| 1 | 3 |", + "| 1 | 4 |", + "| 1 | 5 |", + "| 1 | 6 |", + "| 1 | 7 |", + "| 1 | 8 |", + "| 1 | 9 |", + "| 1 | 10 |", + "| 2 | 1 |", + "| 2 | 2 |", + "| 2 | 3 |", + "| 2 | 4 |", + "| 2 | 5 |", + "| 2 | 6 |", + "| 2 | 7 |", + "| 2 | 8 |", + "| 2 | 9 |", + "| 2 | 10 |", + "+----+----+", + ]; + assert_batches_sorted_eq!(expected, &results); + Ok(()) +} + #[tokio::test] async fn parallel_query_with_filter() -> Result<()> { let tmp_dir = TempDir::new()?; diff --git a/datafusion/core/tests/user_defined/mod.rs b/datafusion/core/tests/user_defined/mod.rs index 09c7c3d3266b..6c6d966cc3aa 100644 --- a/datafusion/core/tests/user_defined/mod.rs +++ b/datafusion/core/tests/user_defined/mod.rs @@ -26,3 +26,6 @@ mod user_defined_plan; /// Tests for User Defined Window Functions mod user_defined_window_functions; + +/// Tests for User Defined Table Functions +mod user_defined_table_functions; diff --git a/datafusion/core/tests/user_defined/user_defined_table_functions.rs b/datafusion/core/tests/user_defined/user_defined_table_functions.rs new file mode 100644 index 000000000000..b5d10b1c5b9b --- /dev/null +++ b/datafusion/core/tests/user_defined/user_defined_table_functions.rs @@ -0,0 +1,219 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::Int64Array; +use arrow::csv::reader::Format; +use arrow::csv::ReaderBuilder; +use async_trait::async_trait; +use datafusion::arrow::datatypes::SchemaRef; +use datafusion::arrow::record_batch::RecordBatch; +use datafusion::datasource::function::TableFunctionImpl; +use datafusion::datasource::TableProvider; +use datafusion::error::Result; +use datafusion::execution::context::SessionState; +use datafusion::execution::TaskContext; +use datafusion::physical_plan::memory::MemoryExec; +use datafusion::physical_plan::{collect, ExecutionPlan}; +use datafusion::prelude::SessionContext; +use datafusion_common::{assert_batches_eq, DFSchema, ScalarValue}; +use datafusion_expr::{EmptyRelation, Expr, LogicalPlan, Projection, TableType}; +use std::fs::File; +use std::io::Seek; +use std::path::Path; +use std::sync::Arc; + +/// test simple udtf with define read_csv with parameters +#[tokio::test] +async fn test_simple_read_csv_udtf() -> Result<()> { + let ctx = SessionContext::new(); + + ctx.register_udtf("read_csv", Arc::new(SimpleCsvTableFunc {})); + + let csv_file = "tests/tpch-csv/nation.csv"; + // read csv with at most 5 rows + let rbs = ctx + .sql(format!("SELECT * FROM read_csv('{csv_file}', 5);").as_str()) + .await? + .collect() + .await?; + + let excepted = [ + "+-------------+-----------+-------------+-------------------------------------------------------------------------------------------------------------+", + "| n_nationkey | n_name | n_regionkey | n_comment |", + "+-------------+-----------+-------------+-------------------------------------------------------------------------------------------------------------+", + "| 1 | ARGENTINA | 1 | al foxes promise slyly according to the regular accounts. bold requests alon |", + "| 2 | BRAZIL | 1 | y alongside of the pending deposits. carefully special packages are about the ironic forges. slyly special |", + "| 3 | CANADA | 1 | eas hang ironic, silent packages. slyly regular packages are furiously over the tithes. fluffily bold |", + "| 4 | EGYPT | 4 | y above the carefully unusual theodolites. final dugouts are quickly across the furiously regular d |", + "| 5 | ETHIOPIA | 0 | ven packages wake quickly. regu |", + "+-------------+-----------+-------------+-------------------------------------------------------------------------------------------------------------+", ]; + assert_batches_eq!(excepted, &rbs); + + // just run, return all rows + let rbs = ctx + .sql(format!("SELECT * FROM read_csv('{csv_file}');").as_str()) + .await? + .collect() + .await?; + let excepted = [ + "+-------------+-----------+-------------+--------------------------------------------------------------------------------------------------------------------+", + "| n_nationkey | n_name | n_regionkey | n_comment |", + "+-------------+-----------+-------------+--------------------------------------------------------------------------------------------------------------------+", + "| 1 | ARGENTINA | 1 | al foxes promise slyly according to the regular accounts. bold requests alon |", + "| 2 | BRAZIL | 1 | y alongside of the pending deposits. carefully special packages are about the ironic forges. slyly special |", + "| 3 | CANADA | 1 | eas hang ironic, silent packages. slyly regular packages are furiously over the tithes. fluffily bold |", + "| 4 | EGYPT | 4 | y above the carefully unusual theodolites. final dugouts are quickly across the furiously regular d |", + "| 5 | ETHIOPIA | 0 | ven packages wake quickly. regu |", + "| 6 | FRANCE | 3 | refully final requests. regular, ironi |", + "| 7 | GERMANY | 3 | l platelets. regular accounts x-ray: unusual, regular acco |", + "| 8 | INDIA | 2 | ss excuses cajole slyly across the packages. deposits print aroun |", + "| 9 | INDONESIA | 2 | slyly express asymptotes. regular deposits haggle slyly. carefully ironic hockey players sleep blithely. carefull |", + "| 10 | IRAN | 4 | efully alongside of the slyly final dependencies. |", + "+-------------+-----------+-------------+--------------------------------------------------------------------------------------------------------------------+" + ]; + assert_batches_eq!(excepted, &rbs); + + Ok(()) +} + +struct SimpleCsvTable { + schema: SchemaRef, + exprs: Vec, + batches: Vec, +} + +#[async_trait] +impl TableProvider for SimpleCsvTable { + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn schema(&self) -> SchemaRef { + self.schema.clone() + } + + fn table_type(&self) -> TableType { + TableType::Base + } + + async fn scan( + &self, + state: &SessionState, + projection: Option<&Vec>, + _filters: &[Expr], + _limit: Option, + ) -> Result> { + let batches = if !self.exprs.is_empty() { + let max_return_lines = self.interpreter_expr(state).await?; + // get max return rows from self.batches + let mut batches = vec![]; + let mut lines = 0; + for batch in &self.batches { + let batch_lines = batch.num_rows(); + if lines + batch_lines > max_return_lines as usize { + let batch_lines = max_return_lines as usize - lines; + batches.push(batch.slice(0, batch_lines)); + break; + } else { + batches.push(batch.clone()); + lines += batch_lines; + } + } + batches + } else { + self.batches.clone() + }; + Ok(Arc::new(MemoryExec::try_new( + &[batches], + TableProvider::schema(self), + projection.cloned(), + )?)) + } +} + +impl SimpleCsvTable { + async fn interpreter_expr(&self, state: &SessionState) -> Result { + use datafusion::logical_expr::expr_rewriter::normalize_col; + use datafusion::logical_expr::utils::columnize_expr; + let plan = LogicalPlan::EmptyRelation(EmptyRelation { + produce_one_row: true, + schema: Arc::new(DFSchema::empty()), + }); + let logical_plan = Projection::try_new( + vec![columnize_expr( + normalize_col(self.exprs[0].clone(), &plan)?, + plan.schema(), + )], + Arc::new(plan), + ) + .map(LogicalPlan::Projection)?; + let rbs = collect( + state.create_physical_plan(&logical_plan).await?, + Arc::new(TaskContext::from(state)), + ) + .await?; + let limit = rbs[0] + .column(0) + .as_any() + .downcast_ref::() + .unwrap() + .value(0); + Ok(limit) + } +} + +struct SimpleCsvTableFunc {} + +impl TableFunctionImpl for SimpleCsvTableFunc { + fn call(&self, exprs: &[Expr]) -> Result> { + let mut new_exprs = vec![]; + let mut filepath = String::new(); + for expr in exprs { + match expr { + Expr::Literal(ScalarValue::Utf8(Some(ref path))) => { + filepath = path.clone() + } + expr => new_exprs.push(expr.clone()), + } + } + let (schema, batches) = read_csv_batches(filepath)?; + let table = SimpleCsvTable { + schema, + exprs: new_exprs.clone(), + batches, + }; + Ok(Arc::new(table)) + } +} + +fn read_csv_batches(csv_path: impl AsRef) -> Result<(SchemaRef, Vec)> { + let mut file = File::open(csv_path)?; + let (schema, _) = Format::default() + .with_header(true) + .infer_schema(&mut file, None)?; + file.rewind()?; + + let reader = ReaderBuilder::new(Arc::new(schema.clone())) + .with_header(true) + .build(file)?; + let mut batches = vec![]; + for bacth in reader { + batches.push(bacth?); + } + let schema = Arc::new(schema); + Ok((schema, batches)) +} diff --git a/datafusion/execution/src/cache/cache_unit.rs b/datafusion/execution/src/cache/cache_unit.rs index c54839061c8a..25f9b9fa4d68 100644 --- a/datafusion/execution/src/cache/cache_unit.rs +++ b/datafusion/execution/src/cache/cache_unit.rs @@ -228,7 +228,7 @@ mod tests { cache.put(&meta.location, vec![meta.clone()].into()); assert_eq!( - cache.get(&meta.location).unwrap().get(0).unwrap().clone(), + cache.get(&meta.location).unwrap().first().unwrap().clone(), meta.clone() ); } diff --git a/datafusion/execution/src/config.rs b/datafusion/execution/src/config.rs index cfcc205b5625..8556335b395a 100644 --- a/datafusion/execution/src/config.rs +++ b/datafusion/execution/src/config.rs @@ -86,7 +86,7 @@ impl SessionConfig { /// Set a generic `str` configuration option pub fn set_str(self, key: &str, value: &str) -> Self { - self.set(key, ScalarValue::Utf8(Some(value.to_string()))) + self.set(key, ScalarValue::from(value)) } /// Customize batch size diff --git a/datafusion/expr/src/built_in_function.rs b/datafusion/expr/src/built_in_function.rs index a51941fdee11..977b556b26cf 100644 --- a/datafusion/expr/src/built_in_function.rs +++ b/datafusion/expr/src/built_in_function.rs @@ -130,6 +130,8 @@ pub enum BuiltinScalarFunction { // array functions /// array_append ArrayAppend, + /// array_sort + ArraySort, /// array_concat ArrayConcat, /// array_has @@ -144,6 +146,8 @@ pub enum BuiltinScalarFunction { ArrayPopBack, /// array_dims ArrayDims, + /// array_distinct + ArrayDistinct, /// array_element ArrayElement, /// array_empty @@ -398,12 +402,14 @@ impl BuiltinScalarFunction { BuiltinScalarFunction::Tanh => Volatility::Immutable, BuiltinScalarFunction::Trunc => Volatility::Immutable, BuiltinScalarFunction::ArrayAppend => Volatility::Immutable, + BuiltinScalarFunction::ArraySort => Volatility::Immutable, BuiltinScalarFunction::ArrayConcat => Volatility::Immutable, BuiltinScalarFunction::ArrayEmpty => Volatility::Immutable, BuiltinScalarFunction::ArrayHasAll => Volatility::Immutable, BuiltinScalarFunction::ArrayHasAny => Volatility::Immutable, BuiltinScalarFunction::ArrayHas => Volatility::Immutable, BuiltinScalarFunction::ArrayDims => Volatility::Immutable, + BuiltinScalarFunction::ArrayDistinct => Volatility::Immutable, BuiltinScalarFunction::ArrayElement => Volatility::Immutable, BuiltinScalarFunction::ArrayExcept => Volatility::Immutable, BuiltinScalarFunction::ArrayLength => Volatility::Immutable, @@ -545,6 +551,7 @@ impl BuiltinScalarFunction { Ok(data_type) } BuiltinScalarFunction::ArrayAppend => Ok(input_expr_types[0].clone()), + BuiltinScalarFunction::ArraySort => Ok(input_expr_types[0].clone()), BuiltinScalarFunction::ArrayConcat => { let mut expr_type = Null; let mut max_dims = 0; @@ -582,6 +589,7 @@ impl BuiltinScalarFunction { BuiltinScalarFunction::ArrayDims => { Ok(List(Arc::new(Field::new("item", UInt64, true)))) } + BuiltinScalarFunction::ArrayDistinct => Ok(input_expr_types[0].clone()), BuiltinScalarFunction::ArrayElement => match &input_expr_types[0] { List(field) => Ok(field.data_type().clone()), _ => plan_err!( @@ -909,6 +917,9 @@ impl BuiltinScalarFunction { // for now, the list is small, as we do not have many built-in functions. match self { BuiltinScalarFunction::ArrayAppend => Signature::any(2, self.volatility()), + BuiltinScalarFunction::ArraySort => { + Signature::variadic_any(self.volatility()) + } BuiltinScalarFunction::ArrayPopFront => Signature::any(1, self.volatility()), BuiltinScalarFunction::ArrayPopBack => Signature::any(1, self.volatility()), BuiltinScalarFunction::ArrayConcat => { @@ -926,6 +937,7 @@ impl BuiltinScalarFunction { Signature::variadic_any(self.volatility()) } BuiltinScalarFunction::ArrayNdims => Signature::any(1, self.volatility()), + BuiltinScalarFunction::ArrayDistinct => Signature::any(1, self.volatility()), BuiltinScalarFunction::ArrayPosition => { Signature::variadic_any(self.volatility()) } @@ -1023,6 +1035,7 @@ impl BuiltinScalarFunction { 1, vec![ Int64, + Float64, Timestamp(Nanosecond, None), Timestamp(Microsecond, None), Timestamp(Millisecond, None), @@ -1557,10 +1570,12 @@ impl BuiltinScalarFunction { "array_push_back", "list_push_back", ], + BuiltinScalarFunction::ArraySort => &["array_sort", "list_sort"], BuiltinScalarFunction::ArrayConcat => { &["array_concat", "array_cat", "list_concat", "list_cat"] } BuiltinScalarFunction::ArrayDims => &["array_dims", "list_dims"], + BuiltinScalarFunction::ArrayDistinct => &["array_distinct", "list_distinct"], BuiltinScalarFunction::ArrayEmpty => &["empty"], BuiltinScalarFunction::ArrayElement => &[ "array_element", diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs index ee9b0ad6f967..f0aab95b8f0d 100644 --- a/datafusion/expr/src/expr.rs +++ b/datafusion/expr/src/expr.rs @@ -373,6 +373,24 @@ impl ScalarFunctionDefinition { ScalarFunctionDefinition::Name(func_name) => func_name.as_ref(), } } + + /// Whether this function is volatile, i.e. whether it can return different results + /// when evaluated multiple times with the same input. + pub fn is_volatile(&self) -> Result { + match self { + ScalarFunctionDefinition::BuiltIn(fun) => { + Ok(fun.volatility() == crate::Volatility::Volatile) + } + ScalarFunctionDefinition::UDF(udf) => { + Ok(udf.signature().volatility == crate::Volatility::Volatile) + } + ScalarFunctionDefinition::Name(func) => { + internal_err!( + "Cannot determine volatility of unresolved function: {func}" + ) + } + } + } } impl ScalarFunction { @@ -671,7 +689,7 @@ impl InSubquery { } } -/// Placeholder, representing bind parameter values such as `$1`. +/// Placeholder, representing bind parameter values such as `$1` or `$name`. /// /// The type of these parameters is inferred using [`Expr::infer_placeholder_types`] /// or can be specified directly using `PREPARE` statements. @@ -1044,7 +1062,7 @@ impl Expr { Expr::GetIndexedField(GetIndexedField { expr: Box::new(self), field: GetFieldAccess::NamedStructField { - name: ScalarValue::Utf8(Some(name.into())), + name: ScalarValue::from(name.into()), }, }) } @@ -1692,14 +1710,28 @@ fn create_names(exprs: &[Expr]) -> Result { .join(", ")) } +/// Whether the given expression is volatile, i.e. whether it can return different results +/// when evaluated multiple times with the same input. +pub fn is_volatile(expr: &Expr) -> Result { + match expr { + Expr::ScalarFunction(func) => func.func_def.is_volatile(), + _ => Ok(false), + } +} + #[cfg(test)] mod test { use crate::expr::Cast; use crate::expr_fn::col; - use crate::{case, lit, Expr}; + use crate::{ + case, lit, BuiltinScalarFunction, ColumnarValue, Expr, ReturnTypeFunction, + ScalarFunctionDefinition, ScalarFunctionImplementation, ScalarUDF, Signature, + Volatility, + }; use arrow::datatypes::DataType; use datafusion_common::Column; use datafusion_common::{Result, ScalarValue}; + use std::sync::Arc; #[test] fn format_case_when() -> Result<()> { @@ -1800,4 +1832,45 @@ mod test { "UInt32(1) OR UInt32(2)" ); } + + #[test] + fn test_is_volatile_scalar_func_definition() { + // BuiltIn + assert!( + ScalarFunctionDefinition::BuiltIn(BuiltinScalarFunction::Random) + .is_volatile() + .unwrap() + ); + assert!( + !ScalarFunctionDefinition::BuiltIn(BuiltinScalarFunction::Abs) + .is_volatile() + .unwrap() + ); + + // UDF + let return_type: ReturnTypeFunction = + Arc::new(move |_| Ok(Arc::new(DataType::Utf8))); + let fun: ScalarFunctionImplementation = + Arc::new(move |_| Ok(ColumnarValue::Scalar(ScalarValue::new_utf8("a")))); + let udf = Arc::new(ScalarUDF::new( + "TestScalarUDF", + &Signature::uniform(1, vec![DataType::Float32], Volatility::Stable), + &return_type, + &fun, + )); + assert!(!ScalarFunctionDefinition::UDF(udf).is_volatile().unwrap()); + + let udf = Arc::new(ScalarUDF::new( + "TestScalarUDF", + &Signature::uniform(1, vec![DataType::Float32], Volatility::Volatile), + &return_type, + &fun, + )); + assert!(ScalarFunctionDefinition::UDF(udf).is_volatile().unwrap()); + + // Unresolved function + ScalarFunctionDefinition::Name(Arc::from("UnresolvedFunc")) + .is_volatile() + .expect_err("Shouldn't determine volatility of unresolved function"); + } } diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs index 6148226f6b1a..cedf1d845137 100644 --- a/datafusion/expr/src/expr_fn.rs +++ b/datafusion/expr/src/expr_fn.rs @@ -583,6 +583,8 @@ scalar_expr!( "appends an element to the end of an array." ); +scalar_expr!(ArraySort, array_sort, array desc null_first, "returns sorted array."); + scalar_expr!( ArrayPopBack, array_pop_back, @@ -658,6 +660,12 @@ scalar_expr!( array, "returns the number of dimensions of the array." ); +scalar_expr!( + ArrayDistinct, + array_distinct, + array, + "return distinct values from the array after removing duplicates." +); scalar_expr!( ArrayPosition, array_position, @@ -1184,6 +1192,7 @@ mod test { test_scalar_expr!(FromUnixtime, from_unixtime, unixtime); test_scalar_expr!(ArrayAppend, array_append, array, element); + test_scalar_expr!(ArraySort, array_sort, array, desc, null_first); test_scalar_expr!(ArrayPopFront, array_pop_front, array); test_scalar_expr!(ArrayPopBack, array_pop_back, array); test_unary_scalar_expr!(ArrayDims, array_dims); diff --git a/datafusion/expr/src/expr_schema.rs b/datafusion/expr/src/expr_schema.rs index 2795ac5f0962..e5b0185d90e0 100644 --- a/datafusion/expr/src/expr_schema.rs +++ b/datafusion/expr/src/expr_schema.rs @@ -83,13 +83,12 @@ impl ExprSchemable for Expr { Expr::Cast(Cast { data_type, .. }) | Expr::TryCast(TryCast { data_type, .. }) => Ok(data_type.clone()), Expr::ScalarFunction(ScalarFunction { func_def, args }) => { + let arg_data_types = args + .iter() + .map(|e| e.get_type(schema)) + .collect::>>()?; match func_def { ScalarFunctionDefinition::BuiltIn(fun) => { - let arg_data_types = args - .iter() - .map(|e| e.get_type(schema)) - .collect::>>()?; - // verify that input data types is consistent with function's `TypeSignature` data_types(&arg_data_types, &fun.signature()).map_err(|_| { plan_datafusion_err!( @@ -105,11 +104,7 @@ impl ExprSchemable for Expr { fun.return_type(&arg_data_types) } ScalarFunctionDefinition::UDF(fun) => { - let data_types = args - .iter() - .map(|e| e.get_type(schema)) - .collect::>>()?; - Ok(fun.return_type(&data_types)?) + Ok(fun.return_type(&arg_data_types)?) } ScalarFunctionDefinition::Name(_) => { internal_err!("Function `Expr` with name should be resolved.") diff --git a/datafusion/expr/src/literal.rs b/datafusion/expr/src/literal.rs index effc31553819..2f04729af2ed 100644 --- a/datafusion/expr/src/literal.rs +++ b/datafusion/expr/src/literal.rs @@ -43,19 +43,19 @@ pub trait TimestampLiteral { impl Literal for &str { fn lit(&self) -> Expr { - Expr::Literal(ScalarValue::Utf8(Some((*self).to_owned()))) + Expr::Literal(ScalarValue::from(*self)) } } impl Literal for String { fn lit(&self) -> Expr { - Expr::Literal(ScalarValue::Utf8(Some((*self).to_owned()))) + Expr::Literal(ScalarValue::from(self.as_ref())) } } impl Literal for &String { fn lit(&self) -> Expr { - Expr::Literal(ScalarValue::Utf8(Some((*self).to_owned()))) + Expr::Literal(ScalarValue::from(self.as_ref())) } } diff --git a/datafusion/expr/src/logical_plan/builder.rs b/datafusion/expr/src/logical_plan/builder.rs index c4ff9fe95435..2264949cf42a 100644 --- a/datafusion/expr/src/logical_plan/builder.rs +++ b/datafusion/expr/src/logical_plan/builder.rs @@ -50,9 +50,9 @@ use crate::{ use arrow::datatypes::{DataType, Schema, SchemaRef}; use datafusion_common::display::ToStringifiedPlan; use datafusion_common::{ - plan_datafusion_err, plan_err, Column, DFField, DFSchema, DFSchemaRef, - DataFusionError, FileType, OwnedTableReference, Result, ScalarValue, TableReference, - ToDFSchema, UnnestOptions, + get_target_functional_dependencies, plan_datafusion_err, plan_err, Column, DFField, + DFSchema, DFSchemaRef, DataFusionError, FileType, OwnedTableReference, Result, + ScalarValue, TableReference, ToDFSchema, UnnestOptions, }; /// Default table name for unnamed table @@ -906,6 +906,9 @@ impl LogicalPlanBuilder { ) -> Result { let group_expr = normalize_cols(group_expr, &self.plan)?; let aggr_expr = normalize_cols(aggr_expr, &self.plan)?; + + let group_expr = + add_group_by_exprs_from_dependencies(group_expr, self.plan.schema())?; Aggregate::try_new(Arc::new(self.plan), group_expr, aggr_expr) .map(LogicalPlan::Aggregate) .map(Self::from) @@ -1166,10 +1169,46 @@ pub fn build_join_schema( ); let mut metadata = left.metadata().clone(); metadata.extend(right.metadata().clone()); - DFSchema::new_with_metadata(fields, metadata) - .map(|schema| schema.with_functional_dependencies(func_dependencies)) + let schema = DFSchema::new_with_metadata(fields, metadata)?; + schema.with_functional_dependencies(func_dependencies) } +/// Add additional "synthetic" group by expressions based on functional +/// dependencies. +/// +/// For example, if we are grouping on `[c1]`, and we know from +/// functional dependencies that column `c1` determines `c2`, this function +/// adds `c2` to the group by list. +/// +/// This allows MySQL style selects like +/// `SELECT col FROM t WHERE pk = 5` if col is unique +fn add_group_by_exprs_from_dependencies( + mut group_expr: Vec, + schema: &DFSchemaRef, +) -> Result> { + // Names of the fields produced by the GROUP BY exprs for example, `GROUP BY + // c1 + 1` produces an output field named `"c1 + 1"` + let mut group_by_field_names = group_expr + .iter() + .map(|e| e.display_name()) + .collect::>>()?; + + if let Some(target_indices) = + get_target_functional_dependencies(schema, &group_by_field_names) + { + for idx in target_indices { + let field = schema.field(idx); + let expr = + Expr::Column(Column::new(field.qualifier().cloned(), field.name())); + let expr_name = expr.display_name()?; + if !group_by_field_names.contains(&expr_name) { + group_by_field_names.push(expr_name); + group_expr.push(expr); + } + } + } + Ok(group_expr) +} /// Errors if one or more expressions have equal names. pub(crate) fn validate_unique_names<'a>( node_name: &str, @@ -1491,7 +1530,7 @@ pub fn unnest_with_options( let df_schema = DFSchema::new_with_metadata(fields, metadata)?; // We can use the existing functional dependencies: let deps = input_schema.functional_dependencies().clone(); - let schema = Arc::new(df_schema.with_functional_dependencies(deps)); + let schema = Arc::new(df_schema.with_functional_dependencies(deps)?); Ok(LogicalPlan::Unnest(Unnest { input: Arc::new(input), diff --git a/datafusion/expr/src/logical_plan/ddl.rs b/datafusion/expr/src/logical_plan/ddl.rs index 97551a941abf..e74992d99373 100644 --- a/datafusion/expr/src/logical_plan/ddl.rs +++ b/datafusion/expr/src/logical_plan/ddl.rs @@ -194,6 +194,8 @@ pub struct CreateExternalTable { pub options: HashMap, /// The list of constraints in the schema, such as primary key, unique, etc. pub constraints: Constraints, + /// Default values for columns + pub column_defaults: HashMap, } // Hashing refers to a subset of fields considered in PartialEq. diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs index ea7a48d2c4f4..d74015bf094d 100644 --- a/datafusion/expr/src/logical_plan/plan.rs +++ b/datafusion/expr/src/logical_plan/plan.rs @@ -33,6 +33,7 @@ use crate::logical_plan::{DmlStatement, Statement}; use crate::utils::{ enumerate_grouping_sets, exprlist_to_fields, find_out_reference_exprs, grouping_set_expr_count, grouping_set_to_exprlist, inspect_expr_pre, + split_conjunction, }; use crate::{ build_join_schema, expr_vec_fmt, BinaryExpr, CreateMemoryTable, CreateView, Expr, @@ -47,8 +48,8 @@ use datafusion_common::tree_node::{ }; use datafusion_common::{ aggregate_functional_dependencies, internal_err, plan_err, Column, Constraints, - DFField, DFSchema, DFSchemaRef, DataFusionError, FunctionalDependencies, - OwnedTableReference, Result, ScalarValue, UnnestOptions, + DFField, DFSchema, DFSchemaRef, DataFusionError, Dependency, FunctionalDependencies, + OwnedTableReference, ParamValues, Result, UnnestOptions, }; // backwards compatibility pub use datafusion_common::display::{PlanType, StringifiedPlan, ToStringifiedPlan}; @@ -877,19 +878,19 @@ impl LogicalPlan { input: Arc::new(inputs[0].clone()), })) } - LogicalPlan::Explain(_) => { - // Explain should be handled specially in the optimizers; - // If this check cannot pass it means some optimizer pass is - // trying to optimize Explain directly - if expr.is_empty() { - return plan_err!("Invalid EXPLAIN command. Expression is empty"); - } - - if inputs.is_empty() { - return plan_err!("Invalid EXPLAIN command. Inputs are empty"); - } - - Ok(self.clone()) + LogicalPlan::Explain(e) => { + assert!( + expr.is_empty(), + "Invalid EXPLAIN command. Expression should empty" + ); + assert_eq!(inputs.len(), 1, "Invalid EXPLAIN command. Inputs are empty"); + Ok(LogicalPlan::Explain(Explain { + verbose: e.verbose, + plan: Arc::new(inputs[0].clone()), + stringified_plans: e.stringified_plans.clone(), + schema: e.schema.clone(), + logical_optimization_succeeded: e.logical_optimization_succeeded, + })) } LogicalPlan::Prepare(Prepare { name, data_types, .. @@ -945,7 +946,7 @@ impl LogicalPlan { // We can use the existing functional dependencies as is: .with_functional_dependencies( input.schema().functional_dependencies().clone(), - ), + )?, ); Ok(LogicalPlan::Unnest(Unnest { @@ -976,9 +977,10 @@ impl LogicalPlan { /// .filter(col("id").eq(placeholder("$1"))).unwrap() /// .build().unwrap(); /// - /// assert_eq!("Filter: t1.id = $1\ - /// \n TableScan: t1", - /// plan.display_indent().to_string() + /// assert_eq!( + /// "Filter: t1.id = $1\ + /// \n TableScan: t1", + /// plan.display_indent().to_string() /// ); /// /// // Fill in the parameter $1 with a literal 3 @@ -986,39 +988,37 @@ impl LogicalPlan { /// ScalarValue::from(3i32) // value at index 0 --> $1 /// ]).unwrap(); /// - /// assert_eq!("Filter: t1.id = Int32(3)\ - /// \n TableScan: t1", - /// plan.display_indent().to_string() + /// assert_eq!( + /// "Filter: t1.id = Int32(3)\ + /// \n TableScan: t1", + /// plan.display_indent().to_string() /// ); + /// + /// // Note you can also used named parameters + /// // Build SELECT * FROM t1 WHRERE id = $my_param + /// let plan = table_scan(Some("t1"), &schema, None).unwrap() + /// .filter(col("id").eq(placeholder("$my_param"))).unwrap() + /// .build().unwrap() + /// // Fill in the parameter $my_param with a literal 3 + /// .with_param_values(vec![ + /// ("my_param", ScalarValue::from(3i32)), + /// ]).unwrap(); + /// + /// assert_eq!( + /// "Filter: t1.id = Int32(3)\ + /// \n TableScan: t1", + /// plan.display_indent().to_string() + /// ); + /// /// ``` pub fn with_param_values( self, - param_values: Vec, + param_values: impl Into, ) -> Result { + let param_values = param_values.into(); match self { LogicalPlan::Prepare(prepare_lp) => { - // Verify if the number of params matches the number of values - if prepare_lp.data_types.len() != param_values.len() { - return plan_err!( - "Expected {} parameters, got {}", - prepare_lp.data_types.len(), - param_values.len() - ); - } - - // Verify if the types of the params matches the types of the values - let iter = prepare_lp.data_types.iter().zip(param_values.iter()); - for (i, (param_type, value)) in iter.enumerate() { - if *param_type != value.data_type() { - return plan_err!( - "Expected parameter of type {:?}, got {:?} at index {}", - param_type, - value.data_type(), - i - ); - } - } - + param_values.verify(&prepare_lp.data_types)?; let input_plan = prepare_lp.input; input_plan.replace_params_with_values(¶m_values) } @@ -1033,7 +1033,13 @@ impl LogicalPlan { pub fn max_rows(self: &LogicalPlan) -> Option { match self { LogicalPlan::Projection(Projection { input, .. }) => input.max_rows(), - LogicalPlan::Filter(Filter { input, .. }) => input.max_rows(), + LogicalPlan::Filter(filter) => { + if filter.is_scalar() { + Some(1) + } else { + filter.input.max_rows() + } + } LogicalPlan::Window(Window { input, .. }) => input.max_rows(), LogicalPlan::Aggregate(Aggregate { input, group_expr, .. @@ -1182,7 +1188,7 @@ impl LogicalPlan { /// See [`Self::with_param_values`] for examples and usage pub fn replace_params_with_values( &self, - param_values: &[ScalarValue], + param_values: &ParamValues, ) -> Result { let new_exprs = self .expressions() @@ -1202,7 +1208,7 @@ impl LogicalPlan { self.with_new_exprs(new_exprs, &new_inputs_with_values) } - /// Walk the logical plan, find any `PlaceHolder` tokens, and return a map of their IDs and DataTypes + /// Walk the logical plan, find any `Placeholder` tokens, and return a map of their IDs and DataTypes pub fn get_parameter_types( &self, ) -> Result>, DataFusionError> { @@ -1239,36 +1245,15 @@ impl LogicalPlan { /// corresponding values provided in the params_values fn replace_placeholders_with_values( expr: Expr, - param_values: &[ScalarValue], + param_values: &ParamValues, ) -> Result { expr.transform(&|expr| { match &expr { Expr::Placeholder(Placeholder { id, data_type }) => { - if id.is_empty() || id == "$0" { - return plan_err!("Empty placeholder id"); - } - // convert id (in format $1, $2, ..) to idx (0, 1, ..) - let idx = id[1..].parse::().map_err(|e| { - DataFusionError::Internal(format!( - "Failed to parse placeholder id: {e}" - )) - })? - 1; - // value at the idx-th position in param_values should be the value for the placeholder - let value = param_values.get(idx).ok_or_else(|| { - DataFusionError::Internal(format!( - "No value found for placeholder with id {id}" - )) - })?; - // check if the data type of the value matches the data type of the placeholder - if Some(value.data_type()) != *data_type { - return internal_err!( - "Placeholder value type mismatch: expected {:?}, got {:?}", - data_type, - value.data_type() - ); - } + let value = + param_values.get_placeholders_with_values(id, data_type)?; // Replace the placeholder with the value - Ok(Transformed::Yes(Expr::Literal(value.clone()))) + Ok(Transformed::Yes(Expr::Literal(value))) } Expr::ScalarSubquery(qry) => { let subquery = @@ -1849,8 +1834,9 @@ pub fn projection_schema(input: &LogicalPlan, exprs: &[Expr]) -> Result`, where its schema has a unique filter that is covered + /// by this conjunction. + /// + /// For example, for the table: + /// ```sql + /// CREATE TABLE t (a INTEGER PRIMARY KEY, b INTEGER); + /// ``` + /// `Filter(a = 2).is_scalar() == true` + /// , whereas + /// `Filter(b = 2).is_scalar() == false` + /// and + /// `Filter(a = 2 OR b = 2).is_scalar() == false` + fn is_scalar(&self) -> bool { + let schema = self.input.schema(); + + let functional_dependencies = self.input.schema().functional_dependencies(); + let unique_keys = functional_dependencies.iter().filter(|dep| { + let nullable = dep.nullable + && dep + .source_indices + .iter() + .any(|&source| schema.field(source).is_nullable()); + !nullable + && dep.mode == Dependency::Single + && dep.target_indices.len() == schema.fields().len() + }); + + let exprs = split_conjunction(&self.predicate); + let eq_pred_cols: HashSet<_> = exprs + .iter() + .filter_map(|expr| { + let Expr::BinaryExpr(BinaryExpr { + left, + op: Operator::Eq, + right, + }) = expr + else { + return None; + }; + // This is a no-op filter expression + if left == right { + return None; + } + + match (left.as_ref(), right.as_ref()) { + (Expr::Column(_), Expr::Column(_)) => None, + (Expr::Column(c), _) | (_, Expr::Column(c)) => { + Some(schema.index_of_column(c).unwrap()) + } + _ => None, + } + }) + .collect(); + + // If we have a functional dependence that is a subset of our predicate, + // this filter is scalar + for key in unique_keys { + if key.source_indices.iter().all(|c| eq_pred_cols.contains(c)) { + return true; + } + } + false + } } /// Window its input based on a set of window spec and window function (e.g. SUM or RANK) @@ -1965,7 +2018,7 @@ impl Window { window_expr, schema: Arc::new( DFSchema::new_with_metadata(window_fields, metadata)? - .with_functional_dependencies(window_func_dependencies), + .with_functional_dependencies(window_func_dependencies)?, ), }) } @@ -2035,7 +2088,7 @@ impl TableScan { .map(|p| { let projected_func_dependencies = func_dependencies.project_functional_dependencies(p, p.len()); - DFSchema::new_with_metadata( + let df_schema = DFSchema::new_with_metadata( p.iter() .map(|i| { DFField::from_qualified( @@ -2045,15 +2098,13 @@ impl TableScan { }) .collect(), schema.metadata().clone(), - ) - .map(|df_schema| { - df_schema.with_functional_dependencies(projected_func_dependencies) - }) + )?; + df_schema.with_functional_dependencies(projected_func_dependencies) }) .unwrap_or_else(|| { - DFSchema::try_from_qualified_schema(table_name.clone(), &schema).map( - |df_schema| df_schema.with_functional_dependencies(func_dependencies), - ) + let df_schema = + DFSchema::try_from_qualified_schema(table_name.clone(), &schema)?; + df_schema.with_functional_dependencies(func_dependencies) })?; let projected_schema = Arc::new(projected_schema); Ok(Self { @@ -2365,7 +2416,7 @@ impl Aggregate { calc_func_dependencies_for_aggregate(&group_expr, &input, &schema)?; let new_schema = schema.as_ref().clone(); let schema = Arc::new( - new_schema.with_functional_dependencies(aggregate_func_dependencies), + new_schema.with_functional_dependencies(aggregate_func_dependencies)?, ); Ok(Self { input, @@ -2575,13 +2626,19 @@ pub struct Unnest { #[cfg(test)] mod tests { + use std::collections::HashMap; + use std::sync::Arc; + use super::*; + use crate::builder::LogicalTableSource; use crate::logical_plan::table_scan; use crate::{col, count, exists, in_subquery, lit, placeholder, GroupingSet}; + use arrow::datatypes::{DataType, Field, Schema}; use datafusion_common::tree_node::TreeNodeVisitor; - use datafusion_common::{not_impl_err, DFSchema, TableReference}; - use std::collections::HashMap; + use datafusion_common::{ + not_impl_err, Constraint, DFSchema, ScalarValue, TableReference, + }; fn employee_schema() -> Schema { Schema::new(vec![ @@ -3028,7 +3085,8 @@ digraph { .build() .unwrap(); - plan.replace_params_with_values(&[42i32.into()]) + let param_values = vec![ScalarValue::Int32(Some(42))]; + plan.replace_params_with_values(¶m_values.clone().into()) .expect_err("unexpectedly succeeded to replace an invalid placeholder"); // test $0 placeholder @@ -3041,7 +3099,7 @@ digraph { .build() .unwrap(); - plan.replace_params_with_values(&[42i32.into()]) + plan.replace_params_with_values(¶m_values.into()) .expect_err("unexpectedly succeeded to replace an invalid placeholder"); } @@ -3076,4 +3134,106 @@ digraph { .unwrap() .is_nullable()); } + + #[test] + fn test_filter_is_scalar() { + // test empty placeholder + let schema = + Arc::new(Schema::new(vec![Field::new("id", DataType::Int32, false)])); + + let source = Arc::new(LogicalTableSource::new(schema)); + let schema = Arc::new( + DFSchema::try_from_qualified_schema( + TableReference::bare("tab"), + &source.schema(), + ) + .unwrap(), + ); + let scan = Arc::new(LogicalPlan::TableScan(TableScan { + table_name: TableReference::bare("tab"), + source: source.clone(), + projection: None, + projected_schema: schema.clone(), + filters: vec![], + fetch: None, + })); + let col = schema.field(0).qualified_column(); + + let filter = Filter::try_new( + Expr::Column(col).eq(Expr::Literal(ScalarValue::Int32(Some(1)))), + scan, + ) + .unwrap(); + assert!(!filter.is_scalar()); + let unique_schema = Arc::new( + schema + .as_ref() + .clone() + .with_functional_dependencies( + FunctionalDependencies::new_from_constraints( + Some(&Constraints::new_unverified(vec![Constraint::Unique( + vec![0], + )])), + 1, + ), + ) + .unwrap(), + ); + let scan = Arc::new(LogicalPlan::TableScan(TableScan { + table_name: TableReference::bare("tab"), + source, + projection: None, + projected_schema: unique_schema.clone(), + filters: vec![], + fetch: None, + })); + let col = schema.field(0).qualified_column(); + + let filter = Filter::try_new( + Expr::Column(col).eq(Expr::Literal(ScalarValue::Int32(Some(1)))), + scan, + ) + .unwrap(); + assert!(filter.is_scalar()); + } + + #[test] + fn test_transform_explain() { + let schema = Schema::new(vec![ + Field::new("foo", DataType::Int32, false), + Field::new("bar", DataType::Int32, false), + ]); + + let plan = table_scan(TableReference::none(), &schema, None) + .unwrap() + .explain(false, false) + .unwrap() + .build() + .unwrap(); + + let external_filter = + col("foo").eq(Expr::Literal(ScalarValue::Boolean(Some(true)))); + + // after transformation, because plan is not the same anymore, + // the parent plan is built again with call to LogicalPlan::with_new_inputs -> with_new_exprs + let plan = plan + .transform(&|plan| match plan { + LogicalPlan::TableScan(table) => { + let filter = Filter::try_new( + external_filter.clone(), + Arc::new(LogicalPlan::TableScan(table)), + ) + .unwrap(); + Ok(Transformed::Yes(LogicalPlan::Filter(filter))) + } + x => Ok(Transformed::No(x)), + }) + .unwrap(); + + let expected = "Explain\ + \n Filter: foo = Boolean(true)\ + \n TableScan: ?table?"; + let actual = format!("{}", plan.display_indent()); + assert_eq!(expected.to_string(), actual) + } } diff --git a/datafusion/expr/src/type_coercion/binary.rs b/datafusion/expr/src/type_coercion/binary.rs index 1027e97d061a..dd9449198796 100644 --- a/datafusion/expr/src/type_coercion/binary.rs +++ b/datafusion/expr/src/type_coercion/binary.rs @@ -116,7 +116,7 @@ fn signature(lhs: &DataType, op: &Operator, rhs: &DataType) -> Result }) } AtArrow | ArrowAt => { - // ArrowAt and AtArrow check for whether one array ic contained in another. + // ArrowAt and AtArrow check for whether one array is contained in another. // The result type is boolean. Signature::comparison defines this signature. // Operation has nothing to do with comparison array_coercion(lhs, rhs).map(Signature::comparison).ok_or_else(|| { diff --git a/datafusion/expr/src/utils.rs b/datafusion/expr/src/utils.rs index 7d126a0f3373..abdd7f5f57f6 100644 --- a/datafusion/expr/src/utils.rs +++ b/datafusion/expr/src/utils.rs @@ -17,6 +17,10 @@ //! Expression utilities +use std::cmp::Ordering; +use std::collections::HashSet; +use std::sync::Arc; + use crate::expr::{Alias, Sort, WindowFunction}; use crate::expr_rewriter::strip_outer_reference; use crate::logical_plan::Aggregate; @@ -25,16 +29,15 @@ use crate::{ and, BinaryExpr, Cast, Expr, ExprSchemable, Filter, GroupingSet, LogicalPlan, Operator, TryCast, }; + use arrow::datatypes::{DataType, TimeUnit}; use datafusion_common::tree_node::{TreeNode, VisitRecursion}; use datafusion_common::{ internal_err, plan_datafusion_err, plan_err, Column, DFField, DFSchema, DFSchemaRef, DataFusionError, Result, ScalarValue, TableReference, }; + use sqlparser::ast::{ExceptSelectItem, ExcludeSelectItem, WildcardAdditionalOptions}; -use std::cmp::Ordering; -use std::collections::HashSet; -use std::sync::Arc; /// The value to which `COUNT(*)` is expanded to in /// `COUNT()` expressions @@ -433,7 +436,7 @@ pub fn expand_qualified_wildcard( let qualified_schema = DFSchema::new_with_metadata(qualified_fields, schema.metadata().clone())? // We can use the functional dependencies as is, since it only stores indices: - .with_functional_dependencies(schema.functional_dependencies().clone()); + .with_functional_dependencies(schema.functional_dependencies().clone())?; let excluded_columns = if let Some(WildcardAdditionalOptions { opt_exclude, opt_except, @@ -501,7 +504,6 @@ pub fn generate_sort_key( let res = final_sort_keys .into_iter() .zip(is_partition_flag) - .map(|(lhs, rhs)| (lhs, rhs)) .collect::>(); Ok(res) } @@ -731,11 +733,7 @@ fn agg_cols(agg: &Aggregate) -> Vec { .collect() } -fn exprlist_to_fields_aggregate( - exprs: &[Expr], - plan: &LogicalPlan, - agg: &Aggregate, -) -> Result> { +fn exprlist_to_fields_aggregate(exprs: &[Expr], agg: &Aggregate) -> Result> { let agg_cols = agg_cols(agg); let mut fields = vec![]; for expr in exprs { @@ -744,7 +742,7 @@ fn exprlist_to_fields_aggregate( // resolve against schema of input to aggregate fields.push(expr.to_field(agg.input.schema())?); } - _ => fields.push(expr.to_field(plan.schema())?), + _ => fields.push(expr.to_field(&agg.schema)?), } } Ok(fields) @@ -761,15 +759,7 @@ pub fn exprlist_to_fields<'a>( // `GROUPING(person.state)` so in order to resolve `person.state` in this case we need to // look at the input to the aggregate instead. let fields = match plan { - LogicalPlan::Aggregate(agg) => { - Some(exprlist_to_fields_aggregate(&exprs, plan, agg)) - } - LogicalPlan::Window(window) => match window.input.as_ref() { - LogicalPlan::Aggregate(agg) => { - Some(exprlist_to_fields_aggregate(&exprs, plan, agg)) - } - _ => None, - }, + LogicalPlan::Aggregate(agg) => Some(exprlist_to_fields_aggregate(&exprs, agg)), _ => None, }; if let Some(fields) = fields { @@ -1241,10 +1231,9 @@ pub fn merge_schema(inputs: Vec<&LogicalPlan>) -> DFSchema { #[cfg(test)] mod tests { use super::*; - use crate::expr_vec_fmt; use crate::{ - col, cube, expr, grouping_set, lit, rollup, AggregateFunction, WindowFrame, - WindowFunction, + col, cube, expr, expr_vec_fmt, grouping_set, lit, rollup, AggregateFunction, + WindowFrame, WindowFunction, }; #[test] diff --git a/datafusion/expr/src/window_frame.rs b/datafusion/expr/src/window_frame.rs index 5f161b85dd9a..2701ca1ecf3b 100644 --- a/datafusion/expr/src/window_frame.rs +++ b/datafusion/expr/src/window_frame.rs @@ -23,6 +23,8 @@ //! - An ending frame boundary, //! - An EXCLUDE clause. +use crate::expr::Sort; +use crate::Expr; use datafusion_common::{plan_err, sql_err, DataFusionError, Result, ScalarValue}; use sqlparser::ast; use sqlparser::parser::ParserError::ParserError; @@ -142,31 +144,57 @@ impl WindowFrame { } } -/// Construct equivalent explicit window frames for implicit corner cases. -/// With this processing, we may assume in downstream code that RANGE/GROUPS -/// frames contain an appropriate ORDER BY clause. -pub fn regularize(mut frame: WindowFrame, order_bys: usize) -> Result { - if frame.units == WindowFrameUnits::Range && order_bys != 1 { +/// Regularizes ORDER BY clause for window definition for implicit corner cases. +pub fn regularize_window_order_by( + frame: &WindowFrame, + order_by: &mut Vec, +) -> Result<()> { + if frame.units == WindowFrameUnits::Range && order_by.len() != 1 { // Normally, RANGE frames require an ORDER BY clause with exactly one - // column. However, an ORDER BY clause may be absent in two edge cases. + // column. However, an ORDER BY clause may be absent or present but with + // more than one column in two edge cases: + // 1. start bound is UNBOUNDED or CURRENT ROW + // 2. end bound is CURRENT ROW or UNBOUNDED. + // In these cases, we regularize the ORDER BY clause if the ORDER BY clause + // is absent. If an ORDER BY clause is present but has more than one column, + // the ORDER BY clause is unchanged. Note that this follows Postgres behavior. if (frame.start_bound.is_unbounded() || frame.start_bound == WindowFrameBound::CurrentRow) && (frame.end_bound == WindowFrameBound::CurrentRow || frame.end_bound.is_unbounded()) { - if order_bys == 0 { - frame.units = WindowFrameUnits::Rows; - frame.start_bound = - WindowFrameBound::Preceding(ScalarValue::UInt64(None)); - frame.end_bound = WindowFrameBound::Following(ScalarValue::UInt64(None)); + // If an ORDER BY clause is absent, it is equivalent to a ORDER BY clause + // with constant value as sort key. + // If an ORDER BY clause is present but has more than one column, it is + // unchanged. + if order_by.is_empty() { + order_by.push(Expr::Sort(Sort::new( + Box::new(Expr::Literal(ScalarValue::UInt64(Some(1)))), + true, + false, + ))); } - } else { + } + } + Ok(()) +} + +/// Checks if given window frame is valid. In particular, if the frame is RANGE +/// with offset PRECEDING/FOLLOWING, it must have exactly one ORDER BY column. +pub fn check_window_frame(frame: &WindowFrame, order_bys: usize) -> Result<()> { + if frame.units == WindowFrameUnits::Range && order_bys != 1 { + // See `regularize_window_order_by`. + if !(frame.start_bound.is_unbounded() + || frame.start_bound == WindowFrameBound::CurrentRow) + || !(frame.end_bound == WindowFrameBound::CurrentRow + || frame.end_bound.is_unbounded()) + { plan_err!("RANGE requires exactly one ORDER BY column")? } } else if frame.units == WindowFrameUnits::Groups && order_bys == 0 { plan_err!("GROUPS requires an ORDER BY clause")? }; - Ok(frame) + Ok(()) } /// There are five ways to describe starting and ending frame boundaries: diff --git a/datafusion/expr/src/window_function.rs b/datafusion/expr/src/window_function.rs index 946a80dd844a..610f1ecaeae9 100644 --- a/datafusion/expr/src/window_function.rs +++ b/datafusion/expr/src/window_function.rs @@ -268,7 +268,20 @@ impl BuiltInWindowFunction { BuiltInWindowFunction::FirstValue | BuiltInWindowFunction::LastValue => { Signature::any(1, Volatility::Immutable) } - BuiltInWindowFunction::Ntile => Signature::any(1, Volatility::Immutable), + BuiltInWindowFunction::Ntile => Signature::uniform( + 1, + vec![ + DataType::UInt64, + DataType::UInt32, + DataType::UInt16, + DataType::UInt8, + DataType::Int64, + DataType::Int32, + DataType::Int16, + DataType::Int8, + ], + Volatility::Immutable, + ), BuiltInWindowFunction::NthValue => Signature::any(2, Volatility::Immutable), } } diff --git a/datafusion/optimizer/README.md b/datafusion/optimizer/README.md index b8e5b93e6692..4f9e0fb98526 100644 --- a/datafusion/optimizer/README.md +++ b/datafusion/optimizer/README.md @@ -153,7 +153,7 @@ Looking at the `EXPLAIN` output we can see that the optimizer has effectively re | logical_plan | Projection: Int64(3) AS Int64(1) + Int64(2) | | | EmptyRelation | | physical_plan | ProjectionExec: expr=[3 as Int64(1) + Int64(2)] | -| | EmptyExec: produce_one_row=true | +| | PlaceholderRowExec | | | | +---------------+-------------------------------------------------+ ``` @@ -318,7 +318,7 @@ In the following example, the `type_coercion` and `simplify_expressions` passes | logical_plan | Projection: Utf8("3.2") AS foo | | | EmptyRelation | | initial_physical_plan | ProjectionExec: expr=[3.2 as foo] | -| | EmptyExec: produce_one_row=true | +| | PlaceholderRowExec | | | | | physical_plan after aggregate_statistics | SAME TEXT AS ABOVE | | physical_plan after join_selection | SAME TEXT AS ABOVE | @@ -326,7 +326,7 @@ In the following example, the `type_coercion` and `simplify_expressions` passes | physical_plan after repartition | SAME TEXT AS ABOVE | | physical_plan after add_merge_exec | SAME TEXT AS ABOVE | | physical_plan | ProjectionExec: expr=[3.2 as foo] | -| | EmptyExec: produce_one_row=true | +| | PlaceholderRowExec | | | | +------------------------------------------------------------+---------------------------------------------------------------------------+ ``` diff --git a/datafusion/optimizer/src/analyzer/type_coercion.rs b/datafusion/optimizer/src/analyzer/type_coercion.rs index e3b86f5db78f..91611251d9dd 100644 --- a/datafusion/optimizer/src/analyzer/type_coercion.rs +++ b/datafusion/optimizer/src/analyzer/type_coercion.rs @@ -503,7 +503,10 @@ fn coerce_window_frame( let target_type = match window_frame.units { WindowFrameUnits::Range => { if let Some(col_type) = current_types.first() { - if col_type.is_numeric() || is_utf8_or_large_utf8(col_type) { + if col_type.is_numeric() + || is_utf8_or_large_utf8(col_type) + || matches!(col_type, DataType::Null) + { col_type } else if is_datetime(col_type) { &DataType::Interval(IntervalUnit::MonthDayNano) diff --git a/datafusion/optimizer/src/common_subexpr_eliminate.rs b/datafusion/optimizer/src/common_subexpr_eliminate.rs index 1d21407a6985..1e089257c61a 100644 --- a/datafusion/optimizer/src/common_subexpr_eliminate.rs +++ b/datafusion/optimizer/src/common_subexpr_eliminate.rs @@ -29,7 +29,7 @@ use datafusion_common::tree_node::{ use datafusion_common::{ internal_err, Column, DFField, DFSchema, DFSchemaRef, DataFusionError, Result, }; -use datafusion_expr::expr::Alias; +use datafusion_expr::expr::{is_volatile, Alias}; use datafusion_expr::logical_plan::{ Aggregate, Filter, LogicalPlan, Projection, Sort, Window, }; @@ -113,6 +113,8 @@ impl CommonSubexprEliminate { let Projection { expr, input, .. } = projection; let input_schema = Arc::clone(input.schema()); let mut expr_set = ExprSet::new(); + + // Visit expr list and build expr identifier to occuring count map (`expr_set`). let arrays = to_arrays(expr, input_schema, &mut expr_set, ExprMask::Normal)?; let (mut new_expr, new_input) = @@ -516,7 +518,7 @@ enum ExprMask { } impl ExprMask { - fn ignores(&self, expr: &Expr) -> bool { + fn ignores(&self, expr: &Expr) -> Result { let is_normal_minus_aggregates = matches!( expr, Expr::Literal(..) @@ -527,12 +529,14 @@ impl ExprMask { | Expr::Wildcard { .. } ); + let is_volatile = is_volatile(expr)?; + let is_aggr = matches!(expr, Expr::AggregateFunction(..)); - match self { - Self::Normal => is_normal_minus_aggregates || is_aggr, - Self::NormalAndAggregates => is_normal_minus_aggregates, - } + Ok(match self { + Self::Normal => is_volatile || is_normal_minus_aggregates || is_aggr, + Self::NormalAndAggregates => is_volatile || is_normal_minus_aggregates, + }) } } @@ -624,7 +628,7 @@ impl TreeNodeVisitor for ExprIdentifierVisitor<'_> { let (idx, sub_expr_desc) = self.pop_enter_mark(); // skip exprs should not be recognize. - if self.expr_mask.ignores(expr) { + if self.expr_mask.ignores(expr)? { self.id_array[idx].0 = self.series_number; let desc = Self::desc_expr(expr); self.visit_stack.push(VisitRecord::ExprItem(desc)); diff --git a/datafusion/optimizer/src/eliminate_limit.rs b/datafusion/optimizer/src/eliminate_limit.rs index 7844ca7909fc..4386253740aa 100644 --- a/datafusion/optimizer/src/eliminate_limit.rs +++ b/datafusion/optimizer/src/eliminate_limit.rs @@ -97,7 +97,7 @@ mod tests { let optimizer = Optimizer::with_rules(vec![Arc::new(EliminateLimit::new())]); let optimized_plan = optimizer .optimize_recursively( - optimizer.rules.get(0).unwrap(), + optimizer.rules.first().unwrap(), plan, &OptimizerContext::new(), )? diff --git a/datafusion/optimizer/src/optimize_projections.rs b/datafusion/optimizer/src/optimize_projections.rs index bbf704a83c55..7ae9f7edf5e5 100644 --- a/datafusion/optimizer/src/optimize_projections.rs +++ b/datafusion/optimizer/src/optimize_projections.rs @@ -15,33 +15,42 @@ // specific language governing permissions and limitations // under the License. -//! Optimizer rule to prune unnecessary Columns from the intermediate schemas inside the [LogicalPlan]. -//! This rule -//! - Removes unnecessary columns that are not showed at the output, and that are not used during computation. -//! - Adds projection to decrease table column size before operators that benefits from less memory at its input. -//! - Removes unnecessary [LogicalPlan::Projection] from the [LogicalPlan]. +//! Optimizer rule to prune unnecessary columns from intermediate schemas +//! inside the [`LogicalPlan`]. This rule: +//! - Removes unnecessary columns that do not appear at the output and/or are +//! not used during any computation step. +//! - Adds projections to decrease table column size before operators that +//! benefit from a smaller memory footprint at its input. +//! - Removes unnecessary [`LogicalPlan::Projection`]s from the [`LogicalPlan`]. + +use std::collections::HashSet; +use std::sync::Arc; + use crate::optimizer::ApplyOrder; -use datafusion_common::{Column, DFSchema, DFSchemaRef, JoinType, Result}; -use datafusion_expr::expr::{Alias, ScalarFunction}; +use crate::{OptimizerConfig, OptimizerRule}; + +use arrow::datatypes::SchemaRef; +use datafusion_common::{ + get_required_group_by_exprs_indices, Column, DFSchema, DFSchemaRef, JoinType, Result, +}; +use datafusion_expr::expr::{Alias, ScalarFunction, ScalarFunctionDefinition}; use datafusion_expr::{ logical_plan::LogicalPlan, projection_schema, Aggregate, BinaryExpr, Cast, Distinct, - Expr, Projection, ScalarFunctionDefinition, TableScan, Window, + Expr, GroupingSet, Projection, TableScan, Window, }; + use hashbrown::HashMap; use itertools::{izip, Itertools}; -use std::collections::HashSet; -use std::sync::Arc; - -use crate::{OptimizerConfig, OptimizerRule}; -/// A rule for optimizing logical plans by removing unused Columns/Fields. +/// A rule for optimizing logical plans by removing unused columns/fields. /// -/// `OptimizeProjections` is an optimizer rule that identifies and eliminates columns from a logical plan -/// that are not used in any downstream operations. This can improve query performance and reduce unnecessary -/// data processing. +/// `OptimizeProjections` is an optimizer rule that identifies and eliminates +/// columns from a logical plan that are not used by downstream operations. +/// This can improve query performance and reduce unnecessary data processing. /// -/// The rule analyzes the input logical plan, determines the necessary column indices, and then removes any -/// unnecessary columns. Additionally, it eliminates any unnecessary projections in the plan. +/// The rule analyzes the input logical plan, determines the necessary column +/// indices, and then removes any unnecessary columns. It also removes any +/// unnecessary projections from the plan tree. #[derive(Default)] pub struct OptimizeProjections {} @@ -58,8 +67,8 @@ impl OptimizerRule for OptimizeProjections { plan: &LogicalPlan, config: &dyn OptimizerConfig, ) -> Result> { - // All of the fields at the output are necessary. - let indices = require_all_indices(plan); + // All output fields are necessary: + let indices = (0..plan.schema().fields().len()).collect::>(); optimize_projections(plan, config, &indices) } @@ -72,30 +81,35 @@ impl OptimizerRule for OptimizeProjections { } } -/// Removes unnecessary columns (e.g Columns that are not referred at the output schema and -/// Columns that are not used during any computation, expression evaluation) from the logical plan and its inputs. +/// Removes unnecessary columns (e.g. columns that do not appear in the output +/// schema and/or are not used during any computation step such as expression +/// evaluation) from the logical plan and its inputs. /// -/// # Arguments +/// # Parameters /// -/// - `plan`: A reference to the input `LogicalPlan` to be optimized. -/// - `_config`: A reference to the optimizer configuration (not currently used). -/// - `indices`: A slice of column indices that represent the necessary column indices for downstream operations. +/// - `plan`: A reference to the input `LogicalPlan` to optimize. +/// - `config`: A reference to the optimizer configuration. +/// - `indices`: A slice of column indices that represent the necessary column +/// indices for downstream operations. /// /// # Returns /// -/// - `Ok(Some(LogicalPlan))`: An optimized `LogicalPlan` with unnecessary columns removed. -/// - `Ok(None)`: If the optimization process results in a logical plan that doesn't require further propagation. -/// - `Err(error)`: If an error occurs during the optimization process. +/// A `Result` object with the following semantics: +/// +/// - `Ok(Some(LogicalPlan))`: An optimized `LogicalPlan` without unnecessary +/// columns. +/// - `Ok(None)`: Signal that the given logical plan did not require any change. +/// - `Err(error)`: An error occured during the optimization process. fn optimize_projections( plan: &LogicalPlan, - _config: &dyn OptimizerConfig, + config: &dyn OptimizerConfig, indices: &[usize], ) -> Result> { // `child_required_indices` stores // - indices of the columns required for each child // - a flag indicating whether putting a projection above children is beneficial for the parent. // As an example LogicalPlan::Filter benefits from small tables. Hence for filter child this flag would be `true`. - let child_required_indices: Option, bool)>> = match plan { + let child_required_indices: Vec<(Vec, bool)> = match plan { LogicalPlan::Sort(_) | LogicalPlan::Filter(_) | LogicalPlan::Repartition(_) @@ -103,36 +117,32 @@ fn optimize_projections( | LogicalPlan::Union(_) | LogicalPlan::SubqueryAlias(_) | LogicalPlan::Distinct(Distinct::On(_)) => { - // Re-route required indices from the parent + column indices referred by expressions in the plan - // to the child. - // All of these operators benefits from small tables at their inputs. Hence projection_beneficial flag is `true`. + // Pass index requirements from the parent as well as column indices + // that appear in this plan's expressions to its child. All these + // operators benefit from "small" inputs, so the projection_beneficial + // flag is `true`. let exprs = plan.expressions(); - let child_req_indices = plan - .inputs() + plan.inputs() .into_iter() .map(|input| { - let required_indices = - get_all_required_indices(indices, input, exprs.iter())?; - Ok((required_indices, true)) + get_all_required_indices(indices, input, exprs.iter()) + .map(|idxs| (idxs, true)) }) - .collect::>>()?; - Some(child_req_indices) + .collect::>()? } LogicalPlan::Limit(_) | LogicalPlan::Prepare(_) => { - // Re-route required indices from the parent + column indices referred by expressions in the plan - // to the child. - // Limit, Prepare doesn't benefit from small column numbers. Hence projection_beneficial flag is `false`. + // Pass index requirements from the parent as well as column indices + // that appear in this plan's expressions to its child. These operators + // do not benefit from "small" inputs, so the projection_beneficial + // flag is `false`. let exprs = plan.expressions(); - let child_req_indices = plan - .inputs() + plan.inputs() .into_iter() .map(|input| { - let required_indices = - get_all_required_indices(indices, input, exprs.iter())?; - Ok((required_indices, false)) + get_all_required_indices(indices, input, exprs.iter()) + .map(|idxs| (idxs, false)) }) - .collect::>>()?; - Some(child_req_indices) + .collect::>()? } LogicalPlan::Copy(_) | LogicalPlan::Ddl(_) @@ -141,79 +151,108 @@ fn optimize_projections( | LogicalPlan::Analyze(_) | LogicalPlan::Subquery(_) | LogicalPlan::Distinct(Distinct::All(_)) => { - // Require all of the fields of the Dml, Ddl, Copy, Explain, Analyze, Subquery, Distinct::All input(s). - // Their child plan can be treated as final plan. Otherwise expected schema may not match. - // TODO: For some subquery variants we may not need to require all indices for its input. - // such as Exists. - let child_requirements = plan - .inputs() + // These plans require all their fields, and their children should + // be treated as final plans -- otherwise, we may have schema a + // mismatch. + // TODO: For some subquery variants (e.g. a subquery arising from an + // EXISTS expression), we may not need to require all indices. + plan.inputs() .iter() - .map(|input| { - // Require all of the fields for each input. - // No projection since all of the fields at the child is required - (require_all_indices(input), false) - }) - .collect::>(); - Some(child_requirements) + .map(|input| ((0..input.schema().fields().len()).collect_vec(), false)) + .collect::>() } LogicalPlan::EmptyRelation(_) | LogicalPlan::Statement(_) | LogicalPlan::Values(_) | LogicalPlan::Extension(_) | LogicalPlan::DescribeTable(_) => { - // EmptyRelation, Values, DescribeTable, Statement has no inputs stop iteration - - // TODO: Add support for extension - // It is not known how to direct requirements to children for LogicalPlan::Extension. - // Safest behaviour is to stop propagation. - None + // These operators have no inputs, so stop the optimization process. + // TODO: Add support for `LogicalPlan::Extension`. + return Ok(None); } LogicalPlan::Projection(proj) => { return if let Some(proj) = merge_consecutive_projections(proj)? { - rewrite_projection_given_requirements(&proj, _config, indices)? - .map(|res| Ok(Some(res))) - // Even if projection cannot be optimized, return merged version - .unwrap_or_else(|| Ok(Some(LogicalPlan::Projection(proj)))) + Ok(Some( + rewrite_projection_given_requirements(&proj, config, indices)? + // Even if we cannot optimize the projection, merge if possible: + .unwrap_or_else(|| LogicalPlan::Projection(proj)), + )) } else { - rewrite_projection_given_requirements(proj, _config, indices) + rewrite_projection_given_requirements(proj, config, indices) }; } LogicalPlan::Aggregate(aggregate) => { - // Split parent requirements to group by and aggregate sections - let group_expr_len = aggregate.group_expr_len()?; - let (_group_by_reqs, mut aggregate_reqs): (Vec, Vec) = - indices.iter().partition(|&&idx| idx < group_expr_len); - // Offset aggregate indices so that they point to valid indices at the `aggregate.aggr_expr` - aggregate_reqs - .iter_mut() - .for_each(|idx| *idx -= group_expr_len); - - // Group by expressions are same - let new_group_bys = aggregate.group_expr.clone(); - - // Only use absolutely necessary aggregate expressions required by parent. - let new_aggr_expr = get_at_indices(&aggregate.aggr_expr, &aggregate_reqs); + // Split parent requirements to GROUP BY and aggregate sections: + let n_group_exprs = aggregate.group_expr_len()?; + let (group_by_reqs, mut aggregate_reqs): (Vec, Vec) = + indices.iter().partition(|&&idx| idx < n_group_exprs); + // Offset aggregate indices so that they point to valid indices at + // `aggregate.aggr_expr`: + for idx in aggregate_reqs.iter_mut() { + *idx -= n_group_exprs; + } + + // Get absolutely necessary GROUP BY fields: + let group_by_expr_existing = aggregate + .group_expr + .iter() + .map(|group_by_expr| group_by_expr.display_name()) + .collect::>>()?; + let new_group_bys = if let Some(simplest_groupby_indices) = + get_required_group_by_exprs_indices( + aggregate.input.schema(), + &group_by_expr_existing, + ) { + // Some of the fields in the GROUP BY may be required by the + // parent even if these fields are unnecessary in terms of + // functional dependency. + let required_indices = + merge_slices(&simplest_groupby_indices, &group_by_reqs); + get_at_indices(&aggregate.group_expr, &required_indices) + } else { + aggregate.group_expr.clone() + }; + + // Only use the absolutely necessary aggregate expressions required + // by the parent: + let mut new_aggr_expr = get_at_indices(&aggregate.aggr_expr, &aggregate_reqs); let all_exprs_iter = new_group_bys.iter().chain(new_aggr_expr.iter()); - let necessary_indices = - indices_referred_by_exprs(&aggregate.input, all_exprs_iter)?; + let schema = aggregate.input.schema(); + let necessary_indices = indices_referred_by_exprs(schema, all_exprs_iter)?; let aggregate_input = if let Some(input) = - optimize_projections(&aggregate.input, _config, &necessary_indices)? + optimize_projections(&aggregate.input, config, &necessary_indices)? { input } else { aggregate.input.as_ref().clone() }; - // Simplify input of the aggregation by adding a projection so that its input only contains - // absolutely necessary columns for the aggregate expressions. Please no that we use aggregate.input.schema() - // because necessary_indices refers to fields in this schema. - let necessary_exprs = - get_required_exprs(aggregate.input.schema(), &necessary_indices); - let (aggregate_input, _is_added) = - add_projection_on_top_if_helpful(aggregate_input, necessary_exprs, true)?; + // Simplify the input of the aggregation by adding a projection so + // that its input only contains absolutely necessary columns for + // the aggregate expressions. Note that necessary_indices refer to + // fields in `aggregate.input.schema()`. + let necessary_exprs = get_required_exprs(schema, &necessary_indices); + let (aggregate_input, _) = + add_projection_on_top_if_helpful(aggregate_input, necessary_exprs)?; + + // Aggregations always need at least one aggregate expression. + // With a nested count, we don't require any column as input, but + // still need to create a correct aggregate, which may be optimized + // out later. As an example, consider the following query: + // + // SELECT COUNT(*) FROM (SELECT COUNT(*) FROM [...]) + // + // which always returns 1. + if new_aggr_expr.is_empty() + && new_group_bys.is_empty() + && !aggregate.aggr_expr.is_empty() + { + new_aggr_expr = vec![aggregate.aggr_expr[0].clone()]; + } - // Create new aggregate plan with updated input, and absolutely necessary fields. + // Create a new aggregate plan with the updated input and only the + // absolutely necessary fields: return Aggregate::try_new( Arc::new(aggregate_input), new_group_bys, @@ -222,43 +261,48 @@ fn optimize_projections( .map(|aggregate| Some(LogicalPlan::Aggregate(aggregate))); } LogicalPlan::Window(window) => { - // Split parent requirements to child and window expression sections. + // Split parent requirements to child and window expression sections: let n_input_fields = window.input.schema().fields().len(); let (child_reqs, mut window_reqs): (Vec, Vec) = indices.iter().partition(|&&idx| idx < n_input_fields); - // Offset window expr indices so that they point to valid indices at the `window.window_expr` - window_reqs - .iter_mut() - .for_each(|idx| *idx -= n_input_fields); + // Offset window expression indices so that they point to valid + // indices at `window.window_expr`: + for idx in window_reqs.iter_mut() { + *idx -= n_input_fields; + } - // Only use window expressions that are absolutely necessary by parent requirements. + // Only use window expressions that are absolutely necessary according + // to parent requirements: let new_window_expr = get_at_indices(&window.window_expr, &window_reqs); - // All of the required column indices at the input of the window by parent, and window expression requirements. + // Get all the required column indices at the input, either by the + // parent or window expression requirements. let required_indices = get_all_required_indices( &child_reqs, &window.input, new_window_expr.iter(), )?; let window_child = if let Some(new_window_child) = - optimize_projections(&window.input, _config, &required_indices)? + optimize_projections(&window.input, config, &required_indices)? { new_window_child } else { window.input.as_ref().clone() }; - // When no window expression is necessary, just use window input. (Remove window operator) + return if new_window_expr.is_empty() { + // When no window expression is necessary, use the input directly: Ok(Some(window_child)) } else { // Calculate required expressions at the input of the window. - // Please note that we use `old_child`, because `required_indices` refers to `old_child`. + // Please note that we use `old_child`, because `required_indices` + // refers to `old_child`. let required_exprs = get_required_exprs(window.input.schema(), &required_indices); - let (window_child, _is_added) = - add_projection_on_top_if_helpful(window_child, required_exprs, true)?; - let window = Window::try_new(new_window_expr, Arc::new(window_child))?; - Ok(Some(LogicalPlan::Window(window))) + let (window_child, _) = + add_projection_on_top_if_helpful(window_child, required_exprs)?; + Window::try_new(new_window_expr, Arc::new(window_child)) + .map(|window| Some(LogicalPlan::Window(window))) }; } LogicalPlan::Join(join) => { @@ -270,323 +314,402 @@ fn optimize_projections( get_all_required_indices(&left_req_indices, &join.left, exprs.iter())?; let right_indices = get_all_required_indices(&right_req_indices, &join.right, exprs.iter())?; - // Join benefits from small columns numbers at its input (decreases memory usage) - // Hence each child benefits from projection. - Some(vec![(left_indices, true), (right_indices, true)]) + // Joins benefit from "small" input tables (lower memory usage). + // Therefore, each child benefits from projection: + vec![(left_indices, true), (right_indices, true)] } LogicalPlan::CrossJoin(cross_join) => { let left_len = cross_join.left.schema().fields().len(); let (left_child_indices, right_child_indices) = split_join_requirements(left_len, indices, &JoinType::Inner); - // Join benefits from small columns numbers at its input (decreases memory usage) - // Hence each child benefits from projection. - Some(vec![ - (left_child_indices, true), - (right_child_indices, true), - ]) + // Joins benefit from "small" input tables (lower memory usage). + // Therefore, each child benefits from projection: + vec![(left_child_indices, true), (right_child_indices, true)] } LogicalPlan::TableScan(table_scan) => { - let projection_fields = table_scan.projected_schema.fields(); let schema = table_scan.source.schema(); - // We expect to find all of the required indices of the projected schema fields. - // among original schema. If at least one of them cannot be found. Use all of the fields in the file. - // (No projection at the source) - let projection = indices - .iter() - .map(|&idx| { - schema.fields().iter().position(|field_source| { - projection_fields[idx].field() == field_source - }) - }) - .collect::>>(); + // Get indices referred to in the original (schema with all fields) + // given projected indices. + let projection = with_indices(&table_scan.projection, schema, |map| { + indices.iter().map(|&idx| map[idx]).collect() + }); - return Ok(Some(LogicalPlan::TableScan(TableScan::try_new( + return TableScan::try_new( table_scan.table_name.clone(), table_scan.source.clone(), - projection, + Some(projection), table_scan.filters.clone(), table_scan.fetch, - )?))); + ) + .map(|table| Some(LogicalPlan::TableScan(table))); } }; - let child_required_indices = - if let Some(child_required_indices) = child_required_indices { - child_required_indices - } else { - // Stop iteration, cannot propagate requirement down below this operator. - return Ok(None); - }; - let new_inputs = izip!(child_required_indices, plan.inputs().into_iter()) .map(|((required_indices, projection_beneficial), child)| { - let (input, mut is_changed) = if let Some(new_input) = - optimize_projections(child, _config, &required_indices)? + let (input, is_changed) = if let Some(new_input) = + optimize_projections(child, config, &required_indices)? { (new_input, true) } else { (child.clone(), false) }; let project_exprs = get_required_exprs(child.schema(), &required_indices); - let (input, is_projection_added) = add_projection_on_top_if_helpful( - input, - project_exprs, - projection_beneficial, - )?; - is_changed |= is_projection_added; - Ok(is_changed.then_some(input)) + let (input, proj_added) = if projection_beneficial { + add_projection_on_top_if_helpful(input, project_exprs)? + } else { + (input, false) + }; + Ok((is_changed || proj_added).then_some(input)) }) - .collect::>>>()?; - // All of the children are same in this case, no need to change plan + .collect::>>()?; if new_inputs.iter().all(|child| child.is_none()) { + // All children are the same in this case, no need to change the plan: Ok(None) } else { - // At least one of the children is changed. + // At least one of the children is changed: let new_inputs = izip!(new_inputs, plan.inputs()) - // If new_input is `None`, this means child is not changed. Hence use `old_child` during construction. + // If new_input is `None`, this means child is not changed, so use + // `old_child` during construction: .map(|(new_input, old_child)| new_input.unwrap_or_else(|| old_child.clone())) .collect::>(); - let res = plan.with_new_inputs(&new_inputs)?; - Ok(Some(res)) + plan.with_new_inputs(&new_inputs).map(Some) } } -/// Merge Consecutive Projections +/// This function applies the given function `f` to the projection indices +/// `proj_indices` if they exist. Otherwise, applies `f` to a default set +/// of indices according to `schema`. +fn with_indices( + proj_indices: &Option>, + schema: SchemaRef, + mut f: F, +) -> Vec +where + F: FnMut(&[usize]) -> Vec, +{ + match proj_indices { + Some(indices) => f(indices.as_slice()), + None => { + let range: Vec = (0..schema.fields.len()).collect(); + f(range.as_slice()) + } + } +} + +/// Merges consecutive projections. /// /// Given a projection `proj`, this function attempts to merge it with a previous -/// projection if it exists and if the merging is beneficial. Merging is considered -/// beneficial when expressions in the current projection are non-trivial and referred to -/// more than once in its input fields. This can act as a caching mechanism for non-trivial -/// computations. +/// projection if it exists and if merging is beneficial. Merging is considered +/// beneficial when expressions in the current projection are non-trivial and +/// appear more than once in its input fields. This can act as a caching mechanism +/// for non-trivial computations. /// -/// # Arguments +/// # Parameters /// /// * `proj` - A reference to the `Projection` to be merged. /// /// # Returns /// -/// A `Result` containing an `Option` of the merged `Projection`. If merging is not beneficial -/// it returns `Ok(None)`. +/// A `Result` object with the following semantics: +/// +/// - `Ok(Some(Projection))`: Merge was beneficial and successful. Contains the +/// merged projection. +/// - `Ok(None)`: Signals that merge is not beneficial (and has not taken place). +/// - `Err(error)`: An error occured during the function call. fn merge_consecutive_projections(proj: &Projection) -> Result> { - let prev_projection = if let LogicalPlan::Projection(prev) = proj.input.as_ref() { - prev - } else { + let LogicalPlan::Projection(prev_projection) = proj.input.as_ref() else { return Ok(None); }; - // Count usages (referral counts) of each projection expression in its input fields - let column_referral_map: HashMap = proj - .expr - .iter() - .flat_map(|expr| expr.to_columns()) - .fold(HashMap::new(), |mut map, cols| { - cols.into_iter() - .for_each(|col| *map.entry(col).or_default() += 1); - map - }); - - // Merging these projections is not beneficial, e.g - // If an expression is not trivial and it is referred more than 1, consecutive projections will be - // beneficial as caching mechanism for non-trivial computations. - // See discussion in: https://github.com/apache/arrow-datafusion/issues/8296 - if column_referral_map.iter().any(|(col, usage)| { - *usage > 1 + // Count usages (referrals) of each projection expression in its input fields: + let mut column_referral_map = HashMap::::new(); + for columns in proj.expr.iter().flat_map(|expr| expr.to_columns()) { + for col in columns.into_iter() { + *column_referral_map.entry(col.clone()).or_default() += 1; + } + } + + // If an expression is non-trivial and appears more than once, consecutive + // projections will benefit from a compute-once approach. For details, see: + // https://github.com/apache/arrow-datafusion/issues/8296 + if column_referral_map.into_iter().any(|(col, usage)| { + usage > 1 && !is_expr_trivial( &prev_projection.expr - [prev_projection.schema.index_of_column(col).unwrap()], + [prev_projection.schema.index_of_column(&col).unwrap()], ) }) { return Ok(None); } - // If all of the expression of the top projection can be rewritten. Rewrite expressions and create a new projection + // If all the expression of the top projection can be rewritten, do so and + // create a new projection: let new_exprs = proj .expr .iter() .map(|expr| rewrite_expr(expr, prev_projection)) .collect::>>>()?; - new_exprs - .map(|exprs| Projection::try_new(exprs, prev_projection.input.clone())) - .transpose() + if let Some(new_exprs) = new_exprs { + let new_exprs = new_exprs + .into_iter() + .zip(proj.expr.iter()) + .map(|(new_expr, old_expr)| { + new_expr.alias_if_changed(old_expr.name_for_alias()?) + }) + .collect::>>()?; + Projection::try_new(new_exprs, prev_projection.input.clone()).map(Some) + } else { + Ok(None) + } } -/// Trim Expression -/// -/// Trim the given expression by removing any unnecessary layers of abstraction. +/// Trim the given expression by removing any unnecessary layers of aliasing. /// If the expression is an alias, the function returns the underlying expression. -/// Otherwise, it returns the original expression unchanged. -/// -/// # Arguments +/// Otherwise, it returns the given expression as is. /// -/// * `expr` - The input expression to be trimmed. +/// Without trimming, we can end up with unnecessary indirections inside expressions +/// during projection merges. /// -/// # Returns -/// -/// The trimmed expression. If the input is an alias, the underlying expression is returned. -/// -/// Without trimming, during projection merge we can end up unnecessary indirections inside the expressions. /// Consider: /// -/// Projection (a1 + b1 as sum1) -/// --Projection (a as a1, b as b1) -/// ----Source (a, b) +/// ```text +/// Projection(a1 + b1 as sum1) +/// --Projection(a as a1, b as b1) +/// ----Source(a, b) +/// ``` /// -/// After merge we want to produce +/// After merge, we want to produce: /// -/// Projection (a + b as sum1) +/// ```text +/// Projection(a + b as sum1) /// --Source(a, b) +/// ``` /// -/// Without trimming we would end up +/// Without trimming, we would end up with: /// -/// Projection (a as a1 + b as b1 as sum1) +/// ```text +/// Projection((a as a1 + b as b1) as sum1) /// --Source(a, b) +/// ``` fn trim_expr(expr: Expr) -> Expr { match expr { - Expr::Alias(alias) => *alias.expr, + Expr::Alias(alias) => trim_expr(*alias.expr), _ => expr, } } -// Check whether expression is trivial (e.g it doesn't include computation.) +// Check whether `expr` is trivial; i.e. it doesn't imply any computation. fn is_expr_trivial(expr: &Expr) -> bool { matches!(expr, Expr::Column(_) | Expr::Literal(_)) } -// Exit early when None is seen. +// Exit early when there is no rewrite to do. macro_rules! rewrite_expr_with_check { ($expr:expr, $input:expr) => { - if let Some(val) = rewrite_expr($expr, $input)? { - val + if let Some(value) = rewrite_expr($expr, $input)? { + value } else { return Ok(None); } }; } -// Rewrites expression using its input projection (Merges consecutive projection expressions). -/// Rewrites an projections expression using its input projection -/// (Helper during merging consecutive projection expressions). +/// Rewrites a projection expression using the projection before it (i.e. its input) +/// This is a subroutine to the `merge_consecutive_projections` function. /// -/// # Arguments +/// # Parameters /// -/// * `expr` - A reference to the expression to be rewritten. -/// * `input` - A reference to the input (itself a projection) of the projection expression. +/// * `expr` - A reference to the expression to rewrite. +/// * `input` - A reference to the input of the projection expression (itself +/// a projection). /// /// # Returns /// -/// A `Result` containing an `Option` of the rewritten expression. If the rewrite is successful, -/// it returns `Ok(Some)` with the modified expression. If the expression cannot be rewritten -/// it returns `Ok(None)`. +/// A `Result` object with the following semantics: +/// +/// - `Ok(Some(Expr))`: Rewrite was successful. Contains the rewritten result. +/// - `Ok(None)`: Signals that `expr` can not be rewritten. +/// - `Err(error)`: An error occured during the function call. fn rewrite_expr(expr: &Expr, input: &Projection) -> Result> { - Ok(match expr { + let result = match expr { Expr::Column(col) => { - // Find index of column + // Find index of column: let idx = input.schema.index_of_column(col)?; - Some(input.expr[idx].clone()) - } - Expr::BinaryExpr(binary) => { - let lhs = trim_expr(rewrite_expr_with_check!(&binary.left, input)); - let rhs = trim_expr(rewrite_expr_with_check!(&binary.right, input)); - Some(Expr::BinaryExpr(BinaryExpr::new( - Box::new(lhs), - binary.op, - Box::new(rhs), - ))) + input.expr[idx].clone() } - Expr::Alias(alias) => { - let new_expr = trim_expr(rewrite_expr_with_check!(&alias.expr, input)); - Some(Expr::Alias(Alias::new( - new_expr, - alias.relation.clone(), - alias.name.clone(), - ))) - } - Expr::Literal(_val) => Some(expr.clone()), + Expr::BinaryExpr(binary) => Expr::BinaryExpr(BinaryExpr::new( + Box::new(trim_expr(rewrite_expr_with_check!(&binary.left, input))), + binary.op, + Box::new(trim_expr(rewrite_expr_with_check!(&binary.right, input))), + )), + Expr::Alias(alias) => Expr::Alias(Alias::new( + trim_expr(rewrite_expr_with_check!(&alias.expr, input)), + alias.relation.clone(), + alias.name.clone(), + )), + Expr::Literal(_) => expr.clone(), Expr::Cast(cast) => { let new_expr = rewrite_expr_with_check!(&cast.expr, input); - Some(Expr::Cast(Cast::new( - Box::new(new_expr), - cast.data_type.clone(), - ))) + Expr::Cast(Cast::new(Box::new(new_expr), cast.data_type.clone())) } Expr::ScalarFunction(scalar_fn) => { - let fun = if let ScalarFunctionDefinition::BuiltIn(fun) = scalar_fn.func_def { - fun - } else { + // TODO: Support UDFs. + let ScalarFunctionDefinition::BuiltIn(fun) = scalar_fn.func_def else { return Ok(None); }; - scalar_fn + return Ok(scalar_fn .args .iter() .map(|expr| rewrite_expr(expr, input)) - .collect::>>>()? - .map(|new_args| Expr::ScalarFunction(ScalarFunction::new(fun, new_args))) + .collect::>>()? + .map(|new_args| { + Expr::ScalarFunction(ScalarFunction::new(fun, new_args)) + })); } - _ => { - // Unsupported type to merge in consecutive projections - None - } - }) + // Unsupported type for consecutive projection merge analysis. + _ => return Ok(None), + }; + Ok(Some(result)) } -/// Retrieves a set of outer-referenced columns from an expression. -/// Please note that `expr.to_columns()` API doesn't return these columns. +/// Retrieves a set of outer-referenced columns by the given expression, `expr`. +/// Note that the `Expr::to_columns()` function doesn't return these columns. /// -/// # Arguments +/// # Parameters /// -/// * `expr` - The expression to be analyzed for outer-referenced columns. +/// * `expr` - The expression to analyze for outer-referenced columns. /// /// # Returns /// -/// A `HashSet` containing columns that are referenced by the expression. -fn outer_columns(expr: &Expr) -> HashSet { +/// If the function can safely infer all outer-referenced columns, returns a +/// `Some(HashSet)` containing these columns. Otherwise, returns `None`. +fn outer_columns(expr: &Expr) -> Option> { let mut columns = HashSet::new(); - outer_columns_helper(expr, &mut columns); - columns + outer_columns_helper(expr, &mut columns).then_some(columns) } -/// Helper function to accumulate outer-referenced columns referred by the `expr`. +/// A recursive subroutine that accumulates outer-referenced columns by the +/// given expression, `expr`. /// -/// # Arguments +/// # Parameters +/// +/// * `expr` - The expression to analyze for outer-referenced columns. +/// * `columns` - A mutable reference to a `HashSet` where detected +/// columns are collected. /// -/// * `expr` - The expression to be analyzed for outer-referenced columns. -/// * `columns` - A mutable reference to a `HashSet` where the detected columns are collected. -fn outer_columns_helper(expr: &Expr, columns: &mut HashSet) { +/// Returns `true` if it can safely collect all outer-referenced columns. +/// Otherwise, returns `false`. +fn outer_columns_helper(expr: &Expr, columns: &mut HashSet) -> bool { match expr { Expr::OuterReferenceColumn(_, col) => { columns.insert(col.clone()); + true } Expr::BinaryExpr(binary_expr) => { - outer_columns_helper(&binary_expr.left, columns); - outer_columns_helper(&binary_expr.right, columns); + outer_columns_helper(&binary_expr.left, columns) + && outer_columns_helper(&binary_expr.right, columns) } Expr::ScalarSubquery(subquery) => { - for expr in &subquery.outer_ref_columns { - outer_columns_helper(expr, columns); - } + let exprs = subquery.outer_ref_columns.iter(); + outer_columns_helper_multi(exprs, columns) } Expr::Exists(exists) => { - for expr in &exists.subquery.outer_ref_columns { - outer_columns_helper(expr, columns); + let exprs = exists.subquery.outer_ref_columns.iter(); + outer_columns_helper_multi(exprs, columns) + } + Expr::Alias(alias) => outer_columns_helper(&alias.expr, columns), + Expr::InSubquery(insubquery) => { + let exprs = insubquery.subquery.outer_ref_columns.iter(); + outer_columns_helper_multi(exprs, columns) + } + Expr::IsNotNull(expr) | Expr::IsNull(expr) => outer_columns_helper(expr, columns), + Expr::Cast(cast) => outer_columns_helper(&cast.expr, columns), + Expr::Sort(sort) => outer_columns_helper(&sort.expr, columns), + Expr::AggregateFunction(aggregate_fn) => { + outer_columns_helper_multi(aggregate_fn.args.iter(), columns) + && aggregate_fn + .order_by + .as_ref() + .map_or(true, |obs| outer_columns_helper_multi(obs.iter(), columns)) + && aggregate_fn + .filter + .as_ref() + .map_or(true, |filter| outer_columns_helper(filter, columns)) + } + Expr::WindowFunction(window_fn) => { + outer_columns_helper_multi(window_fn.args.iter(), columns) + && outer_columns_helper_multi(window_fn.order_by.iter(), columns) + && outer_columns_helper_multi(window_fn.partition_by.iter(), columns) + } + Expr::GroupingSet(groupingset) => match groupingset { + GroupingSet::GroupingSets(multi_exprs) => multi_exprs + .iter() + .all(|e| outer_columns_helper_multi(e.iter(), columns)), + GroupingSet::Cube(exprs) | GroupingSet::Rollup(exprs) => { + outer_columns_helper_multi(exprs.iter(), columns) } + }, + Expr::ScalarFunction(scalar_fn) => { + outer_columns_helper_multi(scalar_fn.args.iter(), columns) + } + Expr::Like(like) => { + outer_columns_helper(&like.expr, columns) + && outer_columns_helper(&like.pattern, columns) } - Expr::Alias(alias) => { - outer_columns_helper(&alias.expr, columns); + Expr::InList(in_list) => { + outer_columns_helper(&in_list.expr, columns) + && outer_columns_helper_multi(in_list.list.iter(), columns) } - _ => {} + Expr::Case(case) => { + let when_then_exprs = case + .when_then_expr + .iter() + .flat_map(|(first, second)| [first.as_ref(), second.as_ref()]); + outer_columns_helper_multi(when_then_exprs, columns) + && case + .expr + .as_ref() + .map_or(true, |expr| outer_columns_helper(expr, columns)) + && case + .else_expr + .as_ref() + .map_or(true, |expr| outer_columns_helper(expr, columns)) + } + Expr::Column(_) | Expr::Literal(_) | Expr::Wildcard { .. } => true, + _ => false, } } -/// Generates the required expressions(Column) that resides at `indices` of the `input_schema`. +/// A recursive subroutine that accumulates outer-referenced columns by the +/// given expressions (`exprs`). +/// +/// # Parameters +/// +/// * `exprs` - The expressions to analyze for outer-referenced columns. +/// * `columns` - A mutable reference to a `HashSet` where detected +/// columns are collected. +/// +/// Returns `true` if it can safely collect all outer-referenced columns. +/// Otherwise, returns `false`. +fn outer_columns_helper_multi<'a>( + mut exprs: impl Iterator, + columns: &mut HashSet, +) -> bool { + exprs.all(|e| outer_columns_helper(e, columns)) +} + +/// Generates the required expressions (columns) that reside at `indices` of +/// the given `input_schema`. /// /// # Arguments /// /// * `input_schema` - A reference to the input schema. -/// * `indices` - A slice of `usize` indices specifying which columns are required. +/// * `indices` - A slice of `usize` indices specifying required columns. /// /// # Returns /// -/// A vector of `Expr::Column` expressions, that sits at `indices` of the `input_schema`. +/// A vector of `Expr::Column` expressions residing at `indices` of the `input_schema`. fn get_required_exprs(input_schema: &Arc, indices: &[usize]) -> Vec { let fields = input_schema.fields(); indices @@ -595,58 +718,70 @@ fn get_required_exprs(input_schema: &Arc, indices: &[usize]) -> Vec>( - input: &LogicalPlan, - exprs: I, +/// A [`Result`] object containing the indices of all required fields in +/// `input_schema` to calculate all `exprs` successfully. +fn indices_referred_by_exprs<'a>( + input_schema: &DFSchemaRef, + exprs: impl Iterator, ) -> Result> { - let new_indices = exprs - .flat_map(|expr| indices_referred_by_expr(input.schema(), expr)) + let indices = exprs + .map(|expr| indices_referred_by_expr(input_schema, expr)) + .collect::>>()?; + Ok(indices + .into_iter() .flatten() - // Make sure no duplicate entries exists and indices are ordered. + // Make sure no duplicate entries exist and indices are ordered: .sorted() .dedup() - .collect::>(); - Ok(new_indices) + .collect()) } -/// Get indices of the necessary fields referred by the `expr` among input schema. +/// Get indices of the fields referred to by the given expression `expr` within +/// the given schema (`input_schema`). /// -/// # Arguments +/// # Parameters /// -/// * `input_schema`: The input schema to search for indices referred by expr. -/// * `expr`: An expression for which we want to find necessary field indices at the input schema. +/// * `input_schema`: The input schema to analyze for index requirements. +/// * `expr`: An expression for which we want to find necessary field indices. /// /// # Returns /// -/// A [Result] object that contains the required field indices of the `input_schema`, to be able to calculate -/// the `expr` successfully. +/// A [`Result`] object containing the indices of all required fields in +/// `input_schema` to calculate `expr` successfully. fn indices_referred_by_expr( input_schema: &DFSchemaRef, expr: &Expr, ) -> Result> { let mut cols = expr.to_columns()?; - // Get outer referenced columns (expr.to_columns() doesn't return these columns). - cols.extend(outer_columns(expr)); - cols.iter() - .filter(|&col| input_schema.has_column(col)) - .map(|col| input_schema.index_of_column(col)) - .collect::>>() + // Get outer-referenced columns: + if let Some(outer_cols) = outer_columns(expr) { + cols.extend(outer_cols); + } else { + // Expression is not known to contain outer columns or not. Hence, do + // not assume anything and require all the schema indices at the input: + return Ok((0..input_schema.fields().len()).collect()); + } + Ok(cols + .iter() + .flat_map(|col| input_schema.index_of_column(col)) + .collect()) } -/// Get all required indices for the input (indices required by parent + indices referred by `exprs`) +/// Gets all required indices for the input; i.e. those required by the parent +/// and those referred to by `exprs`. /// -/// # Arguments +/// # Parameters /// /// * `parent_required_indices` - A slice of indices required by the parent plan. /// * `input` - The input logical plan to analyze for index requirements. @@ -654,30 +789,28 @@ fn indices_referred_by_expr( /// /// # Returns /// -/// A `Result` containing a vector of `usize` indices containing all required indices. -fn get_all_required_indices<'a, I: Iterator>( +/// A `Result` containing a vector of `usize` indices containing all the required +/// indices. +fn get_all_required_indices<'a>( parent_required_indices: &[usize], input: &LogicalPlan, - exprs: I, + exprs: impl Iterator, ) -> Result> { - let referred_indices = indices_referred_by_exprs(input, exprs)?; - Ok(merge_vectors(parent_required_indices, &referred_indices)) + indices_referred_by_exprs(input.schema(), exprs) + .map(|indices| merge_slices(parent_required_indices, &indices)) } -/// Retrieves a list of expressions at specified indices from a slice of expressions. +/// Retrieves the expressions at specified indices within the given slice. Ignores +/// any invalid indices. /// -/// This function takes a slice of expressions `exprs` and a slice of `usize` indices `indices`. -/// It returns a new vector containing the expressions from `exprs` that correspond to the provided indices (with bound check). -/// -/// # Arguments +/// # Parameters /// -/// * `exprs` - A slice of expressions from which expressions are to be retrieved. -/// * `indices` - A slice of `usize` indices specifying the positions of the expressions to be retrieved. +/// * `exprs` - A slice of expressions to index into. +/// * `indices` - A slice of indices specifying the positions of expressions sought. /// /// # Returns /// -/// A vector of expressions that correspond to the specified indices. If any index is out of bounds, -/// the associated expression is skipped in the result. +/// A vector of expressions corresponding to specified indices. fn get_at_indices(exprs: &[Expr], indices: &[usize]) -> Vec { indices .iter() @@ -686,158 +819,148 @@ fn get_at_indices(exprs: &[Expr], indices: &[usize]) -> Vec { .collect() } -/// Merges two slices of `usize` values into a single vector with sorted (ascending) and deduplicated elements. -/// -/// # Arguments -/// -/// * `lhs` - The first slice of `usize` values to be merged. -/// * `rhs` - The second slice of `usize` values to be merged. -/// -/// # Returns -/// -/// A vector of `usize` values containing the merged, sorted, and deduplicated elements from `lhs` and `rhs`. -/// As an example merge of [3, 2, 4] and [3, 6, 1] will produce [1, 2, 3, 6] -fn merge_vectors(lhs: &[usize], rhs: &[usize]) -> Vec { - let mut merged = lhs.to_vec(); - merged.extend(rhs); - // Make sure to run sort before dedup. - // Dedup removes consecutive same entries - // If sort is run before it, all duplicates are removed. - merged.sort(); - merged.dedup(); - merged +/// Merges two slices into a single vector with sorted (ascending) and +/// deduplicated elements. For example, merging `[3, 2, 4]` and `[3, 6, 1]` +/// will produce `[1, 2, 3, 6]`. +fn merge_slices(left: &[T], right: &[T]) -> Vec { + // Make sure to sort before deduping, which removes the duplicates: + left.iter() + .cloned() + .chain(right.iter().cloned()) + .sorted() + .dedup() + .collect() } -/// Splits requirement indices for a join into left and right children based on the join type. +/// Splits requirement indices for a join into left and right children based on +/// the join type. /// -/// This function takes the length of the left child, a slice of requirement indices, and the type -/// of join (e.g., INNER, LEFT, RIGHT, etc.) as arguments. Depending on the join type, it divides -/// the requirement indices into those that apply to the left child and those that apply to the right child. +/// This function takes the length of the left child, a slice of requirement +/// indices, and the type of join (e.g. `INNER`, `LEFT`, `RIGHT`) as arguments. +/// Depending on the join type, it divides the requirement indices into those +/// that apply to the left child and those that apply to the right child. /// -/// - For INNER, LEFT, RIGHT, and FULL joins, the requirements are split between left and right children. -/// The right child indices are adjusted to point to valid positions in the right child by subtracting -/// the length of the left child. +/// - For `INNER`, `LEFT`, `RIGHT` and `FULL` joins, the requirements are split +/// between left and right children. The right child indices are adjusted to +/// point to valid positions within the right child by subtracting the length +/// of the left child. /// -/// - For LEFT ANTI, LEFT SEMI, RIGHT SEMI, and RIGHT ANTI joins, all requirements are re-routed to either -/// the left child or the right child directly, depending on the join type. +/// - For `LEFT ANTI`, `LEFT SEMI`, `RIGHT SEMI` and `RIGHT ANTI` joins, all +/// requirements are re-routed to either the left child or the right child +/// directly, depending on the join type. /// -/// # Arguments +/// # Parameters /// /// * `left_len` - The length of the left child. /// * `indices` - A slice of requirement indices. -/// * `join_type` - The type of join (e.g., INNER, LEFT, RIGHT, etc.). +/// * `join_type` - The type of join (e.g. `INNER`, `LEFT`, `RIGHT`). /// /// # Returns /// -/// A tuple containing two vectors of `usize` indices: the first vector represents the requirements for -/// the left child, and the second vector represents the requirements for the right child. The indices -/// are appropriately split and adjusted based on the join type. +/// A tuple containing two vectors of `usize` indices: The first vector represents +/// the requirements for the left child, and the second vector represents the +/// requirements for the right child. The indices are appropriately split and +/// adjusted based on the join type. fn split_join_requirements( left_len: usize, indices: &[usize], join_type: &JoinType, ) -> (Vec, Vec) { match join_type { - // In these cases requirements split to left and right child. + // In these cases requirements are split between left/right children: JoinType::Inner | JoinType::Left | JoinType::Right | JoinType::Full => { - let (left_child_reqs, mut right_child_reqs): (Vec, Vec) = + let (left_reqs, mut right_reqs): (Vec, Vec) = indices.iter().partition(|&&idx| idx < left_len); - // Decrease right side index by `left_len` so that they point to valid positions in the right child. - right_child_reqs.iter_mut().for_each(|idx| *idx -= left_len); - (left_child_reqs, right_child_reqs) + // Decrease right side indices by `left_len` so that they point to valid + // positions within the right child: + for idx in right_reqs.iter_mut() { + *idx -= left_len; + } + (left_reqs, right_reqs) } // All requirements can be re-routed to left child directly. JoinType::LeftAnti | JoinType::LeftSemi => (indices.to_vec(), vec![]), - // All requirements can be re-routed to right side directly. (No need to change index, join schema is right child schema.) + // All requirements can be re-routed to right side directly. + // No need to change index, join schema is right child schema. JoinType::RightSemi | JoinType::RightAnti => (vec![], indices.to_vec()), } } -/// Adds a projection on top of a logical plan if it is beneficial and reduces the number of columns for the parent operator. +/// Adds a projection on top of a logical plan if doing so reduces the number +/// of columns for the parent operator. /// -/// This function takes a `LogicalPlan`, a list of projection expressions, and a flag indicating whether -/// the projection is beneficial. If the projection is beneficial and reduces the number of columns in -/// the plan, a new `LogicalPlan` with the projection is created and returned, along with a `true` flag. -/// If the projection is unnecessary or doesn't reduce the number of columns, the original plan is returned -/// with a `false` flag. +/// This function takes a `LogicalPlan` and a list of projection expressions. +/// If the projection is beneficial (it reduces the number of columns in the +/// plan) a new `LogicalPlan` with the projection is created and returned, along +/// with a `true` flag. If the projection doesn't reduce the number of columns, +/// the original plan is returned with a `false` flag. /// -/// # Arguments +/// # Parameters /// /// * `plan` - The input `LogicalPlan` to potentially add a projection to. /// * `project_exprs` - A list of expressions for the projection. -/// * `projection_beneficial` - A flag indicating whether the projection is beneficial. /// /// # Returns /// -/// A `Result` containing a tuple with two values: the resulting `LogicalPlan` (with or without -/// the added projection) and a `bool` flag indicating whether the projection was added (`true`) or not (`false`). +/// A `Result` containing a tuple with two values: The resulting `LogicalPlan` +/// (with or without the added projection) and a `bool` flag indicating if a +/// projection was added (`true`) or not (`false`). fn add_projection_on_top_if_helpful( plan: LogicalPlan, project_exprs: Vec, - projection_beneficial: bool, ) -> Result<(LogicalPlan, bool)> { - // Make sure projection decreases table column size, otherwise it is unnecessary. - if !projection_beneficial || project_exprs.len() >= plan.schema().fields().len() { + // Make sure projection decreases the number of columns, otherwise it is unnecessary. + if project_exprs.len() >= plan.schema().fields().len() { Ok((plan, false)) } else { - let new_plan = Projection::try_new(project_exprs, Arc::new(plan)) - .map(LogicalPlan::Projection)?; - Ok((new_plan, true)) + Projection::try_new(project_exprs, Arc::new(plan)) + .map(|proj| (LogicalPlan::Projection(proj), true)) } } -/// Collects and returns a vector of all indices of the fields in the schema of a logical plan. +/// Rewrite the given projection according to the fields required by its +/// ancestors. /// -/// # Arguments +/// # Parameters /// -/// * `plan` - A reference to the `LogicalPlan` for which indices are required. +/// * `proj` - A reference to the original projection to rewrite. +/// * `config` - A reference to the optimizer configuration. +/// * `indices` - A slice of indices representing the columns required by the +/// ancestors of the given projection. /// /// # Returns /// -/// A vector of `usize` indices representing all fields in the schema of the provided logical plan. -fn require_all_indices(plan: &LogicalPlan) -> Vec { - (0..plan.schema().fields().len()).collect() -} - -/// Rewrite Projection Given Required fields by its parent(s). -/// -/// # Arguments +/// A `Result` object with the following semantics: /// -/// * `proj` - A reference to the original projection to be rewritten. -/// * `_config` - A reference to the optimizer configuration (unused in the function). -/// * `indices` - A slice of indices representing the required columns by the parent(s) of projection. -/// -/// # Returns -/// -/// A `Result` containing an `Option` of the rewritten logical plan. If the -/// rewrite is successful, it returns `Some` with the optimized logical plan. -/// If the logical plan remains unchanged it returns `Ok(None)`. +/// - `Ok(Some(LogicalPlan))`: Contains the rewritten projection +/// - `Ok(None)`: No rewrite necessary. +/// - `Err(error)`: An error occured during the function call. fn rewrite_projection_given_requirements( proj: &Projection, - _config: &dyn OptimizerConfig, + config: &dyn OptimizerConfig, indices: &[usize], ) -> Result> { let exprs_used = get_at_indices(&proj.expr, indices); - let required_indices = indices_referred_by_exprs(&proj.input, exprs_used.iter())?; + let required_indices = + indices_referred_by_exprs(proj.input.schema(), exprs_used.iter())?; return if let Some(input) = - optimize_projections(&proj.input, _config, &required_indices)? + optimize_projections(&proj.input, config, &required_indices)? { if &projection_schema(&input, &exprs_used)? == input.schema() { Ok(Some(input)) } else { - let new_proj = Projection::try_new(exprs_used, Arc::new(input))?; - let new_proj = LogicalPlan::Projection(new_proj); - Ok(Some(new_proj)) + Projection::try_new(exprs_used, Arc::new(input)) + .map(|proj| Some(LogicalPlan::Projection(proj))) } } else if exprs_used.len() < proj.expr.len() { - // Projection expression used is different than the existing projection - // In this case, even if child doesn't change we should update projection to use less columns. + // Projection expression used is different than the existing projection. + // In this case, even if the child doesn't change, we should update the + // projection to use fewer columns: if &projection_schema(&proj.input, &exprs_used)? == proj.input.schema() { Ok(Some(proj.input.as_ref().clone())) } else { - let new_proj = Projection::try_new(exprs_used, proj.input.clone())?; - let new_proj = LogicalPlan::Projection(new_proj); - Ok(Some(new_proj)) + Projection::try_new(exprs_used, proj.input.clone()) + .map(|proj| Some(LogicalPlan::Projection(proj))) } } else { // Projection doesn't change. @@ -847,15 +970,16 @@ fn rewrite_projection_given_requirements( #[cfg(test)] mod tests { + use std::sync::Arc; + use crate::optimize_projections::OptimizeProjections; - use datafusion_common::Result; + use crate::test::{assert_optimized_plan_eq, test_table_scan}; + use arrow::datatypes::{DataType, Field, Schema}; + use datafusion_common::{Result, TableReference}; use datafusion_expr::{ - binary_expr, col, lit, logical_plan::builder::LogicalPlanBuilder, LogicalPlan, - Operator, + binary_expr, col, count, lit, logical_plan::builder::LogicalPlanBuilder, + table_scan, Expr, LogicalPlan, Operator, }; - use std::sync::Arc; - - use crate::test::*; fn assert_optimized_plan_equal(plan: &LogicalPlan, expected: &str) -> Result<()> { assert_optimized_plan_eq(Arc::new(OptimizeProjections::new()), plan, expected) @@ -900,4 +1024,39 @@ mod tests { \n TableScan: test projection=[a]"; assert_optimized_plan_equal(&plan, expected) } + + #[test] + fn merge_nested_alias() -> Result<()> { + let table_scan = test_table_scan()?; + let plan = LogicalPlanBuilder::from(table_scan) + .project(vec![col("a").alias("alias1").alias("alias2")])? + .project(vec![col("alias2").alias("alias")])? + .build()?; + + let expected = "Projection: test.a AS alias\ + \n TableScan: test projection=[a]"; + assert_optimized_plan_equal(&plan, expected) + } + + #[test] + fn test_nested_count() -> Result<()> { + let schema = Schema::new(vec![Field::new("foo", DataType::Int32, false)]); + + let groups: Vec = vec![]; + + let plan = table_scan(TableReference::none(), &schema, None) + .unwrap() + .aggregate(groups.clone(), vec![count(lit(1))]) + .unwrap() + .aggregate(groups, vec![count(lit(1))]) + .unwrap() + .build() + .unwrap(); + + let expected = "Aggregate: groupBy=[[]], aggr=[[COUNT(Int32(1))]]\ + \n Projection: \ + \n Aggregate: groupBy=[[]], aggr=[[COUNT(Int32(1))]]\ + \n TableScan: ?table? projection=[]"; + assert_optimized_plan_equal(&plan, expected) + } } diff --git a/datafusion/optimizer/src/optimizer.rs b/datafusion/optimizer/src/optimizer.rs index 7af46ed70adf..0dc34cb809eb 100644 --- a/datafusion/optimizer/src/optimizer.rs +++ b/datafusion/optimizer/src/optimizer.rs @@ -17,6 +17,10 @@ //! Query optimizer traits +use std::collections::HashSet; +use std::sync::Arc; +use std::time::Instant; + use crate::common_subexpr_eliminate::CommonSubexprEliminate; use crate::decorrelate_predicate_subquery::DecorrelatePredicateSubquery; use crate::eliminate_cross_join::EliminateCrossJoin; @@ -41,15 +45,14 @@ use crate::simplify_expressions::SimplifyExpressions; use crate::single_distinct_to_groupby::SingleDistinctToGroupBy; use crate::unwrap_cast_in_comparison::UnwrapCastInComparison; use crate::utils::log_plan; -use chrono::{DateTime, Utc}; + use datafusion_common::alias::AliasGenerator; use datafusion_common::config::ConfigOptions; use datafusion_common::{DataFusionError, Result}; -use datafusion_expr::LogicalPlan; +use datafusion_expr::logical_plan::LogicalPlan; + +use chrono::{DateTime, Utc}; use log::{debug, warn}; -use std::collections::HashSet; -use std::sync::Arc; -use std::time::Instant; /// `OptimizerRule` transforms one [`LogicalPlan`] into another which /// computes the same results, but in a potentially more efficient @@ -447,17 +450,18 @@ pub(crate) fn assert_schema_is_the_same( #[cfg(test)] mod tests { + use std::sync::{Arc, Mutex}; + + use super::ApplyOrder; use crate::optimizer::Optimizer; use crate::test::test_table_scan; use crate::{OptimizerConfig, OptimizerContext, OptimizerRule}; + use datafusion_common::{ plan_err, DFField, DFSchema, DFSchemaRef, DataFusionError, Result, }; use datafusion_expr::logical_plan::EmptyRelation; use datafusion_expr::{col, lit, LogicalPlan, LogicalPlanBuilder, Projection}; - use std::sync::{Arc, Mutex}; - - use super::ApplyOrder; #[test] fn skip_failing_rule() { diff --git a/datafusion/optimizer/src/push_down_filter.rs b/datafusion/optimizer/src/push_down_filter.rs index e8f116d89466..c090fb849a82 100644 --- a/datafusion/optimizer/src/push_down_filter.rs +++ b/datafusion/optimizer/src/push_down_filter.rs @@ -1062,7 +1062,7 @@ mod tests { ]); let mut optimized_plan = optimizer .optimize_recursively( - optimizer.rules.get(0).unwrap(), + optimizer.rules.first().unwrap(), plan, &OptimizerContext::new(), )? diff --git a/datafusion/optimizer/src/push_down_projection.rs b/datafusion/optimizer/src/push_down_projection.rs index bdd66347631c..10cc1879aeeb 100644 --- a/datafusion/optimizer/src/push_down_projection.rs +++ b/datafusion/optimizer/src/push_down_projection.rs @@ -625,7 +625,7 @@ mod tests { let optimizer = Optimizer::with_rules(vec![Arc::new(OptimizeProjections::new())]); let optimized_plan = optimizer .optimize_recursively( - optimizer.rules.get(0).unwrap(), + optimizer.rules.first().unwrap(), plan, &OptimizerContext::new(), )? diff --git a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs index 41c71c9d9aff..e2fbd5e927a1 100644 --- a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs +++ b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs @@ -3297,10 +3297,7 @@ mod tests { col("c4"), NullableInterval::from(ScalarValue::UInt32(Some(9))), ), - ( - col("c1"), - NullableInterval::from(ScalarValue::Utf8(Some("a".to_string()))), - ), + (col("c1"), NullableInterval::from(ScalarValue::from("a"))), ]; let output = simplify_with_guarantee(expr.clone(), guarantees); assert_eq!(output, lit(false)); @@ -3323,8 +3320,8 @@ mod tests { col("c1"), NullableInterval::NotNull { values: Interval::try_new( - ScalarValue::Utf8(Some("d".to_string())), - ScalarValue::Utf8(Some("f".to_string())), + ScalarValue::from("d"), + ScalarValue::from("f"), ) .unwrap(), }, diff --git a/datafusion/optimizer/src/simplify_expressions/guarantees.rs b/datafusion/optimizer/src/simplify_expressions/guarantees.rs index 3cfaae858e2d..860dc326b9b0 100644 --- a/datafusion/optimizer/src/simplify_expressions/guarantees.rs +++ b/datafusion/optimizer/src/simplify_expressions/guarantees.rs @@ -406,8 +406,8 @@ mod tests { col("x"), NullableInterval::MaybeNull { values: Interval::try_new( - ScalarValue::Utf8(Some("abc".to_string())), - ScalarValue::Utf8(Some("def".to_string())), + ScalarValue::from("abc"), + ScalarValue::from("def"), ) .unwrap(), }, @@ -463,7 +463,7 @@ mod tests { ScalarValue::Int32(Some(1)), ScalarValue::Boolean(Some(true)), ScalarValue::Boolean(None), - ScalarValue::Utf8(Some("abc".to_string())), + ScalarValue::from("abc"), ScalarValue::LargeUtf8(Some("def".to_string())), ScalarValue::Date32(Some(18628)), ScalarValue::Date32(None), diff --git a/datafusion/optimizer/src/simplify_expressions/regex.rs b/datafusion/optimizer/src/simplify_expressions/regex.rs index b9d9821b43f0..175b70f2b10e 100644 --- a/datafusion/optimizer/src/simplify_expressions/regex.rs +++ b/datafusion/optimizer/src/simplify_expressions/regex.rs @@ -84,7 +84,7 @@ impl OperatorMode { let like = Like { negated: self.not, expr, - pattern: Box::new(Expr::Literal(ScalarValue::Utf8(Some(pattern)))), + pattern: Box::new(Expr::Literal(ScalarValue::from(pattern))), escape_char: None, case_insensitive: self.i, }; diff --git a/datafusion/optimizer/src/test/mod.rs b/datafusion/optimizer/src/test/mod.rs index 917ddc565c9e..e691fe9a5351 100644 --- a/datafusion/optimizer/src/test/mod.rs +++ b/datafusion/optimizer/src/test/mod.rs @@ -158,7 +158,7 @@ pub fn assert_optimized_plan_eq( let optimizer = Optimizer::with_rules(vec![rule.clone()]); let optimized_plan = optimizer .optimize_recursively( - optimizer.rules.get(0).unwrap(), + optimizer.rules.first().unwrap(), plan, &OptimizerContext::new(), )? @@ -199,7 +199,7 @@ pub fn assert_optimized_plan_eq_display_indent( let optimizer = Optimizer::with_rules(vec![rule]); let optimized_plan = optimizer .optimize_recursively( - optimizer.rules.get(0).unwrap(), + optimizer.rules.first().unwrap(), plan, &OptimizerContext::new(), ) @@ -233,7 +233,7 @@ pub fn assert_optimizer_err( ) { let optimizer = Optimizer::with_rules(vec![rule]); let res = optimizer.optimize_recursively( - optimizer.rules.get(0).unwrap(), + optimizer.rules.first().unwrap(), plan, &OptimizerContext::new(), ); @@ -255,7 +255,7 @@ pub fn assert_optimization_skipped( let optimizer = Optimizer::with_rules(vec![rule]); let new_plan = optimizer .optimize_recursively( - optimizer.rules.get(0).unwrap(), + optimizer.rules.first().unwrap(), plan, &OptimizerContext::new(), )? diff --git a/datafusion/optimizer/tests/optimizer_integration.rs b/datafusion/optimizer/tests/optimizer_integration.rs index e593b07361e2..d857c6154ea9 100644 --- a/datafusion/optimizer/tests/optimizer_integration.rs +++ b/datafusion/optimizer/tests/optimizer_integration.rs @@ -15,8 +15,11 @@ // specific language governing permissions and limitations // under the License. +use std::any::Any; +use std::collections::HashMap; +use std::sync::Arc; + use arrow::datatypes::{DataType, Field, Schema, SchemaRef, TimeUnit}; -use chrono::{DateTime, NaiveDateTime, Utc}; use datafusion_common::config::ConfigOptions; use datafusion_common::{plan_err, DataFusionError, Result}; use datafusion_expr::{AggregateUDF, LogicalPlan, ScalarUDF, TableSource, WindowUDF}; @@ -28,9 +31,8 @@ use datafusion_sql::sqlparser::ast::Statement; use datafusion_sql::sqlparser::dialect::GenericDialect; use datafusion_sql::sqlparser::parser::Parser; use datafusion_sql::TableReference; -use std::any::Any; -use std::collections::HashMap; -use std::sync::Arc; + +use chrono::{DateTime, NaiveDateTime, Utc}; #[cfg(test)] #[ctor::ctor] @@ -324,11 +326,10 @@ fn push_down_filter_groupby_expr_contains_alias() { fn test_same_name_but_not_ambiguous() { let sql = "SELECT t1.col_int32 AS col_int32 FROM test t1 intersect SELECT col_int32 FROM test t2"; let plan = test_sql(sql).unwrap(); - let expected = "LeftSemi Join: col_int32 = t2.col_int32\ - \n Aggregate: groupBy=[[col_int32]], aggr=[[]]\ - \n Projection: t1.col_int32 AS col_int32\ - \n SubqueryAlias: t1\ - \n TableScan: test projection=[col_int32]\ + let expected = "LeftSemi Join: t1.col_int32 = t2.col_int32\ + \n Aggregate: groupBy=[[t1.col_int32]], aggr=[[]]\ + \n SubqueryAlias: t1\ + \n TableScan: test projection=[col_int32]\ \n SubqueryAlias: t2\ \n TableScan: test projection=[col_int32]"; assert_eq!(expected, format!("{plan:?}")); diff --git a/datafusion/physical-expr/benches/in_list.rs b/datafusion/physical-expr/benches/in_list.rs index db017326083a..90bfc5efb61e 100644 --- a/datafusion/physical-expr/benches/in_list.rs +++ b/datafusion/physical-expr/benches/in_list.rs @@ -57,7 +57,7 @@ fn do_benches( .collect(); let in_list: Vec<_> = (0..in_list_length) - .map(|_| ScalarValue::Utf8(Some(random_string(&mut rng, string_length)))) + .map(|_| ScalarValue::from(random_string(&mut rng, string_length))) .collect(); do_bench( diff --git a/datafusion/physical-expr/src/aggregate/groups_accumulator/adapter.rs b/datafusion/physical-expr/src/aggregate/groups_accumulator/adapter.rs index dcc8c37e7484..cf980f4c3f16 100644 --- a/datafusion/physical-expr/src/aggregate/groups_accumulator/adapter.rs +++ b/datafusion/physical-expr/src/aggregate/groups_accumulator/adapter.rs @@ -309,7 +309,7 @@ impl GroupsAccumulator for GroupsAccumulatorAdapter { // double check each array has the same length (aka the // accumulator was implemented correctly - if let Some(first_col) = arrays.get(0) { + if let Some(first_col) = arrays.first() { for arr in &arrays { assert_eq!(arr.len(), first_col.len()) } diff --git a/datafusion/physical-expr/src/aggregate/min_max.rs b/datafusion/physical-expr/src/aggregate/min_max.rs index f5b708e8894e..7e3ef2a2abab 100644 --- a/datafusion/physical-expr/src/aggregate/min_max.rs +++ b/datafusion/physical-expr/src/aggregate/min_max.rs @@ -1297,12 +1297,7 @@ mod tests { #[test] fn max_utf8() -> Result<()> { let a: ArrayRef = Arc::new(StringArray::from(vec!["d", "a", "c", "b"])); - generic_test_op!( - a, - DataType::Utf8, - Max, - ScalarValue::Utf8(Some("d".to_string())) - ) + generic_test_op!(a, DataType::Utf8, Max, ScalarValue::from("d")) } #[test] @@ -1319,12 +1314,7 @@ mod tests { #[test] fn min_utf8() -> Result<()> { let a: ArrayRef = Arc::new(StringArray::from(vec!["d", "a", "c", "b"])); - generic_test_op!( - a, - DataType::Utf8, - Min, - ScalarValue::Utf8(Some("a".to_string())) - ) + generic_test_op!(a, DataType::Utf8, Min, ScalarValue::from("a")) } #[test] diff --git a/datafusion/physical-expr/src/aggregate/string_agg.rs b/datafusion/physical-expr/src/aggregate/string_agg.rs index 74c083959ed8..7adc736932ad 100644 --- a/datafusion/physical-expr/src/aggregate/string_agg.rs +++ b/datafusion/physical-expr/src/aggregate/string_agg.rs @@ -204,7 +204,7 @@ mod tests { ) .unwrap(); - let delimiter = Arc::new(Literal::new(ScalarValue::Utf8(Some(delimiter)))); + let delimiter = Arc::new(Literal::new(ScalarValue::from(delimiter))); let schema = Schema::new(vec![Field::new("a", coerced[0].clone(), true)]); let agg = create_aggregate_expr( &function, diff --git a/datafusion/physical-expr/src/analysis.rs b/datafusion/physical-expr/src/analysis.rs index dc12bdf46acd..f43434362a19 100644 --- a/datafusion/physical-expr/src/analysis.rs +++ b/datafusion/physical-expr/src/analysis.rs @@ -72,8 +72,12 @@ impl AnalysisContext { } } -/// Represents the boundaries of the resulting value from a physical expression, -/// if it were to be an expression, if it were to be evaluated. +/// Represents the boundaries (e.g. min and max values) of a particular column +/// +/// This is used range analysis of expressions, to determine if the expression +/// limits the value of particular columns (e.g. analyzing an expression such as +/// `time < 50` would result in a boundary interval for `time` having a max +/// value of `50`). #[derive(Clone, Debug, PartialEq)] pub struct ExprBoundaries { pub column: Column, @@ -111,6 +115,23 @@ impl ExprBoundaries { distinct_count: col_stats.distinct_count.clone(), }) } + + /// Create `ExprBoundaries` that represent no known bounds for all the + /// columns in `schema` + pub fn try_new_unbounded(schema: &Schema) -> Result> { + schema + .fields() + .iter() + .enumerate() + .map(|(i, field)| { + Ok(Self { + column: Column::new(field.name(), i), + interval: Interval::make_unbounded(field.data_type())?, + distinct_count: Precision::Absent, + }) + }) + .collect() + } } /// Attempts to refine column boundaries and compute a selectivity value. diff --git a/datafusion/physical-expr/src/array_expressions.rs b/datafusion/physical-expr/src/array_expressions.rs index e6543808b97a..c2dc88b10773 100644 --- a/datafusion/physical-expr/src/array_expressions.rs +++ b/datafusion/physical-expr/src/array_expressions.rs @@ -28,14 +28,14 @@ use arrow::datatypes::{DataType, Field, UInt64Type}; use arrow::row::{RowConverter, SortField}; use arrow_buffer::NullBuffer; -use arrow_schema::FieldRef; +use arrow_schema::{FieldRef, SortOptions}; use datafusion_common::cast::{ - as_generic_string_array, as_int64_array, as_list_array, as_string_array, + as_generic_list_array, as_generic_string_array, as_int64_array, as_large_list_array, + as_list_array, as_null_array, as_string_array, }; -use datafusion_common::utils::array_into_list_array; +use datafusion_common::utils::{array_into_list_array, list_ndims}; use datafusion_common::{ - exec_err, internal_datafusion_err, internal_err, not_impl_err, plan_err, - DataFusionError, Result, + exec_err, internal_err, not_impl_err, plan_err, DataFusionError, Result, }; use itertools::Itertools; @@ -102,6 +102,7 @@ fn compare_element_to_list( ) -> Result { let indices = UInt32Array::from(vec![row_index as u32]); let element_array_row = arrow::compute::take(element_array, &indices, None)?; + // Compute all positions in list_row_array (that is itself an // array) that are equal to `from_array_row` let res = match element_array_row.data_type() { @@ -170,36 +171,11 @@ fn compute_array_length( value = downcast_arg!(value, ListArray).value(0); current_dimension += 1; } - _ => return Ok(None), - } - } -} - -/// Returns the dimension of the array -fn compute_array_ndims(arr: Option) -> Result> { - Ok(compute_array_ndims_with_datatype(arr)?.0) -} - -/// Returns the dimension and the datatype of elements of the array -fn compute_array_ndims_with_datatype( - arr: Option, -) -> Result<(Option, DataType)> { - let mut res: u64 = 1; - let mut value = match arr { - Some(arr) => arr, - None => return Ok((None, DataType::Null)), - }; - if value.is_empty() { - return Ok((None, DataType::Null)); - } - - loop { - match value.data_type() { - DataType::List(..) => { - value = downcast_arg!(value, ListArray).value(0); - res += 1; + DataType::LargeList(..) => { + value = downcast_arg!(value, LargeListArray).value(0); + current_dimension += 1; } - data_type => return Ok((Some(res), data_type.clone())), + _ => return Ok(None), } } } @@ -280,7 +256,7 @@ macro_rules! call_array_function { } /// Convert one or more [`ArrayRef`] of the same type into a -/// `ListArray` +/// `ListArray` or 'LargeListArray' depending on the offset size. /// /// # Example (non nested) /// @@ -319,7 +295,10 @@ macro_rules! call_array_function { /// └──────────────┘ └──────────────┘ └─────────────────────────────┘ /// col1 col2 output /// ``` -fn array_array(args: &[ArrayRef], data_type: DataType) -> Result { +fn array_array( + args: &[ArrayRef], + data_type: DataType, +) -> Result { // do not accept 0 arguments. if args.is_empty() { return plan_err!("Array requires at least one argument"); @@ -336,8 +315,9 @@ fn array_array(args: &[ArrayRef], data_type: DataType) -> Result { total_len += arg_data.len(); data.push(arg_data); } - let mut offsets = Vec::with_capacity(total_len); - offsets.push(0); + + let mut offsets: Vec = Vec::with_capacity(total_len); + offsets.push(O::usize_as(0)); let capacity = Capacities::Array(total_len); let data_ref = data.iter().collect::>(); @@ -355,11 +335,11 @@ fn array_array(args: &[ArrayRef], data_type: DataType) -> Result { mutable.extend_nulls(1); } } - offsets.push(mutable.len() as i32); + offsets.push(O::usize_as(mutable.len())); } - let data = mutable.freeze(); - Ok(Arc::new(ListArray::try_new( + + Ok(Arc::new(GenericListArray::::try_new( Arc::new(Field::new("item", data_type, true)), OffsetBuffer::new(offsets.into()), arrow_array::make_array(data), @@ -384,143 +364,69 @@ pub fn make_array(arrays: &[ArrayRef]) -> Result { let array = new_null_array(&DataType::Null, arrays.len()); Ok(Arc::new(array_into_list_array(array))) } - data_type => array_array(arrays, data_type), + DataType::LargeList(..) => array_array::(arrays, data_type), + _ => array_array::(arrays, data_type), } } -fn return_empty(return_null: bool, data_type: DataType) -> Arc { - if return_null { - new_null_array(&data_type, 1) - } else { - new_empty_array(&data_type) - } -} +/// array_element SQL function +/// +/// There are two arguments for array_element, the first one is the array, the second one is the 1-indexed index. +/// `array_element(array, index)` +/// +/// For example: +/// > array_element(\[1, 2, 3], 2) -> 2 +pub fn array_element(args: &[ArrayRef]) -> Result { + let list_array = as_list_array(&args[0])?; + let indexes = as_int64_array(&args[1])?; -macro_rules! list_slice { - ($ARRAY:expr, $I:expr, $J:expr, $RETURN_ELEMENT:expr, $ARRAY_TYPE:ident) => {{ - let array = $ARRAY.as_any().downcast_ref::<$ARRAY_TYPE>().unwrap(); - if $I == 0 && $J == 0 || $ARRAY.is_empty() { - return return_empty($RETURN_ELEMENT, $ARRAY.data_type().clone()); - } + let values = list_array.values(); + let original_data = values.to_data(); + let capacity = Capacities::Array(original_data.len()); - let i = if $I < 0 { - if $I.abs() as usize > array.len() { - return return_empty(true, $ARRAY.data_type().clone()); - } + // use_nulls: true, we don't construct List for array_element, so we need explicit nulls. + let mut mutable = + MutableArrayData::with_capacities(vec![&original_data], true, capacity); - (array.len() as i64 + $I + 1) as usize + fn adjusted_array_index(index: i64, len: usize) -> Option { + // 0 ~ len - 1 + let adjusted_zero_index = if index < 0 { + index + len as i64 } else { - if $I == 0 { - 1 - } else { - $I as usize - } + index - 1 }; - let j = if $J < 0 { - if $J.abs() as usize > array.len() { - return return_empty(true, $ARRAY.data_type().clone()); - } - if $RETURN_ELEMENT { - (array.len() as i64 + $J + 1) as usize - } else { - (array.len() as i64 + $J) as usize - } + if 0 <= adjusted_zero_index && adjusted_zero_index < len as i64 { + Some(adjusted_zero_index) } else { - if $J == 0 { - 1 - } else { - if $J as usize > array.len() { - array.len() - } else { - $J as usize - } - } - }; + // Out of bounds + None + } + } - if i > j || i as usize > $ARRAY.len() { - return_empty($RETURN_ELEMENT, $ARRAY.data_type().clone()) - } else { - Arc::new(array.slice((i - 1), (j + 1 - i))) + for (row_index, offset_window) in list_array.offsets().windows(2).enumerate() { + let start = offset_window[0] as usize; + let end = offset_window[1] as usize; + let len = end - start; + + // array is null + if len == 0 { + mutable.extend_nulls(1); + continue; } - }}; -} -macro_rules! slice { - ($ARRAY:expr, $KEY:expr, $EXTRA_KEY:expr, $RETURN_ELEMENT:expr, $ARRAY_TYPE:ident) => {{ - let sliced_array: Vec> = $ARRAY - .iter() - .zip($KEY.iter()) - .zip($EXTRA_KEY.iter()) - .map(|((arr, i), j)| match (arr, i, j) { - (Some(arr), Some(i), Some(j)) => { - list_slice!(arr, i, j, $RETURN_ELEMENT, $ARRAY_TYPE) - } - (Some(arr), None, Some(j)) => { - list_slice!(arr, 1i64, j, $RETURN_ELEMENT, $ARRAY_TYPE) - } - (Some(arr), Some(i), None) => { - list_slice!(arr, i, arr.len() as i64, $RETURN_ELEMENT, $ARRAY_TYPE) - } - (Some(arr), None, None) if !$RETURN_ELEMENT => arr, - _ => return_empty($RETURN_ELEMENT, $ARRAY.value_type().clone()), - }) - .collect(); + let index = adjusted_array_index(indexes.value(row_index), len); - // concat requires input of at least one array - if sliced_array.is_empty() { - Ok(return_empty($RETURN_ELEMENT, $ARRAY.value_type())) + if let Some(index) = index { + mutable.extend(0, start + index as usize, start + index as usize + 1); } else { - let vec = sliced_array - .iter() - .map(|a| a.as_ref()) - .collect::>(); - let mut i: i32 = 0; - let mut offsets = vec![i]; - offsets.extend( - vec.iter() - .map(|a| { - i += a.len() as i32; - i - }) - .collect::>(), - ); - let values = compute::concat(vec.as_slice()).unwrap(); - - if $RETURN_ELEMENT { - Ok(values) - } else { - let field = - Arc::new(Field::new("item", $ARRAY.value_type().clone(), true)); - Ok(Arc::new(ListArray::try_new( - field, - OffsetBuffer::new(offsets.into()), - values, - None, - )?)) - } + // Index out of bounds + mutable.extend_nulls(1); } - }}; -} - -fn define_array_slice( - list_array: &ListArray, - key: &Int64Array, - extra_key: &Int64Array, - return_element: bool, -) -> Result { - macro_rules! array_function { - ($ARRAY_TYPE:ident) => { - slice!(list_array, key, extra_key, return_element, $ARRAY_TYPE) - }; } - call_array_function!(list_array.value_type(), true) -} -pub fn array_element(args: &[ArrayRef]) -> Result { - let list_array = as_list_array(&args[0])?; - let key = as_int64_array(&args[1])?; - define_array_slice(list_array, key, key, true) + let data = mutable.freeze(); + Ok(arrow_array::make_array(data)) } fn general_except( @@ -559,7 +465,7 @@ fn general_except( dedup.clear(); } - if let Some(values) = converter.convert_rows(rows)?.get(0) { + if let Some(values) = converter.convert_rows(rows)?.first() { Ok(GenericListArray::::new( field.to_owned(), OffsetBuffer::new(offsets.into()), @@ -601,47 +507,136 @@ pub fn array_except(args: &[ArrayRef]) -> Result { } } +/// array_slice SQL function +/// +/// We follow the behavior of array_slice in DuckDB +/// Note that array_slice is 1-indexed. And there are two additional arguments `from` and `to` in array_slice. +/// +/// > array_slice(array, from, to) +/// +/// Positive index is treated as the index from the start of the array. If the +/// `from` index is smaller than 1, it is treated as 1. If the `to` index is larger than the +/// length of the array, it is treated as the length of the array. +/// +/// Negative index is treated as the index from the end of the array. If the index +/// is larger than the length of the array, it is NOT VALID, either in `from` or `to`. +/// The `to` index is exclusive like python slice syntax. +/// +/// See test cases in `array.slt` for more details. pub fn array_slice(args: &[ArrayRef]) -> Result { let list_array = as_list_array(&args[0])?; - let key = as_int64_array(&args[1])?; - let extra_key = as_int64_array(&args[2])?; - define_array_slice(list_array, key, extra_key, false) -} - -fn general_array_pop( - list_array: &GenericListArray, - from_back: bool, -) -> Result<(Vec, Vec)> { - if from_back { - let key = vec![0; list_array.len()]; - // Atttetion: `arr.len() - 1` in extra key defines the last element position (position = index + 1, not inclusive) we want in the new array. - let extra_key: Vec<_> = list_array - .iter() - .map(|x| x.map_or(0, |arr| arr.len() as i64 - 1)) - .collect(); - Ok((key, extra_key)) - } else { - // Atttetion: 2 in the `key`` defines the first element position (position = index + 1) we want in the new array. - // We only handle two cases of the first element index: if the old array has any elements, starts from 2 (index + 1), or starts from initial. - let key: Vec<_> = list_array.iter().map(|x| x.map_or(0, |_| 2)).collect(); - let extra_key: Vec<_> = list_array - .iter() - .map(|x| x.map_or(0, |arr| arr.len() as i64)) - .collect(); - Ok((key, extra_key)) + let from_array = as_int64_array(&args[1])?; + let to_array = as_int64_array(&args[2])?; + + let values = list_array.values(); + let original_data = values.to_data(); + let capacity = Capacities::Array(original_data.len()); + + // use_nulls: false, we don't need nulls but empty array for array_slice, so we don't need explicit nulls but adjust offset to indicate nulls. + let mut mutable = + MutableArrayData::with_capacities(vec![&original_data], false, capacity); + + // We have the slice syntax compatible with DuckDB v0.8.1. + // The rule `adjusted_from_index` and `adjusted_to_index` follows the rule of array_slice in duckdb. + + fn adjusted_from_index(index: i64, len: usize) -> Option { + // 0 ~ len - 1 + let adjusted_zero_index = if index < 0 { + index + len as i64 + } else { + // array_slice(arr, 1, to) is the same as array_slice(arr, 0, to) + std::cmp::max(index - 1, 0) + }; + + if 0 <= adjusted_zero_index && adjusted_zero_index < len as i64 { + Some(adjusted_zero_index) + } else { + // Out of bounds + None + } } + + fn adjusted_to_index(index: i64, len: usize) -> Option { + // 0 ~ len - 1 + let adjusted_zero_index = if index < 0 { + // array_slice in duckdb with negative to_index is python-like, so index itself is exclusive + index + len as i64 - 1 + } else { + // array_slice(arr, from, len + 1) is the same as array_slice(arr, from, len) + std::cmp::min(index - 1, len as i64 - 1) + }; + + if 0 <= adjusted_zero_index && adjusted_zero_index < len as i64 { + Some(adjusted_zero_index) + } else { + // Out of bounds + None + } + } + + let mut offsets = vec![0]; + + for (row_index, offset_window) in list_array.offsets().windows(2).enumerate() { + let start = offset_window[0] as usize; + let end = offset_window[1] as usize; + let len = end - start; + + // len 0 indicate array is null, return empty array in this row. + if len == 0 { + offsets.push(offsets[row_index]); + continue; + } + + // If index is null, we consider it as the minimum / maximum index of the array. + let from_index = if from_array.is_null(row_index) { + Some(0) + } else { + adjusted_from_index(from_array.value(row_index), len) + }; + + let to_index = if to_array.is_null(row_index) { + Some(len as i64 - 1) + } else { + adjusted_to_index(to_array.value(row_index), len) + }; + + if let (Some(from), Some(to)) = (from_index, to_index) { + if from <= to { + assert!(start + to as usize <= end); + mutable.extend(0, start + from as usize, start + to as usize + 1); + offsets.push(offsets[row_index] + (to - from + 1) as i32); + } else { + // invalid range, return empty array + offsets.push(offsets[row_index]); + } + } else { + // invalid range, return empty array + offsets.push(offsets[row_index]); + } + } + + let data = mutable.freeze(); + + Ok(Arc::new(ListArray::try_new( + Arc::new(Field::new("item", list_array.value_type(), true)), + OffsetBuffer::new(offsets.into()), + arrow_array::make_array(data), + None, + )?)) } +/// array_pop_back SQL function pub fn array_pop_back(args: &[ArrayRef]) -> Result { let list_array = as_list_array(&args[0])?; - let (key, extra_key) = general_array_pop(list_array, true)?; - - define_array_slice( - list_array, - &Int64Array::from(key), - &Int64Array::from(extra_key), - false, - ) + let from_array = Int64Array::from(vec![1; list_array.len()]); + let to_array = Int64Array::from( + list_array + .iter() + .map(|arr| arr.map_or(0, |arr| arr.len() as i64 - 1)) + .collect::>(), + ); + let args = vec![args[0].clone(), Arc::new(from_array), Arc::new(to_array)]; + array_slice(args.as_slice()) } /// Appends or prepends elements to a ListArray. @@ -715,7 +710,7 @@ fn general_append_and_prepend( /// # Arguments /// /// * `args` - An array of 1 to 3 ArrayRefs representing start, stop, and step(step value can not be zero.) values. -/// +/// /// # Examples /// /// gen_range(3) => [0, 1, 2] @@ -765,16 +760,18 @@ pub fn gen_range(args: &[ArrayRef]) -> Result { Ok(arr) } +/// array_pop_front SQL function pub fn array_pop_front(args: &[ArrayRef]) -> Result { let list_array = as_list_array(&args[0])?; - let (key, extra_key) = general_array_pop(list_array, false)?; - - define_array_slice( - list_array, - &Int64Array::from(key), - &Int64Array::from(extra_key), - false, - ) + let from_array = Int64Array::from(vec![2; list_array.len()]); + let to_array = Int64Array::from( + list_array + .iter() + .map(|arr| arr.map_or(0, |arr| arr.len() as i64)) + .collect::>(), + ); + let args = vec![args[0].clone(), Arc::new(from_array), Arc::new(to_array)]; + array_slice(args.as_slice()) } /// Array_append SQL function @@ -799,6 +796,85 @@ pub fn array_append(args: &[ArrayRef]) -> Result { Ok(res) } +/// Array_sort SQL function +pub fn array_sort(args: &[ArrayRef]) -> Result { + let sort_option = match args.len() { + 1 => None, + 2 => { + let sort = as_string_array(&args[1])?.value(0); + Some(SortOptions { + descending: order_desc(sort)?, + nulls_first: true, + }) + } + 3 => { + let sort = as_string_array(&args[1])?.value(0); + let nulls_first = as_string_array(&args[2])?.value(0); + Some(SortOptions { + descending: order_desc(sort)?, + nulls_first: order_nulls_first(nulls_first)?, + }) + } + _ => return internal_err!("array_sort expects 1 to 3 arguments"), + }; + + let list_array = as_list_array(&args[0])?; + let row_count = list_array.len(); + + let mut array_lengths = vec![]; + let mut arrays = vec![]; + let mut valid = BooleanBufferBuilder::new(row_count); + for i in 0..row_count { + if list_array.is_null(i) { + array_lengths.push(0); + valid.append(false); + } else { + let arr_ref = list_array.value(i); + let arr_ref = arr_ref.as_ref(); + + let sorted_array = compute::sort(arr_ref, sort_option)?; + array_lengths.push(sorted_array.len()); + arrays.push(sorted_array); + valid.append(true); + } + } + + // Assume all arrays have the same data type + let data_type = list_array.value_type(); + let buffer = valid.finish(); + + let elements = arrays + .iter() + .map(|a| a.as_ref()) + .collect::>(); + + let list_arr = ListArray::new( + Arc::new(Field::new("item", data_type, true)), + OffsetBuffer::from_lengths(array_lengths), + Arc::new(compute::concat(elements.as_slice())?), + Some(NullBuffer::new(buffer)), + ); + Ok(Arc::new(list_arr)) +} + +fn order_desc(modifier: &str) -> Result { + match modifier.to_uppercase().as_str() { + "DESC" => Ok(true), + "ASC" => Ok(false), + _ => internal_err!("the second parameter of array_sort expects DESC or ASC"), + } +} + +fn order_nulls_first(modifier: &str) -> Result { + match modifier.to_uppercase().as_str() { + "NULLS FIRST" => Ok(true), + "NULLS LAST" => Ok(false), + _ => internal_err!( + "the third parameter of array_sort expects NULLS FIRST or NULLS LAST" + ), + } +} + /// Array_prepend SQL function pub fn array_prepend(args: &[ArrayRef]) -> Result { let list_array = as_list_array(&args[1])?; @@ -824,10 +900,7 @@ pub fn array_prepend(args: &[ArrayRef]) -> Result { fn align_array_dimensions(args: Vec) -> Result> { let args_ndim = args .iter() - .map(|arg| compute_array_ndims(Some(arg.to_owned()))) - .collect::>>()? - .into_iter() - .map(|x| x.unwrap_or(0)) + .map(|arg| datafusion_common::utils::list_ndims(arg.data_type())) .collect::>(); let max_ndim = args_ndim.iter().max().unwrap_or(&0); @@ -918,6 +991,7 @@ fn concat_internal(args: &[ArrayRef]) -> Result { Arc::new(compute::concat(elements.as_slice())?), Some(NullBuffer::new(buffer)), ); + Ok(Arc::new(list_arr)) } @@ -925,11 +999,11 @@ fn concat_internal(args: &[ArrayRef]) -> Result { pub fn array_concat(args: &[ArrayRef]) -> Result { let mut new_args = vec![]; for arg in args { - let (ndim, lower_data_type) = - compute_array_ndims_with_datatype(Some(arg.clone()))?; - if ndim.is_none() || ndim == Some(1) { - return not_impl_err!("Array is not type '{lower_data_type:?}'."); - } else if !lower_data_type.equals_datatype(&DataType::Null) { + let ndim = list_ndims(arg.data_type()); + let base_type = datafusion_common::utils::base_type(arg.data_type()); + if ndim == 0 { + return not_impl_err!("Array is not type '{base_type:?}'."); + } else if !base_type.eq(&DataType::Null) { new_args.push(arg.clone()); } } @@ -939,12 +1013,21 @@ pub fn array_concat(args: &[ArrayRef]) -> Result { /// Array_empty SQL function pub fn array_empty(args: &[ArrayRef]) -> Result { - if args[0].as_any().downcast_ref::().is_some() { + if as_null_array(&args[0]).is_ok() { // Make sure to return Boolean type. return Ok(Arc::new(BooleanArray::new_null(args[0].len()))); } + let array_type = args[0].data_type(); + + match array_type { + DataType::List(_) => array_empty_dispatch::(&args[0]), + DataType::LargeList(_) => array_empty_dispatch::(&args[0]), + _ => internal_err!("array_empty does not support type '{array_type:?}'."), + } +} - let array = as_list_array(&args[0])?; +fn array_empty_dispatch(array: &ArrayRef) -> Result { + let array = as_generic_list_array::(array)?; let builder = array .iter() .map(|arr| arr.map(|arr| arr.len() == arr.null_count())) @@ -1340,84 +1423,76 @@ fn general_replace( ) -> Result { // Build up the offsets for the final output array let mut offsets: Vec = vec![0]; - let data_type = list_array.value_type(); - let mut new_values = vec![]; + let values = list_array.values(); + let original_data = values.to_data(); + let to_data = to_array.to_data(); + let capacity = Capacities::Array(original_data.len()); - // n is the number of elements to replace in this row - for (row_index, (list_array_row, n)) in - list_array.iter().zip(arr_n.iter()).enumerate() - { - let last_offset: i32 = offsets - .last() - .copied() - .ok_or_else(|| internal_datafusion_err!("offsets should not be empty"))?; + // First array is the original array, second array is the element to replace with. + let mut mutable = MutableArrayData::with_capacities( + vec![&original_data, &to_data], + false, + capacity, + ); - match list_array_row { - Some(list_array_row) => { - // Compute all positions in list_row_array (that is itself an - // array) that are equal to `from_array_row` - let eq_array = compare_element_to_list( - &list_array_row, - &from_array, - row_index, - true, - )?; + let mut valid = BooleanBufferBuilder::new(list_array.len()); - // Use MutableArrayData to build the replaced array - let original_data = list_array_row.to_data(); - let to_data = to_array.to_data(); - let capacity = Capacities::Array(original_data.len() + to_data.len()); + for (row_index, offset_window) in list_array.offsets().windows(2).enumerate() { + if list_array.is_null(row_index) { + offsets.push(offsets[row_index]); + valid.append(false); + continue; + } - // First array is the original array, second array is the element to replace with. - let mut mutable = MutableArrayData::with_capacities( - vec![&original_data, &to_data], - false, - capacity, - ); - let original_idx = 0; - let replace_idx = 1; - - let mut counter = 0; - for (i, to_replace) in eq_array.iter().enumerate() { - if let Some(true) = to_replace { - mutable.extend(replace_idx, row_index, row_index + 1); - counter += 1; - if counter == *n { - // copy original data for any matches past n - mutable.extend(original_idx, i + 1, eq_array.len()); - break; - } - } else { - // copy original data for false / null matches - mutable.extend(original_idx, i, i + 1); - } - } + let start = offset_window[0] as usize; + let end = offset_window[1] as usize; - let data = mutable.freeze(); - let replaced_array = arrow_array::make_array(data); + let list_array_row = list_array.value(row_index); - offsets.push(last_offset + replaced_array.len() as i32); - new_values.push(replaced_array); - } - None => { - // Null element results in a null row (no new offsets) - offsets.push(last_offset); + // Compute all positions in list_row_array (that is itself an + // array) that are equal to `from_array_row` + let eq_array = + compare_element_to_list(&list_array_row, &from_array, row_index, true)?; + + let original_idx = 0; + let replace_idx = 1; + let n = arr_n[row_index]; + let mut counter = 0; + + // All elements are false, no need to replace, just copy original data + if eq_array.false_count() == eq_array.len() { + mutable.extend(original_idx, start, end); + offsets.push(offsets[row_index] + (end - start) as i32); + valid.append(true); + continue; + } + + for (i, to_replace) in eq_array.iter().enumerate() { + if let Some(true) = to_replace { + mutable.extend(replace_idx, row_index, row_index + 1); + counter += 1; + if counter == n { + // copy original data for any matches past n + mutable.extend(original_idx, start + i + 1, end); + break; + } + } else { + // copy original data for false / null matches + mutable.extend(original_idx, start + i, start + i + 1); } } + + offsets.push(offsets[row_index] + (end - start) as i32); + valid.append(true); } - let values = if new_values.is_empty() { - new_empty_array(&data_type) - } else { - let new_values: Vec<_> = new_values.iter().map(|a| a.as_ref()).collect(); - arrow::compute::concat(&new_values)? - }; + let data = mutable.freeze(); Ok(Arc::new(ListArray::try_new( - Arc::new(Field::new("item", data_type, true)), + Arc::new(Field::new("item", list_array.value_type(), true)), OffsetBuffer::new(offsets.into()), - values, - list_array.nulls().cloned(), + arrow_array::make_array(data), + Some(NullBuffer::new(valid.finish())), )?)) } @@ -1434,7 +1509,7 @@ pub fn array_replace_n(args: &[ArrayRef]) -> Result { } pub fn array_replace_all(args: &[ArrayRef]) -> Result { - // replace all occurences (up to "i64::MAX") + // replace all occurrences (up to "i64::MAX") let arr_n = vec![i64::MAX; args[0].len()]; general_replace(as_list_array(&args[0])?, &args[1], &args[2], arr_n) } @@ -1515,32 +1590,33 @@ pub fn array_union(args: &[ArrayRef]) -> Result { } let array1 = &args[0]; let array2 = &args[1]; + + fn union_arrays( + array1: &ArrayRef, + array2: &ArrayRef, + l_field_ref: &Arc, + r_field_ref: &Arc, + ) -> Result { + match (l_field_ref.data_type(), r_field_ref.data_type()) { + (DataType::Null, _) => Ok(array2.clone()), + (_, DataType::Null) => Ok(array1.clone()), + (_, _) => { + let list1 = array1.as_list::(); + let list2 = array2.as_list::(); + let result = union_generic_lists::(list1, list2, l_field_ref)?; + Ok(Arc::new(result)) + } + } + } + match (array1.data_type(), array2.data_type()) { (DataType::Null, _) => Ok(array2.clone()), (_, DataType::Null) => Ok(array1.clone()), (DataType::List(l_field_ref), DataType::List(r_field_ref)) => { - match (l_field_ref.data_type(), r_field_ref.data_type()) { - (DataType::Null, _) => Ok(array2.clone()), - (_, DataType::Null) => Ok(array1.clone()), - (_, _) => { - let list1 = array1.as_list::(); - let list2 = array2.as_list::(); - let result = union_generic_lists::(list1, list2, l_field_ref)?; - Ok(Arc::new(result)) - } - } + union_arrays::(array1, array2, l_field_ref, r_field_ref) } (DataType::LargeList(l_field_ref), DataType::LargeList(r_field_ref)) => { - match (l_field_ref.data_type(), r_field_ref.data_type()) { - (DataType::Null, _) => Ok(array2.clone()), - (_, DataType::Null) => Ok(array1.clone()), - (_, _) => { - let list1 = array1.as_list::(); - let list2 = array2.as_list::(); - let result = union_generic_lists::(list1, list2, l_field_ref)?; - Ok(Arc::new(result)) - } - } + union_arrays::(array1, array2, l_field_ref, r_field_ref) } _ => { internal_err!( @@ -1721,11 +1797,11 @@ pub fn flatten(args: &[ArrayRef]) -> Result { Ok(Arc::new(flattened_array) as ArrayRef) } -/// Array_length SQL function -pub fn array_length(args: &[ArrayRef]) -> Result { - let list_array = as_list_array(&args[0])?; - let dimension = if args.len() == 2 { - as_int64_array(&args[1])?.clone() +/// Dispatch array length computation based on the offset type. +fn array_length_dispatch(array: &[ArrayRef]) -> Result { + let list_array = as_generic_list_array::(&array[0])?; + let dimension = if array.len() == 2 { + as_int64_array(&array[1])?.clone() } else { Int64Array::from_value(1, list_array.len()) }; @@ -1739,6 +1815,18 @@ pub fn array_length(args: &[ArrayRef]) -> Result { Ok(Arc::new(result) as ArrayRef) } +/// Array_length SQL function +pub fn array_length(args: &[ArrayRef]) -> Result { + match &args[0].data_type() { + DataType::List(_) => array_length_dispatch::(args), + DataType::LargeList(_) => array_length_dispatch::(args), + _ => internal_err!( + "array_length does not support type '{:?}'", + args[0].data_type() + ), + } +} + /// Array_dims SQL function pub fn array_dims(args: &[ArrayRef]) -> Result { let list_array = as_list_array(&args[0])?; @@ -1754,92 +1842,137 @@ pub fn array_dims(args: &[ArrayRef]) -> Result { /// Array_ndims SQL function pub fn array_ndims(args: &[ArrayRef]) -> Result { - let list_array = as_list_array(&args[0])?; + if let Some(list_array) = args[0].as_list_opt::() { + let ndims = datafusion_common::utils::list_ndims(list_array.data_type()); - let result = list_array - .iter() - .map(compute_array_ndims) - .collect::>()?; + let mut data = vec![]; + for arr in list_array.iter() { + if arr.is_some() { + data.push(Some(ndims)) + } else { + data.push(None) + } + } - Ok(Arc::new(result) as ArrayRef) + Ok(Arc::new(UInt64Array::from(data)) as ArrayRef) + } else { + Ok(Arc::new(UInt64Array::from(vec![0; args[0].len()])) as ArrayRef) + } } -/// Array_has SQL function -pub fn array_has(args: &[ArrayRef]) -> Result { - let array = as_list_array(&args[0])?; - let element = &args[1]; +/// Represents the type of comparison for array_has. +#[derive(Debug, PartialEq)] +enum ComparisonType { + // array_has_all + All, + // array_has_any + Any, + // array_has + Single, +} + +fn general_array_has_dispatch( + array: &ArrayRef, + sub_array: &ArrayRef, + comparison_type: ComparisonType, +) -> Result { + let array = if comparison_type == ComparisonType::Single { + let arr = as_generic_list_array::(array)?; + check_datatypes("array_has", &[arr.values(), sub_array])?; + arr + } else { + check_datatypes("array_has", &[array, sub_array])?; + as_generic_list_array::(array)? + }; - check_datatypes("array_has", &[array.values(), element])?; let mut boolean_builder = BooleanArray::builder(array.len()); let converter = RowConverter::new(vec![SortField::new(array.value_type())])?; - let r_values = converter.convert_columns(&[element.clone()])?; - for (row_idx, arr) in array.iter().enumerate() { - if let Some(arr) = arr { + + let element = sub_array.clone(); + let sub_array = if comparison_type != ComparisonType::Single { + as_generic_list_array::(sub_array)? + } else { + array + }; + + for (row_idx, (arr, sub_arr)) in array.iter().zip(sub_array.iter()).enumerate() { + if let (Some(arr), Some(sub_arr)) = (arr, sub_arr) { let arr_values = converter.convert_columns(&[arr])?; - let res = arr_values - .iter() - .dedup() - .any(|x| x == r_values.row(row_idx)); + let sub_arr_values = if comparison_type != ComparisonType::Single { + converter.convert_columns(&[sub_arr])? + } else { + converter.convert_columns(&[element.clone()])? + }; + + let mut res = match comparison_type { + ComparisonType::All => sub_arr_values + .iter() + .dedup() + .all(|elem| arr_values.iter().dedup().any(|x| x == elem)), + ComparisonType::Any => sub_arr_values + .iter() + .dedup() + .any(|elem| arr_values.iter().dedup().any(|x| x == elem)), + ComparisonType::Single => arr_values + .iter() + .dedup() + .any(|x| x == sub_arr_values.row(row_idx)), + }; + + if comparison_type == ComparisonType::Any { + res |= res; + } + boolean_builder.append_value(res); } } Ok(Arc::new(boolean_builder.finish())) } -/// Array_has_any SQL function -pub fn array_has_any(args: &[ArrayRef]) -> Result { - check_datatypes("array_has_any", &[&args[0], &args[1]])?; +/// Array_has SQL function +pub fn array_has(args: &[ArrayRef]) -> Result { + let array_type = args[0].data_type(); - let array = as_list_array(&args[0])?; - let sub_array = as_list_array(&args[1])?; - let mut boolean_builder = BooleanArray::builder(array.len()); + match array_type { + DataType::List(_) => { + general_array_has_dispatch::(&args[0], &args[1], ComparisonType::Single) + } + DataType::LargeList(_) => { + general_array_has_dispatch::(&args[0], &args[1], ComparisonType::Single) + } + _ => internal_err!("array_has does not support type '{array_type:?}'."), + } +} - let converter = RowConverter::new(vec![SortField::new(array.value_type())])?; - for (arr, sub_arr) in array.iter().zip(sub_array.iter()) { - if let (Some(arr), Some(sub_arr)) = (arr, sub_arr) { - let arr_values = converter.convert_columns(&[arr])?; - let sub_arr_values = converter.convert_columns(&[sub_arr])?; +/// Array_has_any SQL function +pub fn array_has_any(args: &[ArrayRef]) -> Result { + let array_type = args[0].data_type(); - let mut res = false; - for elem in sub_arr_values.iter().dedup() { - res |= arr_values.iter().dedup().any(|x| x == elem); - if res { - break; - } - } - boolean_builder.append_value(res); + match array_type { + DataType::List(_) => { + general_array_has_dispatch::(&args[0], &args[1], ComparisonType::Any) } + DataType::LargeList(_) => { + general_array_has_dispatch::(&args[0], &args[1], ComparisonType::Any) + } + _ => internal_err!("array_has_any does not support type '{array_type:?}'."), } - Ok(Arc::new(boolean_builder.finish())) } /// Array_has_all SQL function pub fn array_has_all(args: &[ArrayRef]) -> Result { - check_datatypes("array_has_all", &[&args[0], &args[1]])?; + let array_type = args[0].data_type(); - let array = as_list_array(&args[0])?; - let sub_array = as_list_array(&args[1])?; - - let mut boolean_builder = BooleanArray::builder(array.len()); - - let converter = RowConverter::new(vec![SortField::new(array.value_type())])?; - for (arr, sub_arr) in array.iter().zip(sub_array.iter()) { - if let (Some(arr), Some(sub_arr)) = (arr, sub_arr) { - let arr_values = converter.convert_columns(&[arr])?; - let sub_arr_values = converter.convert_columns(&[sub_arr])?; - - let mut res = true; - for elem in sub_arr_values.iter().dedup() { - res &= arr_values.iter().dedup().any(|x| x == elem); - if !res { - break; - } - } - boolean_builder.append_value(res); + match array_type { + DataType::List(_) => { + general_array_has_dispatch::(&args[0], &args[1], ComparisonType::All) } + DataType::LargeList(_) => { + general_array_has_dispatch::(&args[0], &args[1], ComparisonType::All) + } + _ => internal_err!("array_has_all does not support type '{array_type:?}'."), } - Ok(Arc::new(boolean_builder.finish())) } /// Splits string at occurrences of delimiter and returns an array of parts @@ -1974,7 +2107,7 @@ pub fn array_intersect(args: &[ArrayRef]) -> Result { }; offsets.push(last_offset + rows.len() as i32); let arrays = converter.convert_rows(rows)?; - let array = match arrays.get(0) { + let array = match arrays.first() { Some(array) => array.clone(), None => { return internal_err!( @@ -1997,6 +2130,66 @@ pub fn array_intersect(args: &[ArrayRef]) -> Result { } } +pub fn general_array_distinct( + array: &GenericListArray, + field: &FieldRef, +) -> Result { + let dt = array.value_type(); + let mut offsets = Vec::with_capacity(array.len()); + offsets.push(OffsetSize::usize_as(0)); + let mut new_arrays = Vec::with_capacity(array.len()); + let converter = RowConverter::new(vec![SortField::new(dt.clone())])?; + // distinct for each list in ListArray + for arr in array.iter().flatten() { + let values = converter.convert_columns(&[arr])?; + // sort elements in list and remove duplicates + let rows = values.iter().sorted().dedup().collect::>(); + let last_offset: OffsetSize = offsets.last().copied().unwrap(); + offsets.push(last_offset + OffsetSize::usize_as(rows.len())); + let arrays = converter.convert_rows(rows)?; + let array = match arrays.get(0) { + Some(array) => array.clone(), + None => { + return internal_err!("array_distinct: failed to get array from rows") + } + }; + new_arrays.push(array); + } + let offsets = OffsetBuffer::new(offsets.into()); + let new_arrays_ref = new_arrays.iter().map(|v| v.as_ref()).collect::>(); + let values = compute::concat(&new_arrays_ref)?; + Ok(Arc::new(GenericListArray::::try_new( + field.clone(), + offsets, + values, + None, + )?)) +} + +/// array_distinct SQL function +/// example: from list [1, 3, 2, 3, 1, 2, 4] to [1, 2, 3, 4] +pub fn array_distinct(args: &[ArrayRef]) -> Result { + assert_eq!(args.len(), 1); + + // handle null + if args[0].data_type() == &DataType::Null { + return Ok(args[0].clone()); + } + + // handle for list & largelist + match args[0].data_type() { + DataType::List(field) => { + let array = as_list_array(&args[0])?; + general_array_distinct(array, field) + } + DataType::LargeList(field) => { + let array = as_large_list_array(&args[0])?; + general_array_distinct(array, field) + } + _ => internal_err!("array_distinct only support list array"), + } +} + #[cfg(test)] mod tests { use super::*; @@ -2023,10 +2216,10 @@ mod tests { .unwrap(); let expected = as_list_array(&array2d_1).unwrap(); - let expected_dim = compute_array_ndims(Some(array2d_1.to_owned())).unwrap(); + let expected_dim = datafusion_common::utils::list_ndims(array2d_1.data_type()); assert_ne!(as_list_array(&res[0]).unwrap(), expected); assert_eq!( - compute_array_ndims(Some(res[0].clone())).unwrap(), + datafusion_common::utils::list_ndims(res[0].data_type()), expected_dim ); @@ -2036,10 +2229,10 @@ mod tests { align_array_dimensions(vec![array1d_1, Arc::new(array3d_2.clone())]).unwrap(); let expected = as_list_array(&array3d_1).unwrap(); - let expected_dim = compute_array_ndims(Some(array3d_1.to_owned())).unwrap(); + let expected_dim = datafusion_common::utils::list_ndims(array3d_1.data_type()); assert_ne!(as_list_array(&res[0]).unwrap(), expected); assert_eq!( - compute_array_ndims(Some(res[0].clone())).unwrap(), + datafusion_common::utils::list_ndims(res[0].data_type()), expected_dim ); } diff --git a/datafusion/physical-expr/src/datetime_expressions.rs b/datafusion/physical-expr/src/datetime_expressions.rs index 0d42708c97ec..bbeb2b0dce86 100644 --- a/datafusion/physical-expr/src/datetime_expressions.rs +++ b/datafusion/physical-expr/src/datetime_expressions.rs @@ -36,6 +36,7 @@ use arrow::{ TimestampMillisecondType, TimestampNanosecondType, TimestampSecondType, }, }; +use arrow_array::types::ArrowTimestampType; use arrow_array::{ timezone::Tz, TimestampMicrosecondArray, TimestampMillisecondArray, TimestampSecondArray, @@ -43,7 +44,7 @@ use arrow_array::{ use chrono::prelude::*; use chrono::{Duration, Months, NaiveDate}; use datafusion_common::cast::{ - as_date32_array, as_date64_array, as_generic_string_array, + as_date32_array, as_date64_array, as_generic_string_array, as_primitive_array, as_timestamp_microsecond_array, as_timestamp_millisecond_array, as_timestamp_nanosecond_array, as_timestamp_second_array, }; @@ -130,6 +131,10 @@ fn string_to_timestamp_nanos_shim(s: &str) -> Result { } /// to_timestamp SQL function +/// +/// Note: `to_timestamp` returns `Timestamp(Nanosecond)` though its arguments are interpreted as **seconds**. The supported range for integer input is between `-9223372037` and `9223372036`. +/// Supported range for string input is between `1677-09-21T00:12:44.0` and `2262-04-11T23:47:16.0`. +/// Please use `to_timestamp_seconds` for the input outside of supported bounds. pub fn to_timestamp(args: &[ColumnarValue]) -> Result { handle::( args, @@ -331,7 +336,7 @@ fn date_trunc_coarse(granularity: &str, value: i64, tz: Option) -> Result, tz: Option, @@ -399,123 +404,61 @@ pub fn date_trunc(args: &[ColumnarValue]) -> Result { return exec_err!("Granularity of `date_trunc` must be non-null scalar Utf8"); }; + fn process_array( + array: &dyn Array, + granularity: String, + tz_opt: &Option>, + ) -> Result { + let parsed_tz = parse_tz(tz_opt)?; + let array = as_primitive_array::(array)?; + let array = array + .iter() + .map(|x| general_date_trunc(T::UNIT, &x, parsed_tz, granularity.as_str())) + .collect::>>()? + .with_timezone_opt(tz_opt.clone()); + Ok(ColumnarValue::Array(Arc::new(array))) + } + + fn process_scalar( + v: &Option, + granularity: String, + tz_opt: &Option>, + ) -> Result { + let parsed_tz = parse_tz(tz_opt)?; + let value = general_date_trunc(T::UNIT, v, parsed_tz, granularity.as_str())?; + let value = ScalarValue::new_timestamp::(value, tz_opt.clone()); + Ok(ColumnarValue::Scalar(value)) + } + Ok(match array { ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(v, tz_opt)) => { - let parsed_tz = parse_tz(tz_opt)?; - let value = - _date_trunc(TimeUnit::Nanosecond, v, parsed_tz, granularity.as_str())?; - let value = ScalarValue::TimestampNanosecond(value, tz_opt.clone()); - ColumnarValue::Scalar(value) + process_scalar::(v, granularity, tz_opt)? } ColumnarValue::Scalar(ScalarValue::TimestampMicrosecond(v, tz_opt)) => { - let parsed_tz = parse_tz(tz_opt)?; - let value = - _date_trunc(TimeUnit::Microsecond, v, parsed_tz, granularity.as_str())?; - let value = ScalarValue::TimestampMicrosecond(value, tz_opt.clone()); - ColumnarValue::Scalar(value) + process_scalar::(v, granularity, tz_opt)? } ColumnarValue::Scalar(ScalarValue::TimestampMillisecond(v, tz_opt)) => { - let parsed_tz = parse_tz(tz_opt)?; - let value = - _date_trunc(TimeUnit::Millisecond, v, parsed_tz, granularity.as_str())?; - let value = ScalarValue::TimestampMillisecond(value, tz_opt.clone()); - ColumnarValue::Scalar(value) + process_scalar::(v, granularity, tz_opt)? } ColumnarValue::Scalar(ScalarValue::TimestampSecond(v, tz_opt)) => { - let parsed_tz = parse_tz(tz_opt)?; - let value = - _date_trunc(TimeUnit::Second, v, parsed_tz, granularity.as_str())?; - let value = ScalarValue::TimestampSecond(value, tz_opt.clone()); - ColumnarValue::Scalar(value) + process_scalar::(v, granularity, tz_opt)? } ColumnarValue::Array(array) => { let array_type = array.data_type(); match array_type { DataType::Timestamp(TimeUnit::Second, tz_opt) => { - let parsed_tz = parse_tz(tz_opt)?; - let array = as_timestamp_second_array(array)?; - let array = array - .iter() - .map(|x| { - _date_trunc( - TimeUnit::Second, - &x, - parsed_tz, - granularity.as_str(), - ) - }) - .collect::>()? - .with_timezone_opt(tz_opt.clone()); - ColumnarValue::Array(Arc::new(array)) + process_array::(array, granularity, tz_opt)? } DataType::Timestamp(TimeUnit::Millisecond, tz_opt) => { - let parsed_tz = parse_tz(tz_opt)?; - let array = as_timestamp_millisecond_array(array)?; - let array = array - .iter() - .map(|x| { - _date_trunc( - TimeUnit::Millisecond, - &x, - parsed_tz, - granularity.as_str(), - ) - }) - .collect::>()? - .with_timezone_opt(tz_opt.clone()); - ColumnarValue::Array(Arc::new(array)) + process_array::(array, granularity, tz_opt)? } DataType::Timestamp(TimeUnit::Microsecond, tz_opt) => { - let parsed_tz = parse_tz(tz_opt)?; - let array = as_timestamp_microsecond_array(array)?; - let array = array - .iter() - .map(|x| { - _date_trunc( - TimeUnit::Microsecond, - &x, - parsed_tz, - granularity.as_str(), - ) - }) - .collect::>()? - .with_timezone_opt(tz_opt.clone()); - ColumnarValue::Array(Arc::new(array)) + process_array::(array, granularity, tz_opt)? } DataType::Timestamp(TimeUnit::Nanosecond, tz_opt) => { - let parsed_tz = parse_tz(tz_opt)?; - let array = as_timestamp_nanosecond_array(array)?; - let array = array - .iter() - .map(|x| { - _date_trunc( - TimeUnit::Nanosecond, - &x, - parsed_tz, - granularity.as_str(), - ) - }) - .collect::>()? - .with_timezone_opt(tz_opt.clone()); - ColumnarValue::Array(Arc::new(array)) - } - _ => { - let parsed_tz = None; - let array = as_timestamp_nanosecond_array(array)?; - let array = array - .iter() - .map(|x| { - _date_trunc( - TimeUnit::Nanosecond, - &x, - parsed_tz, - granularity.as_str(), - ) - }) - .collect::>()?; - - ColumnarValue::Array(Arc::new(array)) + process_array::(array, granularity, tz_opt)? } + _ => process_array::(array, granularity, &None)?, } } _ => { @@ -971,6 +914,11 @@ pub fn to_timestamp_invoke(args: &[ColumnarValue]) -> Result { &DataType::Timestamp(TimeUnit::Nanosecond, None), None, ), + DataType::Float64 => cast_column( + &args[0], + &DataType::Timestamp(TimeUnit::Nanosecond, None), + None, + ), DataType::Timestamp(_, None) => cast_column( &args[0], &DataType::Timestamp(TimeUnit::Nanosecond, None), @@ -1340,7 +1288,7 @@ mod tests { .collect::() .with_timezone_opt(tz_opt.clone()); let result = date_trunc(&[ - ColumnarValue::Scalar(ScalarValue::Utf8(Some("day".to_string()))), + ColumnarValue::Scalar(ScalarValue::from("day")), ColumnarValue::Array(Arc::new(input)), ]) .unwrap(); diff --git a/datafusion/physical-expr/src/equivalence.rs b/datafusion/physical-expr/src/equivalence.rs index f9f03300f5e9..4a562f4ef101 100644 --- a/datafusion/physical-expr/src/equivalence.rs +++ b/datafusion/physical-expr/src/equivalence.rs @@ -1520,7 +1520,7 @@ fn update_ordering( node.state = SortProperties::Ordered(options); } else if !node.expr.children().is_empty() { // We have an intermediate (non-leaf) node, account for its children: - node.state = node.expr.get_ordering(&node.children_states); + node.state = node.expr.get_ordering(&node.children_state()); } else if node.expr.as_any().is::() { // We have a Literal, which is the other possible leaf node type: node.state = node.expr.get_ordering(&[]); diff --git a/datafusion/physical-expr/src/expressions/cast.rs b/datafusion/physical-expr/src/expressions/cast.rs index b718b5017c5e..0c4ed3c12549 100644 --- a/datafusion/physical-expr/src/expressions/cast.rs +++ b/datafusion/physical-expr/src/expressions/cast.rs @@ -140,8 +140,7 @@ impl PhysicalExpr for CastExpr { let mut s = state; self.expr.hash(&mut s); self.cast_type.hash(&mut s); - // Add `self.cast_options` when hash is available - // https://github.com/apache/arrow-rs/pull/4395 + self.cast_options.hash(&mut s); } /// A [`CastExpr`] preserves the ordering of its child. @@ -157,8 +156,7 @@ impl PartialEq for CastExpr { .map(|x| { self.expr.eq(&x.expr) && self.cast_type == x.cast_type - // TODO: Use https://github.com/apache/arrow-rs/issues/2966 when available - && self.cast_options.safe == x.cast_options.safe + && self.cast_options == x.cast_options }) .unwrap_or(false) } @@ -176,7 +174,20 @@ pub fn cast_column( kernels::cast::cast_with_options(array, cast_type, &cast_options)?, )), ColumnarValue::Scalar(scalar) => { - let scalar_array = scalar.to_array()?; + let scalar_array = if cast_type + == &DataType::Timestamp(arrow_schema::TimeUnit::Nanosecond, None) + { + if let ScalarValue::Float64(Some(float_ts)) = scalar { + ScalarValue::Int64( + Some((float_ts * 1_000_000_000_f64).trunc() as i64), + ) + .to_array()? + } else { + scalar.to_array()? + } + } else { + scalar.to_array()? + }; let cast_array = kernels::cast::cast_with_options( &scalar_array, cast_type, @@ -201,7 +212,10 @@ pub fn cast_with_options( let expr_type = expr.data_type(input_schema)?; if expr_type == cast_type { Ok(expr.clone()) - } else if can_cast_types(&expr_type, &cast_type) { + } else if can_cast_types(&expr_type, &cast_type) + || (expr_type == DataType::Float64 + && cast_type == DataType::Timestamp(arrow_schema::TimeUnit::Nanosecond, None)) + { Ok(Arc::new(CastExpr::new(expr, cast_type, cast_options))) } else { not_impl_err!("Unsupported CAST from {expr_type:?} to {cast_type:?}") diff --git a/datafusion/physical-expr/src/expressions/get_indexed_field.rs b/datafusion/physical-expr/src/expressions/get_indexed_field.rs index 7d5f16c454d6..43fd5a812a16 100644 --- a/datafusion/physical-expr/src/expressions/get_indexed_field.rs +++ b/datafusion/physical-expr/src/expressions/get_indexed_field.rs @@ -110,7 +110,7 @@ impl GetIndexedFieldExpr { Self::new( arg, GetFieldAccessExpr::NamedStructField { - name: ScalarValue::Utf8(Some(name.into())), + name: ScalarValue::from(name.into()), }, ) } @@ -453,7 +453,7 @@ mod tests { .evaluate(&batch)? .into_array(batch.num_rows()) .expect("Failed to convert to array"); - assert!(result.is_null(0)); + assert!(result.is_empty()); Ok(()) } diff --git a/datafusion/physical-expr/src/expressions/nullif.rs b/datafusion/physical-expr/src/expressions/nullif.rs index 252bd10c3e73..dcd883f92965 100644 --- a/datafusion/physical-expr/src/expressions/nullif.rs +++ b/datafusion/physical-expr/src/expressions/nullif.rs @@ -154,7 +154,7 @@ mod tests { let a = StringArray::from(vec![Some("foo"), Some("bar"), None, Some("baz")]); let a = ColumnarValue::Array(Arc::new(a)); - let lit_array = ColumnarValue::Scalar(ScalarValue::Utf8(Some("bar".to_string()))); + let lit_array = ColumnarValue::Scalar(ScalarValue::from("bar")); let result = nullif_func(&[a, lit_array])?; let result = result.into_array(0).expect("Failed to convert to array"); diff --git a/datafusion/physical-expr/src/functions.rs b/datafusion/physical-expr/src/functions.rs index 72c7f492166d..53de85843919 100644 --- a/datafusion/physical-expr/src/functions.rs +++ b/datafusion/physical-expr/src/functions.rs @@ -329,6 +329,9 @@ pub fn create_physical_fun( BuiltinScalarFunction::ArrayAppend => { Arc::new(|args| make_scalar_function(array_expressions::array_append)(args)) } + BuiltinScalarFunction::ArraySort => { + Arc::new(|args| make_scalar_function(array_expressions::array_sort)(args)) + } BuiltinScalarFunction::ArrayConcat => { Arc::new(|args| make_scalar_function(array_expressions::array_concat)(args)) } @@ -347,6 +350,9 @@ pub fn create_physical_fun( BuiltinScalarFunction::ArrayDims => { Arc::new(|args| make_scalar_function(array_expressions::array_dims)(args)) } + BuiltinScalarFunction::ArrayDistinct => { + Arc::new(|args| make_scalar_function(array_expressions::array_distinct)(args)) + } BuiltinScalarFunction::ArrayElement => { Arc::new(|args| make_scalar_function(array_expressions::array_element)(args)) } @@ -834,9 +840,9 @@ pub fn create_physical_fun( } let input_data_type = args[0].data_type(); - Ok(ColumnarValue::Scalar(ScalarValue::Utf8(Some(format!( + Ok(ColumnarValue::Scalar(ScalarValue::from(format!( "{input_data_type}" - ))))) + )))) }), BuiltinScalarFunction::OverLay => Arc::new(|args| match args[0].data_type() { DataType::Utf8 => { diff --git a/datafusion/physical-expr/src/planner.rs b/datafusion/physical-expr/src/planner.rs index 5501647da2c3..9c212cb81f6b 100644 --- a/datafusion/physical-expr/src/planner.rs +++ b/datafusion/physical-expr/src/planner.rs @@ -348,50 +348,38 @@ pub fn create_physical_expr( ))) } - Expr::ScalarFunction(ScalarFunction { func_def, args }) => match func_def { - ScalarFunctionDefinition::BuiltIn(fun) => { - let physical_args = args - .iter() - .map(|e| { - create_physical_expr( - e, - input_dfschema, - input_schema, - execution_props, - ) - }) - .collect::>>()?; - functions::create_physical_expr( - fun, - &physical_args, - input_schema, - execution_props, - ) - } - ScalarFunctionDefinition::UDF(fun) => { - let mut physical_args = vec![]; - for e in args { - physical_args.push(create_physical_expr( - e, - input_dfschema, + Expr::ScalarFunction(ScalarFunction { func_def, args }) => { + let mut physical_args = args + .iter() + .map(|e| { + create_physical_expr(e, input_dfschema, input_schema, execution_props) + }) + .collect::>>()?; + match func_def { + ScalarFunctionDefinition::BuiltIn(fun) => { + functions::create_physical_expr( + fun, + &physical_args, input_schema, execution_props, - )?); + ) + } + ScalarFunctionDefinition::UDF(fun) => { + // udfs with zero params expect null array as input + if args.is_empty() { + physical_args.push(Arc::new(Literal::new(ScalarValue::Null))); + } + udf::create_physical_expr( + fun.clone().as_ref(), + &physical_args, + input_schema, + ) } - // udfs with zero params expect null array as input - if args.is_empty() { - physical_args.push(Arc::new(Literal::new(ScalarValue::Null))); + ScalarFunctionDefinition::Name(_) => { + internal_err!("Function `Expr` with name should be resolved.") } - udf::create_physical_expr( - fun.clone().as_ref(), - &physical_args, - input_schema, - ) } - ScalarFunctionDefinition::Name(_) => { - internal_err!("Function `Expr` with name should be resolved.") - } - }, + } Expr::Between(Between { expr, negated, diff --git a/datafusion/physical-expr/src/sort_expr.rs b/datafusion/physical-expr/src/sort_expr.rs index 664a6b65b7f7..914d76f9261a 100644 --- a/datafusion/physical-expr/src/sort_expr.rs +++ b/datafusion/physical-expr/src/sort_expr.rs @@ -26,7 +26,7 @@ use crate::PhysicalExpr; use arrow::compute::kernels::sort::{SortColumn, SortOptions}; use arrow::record_batch::RecordBatch; use arrow_schema::Schema; -use datafusion_common::{exec_err, DataFusionError, Result}; +use datafusion_common::Result; use datafusion_expr::ColumnarValue; /// Represents Sort operation for a column in a RecordBatch @@ -65,11 +65,7 @@ impl PhysicalSortExpr { let value_to_sort = self.expr.evaluate(batch)?; let array_to_sort = match value_to_sort { ColumnarValue::Array(array) => array, - ColumnarValue::Scalar(scalar) => { - return exec_err!( - "Sort operation is not applicable to scalar value {scalar}" - ); - } + ColumnarValue::Scalar(scalar) => scalar.to_array_of_size(batch.num_rows())?, }; Ok(SortColumn { values: array_to_sort, diff --git a/datafusion/physical-expr/src/sort_properties.rs b/datafusion/physical-expr/src/sort_properties.rs index f8648abdf7a7..f51374461776 100644 --- a/datafusion/physical-expr/src/sort_properties.rs +++ b/datafusion/physical-expr/src/sort_properties.rs @@ -17,13 +17,12 @@ use std::{ops::Neg, sync::Arc}; -use crate::PhysicalExpr; use arrow_schema::SortOptions; + +use crate::PhysicalExpr; use datafusion_common::tree_node::{TreeNode, VisitRecursion}; use datafusion_common::Result; -use itertools::Itertools; - /// To propagate [`SortOptions`] across the [`PhysicalExpr`], it is insufficient /// to simply use `Option`: There must be a differentiation between /// unordered columns and literal values, since literals may not break the ordering @@ -35,11 +34,12 @@ use itertools::Itertools; /// sorted data; however the ((a_ordered + 999) + c_ordered) expression can. Therefore, /// we need two different variants for literals and unordered columns as literals are /// often more ordering-friendly under most mathematical operations. -#[derive(PartialEq, Debug, Clone, Copy)] +#[derive(PartialEq, Debug, Clone, Copy, Default)] pub enum SortProperties { /// Use the ordinary [`SortOptions`] struct to represent ordered data: Ordered(SortOptions), // This alternative represents unordered data: + #[default] Unordered, // Singleton is used for single-valued literal numbers: Singleton, @@ -151,34 +151,24 @@ impl Neg for SortProperties { pub struct ExprOrdering { pub expr: Arc, pub state: SortProperties, - pub children_states: Vec, + pub children: Vec, } impl ExprOrdering { /// Creates a new [`ExprOrdering`] with [`SortProperties::Unordered`] states /// for `expr` and its children. pub fn new(expr: Arc) -> Self { - let size = expr.children().len(); + let children = expr.children(); Self { expr, - state: SortProperties::Unordered, - children_states: vec![SortProperties::Unordered; size], + state: Default::default(), + children: children.into_iter().map(Self::new).collect(), } } - /// Updates this [`ExprOrdering`]'s children states with the given states. - pub fn with_new_children(mut self, children_states: Vec) -> Self { - self.children_states = children_states; - self - } - - /// Creates new [`ExprOrdering`] objects for each child of the expression. - pub fn children_expr_orderings(&self) -> Vec { - self.expr - .children() - .into_iter() - .map(ExprOrdering::new) - .collect() + /// Get a reference to each child state. + pub fn children_state(&self) -> Vec { + self.children.iter().map(|c| c.state).collect() } } @@ -187,8 +177,8 @@ impl TreeNode for ExprOrdering { where F: FnMut(&Self) -> Result, { - for child in self.children_expr_orderings() { - match op(&child)? { + for child in &self.children { + match op(child)? { VisitRecursion::Continue => {} VisitRecursion::Skip => return Ok(VisitRecursion::Continue), VisitRecursion::Stop => return Ok(VisitRecursion::Stop), @@ -197,25 +187,19 @@ impl TreeNode for ExprOrdering { Ok(VisitRecursion::Continue) } - fn map_children(self, transform: F) -> Result + fn map_children(mut self, transform: F) -> Result where F: FnMut(Self) -> Result, { - if self.children_states.is_empty() { + if self.children.is_empty() { Ok(self) } else { - let child_expr_orderings = self.children_expr_orderings(); - // After mapping over the children, the function `F` applies to the - // current object and updates its state. - Ok(self.with_new_children( - child_expr_orderings - .into_iter() - // Update children states after this transformation: - .map(transform) - // Extract the state (i.e. sort properties) information: - .map_ok(|c| c.state) - .collect::>>()?, - )) + self.children = self + .children + .into_iter() + .map(transform) + .collect::>>()?; + Ok(self) } } } diff --git a/datafusion/physical-expr/src/string_expressions.rs b/datafusion/physical-expr/src/string_expressions.rs index 91d21f95e41f..7d9fecf61407 100644 --- a/datafusion/physical-expr/src/string_expressions.rs +++ b/datafusion/physical-expr/src/string_expressions.rs @@ -37,8 +37,11 @@ use datafusion_common::{ }; use datafusion_common::{internal_err, DataFusionError, Result}; use datafusion_expr::ColumnarValue; -use std::iter; use std::sync::Arc; +use std::{ + fmt::{Display, Formatter}, + iter, +}; use uuid::Uuid; /// applies a unary expression to `args[0]` that is expected to be downcastable to @@ -133,53 +136,6 @@ pub fn ascii(args: &[ArrayRef]) -> Result { Ok(Arc::new(result) as ArrayRef) } -/// Removes the longest string containing only characters in characters (a space by default) from the start and end of string. -/// btrim('xyxtrimyyx', 'xyz') = 'trim' -pub fn btrim(args: &[ArrayRef]) -> Result { - match args.len() { - 1 => { - let string_array = as_generic_string_array::(&args[0])?; - - let result = string_array - .iter() - .map(|string| { - string.map(|string: &str| { - string.trim_start_matches(' ').trim_end_matches(' ') - }) - }) - .collect::>(); - - Ok(Arc::new(result) as ArrayRef) - } - 2 => { - let string_array = as_generic_string_array::(&args[0])?; - let characters_array = as_generic_string_array::(&args[1])?; - - let result = string_array - .iter() - .zip(characters_array.iter()) - .map(|(string, characters)| match (string, characters) { - (None, _) => None, - (_, None) => None, - (Some(string), Some(characters)) => { - let chars: Vec = characters.chars().collect(); - Some( - string - .trim_start_matches(&chars[..]) - .trim_end_matches(&chars[..]), - ) - } - }) - .collect::>(); - - Ok(Arc::new(result) as ArrayRef) - } - other => internal_err!( - "btrim was called with {other} arguments. It requires at least 1 and at most 2." - ), - } -} - /// Returns the character with the given code. chr(0) is disallowed because text data types cannot store that character. /// chr(65) = 'A' pub fn chr(args: &[ArrayRef]) -> Result { @@ -346,44 +302,95 @@ pub fn lower(args: &[ColumnarValue]) -> Result { handle(args, |string| string.to_ascii_lowercase(), "lower") } -/// Removes the longest string containing only characters in characters (a space by default) from the start of string. -/// ltrim('zzzytest', 'xyz') = 'test' -pub fn ltrim(args: &[ArrayRef]) -> Result { +enum TrimType { + Left, + Right, + Both, +} + +impl Display for TrimType { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + match self { + TrimType::Left => write!(f, "ltrim"), + TrimType::Right => write!(f, "rtrim"), + TrimType::Both => write!(f, "btrim"), + } + } +} + +fn general_trim( + args: &[ArrayRef], + trim_type: TrimType, +) -> Result { + let func = match trim_type { + TrimType::Left => |input, pattern: &str| { + let pattern = pattern.chars().collect::>(); + str::trim_start_matches::<&[char]>(input, pattern.as_ref()) + }, + TrimType::Right => |input, pattern: &str| { + let pattern = pattern.chars().collect::>(); + str::trim_end_matches::<&[char]>(input, pattern.as_ref()) + }, + TrimType::Both => |input, pattern: &str| { + let pattern = pattern.chars().collect::>(); + str::trim_end_matches::<&[char]>( + str::trim_start_matches::<&[char]>(input, pattern.as_ref()), + pattern.as_ref(), + ) + }, + }; + + let string_array = as_generic_string_array::(&args[0])?; + match args.len() { 1 => { - let string_array = as_generic_string_array::(&args[0])?; - let result = string_array .iter() - .map(|string| string.map(|string: &str| string.trim_start_matches(' '))) + .map(|string| string.map(|string: &str| func(string, " "))) .collect::>(); Ok(Arc::new(result) as ArrayRef) } 2 => { - let string_array = as_generic_string_array::(&args[0])?; let characters_array = as_generic_string_array::(&args[1])?; let result = string_array .iter() .zip(characters_array.iter()) .map(|(string, characters)| match (string, characters) { - (Some(string), Some(characters)) => { - let chars: Vec = characters.chars().collect(); - Some(string.trim_start_matches(&chars[..])) - } + (Some(string), Some(characters)) => Some(func(string, characters)), _ => None, }) .collect::>(); Ok(Arc::new(result) as ArrayRef) } - other => internal_err!( - "ltrim was called with {other} arguments. It requires at least 1 and at most 2." - ), + other => { + internal_err!( + "{trim_type} was called with {other} arguments. It requires at least 1 and at most 2." + ) + } } } +/// Returns the longest string with leading and trailing characters removed. If the characters are not specified, whitespace is removed. +/// btrim('xyxtrimyyx', 'xyz') = 'trim' +pub fn btrim(args: &[ArrayRef]) -> Result { + general_trim::(args, TrimType::Both) +} + +/// Returns the longest string with leading characters removed. If the characters are not specified, whitespace is removed. +/// ltrim('zzzytest', 'xyz') = 'test' +pub fn ltrim(args: &[ArrayRef]) -> Result { + general_trim::(args, TrimType::Left) +} + +/// Returns the longest string with trailing characters removed. If the characters are not specified, whitespace is removed. +/// rtrim('testxxzx', 'xyz') = 'test' +pub fn rtrim(args: &[ArrayRef]) -> Result { + general_trim::(args, TrimType::Right) +} + /// Repeats string the specified number of times. /// repeat('Pg', 4) = 'PgPgPgPg' pub fn repeat(args: &[ArrayRef]) -> Result { @@ -422,44 +429,6 @@ pub fn replace(args: &[ArrayRef]) -> Result { Ok(Arc::new(result) as ArrayRef) } -/// Removes the longest string containing only characters in characters (a space by default) from the end of string. -/// rtrim('testxxzx', 'xyz') = 'test' -pub fn rtrim(args: &[ArrayRef]) -> Result { - match args.len() { - 1 => { - let string_array = as_generic_string_array::(&args[0])?; - - let result = string_array - .iter() - .map(|string| string.map(|string: &str| string.trim_end_matches(' '))) - .collect::>(); - - Ok(Arc::new(result) as ArrayRef) - } - 2 => { - let string_array = as_generic_string_array::(&args[0])?; - let characters_array = as_generic_string_array::(&args[1])?; - - let result = string_array - .iter() - .zip(characters_array.iter()) - .map(|(string, characters)| match (string, characters) { - (Some(string), Some(characters)) => { - let chars: Vec = characters.chars().collect(); - Some(string.trim_end_matches(&chars[..])) - } - _ => None, - }) - .collect::>(); - - Ok(Arc::new(result) as ArrayRef) - } - other => internal_err!( - "rtrim was called with {other} arguments. It requires at least 1 and at most 2." - ), - } -} - /// Splits string at occurrences of delimiter and returns the n'th field (counting from one). /// split_part('abc~@~def~@~ghi', '~@~', 2) = 'def' pub fn split_part(args: &[ArrayRef]) -> Result { diff --git a/datafusion/physical-expr/src/utils.rs b/datafusion/physical-expr/src/utils.rs index ed62956de8e0..71a7ff5fb778 100644 --- a/datafusion/physical-expr/src/utils.rs +++ b/datafusion/physical-expr/src/utils.rs @@ -129,10 +129,11 @@ pub struct ExprTreeNode { impl ExprTreeNode { pub fn new(expr: Arc) -> Self { + let children = expr.children(); ExprTreeNode { expr, data: None, - child_nodes: vec![], + child_nodes: children.into_iter().map(Self::new).collect_vec(), } } @@ -140,12 +141,8 @@ impl ExprTreeNode { &self.expr } - pub fn children(&self) -> Vec> { - self.expr - .children() - .into_iter() - .map(ExprTreeNode::new) - .collect() + pub fn children(&self) -> &[ExprTreeNode] { + &self.child_nodes } } @@ -155,7 +152,7 @@ impl TreeNode for ExprTreeNode { F: FnMut(&Self) -> Result, { for child in self.children() { - match op(&child)? { + match op(child)? { VisitRecursion::Continue => {} VisitRecursion::Skip => return Ok(VisitRecursion::Continue), VisitRecursion::Stop => return Ok(VisitRecursion::Stop), @@ -170,7 +167,7 @@ impl TreeNode for ExprTreeNode { F: FnMut(Self) -> Result, { self.child_nodes = self - .children() + .child_nodes .into_iter() .map(transform) .collect::>>()?; diff --git a/datafusion/physical-expr/src/window/ntile.rs b/datafusion/physical-expr/src/window/ntile.rs index 49aac0877ab3..f5442e1b0fee 100644 --- a/datafusion/physical-expr/src/window/ntile.rs +++ b/datafusion/physical-expr/src/window/ntile.rs @@ -96,8 +96,9 @@ impl PartitionEvaluator for NtileEvaluator { ) -> Result { let num_rows = num_rows as u64; let mut vec: Vec = Vec::new(); + let n = u64::min(self.n, num_rows); for i in 0..num_rows { - let res = i * self.n / num_rows; + let res = i * n / num_rows; vec.push(res + 1) } Ok(Arc::new(UInt64Array::from(vec))) diff --git a/datafusion/physical-plan/src/aggregates/group_values/row.rs b/datafusion/physical-plan/src/aggregates/group_values/row.rs index 10ff9edb8912..e7c7a42cf902 100644 --- a/datafusion/physical-plan/src/aggregates/group_values/row.rs +++ b/datafusion/physical-plan/src/aggregates/group_values/row.rs @@ -17,22 +17,18 @@ use crate::aggregates::group_values::GroupValues; use ahash::RandomState; -use arrow::compute::cast; use arrow::record_batch::RecordBatch; use arrow::row::{RowConverter, Rows, SortField}; -use arrow_array::{Array, ArrayRef}; -use arrow_schema::{DataType, SchemaRef}; +use arrow_array::ArrayRef; +use arrow_schema::SchemaRef; use datafusion_common::hash_utils::create_hashes; -use datafusion_common::{DataFusionError, Result}; +use datafusion_common::Result; use datafusion_execution::memory_pool::proxy::{RawTableAllocExt, VecAllocExt}; use datafusion_physical_expr::EmitTo; use hashbrown::raw::RawTable; /// A [`GroupValues`] making use of [`Rows`] pub struct GroupValuesRows { - /// The output schema - schema: SchemaRef, - /// Converter for the group values row_converter: RowConverter, @@ -79,7 +75,6 @@ impl GroupValuesRows { let map = RawTable::with_capacity(0); Ok(Self { - schema, row_converter, map, map_size: 0, @@ -170,7 +165,7 @@ impl GroupValues for GroupValuesRows { .take() .expect("Can not emit from empty rows"); - let mut output = match emit_to { + let output = match emit_to { EmitTo::All => { let output = self.row_converter.convert_rows(&group_values)?; group_values.clear(); @@ -203,20 +198,6 @@ impl GroupValues for GroupValuesRows { } }; - // TODO: Materialize dictionaries in group keys (#7647) - for (field, array) in self.schema.fields.iter().zip(&mut output) { - let expected = field.data_type(); - if let DataType::Dictionary(_, v) = expected { - let actual = array.data_type(); - if v.as_ref() != actual { - return Err(DataFusionError::Internal(format!( - "Converted group rows expected dictionary of {v} got {actual}" - ))); - } - *array = cast(array.as_ref(), expected)?; - } - } - self.group_values = Some(group_values); Ok(output) } diff --git a/datafusion/physical-plan/src/aggregates/mod.rs b/datafusion/physical-plan/src/aggregates/mod.rs index 7d7fba6ef6c3..2f69ed061ce1 100644 --- a/datafusion/physical-plan/src/aggregates/mod.rs +++ b/datafusion/physical-plan/src/aggregates/mod.rs @@ -27,17 +27,16 @@ use crate::aggregates::{ }; use crate::metrics::{ExecutionPlanMetricsSet, MetricsSet}; -use crate::windows::{ - get_ordered_partition_by_indices, get_window_mode, PartitionSearchMode, -}; +use crate::windows::{get_ordered_partition_by_indices, get_window_mode}; use crate::{ - DisplayFormatType, Distribution, ExecutionPlan, Partitioning, + DisplayFormatType, Distribution, ExecutionPlan, InputOrderMode, Partitioning, SendableRecordBatchStream, Statistics, }; use arrow::array::ArrayRef; use arrow::datatypes::{Field, Schema, SchemaRef}; use arrow::record_batch::RecordBatch; +use arrow_schema::DataType; use datafusion_common::stats::Precision; use datafusion_common::{not_impl_err, plan_err, DataFusionError, Result}; use datafusion_execution::TaskContext; @@ -286,6 +285,9 @@ pub struct AggregateExec { limit: Option, /// Input plan, could be a partial aggregate or the input to the aggregate pub input: Arc, + /// Original aggregation schema, could be different from `schema` before dictionary group + /// keys get materialized + original_schema: SchemaRef, /// Schema after the aggregate is applied schema: SchemaRef, /// Input schema before any aggregation is applied. For partial aggregate this will be the @@ -300,7 +302,9 @@ pub struct AggregateExec { /// Execution metrics metrics: ExecutionPlanMetricsSet, required_input_ordering: Option, - partition_search_mode: PartitionSearchMode, + /// Describes how the input is ordered relative to the group by columns + input_order_mode: InputOrderMode, + /// Describe how the output is ordered output_ordering: Option, } @@ -405,15 +409,15 @@ fn get_aggregate_search_mode( aggr_expr: &mut [Arc], order_by_expr: &mut [Option], ordering_req: &mut Vec, -) -> PartitionSearchMode { +) -> InputOrderMode { let groupby_exprs = group_by .expr .iter() .map(|(item, _)| item.clone()) .collect::>(); - let mut partition_search_mode = PartitionSearchMode::Linear; + let mut input_order_mode = InputOrderMode::Linear; if !group_by.is_single() || groupby_exprs.is_empty() { - return partition_search_mode; + return input_order_mode; } if let Some((should_reverse, mode)) = @@ -435,9 +439,9 @@ fn get_aggregate_search_mode( ); *ordering_req = reverse_order_bys(ordering_req); } - partition_search_mode = mode; + input_order_mode = mode; } - partition_search_mode + input_order_mode } /// Check whether group by expression contains all of the expression inside `requirement` @@ -469,7 +473,7 @@ impl AggregateExec { input: Arc, input_schema: SchemaRef, ) -> Result { - let schema = create_schema( + let original_schema = create_schema( &input.schema(), &group_by.expr, &aggr_expr, @@ -477,7 +481,11 @@ impl AggregateExec { mode, )?; - let schema = Arc::new(schema); + let schema = Arc::new(materialize_dict_group_keys( + &original_schema, + group_by.expr.len(), + )); + let original_schema = Arc::new(original_schema); // Reset ordering requirement to `None` if aggregator is not order-sensitive order_by_expr = aggr_expr .iter() @@ -507,7 +515,7 @@ impl AggregateExec { &input.equivalence_properties(), )?; let mut ordering_req = requirement.unwrap_or(vec![]); - let partition_search_mode = get_aggregate_search_mode( + let input_order_mode = get_aggregate_search_mode( &group_by, &input, &mut aggr_expr, @@ -552,13 +560,14 @@ impl AggregateExec { filter_expr, order_by_expr, input, + original_schema, schema, input_schema, projection_mapping, metrics: ExecutionPlanMetricsSet::new(), required_input_ordering, limit: None, - partition_search_mode, + input_order_mode, output_ordering, }) } @@ -758,8 +767,8 @@ impl DisplayAs for AggregateExec { write!(f, ", lim=[{limit}]")?; } - if self.partition_search_mode != PartitionSearchMode::Linear { - write!(f, ", ordering_mode={:?}", self.partition_search_mode)?; + if self.input_order_mode != InputOrderMode::Linear { + write!(f, ", ordering_mode={:?}", self.input_order_mode)?; } } } @@ -810,7 +819,7 @@ impl ExecutionPlan for AggregateExec { /// infinite, returns an error to indicate this. fn unbounded_output(&self, children: &[bool]) -> Result { if children[0] { - if self.partition_search_mode == PartitionSearchMode::Linear { + if self.input_order_mode == InputOrderMode::Linear { // Cannot run without breaking pipeline. plan_err!( "Aggregate Error: `GROUP BY` clauses with columns without ordering and GROUPING SETS are not supported for unbounded inputs." @@ -973,6 +982,24 @@ fn create_schema( Ok(Schema::new(fields)) } +/// returns schema with dictionary group keys materialized as their value types +/// The actual convertion happens in `RowConverter` and we don't do unnecessary +/// conversion back into dictionaries +fn materialize_dict_group_keys(schema: &Schema, group_count: usize) -> Schema { + let fields = schema + .fields + .iter() + .enumerate() + .map(|(i, field)| match field.data_type() { + DataType::Dictionary(_, value_data_type) if i < group_count => { + Field::new(field.name(), *value_data_type.clone(), field.is_nullable()) + } + _ => Field::clone(field), + }) + .collect::>(); + Schema::new(fields) +} + fn group_schema(schema: &Schema, group_count: usize) -> SchemaRef { let group_fields = schema.fields()[0..group_count].to_vec(); Arc::new(Schema::new(group_fields)) diff --git a/datafusion/physical-plan/src/aggregates/order/mod.rs b/datafusion/physical-plan/src/aggregates/order/mod.rs index f72d2f06e459..b258b97a9e84 100644 --- a/datafusion/physical-plan/src/aggregates/order/mod.rs +++ b/datafusion/physical-plan/src/aggregates/order/mod.rs @@ -23,7 +23,7 @@ use datafusion_physical_expr::{EmitTo, PhysicalSortExpr}; mod full; mod partial; -use crate::windows::PartitionSearchMode; +use crate::InputOrderMode; pub(crate) use full::GroupOrderingFull; pub(crate) use partial::GroupOrderingPartial; @@ -42,18 +42,16 @@ impl GroupOrdering { /// Create a `GroupOrdering` for the the specified ordering pub fn try_new( input_schema: &Schema, - mode: &PartitionSearchMode, + mode: &InputOrderMode, ordering: &[PhysicalSortExpr], ) -> Result { match mode { - PartitionSearchMode::Linear => Ok(GroupOrdering::None), - PartitionSearchMode::PartiallySorted(order_indices) => { + InputOrderMode::Linear => Ok(GroupOrdering::None), + InputOrderMode::PartiallySorted(order_indices) => { GroupOrderingPartial::try_new(input_schema, order_indices, ordering) .map(GroupOrdering::Partial) } - PartitionSearchMode::Sorted => { - Ok(GroupOrdering::Full(GroupOrderingFull::new())) - } + InputOrderMode::Sorted => Ok(GroupOrdering::Full(GroupOrderingFull::new())), } } diff --git a/datafusion/physical-plan/src/aggregates/row_hash.rs b/datafusion/physical-plan/src/aggregates/row_hash.rs index f96417fc323b..89614fd3020c 100644 --- a/datafusion/physical-plan/src/aggregates/row_hash.rs +++ b/datafusion/physical-plan/src/aggregates/row_hash.rs @@ -324,7 +324,9 @@ impl GroupedHashAggregateStream { .map(create_group_accumulator) .collect::>()?; - let group_schema = group_schema(&agg_schema, agg_group_by.expr.len()); + // we need to use original schema so RowConverter in group_values below + // will do the proper coversion of dictionaries into value types + let group_schema = group_schema(&agg.original_schema, agg_group_by.expr.len()); let spill_expr = group_schema .fields .into_iter() @@ -344,7 +346,7 @@ impl GroupedHashAggregateStream { .find_longest_permutation(&agg_group_by.output_exprs()); let group_ordering = GroupOrdering::try_new( &group_schema, - &agg.partition_search_mode, + &agg.input_order_mode, ordering.as_slice(), )?; diff --git a/datafusion/physical-plan/src/display.rs b/datafusion/physical-plan/src/display.rs index aa368251ebf3..612e164be0e2 100644 --- a/datafusion/physical-plan/src/display.rs +++ b/datafusion/physical-plan/src/display.rs @@ -132,7 +132,7 @@ impl<'a> DisplayableExecutionPlan<'a> { /// ```dot /// strict digraph dot_plan { // 0[label="ProjectionExec: expr=[id@0 + 2 as employee.id + Int32(2)]",tooltip=""] - // 1[label="EmptyExec: produce_one_row=false",tooltip=""] + // 1[label="EmptyExec",tooltip=""] // 0 -> 1 // } /// ``` diff --git a/datafusion/physical-plan/src/empty.rs b/datafusion/physical-plan/src/empty.rs index a3e1fb79edb5..41c8dbed1453 100644 --- a/datafusion/physical-plan/src/empty.rs +++ b/datafusion/physical-plan/src/empty.rs @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -//! EmptyRelation execution plan +//! EmptyRelation with produce_one_row=false execution plan use std::any::Any; use std::sync::Arc; @@ -24,19 +24,16 @@ use super::expressions::PhysicalSortExpr; use super::{common, DisplayAs, SendableRecordBatchStream, Statistics}; use crate::{memory::MemoryStream, DisplayFormatType, ExecutionPlan, Partitioning}; -use arrow::array::{ArrayRef, NullArray}; -use arrow::datatypes::{DataType, Field, Fields, Schema, SchemaRef}; +use arrow::datatypes::SchemaRef; use arrow::record_batch::RecordBatch; use datafusion_common::{internal_err, DataFusionError, Result}; use datafusion_execution::TaskContext; use log::trace; -/// Execution plan for empty relation (produces no rows) +/// Execution plan for empty relation with produce_one_row=false #[derive(Debug)] pub struct EmptyExec { - /// Specifies whether this exec produces a row or not - produce_one_row: bool, /// The schema for the produced row schema: SchemaRef, /// Number of partitions @@ -45,9 +42,8 @@ pub struct EmptyExec { impl EmptyExec { /// Create a new EmptyExec - pub fn new(produce_one_row: bool, schema: SchemaRef) -> Self { + pub fn new(schema: SchemaRef) -> Self { EmptyExec { - produce_one_row, schema, partitions: 1, } @@ -59,36 +55,8 @@ impl EmptyExec { self } - /// Specifies whether this exec produces a row or not - pub fn produce_one_row(&self) -> bool { - self.produce_one_row - } - fn data(&self) -> Result> { - let batch = if self.produce_one_row { - let n_field = self.schema.fields.len(); - // hack for https://github.com/apache/arrow-datafusion/pull/3242 - let n_field = if n_field == 0 { 1 } else { n_field }; - vec![RecordBatch::try_new( - Arc::new(Schema::new( - (0..n_field) - .map(|i| { - Field::new(format!("placeholder_{i}"), DataType::Null, true) - }) - .collect::(), - )), - (0..n_field) - .map(|_i| { - let ret: ArrayRef = Arc::new(NullArray::new(1)); - ret - }) - .collect(), - )?] - } else { - vec![] - }; - - Ok(batch) + Ok(vec![]) } } @@ -100,7 +68,7 @@ impl DisplayAs for EmptyExec { ) -> std::fmt::Result { match t { DisplayFormatType::Default | DisplayFormatType::Verbose => { - write!(f, "EmptyExec: produce_one_row={}", self.produce_one_row) + write!(f, "EmptyExec") } } } @@ -133,10 +101,7 @@ impl ExecutionPlan for EmptyExec { self: Arc, _: Vec>, ) -> Result> { - Ok(Arc::new(EmptyExec::new( - self.produce_one_row, - self.schema.clone(), - ))) + Ok(Arc::new(EmptyExec::new(self.schema.clone()))) } fn execute( @@ -184,7 +149,7 @@ mod tests { let task_ctx = Arc::new(TaskContext::default()); let schema = test::aggr_test_schema(); - let empty = EmptyExec::new(false, schema.clone()); + let empty = EmptyExec::new(schema.clone()); assert_eq!(empty.schema(), schema); // we should have no results @@ -198,16 +163,11 @@ mod tests { #[test] fn with_new_children() -> Result<()> { let schema = test::aggr_test_schema(); - let empty = Arc::new(EmptyExec::new(false, schema.clone())); - let empty_with_row = Arc::new(EmptyExec::new(true, schema)); + let empty = Arc::new(EmptyExec::new(schema.clone())); let empty2 = with_new_children_if_necessary(empty.clone(), vec![])?.into(); assert_eq!(empty.schema(), empty2.schema()); - let empty_with_row_2 = - with_new_children_if_necessary(empty_with_row.clone(), vec![])?.into(); - assert_eq!(empty_with_row.schema(), empty_with_row_2.schema()); - let too_many_kids = vec![empty2]; assert!( with_new_children_if_necessary(empty, too_many_kids).is_err(), @@ -220,44 +180,11 @@ mod tests { async fn invalid_execute() -> Result<()> { let task_ctx = Arc::new(TaskContext::default()); let schema = test::aggr_test_schema(); - let empty = EmptyExec::new(false, schema); + let empty = EmptyExec::new(schema); // ask for the wrong partition assert!(empty.execute(1, task_ctx.clone()).is_err()); assert!(empty.execute(20, task_ctx).is_err()); Ok(()) } - - #[tokio::test] - async fn produce_one_row() -> Result<()> { - let task_ctx = Arc::new(TaskContext::default()); - let schema = test::aggr_test_schema(); - let empty = EmptyExec::new(true, schema); - - let iter = empty.execute(0, task_ctx)?; - let batches = common::collect(iter).await?; - - // should have one item - assert_eq!(batches.len(), 1); - - Ok(()) - } - - #[tokio::test] - async fn produce_one_row_multiple_partition() -> Result<()> { - let task_ctx = Arc::new(TaskContext::default()); - let schema = test::aggr_test_schema(); - let partitions = 3; - let empty = EmptyExec::new(true, schema).with_partitions(partitions); - - for n in 0..partitions { - let iter = empty.execute(n, task_ctx.clone())?; - let batches = common::collect(iter).await?; - - // should have one item - assert_eq!(batches.len(), 1); - } - - Ok(()) - } } diff --git a/datafusion/physical-plan/src/filter.rs b/datafusion/physical-plan/src/filter.rs index 903f4c972ebd..56a1b4e17821 100644 --- a/datafusion/physical-plan/src/filter.rs +++ b/datafusion/physical-plan/src/filter.rs @@ -61,6 +61,8 @@ pub struct FilterExec { input: Arc, /// Execution metrics metrics: ExecutionPlanMetricsSet, + /// Selectivity for statistics. 0 = no rows, 100 all rows + default_selectivity: u8, } impl FilterExec { @@ -74,6 +76,7 @@ impl FilterExec { predicate, input: input.clone(), metrics: ExecutionPlanMetricsSet::new(), + default_selectivity: 20, }), other => { plan_err!("Filter predicate must return boolean values, not {other:?}") @@ -81,6 +84,17 @@ impl FilterExec { } } + pub fn with_default_selectivity( + mut self, + default_selectivity: u8, + ) -> Result { + if default_selectivity > 100 { + return plan_err!("Default flter selectivity needs to be less than 100"); + } + self.default_selectivity = default_selectivity; + Ok(self) + } + /// The expression to filter on. This expression must evaluate to a boolean value. pub fn predicate(&self) -> &Arc { &self.predicate @@ -90,6 +104,11 @@ impl FilterExec { pub fn input(&self) -> &Arc { &self.input } + + /// The default selectivity + pub fn default_selectivity(&self) -> u8 { + self.default_selectivity + } } impl DisplayAs for FilterExec { @@ -166,6 +185,10 @@ impl ExecutionPlan for FilterExec { mut children: Vec>, ) -> Result> { FilterExec::try_new(self.predicate.clone(), children.swap_remove(0)) + .and_then(|e| { + let selectivity = e.default_selectivity(); + e.with_default_selectivity(selectivity) + }) .map(|e| Arc::new(e) as _) } @@ -196,10 +219,7 @@ impl ExecutionPlan for FilterExec { let input_stats = self.input.statistics()?; let schema = self.schema(); if !check_support(predicate, &schema) { - // assume filter selects 20% of rows if we cannot do anything smarter - // tracking issue for making this configurable: - // https://github.com/apache/arrow-datafusion/issues/8133 - let selectivity = 0.2_f64; + let selectivity = self.default_selectivity as f64 / 100.0; let mut stats = input_stats.into_inexact(); stats.num_rows = stats.num_rows.with_estimated_selectivity(selectivity); stats.total_byte_size = stats @@ -987,4 +1007,54 @@ mod tests { Ok(()) } + + #[tokio::test] + async fn test_validation_filter_selectivity() -> Result<()> { + let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]); + let input = Arc::new(StatisticsExec::new( + Statistics::new_unknown(&schema), + schema, + )); + // WHERE a = 10 + let predicate = Arc::new(BinaryExpr::new( + Arc::new(Column::new("a", 0)), + Operator::Eq, + Arc::new(Literal::new(ScalarValue::Int32(Some(10)))), + )); + let filter = FilterExec::try_new(predicate, input)?; + assert!(filter.with_default_selectivity(120).is_err()); + Ok(()) + } + + #[tokio::test] + async fn test_custom_filter_selectivity() -> Result<()> { + // Need a decimal to trigger inexact selectivity + let schema = + Schema::new(vec![Field::new("a", DataType::Decimal128(2, 3), false)]); + let input = Arc::new(StatisticsExec::new( + Statistics { + num_rows: Precision::Inexact(1000), + total_byte_size: Precision::Inexact(4000), + column_statistics: vec![ColumnStatistics { + ..Default::default() + }], + }, + schema, + )); + // WHERE a = 10 + let predicate = Arc::new(BinaryExpr::new( + Arc::new(Column::new("a", 0)), + Operator::Eq, + Arc::new(Literal::new(ScalarValue::Decimal128(Some(10), 10, 10))), + )); + let filter = FilterExec::try_new(predicate, input)?; + let statistics = filter.statistics()?; + assert_eq!(statistics.num_rows, Precision::Inexact(200)); + assert_eq!(statistics.total_byte_size, Precision::Inexact(800)); + let filter = filter.with_default_selectivity(40)?; + let statistics = filter.statistics()?; + assert_eq!(statistics.num_rows, Precision::Inexact(400)); + assert_eq!(statistics.total_byte_size, Precision::Inexact(1600)); + Ok(()) + } } diff --git a/datafusion/physical-plan/src/joins/cross_join.rs b/datafusion/physical-plan/src/joins/cross_join.rs index 4c928d44caf4..938c9e4d343d 100644 --- a/datafusion/physical-plan/src/joins/cross_join.rs +++ b/datafusion/physical-plan/src/joins/cross_join.rs @@ -476,12 +476,8 @@ mod tests { }, ColumnStatistics { distinct_count: Precision::Exact(1), - max_value: Precision::Exact(ScalarValue::Utf8(Some(String::from( - "x", - )))), - min_value: Precision::Exact(ScalarValue::Utf8(Some(String::from( - "a", - )))), + max_value: Precision::Exact(ScalarValue::from("x")), + min_value: Precision::Exact(ScalarValue::from("a")), null_count: Precision::Exact(3), }, ], @@ -512,12 +508,8 @@ mod tests { }, ColumnStatistics { distinct_count: Precision::Exact(1), - max_value: Precision::Exact(ScalarValue::Utf8(Some(String::from( - "x", - )))), - min_value: Precision::Exact(ScalarValue::Utf8(Some(String::from( - "a", - )))), + max_value: Precision::Exact(ScalarValue::from("x")), + min_value: Precision::Exact(ScalarValue::from("a")), null_count: Precision::Exact(3 * right_row_count), }, ColumnStatistics { @@ -548,12 +540,8 @@ mod tests { }, ColumnStatistics { distinct_count: Precision::Exact(1), - max_value: Precision::Exact(ScalarValue::Utf8(Some(String::from( - "x", - )))), - min_value: Precision::Exact(ScalarValue::Utf8(Some(String::from( - "a", - )))), + max_value: Precision::Exact(ScalarValue::from("x")), + min_value: Precision::Exact(ScalarValue::from("a")), null_count: Precision::Exact(3), }, ], @@ -584,12 +572,8 @@ mod tests { }, ColumnStatistics { distinct_count: Precision::Exact(1), - max_value: Precision::Exact(ScalarValue::Utf8(Some(String::from( - "x", - )))), - min_value: Precision::Exact(ScalarValue::Utf8(Some(String::from( - "a", - )))), + max_value: Precision::Exact(ScalarValue::from("x")), + min_value: Precision::Exact(ScalarValue::from("a")), null_count: Precision::Absent, // we don't know the row count on the right }, ColumnStatistics { diff --git a/datafusion/physical-plan/src/lib.rs b/datafusion/physical-plan/src/lib.rs index b2c69b467e9c..6c9e97e03cb7 100644 --- a/datafusion/physical-plan/src/lib.rs +++ b/datafusion/physical-plan/src/lib.rs @@ -58,6 +58,8 @@ pub mod joins; pub mod limit; pub mod memory; pub mod metrics; +mod ordering; +pub mod placeholder_row; pub mod projection; pub mod repartition; pub mod sorts; @@ -72,6 +74,7 @@ pub mod windows; pub use crate::display::{DefaultDisplay, DisplayAs, DisplayFormatType, VerboseDisplay}; pub use crate::metrics::Metric; +pub use crate::ordering::InputOrderMode; pub use crate::topk::TopK; pub use crate::visitor::{accept, visit_execution_plan, ExecutionPlanVisitor}; diff --git a/datafusion/physical-plan/src/memory.rs b/datafusion/physical-plan/src/memory.rs index 39cd47452eff..7de474fda11c 100644 --- a/datafusion/physical-plan/src/memory.rs +++ b/datafusion/physical-plan/src/memory.rs @@ -55,7 +55,7 @@ impl fmt::Debug for MemoryExec { write!(f, "partitions: [...]")?; write!(f, "schema: {:?}", self.projected_schema)?; write!(f, "projection: {:?}", self.projection)?; - if let Some(sort_info) = &self.sort_information.get(0) { + if let Some(sort_info) = &self.sort_information.first() { write!(f, ", output_ordering: {:?}", sort_info)?; } Ok(()) diff --git a/datafusion/physical-plan/src/ordering.rs b/datafusion/physical-plan/src/ordering.rs new file mode 100644 index 000000000000..047f89eef193 --- /dev/null +++ b/datafusion/physical-plan/src/ordering.rs @@ -0,0 +1,51 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +/// Specifies how the input to an aggregation or window operator is ordered +/// relative to their `GROUP BY` or `PARTITION BY` expressions. +/// +/// For example, if the existing ordering is `[a ASC, b ASC, c ASC]` +/// +/// ## Window Functions +/// - A `PARTITION BY b` clause can use `Linear` mode. +/// - A `PARTITION BY a, c` or a `PARTITION BY c, a` can use +/// `PartiallySorted([0])` or `PartiallySorted([1])` modes, respectively. +/// (The vector stores the index of `a` in the respective PARTITION BY expression.) +/// - A `PARTITION BY a, b` or a `PARTITION BY b, a` can use `Sorted` mode. +/// +/// ## Aggregations +/// - A `GROUP BY b` clause can use `Linear` mode. +/// - A `GROUP BY a, c` or a `GROUP BY BY c, a` can use +/// `PartiallySorted([0])` or `PartiallySorted([1])` modes, respectively. +/// (The vector stores the index of `a` in the respective PARTITION BY expression.) +/// - A `GROUP BY a, b` or a `GROUP BY b, a` can use `Sorted` mode. +/// +/// Note these are the same examples as above, but with `GROUP BY` instead of +/// `PARTITION BY` to make the examples easier to read. +#[derive(Debug, Clone, PartialEq)] +pub enum InputOrderMode { + /// There is no partial permutation of the expressions satisfying the + /// existing ordering. + Linear, + /// There is a partial permutation of the expressions satisfying the + /// existing ordering. Indices describing the longest partial permutation + /// are stored in the vector. + PartiallySorted(Vec), + /// There is a (full) permutation of the expressions satisfying the + /// existing ordering. + Sorted, +} diff --git a/datafusion/physical-plan/src/placeholder_row.rs b/datafusion/physical-plan/src/placeholder_row.rs new file mode 100644 index 000000000000..94f32788530b --- /dev/null +++ b/datafusion/physical-plan/src/placeholder_row.rs @@ -0,0 +1,229 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! EmptyRelation produce_one_row=true execution plan + +use std::any::Any; +use std::sync::Arc; + +use super::expressions::PhysicalSortExpr; +use super::{common, DisplayAs, SendableRecordBatchStream, Statistics}; +use crate::{memory::MemoryStream, DisplayFormatType, ExecutionPlan, Partitioning}; + +use arrow::array::{ArrayRef, NullArray}; +use arrow::datatypes::{DataType, Field, Fields, Schema, SchemaRef}; +use arrow::record_batch::RecordBatch; +use datafusion_common::{internal_err, DataFusionError, Result}; +use datafusion_execution::TaskContext; + +use log::trace; + +/// Execution plan for empty relation with produce_one_row=true +#[derive(Debug)] +pub struct PlaceholderRowExec { + /// The schema for the produced row + schema: SchemaRef, + /// Number of partitions + partitions: usize, +} + +impl PlaceholderRowExec { + /// Create a new PlaceholderRowExec + pub fn new(schema: SchemaRef) -> Self { + PlaceholderRowExec { + schema, + partitions: 1, + } + } + + /// Create a new PlaceholderRowExecPlaceholderRowExec with specified partition number + pub fn with_partitions(mut self, partitions: usize) -> Self { + self.partitions = partitions; + self + } + + fn data(&self) -> Result> { + Ok({ + let n_field = self.schema.fields.len(); + // hack for https://github.com/apache/arrow-datafusion/pull/3242 + let n_field = if n_field == 0 { 1 } else { n_field }; + vec![RecordBatch::try_new( + Arc::new(Schema::new( + (0..n_field) + .map(|i| { + Field::new(format!("placeholder_{i}"), DataType::Null, true) + }) + .collect::(), + )), + (0..n_field) + .map(|_i| { + let ret: ArrayRef = Arc::new(NullArray::new(1)); + ret + }) + .collect(), + )?] + }) + } +} + +impl DisplayAs for PlaceholderRowExec { + fn fmt_as( + &self, + t: DisplayFormatType, + f: &mut std::fmt::Formatter, + ) -> std::fmt::Result { + match t { + DisplayFormatType::Default | DisplayFormatType::Verbose => { + write!(f, "PlaceholderRowExec") + } + } + } +} + +impl ExecutionPlan for PlaceholderRowExec { + /// Return a reference to Any that can be used for downcasting + fn as_any(&self) -> &dyn Any { + self + } + + fn schema(&self) -> SchemaRef { + self.schema.clone() + } + + fn children(&self) -> Vec> { + vec![] + } + + /// Get the output partitioning of this plan + fn output_partitioning(&self) -> Partitioning { + Partitioning::UnknownPartitioning(self.partitions) + } + + fn output_ordering(&self) -> Option<&[PhysicalSortExpr]> { + None + } + + fn with_new_children( + self: Arc, + _: Vec>, + ) -> Result> { + Ok(Arc::new(PlaceholderRowExec::new(self.schema.clone()))) + } + + fn execute( + &self, + partition: usize, + context: Arc, + ) -> Result { + trace!("Start PlaceholderRowExec::execute for partition {} of context session_id {} and task_id {:?}", partition, context.session_id(), context.task_id()); + + if partition >= self.partitions { + return internal_err!( + "PlaceholderRowExec invalid partition {} (expected less than {})", + partition, + self.partitions + ); + } + + Ok(Box::pin(MemoryStream::try_new( + self.data()?, + self.schema.clone(), + None, + )?)) + } + + fn statistics(&self) -> Result { + let batch = self + .data() + .expect("Create single row placeholder RecordBatch should not fail"); + Ok(common::compute_record_batch_statistics( + &[batch], + &self.schema, + None, + )) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::with_new_children_if_necessary; + use crate::{common, test}; + + #[test] + fn with_new_children() -> Result<()> { + let schema = test::aggr_test_schema(); + + let placeholder = Arc::new(PlaceholderRowExec::new(schema)); + + let placeholder_2 = + with_new_children_if_necessary(placeholder.clone(), vec![])?.into(); + assert_eq!(placeholder.schema(), placeholder_2.schema()); + + let too_many_kids = vec![placeholder_2]; + assert!( + with_new_children_if_necessary(placeholder, too_many_kids).is_err(), + "expected error when providing list of kids" + ); + Ok(()) + } + + #[tokio::test] + async fn invalid_execute() -> Result<()> { + let task_ctx = Arc::new(TaskContext::default()); + let schema = test::aggr_test_schema(); + let placeholder = PlaceholderRowExec::new(schema); + + // ask for the wrong partition + assert!(placeholder.execute(1, task_ctx.clone()).is_err()); + assert!(placeholder.execute(20, task_ctx).is_err()); + Ok(()) + } + + #[tokio::test] + async fn produce_one_row() -> Result<()> { + let task_ctx = Arc::new(TaskContext::default()); + let schema = test::aggr_test_schema(); + let placeholder = PlaceholderRowExec::new(schema); + + let iter = placeholder.execute(0, task_ctx)?; + let batches = common::collect(iter).await?; + + // should have one item + assert_eq!(batches.len(), 1); + + Ok(()) + } + + #[tokio::test] + async fn produce_one_row_multiple_partition() -> Result<()> { + let task_ctx = Arc::new(TaskContext::default()); + let schema = test::aggr_test_schema(); + let partitions = 3; + let placeholder = PlaceholderRowExec::new(schema).with_partitions(partitions); + + for n in 0..partitions { + let iter = placeholder.execute(n, task_ctx.clone())?; + let batches = common::collect(iter).await?; + + // should have one item + assert_eq!(batches.len(), 1); + } + + Ok(()) + } +} diff --git a/datafusion/physical-plan/src/projection.rs b/datafusion/physical-plan/src/projection.rs index 2e1d3dbf94f5..cc2ab62049ed 100644 --- a/datafusion/physical-plan/src/projection.rs +++ b/datafusion/physical-plan/src/projection.rs @@ -397,12 +397,8 @@ mod tests { }, ColumnStatistics { distinct_count: Precision::Exact(1), - max_value: Precision::Exact(ScalarValue::Utf8(Some(String::from( - "x", - )))), - min_value: Precision::Exact(ScalarValue::Utf8(Some(String::from( - "a", - )))), + max_value: Precision::Exact(ScalarValue::from("x")), + min_value: Precision::Exact(ScalarValue::from("a")), null_count: Precision::Exact(3), }, ColumnStatistics { @@ -439,12 +435,8 @@ mod tests { column_statistics: vec![ ColumnStatistics { distinct_count: Precision::Exact(1), - max_value: Precision::Exact(ScalarValue::Utf8(Some(String::from( - "x", - )))), - min_value: Precision::Exact(ScalarValue::Utf8(Some(String::from( - "a", - )))), + max_value: Precision::Exact(ScalarValue::from("x")), + min_value: Precision::Exact(ScalarValue::from("a")), null_count: Precision::Exact(3), }, ColumnStatistics { diff --git a/datafusion/physical-plan/src/test/exec.rs b/datafusion/physical-plan/src/test/exec.rs index 71e6cba6741e..1f6ee1f117aa 100644 --- a/datafusion/physical-plan/src/test/exec.rs +++ b/datafusion/physical-plan/src/test/exec.rs @@ -790,7 +790,7 @@ impl Stream for PanicStream { } else { self.ready = true; // get called again - cx.waker().clone().wake(); + cx.waker().wake_by_ref(); return Poll::Pending; } } diff --git a/datafusion/physical-plan/src/union.rs b/datafusion/physical-plan/src/union.rs index 9700605ce406..14ef9c2ec27b 100644 --- a/datafusion/physical-plan/src/union.rs +++ b/datafusion/physical-plan/src/union.rs @@ -38,7 +38,7 @@ use crate::stream::ObservedStream; use arrow::datatypes::{Field, Schema, SchemaRef}; use arrow::record_batch::RecordBatch; use datafusion_common::stats::Precision; -use datafusion_common::{exec_err, internal_err, DFSchemaRef, DataFusionError, Result}; +use datafusion_common::{exec_err, internal_err, DataFusionError, Result}; use datafusion_execution::TaskContext; use datafusion_physical_expr::EquivalenceProperties; @@ -95,38 +95,6 @@ pub struct UnionExec { } impl UnionExec { - /// Create a new UnionExec with specified schema. - /// The `schema` should always be a subset of the schema of `inputs`, - /// otherwise, an error will be returned. - pub fn try_new_with_schema( - inputs: Vec>, - schema: DFSchemaRef, - ) -> Result { - let mut exec = Self::new(inputs); - let exec_schema = exec.schema(); - let fields = schema - .fields() - .iter() - .map(|dff| { - exec_schema - .field_with_name(dff.name()) - .cloned() - .map_err(|_| { - DataFusionError::Internal(format!( - "Cannot find the field {:?} in child schema", - dff.name() - )) - }) - }) - .collect::>>()?; - let schema = Arc::new(Schema::new_with_metadata( - fields, - exec.schema().metadata().clone(), - )); - exec.schema = schema; - Ok(exec) - } - /// Create a new UnionExec pub fn new(inputs: Vec>) -> Self { let schema = union_schema(&inputs); @@ -706,12 +674,8 @@ mod tests { }, ColumnStatistics { distinct_count: Precision::Exact(1), - max_value: Precision::Exact(ScalarValue::Utf8(Some(String::from( - "x", - )))), - min_value: Precision::Exact(ScalarValue::Utf8(Some(String::from( - "a", - )))), + max_value: Precision::Exact(ScalarValue::from("x")), + min_value: Precision::Exact(ScalarValue::from("a")), null_count: Precision::Exact(3), }, ColumnStatistics { @@ -735,12 +699,8 @@ mod tests { }, ColumnStatistics { distinct_count: Precision::Absent, - max_value: Precision::Exact(ScalarValue::Utf8(Some(String::from( - "c", - )))), - min_value: Precision::Exact(ScalarValue::Utf8(Some(String::from( - "b", - )))), + max_value: Precision::Exact(ScalarValue::from("c")), + min_value: Precision::Exact(ScalarValue::from("b")), null_count: Precision::Absent, }, ColumnStatistics { @@ -765,12 +725,8 @@ mod tests { }, ColumnStatistics { distinct_count: Precision::Absent, - max_value: Precision::Exact(ScalarValue::Utf8(Some(String::from( - "x", - )))), - min_value: Precision::Exact(ScalarValue::Utf8(Some(String::from( - "a", - )))), + max_value: Precision::Exact(ScalarValue::from("x")), + min_value: Precision::Exact(ScalarValue::from("a")), null_count: Precision::Absent, }, ColumnStatistics { diff --git a/datafusion/physical-plan/src/windows/bounded_window_agg_exec.rs b/datafusion/physical-plan/src/windows/bounded_window_agg_exec.rs index 8156ab1fa31b..431a43bc6055 100644 --- a/datafusion/physical-plan/src/windows/bounded_window_agg_exec.rs +++ b/datafusion/physical-plan/src/windows/bounded_window_agg_exec.rs @@ -31,15 +31,16 @@ use crate::expressions::PhysicalSortExpr; use crate::metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet}; use crate::windows::{ calc_requirements, get_ordered_partition_by_indices, get_partition_by_sort_exprs, - window_equivalence_properties, PartitionSearchMode, + window_equivalence_properties, }; use crate::{ ColumnStatistics, DisplayAs, DisplayFormatType, Distribution, ExecutionPlan, - Partitioning, RecordBatchStream, SendableRecordBatchStream, Statistics, WindowExpr, + InputOrderMode, Partitioning, RecordBatchStream, SendableRecordBatchStream, + Statistics, WindowExpr, }; use arrow::{ - array::{Array, ArrayRef, UInt32Builder}, + array::{Array, ArrayRef, RecordBatchOptions, UInt32Builder}, compute::{concat, concat_batches, sort_to_indices}, datatypes::{Schema, SchemaBuilder, SchemaRef}, record_batch::RecordBatch, @@ -50,7 +51,7 @@ use datafusion_common::utils::{ evaluate_partition_ranges, get_arrayref_at_indices, get_at_indices, get_record_batch_at_indices, get_row_at_idx, }; -use datafusion_common::{exec_err, plan_err, DataFusionError, Result}; +use datafusion_common::{exec_err, DataFusionError, Result}; use datafusion_execution::TaskContext; use datafusion_expr::window_state::{PartitionBatchState, WindowAggState}; use datafusion_expr::ColumnarValue; @@ -81,8 +82,8 @@ pub struct BoundedWindowAggExec { pub partition_keys: Vec>, /// Execution metrics metrics: ExecutionPlanMetricsSet, - /// Partition by search mode - pub partition_search_mode: PartitionSearchMode, + /// Describes how the input is ordered relative to the partition keys + pub input_order_mode: InputOrderMode, /// Partition by indices that define ordering // For example, if input ordering is ORDER BY a, b and window expression // contains PARTITION BY b, a; `ordered_partition_by_indices` would be 1, 0. @@ -98,13 +99,13 @@ impl BoundedWindowAggExec { window_expr: Vec>, input: Arc, partition_keys: Vec>, - partition_search_mode: PartitionSearchMode, + input_order_mode: InputOrderMode, ) -> Result { let schema = create_schema(&input.schema(), &window_expr)?; let schema = Arc::new(schema); let partition_by_exprs = window_expr[0].partition_by(); - let ordered_partition_by_indices = match &partition_search_mode { - PartitionSearchMode::Sorted => { + let ordered_partition_by_indices = match &input_order_mode { + InputOrderMode::Sorted => { let indices = get_ordered_partition_by_indices( window_expr[0].partition_by(), &input, @@ -115,10 +116,8 @@ impl BoundedWindowAggExec { (0..partition_by_exprs.len()).collect::>() } } - PartitionSearchMode::PartiallySorted(ordered_indices) => { - ordered_indices.clone() - } - PartitionSearchMode::Linear => { + InputOrderMode::PartiallySorted(ordered_indices) => ordered_indices.clone(), + InputOrderMode::Linear => { vec![] } }; @@ -128,7 +127,7 @@ impl BoundedWindowAggExec { schema, partition_keys, metrics: ExecutionPlanMetricsSet::new(), - partition_search_mode, + input_order_mode, ordered_partition_by_indices, }) } @@ -162,8 +161,8 @@ impl BoundedWindowAggExec { fn get_search_algo(&self) -> Result> { let partition_by_sort_keys = self.partition_by_sort_keys()?; let ordered_partition_by_indices = self.ordered_partition_by_indices.clone(); - Ok(match &self.partition_search_mode { - PartitionSearchMode::Sorted => { + Ok(match &self.input_order_mode { + InputOrderMode::Sorted => { // In Sorted mode, all partition by columns should be ordered. if self.window_expr()[0].partition_by().len() != ordered_partition_by_indices.len() @@ -175,7 +174,7 @@ impl BoundedWindowAggExec { ordered_partition_by_indices, }) } - PartitionSearchMode::Linear | PartitionSearchMode::PartiallySorted(_) => { + InputOrderMode::Linear | InputOrderMode::PartiallySorted(_) => { Box::new(LinearSearch::new(ordered_partition_by_indices)) } }) @@ -203,7 +202,7 @@ impl DisplayAs for BoundedWindowAggExec { ) }) .collect(); - let mode = &self.partition_search_mode; + let mode = &self.input_order_mode; write!(f, "wdw=[{}], mode=[{:?}]", g.join(", "), mode)?; } } @@ -244,7 +243,7 @@ impl ExecutionPlan for BoundedWindowAggExec { fn required_input_ordering(&self) -> Vec>> { let partition_bys = self.window_expr()[0].partition_by(); let order_keys = self.window_expr()[0].order_by(); - if self.partition_search_mode != PartitionSearchMode::Sorted + if self.input_order_mode != InputOrderMode::Sorted || self.ordered_partition_by_indices.len() >= partition_bys.len() { let partition_bys = self @@ -283,7 +282,7 @@ impl ExecutionPlan for BoundedWindowAggExec { self.window_expr.clone(), children[0].clone(), self.partition_keys.clone(), - self.partition_search_mode.clone(), + self.input_order_mode.clone(), )?)) } @@ -586,7 +585,7 @@ impl LinearSearch { .map(|item| match item.evaluate(record_batch)? { ColumnarValue::Array(array) => Ok(array), ColumnarValue::Scalar(scalar) => { - plan_err!("Sort operation is not applicable to scalar value {scalar}") + scalar.to_array_of_size(record_batch.num_rows()) } }) .collect() @@ -1027,8 +1026,11 @@ impl BoundedWindowAggStream { .iter() .map(|elem| elem.slice(n_out, n_to_keep)) .collect::>(); - self.input_buffer = - RecordBatch::try_new(self.input_buffer.schema(), batch_to_keep)?; + self.input_buffer = RecordBatch::try_new_with_options( + self.input_buffer.schema(), + batch_to_keep, + &RecordBatchOptions::new().with_row_count(Some(n_to_keep)), + )?; Ok(()) } @@ -1114,7 +1116,7 @@ fn get_aggregate_result_out_column( mod tests { use crate::common::collect; use crate::memory::MemoryExec; - use crate::windows::{BoundedWindowAggExec, PartitionSearchMode}; + use crate::windows::{BoundedWindowAggExec, InputOrderMode}; use crate::{get_plan_string, ExecutionPlan}; use arrow_array::RecordBatch; use arrow_schema::{DataType, Field, Schema}; @@ -1201,7 +1203,7 @@ mod tests { window_exprs, memory_exec, vec![], - PartitionSearchMode::Sorted, + InputOrderMode::Sorted, ) .map(|e| Arc::new(e) as Arc)?; diff --git a/datafusion/physical-plan/src/windows/mod.rs b/datafusion/physical-plan/src/windows/mod.rs index d97e3c93a136..3187e6b0fbd3 100644 --- a/datafusion/physical-plan/src/windows/mod.rs +++ b/datafusion/physical-plan/src/windows/mod.rs @@ -27,7 +27,7 @@ use crate::{ cume_dist, dense_rank, lag, lead, percent_rank, rank, Literal, NthValue, Ntile, PhysicalSortExpr, RowNumber, }, - udaf, unbounded_output, ExecutionPlan, PhysicalExpr, + udaf, unbounded_output, ExecutionPlan, InputOrderMode, PhysicalExpr, }; use arrow::datatypes::Schema; @@ -54,30 +54,6 @@ pub use datafusion_physical_expr::window::{ BuiltInWindowExpr, PlainAggregateWindowExpr, WindowExpr, }; -#[derive(Debug, Clone, PartialEq)] -/// Specifies aggregation grouping and/or window partitioning properties of a -/// set of expressions in terms of the existing ordering. -/// For example, if the existing ordering is `[a ASC, b ASC, c ASC]`: -/// - A `PARTITION BY b` clause will result in `Linear` mode. -/// - A `PARTITION BY a, c` or a `PARTITION BY c, a` clause will result in -/// `PartiallySorted([0])` or `PartiallySorted([1])` modes, respectively. -/// The vector stores the index of `a` in the respective PARTITION BY expression. -/// - A `PARTITION BY a, b` or a `PARTITION BY b, a` clause will result in -/// `Sorted` mode. -/// Note that the examples above are applicable for `GROUP BY` clauses too. -pub enum PartitionSearchMode { - /// There is no partial permutation of the expressions satisfying the - /// existing ordering. - Linear, - /// There is a partial permutation of the expressions satisfying the - /// existing ordering. Indices describing the longest partial permutation - /// are stored in the vector. - PartiallySorted(Vec), - /// There is a (full) permutation of the expressions satisfying the - /// existing ordering. - Sorted, -} - /// Create a physical expression for window function pub fn create_window_expr( fun: &WindowFunction, @@ -189,15 +165,26 @@ fn create_built_in_window_expr( BuiltInWindowFunction::PercentRank => Arc::new(percent_rank(name)), BuiltInWindowFunction::CumeDist => Arc::new(cume_dist(name)), BuiltInWindowFunction::Ntile => { - let n: i64 = get_scalar_value_from_args(args, 0)? - .ok_or_else(|| { - DataFusionError::Execution( - "NTILE requires at least 1 argument".to_string(), - ) - })? - .try_into()?; - let n: u64 = n as u64; - Arc::new(Ntile::new(name, n)) + let n = get_scalar_value_from_args(args, 0)?.ok_or_else(|| { + DataFusionError::Execution( + "NTILE requires a positive integer".to_string(), + ) + })?; + + if n.is_null() { + return exec_err!("NTILE requires a positive integer, but finds NULL"); + } + + if n.is_unsigned() { + let n: u64 = n.try_into()?; + Arc::new(Ntile::new(name, n)) + } else { + let n: i64 = n.try_into()?; + if n <= 0 { + return exec_err!("NTILE requires a positive integer"); + } + Arc::new(Ntile::new(name, n as u64)) + } } BuiltInWindowFunction::Lag => { let arg = args[0].clone(); @@ -403,17 +390,17 @@ pub fn get_best_fitting_window( // of the window_exprs are same. let partitionby_exprs = window_exprs[0].partition_by(); let orderby_keys = window_exprs[0].order_by(); - let (should_reverse, partition_search_mode) = - if let Some((should_reverse, partition_search_mode)) = + let (should_reverse, input_order_mode) = + if let Some((should_reverse, input_order_mode)) = get_window_mode(partitionby_exprs, orderby_keys, input) { - (should_reverse, partition_search_mode) + (should_reverse, input_order_mode) } else { return Ok(None); }; let is_unbounded = unbounded_output(input); - if !is_unbounded && partition_search_mode != PartitionSearchMode::Sorted { - // Executor has bounded input and `partition_search_mode` is not `PartitionSearchMode::Sorted` + if !is_unbounded && input_order_mode != InputOrderMode::Sorted { + // Executor has bounded input and `input_order_mode` is not `InputOrderMode::Sorted` // in this case removing the sort is not helpful, return: return Ok(None); }; @@ -441,13 +428,13 @@ pub fn get_best_fitting_window( window_expr, input.clone(), physical_partition_keys.to_vec(), - partition_search_mode, + input_order_mode, )?) as _)) - } else if partition_search_mode != PartitionSearchMode::Sorted { + } else if input_order_mode != InputOrderMode::Sorted { // For `WindowAggExec` to work correctly PARTITION BY columns should be sorted. - // Hence, if `partition_search_mode` is not `PartitionSearchMode::Sorted` we should convert - // input ordering such that it can work with PartitionSearchMode::Sorted (add `SortExec`). - // Effectively `WindowAggExec` works only in PartitionSearchMode::Sorted mode. + // Hence, if `input_order_mode` is not `Sorted` we should convert + // input ordering such that it can work with `Sorted` (add `SortExec`). + // Effectively `WindowAggExec` works only in `Sorted` mode. Ok(None) } else { Ok(Some(Arc::new(WindowAggExec::try_new( @@ -463,16 +450,16 @@ pub fn get_best_fitting_window( /// is sufficient to run the current window operator. /// - A `None` return value indicates that we can not remove the sort in question /// (input ordering is not sufficient to run current window executor). -/// - A `Some((bool, PartitionSearchMode))` value indicates that the window operator +/// - A `Some((bool, InputOrderMode))` value indicates that the window operator /// can run with existing input ordering, so we can remove `SortExec` before it. /// The `bool` field in the return value represents whether we should reverse window -/// operator to remove `SortExec` before it. The `PartitionSearchMode` field represents +/// operator to remove `SortExec` before it. The `InputOrderMode` field represents /// the mode this window operator should work in to accommodate the existing ordering. pub fn get_window_mode( partitionby_exprs: &[Arc], orderby_keys: &[PhysicalSortExpr], input: &Arc, -) -> Option<(bool, PartitionSearchMode)> { +) -> Option<(bool, InputOrderMode)> { let input_eqs = input.equivalence_properties(); let mut partition_by_reqs: Vec = vec![]; let (_, indices) = input_eqs.find_longest_permutation(partitionby_exprs); @@ -493,11 +480,11 @@ pub fn get_window_mode( if partition_by_eqs.ordering_satisfy_requirement(&req) { // Window can be run with existing ordering let mode = if indices.len() == partitionby_exprs.len() { - PartitionSearchMode::Sorted + InputOrderMode::Sorted } else if indices.is_empty() { - PartitionSearchMode::Linear + InputOrderMode::Linear } else { - PartitionSearchMode::PartiallySorted(indices) + InputOrderMode::PartiallySorted(indices) }; return Some((should_swap, mode)); } @@ -521,7 +508,7 @@ mod tests { use futures::FutureExt; - use PartitionSearchMode::{Linear, PartiallySorted, Sorted}; + use InputOrderMode::{Linear, PartiallySorted, Sorted}; fn create_test_schema() -> Result { let nullable_column = Field::new("nullable_col", DataType::Int32, true); @@ -781,11 +768,11 @@ mod tests { // Second field in the tuple is Vec where each element in the vector represents ORDER BY columns // For instance, vec!["c"], corresponds to ORDER BY c ASC NULLS FIRST, (ordering is default ordering. We do not check // for reversibility in this test). - // Third field in the tuple is Option, which corresponds to expected algorithm mode. + // Third field in the tuple is Option, which corresponds to expected algorithm mode. // None represents that existing ordering is not sufficient to run executor with any one of the algorithms // (We need to add SortExec to be able to run it). - // Some(PartitionSearchMode) represents, we can run algorithm with existing ordering; and algorithm should work in - // PartitionSearchMode. + // Some(InputOrderMode) represents, we can run algorithm with existing ordering; and algorithm should work in + // InputOrderMode. let test_cases = vec![ (vec!["a"], vec!["a"], Some(Sorted)), (vec!["a"], vec!["b"], Some(Sorted)), @@ -870,7 +857,7 @@ mod tests { } let res = get_window_mode(&partition_by_exprs, &order_by_exprs, &exec_unbounded); - // Since reversibility is not important in this test. Convert Option<(bool, PartitionSearchMode)> to Option + // Since reversibility is not important in this test. Convert Option<(bool, InputOrderMode)> to Option let res = res.map(|(_, mode)| mode); assert_eq!( res, *expected, @@ -901,12 +888,12 @@ mod tests { // Second field in the tuple is Vec<(str, bool, bool)> where each element in the vector represents ORDER BY columns // For instance, vec![("c", false, false)], corresponds to ORDER BY c ASC NULLS LAST, // similarly, vec![("c", true, true)], corresponds to ORDER BY c DESC NULLS FIRST, - // Third field in the tuple is Option<(bool, PartitionSearchMode)>, which corresponds to expected result. + // Third field in the tuple is Option<(bool, InputOrderMode)>, which corresponds to expected result. // None represents that existing ordering is not sufficient to run executor with any one of the algorithms // (We need to add SortExec to be able to run it). - // Some((bool, PartitionSearchMode)) represents, we can run algorithm with existing ordering. Algorithm should work in - // PartitionSearchMode, bool field represents whether we should reverse window expressions to run executor with existing ordering. - // For instance, `Some((false, PartitionSearchMode::Sorted))`, represents that we shouldn't reverse window expressions. And algorithm + // Some((bool, InputOrderMode)) represents, we can run algorithm with existing ordering. Algorithm should work in + // InputOrderMode, bool field represents whether we should reverse window expressions to run executor with existing ordering. + // For instance, `Some((false, InputOrderMode::Sorted))`, represents that we shouldn't reverse window expressions. And algorithm // should work in Sorted mode to work with existing ordering. let test_cases = vec![ // PARTITION BY a, b ORDER BY c ASC NULLS LAST diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto index 8c2fd5369e33..f391592dfe76 100644 --- a/datafusion/proto/proto/datafusion.proto +++ b/datafusion/proto/proto/datafusion.proto @@ -216,6 +216,7 @@ message CreateExternalTableNode { bool unbounded = 14; map options = 11; Constraints constraints = 15; + map column_defaults = 16; } message PrepareNode { @@ -643,6 +644,8 @@ enum ScalarFunction { Levenshtein = 125; SubstrIndex = 126; FindInSet = 127; + ArraySort = 128; + ArrayDistinct = 129; } message ScalarFunctionNode { @@ -838,6 +841,8 @@ message Field { // for complex data types like structs, unions repeated Field children = 4; map metadata = 5; + int64 dict_id = 6; + bool dict_ordered = 7; } message FixedSizeBinary{ @@ -1159,6 +1164,8 @@ message PhysicalPlanNode { AnalyzeExecNode analyze = 23; JsonSinkExecNode json_sink = 24; SymmetricHashJoinExecNode symmetric_hash_join = 25; + InterleaveExecNode interleave = 26; + PlaceholderRowExecNode placeholder_row = 27; } } @@ -1368,6 +1375,7 @@ message PhysicalNegativeNode { message FilterExecNode { PhysicalPlanNode input = 1; PhysicalExprNode expr = 2; + uint32 default_filter_selectivity = 3; } message FileGroup { @@ -1451,6 +1459,10 @@ message SymmetricHashJoinExecNode { JoinFilter filter = 8; } +message InterleaveExecNode { + repeated PhysicalPlanNode inputs = 1; +} + message UnionExecNode { repeated PhysicalPlanNode inputs = 1; } @@ -1484,8 +1496,11 @@ message JoinOn { } message EmptyExecNode { - bool produce_one_row = 1; - Schema schema = 2; + Schema schema = 1; +} + +message PlaceholderRowExecNode { + Schema schema = 1; } message ProjectionExecNode { @@ -1502,7 +1517,7 @@ enum AggregateMode { SINGLE_PARTITIONED = 4; } -message PartiallySortedPartitionSearchMode { +message PartiallySortedInputOrderMode { repeated uint64 columns = 6; } @@ -1511,9 +1526,9 @@ message WindowAggExecNode { repeated PhysicalWindowExprNode window_expr = 2; repeated PhysicalExprNode partition_keys = 5; // Set optional to `None` for `BoundedWindowAggExec`. - oneof partition_search_mode { + oneof input_order_mode { EmptyMessage linear = 7; - PartiallySortedPartitionSearchMode partially_sorted = 8; + PartiallySortedInputOrderMode partially_sorted = 8; EmptyMessage sorted = 9; } } diff --git a/datafusion/proto/src/generated/pbjson.rs b/datafusion/proto/src/generated/pbjson.rs index b8c5f6a4aae8..d506b5dcce53 100644 --- a/datafusion/proto/src/generated/pbjson.rs +++ b/datafusion/proto/src/generated/pbjson.rs @@ -4026,6 +4026,9 @@ impl serde::Serialize for CreateExternalTableNode { if self.constraints.is_some() { len += 1; } + if !self.column_defaults.is_empty() { + len += 1; + } let mut struct_ser = serializer.serialize_struct("datafusion.CreateExternalTableNode", len)?; if let Some(v) = self.name.as_ref() { struct_ser.serialize_field("name", v)?; @@ -4069,6 +4072,9 @@ impl serde::Serialize for CreateExternalTableNode { if let Some(v) = self.constraints.as_ref() { struct_ser.serialize_field("constraints", v)?; } + if !self.column_defaults.is_empty() { + struct_ser.serialize_field("columnDefaults", &self.column_defaults)?; + } struct_ser.end() } } @@ -4099,6 +4105,8 @@ impl<'de> serde::Deserialize<'de> for CreateExternalTableNode { "unbounded", "options", "constraints", + "column_defaults", + "columnDefaults", ]; #[allow(clippy::enum_variant_names)] @@ -4117,6 +4125,7 @@ impl<'de> serde::Deserialize<'de> for CreateExternalTableNode { Unbounded, Options, Constraints, + ColumnDefaults, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -4152,6 +4161,7 @@ impl<'de> serde::Deserialize<'de> for CreateExternalTableNode { "unbounded" => Ok(GeneratedField::Unbounded), "options" => Ok(GeneratedField::Options), "constraints" => Ok(GeneratedField::Constraints), + "columnDefaults" | "column_defaults" => Ok(GeneratedField::ColumnDefaults), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -4185,6 +4195,7 @@ impl<'de> serde::Deserialize<'de> for CreateExternalTableNode { let mut unbounded__ = None; let mut options__ = None; let mut constraints__ = None; + let mut column_defaults__ = None; while let Some(k) = map_.next_key()? { match k { GeneratedField::Name => { @@ -4273,6 +4284,14 @@ impl<'de> serde::Deserialize<'de> for CreateExternalTableNode { } constraints__ = map_.next_value()?; } + GeneratedField::ColumnDefaults => { + if column_defaults__.is_some() { + return Err(serde::de::Error::duplicate_field("columnDefaults")); + } + column_defaults__ = Some( + map_.next_value::>()? + ); + } } } Ok(CreateExternalTableNode { @@ -4290,6 +4309,7 @@ impl<'de> serde::Deserialize<'de> for CreateExternalTableNode { unbounded: unbounded__.unwrap_or_default(), options: options__.unwrap_or_default(), constraints: constraints__, + column_defaults: column_defaults__.unwrap_or_default(), }) } } @@ -6369,16 +6389,10 @@ impl serde::Serialize for EmptyExecNode { { use serde::ser::SerializeStruct; let mut len = 0; - if self.produce_one_row { - len += 1; - } if self.schema.is_some() { len += 1; } let mut struct_ser = serializer.serialize_struct("datafusion.EmptyExecNode", len)?; - if self.produce_one_row { - struct_ser.serialize_field("produceOneRow", &self.produce_one_row)?; - } if let Some(v) = self.schema.as_ref() { struct_ser.serialize_field("schema", v)?; } @@ -6392,14 +6406,11 @@ impl<'de> serde::Deserialize<'de> for EmptyExecNode { D: serde::Deserializer<'de>, { const FIELDS: &[&str] = &[ - "produce_one_row", - "produceOneRow", "schema", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { - ProduceOneRow, Schema, } impl<'de> serde::Deserialize<'de> for GeneratedField { @@ -6422,7 +6433,6 @@ impl<'de> serde::Deserialize<'de> for EmptyExecNode { E: serde::de::Error, { match value { - "produceOneRow" | "produce_one_row" => Ok(GeneratedField::ProduceOneRow), "schema" => Ok(GeneratedField::Schema), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } @@ -6443,16 +6453,9 @@ impl<'de> serde::Deserialize<'de> for EmptyExecNode { where V: serde::de::MapAccess<'de>, { - let mut produce_one_row__ = None; let mut schema__ = None; while let Some(k) = map_.next_key()? { match k { - GeneratedField::ProduceOneRow => { - if produce_one_row__.is_some() { - return Err(serde::de::Error::duplicate_field("produceOneRow")); - } - produce_one_row__ = Some(map_.next_value()?); - } GeneratedField::Schema => { if schema__.is_some() { return Err(serde::de::Error::duplicate_field("schema")); @@ -6462,7 +6465,6 @@ impl<'de> serde::Deserialize<'de> for EmptyExecNode { } } Ok(EmptyExecNode { - produce_one_row: produce_one_row__.unwrap_or_default(), schema: schema__, }) } @@ -6890,6 +6892,12 @@ impl serde::Serialize for Field { if !self.metadata.is_empty() { len += 1; } + if self.dict_id != 0 { + len += 1; + } + if self.dict_ordered { + len += 1; + } let mut struct_ser = serializer.serialize_struct("datafusion.Field", len)?; if !self.name.is_empty() { struct_ser.serialize_field("name", &self.name)?; @@ -6906,6 +6914,13 @@ impl serde::Serialize for Field { if !self.metadata.is_empty() { struct_ser.serialize_field("metadata", &self.metadata)?; } + if self.dict_id != 0 { + #[allow(clippy::needless_borrow)] + struct_ser.serialize_field("dictId", ToString::to_string(&self.dict_id).as_str())?; + } + if self.dict_ordered { + struct_ser.serialize_field("dictOrdered", &self.dict_ordered)?; + } struct_ser.end() } } @@ -6922,6 +6937,10 @@ impl<'de> serde::Deserialize<'de> for Field { "nullable", "children", "metadata", + "dict_id", + "dictId", + "dict_ordered", + "dictOrdered", ]; #[allow(clippy::enum_variant_names)] @@ -6931,6 +6950,8 @@ impl<'de> serde::Deserialize<'de> for Field { Nullable, Children, Metadata, + DictId, + DictOrdered, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -6957,6 +6978,8 @@ impl<'de> serde::Deserialize<'de> for Field { "nullable" => Ok(GeneratedField::Nullable), "children" => Ok(GeneratedField::Children), "metadata" => Ok(GeneratedField::Metadata), + "dictId" | "dict_id" => Ok(GeneratedField::DictId), + "dictOrdered" | "dict_ordered" => Ok(GeneratedField::DictOrdered), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -6981,6 +7004,8 @@ impl<'de> serde::Deserialize<'de> for Field { let mut nullable__ = None; let mut children__ = None; let mut metadata__ = None; + let mut dict_id__ = None; + let mut dict_ordered__ = None; while let Some(k) = map_.next_key()? { match k { GeneratedField::Name => { @@ -7015,6 +7040,20 @@ impl<'de> serde::Deserialize<'de> for Field { map_.next_value::>()? ); } + GeneratedField::DictId => { + if dict_id__.is_some() { + return Err(serde::de::Error::duplicate_field("dictId")); + } + dict_id__ = + Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) + ; + } + GeneratedField::DictOrdered => { + if dict_ordered__.is_some() { + return Err(serde::de::Error::duplicate_field("dictOrdered")); + } + dict_ordered__ = Some(map_.next_value()?); + } } } Ok(Field { @@ -7023,6 +7062,8 @@ impl<'de> serde::Deserialize<'de> for Field { nullable: nullable__.unwrap_or_default(), children: children__.unwrap_or_default(), metadata: metadata__.unwrap_or_default(), + dict_id: dict_id__.unwrap_or_default(), + dict_ordered: dict_ordered__.unwrap_or_default(), }) } } @@ -7797,6 +7838,9 @@ impl serde::Serialize for FilterExecNode { if self.expr.is_some() { len += 1; } + if self.default_filter_selectivity != 0 { + len += 1; + } let mut struct_ser = serializer.serialize_struct("datafusion.FilterExecNode", len)?; if let Some(v) = self.input.as_ref() { struct_ser.serialize_field("input", v)?; @@ -7804,6 +7848,9 @@ impl serde::Serialize for FilterExecNode { if let Some(v) = self.expr.as_ref() { struct_ser.serialize_field("expr", v)?; } + if self.default_filter_selectivity != 0 { + struct_ser.serialize_field("defaultFilterSelectivity", &self.default_filter_selectivity)?; + } struct_ser.end() } } @@ -7816,12 +7863,15 @@ impl<'de> serde::Deserialize<'de> for FilterExecNode { const FIELDS: &[&str] = &[ "input", "expr", + "default_filter_selectivity", + "defaultFilterSelectivity", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { Input, Expr, + DefaultFilterSelectivity, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -7845,6 +7895,7 @@ impl<'de> serde::Deserialize<'de> for FilterExecNode { match value { "input" => Ok(GeneratedField::Input), "expr" => Ok(GeneratedField::Expr), + "defaultFilterSelectivity" | "default_filter_selectivity" => Ok(GeneratedField::DefaultFilterSelectivity), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -7866,6 +7917,7 @@ impl<'de> serde::Deserialize<'de> for FilterExecNode { { let mut input__ = None; let mut expr__ = None; + let mut default_filter_selectivity__ = None; while let Some(k) = map_.next_key()? { match k { GeneratedField::Input => { @@ -7880,11 +7932,20 @@ impl<'de> serde::Deserialize<'de> for FilterExecNode { } expr__ = map_.next_value()?; } + GeneratedField::DefaultFilterSelectivity => { + if default_filter_selectivity__.is_some() { + return Err(serde::de::Error::duplicate_field("defaultFilterSelectivity")); + } + default_filter_selectivity__ = + Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) + ; + } } } Ok(FilterExecNode { input: input__, expr: expr__, + default_filter_selectivity: default_filter_selectivity__.unwrap_or_default(), }) } } @@ -9165,6 +9226,97 @@ impl<'de> serde::Deserialize<'de> for InListNode { deserializer.deserialize_struct("datafusion.InListNode", FIELDS, GeneratedVisitor) } } +impl serde::Serialize for InterleaveExecNode { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + use serde::ser::SerializeStruct; + let mut len = 0; + if !self.inputs.is_empty() { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion.InterleaveExecNode", len)?; + if !self.inputs.is_empty() { + struct_ser.serialize_field("inputs", &self.inputs)?; + } + struct_ser.end() + } +} +impl<'de> serde::Deserialize<'de> for InterleaveExecNode { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "inputs", + ]; + + #[allow(clippy::enum_variant_names)] + enum GeneratedField { + Inputs, + } + impl<'de> serde::Deserialize<'de> for GeneratedField { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + struct GeneratedVisitor; + + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = GeneratedField; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + #[allow(unused_variables)] + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "inputs" => Ok(GeneratedField::Inputs), + _ => Err(serde::de::Error::unknown_field(value, FIELDS)), + } + } + } + deserializer.deserialize_identifier(GeneratedVisitor) + } + } + struct GeneratedVisitor; + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = InterleaveExecNode; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter.write_str("struct datafusion.InterleaveExecNode") + } + + fn visit_map(self, mut map_: V) -> std::result::Result + where + V: serde::de::MapAccess<'de>, + { + let mut inputs__ = None; + while let Some(k) = map_.next_key()? { + match k { + GeneratedField::Inputs => { + if inputs__.is_some() { + return Err(serde::de::Error::duplicate_field("inputs")); + } + inputs__ = Some(map_.next_value()?); + } + } + } + Ok(InterleaveExecNode { + inputs: inputs__.unwrap_or_default(), + }) + } + } + deserializer.deserialize_struct("datafusion.InterleaveExecNode", FIELDS, GeneratedVisitor) + } +} impl serde::Serialize for IntervalMonthDayNanoValue { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result @@ -14967,7 +15119,7 @@ impl<'de> serde::Deserialize<'de> for PartialTableReference { deserializer.deserialize_struct("datafusion.PartialTableReference", FIELDS, GeneratedVisitor) } } -impl serde::Serialize for PartiallySortedPartitionSearchMode { +impl serde::Serialize for PartiallySortedInputOrderMode { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result where @@ -14978,14 +15130,14 @@ impl serde::Serialize for PartiallySortedPartitionSearchMode { if !self.columns.is_empty() { len += 1; } - let mut struct_ser = serializer.serialize_struct("datafusion.PartiallySortedPartitionSearchMode", len)?; + let mut struct_ser = serializer.serialize_struct("datafusion.PartiallySortedInputOrderMode", len)?; if !self.columns.is_empty() { struct_ser.serialize_field("columns", &self.columns.iter().map(ToString::to_string).collect::>())?; } struct_ser.end() } } -impl<'de> serde::Deserialize<'de> for PartiallySortedPartitionSearchMode { +impl<'de> serde::Deserialize<'de> for PartiallySortedInputOrderMode { #[allow(deprecated)] fn deserialize(deserializer: D) -> std::result::Result where @@ -15029,13 +15181,13 @@ impl<'de> serde::Deserialize<'de> for PartiallySortedPartitionSearchMode { } struct GeneratedVisitor; impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = PartiallySortedPartitionSearchMode; + type Value = PartiallySortedInputOrderMode; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.PartiallySortedPartitionSearchMode") + formatter.write_str("struct datafusion.PartiallySortedInputOrderMode") } - fn visit_map(self, mut map_: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { @@ -15053,12 +15205,12 @@ impl<'de> serde::Deserialize<'de> for PartiallySortedPartitionSearchMode { } } } - Ok(PartiallySortedPartitionSearchMode { + Ok(PartiallySortedInputOrderMode { columns: columns__.unwrap_or_default(), }) } } - deserializer.deserialize_struct("datafusion.PartiallySortedPartitionSearchMode", FIELDS, GeneratedVisitor) + deserializer.deserialize_struct("datafusion.PartiallySortedInputOrderMode", FIELDS, GeneratedVisitor) } } impl serde::Serialize for PartitionColumn { @@ -17847,6 +17999,12 @@ impl serde::Serialize for PhysicalPlanNode { physical_plan_node::PhysicalPlanType::SymmetricHashJoin(v) => { struct_ser.serialize_field("symmetricHashJoin", v)?; } + physical_plan_node::PhysicalPlanType::Interleave(v) => { + struct_ser.serialize_field("interleave", v)?; + } + physical_plan_node::PhysicalPlanType::PlaceholderRow(v) => { + struct_ser.serialize_field("placeholderRow", v)?; + } } } struct_ser.end() @@ -17895,6 +18053,9 @@ impl<'de> serde::Deserialize<'de> for PhysicalPlanNode { "jsonSink", "symmetric_hash_join", "symmetricHashJoin", + "interleave", + "placeholder_row", + "placeholderRow", ]; #[allow(clippy::enum_variant_names)] @@ -17923,6 +18084,8 @@ impl<'de> serde::Deserialize<'de> for PhysicalPlanNode { Analyze, JsonSink, SymmetricHashJoin, + Interleave, + PlaceholderRow, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -17968,6 +18131,8 @@ impl<'de> serde::Deserialize<'de> for PhysicalPlanNode { "analyze" => Ok(GeneratedField::Analyze), "jsonSink" | "json_sink" => Ok(GeneratedField::JsonSink), "symmetricHashJoin" | "symmetric_hash_join" => Ok(GeneratedField::SymmetricHashJoin), + "interleave" => Ok(GeneratedField::Interleave), + "placeholderRow" | "placeholder_row" => Ok(GeneratedField::PlaceholderRow), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -18156,6 +18321,20 @@ impl<'de> serde::Deserialize<'de> for PhysicalPlanNode { return Err(serde::de::Error::duplicate_field("symmetricHashJoin")); } physical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_plan_node::PhysicalPlanType::SymmetricHashJoin) +; + } + GeneratedField::Interleave => { + if physical_plan_type__.is_some() { + return Err(serde::de::Error::duplicate_field("interleave")); + } + physical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_plan_node::PhysicalPlanType::Interleave) +; + } + GeneratedField::PlaceholderRow => { + if physical_plan_type__.is_some() { + return Err(serde::de::Error::duplicate_field("placeholderRow")); + } + physical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_plan_node::PhysicalPlanType::PlaceholderRow) ; } } @@ -19186,6 +19365,97 @@ impl<'de> serde::Deserialize<'de> for PlaceholderNode { deserializer.deserialize_struct("datafusion.PlaceholderNode", FIELDS, GeneratedVisitor) } } +impl serde::Serialize for PlaceholderRowExecNode { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + use serde::ser::SerializeStruct; + let mut len = 0; + if self.schema.is_some() { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion.PlaceholderRowExecNode", len)?; + if let Some(v) = self.schema.as_ref() { + struct_ser.serialize_field("schema", v)?; + } + struct_ser.end() + } +} +impl<'de> serde::Deserialize<'de> for PlaceholderRowExecNode { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "schema", + ]; + + #[allow(clippy::enum_variant_names)] + enum GeneratedField { + Schema, + } + impl<'de> serde::Deserialize<'de> for GeneratedField { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + struct GeneratedVisitor; + + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = GeneratedField; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + #[allow(unused_variables)] + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "schema" => Ok(GeneratedField::Schema), + _ => Err(serde::de::Error::unknown_field(value, FIELDS)), + } + } + } + deserializer.deserialize_identifier(GeneratedVisitor) + } + } + struct GeneratedVisitor; + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = PlaceholderRowExecNode; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter.write_str("struct datafusion.PlaceholderRowExecNode") + } + + fn visit_map(self, mut map_: V) -> std::result::Result + where + V: serde::de::MapAccess<'de>, + { + let mut schema__ = None; + while let Some(k) = map_.next_key()? { + match k { + GeneratedField::Schema => { + if schema__.is_some() { + return Err(serde::de::Error::duplicate_field("schema")); + } + schema__ = map_.next_value()?; + } + } + } + Ok(PlaceholderRowExecNode { + schema: schema__, + }) + } + } + deserializer.deserialize_struct("datafusion.PlaceholderRowExecNode", FIELDS, GeneratedVisitor) + } +} impl serde::Serialize for PlanType { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result @@ -20865,6 +21135,8 @@ impl serde::Serialize for ScalarFunction { Self::Levenshtein => "Levenshtein", Self::SubstrIndex => "SubstrIndex", Self::FindInSet => "FindInSet", + Self::ArraySort => "ArraySort", + Self::ArrayDistinct => "ArrayDistinct", }; serializer.serialize_str(variant) } @@ -21004,6 +21276,8 @@ impl<'de> serde::Deserialize<'de> for ScalarFunction { "Levenshtein", "SubstrIndex", "FindInSet", + "ArraySort", + "ArrayDistinct", ]; struct GeneratedVisitor; @@ -21172,6 +21446,8 @@ impl<'de> serde::Deserialize<'de> for ScalarFunction { "Levenshtein" => Ok(ScalarFunction::Levenshtein), "SubstrIndex" => Ok(ScalarFunction::SubstrIndex), "FindInSet" => Ok(ScalarFunction::FindInSet), + "ArraySort" => Ok(ScalarFunction::ArraySort), + "ArrayDistinct" => Ok(ScalarFunction::ArrayDistinct), _ => Err(serde::de::Error::unknown_variant(value, FIELDS)), } } @@ -25639,7 +25915,7 @@ impl serde::Serialize for WindowAggExecNode { if !self.partition_keys.is_empty() { len += 1; } - if self.partition_search_mode.is_some() { + if self.input_order_mode.is_some() { len += 1; } let mut struct_ser = serializer.serialize_struct("datafusion.WindowAggExecNode", len)?; @@ -25652,15 +25928,15 @@ impl serde::Serialize for WindowAggExecNode { if !self.partition_keys.is_empty() { struct_ser.serialize_field("partitionKeys", &self.partition_keys)?; } - if let Some(v) = self.partition_search_mode.as_ref() { + if let Some(v) = self.input_order_mode.as_ref() { match v { - window_agg_exec_node::PartitionSearchMode::Linear(v) => { + window_agg_exec_node::InputOrderMode::Linear(v) => { struct_ser.serialize_field("linear", v)?; } - window_agg_exec_node::PartitionSearchMode::PartiallySorted(v) => { + window_agg_exec_node::InputOrderMode::PartiallySorted(v) => { struct_ser.serialize_field("partiallySorted", v)?; } - window_agg_exec_node::PartitionSearchMode::Sorted(v) => { + window_agg_exec_node::InputOrderMode::Sorted(v) => { struct_ser.serialize_field("sorted", v)?; } } @@ -25743,7 +26019,7 @@ impl<'de> serde::Deserialize<'de> for WindowAggExecNode { let mut input__ = None; let mut window_expr__ = None; let mut partition_keys__ = None; - let mut partition_search_mode__ = None; + let mut input_order_mode__ = None; while let Some(k) = map_.next_key()? { match k { GeneratedField::Input => { @@ -25765,24 +26041,24 @@ impl<'de> serde::Deserialize<'de> for WindowAggExecNode { partition_keys__ = Some(map_.next_value()?); } GeneratedField::Linear => { - if partition_search_mode__.is_some() { + if input_order_mode__.is_some() { return Err(serde::de::Error::duplicate_field("linear")); } - partition_search_mode__ = map_.next_value::<::std::option::Option<_>>()?.map(window_agg_exec_node::PartitionSearchMode::Linear) + input_order_mode__ = map_.next_value::<::std::option::Option<_>>()?.map(window_agg_exec_node::InputOrderMode::Linear) ; } GeneratedField::PartiallySorted => { - if partition_search_mode__.is_some() { + if input_order_mode__.is_some() { return Err(serde::de::Error::duplicate_field("partiallySorted")); } - partition_search_mode__ = map_.next_value::<::std::option::Option<_>>()?.map(window_agg_exec_node::PartitionSearchMode::PartiallySorted) + input_order_mode__ = map_.next_value::<::std::option::Option<_>>()?.map(window_agg_exec_node::InputOrderMode::PartiallySorted) ; } GeneratedField::Sorted => { - if partition_search_mode__.is_some() { + if input_order_mode__.is_some() { return Err(serde::de::Error::duplicate_field("sorted")); } - partition_search_mode__ = map_.next_value::<::std::option::Option<_>>()?.map(window_agg_exec_node::PartitionSearchMode::Sorted) + input_order_mode__ = map_.next_value::<::std::option::Option<_>>()?.map(window_agg_exec_node::InputOrderMode::Sorted) ; } } @@ -25791,7 +26067,7 @@ impl<'de> serde::Deserialize<'de> for WindowAggExecNode { input: input__, window_expr: window_expr__.unwrap_or_default(), partition_keys: partition_keys__.unwrap_or_default(), - partition_search_mode: partition_search_mode__, + input_order_mode: input_order_mode__, }) } } diff --git a/datafusion/proto/src/generated/prost.rs b/datafusion/proto/src/generated/prost.rs index c31bc4ab5948..8aadc96349ca 100644 --- a/datafusion/proto/src/generated/prost.rs +++ b/datafusion/proto/src/generated/prost.rs @@ -360,6 +360,11 @@ pub struct CreateExternalTableNode { >, #[prost(message, optional, tag = "15")] pub constraints: ::core::option::Option, + #[prost(map = "string, message", tag = "16")] + pub column_defaults: ::std::collections::HashMap< + ::prost::alloc::string::String, + LogicalExprNode, + >, } #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] @@ -1022,6 +1027,10 @@ pub struct Field { ::prost::alloc::string::String, ::prost::alloc::string::String, >, + #[prost(int64, tag = "6")] + pub dict_id: i64, + #[prost(bool, tag = "7")] + pub dict_ordered: bool, } #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] @@ -1516,7 +1525,7 @@ pub mod owned_table_reference { pub struct PhysicalPlanNode { #[prost( oneof = "physical_plan_node::PhysicalPlanType", - tags = "1, 2, 3, 4, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25" + tags = "1, 2, 3, 4, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27" )] pub physical_plan_type: ::core::option::Option, } @@ -1575,6 +1584,10 @@ pub mod physical_plan_node { JsonSink(::prost::alloc::boxed::Box), #[prost(message, tag = "25")] SymmetricHashJoin(::prost::alloc::boxed::Box), + #[prost(message, tag = "26")] + Interleave(super::InterleaveExecNode), + #[prost(message, tag = "27")] + PlaceholderRow(super::PlaceholderRowExecNode), } } #[allow(clippy::derive_partial_eq_without_eq)] @@ -1916,6 +1929,8 @@ pub struct FilterExecNode { pub input: ::core::option::Option<::prost::alloc::boxed::Box>, #[prost(message, optional, tag = "2")] pub expr: ::core::option::Option, + #[prost(uint32, tag = "3")] + pub default_filter_selectivity: u32, } #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] @@ -2031,6 +2046,12 @@ pub struct SymmetricHashJoinExecNode { } #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] +pub struct InterleaveExecNode { + #[prost(message, repeated, tag = "1")] + pub inputs: ::prost::alloc::vec::Vec, +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] pub struct UnionExecNode { #[prost(message, repeated, tag = "1")] pub inputs: ::prost::alloc::vec::Vec, @@ -2084,9 +2105,13 @@ pub struct JoinOn { #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct EmptyExecNode { - #[prost(bool, tag = "1")] - pub produce_one_row: bool, - #[prost(message, optional, tag = "2")] + #[prost(message, optional, tag = "1")] + pub schema: ::core::option::Option, +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct PlaceholderRowExecNode { + #[prost(message, optional, tag = "1")] pub schema: ::core::option::Option, } #[allow(clippy::derive_partial_eq_without_eq)] @@ -2101,7 +2126,7 @@ pub struct ProjectionExecNode { } #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] -pub struct PartiallySortedPartitionSearchMode { +pub struct PartiallySortedInputOrderMode { #[prost(uint64, repeated, tag = "6")] pub columns: ::prost::alloc::vec::Vec, } @@ -2115,21 +2140,19 @@ pub struct WindowAggExecNode { #[prost(message, repeated, tag = "5")] pub partition_keys: ::prost::alloc::vec::Vec, /// Set optional to `None` for `BoundedWindowAggExec`. - #[prost(oneof = "window_agg_exec_node::PartitionSearchMode", tags = "7, 8, 9")] - pub partition_search_mode: ::core::option::Option< - window_agg_exec_node::PartitionSearchMode, - >, + #[prost(oneof = "window_agg_exec_node::InputOrderMode", tags = "7, 8, 9")] + pub input_order_mode: ::core::option::Option, } /// Nested message and enum types in `WindowAggExecNode`. pub mod window_agg_exec_node { /// Set optional to `None` for `BoundedWindowAggExec`. #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Oneof)] - pub enum PartitionSearchMode { + pub enum InputOrderMode { #[prost(message, tag = "7")] Linear(super::EmptyMessage), #[prost(message, tag = "8")] - PartiallySorted(super::PartiallySortedPartitionSearchMode), + PartiallySorted(super::PartiallySortedInputOrderMode), #[prost(message, tag = "9")] Sorted(super::EmptyMessage), } @@ -2596,6 +2619,8 @@ pub enum ScalarFunction { Levenshtein = 125, SubstrIndex = 126, FindInSet = 127, + ArraySort = 128, + ArrayDistinct = 129, } impl ScalarFunction { /// String value of the enum field names used in the ProtoBuf definition. @@ -2732,6 +2757,8 @@ impl ScalarFunction { ScalarFunction::Levenshtein => "Levenshtein", ScalarFunction::SubstrIndex => "SubstrIndex", ScalarFunction::FindInSet => "FindInSet", + ScalarFunction::ArraySort => "ArraySort", + ScalarFunction::ArrayDistinct => "ArrayDistinct", } } /// Creates an enum from field names used in the ProtoBuf definition. @@ -2865,6 +2892,8 @@ impl ScalarFunction { "Levenshtein" => Some(Self::Levenshtein), "SubstrIndex" => Some(Self::SubstrIndex), "FindInSet" => Some(Self::FindInSet), + "ArraySort" => Some(Self::ArraySort), + "ArrayDistinct" => Some(Self::ArrayDistinct), _ => None, } } diff --git a/datafusion/proto/src/logical_plan/from_proto.rs b/datafusion/proto/src/logical_plan/from_proto.rs index ae3628bddeb2..193e0947d6d9 100644 --- a/datafusion/proto/src/logical_plan/from_proto.rs +++ b/datafusion/proto/src/logical_plan/from_proto.rs @@ -39,15 +39,17 @@ use datafusion_common::{ internal_err, plan_datafusion_err, Column, Constraint, Constraints, DFField, DFSchema, DFSchemaRef, DataFusionError, OwnedTableReference, Result, ScalarValue, }; +use datafusion_expr::window_frame::{check_window_frame, regularize_window_order_by}; use datafusion_expr::{ - abs, acos, acosh, array, array_append, array_concat, array_dims, array_element, - array_except, array_has, array_has_all, array_has_any, array_intersect, array_length, - array_ndims, array_position, array_positions, array_prepend, array_remove, - array_remove_all, array_remove_n, array_repeat, array_replace, array_replace_all, - array_replace_n, array_slice, array_to_string, arrow_typeof, ascii, asin, asinh, - atan, atan2, atanh, bit_length, btrim, cardinality, cbrt, ceil, character_length, - chr, coalesce, concat_expr, concat_ws_expr, cos, cosh, cot, current_date, - current_time, date_bin, date_part, date_trunc, decode, degrees, digest, encode, exp, + abs, acos, acosh, array, array_append, array_concat, array_dims, array_distinct, + array_element, array_except, array_has, array_has_all, array_has_any, + array_intersect, array_length, array_ndims, array_position, array_positions, + array_prepend, array_remove, array_remove_all, array_remove_n, array_repeat, + array_replace, array_replace_all, array_replace_n, array_slice, array_sort, + array_to_string, arrow_typeof, ascii, asin, asinh, atan, atan2, atanh, bit_length, + btrim, cardinality, cbrt, ceil, character_length, chr, coalesce, concat_expr, + concat_ws_expr, cos, cosh, cot, current_date, current_time, date_bin, date_part, + date_trunc, decode, degrees, digest, encode, exp, expr::{self, InList, Sort, WindowFunction}, factorial, find_in_set, flatten, floor, from_unixtime, gcd, gen_range, isnan, iszero, lcm, left, levenshtein, ln, log, log10, log2, @@ -58,7 +60,6 @@ use datafusion_expr::{ sqrt, starts_with, string_to_array, strpos, struct_fun, substr, substr_index, substring, tan, tanh, to_hex, to_timestamp_micros, to_timestamp_millis, to_timestamp_nanos, to_timestamp_seconds, translate, trim, trunc, upper, uuid, - window_frame::regularize, AggregateFunction, Between, BinaryExpr, BuiltInWindowFunction, BuiltinScalarFunction, Case, Cast, Expr, GetFieldAccess, GetIndexedField, GroupingSet, GroupingSet::GroupingSets, @@ -376,8 +377,20 @@ impl TryFrom<&protobuf::Field> for Field { type Error = Error; fn try_from(field: &protobuf::Field) -> Result { let datatype = field.arrow_type.as_deref().required("arrow_type")?; - Ok(Self::new(field.name.as_str(), datatype, field.nullable) - .with_metadata(field.metadata.clone())) + let field = if field.dict_id != 0 { + Self::new_dict( + field.name.as_str(), + datatype, + field.nullable, + field.dict_id, + field.dict_ordered, + ) + .with_metadata(field.metadata.clone()) + } else { + Self::new(field.name.as_str(), datatype, field.nullable) + .with_metadata(field.metadata.clone()) + }; + Ok(field) } } @@ -463,6 +476,7 @@ impl From<&protobuf::ScalarFunction> for BuiltinScalarFunction { ScalarFunction::Rtrim => Self::Rtrim, ScalarFunction::ToTimestamp => Self::ToTimestamp, ScalarFunction::ArrayAppend => Self::ArrayAppend, + ScalarFunction::ArraySort => Self::ArraySort, ScalarFunction::ArrayConcat => Self::ArrayConcat, ScalarFunction::ArrayEmpty => Self::ArrayEmpty, ScalarFunction::ArrayExcept => Self::ArrayExcept, @@ -470,6 +484,7 @@ impl From<&protobuf::ScalarFunction> for BuiltinScalarFunction { ScalarFunction::ArrayHasAny => Self::ArrayHasAny, ScalarFunction::ArrayHas => Self::ArrayHas, ScalarFunction::ArrayDims => Self::ArrayDims, + ScalarFunction::ArrayDistinct => Self::ArrayDistinct, ScalarFunction::ArrayElement => Self::ArrayElement, ScalarFunction::Flatten => Self::Flatten, ScalarFunction::ArrayLength => Self::ArrayLength, @@ -1070,7 +1085,7 @@ pub fn parse_expr( .iter() .map(|e| parse_expr(e, registry)) .collect::, _>>()?; - let order_by = expr + let mut order_by = expr .order_by .iter() .map(|e| parse_expr(e, registry)) @@ -1080,7 +1095,8 @@ pub fn parse_expr( .as_ref() .map::, _>(|window_frame| { let window_frame = window_frame.clone().try_into()?; - regularize(window_frame, order_by.len()) + check_window_frame(&window_frame, order_by.len()) + .map(|_| window_frame) }) .transpose()? .ok_or_else(|| { @@ -1088,6 +1104,7 @@ pub fn parse_expr( "missing window frame during deserialization".to_string(), ) })?; + regularize_window_order_by(&window_frame, &mut order_by)?; match window_function { window_expr_node::WindowFunction::AggrFunction(i) => { @@ -1343,6 +1360,11 @@ pub fn parse_expr( parse_expr(&args[0], registry)?, parse_expr(&args[1], registry)?, )), + ScalarFunction::ArraySort => Ok(array_sort( + parse_expr(&args[0], registry)?, + parse_expr(&args[1], registry)?, + parse_expr(&args[2], registry)?, + )), ScalarFunction::ArrayPopFront => { Ok(array_pop_front(parse_expr(&args[0], registry)?)) } @@ -1446,6 +1468,9 @@ pub fn parse_expr( ScalarFunction::ArrayDims => { Ok(array_dims(parse_expr(&args[0], registry)?)) } + ScalarFunction::ArrayDistinct => { + Ok(array_distinct(parse_expr(&args[0], registry)?)) + } ScalarFunction::ArrayElement => Ok(array_element( parse_expr(&args[0], registry)?, parse_expr(&args[1], registry)?, diff --git a/datafusion/proto/src/logical_plan/mod.rs b/datafusion/proto/src/logical_plan/mod.rs index 851f062bd51f..50bca0295def 100644 --- a/datafusion/proto/src/logical_plan/mod.rs +++ b/datafusion/proto/src/logical_plan/mod.rs @@ -15,6 +15,7 @@ // specific language governing permissions and limitations // under the License. +use std::collections::HashMap; use std::fmt::Debug; use std::str::FromStr; use std::sync::Arc; @@ -521,6 +522,13 @@ impl AsLogicalPlan for LogicalPlanNode { order_exprs.push(order_expr) } + let mut column_defaults = + HashMap::with_capacity(create_extern_table.column_defaults.len()); + for (col_name, expr) in &create_extern_table.column_defaults { + let expr = from_proto::parse_expr(expr, ctx)?; + column_defaults.insert(col_name.clone(), expr); + } + Ok(LogicalPlan::Ddl(DdlStatement::CreateExternalTable(CreateExternalTable { schema: pb_schema.try_into()?, name: from_owned_table_reference(create_extern_table.name.as_ref(), "CreateExternalTable")?, @@ -540,6 +548,7 @@ impl AsLogicalPlan for LogicalPlanNode { unbounded: create_extern_table.unbounded, options: create_extern_table.options.clone(), constraints: constraints.into(), + column_defaults, }))) } LogicalPlanType::CreateView(create_view) => { @@ -1298,6 +1307,7 @@ impl AsLogicalPlan for LogicalPlanNode { unbounded, options, constraints, + column_defaults, }, )) => { let mut converted_order_exprs: Vec = vec![]; @@ -1312,6 +1322,12 @@ impl AsLogicalPlan for LogicalPlanNode { converted_order_exprs.push(temp); } + let mut converted_column_defaults = + HashMap::with_capacity(column_defaults.len()); + for (col_name, expr) in column_defaults { + converted_column_defaults.insert(col_name.clone(), expr.try_into()?); + } + Ok(protobuf::LogicalPlanNode { logical_plan_type: Some(LogicalPlanType::CreateExternalTable( protobuf::CreateExternalTableNode { @@ -1329,6 +1345,7 @@ impl AsLogicalPlan for LogicalPlanNode { unbounded: *unbounded, options: options.clone(), constraints: Some(constraints.clone().into()), + column_defaults: converted_column_defaults, }, )), }) diff --git a/datafusion/proto/src/logical_plan/to_proto.rs b/datafusion/proto/src/logical_plan/to_proto.rs index ab8e850014e5..2997d147424d 100644 --- a/datafusion/proto/src/logical_plan/to_proto.rs +++ b/datafusion/proto/src/logical_plan/to_proto.rs @@ -108,6 +108,8 @@ impl TryFrom<&Field> for protobuf::Field { nullable: field.is_nullable(), children: Vec::new(), metadata: field.metadata().clone(), + dict_id: field.dict_id().unwrap_or(0), + dict_ordered: field.dict_is_ordered().unwrap_or(false), }) } } @@ -792,40 +794,39 @@ impl TryFrom<&Expr> for protobuf::LogicalExprNode { .to_string(), )) } - Expr::ScalarFunction(ScalarFunction { func_def, args }) => match func_def { - ScalarFunctionDefinition::BuiltIn(fun) => { - let fun: protobuf::ScalarFunction = fun.try_into()?; - let args: Vec = args - .iter() - .map(|e| e.try_into()) - .collect::, Error>>()?; - Self { - expr_type: Some(ExprType::ScalarFunction( - protobuf::ScalarFunctionNode { - fun: fun.into(), + Expr::ScalarFunction(ScalarFunction { func_def, args }) => { + let args = args + .iter() + .map(|expr| expr.try_into()) + .collect::, Error>>()?; + match func_def { + ScalarFunctionDefinition::BuiltIn(fun) => { + let fun: protobuf::ScalarFunction = fun.try_into()?; + Self { + expr_type: Some(ExprType::ScalarFunction( + protobuf::ScalarFunctionNode { + fun: fun.into(), + args, + }, + )), + } + } + ScalarFunctionDefinition::UDF(fun) => Self { + expr_type: Some(ExprType::ScalarUdfExpr( + protobuf::ScalarUdfExprNode { + fun_name: fun.name().to_string(), args, }, )), - } - } - ScalarFunctionDefinition::UDF(fun) => Self { - expr_type: Some(ExprType::ScalarUdfExpr( - protobuf::ScalarUdfExprNode { - fun_name: fun.name().to_string(), - args: args - .iter() - .map(|expr| expr.try_into()) - .collect::, Error>>()?, - }, - )), - }, - ScalarFunctionDefinition::Name(_) => { - return Err(Error::NotImplemented( + }, + ScalarFunctionDefinition::Name(_) => { + return Err(Error::NotImplemented( "Proto serialization error: Trying to serialize a unresolved function" .to_string(), )); + } } - }, + } Expr::Not(expr) => { let expr = Box::new(protobuf::Not { expr: Some(Box::new(expr.as_ref().try_into()?)), @@ -1503,6 +1504,7 @@ impl TryFrom<&BuiltinScalarFunction> for protobuf::ScalarFunction { BuiltinScalarFunction::Rtrim => Self::Rtrim, BuiltinScalarFunction::ToTimestamp => Self::ToTimestamp, BuiltinScalarFunction::ArrayAppend => Self::ArrayAppend, + BuiltinScalarFunction::ArraySort => Self::ArraySort, BuiltinScalarFunction::ArrayConcat => Self::ArrayConcat, BuiltinScalarFunction::ArrayEmpty => Self::ArrayEmpty, BuiltinScalarFunction::ArrayExcept => Self::ArrayExcept, @@ -1510,6 +1512,7 @@ impl TryFrom<&BuiltinScalarFunction> for protobuf::ScalarFunction { BuiltinScalarFunction::ArrayHasAny => Self::ArrayHasAny, BuiltinScalarFunction::ArrayHas => Self::ArrayHas, BuiltinScalarFunction::ArrayDims => Self::ArrayDims, + BuiltinScalarFunction::ArrayDistinct => Self::ArrayDistinct, BuiltinScalarFunction::ArrayElement => Self::ArrayElement, BuiltinScalarFunction::Flatten => Self::Flatten, BuiltinScalarFunction::ArrayLength => Self::ArrayLength, diff --git a/datafusion/proto/src/physical_plan/mod.rs b/datafusion/proto/src/physical_plan/mod.rs index 6714c35dc615..73091a6fced9 100644 --- a/datafusion/proto/src/physical_plan/mod.rs +++ b/datafusion/proto/src/physical_plan/mod.rs @@ -44,16 +44,16 @@ use datafusion::physical_plan::joins::{ }; use datafusion::physical_plan::joins::{HashJoinExec, PartitionMode}; use datafusion::physical_plan::limit::{GlobalLimitExec, LocalLimitExec}; +use datafusion::physical_plan::placeholder_row::PlaceholderRowExec; use datafusion::physical_plan::projection::ProjectionExec; use datafusion::physical_plan::repartition::RepartitionExec; use datafusion::physical_plan::sorts::sort::SortExec; use datafusion::physical_plan::sorts::sort_preserving_merge::SortPreservingMergeExec; -use datafusion::physical_plan::union::UnionExec; -use datafusion::physical_plan::windows::{ - BoundedWindowAggExec, PartitionSearchMode, WindowAggExec, -}; +use datafusion::physical_plan::union::{InterleaveExec, UnionExec}; +use datafusion::physical_plan::windows::{BoundedWindowAggExec, WindowAggExec}; use datafusion::physical_plan::{ - udaf, AggregateExpr, ExecutionPlan, Partitioning, PhysicalExpr, WindowExpr, + udaf, AggregateExpr, ExecutionPlan, InputOrderMode, Partitioning, PhysicalExpr, + WindowExpr, }; use datafusion_common::{internal_err, not_impl_err, DataFusionError, Result}; use prost::bytes::BufMut; @@ -159,7 +159,16 @@ impl AsExecutionPlan for PhysicalPlanNode { .to_owned(), ) })?; - Ok(Arc::new(FilterExec::try_new(predicate, input)?)) + let filter_selectivity = filter.default_filter_selectivity.try_into(); + let filter = FilterExec::try_new(predicate, input)?; + match filter_selectivity { + Ok(filter_selectivity) => Ok(Arc::new( + filter.with_default_selectivity(filter_selectivity)?, + )), + Err(_) => Err(DataFusionError::Internal( + "filter_selectivity in PhysicalPlanNode is invalid ".to_owned(), + )), + } } PhysicalPlanType::CsvScan(scan) => Ok(Arc::new(CsvExec::new( parse_protobuf_file_scan_config( @@ -313,20 +322,18 @@ impl AsExecutionPlan for PhysicalPlanNode { }) .collect::>>>()?; - if let Some(partition_search_mode) = - window_agg.partition_search_mode.as_ref() - { - let partition_search_mode = match partition_search_mode { - window_agg_exec_node::PartitionSearchMode::Linear(_) => { - PartitionSearchMode::Linear + if let Some(input_order_mode) = window_agg.input_order_mode.as_ref() { + let input_order_mode = match input_order_mode { + window_agg_exec_node::InputOrderMode::Linear(_) => { + InputOrderMode::Linear } - window_agg_exec_node::PartitionSearchMode::PartiallySorted( - protobuf::PartiallySortedPartitionSearchMode { columns }, - ) => PartitionSearchMode::PartiallySorted( + window_agg_exec_node::InputOrderMode::PartiallySorted( + protobuf::PartiallySortedInputOrderMode { columns }, + ) => InputOrderMode::PartiallySorted( columns.iter().map(|c| *c as usize).collect(), ), - window_agg_exec_node::PartitionSearchMode::Sorted(_) => { - PartitionSearchMode::Sorted + window_agg_exec_node::InputOrderMode::Sorted(_) => { + InputOrderMode::Sorted } }; @@ -334,7 +341,7 @@ impl AsExecutionPlan for PhysicalPlanNode { physical_window_expr, input, partition_keys, - partition_search_mode, + input_order_mode, )?)) } else { Ok(Arc::new(WindowAggExec::try_new( @@ -539,7 +546,7 @@ impl AsExecutionPlan for PhysicalPlanNode { f.expression.as_ref().ok_or_else(|| { proto_error("Unexpected empty filter expression") })?, - registry, &schema + registry, &schema, )?; let column_indices = f.column_indices .iter() @@ -550,7 +557,7 @@ impl AsExecutionPlan for PhysicalPlanNode { i.side)) )?; - Ok(ColumnIndex{ + Ok(ColumnIndex { index: i.index as usize, side: side.into(), }) @@ -628,7 +635,7 @@ impl AsExecutionPlan for PhysicalPlanNode { f.expression.as_ref().ok_or_else(|| { proto_error("Unexpected empty filter expression") })?, - registry, &schema + registry, &schema, )?; let column_indices = f.column_indices .iter() @@ -639,7 +646,7 @@ impl AsExecutionPlan for PhysicalPlanNode { i.side)) )?; - Ok(ColumnIndex{ + Ok(ColumnIndex { index: i.index as usize, side: side.into(), }) @@ -687,6 +694,17 @@ impl AsExecutionPlan for PhysicalPlanNode { } Ok(Arc::new(UnionExec::new(inputs))) } + PhysicalPlanType::Interleave(interleave) => { + let mut inputs: Vec> = vec![]; + for input in &interleave.inputs { + inputs.push(input.try_into_physical_plan( + registry, + runtime, + extension_codec, + )?); + } + Ok(Arc::new(InterleaveExec::try_new(inputs)?)) + } PhysicalPlanType::CrossJoin(crossjoin) => { let left: Arc = into_physical_plan( &crossjoin.left, @@ -704,7 +722,11 @@ impl AsExecutionPlan for PhysicalPlanNode { } PhysicalPlanType::Empty(empty) => { let schema = Arc::new(convert_required!(empty.schema)?); - Ok(Arc::new(EmptyExec::new(empty.produce_one_row, schema))) + Ok(Arc::new(EmptyExec::new(schema))) + } + PhysicalPlanType::PlaceholderRow(placeholder) => { + let schema = Arc::new(convert_required!(placeholder.schema)?); + Ok(Arc::new(PlaceholderRowExec::new(schema))) } PhysicalPlanType::Sort(sort) => { let input: Arc = @@ -729,7 +751,7 @@ impl AsExecutionPlan for PhysicalPlanNode { })? .as_ref(); Ok(PhysicalSortExpr { - expr: parse_physical_expr(expr,registry, input.schema().as_ref())?, + expr: parse_physical_expr(expr, registry, input.schema().as_ref())?, options: SortOptions { descending: !sort_expr.asc, nulls_first: sort_expr.nulls_first, @@ -776,7 +798,7 @@ impl AsExecutionPlan for PhysicalPlanNode { })? .as_ref(); Ok(PhysicalSortExpr { - expr: parse_physical_expr(expr,registry, input.schema().as_ref())?, + expr: parse_physical_expr(expr, registry, input.schema().as_ref())?, options: SortOptions { descending: !sort_expr.asc, nulls_first: sort_expr.nulls_first, @@ -839,7 +861,7 @@ impl AsExecutionPlan for PhysicalPlanNode { f.expression.as_ref().ok_or_else(|| { proto_error("Unexpected empty filter expression") })?, - registry, &schema + registry, &schema, )?; let column_indices = f.column_indices .iter() @@ -850,7 +872,7 @@ impl AsExecutionPlan for PhysicalPlanNode { i.side)) )?; - Ok(ColumnIndex{ + Ok(ColumnIndex { index: i.index as usize, side: side.into(), }) @@ -991,6 +1013,7 @@ impl AsExecutionPlan for PhysicalPlanNode { protobuf::FilterExecNode { input: Some(Box::new(input)), expr: Some(exec.predicate().clone().try_into()?), + default_filter_selectivity: exec.default_selectivity() as u32, }, ))), }); @@ -1289,7 +1312,17 @@ impl AsExecutionPlan for PhysicalPlanNode { return Ok(protobuf::PhysicalPlanNode { physical_plan_type: Some(PhysicalPlanType::Empty( protobuf::EmptyExecNode { - produce_one_row: empty.produce_one_row(), + schema: Some(schema), + }, + )), + }); + } + + if let Some(empty) = plan.downcast_ref::() { + let schema = empty.schema().as_ref().try_into()?; + return Ok(protobuf::PhysicalPlanNode { + physical_plan_type: Some(PhysicalPlanType::PlaceholderRow( + protobuf::PlaceholderRowExecNode { schema: Some(schema), }, )), @@ -1456,6 +1489,21 @@ impl AsExecutionPlan for PhysicalPlanNode { }); } + if let Some(interleave) = plan.downcast_ref::() { + let mut inputs: Vec = vec![]; + for input in interleave.inputs() { + inputs.push(protobuf::PhysicalPlanNode::try_from_physical_plan( + input.to_owned(), + extension_codec, + )?); + } + return Ok(protobuf::PhysicalPlanNode { + physical_plan_type: Some(PhysicalPlanType::Interleave( + protobuf::InterleaveExecNode { inputs }, + )), + }); + } + if let Some(exec) = plan.downcast_ref::() { let input = protobuf::PhysicalPlanNode::try_from_physical_plan( exec.input().to_owned(), @@ -1560,7 +1608,7 @@ impl AsExecutionPlan for PhysicalPlanNode { input: Some(Box::new(input)), window_expr, partition_keys, - partition_search_mode: None, + input_order_mode: None, }, ))), }); @@ -1584,24 +1632,20 @@ impl AsExecutionPlan for PhysicalPlanNode { .map(|e| e.clone().try_into()) .collect::>>()?; - let partition_search_mode = match &exec.partition_search_mode { - PartitionSearchMode::Linear => { - window_agg_exec_node::PartitionSearchMode::Linear( - protobuf::EmptyMessage {}, - ) - } - PartitionSearchMode::PartiallySorted(columns) => { - window_agg_exec_node::PartitionSearchMode::PartiallySorted( - protobuf::PartiallySortedPartitionSearchMode { + let input_order_mode = match &exec.input_order_mode { + InputOrderMode::Linear => window_agg_exec_node::InputOrderMode::Linear( + protobuf::EmptyMessage {}, + ), + InputOrderMode::PartiallySorted(columns) => { + window_agg_exec_node::InputOrderMode::PartiallySorted( + protobuf::PartiallySortedInputOrderMode { columns: columns.iter().map(|c| *c as u64).collect(), }, ) } - PartitionSearchMode::Sorted => { - window_agg_exec_node::PartitionSearchMode::Sorted( - protobuf::EmptyMessage {}, - ) - } + InputOrderMode::Sorted => window_agg_exec_node::InputOrderMode::Sorted( + protobuf::EmptyMessage {}, + ), }; return Ok(protobuf::PhysicalPlanNode { @@ -1610,7 +1654,7 @@ impl AsExecutionPlan for PhysicalPlanNode { input: Some(Box::new(input)), window_expr, partition_keys, - partition_search_mode: Some(partition_search_mode), + input_order_mode: Some(input_order_mode), }, ))), }); diff --git a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs index 45727c39a373..8e15b5d0d480 100644 --- a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs @@ -217,11 +217,10 @@ async fn roundtrip_custom_memory_tables() -> Result<()> { async fn roundtrip_custom_listing_tables() -> Result<()> { let ctx = SessionContext::new(); - // Make sure during round-trip, constraint information is preserved let query = "CREATE EXTERNAL TABLE multiple_ordered_table_with_pk ( a0 INTEGER, - a INTEGER, - b INTEGER, + a INTEGER DEFAULT 1*2 + 3, + b INTEGER DEFAULT NULL, c INTEGER, d INTEGER, primary key(c) @@ -232,11 +231,13 @@ async fn roundtrip_custom_listing_tables() -> Result<()> { WITH ORDER (c ASC) LOCATION '../core/tests/data/window_2.csv';"; - let plan = ctx.sql(query).await?.into_optimized_plan()?; + let plan = ctx.state().create_logical_plan(query).await?; let bytes = logical_plan_to_bytes(&plan)?; let logical_round_trip = logical_plan_from_bytes(&bytes, &ctx)?; - assert_eq!(format!("{plan:?}"), format!("{logical_round_trip:?}")); + // Use exact matching to verify everything. Make sure during round-trip, + // information like constraints, column defaults, and other aspects of the plan are preserved. + assert_eq!(plan, logical_round_trip); Ok(()) } @@ -730,7 +731,7 @@ fn round_trip_scalar_values() { ))), ScalarValue::Dictionary( Box::new(DataType::Int32), - Box::new(ScalarValue::Utf8(Some("foo".into()))), + Box::new(ScalarValue::from("foo")), ), ScalarValue::Dictionary( Box::new(DataType::Int32), @@ -971,6 +972,45 @@ fn round_trip_datatype() { } } +#[test] +fn roundtrip_dict_id() -> Result<()> { + let dict_id = 42; + let field = Field::new( + "keys", + DataType::List(Arc::new(Field::new_dict( + "item", + DataType::Dictionary(Box::new(DataType::UInt16), Box::new(DataType::Utf8)), + true, + dict_id, + false, + ))), + false, + ); + let schema = Arc::new(Schema::new(vec![field])); + + // encode + let mut buf: Vec = vec![]; + let schema_proto: datafusion_proto::generated::datafusion::Schema = + schema.try_into().unwrap(); + schema_proto.encode(&mut buf).unwrap(); + + // decode + let schema_proto = + datafusion_proto::generated::datafusion::Schema::decode(buf.as_slice()).unwrap(); + let decoded: Schema = (&schema_proto).try_into()?; + + // assert + let keys = decoded.fields().iter().last().unwrap(); + match keys.data_type() { + DataType::List(field) => { + assert_eq!(field.dict_id(), Some(dict_id), "dict_id should be retained"); + } + _ => panic!("Invalid type"), + } + + Ok(()) +} + #[test] fn roundtrip_null_scalar_values() { let test_types = vec![ diff --git a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs index d7d762d470d7..da76209dbb49 100644 --- a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs @@ -49,13 +49,16 @@ use datafusion::physical_plan::joins::{ HashJoinExec, NestedLoopJoinExec, PartitionMode, StreamJoinPartitionMode, }; use datafusion::physical_plan::limit::{GlobalLimitExec, LocalLimitExec}; +use datafusion::physical_plan::placeholder_row::PlaceholderRowExec; use datafusion::physical_plan::projection::ProjectionExec; +use datafusion::physical_plan::repartition::RepartitionExec; use datafusion::physical_plan::sorts::sort::SortExec; +use datafusion::physical_plan::union::{InterleaveExec, UnionExec}; use datafusion::physical_plan::windows::{ BuiltInWindowExpr, PlainAggregateWindowExpr, WindowAggExec, }; use datafusion::physical_plan::{ - functions, udaf, AggregateExpr, ExecutionPlan, PhysicalExpr, Statistics, + functions, udaf, AggregateExpr, ExecutionPlan, Partitioning, PhysicalExpr, Statistics, }; use datafusion::prelude::SessionContext; use datafusion::scalar::ScalarValue; @@ -102,7 +105,7 @@ fn roundtrip_test_with_context( #[test] fn roundtrip_empty() -> Result<()> { - roundtrip_test(Arc::new(EmptyExec::new(false, Arc::new(Schema::empty())))) + roundtrip_test(Arc::new(EmptyExec::new(Arc::new(Schema::empty())))) } #[test] @@ -115,7 +118,7 @@ fn roundtrip_date_time_interval() -> Result<()> { false, ), ]); - let input = Arc::new(EmptyExec::new(false, Arc::new(schema.clone()))); + let input = Arc::new(EmptyExec::new(Arc::new(schema.clone()))); let date_expr = col("some_date", &schema)?; let literal_expr = col("some_interval", &schema)?; let date_time_interval_expr = @@ -130,7 +133,7 @@ fn roundtrip_date_time_interval() -> Result<()> { #[test] fn roundtrip_local_limit() -> Result<()> { roundtrip_test(Arc::new(LocalLimitExec::new( - Arc::new(EmptyExec::new(false, Arc::new(Schema::empty()))), + Arc::new(EmptyExec::new(Arc::new(Schema::empty()))), 25, ))) } @@ -138,7 +141,7 @@ fn roundtrip_local_limit() -> Result<()> { #[test] fn roundtrip_global_limit() -> Result<()> { roundtrip_test(Arc::new(GlobalLimitExec::new( - Arc::new(EmptyExec::new(false, Arc::new(Schema::empty()))), + Arc::new(EmptyExec::new(Arc::new(Schema::empty()))), 0, Some(25), ))) @@ -147,7 +150,7 @@ fn roundtrip_global_limit() -> Result<()> { #[test] fn roundtrip_global_skip_no_limit() -> Result<()> { roundtrip_test(Arc::new(GlobalLimitExec::new( - Arc::new(EmptyExec::new(false, Arc::new(Schema::empty()))), + Arc::new(EmptyExec::new(Arc::new(Schema::empty()))), 10, None, // no limit ))) @@ -177,8 +180,8 @@ fn roundtrip_hash_join() -> Result<()> { ] { for partition_mode in &[PartitionMode::Partitioned, PartitionMode::CollectLeft] { roundtrip_test(Arc::new(HashJoinExec::try_new( - Arc::new(EmptyExec::new(false, schema_left.clone())), - Arc::new(EmptyExec::new(false, schema_right.clone())), + Arc::new(EmptyExec::new(schema_left.clone())), + Arc::new(EmptyExec::new(schema_right.clone())), on.clone(), None, join_type, @@ -209,8 +212,8 @@ fn roundtrip_nested_loop_join() -> Result<()> { JoinType::RightSemi, ] { roundtrip_test(Arc::new(NestedLoopJoinExec::try_new( - Arc::new(EmptyExec::new(false, schema_left.clone())), - Arc::new(EmptyExec::new(false, schema_right.clone())), + Arc::new(EmptyExec::new(schema_left.clone())), + Arc::new(EmptyExec::new(schema_right.clone())), None, join_type, )?))?; @@ -231,21 +234,21 @@ fn roundtrip_window() -> Result<()> { }; let builtin_window_expr = Arc::new(BuiltInWindowExpr::new( - Arc::new(NthValue::first( - "FIRST_VALUE(a) PARTITION BY [b] ORDER BY [a ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", - col("a", &schema)?, - DataType::Int64, - )), - &[col("b", &schema)?], - &[PhysicalSortExpr { - expr: col("a", &schema)?, - options: SortOptions { - descending: false, - nulls_first: false, - }, - }], - Arc::new(window_frame), - )); + Arc::new(NthValue::first( + "FIRST_VALUE(a) PARTITION BY [b] ORDER BY [a ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", + col("a", &schema)?, + DataType::Int64, + )), + &[col("b", &schema)?], + &[PhysicalSortExpr { + expr: col("a", &schema)?, + options: SortOptions { + descending: false, + nulls_first: false, + }, + }], + Arc::new(window_frame), + )); let plain_aggr_window_expr = Arc::new(PlainAggregateWindowExpr::new( Arc::new(Avg::new( @@ -275,7 +278,7 @@ fn roundtrip_window() -> Result<()> { Arc::new(window_frame), )); - let input = Arc::new(EmptyExec::new(false, schema.clone())); + let input = Arc::new(EmptyExec::new(schema.clone())); roundtrip_test(Arc::new(WindowAggExec::try_new( vec![ @@ -309,7 +312,7 @@ fn rountrip_aggregate() -> Result<()> { aggregates.clone(), vec![None], vec![None], - Arc::new(EmptyExec::new(false, schema.clone())), + Arc::new(EmptyExec::new(schema.clone())), schema, )?)) } @@ -377,7 +380,7 @@ fn roundtrip_aggregate_udaf() -> Result<()> { aggregates.clone(), vec![None], vec![None], - Arc::new(EmptyExec::new(false, schema.clone())), + Arc::new(EmptyExec::new(schema.clone())), schema, )?), ctx, @@ -403,7 +406,7 @@ fn roundtrip_filter_with_not_and_in_list() -> Result<()> { let and = binary(not, Operator::And, in_list, &schema)?; roundtrip_test(Arc::new(FilterExec::try_new( and, - Arc::new(EmptyExec::new(false, schema.clone())), + Arc::new(EmptyExec::new(schema.clone())), )?)) } @@ -430,7 +433,7 @@ fn roundtrip_sort() -> Result<()> { ]; roundtrip_test(Arc::new(SortExec::new( sort_exprs, - Arc::new(EmptyExec::new(false, schema)), + Arc::new(EmptyExec::new(schema)), ))) } @@ -458,11 +461,11 @@ fn roundtrip_sort_preserve_partitioning() -> Result<()> { roundtrip_test(Arc::new(SortExec::new( sort_exprs.clone(), - Arc::new(EmptyExec::new(false, schema.clone())), + Arc::new(EmptyExec::new(schema.clone())), )))?; roundtrip_test(Arc::new( - SortExec::new(sort_exprs, Arc::new(EmptyExec::new(false, schema))) + SortExec::new(sort_exprs, Arc::new(EmptyExec::new(schema))) .with_preserve_partitioning(true), )) } @@ -512,7 +515,7 @@ fn roundtrip_builtin_scalar_function() -> Result<()> { let field_b = Field::new("b", DataType::Int64, false); let schema = Arc::new(Schema::new(vec![field_a, field_b])); - let input = Arc::new(EmptyExec::new(false, schema.clone())); + let input = Arc::new(EmptyExec::new(schema.clone())); let execution_props = ExecutionProps::new(); @@ -539,7 +542,7 @@ fn roundtrip_scalar_udf() -> Result<()> { let field_b = Field::new("b", DataType::Int64, false); let schema = Arc::new(Schema::new(vec![field_a, field_b])); - let input = Arc::new(EmptyExec::new(false, schema.clone())); + let input = Arc::new(EmptyExec::new(schema.clone())); let fn_impl = |args: &[ArrayRef]| Ok(Arc::new(args[0].clone()) as ArrayRef); @@ -592,7 +595,7 @@ fn roundtrip_distinct_count() -> Result<()> { aggregates.clone(), vec![None], vec![None], - Arc::new(EmptyExec::new(false, schema.clone())), + Arc::new(EmptyExec::new(schema.clone())), schema, )?)) } @@ -603,7 +606,7 @@ fn roundtrip_like() -> Result<()> { Field::new("a", DataType::Utf8, false), Field::new("b", DataType::Utf8, false), ]); - let input = Arc::new(EmptyExec::new(false, Arc::new(schema.clone()))); + let input = Arc::new(EmptyExec::new(Arc::new(schema.clone()))); let like_expr = like( false, false, @@ -630,13 +633,13 @@ fn roundtrip_get_indexed_field_named_struct_field() -> Result<()> { ]; let schema = Schema::new(fields); - let input = Arc::new(EmptyExec::new(false, Arc::new(schema.clone()))); + let input = Arc::new(EmptyExec::new(Arc::new(schema.clone()))); let col_arg = col("arg", &schema)?; let get_indexed_field_expr = Arc::new(GetIndexedFieldExpr::new( col_arg, GetFieldAccessExpr::NamedStructField { - name: ScalarValue::Utf8(Some(String::from("name"))), + name: ScalarValue::from("name"), }, )); @@ -657,7 +660,7 @@ fn roundtrip_get_indexed_field_list_index() -> Result<()> { ]; let schema = Schema::new(fields); - let input = Arc::new(EmptyExec::new(true, Arc::new(schema.clone()))); + let input = Arc::new(PlaceholderRowExec::new(Arc::new(schema.clone()))); let col_arg = col("arg", &schema)?; let col_key = col("key", &schema)?; @@ -684,7 +687,7 @@ fn roundtrip_get_indexed_field_list_range() -> Result<()> { ]; let schema = Schema::new(fields); - let input = Arc::new(EmptyExec::new(false, Arc::new(schema.clone()))); + let input = Arc::new(EmptyExec::new(Arc::new(schema.clone()))); let col_arg = col("arg", &schema)?; let col_start = col("start", &schema)?; @@ -710,7 +713,7 @@ fn roundtrip_analyze() -> Result<()> { let field_a = Field::new("plan_type", DataType::Utf8, false); let field_b = Field::new("plan", DataType::Utf8, false); let schema = Schema::new(vec![field_a, field_b]); - let input = Arc::new(EmptyExec::new(true, Arc::new(schema.clone()))); + let input = Arc::new(PlaceholderRowExec::new(Arc::new(schema.clone()))); roundtrip_test(Arc::new(AnalyzeExec::new( false, @@ -725,7 +728,7 @@ fn roundtrip_json_sink() -> Result<()> { let field_a = Field::new("plan_type", DataType::Utf8, false); let field_b = Field::new("plan", DataType::Utf8, false); let schema = Arc::new(Schema::new(vec![field_a, field_b])); - let input = Arc::new(EmptyExec::new(true, schema.clone())); + let input = Arc::new(PlaceholderRowExec::new(schema.clone())); let file_sink_config = FileSinkConfig { object_store_url: ObjectStoreUrl::local_filesystem(), @@ -785,8 +788,8 @@ fn roundtrip_sym_hash_join() -> Result<()> { ] { roundtrip_test(Arc::new( datafusion::physical_plan::joins::SymmetricHashJoinExec::try_new( - Arc::new(EmptyExec::new(false, schema_left.clone())), - Arc::new(EmptyExec::new(false, schema_right.clone())), + Arc::new(EmptyExec::new(schema_left.clone())), + Arc::new(EmptyExec::new(schema_right.clone())), on.clone(), None, join_type, @@ -798,3 +801,34 @@ fn roundtrip_sym_hash_join() -> Result<()> { } Ok(()) } + +#[test] +fn roundtrip_union() -> Result<()> { + let field_a = Field::new("col", DataType::Int64, false); + let schema_left = Schema::new(vec![field_a.clone()]); + let schema_right = Schema::new(vec![field_a]); + let left = EmptyExec::new(Arc::new(schema_left)); + let right = EmptyExec::new(Arc::new(schema_right)); + let inputs: Vec> = vec![Arc::new(left), Arc::new(right)]; + let union = UnionExec::new(inputs); + roundtrip_test(Arc::new(union)) +} + +#[test] +fn roundtrip_interleave() -> Result<()> { + let field_a = Field::new("col", DataType::Int64, false); + let schema_left = Schema::new(vec![field_a.clone()]); + let schema_right = Schema::new(vec![field_a]); + let partition = Partitioning::Hash(vec![], 3); + let left = RepartitionExec::try_new( + Arc::new(EmptyExec::new(Arc::new(schema_left))), + partition.clone(), + )?; + let right = RepartitionExec::try_new( + Arc::new(EmptyExec::new(Arc::new(schema_right))), + partition.clone(), + )?; + let inputs: Vec> = vec![Arc::new(left), Arc::new(right)]; + let interleave = InterleaveExec::try_new(inputs)?; + roundtrip_test(Arc::new(interleave)) +} diff --git a/datafusion/sql/src/expr/function.rs b/datafusion/sql/src/expr/function.rs index 958e03879842..73de4fa43907 100644 --- a/datafusion/sql/src/expr/function.rs +++ b/datafusion/sql/src/expr/function.rs @@ -21,7 +21,7 @@ use datafusion_common::{ }; use datafusion_expr::expr::ScalarFunction; use datafusion_expr::function::suggest_valid_function; -use datafusion_expr::window_frame::regularize; +use datafusion_expr::window_frame::{check_window_frame, regularize_window_order_by}; use datafusion_expr::{ expr, window_function, AggregateFunction, BuiltinScalarFunction, Expr, WindowFrame, WindowFunction, @@ -92,21 +92,30 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { .into_iter() .map(|e| self.sql_expr_to_logical_expr(e, schema, planner_context)) .collect::>>()?; - let order_by = - self.order_by_to_sort_expr(&window.order_by, schema, planner_context)?; + let mut order_by = self.order_by_to_sort_expr( + &window.order_by, + schema, + planner_context, + // Numeric literals in window function ORDER BY are treated as constants + false, + )?; let window_frame = window .window_frame .as_ref() .map(|window_frame| { let window_frame = window_frame.clone().try_into()?; - regularize(window_frame, order_by.len()) + check_window_frame(&window_frame, order_by.len()) + .map(|_| window_frame) }) .transpose()?; + let window_frame = if let Some(window_frame) = window_frame { + regularize_window_order_by(&window_frame, &mut order_by)?; window_frame } else { WindowFrame::new(!order_by.is_empty()) }; + if let Ok(fun) = self.find_window_func(&name) { let expr = match fun { WindowFunction::AggregateFunction(aggregate_fun) => { @@ -143,7 +152,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { // next, aggregate built-ins if let Ok(fun) = AggregateFunction::from_str(&name) { let order_by = - self.order_by_to_sort_expr(&order_by, schema, planner_context)?; + self.order_by_to_sort_expr(&order_by, schema, planner_context, true)?; let order_by = (!order_by.is_empty()).then_some(order_by); let args = self.function_args_to_expr(args, schema, planner_context)?; let filter: Option> = filter diff --git a/datafusion/sql/src/expr/mod.rs b/datafusion/sql/src/expr/mod.rs index b8c130055a5a..27351e10eb34 100644 --- a/datafusion/sql/src/expr/mod.rs +++ b/datafusion/sql/src/expr/mod.rs @@ -171,7 +171,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { Ok(Expr::ScalarFunction(ScalarFunction::new( BuiltinScalarFunction::DatePart, vec![ - Expr::Literal(ScalarValue::Utf8(Some(format!("{field}")))), + Expr::Literal(ScalarValue::from(format!("{field}"))), self.sql_expr_to_logical_expr(*expr, schema, planner_context)?, ], ))) @@ -555,7 +555,12 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { } = array_agg; let order_by = if let Some(order_by) = order_by { - Some(self.order_by_to_sort_expr(&order_by, input_schema, planner_context)?) + Some(self.order_by_to_sort_expr( + &order_by, + input_schema, + planner_context, + true, + )?) } else { None }; @@ -739,7 +744,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { SQLExpr::Value( Value::SingleQuotedString(s) | Value::DoubleQuotedString(s), ) => GetFieldAccess::NamedStructField { - name: ScalarValue::Utf8(Some(s)), + name: ScalarValue::from(s), }, SQLExpr::JsonAccess { left, diff --git a/datafusion/sql/src/expr/order_by.rs b/datafusion/sql/src/expr/order_by.rs index 1dccc2376f0b..772255bd9773 100644 --- a/datafusion/sql/src/expr/order_by.rs +++ b/datafusion/sql/src/expr/order_by.rs @@ -24,12 +24,17 @@ use datafusion_expr::Expr; use sqlparser::ast::{Expr as SQLExpr, OrderByExpr, Value}; impl<'a, S: ContextProvider> SqlToRel<'a, S> { - /// convert sql [OrderByExpr] to `Vec` + /// Convert sql [OrderByExpr] to `Vec`. + /// + /// If `literal_to_column` is true, treat any numeric literals (e.g. `2`) as a 1 based index + /// into the SELECT list (e.g. `SELECT a, b FROM table ORDER BY 2`). + /// If false, interpret numeric literals as constant values. pub(crate) fn order_by_to_sort_expr( &self, exprs: &[OrderByExpr], schema: &DFSchema, planner_context: &mut PlannerContext, + literal_to_column: bool, ) -> Result> { let mut expr_vec = vec![]; for e in exprs { @@ -40,7 +45,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { } = e; let expr = match expr { - SQLExpr::Value(Value::Number(v, _)) => { + SQLExpr::Value(Value::Number(v, _)) if literal_to_column => { let field_index = v .parse::() .map_err(|err| plan_datafusion_err!("{}", err))?; diff --git a/datafusion/sql/src/expr/value.rs b/datafusion/sql/src/expr/value.rs index a3f29da488ba..708f7c60011a 100644 --- a/datafusion/sql/src/expr/value.rs +++ b/datafusion/sql/src/expr/value.rs @@ -108,7 +108,12 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { } Ok(index) => index - 1, Err(_) => { - return plan_err!("Invalid placeholder, not a number: {param}"); + return if param_data_types.is_empty() { + Ok(Expr::Placeholder(Placeholder::new(param, None))) + } else { + // when PREPARE Statement, param_data_types length is always 0 + plan_err!("Invalid placeholder, not a number: {param}") + }; } }; // Check if the placeholder is in the parameter list diff --git a/datafusion/sql/src/planner.rs b/datafusion/sql/src/planner.rs index 622e5aca799a..c5c30e3a2253 100644 --- a/datafusion/sql/src/planner.rs +++ b/datafusion/sql/src/planner.rs @@ -52,6 +52,15 @@ pub trait ContextProvider { } /// Getter for a datasource fn get_table_source(&self, name: TableReference) -> Result>; + /// Getter for a table function + fn get_table_function_source( + &self, + _name: &str, + _args: Vec, + ) -> Result> { + not_impl_err!("Table Functions are not supported") + } + /// Getter for a UDF description fn get_function_meta(&self, name: &str) -> Option>; /// Getter for a UDAF description diff --git a/datafusion/sql/src/query.rs b/datafusion/sql/src/query.rs index 643f41d84485..dd4cab126261 100644 --- a/datafusion/sql/src/query.rs +++ b/datafusion/sql/src/query.rs @@ -161,7 +161,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { } let order_by_rex = - self.order_by_to_sort_expr(&order_by, plan.schema(), planner_context)?; + self.order_by_to_sort_expr(&order_by, plan.schema(), planner_context, true)?; if let LogicalPlan::Distinct(Distinct::On(ref distinct_on)) = plan { // In case of `DISTINCT ON` we must capture the sort expressions since during the plan diff --git a/datafusion/sql/src/relation/mod.rs b/datafusion/sql/src/relation/mod.rs index 180743d19b7b..b233f47a058f 100644 --- a/datafusion/sql/src/relation/mod.rs +++ b/datafusion/sql/src/relation/mod.rs @@ -16,9 +16,11 @@ // under the License. use crate::planner::{ContextProvider, PlannerContext, SqlToRel}; -use datafusion_common::{not_impl_err, DataFusionError, Result}; +use datafusion_common::{ + not_impl_err, plan_err, DFSchema, DataFusionError, Result, TableReference, +}; use datafusion_expr::{LogicalPlan, LogicalPlanBuilder}; -use sqlparser::ast::TableFactor; +use sqlparser::ast::{FunctionArg, FunctionArgExpr, TableFactor}; mod join; @@ -30,24 +32,58 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { planner_context: &mut PlannerContext, ) -> Result { let (plan, alias) = match relation { - TableFactor::Table { name, alias, .. } => { - // normalize name and alias - let table_ref = self.object_name_to_table_reference(name)?; - let table_name = table_ref.to_string(); - let cte = planner_context.get_cte(&table_name); - ( - match ( - cte, - self.context_provider.get_table_source(table_ref.clone()), - ) { - (Some(cte_plan), _) => Ok(cte_plan.clone()), - (_, Ok(provider)) => { - LogicalPlanBuilder::scan(table_ref, provider, None)?.build() - } - (None, Err(e)) => Err(e), - }?, - alias, - ) + TableFactor::Table { + name, alias, args, .. + } => { + if let Some(func_args) = args { + let tbl_func_name = name.0.first().unwrap().value.to_string(); + let args = func_args + .into_iter() + .flat_map(|arg| { + if let FunctionArg::Unnamed(FunctionArgExpr::Expr(expr)) = arg + { + self.sql_expr_to_logical_expr( + expr, + &DFSchema::empty(), + planner_context, + ) + } else { + plan_err!("Unsupported function argument type: {:?}", arg) + } + }) + .collect::>(); + let provider = self + .context_provider + .get_table_function_source(&tbl_func_name, args)?; + let plan = LogicalPlanBuilder::scan( + TableReference::Bare { + table: std::borrow::Cow::Borrowed("tmp_table"), + }, + provider, + None, + )? + .build()?; + (plan, alias) + } else { + // normalize name and alias + let table_ref = self.object_name_to_table_reference(name)?; + let table_name = table_ref.to_string(); + let cte = planner_context.get_cte(&table_name); + ( + match ( + cte, + self.context_provider.get_table_source(table_ref.clone()), + ) { + (Some(cte_plan), _) => Ok(cte_plan.clone()), + (_, Ok(provider)) => { + LogicalPlanBuilder::scan(table_ref, provider, None)? + .build() + } + (None, Err(e)) => Err(e), + }?, + alias, + ) + } } TableFactor::Derived { subquery, alias, .. diff --git a/datafusion/sql/src/select.rs b/datafusion/sql/src/select.rs index c546ca755206..a0819e4aaf8e 100644 --- a/datafusion/sql/src/select.rs +++ b/datafusion/sql/src/select.rs @@ -25,10 +25,7 @@ use crate::utils::{ }; use datafusion_common::Column; -use datafusion_common::{ - get_target_functional_dependencies, not_impl_err, plan_err, DFSchemaRef, - DataFusionError, Result, -}; +use datafusion_common::{not_impl_err, plan_err, DataFusionError, Result}; use datafusion_expr::expr::Alias; use datafusion_expr::expr_rewriter::{ normalize_col, normalize_col_with_schemas_and_ambiguity_check, @@ -384,7 +381,12 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { &[&[plan.schema()]], &plan.using_columns()?, )?; - let expr = col.alias(self.normalizer.normalize(alias)); + let name = self.normalizer.normalize(alias); + // avoiding adding an alias if the column name is the same. + let expr = match &col { + Expr::Column(column) if column.name.eq(&name) => col, + _ => col.alias(name), + }; Ok(vec![expr]) } SelectItem::Wildcard(options) => { @@ -529,14 +531,17 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { group_by_exprs: &[Expr], aggr_exprs: &[Expr], ) -> Result<(LogicalPlan, Vec, Option)> { - let group_by_exprs = - get_updated_group_by_exprs(group_by_exprs, select_exprs, input.schema())?; - // create the aggregate plan let plan = LogicalPlanBuilder::from(input.clone()) - .aggregate(group_by_exprs.clone(), aggr_exprs.to_vec())? + .aggregate(group_by_exprs.to_vec(), aggr_exprs.to_vec())? .build()?; + let group_by_exprs = if let LogicalPlan::Aggregate(agg) = &plan { + &agg.group_expr + } else { + unreachable!(); + }; + // in this next section of code we are re-writing the projection to refer to columns // output by the aggregate plan. For example, if the projection contains the expression // `SUM(a)` then we replace that with a reference to a column `SUM(a)` produced by @@ -545,7 +550,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { // combine the original grouping and aggregate expressions into one list (note that // we do not add the "having" expression since that is not part of the projection) let mut aggr_projection_exprs = vec![]; - for expr in &group_by_exprs { + for expr in group_by_exprs { match expr { Expr::GroupingSet(GroupingSet::Rollup(exprs)) => { aggr_projection_exprs.extend_from_slice(exprs) @@ -654,61 +659,3 @@ fn match_window_definitions( } Ok(()) } - -/// Update group by exprs, according to functional dependencies -/// The query below -/// -/// SELECT sn, amount -/// FROM sales_global -/// GROUP BY sn -/// -/// cannot be calculated, because it has a column(`amount`) which is not -/// part of group by expression. -/// However, if we know that, `sn` is determinant of `amount`. We can -/// safely, determine value of `amount` for each distinct `sn`. For these cases -/// we rewrite the query above as -/// -/// SELECT sn, amount -/// FROM sales_global -/// GROUP BY sn, amount -/// -/// Both queries, are functionally same. \[Because, (`sn`, `amount`) and (`sn`) -/// defines the identical groups. \] -/// This function updates group by expressions such that select expressions that are -/// not in group by expression, are added to the group by expressions if they are dependent -/// of the sub-set of group by expressions. -fn get_updated_group_by_exprs( - group_by_exprs: &[Expr], - select_exprs: &[Expr], - schema: &DFSchemaRef, -) -> Result> { - let mut new_group_by_exprs = group_by_exprs.to_vec(); - let fields = schema.fields(); - let group_by_expr_names = group_by_exprs - .iter() - .map(|group_by_expr| group_by_expr.display_name()) - .collect::>>()?; - // Get targets that can be used in a select, even if they do not occur in aggregation: - if let Some(target_indices) = - get_target_functional_dependencies(schema, &group_by_expr_names) - { - // Calculate dependent fields names with determinant GROUP BY expression: - let associated_field_names = target_indices - .iter() - .map(|idx| fields[*idx].qualified_name()) - .collect::>(); - // Expand GROUP BY expressions with select expressions: If a GROUP - // BY expression is a determinant key, we can use its dependent - // columns in select statements also. - for expr in select_exprs { - let expr_name = format!("{}", expr); - if !new_group_by_exprs.contains(expr) - && associated_field_names.contains(&expr_name) - { - new_group_by_exprs.push(expr.clone()); - } - } - } - - Ok(new_group_by_exprs) -} diff --git a/datafusion/sql/src/statement.rs b/datafusion/sql/src/statement.rs index aa2f0583cb99..12083554f093 100644 --- a/datafusion/sql/src/statement.rs +++ b/datafusion/sql/src/statement.rs @@ -458,6 +458,9 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { if ignore { plan_err!("Insert-ignore clause not supported")?; } + let Some(source) = source else { + plan_err!("Inserts without a source not supported")? + }; let _ = into; // optional keyword doesn't change behavior self.insert_to_plan(table_name, columns, source, overwrite) } @@ -566,7 +569,10 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { }); Ok(LogicalPlan::Statement(statement)) } - Statement::Rollback { chain } => { + Statement::Rollback { chain, savepoint } => { + if savepoint.is_some() { + plan_err!("Savepoints not supported")?; + } let statement = PlanStatement::TransactionEnd(TransactionEnd { conclusion: TransactionConclusion::Rollback, chain, @@ -704,7 +710,8 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { let mut all_results = vec![]; for expr in order_exprs { // Convert each OrderByExpr to a SortExpr: - let expr_vec = self.order_by_to_sort_expr(&expr, schema, planner_context)?; + let expr_vec = + self.order_by_to_sort_expr(&expr, schema, planner_context, true)?; // Verify that columns of all SortExprs exist in the schema: for expr in expr_vec.iter() { for column in expr.to_columns()?.iter() { @@ -755,11 +762,18 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { )?; } + let mut planner_context = PlannerContext::new(); + + let column_defaults = self + .build_column_defaults(&columns, &mut planner_context)? + .into_iter() + .collect(); + let schema = self.build_schema(columns)?; let df_schema = schema.to_dfschema_ref()?; let ordered_exprs = - self.build_order_by(order_exprs, &df_schema, &mut PlannerContext::new())?; + self.build_order_by(order_exprs, &df_schema, &mut planner_context)?; // External tables do not support schemas at the moment, so the name is just a table name let name = OwnedTableReference::bare(name); @@ -781,6 +795,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { unbounded, options, constraints, + column_defaults, }, ))) } diff --git a/datafusion/sql/tests/sql_integration.rs b/datafusion/sql/tests/sql_integration.rs index d5b06bcf815f..48ba50145308 100644 --- a/datafusion/sql/tests/sql_integration.rs +++ b/datafusion/sql/tests/sql_integration.rs @@ -22,11 +22,11 @@ use std::{sync::Arc, vec}; use arrow_schema::*; use sqlparser::dialect::{Dialect, GenericDialect, HiveDialect, MySqlDialect}; -use datafusion_common::plan_err; use datafusion_common::{ assert_contains, config::ConfigOptions, DataFusionError, Result, ScalarValue, TableReference, }; +use datafusion_common::{plan_err, ParamValues}; use datafusion_expr::{ logical_plan::{LogicalPlan, Prepare}, AggregateUDF, ScalarUDF, TableSource, WindowUDF, @@ -471,6 +471,10 @@ Dml: op=[Insert Into] table=[test_decimal] "INSERT INTO person (id, first_name, last_name) VALUES ($2, $4, $6)", "Error during planning: Placeholder type could not be resolved" )] +#[case::placeholder_type_unresolved( + "INSERT INTO person (id, first_name, last_name) VALUES ($id, $first_name, $last_name)", + "Error during planning: Can't parse placeholder: $id" +)] #[test] fn test_insert_schema_errors(#[case] sql: &str, #[case] error: &str) { let err = logical_plan(sql).unwrap_err(); @@ -2674,7 +2678,7 @@ fn prepare_stmt_quick_test( fn prepare_stmt_replace_params_quick_test( plan: LogicalPlan, - param_values: Vec, + param_values: impl Into, expected_plan: &str, ) -> LogicalPlan { // replace params @@ -3542,13 +3546,24 @@ fn test_select_unsupported_syntax_errors(#[case] sql: &str, #[case] error: &str) fn select_order_by_with_cast() { let sql = "SELECT first_name AS first_name FROM (SELECT first_name AS first_name FROM person) ORDER BY CAST(first_name as INT)"; - let expected = "Sort: CAST(first_name AS first_name AS Int32) ASC NULLS LAST\ - \n Projection: first_name AS first_name\ - \n Projection: person.first_name AS first_name\ + let expected = "Sort: CAST(person.first_name AS Int32) ASC NULLS LAST\ + \n Projection: person.first_name\ + \n Projection: person.first_name\ \n TableScan: person"; quick_test(sql, expected); } +#[test] +fn test_avoid_add_alias() { + // avoiding adding an alias if the column name is the same. + // plan1 = plan2 + let sql = "select person.id as id from person order by person.id"; + let plan1 = logical_plan(sql).unwrap(); + let sql = "select id from person order by id"; + let plan2 = logical_plan(sql).unwrap(); + assert_eq!(format!("{plan1:?}"), format!("{plan2:?}")); +} + #[test] fn test_duplicated_left_join_key_inner_join() { // person.id * 2 happen twice in left side. @@ -3726,7 +3741,7 @@ fn test_prepare_statement_to_plan_no_param() { /////////////////// // replace params with values - let param_values = vec![]; + let param_values: Vec = vec![]; let expected_plan = "Projection: person.id, person.age\ \n Filter: person.age = Int64(10)\ \n TableScan: person"; @@ -3740,7 +3755,7 @@ fn test_prepare_statement_to_plan_one_param_no_value_panic() { let sql = "PREPARE my_plan(INT) AS SELECT id, age FROM person WHERE age = 10"; let plan = logical_plan(sql).unwrap(); // declare 1 param but provide 0 - let param_values = vec![]; + let param_values: Vec = vec![]; assert_eq!( plan.with_param_values(param_values) .unwrap_err() @@ -3853,7 +3868,7 @@ Projection: person.id, orders.order_id assert_eq!(actual_types, expected_types); // replace params with values - let param_values = vec![ScalarValue::Int32(Some(10))]; + let param_values = vec![ScalarValue::Int32(Some(10))].into(); let expected_plan = r#" Projection: person.id, orders.order_id Inner Join: Filter: person.id = orders.customer_id AND person.age = Int32(10) @@ -3885,7 +3900,7 @@ Projection: person.id, person.age assert_eq!(actual_types, expected_types); // replace params with values - let param_values = vec![ScalarValue::Int32(Some(10))]; + let param_values = vec![ScalarValue::Int32(Some(10))].into(); let expected_plan = r#" Projection: person.id, person.age Filter: person.age = Int32(10) @@ -3919,7 +3934,8 @@ Projection: person.id, person.age assert_eq!(actual_types, expected_types); // replace params with values - let param_values = vec![ScalarValue::Int32(Some(10)), ScalarValue::Int32(Some(30))]; + let param_values = + vec![ScalarValue::Int32(Some(10)), ScalarValue::Int32(Some(30))].into(); let expected_plan = r#" Projection: person.id, person.age Filter: person.age BETWEEN Int32(10) AND Int32(30) @@ -3955,7 +3971,7 @@ Projection: person.id, person.age assert_eq!(actual_types, expected_types); // replace params with values - let param_values = vec![ScalarValue::UInt32(Some(10))]; + let param_values = vec![ScalarValue::UInt32(Some(10))].into(); let expected_plan = r#" Projection: person.id, person.age Filter: person.age = () @@ -3995,7 +4011,8 @@ Dml: op=[Update] table=[person] assert_eq!(actual_types, expected_types); // replace params with values - let param_values = vec![ScalarValue::Int32(Some(42)), ScalarValue::UInt32(Some(1))]; + let param_values = + vec![ScalarValue::Int32(Some(42)), ScalarValue::UInt32(Some(1))].into(); let expected_plan = r#" Dml: op=[Update] table=[person] Projection: person.id AS id, person.first_name AS first_name, person.last_name AS last_name, Int32(42) AS age, person.state AS state, person.salary AS salary, person.birth_date AS birth_date, person.😀 AS 😀 @@ -4032,9 +4049,10 @@ fn test_prepare_statement_insert_infer() { // replace params with values let param_values = vec![ ScalarValue::UInt32(Some(1)), - ScalarValue::Utf8(Some("Alan".to_string())), - ScalarValue::Utf8(Some("Turing".to_string())), - ]; + ScalarValue::from("Alan"), + ScalarValue::from("Turing"), + ] + .into(); let expected_plan = "Dml: op=[Insert Into] table=[person]\ \n Projection: column1 AS id, column2 AS first_name, column3 AS last_name, \ CAST(NULL AS Int32) AS age, CAST(NULL AS Utf8) AS state, CAST(NULL AS Float64) AS salary, \ @@ -4113,11 +4131,11 @@ fn test_prepare_statement_to_plan_multi_params() { // replace params with values let param_values = vec![ ScalarValue::Int32(Some(10)), - ScalarValue::Utf8(Some("abc".to_string())), + ScalarValue::from("abc"), ScalarValue::Float64(Some(100.0)), ScalarValue::Int32(Some(20)), ScalarValue::Float64(Some(200.0)), - ScalarValue::Utf8(Some("xyz".to_string())), + ScalarValue::from("xyz"), ]; let expected_plan = "Projection: person.id, person.age, Utf8(\"xyz\")\ @@ -4183,8 +4201,8 @@ fn test_prepare_statement_to_plan_value_list() { /////////////////// // replace params with values let param_values = vec![ - ScalarValue::Utf8(Some("a".to_string())), - ScalarValue::Utf8(Some("b".to_string())), + ScalarValue::from("a".to_string()), + ScalarValue::from("b".to_string()), ]; let expected_plan = "Projection: t.num, t.letter\ \n SubqueryAlias: t\ diff --git a/datafusion/sqllogictest/test_files/aggregate.slt b/datafusion/sqllogictest/test_files/aggregate.slt index 88590055484f..7cfc9c707d43 100644 --- a/datafusion/sqllogictest/test_files/aggregate.slt +++ b/datafusion/sqllogictest/test_files/aggregate.slt @@ -2421,6 +2421,15 @@ select max(x_dict) from value_dict group by x_dict % 2 order by max(x_dict); 4 5 +query T +select arrow_typeof(x_dict) from value_dict group by x_dict; +---- +Int32 +Int32 +Int32 +Int32 +Int32 + statement ok drop table value @@ -3190,3 +3199,16 @@ FROM my_data GROUP BY dummy ---- text1, text1, text1 + + +# Queries with nested count(*) + +query I +select count(*) from (select count(*) from (select 1)); +---- +1 + +query I +select count(*) from (select count(*) a, count(*) b from (select 1)); +---- +1 \ No newline at end of file diff --git a/datafusion/sqllogictest/test_files/array.slt b/datafusion/sqllogictest/test_files/array.slt index 9e3ac3bf08f6..1202a2b1e99d 100644 --- a/datafusion/sqllogictest/test_files/array.slt +++ b/datafusion/sqllogictest/test_files/array.slt @@ -182,6 +182,38 @@ AS VALUES (make_array([[1], [2]], [[2], [3]]), make_array([1], [2])) ; +statement ok +CREATE TABLE array_distinct_table_1D +AS VALUES + (make_array(1, 1, 2, 2, 3)), + (make_array(1, 2, 3, 4, 5)), + (make_array(3, 5, 3, 3, 3)) +; + +statement ok +CREATE TABLE array_distinct_table_1D_UTF8 +AS VALUES + (make_array('a', 'a', 'bc', 'bc', 'def')), + (make_array('a', 'bc', 'def', 'defg', 'defg')), + (make_array('defg', 'defg', 'defg', 'defg', 'defg')) +; + +statement ok +CREATE TABLE array_distinct_table_2D +AS VALUES + (make_array([1,2], [1,2], [3,4], [3,4], [5,6])), + (make_array([1,2], [3,4], [5,6], [7,8], [9,10])), + (make_array([5,6], [5,6], NULL)) +; + +statement ok +CREATE TABLE array_distinct_table_1D_large +AS VALUES + (arrow_cast(make_array(1, 1, 2, 2, 3), 'LargeList(Int64)')), + (arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)')), + (arrow_cast(make_array(3, 5, 3, 3, 3), 'LargeList(Int64)')) +; + statement ok CREATE TABLE array_intersect_table_1D AS VALUES @@ -1052,6 +1084,44 @@ select make_array(['a','b'], null); ---- [[a, b], ] +## array_sort (aliases: `list_sort`) +query ??? +select array_sort(make_array(1, 3, null, 5, NULL, -5)), array_sort(make_array(1, 3, null, 2), 'ASC'), array_sort(make_array(1, 3, null, 2), 'desc', 'NULLS FIRST'); +---- +[, , -5, 1, 3, 5] [, 1, 2, 3] [, 3, 2, 1] + +query ? +select array_sort(column1, 'DESC', 'NULLS LAST') from arrays_values; +---- +[10, 9, 8, 7, 6, 5, 4, 3, 2, ] +[20, 18, 17, 16, 15, 14, 13, 12, 11, ] +[30, 29, 28, 27, 26, 25, 23, 22, 21, ] +[40, 39, 38, 37, 35, 34, 33, 32, 31, ] +NULL +[50, 49, 48, 47, 46, 45, 44, 43, 42, 41] +[60, 59, 58, 57, 56, 55, 54, 52, 51, ] +[70, 69, 68, 67, 66, 65, 64, 63, 62, 61] + +query ? +select array_sort(column1, 'ASC', 'NULLS FIRST') from arrays_values; +---- +[, 2, 3, 4, 5, 6, 7, 8, 9, 10] +[, 11, 12, 13, 14, 15, 16, 17, 18, 20] +[, 21, 22, 23, 25, 26, 27, 28, 29, 30] +[, 31, 32, 33, 34, 35, 37, 38, 39, 40] +NULL +[41, 42, 43, 44, 45, 46, 47, 48, 49, 50] +[, 51, 52, 54, 55, 56, 57, 58, 59, 60] +[61, 62, 63, 64, 65, 66, 67, 68, 69, 70] + + +## list_sort (aliases: `array_sort`) +query ??? +select list_sort(make_array(1, 3, null, 5, NULL, -5)), list_sort(make_array(1, 3, null, 2), 'ASC'), list_sort(make_array(1, 3, null, 2), 'desc', 'NULLS FIRST'); +---- +[, , -5, 1, 3, 5] [, 1, 2, 3] [, 3, 2, 1] + + ## array_append (aliases: `list_append`, `array_push_back`, `list_push_back`) # TODO: array_append with NULLs @@ -1224,7 +1294,7 @@ select array_prepend(make_array(1, 11, 111), column1), array_prepend(column2, ma # array_repeat scalar function #1 query ???????? -select +select array_repeat(1, 5), array_repeat(3.14, 3), array_repeat('l', 4), @@ -1257,7 +1327,7 @@ AS VALUES (0, 3, 3.3, 'datafusion', make_array(8, 9)); query ?????? -select +select array_repeat(column2, column1), array_repeat(column3, column1), array_repeat(column4, column1), @@ -1272,7 +1342,7 @@ from array_repeat_table; [] [] [] [] [3, 3, 3] [] statement ok -drop table array_repeat_table; +drop table array_repeat_table; ## array_concat (aliases: `array_cat`, `list_concat`, `list_cat`) @@ -2188,7 +2258,7 @@ select array_remove(make_array(1, 2, 2, 1, 1), 2), array_remove(make_array(1.0, [1, 2, 1, 1] [2.0, 2.0, 1.0, 1.0] [h, e, l, o] query ??? -select +select array_remove(make_array(1, null, 2, 3), 2), array_remove(make_array(1.1, null, 2.2, 3.3), 1.1), array_remove(make_array('a', null, 'bc'), 'a'); @@ -2371,24 +2441,44 @@ select array_length(make_array(1, 2, 3, 4, 5)), array_length(make_array(1, 2, 3) ---- 5 3 3 +query III +select array_length(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)')), array_length(arrow_cast(make_array(1, 2, 3), 'LargeList(Int64)')), array_length(arrow_cast(make_array([1, 2], [3, 4], [5, 6]), 'LargeList(List(Int64))')); +---- +5 3 3 + # array_length scalar function #2 query III select array_length(make_array(1, 2, 3, 4, 5), 1), array_length(make_array(1, 2, 3), 1), array_length(make_array([1, 2], [3, 4], [5, 6]), 1); ---- 5 3 3 +query III +select array_length(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)'), 1), array_length(arrow_cast(make_array(1, 2, 3), 'LargeList(Int64)'), 1), array_length(arrow_cast(make_array([1, 2], [3, 4], [5, 6]), 'LargeList(List(Int64))'), 1); +---- +5 3 3 + # array_length scalar function #3 query III select array_length(make_array(1, 2, 3, 4, 5), 2), array_length(make_array(1, 2, 3), 2), array_length(make_array([1, 2], [3, 4], [5, 6]), 2); ---- NULL NULL 2 +query III +select array_length(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)'), 2), array_length(arrow_cast(make_array(1, 2, 3), 'LargeList(Int64)'), 2), array_length(arrow_cast(make_array([1, 2], [3, 4], [5, 6]), 'LargeList(List(Int64))'), 2); +---- +NULL NULL 2 + # array_length scalar function #4 query II select array_length(array_repeat(array_repeat(array_repeat(3, 5), 2), 3), 1), array_length(array_repeat(array_repeat(array_repeat(3, 5), 2), 3), 2); ---- 3 2 +query II +select array_length(arrow_cast(array_repeat(array_repeat(array_repeat(3, 5), 2), 3), 'LargeList(List(List(Int64)))'), 1), array_length(arrow_cast(array_repeat(array_repeat(array_repeat(3, 5), 2), 3), 'LargeList(List(List(Int64)))'), 2); +---- +3 2 + # array_length scalar function #5 query III select array_length(make_array()), array_length(make_array(), 1), array_length(make_array(), 2) @@ -2407,6 +2497,11 @@ select list_length(make_array(1, 2, 3, 4, 5)), list_length(make_array(1, 2, 3)), ---- 5 3 3 NULL +query III +select list_length(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)')), list_length(arrow_cast(make_array(1, 2, 3), 'LargeList(Int64)')), list_length(arrow_cast(make_array([1, 2], [3, 4], [5, 6]), 'LargeList(List(Int64))')); +---- +5 3 3 + # array_length with columns query I select array_length(column1, column3) from arrays_values; @@ -2420,6 +2515,18 @@ NULL NULL NULL +query I +select array_length(arrow_cast(column1, 'LargeList(Int64)'), column3) from arrays_values; +---- +10 +NULL +NULL +NULL +NULL +NULL +NULL +NULL + # array_length with columns and scalars query II select array_length(array[array[1, 2], array[3, 4]], column3), array_length(column1, 1) from arrays_values; @@ -2433,6 +2540,18 @@ NULL 10 NULL 10 NULL 10 +query II +select array_length(arrow_cast(array[array[1, 2], array[3, 4]], 'LargeList(List(Int64))'), column3), array_length(arrow_cast(column1, 'LargeList(Int64)'), 1) from arrays_values; +---- +2 10 +2 10 +NULL 10 +NULL 10 +NULL NULL +NULL 10 +NULL 10 +NULL 10 + ## array_dims (aliases: `list_dims`) # array dims error @@ -2479,10 +2598,44 @@ NULL [3] [4] ## array_ndims (aliases: `list_ndims`) # array_ndims scalar function #1 + query III -select array_ndims(make_array(1, 2, 3)), array_ndims(make_array([1, 2], [3, 4])), array_ndims(make_array([[[[1], [2]]]])); +select + array_ndims(1), + array_ndims(null), + array_ndims([2, 3]); ---- -1 2 5 +0 0 1 + +statement ok +CREATE TABLE array_ndims_table +AS VALUES + (1, [1, 2, 3], [[7]], [[[[[10]]]]]), + (2, [4, 5], [[8]], [[[[[10]]]]]), + (null, [6], [[9]], [[[[[10]]]]]), + (3, [6], [[9]], [[[[[10]]]]]) +; + +query IIII +select + array_ndims(column1), + array_ndims(column2), + array_ndims(column3), + array_ndims(column4) +from array_ndims_table; +---- +0 1 2 5 +0 1 2 5 +0 1 2 5 +0 1 2 5 + +statement ok +drop table array_ndims_table; + +query I +select array_ndims(arrow_cast([null], 'List(List(List(Int64)))')); +---- +3 # array_ndims scalar function #2 query II @@ -2494,7 +2647,7 @@ select array_ndims(array_repeat(array_repeat(array_repeat(1, 3), 2), 1)), array_ query II select array_ndims(make_array()), array_ndims(make_array(make_array())) ---- -NULL 2 +1 2 # list_ndims scalar function #4 (function alias `array_ndims`) query III @@ -2505,7 +2658,7 @@ select list_ndims(make_array(1, 2, 3)), list_ndims(make_array([1, 2], [3, 4])), query II select array_ndims(make_array()), array_ndims(make_array(make_array())) ---- -NULL 2 +1 2 # array_ndims with columns query III @@ -2538,6 +2691,23 @@ select array_has(make_array(1,2), 1), ---- true true true true true false true false true false true false +query BBBBBBBBBBBB +select array_has(arrow_cast(make_array(1,2), 'LargeList(Int64)'), 1), + array_has(arrow_cast(make_array(1,2,NULL), 'LargeList(Int64)'), 1), + array_has(arrow_cast(make_array([2,3], [3,4]), 'LargeList(List(Int64))'), make_array(2,3)), + array_has(arrow_cast(make_array([[1], [2,3]], [[4,5], [6]]), 'LargeList(List(List(Int64)))'), make_array([1], [2,3])), + array_has(arrow_cast(make_array([[1], [2,3]], [[4,5], [6]]), 'LargeList(List(List(Int64)))'), make_array([4,5], [6])), + array_has(arrow_cast(make_array([[1], [2,3]], [[4,5], [6]]), 'LargeList(List(List(Int64)))'), make_array([1])), + array_has(arrow_cast(make_array([[[1]]]), 'LargeList(List(List(List(Int64))))'), make_array([[1]])), + array_has(arrow_cast(make_array([[[1]]], [[[1], [2]]]), 'LargeList(List(List(List(Int64))))'), make_array([[2]])), + array_has(arrow_cast(make_array([[[1]]], [[[1], [2]]]), 'LargeList(List(List(List(Int64))))'), make_array([[1], [2]])), + list_has(arrow_cast(make_array(1,2,3), 'LargeList(Int64)'), 4), + array_contains(arrow_cast(make_array(1,2,3), 'LargeList(Int64)'), 3), + list_contains(arrow_cast(make_array(1,2,3), 'LargeList(Int64)'), 0) +; +---- +true true true true true false true false true false true false + query BBB select array_has(column1, column2), array_has_all(column3, column4), @@ -2547,6 +2717,15 @@ from array_has_table_1D; true true true false false false +query BBB +select array_has(arrow_cast(column1, 'LargeList(Int64)'), column2), + array_has_all(arrow_cast(column3, 'LargeList(Int64)'), arrow_cast(column4, 'LargeList(Int64)')), + array_has_any(arrow_cast(column5, 'LargeList(Int64)'), arrow_cast(column6, 'LargeList(Int64)')) +from array_has_table_1D; +---- +true true true +false false false + query BBB select array_has(column1, column2), array_has_all(column3, column4), @@ -2556,6 +2735,15 @@ from array_has_table_1D_Float; true true false false false true +query BBB +select array_has(arrow_cast(column1, 'LargeList(Float64)'), column2), + array_has_all(arrow_cast(column3, 'LargeList(Float64)'), arrow_cast(column4, 'LargeList(Float64)')), + array_has_any(arrow_cast(column5, 'LargeList(Float64)'), arrow_cast(column6, 'LargeList(Float64)')) +from array_has_table_1D_Float; +---- +true true false +false false true + query BBB select array_has(column1, column2), array_has_all(column3, column4), @@ -2565,6 +2753,15 @@ from array_has_table_1D_Boolean; false true true true true true +query BBB +select array_has(arrow_cast(column1, 'LargeList(Boolean)'), column2), + array_has_all(arrow_cast(column3, 'LargeList(Boolean)'), arrow_cast(column4, 'LargeList(Boolean)')), + array_has_any(arrow_cast(column5, 'LargeList(Boolean)'), arrow_cast(column6, 'LargeList(Boolean)')) +from array_has_table_1D_Boolean; +---- +false true true +true true true + query BBB select array_has(column1, column2), array_has_all(column3, column4), @@ -2574,6 +2771,15 @@ from array_has_table_1D_UTF8; true true false false false true +query BBB +select array_has(arrow_cast(column1, 'LargeList(Utf8)'), column2), + array_has_all(arrow_cast(column3, 'LargeList(Utf8)'), arrow_cast(column4, 'LargeList(Utf8)')), + array_has_any(arrow_cast(column5, 'LargeList(Utf8)'), arrow_cast(column6, 'LargeList(Utf8)')) +from array_has_table_1D_UTF8; +---- +true true false +false false true + query BB select array_has(column1, column2), array_has_all(column3, column4) @@ -2582,6 +2788,14 @@ from array_has_table_2D; false true true false +query BB +select array_has(arrow_cast(column1, 'LargeList(List(Int64))'), column2), + array_has_all(arrow_cast(column3, 'LargeList(List(Int64))'), arrow_cast(column4, 'LargeList(List(Int64))')) +from array_has_table_2D; +---- +false true +true false + query B select array_has_all(column1, column2) from array_has_table_2D_float; @@ -2589,6 +2803,13 @@ from array_has_table_2D_float; true false +query B +select array_has_all(arrow_cast(column1, 'LargeList(List(Float64))'), arrow_cast(column2, 'LargeList(List(Float64))')) +from array_has_table_2D_float; +---- +true +false + query B select array_has(column1, column2) from array_has_table_3D; ---- @@ -2600,6 +2821,17 @@ true false true +query B +select array_has(arrow_cast(column1, 'LargeList(List(List(Int64)))'), column2) from array_has_table_3D; +---- +false +true +false +false +true +false +true + query BBBB select array_has(column1, make_array(5, 6)), array_has(column1, make_array(7, NULL)), @@ -2614,6 +2846,20 @@ false true false false false false false false false false false false +query BBBB +select array_has(arrow_cast(column1, 'LargeList(List(Int64))'), make_array(5, 6)), + array_has(arrow_cast(column1, 'LargeList(List(Int64))'), make_array(7, NULL)), + array_has(arrow_cast(column2, 'LargeList(Float64)'), 5.5), + array_has(arrow_cast(column3, 'LargeList(Utf8)'), 'o') +from arrays; +---- +false false false true +true false true false +true false false true +false true false false +false false false false +false false false false + query BBBBBBBBBBBBB select array_has_all(make_array(1,2,3), make_array(1,3)), array_has_all(make_array(1,2,3), make_array(1,4)), @@ -2632,6 +2878,91 @@ select array_has_all(make_array(1,2,3), make_array(1,3)), ---- true false true false false false true true false false true false true +query BBBBBBBBBBBBB +select array_has_all(arrow_cast(make_array(1,2,3), 'LargeList(Int64)'), arrow_cast(make_array(1,3), 'LargeList(Int64)')), + array_has_all(arrow_cast(make_array(1,2,3),'LargeList(Int64)'), arrow_cast(make_array(1,4), 'LargeList(Int64)')), + array_has_all(arrow_cast(make_array([1,2], [3,4]), 'LargeList(List(Int64))'), arrow_cast(make_array([1,2]), 'LargeList(List(Int64))')), + array_has_all(arrow_cast(make_array([1,2], [3,4]), 'LargeList(List(Int64))'), arrow_cast(make_array([1,3]), 'LargeList(List(Int64))')), + array_has_all(arrow_cast(make_array([1,2], [3,4]), 'LargeList(List(Int64))'), arrow_cast(make_array([1,2], [3,4], [5,6]), 'LargeList(List(Int64))')), + array_has_all(arrow_cast(make_array([[1,2,3]]), 'LargeList(List(List(Int64)))'), arrow_cast(make_array([[1]]), 'LargeList(List(List(Int64)))')), + array_has_all(arrow_cast(make_array([[1,2,3]]), 'LargeList(List(List(Int64)))'), arrow_cast(make_array([[1,2,3]]), 'LargeList(List(List(Int64)))')), + array_has_any(arrow_cast(make_array(1,2,3),'LargeList(Int64)'), arrow_cast(make_array(1,10,100), 'LargeList(Int64)')), + array_has_any(arrow_cast(make_array(1,2,3),'LargeList(Int64)'), arrow_cast(make_array(10,100),'LargeList(Int64)')), + array_has_any(arrow_cast(make_array([1,2], [3,4]), 'LargeList(List(Int64))'), arrow_cast(make_array([1,10], [10,4]), 'LargeList(List(Int64))')), + array_has_any(arrow_cast(make_array([1,2], [3,4]), 'LargeList(List(Int64))'), arrow_cast(make_array([10,20], [3,4]), 'LargeList(List(Int64))')), + array_has_any(arrow_cast(make_array([[1,2,3]]), 'LargeList(List(List(Int64)))'), arrow_cast(make_array([[1,2,3], [4,5,6]]), 'LargeList(List(List(Int64)))')), + array_has_any(arrow_cast(make_array([[1,2,3]]), 'LargeList(List(List(Int64)))'), arrow_cast(make_array([[1,2,3]], [[4,5,6]]), 'LargeList(List(List(Int64)))')) +; +---- +true false true false false false true true false false true false true + +query BBBBBBBBBBBBB +select array_has_all(arrow_cast(make_array(1,2,3), 'LargeList(Int64)'), arrow_cast(make_array(1,3), 'LargeList(Int64)')), + array_has_all(arrow_cast(make_array(1,2,3),'LargeList(Int64)'), arrow_cast(make_array(1,4), 'LargeList(Int64)')), + array_has_all(arrow_cast(make_array([1,2], [3,4]), 'LargeList(List(Int64))'), arrow_cast(make_array([1,2]), 'LargeList(List(Int64))')), + array_has_all(arrow_cast(make_array([1,2], [3,4]), 'LargeList(List(Int64))'), arrow_cast(make_array([1,3]), 'LargeList(List(Int64))')), + array_has_all(arrow_cast(make_array([1,2], [3,4]), 'LargeList(List(Int64))'), arrow_cast(make_array([1,2], [3,4], [5,6]), 'LargeList(List(Int64))')), + array_has_all(arrow_cast(make_array([[1,2,3]]), 'LargeList(List(List(Int64)))'), arrow_cast(make_array([[1]]), 'LargeList(List(List(Int64)))')), + array_has_all(arrow_cast(make_array([[1,2,3]]), 'LargeList(List(List(Int64)))'), arrow_cast(make_array([[1,2,3]]), 'LargeList(List(List(Int64)))')), + array_has_any(arrow_cast(make_array(1,2,3),'LargeList(Int64)'), arrow_cast(make_array(1,10,100), 'LargeList(Int64)')), + array_has_any(arrow_cast(make_array(1,2,3),'LargeList(Int64)'), arrow_cast(make_array(10,100),'LargeList(Int64)')), + array_has_any(arrow_cast(make_array([1,2], [3,4]), 'LargeList(List(Int64))'), arrow_cast(make_array([1,10], [10,4]), 'LargeList(List(Int64))')), + array_has_any(arrow_cast(make_array([1,2], [3,4]), 'LargeList(List(Int64))'), arrow_cast(make_array([10,20], [3,4]), 'LargeList(List(Int64))')), + array_has_any(arrow_cast(make_array([[1,2,3]]), 'LargeList(List(List(Int64)))'), arrow_cast(make_array([[1,2,3], [4,5,6]]), 'LargeList(List(List(Int64)))')), + array_has_any(arrow_cast(make_array([[1,2,3]]), 'LargeList(List(List(Int64)))'), arrow_cast(make_array([[1,2,3]], [[4,5,6]]), 'LargeList(List(List(Int64)))')) +; +---- +true false true false false false true true false false true false true + +## array_distinct + +query ? +select array_distinct(null); +---- +NULL + +query ? +select array_distinct([]); +---- +[] + +query ? +select array_distinct([[], []]); +---- +[[]] + +query ? +select array_distinct(column1) +from array_distinct_table_1D; +---- +[1, 2, 3] +[1, 2, 3, 4, 5] +[3, 5] + +query ? +select array_distinct(column1) +from array_distinct_table_1D_UTF8; +---- +[a, bc, def] +[a, bc, def, defg] +[defg] + +query ? +select array_distinct(column1) +from array_distinct_table_2D; +---- +[[1, 2], [3, 4], [5, 6]] +[[1, 2], [3, 4], [5, 6], [7, 8], [9, 10]] +[, [5, 6]] + +query ? +select array_distinct(column1) +from array_distinct_table_1D_large; +---- +[1, 2, 3] +[1, 2, 3, 4, 5] +[3, 5] + query ??? select array_intersect(column1, column2), array_intersect(column3, column4), @@ -2693,7 +3024,7 @@ from array_intersect_table_3D; query ?????? SELECT array_intersect(make_array(1,2,3), make_array(2,3,4)), array_intersect(make_array(1,3,5), make_array(2,4,6)), - array_intersect(make_array('aa','bb','cc'), make_array('cc','aa','dd')), + array_intersect(make_array('aa','bb','cc'), make_array('cc','aa','dd')), array_intersect(make_array(true, false), make_array(true)), array_intersect(make_array(1.1, 2.2, 3.3), make_array(2.2, 3.3, 4.4)), array_intersect(make_array([1, 1], [2, 2], [3, 3]), make_array([2, 2], [3, 3], [4, 4])) @@ -2724,7 +3055,7 @@ NULL query ?????? SELECT list_intersect(make_array(1,2,3), make_array(2,3,4)), list_intersect(make_array(1,3,5), make_array(2,4,6)), - list_intersect(make_array('aa','bb','cc'), make_array('cc','aa','dd')), + list_intersect(make_array('aa','bb','cc'), make_array('cc','aa','dd')), list_intersect(make_array(true, false), make_array(true)), list_intersect(make_array(1.1, 2.2, 3.3), make_array(2.2, 3.3, 4.4)), list_intersect(make_array([1, 1], [2, 2], [3, 3]), make_array([2, 2], [3, 3], [4, 4])) @@ -3056,18 +3387,33 @@ select empty(make_array(1)); ---- false +query B +select empty(arrow_cast(make_array(1), 'LargeList(Int64)')); +---- +false + # empty scalar function #2 query B select empty(make_array()); ---- true +query B +select empty(arrow_cast(make_array(), 'LargeList(Null)')); +---- +true + # empty scalar function #3 query B select empty(make_array(NULL)); ---- false +query B +select empty(arrow_cast(make_array(NULL), 'LargeList(Null)')); +---- +false + # empty scalar function #4 query B select empty(NULL); @@ -3086,6 +3432,17 @@ NULL false false +query B +select empty(arrow_cast(column1, 'LargeList(List(Int64))')) from arrays; +---- +false +false +false +false +NULL +false +false + query ? SELECT string_to_array('abcxxxdef', 'xxx') ---- diff --git a/datafusion/sqllogictest/test_files/explain.slt b/datafusion/sqllogictest/test_files/explain.slt index 18792735ffed..4583ef319b7f 100644 --- a/datafusion/sqllogictest/test_files/explain.slt +++ b/datafusion/sqllogictest/test_files/explain.slt @@ -94,7 +94,7 @@ EXPLAIN select count(*) from (values ('a', 1, 100), ('a', 2, 150)) as t (c1,c2,c ---- physical_plan ProjectionExec: expr=[2 as COUNT(*)] ---EmptyExec: produce_one_row=true +--PlaceholderRowExec statement ok set datafusion.explain.physical_plan_only = false @@ -368,7 +368,7 @@ Projection: List([[1, 2, 3], [4, 5, 6]]) AS make_array(make_array(Int64(1),Int64 --EmptyRelation physical_plan ProjectionExec: expr=[[[1, 2, 3], [4, 5, 6]] as make_array(make_array(Int64(1),Int64(2),Int64(3)),make_array(Int64(4),Int64(5),Int64(6)))] ---EmptyExec: produce_one_row=true +--PlaceholderRowExec query TT explain select [[1, 2, 3], [4, 5, 6]]; @@ -378,4 +378,4 @@ Projection: List([[1, 2, 3], [4, 5, 6]]) AS make_array(make_array(Int64(1),Int64 --EmptyRelation physical_plan ProjectionExec: expr=[[[1, 2, 3], [4, 5, 6]] as make_array(make_array(Int64(1),Int64(2),Int64(3)),make_array(Int64(4),Int64(5),Int64(6)))] ---EmptyExec: produce_one_row=true +--PlaceholderRowExec diff --git a/datafusion/sqllogictest/test_files/functions.slt b/datafusion/sqllogictest/test_files/functions.slt index 4f55ea316bb9..1903088b0748 100644 --- a/datafusion/sqllogictest/test_files/functions.slt +++ b/datafusion/sqllogictest/test_files/functions.slt @@ -995,3 +995,9 @@ query ? SELECT find_in_set(NULL, NULL) ---- NULL + +# Verify that multiple calls to volatile functions like `random()` are not combined / optimized away +query B +SELECT r FROM (SELECT r1 == r2 r, r1, r2 FROM (SELECT random() r1, random() r2) WHERE r1 > 0 AND r2 > 0) +---- +false diff --git a/datafusion/sqllogictest/test_files/groupby.slt b/datafusion/sqllogictest/test_files/groupby.slt index 1d6d7dc671fa..b7be4d78b583 100644 --- a/datafusion/sqllogictest/test_files/groupby.slt +++ b/datafusion/sqllogictest/test_files/groupby.slt @@ -3211,6 +3211,21 @@ SELECT s.sn, s.amount, 2*s.sn 3 200 6 4 100 8 +# we should be able to re-write group by expression +# using functional dependencies for complex expressions also. +# In this case, we use 2*s.amount instead of s.amount. +query IRI +SELECT s.sn, 2*s.amount, 2*s.sn + FROM sales_global_with_pk AS s + GROUP BY sn + ORDER BY sn +---- +0 60 0 +1 100 2 +2 150 4 +3 400 6 +4 200 8 + query IRI SELECT s.sn, s.amount, 2*s.sn FROM sales_global_with_pk_alternate AS s @@ -3364,7 +3379,7 @@ SELECT column1, COUNT(*) as column2 FROM (VALUES (['a', 'b'], 1), (['c', 'd', 'e # primary key should be aware from which columns it is associated -statement error DataFusion error: Error during planning: Projection references non-aggregate values: Expression r.sn could not be resolved from available columns: l.sn, SUM\(l.amount\) +statement error DataFusion error: Error during planning: Projection references non-aggregate values: Expression r.sn could not be resolved from available columns: l.sn, l.zip_code, l.country, l.ts, l.currency, l.amount, SUM\(l.amount\) SELECT l.sn, r.sn, SUM(l.amount), r.amount FROM sales_global_with_pk AS l JOIN sales_global_with_pk AS r @@ -3456,7 +3471,7 @@ ORDER BY r.sn 4 100 2022-01-03T10:00:00 # after join, new window expressions shouldn't be associated with primary keys -statement error DataFusion error: Error during planning: Projection references non-aggregate values: Expression rn1 could not be resolved from available columns: r.sn, SUM\(r.amount\) +statement error DataFusion error: Error during planning: Projection references non-aggregate values: Expression rn1 could not be resolved from available columns: r.sn, r.ts, r.amount, SUM\(r.amount\) SELECT r.sn, SUM(r.amount), rn1 FROM (SELECT r.ts, r.sn, r.amount, @@ -3784,6 +3799,192 @@ AggregateExec: mode=FinalPartitioned, gby=[c@0 as c, b@1 as b], aggr=[SUM(multip ----------RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 ------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[b, c, d], output_ordering=[c@1 ASC NULLS LAST], has_header=true +statement ok +set datafusion.execution.target_partitions = 1; + +query TT +EXPLAIN SELECT c, sum1 + FROM + (SELECT c, b, a, SUM(d) as sum1 + FROM multiple_ordered_table_with_pk + GROUP BY c) +GROUP BY c; +---- +logical_plan +Aggregate: groupBy=[[multiple_ordered_table_with_pk.c, sum1]], aggr=[[]] +--Projection: multiple_ordered_table_with_pk.c, SUM(multiple_ordered_table_with_pk.d) AS sum1 +----Aggregate: groupBy=[[multiple_ordered_table_with_pk.c]], aggr=[[SUM(CAST(multiple_ordered_table_with_pk.d AS Int64))]] +------TableScan: multiple_ordered_table_with_pk projection=[c, d] +physical_plan +AggregateExec: mode=Single, gby=[c@0 as c, sum1@1 as sum1], aggr=[], ordering_mode=PartiallySorted([0]) +--ProjectionExec: expr=[c@0 as c, SUM(multiple_ordered_table_with_pk.d)@1 as sum1] +----AggregateExec: mode=Single, gby=[c@0 as c], aggr=[SUM(multiple_ordered_table_with_pk.d)], ordering_mode=Sorted +------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[c, d], output_ordering=[c@0 ASC NULLS LAST], has_header=true + +query TT +EXPLAIN SELECT c, sum1, SUM(b) OVER() as sumb + FROM + (SELECT c, b, a, SUM(d) as sum1 + FROM multiple_ordered_table_with_pk + GROUP BY c); +---- +logical_plan +Projection: multiple_ordered_table_with_pk.c, sum1, SUM(multiple_ordered_table_with_pk.b) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING AS sumb +--WindowAggr: windowExpr=[[SUM(CAST(multiple_ordered_table_with_pk.b AS Int64)) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING]] +----Projection: multiple_ordered_table_with_pk.c, multiple_ordered_table_with_pk.b, SUM(multiple_ordered_table_with_pk.d) AS sum1 +------Aggregate: groupBy=[[multiple_ordered_table_with_pk.c, multiple_ordered_table_with_pk.b]], aggr=[[SUM(CAST(multiple_ordered_table_with_pk.d AS Int64))]] +--------TableScan: multiple_ordered_table_with_pk projection=[b, c, d] +physical_plan +ProjectionExec: expr=[c@0 as c, sum1@2 as sum1, SUM(multiple_ordered_table_with_pk.b) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING@3 as sumb] +--WindowAggExec: wdw=[SUM(multiple_ordered_table_with_pk.b) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING: Ok(Field { name: "SUM(multiple_ordered_table_with_pk.b) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(NULL)) }] +----ProjectionExec: expr=[c@0 as c, b@1 as b, SUM(multiple_ordered_table_with_pk.d)@2 as sum1] +------AggregateExec: mode=Single, gby=[c@1 as c, b@0 as b], aggr=[SUM(multiple_ordered_table_with_pk.d)], ordering_mode=PartiallySorted([0]) +--------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[b, c, d], output_ordering=[c@1 ASC NULLS LAST], has_header=true + +query TT +EXPLAIN SELECT lhs.c, rhs.c, lhs.sum1, rhs.sum1 + FROM + (SELECT c, b, a, SUM(d) as sum1 + FROM multiple_ordered_table_with_pk + GROUP BY c) as lhs + JOIN + (SELECT c, b, a, SUM(d) as sum1 + FROM multiple_ordered_table_with_pk + GROUP BY c) as rhs + ON lhs.b=rhs.b; +---- +logical_plan +Projection: lhs.c, rhs.c, lhs.sum1, rhs.sum1 +--Inner Join: lhs.b = rhs.b +----SubqueryAlias: lhs +------Projection: multiple_ordered_table_with_pk.c, multiple_ordered_table_with_pk.b, SUM(multiple_ordered_table_with_pk.d) AS sum1 +--------Aggregate: groupBy=[[multiple_ordered_table_with_pk.c, multiple_ordered_table_with_pk.b]], aggr=[[SUM(CAST(multiple_ordered_table_with_pk.d AS Int64))]] +----------TableScan: multiple_ordered_table_with_pk projection=[b, c, d] +----SubqueryAlias: rhs +------Projection: multiple_ordered_table_with_pk.c, multiple_ordered_table_with_pk.b, SUM(multiple_ordered_table_with_pk.d) AS sum1 +--------Aggregate: groupBy=[[multiple_ordered_table_with_pk.c, multiple_ordered_table_with_pk.b]], aggr=[[SUM(CAST(multiple_ordered_table_with_pk.d AS Int64))]] +----------TableScan: multiple_ordered_table_with_pk projection=[b, c, d] +physical_plan +ProjectionExec: expr=[c@0 as c, c@3 as c, sum1@2 as sum1, sum1@5 as sum1] +--CoalesceBatchesExec: target_batch_size=2 +----HashJoinExec: mode=CollectLeft, join_type=Inner, on=[(b@1, b@1)] +------ProjectionExec: expr=[c@0 as c, b@1 as b, SUM(multiple_ordered_table_with_pk.d)@2 as sum1] +--------AggregateExec: mode=Single, gby=[c@1 as c, b@0 as b], aggr=[SUM(multiple_ordered_table_with_pk.d)], ordering_mode=PartiallySorted([0]) +----------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[b, c, d], output_ordering=[c@1 ASC NULLS LAST], has_header=true +------ProjectionExec: expr=[c@0 as c, b@1 as b, SUM(multiple_ordered_table_with_pk.d)@2 as sum1] +--------AggregateExec: mode=Single, gby=[c@1 as c, b@0 as b], aggr=[SUM(multiple_ordered_table_with_pk.d)], ordering_mode=PartiallySorted([0]) +----------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[b, c, d], output_ordering=[c@1 ASC NULLS LAST], has_header=true + +query TT +EXPLAIN SELECT lhs.c, rhs.c, lhs.sum1, rhs.sum1 + FROM + (SELECT c, b, a, SUM(d) as sum1 + FROM multiple_ordered_table_with_pk + GROUP BY c) as lhs + CROSS JOIN + (SELECT c, b, a, SUM(d) as sum1 + FROM multiple_ordered_table_with_pk + GROUP BY c) as rhs; +---- +logical_plan +Projection: lhs.c, rhs.c, lhs.sum1, rhs.sum1 +--CrossJoin: +----SubqueryAlias: lhs +------Projection: multiple_ordered_table_with_pk.c, SUM(multiple_ordered_table_with_pk.d) AS sum1 +--------Aggregate: groupBy=[[multiple_ordered_table_with_pk.c]], aggr=[[SUM(CAST(multiple_ordered_table_with_pk.d AS Int64))]] +----------TableScan: multiple_ordered_table_with_pk projection=[c, d] +----SubqueryAlias: rhs +------Projection: multiple_ordered_table_with_pk.c, SUM(multiple_ordered_table_with_pk.d) AS sum1 +--------Aggregate: groupBy=[[multiple_ordered_table_with_pk.c]], aggr=[[SUM(CAST(multiple_ordered_table_with_pk.d AS Int64))]] +----------TableScan: multiple_ordered_table_with_pk projection=[c, d] +physical_plan +ProjectionExec: expr=[c@0 as c, c@2 as c, sum1@1 as sum1, sum1@3 as sum1] +--CrossJoinExec +----ProjectionExec: expr=[c@0 as c, SUM(multiple_ordered_table_with_pk.d)@1 as sum1] +------AggregateExec: mode=Single, gby=[c@0 as c], aggr=[SUM(multiple_ordered_table_with_pk.d)], ordering_mode=Sorted +--------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[c, d], output_ordering=[c@0 ASC NULLS LAST], has_header=true +----ProjectionExec: expr=[c@0 as c, SUM(multiple_ordered_table_with_pk.d)@1 as sum1] +------AggregateExec: mode=Single, gby=[c@0 as c], aggr=[SUM(multiple_ordered_table_with_pk.d)], ordering_mode=Sorted +--------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[c, d], output_ordering=[c@0 ASC NULLS LAST], has_header=true + +# we do not generate physical plan for Repartition yet (e.g Distribute By queries). +query TT +EXPLAIN SELECT a, b, sum1 +FROM (SELECT c, b, a, SUM(d) as sum1 + FROM multiple_ordered_table_with_pk + GROUP BY c) +DISTRIBUTE BY a +---- +logical_plan +Repartition: DistributeBy(a) +--Projection: multiple_ordered_table_with_pk.a, multiple_ordered_table_with_pk.b, SUM(multiple_ordered_table_with_pk.d) AS sum1 +----Aggregate: groupBy=[[multiple_ordered_table_with_pk.c, multiple_ordered_table_with_pk.a, multiple_ordered_table_with_pk.b]], aggr=[[SUM(CAST(multiple_ordered_table_with_pk.d AS Int64))]] +------TableScan: multiple_ordered_table_with_pk projection=[a, b, c, d] + +# union with aggregate +query TT +EXPLAIN SELECT c, a, SUM(d) as sum1 + FROM multiple_ordered_table_with_pk + GROUP BY c +UNION ALL + SELECT c, a, SUM(d) as sum1 + FROM multiple_ordered_table_with_pk + GROUP BY c +---- +logical_plan +Union +--Projection: multiple_ordered_table_with_pk.c, multiple_ordered_table_with_pk.a, SUM(multiple_ordered_table_with_pk.d) AS sum1 +----Aggregate: groupBy=[[multiple_ordered_table_with_pk.c, multiple_ordered_table_with_pk.a]], aggr=[[SUM(CAST(multiple_ordered_table_with_pk.d AS Int64))]] +------TableScan: multiple_ordered_table_with_pk projection=[a, c, d] +--Projection: multiple_ordered_table_with_pk.c, multiple_ordered_table_with_pk.a, SUM(multiple_ordered_table_with_pk.d) AS sum1 +----Aggregate: groupBy=[[multiple_ordered_table_with_pk.c, multiple_ordered_table_with_pk.a]], aggr=[[SUM(CAST(multiple_ordered_table_with_pk.d AS Int64))]] +------TableScan: multiple_ordered_table_with_pk projection=[a, c, d] +physical_plan +UnionExec +--ProjectionExec: expr=[c@0 as c, a@1 as a, SUM(multiple_ordered_table_with_pk.d)@2 as sum1] +----AggregateExec: mode=Single, gby=[c@1 as c, a@0 as a], aggr=[SUM(multiple_ordered_table_with_pk.d)], ordering_mode=Sorted +------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a, c, d], output_orderings=[[a@0 ASC NULLS LAST], [c@1 ASC NULLS LAST]], has_header=true +--ProjectionExec: expr=[c@0 as c, a@1 as a, SUM(multiple_ordered_table_with_pk.d)@2 as sum1] +----AggregateExec: mode=Single, gby=[c@1 as c, a@0 as a], aggr=[SUM(multiple_ordered_table_with_pk.d)], ordering_mode=Sorted +------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a, c, d], output_orderings=[[a@0 ASC NULLS LAST], [c@1 ASC NULLS LAST]], has_header=true + +# table scan should be simplified. +query TT +EXPLAIN SELECT c, a, SUM(d) as sum1 + FROM multiple_ordered_table_with_pk + GROUP BY c +---- +logical_plan +Projection: multiple_ordered_table_with_pk.c, multiple_ordered_table_with_pk.a, SUM(multiple_ordered_table_with_pk.d) AS sum1 +--Aggregate: groupBy=[[multiple_ordered_table_with_pk.c, multiple_ordered_table_with_pk.a]], aggr=[[SUM(CAST(multiple_ordered_table_with_pk.d AS Int64))]] +----TableScan: multiple_ordered_table_with_pk projection=[a, c, d] +physical_plan +ProjectionExec: expr=[c@0 as c, a@1 as a, SUM(multiple_ordered_table_with_pk.d)@2 as sum1] +--AggregateExec: mode=Single, gby=[c@1 as c, a@0 as a], aggr=[SUM(multiple_ordered_table_with_pk.d)], ordering_mode=Sorted +----CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a, c, d], output_orderings=[[a@0 ASC NULLS LAST], [c@1 ASC NULLS LAST]], has_header=true + +# limit should be simplified +query TT +EXPLAIN SELECT * + FROM (SELECT c, a, SUM(d) as sum1 + FROM multiple_ordered_table_with_pk + GROUP BY c + LIMIT 5) +---- +logical_plan +Projection: multiple_ordered_table_with_pk.c, multiple_ordered_table_with_pk.a, SUM(multiple_ordered_table_with_pk.d) AS sum1 +--Limit: skip=0, fetch=5 +----Aggregate: groupBy=[[multiple_ordered_table_with_pk.c, multiple_ordered_table_with_pk.a]], aggr=[[SUM(CAST(multiple_ordered_table_with_pk.d AS Int64))]] +------TableScan: multiple_ordered_table_with_pk projection=[a, c, d] +physical_plan +ProjectionExec: expr=[c@0 as c, a@1 as a, SUM(multiple_ordered_table_with_pk.d)@2 as sum1] +--GlobalLimitExec: skip=0, fetch=5 +----AggregateExec: mode=Single, gby=[c@1 as c, a@0 as a], aggr=[SUM(multiple_ordered_table_with_pk.d)], ordering_mode=Sorted +------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a, c, d], output_orderings=[[a@0 ASC NULLS LAST], [c@1 ASC NULLS LAST]], has_header=true + +statement ok +set datafusion.execution.target_partitions = 8; + # Tests for single distinct to group by optimization rule statement ok CREATE TABLE t(x int) AS VALUES (1), (2), (1); @@ -4047,3 +4248,34 @@ set datafusion.sql_parser.dialect = 'Generic'; statement ok drop table aggregate_test_100; + + +# Create an unbounded external table with primary key +# column c +statement ok +CREATE EXTERNAL TABLE unbounded_multiple_ordered_table_with_pk ( + a0 INTEGER, + a INTEGER, + b INTEGER, + c INTEGER primary key, + d INTEGER +) +STORED AS CSV +WITH HEADER ROW +WITH ORDER (a ASC, b ASC) +WITH ORDER (c ASC) +LOCATION '../core/tests/data/window_2.csv'; + +# Query below can be executed, since c is primary key. +query III rowsort +SELECT c, a, SUM(d) +FROM unbounded_multiple_ordered_table_with_pk +GROUP BY c +ORDER BY c +LIMIT 5 +---- +0 0 0 +1 0 2 +2 0 0 +3 0 0 +4 0 1 diff --git a/datafusion/sqllogictest/test_files/information_schema.slt b/datafusion/sqllogictest/test_files/information_schema.slt index 741ff724781f..5c6bf6e2dac1 100644 --- a/datafusion/sqllogictest/test_files/information_schema.slt +++ b/datafusion/sqllogictest/test_files/information_schema.slt @@ -188,6 +188,7 @@ datafusion.explain.logical_plan_only false datafusion.explain.physical_plan_only false datafusion.explain.show_statistics false datafusion.optimizer.allow_symmetric_joins_without_pruning true +datafusion.optimizer.default_filter_selectivity 20 datafusion.optimizer.enable_distinct_aggregation_soft_limit true datafusion.optimizer.enable_round_robin_repartition true datafusion.optimizer.enable_topk_aggregation true @@ -261,6 +262,7 @@ datafusion.explain.logical_plan_only false When set to true, the explain stateme datafusion.explain.physical_plan_only false When set to true, the explain statement will only print physical plans datafusion.explain.show_statistics false When set to true, the explain statement will print operator statistics for physical plans datafusion.optimizer.allow_symmetric_joins_without_pruning true Should DataFusion allow symmetric hash joins for unbounded data sources even when its inputs do not have any ordering or filtering If the flag is not enabled, the SymmetricHashJoin operator will be unable to prune its internal buffers, resulting in certain join types - such as Full, Left, LeftAnti, LeftSemi, Right, RightAnti, and RightSemi - being produced only at the end of the execution. This is not typical in stream processing. Additionally, without proper design for long runner execution, all types of joins may encounter out-of-memory errors. +datafusion.optimizer.default_filter_selectivity 20 The default filter selectivity used by Filter Statistics when an exact selectivity cannot be determined. Valid values are between 0 (no selectivity) and 100 (all rows are selected). datafusion.optimizer.enable_distinct_aggregation_soft_limit true When set to true, the optimizer will push a limit operation into grouped aggregations which have no aggregate expressions, as a soft limit, emitting groups once the limit is reached, before all rows in the group are read. datafusion.optimizer.enable_round_robin_repartition true When set to true, the physical plan optimizer will try to add round robin repartitioning to increase parallelism to leverage more CPU cores datafusion.optimizer.enable_topk_aggregation true When set to true, the optimizer will attempt to perform limit operations during aggregations, if possible diff --git a/datafusion/sqllogictest/test_files/insert.slt b/datafusion/sqllogictest/test_files/insert.slt index 75252b3b7c35..e20b3779459b 100644 --- a/datafusion/sqllogictest/test_files/insert.slt +++ b/datafusion/sqllogictest/test_files/insert.slt @@ -382,6 +382,27 @@ select a,b,c,d from test_column_defaults 1 10 100 ABC NULL 20 500 default_text +# fill the timestamp column with default value `now()` again, it should be different from the previous one +query IIITP +insert into test_column_defaults(a, b, c, d) values(2, 20, 200, 'DEF') +---- +1 + +# Ensure that the default expression `now()` is evaluated during insertion, not optimized away. +# Rows are inserted during different time, so their timestamp values should be different. +query I rowsort +select count(distinct e) from test_column_defaults +---- +3 + +# Expect all rows to be true as now() was inserted into the table +query B rowsort +select e < now() from test_column_defaults +---- +true +true +true + statement ok drop table test_column_defaults diff --git a/datafusion/sqllogictest/test_files/insert_to_external.slt b/datafusion/sqllogictest/test_files/insert_to_external.slt index 39323479ff74..85c2db7faaf6 100644 --- a/datafusion/sqllogictest/test_files/insert_to_external.slt +++ b/datafusion/sqllogictest/test_files/insert_to_external.slt @@ -543,3 +543,70 @@ select * from table_without_values; statement ok drop table table_without_values; + + +### Test for specifying column's default value + +statement ok +CREATE EXTERNAL TABLE test_column_defaults( + a int, + b int not null default null, + c int default 100*2+300, + d text default lower('DEFAULT_TEXT'), + e timestamp default now() +) STORED AS parquet +LOCATION 'test_files/scratch/insert_to_external/external_parquet_table_q6' +OPTIONS (create_local_path 'true'); + +# fill in all column values +query IIITP +insert into test_column_defaults values(1, 10, 100, 'ABC', now()) +---- +1 + +statement error DataFusion error: Execution error: Invalid batch column at '1' has null but schema specifies non-nullable +insert into test_column_defaults(a) values(2) + +query IIITP +insert into test_column_defaults(b) values(20) +---- +1 + +query IIIT rowsort +select a,b,c,d from test_column_defaults +---- +1 10 100 ABC +NULL 20 500 default_text + +# fill the timestamp column with default value `now()` again, it should be different from the previous one +query IIITP +insert into test_column_defaults(a, b, c, d) values(2, 20, 200, 'DEF') +---- +1 + +# Ensure that the default expression `now()` is evaluated during insertion, not optimized away. +# Rows are inserted during different time, so their timestamp values should be different. +query I rowsort +select count(distinct e) from test_column_defaults +---- +3 + +# Expect all rows to be true as now() was inserted into the table +query B rowsort +select e < now() from test_column_defaults +---- +true +true +true + +statement ok +drop table test_column_defaults + +# test invalid default value +statement error DataFusion error: Error during planning: Column reference is not allowed in the DEFAULT expression : Schema error: No field named a. +CREATE EXTERNAL TABLE test_column_defaults( + a int, + b int default a+1 +) STORED AS parquet +LOCATION 'test_files/scratch/insert_to_external/external_parquet_table_q7' +OPTIONS (create_local_path 'true'); diff --git a/datafusion/sqllogictest/test_files/join.slt b/datafusion/sqllogictest/test_files/join.slt index 874d849e9a29..386ffe766b19 100644 --- a/datafusion/sqllogictest/test_files/join.slt +++ b/datafusion/sqllogictest/test_files/join.slt @@ -556,7 +556,7 @@ query TT explain select * from t1 join t2 on false; ---- logical_plan EmptyRelation -physical_plan EmptyExec: produce_one_row=false +physical_plan EmptyExec # Make batch size smaller than table row number. to introduce parallelism to the plan. statement ok diff --git a/datafusion/sqllogictest/test_files/limit.slt b/datafusion/sqllogictest/test_files/limit.slt index 182195112e87..e063d6e8960a 100644 --- a/datafusion/sqllogictest/test_files/limit.slt +++ b/datafusion/sqllogictest/test_files/limit.slt @@ -312,7 +312,7 @@ Aggregate: groupBy=[[]], aggr=[[COUNT(UInt8(1)) AS COUNT(*)]] ----TableScan: t1 projection=[], fetch=14 physical_plan ProjectionExec: expr=[0 as COUNT(*)] ---EmptyExec: produce_one_row=true +--PlaceholderRowExec query I SELECT COUNT(*) FROM (SELECT a FROM t1 LIMIT 3 OFFSET 11); @@ -330,7 +330,7 @@ Aggregate: groupBy=[[]], aggr=[[COUNT(UInt8(1)) AS COUNT(*)]] ----TableScan: t1 projection=[], fetch=11 physical_plan ProjectionExec: expr=[2 as COUNT(*)] ---EmptyExec: produce_one_row=true +--PlaceholderRowExec query I SELECT COUNT(*) FROM (SELECT a FROM t1 LIMIT 3 OFFSET 8); @@ -348,7 +348,7 @@ Aggregate: groupBy=[[]], aggr=[[COUNT(UInt8(1)) AS COUNT(*)]] ----TableScan: t1 projection=[] physical_plan ProjectionExec: expr=[2 as COUNT(*)] ---EmptyExec: produce_one_row=true +--PlaceholderRowExec query I SELECT COUNT(*) FROM (SELECT a FROM t1 LIMIT 3 OFFSET 8); diff --git a/datafusion/sqllogictest/test_files/schema_evolution.slt b/datafusion/sqllogictest/test_files/schema_evolution.slt new file mode 100644 index 000000000000..36d54159e24d --- /dev/null +++ b/datafusion/sqllogictest/test_files/schema_evolution.slt @@ -0,0 +1,140 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +########## +# Tests for schema evolution -- reading +# data from different files with different schemas +########## + + +statement ok +CREATE EXTERNAL TABLE parquet_table(a varchar, b int, c float) STORED AS PARQUET +LOCATION 'test_files/scratch/schema_evolution/parquet_table/'; + +# File1 has only columns a and b +statement ok +COPY ( + SELECT column1 as a, column2 as b + FROM ( VALUES ('foo', 1), ('foo', 2), ('foo', 3) ) + ) TO 'test_files/scratch/schema_evolution/parquet_table/1.parquet' +(FORMAT PARQUET, SINGLE_FILE_OUTPUT true); + + +# File2 has only b +statement ok +COPY ( + SELECT column1 as b + FROM ( VALUES (10) ) + ) TO 'test_files/scratch/schema_evolution/parquet_table/2.parquet' +(FORMAT PARQUET, SINGLE_FILE_OUTPUT true); + +# File3 has a column from 'z' which does not appear in the table +# but also values from a which do appear in the table +statement ok +COPY ( + SELECT column1 as z, column2 as a + FROM ( VALUES ('bar', 'foo'), ('blarg', 'foo') ) + ) TO 'test_files/scratch/schema_evolution/parquet_table/3.parquet' +(FORMAT PARQUET, SINGLE_FILE_OUTPUT true); + +# File4 has data for b and a (reversed) and d +statement ok +COPY ( + SELECT column1 as b, column2 as a, column3 as c + FROM ( VALUES (100, 'foo', 10.5), (200, 'foo', 12.6), (300, 'bzz', 13.7) ) + ) TO 'test_files/scratch/schema_evolution/parquet_table/4.parquet' +(FORMAT PARQUET, SINGLE_FILE_OUTPUT true); + +# The logical distribution of `a`, `b` and `c` in the files is like this: +# +## File1: +# foo 1 NULL +# foo 2 NULL +# foo 3 NULL +# +## File2: +# NULL 10 NULL +# +## File3: +# foo NULL NULL +# foo NULL NULL +# +## File4: +# foo 100 10.5 +# foo 200 12.6 +# bzz 300 13.7 + +# Show all the data +query TIR rowsort +select * from parquet_table; +---- +NULL 10 NULL +bzz 300 13.7 +foo 1 NULL +foo 100 10.5 +foo 2 NULL +foo 200 12.6 +foo 3 NULL +foo NULL NULL +foo NULL NULL + +# Should see all 7 rows that have 'a=foo' +query TIR rowsort +select * from parquet_table where a = 'foo'; +---- +foo 1 NULL +foo 100 10.5 +foo 2 NULL +foo 200 12.6 +foo 3 NULL +foo NULL NULL +foo NULL NULL + +query TIR rowsort +select * from parquet_table where a != 'foo'; +---- +bzz 300 13.7 + +# this should produce at least one row +query TIR rowsort +select * from parquet_table where a is NULL; +---- +NULL 10 NULL + +query TIR rowsort +select * from parquet_table where b > 5; +---- +NULL 10 NULL +bzz 300 13.7 +foo 100 10.5 +foo 200 12.6 + + +query TIR rowsort +select * from parquet_table where b < 150; +---- +NULL 10 NULL +foo 1 NULL +foo 100 10.5 +foo 2 NULL +foo 3 NULL + +query TIR rowsort +select * from parquet_table where c > 11.0; +---- +bzz 300 13.7 +foo 200 12.6 diff --git a/datafusion/sqllogictest/test_files/select.slt b/datafusion/sqllogictest/test_files/select.slt index bb81c5a9a138..ea570b99d4dd 100644 --- a/datafusion/sqllogictest/test_files/select.slt +++ b/datafusion/sqllogictest/test_files/select.slt @@ -868,6 +868,21 @@ statement error DataFusion error: Error during planning: EXCLUDE or EXCEPT conta SELECT * EXCLUDE(d, b, c, a, a, b, c, d) FROM table1 +# avoiding adding an alias if the column name is the same +query TT +EXPLAIN select a as a FROM table1 order by a +---- +logical_plan +Sort: table1.a ASC NULLS LAST +--TableScan: table1 projection=[a] +physical_plan +SortExec: expr=[a@0 ASC NULLS LAST] +--MemoryExec: partitions=1, partition_sizes=[1] + +# ambiguous column references in on join +query error DataFusion error: Schema error: Ambiguous reference to unqualified field a +EXPLAIN select a as a FROM table1 t1 CROSS JOIN table1 t2 order by a + # run below query in multi partitions statement ok set datafusion.execution.target_partitions = 2; @@ -1041,3 +1056,51 @@ drop table annotated_data_finite2; statement ok drop table t; + +statement ok +create table t(x bigint, y bigint) as values (1,2), (1,3); + +query II +select z+1, y from (select x+1 as z, y from t) where y > 1; +---- +3 2 +3 3 + +query TT +EXPLAIN SELECT x/2, x/2+1 FROM t; +---- +logical_plan +Projection: t.x / Int64(2)Int64(2)t.x AS t.x / Int64(2), t.x / Int64(2)Int64(2)t.x AS t.x / Int64(2) + Int64(1) +--Projection: t.x / Int64(2) AS t.x / Int64(2)Int64(2)t.x +----TableScan: t projection=[x] +physical_plan +ProjectionExec: expr=[t.x / Int64(2)Int64(2)t.x@0 as t.x / Int64(2), t.x / Int64(2)Int64(2)t.x@0 + 1 as t.x / Int64(2) + Int64(1)] +--ProjectionExec: expr=[x@0 / 2 as t.x / Int64(2)Int64(2)t.x] +----MemoryExec: partitions=1, partition_sizes=[1] + +query II +SELECT x/2, x/2+1 FROM t; +---- +0 1 +0 1 + +query TT +EXPLAIN SELECT abs(x), abs(x) + abs(y) FROM t; +---- +logical_plan +Projection: abs(t.x)t.x AS abs(t.x), abs(t.x)t.x AS abs(t.x) + abs(t.y) +--Projection: abs(t.x) AS abs(t.x)t.x, t.y +----TableScan: t projection=[x, y] +physical_plan +ProjectionExec: expr=[abs(t.x)t.x@0 as abs(t.x), abs(t.x)t.x@0 + abs(y@1) as abs(t.x) + abs(t.y)] +--ProjectionExec: expr=[abs(x@0) as abs(t.x)t.x, y@1 as y] +----MemoryExec: partitions=1, partition_sizes=[1] + +query II +SELECT abs(x), abs(x) + abs(y) FROM t; +---- +1 3 +1 4 + +statement ok +DROP TABLE t; diff --git a/datafusion/sqllogictest/test_files/subquery.slt b/datafusion/sqllogictest/test_files/subquery.slt index 430e676fa477..3e0fcb7aa96e 100644 --- a/datafusion/sqllogictest/test_files/subquery.slt +++ b/datafusion/sqllogictest/test_files/subquery.slt @@ -49,6 +49,13 @@ CREATE TABLE t2(t2_id INT, t2_name TEXT, t2_int INT) AS VALUES (44, 'x', 3), (55, 'w', 3); +statement ok +CREATE TABLE t3(t3_id INT PRIMARY KEY, t3_name TEXT, t3_int INT) AS VALUES +(11, 'e', 3), +(22, 'f', 1), +(44, 'g', 3), +(55, 'h', 3); + statement ok CREATE EXTERNAL TABLE IF NOT EXISTS customer ( c_custkey BIGINT, @@ -419,6 +426,17 @@ SELECT t1_id, t1_name, t1_int FROM t1 order by t1_int in (SELECT t2_int FROM t2 statement error DataFusion error: check_analyzed_plan\ncaused by\nError during planning: Correlated scalar subquery must be aggregated to return at most one row SELECT t1_id, (SELECT t2_int FROM t2 WHERE t2.t2_int = t1.t1_int) as t2_int from t1 +#non_aggregated_correlated_scalar_subquery_unique +query II rowsort +SELECT t1_id, (SELECT t3_int FROM t3 WHERE t3.t3_id = t1.t1_id) as t3_int from t1 +---- +11 3 +22 1 +33 NULL +44 3 + + +#non_aggregated_correlated_scalar_subquery statement error DataFusion error: check_analyzed_plan\ncaused by\nError during planning: Correlated scalar subquery must be aggregated to return at most one row SELECT t1_id, (SELECT t2_int FROM t2 WHERE t2.t2_int = t1_int group by t2_int) as t2_int from t1 diff --git a/datafusion/sqllogictest/test_files/timestamps.slt b/datafusion/sqllogictest/test_files/timestamps.slt index 3830d8f86812..71b6ddf33f39 100644 --- a/datafusion/sqllogictest/test_files/timestamps.slt +++ b/datafusion/sqllogictest/test_files/timestamps.slt @@ -291,6 +291,35 @@ SELECT COUNT(*) FROM ts_data_secs where ts > to_timestamp_seconds('2020-09-08T12 ---- 2 + +# to_timestamp float inputs + +query PPP +SELECT to_timestamp(1.1) as c1, cast(1.1 as timestamp) as c2, 1.1::timestamp as c3; +---- +1970-01-01T00:00:01.100 1970-01-01T00:00:01.100 1970-01-01T00:00:01.100 + +query PPP +SELECT to_timestamp(-1.1) as c1, cast(-1.1 as timestamp) as c2, (-1.1)::timestamp as c3; +---- +1969-12-31T23:59:58.900 1969-12-31T23:59:58.900 1969-12-31T23:59:58.900 + +query PPP +SELECT to_timestamp(0.0) as c1, cast(0.0 as timestamp) as c2, 0.0::timestamp as c3; +---- +1970-01-01T00:00:00 1970-01-01T00:00:00 1970-01-01T00:00:00 + +query PPP +SELECT to_timestamp(1.23456789) as c1, cast(1.23456789 as timestamp) as c2, 1.23456789::timestamp as c3; +---- +1970-01-01T00:00:01.234567890 1970-01-01T00:00:01.234567890 1970-01-01T00:00:01.234567890 + +query PPP +SELECT to_timestamp(123456789.123456789) as c1, cast(123456789.123456789 as timestamp) as c2, 123456789.123456789::timestamp as c3; +---- +1973-11-29T21:33:09.123456784 1973-11-29T21:33:09.123456784 1973-11-29T21:33:09.123456784 + + # from_unixtime # 1599566400 is '2020-09-08T12:00:00+00:00' diff --git a/datafusion/sqllogictest/test_files/union.slt b/datafusion/sqllogictest/test_files/union.slt index 0f255cdb9fb9..b4e338875e24 100644 --- a/datafusion/sqllogictest/test_files/union.slt +++ b/datafusion/sqllogictest/test_files/union.slt @@ -82,6 +82,11 @@ SELECT 2 as x 1 2 +query I +select count(*) from (select id from t1 union all select id from t2) +---- +6 + # csv_union_all statement ok CREATE EXTERNAL TABLE aggregate_test_100 ( @@ -546,11 +551,11 @@ UnionExec ------CoalesceBatchesExec: target_batch_size=2 --------RepartitionExec: partitioning=Hash([Int64(1)@0], 4), input_partitions=1 ----------AggregateExec: mode=Partial, gby=[1 as Int64(1)], aggr=[] -------------EmptyExec: produce_one_row=true +------------PlaceholderRowExec --ProjectionExec: expr=[2 as a] -----EmptyExec: produce_one_row=true +----PlaceholderRowExec --ProjectionExec: expr=[3 as a] -----EmptyExec: produce_one_row=true +----PlaceholderRowExec # test UNION ALL aliases correctly with aliased subquery query TT @@ -578,7 +583,7 @@ UnionExec --------RepartitionExec: partitioning=Hash([n@0], 4), input_partitions=1 ----------AggregateExec: mode=Partial, gby=[n@0 as n], aggr=[COUNT(*)] ------------ProjectionExec: expr=[5 as n] ---------------EmptyExec: produce_one_row=true +--------------PlaceholderRowExec --ProjectionExec: expr=[1 as count, MAX(Int64(10))@0 as n] ----AggregateExec: mode=Single, gby=[], aggr=[MAX(Int64(10))] -------EmptyExec: produce_one_row=true +------PlaceholderRowExec diff --git a/datafusion/sqllogictest/test_files/window.slt b/datafusion/sqllogictest/test_files/window.slt index b2491478d84e..f3de5b54fc8b 100644 --- a/datafusion/sqllogictest/test_files/window.slt +++ b/datafusion/sqllogictest/test_files/window.slt @@ -279,13 +279,13 @@ SortPreservingMergeExec: [b@0 ASC NULLS LAST] ------------AggregateExec: mode=Partial, gby=[b@1 as b], aggr=[MAX(d.a)] --------------UnionExec ----------------ProjectionExec: expr=[1 as a, aa as b] -------------------EmptyExec: produce_one_row=true +------------------PlaceholderRowExec ----------------ProjectionExec: expr=[3 as a, aa as b] -------------------EmptyExec: produce_one_row=true +------------------PlaceholderRowExec ----------------ProjectionExec: expr=[5 as a, bb as b] -------------------EmptyExec: produce_one_row=true +------------------PlaceholderRowExec ----------------ProjectionExec: expr=[7 as a, bb as b] -------------------EmptyExec: produce_one_row=true +------------------PlaceholderRowExec # Check actual result: query TI @@ -365,13 +365,13 @@ SortPreservingMergeExec: [b@0 ASC NULLS LAST] --------------RepartitionExec: partitioning=Hash([b@1], 4), input_partitions=4 ----------------UnionExec ------------------ProjectionExec: expr=[1 as a, aa as b] ---------------------EmptyExec: produce_one_row=true +--------------------PlaceholderRowExec ------------------ProjectionExec: expr=[3 as a, aa as b] ---------------------EmptyExec: produce_one_row=true +--------------------PlaceholderRowExec ------------------ProjectionExec: expr=[5 as a, bb as b] ---------------------EmptyExec: produce_one_row=true +--------------------PlaceholderRowExec ------------------ProjectionExec: expr=[7 as a, bb as b] ---------------------EmptyExec: produce_one_row=true +--------------------PlaceholderRowExec # check actual result @@ -3581,3 +3581,215 @@ CREATE TABLE new_table AS SELECT NTILE(2) OVER(ORDER BY c1) AS ntile_2 FROM aggr statement ok DROP TABLE new_table; + +statement ok +CREATE TABLE t1 (a int) AS VALUES (1), (2), (3); + +query I +SELECT NTILE(9223377) OVER(ORDER BY a) FROM t1; +---- +1 +2 +3 + +query I +SELECT NTILE(9223372036854775809) OVER(ORDER BY a) FROM t1; +---- +1 +2 +3 + +query error DataFusion error: Execution error: NTILE requires a positive integer +SELECT NTILE(-922337203685477580) OVER(ORDER BY a) FROM t1; + +query error DataFusion error: Execution error: Table 't' doesn't exist\. +DROP TABLE t; + +# NTILE with PARTITION BY, those tests from duckdb: https://github.com/duckdb/duckdb/blob/main/test/sql/window/test_ntile.test +statement ok +CREATE TABLE score_board (team_name VARCHAR, player VARCHAR, score INTEGER) as VALUES + ('Mongrels', 'Apu', 350), + ('Mongrels', 'Ned', 666), + ('Mongrels', 'Meg', 1030), + ('Mongrels', 'Burns', 1270), + ('Simpsons', 'Homer', 1), + ('Simpsons', 'Lisa', 710), + ('Simpsons', 'Marge', 990), + ('Simpsons', 'Bart', 2010) + +query TTII +SELECT + team_name, + player, + score, + NTILE(2) OVER (PARTITION BY team_name ORDER BY score ASC) AS NTILE +FROM score_board s +ORDER BY team_name, score; +---- +Mongrels Apu 350 1 +Mongrels Ned 666 1 +Mongrels Meg 1030 2 +Mongrels Burns 1270 2 +Simpsons Homer 1 1 +Simpsons Lisa 710 1 +Simpsons Marge 990 2 +Simpsons Bart 2010 2 + +query TTII +SELECT + team_name, + player, + score, + NTILE(2) OVER (ORDER BY score ASC) AS NTILE +FROM score_board s +ORDER BY score; +---- +Simpsons Homer 1 1 +Mongrels Apu 350 1 +Mongrels Ned 666 1 +Simpsons Lisa 710 1 +Simpsons Marge 990 2 +Mongrels Meg 1030 2 +Mongrels Burns 1270 2 +Simpsons Bart 2010 2 + +query TTII +SELECT + team_name, + player, + score, + NTILE(1000) OVER (PARTITION BY team_name ORDER BY score ASC) AS NTILE +FROM score_board s +ORDER BY team_name, score; +---- +Mongrels Apu 350 1 +Mongrels Ned 666 2 +Mongrels Meg 1030 3 +Mongrels Burns 1270 4 +Simpsons Homer 1 1 +Simpsons Lisa 710 2 +Simpsons Marge 990 3 +Simpsons Bart 2010 4 + +query TTII +SELECT + team_name, + player, + score, + NTILE(1) OVER (PARTITION BY team_name ORDER BY score ASC) AS NTILE +FROM score_board s +ORDER BY team_name, score; +---- +Mongrels Apu 350 1 +Mongrels Ned 666 1 +Mongrels Meg 1030 1 +Mongrels Burns 1270 1 +Simpsons Homer 1 1 +Simpsons Lisa 710 1 +Simpsons Marge 990 1 +Simpsons Bart 2010 1 + +# incorrect number of parameters for ntile +query error DataFusion error: Execution error: NTILE requires a positive integer, but finds NULL +SELECT + NTILE(NULL) OVER (PARTITION BY team_name ORDER BY score ASC) AS NTILE +FROM score_board s + +query error DataFusion error: Execution error: NTILE requires a positive integer +SELECT + NTILE(-1) OVER (PARTITION BY team_name ORDER BY score ASC) AS NTILE +FROM score_board s + +query error DataFusion error: Execution error: NTILE requires a positive integer +SELECT + NTILE(0) OVER (PARTITION BY team_name ORDER BY score ASC) AS NTILE +FROM score_board s + +statement error +SELECT + NTILE() OVER (PARTITION BY team_name ORDER BY score ASC) AS NTILE +FROM score_board s + +statement error +SELECT + NTILE(1,2) OVER (PARTITION BY team_name ORDER BY score ASC) AS NTILE +FROM score_board s + +statement error +SELECT + NTILE(1,2,3) OVER (PARTITION BY team_name ORDER BY score ASC) AS NTILE +FROM score_board s + +statement error +SELECT + NTILE(1,2,3,4) OVER (PARTITION BY team_name ORDER BY score ASC) AS NTILE +FROM score_board s + +statement ok +DROP TABLE score_board; + +# Regularize RANGE frame +query error DataFusion error: Error during planning: RANGE requires exactly one ORDER BY column +select a, + rank() over (order by a, a + 1 RANGE BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING) rnk + from (select 1 a union select 2 a) q ORDER BY a + +query II +select a, + rank() over (order by a RANGE BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING) rnk + from (select 1 a union select 2 a) q ORDER BY a +---- +1 1 +2 2 + +query error DataFusion error: Error during planning: RANGE requires exactly one ORDER BY column +select a, + rank() over (RANGE BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING) rnk + from (select 1 a union select 2 a) q ORDER BY a + +query II +select a, + rank() over (order by a, a + 1 RANGE BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) rnk + from (select 1 a union select 2 a) q ORDER BY a +---- +1 1 +2 2 + +query II +select a, + rank() over (order by a RANGE BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) rnk + from (select 1 a union select 2 a) q ORDER BY a +---- +1 1 +2 2 + +query II +select a, + rank() over (RANGE BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) rnk + from (select 1 a union select 2 a) q ORDER BY a +---- +1 1 +2 1 + +query I +select rank() over (RANGE between UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) rnk + from (select 1 a union select 2 a) q; +---- +1 +1 + +query II +select a, + rank() over (order by 1 RANGE BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) rnk + from (select 1 a union select 2 a) q ORDER BY a +---- +1 1 +2 1 + +query II +select a, + rank() over (order by null RANGE BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) rnk + from (select 1 a union select 2 a) q ORDER BY a +---- +1 1 +2 1 diff --git a/datafusion/substrait/src/logical_plan/consumer.rs b/datafusion/substrait/src/logical_plan/consumer.rs index cf05d814a5cb..ffc9d094ab91 100644 --- a/datafusion/substrait/src/logical_plan/consumer.rs +++ b/datafusion/substrait/src/logical_plan/consumer.rs @@ -434,6 +434,15 @@ pub async fn from_substrait_rel( None => plan_err!("JoinRel without join condition is not allowed"), } } + Some(RelType::Cross(cross)) => { + let left: LogicalPlanBuilder = LogicalPlanBuilder::from( + from_substrait_rel(ctx, cross.left.as_ref().unwrap(), extensions).await?, + ); + let right = + from_substrait_rel(ctx, cross.right.as_ref().unwrap(), extensions) + .await?; + left.cross_join(right)?.build() + } Some(RelType::Read(read)) => match &read.as_ref().read_type { Some(ReadType::NamedTable(nt)) => { let table_reference = match nt.names.len() { diff --git a/datafusion/substrait/src/logical_plan/producer.rs b/datafusion/substrait/src/logical_plan/producer.rs index d576e70711df..c5f1278be6e0 100644 --- a/datafusion/substrait/src/logical_plan/producer.rs +++ b/datafusion/substrait/src/logical_plan/producer.rs @@ -19,7 +19,7 @@ use std::collections::HashMap; use std::ops::Deref; use std::sync::Arc; -use datafusion::logical_expr::{Distinct, Like, WindowFrameUnits}; +use datafusion::logical_expr::{CrossJoin, Distinct, Like, WindowFrameUnits}; use datafusion::{ arrow::datatypes::{DataType, TimeUnit}, error::{DataFusionError, Result}, @@ -40,6 +40,7 @@ use datafusion::logical_expr::{expr, Between, JoinConstraint, LogicalPlan, Opera use datafusion::prelude::Expr; use prost_types::Any as ProtoAny; use substrait::proto::expression::window_function::BoundsType; +use substrait::proto::CrossRel; use substrait::{ proto::{ aggregate_function::AggregationInvocation, @@ -332,6 +333,23 @@ pub fn to_substrait_rel( }))), })) } + LogicalPlan::CrossJoin(cross_join) => { + let CrossJoin { + left, + right, + schema: _, + } = cross_join; + let left = to_substrait_rel(left.as_ref(), ctx, extension_info)?; + let right = to_substrait_rel(right.as_ref(), ctx, extension_info)?; + Ok(Box::new(Rel { + rel_type: Some(RelType::Cross(Box::new(CrossRel { + common: None, + left: Some(left), + right: Some(right), + advanced_extension: None, + }))), + })) + } LogicalPlan::SubqueryAlias(alias) => { // Do nothing if encounters SubqueryAlias // since there is no corresponding relation type in Substrait diff --git a/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs b/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs index 1c5dbe9ce884..691fba864449 100644 --- a/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs @@ -394,6 +394,11 @@ async fn roundtrip_inlist_4() -> Result<()> { roundtrip("SELECT * FROM data WHERE f NOT IN ('a', 'b', 'c', 'd')").await } +#[tokio::test] +async fn roundtrip_cross_join() -> Result<()> { + roundtrip("SELECT * FROM data CROSS JOIN data2").await +} + #[tokio::test] async fn roundtrip_inner_join() -> Result<()> { roundtrip("SELECT data.a FROM data JOIN data2 ON data.a = data2.a").await diff --git a/docs/source/library-user-guide/custom-table-providers.md b/docs/source/library-user-guide/custom-table-providers.md index ca0e9de779ef..9da207da68f3 100644 --- a/docs/source/library-user-guide/custom-table-providers.md +++ b/docs/source/library-user-guide/custom-table-providers.md @@ -25,7 +25,7 @@ This section will also touch on how to have DataFusion use the new `TableProvide ## Table Provider and Scan -The `scan` method on the `TableProvider` is likely its most important. It returns an `ExecutionPlan` that DataFusion will use to read the actual data during execution o the query. +The `scan` method on the `TableProvider` is likely its most important. It returns an `ExecutionPlan` that DataFusion will use to read the actual data during execution of the query. ### Scan diff --git a/docs/source/library-user-guide/working-with-exprs.md b/docs/source/library-user-guide/working-with-exprs.md index a8baf24d5f0a..96be8ef7f1ae 100644 --- a/docs/source/library-user-guide/working-with-exprs.md +++ b/docs/source/library-user-guide/working-with-exprs.md @@ -17,7 +17,7 @@ under the License. --> -# Working with Exprs +# Working with `Expr`s @@ -48,12 +48,11 @@ As another example, the SQL expression `a + b * c` would be represented as an `E └────────────────────┘ └────────────────────┘ ``` -As the writer of a library, you may want to use or create `Expr`s to represent computations that you want to perform. This guide will walk you through how to make your own scalar UDF as an `Expr` and how to rewrite `Expr`s to inline the simple UDF. +As the writer of a library, you can use `Expr`s to represent computations that you want to perform. This guide will walk you through how to make your own scalar UDF as an `Expr` and how to rewrite `Expr`s to inline the simple UDF. -There are also executable examples for working with `Expr`s: +## Creating and Evaluating `Expr`s -- [rewrite_expr.rs](https://github.com/apache/arrow-datafusion/blob/main/datafusion-examples/examples/rewrite_expr.rs) -- [expr_api.rs](https://github.com/apache/arrow-datafusion/blob/main/datafusion-examples/examples/expr_api.rs) +Please see [expr_api.rs](https://github.com/apache/arrow-datafusion/blob/main/datafusion-examples/examples/expr_api.rs) for well commented code for creating, evaluating, simplifying, and analyzing `Expr`s. ## A Scalar UDF Example @@ -79,7 +78,9 @@ let expr = add_one_udf.call(vec![col("my_column")]); If you'd like to learn more about `Expr`s, before we get into the details of creating and rewriting them, you can read the [expression user-guide](./../user-guide/expressions.md). -## Rewriting Exprs +## Rewriting `Expr`s + +[rewrite_expr.rs](https://github.com/apache/arrow-datafusion/blob/main/datafusion-examples/examples/rewrite_expr.rs) contains example code for rewriting `Expr`s. Rewriting Expressions is the process of taking an `Expr` and transforming it into another `Expr`. This is useful for a number of reasons, including: diff --git a/docs/source/user-guide/cli.md b/docs/source/user-guide/cli.md index e8fdae7bb097..525ab090ce51 100644 --- a/docs/source/user-guide/cli.md +++ b/docs/source/user-guide/cli.md @@ -31,7 +31,9 @@ The easiest way to install DataFusion CLI a spin is via `cargo install datafusio ### Install and run using Homebrew (on MacOS) -DataFusion CLI can also be installed via Homebrew (on MacOS). Install it as any other pre-built software like this: +DataFusion CLI can also be installed via Homebrew (on MacOS). If you don't have Homebrew installed, you can check how to install it [here](https://docs.brew.sh/Installation). + +Install it as any other pre-built software like this: ```bash brew install datafusion @@ -46,6 +48,34 @@ brew install datafusion datafusion-cli ``` +### Install and run using PyPI + +DataFusion CLI can also be installed via PyPI. You can check how to install PyPI [here](https://pip.pypa.io/en/latest/installation/). + +Install it as any other pre-built software like this: + +```bash +pip3 install datafusion +# Defaulting to user installation because normal site-packages is not writeable +# Collecting datafusion +# Downloading datafusion-33.0.0-cp38-abi3-macosx_11_0_arm64.whl.metadata (9.6 kB) +# Collecting pyarrow>=11.0.0 (from datafusion) +# Downloading pyarrow-14.0.1-cp39-cp39-macosx_11_0_arm64.whl.metadata (3.0 kB) +# Requirement already satisfied: numpy>=1.16.6 in /Users/Library/Python/3.9/lib/python/site-packages (from pyarrow>=11.0.0->datafusion) (1.23.4) +# Downloading datafusion-33.0.0-cp38-abi3-macosx_11_0_arm64.whl (13.5 MB) +# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 13.5/13.5 MB 3.6 MB/s eta 0:00:00 +# Downloading pyarrow-14.0.1-cp39-cp39-macosx_11_0_arm64.whl (24.0 MB) +# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 24.0/24.0 MB 36.4 MB/s eta 0:00:00 +# Installing collected packages: pyarrow, datafusion +# Attempting uninstall: pyarrow +# Found existing installation: pyarrow 10.0.1 +# Uninstalling pyarrow-10.0.1: +# Successfully uninstalled pyarrow-10.0.1 +# Successfully installed datafusion-33.0.0 pyarrow-14.0.1 + +datafusion-cli +``` + ### Run using Docker There is no officially published Docker image for the DataFusion CLI, so it is necessary to build from source diff --git a/docs/source/user-guide/configs.md b/docs/source/user-guide/configs.md index 11363f0657f6..d5a43e429e09 100644 --- a/docs/source/user-guide/configs.md +++ b/docs/source/user-guide/configs.md @@ -99,6 +99,7 @@ Environment variables are read during `SessionConfig` initialisation so they mus | datafusion.optimizer.top_down_join_key_reordering | true | When set to true, the physical plan optimizer will run a top down process to reorder the join keys | | datafusion.optimizer.prefer_hash_join | true | When set to true, the physical plan optimizer will prefer HashJoin over SortMergeJoin. HashJoin can work more efficiently than SortMergeJoin but consumes more memory | | datafusion.optimizer.hash_join_single_partition_threshold | 1048576 | The maximum estimated size in bytes for one input side of a HashJoin will be collected into a single partition | +| datafusion.optimizer.default_filter_selectivity | 20 | The default filter selectivity used by Filter Statistics when an exact selectivity cannot be determined. Valid values are between 0 (no selectivity) and 100 (all rows are selected). | | datafusion.explain.logical_plan_only | false | When set to true, the explain statement will only print logical plans | | datafusion.explain.physical_plan_only | false | When set to true, the explain statement will only print physical plans | | datafusion.explain.show_statistics | false | When set to true, the explain statement will print operator statistics for physical plans | diff --git a/docs/source/user-guide/expressions.md b/docs/source/user-guide/expressions.md index 257c50dfa497..b8689e556741 100644 --- a/docs/source/user-guide/expressions.md +++ b/docs/source/user-guide/expressions.md @@ -215,6 +215,7 @@ Unlike to some databases the math functions in Datafusion works the same way as | array_has_all(array, sub-array) | Returns true if all elements of sub-array exist in array `array_has_all([1,2,3], [1,3]) -> true` | | array_has_any(array, sub-array) | Returns true if any elements exist in both arrays `array_has_any([1,2,3], [1,4]) -> true` | | array_dims(array) | Returns an array of the array's dimensions. `array_dims([[1, 2, 3], [4, 5, 6]]) -> [2, 3]` | +| array_distinct(array) | Returns distinct values from the array after removing duplicates. `array_distinct([1, 3, 2, 3, 1, 2, 4]) -> [1, 2, 3, 4]` | | array_element(array, index) | Extracts the element with the index n from the array `array_element([1, 2, 3, 4], 3) -> 3` | | flatten(array) | Converts an array of arrays to a flat array `flatten([[1], [2, 3], [4, 5, 6]]) -> [1, 2, 3, 4, 5, 6]` | | array_length(array, dimension) | Returns the length of the array dimension. `array_length([1, 2, 3, 4, 5]) -> 5` | diff --git a/docs/source/user-guide/sql/scalar_functions.md b/docs/source/user-guide/sql/scalar_functions.md index c0889d94dbac..9a9bec9df77b 100644 --- a/docs/source/user-guide/sql/scalar_functions.md +++ b/docs/source/user-guide/sql/scalar_functions.md @@ -1457,11 +1457,14 @@ extract(field FROM source) ### `to_timestamp` Converts a value to a timestamp (`YYYY-MM-DDT00:00:00Z`). -Supports strings, integer, and unsigned integer types as input. +Supports strings, integer, unsigned integer, and double types as input. Strings are parsed as RFC3339 (e.g. '2023-07-20T05:44:00') -Integers and unsigned integers are interpreted as seconds since the unix epoch (`1970-01-01T00:00:00Z`) +Integers, unsigned integers, and doubles are interpreted as seconds since the unix epoch (`1970-01-01T00:00:00Z`) return the corresponding timestamp. +Note: `to_timestamp` returns `Timestamp(Nanosecond)`. The supported range for integer input is between `-9223372037` and `9223372036`. +Supported range for string input is between `1677-09-21T00:12:44.0` and `2262-04-11T23:47:16.0`. Please use `to_timestamp_seconds` for the input outside of supported bounds. + ``` to_timestamp(expression) ``` @@ -1552,6 +1555,7 @@ from_unixtime(expression) ## Array Functions - [array_append](#array_append) +- [array_sort](#array_sort) - [array_cat](#array_cat) - [array_concat](#array_concat) - [array_contains](#array_contains) @@ -1581,6 +1585,7 @@ from_unixtime(expression) - [cardinality](#cardinality) - [empty](#empty) - [list_append](#list_append) +- [list_sort](#list_sort) - [list_cat](#list_cat) - [list_concat](#list_concat) - [list_dims](#list_dims) @@ -1642,6 +1647,36 @@ array_append(array, element) - list_append - list_push_back +### `array_sort` + +Sort array. + +``` +array_sort(array, desc, nulls_first) +``` + +#### Arguments + +- **array**: Array expression. + Can be a constant, column, or function, and any combination of array operators. +- **desc**: Whether to sort in descending order(`ASC` or `DESC`). +- **nulls_first**: Whether to sort nulls first(`NULLS FIRST` or `NULLS LAST`). + +#### Example + +``` +❯ select array_sort([3, 1, 2]); ++-----------------------------+ +| array_sort(List([3,1,2])) | ++-----------------------------+ +| [1, 2, 3] | ++-----------------------------+ +``` + +#### Aliases + +- list_sort + ### `array_cat` _Alias of [array_concat](#array_concat)._ @@ -2368,7 +2403,7 @@ array_except(array1, array2) +----------------------------------------------------+ | array_except([1, 2, 3, 4], [3, 4, 5, 6]); | +----------------------------------------------------+ -| [3, 4] | +| [1, 2] | +----------------------------------------------------+ ``` @@ -2430,6 +2465,10 @@ empty(array) _Alias of [array_append](#array_append)._ +### `list_sort` + +_Alias of [array_sort](#array_sort)._ + ### `list_cat` _Alias of [array_concat](#array_concat)._