|
6 | 6 |
|
7 | 7 | import logging
|
8 | 8 | import multiprocessing
|
| 9 | +from pathlib import Path |
9 | 10 | from time import time
|
10 | 11 | from typing import Callable, Dict, List, Optional, Tuple
|
11 | 12 |
|
@@ -536,7 +537,7 @@ def large_scale_cellpose_gradients_per_axis(
|
536 | 537 | super_chunksize: Optional[Tuple[int, ...]] = None,
|
537 | 538 | lazy_callback_fn: Optional[Callable[[ArrayLike], ArrayLike]] = None,
|
538 | 539 | global_normalization: Optional[bool] = True,
|
539 |
| - model_name: Optional[str] = "cyto", |
| 540 | + model_name: Optional[PathLike] = "cyto", |
540 | 541 | cell_diameter: Optional[int] = 15,
|
541 | 542 | cell_channels: Optional[List[int]] = [0, 0],
|
542 | 543 | chn_percentiles: Optional[Dict] = None,
|
@@ -591,8 +592,9 @@ def large_scale_cellpose_gradients_per_axis(
|
591 | 592 | global_normalization: Optional[bool] = True
|
592 | 593 | If we want to normalize the data for cellpose.
|
593 | 594 |
|
594 |
| - model_name: Optional[str] = "cyto" |
595 |
| - Model name to be used by cellpose |
| 595 | + model_name: Optional[PathLike] = "cyto" |
| 596 | + Model name to be used by cellpose. This could also be a path |
| 597 | + pointing to a pretrained model. |
596 | 598 |
|
597 | 599 | cell_diameter: Optional[int] = 15
|
598 | 600 | Cell diameter for cellpose
|
@@ -710,7 +712,17 @@ def large_scale_cellpose_gradients_per_axis(
|
710 | 712 |
|
711 | 713 | # Getting current GPU device and inizialing cellpose network
|
712 | 714 | sdevice, gpu = assign_device(use_torch=use_GPU, gpu=use_GPU)
|
713 |
| - model = CellposeModel(gpu=gpu, model_type=model_name, diam_mean=cell_diameter, device=sdevice) |
| 715 | + |
| 716 | + # Loading model, could be a pretrained model |
| 717 | + if Path(model_name).exists(): |
| 718 | + model = CellposeModel( |
| 719 | + gpu=gpu, pretrained_model=str(model_name), diam_mean=cell_diameter, device=sdevice |
| 720 | + ) |
| 721 | + |
| 722 | + else: |
| 723 | + model = CellposeModel( |
| 724 | + gpu=gpu, model_type=str(model_name), diam_mean=cell_diameter, device=sdevice |
| 725 | + ) |
714 | 726 |
|
715 | 727 | # Estimating total batches
|
716 | 728 | total_batches = np.prod(zarr_dataset.lazy_data.shape) / (
|
|
0 commit comments