diff --git a/source/ftrack_api/accessor/server.py b/source/ftrack_api/accessor/server.py index a8f5a188..459fb933 100644 --- a/source/ftrack_api/accessor/server.py +++ b/source/ftrack_api/accessor/server.py @@ -4,6 +4,7 @@ import os import hashlib import base64 +from typing import BinaryIO import requests @@ -23,6 +24,7 @@ def __init__(self, resource_identifier, session, mode="rb"): self.resource_identifier = resource_identifier self._session = session self._has_read = False + self._has_uploaded = False super(ServerFile, self).__init__() @@ -30,8 +32,8 @@ def flush(self): """Flush all changes.""" super(ServerFile, self).flush() - if self.mode == "wb": - self._write() + if not self._has_uploaded and self.mode == "wb": + self._flush_to_server() def read(self, limit=None): """Read file.""" @@ -41,6 +43,12 @@ def read(self, limit=None): return super(ServerFile, self).read(limit) + def write(self, content): + """Write *content* to file.""" + assert self._has_uploaded is False, "Cannot write to file after upload." + + return super().write(content) + def _read(self): """Read all remote content from key into wrapped_file.""" position = self.tell() @@ -66,13 +74,22 @@ def _read(self): for block in response.iter_content(ftrack_api.symbol.CHUNK_SIZE): self.wrapped_file.write(block) - self.flush() self.seek(position) + self._has_uploaded = False - def _write(self): + def _flush_to_server(self): """Write current data to remote key.""" position = self.tell() + self.upload_to_server(self.wrapped_file) + + self.seek(position) + + def upload_to_server(self, source_file: "BinaryIO"): + """ + Direct upload source to server. + Use with caution, it will forbid any further write operation until you read the file. + """ # Retrieve component from cache to construct a filename. component = self._session.get("FileComponent", self.resource_identifier) if not component: @@ -92,9 +109,9 @@ def _write(self): self._session, component_id=self.resource_identifier, file_name=name, - file_size=self._get_size(), - file=self.wrapped_file, - checksum=self._compute_checksum(), + file_size=self._get_size(source_file), + file=source_file, + checksum=self._compute_checksum(source_file), ) uploader.start() except Exception as error: @@ -102,18 +119,19 @@ def _write(self): "Failed to put file to server: {0}.".format(error) ) - self.seek(position) + self._has_uploaded = True - def _get_size(self): + @staticmethod + def _get_size(file: "BinaryIO"): """Return size of file in bytes.""" - position = self.tell() - length = self.seek(0, os.SEEK_END) - self.seek(position) + position = file.tell() + length = file.seek(0, os.SEEK_END) + file.seek(position) return length - def _compute_checksum(self): + @staticmethod + def _compute_checksum(fp: "BinaryIO"): """Return checksum for file.""" - fp = self.wrapped_file buf_size = ftrack_api.symbol.CHUNK_SIZE hash_obj = hashlib.md5() spos = fp.tell() diff --git a/source/ftrack_api/entity/location.py b/source/ftrack_api/entity/location.py index 3a575fb0..a669bfc3 100644 --- a/source/ftrack_api/entity/location.py +++ b/source/ftrack_api/entity/location.py @@ -5,6 +5,8 @@ import collections.abc import functools +import ftrack_api.data +import ftrack_api.accessor.server import ftrack_api.entity.base import ftrack_api.exception import ftrack_api.event.base @@ -346,13 +348,19 @@ def _add_data(self, component, resource_identifier, source): ) target_data = self.accessor.open(resource_identifier, "wb") - # Read/write data in chunks to avoid reading all into memory at the - # same time. - chunked_read = functools.partial( - source_data.read, ftrack_api.symbol.CHUNK_SIZE - ) - for chunk in iter(chunked_read, b""): - target_data.write(chunk) + if isinstance(source_data, ftrack_api.data.File) and isinstance( + target_data, ftrack_api.accessor.server.ServerFile + ): + # If source is a file and target is a server, use the server's upload method + target_data.upload_to_server(source_data.wrapped_file) + else: + # Read/write data in chunks to avoid reading all into memory at the + # same time. + chunked_read = functools.partial( + source_data.read, ftrack_api.symbol.CHUNK_SIZE + ) + for chunk in iter(chunked_read, b""): + target_data.write(chunk) target_data.close() source_data.close() diff --git a/source/ftrack_api/uploader.py b/source/ftrack_api/uploader.py index e672c33a..87a2f873 100644 --- a/source/ftrack_api/uploader.py +++ b/source/ftrack_api/uploader.py @@ -4,7 +4,7 @@ import logging import math import os -from typing import IO, Awaitable, Callable, TYPE_CHECKING, List, Optional +from typing import BinaryIO, Awaitable, Callable, TYPE_CHECKING, List, Optional import typing import anyio @@ -66,7 +66,7 @@ def __init__( component_id: str, file_name: str, file_size: int, - file: "IO", + file: "BinaryIO", checksum: Optional[str], ): self.session = session