Skip to content

Commit 353ba73

Browse files
authored
Imrovements to cp in generic (#1835)
1 parent 372ef69 commit 353ba73

File tree

2 files changed

+73
-61
lines changed

2 files changed

+73
-61
lines changed

fsspec/generic.py

+45-61
Original file line numberDiff line numberDiff line change
@@ -16,15 +16,13 @@
1616

1717

1818
def set_generic_fs(protocol, **storage_options):
19+
"""Populate the dict used for method=="generic" lookups"""
1920
_generic_fs[protocol] = filesystem(protocol, **storage_options)
2021

2122

22-
default_method = "default"
23-
24-
25-
def _resolve_fs(url, method=None, protocol=None, storage_options=None):
23+
def _resolve_fs(url, method, protocol=None, storage_options=None):
2624
"""Pick instance of backend FS"""
27-
method = method or default_method
25+
url = url[0] if isinstance(url, (list, tuple)) else url
2826
protocol = protocol or split_protocol(url)[0]
2927
storage_options = storage_options or {}
3028
if method == "default":
@@ -159,7 +157,7 @@ class GenericFileSystem(AsyncFileSystem):
159157

160158
protocol = "generic" # there is no real reason to ever use a protocol with this FS
161159

162-
def __init__(self, default_method="default", **kwargs):
160+
def __init__(self, default_method="default", storage_options=None, **kwargs):
163161
"""
164162
165163
Parameters
@@ -171,22 +169,25 @@ def __init__(self, default_method="default", **kwargs):
171169
configured via the config system
172170
- "generic": takes instances from the `_generic_fs` dict in this module,
173171
which you must populate before use. Keys are by protocol
172+
- "options": expects storage_options, a dict mapping protocol to
173+
kwargs to use when constructing the filesystem
174174
- "current": takes the most recently instantiated version of each FS
175175
"""
176176
self.method = default_method
177+
self.st_opts = storage_options
177178
super().__init__(**kwargs)
178179

179180
def _parent(self, path):
180-
fs = _resolve_fs(path, self.method)
181+
fs = _resolve_fs(path, self.method, storage_options=self.st_opts)
181182
return fs.unstrip_protocol(fs._parent(path))
182183

183184
def _strip_protocol(self, path):
184185
# normalization only
185-
fs = _resolve_fs(path, self.method)
186+
fs = _resolve_fs(path, self.method, storage_options=self.st_opts)
186187
return fs.unstrip_protocol(fs._strip_protocol(path))
187188

188189
async def _find(self, path, maxdepth=None, withdirs=False, detail=False, **kwargs):
189-
fs = _resolve_fs(path, self.method)
190+
fs = _resolve_fs(path, self.method, storage_options=self.st_opts)
190191
if fs.async_impl:
191192
out = await fs._find(
192193
path, maxdepth=maxdepth, withdirs=withdirs, detail=True, **kwargs
@@ -251,7 +252,7 @@ async def _pipe_file(
251252
value,
252253
**kwargs,
253254
):
254-
fs = _resolve_fs(path, self.method)
255+
fs = _resolve_fs(path, self.method, storage_options=self.st_opts)
255256
if fs.async_impl:
256257
return await fs._pipe_file(path, value, **kwargs)
257258
else:
@@ -269,7 +270,7 @@ async def _rm(self, url, **kwargs):
269270

270271
async def _makedirs(self, path, exist_ok=False):
271272
logger.debug("Make dir %s", path)
272-
fs = _resolve_fs(path, self.method)
273+
fs = _resolve_fs(path, self.method, storage_options=self.st_opts)
273274
if fs.async_impl:
274275
await fs._makedirs(path, exist_ok=exist_ok)
275276
else:
@@ -288,42 +289,18 @@ async def _cp_file(
288289
url2,
289290
blocksize=2**20,
290291
callback=DEFAULT_CALLBACK,
292+
tempdir: Optional[str] = None,
291293
**kwargs,
292294
):
293295
fs = _resolve_fs(url, self.method)
294296
fs2 = _resolve_fs(url2, self.method)
295297
if fs is fs2:
296298
# pure remote
297299
if fs.async_impl:
298-
return await fs._cp_file(url, url2, **kwargs)
300+
return await fs._copy(url, url2, **kwargs)
299301
else:
300-
return fs.cp_file(url, url2, **kwargs)
301-
kw = {"blocksize": 0, "cache_type": "none"}
302-
try:
303-
f1 = (
304-
await fs.open_async(url, "rb")
305-
if hasattr(fs, "open_async")
306-
else fs.open(url, "rb", **kw)
307-
)
308-
callback.set_size(await maybe_await(f1.size))
309-
f2 = (
310-
await fs2.open_async(url2, "wb")
311-
if hasattr(fs2, "open_async")
312-
else fs2.open(url2, "wb", **kw)
313-
)
314-
while f1.size is None or f2.tell() < f1.size:
315-
data = await maybe_await(f1.read(blocksize))
316-
if f1.size is None and not data:
317-
break
318-
await maybe_await(f2.write(data))
319-
callback.absolute_update(f2.tell())
320-
finally:
321-
try:
322-
await maybe_await(f2.close())
323-
await maybe_await(f1.close())
324-
except NameError:
325-
# fail while opening f1 or f2
326-
pass
302+
return fs.copy(url, url2, **kwargs)
303+
await copy_file_op(fs, [url], fs2, [url2], tempdir, 1, on_error="raise")
327304

328305
async def _make_many_dirs(self, urls, exist_ok=True):
329306
fs = _resolve_fs(urls[0], self.method)
@@ -347,17 +324,22 @@ async def _copy(
347324
tempdir: Optional[str] = None,
348325
**kwargs,
349326
):
327+
# TODO: special case for one FS being local, which can use get/put
328+
# TODO: special case for one being memFS, which can use cat/pipe
350329
if recursive:
351-
raise NotImplementedError
352-
fs = _resolve_fs(path1[0], self.method)
353-
fs2 = _resolve_fs(path2[0], self.method)
354-
# not expanding paths atm., assume call is from rsync()
330+
raise NotImplementedError("Please use fsspec.generic.rsync")
331+
path1 = [path1] if isinstance(path1, str) else path1
332+
path2 = [path2] if isinstance(path2, str) else path2
333+
334+
fs = _resolve_fs(path1, self.method)
335+
fs2 = _resolve_fs(path2, self.method)
336+
355337
if fs is fs2:
356-
# pure remote
357338
if fs.async_impl:
358339
return await fs._copy(path1, path2, **kwargs)
359340
else:
360341
return fs.copy(path1, path2, **kwargs)
342+
361343
await copy_file_op(
362344
fs, path1, fs2, path2, tempdir, batch_size, on_error=on_error
363345
)
@@ -377,31 +359,33 @@ async def copy_file_op(
377359
fs2,
378360
u2,
379361
os.path.join(tempdir, uuid.uuid4().hex),
380-
on_error=on_error,
381362
)
382363
for u1, u2 in zip(url1, url2)
383364
]
384-
await _run_coros_in_chunks(coros, batch_size=batch_size)
365+
out = await _run_coros_in_chunks(
366+
coros, batch_size=batch_size, return_exceptions=True
367+
)
385368
finally:
386369
shutil.rmtree(tempdir)
370+
if on_error == "return":
371+
return out
372+
elif on_error == "raise":
373+
for o in out:
374+
if isinstance(o, Exception):
375+
raise o
387376

388377

389378
async def _copy_file_op(fs1, url1, fs2, url2, local, on_error="ignore"):
390-
ex = () if on_error == "raise" else Exception
391-
logger.debug("Copy %s -> %s", url1, url2)
392-
try:
393-
if fs1.async_impl:
394-
await fs1._get_file(url1, local)
395-
else:
396-
fs1.get_file(url1, local)
397-
if fs2.async_impl:
398-
await fs2._put_file(local, url2)
399-
else:
400-
fs2.put_file(local, url2)
401-
os.unlink(local)
402-
logger.debug("Copy %s -> %s; done", url1, url2)
403-
except ex as e:
404-
logger.debug("ignoring cp exception for %s: %s", url1, e)
379+
if fs1.async_impl:
380+
await fs1._get_file(url1, local)
381+
else:
382+
fs1.get_file(url1, local)
383+
if fs2.async_impl:
384+
await fs2._put_file(local, url2)
385+
else:
386+
fs2.put_file(local, url2)
387+
os.unlink(local)
388+
logger.debug("Copy %s -> %s; done", url1, url2)
405389

406390

407391
async def maybe_await(cor):

fsspec/tests/test_generic.py

+28
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,34 @@ def test_cat_async(server):
4848
assert fs.cat(server.realfile) == data
4949

5050

51+
def test_cp_one(server, tmpdir):
52+
fsspec.filesystem("http", headers={"give_length": "true", "head_ok": "true"})
53+
local = fsspec.filesystem("file")
54+
fn = f"file://{tmpdir}/afile"
55+
56+
fs = fsspec.filesystem("generic", default_method="current")
57+
58+
fs.copy([server.realfile], [fn])
59+
assert local.cat(fn) == data
60+
fs.rm(fn)
61+
assert not fs.exists(fn)
62+
63+
fs.copy(server.realfile, fn)
64+
assert local.cat(fn) == data
65+
fs.rm(fn)
66+
assert not fs.exists(fn)
67+
68+
fs.cp([server.realfile], [fn])
69+
assert local.cat(fn) == data
70+
fs.rm(fn)
71+
assert not fs.exists(fn)
72+
73+
fs.cp_file(server.realfile, fn)
74+
assert local.cat(fn) == data
75+
fs.rm(fn)
76+
assert not fs.exists(fn)
77+
78+
5179
def test_rsync(tmpdir, m):
5280
from fsspec.generic import GenericFileSystem, rsync
5381

0 commit comments

Comments
 (0)