diff --git a/rust/arrow/benches/csv_writer.rs b/rust/arrow/benches/csv_writer.rs index ec3bc5a6fab03..49b1eede37247 100644 --- a/rust/arrow/benches/csv_writer.rs +++ b/rust/arrow/benches/csv_writer.rs @@ -48,7 +48,7 @@ fn record_batches_to_csv() { let c3 = PrimitiveArray::::from(vec![3, 2, 1]); let c4 = PrimitiveArray::::from(vec![Some(true), Some(false), None]); - let b = RecordBatch::new( + let b = RecordBatch::try_new( Arc::new(schema), vec![Arc::new(c1), Arc::new(c2), Arc::new(c3), Arc::new(c4)], ); diff --git a/rust/arrow/examples/dynamic_types.rs b/rust/arrow/examples/dynamic_types.rs index 8e6bb5d41c01b..2f361f448d0b6 100644 --- a/rust/arrow/examples/dynamic_types.rs +++ b/rust/arrow/examples/dynamic_types.rs @@ -22,9 +22,10 @@ extern crate arrow; use arrow::array::*; use arrow::datatypes::*; +use arrow::error::Result; use arrow::record_batch::*; -fn main() { +fn main() -> Result<()> { // define schema let schema = Schema::new(vec![ Field::new("id", DataType::Int32, false), @@ -58,9 +59,10 @@ fn main() { ]); // build a record batch - let batch = RecordBatch::new(Arc::new(schema), vec![Arc::new(id), Arc::new(nested)]); + let batch = + RecordBatch::try_new(Arc::new(schema), vec![Arc::new(id), Arc::new(nested)])?; - process(&batch); + Ok(process(&batch)) } /// Create a new batch by performing a projection of id, nested.c @@ -88,7 +90,7 @@ fn process(batch: &RecordBatch) { Field::new("sum", DataType::Float64, false), ]); - let _ = RecordBatch::new( + let _ = RecordBatch::try_new( Arc::new(projected_schema), vec![ id.clone(), // NOTE: this is cloning the Arc not the array data diff --git a/rust/arrow/src/csv/reader.rs b/rust/arrow/src/csv/reader.rs index a511b93463d59..85b2ccd756475 100644 --- a/rust/arrow/src/csv/reader.rs +++ b/rust/arrow/src/csv/reader.rs @@ -329,7 +329,10 @@ impl Reader { let projected_schema = Arc::new(Schema::new(projected_fields)); match arrays { - Ok(arr) => Ok(Some(RecordBatch::new(projected_schema, arr))), + Ok(arr) => match RecordBatch::try_new(projected_schema, arr) { + Ok(batch) => Ok(Some(batch)), + Err(e) => Err(e), + }, Err(e) => Err(e), } } diff --git a/rust/arrow/src/csv/writer.rs b/rust/arrow/src/csv/writer.rs index bf1e582d310c7..945fb7101fd11 100644 --- a/rust/arrow/src/csv/writer.rs +++ b/rust/arrow/src/csv/writer.rs @@ -50,10 +50,10 @@ //! let c3 = PrimitiveArray::::from(vec![3, 2, 1]); //! let c4 = PrimitiveArray::::from(vec![Some(true), Some(false), None]); //! -//! let batch = RecordBatch::new( +//! let batch = RecordBatch::try_new( //! Arc::new(schema), //! vec![Arc::new(c1), Arc::new(c2), Arc::new(c3), Arc::new(c4)], -//! ); +//! ).unwrap(); //! //! let file = get_temp_file("out.csv", &[]); //! @@ -287,10 +287,11 @@ mod tests { let c3 = PrimitiveArray::::from(vec![3, 2, 1]); let c4 = PrimitiveArray::::from(vec![Some(true), Some(false), None]); - let batch = RecordBatch::new( + let batch = RecordBatch::try_new( Arc::new(schema), vec![Arc::new(c1), Arc::new(c2), Arc::new(c3), Arc::new(c4)], - ); + ) + .unwrap(); let file = get_temp_file("columns.csv", &[]); @@ -331,10 +332,11 @@ mod tests { let c3 = PrimitiveArray::::from(vec![3, 2, 1]); let c4 = PrimitiveArray::::from(vec![Some(true), Some(false), None]); - let batch = RecordBatch::new( + let batch = RecordBatch::try_new( Arc::new(schema), vec![Arc::new(c1), Arc::new(c2), Arc::new(c3), Arc::new(c4)], - ); + ) + .unwrap(); let file = get_temp_file("custom_options.csv", &[]); diff --git a/rust/arrow/src/error.rs b/rust/arrow/src/error.rs index 96ed944f98ebd..2f758d4b385fb 100644 --- a/rust/arrow/src/error.rs +++ b/rust/arrow/src/error.rs @@ -30,6 +30,7 @@ pub enum ArrowError { CsvError(String), JsonError(String), IoError(String), + InvalidArgumentError(String), } impl From<::std::io::Error> for ArrowError { diff --git a/rust/arrow/src/json/reader.rs b/rust/arrow/src/json/reader.rs index 14954923ed927..8bdbf89ea1516 100644 --- a/rust/arrow/src/json/reader.rs +++ b/rust/arrow/src/json/reader.rs @@ -487,7 +487,10 @@ impl Reader { let projected_schema = Arc::new(Schema::new(projected_fields)); match arrays { - Ok(arr) => Ok(Some(RecordBatch::new(projected_schema, arr))), + Ok(arr) => match RecordBatch::try_new(projected_schema, arr) { + Ok(batch) => Ok(Some(batch)), + Err(e) => Err(e), + }, Err(e) => Err(e), } } diff --git a/rust/arrow/src/record_batch.rs b/rust/arrow/src/record_batch.rs index e3da6284b0001..62f93b899bee6 100644 --- a/rust/arrow/src/record_batch.rs +++ b/rust/arrow/src/record_batch.rs @@ -25,6 +25,7 @@ use std::sync::Arc; use crate::array::*; use crate::datatypes::*; +use crate::error::{ArrowError, Result}; /// A batch of column-oriented data #[derive(Clone)] @@ -34,36 +35,61 @@ pub struct RecordBatch { } impl RecordBatch { - pub fn new(schema: Arc, columns: Vec) -> Self { - // assert that there are some columns - assert!( - columns.len() > 0, - "at least one column must be defined to create a record batch" - ); - // assert that all columns have the same row count + /// Creates a `RecordBatch` from a schema and columns + /// + /// Expects the following: + /// * the vec of columns to not be empty + /// * the schema and column data types to have equal lengths and match + /// * each array in columns to have the same length + pub fn try_new(schema: Arc, columns: Vec) -> Result { + // check that there are some columns + if columns.is_empty() { + return Err(ArrowError::InvalidArgumentError( + "at least one column must be defined to create a record batch" + .to_string(), + )); + } + // check that number of fields in schema match column length + if schema.fields().len() != columns.len() { + return Err(ArrowError::InvalidArgumentError( + "number of columns must match number of fields in schema".to_string(), + )); + } + // check that all columns have the same row count, and match the schema let len = columns[0].data().len(); - for i in 1..columns.len() { - assert_eq!( - len, - columns[i].len(), - "all columns in a record batch must have the same length" - ); + for i in 0..columns.len() { + if columns[i].len() != len { + return Err(ArrowError::InvalidArgumentError( + "all columns in a record batch must have the same length".to_string(), + )); + } + if columns[i].data_type() != schema.field(i).data_type() { + return Err(ArrowError::InvalidArgumentError(format!( + "column types must match schema types, expected {:?} but found {:?} at column index {}", + schema.field(i).data_type(), + columns[i].data_type(), + i))); + } } - RecordBatch { schema, columns } + Ok(RecordBatch { schema, columns }) } + /// Returns the schema of the record batch pub fn schema(&self) -> &Arc { &self.schema } + /// Number of columns in the record batch pub fn num_columns(&self) -> usize { self.columns.len() } + /// Number of rows in each column pub fn num_rows(&self) -> usize { self.columns[0].data().len() } + /// Get a reference to a column's array by index pub fn column(&self, i: usize) -> &ArrayRef { &self.columns[i] } @@ -103,7 +129,8 @@ mod tests { let b = BinaryArray::from(array_data); let record_batch = - RecordBatch::new(Arc::new(schema), vec![Arc::new(a), Arc::new(b)]); + RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a), Arc::new(b)]) + .unwrap(); assert_eq!(5, record_batch.num_rows()); assert_eq!(2, record_batch.num_columns()); @@ -112,4 +139,26 @@ mod tests { assert_eq!(5, record_batch.column(0).data().len()); assert_eq!(5, record_batch.column(1).data().len()); } + + #[test] + fn create_record_batch_schema_mismatch() { + let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]); + + let a = Int64Array::from(vec![1, 2, 3, 4, 5]); + + let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a)]); + assert!(!batch.is_ok()); + } + + #[test] + fn create_record_batch_record_mismatch() { + let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]); + + let a = Int32Array::from(vec![1, 2, 3, 4, 5]); + let b = Int32Array::from(vec![1, 2, 3, 4, 5]); + + let batch = + RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a), Arc::new(b)]); + assert!(!batch.is_ok()); + } } diff --git a/rust/datafusion/src/datasource/memory.rs b/rust/datafusion/src/datasource/memory.rs index 33673934c65d6..5168ae9a56714 100644 --- a/rust/datafusion/src/datasource/memory.rs +++ b/rust/datafusion/src/datasource/memory.rs @@ -102,20 +102,25 @@ impl Table for MemTable { let projected_schema = Arc::new(Schema::new(projected_columns?)); - Ok(Rc::new(RefCell::new(MemBatchIterator { - schema: projected_schema.clone(), - index: 0, - batches: self - .batches - .iter() - .map(|batch| { - RecordBatch::new( - projected_schema.clone(), - columns.iter().map(|i| batch.column(*i).clone()).collect(), - ) - }) - .collect(), - }))) + let batches = self + .batches + .iter() + .map(|batch| { + RecordBatch::try_new( + projected_schema.clone(), + columns.iter().map(|i| batch.column(*i).clone()).collect(), + ) + }) + .collect(); + + match batches { + Ok(batches) => Ok(Rc::new(RefCell::new(MemBatchIterator { + schema: projected_schema.clone(), + index: 0, + batches, + }))), + Err(e) => Err(ExecutionError::ArrowError(e)), + } } } @@ -155,14 +160,15 @@ mod tests { Field::new("c", DataType::Int32, false), ])); - let batch = RecordBatch::new( + let batch = RecordBatch::try_new( schema.clone(), vec![ Arc::new(Int32Array::from(vec![1, 2, 3])), Arc::new(Int32Array::from(vec![4, 5, 6])), Arc::new(Int32Array::from(vec![7, 8, 9])), ], - ); + ) + .unwrap(); let provider = MemTable::new(schema, vec![batch]).unwrap(); @@ -183,14 +189,15 @@ mod tests { Field::new("c", DataType::Int32, false), ])); - let batch = RecordBatch::new( + let batch = RecordBatch::try_new( schema.clone(), vec![ Arc::new(Int32Array::from(vec![1, 2, 3])), Arc::new(Int32Array::from(vec![4, 5, 6])), Arc::new(Int32Array::from(vec![7, 8, 9])), ], - ); + ) + .unwrap(); let provider = MemTable::new(schema, vec![batch]).unwrap(); @@ -208,14 +215,15 @@ mod tests { Field::new("c", DataType::Int32, false), ])); - let batch = RecordBatch::new( + let batch = RecordBatch::try_new( schema.clone(), vec![ Arc::new(Int32Array::from(vec![1, 2, 3])), Arc::new(Int32Array::from(vec![4, 5, 6])), Arc::new(Int32Array::from(vec![7, 8, 9])), ], - ); + ) + .unwrap(); let provider = MemTable::new(schema, vec![batch]).unwrap(); @@ -243,14 +251,15 @@ mod tests { Field::new("c", DataType::Int32, false), ])); - let batch = RecordBatch::new( + let batch = RecordBatch::try_new( schema1.clone(), vec![ Arc::new(Int32Array::from(vec![1, 2, 3])), Arc::new(Int32Array::from(vec![4, 5, 6])), Arc::new(Int32Array::from(vec![7, 8, 9])), ], - ); + ) + .unwrap(); match MemTable::new(schema2, vec![batch]) { Err(ExecutionError::General(e)) => assert_eq!( diff --git a/rust/datafusion/src/execution/aggregate.rs b/rust/datafusion/src/execution/aggregate.rs index 87a9f111610ec..f9eb3e20fa339 100644 --- a/rust/datafusion/src/execution/aggregate.rs +++ b/rust/datafusion/src/execution/aggregate.rs @@ -800,7 +800,10 @@ impl AggregateRelation { } } - Ok(Some(RecordBatch::new(self.schema.clone(), result_columns))) + Ok(Some(RecordBatch::try_new( + self.schema.clone(), + result_columns, + )?)) } fn with_group_by(&mut self) -> Result> { @@ -1008,7 +1011,10 @@ impl AggregateRelation { result_arrays.push(array?); } - Ok(Some(RecordBatch::new(self.schema.clone(), result_arrays))) + Ok(Some(RecordBatch::try_new( + self.schema.clone(), + result_arrays, + )?)) } } @@ -1136,7 +1142,7 @@ mod tests { .unwrap(); let aggr_schema = Arc::new(Schema::new(vec![ - Field::new("c2", DataType::Int32, false), + Field::new("c2", DataType::UInt32, false), Field::new("min", DataType::Float64, false), Field::new("max", DataType::Float64, false), Field::new("sum", DataType::Float64, false), diff --git a/rust/datafusion/src/execution/context.rs b/rust/datafusion/src/execution/context.rs index b26b31099b5e1..8d4a9848ac433 100644 --- a/rust/datafusion/src/execution/context.rs +++ b/rust/datafusion/src/execution/context.rs @@ -187,8 +187,15 @@ impl ExecutionContext { .collect(); let compiled_aggr_expr = compiled_aggr_expr_result?; + let mut output_fields: Vec = vec![]; + for expr in group_expr { + output_fields.push(expr_to_field(expr, input_schema.as_ref())); + } + for expr in aggr_expr { + output_fields.push(expr_to_field(expr, input_schema.as_ref())); + } let rel = AggregateRelation::new( - Arc::new(Schema::empty()), //(expr_to_field(&compiled_group_expr, &input_schema))), + Arc::new(Schema::new(output_fields)), input_rel, compiled_group_expr, compiled_aggr_expr, diff --git a/rust/datafusion/src/execution/filter.rs b/rust/datafusion/src/execution/filter.rs index c4d1cecbf45ad..34679265bd1ef 100644 --- a/rust/datafusion/src/execution/filter.rs +++ b/rust/datafusion/src/execution/filter.rs @@ -71,7 +71,7 @@ impl Relation for FilterRelation { .collect(); let filtered_batch: RecordBatch = - RecordBatch::new(self.schema.clone(), filtered_columns?); + RecordBatch::try_new(self.schema.clone(), filtered_columns?)?; Ok(Some(filtered_batch)) } diff --git a/rust/datafusion/src/execution/limit.rs b/rust/datafusion/src/execution/limit.rs index bfd8706cc1ae8..c58e4fd819f53 100644 --- a/rust/datafusion/src/execution/limit.rs +++ b/rust/datafusion/src/execution/limit.rs @@ -66,7 +66,7 @@ impl Relation for LimitRelation { .collect(); let limited_batch: RecordBatch = - RecordBatch::new(self.schema.clone(), limited_columns?); + RecordBatch::try_new(self.schema.clone(), limited_columns?)?; self.num_consumed_rows += capacity; Ok(Some(limited_batch)) diff --git a/rust/datafusion/src/execution/projection.rs b/rust/datafusion/src/execution/projection.rs index bcf3ebba7d98d..a02213a57fbb7 100644 --- a/rust/datafusion/src/execution/projection.rs +++ b/rust/datafusion/src/execution/projection.rs @@ -64,7 +64,7 @@ impl Relation for ProjectRelation { ); let projected_batch: RecordBatch = - RecordBatch::new(Arc::new(schema), projected_columns?); + RecordBatch::try_new(Arc::new(schema), projected_columns?)?; Ok(Some(projected_batch)) }