Skip to content

Commit b9df37e

Browse files
authored
Merge pull request #6 from AllenNeuralDynamics/feat-pypi-release
Feat pypi release
2 parents 88d1078 + 568dd4b commit b9df37e

19 files changed

+233
-166
lines changed

.flake8

+7
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
[flake8]
2+
exclude =
3+
.git,
4+
__pycache__,
5+
build
6+
max-complexity = 10
7+
max-line-length = 100

README.md

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
1-
# aind-z1-cell-segmentation
1+
# aind-large-scale-cellpose
22

3-
Large-scale cell segmentation using cellpose for Z1 data.
3+
Large-scale cell segmentation using cellpose.
44

55
The approach is the following:
66

File renamed without changes.

code/cell_segmentation/cellpose_segmentation/combine_gradients.py code/aind_large_scale_cellpose/cellpose_segmentation/combine_gradients.py

+6-12
Original file line numberDiff line numberDiff line change
@@ -81,9 +81,9 @@ def execute_worker(
8181
)
8282

8383
# Cell probability above threshold
84-
cell_probability = (
85-
data[0][-1] + data[1][-1] + data[2][-1] > cellprob_threshold
86-
).astype(np.uint8)
84+
cell_probability = (data[0][-1] + data[1][-1] + data[2][-1] > cellprob_threshold).astype(
85+
np.uint8
86+
)
8787

8888
# Looking at flows within cell areas
8989
dP_masked = dP * cell_probability
@@ -218,9 +218,7 @@ def combine_gradients(
218218
0,
219219
0,
220220
)
221-
logger.info(
222-
f"Overlap size based on cell diameter * 2: {overlap_prediction_chunksize}"
223-
)
221+
logger.info(f"Overlap size based on cell diameter * 2: {overlap_prediction_chunksize}")
224222

225223
lazy_data = (
226224
ImageReaderFactory()
@@ -270,18 +268,14 @@ def combine_gradients(
270268
logger.info(
271269
f"Combined gradients: {output_combined_gradients} - chunks: {output_combined_gradients.chunks}" # noqa: E501
272270
)
273-
logger.info(
274-
f"Cell probabilities path: {output_cellprob} - chunks: {output_cellprob.chunks}"
275-
)
271+
logger.info(f"Cell probabilities path: {output_cellprob} - chunks: {output_cellprob.chunks}")
276272

277273
# Estimating total batches
278274
total_batches = np.prod(zarr_dataset.lazy_data.shape) / (
279275
np.prod(zarr_dataset.prediction_chunksize) * batch_size
280276
)
281277
samples_per_iter = n_workers * batch_size
282-
logger.info(
283-
f"Number of batches: {total_batches} - Samples per iteration: {samples_per_iter}"
284-
)
278+
logger.info(f"Number of batches: {total_batches} - Samples per iteration: {samples_per_iter}")
285279

286280
logger.info(f"{20*'='} Starting combination of gradients {20*'='}")
287281
start_time = time()

code/cell_segmentation/cellpose_segmentation/compute_flows.py code/aind_large_scale_cellpose/cellpose_segmentation/compute_flows.py

+7-20
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,7 @@
1515
import zarr
1616
from aind_large_scale_prediction._shared.types import ArrayLike, PathLike
1717
from aind_large_scale_prediction.generator.dataset import create_data_loader
18-
from aind_large_scale_prediction.generator.utils import (
19-
recover_global_position, unpad_global_coords)
18+
from aind_large_scale_prediction.generator.utils import recover_global_position, unpad_global_coords
2019
from aind_large_scale_prediction.io import ImageReaderFactory
2120
from cellpose import core
2221
from cellpose.dynamics import follow_flows
@@ -69,9 +68,7 @@ def computing_overlapping_hist_and_seed_finding(
6968
# Flatten p and compute edges
7069
p_flatten = p.astype("int32").reshape(p.shape[0], -1)
7170
shape0 = p.shape[1:]
72-
edges = [
73-
np.arange(-0.5 - rpad, shape0[i] + 0.5 + rpad, 1) for i in range(len(shape0))
74-
]
71+
edges = [np.arange(-0.5 - rpad, shape0[i] + 0.5 + rpad, 1) for i in range(len(shape0))]
7572

7673
# Compute histogram
7774
h, _ = np.histogramdd(tuple(p_flatten), bins=edges)
@@ -94,9 +91,7 @@ def computing_overlapping_hist_and_seed_finding(
9491

9592
# Compute pixel coordinates
9693
pix_local = np.column_stack(seeds_sorted).astype(np.uint32)
97-
pix_global = pix_local + np.array(
98-
[global_coord.start for global_coord in global_coords]
99-
)
94+
pix_global = pix_local + np.array([global_coord.start for global_coord in global_coords])
10095

10196
return pix_global, pix_local, h
10297

@@ -336,12 +331,8 @@ def generate_flows_and_centroids(
336331
multiprocessing.set_start_method("spawn", force=True)
337332

338333
# Getting overlap prediction chunksize
339-
overlap_prediction_chunksize = (0,) + tuple(
340-
[axis_overlap * 2] * len(prediction_chunksize[-3:])
341-
)
342-
logger.info(
343-
f"Overlap size based on cell diameter * 2: {overlap_prediction_chunksize}"
344-
)
334+
overlap_prediction_chunksize = (0,) + tuple([axis_overlap * 2] * len(prediction_chunksize[-3:]))
335+
logger.info(f"Overlap size based on cell diameter * 2: {overlap_prediction_chunksize}")
345336

346337
lazy_data = (
347338
ImageReaderFactory()
@@ -406,9 +397,7 @@ def generate_flows_and_centroids(
406397
np.prod(zarr_dataset.prediction_chunksize) * batch_size
407398
)
408399
samples_per_iter = n_workers * batch_size
409-
logger.info(
410-
f"Number of batches: {total_batches} - Samples per iteration: {samples_per_iter}"
411-
)
400+
logger.info(f"Number of batches: {total_batches} - Samples per iteration: {samples_per_iter}")
412401

413402
logger.info(f"{20*'='} Combining flows and creating histograms {20*'='}")
414403
start_time = time()
@@ -465,9 +454,7 @@ def generate_flows_and_centroids(
465454
curr_picked_blocks = 0
466455
picked_blocks = []
467456
time_proc_blocks_end = time()
468-
logger.info(
469-
f"Time processing blocks: {time_proc_blocks_end - time_proc_blocks}"
470-
)
457+
logger.info(f"Time processing blocks: {time_proc_blocks_end - time_proc_blocks}")
471458

472459
if curr_picked_blocks != 0:
473460
logger.info(f"Blocks not processed inside of loop: {curr_picked_blocks}")

code/cell_segmentation/cellpose_segmentation/compute_masks.py code/aind_large_scale_cellpose/cellpose_segmentation/compute_masks.py

+14-27
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,7 @@
1818
import psutil
1919
import zarr
2020
from aind_large_scale_prediction.generator.dataset import create_data_loader
21-
from aind_large_scale_prediction.generator.utils import (
22-
recover_global_position, unpad_global_coords)
21+
from aind_large_scale_prediction.generator.utils import recover_global_position, unpad_global_coords
2322
from aind_large_scale_prediction.io import ImageReaderFactory
2423
from cellpose import metrics
2524
from scipy.ndimage import binary_fill_holes, grey_dilation, map_coordinates
@@ -100,9 +99,7 @@ def create_initial_mask(
10099

101100
# Expand each voxel 3 voxels around it
102101
for i, e in enumerate(expand):
103-
epix = (
104-
e[:, np.newaxis] + np.expand_dims(cell_centroids[k][i], 0) - 1
105-
)
102+
epix = e[:, np.newaxis] + np.expand_dims(cell_centroids[k][i], 0) - 1
106103
# Flattenning points around a point inside ZYX
107104
epix = epix.flatten()
108105

@@ -155,7 +152,10 @@ def create_initial_mask(
155152

156153

157154
def remove_bad_flow_masks(
158-
masks: ArrayLike, flows: ArrayLike, threshold: Optional[float] = 0.4, device=None
155+
masks: ArrayLike,
156+
flows: ArrayLike,
157+
threshold: Optional[float] = 0.4,
158+
device=None,
159159
) -> ArrayLike:
160160
"""
161161
Removes bad flows within the generated initial mask.
@@ -219,9 +219,7 @@ def fill_holes_and_remove_small_masks(
219219
"""
220220

221221
if masks.ndim > 3 or masks.ndim < 2:
222-
raise ValueError(
223-
"masks_to_outlines takes 2D or 3D array, not %dD array" % masks.ndim
224-
)
222+
raise ValueError("masks_to_outlines takes 2D or 3D array, not %dD array" % masks.ndim)
225223

226224
masks_properties = regionprops(masks)
227225

@@ -388,9 +386,7 @@ def extract_global_to_local(
388386
]
389387

390388
# Mapping to the local coordinate system of the chunk
391-
picked_global_ids_with_cells[..., :3] = (
392-
picked_global_ids_with_cells[..., :3] - start_pos - pad
393-
)
389+
picked_global_ids_with_cells[..., :3] = picked_global_ids_with_cells[..., :3] - start_pos - pad
394390

395391
# Validating seeds are within block boundaries
396392
picked_global_ids_with_cells = picked_global_ids_with_cells[
@@ -531,9 +527,7 @@ def execute_worker(
531527
f"Global slices: {global_coord_pos} - Unpadded global slices: {unpadded_global_slice[1:]} - Local slices: {unpadded_local_slice[1:]}" # noqa: E501
532528
)
533529

534-
global_points_path = (
535-
f"{cell_centroids_path}/global_seeds_{unpadded_global_slice[1:]}.npy"
536-
)
530+
global_points_path = f"{cell_centroids_path}/global_seeds_{unpadded_global_slice[1:]}.npy"
537531

538532
# Unpadded block mask zeros if seeds don't exist in that area
539533
chunked_seg_mask = np.zeros(data.shape[1:], dtype=np.uint32)
@@ -685,6 +679,7 @@ def generate_masks(
685679

686680
output_seg_dtype = np.uint32 # get_output_seg_data_type(n_cells=n_ids.shape[0])
687681
global_seeds = np.vstack((global_seeds.T, n_ids)).T
682+
np.save(f"{results_folder}/cell_centroids.npy", global_seeds)
688683

689684
else:
690685
raise ValueError("Please, provide the global seeds")
@@ -736,12 +731,8 @@ def generate_masks(
736731
multiprocessing.set_start_method("spawn", force=True)
737732

738733
# Getting overlap prediction chunksize
739-
overlap_prediction_chunksize = (0,) + tuple(
740-
[axis_overlap * 2] * len(prediction_chunksize[-3:])
741-
)
742-
logger.info(
743-
f"Overlap size based on cell diameter * 2: {overlap_prediction_chunksize}"
744-
)
734+
overlap_prediction_chunksize = (0,) + tuple([axis_overlap * 2] * len(prediction_chunksize[-3:]))
735+
logger.info(f"Overlap size based on cell diameter * 2: {overlap_prediction_chunksize}")
745736

746737
lazy_data = (
747738
ImageReaderFactory()
@@ -799,18 +790,14 @@ def generate_masks(
799790

800791
hists = zarr.open(hists_path, "r")
801792

802-
logger.info(
803-
f"Creating masks in path: {output_seg_masks} chunks: {output_seg_masks.chunks}"
804-
)
793+
logger.info(f"Creating masks in path: {output_seg_masks} chunks: {output_seg_masks.chunks}")
805794

806795
# Estimating total batches
807796
total_batches = np.prod(zarr_dataset.lazy_data.shape) / (
808797
np.prod(zarr_dataset.prediction_chunksize) * batch_size
809798
)
810799
samples_per_iter = n_workers * batch_size
811-
logger.info(
812-
f"Number of batches: {total_batches} - Samples per iteration: {samples_per_iter}"
813-
)
800+
logger.info(f"Number of batches: {total_batches} - Samples per iteration: {samples_per_iter}")
814801

815802
logger.info(f"{20*'='} Starting mask generation {20*'='}")
816803
start_time = time()

code/cell_segmentation/cellpose_segmentation/compute_percentiles.py code/aind_large_scale_cellpose/cellpose_segmentation/compute_percentiles.py

+8-12
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,7 @@
99
import numpy as np
1010
from aind_large_scale_prediction._shared.types import ArrayLike
1111
from aind_large_scale_prediction.generator.utils import concatenate_lazy_data
12-
from aind_large_scale_prediction.generator.zarr_slice_generator import \
13-
BlockedZarrArrayIterator
12+
from aind_large_scale_prediction.generator.zarr_slice_generator import BlockedZarrArrayIterator
1413
from aind_large_scale_prediction.io import extract_data
1514
from dask import config as da_cfg
1615
from distributed import Client, LocalCluster
@@ -72,9 +71,7 @@ def get_channel_percentiles(
7271
Dictionary with the computed percentiles.
7372
"""
7473
# Iterate through the input array in steps equal to the block shape dimensions
75-
slices_to_process = list(
76-
BlockedZarrArrayIterator.gen_slices(array.shape, block_shape)
77-
)
74+
slices_to_process = list(BlockedZarrArrayIterator.gen_slices(array.shape, block_shape))
7875

7976
percentiles = {}
8077

@@ -206,17 +203,18 @@ def combine_percentiles(percentiles: Dict, method: Optional[str] = "min_max") ->
206203
channel_percentiles_cmb = None
207204
if method == "min_max":
208205
channel_percentiles_cmb = np.array(
209-
[np.min(channel_percentiles[0]), np.max(channel_percentiles[1])]
206+
[
207+
np.min(channel_percentiles[0]),
208+
np.max(channel_percentiles[1]),
209+
]
210210
)
211211

212212
elif method == "median":
213213
channel_percentiles_cmb = np.median(channel_percentiles, axis=1)
214214

215215
combined_percentiles.append(list(channel_percentiles_cmb))
216216

217-
print(
218-
f"Channel {chn_idx}: {channel_percentiles} - cmb: {channel_percentiles_cmb}"
219-
)
217+
print(f"Channel {chn_idx}: {channel_percentiles} - cmb: {channel_percentiles_cmb}")
220218

221219
return combined_percentiles
222220

@@ -282,9 +280,7 @@ def compute_percentiles(
282280
threads_per_worker=threads_per_worker,
283281
)
284282

285-
combined_percentiles = combine_percentiles(
286-
percentiles=percentiles, method=combine_method
287-
)
283+
combined_percentiles = combine_percentiles(percentiles=percentiles, method=combine_method)
288284

289285
return combined_percentiles, percentiles
290286

0 commit comments

Comments
 (0)