Skip to content

Commit 04cbd33

Browse files
committed
Update for changes in apache/arrow-rs#2711
1 parent 259f2e4 commit 04cbd33

9 files changed

+111
-91
lines changed

Cargo.lock

Lines changed: 41 additions & 61 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,3 +52,14 @@ name = "datafusion._internal"
5252
[profile.release]
5353
lto = true
5454
codegen-units = 1
55+
56+
57+
[patch.crates-io]
58+
arrow = { git = "https://github.com/tustvold/arrow-rs.git", rev = "6f62bb62f630cbd910ae5b1b04f97688af7c1b42" }
59+
parquet = { git = "https://github.com/tustvold/arrow-rs.git", rev = "6f62bb62f630cbd910ae5b1b04f97688af7c1b42" }
60+
arrow-buffer = { git = "https://github.com/tustvold/arrow-rs.git", rev = "6f62bb62f630cbd910ae5b1b04f97688af7c1b42" }
61+
arrow-schema = { git = "https://github.com/tustvold/arrow-rs.git", rev = "6f62bb62f630cbd910ae5b1b04f97688af7c1b42" }
62+
63+
datafusion = { git = "https://github.com/tustvold/arrow-datafusion.git", rev = "9de354bf45c0cc4121af04ea8138df7fddab76ed" }
64+
datafusion-expr = { git = "https://github.com/tustvold/arrow-datafusion.git", rev = "9de354bf45c0cc4121af04ea8138df7fddab76ed" }
65+
datafusion-common = { git = "https://github.com/tustvold/arrow-datafusion.git", rev = "9de354bf45c0cc4121af04ea8138df7fddab76ed"}

src/context.rs

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ use pyo3::exceptions::{PyKeyError, PyValueError};
2424
use pyo3::prelude::*;
2525

2626
use datafusion::arrow::datatypes::Schema;
27+
use datafusion::arrow::pyarrow::PyArrowType;
2728
use datafusion::arrow::record_batch::RecordBatch;
2829
use datafusion::datasource::datasource::TableProvider;
2930
use datafusion::datasource::MemTable;
@@ -99,9 +100,16 @@ impl PySessionContext {
99100
Ok(PyDataFrame::new(df))
100101
}
101102

102-
fn create_dataframe(&mut self, partitions: Vec<Vec<RecordBatch>>) -> PyResult<PyDataFrame> {
103-
let table = MemTable::try_new(partitions[0][0].schema(), partitions)
104-
.map_err(DataFusionError::from)?;
103+
fn create_dataframe(
104+
&mut self,
105+
partitions: Vec<Vec<PyArrowType<RecordBatch>>>,
106+
) -> PyResult<PyDataFrame> {
107+
let schema = partitions[0][0].0.schema();
108+
let partitions = partitions
109+
.into_iter()
110+
.map(|x| x.into_iter().map(|x| x.0).collect())
111+
.collect();
112+
let table = MemTable::try_new(schema, partitions).map_err(DataFusionError::from)?;
105113

106114
// generate a random (unique) name for this table
107115
// table name cannot start with numeric digit
@@ -136,9 +144,13 @@ impl PySessionContext {
136144
fn register_record_batches(
137145
&mut self,
138146
name: &str,
139-
partitions: Vec<Vec<RecordBatch>>,
147+
partitions: Vec<Vec<PyArrowType<RecordBatch>>>,
140148
) -> PyResult<()> {
141-
let schema = partitions[0][0].schema();
149+
let schema = partitions[0][0].0.schema();
150+
let partitions = partitions
151+
.into_iter()
152+
.map(|x| x.into_iter().map(|x| x.0).collect())
153+
.collect();
142154
let table = MemTable::try_new(schema, partitions)?;
143155
self.ctx
144156
.register_table(name, Arc::new(table))
@@ -182,7 +194,7 @@ impl PySessionContext {
182194
&mut self,
183195
name: &str,
184196
path: PathBuf,
185-
schema: Option<Schema>,
197+
schema: Option<PyArrowType<Schema>>,
186198
has_header: bool,
187199
delimiter: &str,
188200
schema_infer_max_records: usize,
@@ -204,7 +216,7 @@ impl PySessionContext {
204216
.delimiter(delimiter[0])
205217
.schema_infer_max_records(schema_infer_max_records)
206218
.file_extension(file_extension);
207-
options.schema = schema.as_ref();
219+
options.schema = schema.as_ref().map(|x| &x.0);
208220

209221
let result = self.ctx.register_csv(name, path, options);
210222
wait_for_future(py, result).map_err(DataFusionError::from)?;

src/dataframe.rs

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
use crate::utils::wait_for_future;
1919
use crate::{errors::DataFusionError, expression::PyExpr};
2020
use datafusion::arrow::datatypes::Schema;
21-
use datafusion::arrow::pyarrow::PyArrowConvert;
21+
use datafusion::arrow::pyarrow::{PyArrowConvert, PyArrowException, PyArrowType};
2222
use datafusion::arrow::util::pretty;
2323
use datafusion::dataframe::DataFrame;
2424
use datafusion::prelude::*;
@@ -65,8 +65,8 @@ impl PyDataFrame {
6565
}
6666

6767
/// Returns the schema from the logical plan
68-
fn schema(&self) -> Schema {
69-
self.df.schema().into()
68+
fn schema(&self) -> PyArrowType<Schema> {
69+
PyArrowType(self.df.schema().into())
7070
}
7171

7272
#[args(args = "*")]
@@ -144,7 +144,8 @@ impl PyDataFrame {
144144
fn show(&self, py: Python, num: usize) -> PyResult<()> {
145145
let df = self.df.limit(0, Some(num))?;
146146
let batches = wait_for_future(py, df.collect())?;
147-
Ok(pretty::print_batches(&batches)?)
147+
Ok(pretty::print_batches(&batches)
148+
.map_err(|err| PyArrowException::new_err(err.to_string()))?)
148149
}
149150

150151
/// Filter out duplicate rows
@@ -186,7 +187,8 @@ impl PyDataFrame {
186187
fn explain(&self, py: Python, verbose: bool, analyze: bool) -> PyResult<()> {
187188
let df = self.df.explain(verbose, analyze)?;
188189
let batches = wait_for_future(py, df.collect())?;
189-
Ok(pretty::print_batches(&batches)?)
190+
Ok(pretty::print_batches(&batches)
191+
.map_err(|err| PyArrowException::new_err(err.to_string()))?)
190192
}
191193

192194
/// Repartition a `DataFrame` based on a logical partitioning scheme.

src/dataset.rs

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ use std::sync::Arc;
2727
use async_trait::async_trait;
2828

2929
use datafusion::arrow::datatypes::SchemaRef;
30+
use datafusion::arrow::pyarrow::PyArrowType;
3031
use datafusion::datasource::datasource::TableProviderFilterPushDown;
3132
use datafusion::datasource::{TableProvider, TableType};
3233
use datafusion::error::{DataFusionError, Result as DFResult};
@@ -74,7 +75,14 @@ impl TableProvider for Dataset {
7475
Python::with_gil(|py| {
7576
let dataset = self.dataset.as_ref(py);
7677
// This can panic but since we checked that self.dataset is a pyarrow.dataset.Dataset it should never
77-
Arc::new(dataset.getattr("schema").unwrap().extract().unwrap())
78+
Arc::new(
79+
dataset
80+
.getattr("schema")
81+
.unwrap()
82+
.extract::<PyArrowType<_>>()
83+
.unwrap()
84+
.0,
85+
)
7886
})
7987
}
8088

0 commit comments

Comments
 (0)