Skip to content

Commit f541e13

Browse files
committed
Support casting Utf8 to Boolean (apache#1738)
1 parent 52d28cd commit f541e13

File tree

1 file changed

+62
-7
lines changed

1 file changed

+62
-7
lines changed

arrow/src/compute/kernels/cast.rs

Lines changed: 62 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,7 @@ pub fn can_cast_types(from_type: &DataType, to_type: &DataType) -> bool {
135135
(Dictionary(_, value_type), _) => can_cast_types(value_type, to_type),
136136
(_, Dictionary(_, value_type)) => can_cast_types(from_type, value_type),
137137

138-
(_, Boolean) => DataType::is_numeric(from_type),
138+
(_, Boolean) => DataType::is_numeric(from_type) || from_type == &Utf8,
139139
(Boolean, _) => DataType::is_numeric(to_type) || to_type == &Utf8,
140140

141141
(Utf8, LargeUtf8) => true,
@@ -252,6 +252,8 @@ pub fn can_cast_types(from_type: &DataType, to_type: &DataType) -> bool {
252252
///
253253
/// Behavior:
254254
/// * Boolean to Utf8: `true` => '1', `false` => `0`
255+
/// * Utf8 to boolean: `true`, `yes`, `on`, `1` => `true`, `false`, `no`, `off`, `0` => `false`,
256+
/// short variants are accepted, other strings return null or error
255257
/// * Utf8 to numeric: strings that can't be parsed to numbers return null, float strings
256258
/// in integer casts return null
257259
/// * Numeric to boolean: 0 returns `false`, any other value returns `true`
@@ -265,7 +267,6 @@ pub fn can_cast_types(from_type: &DataType, to_type: &DataType) -> bool {
265267
/// Unsupported Casts
266268
/// * To or from `StructArray`
267269
/// * List to primitive
268-
/// * Utf8 to boolean
269270
/// * Interval and duration
270271
pub fn cast(array: &ArrayRef, to_type: &DataType) -> Result<ArrayRef> {
271272
cast_with_options(array, to_type, &DEFAULT_CAST_OPTIONS)
@@ -368,6 +369,8 @@ macro_rules! cast_decimal_to_float {
368369
///
369370
/// Behavior:
370371
/// * Boolean to Utf8: `true` => '1', `false` => `0`
372+
/// * Utf8 to boolean: `true`, `yes`, `on`, `1` => `true`, `false`, `no`, `off`, `0` => `false`,
373+
/// short variants are accepted, other strings return null or error
371374
/// * Utf8 to numeric: strings that can't be parsed to numbers return null, float strings
372375
/// in integer casts return null
373376
/// * Numeric to boolean: 0 returns `false`, any other value returns `true`
@@ -381,7 +384,6 @@ macro_rules! cast_decimal_to_float {
381384
/// Unsupported Casts
382385
/// * To or from `StructArray`
383386
/// * List to primitive
384-
/// * Utf8 to boolean
385387
pub fn cast_with_options(
386388
array: &ArrayRef,
387389
to_type: &DataType,
@@ -589,10 +591,7 @@ pub fn cast_with_options(
589591
Int64 => cast_numeric_to_bool::<Int64Type>(array),
590592
Float32 => cast_numeric_to_bool::<Float32Type>(array),
591593
Float64 => cast_numeric_to_bool::<Float64Type>(array),
592-
Utf8 => Err(ArrowError::CastError(format!(
593-
"Casting from {:?} to {:?} not supported",
594-
from_type, to_type,
595-
))),
594+
Utf8 => cast_utf8_to_boolean(array, cast_options),
596595
_ => Err(ArrowError::CastError(format!(
597596
"Casting from {:?} to {:?} not supported",
598597
from_type, to_type,
@@ -1561,6 +1560,34 @@ fn cast_string_to_timestamp_ns<Offset: StringOffsetSizeTrait>(
15611560
Ok(Arc::new(array) as ArrayRef)
15621561
}
15631562

1563+
/// Casts Utf8 to Boolean
1564+
fn cast_utf8_to_boolean(from: &ArrayRef, cast_options: &CastOptions) -> Result<ArrayRef> {
1565+
let array = as_string_array(from);
1566+
1567+
let output_array = array
1568+
.iter()
1569+
.map(|value| match value {
1570+
Some(value) => match value.to_ascii_lowercase().trim() {
1571+
"t" | "tr" | "tru" | "true" | "y" | "ye" | "yes" | "on" | "1" => {
1572+
Ok(Some(true))
1573+
}
1574+
"f" | "fa" | "fal" | "fals" | "false" | "n" | "no" | "of" | "off"
1575+
| "0" => Ok(Some(false)),
1576+
invalid_value => match cast_options.safe {
1577+
true => Ok(None),
1578+
false => Err(ArrowError::CastError(format!(
1579+
"Cannot cast string '{}' to value of Boolean type",
1580+
invalid_value,
1581+
))),
1582+
},
1583+
},
1584+
None => Ok(None),
1585+
})
1586+
.collect::<Result<BooleanArray>>()?;
1587+
1588+
Ok(Arc::new(output_array))
1589+
}
1590+
15641591
/// Cast numeric types to Boolean
15651592
///
15661593
/// Any zero value returns `false` while non-zero returns `true`
@@ -2538,6 +2565,34 @@ mod tests {
25382565
}
25392566
}
25402567

2568+
#[test]
2569+
fn test_cast_utf8_to_bool() {
2570+
let strings = Arc::new(StringArray::from(vec![
2571+
"true", "false", "invalid", " Y ", "",
2572+
])) as ArrayRef;
2573+
let casted = cast(&strings, &DataType::Boolean).unwrap();
2574+
let expected =
2575+
BooleanArray::from(vec![Some(true), Some(false), None, Some(true), None]);
2576+
assert_eq!(*as_boolean_array(&casted), expected);
2577+
}
2578+
2579+
#[test]
2580+
fn test_cast_with_options_utf8_to_bool() {
2581+
let strings = Arc::new(StringArray::from(vec![
2582+
"true", "false", "invalid", " Y ", "",
2583+
])) as ArrayRef;
2584+
let casted =
2585+
cast_with_options(&strings, &DataType::Boolean, &CastOptions { safe: false });
2586+
match casted {
2587+
Ok(_) => panic!("expected error"),
2588+
Err(e) => {
2589+
assert!(e.to_string().contains(
2590+
"Cast error: Cannot cast string 'invalid' to value of Boolean type"
2591+
))
2592+
}
2593+
}
2594+
}
2595+
25412596
#[test]
25422597
fn test_cast_bool_to_i32() {
25432598
let a = BooleanArray::from(vec![Some(true), Some(false), None]);

0 commit comments

Comments
 (0)