@@ -800,48 +800,51 @@ def test_split_filepath():
800
800
801
801
def test_update_stratified_split ():
802
802
803
- dataset = (
804
- Dataset .from_dataframe (pd .DataFrame (dict (
805
- index = np .arange (100 ),
806
- number = np .random .randn (100 ),
807
- stratify = np .random .randint (0 , 10 , 100 ),
808
- )))
809
- .map (tuple )
810
- )
803
+ for _ in range (5 ):
811
804
812
- filepath = Path ('tmp_test_split.json' )
805
+ dataset = (
806
+ Dataset .from_dataframe (pd .DataFrame (dict (
807
+ index = np .arange (100 ),
808
+ number = np .random .randn (100 ),
809
+ stratify1 = np .random .randint (0 , 10 , 100 ),
810
+ stratify2 = np .random .randint (0 , 10 , 100 ),
811
+ )))
812
+ .map (tuple )
813
+ )
813
814
814
- splits1 = (
815
- dataset
816
- .subset (lambda df : df ['index' ] < 50 )
817
- .split (
818
- key_column = 'index' ,
819
- proportions = dict (train = 0.8 , test = 0.2 ),
820
- filepath = filepath ,
821
- stratify_column = 'stratify' ,
815
+ filepath = Path ('tmp_test_split.json' )
816
+
817
+ splits1 = (
818
+ dataset
819
+ .subset (lambda df : df ['index' ] < 50 )
820
+ .split (
821
+ key_column = 'index' ,
822
+ proportions = dict (train = 0.8 , test = 0.2 ),
823
+ filepath = filepath ,
824
+ stratify_column = 'stratify1' ,
825
+ )
822
826
)
823
- )
824
827
825
- splits2 = (
826
- dataset
827
- .split (
828
- key_column = 'index' ,
829
- proportions = dict (train = 0.8 , test = 0.2 ),
830
- filepath = filepath ,
831
- stratify_column = 'stratify' ,
828
+ splits2 = (
829
+ dataset
830
+ .split (
831
+ key_column = 'index' ,
832
+ proportions = dict (train = 0.8 , test = 0.2 ),
833
+ filepath = filepath ,
834
+ stratify_column = 'stratify2' ,
835
+ )
832
836
)
833
- )
834
837
835
- assert (
836
- splits1 ['train' ].dataframe ['index' ]
837
- .isin (splits2 ['train' ].dataframe ['index' ])
838
- .all ()
839
- )
838
+ assert (
839
+ splits1 ['train' ].dataframe ['index' ]
840
+ .isin (splits2 ['train' ].dataframe ['index' ])
841
+ .all ()
842
+ )
840
843
841
- assert (
842
- splits1 ['test' ].dataframe ['index' ]
843
- .isin (splits2 ['test' ].dataframe ['index' ])
844
- .all ()
845
- )
844
+ assert (
845
+ splits1 ['test' ].dataframe ['index' ]
846
+ .isin (splits2 ['test' ].dataframe ['index' ])
847
+ .all ()
848
+ )
846
849
847
- filepath .unlink ()
850
+ filepath .unlink ()
0 commit comments