diff --git a/benchmarks/src/bin/dfbench.rs b/benchmarks/src/bin/dfbench.rs index 81aa5437dd5f..887d23a9d7c3 100644 --- a/benchmarks/src/bin/dfbench.rs +++ b/benchmarks/src/bin/dfbench.rs @@ -53,12 +53,12 @@ pub async fn main() -> Result<()> { env_logger::init(); match Options::from_args() { - Options::Tpch(opt) => opt.run().await, + Options::Tpch(opt) => Box::pin(opt.run()).await, Options::TpchConvert(opt) => opt.run().await, Options::Clickbench(opt) => opt.run().await, Options::ParquetFilter(opt) => opt.run().await, Options::Sort(opt) => opt.run().await, Options::SortTpch(opt) => opt.run().await, - Options::Imdb(opt) => opt.run().await, + Options::Imdb(opt) => Box::pin(opt.run()).await, } } diff --git a/benchmarks/src/bin/external_aggr.rs b/benchmarks/src/bin/external_aggr.rs index 6438593a20a0..86e5f7cbf814 100644 --- a/benchmarks/src/bin/external_aggr.rs +++ b/benchmarks/src/bin/external_aggr.rs @@ -245,9 +245,9 @@ impl ExternalAggrConfig { table, start.elapsed().as_millis() ); - ctx.register_table(table, Arc::new(memtable))?; + ctx.register_table(table, Arc::new(memtable)).await?; } else { - ctx.register_table(table, table_provider)?; + ctx.register_table(table, table_provider).await?; } } Ok(()) diff --git a/benchmarks/src/bin/h2o.rs b/benchmarks/src/bin/h2o.rs index 1ddeb786a591..81dd0a4cc7fd 100644 --- a/benchmarks/src/bin/h2o.rs +++ b/benchmarks/src/bin/h2o.rs @@ -94,7 +94,7 @@ async fn group_by(opt: &GroupBy) -> Result<()> { let partition_size = num_cpus::get(); let memtable = MemTable::load(Arc::new(csv), Some(partition_size), &ctx.state()).await?; - ctx.register_table("x", Arc::new(memtable))?; + ctx.register_table("x", Arc::new(memtable)).await?; } else { ctx.register_csv("x", path, CsvReadOptions::default().schema(&schema)) .await?; diff --git a/benchmarks/src/bin/imdb.rs b/benchmarks/src/bin/imdb.rs index 13421f8a89a9..5ce99928df66 100644 --- a/benchmarks/src/bin/imdb.rs +++ b/benchmarks/src/bin/imdb.rs @@ -53,7 +53,7 @@ pub async fn main() -> Result<()> { env_logger::init(); match ImdbOpt::from_args() { ImdbOpt::Benchmark(BenchmarkSubCommandOpt::DataFusionBenchmark(opt)) => { - opt.run().await + Box::pin(opt.run()).await } ImdbOpt::Convert(opt) => opt.run().await, } diff --git a/benchmarks/src/bin/tpch.rs b/benchmarks/src/bin/tpch.rs index 3270b082cfb4..ca2bb8e57c0e 100644 --- a/benchmarks/src/bin/tpch.rs +++ b/benchmarks/src/bin/tpch.rs @@ -58,7 +58,7 @@ async fn main() -> Result<()> { env_logger::init(); match TpchOpt::from_args() { TpchOpt::Benchmark(BenchmarkSubCommandOpt::DataFusionBenchmark(opt)) => { - opt.run().await + Box::pin(opt.run()).await } TpchOpt::Convert(opt) => opt.run().await, } diff --git a/benchmarks/src/imdb/run.rs b/benchmarks/src/imdb/run.rs index 47c356990881..9ea099d2fcb5 100644 --- a/benchmarks/src/imdb/run.rs +++ b/benchmarks/src/imdb/run.rs @@ -358,9 +358,9 @@ impl RunOpt { table, start.elapsed().as_millis() ); - ctx.register_table(*table, Arc::new(memtable))?; + ctx.register_table(*table, Arc::new(memtable)).await?; } else { - ctx.register_table(*table, table_provider)?; + ctx.register_table(*table, table_provider).await?; } } Ok(()) diff --git a/benchmarks/src/sort_tpch.rs b/benchmarks/src/sort_tpch.rs index 4b83b3b8889a..546fb4267860 100644 --- a/benchmarks/src/sort_tpch.rs +++ b/benchmarks/src/sort_tpch.rs @@ -236,9 +236,9 @@ impl RunOpt { table, start.elapsed().as_millis() ); - ctx.register_table(table, Arc::new(memtable))?; + ctx.register_table(table, Arc::new(memtable)).await?; } else { - ctx.register_table(table, table_provider)?; + ctx.register_table(table, table_provider).await?; } } Ok(()) diff --git a/benchmarks/src/tpch/run.rs b/benchmarks/src/tpch/run.rs index 9ff1f72d8606..531693d0d4a3 100644 --- a/benchmarks/src/tpch/run.rs +++ b/benchmarks/src/tpch/run.rs @@ -182,9 +182,9 @@ impl RunOpt { table, start.elapsed().as_millis() ); - ctx.register_table(*table, Arc::new(memtable))?; + ctx.register_table(*table, Arc::new(memtable)).await?; } else { - ctx.register_table(*table, table_provider)?; + ctx.register_table(*table, table_provider).await?; } } Ok(()) diff --git a/datafusion-cli/Cargo.lock b/datafusion-cli/Cargo.lock index d2a92fea311e..946c7600768e 100644 --- a/datafusion-cli/Cargo.lock +++ b/datafusion-cli/Cargo.lock @@ -1248,6 +1248,7 @@ dependencies = [ "datafusion-execution", "datafusion-expr", "datafusion-physical-plan", + "futures", "parking_lot", ] diff --git a/datafusion-cli/src/catalog.rs b/datafusion-cli/src/catalog.rs index ceb72dbc546b..81c92e42a0ed 100644 --- a/datafusion-cli/src/catalog.rs +++ b/datafusion-cli/src/catalog.rs @@ -31,6 +31,7 @@ use datafusion::execution::session_state::SessionStateBuilder; use async_trait::async_trait; use dirs::home_dir; +use futures::stream::BoxStream; use parking_lot::RwLock; /// Wraps another catalog, automatically register require object stores for the file locations @@ -49,28 +50,29 @@ impl DynamicObjectStoreCatalog { } } +#[async_trait] impl CatalogProviderList for DynamicObjectStoreCatalog { fn as_any(&self) -> &dyn Any { self } - fn register_catalog( + async fn register_catalog( &self, name: String, catalog: Arc, - ) -> Option> { - self.inner.register_catalog(name, catalog) + ) -> Result>> { + self.inner.register_catalog(name, catalog).await } - fn catalog_names(&self) -> Vec { - self.inner.catalog_names() + async fn catalog_names(&self) -> BoxStream<'static, Result> { + self.inner.catalog_names().await } - fn catalog(&self, name: &str) -> Option> { + async fn catalog(&self, name: &str) -> Result>> { let state = self.state.clone(); - self.inner.catalog(name).map(|catalog| { + Ok(self.inner.catalog(name).await?.map(|catalog| { Arc::new(DynamicObjectStoreCatalogProvider::new(catalog, state)) as _ - }) + })) } } @@ -90,28 +92,29 @@ impl DynamicObjectStoreCatalogProvider { } } +#[async_trait] impl CatalogProvider for DynamicObjectStoreCatalogProvider { fn as_any(&self) -> &dyn Any { self } - fn schema_names(&self) -> Vec { - self.inner.schema_names() + async fn schema_names(&self) -> BoxStream<'static, Result> { + self.inner.schema_names().await } - fn schema(&self, name: &str) -> Option> { + async fn schema(&self, name: &str) -> Result>> { let state = self.state.clone(); - self.inner.schema(name).map(|schema| { + Ok(self.inner.schema(name).await?.map(|schema| { Arc::new(DynamicObjectStoreSchemaProvider::new(schema, state)) as _ - }) + })) } - fn register_schema( + async fn register_schema( &self, name: &str, schema: Arc, ) -> Result>> { - self.inner.register_schema(name, schema) + self.inner.register_schema(name, schema).await } } @@ -138,16 +141,16 @@ impl SchemaProvider for DynamicObjectStoreSchemaProvider { self } - fn table_names(&self) -> Vec { - self.inner.table_names() + async fn table_names(&self) -> BoxStream<'static, Result> { + self.inner.table_names().await } - fn register_table( + async fn register_table( &self, name: String, table: Arc, ) -> Result>> { - self.inner.register_table(name, table) + self.inner.register_table(name, table).await } async fn table(&self, name: &str) -> Result>> { @@ -166,7 +169,7 @@ impl SchemaProvider for DynamicObjectStoreSchemaProvider { .ok_or_else(|| plan_datafusion_err!("locking error"))? .read() .clone(); - let mut builder = SessionStateBuilder::from(state.clone()); + let mut builder = SessionStateBuilder::new_from_existing(state.clone()).await; let optimized_name = substitute_tilde(name.to_owned()); let table_url = ListingTableUrl::parse(optimized_name.as_str())?; let scheme = table_url.scheme(); @@ -194,7 +197,7 @@ impl SchemaProvider for DynamicObjectStoreSchemaProvider { } _ => {} }; - state = builder.build(); + state = builder.build().await; let store = get_object_store( &state, table_url.scheme(), @@ -208,12 +211,15 @@ impl SchemaProvider for DynamicObjectStoreSchemaProvider { self.inner.table(name).await } - fn deregister_table(&self, name: &str) -> Result>> { - self.inner.deregister_table(name) + async fn deregister_table( + &self, + name: &str, + ) -> Result>> { + self.inner.deregister_table(name).await } - fn table_exist(&self, name: &str) -> bool { - self.inner.table_exist(name) + async fn table_exist(&self, name: &str) -> bool { + self.inner.table_exist(name).await } } @@ -234,8 +240,9 @@ mod tests { use datafusion::catalog::SchemaProvider; use datafusion::prelude::SessionContext; + use futures::TryStreamExt; - fn setup_context() -> (SessionContext, Arc) { + async fn setup_context() -> (SessionContext, Arc) { let ctx = SessionContext::new(); ctx.register_catalog_list(Arc::new(DynamicObjectStoreCatalog::new( ctx.state().catalog_list().clone(), @@ -247,10 +254,30 @@ mod tests { ctx.state_weak_ref(), ) as &dyn CatalogProviderList; let catalog = provider - .catalog(provider.catalog_names().first().unwrap()) + .catalog( + &provider + .catalog_names() + .await + .try_next() + .await + .unwrap() + .unwrap(), + ) + .await + .unwrap() .unwrap(); let schema = catalog - .schema(catalog.schema_names().first().unwrap()) + .schema( + &catalog + .schema_names() + .await + .try_next() + .await + .unwrap() + .unwrap(), + ) + .await + .unwrap() .unwrap(); (ctx, schema) } @@ -262,7 +289,7 @@ mod tests { let domain = "example.com"; let location = format!("http://{domain}/file.parquet"); - let (ctx, schema) = setup_context(); + let (ctx, schema) = setup_context().await; // That's a non registered table so expecting None here let table = schema.table(&location).await?; @@ -287,7 +314,7 @@ mod tests { let bucket = "examples3bucket"; let location = format!("s3://{bucket}/file.parquet"); - let (ctx, schema) = setup_context(); + let (ctx, schema) = setup_context().await; let table = schema.table(&location).await?; assert!(table.is_none()); @@ -309,7 +336,7 @@ mod tests { let bucket = "examplegsbucket"; let location = format!("gs://{bucket}/file.parquet"); - let (ctx, schema) = setup_context(); + let (ctx, schema) = setup_context().await; let table = schema.table(&location).await?; assert!(table.is_none()); @@ -329,7 +356,7 @@ mod tests { #[tokio::test] async fn query_invalid_location_test() { let location = "ts://file.parquet"; - let (_ctx, schema) = setup_context(); + let (_ctx, schema) = setup_context().await; assert!(schema.table(location).await.is_err()); } diff --git a/datafusion-cli/src/main.rs b/datafusion-cli/src/main.rs index 4c6c352ff339..6c01b1f909ad 100644 --- a/datafusion-cli/src/main.rs +++ b/datafusion-cli/src/main.rs @@ -176,7 +176,8 @@ async fn main_inner() -> Result<()> { // enable dynamic file query let ctx = SessionContext::new_with_config_rt(session_config.clone(), Arc::new(runtime_env)) - .enable_url_table(); + .enable_url_table() + .await; ctx.refresh_catalogs().await?; // install dynamic catalog provider that can register required object stores ctx.register_catalog_list(Arc::new(DynamicObjectStoreCatalog::new( diff --git a/datafusion-examples/examples/advanced_parquet_index.rs b/datafusion-examples/examples/advanced_parquet_index.rs index 67b745d4074e..24fd72f7aa05 100644 --- a/datafusion-examples/examples/advanced_parquet_index.rs +++ b/datafusion-examples/examples/advanced_parquet_index.rs @@ -168,7 +168,8 @@ async fn main() -> Result<()> { // SessionContext for running queries that has the table provider // registered as "index_table" let ctx = SessionContext::new(); - ctx.register_table("index_table", Arc::clone(&provider) as _)?; + ctx.register_table("index_table", Arc::clone(&provider) as _) + .await?; // register object store provider for urls like `file://` work let url = Url::try_from("file://").unwrap(); diff --git a/datafusion-examples/examples/advanced_udaf.rs b/datafusion-examples/examples/advanced_udaf.rs index 414596bdc678..812cc10781f2 100644 --- a/datafusion-examples/examples/advanced_udaf.rs +++ b/datafusion-examples/examples/advanced_udaf.rs @@ -198,7 +198,7 @@ impl Accumulator for GeometricMean { } // create local session context with an in-memory table -fn create_context() -> Result { +async fn create_context() -> Result { use datafusion::datasource::MemTable; // define a schema. let schema = Arc::new(Schema::new(vec![ @@ -227,7 +227,7 @@ fn create_context() -> Result { // declare a table in memory. In spark API, this corresponds to createDataFrame(...). let provider = MemTable::try_new(schema, vec![vec![batch1], vec![batch2]])?; - ctx.register_table("t", Arc::new(provider))?; + ctx.register_table("t", Arc::new(provider)).await?; Ok(ctx) } @@ -401,7 +401,7 @@ impl GroupsAccumulator for GeometricMeanGroupsAccumulator { #[tokio::main] async fn main() -> Result<()> { - let ctx = create_context()?; + let ctx = create_context().await?; // create the AggregateUDF let geometric_mean = AggregateUDF::from(GeoMeanUdaf::new()); diff --git a/datafusion-examples/examples/advanced_udf.rs b/datafusion-examples/examples/advanced_udf.rs index aee3be6c9285..22c78de2ecd3 100644 --- a/datafusion-examples/examples/advanced_udf.rs +++ b/datafusion-examples/examples/advanced_udf.rs @@ -195,7 +195,7 @@ impl ScalarUDFImpl for PowUdf { /// and invoke it via the DataFrame API and SQL #[tokio::main] async fn main() -> Result<()> { - let ctx = create_context()?; + let ctx = create_context().await?; // create the UDF let pow = ScalarUDF::from(PowUdf::new()); @@ -234,7 +234,7 @@ async fn main() -> Result<()> { /// | 5.1 | 4.0 | /// +-----+-----+ /// ``` -fn create_context() -> Result { +async fn create_context() -> Result { // define data. let a: ArrayRef = Arc::new(Float32Array::from(vec![2.1, 3.1, 4.1, 5.1])); let b: ArrayRef = Arc::new(Float64Array::from(vec![1.0, 2.0, 3.0, 4.0])); @@ -244,6 +244,6 @@ fn create_context() -> Result { let ctx = SessionContext::new(); // declare a table in memory. In Spark API, this corresponds to createDataFrame(...). - ctx.register_batch("t", batch)?; + ctx.register_batch("t", batch).await?; Ok(ctx) } diff --git a/datafusion-examples/examples/analyzer_rule.rs b/datafusion-examples/examples/analyzer_rule.rs index bd067be97b8b..623c8d407b8f 100644 --- a/datafusion-examples/examples/analyzer_rule.rs +++ b/datafusion-examples/examples/analyzer_rule.rs @@ -46,7 +46,7 @@ pub async fn main() -> Result<()> { let ctx = SessionContext::new(); ctx.add_analyzer_rule(Arc::clone(&rule) as _); - ctx.register_batch("employee", employee_batch())?; + ctx.register_batch("employee", employee_batch()).await?; // Now, planning any SQL statement also invokes the AnalyzerRule let plan = ctx diff --git a/datafusion-examples/examples/catalog.rs b/datafusion-examples/examples/catalog.rs index f40f1dfb5a15..f584a55d01e8 100644 --- a/datafusion-examples/examples/catalog.rs +++ b/datafusion-examples/examples/catalog.rs @@ -29,6 +29,7 @@ use datafusion::{ execution::context::SessionState, prelude::SessionContext, }; +use futures::{stream::BoxStream, StreamExt}; use std::sync::RwLock; use std::{any::Any, collections::HashMap, path::Path, sync::Arc}; use std::{fs::File, io::Write}; @@ -74,11 +75,17 @@ async fn main() -> Result<()> { .await?; // register schemas into catalog - catalog.register_schema("schema_a", schema_a.clone())?; - catalog.register_schema("schema_b", schema_b.clone())?; + catalog + .register_schema("schema_a", schema_a.clone()) + .await?; + catalog + .register_schema("schema_b", schema_b.clone()) + .await?; // register our catalog in the context - ctx.register_catalog("dircat", Arc::new(catalog)); + ctx.register_catalog("dircat", Arc::new(catalog)) + .await + .unwrap(); { // catalog was passed down into our custom catalog list since we override the ctx's default let catalogs = cataloglist.catalogs.read().unwrap(); @@ -184,9 +191,9 @@ impl SchemaProvider for DirSchema { self } - fn table_names(&self) -> Vec { + async fn table_names(&self) -> BoxStream<'static, Result> { let tables = self.tables.read().unwrap(); - tables.keys().cloned().collect::>() + futures::stream::iter(tables.keys().cloned().map(Ok).collect::>()).boxed() } async fn table(&self, name: &str) -> Result>> { @@ -194,11 +201,11 @@ impl SchemaProvider for DirSchema { Ok(tables.get(name).cloned()) } - fn table_exist(&self, name: &str) -> bool { + async fn table_exist(&self, name: &str) -> bool { let tables = self.tables.read().unwrap(); tables.contains_key(name) } - fn register_table( + async fn register_table( &self, name: String, table: Arc, @@ -212,7 +219,10 @@ impl SchemaProvider for DirSchema { /// If supported by the implementation, removes an existing table from this schema and returns it. /// If no table of that name exists, returns Ok(None). #[allow(unused_variables)] - fn deregister_table(&self, name: &str) -> Result>> { + async fn deregister_table( + &self, + name: &str, + ) -> Result>> { let mut tables = self.tables.write().unwrap(); log::info!("dropping table {name}"); Ok(tables.remove(name)) @@ -230,11 +240,13 @@ impl DirCatalog { } } } + +#[async_trait] impl CatalogProvider for DirCatalog { fn as_any(&self) -> &dyn Any { self } - fn register_schema( + async fn register_schema( &self, name: &str, schema: Arc, @@ -244,19 +256,19 @@ impl CatalogProvider for DirCatalog { Ok(Some(schema)) } - fn schema_names(&self) -> Vec { + async fn schema_names(&self) -> BoxStream<'static, Result> { let schemas = self.schemas.read().unwrap(); - schemas.keys().cloned().collect() + futures::stream::iter(schemas.keys().cloned().map(Ok).collect::>()).boxed() } - fn schema(&self, name: &str) -> Option> { + async fn schema(&self, name: &str) -> Result>> { let schemas = self.schemas.read().unwrap(); let maybe_schema = schemas.get(name); if let Some(schema) = maybe_schema { let schema = schema.clone() as Arc; - Some(schema) + Ok(Some(schema)) } else { - None + Ok(None) } } } @@ -272,30 +284,32 @@ impl CustomCatalogProviderList { } } } + +#[async_trait] impl CatalogProviderList for CustomCatalogProviderList { fn as_any(&self) -> &dyn Any { self } - fn register_catalog( + async fn register_catalog( &self, name: String, catalog: Arc, - ) -> Option> { + ) -> Result>> { let mut cats = self.catalogs.write().unwrap(); cats.insert(name, catalog.clone()); - Some(catalog) + Ok(Some(catalog)) } /// Retrieves the list of available catalog names - fn catalog_names(&self) -> Vec { + async fn catalog_names(&self) -> BoxStream<'static, Result> { let cats = self.catalogs.read().unwrap(); - cats.keys().cloned().collect() + futures::stream::iter(cats.keys().cloned().map(Ok).collect::>()).boxed() } /// Retrieves a specific catalog by name, provided it exists. - fn catalog(&self, name: &str) -> Option> { + async fn catalog(&self, name: &str) -> Result>> { let cats = self.catalogs.read().unwrap(); - cats.get(name).cloned() + Ok(cats.get(name).cloned()) } } diff --git a/datafusion-examples/examples/custom_file_format.rs b/datafusion-examples/examples/custom_file_format.rs index 95168597ebaa..c35a881b64d1 100644 --- a/datafusion-examples/examples/custom_file_format.rs +++ b/datafusion-examples/examples/custom_file_format.rs @@ -179,7 +179,10 @@ impl GetExt for TSVFileFactory { #[tokio::main] async fn main() -> Result<()> { // Create a new context with the default configuration - let mut state = SessionStateBuilder::new().with_default_features().build(); + let mut state = SessionStateBuilder::new() + .with_default_features() + .build() + .await; // Register the custom file format let file_format = Arc::new(TSVFileFactory::new()); @@ -189,7 +192,7 @@ async fn main() -> Result<()> { let ctx = SessionContext::new_with_state(state); let mem_table = create_mem_table(); - ctx.register_table("mem_table", mem_table).unwrap(); + ctx.register_table("mem_table", mem_table).await.unwrap(); let temp_dir = tempdir().unwrap(); let table_save_path = temp_dir.path().join("mem_table.tsv"); diff --git a/datafusion-examples/examples/dataframe.rs b/datafusion-examples/examples/dataframe.rs index 59766e881e8b..fc4434cf319a 100644 --- a/datafusion-examples/examples/dataframe.rs +++ b/datafusion-examples/examples/dataframe.rs @@ -109,7 +109,7 @@ async fn read_csv(ctx: &SessionContext) -> Result<()> { // You can also create DataFrames from the result of sql queries // and using the `enable_url_table` refer to local files directly - let dyn_ctx = ctx.clone().enable_url_table(); + let dyn_ctx = ctx.clone().enable_url_table().await; let csv_df = dyn_ctx .sql(&format!("SELECT rating, unixtime FROM '{}'", file_path)) .await?; @@ -127,7 +127,7 @@ async fn read_memory(ctx: &SessionContext) -> Result<()> { let batch = RecordBatch::try_from_iter(vec![("a", a), ("b", b)])?; // declare a table in memory. In Apache Spark API, this corresponds to createDataFrame(...). - ctx.register_batch("t", batch)?; + ctx.register_batch("t", batch).await?; let df = ctx.table("t").await?; // construct an expression corresponding to "SELECT a, b FROM t WHERE b = 10" in SQL diff --git a/datafusion-examples/examples/external_dependency/query-aws-s3.rs b/datafusion-examples/examples/external_dependency/query-aws-s3.rs index da2d7e4879f9..9e07f2c98b81 100644 --- a/datafusion-examples/examples/external_dependency/query-aws-s3.rs +++ b/datafusion-examples/examples/external_dependency/query-aws-s3.rs @@ -64,7 +64,7 @@ async fn main() -> Result<()> { df.show().await?; // dynamic query by the file path - let ctx = ctx.enable_url_table(); + let ctx = ctx.enable_url_table().await; let df = ctx .sql(format!(r#"SELECT * FROM '{}' LIMIT 10"#, &path).as_str()) .await?; diff --git a/datafusion-examples/examples/ffi/ffi_module_loader/src/main.rs b/datafusion-examples/examples/ffi/ffi_module_loader/src/main.rs index 6e376ca866e8..564ab71a1f23 100644 --- a/datafusion-examples/examples/ffi/ffi_module_loader/src/main.rs +++ b/datafusion-examples/examples/ffi/ffi_module_loader/src/main.rs @@ -55,7 +55,8 @@ async fn main() -> Result<()> { let ctx = SessionContext::new(); // Display the data to show the full cycle works. - ctx.register_table("external_table", Arc::new(foreign_table_provider))?; + ctx.register_table("external_table", Arc::new(foreign_table_provider)) + .await?; let df = ctx.table("external_table").await?; df.show().await?; diff --git a/datafusion-examples/examples/file_stream_provider.rs b/datafusion-examples/examples/file_stream_provider.rs index e4fd937fd373..75b4785fa2b1 100644 --- a/datafusion-examples/examples/file_stream_provider.rs +++ b/datafusion-examples/examples/file_stream_provider.rs @@ -160,7 +160,7 @@ mod non_windows { let order = vec![vec![datafusion_expr::col("a1").sort(true, false)]]; let provider = fifo_table(schema.clone(), fifo_path, order.clone()); - ctx.register_table("fifo", provider)?; + ctx.register_table("fifo", provider).await?; let df = ctx.sql("SELECT * FROM fifo").await.unwrap(); let mut stream = df.execute_stream().await.unwrap(); diff --git a/datafusion-examples/examples/flight/flight_sql_server.rs b/datafusion-examples/examples/flight/flight_sql_server.rs index 2e46daf7cb4e..42b198eb9f83 100644 --- a/datafusion-examples/examples/flight/flight_sql_server.rs +++ b/datafusion-examples/examples/flight/flight_sql_server.rs @@ -170,11 +170,15 @@ impl FlightSqlServiceImpl { let mut schemas = vec![]; let mut names = vec![]; let mut types = vec![]; - for catalog in ctx.catalog_names() { - let catalog_provider = ctx.catalog(&catalog).unwrap(); - for schema in catalog_provider.schema_names() { - let schema_provider = catalog_provider.schema(&schema).unwrap(); - for table in schema_provider.table_names() { + let mut catalog_names = ctx.catalog_names().await; + while let Some(catalog) = catalog_names.try_next().await.unwrap() { + let catalog_provider = ctx.catalog(&catalog).await.unwrap().unwrap(); + let mut schema_names = catalog_provider.schema_names().await; + while let Some(schema) = schema_names.try_next().await.unwrap() { + let schema_provider = + catalog_provider.schema(&schema).await.unwrap().unwrap(); + let mut table_names = schema_provider.table_names().await; + while let Some(table) = table_names.try_next().await.unwrap() { let table_provider = schema_provider.table(&table).await.unwrap().unwrap(); catalogs.push(catalog.clone()); diff --git a/datafusion-examples/examples/make_date.rs b/datafusion-examples/examples/make_date.rs index 98bbb21bbff8..4af8f6ea9314 100644 --- a/datafusion-examples/examples/make_date.rs +++ b/datafusion-examples/examples/make_date.rs @@ -49,7 +49,7 @@ async fn main() -> Result<()> { let ctx = SessionContext::new(); // declare a table in memory. In spark API, this corresponds to createDataFrame(...). - ctx.register_batch("t", batch)?; + ctx.register_batch("t", batch).await?; let df = ctx.table("t").await?; // use make_date function to convert col 'y', 'm' & 'd' to a date diff --git a/datafusion-examples/examples/memtable.rs b/datafusion-examples/examples/memtable.rs index 5cce578039e7..baae6fc7554d 100644 --- a/datafusion-examples/examples/memtable.rs +++ b/datafusion-examples/examples/memtable.rs @@ -34,7 +34,7 @@ async fn main() -> Result<()> { let ctx = SessionContext::new(); // Register the in-memory table containing the data - ctx.register_table("users", Arc::new(mem_table))?; + ctx.register_table("users", Arc::new(mem_table)).await?; let dataframe = ctx.sql("SELECT * FROM users;").await?; diff --git a/datafusion-examples/examples/optimizer_rule.rs b/datafusion-examples/examples/optimizer_rule.rs index 0f28a1670252..8df4c203a4bd 100644 --- a/datafusion-examples/examples/optimizer_rule.rs +++ b/datafusion-examples/examples/optimizer_rule.rs @@ -48,7 +48,7 @@ pub async fn main() -> Result<()> { ctx.add_optimizer_rule(Arc::new(MyOptimizerRule {})); // Now, let's plan and run queries with the new rule - ctx.register_batch("person", person_batch())?; + ctx.register_batch("person", person_batch()).await?; let sql = "SELECT * FROM person WHERE age = 22"; let plan = ctx.sql(sql).await?.into_optimized_plan()?; diff --git a/datafusion-examples/examples/parquet_index.rs b/datafusion-examples/examples/parquet_index.rs index d6e17764442d..692d82eeea1e 100644 --- a/datafusion-examples/examples/parquet_index.rs +++ b/datafusion-examples/examples/parquet_index.rs @@ -121,7 +121,8 @@ async fn main() -> Result<()> { // Create a SessionContext for running queries that has the table provider // registered as "index_table" let ctx = SessionContext::new(); - ctx.register_table("index_table", Arc::clone(&provider) as _)?; + ctx.register_table("index_table", Arc::clone(&provider) as _) + .await?; // register object store provider for urls like `file://` work let url = Url::try_from("file://").unwrap(); diff --git a/datafusion-examples/examples/simple_udaf.rs b/datafusion-examples/examples/simple_udaf.rs index ef97bf9763b0..590ced51f148 100644 --- a/datafusion-examples/examples/simple_udaf.rs +++ b/datafusion-examples/examples/simple_udaf.rs @@ -26,7 +26,7 @@ use datafusion_common::cast::as_float64_array; use std::sync::Arc; // create local session context with an in-memory table -fn create_context() -> Result { +async fn create_context() -> Result { use datafusion::arrow::datatypes::{Field, Schema}; use datafusion::datasource::MemTable; // define a schema. @@ -47,7 +47,7 @@ fn create_context() -> Result { // declare a table in memory. In spark API, this corresponds to createDataFrame(...). let provider = MemTable::try_new(schema, vec![vec![batch1], vec![batch2]])?; - ctx.register_table("t", Arc::new(provider))?; + ctx.register_table("t", Arc::new(provider)).await?; Ok(ctx) } @@ -137,7 +137,7 @@ impl Accumulator for GeometricMean { #[tokio::main] async fn main() -> Result<()> { - let ctx = create_context()?; + let ctx = create_context().await?; // here is where we define the UDAF. We also declare its signature: let geometric_mean = create_udaf( diff --git a/datafusion-examples/examples/simple_udf.rs b/datafusion-examples/examples/simple_udf.rs index 6879a17f34be..dbb70110374a 100644 --- a/datafusion-examples/examples/simple_udf.rs +++ b/datafusion-examples/examples/simple_udf.rs @@ -42,7 +42,7 @@ use std::sync::Arc; /// | 5.1 | 4.0 | /// +-----+-----+ /// ``` -fn create_context() -> Result { +async fn create_context() -> Result { // define data. let a: ArrayRef = Arc::new(Float32Array::from(vec![2.1, 3.1, 4.1, 5.1])); let b: ArrayRef = Arc::new(Float64Array::from(vec![1.0, 2.0, 3.0, 4.0])); @@ -52,14 +52,14 @@ fn create_context() -> Result { let ctx = SessionContext::new(); // declare a table in memory. In spark API, this corresponds to createDataFrame(...). - ctx.register_batch("t", batch)?; + ctx.register_batch("t", batch).await?; Ok(ctx) } /// In this example we will declare a single-type, single return type UDF that exponentiates f64, a^b #[tokio::main] async fn main() -> Result<()> { - let ctx = create_context()?; + let ctx = create_context().await?; // First, declare the actual implementation of the calculation let pow = Arc::new(|args: &[ColumnarValue]| { diff --git a/datafusion-examples/examples/simplify_udaf_expression.rs b/datafusion-examples/examples/simplify_udaf_expression.rs index 52a27317e3c3..f7cff03fefed 100644 --- a/datafusion-examples/examples/simplify_udaf_expression.rs +++ b/datafusion-examples/examples/simplify_udaf_expression.rs @@ -108,7 +108,7 @@ impl AggregateUDFImpl for BetterAvgUdaf { } // create local session context with an in-memory table -fn create_context() -> Result { +async fn create_context() -> Result { use datafusion::datasource::MemTable; // define a schema. let schema = Arc::new(Schema::new(vec![ @@ -136,13 +136,13 @@ fn create_context() -> Result { // declare a table in memory. In spark API, this corresponds to createDataFrame(...). let provider = MemTable::try_new(schema, vec![vec![batch1], vec![batch2]])?; - ctx.register_table("t", Arc::new(provider))?; + ctx.register_table("t", Arc::new(provider)).await?; Ok(ctx) } #[tokio::main] async fn main() -> Result<()> { - let ctx = create_context()?; + let ctx = create_context().await?; let better_avg = AggregateUDF::from(BetterAvgUdaf::new()); ctx.register_udaf(better_avg.clone()); diff --git a/datafusion-examples/examples/sql_analysis.rs b/datafusion-examples/examples/sql_analysis.rs index 2158b8e4b016..b2f4a59f6fdb 100644 --- a/datafusion-examples/examples/sql_analysis.rs +++ b/datafusion-examples/examples/sql_analysis.rs @@ -275,7 +275,8 @@ from ctx.register_table( table.name, Arc::new(MemTable::try_new(Arc::new(table.schema.clone()), vec![])?), - )?; + ) + .await?; } // We can create a LogicalPlan from a SQL query like this let logical_plan = ctx.sql(tpcds_query_88).await?.into_optimized_plan()?; diff --git a/datafusion-examples/examples/to_char.rs b/datafusion-examples/examples/to_char.rs index f8ed68b46f19..c5e45708cfb9 100644 --- a/datafusion-examples/examples/to_char.rs +++ b/datafusion-examples/examples/to_char.rs @@ -49,7 +49,7 @@ async fn main() -> Result<()> { let ctx = SessionContext::new(); // declare a table in memory. In spark API, this corresponds to createDataFrame(...). - ctx.register_batch("t", batch)?; + ctx.register_batch("t", batch).await?; let _ = ctx.table("t").await?; // use to_char function to convert col 'values' to timestamp type using diff --git a/datafusion-examples/examples/to_date.rs b/datafusion-examples/examples/to_date.rs index 99ee555ffc17..59e62bbeef56 100644 --- a/datafusion-examples/examples/to_date.rs +++ b/datafusion-examples/examples/to_date.rs @@ -45,7 +45,7 @@ async fn main() -> Result<()> { let ctx = SessionContext::new(); // declare a table in memory. In spark API, this corresponds to createDataFrame(...). - ctx.register_batch("t", batch)?; + ctx.register_batch("t", batch).await?; let df = ctx.table("t").await?; // use to_date function to convert col 'a' to timestamp type using the default parsing diff --git a/datafusion-examples/examples/to_timestamp.rs b/datafusion-examples/examples/to_timestamp.rs index 940c85df33c5..35247eb5e4fe 100644 --- a/datafusion-examples/examples/to_timestamp.rs +++ b/datafusion-examples/examples/to_timestamp.rs @@ -57,7 +57,7 @@ async fn main() -> Result<()> { let ctx = SessionContext::new(); // declare a table in memory. In spark API, this corresponds to createDataFrame(...). - ctx.register_batch("t", batch)?; + ctx.register_batch("t", batch).await?; let df = ctx.table("t").await?; // use to_timestamp function to convert col 'a' to timestamp type using the default parsing diff --git a/datafusion/catalog/Cargo.toml b/datafusion/catalog/Cargo.toml index f9801352087d..41e3cedbd9cf 100644 --- a/datafusion/catalog/Cargo.toml +++ b/datafusion/catalog/Cargo.toml @@ -34,6 +34,7 @@ datafusion-common = { workspace = true } datafusion-execution = { workspace = true } datafusion-expr = { workspace = true } datafusion-physical-plan = { workspace = true } +futures = { workspace = true } parking_lot = { workspace = true } [lints] diff --git a/datafusion/catalog/src/catalog.rs b/datafusion/catalog/src/catalog.rs index 048a7f14ed37..9da3f0c091a1 100644 --- a/datafusion/catalog/src/catalog.rs +++ b/datafusion/catalog/src/catalog.rs @@ -20,8 +20,10 @@ use std::fmt::Debug; use std::sync::Arc; pub use crate::schema::SchemaProvider; +use async_trait::async_trait; use datafusion_common::not_impl_err; use datafusion_common::Result; +use futures::stream::BoxStream; /// Represents a catalog, comprising a number of named schemas. /// @@ -102,16 +104,17 @@ use datafusion_common::Result; /// /// [`TableProvider`]: crate::TableProvider +#[async_trait] pub trait CatalogProvider: Debug + Sync + Send { /// Returns the catalog provider as [`Any`] /// so that it can be downcast to a specific implementation. fn as_any(&self) -> &dyn Any; /// Retrieves the list of available schema names in this catalog. - fn schema_names(&self) -> Vec; + async fn schema_names(&self) -> BoxStream<'static, Result>; /// Retrieves a specific schema from the catalog by name, provided it exists. - fn schema(&self, name: &str) -> Option>; + async fn schema(&self, name: &str) -> Result>>; /// Adds a new schema to this catalog. /// @@ -119,7 +122,7 @@ pub trait CatalogProvider: Debug + Sync + Send { /// the catalog and returned. /// /// By default returns a "Not Implemented" error - fn register_schema( + async fn register_schema( &self, name: &str, schema: Arc, @@ -140,7 +143,7 @@ pub trait CatalogProvider: Debug + Sync + Send { /// does not exist. /// /// By default returns a "Not Implemented" error - fn deregister_schema( + async fn deregister_schema( &self, _name: &str, _cascade: bool, @@ -153,6 +156,7 @@ pub trait CatalogProvider: Debug + Sync + Send { /// /// Please see the documentation on `CatalogProvider` for details of /// implementing a custom catalog. +#[async_trait] pub trait CatalogProviderList: Debug + Sync + Send { /// Returns the catalog list as [`Any`] /// so that it can be downcast to a specific implementation. @@ -160,15 +164,15 @@ pub trait CatalogProviderList: Debug + Sync + Send { /// Adds a new catalog to this catalog list /// If a catalog of the same name existed before, it is replaced in the list and returned. - fn register_catalog( + async fn register_catalog( &self, name: String, catalog: Arc, - ) -> Option>; + ) -> Result>>; /// Retrieves the list of available catalog names - fn catalog_names(&self) -> Vec; + async fn catalog_names(&self) -> BoxStream<'static, Result>; /// Retrieves a specific catalog by name, provided it exists. - fn catalog(&self, name: &str) -> Option>; + async fn catalog(&self, name: &str) -> Result>>; } diff --git a/datafusion/catalog/src/dynamic_file/catalog.rs b/datafusion/catalog/src/dynamic_file/catalog.rs index ccccb9762eb4..eba45d1c6bdb 100644 --- a/datafusion/catalog/src/dynamic_file/catalog.rs +++ b/datafusion/catalog/src/dynamic_file/catalog.rs @@ -19,6 +19,8 @@ use crate::{CatalogProvider, CatalogProviderList, SchemaProvider, TableProvider}; use async_trait::async_trait; +use datafusion_common::Result; +use futures::stream::BoxStream; use std::any::Any; use std::fmt::Debug; use std::sync::Arc; @@ -41,30 +43,31 @@ impl DynamicFileCatalog { } } +#[async_trait] impl CatalogProviderList for DynamicFileCatalog { fn as_any(&self) -> &dyn Any { self } - fn register_catalog( + async fn register_catalog( &self, name: String, catalog: Arc, - ) -> Option> { - self.inner.register_catalog(name, catalog) + ) -> Result>> { + Ok(self.inner.register_catalog(name, catalog).await?) } - fn catalog_names(&self) -> Vec { - self.inner.catalog_names() + async fn catalog_names(&self) -> BoxStream<'static, Result> { + self.inner.catalog_names().await } - fn catalog(&self, name: &str) -> Option> { - self.inner.catalog(name).map(|catalog| { + async fn catalog(&self, name: &str) -> Result>> { + Ok(self.inner.catalog(name).await?.map(|catalog| { Arc::new(DynamicFileCatalogProvider::new( catalog, Arc::clone(&self.factory), )) as _ - }) + })) } } @@ -86,30 +89,31 @@ impl DynamicFileCatalogProvider { } } +#[async_trait] impl CatalogProvider for DynamicFileCatalogProvider { fn as_any(&self) -> &dyn Any { self } - fn schema_names(&self) -> Vec { - self.inner.schema_names() + async fn schema_names(&self) -> BoxStream<'static, Result> { + self.inner.schema_names().await } - fn schema(&self, name: &str) -> Option> { - self.inner.schema(name).map(|schema| { + async fn schema(&self, name: &str) -> Result>> { + Ok(self.inner.schema(name).await?.map(|schema| { Arc::new(DynamicFileSchemaProvider::new( schema, Arc::clone(&self.factory), )) as _ - }) + })) } - fn register_schema( + async fn register_schema( &self, name: &str, schema: Arc, - ) -> datafusion_common::Result>> { - self.inner.register_schema(name, schema) + ) -> Result>> { + Ok(self.inner.register_schema(name, schema).await?) } } @@ -141,14 +145,11 @@ impl SchemaProvider for DynamicFileSchemaProvider { self } - fn table_names(&self) -> Vec { - self.inner.table_names() + async fn table_names(&self) -> BoxStream<'static, Result> { + self.inner.table_names().await } - async fn table( - &self, - name: &str, - ) -> datafusion_common::Result>> { + async fn table(&self, name: &str) -> Result>> { if let Some(table) = self.inner.table(name).await? { return Ok(Some(table)); }; @@ -156,23 +157,23 @@ impl SchemaProvider for DynamicFileSchemaProvider { self.factory.try_new(name).await } - fn register_table( + async fn register_table( &self, name: String, table: Arc, - ) -> datafusion_common::Result>> { - self.inner.register_table(name, table) + ) -> Result>> { + self.inner.register_table(name, table).await } - fn deregister_table( + async fn deregister_table( &self, name: &str, - ) -> datafusion_common::Result>> { - self.inner.deregister_table(name) + ) -> Result>> { + self.inner.deregister_table(name).await } - fn table_exist(&self, name: &str) -> bool { - self.inner.table_exist(name) + async fn table_exist(&self, name: &str) -> bool { + self.inner.table_exist(name).await } } @@ -180,8 +181,5 @@ impl SchemaProvider for DynamicFileSchemaProvider { #[async_trait] pub trait UrlTableFactory: Debug + Sync + Send { /// create a new table provider from the provided url - async fn try_new( - &self, - url: &str, - ) -> datafusion_common::Result>>; + async fn try_new(&self, url: &str) -> Result>>; } diff --git a/datafusion/catalog/src/schema.rs b/datafusion/catalog/src/schema.rs index 5b37348fd742..8db7dd3207c9 100644 --- a/datafusion/catalog/src/schema.rs +++ b/datafusion/catalog/src/schema.rs @@ -20,6 +20,7 @@ use async_trait::async_trait; use datafusion_common::{exec_err, DataFusionError}; +use futures::stream::BoxStream; use std::any::Any; use std::fmt::Debug; use std::sync::Arc; @@ -36,7 +37,7 @@ use datafusion_common::Result; pub trait SchemaProvider: Debug + Sync + Send { /// Returns the owner of the Schema, default is None. This value is reported /// as part of `information_tables.schemata - fn owner_name(&self) -> Option<&str> { + async fn owner_name(&self) -> Option<&str> { None } @@ -45,7 +46,7 @@ pub trait SchemaProvider: Debug + Sync + Send { fn as_any(&self) -> &dyn Any; /// Retrieves the list of available table names in this schema. - fn table_names(&self) -> Vec; + async fn table_names(&self) -> BoxStream<'static, Result>; /// Retrieves a specific table from the schema by name, if it exists, /// otherwise returns `None`. @@ -60,7 +61,7 @@ pub trait SchemaProvider: Debug + Sync + Send { /// If a table of the same name was already registered, returns "Table /// already exists" error. #[allow(unused_variables)] - fn register_table( + async fn register_table( &self, name: String, table: Arc, @@ -73,10 +74,13 @@ pub trait SchemaProvider: Debug + Sync + Send { /// /// If no `name` table exists, returns Ok(None). #[allow(unused_variables)] - fn deregister_table(&self, name: &str) -> Result>> { + async fn deregister_table( + &self, + name: &str, + ) -> Result>> { exec_err!("schema provider does not support deregistering tables") } /// Returns true if table exist in the schema provider, false otherwise. - fn table_exist(&self, name: &str) -> bool; + async fn table_exist(&self, name: &str) -> bool; } diff --git a/datafusion/core/benches/aggregate_query_sql.rs b/datafusion/core/benches/aggregate_query_sql.rs index 1d8d87ada784..001c1ca6abe3 100644 --- a/datafusion/core/benches/aggregate_query_sql.rs +++ b/datafusion/core/benches/aggregate_query_sql.rs @@ -25,6 +25,7 @@ use crate::criterion::Criterion; use data_utils::create_table_provider; use datafusion::error::Result; use datafusion::execution::context::SessionContext; +use futures::FutureExt; use parking_lot::Mutex; use std::sync::Arc; use tokio::runtime::Runtime; @@ -42,7 +43,7 @@ fn create_context( ) -> Result>> { let ctx = SessionContext::new(); let provider = create_table_provider(partitions_len, array_len, batch_size)?; - ctx.register_table("t", provider)?; + ctx.register_table("t", provider).now_or_never().unwrap()?; Ok(Arc::new(Mutex::new(ctx))) } diff --git a/datafusion/core/benches/distinct_query_sql.rs b/datafusion/core/benches/distinct_query_sql.rs index c242798a56f0..5957acbf016b 100644 --- a/datafusion/core/benches/distinct_query_sql.rs +++ b/datafusion/core/benches/distinct_query_sql.rs @@ -29,6 +29,7 @@ use datafusion::{datasource::MemTable, error::Result}; use datafusion_execution::config::SessionConfig; use datafusion_execution::TaskContext; +use futures::FutureExt; use parking_lot::Mutex; use std::{sync::Arc, time::Duration}; use tokio::runtime::Runtime; @@ -46,7 +47,7 @@ fn create_context( ) -> Result>> { let ctx = SessionContext::new(); let provider = create_table_provider(partitions_len, array_len, batch_size)?; - ctx.register_table("t", provider)?; + ctx.register_table("t", provider).now_or_never().unwrap()?; Ok(Arc::new(Mutex::new(ctx))) } @@ -137,7 +138,7 @@ pub async fn create_context_sampled_data( // Create the DataFrame let cfg = SessionConfig::new(); let ctx = SessionContext::new_with_config(cfg); - let _ = ctx.register_table("traces", mem_table)?; + let _ = ctx.register_table("traces", mem_table).await?; let df = ctx.sql(sql).await?; let physical_plan = df.create_physical_plan().await?; Ok((physical_plan, ctx.task_ctx())) diff --git a/datafusion/core/benches/filter_query_sql.rs b/datafusion/core/benches/filter_query_sql.rs index 0e09ae09d7c2..311c53dbb1ee 100644 --- a/datafusion/core/benches/filter_query_sql.rs +++ b/datafusion/core/benches/filter_query_sql.rs @@ -23,7 +23,7 @@ use arrow::{ use criterion::{criterion_group, criterion_main, Criterion}; use datafusion::prelude::SessionContext; use datafusion::{datasource::MemTable, error::Result}; -use futures::executor::block_on; +use futures::{executor::block_on, FutureExt}; use std::sync::Arc; use tokio::runtime::Runtime; @@ -60,7 +60,9 @@ fn create_context(array_len: usize, batch_size: usize) -> Result // declare a table in memory. In spark API, this corresponds to createDataFrame(...). let provider = MemTable::try_new(schema, vec![batches])?; - ctx.register_table("t", Arc::new(provider))?; + ctx.register_table("t", Arc::new(provider)) + .now_or_never() + .unwrap()?; Ok(ctx) } diff --git a/datafusion/core/benches/map_query_sql.rs b/datafusion/core/benches/map_query_sql.rs index e4c5f7c5deb3..8a497a6510f1 100644 --- a/datafusion/core/benches/map_query_sql.rs +++ b/datafusion/core/benches/map_query_sql.rs @@ -19,6 +19,7 @@ use std::sync::Arc; use arrow_array::{ArrayRef, Int32Array, RecordBatch}; use criterion::{black_box, criterion_group, criterion_main, Criterion}; +use futures::FutureExt; use parking_lot::Mutex; use rand::prelude::ThreadRng; use rand::Rng; @@ -55,7 +56,9 @@ fn t_batch(num: i32) -> RecordBatch { fn create_context(num: i32) -> datafusion_common::Result>> { let ctx = SessionContext::new(); - ctx.register_batch("t", t_batch(num))?; + ctx.register_batch("t", t_batch(num)) + .now_or_never() + .expect("default context should use synchronous in-memory catalog")?; Ok(Arc::new(Mutex::new(ctx))) } diff --git a/datafusion/core/benches/math_query_sql.rs b/datafusion/core/benches/math_query_sql.rs index 92c59d506640..2ca8ec5e6320 100644 --- a/datafusion/core/benches/math_query_sql.rs +++ b/datafusion/core/benches/math_query_sql.rs @@ -19,6 +19,7 @@ extern crate criterion; use criterion::Criterion; +use futures::FutureExt; use parking_lot::Mutex; use std::sync::Arc; @@ -72,7 +73,9 @@ fn create_context( // declare a table in memory. In spark API, this corresponds to createDataFrame(...). let provider = MemTable::try_new(schema, vec![batches])?; - ctx.register_table("t", Arc::new(provider))?; + ctx.register_table("t", Arc::new(provider)) + .now_or_never() + .expect("default context should use synchronous in-memory catalog")?; Ok(Arc::new(Mutex::new(ctx))) } diff --git a/datafusion/core/benches/sort_limit_query_sql.rs b/datafusion/core/benches/sort_limit_query_sql.rs index cfd4b8bc4bba..8c6410103aef 100644 --- a/datafusion/core/benches/sort_limit_query_sql.rs +++ b/datafusion/core/benches/sort_limit_query_sql.rs @@ -24,6 +24,7 @@ use datafusion::datasource::listing::{ }; use datafusion::prelude::SessionConfig; +use futures::FutureExt; use parking_lot::Mutex; use std::sync::Arc; @@ -95,6 +96,8 @@ fn create_context() -> Arc> { .await .unwrap(); ctx.register_table("aggregate_test_100", Arc::new(mem_table)) + .now_or_never() + .expect("default context should use synchronous in-memory catalog") .unwrap(); ctx_holder.lock().push(Arc::new(Mutex::new(ctx))) }); diff --git a/datafusion/core/benches/sql_planner.rs b/datafusion/core/benches/sql_planner.rs index 44320e7a287a..72ad35753114 100644 --- a/datafusion/core/benches/sql_planner.rs +++ b/datafusion/core/benches/sql_planner.rs @@ -28,6 +28,7 @@ use criterion::Bencher; use datafusion::datasource::MemTable; use datafusion::execution::context::SessionContext; use datafusion_common::ScalarValue; +use futures::FutureExt; use itertools::Itertools; use std::fs::File; use std::io::{BufRead, BufReader}; @@ -78,15 +79,23 @@ fn create_table_provider(column_prefix: &str, num_columns: usize) -> Arc SessionContext { let ctx = SessionContext::new(); - ctx.register_table("t1", create_table_provider("a", 200)) - .unwrap(); - ctx.register_table("t2", create_table_provider("b", 200)) - .unwrap(); - ctx.register_table("t700", create_table_provider("c", 700)) - .unwrap(); - ctx.register_table("t1000", create_table_provider("d", 1000)) - .unwrap(); - ctx + async move { + ctx.register_table("t1", create_table_provider("a", 200)) + .await + .unwrap(); + ctx.register_table("t2", create_table_provider("b", 200)) + .await + .unwrap(); + ctx.register_table("t700", create_table_provider("c", 700)) + .await + .unwrap(); + ctx.register_table("t1000", create_table_provider("d", 1000)) + .await + .unwrap(); + ctx + } + .now_or_never() + .expect("default context should use synchronous in-memory catalog") } /// Register the table definitions as a MemTable with the context and return the @@ -97,6 +106,8 @@ fn register_defs(ctx: SessionContext, defs: Vec) -> SessionContext { name, Arc::new(MemTable::try_new(Arc::new(schema.clone()), vec![vec![]]).unwrap()), ) + .now_or_never() + .expect("default context should use synchronous in-memory catalog") .unwrap(); }); ctx diff --git a/datafusion/core/benches/topk_aggregate.rs b/datafusion/core/benches/topk_aggregate.rs index 922cbd2b4229..37956edb8f5b 100644 --- a/datafusion/core/benches/topk_aggregate.rs +++ b/datafusion/core/benches/topk_aggregate.rs @@ -24,6 +24,7 @@ use datafusion::prelude::SessionContext; use datafusion::{datasource::MemTable, error::Result}; use datafusion_execution::config::SessionConfig; use datafusion_execution::TaskContext; +use futures::FutureExt; use std::sync::Arc; use tokio::runtime::Runtime; @@ -42,7 +43,10 @@ async fn create_context( let opts = cfg.options_mut(); opts.optimizer.enable_topk_aggregation = use_topk; let ctx = SessionContext::new_with_config(cfg); - let _ = ctx.register_table("traces", mem_table)?; + let _ = ctx + .register_table("traces", mem_table) + .now_or_never() + .expect("default context should use synchronous in-memory catalog")?; let sql = format!("select trace_id, max(timestamp_ms) from traces group by trace_id order by max(timestamp_ms) desc limit {limit};"); let df = ctx.sql(sql.as_str()).await?; let physical_plan = df.create_physical_plan().await?; diff --git a/datafusion/core/benches/window_query_sql.rs b/datafusion/core/benches/window_query_sql.rs index 42a1e51be361..740e6067532f 100644 --- a/datafusion/core/benches/window_query_sql.rs +++ b/datafusion/core/benches/window_query_sql.rs @@ -25,6 +25,7 @@ use crate::criterion::Criterion; use data_utils::create_table_provider; use datafusion::error::Result; use datafusion::execution::context::SessionContext; +use futures::FutureExt; use parking_lot::Mutex; use std::sync::Arc; use tokio::runtime::Runtime; @@ -42,7 +43,9 @@ fn create_context( ) -> Result>> { let ctx = SessionContext::new(); let provider = create_table_provider(partitions_len, array_len, batch_size)?; - ctx.register_table("t", provider)?; + ctx.register_table("t", provider) + .now_or_never() + .expect("default context should use synchronous in-memory catalog")?; Ok(Arc::new(Mutex::new(ctx))) } diff --git a/datafusion/core/src/catalog_common/information_schema.rs b/datafusion/core/src/catalog_common/information_schema.rs index 1d4a3c15f7ca..0d83f8b3f587 100644 --- a/datafusion/core/src/catalog_common/information_schema.rs +++ b/datafusion/core/src/catalog_common/information_schema.rs @@ -39,6 +39,8 @@ use async_trait::async_trait; use datafusion_common::error::Result; use datafusion_common::DataFusionError; use datafusion_expr::{AggregateUDF, ScalarUDF, Signature, TypeSignature, WindowUDF}; +use futures::stream::BoxStream; +use futures::{StreamExt, TryStreamExt}; use std::collections::{HashMap, HashSet}; use std::fmt::Debug; use std::{any::Any, sync::Arc}; @@ -96,14 +98,17 @@ impl InformationSchemaConfig { ) -> Result<(), DataFusionError> { // create a mem table with the names of tables - for catalog_name in self.catalog_list.catalog_names() { - let catalog = self.catalog_list.catalog(&catalog_name).unwrap(); + let mut catalog_names = self.catalog_list.catalog_names().await; + while let Some(catalog_name) = catalog_names.try_next().await? { + let catalog = self.catalog_list.catalog(&catalog_name).await?.unwrap(); - for schema_name in catalog.schema_names() { + let mut schema_names = catalog.schema_names().await; + while let Some(schema_name) = schema_names.try_next().await? { if schema_name != INFORMATION_SCHEMA { // schema name may not exist in the catalog, so we need to check - if let Some(schema) = catalog.schema(&schema_name) { - for table_name in schema.table_names() { + if let Some(schema) = catalog.schema(&schema_name).await? { + let mut table_names = schema.table_names().await; + while let Some(table_name) = table_names.try_next().await? { if let Some(table) = schema.table(&table_name).await? { builder.add_table( &catalog_name, @@ -131,33 +136,42 @@ impl InformationSchemaConfig { Ok(()) } - async fn make_schemata(&self, builder: &mut InformationSchemataBuilder) { - for catalog_name in self.catalog_list.catalog_names() { - let catalog = self.catalog_list.catalog(&catalog_name).unwrap(); + async fn make_schemata( + &self, + builder: &mut InformationSchemataBuilder, + ) -> Result<()> { + let mut catalog_names = self.catalog_list.catalog_names().await; + while let Some(catalog_name) = catalog_names.try_next().await? { + let catalog = self.catalog_list.catalog(&catalog_name).await?.unwrap(); - for schema_name in catalog.schema_names() { + let mut schema_names = catalog.schema_names().await; + while let Some(schema_name) = schema_names.try_next().await? { if schema_name != INFORMATION_SCHEMA { - if let Some(schema) = catalog.schema(&schema_name) { - let schema_owner = schema.owner_name(); + if let Some(schema) = catalog.schema(&schema_name).await? { + let schema_owner = schema.owner_name().await; builder.add_schemata(&catalog_name, &schema_name, schema_owner); } } } } + Ok(()) } async fn make_views( &self, builder: &mut InformationSchemaViewBuilder, ) -> Result<(), DataFusionError> { - for catalog_name in self.catalog_list.catalog_names() { - let catalog = self.catalog_list.catalog(&catalog_name).unwrap(); + let mut catalog_names = self.catalog_list.catalog_names().await; + while let Some(catalog_name) = catalog_names.try_next().await? { + let catalog = self.catalog_list.catalog(&catalog_name).await?.unwrap(); - for schema_name in catalog.schema_names() { + let mut schema_names = catalog.schema_names().await; + while let Some(schema_name) = schema_names.try_next().await? { if schema_name != INFORMATION_SCHEMA { // schema name may not exist in the catalog, so we need to check - if let Some(schema) = catalog.schema(&schema_name) { - for table_name in schema.table_names() { + if let Some(schema) = catalog.schema(&schema_name).await? { + let mut table_names = schema.table_names().await; + while let Some(table_name) = table_names.try_next().await? { if let Some(table) = schema.table(&table_name).await? { builder.add_view( &catalog_name, @@ -180,14 +194,17 @@ impl InformationSchemaConfig { &self, builder: &mut InformationSchemaColumnsBuilder, ) -> Result<(), DataFusionError> { - for catalog_name in self.catalog_list.catalog_names() { - let catalog = self.catalog_list.catalog(&catalog_name).unwrap(); + let mut catalog_names = self.catalog_list.catalog_names().await; + while let Some(catalog_name) = catalog_names.try_next().await? { + let catalog = self.catalog_list.catalog(&catalog_name).await?.unwrap(); - for schema_name in catalog.schema_names() { + let mut schema_names = catalog.schema_names().await; + while let Some(schema_name) = schema_names.try_next().await? { if schema_name != INFORMATION_SCHEMA { // schema name may not exist in the catalog, so we need to check - if let Some(schema) = catalog.schema(&schema_name) { - for table_name in schema.table_names() { + if let Some(schema) = catalog.schema(&schema_name).await? { + let mut table_names = schema.table_names().await; + while let Some(table_name) = table_names.try_next().await? { if let Some(table) = schema.table(&table_name).await? { for (field_position, field) in table.schema().fields().iter().enumerate() @@ -469,11 +486,9 @@ impl SchemaProvider for InformationSchemaProvider { self } - fn table_names(&self) -> Vec { - INFORMATION_SCHEMA_TABLES - .iter() - .map(|t| t.to_string()) - .collect() + async fn table_names(&self) -> BoxStream<'static, Result> { + let table_names = INFORMATION_SCHEMA_TABLES.iter().map(|s| Ok(s.to_string())); + futures::stream::iter(table_names).boxed() } async fn table( @@ -497,7 +512,7 @@ impl SchemaProvider for InformationSchemaProvider { ))) } - fn table_exist(&self, name: &str) -> bool { + async fn table_exist(&self, name: &str) -> bool { INFORMATION_SCHEMA_TABLES.contains(&name.to_ascii_lowercase().as_str()) } } @@ -994,7 +1009,7 @@ impl PartitionStream for InformationSchemata { Arc::clone(&self.schema), // TODO: Stream this futures::stream::once(async move { - config.make_schemata(&mut builder).await; + config.make_schemata(&mut builder).await?; Ok(builder.finish()) }), )) diff --git a/datafusion/core/src/catalog_common/listing_schema.rs b/datafusion/core/src/catalog_common/listing_schema.rs index dc55a07ef82d..8be9177df736 100644 --- a/datafusion/core/src/catalog_common/listing_schema.rs +++ b/datafusion/core/src/catalog_common/listing_schema.rs @@ -26,12 +26,13 @@ use crate::catalog::{SchemaProvider, TableProvider, TableProviderFactory}; use crate::execution::context::SessionState; use datafusion_common::{ - Constraints, DFSchema, DataFusionError, HashMap, TableReference, + Constraints, DFSchema, DataFusionError, HashMap, Result, TableReference, }; use datafusion_expr::CreateExternalTable; use async_trait::async_trait; -use futures::TryStreamExt; +use futures::stream::BoxStream; +use futures::{StreamExt, TryStreamExt}; use itertools::Itertools; use object_store::ObjectStore; @@ -88,7 +89,7 @@ impl ListingSchemaProvider { } /// Reload table information from ObjectStore - pub async fn refresh(&self, state: &SessionState) -> datafusion_common::Result<()> { + pub async fn refresh(&self, state: &SessionState) -> Result<()> { let entries: Vec<_> = self.store.list(Some(&self.path)).try_collect().await?; let base = Path::new(self.path.as_ref()); let mut tables = HashSet::new(); @@ -123,7 +124,7 @@ impl ListingSchemaProvider { DataFusionError::Internal("Cannot parse file name!".to_string()) })?; - if !self.table_exist(table_name) { + if !self.table_exist(table_name).await { let table_url = format!("{}/{}", self.authority, table_path); let name = TableReference::bare(table_name); @@ -148,8 +149,9 @@ impl ListingSchemaProvider { }, ) .await?; - let _ = - self.register_table(table_name.to_string(), Arc::clone(&provider))?; + let _ = self + .register_table(table_name.to_string(), Arc::clone(&provider)) + .await?; } } Ok(()) @@ -162,13 +164,10 @@ impl SchemaProvider for ListingSchemaProvider { self } - fn table_names(&self) -> Vec { - self.tables - .lock() - .expect("Can't lock tables") - .keys() - .map(|it| it.to_string()) - .collect() + async fn table_names(&self) -> BoxStream<'static, Result> { + let tables = self.tables.lock().expect("Can't lock tables"); + let tables = tables.keys().map(|k| Ok(k.clone())).collect::>(); + futures::stream::iter(tables).boxed() } async fn table( @@ -183,11 +182,11 @@ impl SchemaProvider for ListingSchemaProvider { .cloned()) } - fn register_table( + async fn register_table( &self, name: String, table: Arc, - ) -> datafusion_common::Result>> { + ) -> Result>> { self.tables .lock() .expect("Can't lock tables") @@ -195,14 +194,14 @@ impl SchemaProvider for ListingSchemaProvider { Ok(Some(table)) } - fn deregister_table( + async fn deregister_table( &self, name: &str, - ) -> datafusion_common::Result>> { + ) -> Result>> { Ok(self.tables.lock().expect("Can't lock tables").remove(name)) } - fn table_exist(&self, name: &str) -> bool { + async fn table_exist(&self, name: &str) -> bool { self.tables .lock() .expect("Can't lock tables") diff --git a/datafusion/core/src/catalog_common/memory.rs b/datafusion/core/src/catalog_common/memory.rs index 6cdefc31f18c..c4512d364e3b 100644 --- a/datafusion/core/src/catalog_common/memory.rs +++ b/datafusion/core/src/catalog_common/memory.rs @@ -23,7 +23,9 @@ use crate::catalog::{ }; use async_trait::async_trait; use dashmap::DashMap; -use datafusion_common::{exec_err, DataFusionError}; +use datafusion_common::{exec_err, DataFusionError, Result}; +use futures::stream::BoxStream; +use futures::{StreamExt, TryStreamExt}; use std::any::Any; use std::sync::Arc; @@ -49,25 +51,31 @@ impl Default for MemoryCatalogProviderList { } } +#[async_trait] impl CatalogProviderList for MemoryCatalogProviderList { fn as_any(&self) -> &dyn Any { self } - fn register_catalog( + async fn register_catalog( &self, name: String, catalog: Arc, - ) -> Option> { - self.catalogs.insert(name, catalog) + ) -> Result>> { + Ok(self.catalogs.insert(name, catalog)) } - fn catalog_names(&self) -> Vec { - self.catalogs.iter().map(|c| c.key().clone()).collect() + async fn catalog_names(&self) -> BoxStream<'static, Result> { + let catalog_names = self + .catalogs + .iter() + .map(|keyval| Ok(keyval.key().clone())) + .collect::>(); + futures::stream::iter(catalog_names).boxed() } - fn catalog(&self, name: &str) -> Option> { - self.catalogs.get(name).map(|c| Arc::clone(c.value())) + async fn catalog(&self, name: &str) -> Result>> { + Ok(self.catalogs.get(name).map(|c| Arc::clone(c.value()))) } } @@ -92,34 +100,40 @@ impl Default for MemoryCatalogProvider { } } +#[async_trait] impl CatalogProvider for MemoryCatalogProvider { fn as_any(&self) -> &dyn Any { self } - fn schema_names(&self) -> Vec { - self.schemas.iter().map(|s| s.key().clone()).collect() + async fn schema_names(&self) -> BoxStream<'static, Result> { + let schema_names = self + .schemas + .iter() + .map(|keyval| Ok(keyval.key().clone())) + .collect::>(); + futures::stream::iter(schema_names).boxed() } - fn schema(&self, name: &str) -> Option> { - self.schemas.get(name).map(|s| Arc::clone(s.value())) + async fn schema(&self, name: &str) -> Result>> { + Ok(self.schemas.get(name).map(|s| Arc::clone(s.value()))) } - fn register_schema( + async fn register_schema( &self, name: &str, schema: Arc, - ) -> datafusion_common::Result>> { + ) -> Result>> { Ok(self.schemas.insert(name.into(), schema)) } - fn deregister_schema( + async fn deregister_schema( &self, name: &str, cascade: bool, - ) -> datafusion_common::Result>> { - if let Some(schema) = self.schema(name) { - let table_names = schema.table_names(); + ) -> Result>> { + if let Some(schema) = self.schema(name).await? { + let table_names = schema.table_names().await.try_collect::>().await?; match (table_names.is_empty(), cascade) { (true, _) | (false, true) => { let (_, removed) = self.schemas.remove(name).unwrap(); @@ -164,39 +178,41 @@ impl SchemaProvider for MemorySchemaProvider { self } - fn table_names(&self) -> Vec { - self.tables + async fn table_names(&self) -> BoxStream<'static, Result> { + let table_names = self + .tables .iter() - .map(|table| table.key().clone()) - .collect() + .map(|keyval| Ok(keyval.key().clone())) + .collect::>(); + futures::stream::iter(table_names).boxed() } async fn table( &self, name: &str, - ) -> datafusion_common::Result>, DataFusionError> { + ) -> Result>, DataFusionError> { Ok(self.tables.get(name).map(|table| Arc::clone(table.value()))) } - fn register_table( + async fn register_table( &self, name: String, table: Arc, - ) -> datafusion_common::Result>> { - if self.table_exist(name.as_str()) { + ) -> Result>> { + if self.table_exist(name.as_str()).await { return exec_err!("The table {name} already exists"); } Ok(self.tables.insert(name, table)) } - fn deregister_table( + async fn deregister_table( &self, name: &str, - ) -> datafusion_common::Result>> { + ) -> Result>> { Ok(self.tables.remove(name).map(|(_, table)| table)) } - fn table_exist(&self, name: &str) -> bool { + async fn table_exist(&self, name: &str) -> bool { self.tables.contains_key(name) } } @@ -214,55 +230,59 @@ mod test { use std::any::Any; use std::sync::Arc; - #[test] - fn memory_catalog_dereg_nonempty_schema() { + #[tokio::test] + async fn memory_catalog_dereg_nonempty_schema() { let cat = Arc::new(MemoryCatalogProvider::new()) as Arc; let schema = Arc::new(MemorySchemaProvider::new()) as Arc; let test_table = Arc::new(EmptyTable::new(Arc::new(Schema::empty()))) as Arc; - schema.register_table("t".into(), test_table).unwrap(); + schema.register_table("t".into(), test_table).await.unwrap(); - cat.register_schema("foo", schema.clone()).unwrap(); + cat.register_schema("foo", schema.clone()).await.unwrap(); assert!( - cat.deregister_schema("foo", false).is_err(), + cat.deregister_schema("foo", false).await.is_err(), "dropping empty schema without cascade should error" ); - assert!(cat.deregister_schema("foo", true).unwrap().is_some()); + assert!(cat.deregister_schema("foo", true).await.unwrap().is_some()); } - #[test] - fn memory_catalog_dereg_empty_schema() { + #[tokio::test] + async fn memory_catalog_dereg_empty_schema() { let cat = Arc::new(MemoryCatalogProvider::new()) as Arc; let schema = Arc::new(MemorySchemaProvider::new()) as Arc; - cat.register_schema("foo", schema).unwrap(); + cat.register_schema("foo", schema).await.unwrap(); - assert!(cat.deregister_schema("foo", false).unwrap().is_some()); + assert!(cat.deregister_schema("foo", false).await.unwrap().is_some()); } - #[test] - fn memory_catalog_dereg_missing() { + #[tokio::test] + async fn memory_catalog_dereg_missing() { let cat = Arc::new(MemoryCatalogProvider::new()) as Arc; - assert!(cat.deregister_schema("foo", false).unwrap().is_none()); + assert!(cat.deregister_schema("foo", false).await.unwrap().is_none()); } - #[test] - fn default_register_schema_not_supported() { + #[tokio::test] + async fn default_register_schema_not_supported() { // mimic a new CatalogProvider and ensure it does not support registering schemas #[derive(Debug)] struct TestProvider {} + #[async_trait] impl CatalogProvider for TestProvider { fn as_any(&self) -> &dyn Any { self } - fn schema_names(&self) -> Vec { + async fn schema_names(&self) -> BoxStream<'static, Result> { unimplemented!() } - fn schema(&self, _name: &str) -> Option> { + async fn schema( + &self, + _name: &str, + ) -> Result>> { unimplemented!() } } @@ -270,7 +290,7 @@ mod test { let schema = Arc::new(MemorySchemaProvider::new()) as Arc; let catalog = Arc::new(TestProvider {}); - match catalog.register_schema("foo", schema) { + match catalog.register_schema("foo", schema).await { Ok(_) => panic!("unexpected OK"), Err(e) => assert_eq!(e.strip_backtrace(), "This feature is not implemented: Registering new schemas is not supported"), }; @@ -280,18 +300,24 @@ mod test { async fn test_mem_provider() { let provider = MemorySchemaProvider::new(); let table_name = "test_table_exist"; - assert!(!provider.table_exist(table_name)); - assert!(provider.deregister_table(table_name).unwrap().is_none()); + assert!(!provider.table_exist(table_name).await); + assert!(provider + .deregister_table(table_name) + .await + .unwrap() + .is_none()); let test_table = EmptyTable::new(Arc::new(Schema::empty())); // register table successfully assert!(provider .register_table(table_name.to_string(), Arc::new(test_table)) + .await .unwrap() .is_none()); - assert!(provider.table_exist(table_name)); + assert!(provider.table_exist(table_name).await); let other_table = EmptyTable::new(Arc::new(Schema::empty())); - let result = - provider.register_table(table_name.to_string(), Arc::new(other_table)); + let result = provider + .register_table(table_name.to_string(), Arc::new(other_table)) + .await; assert!(result.is_err()); } @@ -324,10 +350,16 @@ mod test { schema .register_table("alltypes_plain".to_string(), Arc::new(table)) + .await .unwrap(); - catalog.register_schema("active", Arc::new(schema)).unwrap(); - ctx.register_catalog("cat", Arc::new(catalog)); + catalog + .register_schema("active", Arc::new(schema)) + .await + .unwrap(); + ctx.register_catalog("cat", Arc::new(catalog)) + .await + .unwrap(); let df = ctx .sql("SELECT id, bool_col FROM cat.active.alltypes_plain") diff --git a/datafusion/core/src/dataframe/mod.rs b/datafusion/core/src/dataframe/mod.rs index 82ee52d7b2e3..92898af4cd6b 100644 --- a/datafusion/core/src/dataframe/mod.rs +++ b/datafusion/core/src/dataframe/mod.rs @@ -2201,7 +2201,7 @@ mod tests { let batch = RecordBatch::try_from_iter(vec![("f.c1", Arc::new(array) as _)])?; let ctx = SessionContext::new(); - ctx.register_batch("t", batch)?; + ctx.register_batch("t", batch).await?; let df = ctx.table("t").await?.select_columns(&["f.c1"])?; @@ -2302,7 +2302,7 @@ mod tests { ])?; let ctx = SessionContext::new(); - ctx.register_batch("t", batch)?; + ctx.register_batch("t", batch).await?; let df = ctx.table("t").await?.drop_columns(&["f\"c1"])?; @@ -2334,7 +2334,7 @@ mod tests { ])?; let ctx = SessionContext::new(); - ctx.register_batch("t", batch)?; + ctx.register_batch("t", batch).await?; let df = ctx.table("t").await?.drop_columns(&["f.c1"])?; @@ -3166,7 +3166,8 @@ mod tests { let df_impl = DataFrame::new(ctx.state(), df.plan.clone()); // register a dataframe as a table - ctx.register_table("test_table", df_impl.clone().into_view())?; + ctx.register_table("test_table", df_impl.clone().into_view()) + .await?; // pull the table out let table = ctx.table("test_table").await?; @@ -3347,8 +3348,8 @@ mod tests { let ctx = SessionContext::new(); let table = df.into_view(); - ctx.register_table("t1", table.clone())?; - ctx.register_table("t2", table)?; + ctx.register_table("t1", table.clone()).await?; + ctx.register_table("t2", table).await?; let df = ctx .table("t1") .await? @@ -3463,8 +3464,8 @@ mod tests { let ctx = SessionContext::new(); let table = df.into_view(); - ctx.register_table("t1", table.clone())?; - ctx.register_table("t2", table)?; + ctx.register_table("t1", table.clone()).await?; + ctx.register_table("t2", table).await?; let actual_err = ctx .table("t1") @@ -3491,8 +3492,8 @@ mod tests { let ctx = SessionContext::new(); let table = df.into_view(); - ctx.register_table("t1", table.clone())?; - ctx.register_table("t2", table)?; + ctx.register_table("t1", table.clone()).await?; + ctx.register_table("t2", table).await?; let df = ctx .table("t1") .await? @@ -3659,7 +3660,7 @@ mod tests { )?; let ctx = SessionContext::new(); - ctx.register_batch("test", data)?; + ctx.register_batch("test", data).await?; let sql = r#" SELECT @@ -3682,7 +3683,7 @@ mod tests { let batch = RecordBatch::try_from_iter(vec![("f.c1", Arc::new(array) as _)])?; let ctx = SessionContext::new(); - ctx.register_batch("t", batch)?; + ctx.register_batch("t", batch).await?; let df = ctx .table("t") diff --git a/datafusion/core/src/dataframe/parquet.rs b/datafusion/core/src/dataframe/parquet.rs index 0af68783c41f..0ebb86c71362 100644 --- a/datafusion/core/src/dataframe/parquet.rs +++ b/datafusion/core/src/dataframe/parquet.rs @@ -126,7 +126,8 @@ mod tests { ) .await?; - ctx.register_table("t1", ctx.table("test").await?.into_view())?; + ctx.register_table("t1", ctx.table("test").await?.into_view()) + .await?; let df = ctx .table("t1") diff --git a/datafusion/core/src/datasource/file_format/csv.rs b/datafusion/core/src/datasource/file_format/csv.rs index 9f979ddf01e7..6b1449f3243f 100644 --- a/datafusion/core/src/datasource/file_format/csv.rs +++ b/datafusion/core/src/datasource/file_format/csv.rs @@ -991,7 +991,8 @@ mod tests { .with_config(cfg) .with_runtime_env(runtime) .with_default_features() - .build(); + .build() + .await; let integration = LocalFileSystem::new_with_prefix(arrow_test_data()).unwrap(); let path = Path::from("csv/aggregate_test_100.csv"); let csv = CsvFormat::default().with_has_header(true); diff --git a/datafusion/core/src/datasource/listing/helpers.rs b/datafusion/core/src/datasource/listing/helpers.rs index 04c64156b125..9ca8e12228fa 100644 --- a/datafusion/core/src/datasource/listing/helpers.rs +++ b/datafusion/core/src/datasource/listing/helpers.rs @@ -567,7 +567,8 @@ mod tests { let (store, state) = make_test_store_and_state(&[ ("tablepath/mypartition=val1/notparquetfile", 100), ("tablepath/file.parquet", 100), - ]); + ]) + .await; let filter = Expr::eq(col("mypartition"), lit("val1")); let pruned = pruned_partition_list( &state, @@ -591,7 +592,8 @@ mod tests { ("tablepath/mypartition=val1/file.parquet", 100), ("tablepath/mypartition=val2/file.parquet", 100), ("tablepath/mypartition=val1/other=val3/file.parquet", 100), - ]); + ]) + .await; let filter = Expr::eq(col("mypartition"), lit("val1")); let pruned = pruned_partition_list( &state, @@ -630,7 +632,8 @@ mod tests { ("tablepath/part1=p1v2/part2=p2v1/file2.parquet", 100), ("tablepath/part1=p1v3/part2=p2v1/file2.parquet", 100), ("tablepath/part1=p1v2/part2=p2v2/file2.parquet", 100), - ]); + ]) + .await; let filter1 = Expr::eq(col("part1"), lit("p1v2")); let filter2 = Expr::eq(col("part2"), lit("p2v1")); let pruned = pruned_partition_list( diff --git a/datafusion/core/src/datasource/listing/table.rs b/datafusion/core/src/datasource/listing/table.rs index ffe49dd2ba11..6b2a0d987907 100644 --- a/datafusion/core/src/datasource/listing/table.rs +++ b/datafusion/core/src/datasource/listing/table.rs @@ -1348,7 +1348,7 @@ mod tests { async fn read_empty_table() -> Result<()> { let ctx = SessionContext::new(); let path = String::from("table/p1=v1/file.avro"); - register_test_store(&ctx, &[(&path, 100)]); + register_test_store(&ctx, &[(&path, 100)]).await; let opt = ListingOptions::new(Arc::new(AvroFormat {})) .with_file_extension(AvroFormat.get_ext()) @@ -1592,7 +1592,8 @@ mod tests { file_ext: Option<&str>, ) -> Result<()> { let ctx = SessionContext::new(); - register_test_store(&ctx, &files.iter().map(|f| (*f, 10)).collect::>()); + register_test_store(&ctx, &files.iter().map(|f| (*f, 10)).collect::>()) + .await; let format = AvroFormat {}; @@ -1626,7 +1627,8 @@ mod tests { file_ext: Option<&str>, ) -> Result<()> { let ctx = SessionContext::new(); - register_test_store(&ctx, &files.iter().map(|f| (*f, 10)).collect::>()); + register_test_store(&ctx, &files.iter().map(|f| (*f, 10)).collect::>()) + .await; let format = AvroFormat {}; @@ -2028,7 +2030,9 @@ mod tests { schema.clone(), vec![vec![batch.clone(), batch.clone()]], )?); - session_ctx.register_table("source", source_table.clone())?; + session_ctx + .register_table("source", source_table.clone()) + .await?; // Convert the source table into a provider so that it can be used in a query let source = provider_as_source(source_table); // Create a table scan logical plan to read from the source table diff --git a/datafusion/core/src/datasource/memory.rs b/datafusion/core/src/datasource/memory.rs index c1e0bea0b3ff..0663dc248ba1 100644 --- a/datafusion/core/src/datasource/memory.rs +++ b/datafusion/core/src/datasource/memory.rs @@ -629,10 +629,14 @@ mod tests { let session_ctx = SessionContext::new(); // Create and register the initial table with the provided schema and data let initial_table = Arc::new(MemTable::try_new(schema.clone(), initial_data)?); - session_ctx.register_table("t", initial_table.clone())?; + session_ctx + .register_table("t", initial_table.clone()) + .await?; // Create and register the source table with the provided schema and inserted data let source_table = Arc::new(MemTable::try_new(schema.clone(), inserted_data)?); - session_ctx.register_table("source", source_table.clone())?; + session_ctx + .register_table("source", source_table.clone()) + .await?; // Convert the source table into a provider so that it can be used in a query let source = provider_as_source(source_table); // Create a table scan logical plan to read from the source table diff --git a/datafusion/core/src/datasource/physical_plan/file_stream.rs b/datafusion/core/src/datasource/physical_plan/file_stream.rs index 18cda4524ab2..ca1d30083d79 100644 --- a/datafusion/core/src/datasource/physical_plan/file_stream.rs +++ b/datafusion/core/src/datasource/physical_plan/file_stream.rs @@ -636,7 +636,7 @@ mod tests { .map(|(name, size)| (name.as_str(), *size)) .collect(); - register_test_store(&ctx, &mock_files_ref); + register_test_store(&ctx, &mock_files_ref).await; let file_group = mock_files .into_iter() diff --git a/datafusion/core/src/datasource/view.rs b/datafusion/core/src/datasource/view.rs index 1ffe54e4b06c..75f90fb67790 100644 --- a/datafusion/core/src/datasource/view.rs +++ b/datafusion/core/src/datasource/view.rs @@ -461,7 +461,8 @@ mod tests { ) .await?; - ctx.register_table("t1", ctx.table("test").await?.into_view())?; + ctx.register_table("t1", ctx.table("test").await?.into_view()) + .await?; ctx.sql("CREATE VIEW t2 as SELECT * FROM t1").await?; @@ -492,7 +493,8 @@ mod tests { ) .await?; - ctx.register_table("t1", ctx.table("test").await?.into_view())?; + ctx.register_table("t1", ctx.table("test").await?.into_view()) + .await?; ctx.sql("CREATE VIEW t2 as SELECT * FROM t1").await?; diff --git a/datafusion/core/src/execution/context/mod.rs b/datafusion/core/src/execution/context/mod.rs index 5f01d41c31e7..b90706ce8017 100644 --- a/datafusion/core/src/execution/context/mod.rs +++ b/datafusion/core/src/execution/context/mod.rs @@ -66,6 +66,8 @@ use datafusion_expr::{ planner::ExprPlanner, Expr, UserDefinedLogicalNode, WindowUDF, }; +use futures::stream::BoxStream; +use futures::{FutureExt, TryStreamExt}; // backwards compatibility pub use crate::execution::session_state::SessionState; @@ -212,9 +214,12 @@ where /// /// ``` /// # use std::sync::Arc; +/// # use datafusion::error::Result; /// # use datafusion::prelude::*; /// # use datafusion::execution::SessionStateBuilder; /// # use datafusion_execution::runtime_env::RuntimeEnvBuilder; +/// # #[tokio::main] +/// # async fn main() -> Result<()> { /// // Configure a 4k batch size /// let config = SessionConfig::new() .with_batch_size(4 * 1024); /// @@ -230,10 +235,13 @@ where /// .with_runtime_env(runtime_env) /// // include support for built in functions and configurations /// .with_default_features() -/// .build(); +/// .build() +/// .await; /// /// // Create a SessionContext /// let ctx = SessionContext::from(state); +/// # Ok(()) +/// # } /// ``` /// /// # Relationship between `SessionContext`, `SessionState`, and `TaskContext` @@ -285,13 +293,14 @@ impl SessionContext { /// Finds any [`ListingSchemaProvider`]s and instructs them to reload tables from "disk" pub async fn refresh_catalogs(&self) -> Result<()> { - let cat_names = self.catalog_names().clone(); - for cat_name in cat_names.iter() { - let cat = self.catalog(cat_name.as_str()).ok_or_else(|| { + let mut cat_names = self.catalog_names().await; + while let Some(cat_name) = cat_names.try_next().await? { + let cat = self.catalog(cat_name.as_str()).await?.ok_or_else(|| { DataFusionError::Internal("Catalog not found!".to_string()) })?; - for schema_name in cat.schema_names() { - let schema = cat.schema(schema_name.as_str()).ok_or_else(|| { + let mut schema_names = cat.schema_names().await; + while let Some(schema_name) = schema_names.try_next().await? { + let schema = cat.schema(&schema_name).await?.ok_or_else(|| { DataFusionError::Internal("Schema not found!".to_string()) })?; let lister = schema.as_any().downcast_ref::(); @@ -331,7 +340,7 @@ impl SessionContext { .with_config(config) .with_runtime_env(runtime) .with_default_features() - .build(); + .build().now_or_never().expect("state built without custom catalog should use synchronous in-memory catalog"); Self::new_with_state(state) } @@ -364,7 +373,7 @@ impl SessionContext { /// # #[tokio::main] /// # async fn main() -> Result<()> { /// let ctx = SessionContext::new() - /// .enable_url_table(); // permit local file access + /// .enable_url_table().await; // permit local file access /// let results = ctx /// .sql("SELECT a, MIN(b) FROM 'tests/data/example.csv' as example GROUP BY a LIMIT 100") /// .await? @@ -383,7 +392,7 @@ impl SessionContext { /// # Ok(()) /// # } /// ``` - pub fn enable_url_table(self) -> Self { + pub async fn enable_url_table(self) -> Self { let current_catalog_list = Arc::clone(self.state.read().catalog_list()); let factory = Arc::new(DynamicListTableFactory::new(SessionStore::new())); let catalog_list = Arc::new(DynamicFileCatalog::new( @@ -392,8 +401,10 @@ impl SessionContext { )); let ctx: SessionContext = self .into_state_builder() + .await .with_catalog_list(catalog_list) .build() + .await .into(); // register new state with the factory factory.session_store().with_state(ctx.state_weak_ref()); @@ -410,21 +421,29 @@ impl SessionContext { /// # Example /// ``` /// # use std::sync::Arc; + /// # use datafusion::error::Result; /// # use datafusion::prelude::*; /// # use datafusion::execution::SessionStateBuilder; /// # use datafusion_optimizer::push_down_filter::PushDownFilter; + /// # #[tokio::main] + /// # async fn main() -> Result<()> { /// let my_rule = PushDownFilter{}; // pretend it is a new rule /// // Create a new builder with a custom optimizer rule /// let context: SessionContext = SessionStateBuilder::new() /// .with_optimizer_rule(Arc::new(my_rule)) /// .build() + /// .await /// .into(); /// // Enable local file access and convert context back to a builder /// let builder = context /// .enable_url_table() - /// .into_state_builder(); + /// .await + /// .into_state_builder() + /// .await; + /// # Ok(()) + /// # } /// ``` - pub fn into_state_builder(self) -> SessionStateBuilder { + pub async fn into_state_builder(self) -> SessionStateBuilder { let SessionContext { session_id: _, session_start_time: _, @@ -434,7 +453,7 @@ impl SessionContext { Ok(rwlock) => rwlock.into_inner(), Err(state) => state.read().clone(), }; - SessionStateBuilder::from(state) + SessionStateBuilder::new_from_existing(state).await } /// Returns the time this `SessionContext` was created @@ -492,7 +511,7 @@ impl SessionContext { } /// Registers the [`RecordBatch`] as the specified table name - pub fn register_batch( + pub async fn register_batch( &self, table_name: &str, batch: RecordBatch, @@ -504,6 +523,7 @@ impl SessionContext { }, Arc::new(table), ) + .await } /// Return the [RuntimeEnv] used to run queries with this `SessionContext` @@ -568,7 +588,7 @@ impl SessionContext { /// .await? /// .collect() /// .await?; - /// assert!(ctx.table_exist("foo").unwrap()); + /// assert!(ctx.table_exist("foo").await.unwrap()); /// # Ok(()) /// # } /// ``` @@ -654,9 +674,7 @@ impl SessionContext { // stack overflows. match ddl { DdlStatement::CreateExternalTable(cmd) => { - (Box::pin(async move { self.create_external_table(&cmd).await }) - as std::pin::Pin + Send>>) - .await + self.create_external_table(&cmd).await } DdlStatement::CreateMemoryTable(cmd) => { Box::pin(self.create_memory_table(cmd)).await @@ -770,7 +788,7 @@ impl SessionContext { &self, cmd: &CreateExternalTable, ) -> Result { - let exist = self.table_exist(cmd.name.clone())?; + let exist = self.table_exist(cmd.name.clone()).await?; if cmd.temporary { return not_impl_err!("Temporary tables not supported"); @@ -787,7 +805,8 @@ impl SessionContext { let table_provider: Arc = self.create_custom_table(cmd).await?; - self.register_table(cmd.name.clone(), table_provider)?; + self.register_table(cmd.name.clone(), table_provider) + .await?; self.return_empty_dataframe() } @@ -813,7 +832,7 @@ impl SessionContext { match (if_not_exists, or_replace, table) { (true, false, Ok(_)) => self.return_empty_dataframe(), (false, true, Ok(_)) => { - self.deregister_table(name.clone())?; + self.deregister_table(name.clone()).await?; let schema = Arc::new(input.schema().as_ref().into()); let physical = DataFrame::new(self.state(), input); @@ -825,7 +844,7 @@ impl SessionContext { .with_column_defaults(column_defaults.into_iter().collect()), ); - self.register_table(name.clone(), table)?; + self.register_table(name.clone(), table).await?; self.return_empty_dataframe() } (true, true, Ok(_)) => { @@ -844,7 +863,7 @@ impl SessionContext { .with_column_defaults(column_defaults.into_iter().collect()), ); - self.register_table(name, table)?; + self.register_table(name, table).await?; self.return_empty_dataframe() } (false, false, Ok(_)) => exec_err!("Table '{name}' already exists"), @@ -868,15 +887,15 @@ impl SessionContext { match (or_replace, view) { (true, Ok(_)) => { - self.deregister_table(name.clone())?; + self.deregister_table(name.clone()).await?; let table = Arc::new(ViewTable::try_new((*input).clone(), definition)?); - self.register_table(name, table)?; + self.register_table(name, table).await?; self.return_empty_dataframe() } (_, Err(_)) => { let table = Arc::new(ViewTable::try_new((*input).clone(), definition)?); - self.register_table(name, table)?; + self.register_table(name, table).await?; self.return_empty_dataframe() } (false, Ok(_)) => exec_err!("Table '{name}' already exists"), @@ -895,9 +914,14 @@ impl SessionContext { let tokens: Vec<&str> = schema_name.split('.').collect(); let (catalog, schema_name) = match tokens.len() { 1 => { - let state = self.state.read(); - let name = &state.config().options().catalog.default_catalog; - let catalog = state.catalog_list().catalog(name).ok_or_else(|| { + let (name, catalog_list) = { + let state = self.state.read(); + let name = state.config().options().catalog.default_catalog.clone(); + let catalog_list = Arc::clone(state.catalog_list()); + (name, catalog_list) + }; + let catalog_fut = catalog_list.catalog(&name); + let catalog = catalog_fut.await?.ok_or_else(|| { DataFusionError::Execution(format!( "Missing default catalog '{name}'" )) @@ -906,20 +930,20 @@ impl SessionContext { } 2 => { let name = &tokens[0]; - let catalog = self.catalog(name).ok_or_else(|| { + let catalog = self.catalog(name).await?.ok_or_else(|| { DataFusionError::Execution(format!("Missing catalog '{name}'")) })?; (catalog, tokens[1]) } _ => return exec_err!("Unable to parse catalog from {schema_name}"), }; - let schema = catalog.schema(schema_name); + let schema = catalog.schema(schema_name).await?; match (if_not_exists, schema) { (true, Some(_)) => self.return_empty_dataframe(), (true, None) | (false, None) => { let schema = Arc::new(MemorySchemaProvider::new()); - catalog.register_schema(schema_name, schema)?; + catalog.register_schema(schema_name, schema).await?; self.return_empty_dataframe() } (false, Some(_)) => exec_err!("Schema '{schema_name}' already exists"), @@ -932,16 +956,16 @@ impl SessionContext { if_not_exists, .. } = cmd; - let catalog = self.catalog(catalog_name.as_str()); + let catalog = self.catalog(catalog_name.as_str()).await?; match (if_not_exists, catalog) { (true, Some(_)) => self.return_empty_dataframe(), (true, None) | (false, None) => { let new_catalog = Arc::new(MemoryCatalogProvider::new()); - self.state - .write() - .catalog_list() - .register_catalog(catalog_name, new_catalog); + let catalog_list = Arc::clone(self.state.write().catalog_list()); + catalog_list + .register_catalog(catalog_name, new_catalog) + .await?; self.return_empty_dataframe() } (false, Some(_)) => exec_err!("Catalog '{catalog_name}' already exists"), @@ -984,14 +1008,18 @@ impl SessionContext { schema: _, } = cmd; let catalog = { - let state = self.state.read(); - let catalog_name = match &name { - SchemaReference::Full { catalog, .. } => catalog.to_string(), - SchemaReference::Bare { .. } => { - state.config_options().catalog.default_catalog.to_string() - } + let (catalog_name, catalog_list) = { + let state = self.state.read(); + let catalog_name = match &name { + SchemaReference::Full { catalog, .. } => catalog.to_string(), + SchemaReference::Bare { .. } => { + state.config_options().catalog.default_catalog.to_string() + } + }; + let catalog_list = Arc::clone(state.catalog_list()); + (catalog_name, catalog_list) }; - if let Some(catalog) = state.catalog_list().catalog(&catalog_name) { + if let Some(catalog) = catalog_list.catalog(&catalog_name).await? { catalog } else if allow_missing { return self.return_empty_dataframe(); @@ -999,7 +1027,9 @@ impl SessionContext { return self.schema_doesnt_exist_err(name); } }; - let dereg = catalog.deregister_schema(name.schema_name(), cascade)?; + let dereg = catalog + .deregister_schema(name.schema_name(), cascade) + .await?; match (dereg, allow_missing) { (None, true) => self.return_empty_dataframe(), (None, false) => self.schema_doesnt_exist_err(name), @@ -1051,18 +1081,23 @@ impl SessionContext { let table_ref = table_ref.into(); let table = table_ref.table().to_owned(); let maybe_schema = { - let state = self.state.read(); - let resolved = state.resolve_table_ref(table_ref); - state - .catalog_list() - .catalog(&resolved.catalog) - .and_then(|c| c.schema(&resolved.schema)) + let (resolved, catalog_list) = { + let state = self.state.read(); + let resolved = state.resolve_table_ref(table_ref); + let catalog_list = Arc::clone(state.catalog_list()); + (resolved, catalog_list) + }; + if let Some(catalog) = catalog_list.catalog(&resolved.catalog).await? { + catalog.schema(&resolved.schema).await? + } else { + None + } }; if let Some(schema) = maybe_schema { if let Some(table_provider) = schema.table(&table).await? { if table_provider.table_type() == table_type { - schema.deregister_table(&table)?; + schema.deregister_table(&table).await?; return Ok(true); } } @@ -1371,7 +1406,7 @@ impl SessionContext { .with_listing_options(options) .with_schema(resolved_schema); let table = ListingTable::try_new(config)?.with_definition(sql_definition); - self.register_table(table_ref, Arc::new(table))?; + self.register_table(table_ref, Arc::new(table)).await?; Ok(()) } @@ -1403,26 +1438,26 @@ impl SessionContext { /// /// Returns the [`CatalogProvider`] previously registered for this /// name, if any - pub fn register_catalog( + pub async fn register_catalog( &self, name: impl Into, catalog: Arc, - ) -> Option> { + ) -> Result>> { let name = name.into(); - self.state - .read() - .catalog_list() - .register_catalog(name, catalog) + let catalog_list = Arc::clone(self.state.read().catalog_list()); + catalog_list.register_catalog(name, catalog).await } /// Retrieves the list of available catalog names. - pub fn catalog_names(&self) -> Vec { - self.state.read().catalog_list().catalog_names() + pub async fn catalog_names(&self) -> BoxStream<'static, Result> { + let catalog_list = Arc::clone(self.state.read().catalog_list()); + catalog_list.catalog_names().await } /// Retrieves a [`CatalogProvider`] instance by name - pub fn catalog(&self, name: &str) -> Option> { - self.state.read().catalog_list().catalog(name) + pub async fn catalog(&self, name: &str) -> Result>> { + let catalog_list = Arc::clone(self.state.read().catalog_list()); + catalog_list.catalog(name).await } /// Registers a [`TableProvider`] as a table that can be @@ -1430,44 +1465,40 @@ impl SessionContext { /// /// If a table of the same name was already registered, returns "Table /// already exists" error. - pub fn register_table( + pub async fn register_table( &self, table_ref: impl Into, provider: Arc, ) -> Result>> { let table_ref: TableReference = table_ref.into(); let table = table_ref.table().to_owned(); - self.state - .read() - .schema_for_ref(table_ref)? - .register_table(table, provider) + let schema_fut = self.state.read().schema_for_ref(table_ref); + schema_fut.await?.register_table(table, provider).await } /// Deregisters the given table. /// /// Returns the registered provider, if any - pub fn deregister_table( + pub async fn deregister_table( &self, table_ref: impl Into, ) -> Result>> { let table_ref = table_ref.into(); let table = table_ref.table().to_owned(); - self.state - .read() - .schema_for_ref(table_ref)? - .deregister_table(&table) + let schema_fut = self.state.read().schema_for_ref(table_ref); + schema_fut.await?.deregister_table(&table).await } /// Return `true` if the specified table exists in the schema provider. - pub fn table_exist(&self, table_ref: impl Into) -> Result { + pub async fn table_exist( + &self, + table_ref: impl Into, + ) -> Result { let table_ref: TableReference = table_ref.into(); let table = table_ref.table(); let table_ref = table_ref.clone(); - Ok(self - .state - .read() - .schema_for_ref(table_ref)? - .table_exist(table)) + let schema_fut = self.state.read().schema_for_ref(table_ref); + Ok(schema_fut.await?.table_exist(table).await) } /// Retrieves a [`DataFrame`] representing a table previously @@ -1513,7 +1544,8 @@ impl SessionContext { ) -> Result> { let table_ref = table_ref.into(); let table = table_ref.table().to_string(); - let schema = self.state.read().schema_for_ref(table_ref)?; + let schema_fut = self.state.read().schema_for_ref(table_ref); + let schema = schema_fut.await?; match schema.table(&table).await? { Some(ref provider) => Ok(Arc::clone(provider)), _ => plan_err!("No table named '{table}'"), @@ -1631,12 +1663,6 @@ impl From for SessionContext { } } -impl From for SessionStateBuilder { - fn from(session: SessionContext) -> Self { - session.into_state_builder() - } -} - /// A planner used to add extensions to DataFusion logical and physical plans. #[async_trait] pub trait QueryPlanner: Debug { @@ -1855,7 +1881,7 @@ mod tests { ctx.register_variable(VarType::UserDefined, Arc::new(variable_provider)); let provider = test::create_table_dual(); - ctx.register_table("dual", provider)?; + ctx.register_table("dual", provider).await?; let results = plan_and_collect(&ctx, "SELECT @@version, @name, @integer + 1 FROM dual") @@ -1892,10 +1918,10 @@ mod tests { let ctx = create_ctx(&tmp_dir, partition_count).await?; let provider = test::create_table_dual(); - ctx.register_table("dual", provider)?; + ctx.register_table("dual", provider).await?; - assert!(ctx.deregister_table("dual")?.is_some()); - assert!(ctx.deregister_table("dual")?.is_none()); + assert!(ctx.deregister_table("dual").await?.is_some()); + assert!(ctx.deregister_table("dual").await?.is_none()); Ok(()) } @@ -1940,7 +1966,8 @@ mod tests { .with_config(cfg) .with_runtime_env(runtime) .with_default_features() - .build(); + .build() + .await; let ctx = SessionContext::new_with_state(session_state); ctx.refresh_catalogs().await?; @@ -1972,8 +1999,11 @@ mod tests { let session_state = SessionStateBuilder::new() .with_default_features() .with_config(cfg) - .build(); - let ctx = SessionContext::new_with_state(session_state).enable_url_table(); + .build() + .await; + let ctx = SessionContext::new_with_state(session_state) + .enable_url_table() + .await; let result = plan_and_collect( &ctx, format!("select c_name from '{}' limit 3;", &url).as_str(), @@ -2003,7 +2033,8 @@ mod tests { .with_runtime_env(runtime) .with_default_features() .with_query_planner(Arc::new(MyQueryPlanner {})) - .build(); + .build() + .await; let ctx = SessionContext::new_with_state(session_state); let df = ctx.sql("SELECT 1").await?; @@ -2018,7 +2049,8 @@ mod tests { ); assert!(matches!( - ctx.register_table("test", test::table_with_sequence(1, 1)?), + ctx.register_table("test", test::table_with_sequence(1, 1)?) + .await, Err(DataFusionError::Plan(_)) )); @@ -2061,11 +2093,15 @@ mod tests { let schema = MemorySchemaProvider::new(); schema .register_table("test".to_owned(), test::table_with_sequence(1, 1).unwrap()) + .await .unwrap(); catalog .register_schema("my_schema", Arc::new(schema)) + .await + .unwrap(); + ctx.register_catalog("my_catalog", Arc::new(catalog)) + .await .unwrap(); - ctx.register_catalog("my_catalog", Arc::new(catalog)); for table_ref in &["my_catalog.my_schema.test", "my_schema.test", "test"] { let result = plan_and_collect( @@ -2093,16 +2129,25 @@ mod tests { let catalog_a = MemoryCatalogProvider::new(); let schema_a = MemorySchemaProvider::new(); schema_a - .register_table("table_a".to_owned(), test::table_with_sequence(1, 1)?)?; - catalog_a.register_schema("schema_a", Arc::new(schema_a))?; - ctx.register_catalog("catalog_a", Arc::new(catalog_a)); + .register_table("table_a".to_owned(), test::table_with_sequence(1, 1)?) + .await?; + catalog_a + .register_schema("schema_a", Arc::new(schema_a)) + .await?; + ctx.register_catalog("catalog_a", Arc::new(catalog_a)) + .await + .unwrap(); let catalog_b = MemoryCatalogProvider::new(); let schema_b = MemorySchemaProvider::new(); schema_b - .register_table("table_b".to_owned(), test::table_with_sequence(1, 2)?)?; - catalog_b.register_schema("schema_b", Arc::new(schema_b))?; - ctx.register_catalog("catalog_b", Arc::new(catalog_b)); + .register_table("table_b".to_owned(), test::table_with_sequence(1, 2)?) + .await?; + catalog_b + .register_schema("schema_b", Arc::new(schema_b)) + .await?; + ctx.register_catalog("catalog_b", Arc::new(catalog_b)) + .await?; let result = plan_and_collect( &ctx, @@ -2140,7 +2185,7 @@ mod tests { // register a single catalog let catalog = Arc::new(MemoryCatalogProvider::new()); let catalog_weak = Arc::downgrade(&catalog); - ctx.register_catalog("my_catalog", catalog); + ctx.register_catalog("my_catalog", catalog).await.unwrap(); let catalog_list_weak = { let state = ctx.state.read(); @@ -2207,7 +2252,8 @@ mod tests { let state = SessionStateBuilder::new() .with_default_features() .with_type_planner(Arc::new(MyTypePlanner {})) - .build(); + .build() + .await; let ctx = SessionContext::new_with_state(state); let result = ctx .sql("SELECT DATETIME '2021-01-01 00:00:00'") diff --git a/datafusion/core/src/execution/session_state.rs b/datafusion/core/src/execution/session_state.rs index e99cf8222381..2d5a807bf25a 100644 --- a/datafusion/core/src/execution/session_state.rs +++ b/datafusion/core/src/execution/session_state.rs @@ -66,6 +66,8 @@ use datafusion_physical_optimizer::PhysicalOptimizerRule; use datafusion_physical_plan::ExecutionPlan; use datafusion_sql::parser::{DFParser, Statement}; use datafusion_sql::planner::{ContextProvider, ParserOptions, PlannerContext, SqlToRel}; +use futures::future::BoxFuture; +use futures::FutureExt; use itertools::Itertools; use log::{debug, info}; use object_store::ObjectStore; @@ -268,12 +270,14 @@ impl SessionState { .with_runtime_env(runtime) .with_default_features() .build() + .now_or_never() + .expect("state built without custom catalog should use synchronous in-memory catalog") } /// Returns new [`SessionState`] using the provided /// [`SessionConfig`], [`RuntimeEnv`], and [`CatalogProviderList`] #[deprecated(since = "40.0.0", note = "Use SessionStateBuilder")] - pub fn new_with_config_rt_and_catalog_list( + pub async fn new_with_config_rt_and_catalog_list( config: SessionConfig, runtime: Arc, catalog_list: Arc, @@ -284,6 +288,7 @@ impl SessionState { .with_catalog_list(catalog_list) .with_default_features() .build() + .await } pub(crate) fn resolve_table_ref( @@ -297,31 +302,44 @@ impl SessionState { } /// Retrieve the [`SchemaProvider`] for a specific [`TableReference`], if it - /// esists. + /// exists. pub fn schema_for_ref( &self, table_ref: impl Into, - ) -> datafusion_common::Result> { + ) -> BoxFuture<'static, datafusion_common::Result>> { let resolved_ref = self.resolve_table_ref(table_ref); + let catalog_list = Arc::clone(&self.catalog_list); if self.config.information_schema() && *resolved_ref.schema == *INFORMATION_SCHEMA { - return Ok(Arc::new(InformationSchemaProvider::new(Arc::clone( - &self.catalog_list, - )))); + return std::future::ready(Ok(Arc::new(InformationSchemaProvider::new( + catalog_list, + )) as Arc)) + .boxed(); } - self.catalog_list - .catalog(&resolved_ref.catalog) - .ok_or_else(|| { - plan_datafusion_err!( - "failed to resolve catalog: {}", - resolved_ref.catalog - ) - })? - .schema(&resolved_ref.schema) - .ok_or_else(|| { - plan_datafusion_err!("failed to resolve schema: {}", resolved_ref.schema) - }) + // Note: This function is not declared async we can force the returned future to be 'static + // which is useful in a number of places this is called to avoid holding the `self.state` + // lock across an await boundary. + async move { + catalog_list + .catalog(&resolved_ref.catalog) + .await? + .ok_or_else(|| { + plan_datafusion_err!( + "failed to resolve catalog: {}", + resolved_ref.catalog + ) + })? + .schema(&resolved_ref.schema) + .await? + .ok_or_else(|| { + plan_datafusion_err!( + "failed to resolve schema: {}", + resolved_ref.schema + ) + }) + } + .boxed() } #[deprecated(since = "40.0.0", note = "Use SessionStateBuilder")] @@ -548,7 +566,7 @@ impl SessionState { let resolved = self.resolve_table_ref(reference); if let Entry::Vacant(v) = provider.tables.entry(resolved) { let resolved = v.key(); - if let Ok(schema) = self.schema_for_ref(resolved.clone()) { + if let Ok(schema) = self.schema_for_ref(resolved.clone()).await { if let Some(table) = schema.table(&resolved.table).await? { v.insert(provider_as_source(table)); } @@ -1019,10 +1037,12 @@ impl SessionStateBuilder { /// be cloned from what is set in the provided session state. If the default /// catalog exists in existing session state, the new session state will not /// create default catalog and schema. - pub fn new_from_existing(existing: SessionState) -> Self { + pub async fn new_from_existing(existing: SessionState) -> Self { let default_catalog_exist = existing .catalog_list() .catalog(&existing.config.options().catalog.default_catalog) + .await + .unwrap() .is_some(); // The new `with_create_default_catalog_and_schema` should be false if the default catalog exists let create_default_catalog_and_schema = existing @@ -1327,7 +1347,7 @@ impl SessionStateBuilder { /// Note that there is an explicit option for enabling catalog and schema defaults /// in [SessionConfig::create_default_catalog_and_schema] which if enabled /// will be built here. - pub fn build(self) -> SessionState { + pub async fn build(self) -> SessionState { let Self { session_id, analyzer, @@ -1427,10 +1447,14 @@ impl SessionStateBuilder { &state.runtime_env, ); - state.catalog_list.register_catalog( - state.config.options().catalog.default_catalog.clone(), - Arc::new(default_catalog), - ); + state + .catalog_list + .register_catalog( + state.config.options().catalog.default_catalog.clone(), + Arc::new(default_catalog), + ) + .await + .unwrap(); } if let Some(analyzer_rules) = analyzer_rules { @@ -1621,12 +1645,6 @@ impl Default for SessionStateBuilder { } } -impl From for SessionStateBuilder { - fn from(state: SessionState) -> Self { - SessionStateBuilder::new_from_existing(state) - } -} - /// Adapter that implements the [`ContextProvider`] trait for a [`SessionState`] /// /// This is used so the SQL planner can access the state of the session without @@ -1975,8 +1993,8 @@ mod tests { use std::collections::HashMap; use std::sync::Arc; - #[test] - fn test_session_state_with_default_features() { + #[tokio::test] + async fn test_session_state_with_default_features() { // test array planners with and without builtin planners fn sql_to_expr(state: &SessionState) -> Result { let provider = SessionContextProvider { @@ -1994,18 +2012,21 @@ mod tests { query.sql_to_expr(sql_expr, &df_schema, &mut PlannerContext::new()) } - let state = SessionStateBuilder::new().with_default_features().build(); + let state = SessionStateBuilder::new() + .with_default_features() + .build() + .await; assert!(sql_to_expr(&state).is_ok()); // if no builtin planners exist, you should register your own, otherwise returns error - let state = SessionStateBuilder::new().build(); + let state = SessionStateBuilder::new().build().await; assert!(sql_to_expr(&state).is_err()) } - #[test] - fn test_from_existing() -> Result<()> { + #[tokio::test] + async fn test_from_existing() -> Result<()> { fn employee_batch() -> RecordBatch { let name: ArrayRef = Arc::new(StringArray::from_iter_values(["Andy", "Andrew"])); @@ -2017,11 +2038,14 @@ mod tests { let session_state = SessionStateBuilder::new() .with_catalog_list(Arc::new(MemoryCatalogProviderList::new())) - .build(); + .build() + .await; let table_ref = session_state.resolve_table_ref("employee").to_string(); session_state - .schema_for_ref(&table_ref)? - .register_table("employee".to_string(), Arc::new(table))?; + .schema_for_ref(&table_ref) + .await? + .register_table("employee".to_string(), Arc::new(table)) + .await?; let default_catalog = session_state .config @@ -2038,38 +2062,63 @@ mod tests { let is_exist = session_state .catalog_list() .catalog(default_catalog.as_str()) + .await .unwrap() - .schema(default_schema.as_str()) - .unwrap() - .table_exist("employee"); - assert!(is_exist); - let new_state = SessionStateBuilder::new_from_existing(session_state).build(); - assert!(new_state - .catalog_list() - .catalog(default_catalog.as_str()) .unwrap() .schema(default_schema.as_str()) + .await + .unwrap() .unwrap() - .table_exist("employee")); + .table_exist("employee") + .await; + assert!(is_exist); + let new_state = SessionStateBuilder::new_from_existing(session_state) + .await + .build() + .await; + assert!( + new_state + .catalog_list() + .catalog(default_catalog.as_str()) + .await + .unwrap() + .unwrap() + .schema(default_schema.as_str()) + .await + .unwrap() + .unwrap() + .table_exist("employee") + .await + ); // if `with_create_default_catalog_and_schema` is disabled, the new one shouldn't create default catalog and schema let disable_create_default = SessionConfig::default().with_create_default_catalog_and_schema(false); let without_default_state = SessionStateBuilder::new() .with_config(disable_create_default) - .build(); + .build() + .await; assert!(without_default_state .catalog_list() .catalog(&default_catalog) + .await + .unwrap() + .is_none()); + let new_state = SessionStateBuilder::new_from_existing(without_default_state) + .await + .build() + .await; + assert!(new_state + .catalog_list() + .catalog(&default_catalog) + .await + .unwrap() .is_none()); - let new_state = - SessionStateBuilder::new_from_existing(without_default_state).build(); - assert!(new_state.catalog_list().catalog(&default_catalog).is_none()); Ok(()) } - #[test] - fn test_session_state_with_optimizer_rules() { + #[tokio::test] + async fn test_session_state_with_optimizer_rules() { #[derive(Default, Debug)] struct DummyRule {} @@ -2081,14 +2130,16 @@ mod tests { // test building sessions with fresh set of rules let state = SessionStateBuilder::new() .with_optimizer_rules(vec![Arc::new(DummyRule {})]) - .build(); + .build() + .await; assert_eq!(state.optimizers().len(), 1); // test adding rules to default recommendations let state = SessionStateBuilder::new() .with_optimizer_rule(Arc::new(DummyRule {})) - .build(); + .build() + .await; assert_eq!( state.optimizers().len(), @@ -2096,18 +2147,19 @@ mod tests { ); } - #[test] - fn test_with_table_factories() -> Result<()> { + #[tokio::test] + async fn test_with_table_factories() -> Result<()> { use crate::test_util::TestTableFactory; - let state = SessionStateBuilder::new().build(); + let state = SessionStateBuilder::new().build().await; let table_factories = state.table_factories(); assert!(table_factories.is_empty()); let table_factory = Arc::new(TestTableFactory {}); let state = SessionStateBuilder::new() .with_table_factory("employee".to_string(), table_factory) - .build(); + .build() + .await; let table_factories = state.table_factories(); assert_eq!(table_factories.len(), 1); assert!(table_factories.contains_key("employee")); diff --git a/datafusion/core/src/execution/session_state_defaults.rs b/datafusion/core/src/execution/session_state_defaults.rs index 7ba332c520c1..d0b54418191f 100644 --- a/datafusion/core/src/execution/session_state_defaults.rs +++ b/datafusion/core/src/execution/session_state_defaults.rs @@ -35,6 +35,7 @@ use datafusion_execution::object_store::ObjectStoreUrl; use datafusion_execution::runtime_env::RuntimeEnv; use datafusion_expr::planner::ExprPlanner; use datafusion_expr::{AggregateUDF, ScalarUDF, WindowUDF}; +use futures::FutureExt; use std::collections::HashMap; use std::sync::Arc; use url::Url; @@ -72,6 +73,8 @@ impl SessionStateDefaults { &config.options().catalog.default_schema, Arc::new(MemorySchemaProvider::new()), ) + .now_or_never() + .expect("memory catalog provider is synchronous") .expect("memory catalog provider can register schema"); Self::register_default_schema(config, table_factories, runtime, &default_catalog); @@ -202,6 +205,8 @@ impl SessionStateDefaults { ); let _ = default_catalog .register_schema("default", Arc::new(schema)) + .now_or_never() + .expect("memory catalog provider is synchronous") .expect("Failed to register default schema"); } diff --git a/datafusion/core/src/physical_optimizer/test_utils.rs b/datafusion/core/src/physical_optimizer/test_utils.rs index 88bb0b6fef23..245d87dee7a4 100644 --- a/datafusion/core/src/physical_optimizer/test_utils.rs +++ b/datafusion/core/src/physical_optimizer/test_utils.rs @@ -71,7 +71,8 @@ async fn register_current_csv( true => { let source = FileStreamProvider::new_file(schema, path.into()); let config = StreamConfig::new(Arc::new(source)); - ctx.register_table(table_name, Arc::new(StreamTable::new(Arc::new(config))))?; + ctx.register_table(table_name, Arc::new(StreamTable::new(Arc::new(config)))) + .await?; } false => { ctx.register_csv(table_name, &path, CsvReadOptions::new().schema(&schema)) diff --git a/datafusion/core/src/physical_planner.rs b/datafusion/core/src/physical_planner.rs index 44537c951f94..f753380eb544 100644 --- a/datafusion/core/src/physical_planner.rs +++ b/datafusion/core/src/physical_planner.rs @@ -545,7 +545,7 @@ impl DefaultPhysicalPlanner { .. }) => { let name = table_name.table(); - let schema = session_state.schema_for_ref(table_name.clone())?; + let schema = session_state.schema_for_ref(table_name.clone()).await?; if let Some(provider) = schema.table(name).await? { let input_exec = children.one()?; provider @@ -2033,7 +2033,7 @@ mod tests { use datafusion_functions_aggregate::expr_fn::sum; use datafusion_physical_expr::EquivalenceProperties; - fn make_session_state() -> SessionState { + async fn make_session_state() -> SessionState { let runtime = Arc::new(RuntimeEnv::default()); let config = SessionConfig::new().with_target_partitions(4); let config = config.set_bool("datafusion.optimizer.skip_failed_rules", false); @@ -2042,10 +2042,11 @@ mod tests { .with_runtime_env(runtime) .with_default_features() .build() + .await } async fn plan(logical_plan: &LogicalPlan) -> Result> { - let session_state = make_session_state(); + let session_state = make_session_state().await; // optimize the logical plan let logical_plan = session_state.optimize(logical_plan)?; let planner = DefaultPhysicalPlanner::default(); @@ -2087,7 +2088,7 @@ mod tests { let physical_input_schema = plan.schema(); let physical_input_schema = physical_input_schema.as_ref(); let logical_input_schema = logical_plan.schema(); - let session_state = make_session_state(); + let session_state = make_session_state().await; let cube = create_cube_physical_expr( &exprs, @@ -2114,7 +2115,7 @@ mod tests { let physical_input_schema = plan.schema(); let physical_input_schema = physical_input_schema.as_ref(); let logical_input_schema = logical_plan.schema(); - let session_state = make_session_state(); + let session_state = make_session_state().await; let rollup = create_rollup_physical_expr( &exprs, @@ -2140,7 +2141,7 @@ mod tests { let expr = planner.create_physical_expr( &col("a").not(), &dfschema, - &make_session_state(), + &make_session_state().await, )?; let expected = expressions::not(expressions::col("a", &schema)?)?; @@ -2170,7 +2171,7 @@ mod tests { #[tokio::test] async fn error_during_extension_planning() { - let session_state = make_session_state(); + let session_state = make_session_state().await; let planner = DefaultPhysicalPlanner::with_extension_planners(vec![Arc::new( ErrorExtensionPlanner {}, )]); @@ -2231,7 +2232,7 @@ mod tests { #[tokio::test] async fn default_extension_planner() { - let session_state = make_session_state(); + let session_state = make_session_state().await; let planner = DefaultPhysicalPlanner::default(); let logical_plan = LogicalPlan::Extension(Extension { node: Arc::new(NoOpExtensionNode::default()), @@ -2255,7 +2256,7 @@ mod tests { async fn bad_extension_planner() { // Test that creating an execution plan whose schema doesn't // match the logical plan's schema generates an error. - let session_state = make_session_state(); + let session_state = make_session_state().await; let planner = DefaultPhysicalPlanner::with_extension_planners(vec![Arc::new( BadExtensionPlanner {}, )]); diff --git a/datafusion/core/src/test/object_store.rs b/datafusion/core/src/test/object_store.rs index cac430c5b49d..566d52f10e09 100644 --- a/datafusion/core/src/test/object_store.rs +++ b/datafusion/core/src/test/object_store.rs @@ -26,13 +26,15 @@ use std::sync::Arc; use url::Url; /// Returns a test object store with the provided `ctx` -pub fn register_test_store(ctx: &SessionContext, files: &[(&str, u64)]) { +pub async fn register_test_store(ctx: &SessionContext, files: &[(&str, u64)]) { let url = Url::parse("test://").unwrap(); - ctx.register_object_store(&url, make_test_store_and_state(files).0); + ctx.register_object_store(&url, make_test_store_and_state(files).await.0); } /// Create a test object store with the provided files -pub fn make_test_store_and_state(files: &[(&str, u64)]) -> (Arc, SessionState) { +pub async fn make_test_store_and_state( + files: &[(&str, u64)], +) -> (Arc, SessionState) { let memory = InMemory::new(); for (name, size) in files { @@ -45,7 +47,10 @@ pub fn make_test_store_and_state(files: &[(&str, u64)]) -> (Arc, Sessi ( Arc::new(memory), - SessionStateBuilder::new().with_default_features().build(), + SessionStateBuilder::new() + .with_default_features() + .build() + .await, ) } diff --git a/datafusion/core/src/test_util/mod.rs b/datafusion/core/src/test_util/mod.rs index c4c84d667a06..81b1b3fed6a7 100644 --- a/datafusion/core/src/test_util/mod.rs +++ b/datafusion/core/src/test_util/mod.rs @@ -357,7 +357,7 @@ impl RecordBatchStream for UnboundedStream { } /// This function creates an unbounded sorted file for testing purposes. -pub fn register_unbounded_file_with_ordering( +pub async fn register_unbounded_file_with_ordering( ctx: &SessionContext, schema: SchemaRef, file_path: &Path, @@ -368,7 +368,8 @@ pub fn register_unbounded_file_with_ordering( let config = StreamConfig::new(Arc::new(source)).with_order(file_sort_order); // Register table: - ctx.register_table(table_name, Arc::new(StreamTable::new(Arc::new(config))))?; + ctx.register_table(table_name, Arc::new(StreamTable::new(Arc::new(config)))) + .await?; Ok(()) } diff --git a/datafusion/core/tests/custom_sources_cases/mod.rs b/datafusion/core/tests/custom_sources_cases/mod.rs index e1bd14105e23..49cd839f11fc 100644 --- a/datafusion/core/tests/custom_sources_cases/mod.rs +++ b/datafusion/core/tests/custom_sources_cases/mod.rs @@ -268,6 +268,7 @@ async fn custom_source_dataframe() -> Result<()> { async fn optimizers_catch_all_statistics() { let ctx = SessionContext::new(); ctx.register_table("test", Arc::new(CustomTableProvider)) + .await .unwrap(); let df = ctx diff --git a/datafusion/core/tests/custom_sources_cases/provider_filter_pushdown.rs b/datafusion/core/tests/custom_sources_cases/provider_filter_pushdown.rs index 09f7265d639a..2025d8beba56 100644 --- a/datafusion/core/tests/custom_sources_cases/provider_filter_pushdown.rs +++ b/datafusion/core/tests/custom_sources_cases/provider_filter_pushdown.rs @@ -245,7 +245,7 @@ async fn assert_provider_row_count(value: i64, expected_count: i64) -> Result<() let result_col: &Int64Array = as_primitive_array(results[0].column(0))?; assert_eq!(result_col.value(0), expected_count); - ctx.register_table("data", Arc::new(provider))?; + ctx.register_table("data", Arc::new(provider)).await?; let sql_results = ctx .sql(&format!("select count(*) from data where flag = {value}")) .await? diff --git a/datafusion/core/tests/custom_sources_cases/statistics.rs b/datafusion/core/tests/custom_sources_cases/statistics.rs index 41d182a3767b..1da782d79203 100644 --- a/datafusion/core/tests/custom_sources_cases/statistics.rs +++ b/datafusion/core/tests/custom_sources_cases/statistics.rs @@ -182,11 +182,11 @@ impl ExecutionPlan for StatisticsValidation { } } -fn init_ctx(stats: Statistics, schema: Schema) -> Result { +async fn init_ctx(stats: Statistics, schema: Schema) -> Result { let ctx = SessionContext::new(); let provider: Arc = Arc::new(StatisticsValidation::new(stats, Arc::new(schema))); - ctx.register_table("stats_table", provider)?; + ctx.register_table("stats_table", provider).await?; Ok(ctx) } @@ -220,7 +220,7 @@ fn fully_defined() -> (Statistics, Schema) { #[tokio::test] async fn sql_basic() -> Result<()> { let (stats, schema) = fully_defined(); - let ctx = init_ctx(stats.clone(), schema)?; + let ctx = init_ctx(stats.clone(), schema).await?; let df = ctx.sql("SELECT * from stats_table").await.unwrap(); let physical_plan = df.create_physical_plan().await.unwrap(); @@ -234,7 +234,7 @@ async fn sql_basic() -> Result<()> { #[tokio::test] async fn sql_filter() -> Result<()> { let (stats, schema) = fully_defined(); - let ctx = init_ctx(stats, schema)?; + let ctx = init_ctx(stats, schema).await?; let df = ctx .sql("SELECT * FROM stats_table WHERE c1 = 5") @@ -252,7 +252,7 @@ async fn sql_filter() -> Result<()> { async fn sql_limit() -> Result<()> { let (stats, schema) = fully_defined(); let col_stats = Statistics::unknown_column(&schema); - let ctx = init_ctx(stats.clone(), schema)?; + let ctx = init_ctx(stats.clone(), schema).await?; let df = ctx.sql("SELECT * FROM stats_table LIMIT 5").await.unwrap(); let physical_plan = df.create_physical_plan().await.unwrap(); @@ -281,7 +281,7 @@ async fn sql_limit() -> Result<()> { #[tokio::test] async fn sql_window() -> Result<()> { let (stats, schema) = fully_defined(); - let ctx = init_ctx(stats.clone(), schema)?; + let ctx = init_ctx(stats.clone(), schema).await?; let df = ctx .sql("SELECT c2, sum(c1) over (partition by c2) FROM stats_table") diff --git a/datafusion/core/tests/dataframe/dataframe_functions.rs b/datafusion/core/tests/dataframe/dataframe_functions.rs index 1bd90fce839d..8f2e715a9dcc 100644 --- a/datafusion/core/tests/dataframe/dataframe_functions.rs +++ b/datafusion/core/tests/dataframe/dataframe_functions.rs @@ -69,7 +69,7 @@ async fn create_test_table() -> Result { let ctx = SessionContext::new(); - ctx.register_batch("test", batch)?; + ctx.register_batch("test", batch).await?; ctx.table("test").await } diff --git a/datafusion/core/tests/dataframe/mod.rs b/datafusion/core/tests/dataframe/mod.rs index 439aa6147e9b..a52f2ddade9d 100644 --- a/datafusion/core/tests/dataframe/mod.rs +++ b/datafusion/core/tests/dataframe/mod.rs @@ -65,7 +65,7 @@ use datafusion_functions_aggregate::expr_fn::{array_agg, avg, count, max, sum}; #[tokio::test] async fn test_count_wildcard_on_sort() -> Result<()> { - let ctx = create_join_context()?; + let ctx = create_join_context().await?; let sql_results = ctx .sql("select b,count(*) from t1 group by b order by count(*)") @@ -92,7 +92,7 @@ async fn test_count_wildcard_on_sort() -> Result<()> { #[tokio::test] async fn test_count_wildcard_on_where_in() -> Result<()> { - let ctx = create_join_context()?; + let ctx = create_join_context().await?; let sql_results = ctx .sql("SELECT a,b FROM t1 WHERE a in (SELECT count(*) FROM t2)") .await? @@ -103,7 +103,7 @@ async fn test_count_wildcard_on_where_in() -> Result<()> { // In the same SessionContext, AliasGenerator will increase subquery_alias id by 1 // https://github.com/apache/datafusion/blame/cf45eb9020092943b96653d70fafb143cc362e19/datafusion/optimizer/src/alias.rs#L40-L43 // for compare difference between sql and df logical plan, we need to create a new SessionContext here - let ctx = create_join_context()?; + let ctx = create_join_context().await?; let df_results = ctx .table("t1") .await? @@ -133,7 +133,7 @@ async fn test_count_wildcard_on_where_in() -> Result<()> { #[tokio::test] async fn test_count_wildcard_on_where_exist() -> Result<()> { - let ctx = create_join_context()?; + let ctx = create_join_context().await?; let sql_results = ctx .sql("SELECT a, b FROM t1 WHERE EXISTS (SELECT count(*) FROM t2)") .await? @@ -169,7 +169,7 @@ async fn test_count_wildcard_on_where_exist() -> Result<()> { #[tokio::test] async fn test_count_wildcard_on_window() -> Result<()> { - let ctx = create_join_context()?; + let ctx = create_join_context().await?; let sql_results = ctx .sql("select count(*) OVER(ORDER BY a DESC RANGE BETWEEN 6 PRECEDING AND 2 FOLLOWING) from t1") @@ -207,7 +207,7 @@ async fn test_count_wildcard_on_window() -> Result<()> { #[tokio::test] async fn test_count_wildcard_on_aggregate() -> Result<()> { - let ctx = create_join_context()?; + let ctx = create_join_context().await?; register_alltypes_tiny_pages_parquet(&ctx).await?; let sql_results = ctx @@ -239,7 +239,7 @@ async fn test_count_wildcard_on_aggregate() -> Result<()> { #[tokio::test] async fn test_count_wildcard_on_where_scalar_subquery() -> Result<()> { - let ctx = create_join_context()?; + let ctx = create_join_context().await?; let sql_results = ctx .sql("select a,b from t1 where (select count(*) from t2 where t1.a = t2.a)>0;") @@ -251,7 +251,7 @@ async fn test_count_wildcard_on_where_scalar_subquery() -> Result<()> { // In the same SessionContext, AliasGenerator will increase subquery_alias id by 1 // https://github.com/apache/datafusion/blame/cf45eb9020092943b96653d70fafb143cc362e19/datafusion/optimizer/src/alias.rs#L40-L43 // for compare difference between sql and df logical plan, we need to create a new SessionContext here - let ctx = create_join_context()?; + let ctx = create_join_context().await?; let df_results = ctx .table("t1") .await? @@ -310,11 +310,11 @@ async fn join() -> Result<()> { let ctx = SessionContext::new(); - ctx.register_batch("aa", batch1)?; + ctx.register_batch("aa", batch1).await?; let df1 = ctx.table("aa").await?; - ctx.register_batch("aaa", batch2)?; + ctx.register_batch("aaa", batch2).await?; let df2 = ctx.table("aaa").await?; @@ -344,7 +344,7 @@ async fn sort_on_unprojected_columns() -> Result<()> { .unwrap(); let ctx = SessionContext::new(); - ctx.register_batch("t", batch).unwrap(); + ctx.register_batch("t", batch).await.unwrap(); let df = ctx .table("t") @@ -387,7 +387,7 @@ async fn sort_on_distinct_columns() -> Result<()> { .unwrap(); let ctx = SessionContext::new(); - ctx.register_batch("t", batch).unwrap(); + ctx.register_batch("t", batch).await.unwrap(); let df = ctx .table("t") .await @@ -429,7 +429,7 @@ async fn sort_on_distinct_unprojected_columns() -> Result<()> { // Cannot sort on a column after distinct that would add a new column let ctx = SessionContext::new(); - ctx.register_batch("t", batch)?; + ctx.register_batch("t", batch).await?; let err = ctx .table("t") .await? @@ -528,7 +528,7 @@ async fn filter_with_alias_overwrite() -> Result<()> { .unwrap(); let ctx = SessionContext::new(); - ctx.register_batch("t", batch).unwrap(); + ctx.register_batch("t", batch).await.unwrap(); let df = ctx .table("t") @@ -562,7 +562,7 @@ async fn select_with_alias_overwrite() -> Result<()> { )?; let ctx = SessionContext::new(); - ctx.register_batch("t", batch)?; + ctx.register_batch("t", batch).await?; let df = ctx .table("t") @@ -741,7 +741,7 @@ async fn test_grouping_set_array_agg_with_overflow() -> Result<()> { #[tokio::test] async fn join_with_alias_filter() -> Result<()> { - let join_ctx = create_join_context()?; + let join_ctx = create_join_context().await?; let t1 = join_ctx.table("t1").await?; let t2 = join_ctx.table("t2").await?; let t1_schema = t1.schema().clone(); @@ -797,7 +797,7 @@ async fn join_with_alias_filter() -> Result<()> { #[tokio::test] async fn right_semi_with_alias_filter() -> Result<()> { - let join_ctx = create_join_context()?; + let join_ctx = create_join_context().await?; let t1 = join_ctx.table("t1").await?; let t2 = join_ctx.table("t2").await?; @@ -842,7 +842,7 @@ async fn right_semi_with_alias_filter() -> Result<()> { #[tokio::test] async fn right_anti_filter_push_down() -> Result<()> { - let join_ctx = create_join_context()?; + let join_ctx = create_join_context().await?; let t1 = join_ctx.table("t1").await?; let t2 = join_ctx.table("t2").await?; @@ -1046,7 +1046,7 @@ async fn unnest_fixed_list() -> Result<()> { let batch = get_fixed_list_batch()?; let ctx = SessionContext::new(); - ctx.register_batch("shapes", batch)?; + ctx.register_batch("shapes", batch).await?; let df = ctx.table("shapes").await?; let results = df.clone().collect().await?; @@ -1096,7 +1096,7 @@ async fn unnest_fixed_list_drop_nulls() -> Result<()> { let batch = get_fixed_list_batch()?; let ctx = SessionContext::new(); - ctx.register_batch("shapes", batch)?; + ctx.register_batch("shapes", batch).await?; let df = ctx.table("shapes").await?; let results = df.clone().collect().await?; @@ -1163,7 +1163,7 @@ async fn unnest_fixed_list_nonull() -> Result<()> { ])?; let ctx = SessionContext::new(); - ctx.register_batch("shapes", batch)?; + ctx.register_batch("shapes", batch).await?; let df = ctx.table("shapes").await?; let results = df.clone().collect().await?; @@ -1264,7 +1264,7 @@ async fn unnest_array_agg() -> Result<()> { ])?; let ctx = SessionContext::new(); - ctx.register_batch("shapes", batch)?; + ctx.register_batch("shapes", batch).await?; let df = ctx.table("shapes").await?; let results = df.clone().collect().await?; @@ -1354,7 +1354,7 @@ async fn unnest_with_redundant_columns() -> Result<()> { ])?; let ctx = SessionContext::new(); - ctx.register_batch("shapes", batch)?; + ctx.register_batch("shapes", batch).await?; let df = ctx.table("shapes").await?; let results = df.clone().collect().await?; @@ -1547,7 +1547,8 @@ async fn test_read_batches() -> Result<()> { .with_config(config) .with_runtime_env(runtime) .with_default_features() - .build(); + .build() + .await; let ctx = SessionContext::new_with_state(state); let schema = Arc::new(Schema::new(vec![ @@ -1601,7 +1602,8 @@ async fn test_read_batches_empty() -> Result<()> { .with_config(config) .with_runtime_env(runtime) .with_default_features() - .build(); + .build() + .await; let ctx = SessionContext::new_with_state(state); let batches = vec![]; @@ -1615,7 +1617,10 @@ async fn test_read_batches_empty() -> Result<()> { #[tokio::test] async fn consecutive_projection_same_schema() -> Result<()> { - let state = SessionStateBuilder::new().with_default_features().build(); + let state = SessionStateBuilder::new() + .with_default_features() + .build() + .await; let ctx = SessionContext::new_with_state(state); let schema = Arc::new(Schema::new(vec![Field::new("id", DataType::Int32, false)])); @@ -1687,7 +1692,7 @@ async fn create_test_table(name: &str) -> Result { let ctx = SessionContext::new(); - ctx.register_batch(name, batch)?; + ctx.register_batch(name, batch).await?; ctx.table(name).await } @@ -1702,7 +1707,7 @@ async fn aggregates_table(ctx: &SessionContext) -> Result { .await } -fn create_join_context() -> Result { +async fn create_join_context() -> Result { let t1 = Arc::new(Schema::new(vec![ Field::new("a", DataType::UInt32, false), Field::new("b", DataType::Utf8, false), @@ -1735,8 +1740,8 @@ fn create_join_context() -> Result { let ctx = SessionContext::new(); - ctx.register_batch("t1", batch1)?; - ctx.register_batch("t2", batch2)?; + ctx.register_batch("t1", batch1).await?; + ctx.register_batch("t2", batch2).await?; Ok(ctx) } @@ -1804,7 +1809,7 @@ async fn table_with_nested_types(n: usize) -> Result { ])?; let ctx = SessionContext::new(); - ctx.register_batch("shapes", batch)?; + ctx.register_batch("shapes", batch).await?; ctx.table("shapes").await } @@ -1876,7 +1881,7 @@ async fn table_with_mixed_lists() -> Result { ])?; let ctx = SessionContext::new(); - ctx.register_batch("mixed_lists", batch)?; + ctx.register_batch("mixed_lists", batch).await?; ctx.table("mixed_lists").await } @@ -1910,7 +1915,7 @@ async fn table_with_lists_and_nulls() -> Result { ])?; let ctx = SessionContext::new(); - ctx.register_batch("shapes", batch)?; + ctx.register_batch("shapes", batch).await?; ctx.table("shapes").await } @@ -1951,7 +1956,7 @@ async fn use_var_provider() -> Result<()> { .set_bool("datafusion.optimizer.skip_failed_rules", false); let ctx = SessionContext::new_with_config(config); - ctx.register_table("csv_table", mem_table)?; + ctx.register_table("csv_table", mem_table).await?; ctx.register_variable(VarType::UserDefined, Arc::new(HardcodedIntProvider {})); let dataframe = ctx @@ -2074,7 +2079,7 @@ async fn write_partitioned_parquet_results() -> Result<()> { let mem_table = Arc::new(MemTable::try_new(schema, vec![vec![record_batch]])?); // Register the table in the context - ctx.register_table("test", mem_table)?; + ctx.register_table("test", mem_table).await?; let local = Arc::new(LocalFileSystem::new_with_prefix(&tmp_dir)?); let local_url = Url::parse("file://local").unwrap(); @@ -2239,7 +2244,7 @@ async fn sparse_union_is_null() { let ctx = SessionContext::new(); - ctx.register_batch("union_batch", batch).unwrap(); + ctx.register_batch("union_batch", batch).await.unwrap(); let df = ctx.table("union_batch").await.unwrap(); @@ -2316,7 +2321,7 @@ async fn dense_union_is_null() { let ctx = SessionContext::new(); - ctx.register_batch("union_batch", batch).unwrap(); + ctx.register_batch("union_batch", batch).await.unwrap(); let df = ctx.table("union_batch").await.unwrap(); @@ -2383,7 +2388,7 @@ async fn boolean_dictionary_as_filter() { let ctx = SessionContext::new(); - ctx.register_batch("dict_batch", batch).unwrap(); + ctx.register_batch("dict_batch", batch).await.unwrap(); let df = ctx.table("dict_batch").await.unwrap(); @@ -2436,7 +2441,9 @@ async fn boolean_dictionary_as_filter() { let batch = RecordBatch::try_new(schema, vec![Arc::new(nested_array)]).unwrap(); - ctx.register_batch("nested_dict_batch", batch).unwrap(); + ctx.register_batch("nested_dict_batch", batch) + .await + .unwrap(); let df = ctx.table("nested_dict_batch").await.unwrap(); diff --git a/datafusion/core/tests/execution/logical_plan.rs b/datafusion/core/tests/execution/logical_plan.rs index 168bf484e541..c3d2644170b9 100644 --- a/datafusion/core/tests/execution/logical_plan.rs +++ b/datafusion/core/tests/execution/logical_plan.rs @@ -68,7 +68,7 @@ async fn count_only_nulls() -> Result<()> { )?); // Execute and verify results - let session_state = SessionStateBuilder::new().build(); + let session_state = SessionStateBuilder::new().build().await; let physical_plan = session_state.create_physical_plan(&aggregate).await?; let result = collect(physical_plan, Arc::new(TaskContext::from(&session_state))).await?; diff --git a/datafusion/core/tests/fifo/mod.rs b/datafusion/core/tests/fifo/mod.rs index cb587e3510c2..6b950e391058 100644 --- a/datafusion/core/tests/fifo/mod.rs +++ b/datafusion/core/tests/fifo/mod.rs @@ -189,7 +189,7 @@ mod unix_test { ])); let provider = fifo_table(schema, fifo_path, vec![]); - ctx.register_table("left", provider).unwrap(); + ctx.register_table("left", provider).await.unwrap(); // Register right table let schema = aggr_test_schema(); @@ -251,10 +251,10 @@ mod unix_test { // Set unbounded sorted files read configuration let provider = fifo_table(schema.clone(), left_fifo.clone(), order.clone()); - ctx.register_table("left", provider)?; + ctx.register_table("left", provider).await?; let provider = fifo_table(schema.clone(), right_fifo.clone(), order); - ctx.register_table("right", provider)?; + ctx.register_table("right", provider).await?; // Execute the query, with no matching rows. (since key is modulus 10) let df = ctx diff --git a/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs b/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs index 09d0c8d5ca2e..dabd06e2f19d 100644 --- a/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs +++ b/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs @@ -508,7 +508,7 @@ async fn group_by_string_test( provider }; - ctx.register_table("t", Arc::new(provider)).unwrap(); + ctx.register_table("t", Arc::new(provider)).await.unwrap(); let df = ctx .sql("SELECT a, COUNT(*) FROM t GROUP BY a") diff --git a/datafusion/core/tests/fuzz_cases/aggregation_fuzzer/context_generator.rs b/datafusion/core/tests/fuzz_cases/aggregation_fuzzer/context_generator.rs index af454bee7ce8..a2ac9908968b 100644 --- a/datafusion/core/tests/fuzz_cases/aggregation_fuzzer/context_generator.rs +++ b/datafusion/core/tests/fuzz_cases/aggregation_fuzzer/context_generator.rs @@ -87,7 +87,7 @@ impl SessionContextGenerator { impl SessionContextGenerator { /// Generate the `SessionContext` for the baseline run - pub fn generate_baseline(&self) -> Result { + pub async fn generate_baseline(&self) -> Result { let schema = self.dataset.batches[0].schema(); let batches = self.dataset.batches.clone(); let provider = MemTable::try_new(schema, vec![batches])?; @@ -107,11 +107,11 @@ impl SessionContextGenerator { table_provider: Arc::new(provider), }; - builder.build() + builder.build().await } /// Randomly generate session context - pub fn generate(&self) -> Result { + pub async fn generate(&self) -> Result { let mut rng = thread_rng(); let schema = self.dataset.batches[0].schema(); let batches = self.dataset.batches.clone(); @@ -155,7 +155,7 @@ impl SessionContextGenerator { table_provider: Arc::new(provider), }; - builder.build() + builder.build().await } } @@ -179,7 +179,7 @@ struct GeneratedSessionContextBuilder { } impl GeneratedSessionContextBuilder { - fn build(self) -> Result { + async fn build(self) -> Result { // Build session context let mut session_config = SessionConfig::default(); session_config = session_config.set( @@ -200,7 +200,8 @@ impl GeneratedSessionContextBuilder { ); let ctx = SessionContext::new_with_config(session_config); - ctx.register_table(self.table_name, self.table_provider)?; + ctx.register_table(self.table_name, self.table_provider) + .await?; let params = SessionContextParams { batch_size: self.batch_size, @@ -312,10 +313,10 @@ mod test { let ctx_generator = SessionContextGenerator::new(Arc::new(dataset), "fuzz_table"); let query = "select b, count(a) from fuzz_table group by b"; - let baseline_wrapped_ctx = ctx_generator.generate_baseline().unwrap(); + let baseline_wrapped_ctx = ctx_generator.generate_baseline().await.unwrap(); let mut random_wrapped_ctxs = Vec::with_capacity(8); for _ in 0..8 { - let ctx = ctx_generator.generate().unwrap(); + let ctx = ctx_generator.generate().await.unwrap(); random_wrapped_ctxs.push(ctx); } diff --git a/datafusion/core/tests/fuzz_cases/aggregation_fuzzer/fuzzer.rs b/datafusion/core/tests/fuzz_cases/aggregation_fuzzer/fuzzer.rs index d021e73f35b2..0b2ea6f130a0 100644 --- a/datafusion/core/tests/fuzz_cases/aggregation_fuzzer/fuzzer.rs +++ b/datafusion/core/tests/fuzz_cases/aggregation_fuzzer/fuzzer.rs @@ -233,6 +233,7 @@ impl AggregationFuzzer { // Generate the baseline context, and get the baseline result firstly let baseline_ctx_with_params = ctx_generator .generate_baseline() + .await .expect("should success to generate baseline session context"); let baseline_result = run_sql(&sql, &baseline_ctx_with_params.ctx) .await @@ -242,6 +243,7 @@ impl AggregationFuzzer { for _ in 0..CTX_GEN_ROUNDS { let ctx_with_params = ctx_generator .generate() + .await .expect("should success to generate session context"); let task = AggregationFuzzTestTask { dataset_ref: dataset_ref.clone(), diff --git a/datafusion/core/tests/fuzz_cases/distinct_count_string_fuzz.rs b/datafusion/core/tests/fuzz_cases/distinct_count_string_fuzz.rs index 64b858cebc84..e53311f2b3a0 100644 --- a/datafusion/core/tests/fuzz_cases/distinct_count_string_fuzz.rs +++ b/datafusion/core/tests/fuzz_cases/distinct_count_string_fuzz.rs @@ -59,7 +59,7 @@ async fn run_distinct_count_test(mut generator: StringBatchGenerator) { ]; let provider = MemTable::try_new(schema, partitions).unwrap(); - ctx.register_table("t", Arc::new(provider)).unwrap(); + ctx.register_table("t", Arc::new(provider)).await.unwrap(); // input has two columns, a and b. The result is the number of distinct // values in each column. // diff --git a/datafusion/core/tests/memory_limit/mod.rs b/datafusion/core/tests/memory_limit/mod.rs index c431cd6303db..99873a8159ab 100644 --- a/datafusion/core/tests/memory_limit/mod.rs +++ b/datafusion/core/tests/memory_limit/mod.rs @@ -534,8 +534,10 @@ impl TestCase { None => builder, }; - let ctx = SessionContext::new_with_state(builder.build()); - ctx.register_table("t", table).expect("registering table"); + let ctx = SessionContext::new_with_state(builder.build().await); + ctx.register_table("t", table) + .await + .expect("registering table"); let query = query.expect("Test error: query not specified"); let df = ctx.sql(&query).await.expect("Planning query"); diff --git a/datafusion/core/tests/parquet/file_statistics.rs b/datafusion/core/tests/parquet/file_statistics.rs index 4b5d22bfa71f..52827ebdfd34 100644 --- a/datafusion/core/tests/parquet/file_statistics.rs +++ b/datafusion/core/tests/parquet/file_statistics.rs @@ -188,7 +188,10 @@ async fn get_listing_table( ) -> ListingTable { let schema = opt .infer_schema( - &SessionStateBuilder::new().with_default_features().build(), + &SessionStateBuilder::new() + .with_default_features() + .build() + .await, table_path, ) .await diff --git a/datafusion/core/tests/parquet/mod.rs b/datafusion/core/tests/parquet/mod.rs index 3f68222a2ce3..42e5f3cee430 100644 --- a/datafusion/core/tests/parquet/mod.rs +++ b/datafusion/core/tests/parquet/mod.rs @@ -210,8 +210,8 @@ impl ContextWithParquet { ctx.register_parquet("t", &parquet_path, ParquetReadOptions::default()) .await .unwrap(); - let provider = ctx.deregister_table("t").unwrap().unwrap(); - ctx.register_table("t", provider.clone()).unwrap(); + let provider = ctx.deregister_table("t").await.unwrap().unwrap(); + ctx.register_table("t", provider.clone()).await.unwrap(); Self { _file: file, diff --git a/datafusion/core/tests/sql/create_drop.rs b/datafusion/core/tests/sql/create_drop.rs index 83712053b954..6bd4282196b5 100644 --- a/datafusion/core/tests/sql/create_drop.rs +++ b/datafusion/core/tests/sql/create_drop.rs @@ -22,7 +22,10 @@ use super::*; #[tokio::test] async fn create_custom_table() -> Result<()> { - let mut state = SessionStateBuilder::new().with_default_features().build(); + let mut state = SessionStateBuilder::new() + .with_default_features() + .build() + .await; state .table_factories_mut() .insert("DELTATABLE".to_string(), Arc::new(TestTableFactory {})); @@ -31,9 +34,9 @@ async fn create_custom_table() -> Result<()> { let sql = "CREATE EXTERNAL TABLE dt STORED AS DELTATABLE LOCATION 's3://bucket/schema/table';"; ctx.sql(sql).await.unwrap(); - let cat = ctx.catalog("datafusion").unwrap(); - let schema = cat.schema("public").unwrap(); - let exists = schema.table_exist("dt"); + let cat = ctx.catalog("datafusion").await.unwrap().unwrap(); + let schema = cat.schema("public").await.unwrap().unwrap(); + let exists = schema.table_exist("dt").await; assert!(exists, "Table should have been created!"); Ok(()) @@ -41,7 +44,10 @@ async fn create_custom_table() -> Result<()> { #[tokio::test] async fn create_external_table_with_ddl() -> Result<()> { - let mut state = SessionStateBuilder::new().with_default_features().build(); + let mut state = SessionStateBuilder::new() + .with_default_features() + .build() + .await; state .table_factories_mut() .insert("MOCKTABLE".to_string(), Arc::new(TestTableFactory {})); @@ -50,10 +56,10 @@ async fn create_external_table_with_ddl() -> Result<()> { let sql = "CREATE EXTERNAL TABLE dt (a_id integer, a_str string, a_bool boolean) STORED AS MOCKTABLE LOCATION 'mockprotocol://path/to/table';"; ctx.sql(sql).await.unwrap(); - let cat = ctx.catalog("datafusion").unwrap(); - let schema = cat.schema("public").unwrap(); + let cat = ctx.catalog("datafusion").await.unwrap().unwrap(); + let schema = cat.schema("public").await.unwrap().unwrap(); - let exists = schema.table_exist("dt"); + let exists = schema.table_exist("dt").await; assert!(exists, "Table should have been created!"); let table_schema = schema.table("dt").await.unwrap().unwrap().schema(); diff --git a/datafusion/core/tests/sql/joins.rs b/datafusion/core/tests/sql/joins.rs index fab92c0f9c2b..30cfcb33a47c 100644 --- a/datafusion/core/tests/sql/joins.rs +++ b/datafusion/core/tests/sql/joins.rs @@ -47,7 +47,8 @@ async fn join_change_in_planner() -> Result<()> { &left_file_path, "left", file_sort_order.clone(), - )?; + ) + .await?; let right_file_path = tmp_dir.path().join("right.csv"); File::create(right_file_path.clone()).unwrap(); register_unbounded_file_with_ordering( @@ -56,7 +57,8 @@ async fn join_change_in_planner() -> Result<()> { &right_file_path, "right", file_sort_order, - )?; + ) + .await?; let sql = "SELECT t1.a1, t1.a2, t2.a1, t2.a2 FROM left as t1 FULL JOIN right as t2 ON t1.a2 = t2.a2 AND t1.a1 > t2.a1 + 3 AND t1.a1 < t2.a1 + 10"; let dataframe = ctx.sql(sql).await?; let physical_plan = dataframe.create_physical_plan().await?; @@ -115,7 +117,8 @@ async fn join_no_order_on_filter() -> Result<()> { &left_file_path, "left", file_sort_order.clone(), - )?; + ) + .await?; let right_file_path = tmp_dir.path().join("right.csv"); File::create(right_file_path.clone()).unwrap(); register_unbounded_file_with_ordering( @@ -124,7 +127,8 @@ async fn join_no_order_on_filter() -> Result<()> { &right_file_path, "right", file_sort_order, - )?; + ) + .await?; let sql = "SELECT * FROM left as t1 FULL JOIN right as t2 ON t1.a2 = t2.a2 AND t1.a3 > t2.a3 + 3 AND t1.a3 < t2.a3 + 10"; let dataframe = ctx.sql(sql).await?; let physical_plan = dataframe.create_physical_plan().await?; @@ -168,13 +172,15 @@ async fn join_change_in_planner_without_sort() -> Result<()> { ])); let left_source = FileStreamProvider::new_file(schema.clone(), left_file_path); let left = StreamConfig::new(Arc::new(left_source)); - ctx.register_table("left", Arc::new(StreamTable::new(Arc::new(left))))?; + ctx.register_table("left", Arc::new(StreamTable::new(Arc::new(left)))) + .await?; let right_file_path = tmp_dir.path().join("right.csv"); File::create(right_file_path.clone())?; let right_source = FileStreamProvider::new_file(schema, right_file_path); let right = StreamConfig::new(Arc::new(right_source)); - ctx.register_table("right", Arc::new(StreamTable::new(Arc::new(right))))?; + ctx.register_table("right", Arc::new(StreamTable::new(Arc::new(right)))) + .await?; let sql = "SELECT t1.a1, t1.a2, t2.a1, t2.a2 FROM left as t1 FULL JOIN right as t2 ON t1.a2 = t2.a2 AND t1.a1 > t2.a1 + 3 AND t1.a1 < t2.a1 + 10"; let dataframe = ctx.sql(sql).await?; let physical_plan = dataframe.create_physical_plan().await?; @@ -220,12 +226,14 @@ async fn join_change_in_planner_without_sort_not_allowed() -> Result<()> { ])); let left_source = FileStreamProvider::new_file(schema.clone(), left_file_path); let left = StreamConfig::new(Arc::new(left_source)); - ctx.register_table("left", Arc::new(StreamTable::new(Arc::new(left))))?; + ctx.register_table("left", Arc::new(StreamTable::new(Arc::new(left)))) + .await?; let right_file_path = tmp_dir.path().join("right.csv"); File::create(right_file_path.clone())?; let right_source = FileStreamProvider::new_file(schema.clone(), right_file_path); let right = StreamConfig::new(Arc::new(right_source)); - ctx.register_table("right", Arc::new(StreamTable::new(Arc::new(right))))?; + ctx.register_table("right", Arc::new(StreamTable::new(Arc::new(right)))) + .await?; let df = ctx.sql("SELECT t1.a1, t1.a2, t2.a1, t2.a2 FROM left as t1 FULL JOIN right as t2 ON t1.a2 = t2.a2 AND t1.a1 > t2.a1 + 3 AND t1.a1 < t2.a1 + 10").await?; match df.create_physical_plan().await { Ok(_) => panic!("Expecting error."), diff --git a/datafusion/core/tests/sql/path_partition.rs b/datafusion/core/tests/sql/path_partition.rs index 975984e5b11f..b112233a31f4 100644 --- a/datafusion/core/tests/sql/path_partition.rs +++ b/datafusion/core/tests/sql/path_partition.rs @@ -264,7 +264,8 @@ async fn csv_filter_with_file_col() -> Result<()> { ], &[("date", DataType::Utf8)], "mirror:///mytable/", - ); + ) + .await; let result = ctx .sql("SELECT c1, c2 FROM t WHERE date='2021-10-27' and c1!='2021-10-27' LIMIT 5") @@ -302,7 +303,8 @@ async fn csv_filter_with_file_nonstring_col() -> Result<()> { ], &[("date", DataType::Date32)], "mirror:///mytable/", - ); + ) + .await; let result = ctx .sql("SELECT c1, c2, date FROM t WHERE date > '2021-10-27' LIMIT 5") @@ -340,7 +342,8 @@ async fn csv_projection_on_partition() -> Result<()> { ], &[("date", DataType::Date32)], "mirror:///mytable/", - ); + ) + .await; let result = ctx .sql("SELECT c1, date FROM t WHERE date='2021-10-27' LIMIT 5") @@ -379,7 +382,8 @@ async fn csv_grouping_by_partition() -> Result<()> { ], &[("date", DataType::Date32)], "mirror:///mytable/", - ); + ) + .await; let result = ctx .sql("SELECT date, count(*), count(distinct(c1)) FROM t WHERE date<='2021-10-27' GROUP BY date") @@ -573,7 +577,7 @@ async fn parquet_overlapping_columns() -> Result<()> { Ok(()) } -fn register_partitioned_aggregate_csv( +async fn register_partitioned_aggregate_csv( ctx: &SessionContext, store_paths: &[&str], partition_cols: &[(&str, DataType)], @@ -603,6 +607,7 @@ fn register_partitioned_aggregate_csv( let table = ListingTable::try_new(config).unwrap(); ctx.register_table("t", Arc::new(table)) + .await .expect("registering listing table failed"); } @@ -622,6 +627,7 @@ async fn register_partitioned_alltypes_parquet( ) .await; ctx.register_table("t", table) + .await .expect("registering listing table failed"); } diff --git a/datafusion/core/tests/sql/select.rs b/datafusion/core/tests/sql/select.rs index 2e815303e3ce..c1d1f336dcdd 100644 --- a/datafusion/core/tests/sql/select.rs +++ b/datafusion/core/tests/sql/select.rs @@ -154,7 +154,7 @@ async fn prepared_statement_type_coercion() -> Result<()> { ("signed", Arc::new(signed_ints) as ArrayRef), ("unsigned", Arc::new(unsigned_ints) as ArrayRef), ])?; - ctx.register_batch("test", batch)?; + ctx.register_batch("test", batch).await?; let results = ctx.sql("SELECT signed, unsigned FROM test WHERE $1 >= signed AND signed <= $2 AND unsigned = $3") .await? .with_param_values(vec![ @@ -184,7 +184,7 @@ async fn test_parameter_type_coercion() -> Result<()> { ("signed", Arc::new(signed_ints) as ArrayRef), ("unsigned", Arc::new(unsigned_ints) as ArrayRef), ])?; - ctx.register_batch("test", batch)?; + ctx.register_batch("test", batch).await?; let results = ctx.sql("SELECT signed, unsigned FROM test WHERE $foo >= signed AND signed <= $bar AND unsigned <= $baz AND unsigned = $str") .await? .with_param_values(vec![ @@ -215,7 +215,7 @@ async fn test_parameter_invalid_types() -> Result<()> { ])]); let batch = RecordBatch::try_from_iter(vec![("list", Arc::new(list_array) as ArrayRef)])?; - ctx.register_batch("test", batch)?; + ctx.register_batch("test", batch).await?; let results = ctx .sql("SELECT list FROM test WHERE list = $1") .await? diff --git a/datafusion/core/tests/tpcds_planning.rs b/datafusion/core/tests/tpcds_planning.rs index 252d76d0f9d9..1ec148468309 100644 --- a/datafusion/core/tests/tpcds_planning.rs +++ b/datafusion/core/tests/tpcds_planning.rs @@ -1040,7 +1040,8 @@ async fn regression_test(query_no: u8, create_physical: bool) -> Result<()> { Arc::new(table.schema.clone()), vec![vec![]], )?), - )?; + ) + .await?; } // some queries have multiple statements diff --git a/datafusion/core/tests/user_defined/insert_operation.rs b/datafusion/core/tests/user_defined/insert_operation.rs index ff14fa0be3fb..4a38af6b8f08 100644 --- a/datafusion/core/tests/user_defined/insert_operation.rs +++ b/datafusion/core/tests/user_defined/insert_operation.rs @@ -34,6 +34,7 @@ async fn insert_operation_is_passed_correctly_to_table_provider() { let ctx = session_ctx_with_dialect("SQLite"); let table_provider = Arc::new(TestInsertTableProvider::new()); ctx.register_table("testing", table_provider.clone()) + .await .unwrap(); let sql = "INSERT INTO testing (column) VALUES (1)"; diff --git a/datafusion/core/tests/user_defined/user_defined_aggregates.rs b/datafusion/core/tests/user_defined/user_defined_aggregates.rs index 99c00615376f..331342bca615 100644 --- a/datafusion/core/tests/user_defined/user_defined_aggregates.rs +++ b/datafusion/core/tests/user_defined/user_defined_aggregates.rs @@ -59,7 +59,7 @@ use datafusion_functions_aggregate::average::AvgAccumulator; /// Test to show the contents of the setup #[tokio::test] async fn test_setup() { - let TestContext { ctx, test_state: _ } = TestContext::new(); + let TestContext { ctx, test_state: _ } = TestContext::new().await; let sql = "SELECT * from t order by time"; let expected = [ "+-------+----------------------------+", @@ -78,7 +78,7 @@ async fn test_setup() { /// Basic user defined aggregate #[tokio::test] async fn test_udaf() { - let TestContext { ctx, test_state } = TestContext::new(); + let TestContext { ctx, test_state } = TestContext::new().await; assert!(!test_state.update_batch()); let sql = "SELECT time_sum(time) from t"; let expected = [ @@ -97,7 +97,7 @@ async fn test_udaf() { /// User defined aggregate used as a window function #[tokio::test] async fn test_udaf_as_window() { - let TestContext { ctx, test_state } = TestContext::new(); + let TestContext { ctx, test_state } = TestContext::new().await; let sql = "SELECT time_sum(time) OVER() as time_sum from t"; let expected = [ "+----------------------------+", @@ -119,7 +119,7 @@ async fn test_udaf_as_window() { /// User defined aggregate used as a window function with a window frame #[tokio::test] async fn test_udaf_as_window_with_frame() { - let TestContext { ctx, test_state } = TestContext::new(); + let TestContext { ctx, test_state } = TestContext::new().await; let sql = "SELECT time_sum(time) OVER(ORDER BY time ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING) as time_sum from t"; let expected = [ "+----------------------------+", @@ -144,7 +144,8 @@ async fn test_udaf_as_window_with_frame() { async fn test_udaf_as_window_with_frame_without_retract_batch() { let test_state = Arc::new(TestState::new().with_error_on_retract_batch()); - let TestContext { ctx, test_state: _ } = TestContext::new_with_test_state(test_state); + let TestContext { ctx, test_state: _ } = + TestContext::new_with_test_state(test_state).await; let sql = "SELECT time_sum(time) OVER(ORDER BY time ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING) as time_sum from t"; // Note if this query ever does start working let err = execute(&ctx, sql).await.unwrap_err(); @@ -154,7 +155,7 @@ async fn test_udaf_as_window_with_frame_without_retract_batch() { /// Basic query for with a udaf returning a structure #[tokio::test] async fn test_udaf_returning_struct() { - let TestContext { ctx, test_state: _ } = TestContext::new(); + let TestContext { ctx, test_state: _ } = TestContext::new().await; let sql = "SELECT first(value, time) from t"; let expected = [ "+------------------------------------------------+", @@ -169,7 +170,7 @@ async fn test_udaf_returning_struct() { /// Demonstrate extracting the fields from a structure using a subquery #[tokio::test] async fn test_udaf_returning_struct_subquery() { - let TestContext { ctx, test_state: _ } = TestContext::new(); + let TestContext { ctx, test_state: _ } = TestContext::new().await; let sql = "select sq.first['value'], sq.first['time'] from (SELECT first(value, time) as first from t) as sq"; let expected = [ "+-----------------+----------------------------+", @@ -186,7 +187,7 @@ async fn test_udaf_shadows_builtin_fn() { let TestContext { mut ctx, test_state, - } = TestContext::new(); + } = TestContext::new().await; let sql = "SELECT sum(arrow_cast(time, 'Int64')) from t"; // compute with builtin `sum` aggregator @@ -233,7 +234,7 @@ async fn simple_udaf() -> Result<()> { let ctx = SessionContext::new(); let provider = MemTable::try_new(Arc::new(schema), vec![vec![batch1], vec![batch2]])?; - ctx.register_table("t", Arc::new(provider))?; + ctx.register_table("t", Arc::new(provider)).await?; // define a udaf, using a DataFusion's accumulator let my_avg = create_udaf( @@ -289,7 +290,7 @@ async fn case_sensitive_identifiers_user_defined_aggregates() -> Result<()> { let ctx = SessionContext::new(); let arr = Int32Array::from(vec![1]); let batch = RecordBatch::try_from_iter(vec![("i", Arc::new(arr) as _)])?; - ctx.register_batch("t", batch).unwrap(); + ctx.register_batch("t", batch).await.unwrap(); // Note capitalization let my_avg = create_udaf( @@ -333,7 +334,7 @@ async fn test_user_defined_functions_with_alias() -> Result<()> { let ctx = SessionContext::new(); let arr = Int32Array::from(vec![1]); let batch = RecordBatch::try_from_iter(vec![("i", Arc::new(arr) as _)])?; - ctx.register_batch("t", batch).unwrap(); + ctx.register_batch("t", batch).await.unwrap(); let my_avg = create_udaf( "dummy", @@ -369,7 +370,7 @@ async fn test_groups_accumulator() -> Result<()> { let ctx = SessionContext::new(); let arr = Int32Array::from(vec![1]); let batch = RecordBatch::try_from_iter(vec![("a", Arc::new(arr) as _)])?; - ctx.register_batch("t", batch).unwrap(); + ctx.register_batch("t", batch).await.unwrap(); let udaf = AggregateUDF::from(TestGroupsAccumulator { signature: Signature::exact(vec![DataType::Float64], Volatility::Immutable), @@ -391,7 +392,7 @@ async fn test_parameterized_aggregate_udf() -> Result<()> { )])?; let ctx = SessionContext::new(); - ctx.register_batch("t", batch)?; + ctx.register_batch("t", batch).await?; let t = ctx.table("t").await?; let signature = Signature::exact(vec![DataType::Utf8], Volatility::Immutable); let udf1 = AggregateUDF::from(TestGroupsAccumulator { @@ -428,7 +429,7 @@ async fn test_parameterized_aggregate_udf() -> Result<()> { ]; assert_batches_eq!(expected, &actual); - ctx.deregister_table("t")?; + ctx.deregister_table("t").await?; Ok(()) } @@ -451,12 +452,12 @@ struct TestContext { } impl TestContext { - fn new() -> Self { + async fn new() -> Self { let test_state = Arc::new(TestState::new()); - Self::new_with_test_state(test_state) + Self::new_with_test_state(test_state).await } - fn new_with_test_state(test_state: Arc) -> Self { + async fn new_with_test_state(test_state: Arc) -> Self { let value = Float64Array::from(vec![3.0, 2.0, 1.0, 5.0, 5.0]); let time = TimestampNanosecondArray::from(vec![3000, 2000, 4000, 5000, 5000]); @@ -468,7 +469,7 @@ impl TestContext { let mut ctx = SessionContext::new(); - ctx.register_batch("t", batch).unwrap(); + ctx.register_batch("t", batch).await.unwrap(); // Tell DataFusion about the "first" function FirstSelector::register(&mut ctx); diff --git a/datafusion/core/tests/user_defined/user_defined_plan.rs b/datafusion/core/tests/user_defined/user_defined_plan.rs index 520a91aeb4d6..d00c7278f846 100644 --- a/datafusion/core/tests/user_defined/user_defined_plan.rs +++ b/datafusion/core/tests/user_defined/user_defined_plan.rs @@ -262,14 +262,14 @@ async fn normal_query_with_analyzer() -> Result<()> { // Run the query using topk optimization async fn topk_query() -> Result<()> { // Note the only difference is that the top - let ctx = setup_table(make_topk_context()).await?; + let ctx = setup_table(make_topk_context().await).await?; run_and_compare_query(ctx, "Topk context").await } #[tokio::test] // Run EXPLAIN PLAN and show the plan was in fact rewritten async fn topk_plan() -> Result<()> { - let ctx = setup_table(make_topk_context()).await?; + let ctx = setup_table(make_topk_context().await).await?; let mut expected = ["| logical_plan after topk | TopK: k=3 |", "| | TableScan: sales projection=[customer_id,revenue] |"].join("\n"); @@ -295,7 +295,7 @@ async fn topk_plan() -> Result<()> { Ok(()) } -fn make_topk_context() -> SessionContext { +async fn make_topk_context() -> SessionContext { let config = SessionConfig::new().with_target_partitions(48); let runtime = Arc::new(RuntimeEnv::default()); let state = SessionStateBuilder::new() @@ -305,7 +305,8 @@ fn make_topk_context() -> SessionContext { .with_query_planner(Arc::new(TopKQueryPlanner {})) .with_optimizer_rule(Arc::new(TopKOptimizerRule {})) .with_analyzer_rule(Arc::new(MyAnalyzerRule {})) - .build(); + .build() + .await; SessionContext::new_with_state(state) } diff --git a/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs b/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs index a59394f90814..fc81ed417020 100644 --- a/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs +++ b/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs @@ -104,7 +104,7 @@ async fn scalar_udf() -> Result<()> { let ctx = SessionContext::new(); - ctx.register_batch("t", batch)?; + ctx.register_batch("t", batch).await?; let myfunc = Arc::new(|args: &[ColumnarValue]| { let ColumnarValue::Array(l) = &args[0] else { @@ -171,7 +171,7 @@ async fn scalar_udf() -> Result<()> { assert_eq!(a.value(i) + b.value(i), sum.value(i)); } - ctx.deregister_table("t")?; + ctx.deregister_table("t").await?; Ok(()) } @@ -229,7 +229,7 @@ async fn test_row_mismatch_error_in_scalar_udf() -> Result<()> { let ctx = SessionContext::new(); - ctx.register_batch("t", batch)?; + ctx.register_batch("t", batch).await?; // udf that always return 1 row let buggy_udf = Arc::new(|_: &[ColumnarValue]| { @@ -266,7 +266,7 @@ async fn scalar_udf_zero_params() -> Result<()> { )?; let ctx = SessionContext::new(); - ctx.register_batch("t", batch)?; + ctx.register_batch("t", batch).await?; let get_100_udf = Simple0ArgsScalarUDF { name: "get_100".to_string(), @@ -318,7 +318,7 @@ async fn scalar_udf_override_built_in_scalar_function() -> Result<()> { )?; let ctx = SessionContext::new(); - ctx.register_batch("t", batch)?; + ctx.register_batch("t", batch).await?; // register a UDF that has the same name as a builtin function (abs) and just returns 1 regardless of input ctx.register_udf(create_udf( "abs", @@ -378,15 +378,17 @@ async fn udaf_as_window_func() -> Result<()> { ); let context = SessionContext::new(); - context.register_table( - "my_table", - Arc::new(datafusion::datasource::empty::EmptyTable::new(Arc::new( - Schema::new(vec![ - Field::new("a", DataType::UInt32, false), - Field::new("b", DataType::Int32, false), - ]), - ))), - )?; + context + .register_table( + "my_table", + Arc::new(datafusion::datasource::empty::EmptyTable::new(Arc::new( + Schema::new(vec![ + Field::new("a", DataType::UInt32, false), + Field::new("b", DataType::Int32, false), + ]), + ))), + ) + .await?; context.register_udaf(my_acc); let sql = "SELECT a, MY_ACC(b) OVER(PARTITION BY a) FROM my_table"; @@ -404,7 +406,7 @@ async fn case_sensitive_identifiers_user_defined_functions() -> Result<()> { let ctx = SessionContext::new(); let arr = Int32Array::from(vec![1]); let batch = RecordBatch::try_from_iter(vec![("i", Arc::new(arr) as _)])?; - ctx.register_batch("t", batch).unwrap(); + ctx.register_batch("t", batch).await.unwrap(); let myfunc = Arc::new(|args: &[ColumnarValue]| { let ColumnarValue::Array(array) = &args[0] else { @@ -449,7 +451,7 @@ async fn test_user_defined_functions_with_alias() -> Result<()> { let ctx = SessionContext::new(); let arr = Int32Array::from(vec![1]); let batch = RecordBatch::try_from_iter(vec![("i", Arc::new(arr) as _)])?; - ctx.register_batch("t", batch).unwrap(); + ctx.register_batch("t", batch).await.unwrap(); let myfunc = Arc::new(|args: &[ColumnarValue]| { let ColumnarValue::Array(array) = &args[0] else { @@ -572,7 +574,7 @@ async fn volatile_scalar_udf_with_params() -> Result<()> { )?; let ctx = SessionContext::new(); - ctx.register_batch("t", batch)?; + ctx.register_batch("t", batch).await?; let get_new_str_udf = AddIndexToStringVolatileScalarUDF::new(); @@ -637,7 +639,7 @@ async fn volatile_scalar_udf_with_params() -> Result<()> { )?; let ctx = SessionContext::new(); - ctx.register_batch("t", batch)?; + ctx.register_batch("t", batch).await?; let get_new_str_udf = AddIndexToStringVolatileScalarUDF::new(); @@ -736,7 +738,7 @@ async fn test_user_defined_functions_cast_to_i64() -> Result<()> { vec![Arc::new(Float32Array::from(vec![1.0, 2.0, 3.0]))], )?; - ctx.register_batch("t", batch)?; + ctx.register_batch("t", batch).await?; let cast_to_i64_udf = ScalarUDF::from(CastToI64UDF::new()); ctx.register_udf(cast_to_i64_udf); @@ -1293,7 +1295,7 @@ async fn test_parameterized_scalar_udf() -> Result<()> { )])?; let ctx = SessionContext::new(); - ctx.register_batch("t", batch)?; + ctx.register_batch("t", batch).await?; let t = ctx.table("t").await?; let foo_udf = ScalarUDF::from(MyRegexUdf::new("fo{2}")); let bar_udf = ScalarUDF::from(MyRegexUdf::new("[Bb]ar")); @@ -1323,7 +1325,7 @@ async fn test_parameterized_scalar_udf() -> Result<()> { ]; assert_batches_eq!(expected, &actual); - ctx.deregister_table("t")?; + ctx.deregister_table("t").await?; Ok(()) } diff --git a/datafusion/core/tests/user_defined/user_defined_window_functions.rs b/datafusion/core/tests/user_defined/user_defined_window_functions.rs index 10ee0c5cd2dc..92477427ad21 100644 --- a/datafusion/core/tests/user_defined/user_defined_window_functions.rs +++ b/datafusion/core/tests/user_defined/user_defined_window_functions.rs @@ -63,7 +63,7 @@ const BOUNDED_WINDOW_QUERY: &str = #[tokio::test] async fn test_setup() { let test_state = TestState::new(); - let TestContext { ctx, test_state: _ } = TestContext::new(test_state); + let TestContext { ctx, test_state: _ } = TestContext::new(test_state).await; let sql = "SELECT * from t order by x, y"; let expected = vec![ @@ -89,7 +89,7 @@ async fn test_setup() { #[tokio::test] async fn test_udwf() { let test_state = TestState::new(); - let TestContext { ctx, test_state } = TestContext::new(test_state); + let TestContext { ctx, test_state } = TestContext::new(test_state).await; let expected = vec![ "+---+---+-----+-----------------------------------------------------------------------------------------------------------------------+", @@ -133,7 +133,7 @@ async fn test_deregister_udwf() -> Result<()> { #[tokio::test] async fn test_udwf_with_alias() { let test_state = TestState::new(); - let TestContext { ctx, .. } = TestContext::new(test_state); + let TestContext { ctx, .. } = TestContext::new(test_state).await; let expected = vec![ "+---+---+-----+-----------------------------------------------------------------------------------------------------------------------+", @@ -163,7 +163,7 @@ async fn test_udwf_with_alias() { #[tokio::test] async fn test_udwf_bounded_window_ignores_frame() { let test_state = TestState::new(); - let TestContext { ctx, test_state } = TestContext::new(test_state); + let TestContext { ctx, test_state } = TestContext::new(test_state).await; // Since the UDWF doesn't say it needs the window frame, the frame is ignored let expected = vec![ @@ -195,7 +195,7 @@ async fn test_udwf_bounded_window_ignores_frame() { #[tokio::test] async fn test_udwf_bounded_window() { let test_state = TestState::new().with_uses_window_frame(); - let TestContext { ctx, test_state } = TestContext::new(test_state); + let TestContext { ctx, test_state } = TestContext::new(test_state).await; let expected = vec![ "+---+---+-----+--------------------------------------------------------------------------------------------------------------+", @@ -228,7 +228,7 @@ async fn test_stateful_udwf() { let test_state = TestState::new() .with_supports_bounded_execution() .with_uses_window_frame(); - let TestContext { ctx, test_state } = TestContext::new(test_state); + let TestContext { ctx, test_state } = TestContext::new(test_state).await; let expected = vec![ "+---+---+-----+-----------------------------------------------------------------------------------------------------------------------+", @@ -260,7 +260,7 @@ async fn test_stateful_udwf_bounded_window() { let test_state = TestState::new() .with_supports_bounded_execution() .with_uses_window_frame(); - let TestContext { ctx, test_state } = TestContext::new(test_state); + let TestContext { ctx, test_state } = TestContext::new(test_state).await; let expected = vec![ "+---+---+-----+--------------------------------------------------------------------------------------------------------------+", @@ -291,7 +291,7 @@ async fn test_stateful_udwf_bounded_window() { #[tokio::test] async fn test_udwf_query_include_rank() { let test_state = TestState::new().with_include_rank(); - let TestContext { ctx, test_state } = TestContext::new(test_state); + let TestContext { ctx, test_state } = TestContext::new(test_state).await; let expected = vec![ "+---+---+-----+-----------------------------------------------------------------------------------------------------------------------+", @@ -323,7 +323,7 @@ async fn test_udwf_query_include_rank() { #[tokio::test] async fn test_udwf_bounded_query_include_rank() { let test_state = TestState::new().with_include_rank(); - let TestContext { ctx, test_state } = TestContext::new(test_state); + let TestContext { ctx, test_state } = TestContext::new(test_state).await; let expected = vec![ "+---+---+-----+--------------------------------------------------------------------------------------------------------------+", @@ -357,7 +357,7 @@ async fn test_udwf_bounded_window_returns_null() { let test_state = TestState::new() .with_uses_window_frame() .with_null_for_zero(); - let TestContext { ctx, test_state } = TestContext::new(test_state); + let TestContext { ctx, test_state } = TestContext::new(test_state).await; let expected = vec![ "+---+---+-----+--------------------------------------------------------------------------------------------------------------+", @@ -412,7 +412,7 @@ struct TestContext { } impl TestContext { - fn new(test_state: TestState) -> Self { + async fn new(test_state: TestState) -> Self { let test_state = Arc::new(test_state); let x = Int64Array::from(vec![1, 1, 1, 2, 2, 2, 2, 2, 2, 2]); let y = StringArray::from(vec!["a", "b", "c", "d", "e", "f", "g", "h", "i", "j"]); @@ -427,7 +427,7 @@ impl TestContext { let mut ctx = SessionContext::new(); - ctx.register_batch("t", batch).unwrap(); + ctx.register_batch("t", batch).await.unwrap(); // Tell DataFusion about the window function OddCounter::register(&mut ctx, Arc::clone(&test_state)); diff --git a/datafusion/ffi/src/table_provider.rs b/datafusion/ffi/src/table_provider.rs index 01f7c46106a2..844490cb446d 100644 --- a/datafusion/ffi/src/table_provider.rs +++ b/datafusion/ffi/src/table_provider.rs @@ -225,7 +225,8 @@ unsafe extern "C" fn scan_fn_wrapper( let session = SessionStateBuilder::new() .with_default_features() .with_config(config.0) - .build(); + .build() + .await; let ctx = SessionContext::new_with_state(session); let filters = match filters_serialized.is_empty() { @@ -467,7 +468,8 @@ mod tests { let foreign_table_provider: ForeignTableProvider = (&ffi_provider).into(); - ctx.register_table("t", Arc::new(foreign_table_provider))?; + ctx.register_table("t", Arc::new(foreign_table_provider)) + .await?; let df = ctx.table("t").await?; diff --git a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs index cfb6862a0ca3..d80422716d20 100644 --- a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs @@ -215,7 +215,10 @@ async fn roundtrip_custom_tables() -> Result<()> { let mut table_factories: HashMap> = HashMap::new(); table_factories.insert("TESTTABLE".to_string(), Arc::new(TestTableFactory {})); - let mut state = SessionStateBuilder::new().with_default_features().build(); + let mut state = SessionStateBuilder::new() + .with_default_features() + .build() + .await; // replace factories *state.table_factories_mut() = table_factories; let ctx = SessionContext::new_with_state(state); diff --git a/datafusion/sqllogictest/src/test_context.rs b/datafusion/sqllogictest/src/test_context.rs index 2466303c32a9..3f133bb67172 100644 --- a/datafusion/sqllogictest/src/test_context.rs +++ b/datafusion/sqllogictest/src/test_context.rs @@ -99,7 +99,7 @@ impl TestContext { } } "dynamic_file.slt" => { - test_ctx.ctx = test_ctx.ctx.enable_url_table(); + test_ctx.ctx = test_ctx.ctx.enable_url_table().await; } "joins.slt" => { info!("Registering partition table tables"); @@ -241,6 +241,7 @@ pub async fn register_temp_table(ctx: &SessionContext) { "datafusion.public.temp", Arc::new(TestTable(TableType::Temporary)), ) + .await .unwrap(); } @@ -250,13 +251,17 @@ pub async fn register_table_with_many_types(ctx: &SessionContext) { catalog .register_schema("my_schema", Arc::new(schema)) + .await + .unwrap(); + ctx.register_catalog("my_catalog", Arc::new(catalog)) + .await .unwrap(); - ctx.register_catalog("my_catalog", Arc::new(catalog)); ctx.register_table( "my_catalog.my_schema.table_with_many_types", table_with_many_types(), ) + .await .unwrap(); } @@ -274,6 +279,7 @@ pub async fn register_table_with_map(ctx: &SessionContext) { let memory_table = MemTable::try_new(schema.into(), vec![vec![]]).unwrap(); ctx.register_table("table_with_map", Arc::new(memory_table)) + .await .unwrap(); } @@ -366,7 +372,9 @@ pub async fn register_metadata_tables(ctx: &SessionContext) { ) .unwrap(); - ctx.register_batch("table_with_metadata", batch).unwrap(); + ctx.register_batch("table_with_metadata", batch) + .await + .unwrap(); } /// Create a UDF function named "example". See the `sample_udf.rs` example diff --git a/datafusion/substrait/src/lib.rs b/datafusion/substrait/src/lib.rs index 1389cac75b99..60ec84c31972 100644 --- a/datafusion/substrait/src/lib.rs +++ b/datafusion/substrait/src/lib.rs @@ -59,7 +59,7 @@ //! // Create a plan that scans table 't' //! let ctx = SessionContext::new(); //! let batch = RecordBatch::try_from_iter(vec![("x", Arc::new(Int32Array::from(vec![42])) as _)])?; -//! ctx.register_batch("t", batch)?; +//! ctx.register_batch("t", batch).await?; //! let df = ctx.sql("SELECT x from t").await?; //! let plan = df.into_optimized_plan()?; //! diff --git a/datafusion/substrait/src/logical_plan/producer.rs b/datafusion/substrait/src/logical_plan/producer.rs index 29019dfd74f3..2408e88ade01 100644 --- a/datafusion/substrait/src/logical_plan/producer.rs +++ b/datafusion/substrait/src/logical_plan/producer.rs @@ -2441,7 +2441,7 @@ mod test { #[tokio::test] async fn extended_expressions() -> Result<()> { - let state = SessionStateBuilder::default().build(); + let state = SessionStateBuilder::default().build().await; // One expression, empty input schema let expr = Expr::Literal(ScalarValue::Int32(Some(42))); @@ -2493,7 +2493,7 @@ mod test { #[tokio::test] async fn invalid_extended_expression() { - let state = SessionStateBuilder::default().build(); + let state = SessionStateBuilder::default().build().await; // Not ok if input schema is missing field referenced by expr let expr = Expr::Column("missing".into()); diff --git a/datafusion/substrait/src/logical_plan/state.rs b/datafusion/substrait/src/logical_plan/state.rs index 0bd749c1105d..e9f63e5fd66d 100644 --- a/datafusion/substrait/src/logical_plan/state.rs +++ b/datafusion/substrait/src/logical_plan/state.rs @@ -56,7 +56,7 @@ impl SubstraitPlanningState for SessionState { reference: &TableReference, ) -> Result>, DataFusionError> { let table = reference.table().to_string(); - let schema = self.schema_for_ref(reference.clone())?; + let schema = self.schema_for_ref(reference.clone()).await?; let table_provider = schema.table(&table).await?; Ok(table_provider) } diff --git a/datafusion/substrait/tests/cases/consumer_integration.rs b/datafusion/substrait/tests/cases/consumer_integration.rs index 219f656bb471..f28e9b112645 100644 --- a/datafusion/substrait/tests/cases/consumer_integration.rs +++ b/datafusion/substrait/tests/cases/consumer_integration.rs @@ -40,7 +40,7 @@ mod tests { )) .expect("failed to parse json"); - let ctx = add_plan_schemas_to_ctx(SessionContext::new(), &proto)?; + let ctx = add_plan_schemas_to_ctx(SessionContext::new(), &proto).await?; let plan = from_substrait_plan(&ctx.state(), &proto).await?; Ok(format!("{}", plan)) } diff --git a/datafusion/substrait/tests/cases/emit_kind_tests.rs b/datafusion/substrait/tests/cases/emit_kind_tests.rs index 08537d0d110f..f00e8cd57e79 100644 --- a/datafusion/substrait/tests/cases/emit_kind_tests.rs +++ b/datafusion/substrait/tests/cases/emit_kind_tests.rs @@ -32,7 +32,7 @@ mod tests { let proto_plan = read_json( "tests/testdata/test_plans/emit_kind/direct_on_project.substrait.json", ); - let ctx = add_plan_schemas_to_ctx(SessionContext::new(), &proto_plan)?; + let ctx = add_plan_schemas_to_ctx(SessionContext::new(), &proto_plan).await?; let plan = from_substrait_plan(&ctx.state(), &proto_plan).await?; let plan_str = format!("{}", plan); @@ -50,7 +50,7 @@ mod tests { let proto_plan = read_json( "tests/testdata/test_plans/emit_kind/emit_on_filter.substrait.json", ); - let ctx = add_plan_schemas_to_ctx(SessionContext::new(), &proto_plan)?; + let ctx = add_plan_schemas_to_ctx(SessionContext::new(), &proto_plan).await?; let plan = from_substrait_plan(&ctx.state(), &proto_plan).await?; let plan_str = format!("{}", plan); @@ -69,7 +69,8 @@ mod tests { let state = SessionStateBuilder::new() .with_config(SessionConfig::default()) .with_default_features() - .build(); + .build() + .await; let ctx = SessionContext::new_with_state(state); ctx.register_csv("data", "tests/testdata/data.csv", CsvReadOptions::default()) .await?; diff --git a/datafusion/substrait/tests/cases/function_test.rs b/datafusion/substrait/tests/cases/function_test.rs index 043808456176..725783888645 100644 --- a/datafusion/substrait/tests/cases/function_test.rs +++ b/datafusion/substrait/tests/cases/function_test.rs @@ -28,7 +28,7 @@ mod tests { #[tokio::test] async fn contains_function_test() -> Result<()> { let proto_plan = read_json("tests/testdata/contains_plan.substrait.json"); - let ctx = add_plan_schemas_to_ctx(SessionContext::new(), &proto_plan)?; + let ctx = add_plan_schemas_to_ctx(SessionContext::new(), &proto_plan).await?; let plan = from_substrait_plan(&ctx.state(), &proto_plan).await?; let plan_str = format!("{}", plan); diff --git a/datafusion/substrait/tests/cases/logical_plans.rs b/datafusion/substrait/tests/cases/logical_plans.rs index 65f404bbda55..6419b64f5952 100644 --- a/datafusion/substrait/tests/cases/logical_plans.rs +++ b/datafusion/substrait/tests/cases/logical_plans.rs @@ -37,7 +37,7 @@ mod tests { // ./isthmus-cli/build/graal/isthmus --create "create table data (d boolean)" "select not d from data" let proto_plan = read_json("tests/testdata/test_plans/select_not_bool.substrait.json"); - let ctx = add_plan_schemas_to_ctx(SessionContext::new(), &proto_plan)?; + let ctx = add_plan_schemas_to_ctx(SessionContext::new(), &proto_plan).await?; let plan = from_substrait_plan(&ctx.state(), &proto_plan).await?; assert_eq!( @@ -62,7 +62,7 @@ mod tests { // ./isthmus-cli/build/graal/isthmus --create "create table data (d int, part int, ord int)" "select sum(d) OVER (PARTITION BY part ORDER BY ord ROWS BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING) AS lead_expr from data" let proto_plan = read_json("tests/testdata/test_plans/select_window.substrait.json"); - let ctx = add_plan_schemas_to_ctx(SessionContext::new(), &proto_plan)?; + let ctx = add_plan_schemas_to_ctx(SessionContext::new(), &proto_plan).await?; let plan = from_substrait_plan(&ctx.state(), &proto_plan).await?; assert_eq!( @@ -81,7 +81,7 @@ mod tests { // This test confirms that reading a plan with non-nullable lists works as expected. let proto_plan = read_json("tests/testdata/test_plans/non_nullable_lists.substrait.json"); - let ctx = add_plan_schemas_to_ctx(SessionContext::new(), &proto_plan)?; + let ctx = add_plan_schemas_to_ctx(SessionContext::new(), &proto_plan).await?; let plan = from_substrait_plan(&ctx.state(), &proto_plan).await?; assert_eq!(format!("{}", &plan), "Values: (List([1, 2]))"); diff --git a/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs b/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs index d03ab5182028..dae2234a3a1b 100644 --- a/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs @@ -1372,7 +1372,8 @@ async fn create_context() -> Result { .with_runtime_env(Arc::new(RuntimeEnv::default())) .with_default_features() .with_serializer_registry(Arc::new(MockSerializerRegistry)) - .build(); + .build() + .await; // register udaf for test, e.g. `sum()` datafusion_functions_aggregate::register_all(&mut state) diff --git a/datafusion/substrait/tests/cases/substrait_validations.rs b/datafusion/substrait/tests/cases/substrait_validations.rs index c77bf1489f4e..1600c9f43fb0 100644 --- a/datafusion/substrait/tests/cases/substrait_validations.rs +++ b/datafusion/substrait/tests/cases/substrait_validations.rs @@ -30,7 +30,7 @@ mod tests { use std::collections::HashMap; use std::sync::Arc; - fn generate_context_with_table( + async fn generate_context_with_table( table_name: &str, fields: Vec<(&str, DataType, bool)>, ) -> Result { @@ -52,7 +52,8 @@ mod tests { ctx.register_table( table_ref, Arc::new(EmptyTable::new(df_schema.inner().clone())), - )?; + ) + .await?; Ok(ctx) } @@ -64,7 +65,7 @@ mod tests { let df_schema = vec![("a", DataType::Int32, false), ("b", DataType::Int32, true)]; - let ctx = generate_context_with_table("DATA", df_schema)?; + let ctx = generate_context_with_table("DATA", df_schema).await?; let plan = from_substrait_plan(&ctx.state(), &proto_plan).await?; assert_eq!( @@ -85,7 +86,7 @@ mod tests { ("a", DataType::Int32, false), ("c", DataType::Int32, false), ]; - let ctx = generate_context_with_table("DATA", df_schema)?; + let ctx = generate_context_with_table("DATA", df_schema).await?; let plan = from_substrait_plan(&ctx.state(), &proto_plan).await?; assert_eq!( @@ -108,7 +109,7 @@ mod tests { ("c", DataType::Int32, false), ("b", DataType::Int32, false), ]; - let ctx = generate_context_with_table("DATA", df_schema)?; + let ctx = generate_context_with_table("DATA", df_schema).await?; let plan = from_substrait_plan(&ctx.state(), &proto_plan).await?; assert_eq!( @@ -127,7 +128,7 @@ mod tests { let df_schema = vec![("a", DataType::Int32, false), ("c", DataType::Int32, true)]; - let ctx = generate_context_with_table("DATA", df_schema)?; + let ctx = generate_context_with_table("DATA", df_schema).await?; let res = from_substrait_plan(&ctx.state(), &proto_plan).await; assert!(res.is_err()); Ok(()) @@ -139,7 +140,8 @@ mod tests { read_json("tests/testdata/test_plans/simple_select.substrait.json"); let ctx = - generate_context_with_table("DATA", vec![("a", DataType::Date32, true)])?; + generate_context_with_table("DATA", vec![("a", DataType::Date32, true)]) + .await?; let res = from_substrait_plan(&ctx.state(), &proto_plan).await; assert!(res.is_err()); Ok(()) diff --git a/datafusion/substrait/tests/utils.rs b/datafusion/substrait/tests/utils.rs index 00cbfb0c412c..108a4798f8c5 100644 --- a/datafusion/substrait/tests/utils.rs +++ b/datafusion/substrait/tests/utils.rs @@ -46,7 +46,7 @@ pub mod test { .expect("failed to parse json") } - pub fn add_plan_schemas_to_ctx( + pub async fn add_plan_schemas_to_ctx( ctx: SessionContext, plan: &Plan, ) -> Result { @@ -66,7 +66,7 @@ pub mod test { } } for (table_reference, table) in schema_map.into_iter() { - ctx.register_table(table_reference, table)?; + ctx.register_table(table_reference, table).await?; } Ok(ctx) } diff --git a/docs/source/library-user-guide/using-the-dataframe-api.md b/docs/source/library-user-guide/using-the-dataframe-api.md index 7f3e28c255c6..82bc458264ee 100644 --- a/docs/source/library-user-guide/using-the-dataframe-api.md +++ b/docs/source/library-user-guide/using-the-dataframe-api.md @@ -88,7 +88,7 @@ async fn main() -> Result<()> { ("id", Arc::new(Int32Array::from(vec![1, 2, 3])) as ArrayRef), ("bank_account", Arc::new(Int32Array::from(vec![9000, 8000, 7000]))), ])?; - ctx.register_batch("users", data)?; + ctx.register_batch("users", data).await?; // Create a DataFrame using SQL let dataframe = ctx.sql("SELECT * FROM users;") .await?