@@ -639,7 +639,9 @@ def rechunk_for_cohorts(
639
639
return array .rechunk ({axis : newchunks })
640
640
641
641
642
- def rechunk_for_blockwise (array : DaskArray , axis : T_Axis , labels : np .ndarray ) -> DaskArray :
642
+ def rechunk_for_blockwise (
643
+ array : DaskArray , axis : T_Axis , labels : np .ndarray , * , force : bool = True
644
+ ) -> DaskArray :
643
645
"""
644
646
Rechunks array so that group boundaries line up with chunk boundaries, allowing
645
647
embarrassingly parallel group reductions.
@@ -672,11 +674,16 @@ def rechunk_for_blockwise(array: DaskArray, axis: T_Axis, labels: np.ndarray) ->
672
674
return array
673
675
674
676
Δn = abs (len (newchunks ) - len (chunks ))
675
- if (Δn / len (chunks ) < BLOCKWISE_RECHUNK_NUM_CHUNKS_THRESHOLD ) and (
676
- abs (max (newchunks ) - max (chunks )) / max (chunks ) < BLOCKWISE_RECHUNK_CHUNK_SIZE_THRESHOLD
677
+ if force or (
678
+ (Δn / len (chunks ) < BLOCKWISE_RECHUNK_NUM_CHUNKS_THRESHOLD )
679
+ and (
680
+ abs (max (newchunks ) - max (chunks )) / max (chunks ) < BLOCKWISE_RECHUNK_CHUNK_SIZE_THRESHOLD
681
+ )
677
682
):
678
683
# Less than 25% change in number of chunks, let's do it
679
684
return array .rechunk ({axis : newchunks })
685
+ else :
686
+ return array
680
687
681
688
682
689
def reindex_ (
@@ -2496,7 +2503,7 @@ def groupby_reduce(
2496
2503
):
2497
2504
# Let's try rechunking for sorted 1D by.
2498
2505
(single_axis ,) = axis_
2499
- array = rechunk_for_blockwise (array , single_axis , by_ )
2506
+ array = rechunk_for_blockwise (array , single_axis , by_ , force = False )
2500
2507
2501
2508
if _is_first_last_reduction (func ):
2502
2509
if has_dask and nax != 1 :
0 commit comments