Skip to content

Commit 9471858

Browse files
author
FelixAbrahamsson
committed
feature: add sequentialsampler
1 parent 160a8cf commit 9471858

File tree

3 files changed

+41
-0
lines changed

3 files changed

+41
-0
lines changed

datastream/datastream.py

+13
Original file line numberDiff line numberDiff line change
@@ -505,3 +505,16 @@ def test_take():
505505
Datastream(Dataset.from_subscriptable(list('d'))),
506506
])
507507
assert len(list(datastream.take(2).data_loader(batch_size=1))) == 2
508+
509+
510+
def test_sequential_sampler():
511+
512+
from datastream.samplers import SequentialSampler
513+
514+
dataset = Dataset.from_subscriptable(list('abc'))
515+
datastream = Datastream(dataset, SequentialSampler(len(dataset))).take(2)
516+
assert len(list(datastream.data_loader(batch_size=1))) == 2
517+
518+
datastream = Datastream(dataset, SequentialSampler(len(dataset)))
519+
it = iter(datastream.data_loader(batch_size=6, n_batches_per_epoch=10))
520+
assert next(it) == ['a', 'b', 'c', 'a', 'b', 'c']

datastream/samplers/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from datastream.samplers.standard_sampler import StandardSampler
2+
from datastream.samplers.sequential_sampler import SequentialSampler
23
from datastream.samplers.merge_sampler import MergeSampler
34
from datastream.samplers.multi_sampler import MultiSampler
45
from datastream.samplers.repeat_sampler import RepeatSampler
+27
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
from __future__ import annotations
2+
from pydantic import BaseModel
3+
import torch
4+
5+
6+
class SequentialSampler(BaseModel, torch.utils.data.Sampler):
7+
sampler: torch.utils.data.SequentialSampler
8+
9+
class Config:
10+
arbitrary_types_allowed = True
11+
allow_mutation = False
12+
13+
def __init__(self, length):
14+
BaseModel.__init__(
15+
self,
16+
sampler=torch.utils.data.SequentialSampler(torch.ones(length))
17+
)
18+
19+
def __len__(self):
20+
return len(self.sampler)
21+
22+
def __iter__(self):
23+
return iter(self.sampler)
24+
25+
def sample_proportion(self, proportion):
26+
sampler = SequentialSampler(int(len(self) * proportion))
27+
return sampler

0 commit comments

Comments
 (0)