diff --git a/webknossos/webknossos/annotation/annotation.py b/webknossos/webknossos/annotation/annotation.py index 2448398ef..6b0067314 100644 --- a/webknossos/webknossos/annotation/annotation.py +++ b/webknossos/webknossos/annotation/annotation.py @@ -51,7 +51,7 @@ from pathlib import Path from shutil import copyfileobj from tempfile import TemporaryDirectory -from typing import BinaryIO, Union, cast, overload +from typing import BinaryIO, Literal, Union, cast, overload from zipfile import ZIP_DEFLATED, ZipFile from zlib import Z_BEST_SPEED @@ -61,8 +61,9 @@ from zipp import Path as ZipPath import webknossos._nml as wknml +from webknossos.geometry.mag import Mag -from ..client.api_client.models import ApiAnnotation +from ..client.api_client.models import ApiAnnotation, ApiMeshAdHoc, ApiMeshPrecomputed from ..dataset import ( SEGMENTATION_CATEGORY, DataFormat, @@ -473,6 +474,75 @@ def download( else: return annotation + def download_mesh( + self, + segment_id: int, + output_dir: PathLike, + tracing_id: str | None = None, + layer_name: str | None = None, + is_precomputed: bool = False, + mesh_file_name: str | None = None, + datastore_url: str | None = None, + lod: int = 0, + mapping_name: str | None = None, + mapping_type: Literal["agglomerate", "json"] | None = None, + mag: Mag | None = None, + seed_position: Vec3Int | None = None, + token: str | None = None, + ) -> UPath: + from ..client.context import _get_context + from ..datastore import Datastore + + context = _get_context() + datastore_url = datastore_url or Datastore.get_upload_url() + tracingstore = context.get_tracingstore_api_client() + mesh: ApiMeshAdHoc | ApiMeshPrecomputed + if is_precomputed: + assert mesh_file_name is not None + mesh = ApiMeshPrecomputed( + lod=lod, + mesh_file_name=mesh_file_name, + segment_id=segment_id, + mapping_name=mapping_name, + ) + else: + assert mapping_type is not None + assert mag is not None + assert seed_position is not None + mesh = ApiMeshAdHoc( + lod=lod, + segment_id=segment_id, + mapping_name=mapping_name, + mapping_type=mapping_type, + mag=mag.to_tuple(), + seed_position=seed_position.to_tuple(), + ) + file_path = UPath(output_dir) / f"{tracing_id}_{segment_id}.stl" + file_path.parent.mkdir(parents=True, exist_ok=True) + if tracing_id is None: + datastore = context.get_datastore_api_client(datastore_url=datastore_url) + assert layer_name is not None, ( + "When you attempt to download a mesh without a tracing_id, the layer_name must be set." + ) + mesh_download = datastore.annotation_download_mesh( + mesh, + organization_id=self.organization_id or context.organization_id, + directory_name=self.dataset_name, + layer_name=layer_name, + token=token, + ) + else: + mesh_download = tracingstore.annotation_download_mesh( + mesh=mesh, + tracing_id=tracing_id, + token=token, + ) + + with file_path.open("wb") as f: + for chunk in mesh_download: + f.write(chunk) + return file_path + @classmethod def open_remote( cls, diff --git a/webknossos/webknossos/client/api_client/__init__.py b/webknossos/webknossos/client/api_client/__init__.py index 1d939c013..379ae9ab5 100644 --- a/webknossos/webknossos/client/api_client/__init__.py +++ b/webknossos/webknossos/client/api_client/__init__.py @@ -1,5 +1,11 @@ from .datastore_api_client import DatastoreApiClient from .errors import ApiClientError +from .tracingstore_api_client import TracingstoreApiClient from .wk_api_client import WkApiClient -__all__ = ["WkApiClient", "DatastoreApiClient", "ApiClientError"] +__all__ = [ + "WkApiClient", + "DatastoreApiClient", + "ApiClientError", + "TracingstoreApiClient", +] diff --git a/webknossos/webknossos/client/api_client/_abstract_api_client.py b/webknossos/webknossos/client/api_client/_abstract_api_client.py index d846d16ee..3578bc906 100644 --- a/webknossos/webknossos/client/api_client/_abstract_api_client.py +++ b/webknossos/webknossos/client/api_client/_abstract_api_client.py @@ -1,5 +1,6 @@ import logging from abc import ABC, abstractmethod +from collections.abc import Iterator from typing import Any, TypeVar import httpx @@ -84,6 +85,24 @@ def _post_json( timeout_seconds=timeout_seconds, ) + def _post_json_with_response_stream( + self, + route: str, + body_structured: Any, + query: Query | None = None, + retry_count: int = 0, + timeout_seconds: float | None = None, + ) -> Iterator[bytes]: + body_json = self._prepare_for_json(body_structured) + response = self._post( + route, + body_json=body_json, + query=query, + retry_count=retry_count, + timeout_seconds=timeout_seconds, + ) + yield from response.iter_bytes() + def _get_file( self, route: str, query: Query | None = None, retry_count: int = 0 ) -> tuple[bytes, str]: diff --git a/webknossos/webknossos/client/api_client/datastore_api_client.py b/webknossos/webknossos/client/api_client/datastore_api_client.py index 8f258c1d4..6181af21e 100644 --- a/webknossos/webknossos/client/api_client/datastore_api_client.py +++ b/webknossos/webknossos/client/api_client/datastore_api_client.py @@ -1,8 +1,12 @@ +from collections.abc import Iterator + from webknossos.client.api_client.models import ( ApiDatasetAnnounceUpload, ApiDatasetManualUploadSuccess, ApiDatasetUploadInformation, ApiDatasetUploadSuccess, + ApiMeshAdHoc, + ApiMeshPrecomputed, ApiReserveDatasetUploadInformation, ) @@ -107,3 +111,19 @@ def dataset_get_raw_data( } response = self._get(route, query) return response.content, response.headers.get("MISSING-BUCKETS") + + def annotation_download_mesh( + self, + mesh: ApiMeshPrecomputed | ApiMeshAdHoc, + organization_id: str, + directory_name: str, + layer_name: str, + token: str | None, + ) -> Iterator[bytes]: + route = f"/datasets/{organization_id}/{directory_name}/layers/{layer_name}/meshes/fullMesh.stl" + query: Query = {"token": token} + yield from self._post_json_with_response_stream( + route=route, + body_structured=mesh, + query=query, + ) diff --git a/webknossos/webknossos/client/api_client/models.py b/webknossos/webknossos/client/api_client/models.py index 7785add88..31aea9cf5 100644 --- a/webknossos/webknossos/client/api_client/models.py +++ b/webknossos/webknossos/client/api_client/models.py @@ -1,4 +1,4 @@ -from typing import Any +from typing import Any, Literal import attr @@ -34,6 +34,12 @@ class ApiDataStore: allows_upload: bool +@attr.s(auto_attribs=True) +class ApiTracingStore: + name: str + url: str + + @attr.s(auto_attribs=True) class ApiTeam: id: str @@ -355,3 +361,23 @@ class ApiFolder: allowed_teams_cumulative: list[ApiTeam] is_editable: bool metadata: list[ApiMetadata] | None = None + + +@attr.s(auto_attribs=True) +class ApiMeshPrecomputed: + lod: int + mesh_file_name: str + segment_id: int + mapping_name: str | None + + +@attr.s(auto_attribs=True) +class ApiMeshAdHoc: + lod: int + segment_id: int # if mapping name is set, this is an agglomerate id + mapping_name: str | None + mapping_type: Literal["json", "agglomerate"] + mag: tuple[int, int, int] + seed_position: tuple[int, int, int] + mesh_file_name: None = None # None means ad-hoc mesh + additional_coordinates: None = None # ND datasets are not supported yet diff --git a/webknossos/webknossos/client/api_client/tracingstore_api_client.py b/webknossos/webknossos/client/api_client/tracingstore_api_client.py new file mode 100644 index 000000000..c3cb3f0eb --- /dev/null +++ b/webknossos/webknossos/client/api_client/tracingstore_api_client.py @@ -0,0 +1,42 @@ +from collections.abc import Iterator + +from webknossos.client.api_client.models import ( + ApiMeshAdHoc, + ApiMeshPrecomputed, +) + +from ._abstract_api_client import AbstractApiClient, Query + + +class TracingstoreApiClient(AbstractApiClient): + # Client to use the HTTP API of WEBKNOSSOS datastore servers. + # When adding a method here, use the utility methods from AbstractApiClient + # and add more as needed. + # Methods here are prefixed with the domain, e.g. dataset_finish_upload (not finish_dataset_upload) + + def __init__( + self, + base_url: str, + timeout_seconds: float, + headers: dict[str, str] | None = None, + ): + super().__init__(timeout_seconds, headers) + self.base_url = base_url + + @property + def url_prefix(self) -> str: + return f"{self.base_url}/tracings" + + def annotation_download_mesh( + self, + mesh: ApiMeshPrecomputed | ApiMeshAdHoc, + tracing_id: str, + token: str | None, + ) -> Iterator[bytes]: + route = f"/volume/{tracing_id}/fullMesh.stl" + query: Query = {"token": token} + yield from self._post_json_with_response_stream( + route=route, + body_structured=mesh, + query=query, + ) diff --git a/webknossos/webknossos/client/api_client/wk_api_client.py b/webknossos/webknossos/client/api_client/wk_api_client.py index 2ae44d8bf..3fe188eb9 100644 --- a/webknossos/webknossos/client/api_client/wk_api_client.py +++ b/webknossos/webknossos/client/api_client/wk_api_client.py @@ -22,6 +22,7 @@ ApiTaskParameters, ApiTeam, ApiTeamAdd, + ApiTracingStore, ApiUser, ApiWkBuildInfo, ) @@ -135,6 +136,10 @@ def datastore_list(self) -> list[ApiDataStore]: route = "/datastores" return self._get_json(route, list[ApiDataStore]) + def tracingstore(self) -> ApiTracingStore: + route = "/tracingstore" + return self._get_json(route, ApiTracingStore) + def project_info_by_name(self, project_name: str) -> ApiProject: route = f"/projects/byName/{project_name}" return self._get_json(route, ApiProject) diff --git a/webknossos/webknossos/client/context.py b/webknossos/webknossos/client/context.py index e389d8cb9..dc73fae82 100644 --- a/webknossos/webknossos/client/context.py +++ b/webknossos/webknossos/client/context.py @@ -56,7 +56,7 @@ from rich.prompt import Prompt from ._defaults import DEFAULT_HTTP_TIMEOUT, DEFAULT_WEBKNOSSOS_URL -from .api_client import DatastoreApiClient, WkApiClient +from .api_client import DatastoreApiClient, TracingstoreApiClient, WkApiClient load_dotenv() @@ -166,6 +166,17 @@ def get_datastore_api_client(self, datastore_url: str) -> DatastoreApiClient: headers=headers, ) + def get_tracingstore_api_client(self) -> TracingstoreApiClient: + if self.datastore_token is not None: + headers = {"X-Auth-Token": self.datastore_token} + api_tracingstore = self.api_client_with_auth.tracingstore() + + return TracingstoreApiClient( + base_url=api_tracingstore.url, + timeout_seconds=self.timeout, + headers=headers, + ) + _webknossos_context_var: ContextVar[_WebknossosContext] = ContextVar( "_webknossos_context_var", default=_WebknossosContext() diff --git a/webknossos/webknossos/tracingstore/__init__.py b/webknossos/webknossos/tracingstore/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/webknossos/webknossos/tracingstore/tracingstore.py b/webknossos/webknossos/tracingstore/tracingstore.py new file mode 100644 index 000000000..84c00e44a --- /dev/null +++ b/webknossos/webknossos/tracingstore/tracingstore.py @@ -0,0 +1,35 @@ +import attr + + +@attr.frozen +class Tracingstore: + """Tracingstore class for interactions with the tracing store.""" + + name: str + url: str + + @classmethod + def get_tracingstore( + cls, + ) -> "Tracingstore": + """Get the tracingstore for current webknossos url. + + + Returns: + Tracingstore object + + Examples: + ``` + # Get a list of all datastores that allow dataset uploads + tracingstore = Tracingstore.get_tracingstore() + ``` + """ + + from ..client.context import _get_context + + context = _get_context() + api_tracingstore = context.api_client_with_auth.tracingstore() + return cls( + api_tracingstore.name, + api_tracingstore.url, + )