-
Notifications
You must be signed in to change notification settings - Fork 7
Ngio projection #866
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Ngio projection #866
Changes from 6 commits
e55821e
605d469
5b46f01
dd6dba9
802bdb8
3da52c8
c8b050a
2d63e84
0155294
a17d4bb
1adf2bd
da10279
9e744e4
65f69be
b784936
b2abce6
f76134d
77a9941
2f03cfe
cb67710
387aaae
aacb3e2
6bbc5bf
bc25920
4c12f2b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -12,28 +12,38 @@ | |
""" | ||
Task for 3D->2D maximum-intensity projection. | ||
""" | ||
import logging | ||
from __future__ import annotations | ||
|
||
from typing import Any | ||
from typing import TYPE_CHECKING | ||
|
||
import anndata as ad | ||
import dask.array as da | ||
import zarr | ||
from ngio import NgffImage | ||
from ngio.utils import ngio_logger | ||
from pydantic import validate_call | ||
from zarr.errors import ContainsArrayError | ||
|
||
from fractal_tasks_core.ngff import load_NgffImageMeta | ||
from fractal_tasks_core.pyramids import build_pyramid | ||
from fractal_tasks_core.roi import ( | ||
convert_ROIs_from_3D_to_2D, | ||
) | ||
from fractal_tasks_core.tables import write_table | ||
from fractal_tasks_core.tables.v1 import get_tables_list_v1 | ||
|
||
from fractal_tasks_core.tasks.io_models import InitArgsMIP | ||
from fractal_tasks_core.tasks.projection_utils import DaskProjectionMethod | ||
from fractal_tasks_core.zarr_utils import OverwriteNotAllowedError | ||
|
||
if TYPE_CHECKING: | ||
from ngio.core import Image | ||
lorenzocerrone marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
|
||
def _compute_new_shape(source_image: Image) -> tuple[int]: | ||
"""Compute the new shape of the image after the projection. | ||
|
||
The new shape is the same as the original one, | ||
except for the z-axis, which is set to 1. | ||
""" | ||
on_disk_shape = source_image.on_disk_shape | ||
ngio_logger.info(f"Source {on_disk_shape=}") | ||
|
||
on_disk_z_index = source_image.dataset.on_disk_axes_names.index("z") | ||
|
||
logger = logging.getLogger(__name__) | ||
dest_on_disk_shape = list(on_disk_shape) | ||
dest_on_disk_shape[on_disk_z_index] = 1 | ||
ngio_logger.info(f"Destination {dest_on_disk_shape=}") | ||
return tuple(dest_on_disk_shape) | ||
|
||
|
||
@validate_call | ||
|
@@ -55,122 +65,65 @@ | |
`create_cellvoyager_ome_zarr_init`. | ||
""" | ||
method = DaskProjectionMethod(init_args.method) | ||
logger.info(f"{init_args.origin_url=}") | ||
logger.info(f"{zarr_url=}") | ||
logger.info(f"{method=}") | ||
ngio_logger.info(f"{init_args.origin_url=}") | ||
lorenzocerrone marked this conversation as resolved.
Show resolved
Hide resolved
|
||
ngio_logger.info(f"{zarr_url=}") | ||
ngio_logger.info(f"{method=}") | ||
|
||
# Read image metadata | ||
ngff_image = load_NgffImageMeta(init_args.origin_url) | ||
# Currently not using the validation models due to wavelength_id issue | ||
# See #681 for discussion | ||
# new_attrs = ngff_image.model_dump(exclude_none=True) | ||
# Current way to get the necessary metadata for MIP | ||
group = zarr.open_group(init_args.origin_url, mode="r") | ||
new_attrs = group.attrs.asdict() | ||
|
||
# Create the zarr image with correct | ||
new_image_group = zarr.group(zarr_url) | ||
new_image_group.attrs.put(new_attrs) | ||
|
||
# Load 0-th level | ||
data_czyx = da.from_zarr(init_args.origin_url + "/0") | ||
num_channels = data_czyx.shape[0] | ||
chunksize_y = data_czyx.chunksize[-2] | ||
chunksize_x = data_czyx.chunksize[-1] | ||
logger.info(f"{num_channels=}") | ||
logger.info(f"{chunksize_y=}") | ||
logger.info(f"{chunksize_x=}") | ||
|
||
# Loop over channels | ||
accumulate_chl = [] | ||
for ind_ch in range(num_channels): | ||
# Perform MIP for each channel of level 0 | ||
project_yx = da.stack( | ||
[method.apply(data_czyx[ind_ch], axis=0)], axis=0 | ||
) | ||
accumulate_chl.append(project_yx) | ||
accumulated_array = da.stack(accumulate_chl, axis=0) | ||
|
||
# Write to disk (triggering execution) | ||
try: | ||
accumulated_array.to_zarr( | ||
f"{zarr_url}/0", | ||
overwrite=init_args.overwrite, | ||
dimension_separator="/", | ||
write_empty_chunks=False, | ||
) | ||
except ContainsArrayError as e: | ||
error_msg = ( | ||
f"Cannot write array to zarr group at '{zarr_url}/0', " | ||
f"with {init_args.overwrite=} (original error: {str(e)}).\n" | ||
"Hint: try setting overwrite=True." | ||
original_ngff_image = NgffImage(init_args.origin_url) | ||
orginal_image = original_ngff_image.get_image() | ||
|
||
if orginal_image.is_2d or orginal_image.is_2d_time_series: | ||
raise ValueError( | ||
"The input image is 2D, " | ||
"projection is only supported for 3D images." | ||
) | ||
logger.error(error_msg) | ||
raise OverwriteNotAllowedError(error_msg) | ||
|
||
# Starting from on-disk highest-resolution data, build and write to disk a | ||
# pyramid of coarser levels | ||
build_pyramid( | ||
zarrurl=zarr_url, | ||
# Compute the new shape and pixel size | ||
dest_on_disk_shape = _compute_new_shape(orginal_image) | ||
|
||
dest_pixel_size = orginal_image.pixel_size | ||
dest_pixel_size.z = 1.0 | ||
ngio_logger.info(f"New shape: {dest_on_disk_shape=}") | ||
|
||
# Create the new empty image | ||
new_ngff_image = original_ngff_image.derive_new_image( | ||
store=zarr_url, | ||
name="MIP", | ||
on_disk_shape=dest_on_disk_shape, | ||
pixel_sizes=dest_pixel_size, | ||
overwrite=init_args.overwrite, | ||
num_levels=ngff_image.num_levels, | ||
coarsening_xy=ngff_image.coarsening_xy, | ||
chunksize=(1, 1, chunksize_y, chunksize_x), | ||
) | ||
new_image = new_ngff_image.get_image() | ||
|
||
# Copy over any tables from the original zarr | ||
# Generate the list of tables: | ||
tables = get_tables_list_v1(init_args.origin_url) | ||
roi_tables = get_tables_list_v1(init_args.origin_url, table_type="ROIs") | ||
non_roi_tables = [table for table in tables if table not in roi_tables] | ||
|
||
for table in roi_tables: | ||
logger.info( | ||
f"Reading {table} from " | ||
f"{init_args.origin_url=}, convert it to 2D, and " | ||
"write it back to the new zarr file." | ||
) | ||
new_ROI_table = ad.read_zarr(f"{init_args.origin_url}/tables/{table}") | ||
old_ROI_table_attrs = zarr.open_group( | ||
f"{init_args.origin_url}/tables/{table}" | ||
).attrs.asdict() | ||
|
||
# Convert 3D ROIs to 2D | ||
pxl_sizes_zyx = ngff_image.get_pixel_sizes_zyx(level=0) | ||
new_ROI_table = convert_ROIs_from_3D_to_2D( | ||
new_ROI_table, pixel_size_z=pxl_sizes_zyx[0] | ||
) | ||
# Write new table | ||
write_table( | ||
new_image_group, | ||
table, | ||
new_ROI_table, | ||
table_attrs=old_ROI_table_attrs, | ||
overwrite=init_args.overwrite, | ||
) | ||
# Process the image | ||
z_axis_index = orginal_image.find_axis("z") | ||
source_dask = orginal_image.get_array( | ||
mode="dask", preserve_dimensions=True | ||
) | ||
|
||
for table in non_roi_tables: | ||
logger.info( | ||
f"Reading {table} from " | ||
f"{init_args.origin_url=}, and " | ||
"write it back to the new zarr file." | ||
) | ||
new_non_ROI_table = ad.read_zarr( | ||
f"{init_args.origin_url}/tables/{table}" | ||
) | ||
old_non_ROI_table_attrs = zarr.open_group( | ||
f"{init_args.origin_url}/tables/{table}" | ||
).attrs.asdict() | ||
|
||
# Write new table | ||
write_table( | ||
new_image_group, | ||
table, | ||
new_non_ROI_table, | ||
table_attrs=old_non_ROI_table_attrs, | ||
overwrite=init_args.overwrite, | ||
dest_dask = method.apply(dask_array=source_dask, axis=z_axis_index) | ||
dest_dask = da.expand_dims(dest_dask, axis=z_axis_index) | ||
new_image.set_array(dest_dask) | ||
new_image.consolidate() | ||
# Ends | ||
|
||
# Copy over the tables | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Naive question by someone who is not involved in ngio development/integration: where does this feature fit best? Would it make sense to have something like ngio.copy_tables(original_ngff_image, new_ngff_image, project_z=True) or at least ngio.copy_tables(original_ngff_image, new_ngff_image)
# and then set `z_length` manually or would it be just additional complexity? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The responsability of copying the tables is now moved to the |
||
for roi_table in original_ngff_image.tables.list(table_type="roi_table"): | ||
jluethi marked this conversation as resolved.
Show resolved
Hide resolved
|
||
table = original_ngff_image.tables.get_table(roi_table) | ||
mip_table = new_ngff_image.tables.new( | ||
roi_table, table_type="roi_table", overwrite=True | ||
) | ||
|
||
roi_list = [] | ||
for roi in table.rois: | ||
roi.z_length = roi.z + 1 | ||
jluethi marked this conversation as resolved.
Show resolved
Hide resolved
|
||
roi_list.append(roi) | ||
|
||
mip_table.set_rois(roi_list, overwrite=True) | ||
mip_table.consolidate() | ||
ngio_logger.info(f"Table {roi_table} copied.") | ||
|
||
# Generate image_list_updates | ||
image_list_update_dict = dict( | ||
image_list_updates=[ | ||
|
@@ -189,5 +142,5 @@ | |
|
||
run_fractal_task( | ||
lorenzocerrone marked this conversation as resolved.
Show resolved
Hide resolved
|
||
task_function=projection, | ||
logger_name=logger.name, | ||
logger_name=ngio_logger.name, | ||
) |
Uh oh!
There was an error while loading. Please reload this page.