@@ -527,7 +527,7 @@ def test_sequential_sampler():
527
527
assert next (it ) == ['a' , 'b' , 'c' , 'a' , 'b' , 'c' ]
528
528
529
529
530
- def test_merge_concat ():
530
+ def test_concat_merge ():
531
531
dataset = Dataset .concat ([
532
532
Dataset .from_subscriptable ([1 , 2 ]),
533
533
Dataset .from_subscriptable ([1 , 3 , 5 ]),
@@ -536,8 +536,51 @@ def test_merge_concat():
536
536
datastream = Datastream .merge ([
537
537
Datastream (dataset ),
538
538
Datastream (dataset .subset (
539
- lambda df : df [ " index" ] <= 3
539
+ lambda df : [ index < 3 for index in range ( len ( df ))]
540
540
)),
541
541
])
542
542
543
- list (datastream )
543
+ assert len (dataset .subset (
544
+ lambda df : [index < 3 for index in range (len (df ))]
545
+ )) == 3
546
+
547
+ assert len (list (datastream )) == 6
548
+
549
+
550
+ def test_combine_concat_merge ():
551
+ dataset = Dataset .concat ([
552
+ Dataset .zip ([
553
+ Dataset .from_subscriptable ([1 ]),
554
+ Dataset .from_subscriptable ([2 ]),
555
+ ]),
556
+ Dataset .combine ([
557
+ Dataset .from_subscriptable ([3 , 3 ]),
558
+ Dataset .from_subscriptable ([4 , 4 , 4 ]),
559
+ ]),
560
+ ])
561
+
562
+ datastream = Datastream .merge ([
563
+ Datastream (dataset ),
564
+ Datastream (Dataset .zip ([
565
+ Dataset .from_subscriptable ([5 ]),
566
+ Dataset .from_subscriptable ([6 ]),
567
+ ])),
568
+ ])
569
+
570
+ assert len (list (datastream )) == 2
571
+
572
+
573
+ def test_last_batch ():
574
+ from datastream .samplers import SequentialSampler
575
+
576
+ datastream = Datastream (
577
+ Dataset .from_subscriptable (list ('abc' ))
578
+ )
579
+ assert list (map (len , datastream .data_loader (batch_size = 4 ))) == [3 ]
580
+ assert list (map (len , datastream .data_loader (batch_size = 4 , n_batches_per_epoch = 2 ))) == [4 , 4 ]
581
+
582
+ datastream = Datastream (
583
+ Dataset .from_subscriptable (list ('abc' )),
584
+ SequentialSampler (3 ),
585
+ )
586
+ assert list (map (len , datastream .data_loader (batch_size = 2 ))) == [2 , 1 ]
0 commit comments