Skip to content

Commit 60b9c42

Browse files
committed
fix: dataset.concat reworked to allow operations like subset and merge afterwards
1 parent d201f79 commit 60b9c42

File tree

3 files changed

+64
-5
lines changed

3 files changed

+64
-5
lines changed

.flake8

+4
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
[flake8]
2+
max-line-length = 80
3+
select = C,E,F,W,B,B950
4+
extend-ignore = E203, E501

datastream/dataset.py

+14-2
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
)
66
from pathlib import Path
77
from functools import lru_cache
8+
import string
9+
import random
810
import textwrap
911
import inspect
1012
import numpy as np
@@ -377,13 +379,23 @@ def get_item(dataframe, index):
377379
get_item=get_item,
378380
)
379381
else:
382+
dataset_column = (
383+
'__concat__'
384+
+ ''.join([random.choice(string.ascii_lowercase) for _ in range(8)])
385+
)
386+
387+
new_dataframe = pd.concat([dataset.dataframe for dataset in datasets])
388+
new_dataframe[dataset_column] = [
389+
from_concat_mapping(index)[0]
390+
for index in range(len(new_dataframe))
391+
]
380392

381393
def get_item(dataframe, index):
382-
dataset_index, _ = from_concat_mapping(index)
394+
dataset_index = int(dataframe.iloc[index][dataset_column])
383395
return datasets[dataset_index].get_item(dataframe, index)
384396

385397
return Dataset(
386-
dataframe=pd.concat([dataset.dataframe for dataset in datasets]),
398+
dataframe=new_dataframe,
387399
length=sum(map(len, datasets)),
388400
get_item=get_item,
389401
)

datastream/datastream.py

+46-3
Original file line numberDiff line numberDiff line change
@@ -527,7 +527,7 @@ def test_sequential_sampler():
527527
assert next(it) == ['a', 'b', 'c', 'a', 'b', 'c']
528528

529529

530-
def test_merge_concat():
530+
def test_concat_merge():
531531
dataset = Dataset.concat([
532532
Dataset.from_subscriptable([1, 2]),
533533
Dataset.from_subscriptable([1, 3, 5]),
@@ -536,8 +536,51 @@ def test_merge_concat():
536536
datastream = Datastream.merge([
537537
Datastream(dataset),
538538
Datastream(dataset.subset(
539-
lambda df: df["index"] <= 3
539+
lambda df: [index < 3 for index in range(len(df))]
540540
)),
541541
])
542542

543-
list(datastream)
543+
assert len(dataset.subset(
544+
lambda df: [index < 3 for index in range(len(df))]
545+
)) == 3
546+
547+
assert len(list(datastream)) == 6
548+
549+
550+
def test_combine_concat_merge():
551+
dataset = Dataset.concat([
552+
Dataset.zip([
553+
Dataset.from_subscriptable([1]),
554+
Dataset.from_subscriptable([2]),
555+
]),
556+
Dataset.combine([
557+
Dataset.from_subscriptable([3, 3]),
558+
Dataset.from_subscriptable([4, 4, 4]),
559+
]),
560+
])
561+
562+
datastream = Datastream.merge([
563+
Datastream(dataset),
564+
Datastream(Dataset.zip([
565+
Dataset.from_subscriptable([5]),
566+
Dataset.from_subscriptable([6]),
567+
])),
568+
])
569+
570+
assert len(list(datastream)) == 2
571+
572+
573+
def test_last_batch():
574+
from datastream.samplers import SequentialSampler
575+
576+
datastream = Datastream(
577+
Dataset.from_subscriptable(list('abc'))
578+
)
579+
assert list(map(len, datastream.data_loader(batch_size=4))) == [3]
580+
assert list(map(len, datastream.data_loader(batch_size=4, n_batches_per_epoch=2))) == [4, 4]
581+
582+
datastream = Datastream(
583+
Dataset.from_subscriptable(list('abc')),
584+
SequentialSampler(3),
585+
)
586+
assert list(map(len, datastream.data_loader(batch_size=2))) == [2, 1]

0 commit comments

Comments
 (0)