Skip to content

Commit 43a3a38

Browse files
committed
681: implement rtrimmed_length udf
adds rtrimmed_length implementation according to snowflake spec
1 parent 19fd41b commit 43a3a38

File tree

2 files changed

+123
-0
lines changed

2 files changed

+123
-0
lines changed

crates/df-builtins/src/lib.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ mod is_array;
2626
mod is_object;
2727
mod nullifzero;
2828
mod parse_json;
29+
mod rtrimmed_length;
2930
pub mod table;
3031
mod time_from_parts;
3132
mod timestamp_from_parts;
@@ -49,6 +50,7 @@ pub fn register_udfs(registry: &mut dyn FunctionRegistry) -> Result<()> {
4950
nullifzero::get_udf(),
5051
is_object::get_udf(),
5152
is_array::get_udf(),
53+
rtrimmed_length::get_udf(),
5254
Arc::new(ScalarUDF::from(ToBooleanFunc::new(false))),
5355
Arc::new(ScalarUDF::from(ToBooleanFunc::new(true))),
5456
Arc::new(ScalarUDF::from(ToTimeFunc::new(false))),
Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
use datafusion::arrow::{array::UInt64Array, datatypes::DataType};
2+
use datafusion::error::Result as DFResult;
3+
use datafusion_common::cast::as_string_array;
4+
use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility};
5+
use std::any::Any;
6+
use std::sync::Arc;
7+
8+
// rtrimmed_length SQL function
9+
// Returns the length of its argument, minus trailing whitespace, but including leading whitespace.
10+
// Syntax: RTRIMMED_LENGTH( <string_expr> )
11+
#[derive(Debug)]
12+
pub struct RTrimmedLengthFunc {
13+
signature: Signature,
14+
}
15+
16+
impl Default for RTrimmedLengthFunc {
17+
fn default() -> Self {
18+
Self::new()
19+
}
20+
}
21+
22+
impl RTrimmedLengthFunc {
23+
pub fn new() -> Self {
24+
Self {
25+
signature: Signature::exact(vec![DataType::Utf8], Volatility::Immutable),
26+
}
27+
}
28+
}
29+
30+
impl ScalarUDFImpl for RTrimmedLengthFunc {
31+
fn as_any(&self) -> &dyn Any {
32+
self
33+
}
34+
35+
fn name(&self) -> &'static str {
36+
"rtrimmed_length"
37+
}
38+
39+
fn signature(&self) -> &Signature {
40+
&self.signature
41+
}
42+
43+
fn return_type(&self, _arg_types: &[DataType]) -> DFResult<DataType> {
44+
Ok(DataType::UInt64)
45+
}
46+
47+
#[allow(clippy::as_conversions)]
48+
fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
49+
let ScalarFunctionArgs { args, .. } = args;
50+
51+
let arr = match &args[0] {
52+
ColumnarValue::Array(arr) => arr,
53+
ColumnarValue::Scalar(v) => &v.to_array()?,
54+
};
55+
56+
let strs = as_string_array(&arr)?;
57+
58+
let new_array = strs
59+
.iter()
60+
.map(|array_elem| array_elem.map(|value| value.trim_end_matches(' ').len() as u64))
61+
.collect::<UInt64Array>();
62+
63+
Ok(ColumnarValue::Array(Arc::new(new_array)))
64+
}
65+
}
66+
67+
super::macros::make_udf_function!(RTrimmedLengthFunc);
68+
69+
#[cfg(test)]
70+
mod tests {
71+
use super::*;
72+
use datafusion::prelude::SessionContext;
73+
use datafusion_common::assert_batches_eq;
74+
use datafusion_expr::ScalarUDF;
75+
76+
#[tokio::test]
77+
async fn test_it_works() -> DFResult<()> {
78+
let ctx = SessionContext::new();
79+
ctx.register_udf(ScalarUDF::from(RTrimmedLengthFunc::new()));
80+
81+
let create = "CREATE OR REPLACE TABLE test_strings (s STRING);";
82+
ctx.sql(create).await?.collect().await?;
83+
84+
let insert = r"
85+
INSERT INTO test_strings VALUES
86+
(' ABCD '),
87+
(' ABCDEFG'),
88+
('ABCDEFGH '),
89+
(' '),
90+
(''),
91+
('ABC'),
92+
(E'ABCDEFGH \t'),
93+
(E'ABCDEFGH \n'),
94+
(NULL);
95+
";
96+
ctx.sql(insert).await?.collect().await?;
97+
98+
let q = "SELECT RTRIMMED_LENGTH(s) FROM test_strings;";
99+
let result = ctx.sql(q).await?.collect().await?;
100+
101+
assert_batches_eq!(
102+
&[
103+
"+---------------------------------+",
104+
"| rtrimmed_length(test_strings.s) |",
105+
"+---------------------------------+",
106+
"| 6 |",
107+
"| 10 |",
108+
"| 8 |",
109+
"| 0 |",
110+
"| 0 |",
111+
"| 3 |",
112+
"| 11 |",
113+
"| 11 |",
114+
"| |",
115+
"+---------------------------------+",
116+
],
117+
&result
118+
);
119+
Ok(())
120+
}
121+
}

0 commit comments

Comments
 (0)