@@ -867,7 +867,8 @@ def test_groupby_bins(chunk_labels, kwargs, chunks, engine, method) -> None:
867
867
if chunk_labels :
868
868
labels = dask .array .from_array (labels , chunks = chunks )
869
869
870
- with raise_if_dask_computes ():
870
+ max_computes = 1 if method == "cohorts" else 0
871
+ with raise_if_dask_computes (max_computes ):
871
872
actual , * groups = groupby_reduce (
872
873
array , labels , func = "count" , fill_value = 0 , engine = engine , method = method , ** kwargs
873
874
)
@@ -1072,7 +1073,9 @@ def test_cohorts_nd_by(by_is_dask, func, method, axis, engine):
1072
1073
):
1073
1074
pytest .skip ()
1074
1075
if axis is not None and method != "map-reduce" :
1075
- pytest .xfail ()
1076
+ pytest .skip ()
1077
+ if by_is_dask and method == "blockwise" :
1078
+ pytest .skip ()
1076
1079
1077
1080
o = dask .array .ones ((3 ,), chunks = - 1 )
1078
1081
o2 = dask .array .ones ((2 , 3 ), chunks = - 1 )
@@ -1092,6 +1095,9 @@ def test_cohorts_nd_by(by_is_dask, func, method, axis, engine):
1092
1095
fill_value = - 123
1093
1096
1094
1097
kwargs = dict (func = func , engine = engine , method = method , axis = axis , fill_value = fill_value )
1098
+ if by_is_dask and axis is not None and method == "map-reduce" :
1099
+ kwargs ["expected_groups" ] = pd .Index ([1 , 2 , 3 , 4 , 30 , 31 , 40 ])
1100
+
1095
1101
if "quantile" in func :
1096
1102
kwargs ["finalize_kwargs" ] = {"q" : DEFAULT_QUANTILE }
1097
1103
actual , groups = groupby_reduce (array , by , ** kwargs )
@@ -1102,6 +1108,7 @@ def test_cohorts_nd_by(by_is_dask, func, method, axis, engine):
1102
1108
if isinstance (by , dask .array .Array ):
1103
1109
cache .clear ()
1104
1110
actual_cohorts = find_group_cohorts (by , array .chunks [- by .ndim :])
1111
+ cache .clear ()
1105
1112
expected_cohorts = find_group_cohorts (by .compute (), array .chunks [- by .ndim :])
1106
1113
assert actual_cohorts == expected_cohorts
1107
1114
# assert cache.nbytes
0 commit comments