Skip to content

Commit 43362db

Browse files
committed
adding pretrained option to cellpose model name
1 parent b9df37e commit 43362db

File tree

1 file changed

+16
-4
lines changed

1 file changed

+16
-4
lines changed

code/aind_large_scale_cellpose/cellpose_segmentation/predict_gradients.py

+16-4
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
import logging
88
import multiprocessing
9+
from pathlib import Path
910
from time import time
1011
from typing import Callable, Dict, List, Optional, Tuple
1112

@@ -536,7 +537,7 @@ def large_scale_cellpose_gradients_per_axis(
536537
super_chunksize: Optional[Tuple[int, ...]] = None,
537538
lazy_callback_fn: Optional[Callable[[ArrayLike], ArrayLike]] = None,
538539
global_normalization: Optional[bool] = True,
539-
model_name: Optional[str] = "cyto",
540+
model_name: Optional[PathLike] = "cyto",
540541
cell_diameter: Optional[int] = 15,
541542
cell_channels: Optional[List[int]] = [0, 0],
542543
chn_percentiles: Optional[Dict] = None,
@@ -591,8 +592,9 @@ def large_scale_cellpose_gradients_per_axis(
591592
global_normalization: Optional[bool] = True
592593
If we want to normalize the data for cellpose.
593594
594-
model_name: Optional[str] = "cyto"
595-
Model name to be used by cellpose
595+
model_name: Optional[PathLike] = "cyto"
596+
Model name to be used by cellpose. This could also be a path
597+
pointing to a pretrained model.
596598
597599
cell_diameter: Optional[int] = 15
598600
Cell diameter for cellpose
@@ -710,7 +712,17 @@ def large_scale_cellpose_gradients_per_axis(
710712

711713
# Getting current GPU device and inizialing cellpose network
712714
sdevice, gpu = assign_device(use_torch=use_GPU, gpu=use_GPU)
713-
model = CellposeModel(gpu=gpu, model_type=model_name, diam_mean=cell_diameter, device=sdevice)
715+
716+
# Loading model, could be a pretrained model
717+
if Path(model_name).exists():
718+
model = CellposeModel(
719+
gpu=gpu, pretrained_model=str(model_name), diam_mean=cell_diameter, device=sdevice
720+
)
721+
722+
else:
723+
model = CellposeModel(
724+
gpu=gpu, model_type=str(model_name), diam_mean=cell_diameter, device=sdevice
725+
)
714726

715727
# Estimating total batches
716728
total_batches = np.prod(zarr_dataset.lazy_data.shape) / (

0 commit comments

Comments
 (0)