Skip to content

Commit 0b00c6f

Browse files
authored
[Config] - Support for key-value pair configuration settings (#52)
* feat: impl a new Config class * fix: add u64 support for config
1 parent 259f2e4 commit 0b00c6f

File tree

5 files changed

+129
-4
lines changed

5 files changed

+129
-4
lines changed

datafusion/__init__.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,22 +23,21 @@
2323
except ImportError:
2424
import importlib_metadata
2525

26-
2726
import pyarrow as pa
2827

2928
from ._internal import (
3029
AggregateUDF,
30+
Config,
3131
DataFrame,
3232
SessionContext,
3333
Expression,
3434
ScalarUDF,
3535
)
3636

37-
3837
__version__ = importlib_metadata.version(__name__)
3938

40-
4139
__all__ = [
40+
"Config",
4241
"DataFrame",
4342
"SessionContext",
4443
"Expression",

datafusion/tests/test_config.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
18+
from datafusion import Config
19+
import pytest
20+
21+
22+
@pytest.fixture
23+
def config():
24+
return Config()
25+
26+
27+
def test_get_then_set(config):
28+
config_key = "datafusion.optimizer.filter_null_join_keys"
29+
30+
assert config.get(config_key).as_py() is False
31+
32+
config.set(config_key, True)
33+
assert config.get(config_key).as_py() is True
34+
35+
36+
def test_get_all(config):
37+
config.get_all()
38+
39+
40+
def test_get_invalid_config(config):
41+
assert config.get("not.valid.key") is None

src/config.rs

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
use pyo3::prelude::*;
19+
use pyo3::types::*;
20+
21+
use datafusion::config::ConfigOptions;
22+
use datafusion_common::ScalarValue;
23+
24+
#[pyclass(name = "Config", module = "datafusion", subclass)]
25+
#[derive(Clone)]
26+
pub(crate) struct PyConfig {
27+
config: ConfigOptions,
28+
}
29+
30+
#[pymethods]
31+
impl PyConfig {
32+
#[new]
33+
fn py_new() -> Self {
34+
Self {
35+
config: ConfigOptions::new(),
36+
}
37+
}
38+
39+
/// Get configurations from environment variables
40+
#[staticmethod]
41+
pub fn from_env() -> Self {
42+
Self {
43+
config: ConfigOptions::from_env(),
44+
}
45+
}
46+
47+
/// Get a configuration option
48+
pub fn get(&mut self, key: &str, py: Python) -> PyResult<PyObject> {
49+
Ok(self.config.get(key).into_py(py))
50+
}
51+
52+
/// Set a configuration option
53+
pub fn set(&mut self, key: &str, value: PyObject, py: Python) {
54+
self.config.set(key, py_obj_to_scalar_value(py, value))
55+
}
56+
57+
/// Get all configuration options
58+
pub fn get_all(&mut self, py: Python) -> PyResult<PyObject> {
59+
let dict = PyDict::new(py);
60+
for (key, value) in self.config.options() {
61+
dict.set_item(key, value.clone().into_py(py))?;
62+
}
63+
Ok(dict.into())
64+
}
65+
}
66+
67+
/// Convert a python object to a ScalarValue
68+
fn py_obj_to_scalar_value(py: Python, obj: PyObject) -> ScalarValue {
69+
if let Ok(value) = obj.extract::<bool>(py) {
70+
ScalarValue::Boolean(Some(value))
71+
} else if let Ok(value) = obj.extract::<i64>(py) {
72+
ScalarValue::Int64(Some(value))
73+
} else if let Ok(value) = obj.extract::<u64>(py) {
74+
ScalarValue::UInt64(Some(value))
75+
} else if let Ok(value) = obj.extract::<f64>(py) {
76+
ScalarValue::Float64(Some(value))
77+
} else if let Ok(value) = obj.extract::<String>(py) {
78+
ScalarValue::Utf8(Some(value))
79+
} else {
80+
panic!("Unsupported value type")
81+
}
82+
}

src/expression.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ use pyo3::{basic::CompareOp, prelude::*};
1919
use std::convert::{From, Into};
2020

2121
use datafusion::arrow::datatypes::DataType;
22-
use datafusion::logical_plan::{col, lit, Expr};
22+
use datafusion_expr::{col, lit, Expr};
2323

2424
use datafusion::scalar::ScalarValue;
2525

src/lib.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@ use pyo3::prelude::*;
2222
#[allow(clippy::borrow_deref_ref)]
2323
pub mod catalog;
2424
#[allow(clippy::borrow_deref_ref)]
25+
mod config;
26+
#[allow(clippy::borrow_deref_ref)]
2527
mod context;
2628
#[allow(clippy::borrow_deref_ref)]
2729
mod dataframe;
@@ -58,6 +60,7 @@ fn _internal(py: Python, m: &PyModule) -> PyResult<()> {
5860
m.add_class::<expression::PyExpr>()?;
5961
m.add_class::<udf::PyScalarUDF>()?;
6062
m.add_class::<udaf::PyAggregateUDF>()?;
63+
m.add_class::<config::PyConfig>()?;
6164

6265
// Register the functions as a submodule
6366
let funcs = PyModule::new(py, "functions")?;

0 commit comments

Comments
 (0)