Skip to content

Add method to download meshes #1307

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 3 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 69 additions & 2 deletions webknossos/webknossos/annotation/annotation.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@
from pathlib import Path
from shutil import copyfileobj
from tempfile import TemporaryDirectory
from typing import BinaryIO, Union, cast, overload
from typing import BinaryIO, Literal, Union, cast, overload
from zipfile import ZIP_DEFLATED, ZipFile
from zlib import Z_BEST_SPEED

Expand All @@ -61,8 +61,9 @@
from zipp import Path as ZipPath

import webknossos._nml as wknml
from webknossos.geometry.mag import Mag

from ..client.api_client.models import ApiAnnotation
from ..client.api_client.models import ApiAnnotation, ApiMeshAdHoc, ApiMeshPrecomputed
from ..dataset import (
SEGMENTATION_CATEGORY,
DataFormat,
Expand Down Expand Up @@ -473,6 +474,72 @@ def download(
else:
return annotation

def download_mesh(
self,
segment_id: int,
output_dir: PathLike,
tracing_id: str | None = None,
layer_name: str | None = None,
is_precomputed: bool = False,
mesh_file_name: str | None = None,
datastore_url: str | None = None,
lod: int = 0,
mapping_name: str | None = None,
mapping_type: Literal["agglomerate", "json"] | None = None,
mag: Mag | None = None,
seed_position: Vec3Int | None = None,
token: str | None = None,
) -> UPath:
from ..client.context import _get_context
from ..datastore import Datastore

context = _get_context()
datastore_url = datastore_url or Datastore.get_upload_url()
tracingstore = context.get_tracingstore_api_client()
mesh: ApiMeshAdHoc | ApiMeshPrecomputed
if is_precomputed:
assert mesh_file_name is not None
mesh = ApiMeshPrecomputed(
lod=lod,
mesh_file_name=mesh_file_name,
segment_id=segment_id,
mapping_name=mapping_name,
)
else:
assert mapping_type is not None
assert mag is not None
assert seed_position is not None
mesh = ApiMeshAdHoc(
lod=lod,
segment_id=segment_id,
mapping_name=mapping_name,
mapping_type=mapping_type,
mag=mag.to_tuple(),
seed_position=seed_position.to_tuple(),
)
file_path = UPath(output_dir) / f"{tracing_id}_{segment_id}.stl"
file_path.parent.mkdir(parents=True, exist_ok=True)
if tracing_id is None:
datastore = context.get_datastore_api_client(datastore_url=datastore_url)
mesh_download = datastore.annotation_download_mesh(
mesh,
organization_id=self.organization_id,
directory_name=self.dataset_name,
layer_name=layer_name,
token=token,
)
else:
mesh_download = tracingstore.annotation_download_mesh(
mesh=mesh,
tracing_id=tracing_id,
token=token,
)

with file_path.open("wb") as f:
for chunk in mesh_download:
f.write(chunk)
return file_path

@classmethod
def open_remote(
cls,
Expand Down
1 change: 1 addition & 0 deletions webknossos/webknossos/client/api_client/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from .datastore_api_client import DatastoreApiClient
from .tracingstore_api_client import TracingstoreApiClient
from .errors import ApiClientError
from .wk_api_client import WkApiClient

Expand Down
19 changes: 19 additions & 0 deletions webknossos/webknossos/client/api_client/_abstract_api_client.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import logging
from abc import ABC, abstractmethod
from collections.abc import Iterator
from typing import Any, TypeVar

import httpx
Expand Down Expand Up @@ -84,6 +85,24 @@ def _post_json(
timeout_seconds=timeout_seconds,
)

def _post_json_with_response_stream(
self,
route: str,
body_structured: Any,
query: Query | None = None,
retry_count: int = 0,
timeout_seconds: float | None = None,
) -> Iterator[bytes]:
body_json = self._prepare_for_json(body_structured)
response = self._post(
route,
body_json=body_json,
query=query,
retry_count=retry_count,
timeout_seconds=timeout_seconds,
)
yield from response.iter_bytes()

def _get_file(
self, route: str, query: Query | None = None, retry_count: int = 0
) -> tuple[bytes, str]:
Expand Down
20 changes: 20 additions & 0 deletions webknossos/webknossos/client/api_client/datastore_api_client.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
from collections.abc import Iterator

from webknossos.client.api_client.models import (
ApiDatasetAnnounceUpload,
ApiDatasetManualUploadSuccess,
ApiDatasetUploadInformation,
ApiDatasetUploadSuccess,
ApiMeshAdHoc,
ApiMeshPrecomputed,
ApiReserveDatasetUploadInformation,
)

Expand Down Expand Up @@ -107,3 +111,19 @@ def dataset_get_raw_data(
}
response = self._get(route, query)
return response.content, response.headers.get("MISSING-BUCKETS")

def annotation_download_mesh(
self,
mesh: ApiMeshPrecomputed | ApiMeshAdHoc,
organization_id: str,
directory_name: str,
layer_name: str,
token: str | None,
) -> Iterator[bytes]:
route = f"/datasets/{organization_id}/{directory_name}/layers/{layer_name}/meshes/fullMesh.stl"
query: Query = {"token": token}
yield from self._post_json_with_response_stream(
route=route,
body_structured=mesh,
query=query,
)
28 changes: 27 additions & 1 deletion webknossos/webknossos/client/api_client/models.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any
from typing import Any, Literal

import attr

Expand Down Expand Up @@ -34,6 +34,12 @@ class ApiDataStore:
allows_upload: bool


@attr.s(auto_attribs=True)
class ApiTracingStore:
name: str
url: str


@attr.s(auto_attribs=True)
class ApiTeam:
id: str
Expand Down Expand Up @@ -355,3 +361,23 @@ class ApiFolder:
allowed_teams_cumulative: list[ApiTeam]
is_editable: bool
metadata: list[ApiMetadata] | None = None


@attr.s(auto_attribs=True)
class ApiMeshPrecomputed:
lod: int
mesh_file_name: str
segment_id: int
mapping_name: str | None


@attr.s(auto_attribs=True)
class ApiMeshAdHoc:
lod: int
segment_id: int # if mapping name is set, this is an agglomerate id
mapping_name: str | None
mapping_type: Literal["json", "agglomerate"]
mag: tuple[int, int, int]
seed_position: tuple[int, int, int]
mesh_file_name: None = None # None means ad-hoc mesh
additional_coordinates: None = None # ND datasets are not supported yet
42 changes: 42 additions & 0 deletions webknossos/webknossos/client/api_client/tracingstore_api_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
from collections.abc import Iterator

from webknossos.client.api_client.models import (
ApiMeshAdHoc,
ApiMeshPrecomputed,
)

from ._abstract_api_client import AbstractApiClient, Query


class TracingstoreApiClient(AbstractApiClient):
# Client to use the HTTP API of WEBKNOSSOS datastore servers.
# When adding a method here, use the utility methods from AbstractApiClient
# and add more as needed.
# Methods here are prefixed with the domain, e.g. dataset_finish_upload (not finish_dataset_upload)

def __init__(
self,
base_url: str,
timeout_seconds: float,
headers: dict[str, str] | None = None,
):
super().__init__(timeout_seconds, headers)
self.base_url = base_url

@property
def url_prefix(self) -> str:
return f"{self.base_url}/tracings"

def annotation_download_mesh(
self,
mesh: ApiMeshPrecomputed | ApiMeshAdHoc,
tracing_id: str,
token: str | None,
) -> Iterator[bytes]:
route = f"/volume/{tracing_id}/fullMesh.stl"
query: Query = {"token": token}
yield from self._post_json_with_response_stream(
route=route,
body_structured=mesh,
query=query,
)
5 changes: 5 additions & 0 deletions webknossos/webknossos/client/api_client/wk_api_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
ApiDatasetId,
ApiDatasetIsValidNewNameResponse,
ApiDataStore,
ApiTracingStore,
ApiDataStoreToken,
ApiFolderWithParent,
ApiLoggedTimeGroupedByMonth,
Expand Down Expand Up @@ -135,6 +136,10 @@ def datastore_list(self) -> list[ApiDataStore]:
route = "/datastores"
return self._get_json(route, list[ApiDataStore])

def tracingstore(self) -> ApiTracingStore:
route = "/tracingstore"
return self._get_json(route, ApiTracingStore)

def project_info_by_name(self, project_name: str) -> ApiProject:
route = f"/projects/byName/{project_name}"
return self._get_json(route, ApiProject)
Expand Down
13 changes: 12 additions & 1 deletion webknossos/webknossos/client/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@
from rich.prompt import Prompt

from ._defaults import DEFAULT_HTTP_TIMEOUT, DEFAULT_WEBKNOSSOS_URL
from .api_client import DatastoreApiClient, WkApiClient
from .api_client import DatastoreApiClient, WkApiClient, TracingstoreApiClient

load_dotenv()

Expand Down Expand Up @@ -166,6 +166,17 @@ def get_datastore_api_client(self, datastore_url: str) -> DatastoreApiClient:
headers=headers,
)

def get_tracingstore_api_client(self) -> TracingstoreApiClient:
if self.datastore_token is not None:
headers = {"X-Auth-Token": self.datastore_token}
api_tracingstore = self.api_client_with_auth.tracingstore()

return TracingstoreApiClient(
base_url=api_tracingstore.url,
timeout_seconds=self.timeout,
headers=headers,
)


_webknossos_context_var: ContextVar[_WebknossosContext] = ContextVar(
"_webknossos_context_var", default=_WebknossosContext()
Expand Down
Empty file.
35 changes: 35 additions & 0 deletions webknossos/webknossos/tracingstore/tracingstore.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
import attr


@attr.frozen
class Tracingstore:
"""Tracingstore class for interactions with the tracing store."""

name: str
url: str

@classmethod
def get_tracingstore(
cls,
) -> "Tracingstore":
"""Get the tracingstore for current webknossos url.


Returns:
Tracingstore object

Examples:
```
# Get a list of all datastores that allow dataset uploads
tracingstore = Tracingstore.get_tracingstore()
```
"""

from ..client.context import _get_context

context = _get_context()
api_tracingstore = context.api_client_with_auth.tracingstore()
return cls(
api_tracingstore.name,
api_tracingstore.url,
)
Loading