Skip to content

Commit 8f7d093

Browse files
committed
Fix first, last again
1 parent ebcd06c commit 8f7d093

File tree

1 file changed

+27
-7
lines changed

1 file changed

+27
-7
lines changed

flox/core.py

+27-7
Original file line numberDiff line numberDiff line change
@@ -179,7 +179,9 @@ def _is_minmax_reduction(func: T_Agg) -> bool:
179179

180180

181181
def _is_first_last_reduction(func: T_Agg) -> bool:
182-
return isinstance(func, str) and func in ["nanfirst", "nanlast", "first", "last"]
182+
if isinstance(func, Aggregation):
183+
func = func.name
184+
return func in ["nanfirst", "nanlast", "first", "last"]
183185

184186

185187
def _get_expected_groups(by: T_By, sort: bool) -> T_ExpectIndex:
@@ -680,6 +682,7 @@ def rechunk_for_blockwise(
680682
abs(max(newchunks) - max(chunks)) / max(chunks) < BLOCKWISE_RECHUNK_CHUNK_SIZE_THRESHOLD
681683
)
682684
):
685+
logger.debug("Rechunking to enable blockwise.")
683686
# Less than 25% change in number of chunks, let's do it
684687
return array.rechunk({axis: newchunks})
685688
else:
@@ -1668,7 +1671,12 @@ def dask_groupby_agg(
16681671
# This allows us to discover groups at compute time, support argreductions, lower intermediate
16691672
# memory usage (but method="cohorts" would also work to reduce memory in some cases)
16701673
labels_are_unknown = is_duck_dask_array(by_input) and expected_groups is None
1671-
do_simple_combine = not _is_arg_reduction(agg) and not labels_are_unknown
1674+
do_grouped_combine = (
1675+
_is_arg_reduction(agg)
1676+
or labels_are_unknown
1677+
or (_is_first_last_reduction(agg) and array.dtype.kind != "f")
1678+
)
1679+
do_simple_combine = not do_grouped_combine
16721680

16731681
if method == "blockwise":
16741682
# use the "non dask" code path, but applied blockwise
@@ -2012,8 +2020,13 @@ def _validate_reindex(
20122020
expected_groups,
20132021
any_by_dask: bool,
20142022
is_dask_array: bool,
2023+
array_dtype: Any,
20152024
) -> bool | None:
20162025
# logger.debug("Entering _validate_reindex: reindex is {}".format(reindex)) # noqa
2026+
def first_or_last():
2027+
return func in ["first", "last"] or (
2028+
_is_first_last_reduction(func) and array_dtype.kind != "f"
2029+
)
20172030

20182031
all_numpy = not is_dask_array and not any_by_dask
20192032
if reindex is True and not all_numpy:
@@ -2023,7 +2036,7 @@ def _validate_reindex(
20232036
raise ValueError(
20242037
"reindex=True is not a valid choice for method='blockwise' or method='cohorts'."
20252038
)
2026-
if func in ["first", "last"]:
2039+
if first_or_last():
20272040
raise ValueError("reindex must be None or False when func is 'first' or 'last.")
20282041

20292042
if reindex is None:
@@ -2034,9 +2047,10 @@ def _validate_reindex(
20342047
if all_numpy:
20352048
return True
20362049

2037-
if func in ["first", "last"]:
2050+
if first_or_last():
20382051
# have to do the grouped_combine since there's no good fill_value
2039-
reindex = False
2052+
# Also needed for nanfirst, nanlast with no-NaN dtypes
2053+
return False
20402054

20412055
if method == "blockwise":
20422056
# for grouping by dask arrays, we set reindex=True
@@ -2439,7 +2453,13 @@ def groupby_reduce(
24392453
raise ValueError(f"method={method!r} can only be used when grouping by numpy arrays.")
24402454

24412455
reindex = _validate_reindex(
2442-
reindex, func, method, expected_groups, any_by_dask, is_duck_dask_array(array)
2456+
reindex,
2457+
func,
2458+
method,
2459+
expected_groups,
2460+
any_by_dask,
2461+
is_duck_dask_array(array),
2462+
array.dtype,
24432463
)
24442464

24452465
if not is_duck_array(array):
@@ -2638,7 +2658,7 @@ def groupby_reduce(
26382658

26392659
# TODO: clean this up
26402660
reindex = _validate_reindex(
2641-
reindex, func, method, expected_, any_by_dask, is_duck_dask_array(array)
2661+
reindex, func, method, expected_, any_by_dask, is_duck_dask_array(array), array.dtype
26422662
)
26432663

26442664
if TYPE_CHECKING:

0 commit comments

Comments
 (0)