Skip to content

Commit 368641f

Browse files
authored
feat: multifile handling in pin_upload/pin_download (#319)
* first prototype of working pin_upload * save single file uploads as Paths too * handle pin_download as well * change back connect api * update tests * add tests for upload/download * return hashes in a list
1 parent e64874f commit 368641f

File tree

6 files changed

+122
-72
lines changed

6 files changed

+122
-72
lines changed

pins/boards.py

+52-22
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515

1616
from .cache import PinsCache
1717
from .config import get_allow_rsc_short_name
18-
from .drivers import default_title, load_data, load_file, save_data
18+
from .drivers import REQUIRES_SINGLE_FILE, default_title, load_data, load_file, save_data
1919
from .errors import PinsError, PinsVersionError
2020
from .meta import Meta, MetaFactory, MetaRaw
2121
from .utils import ExtendMethodDoc, inform, warn_deprecated
@@ -243,9 +243,17 @@ def _pin_store(
243243
if isinstance(x, (tuple, list)) and len(x) == 1:
244244
x = x[0]
245245

246-
_p = Path(x)
247-
_base_len = len(_p.name) - len("".join(_p.suffixes))
248-
object_name = _p.name[:_base_len]
246+
if not isinstance(x, (list, tuple)):
247+
_p = Path(x)
248+
_base_len = len(_p.name) - len("".join(_p.suffixes))
249+
object_name = _p.name[:_base_len]
250+
else:
251+
# multifile upload, keep list of filenames
252+
object_name = []
253+
for file in x:
254+
_p = Path(file)
255+
# _base_len = len(_p.name) - len("".join(_p.suffixes))
256+
object_name.append(_p.name) # [:_base_len])
249257
else:
250258
object_name = None
251259

@@ -415,20 +423,32 @@ def pin_download(self, name, version=None, hash=None) -> Sequence[str]:
415423
if hash is not None:
416424
raise NotImplementedError("TODO: validate hash")
417425

426+
fnames = [meta.file] if isinstance(meta.file, str) else meta.file
427+
pin_type = meta.type
428+
429+
if len(fnames) > 1 and pin_type in REQUIRES_SINGLE_FILE:
430+
raise ValueError("Cannot load data when more than 1 file")
431+
418432
pin_name = self.path_to_pin(name)
433+
files = []
419434

420-
# TODO: raise for multiple files
421-
# fetch file
422-
with load_file(
423-
meta, self.fs, self.construct_path([pin_name, meta.version.version])
424-
) as f:
425-
# could also check whether f isinstance of PinCache
426-
fname = getattr(f, "name", None)
435+
for fname in fnames:
436+
# fetch file
437+
with load_file(
438+
fname,
439+
self.fs,
440+
self.construct_path([pin_name, meta.version.version]),
441+
pin_type,
442+
) as f:
443+
# could also check whether f isinstance of PinCache
444+
fname = getattr(f, "name", None)
427445

428-
if fname is None:
429-
raise PinsError("pin_download requires a cache.")
446+
if fname is None:
447+
raise PinsError("pin_download requires a cache.")
430448

431-
return [str(Path(fname).absolute())]
449+
files.append(str(Path(fname).absolute()))
450+
451+
return files
432452

433453
def pin_upload(
434454
self,
@@ -461,6 +481,12 @@ def pin_upload(
461481
This gets stored on the Meta.user field.
462482
"""
463483

484+
if isinstance(paths, (list, tuple)):
485+
# check if all paths exist
486+
for path in paths:
487+
if not Path(path).is_file():
488+
raise PinsError(f"Path is not a valid file: {path}")
489+
464490
return self._pin_store(
465491
paths,
466492
name,
@@ -665,7 +691,7 @@ def prepare_pin_version(
665691
metadata: Mapping | None = None,
666692
versioned: bool | None = None,
667693
created: datetime | None = None,
668-
object_name: str | None = None,
694+
object_name: str | list[str] | None = None,
669695
):
670696
meta = self._create_meta(
671697
pin_dir_path,
@@ -710,14 +736,18 @@ def _create_meta(
710736
# create metadata from object on disk ---------------------------------
711737
# save all pin data to a temporary folder (including data.txt), so we
712738
# can fs.put it all straight onto the backend filesystem
713-
714-
if object_name is None:
715-
p_obj = Path(pin_dir_path) / name
739+
apply_suffix = True
740+
if isinstance(object_name, (list, tuple)):
741+
apply_suffix = False
742+
p_obj = []
743+
for obj in object_name:
744+
p_obj.append(str(Path(pin_dir_path) / obj))
745+
elif object_name is None:
746+
p_obj = str(Path(pin_dir_path) / name)
716747
else:
717-
p_obj = Path(pin_dir_path) / object_name
718-
748+
p_obj = str(Path(pin_dir_path) / object_name)
719749
# file is saved locally in order to hash, calc size
720-
file_names = save_data(x, str(p_obj), type)
750+
file_names = save_data(x, p_obj, type, apply_suffix)
721751

722752
meta = self.meta_factory.create(
723753
pin_dir_path,
@@ -910,7 +940,7 @@ def pin_download(self, name, version=None, hash=None) -> Sequence[str]:
910940
meta = self.pin_meta(name, version)
911941

912942
if isinstance(meta, MetaRaw):
913-
f = load_file(meta, self.fs, None)
943+
f = load_file(meta.file, self.fs, None, meta.type)
914944
else:
915945
raise NotImplementedError(
916946
"TODO: pin_download currently can only read a url to a single file."

pins/drivers.py

+34-34
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111

1212
UNSAFE_TYPES = frozenset(["joblib"])
13-
REQUIRES_SINGLE_FILE = frozenset(["csv", "joblib", "file"])
13+
REQUIRES_SINGLE_FILE = frozenset(["csv", "joblib"])
1414

1515

1616
def _assert_is_pandas_df(x, file_type: str) -> None:
@@ -22,35 +22,24 @@ def _assert_is_pandas_df(x, file_type: str) -> None:
2222
)
2323

2424

25-
def load_path(meta, path_to_version):
26-
# Check that only a single file name was given
27-
fnames = [meta.file] if isinstance(meta.file, str) else meta.file
28-
29-
_type = meta.type
30-
31-
if len(fnames) > 1 and _type in REQUIRES_SINGLE_FILE:
32-
raise ValueError("Cannot load data when more than 1 file")
33-
25+
def load_path(filename: str, path_to_version, pin_type=None):
3426
# file path creation ------------------------------------------------------
35-
36-
if _type == "table":
27+
if pin_type == "table":
3728
# this type contains an rds and csv files named data.{ext}, so we match
3829
# R pins behavior and hardcode the name
39-
target_fname = "data.csv"
40-
else:
41-
target_fname = fnames[0]
30+
filename = "data.csv"
4231

4332
if path_to_version is not None:
44-
path_to_file = f"{path_to_version}/{target_fname}"
33+
path_to_file = f"{path_to_version}/{filename}"
4534
else:
4635
# BoardUrl doesn't have versions, and the file is the full url
47-
path_to_file = target_fname
36+
path_to_file = filename
4837

4938
return path_to_file
5039

5140

52-
def load_file(meta: Meta, fs, path_to_version):
53-
return fs.open(load_path(meta, path_to_version))
41+
def load_file(filename: str, fs, path_to_version, pin_type):
42+
return fs.open(load_path(filename, path_to_version, pin_type))
5443

5544

5645
def load_data(
@@ -81,7 +70,7 @@ def load_data(
8170
" * https://scikit-learn.org/stable/modules/model_persistence.html#security-maintainability-limitations"
8271
)
8372

84-
with load_file(meta, fs, path_to_version) as f:
73+
with load_file(meta.file, fs, path_to_version, meta.type) as f:
8574
if meta.type == "csv":
8675
import pandas as pd
8776

@@ -136,7 +125,9 @@ def load_data(
136125
raise NotImplementedError(f"No driver for type {meta.type}")
137126

138127

139-
def save_data(obj, fname, type=None, apply_suffix: bool = True) -> "str | Sequence[str]":
128+
def save_data(
129+
obj, fname, pin_type=None, apply_suffix: bool = True
130+
) -> "str | Sequence[str]":
140131
# TODO: extensible saving with deferred importing
141132
# TODO: how to encode arguments to saving / loading drivers?
142133
# e.g. pandas index options
@@ -145,59 +136,68 @@ def save_data(obj, fname, type=None, apply_suffix: bool = True) -> "str | Sequen
145136
# of saving / loading objects different ways.
146137

147138
if apply_suffix:
148-
if type == "file":
139+
if pin_type == "file":
149140
suffix = "".join(Path(obj).suffixes)
150141
else:
151-
suffix = f".{type}"
142+
suffix = f".{pin_type}"
152143
else:
153144
suffix = ""
154145

155-
final_name = f"{fname}{suffix}"
146+
if isinstance(fname, list):
147+
final_name = fname
148+
else:
149+
final_name = f"{fname}{suffix}"
156150

157-
if type == "csv":
151+
if pin_type == "csv":
158152
_assert_is_pandas_df(obj, file_type=type)
159153

160154
obj.to_csv(final_name, index=False)
161155

162-
elif type == "arrow":
156+
elif pin_type == "arrow":
163157
# NOTE: R pins accepts the type arrow, and saves it as feather.
164158
# we allow reading this type, but raise an error for writing.
165159
_assert_is_pandas_df(obj, file_type=type)
166160

167161
obj.to_feather(final_name)
168162

169-
elif type == "feather":
163+
elif pin_type == "feather":
170164
_assert_is_pandas_df(obj, file_type=type)
171165

172166
raise NotImplementedError(
173167
'Saving data as type "feather" no longer supported. Use type "arrow" instead.'
174168
)
175169

176-
elif type == "parquet":
170+
elif pin_type == "parquet":
177171
_assert_is_pandas_df(obj, file_type=type)
178172

179173
obj.to_parquet(final_name)
180174

181-
elif type == "joblib":
175+
elif pin_type == "joblib":
182176
import joblib
183177

184178
joblib.dump(obj, final_name)
185179

186-
elif type == "json":
180+
elif pin_type == "json":
187181
import json
188182

189183
json.dump(obj, open(final_name, "w"))
190184

191-
elif type == "file":
185+
elif pin_type == "file":
192186
import contextlib
193187
import shutil
194188

189+
if isinstance(obj, list):
190+
for file, final in zip(obj, final_name):
191+
with contextlib.suppress(shutil.SameFileError):
192+
shutil.copyfile(str(file), final)
193+
return obj
195194
# ignore the case where the source is the same as the target
196-
with contextlib.suppress(shutil.SameFileError):
197-
shutil.copyfile(str(obj), final_name)
195+
else:
196+
with contextlib.suppress(shutil.SameFileError):
197+
shutil.copyfile(str(obj), final_name)
198198

199199
else:
200-
raise NotImplementedError(f"Cannot save type: {type}")
200+
raise NotImplementedError(f"Cannot save type: {pin_type}")
201201

202202
return final_name
203203

pins/meta.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -245,7 +245,12 @@ def create(
245245

246246
raise NotImplementedError("Cannot create from file object.")
247247
else:
248-
raise NotImplementedError("TODO: creating meta from multiple files")
248+
if isinstance(files, (list, tuple)):
249+
from pathlib import Path
250+
251+
file_name = [Path(f).name for f in files]
252+
file_size = [Path(f).stat().st_size for f in files]
253+
version = Version.from_files(files, created)
249254

250255
return Meta(
251256
title=title,

pins/tests/test_boards.py

+20
Original file line numberDiff line numberDiff line change
@@ -266,6 +266,26 @@ def test_board_pin_upload_path_list(board_with_cache, tmp_path):
266266
(pin_path,) = board_with_cache.pin_download("cool_pin")
267267

268268

269+
def test_board_pin_download_filename_multifile(board_with_cache, tmp_path):
270+
# create and save data
271+
df = pd.DataFrame({"x": [1, 2, 3]})
272+
273+
path1, path2 = tmp_path / "data1.csv", tmp_path / "data2.csv"
274+
df.to_csv(path1, index=False)
275+
df.to_csv(path2, index=False)
276+
277+
meta = board_with_cache.pin_upload([path1, path2], "cool_pin")
278+
279+
assert meta.type == "file"
280+
assert meta.file == ["data1.csv", "data2.csv"]
281+
282+
pin_path = board_with_cache.pin_download("cool_pin")
283+
284+
assert len(pin_path) == 2
285+
assert Path(pin_path[0]).name == "data1.csv"
286+
assert Path(pin_path[1]).name == "data2.csv"
287+
288+
269289
def test_board_pin_write_rsc_index_html(board, tmp_path: Path, snapshot):
270290
if board.fs.protocol != "rsc":
271291
pytest.skip()

pins/tests/test_drivers.py

+3-11
Original file line numberDiff line numberDiff line change
@@ -164,31 +164,23 @@ def test_driver_apply_suffix_false(tmp_path: Path):
164164

165165

166166
class TestLoadFile:
167-
def test_multi_file_raises(self):
168-
class _MockMetaMultiFile:
169-
file: str | list[str] = ["a", "b"]
170-
type: str = "csv"
171-
172-
with pytest.raises(ValueError, match="Cannot load data when more than 1 file"):
173-
load_path(_MockMetaMultiFile(), None)
174-
175167
def test_str_file(self):
176168
class _MockMetaStrFile:
177169
file: str = "a"
178170
type: str = "csv"
179171

180-
assert load_path(_MockMetaStrFile(), None) == "a"
172+
assert load_path(_MockMetaStrFile().file, None, _MockMetaStrFile().type) == "a"
181173

182174
def test_table(self):
183175
class _MockMetaTable:
184176
file: str = "a"
185177
type: str = "table"
186178

187-
assert load_path(_MockMetaTable(), None) == "data.csv"
179+
assert load_path(_MockMetaTable().file, None, _MockMetaTable().type) == "data.csv"
188180

189181
def test_version(self):
190182
class _MockMetaTable:
191183
file: str = "a"
192184
type: str = "csv"
193185

194-
assert load_path(_MockMetaTable(), "v1") == "v1/a"
186+
assert load_path(_MockMetaTable().file, "v1", _MockMetaTable().type) == "v1/a"

pins/versions.py

+7-4
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import logging
44
from dataclasses import asdict, dataclass
55
from datetime import datetime
6+
from pathlib import Path
67
from typing import Mapping, Sequence
78

89
from xxhash import xxh64
@@ -56,9 +57,7 @@ def render_created(self):
5657
def hash_file(f: IOBase, block_size: int = -1) -> str:
5758
# TODO: what kind of things implement the "buffer API"?
5859
hasher = xxh64()
59-
6060
buf = f.read(block_size)
61-
6261
while len(buf) > 0:
6362
hasher.update(buf)
6463
buf = f.read(block_size)
@@ -99,14 +98,18 @@ def from_files(
9998
) -> Version:
10099
hashes = []
101100
for f in files:
102-
hash_ = cls.hash_file(open(f, "rb") if isinstance(f, str) else f)
101+
hash_ = cls.hash_file(open(f, "rb") if isinstance(f, (str, Path)) else f)
103102
hashes.append(hash_)
104103

105104
if created is None:
106105
created = datetime.now()
107106

108107
if len(hashes) > 1:
109-
raise NotImplementedError("Only 1 file may be currently be hashed")
108+
# Combine the hashes into a single string
109+
combined_hashes = "".join(hashes)
110+
111+
# Create an xxh64 hash of the combined string
112+
hashes = [xxh64(combined_hashes).hexdigest()]
110113

111114
return cls(created, hashes[0])
112115

0 commit comments

Comments
 (0)