Skip to content

Commit 0930c14

Browse files
author
FelixAbrahamsson
committed
improve: continue split even if we cant guarantee correct stratification
1 parent 294c962 commit 0930c14

File tree

2 files changed

+41
-38
lines changed

2 files changed

+41
-38
lines changed

datastream/dataset.py

+40-37
Original file line numberDiff line numberDiff line change
@@ -800,48 +800,51 @@ def test_split_filepath():
800800

801801
def test_update_stratified_split():
802802

803-
dataset = (
804-
Dataset.from_dataframe(pd.DataFrame(dict(
805-
index=np.arange(100),
806-
number=np.random.randn(100),
807-
stratify=np.random.randint(0, 10, 100),
808-
)))
809-
.map(tuple)
810-
)
803+
for _ in range(5):
811804

812-
filepath = Path('tmp_test_split.json')
805+
dataset = (
806+
Dataset.from_dataframe(pd.DataFrame(dict(
807+
index=np.arange(100),
808+
number=np.random.randn(100),
809+
stratify1=np.random.randint(0, 10, 100),
810+
stratify2=np.random.randint(0, 10, 100),
811+
)))
812+
.map(tuple)
813+
)
813814

814-
splits1 = (
815-
dataset
816-
.subset(lambda df: df['index'] < 50)
817-
.split(
818-
key_column='index',
819-
proportions=dict(train=0.8, test=0.2),
820-
filepath=filepath,
821-
stratify_column='stratify',
815+
filepath = Path('tmp_test_split.json')
816+
817+
splits1 = (
818+
dataset
819+
.subset(lambda df: df['index'] < 50)
820+
.split(
821+
key_column='index',
822+
proportions=dict(train=0.8, test=0.2),
823+
filepath=filepath,
824+
stratify_column='stratify1',
825+
)
822826
)
823-
)
824827

825-
splits2 = (
826-
dataset
827-
.split(
828-
key_column='index',
829-
proportions=dict(train=0.8, test=0.2),
830-
filepath=filepath,
831-
stratify_column='stratify',
828+
splits2 = (
829+
dataset
830+
.split(
831+
key_column='index',
832+
proportions=dict(train=0.8, test=0.2),
833+
filepath=filepath,
834+
stratify_column='stratify2',
835+
)
832836
)
833-
)
834837

835-
assert (
836-
splits1['train'].dataframe['index']
837-
.isin(splits2['train'].dataframe['index'])
838-
.all()
839-
)
838+
assert (
839+
splits1['train'].dataframe['index']
840+
.isin(splits2['train'].dataframe['index'])
841+
.all()
842+
)
840843

841-
assert (
842-
splits1['test'].dataframe['index']
843-
.isin(splits2['test'].dataframe['index'])
844-
.all()
845-
)
844+
assert (
845+
splits1['test'].dataframe['index']
846+
.isin(splits2['test'].dataframe['index'])
847+
.all()
848+
)
846849

847-
filepath.unlink()
850+
filepath.unlink()

datastream/tools/split_dataframes.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ def split_proportion(
8989
else:
9090
split = previous_split
9191
split[split_name] += selected(
92-
n_target_split_ - n_previous_split,
92+
min(n_target_split_ - n_previous_split, len(unassigned_)),
9393
unassigned_,
9494
)
9595
return split

0 commit comments

Comments
 (0)