diff --git a/napari_cellseg3d/code_models/model_framework.py b/napari_cellseg3d/code_models/model_framework.py index 7caff7b6..1d90e906 100644 --- a/napari_cellseg3d/code_models/model_framework.py +++ b/napari_cellseg3d/code_models/model_framework.py @@ -3,6 +3,8 @@ from typing import TYPE_CHECKING import torch +import torch.backends +import torch.backends.mps if TYPE_CHECKING: import napari @@ -94,6 +96,14 @@ def __init__( available_devices = ["CPU"] + [ f"GPU {i}" for i in range(torch.cuda.device_count()) ] + from napari_cellseg3d.utils import _is_mps_available + + try: + if _is_mps_available(torch): + available_devices.append("MPS (beta)") + except Exception as e: + logger.error(f"Error while checking MPS availability : {e}") + self.device_choice = ui.DropdownMenu( available_devices, parent=self, @@ -345,6 +355,8 @@ def check_device_choice(self): elif "GPU" in choice: i = int(choice.split(" ")[1]) device = f"cuda:{i}" + elif choice == "MPS (beta)": # TODO : check if MPS is available + device = "mps" else: device = self.get_device() logger.debug(f"DEVICE choice : {device}") diff --git a/napari_cellseg3d/code_models/worker_inference.py b/napari_cellseg3d/code_models/worker_inference.py index 46ba77eb..38e2f036 100644 --- a/napari_cellseg3d/code_models/worker_inference.py +++ b/napari_cellseg3d/code_models/worker_inference.py @@ -838,6 +838,11 @@ def inference(self): torch.set_num_threads(1) # required for threading on macOS ? self.log("Number of threads has been set to 1 for macOS") + if self.config.device == "mps": + from os import environ + + environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" + try: dims = self.config.model_info.model_input_size self.log(f"MODEL DIMS : {dims}") diff --git a/napari_cellseg3d/utils.py b/napari_cellseg3d/utils.py index 45f50582..fd9f2f32 100644 --- a/napari_cellseg3d/utils.py +++ b/napari_cellseg3d/utils.py @@ -7,6 +7,7 @@ import napari import numpy as np +import pkg_resources import torch from monai.transforms import Zoom from numpy.random import PCG64, Generator @@ -645,3 +646,18 @@ def fraction_above_threshold(volume: np.array, threshold=0.5) -> float: f"non zero in above_thresh : {np.count_nonzero(above_thresh)}" ) return np.count_nonzero(above_thresh) / np.size(flattened) + + +def _is_mps_available(torch): + available = False + if pkg_resources.parse_version( + torch.__version__ + ) >= pkg_resources.parse_version("1.12"): + LOGGER.debug("Torch version is 1.12 or higher, compatible with MPS") + if torch.backends.mps.is_available(): + LOGGER.debug("MPS is available") + if torch.backends.mps.is_built(): + LOGGER.debug("MPS is built") + available = True + + return available