Skip to content

Commit 71d3085

Browse files
committed
move filename determination to models.Response and file and progress bar management to clients
1 parent 05a02cb commit 71d3085

File tree

9 files changed

+249
-250
lines changed

9 files changed

+249
-250
lines changed

planet/clients/data.py

Lines changed: 29 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@
2020
from typing import Any, AsyncIterator, Callable, Dict, List, Optional
2121
import uuid
2222

23+
from tqdm.asyncio import tqdm
24+
2325
from ..data_filter import empty_filter
2426
from .. import exceptions
2527
from ..constants import PLANET_BASE_URL
@@ -586,20 +588,40 @@ async def download_asset(self,
586588
587589
Raises:
588590
planet.exceptions.APIError: On API error.
589-
planet.exceptions.ClientError: If asset is not active, asset
590-
description is not valid, or retry limit is exceeded.
591+
planet.exceptions.ClientError: If asset is not active or asset
592+
description is not valid.
591593
"""
592594
try:
593595
location = asset['location']
594596
except KeyError:
595597
raise exceptions.ClientError(
596598
'asset missing ["location"] entry. Is asset active?')
597599

598-
return await self._session.write(location,
599-
filename=filename,
600-
directory=directory,
601-
overwrite=overwrite,
602-
progress_bar=progress_bar)
600+
response = await self._session.request(method='GET', url=location)
601+
filename = filename or response.filename
602+
if not filename:
603+
raise exceptions.ClientError(
604+
f'Could not determine filename at {location}')
605+
606+
dl_path = Path(directory, filename)
607+
dl_path.parent.mkdir(exist_ok=True, parents=True)
608+
LOGGER.info(f'Downloading {dl_path}')
609+
610+
try:
611+
mode = 'wb' if overwrite else 'xb'
612+
with open(dl_path, mode) as fp:
613+
with tqdm(total=response.length,
614+
unit_scale=True,
615+
unit_divisor=1024 * 1024,
616+
unit='B',
617+
desc=str(filename),
618+
disable=not progress_bar) as progress:
619+
update = progress.update if progress_bar else LOGGER.debug
620+
await self._session.write(location, fp, update)
621+
except FileExistsError:
622+
LOGGER.info(f'File {dl_path} exists, not overwriting')
623+
624+
return dl_path
603625

604626
@staticmethod
605627
def validate_checksum(asset: dict, filename: Path):

planet/clients/orders.py

Lines changed: 28 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,15 @@
1515
"""Functionality for interacting with the orders api"""
1616
import asyncio
1717
import logging
18+
from pathlib import Path
1819
import time
1920
from typing import AsyncIterator, Callable, List, Optional
2021
import uuid
2122
import json
2223
import hashlib
2324

24-
from pathlib import Path
25+
from tqdm.asyncio import tqdm
26+
2527
from .. import exceptions
2628
from ..constants import PLANET_BASE_URL
2729
from ..http import Session
@@ -255,11 +257,31 @@ async def download_asset(self,
255257
limit is exceeded.
256258
257259
"""
258-
return await self._session.write(location,
259-
filename=filename,
260-
directory=directory,
261-
overwrite=overwrite,
262-
progress_bar=progress_bar)
260+
response = await self._session.request(method='GET', url=location)
261+
filename = filename or response.filename
262+
length = response.length
263+
if not filename:
264+
raise exceptions.ClientError(
265+
f'Could not determine filename at {location}')
266+
267+
dl_path = Path(directory, filename)
268+
dl_path.parent.mkdir(exist_ok=True, parents=True)
269+
LOGGER.info(f'Downloading {dl_path}')
270+
271+
try:
272+
mode = 'wb' if overwrite else 'xb'
273+
with open(dl_path, mode) as fp:
274+
with tqdm(total=length,
275+
unit_scale=True,
276+
unit_divisor=1024 * 1024,
277+
unit='B',
278+
desc=str(filename),
279+
disable=not progress_bar) as progress:
280+
await self._session.write(location, fp, progress.update)
281+
except FileExistsError:
282+
LOGGER.info(f'File {dl_path} exists, not overwriting')
283+
284+
return dl_path
263285

264286
async def download_order(self,
265287
order_id: str,

planet/http.py

Lines changed: 13 additions & 111 deletions
Original file line numberDiff line numberDiff line change
@@ -18,17 +18,11 @@
1818
from collections import Counter
1919
from http import HTTPStatus
2020
import logging
21-
import mimetypes
22-
from pathlib import Path
2321
import random
24-
import re
25-
import string
2622
import time
27-
from typing import Optional
28-
from urllib.parse import urlparse
23+
from typing import Callable, Optional
2924

3025
import httpx
31-
from tqdm.asyncio import tqdm
3226
from typing_extensions import Literal
3327

3428
from .auth import Auth, AuthType
@@ -332,6 +326,7 @@ async def _retry(self, func, *a, **kw):
332326
LOGGER.info(f'Retrying: sleeping {wait_time}s')
333327
await asyncio.sleep(wait_time)
334328
else:
329+
LOGGER.info('Retrying: failed')
335330
raise e
336331

337332
self.outcomes.update(['Successful'])
@@ -399,80 +394,32 @@ async def _send(self, request, stream=False) -> httpx.Response:
399394

400395
return http_resp
401396

402-
async def write(self,
403-
url: str,
404-
filename: Optional[str] = None,
405-
directory: Path = Path('.'),
406-
overwrite: bool = False,
407-
progress_bar: bool = False) -> Path:
397+
async def write(self, url: str, fp, callback: Optional[Callable] = None):
408398
"""Write data to local file with limiting and retries.
409399
410400
Parameters:
411-
url: Remote location url
412-
filename: Custom name to assign to downloaded file.
413-
directory: Base directory for file download. This directory will be
414-
created if it does not already exist.
415-
overwrite: Overwrite any existing files.
416-
progress_bar: Show progress bar during download.
417-
418-
Returns:
419-
Path to downloaded file.
401+
url: Remote location url.
402+
fp: Open write file pointer.
403+
callback: Function that handles write progress updates.
420404
421405
Raises:
422406
planet.exceptions.APIException: On API error.
423-
planet.exceptions.ClientError: When retry limit is exceeded.
424407
425408
"""
426409

427-
async def _write():
428-
async with self._client.stream('GET', url) as response:
429-
430-
dl_path = Path(
431-
directory,
432-
filename or _get_filename_from_response(response))
433-
dl_path.parent.mkdir(exist_ok=True, parents=True)
434-
435-
await self._write_response(response,
436-
dl_path,
437-
overwrite=overwrite,
438-
progress_bar=progress_bar)
439-
440-
return dl_path
441-
442410
async def _limited_write():
443411
async with self._limiter:
444-
dl_path = await _write()
445-
return dl_path
446-
447-
return await self._retry(_limited_write)
448-
449-
async def _write_response(self,
450-
response,
451-
filename,
452-
overwrite,
453-
progress_bar):
454-
total = int(response.headers["Content-Length"])
455-
456-
try:
457-
mode = 'wb' if overwrite else 'xb'
458-
with open(filename, mode) as fp:
459-
460-
with tqdm(total=total,
461-
unit_scale=True,
462-
unit_divisor=1024 * 1024,
463-
unit='B',
464-
desc=str(filename),
465-
disable=not progress_bar) as progress:
412+
async with self._client.stream('GET', url) as response:
466413
previous = response.num_bytes_downloaded
467414

468415
async for chunk in response.aiter_bytes():
469416
fp.write(chunk)
470-
new = response.num_bytes_downloaded - previous
471-
progress.update(new - previous)
472-
previous = new
473-
progress.update()
474-
except FileExistsError:
475-
LOGGER.info(f'File {filename} exists, not overwriting')
417+
current = response.num_bytes_downloaded
418+
if callback is not None:
419+
callback(current - previous)
420+
previous = current
421+
422+
await self._retry(_limited_write)
476423

477424
def client(self,
478425
name: Literal['data', 'orders', 'subscriptions'],
@@ -498,51 +445,6 @@ def client(self,
498445
raise exceptions.ClientError("No such client.")
499446

500447

501-
def _get_filename_from_response(response) -> str:
502-
"""The name of the response resource.
503-
504-
The default is to use the content-disposition header value from the
505-
response. If not found, falls back to resolving the name from the url
506-
or generating a random name with the type from the response.
507-
"""
508-
name = (_get_filename_from_headers(response.headers)
509-
or _get_filename_from_url(response.url)
510-
or _get_random_filename(response.headers.get('content-type')))
511-
return name
512-
513-
514-
def _get_filename_from_headers(headers):
515-
"""Get a filename from the Content-Disposition header, if available.
516-
517-
:param headers dict: a ``dict`` of response headers
518-
:returns: a filename (i.e. ``basename``)
519-
:rtype: str or None
520-
"""
521-
cd = headers.get('content-disposition', '')
522-
match = re.search('filename="?([^"]+)"?', cd)
523-
return match.group(1) if match else None
524-
525-
526-
def _get_filename_from_url(url: str) -> Optional[str]:
527-
"""Get a filename from a url.
528-
529-
Getting a name for Landsat imagery uses this function.
530-
"""
531-
path = urlparse(url).path
532-
name = path[path.rfind('/') + 1:]
533-
return name or None
534-
535-
536-
def _get_random_filename(content_type=None) -> str:
537-
"""Get a pseudo-random, Planet-looking filename.
538-
"""
539-
extension = mimetypes.guess_extension(content_type or '') or ''
540-
characters = string.ascii_letters + '0123456789'
541-
letters = ''.join(random.sample(characters, 8))
542-
name = 'planet-{}{}'.format(letters, extension)
543-
return name
544-
545-
546448
class AuthSession(BaseSession):
547449
"""Synchronous connection to the Planet Auth service."""
548450

planet/models.py

Lines changed: 59 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,9 @@
1414
# limitations under the License.
1515
"""Manage data for requests and responses."""
1616
import logging
17-
from typing import AsyncGenerator, Callable, List
17+
import re
18+
from typing import AsyncGenerator, Callable, List, Optional
19+
from urllib.parse import urlparse
1820

1921
import httpx
2022

@@ -42,11 +44,67 @@ def status_code(self) -> int:
4244
"""HTTP status code"""
4345
return self._http_response.status_code
4446

47+
@property
48+
def filename(self) -> Optional[str]:
49+
"""Name of the download file.
50+
51+
The filename is None if the response does not represent a download.
52+
"""
53+
filename = None
54+
55+
if self.length is not None: # is a download file
56+
filename = _get_filename_from_response(self._http_response)
57+
58+
return filename
59+
60+
@property
61+
def length(self) -> Optional[int]:
62+
"""Length of the download file.
63+
64+
The length is None if the response does not represent a download.
65+
"""
66+
LOGGER.warning('here')
67+
try:
68+
length = int(self._http_response.headers["Content-Length"])
69+
except KeyError:
70+
length = None
71+
LOGGER.warning(length)
72+
return length
73+
4574
def json(self) -> dict:
4675
"""Response json"""
4776
return self._http_response.json()
4877

4978

79+
def _get_filename_from_response(response) -> Optional[str]:
80+
"""The name of the response resource.
81+
82+
The default is to use the content-disposition header value from the
83+
response. If not found, falls back to resolving the name from the url
84+
or generating a random name with the type from the response.
85+
"""
86+
name = (_get_filename_from_headers(response.headers)
87+
or _get_filename_from_url(str(response.url)))
88+
return name
89+
90+
91+
def _get_filename_from_headers(headers: httpx.Headers) -> Optional[str]:
92+
"""Get a filename from the Content-Disposition header, if available."""
93+
cd = headers.get('content-disposition', '')
94+
match = re.search('filename="?([^"]+)"?', cd)
95+
return match.group(1) if match else None
96+
97+
98+
def _get_filename_from_url(url: str) -> Optional[str]:
99+
"""Get a filename from the url.
100+
101+
Getting a name for Landsat imagery uses this function.
102+
"""
103+
path = urlparse(url).path
104+
name = path[path.rfind('/') + 1:]
105+
return name or None
106+
107+
50108
class Paged:
51109
"""Asynchronous iterator over results in a paged resource.
52110

tests/integration/test_data_api.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -837,11 +837,15 @@ async def _stream_img():
837837
# populate request parameter to avoid respx cloning, which throws
838838
# an error caused by respx and not this code
839839
# https://github.com/lundberg/respx/issues/130
840-
mock_resp = httpx.Response(HTTPStatus.OK,
841-
stream=_stream_img(),
842-
headers=img_headers,
843-
request='donotcloneme')
844-
respx.get(dl_url).return_value = mock_resp
840+
respx.get(dl_url).side_effect = [
841+
httpx.Response(HTTPStatus.OK,
842+
headers=img_headers,
843+
request='donotcloneme'),
844+
httpx.Response(HTTPStatus.OK,
845+
stream=_stream_img(),
846+
headers=img_headers,
847+
request='donotcloneme')
848+
]
845849

846850
basic_udm2_asset = {
847851
"_links": {
@@ -863,7 +867,8 @@ async def _stream_img():
863867

864868
path = await cl.download_asset(basic_udm2_asset,
865869
directory=tmpdir,
866-
overwrite=overwrite)
870+
overwrite=overwrite,
871+
progress_bar=False)
867872
assert path.name == 'img.tif'
868873
assert path.is_file()
869874

0 commit comments

Comments
 (0)