diff --git a/datafusion/core/src/datasource/datasource.rs b/datafusion/core/src/datasource/datasource.rs index 6277ce146adf..8db075a30a79 100644 --- a/datafusion/core/src/datasource/datasource.rs +++ b/datafusion/core/src/datasource/datasource.rs @@ -21,7 +21,7 @@ use std::any::Any; use std::sync::Arc; use async_trait::async_trait; -use datafusion_common::Statistics; +use datafusion_common::{DataFusionError, Statistics}; use datafusion_expr::{CreateExternalTable, LogicalPlan}; pub use datafusion_expr::{TableProviderFilterPushDown, TableType}; @@ -97,6 +97,16 @@ pub trait TableProvider: Sync + Send { fn statistics(&self) -> Option { None } + + /// Insert into this table + async fn insert_into( + &self, + _state: &SessionState, + _input: &LogicalPlan, + ) -> Result<()> { + let msg = "Insertion not implemented for this table".to_owned(); + Err(DataFusionError::NotImplemented(msg)) + } } /// A factory which creates [`TableProvider`]s at runtime given a URL. diff --git a/datafusion/core/src/datasource/memory.rs b/datafusion/core/src/datasource/memory.rs index ac1f4947f87d..b5fa33e38827 100644 --- a/datafusion/core/src/datasource/memory.rs +++ b/datafusion/core/src/datasource/memory.rs @@ -19,18 +19,22 @@ //! queried by DataFusion. This allows data to be pre-loaded into memory and then //! repeatedly queried without incurring additional file I/O overhead. -use futures::StreamExt; +use futures::{StreamExt, TryStreamExt}; use std::any::Any; use std::sync::Arc; use arrow::datatypes::SchemaRef; use arrow::record_batch::RecordBatch; use async_trait::async_trait; +use datafusion_expr::LogicalPlan; +use tokio::sync::RwLock; +use tokio::task; use crate::datasource::{TableProvider, TableType}; use crate::error::{DataFusionError, Result}; use crate::execution::context::SessionState; use crate::logical_expr::Expr; +use crate::physical_plan::coalesce_partitions::CoalescePartitionsExec; use crate::physical_plan::common; use crate::physical_plan::common::AbortOnDropSingle; use crate::physical_plan::memory::MemoryExec; @@ -41,7 +45,7 @@ use crate::physical_plan::{repartition::RepartitionExec, Partitioning}; #[derive(Debug)] pub struct MemTable { schema: SchemaRef, - batches: Vec>, + batches: Arc>>>, } impl MemTable { @@ -54,7 +58,7 @@ impl MemTable { { Ok(Self { schema, - batches: partitions, + batches: Arc::new(RwLock::new(partitions)), }) } else { Err(DataFusionError::Plan( @@ -143,22 +147,102 @@ impl TableProvider for MemTable { _filters: &[Expr], _limit: Option, ) -> Result> { + let batches = &self.batches.read().await; Ok(Arc::new(MemoryExec::try_new( - &self.batches.clone(), + batches, self.schema(), projection.cloned(), )?)) } + + /// Inserts the execution results of a given [LogicalPlan] into this [MemTable]. + /// The `LogicalPlan` must have the same schema as this `MemTable`. + /// + /// # Arguments + /// + /// * `state` - The [SessionState] containing the context for executing the plan. + /// * `input` - The [LogicalPlan] to execute and insert. + /// + /// # Returns + /// + /// * A `Result` indicating success or failure. + async fn insert_into(&self, state: &SessionState, input: &LogicalPlan) -> Result<()> { + // Create a physical plan from the logical plan. + let plan = state.create_physical_plan(input).await?; + + // Check that the schema of the plan matches the schema of this table. + if !plan.schema().eq(&self.schema) { + return Err(DataFusionError::Plan( + "Inserting query must have the same schema with the table.".to_string(), + )); + } + + // Get the number of partitions in the plan and the table. + let plan_partition_count = plan.output_partitioning().partition_count(); + let table_partition_count = self.batches.read().await.len(); + + // Adjust the plan as necessary to match the number of partitions in the table. + let plan: Arc = if plan_partition_count + == table_partition_count + || table_partition_count == 0 + { + plan + } else if table_partition_count == 1 { + // If the table has only one partition, coalesce the partitions in the plan. + Arc::new(CoalescePartitionsExec::new(plan)) + } else { + // Otherwise, repartition the plan using a round-robin partitioning scheme. + Arc::new(RepartitionExec::try_new( + plan, + Partitioning::RoundRobinBatch(table_partition_count), + )?) + }; + + // Get the task context from the session state. + let task_ctx = state.task_ctx(); + + // Execute the plan and collect the results into batches. + let mut tasks = vec![]; + for idx in 0..plan.output_partitioning().partition_count() { + let stream = plan.execute(idx, task_ctx.clone())?; + let handle = task::spawn(async move { + stream.try_collect().await.map_err(DataFusionError::from) + }); + tasks.push(AbortOnDropSingle::new(handle)); + } + let results = futures::future::join_all(tasks) + .await + .into_iter() + .map(|result| { + result.map_err(|e| DataFusionError::Execution(format!("{e}")))? + }) + .collect::>>>()?; + + // Write the results into the table. + let mut all_batches = self.batches.write().await; + + if all_batches.is_empty() { + *all_batches = results + } else { + for (batches, result) in all_batches.iter_mut().zip(results.into_iter()) { + batches.extend(result); + } + } + + Ok(()) + } } #[cfg(test)] mod tests { use super::*; + use crate::datasource::provider_as_source; use crate::from_slice::FromSlice; use crate::prelude::SessionContext; use arrow::array::Int32Array; use arrow::datatypes::{DataType, Field, Schema}; use arrow::error::ArrowError; + use datafusion_expr::LogicalPlanBuilder; use futures::StreamExt; use std::collections::HashMap; @@ -388,4 +472,135 @@ mod tests { Ok(()) } + + fn create_mem_table_scan( + schema: SchemaRef, + data: Vec>, + ) -> Result> { + // Convert the table into a provider so that it can be used in a query + let provider = provider_as_source(Arc::new(MemTable::try_new(schema, data)?)); + // Create a table scan logical plan to read from the table + Ok(Arc::new( + LogicalPlanBuilder::scan("source", provider, None)?.build()?, + )) + } + + fn create_initial_ctx() -> Result<(SessionContext, SchemaRef, RecordBatch)> { + // Create a new session context + let session_ctx = SessionContext::new(); + // Create a new schema with one field called "a" of type Int32 + let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)])); + + // Create a new batch of data to insert into the table + let batch = RecordBatch::try_new( + schema.clone(), + vec![Arc::new(Int32Array::from_slice([1, 2, 3]))], + )?; + Ok((session_ctx, schema, batch)) + } + + #[tokio::test] + async fn test_insert_into_single_partition() -> Result<()> { + let (session_ctx, schema, batch) = create_initial_ctx()?; + let initial_table = Arc::new(MemTable::try_new( + schema.clone(), + vec![vec![batch.clone()]], + )?); + // Create a table scan logical plan to read from the table + let single_partition_table_scan = + create_mem_table_scan(schema.clone(), vec![vec![batch.clone()]])?; + // Insert the data from the provider into the table + initial_table + .insert_into(&session_ctx.state(), &single_partition_table_scan) + .await?; + // Ensure that the table now contains two batches of data in the same partition + assert_eq!(initial_table.batches.read().await.get(0).unwrap().len(), 2); + + // Create a new provider with 2 partitions + let multi_partition_table_scan = create_mem_table_scan( + schema.clone(), + vec![vec![batch.clone()], vec![batch]], + )?; + + // Insert the data from the provider into the table. We expect coalescing partitions. + initial_table + .insert_into(&session_ctx.state(), &multi_partition_table_scan) + .await?; + // Ensure that the table now contains 4 batches of data with only 1 partition + assert_eq!(initial_table.batches.read().await.get(0).unwrap().len(), 4); + assert_eq!(initial_table.batches.read().await.len(), 1); + Ok(()) + } + + #[tokio::test] + async fn test_insert_into_multiple_partition() -> Result<()> { + let (session_ctx, schema, batch) = create_initial_ctx()?; + // create a memory table with two partitions, each having one batch with the same data + let initial_table = Arc::new(MemTable::try_new( + schema.clone(), + vec![vec![batch.clone()], vec![batch.clone()]], + )?); + + // scan a data source provider from a memory table with a single partition + let single_partition_table_scan = create_mem_table_scan( + schema.clone(), + vec![vec![batch.clone(), batch.clone()]], + )?; + + // insert the data from the 1 partition data source provider into the initial table + initial_table + .insert_into(&session_ctx.state(), &single_partition_table_scan) + .await?; + + // We expect round robin repartition here, each partition gets 1 batch. + assert_eq!(initial_table.batches.read().await.get(0).unwrap().len(), 2); + assert_eq!(initial_table.batches.read().await.get(1).unwrap().len(), 2); + + // scan a data source provider from a memory table with 2 partition + let multi_partition_table_scan = create_mem_table_scan( + schema.clone(), + vec![vec![batch.clone()], vec![batch]], + )?; + // We expect one-to-one partition mapping. + initial_table + .insert_into(&session_ctx.state(), &multi_partition_table_scan) + .await?; + // Ensure that the table now contains 3 batches of data with 2 partitions. + assert_eq!(initial_table.batches.read().await.get(0).unwrap().len(), 3); + assert_eq!(initial_table.batches.read().await.get(1).unwrap().len(), 3); + Ok(()) + } + + #[tokio::test] + async fn test_insert_into_empty_table() -> Result<()> { + let (session_ctx, schema, batch) = create_initial_ctx()?; + // create empty memory table + let initial_table = Arc::new(MemTable::try_new(schema.clone(), vec![])?); + + // scan a data source provider from a memory table with a single partition + let single_partition_table_scan = create_mem_table_scan( + schema.clone(), + vec![vec![batch.clone(), batch.clone()]], + )?; + + // insert the data from the 1 partition data source provider into the initial table + initial_table + .insert_into(&session_ctx.state(), &single_partition_table_scan) + .await?; + + assert_eq!(initial_table.batches.read().await.get(0).unwrap().len(), 2); + + // scan a data source provider from a memory table with 2 partition + let single_partition_table_scan = create_mem_table_scan( + schema.clone(), + vec![vec![batch.clone()], vec![batch]], + )?; + // We expect coalesce partitions here. + initial_table + .insert_into(&session_ctx.state(), &single_partition_table_scan) + .await?; + // Ensure that the table now contains 3 batches of data with 2 partitions. + assert_eq!(initial_table.batches.read().await.get(0).unwrap().len(), 4); + Ok(()) + } } diff --git a/datafusion/core/src/execution/context.rs b/datafusion/core/src/execution/context.rs index c985ccdc5e3a..aa02e4be3200 100644 --- a/datafusion/core/src/execution/context.rs +++ b/datafusion/core/src/execution/context.rs @@ -31,7 +31,7 @@ use crate::{ optimizer::PhysicalOptimizerRule, }, }; -use datafusion_expr::{DescribeTable, StringifiedPlan}; +use datafusion_expr::{DescribeTable, DmlStatement, StringifiedPlan, WriteOp}; pub use datafusion_physical_expr::execution_props::ExecutionProps; use datafusion_physical_expr::var_provider::is_system_variables; use parking_lot::RwLock; @@ -308,7 +308,8 @@ impl SessionContext { /// Creates a [`DataFrame`] that will execute a SQL query. /// - /// Note: This API implements DDL such as `CREATE TABLE` and `CREATE VIEW` with in-memory + /// Note: This API implements DDL statements such as `CREATE TABLE` and + /// `CREATE VIEW` and DML statements such as `INSERT INTO` with in-memory /// default implementations. /// /// If this is not desirable, consider using [`SessionState::create_logical_plan()`] which @@ -318,6 +319,24 @@ impl SessionContext { let plan = self.state().create_logical_plan(sql).await?; match plan { + LogicalPlan::Dml(DmlStatement { + table_name, + op: WriteOp::Insert, + input, + .. + }) => { + if self.table_exist(&table_name)? { + let name = table_name.table(); + let provider = self.table_provider(name).await?; + provider.insert_into(&self.state(), &input).await?; + } else { + return Err(DataFusionError::Execution(format!( + "Table '{}' does not exist", + table_name + ))); + } + self.return_empty_dataframe() + } LogicalPlan::CreateExternalTable(cmd) => { self.create_external_table(&cmd).await } diff --git a/datafusion/core/tests/sqllogictests/src/engines/datafusion/insert.rs b/datafusion/core/tests/sqllogictests/src/engines/datafusion/insert.rs deleted file mode 100644 index a8fca3b16c06..000000000000 --- a/datafusion/core/tests/sqllogictests/src/engines/datafusion/insert.rs +++ /dev/null @@ -1,93 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -use super::error::Result; -use crate::engines::datafusion::util::LogicTestContextProvider; -use crate::engines::output::DFOutput; -use arrow::record_batch::RecordBatch; -use datafusion::datasource::MemTable; -use datafusion::prelude::SessionContext; -use datafusion_common::{DFSchema, DataFusionError}; -use datafusion_expr::Expr as DFExpr; -use datafusion_sql::planner::{object_name_to_table_reference, PlannerContext, SqlToRel}; -use sqllogictest::DBOutput; -use sqlparser::ast::{Expr, SetExpr, Statement as SQLStatement}; -use std::sync::Arc; - -pub async fn insert(ctx: &SessionContext, insert_stmt: SQLStatement) -> Result { - // First, use sqlparser to get table name and insert values - let table_reference; - let insert_values: Vec>; - match insert_stmt { - SQLStatement::Insert { - table_name, source, .. - } => { - table_reference = object_name_to_table_reference( - table_name, - ctx.enable_ident_normalization(), - )?; - - // Todo: check columns match table schema - match *source.body { - SetExpr::Values(values) => { - insert_values = values.rows; - } - _ => { - // Directly panic: make it easy to find the location of the error. - panic!() - } - } - } - _ => unreachable!(), - } - - // Second, get batches in table and destroy the old table - let mut origin_batches = ctx.table(&table_reference).await?.collect().await?; - let schema = ctx.table_provider(&table_reference).await?.schema(); - ctx.deregister_table(&table_reference)?; - - // Third, transfer insert values to `RecordBatch` - // Attention: schema info can be ignored. (insert values don't contain schema info) - let sql_to_rel = SqlToRel::new(&LogicTestContextProvider {}); - let num_rows = insert_values.len(); - for row in insert_values.into_iter() { - let logical_exprs = row - .into_iter() - .map(|expr| { - sql_to_rel.sql_to_expr( - expr, - &DFSchema::empty(), - &mut PlannerContext::new(), - ) - }) - .collect::, DataFusionError>>()?; - // Directly use `select` to get `RecordBatch` - let dataframe = ctx.read_empty()?; - origin_batches.extend(dataframe.select(logical_exprs)?.collect().await?) - } - - // Replace new batches schema to old schema - for batch in origin_batches.iter_mut() { - *batch = RecordBatch::try_new(schema.clone(), batch.columns().to_vec())?; - } - - // Final, create new memtable with same schema. - let new_provider = MemTable::try_new(schema, vec![origin_batches])?; - ctx.register_table(&table_reference, Arc::new(new_provider))?; - - Ok(DBOutput::StatementComplete(num_rows as u64)) -} diff --git a/datafusion/core/tests/sqllogictests/src/engines/datafusion/mod.rs b/datafusion/core/tests/sqllogictests/src/engines/datafusion/mod.rs index 1f8f7feb36e5..cdd6663a5e0b 100644 --- a/datafusion/core/tests/sqllogictests/src/engines/datafusion/mod.rs +++ b/datafusion/core/tests/sqllogictests/src/engines/datafusion/mod.rs @@ -26,13 +26,11 @@ use create_table::create_table; use datafusion::arrow::record_batch::RecordBatch; use datafusion::prelude::SessionContext; use datafusion_sql::parser::{DFParser, Statement}; -use insert::insert; use sqllogictest::DBOutput; use sqlparser::ast::Statement as SQLStatement; mod create_table; mod error; -mod insert; mod normalize; mod util; @@ -85,7 +83,6 @@ async fn run_query(ctx: &SessionContext, sql: impl Into) -> Result return insert(ctx, statement).await, SQLStatement::CreateTable { query, constraints, diff --git a/datafusion/core/tests/sqllogictests/test_files/ddl.slt b/datafusion/core/tests/sqllogictests/test_files/ddl.slt index 642093c364f3..59bfc91b541f 100644 --- a/datafusion/core/tests/sqllogictests/test_files/ddl.slt +++ b/datafusion/core/tests/sqllogictests/test_files/ddl.slt @@ -63,7 +63,7 @@ statement error Table 'user' doesn't exist. DROP TABLE user; # Can not insert into a undefined table -statement error No table named 'user' +statement error DataFusion error: Error during planning: table 'datafusion.public.user' not found insert into user values(1, 20); ########## @@ -421,9 +421,27 @@ statement ok DROP TABLE aggregate_simple +# sql_table_insert +statement ok +CREATE TABLE abc AS VALUES (1,2,3), (4,5,6); + +statement ok +CREATE TABLE xyz AS VALUES (1,3,3), (5,5,6); + +statement ok +INSERT INTO abc SELECT * FROM xyz; + +query III +SELECT * FROM abc +---- +1 2 3 +4 5 6 +1 3 3 +5 5 6 + # Should create an empty table statement ok -CREATE TABLE table_without_values(field1 BIGINT, field2 BIGINT); +CREATE TABLE table_without_values(field1 BIGINT NULL, field2 BIGINT NULL); # Should skip existing table @@ -444,8 +462,8 @@ CREATE OR REPLACE TABLE IF NOT EXISTS table_without_values(field1 BIGINT, field2 statement ok insert into table_without_values values (1, 2), (2, 3), (2, 4); -query II rowsort -select * from table_without_values; +query II +select * from table_without_values ---- 1 2 2 3 @@ -454,7 +472,7 @@ select * from table_without_values; # Should recreate existing table statement ok -CREATE OR REPLACE TABLE table_without_values(field1 BIGINT, field2 BIGINT); +CREATE OR REPLACE TABLE table_without_values(field1 BIGINT NULL, field2 BIGINT NULL); # Should insert into a recreated table diff --git a/datafusion/expr/src/logical_plan/builder.rs b/datafusion/expr/src/logical_plan/builder.rs index a1fa82cda24e..9be046934233 100644 --- a/datafusion/expr/src/logical_plan/builder.rs +++ b/datafusion/expr/src/logical_plan/builder.rs @@ -24,7 +24,7 @@ use crate::expr_rewriter::{ }; use crate::type_coercion::binary::comparison_coercion; use crate::utils::{columnize_expr, compare_sort_expr, exprlist_to_fields, from_plan}; -use crate::{and, binary_expr, Operator}; +use crate::{and, binary_expr, DmlStatement, Operator, WriteOp}; use crate::{ logical_plan::{ Aggregate, Analyze, CrossJoin, Distinct, EmptyRelation, Explain, Filter, Join, @@ -40,8 +40,8 @@ use crate::{ }; use arrow::datatypes::{DataType, Schema, SchemaRef}; use datafusion_common::{ - Column, DFField, DFSchema, DFSchemaRef, DataFusionError, Result, ScalarValue, - ToDFSchema, + Column, DFField, DFSchema, DFSchemaRef, DataFusionError, OwnedTableReference, Result, + ScalarValue, ToDFSchema, }; use std::any::Any; use std::cmp::Ordering; @@ -201,6 +201,21 @@ impl LogicalPlanBuilder { Self::scan_with_filters(table_name, table_source, projection, vec![]) } + /// Create a [DmlStatement] for inserting the contents of this builder into the named table + pub fn insert_into( + input: LogicalPlan, + table_name: impl Into, + table_schema: &Schema, + ) -> Result { + let table_schema = table_schema.clone().to_dfschema_ref()?; + Ok(Self::from(LogicalPlan::Dml(DmlStatement { + table_name: table_name.into(), + table_schema, + op: WriteOp::Insert, + input: Arc::new(input), + }))) + } + /// Convert a table provider into a builder with a TableScan pub fn scan_with_filters( table_name: impl Into,