Skip to content

8328 nnunet bundle integration #8469

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 166 commits into
base: dev
Choose a base branch
from
Open
Changes from all commits
Commits
Show all changes
166 commits
Select commit Hold shift + click to select a range
13bdec5
Add nnUNet integration and corresponding unit tests
SimoneBendazzoli93 Jan 28, 2025
74aaf73
Implement nnUNet model conversion to MONAI bundle format and enhance …
SimoneBendazzoli93 Feb 5, 2025
ccdc76c
Merge branch 'Project-MONAI:dev' into dev
SimoneBendazzoli93 Feb 5, 2025
b61e4e1
Refactor nnUNet bundle integration tests for clarity and remove redun…
SimoneBendazzoli93 Feb 5, 2025
8e4a66c
Code reformatting
SimoneBendazzoli93 Feb 5, 2025
49dbb5d
DCO Remediation Commit for simben <[email protected]>
SimoneBendazzoli93 Feb 5, 2025
5a04fe0
nibabel importing moved to setUp
SimoneBendazzoli93 Feb 5, 2025
24643b8
DCO Remediation Commit for simben <[email protected]>
SimoneBendazzoli93 Feb 5, 2025
253dab1
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 5, 2025
43c694b
Add nnUNet Bundle documentation and related functions to bundle.rst
SimoneBendazzoli93 Feb 6, 2025
569df7b
Refactor nnUNet documentation and examples for clarity; update fold p…
SimoneBendazzoli93 Feb 6, 2025
fcf5ac0
Clean up whitespace in nnunet.py and add test for nnunet bundle integ…
SimoneBendazzoli93 Feb 6, 2025
c846b6d
Fix type conversion for folds and improve input_files type checking i…
SimoneBendazzoli93 Feb 6, 2025
2da5ca9
DCO Remediation Commit for simben <[email protected]>
SimoneBendazzoli93 Feb 6, 2025
48d53a4
Fix documentation for output tensor return in ModelnnUNetWrapper
SimoneBendazzoli93 Feb 6, 2025
230cb9b
Remove outdated method documentation for forward pass in ModelnnUNetW…
SimoneBendazzoli93 Feb 6, 2025
507bca8
Merge branch 'dev' into 8328-nnunet-bundle-integration
SimoneBendazzoli93 Feb 12, 2025
e7a44f5
Merge branch 'Project-MONAI:dev' into dev
SimoneBendazzoli93 Feb 12, 2025
8ba37ad
Merge branch 'dev' into 8328-nnunet-bundle-integration
SimoneBendazzoli93 Feb 13, 2025
2ec9207
Merge remote-tracking branch 'origin/dev' into 8328-nnunet-bundle-int…
SimoneBendazzoli93 Feb 18, 2025
bce281e
Merge remote-tracking branch 'upstream/dev' into 8328-nnunet-bundle-i…
SimoneBendazzoli93 Feb 18, 2025
ea8028f
Add integration tests for nnUNet bundle functionality
SimoneBendazzoli93 Feb 18, 2025
1a1c78d
Merge branch 'Project-MONAI:dev' into dev
SimoneBendazzoli93 Feb 18, 2025
97c98bd
Merge branch 'Project-MONAI:dev' into 8328-nnunet-bundle-integration
SimoneBendazzoli93 Feb 24, 2025
9dc6532
DCO Remediation Commit for simben <[email protected]>
SimoneBendazzoli93 Feb 24, 2025
a209e25
Merge branch 'Project-MONAI:dev' into dev
SimoneBendazzoli93 Feb 28, 2025
7b6a247
Merge remote-tracking branch 'origin/dev' into 8328-nnunet-bundle-int…
SimoneBendazzoli93 Feb 28, 2025
d29ff7d
Merge branch 'dev' into 8328-nnunet-bundle-integration
ericspod Mar 4, 2025
8bbb63b
Refactor nnUNet imports for improved module organization
SimoneBendazzoli93 Mar 10, 2025
2e6bb14
Merge branch '8328-nnunet-bundle-integration' of https://github.com/S…
SimoneBendazzoli93 Mar 10, 2025
6e97f39
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 10, 2025
a7cad28
Update nnUNet import paths and comment out Transposed transform in in…
SimoneBendazzoli93 Mar 10, 2025
6e1a2bd
Merge branch '8328-nnunet-bundle-integration' of https://github.com/S…
SimoneBendazzoli93 Mar 10, 2025
e099b08
Merge branch 'Project-MONAI:dev' into dev
SimoneBendazzoli93 Mar 10, 2025
0734eb3
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 10, 2025
fe173b6
Merge remote-tracking branch 'origin/dev' into 8328-nnunet-bundle-int…
SimoneBendazzoli93 Mar 10, 2025
10aa0ce
Merge branch '8328-nnunet-bundle-integration' of https://github.com/S…
SimoneBendazzoli93 Mar 10, 2025
c107015
DCO Remediation Commit for simben <[email protected]>
SimoneBendazzoli93 Mar 10, 2025
3b5e80b
Update documentation for nnUNet Bundle integration
SimoneBendazzoli93 Mar 10, 2025
224e924
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 10, 2025
1da18e1
Refactor imports and improve code formatting in nnUNet bundle
SimoneBendazzoli93 Mar 10, 2025
04a12e2
Merge branch '8328-nnunet-bundle-integration' of https://github.com/S…
SimoneBendazzoli93 Mar 10, 2025
fe02136
DCO Remediation Commit for simben <[email protected]>
SimoneBendazzoli93 Mar 10, 2025
9f91aaf
Fix formatting in test_integration_nnunet_bundle.py for improved read…
SimoneBendazzoli93 Mar 10, 2025
a2bc247
DCO Remediation Commit for simben <[email protected]>
SimoneBendazzoli93 Mar 10, 2025
23de2bc
Refactor forward method in ModelnnUNetWrapper for clarity and type co…
SimoneBendazzoli93 Mar 10, 2025
6c32444
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 10, 2025
f43125a
Remove commented-out code in ModelnnUNetWrapper for improved readability
SimoneBendazzoli93 Mar 10, 2025
1032c9b
Merge branch '8328-nnunet-bundle-integration' of https://github.com/S…
SimoneBendazzoli93 Mar 10, 2025
30cb6c5
Comment out the torch.compile line in ModelnnUNetWrapper to prevent p…
SimoneBendazzoli93 Mar 10, 2025
2c24a6d
Merge branch 'Project-MONAI:dev' into dev
SimoneBendazzoli93 Mar 24, 2025
568d25a
Merge branch '8328-nnunet-bundle-integration' into dev
SimoneBendazzoli93 Mar 24, 2025
1a30a0b
Add JSON generation widgets for nnUNet and update requirements
SimoneBendazzoli93 Mar 24, 2025
ca851cd
Add modality_list parameter to nnUNetExecutor and related functions
SimoneBendazzoli93 Mar 25, 2025
fee1bb0
Fix modality_list check in prepare_config and add debug print statement
SimoneBendazzoli93 Mar 25, 2025
1972504
Rename nnUNetMONAIModelWrapper to ModelnnUNetWrapper for consistency
SimoneBendazzoli93 Mar 26, 2025
5c633f2
SimoneBendazzoli93 Mar 28, 2025
052ef64
Add original_dataset_name to nnunet_plans in preprocess function
SimoneBendazzoli93 Mar 28, 2025
1c41164
Remove unused nvflare module files and restore polygraphy in requirem…
SimoneBendazzoli93 Mar 30, 2025
60185d1
Add new functions to nnunet_bundle for converting between MONAI and n…
SimoneBendazzoli93 Mar 30, 2025
b0ecb2c
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 30, 2025
678334e
DCO Remediation Commit for Simone Bendazzoli <[email protected]>
SimoneBendazzoli93 Mar 30, 2025
9da69fc
Merge branch '8328-nnunet-bundle-integration' of https://github.com/S…
SimoneBendazzoli93 Mar 30, 2025
6178082
Add ModelnnUNetWrapper import to nnunet bundle
SimoneBendazzoli93 Mar 30, 2025
2a83641
Merge branch 'Project-MONAI:dev' into dev
SimoneBendazzoli93 Mar 30, 2025
570d068
Merge remote-tracking branch 'origin/dev' into 8328-nnunet-bundle-int…
SimoneBendazzoli93 Mar 30, 2025
78a7d14
DCO Remediation Commit for Simone Bendazzoli <[email protected]>
SimoneBendazzoli93 Mar 30, 2025
8d132f8
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 30, 2025
88a5e5a
Refactor nnUNet integration: update type hints and improve parameter …
SimoneBendazzoli93 Mar 30, 2025
49d0897
Refactor nnUNet bundle: clean up import order and improve code format…
SimoneBendazzoli93 Mar 30, 2025
0f8335b
Merge branch '8328-nnunet-bundle-integration' of https://github.com/S…
SimoneBendazzoli93 Mar 30, 2025
18d5a4c
DCO Remediation Commit for Simone Bendazzoli <[email protected]>
SimoneBendazzoli93 Mar 30, 2025
4ca028a
Enhance nnUNet bundle: add nnUNetPredictor import and update type hin…
SimoneBendazzoli93 Mar 30, 2025
050651c
DCO Remediation Commit for Simone Bendazzoli <[email protected]>
SimoneBendazzoli93 Mar 30, 2025
b881cd3
Refactor nnUNet bundle: remove unused nnUNetTrainer and nnUNetPredict…
SimoneBendazzoli93 Mar 30, 2025
88a28d2
DCO Remediation Commit for Simone Bendazzoli <[email protected]>
SimoneBendazzoli93 Mar 30, 2025
782f1fd
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 30, 2025
8e510e1
Update docstring in get_nnunet_trainer to include link for supported …
SimoneBendazzoli93 Mar 30, 2025
f74a39c
Merge branch '8328-nnunet-bundle-integration' of https://github.com/S…
SimoneBendazzoli93 Mar 30, 2025
8854557
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 30, 2025
7d86a73
DCO Remediation Commit for Simone Bendazzoli <[email protected]>
SimoneBendazzoli93 Mar 30, 2025
a346719
Merge branch '8328-nnunet-bundle-integration' of https://github.com/S…
SimoneBendazzoli93 Mar 30, 2025
5422368
Update docstring in get_nnunet_trainer for better readability of supp…
SimoneBendazzoli93 Mar 30, 2025
aff9cbe
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 30, 2025
5527ac8
DCO Remediation Commit for Simone Bendazzoli <[email protected]>
SimoneBendazzoli93 Mar 30, 2025
d969b79
Merge branch '8328-nnunet-bundle-integration' of https://github.com/S…
SimoneBendazzoli93 Mar 30, 2025
7d60fd7
nvflare support
SimoneBendazzoli93 Mar 31, 2025
ed2360c
Merge branch 'dev' into 8328-nnunet-bundle-integration
SimoneBendazzoli93 Apr 1, 2025
3b13218
Update requirements.txt
SimoneBendazzoli93 Apr 1, 2025
fbf6105
Add nnunet_root_folder parameter to train function
SimoneBendazzoli93 Apr 2, 2025
d1035ca
```
SimoneBendazzoli93 Apr 2, 2025
a8b0a23
Merge branch 'dev' of https://github.com/SimoneBendazzoli93/MONAI int…
SimoneBendazzoli93 Apr 2, 2025
47798af
Remove conditional print statement for torch.compile in nnUNetWrapper
SimoneBendazzoli93 Apr 2, 2025
c04cc05
Merge branch 'dev' into 8328-nnunet-bundle-integration
SimoneBendazzoli93 Apr 2, 2025
0578b22
Remove unused nvflare module files
SimoneBendazzoli93 Apr 2, 2025
6d3fb0c
Update torch version requirement in requirements.txt
SimoneBendazzoli93 Apr 2, 2025
fd12ece
Add continue_training parameter to nnUNetExecutor and update job conf…
SimoneBendazzoli93 Apr 3, 2025
fed603b
DCO Remediation Commit for Simone Bendazzoli <[email protected]>
SimoneBendazzoli93 Apr 3, 2025
d4a9a9e
Fix config file path for continue training in train function
SimoneBendazzoli93 Apr 3, 2025
b08d416
Refactor continue_training handling in job config generation and trai…
SimoneBendazzoli93 Apr 3, 2025
ff64bd0
Fix config file path for training in nvflare_nnunet.py
SimoneBendazzoli93 Apr 3, 2025
96df74c
Add finalize_bundle function to handle MONAI bundle finalization and …
SimoneBendazzoli93 Apr 7, 2025
6c6346b
Add finalize_config call in generate_configs to ensure proper configu…
SimoneBendazzoli93 Apr 7, 2025
7a84968
Add validation checks and logging in finalize_bundle for improved err…
SimoneBendazzoli93 Apr 7, 2025
b8de5c4
Add finalize task handling and update data source config path format
SimoneBendazzoli93 Apr 7, 2025
665d989
Enhance ModelnnUNetWrapper and get_nnunet_monai_predictor to accept a…
SimoneBendazzoli93 Apr 7, 2025
6480bba
Add logging of validation summary to finalize_bundle for improved tra…
SimoneBendazzoli93 Apr 8, 2025
c4b3453
Add run_job function to submit jobs with optional meta configuration …
SimoneBendazzoli93 Apr 8, 2025
2cf94c7
Store job ID after submitting job in run_job function for better trac…
SimoneBendazzoli93 Apr 8, 2025
2d09694
Add dataset_name_or_id to nnunet_config and update data source config…
SimoneBendazzoli93 Apr 8, 2025
fcf6981
Update preprocess function to return dataset names for improved data …
SimoneBendazzoli93 Apr 8, 2025
b13b3e8
Refactor process_client_response to simplify response validation logic
SimoneBendazzoli93 Apr 8, 2025
e5ad7f0
Merge remote-tracking branch 'upstream/dev' into dev
SimoneBendazzoli93 Apr 16, 2025
397d0eb
Increase timeout value from 6000 to 60000 in multiple configuration f…
SimoneBendazzoli93 Apr 16, 2025
f819112
Add nnUNetPrepareBundleJsonGenerator and cross-site validation config…
SimoneBendazzoli93 Apr 23, 2025
f8cb161
Add cross-site validation function and refactor training metrics comp…
SimoneBendazzoli93 Apr 23, 2025
d62fd04
Add utility functions for NIFTI file handling and data preparation
SimoneBendazzoli93 Apr 23, 2025
3207bb7
Rename parameter for clarity in prepare_data_folder_api function
SimoneBendazzoli93 Apr 23, 2025
c5e64c7
Add labels and regions_class_order parameters to create_new_dataset_j…
SimoneBendazzoli93 Apr 25, 2025
4a15b97
Add cross-site evaluation and validation metrics computation functions
SimoneBendazzoli93 Apr 25, 2025
71b30d7
Disable sorting of keys in JSON output for create_new_dataset_json fu…
SimoneBendazzoli93 Apr 25, 2025
188fa53
Disable sorting of keys in JSON output for create_new_dataset_json fu…
SimoneBendazzoli93 Apr 28, 2025
db2f65b
Add logging for case processing in concatenate_modalities function
SimoneBendazzoli93 Apr 29, 2025
6e82b7e
Update print statement in concatenate_modalities function to use pati…
SimoneBendazzoli93 Apr 29, 2025
e55f2f5
Add label_dict parameter to nnUNetExecutor and update bundle configur…
SimoneBendazzoli93 May 1, 2025
d6a7777
Refactor dataset_name_or_id handling in nnUNetExecutor and update net…
SimoneBendazzoli93 May 2, 2025
8c2330d
Fix argument passing for run_cross_site_validation in nnUNetExecutor
SimoneBendazzoli93 May 2, 2025
9656483
Fix train_postprocessing assignment in prepare_bundle function for re…
SimoneBendazzoli93 May 2, 2025
c267fdf
Enhance error message in _check_converted method to include local wei…
SimoneBendazzoli93 May 2, 2025
460481f
Add ResultDownloader class for job result processing and model file m…
SimoneBendazzoli93 May 3, 2025
a50ef4c
Update network configuration in prepare_bundle function and remove Re…
SimoneBendazzoli93 May 3, 2025
f63cc5e
Fix image dimension handling and enable verbose logging in get_nnunet…
SimoneBendazzoli93 May 5, 2025
ac94578
Refactor training and validation processes by introducing API functio…
SimoneBendazzoli93 May 6, 2025
9f9261d
Add dataset_name parameter to MLflow logging in prepare_bundle_api
SimoneBendazzoli93 May 7, 2025
2e6f7c9
Add original_path and model_name parameters to cross-site validation …
SimoneBendazzoli93 May 15, 2025
75d6720
Fix spacing property extraction in ModelnnUNetWrapper for MetaTensor …
SimoneBendazzoli93 May 15, 2025
406923a
Add debug print for labels and include background in validation and t…
SimoneBendazzoli93 May 16, 2025
84d25c6
Update train_data configuration in prepare_bundle_api based on nnUNet…
SimoneBendazzoli93 May 21, 2025
c8104f4
Enhance prepare_data_folder_api to support 'monai-label' dataset form…
SimoneBendazzoli93 May 27, 2025
78f6650
Add nnUNetV2Runner initialization in finalize_bundle for dataset vali…
SimoneBendazzoli93 May 27, 2025
670ff15
Update finalize_bundle to support federated training in finalize_bund…
SimoneBendazzoli93 May 27, 2025
1cd4312
Fix metrics initialization in compute_validation_metrics for proper l…
SimoneBendazzoli93 May 27, 2025
f5ea8e8
Add debug prints for HD95, ASD, and Dice metrics in compute_validatio…
SimoneBendazzoli93 May 27, 2025
310745f
Refactor compute_validation_metrics to correct metric indexing and re…
SimoneBendazzoli93 May 27, 2025
6611789
Fix initialization of mean summary in compute_validation_metrics to e…
SimoneBendazzoli93 May 27, 2025
2f111d5
Add debug print for labels in finalize_bundle to aid in validation tr…
SimoneBendazzoli93 May 27, 2025
f42840d
Fix label handling in finalize_bundle to ensure correct metric logging
SimoneBendazzoli93 May 27, 2025
50faeef
Fix label handling in finalize_bundle to support label formatting and…
SimoneBendazzoli93 May 27, 2025
0fe60f0
Add warning for missing labels in finalize_bundle to improve metric l…
SimoneBendazzoli93 May 27, 2025
f2dbc26
Fix mlflow_params key assignment in prepare_bundle_api for consistency
SimoneBendazzoli93 Jun 3, 2025
db5ae69
Fix model name check in ModelnnUNetWrapper to ensure it ends with .pt…
SimoneBendazzoli93 Jun 3, 2025
c8471e7
Merge remote-tracking branch 'upstream/dev' into dev
SimoneBendazzoli93 Jun 3, 2025
02eec30
Merge branch 'dev' into 8328-nnunet-bundle-integration
SimoneBendazzoli93 Jun 3, 2025
3b61e9f
Enhance ModelnnUNetWrapper to accept additional parameters for datase…
SimoneBendazzoli93 Jun 3, 2025
eae66cb
Fix error message in MonaiAlgo to remove local weight dict keys from …
SimoneBendazzoli93 Jun 3, 2025
70418a6
Remove unused label and regions_class_order parameters from dataset J…
SimoneBendazzoli93 Jun 3, 2025
aba934d
Refactor create_new_dataset_json to remove unused labels and regions_…
SimoneBendazzoli93 Jun 3, 2025
9500de1
Add background label to new dataset JSON creation and update foregrou…
SimoneBendazzoli93 Jun 3, 2025
f2dad2c
Fix formatting in create_new_dataset_json by removing unnecessary bla…
SimoneBendazzoli93 Jun 3, 2025
dfeb9fc
Update torch.load calls to include weights_only=False for improved mo…
SimoneBendazzoli93 Jun 3, 2025
6f7e558
DCO Remediation Commit for Simone Bendazzoli <[email protected]>
SimoneBendazzoli93 Jun 3, 2025
d6bec33
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 3, 2025
92e1108
Update type hints for dataset_json, plans, and nnunet_config paramete…
SimoneBendazzoli93 Jun 3, 2025
c10b3b6
Merge branch '8328-nnunet-bundle-integration' of https://github.com/S…
SimoneBendazzoli93 Jun 3, 2025
c90e960
DCO Remediation Commit for Simone Bendazzoli <[email protected]>
SimoneBendazzoli93 Jun 3, 2025
a748a59
Merge branch 'dev' into 8328-nnunet-bundle-integration
SimoneBendazzoli93 Jun 16, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
118 changes: 87 additions & 31 deletions monai/apps/nnunet/nnunet_bundle.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def get_nnunet_trainer(
cudnn.benchmark = True

if pretrained_model is not None:
state_dict = torch.load(pretrained_model)
state_dict = torch.load(pretrained_model, weights_only=False)
if "network_weights" in state_dict:
nnunet_trainer.network._orig_mod.load_state_dict(state_dict["network_weights"])
return nnunet_trainer
Expand All @@ -152,6 +152,12 @@ class ModelnnUNetWrapper(torch.nn.Module):
The folder path where the model and related files are stored.
model_name : str, optional
The name of the model file, by default "model.pt".
dataset_json : dict, optional
The dataset JSON file containing dataset information.
plans : dict, optional
The plans JSON file containing model configuration.
nnunet_config : dict, optional
The nnUNet configuration dictionary containing model parameters.

Attributes
----------
Expand All @@ -162,11 +168,19 @@ class ModelnnUNetWrapper(torch.nn.Module):

Notes
-----
This class integrates nnUNet model with MONAI framework by loading necessary configurations,
This class integrates nnUNet model with MONAI framework by loading configurations,
restoring network architecture, and setting up the predictor for inference.
"""

def __init__(self, predictor: object, model_folder: Union[str, Path], model_name: str = "model.pt"): # type: ignore
def __init__(
self,
predictor: object,
model_folder: Union[str, Path],
model_name: str = "model.pt",
dataset_json: Optional[dict] = None,
plans: Optional[dict] = None,
nnunet_config: Optional[dict] = None,
): # type: ignore
super().__init__()
self.predictor = predictor

Expand All @@ -175,30 +189,43 @@ def __init__(self, predictor: object, model_folder: Union[str, Path], model_name
from nnunetv2.utilities.plans_handling.plans_handler import PlansManager

# Block Added from nnUNet/nnunetv2/inference/predict_from_raw_data.py#nnUNetPredictor
dataset_json = load_json(join(Path(model_training_output_dir).parent, "dataset.json"))
plans = load_json(join(Path(model_training_output_dir).parent, "plans.json"))
if dataset_json is None:
dataset_json = load_json(join(Path(model_training_output_dir).parent, "dataset.json"))
if plans is None:
plans = load_json(join(Path(model_training_output_dir).parent, "plans.json"))
plans_manager = PlansManager(plans)

parameters = []

checkpoint = torch.load(
join(Path(model_training_output_dir).parent, "nnunet_checkpoint.pth"), map_location=torch.device("cpu")
)
trainer_name = checkpoint["trainer_name"]
configuration_name = checkpoint["init_args"]["configuration"]
inference_allowed_mirroring_axes = (
checkpoint["inference_allowed_mirroring_axes"]
if "inference_allowed_mirroring_axes" in checkpoint.keys()
else None
)
if Path(model_training_output_dir).joinpath(model_name).is_file():
monai_checkpoint = torch.load(join(model_training_output_dir, model_name), map_location=torch.device("cpu"))
if nnunet_config is None:
checkpoint = torch.load(
join(Path(model_training_output_dir).parent, "nnunet_checkpoint.pth"),
map_location=torch.device("cpu"),
weights_only=False,
)
trainer_name = checkpoint["trainer_name"]
configuration_name = checkpoint["init_args"]["configuration"]
inference_allowed_mirroring_axes = (
checkpoint["inference_allowed_mirroring_axes"]
if "inference_allowed_mirroring_axes" in checkpoint.keys()
else None
)
else:
trainer_name = nnunet_config["trainer_name"]
configuration_name = nnunet_config["configuration"]
inference_allowed_mirroring_axes = nnunet_config["inference_allowed_mirroring_axes"]

if Path(model_training_output_dir).joinpath(model_name).is_file() and model_name.endswith(".pt"):
monai_checkpoint = torch.load(
join(model_training_output_dir, model_name), map_location=torch.device("cpu"), weights_only=False
)
if "network_weights" in monai_checkpoint.keys():
parameters.append(monai_checkpoint["network_weights"])
else:
parameters.append(monai_checkpoint)

configuration_manager = plans_manager.get_configuration(configuration_name)

import nnunetv2
from nnunetv2.utilities.find_class_by_name import recursive_find_python_class
from nnunetv2.utilities.label_handling.label_handling import determine_num_input_channels
Expand Down Expand Up @@ -255,7 +282,16 @@ def forward(self, x: MetaTensor) -> MetaTensor:
"""
if isinstance(x, MetaTensor):
if "pixdim" in x.meta:
properties_or_list_of_properties = {"spacing": x.meta["pixdim"][0][1:4].numpy().tolist()}
if x.meta["pixdim"].ndim == 1:
if x.meta["pixdim"][0] == 1:
properties_or_list_of_properties = {"spacing": x.meta["pixdim"][1:4].tolist()}
else:
properties_or_list_of_properties = {"spacing": x.meta["pixdim"][:3].tolist()}
else:
if x.meta["pixdim"][0][0] == 1:
properties_or_list_of_properties = {"spacing": x.meta["pixdim"][0][1:4].numpy().tolist()}
else:
properties_or_list_of_properties = {"spacing": x.meta["pixdim"][0][:3].numpy().tolist()}
elif "affine" in x.meta:
spacing = [
abs(x.meta["affine"][0][0].item()),
Expand All @@ -269,6 +305,8 @@ def forward(self, x: MetaTensor) -> MetaTensor:
raise TypeError("Input must be a MetaTensor or a tuple of MetaTensors.")

image_or_list_of_images = x.cpu().numpy()[0, :]
image_or_list_of_images = np.transpose(image_or_list_of_images, (0, 3, 2, 1))
properties_or_list_of_properties["spacing"] = properties_or_list_of_properties["spacing"][::-1]

# input_files should be a list of file paths, one per modality
prediction_output = self.predictor.predict_from_list_of_npy_arrays( # type: ignore
Expand All @@ -286,11 +324,17 @@ def forward(self, x: MetaTensor) -> MetaTensor:
for out in prediction_output: # Add batch and channel dimensions
out_tensors.append(torch.from_numpy(np.expand_dims(np.expand_dims(out, 0), 0)))
out_tensor = torch.cat(out_tensors, 0) # Concatenate along batch dimension

out_tensor = out_tensor.permute(0, 1, 4, 3, 2)
return MetaTensor(out_tensor, meta=x.meta)


def get_nnunet_monai_predictor(model_folder: Union[str, Path], model_name: str = "model.pt") -> ModelnnUNetWrapper:
def get_nnunet_monai_predictor(
model_folder: Union[str, Path],
model_name: str = "model.pt",
dataset_json: Optional[dict] = None,
plans: Optional[dict] = None,
nnunet_config: Optional[dict] = None,
) -> ModelnnUNetWrapper:
"""
Initializes and returns a `nnUNetMONAIModelWrapper` containing the corresponding `nnUNetPredictor`.
The model folder should contain the following files, created during training:
Expand Down Expand Up @@ -321,6 +365,12 @@ def get_nnunet_monai_predictor(model_folder: Union[str, Path], model_name: str =
The folder where the model is stored.
model_name : str, optional
The name of the model file, by default "model.pt".
dataset_json : dict, optional
The dataset JSON file containing dataset information.
plans : dict, optional
The plans JSON file containing model configuration.
nnunet_config : dict, optional
The nnUNet configuration dictionary containing model parameters.

Returns
-------
Expand All @@ -335,12 +385,12 @@ def get_nnunet_monai_predictor(model_folder: Union[str, Path], model_name: str =
use_gaussian=True,
use_mirroring=False,
device=torch.device("cuda", 0),
verbose=False,
verbose_preprocessing=False,
verbose=True,
verbose_preprocessing=True,
allow_tqdm=True,
)
# initializes the network architecture, loads the checkpoint
wrapper = ModelnnUNetWrapper(predictor, model_folder, model_name)
wrapper = ModelnnUNetWrapper(predictor, model_folder, model_name, dataset_json, plans, nnunet_config)
return wrapper


Expand Down Expand Up @@ -383,8 +433,12 @@ def convert_nnunet_to_monai_bundle(nnunet_config: dict, bundle_root_folder: str,
dataset_name, f"{nnunet_trainer}__{nnunet_plans}__{nnunet_configuration}"
)

nnunet_checkpoint_final = torch.load(Path(nnunet_model_folder).joinpath(f"fold_{fold}", "checkpoint_final.pth"))
nnunet_checkpoint_best = torch.load(Path(nnunet_model_folder).joinpath(f"fold_{fold}", "checkpoint_best.pth"))
nnunet_checkpoint_final = torch.load(
Path(nnunet_model_folder).joinpath(f"fold_{fold}", "checkpoint_final.pth"), weights_only=False
)
nnunet_checkpoint_best = torch.load(
Path(nnunet_model_folder).joinpath(f"fold_{fold}", "checkpoint_best.pth"), weights_only=False
)

nnunet_checkpoint = {}
nnunet_checkpoint["inference_allowed_mirroring_axes"] = nnunet_checkpoint_final["inference_allowed_mirroring_axes"]
Expand Down Expand Up @@ -470,7 +524,7 @@ def get_network_from_nnunet_plans(
if model_ckpt is None:
return network
else:
state_dict = torch.load(model_ckpt)
state_dict = torch.load(model_ckpt, weights_only=False)
network.load_state_dict(state_dict[model_key_in_ckpt])
return network

Expand Down Expand Up @@ -534,7 +588,7 @@ def subfiles(

Path(nnunet_model_folder).joinpath(f"fold_{fold}").mkdir(parents=True, exist_ok=True)

nnunet_checkpoint: dict = torch.load(f"{bundle_root_folder}/models/nnunet_checkpoint.pth")
nnunet_checkpoint: dict = torch.load(f"{bundle_root_folder}/models/nnunet_checkpoint.pth", weights_only=False)
latest_checkpoints: list[str] = subfiles(
Path(bundle_root_folder).joinpath("models", f"fold_{fold}"), prefix="checkpoint_epoch", sort=True
)
Expand All @@ -545,7 +599,7 @@ def subfiles(
epochs.sort()
final_epoch: int = epochs[-1]
monai_last_checkpoint: dict = torch.load(
f"{bundle_root_folder}/models/fold_{fold}/checkpoint_epoch={final_epoch}.pt"
f"{bundle_root_folder}/models/fold_{fold}/checkpoint_epoch={final_epoch}.pt", weights_only=False
)

best_checkpoints: list[str] = subfiles(
Expand All @@ -558,10 +612,11 @@ def subfiles(
key_metrics.sort()
best_key_metric: str = key_metrics[-1]
monai_best_checkpoint: dict = torch.load(
f"{bundle_root_folder}/models/fold_{fold}/checkpoint_key_metric={best_key_metric}.pt"
f"{bundle_root_folder}/models/fold_{fold}/checkpoint_key_metric={best_key_metric}.pt", weights_only=False
)

nnunet_checkpoint["optimizer_state"] = monai_last_checkpoint["optimizer_state"]
if "optimizer_state" in monai_last_checkpoint:
nnunet_checkpoint["optimizer_state"] = monai_last_checkpoint["optimizer_state"]

nnunet_checkpoint["network_weights"] = odict()

Expand All @@ -577,7 +632,8 @@ def subfiles(

nnunet_checkpoint["network_weights"] = odict()

nnunet_checkpoint["optimizer_state"] = monai_best_checkpoint["optimizer_state"]
if "optimizer_state" in monai_last_checkpoint:
nnunet_checkpoint["optimizer_state"] = monai_best_checkpoint["optimizer_state"]

for key in monai_best_checkpoint["network_weights"]:
nnunet_checkpoint["network_weights"][key] = monai_best_checkpoint["network_weights"][key]
Expand Down
Loading