Skip to content

Commit c39f769

Browse files
committedOct 25, 2022
chore: reformat black and fix build
1 parent 5bd8f9a commit c39f769

21 files changed

+994
-881
lines changed
 

‎.github/workflows/publish.yml

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -9,40 +9,44 @@ jobs:
99
runs-on: ubuntu-latest
1010
strategy:
1111
matrix:
12-
python-version: [3.7, 3.8, 3.9]
12+
python-version: [3.8]
1313
steps:
1414
- uses: actions/checkout@v2
1515
- name: Set up Python ${{ matrix.python-version }}
1616
uses: actions/setup-python@v2
1717
with:
1818
python-version: ${{ matrix.python-version }}
1919

20+
- name: Install poetry
21+
run: |
22+
python -m pip install --upgrade pip
23+
curl -sSL https://install.python-poetry.org | python - --version 1.2.2
24+
echo "${HOME}/.local/bin" >> $GITHUB_PATH
25+
2026
- name: Install dependencies
2127
run: |
22-
curl -sSL https://install.python-poetry.org | python - --version 1.2.1
23-
$HOME/.local/bin/poetry install --no-root
28+
poetry install --no-root
2429
2530
- name: Run tests
2631
run: |
27-
$HOME/.local/bin/poetry run pytest
32+
poetry run pytest
2833
2934
- name: Build wheels
3035
run: |
31-
$HOME/.local/bin/poetry version $(git tag --points-at HEAD)
32-
$HOME/.local/bin/poetry build
36+
poetry version $(git tag --points-at HEAD)
37+
poetry build
3338
3439
- name: Test install package
3540
run: |
36-
mkdir test_install
37-
cd test_install
38-
$HOME/.local/bin/poetry init
39-
$HOME/.local/bin/poetry add ../dist/$(ls dist/*.whl)
41+
poetry new test-install
42+
cd test-install
43+
poetry add ../dist/$(ls ../dist/*.whl)
4044
41-
$HOME/.local/bin/poetry run python -c "import datastream"
45+
poetry run python -c "import datastream"
4246
4347
- name: Upload
4448
env:
4549
USERNAME: __token__
4650
PASSWORD: ${{ secrets.PYPI_TOKEN }}
4751
run: |
48-
$HOME/.local/bin/poetry publish --username=$USERNAME --password=$PASSWORD
52+
poetry publish --username=$USERNAME --password=$PASSWORD

‎.github/workflows/test.yml

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,18 +24,23 @@ jobs:
2424
${{ runner.os }}-pip-
2525
${{ runner.os }}-
2626
27+
- name: Install poetry
28+
run: |
29+
python -m pip install --upgrade pip
30+
curl -sSL https://install.python-poetry.org | python - --version 1.2.2
31+
echo "${HOME}/.local/bin" >> $GITHUB_PATH
32+
2733
- name: Install dependencies
2834
run: |
29-
curl -sSL https://install.python-poetry.org | python - --version 1.2.1
30-
$HOME/.local/bin/poetry install install
35+
poetry install
3136
3237
- name: Run tests
3338
run: |
34-
$HOME/.local/bin/poetry install run pytest
39+
poetry run pytest
3540
3641
- name: Build wheels
3742
run: |
38-
$HOME/.local/bin/poetry install build
43+
poetry build
3944
4045
build-docs:
4146
runs-on: ubuntu-latest

‎README.rst

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -110,9 +110,3 @@ Install from source
110110
===================
111111

112112
.. pip install -e .
113-
114-
To patch the code locally for `Python 3.6` run `patch-python3.6.sh`.
115-
116-
.. code-block:: bash
117-
118-
$ ./patch-python3.6.sh

‎datastream/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,8 @@
22
from datastream.datastream import Datastream
33

44
from pkg_resources import get_distribution, DistributionNotFound
5+
56
try:
6-
__version__ = get_distribution('pytorch-datastream').version
7+
__version__ = get_distribution("pytorch-datastream").version
78
except DistributionNotFound:
89
pass

‎datastream/dataset.py

Lines changed: 239 additions & 218 deletions
Large diffs are not rendered by default.

‎datastream/datastream.py

Lines changed: 167 additions & 174 deletions
Large diffs are not rendered by default.

‎datastream/samplers/merge_sampler.py

Lines changed: 18 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,7 @@ def __init__(self, samplers, datasets, ns):
2828
ns=ns,
2929
length=MergeSampler.merged_samplers_length(samplers, ns),
3030
from_mapping=Dataset.create_from_concat_mapping(datasets),
31-
merged_samplers=MergeSampler.merge_samplers(
32-
samplers, datasets, ns
33-
),
31+
merged_samplers=MergeSampler.merge_samplers(samplers, datasets, ns),
3432
)
3533

3634
def __len__(self):
@@ -41,10 +39,7 @@ def __iter__(self):
4139

4240
@staticmethod
4341
def merged_samplers_length(samplers, ns):
44-
return (
45-
min([len(sampler) / n for sampler, n in zip(samplers, ns)])
46-
* sum(ns)
47-
)
42+
return min([len(sampler) / n for sampler, n in zip(samplers, ns)]) * sum(ns)
4843

4944
@staticmethod
5045
def merge_samplers(samplers, datasets, ns):
@@ -54,13 +49,18 @@ def batch(iterable, n):
5449
while True:
5550
yield [next(iterable) for _ in range(n)]
5651

57-
index_batch = zip(*[
58-
batch(map(
59-
partial(to_mapping, dataset_index),
60-
repeat_map_chain(iter, sampler),
61-
), n)
62-
for dataset_index, (sampler, n) in enumerate(zip(samplers, ns))
63-
])
52+
index_batch = zip(
53+
*[
54+
batch(
55+
map(
56+
partial(to_mapping, dataset_index),
57+
repeat_map_chain(iter, sampler),
58+
),
59+
n,
60+
)
61+
for dataset_index, (sampler, n) in enumerate(zip(samplers, ns))
62+
]
63+
)
6464

6565
return chain.from_iterable(chain.from_iterable(index_batch))
6666

@@ -74,25 +74,18 @@ def update_weights_(self, function):
7474

7575
def update_example_weight_(self, weight, index):
7676
dataset_index, inner_index = self.from_mapping(index)
77-
self.samplers[dataset_index].update_example_weight_(
78-
weight, inner_index
79-
)
77+
self.samplers[dataset_index].update_example_weight_(weight, inner_index)
8078

8179
def sample_proportion(self, proportion):
8280
return MergeSampler(
83-
[
84-
sampler.sample_proportion(proportion)
85-
for sampler in self.samplers
86-
],
81+
[sampler.sample_proportion(proportion) for sampler in self.samplers],
8782
self.datasets,
8883
self.ns,
8984
)
9085

9186
def state_dict(self):
92-
return dict(
93-
samplers=[sampler.state_dict() for sampler in self.samplers]
94-
)
87+
return dict(samplers=[sampler.state_dict() for sampler in self.samplers])
9588

9689
def load_state_dict(self, state_dict):
97-
for sampler, state_dict in zip(self.samplers, state_dict['samplers']):
90+
for sampler, state_dict in zip(self.samplers, state_dict["samplers"]):
9891
sampler.load_state_dict(state_dict)

‎datastream/samplers/multi_sampler.py

Lines changed: 12 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ def __init__(self, samplers, dataset):
2828
merged_samplers=MultiSampler.merge_samplers(
2929
samplers,
3030
[1 for _ in samplers],
31-
)
31+
),
3232
)
3333

3434
@staticmethod
@@ -50,10 +50,12 @@ def batch(iterable, n):
5050
while True:
5151
yield [next(iterable) for _ in range(n)]
5252

53-
index_batch = zip(*[
54-
batch(repeat_map_chain(iter, sampler), n)
55-
for sampler, n in zip(samplers, ns)
56-
])
53+
index_batch = zip(
54+
*[
55+
batch(repeat_map_chain(iter, sampler), n)
56+
for sampler, n in zip(samplers, ns)
57+
]
58+
)
5759

5860
return chain.from_iterable(chain.from_iterable(index_batch))
5961

@@ -66,24 +68,17 @@ def update_weights_(self, function):
6668

6769
def update_example_weight_(self, weights, index):
6870
for sampler, weight in zip(self.samplers, weights):
69-
sampler.update_example_weight_(
70-
weight, index
71-
)
71+
sampler.update_example_weight_(weight, index)
7272

7373
def sample_proportion(self, proportion):
7474
return MultiSampler(
75-
[
76-
sampler.sample_proportion(proportion)
77-
for sampler in self.samplers
78-
],
79-
self.dataset
75+
[sampler.sample_proportion(proportion) for sampler in self.samplers],
76+
self.dataset,
8077
)
8178

8279
def state_dict(self):
83-
return dict(
84-
samplers=[sampler.state_dict() for sampler in self.samplers]
85-
)
80+
return dict(samplers=[sampler.state_dict() for sampler in self.samplers])
8681

8782
def load_state_dict(self, state_dict):
88-
for sampler, state_dict in zip(self.samplers, state_dict['samplers']):
83+
for sampler, state_dict in zip(self.samplers, state_dict["samplers"]):
8984
sampler.load_state_dict(state_dict)

‎datastream/samplers/repeat_sampler.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,16 +14,16 @@ class Config:
1414
arbitrary_types_allowed = True
1515

1616
def __init__(self, sampler, length, epoch_bound=False):
17-
'''
17+
"""
1818
Wrapper that repeats and limits length of sampling based on
1919
epoch length and batch size
20-
'''
20+
"""
2121
BaseModel.__init__(
2222
self,
2323
sampler=sampler,
2424
length=length,
2525
epoch_bound=epoch_bound,
26-
queue=iter(sampler)
26+
queue=iter(sampler),
2727
)
2828

2929
def __iter__(self):

‎datastream/samplers/sequential_sampler.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,7 @@ class Config:
1212

1313
def __init__(self, length):
1414
BaseModel.__init__(
15-
self,
16-
sampler=torch.utils.data.SequentialSampler(torch.ones(length))
15+
self, sampler=torch.utils.data.SequentialSampler(torch.ones(length))
1716
)
1817

1918
def __len__(self):
@@ -23,7 +22,4 @@ def __iter__(self):
2322
return iter(self.sampler)
2423

2524
def sample_proportion(self, proportion):
26-
return SequentialSampler(min(
27-
len(self),
28-
int(len(self) * proportion)
29-
))
25+
return SequentialSampler(min(len(self), int(len(self) * proportion)))

‎datastream/samplers/standard_sampler.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ def __init__(self, length, proportion=1.0, replacement=False):
2121
torch.ones(length).double(),
2222
num_samples=int(max(1, min(length, length * proportion))),
2323
replacement=replacement,
24-
)
24+
),
2525
)
2626

2727
def __len__(self):
@@ -41,7 +41,7 @@ def update_weights_(self, function):
4141
self.sampler.weights[:] = function(self.sampler.weights)
4242

4343
def update_example_weight_(self, weight, index):
44-
if hasattr(weight, 'item'):
44+
if hasattr(weight, "item"):
4545
weight = weight.item()
4646

4747
self.sampler.weights[index] = weight
@@ -59,4 +59,4 @@ def state_dict(self):
5959
return dict(weights=self.sampler.weights)
6060

6161
def load_state_dict(self, state_dict):
62-
self.sampler.weights[:] = state_dict['weights']
62+
self.sampler.weights[:] = state_dict["weights"]

‎datastream/samplers/zip_sampler.py

Lines changed: 8 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -50,9 +50,7 @@ def zip_samplers(samplers, datasets):
5050
def weight(self, index):
5151
return [
5252
sampler.weight(inner_index)
53-
for sampler, inner_index in zip(
54-
self.samplers, self.from_mapping(index)
55-
)
53+
for sampler, inner_index in zip(self.samplers, self.from_mapping(index))
5654
]
5755

5856
def update_weights_(self, function):
@@ -61,24 +59,17 @@ def update_weights_(self, function):
6159

6260
def update_example_weight_(self, weights, index):
6361
inner_indices = self.from_mapping(index)
64-
for sampler, weight, inner_index in zip(
65-
self.samplers, weights, inner_indices
66-
):
67-
sampler.update_example_weight_(
68-
weight, inner_index
69-
)
62+
for sampler, weight, inner_index in zip(self.samplers, weights, inner_indices):
63+
sampler.update_example_weight_(weight, inner_index)
7064

7165
def sample_proportion(self, proportion):
72-
return ZipSampler([
73-
sampler.sample_proportion(proportion)
74-
for sampler in self.samplers
75-
])
66+
return ZipSampler(
67+
[sampler.sample_proportion(proportion) for sampler in self.samplers]
68+
)
7669

7770
def state_dict(self):
78-
return dict(
79-
samplers=[sampler.state_dict() for sampler in self.samplers]
80-
)
71+
return dict(samplers=[sampler.state_dict() for sampler in self.samplers])
8172

8273
def load_state_dict(self, state_dict):
83-
for sampler, state_dict in zip(self.samplers, state_dict['samplers']):
74+
for sampler, state_dict in zip(self.samplers, state_dict["samplers"]):
8475
sampler.load_state_dict(state_dict)

‎datastream/tools/numpy_seed.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,8 @@
33

44

55
def numpy_seed(seed):
6-
'''Function decorator that sets a temporary numpy seed during execution'''
6+
"""Function decorator that sets a temporary numpy seed during execution"""
7+
78
def decorator(fn):
89
@wraps(fn)
910
def seeded_function(*args, **kwargs):
@@ -12,25 +13,24 @@ def seeded_function(*args, **kwargs):
1213
output = fn(*args, **kwargs)
1314
np.random.set_state(random_state)
1415
return output
16+
1517
return seeded_function
18+
1619
return decorator
1720

1821

1922
def test_numpy_seed():
20-
2123
def get_random_uniform(min, max):
2224
return np.random.random() * (max - min) + min
2325

2426
random_state = np.random.get_state()
2527
numpy_seed(1)(get_random_uniform)(-1, 1)
2628
assert np.all(random_state[1] == np.random.get_state()[1])
2729

28-
assert (
29-
numpy_seed(1)(get_random_uniform)(-1, 1) ==
30-
numpy_seed(1)(get_random_uniform)(-1, 1)
31-
)
30+
assert numpy_seed(1)(get_random_uniform)(-1, 1) == numpy_seed(1)(
31+
get_random_uniform
32+
)(-1, 1)
3233

33-
assert (
34-
numpy_seed(1)(get_random_uniform)(-1, 1) !=
35-
numpy_seed(None)(get_random_uniform)(-1, 1)
36-
)
34+
assert numpy_seed(1)(get_random_uniform)(-1, 1) != numpy_seed(None)(
35+
get_random_uniform
36+
)(-1, 1)

‎datastream/tools/split_dataframes.py

Lines changed: 51 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -14,45 +14,50 @@ def split_dataframes(
1414
filepath: Optional[Path] = None,
1515
frozen: Optional[bool] = False,
1616
):
17-
'''
17+
"""
1818
Split and save result. Add new examples and continue from the old split.
1919
2020
As new examples come in it can handle:
2121
- Changing test size
2222
- Adapt after removing examples from dataset
2323
- Adapt to new stratification
24-
'''
24+
"""
2525
if abs(sum(proportions.values()) - 1.0) >= 1e-5:
26-
raise ValueError(' '.join([
27-
'Expected sum of proportions to be 1.',
28-
f'Proportions were {tuple(proportions.values())}',
29-
]))
26+
raise ValueError(
27+
" ".join(
28+
[
29+
"Expected sum of proportions to be 1.",
30+
f"Proportions were {tuple(proportions.values())}",
31+
]
32+
)
33+
)
3034

3135
if filepath is not None and filepath.exists():
3236
split = json.loads(filepath.read_text())
3337

3438
if set(proportions.keys()) != set(split.keys()):
35-
raise ValueError(' '.join([
36-
'Expected split names in split file to be the same as the',
37-
'keys in proportions',
38-
]))
39+
raise ValueError(
40+
" ".join(
41+
[
42+
"Expected split names in split file to be the same as the",
43+
"keys in proportions",
44+
]
45+
)
46+
)
3947
else:
40-
split = {
41-
split_name: list()
42-
for split_name in proportions.keys()
43-
}
48+
split = {split_name: list() for split_name in proportions.keys()}
4449

4550
key_dataframe = pd.DataFrame({key_column: np.sort(dataframe[key_column].unique())})
4651

4752
if frozen:
4853
if sum(map(len, split.values())) == 0:
49-
raise ValueError('Frozen split is empty')
54+
raise ValueError("Frozen split is empty")
5055
n_unassigned = (~key_dataframe[key_column].isin(sum(split.values(), []))).sum()
5156
if n_unassigned > 0:
5257
warnings.warn(
5358
(
54-
f'Found {n_unassigned} unassigned examples when splitting the dataset.'
55-
' The split is frozen so they will will be discarded'
59+
f"Found {n_unassigned} unassigned examples when splitting the dataset."
60+
" The split is frozen so they will will be discarded"
5661
),
5762
UserWarning,
5863
)
@@ -120,23 +125,23 @@ def n_target_split(keys, proportion):
120125

121126

122127
def selected(k, unassigned):
123-
return np.random.choice(
124-
unassigned, size=k, replace=False
125-
).tolist()
128+
return np.random.choice(unassigned, size=k, replace=False).tolist()
126129

127130

128131
def mock_dataframe():
129-
return pd.DataFrame(dict(
130-
index=np.arange(100),
131-
number=np.random.randn(100),
132-
))
132+
return pd.DataFrame(
133+
dict(
134+
index=np.arange(100),
135+
number=np.random.randn(100),
136+
)
137+
)
133138

134139

135140
def test_standard():
136-
split_file = Path('test_standard.json')
141+
split_file = Path("test_standard.json")
137142
split_dataframes_ = split_dataframes(
138143
mock_dataframe(),
139-
key_column='index',
144+
key_column="index",
140145
proportions=dict(
141146
gradient=0.8,
142147
early_stopping=0.1,
@@ -151,18 +156,17 @@ def test_standard():
151156

152157

153158
def test_group_split_dataframe():
154-
dataframe = mock_dataframe().assign(group=lambda df: df['index'] // 4)
159+
dataframe = mock_dataframe().assign(group=lambda df: df["index"] // 4)
155160
split_dataframes_ = split_dataframes(
156161
dataframe,
157-
key_column='group',
162+
key_column="group",
158163
proportions=dict(
159164
train=0.8,
160165
compare=0.2,
161166
),
162167
)
163-
group_overlap = (
164-
set(split_dataframes_['train'].group)
165-
.intersection(split_dataframes_['compare'].group)
168+
group_overlap = set(split_dataframes_["train"].group).intersection(
169+
split_dataframes_["compare"].group
166170
)
167171
assert len(group_overlap) == 0
168172
assert tuple(map(len, split_dataframes_.values())) == (80, 20)
@@ -171,11 +175,11 @@ def test_group_split_dataframe():
171175
def test_validate_proportions():
172176
from pytest import raises
173177

174-
split_file = Path('test_validate_proportions.json')
178+
split_file = Path("test_validate_proportions.json")
175179
with raises(ValueError):
176180
split_dataframes(
177181
mock_dataframe(),
178-
key_column='index',
182+
key_column="index",
179183
proportions=dict(train=0.4, test=0.4),
180184
filepath=split_file,
181185
)
@@ -184,11 +188,11 @@ def test_validate_proportions():
184188
def test_missing_key_column():
185189
from pytest import raises
186190

187-
split_file = Path('test_missing_key_column.json')
191+
split_file = Path("test_missing_key_column.json")
188192
with raises(KeyError):
189193
split_dataframes(
190194
mock_dataframe(),
191-
key_column='should_fail',
195+
key_column="should_fail",
192196
proportions=dict(train=0.8, test=0.2),
193197
filepath=split_file,
194198
)
@@ -198,18 +202,18 @@ def test_missing_key_column():
198202

199203

200204
def test_no_split():
201-
'''we do not need to support this'''
205+
"""we do not need to support this"""
202206
split_dataframes(
203207
mock_dataframe(),
204-
key_column='index',
208+
key_column="index",
205209
proportions=dict(all=1.0),
206210
)
207211

208212

209213
def test_split_empty():
210214
split_dataframes_ = split_dataframes(
211215
mock_dataframe().iloc[:0],
212-
key_column='index',
216+
key_column="index",
213217
proportions=dict(train=0.8, test=0.2),
214218
)
215219
for df in split_dataframes_.values():
@@ -219,28 +223,28 @@ def test_split_empty():
219223
def test_split_single_row():
220224
split_dataframes_ = split_dataframes(
221225
mock_dataframe().iloc[:1],
222-
key_column='index',
226+
key_column="index",
223227
proportions=dict(train=0.9999, test=0.0001),
224228
)
225-
assert len(split_dataframes_['train']) == 1
226-
assert len(split_dataframes_['test']) == 0
229+
assert len(split_dataframes_["train"]) == 1
230+
assert len(split_dataframes_["test"]) == 0
227231

228232

229233
def test_changed_split_names():
230234
from pytest import raises
231235

232-
split_file = Path('test_changed_split_names.json')
236+
split_file = Path("test_changed_split_names.json")
233237
split_dataframes(
234238
mock_dataframe(),
235-
key_column='index',
239+
key_column="index",
236240
proportions=dict(train=0.8, test=0.2),
237241
filepath=split_file,
238242
)
239243

240244
with raises(ValueError):
241245
split_dataframes(
242246
mock_dataframe(),
243-
key_column='index',
247+
key_column="index",
244248
proportions=dict(should_fail=0.8, test=0.2),
245249
filepath=split_file,
246250
)
@@ -255,15 +259,15 @@ def test_frozen():
255259
with raises(ValueError):
256260
split_dataframes(
257261
dataframe,
258-
key_column='index',
262+
key_column="index",
259263
proportions=dict(train=0.8, test=0.2),
260264
frozen=True,
261265
)
262266

263-
split_file = Path('test_frozen.json')
267+
split_file = Path("test_frozen.json")
264268
split_dataframes(
265269
dataframe,
266-
key_column='index',
270+
key_column="index",
267271
proportions=dict(train=0.8, test=0.2),
268272
filepath=split_file,
269273
)

‎datastream/tools/star.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,10 @@
22

33

44
def star(fn):
5-
'''Wrap function to expand input to arguments'''
5+
"""Wrap function to expand input to arguments"""
6+
67
@wraps(fn)
78
def wrapper(args):
89
return fn(*args)
10+
911
return wrapper

‎datastream/tools/starcompose.py

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,9 @@
1-
2-
31
def starcompose(*transforms):
4-
'''
2+
"""
53
left compose functions together and expand tuples to args
64
75
Use starcompose.debug for verbose output when debugging
8-
'''
6+
"""
97

108
# TODO: consider doing starcompose with inner function calls rather than
119
# a loop
@@ -16,24 +14,28 @@ def _compose(*x):
1614
else:
1715
x = t(x)
1816
return x
17+
1918
return _compose
2019

2120

2221
def starcompose_debug(*transforms):
23-
'''
22+
"""
2423
verbose starcompose for debugging
25-
'''
26-
print('starcompose debug')
24+
"""
25+
print("starcompose debug")
26+
2727
def _compose(*x):
2828
for index, t in enumerate(transforms):
29-
print(f'{index}:, fn={t}, x={x}')
29+
print(f"{index}:, fn={t}, x={x}")
3030
if type(x) is tuple:
3131
x = t(*x)
3232
else:
3333
x = t(x)
3434
return x
35+
3536
return _compose
3637

38+
3739
starcompose.debug = starcompose_debug
3840

3941

@@ -42,16 +44,16 @@ def test_starcompose():
4244

4345
test = starcompose(lambda x, y: x + y)
4446
if test(3, 5) != 8:
45-
raise Exception('Two args inputs failed')
47+
raise Exception("Two args inputs failed")
4648

4749
test = starcompose(lambda x: sum(x))
4850
if test((3, 5)) != 8:
49-
raise Exception('Tuple input failed')
51+
raise Exception("Tuple input failed")
5052

5153
test = starcompose(
5254
lambda x: (x, x),
5355
lambda x, y: x + y,
5456
lambda x: x * 2,
5557
)
5658
if test(10) != 40:
57-
raise Exception('Expanded tuple for inner function failed')
59+
raise Exception("Expanded tuple for inner function failed")

‎datastream/tools/stratified_split.py

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -13,19 +13,14 @@ def stratified_split(
1313
seed: Optional[int] = None,
1414
frozen: Optional[bool] = False,
1515
):
16-
if (
17-
stratify_column is not None
18-
and any(dataset.dataframe[key_column].duplicated())
19-
):
16+
if stratify_column is not None and any(dataset.dataframe[key_column].duplicated()):
2017
# mathematically impossible in the general case
2118
warnings.warn(
22-
'Trying to do stratified split with non-unique key column'
23-
' - cannot guarantee correct splitting of key values.'
19+
"Trying to do stratified split with non-unique key column"
20+
" - cannot guarantee correct splitting of key values."
2421
)
2522
strata = {
26-
stratum_value: dataset.subset(
27-
lambda df: df[stratify_column] == stratum_value
28-
)
23+
stratum_value: dataset.subset(lambda df: df[stratify_column] == stratum_value)
2924
for stratum_value in dataset.dataframe[stratify_column].unique()
3025
}
3126
split_strata = [

‎datastream/tools/verify_split.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
@validate_arguments
77
def verify_split(old_path: Path, new_path: Path):
8-
'''
8+
"""
99
Verify that no keys from an old split are present in a different new split.
1010
1111
.. highlight:: python
@@ -16,7 +16,7 @@ def verify_split(old_path: Path, new_path: Path):
1616
"path/to/new/split.json",
1717
)
1818
19-
'''
19+
"""
2020
for old_split_name, old_split in json.loads(old_path.read_text()).items():
2121
for new_split_name, new_split in json.loads(new_path.read_text()).items():
2222
if (
@@ -26,8 +26,13 @@ def verify_split(old_path: Path, new_path: Path):
2626
raise ValueError(
2727
f'Some keys from old split "{old_split_name}"'
2828
f' are present in new split "{new_split_name}":\n'
29-
+ str("\n".join(
30-
[str(old_split[index]) for index in range(min(10, len(old_split)))]
31-
+ (["..."] if len(old_split) > 10 else [])
32-
))
29+
+ str(
30+
"\n".join(
31+
[
32+
str(old_split[index])
33+
for index in range(min(10, len(old_split)))
34+
]
35+
+ (["..."] if len(old_split) > 10 else [])
36+
)
37+
)
3338
)

‎docs/source/requirements.txt

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -23,18 +23,18 @@ lazy-object-proxy==1.4.3
2323
MarkupSafe==1.1.1
2424
mccabe==0.6.1
2525
more-itertools==8.3.0
26-
numpy==1.18.5
26+
numpy==1.23.4
2727
packaging==20.4
2828
pandas==1.1.5
2929
pkginfo==1.5.0.1
3030
pluggy==0.13.1
31-
py==1.10.0
31+
py==1.11.0
3232
pycparser==2.20
3333
pydantic==1.8.2
3434
Pygments==2.7.4
3535
pylint==2.5.3
3636
pyparsing==2.4.7
37-
pyspark==3.0.3
37+
pyspark==3.3.0
3838
pytest==5.4.3
3939
python-dateutil==2.8.1
4040
pytz==2020.1
@@ -58,12 +58,12 @@ sphinxcontrib-jsmath==1.0.1
5858
sphinxcontrib-qthelp==1.0.3
5959
sphinxcontrib-serializinghtml==1.1.4
6060
toml==0.10.1
61-
torch==1.8.1
61+
torch==1.12.1
6262
tqdm==4.46.1
6363
twine==3.1.1
6464
typing-extensions==3.10.0.0
6565
urllib3==1.26.5
66-
waitress==1.4.4
66+
waitress==2.1.1
6767
wcwidth==0.2.4
6868
webencodings==0.5.1
6969
WebOb==1.8.6

‎poetry.lock

Lines changed: 413 additions & 302 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

‎pyproject.toml

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,10 @@
22
name = "pytorch-datastream"
33
version = "0.0.0"
44
description = "Simple dataset to dataloader library for pytorch"
5-
authors = ["Aiwizo"]
5+
authors = ["NextML"]
66
license = "Apache-2.0"
77
readme = "README.rst"
8-
repository = "https://github.com/Aiwizo/pytorch-datastream"
8+
repository = "https://github.com/nextml-code/pytorch-datastream"
99
documentation = "https://pytorch-datastream.readthedocs.io"
1010
keywords = [
1111
"pytorch",
@@ -34,16 +34,17 @@ packages = [
3434
]
3535

3636
[tool.poetry.dependencies]
37-
python = "^3.7"
37+
python = "^3.8"
3838
torch = "^1.4.0"
3939
numpy = "^1.17.0"
4040
pandas = "^1.0.5"
4141
pydantic = "^1.5.0"
4242

43-
[tool.poetry.dev-dependencies]
43+
[tool.poetry.group.dev.dependencies]
4444
pylint = "^2.6.0"
4545
flake8 = "^3.8.4"
4646
pytest = "^6.1.2"
47+
black = "^22.10.0"
4748

4849
[build-system]
4950
requires = ["poetry-core>=1.0.0"]

0 commit comments

Comments
 (0)
Please sign in to comment.