Skip to content

Commit a7a74fa

Browse files
authored
Safeguard against potential inexact row count being smaller than exact null count (#9007)
* Safeguard against potential inexact row count being smaller than exact null count * Add test hitting the former overflow panic
1 parent ed24539 commit a7a74fa

File tree

2 files changed

+120
-91
lines changed

2 files changed

+120
-91
lines changed

datafusion/common/src/stats.rs

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -48,14 +48,15 @@ impl<T: Debug + Clone + PartialEq + Eq + PartialOrd> Precision<T> {
4848

4949
/// Transform the value in this [`Precision`] object, if one exists, using
5050
/// the given function. Preserves the exactness state.
51-
pub fn map<F>(self, f: F) -> Precision<T>
51+
pub fn map<U, F>(self, f: F) -> Precision<U>
5252
where
53-
F: Fn(T) -> T,
53+
F: Fn(T) -> U,
54+
U: Debug + Clone + PartialEq + Eq + PartialOrd,
5455
{
5556
match self {
5657
Precision::Exact(val) => Precision::Exact(f(val)),
5758
Precision::Inexact(val) => Precision::Inexact(f(val)),
58-
_ => self,
59+
_ => Precision::<U>::Absent,
5960
}
6061
}
6162

datafusion/physical-plan/src/joins/utils.rs

Lines changed: 116 additions & 88 deletions
Original file line numberDiff line numberDiff line change
@@ -955,7 +955,12 @@ fn max_distinct_count(
955955
let result = match num_rows {
956956
Precision::Absent => Precision::Absent,
957957
Precision::Inexact(count) => {
958-
Precision::Inexact(count - stats.null_count.get_value().unwrap_or(&0))
958+
// To safeguard against inexact number of rows (e.g. 0) being smaller than
959+
// an exact null count we need to do a checked subtraction.
960+
match count.checked_sub(*stats.null_count.get_value().unwrap_or(&0)) {
961+
None => Precision::Inexact(0),
962+
Some(non_null_count) => Precision::Inexact(non_null_count),
963+
}
959964
}
960965
Precision::Exact(count) => {
961966
let count = count - stats.null_count.get_value().unwrap_or(&0);
@@ -1468,6 +1473,7 @@ mod tests {
14681473
use arrow::error::{ArrowError, Result as ArrowResult};
14691474
use arrow_schema::SortOptions;
14701475

1476+
use datafusion_common::stats::Precision::{Absent, Exact, Inexact};
14711477
use datafusion_common::{arrow_datafusion_err, arrow_err, ScalarValue};
14721478

14731479
fn check(left: &[Column], right: &[Column], on: &[(Column, Column)]) -> Result<()> {
@@ -1635,25 +1641,26 @@ mod tests {
16351641
}
16361642

16371643
fn create_column_stats(
1638-
min: Option<i64>,
1639-
max: Option<i64>,
1640-
distinct_count: Option<usize>,
1644+
min: Precision<i64>,
1645+
max: Precision<i64>,
1646+
distinct_count: Precision<usize>,
1647+
null_count: Precision<usize>,
16411648
) -> ColumnStatistics {
16421649
ColumnStatistics {
1643-
distinct_count: distinct_count
1644-
.map(Precision::Inexact)
1645-
.unwrap_or(Precision::Absent),
1646-
min_value: min
1647-
.map(|size| Precision::Inexact(ScalarValue::from(size)))
1648-
.unwrap_or(Precision::Absent),
1649-
max_value: max
1650-
.map(|size| Precision::Inexact(ScalarValue::from(size)))
1651-
.unwrap_or(Precision::Absent),
1652-
..Default::default()
1650+
distinct_count,
1651+
min_value: min.map(ScalarValue::from),
1652+
max_value: max.map(ScalarValue::from),
1653+
null_count,
16531654
}
16541655
}
16551656

1656-
type PartialStats = (usize, Option<i64>, Option<i64>, Option<usize>);
1657+
type PartialStats = (
1658+
usize,
1659+
Precision<i64>,
1660+
Precision<i64>,
1661+
Precision<usize>,
1662+
Precision<usize>,
1663+
);
16571664

16581665
// This is mainly for validating the all edge cases of the estimation, but
16591666
// more advanced (and real world test cases) are below where we need some control
@@ -1670,133 +1677,156 @@ mod tests {
16701677
//
16711678
// distinct(left) == NaN, distinct(right) == NaN
16721679
(
1673-
(10, Some(1), Some(10), None),
1674-
(10, Some(1), Some(10), None),
1675-
Some(Precision::Inexact(10)),
1680+
(10, Inexact(1), Inexact(10), Absent, Absent),
1681+
(10, Inexact(1), Inexact(10), Absent, Absent),
1682+
Some(Inexact(10)),
16761683
),
16771684
// range(left) > range(right)
16781685
(
1679-
(10, Some(6), Some(10), None),
1680-
(10, Some(8), Some(10), None),
1681-
Some(Precision::Inexact(20)),
1686+
(10, Inexact(6), Inexact(10), Absent, Absent),
1687+
(10, Inexact(8), Inexact(10), Absent, Absent),
1688+
Some(Inexact(20)),
16821689
),
16831690
// range(right) > range(left)
16841691
(
1685-
(10, Some(8), Some(10), None),
1686-
(10, Some(6), Some(10), None),
1687-
Some(Precision::Inexact(20)),
1692+
(10, Inexact(8), Inexact(10), Absent, Absent),
1693+
(10, Inexact(6), Inexact(10), Absent, Absent),
1694+
Some(Inexact(20)),
16881695
),
16891696
// range(left) > len(left), range(right) > len(right)
16901697
(
1691-
(10, Some(1), Some(15), None),
1692-
(20, Some(1), Some(40), None),
1693-
Some(Precision::Inexact(10)),
1698+
(10, Inexact(1), Inexact(15), Absent, Absent),
1699+
(20, Inexact(1), Inexact(40), Absent, Absent),
1700+
Some(Inexact(10)),
16941701
),
16951702
// When we have distinct count.
16961703
(
1697-
(10, Some(1), Some(10), Some(10)),
1698-
(10, Some(1), Some(10), Some(10)),
1699-
Some(Precision::Inexact(10)),
1704+
(10, Inexact(1), Inexact(10), Inexact(10), Absent),
1705+
(10, Inexact(1), Inexact(10), Inexact(10), Absent),
1706+
Some(Inexact(10)),
17001707
),
17011708
// distinct(left) > distinct(right)
17021709
(
1703-
(10, Some(1), Some(10), Some(5)),
1704-
(10, Some(1), Some(10), Some(2)),
1705-
Some(Precision::Inexact(20)),
1710+
(10, Inexact(1), Inexact(10), Inexact(5), Absent),
1711+
(10, Inexact(1), Inexact(10), Inexact(2), Absent),
1712+
Some(Inexact(20)),
17061713
),
17071714
// distinct(right) > distinct(left)
17081715
(
1709-
(10, Some(1), Some(10), Some(2)),
1710-
(10, Some(1), Some(10), Some(5)),
1711-
Some(Precision::Inexact(20)),
1716+
(10, Inexact(1), Inexact(10), Inexact(2), Absent),
1717+
(10, Inexact(1), Inexact(10), Inexact(5), Absent),
1718+
Some(Inexact(20)),
17121719
),
17131720
// min(left) < 0 (range(left) > range(right))
17141721
(
1715-
(10, Some(-5), Some(5), None),
1716-
(10, Some(1), Some(5), None),
1717-
Some(Precision::Inexact(10)),
1722+
(10, Inexact(-5), Inexact(5), Absent, Absent),
1723+
(10, Inexact(1), Inexact(5), Absent, Absent),
1724+
Some(Inexact(10)),
17181725
),
17191726
// min(right) < 0, max(right) < 0 (range(right) > range(left))
17201727
(
1721-
(10, Some(-25), Some(-20), None),
1722-
(10, Some(-25), Some(-15), None),
1723-
Some(Precision::Inexact(10)),
1728+
(10, Inexact(-25), Inexact(-20), Absent, Absent),
1729+
(10, Inexact(-25), Inexact(-15), Absent, Absent),
1730+
Some(Inexact(10)),
17241731
),
17251732
// range(left) < 0, range(right) >= 0
17261733
// (there isn't a case where both left and right ranges are negative
17271734
// so one of them is always going to work, this just proves negative
17281735
// ranges with bigger absolute values are not are not accidentally used).
17291736
(
1730-
(10, Some(-10), Some(0), None),
1731-
(10, Some(0), Some(10), Some(5)),
1732-
Some(Precision::Inexact(10)),
1737+
(10, Inexact(-10), Inexact(0), Absent, Absent),
1738+
(10, Inexact(0), Inexact(10), Inexact(5), Absent),
1739+
Some(Inexact(10)),
17331740
),
17341741
// range(left) = 1, range(right) = 1
17351742
(
1736-
(10, Some(1), Some(1), None),
1737-
(10, Some(1), Some(1), None),
1738-
Some(Precision::Inexact(100)),
1743+
(10, Inexact(1), Inexact(1), Absent, Absent),
1744+
(10, Inexact(1), Inexact(1), Absent, Absent),
1745+
Some(Inexact(100)),
17391746
),
17401747
//
17411748
// Edge cases
17421749
// ==========
17431750
//
17441751
// No column level stats.
1745-
((10, None, None, None), (10, None, None, None), None),
1752+
(
1753+
(10, Absent, Absent, Absent, Absent),
1754+
(10, Absent, Absent, Absent, Absent),
1755+
None,
1756+
),
17461757
// No min or max (or both).
1747-
((10, None, None, Some(3)), (10, None, None, Some(3)), None),
17481758
(
1749-
(10, Some(2), None, Some(3)),
1750-
(10, None, Some(5), Some(3)),
1759+
(10, Absent, Absent, Inexact(3), Absent),
1760+
(10, Absent, Absent, Inexact(3), Absent),
1761+
None,
1762+
),
1763+
(
1764+
(10, Inexact(2), Absent, Inexact(3), Absent),
1765+
(10, Absent, Inexact(5), Inexact(3), Absent),
17511766
None,
17521767
),
17531768
(
1754-
(10, None, Some(3), Some(3)),
1755-
(10, Some(1), None, Some(3)),
1769+
(10, Absent, Inexact(3), Inexact(3), Absent),
1770+
(10, Inexact(1), Absent, Inexact(3), Absent),
1771+
None,
1772+
),
1773+
(
1774+
(10, Absent, Inexact(3), Absent, Absent),
1775+
(10, Inexact(1), Absent, Absent, Absent),
17561776
None,
17571777
),
1758-
((10, None, Some(3), None), (10, Some(1), None, None), None),
17591778
// Non overlapping min/max (when exact=False).
17601779
(
1761-
(10, Some(0), Some(10), None),
1762-
(10, Some(11), Some(20), None),
1763-
Some(Precision::Inexact(0)),
1780+
(10, Inexact(0), Inexact(10), Absent, Absent),
1781+
(10, Inexact(11), Inexact(20), Absent, Absent),
1782+
Some(Inexact(0)),
17641783
),
17651784
(
1766-
(10, Some(11), Some(20), None),
1767-
(10, Some(0), Some(10), None),
1768-
Some(Precision::Inexact(0)),
1785+
(10, Inexact(11), Inexact(20), Absent, Absent),
1786+
(10, Inexact(0), Inexact(10), Absent, Absent),
1787+
Some(Inexact(0)),
17691788
),
17701789
// distinct(left) = 0, distinct(right) = 0
17711790
(
1772-
(10, Some(1), Some(10), Some(0)),
1773-
(10, Some(1), Some(10), Some(0)),
1791+
(10, Inexact(1), Inexact(10), Inexact(0), Absent),
1792+
(10, Inexact(1), Inexact(10), Inexact(0), Absent),
17741793
None,
17751794
),
1795+
// Inexact row count < exact null count with absent distinct count
1796+
(
1797+
(0, Inexact(1), Inexact(10), Absent, Exact(5)),
1798+
(10, Inexact(1), Inexact(10), Absent, Absent),
1799+
Some(Inexact(0)),
1800+
),
17761801
];
17771802

17781803
for (left_info, right_info, expected_cardinality) in cases {
17791804
let left_num_rows = left_info.0;
1780-
let left_col_stats =
1781-
vec![create_column_stats(left_info.1, left_info.2, left_info.3)];
1805+
let left_col_stats = vec![create_column_stats(
1806+
left_info.1,
1807+
left_info.2,
1808+
left_info.3,
1809+
left_info.4,
1810+
)];
17821811

17831812
let right_num_rows = right_info.0;
17841813
let right_col_stats = vec![create_column_stats(
17851814
right_info.1,
17861815
right_info.2,
17871816
right_info.3,
1817+
right_info.4,
17881818
)];
17891819

17901820
assert_eq!(
17911821
estimate_inner_join_cardinality(
17921822
Statistics {
1793-
num_rows: Precision::Inexact(left_num_rows),
1794-
total_byte_size: Precision::Absent,
1823+
num_rows: Inexact(left_num_rows),
1824+
total_byte_size: Absent,
17951825
column_statistics: left_col_stats.clone(),
17961826
},
17971827
Statistics {
1798-
num_rows: Precision::Inexact(right_num_rows),
1799-
total_byte_size: Precision::Absent,
1828+
num_rows: Inexact(right_num_rows),
1829+
total_byte_size: Absent,
18001830
column_statistics: right_col_stats.clone(),
18011831
},
18021832
),
@@ -1814,9 +1844,7 @@ mod tests {
18141844
);
18151845

18161846
assert_eq!(
1817-
partial_join_stats
1818-
.clone()
1819-
.map(|s| Precision::Inexact(s.num_rows)),
1847+
partial_join_stats.clone().map(|s| Inexact(s.num_rows)),
18201848
expected_cardinality.clone()
18211849
);
18221850
assert_eq!(
@@ -1832,13 +1860,13 @@ mod tests {
18321860
#[test]
18331861
fn test_inner_join_cardinality_multiple_column() -> Result<()> {
18341862
let left_col_stats = vec![
1835-
create_column_stats(Some(0), Some(100), Some(100)),
1836-
create_column_stats(Some(100), Some(500), Some(150)),
1863+
create_column_stats(Inexact(0), Inexact(100), Inexact(100), Absent),
1864+
create_column_stats(Inexact(100), Inexact(500), Inexact(150), Absent),
18371865
];
18381866

18391867
let right_col_stats = vec![
1840-
create_column_stats(Some(0), Some(100), Some(50)),
1841-
create_column_stats(Some(100), Some(500), Some(200)),
1868+
create_column_stats(Inexact(0), Inexact(100), Inexact(50), Absent),
1869+
create_column_stats(Inexact(100), Inexact(500), Inexact(200), Absent),
18421870
];
18431871

18441872
// We have statistics about 4 columns, where the highest distinct
@@ -1916,15 +1944,15 @@ mod tests {
19161944
];
19171945

19181946
let left_col_stats = vec![
1919-
create_column_stats(Some(0), Some(100), Some(100)),
1920-
create_column_stats(Some(0), Some(500), Some(500)),
1921-
create_column_stats(Some(1000), Some(10000), None),
1947+
create_column_stats(Inexact(0), Inexact(100), Inexact(100), Absent),
1948+
create_column_stats(Inexact(0), Inexact(500), Inexact(500), Absent),
1949+
create_column_stats(Inexact(1000), Inexact(10000), Absent, Absent),
19221950
];
19231951

19241952
let right_col_stats = vec![
1925-
create_column_stats(Some(0), Some(100), Some(50)),
1926-
create_column_stats(Some(0), Some(2000), Some(2500)),
1927-
create_column_stats(Some(0), Some(100), None),
1953+
create_column_stats(Inexact(0), Inexact(100), Inexact(50), Absent),
1954+
create_column_stats(Inexact(0), Inexact(2000), Inexact(2500), Absent),
1955+
create_column_stats(Inexact(0), Inexact(100), Absent, Absent),
19281956
];
19291957

19301958
for (join_type, expected_num_rows) in cases {
@@ -1965,15 +1993,15 @@ mod tests {
19651993
// Join on a=c, x=y (ignores b/d) where x and y does not intersect
19661994

19671995
let left_col_stats = vec![
1968-
create_column_stats(Some(0), Some(100), Some(100)),
1969-
create_column_stats(Some(0), Some(500), Some(500)),
1970-
create_column_stats(Some(1000), Some(10000), None),
1996+
create_column_stats(Inexact(0), Inexact(100), Inexact(100), Absent),
1997+
create_column_stats(Inexact(0), Inexact(500), Inexact(500), Absent),
1998+
create_column_stats(Inexact(1000), Inexact(10000), Absent, Absent),
19711999
];
19722000

19732001
let right_col_stats = vec![
1974-
create_column_stats(Some(0), Some(100), Some(50)),
1975-
create_column_stats(Some(0), Some(2000), Some(2500)),
1976-
create_column_stats(Some(0), Some(100), None),
2002+
create_column_stats(Inexact(0), Inexact(100), Inexact(50), Absent),
2003+
create_column_stats(Inexact(0), Inexact(2000), Inexact(2500), Absent),
2004+
create_column_stats(Inexact(0), Inexact(100), Absent, Absent),
19772005
];
19782006

19792007
let join_on = vec![

0 commit comments

Comments
 (0)