diff --git a/monai/apps/generation/maisi/networks/autoencoderkl_maisi.py b/monai/apps/generation/maisi/networks/autoencoderkl_maisi.py index 86b4e68864..a205009df4 100644 --- a/monai/apps/generation/maisi/networks/autoencoderkl_maisi.py +++ b/monai/apps/generation/maisi/networks/autoencoderkl_maisi.py @@ -43,7 +43,8 @@ class MaisiGroupNorm3D(nn.GroupNorm): num_channels: Number of channels for the group norm. eps: Epsilon value for numerical stability. affine: Whether to use learnable affine parameters, default to `True`. - norm_float16: If True, convert output of MaisiGroupNorm3D to float16 format, default to `False`. + norm_float16: If True, convert output of MaisiGroupNorm3D to float16, if False convert to float32. + If None, convert to the datatype of the input. Defaults to `False`. print_info: Whether to print information, default to `False`. save_mem: Whether to clean CUDA cache in order to save GPU memory, default to `True`. """ @@ -54,7 +55,7 @@ def __init__( num_channels: int, eps: float = 1e-5, affine: bool = True, - norm_float16: bool = False, + norm_float16: bool | None = False, print_info: bool = False, save_mem: bool = True, ): @@ -67,6 +68,8 @@ def forward(self, input: torch.Tensor) -> torch.Tensor: if self.print_info: logger.info(f"MaisiGroupNorm3D with input size: {input.size()}") + target_dtype = input.dtype + if len(input.shape) != 5: raise ValueError("Expected a 5D tensor") @@ -75,13 +78,17 @@ def forward(self, input: torch.Tensor) -> torch.Tensor: inputs = [] for i in range(input.size(1)): - array = input[:, i : i + 1, ...].to(dtype=torch.float32) + array = input[:, i : i + 1, ...] + if self.norm_float16 is not None: + array = array.to(dtype=torch.float32) mean = array.mean([2, 3, 4, 5], keepdim=True) std = array.var([2, 3, 4, 5], unbiased=False, keepdim=True).add_(self.eps).sqrt_() - if self.norm_float16: + if self.norm_float16 is None: + inputs.append(((array - mean) / std).to(dtype=target_dtype)) + elif self.norm_float16: inputs.append(((array - mean) / std).to(dtype=torch.float16)) else: - inputs.append((array - mean) / std) + inputs.append(((array - mean) / std).to(dtype=torch.float32)) del input _empty_cuda_cache(self.save_mem) @@ -393,7 +400,8 @@ class MaisiResBlock(nn.Module): out_channels: Number of output channels. num_splits: Number of splits for the input tensor. dim_split: Dimension of splitting for the input tensor. - norm_float16: If True, convert output of MaisiGroupNorm3D to float16 format, default to `False`. + norm_float16: If True, convert output of MaisiGroupNorm3D to float16, if False convert to float32. + If None, convert to the datatype of the input. Defaults to `False`. print_info: Whether to print information, default to `False`. save_mem: Whether to clean CUDA cache in order to save GPU memory, default to `True`. """ @@ -407,7 +415,7 @@ def __init__( out_channels: int, num_splits: int, dim_split: int, - norm_float16: bool = False, + norm_float16: bool | None = False, print_info: bool = False, save_mem: bool = True, ) -> None: @@ -524,7 +532,8 @@ class MaisiEncoder(nn.Module): use_flash_attention: If True, use flash attention for a memory efficient attention mechanism. num_splits: Number of splits for the input tensor. dim_split: Dimension of splitting for the input tensor. - norm_float16: If True, convert output of MaisiGroupNorm3D to float16 format, default to `False`. + norm_float16: If True, convert output of MaisiGroupNorm3D to float16, if False convert to float32. + If None, convert to the datatype of the input. Defaults to `False`. print_info: Whether to print information, default to `False`. save_mem: Whether to clean CUDA cache in order to save GPU memory, default to `True`. """ @@ -541,7 +550,7 @@ def __init__( attention_levels: Sequence[bool], num_splits: int, dim_split: int, - norm_float16: bool = False, + norm_float16: bool | None = False, print_info: bool = False, save_mem: bool = True, with_nonlocal_attn: bool = True, @@ -714,7 +723,8 @@ class MaisiDecoder(nn.Module): use_convtranspose: If True, use ConvTranspose to upsample feature maps in decoder. num_splits: Number of splits for the input tensor. dim_split: Dimension of splitting for the input tensor. - norm_float16: If True, convert output of MaisiGroupNorm3D to float16 format, default to `False`. + norm_float16: If True, convert output of MaisiGroupNorm3D to float16, if False convert to float32. + If None, convert to the datatype of the input. Defaults to `False`. print_info: Whether to print information, default to `False`. save_mem: Whether to clean CUDA cache in order to save GPU memory, default to `True`. """ @@ -731,7 +741,7 @@ def __init__( attention_levels: Sequence[bool], num_splits: int, dim_split: int, - norm_float16: bool = False, + norm_float16: bool | None = False, print_info: bool = False, save_mem: bool = True, with_nonlocal_attn: bool = True, @@ -905,7 +915,8 @@ class AutoencoderKlMaisi(AutoencoderKL): use_convtranspose: If True, use ConvTranspose to upsample feature maps in decoder. num_splits: Number of splits for the input tensor. dim_split: Dimension of splitting for the input tensor. - norm_float16: If True, convert output of MaisiGroupNorm3D to float16 format, default to `False`. + norm_float16: If True, convert output of MaisiGroupNorm3D to float16, if False convert to float32. + If None, convert to the datatype of the input. Defaults to `False`. print_info: Whether to print information, default to `False`. save_mem: Whether to clean CUDA cache in order to save GPU memory, default to `True`. """ @@ -930,7 +941,7 @@ def __init__( use_convtranspose: bool = False, num_splits: int = 16, dim_split: int = 0, - norm_float16: bool = False, + norm_float16: bool | None = False, print_info: bool = False, save_mem: bool = True, ) -> None: diff --git a/tests/apps/maisi/networks/test_autoencoderkl_maisi.py b/tests/apps/maisi/networks/test_autoencoderkl_maisi.py index 6b9aae1d17..0834e6278b 100644 --- a/tests/apps/maisi/networks/test_autoencoderkl_maisi.py +++ b/tests/apps/maisi/networks/test_autoencoderkl_maisi.py @@ -75,27 +75,43 @@ else: CASES = CASES_NO_ATTENTION +test_dtypes = [torch.float32] +if device.type == "cuda": + test_dtypes.append(torch.bfloat16) + test_dtypes.append(torch.float16) + +DTYPE_CASES = [] +for dtype in test_dtypes: + for case in CASES: + for norm_float in [False, None]: + if dtype != torch.float32 and norm_float is not None: + continue + new_case = [{**case[0], "norm_float16": norm_float}, case[1], case[2], case[3]] # type: ignore[dict-item] + DTYPE_CASES.append(new_case + [dtype]) + class TestAutoencoderKlMaisi(unittest.TestCase): - @parameterized.expand(CASES) - def test_shape(self, input_param, input_shape, expected_shape, expected_latent_shape): - net = AutoencoderKlMaisi(**input_param).to(device) + + @parameterized.expand(DTYPE_CASES) + def test_shape(self, input_param, input_shape, expected_shape, expected_latent_shape, dtype): + net = AutoencoderKlMaisi(**input_param).to(device=device, dtype=dtype) + print(input_param) with eval_mode(net): - result = net.forward(torch.randn(input_shape).to(device)) + result = net.forward(torch.randn(input_shape).to(device=device, dtype=dtype)) self.assertEqual(result[0].shape, expected_shape) self.assertEqual(result[1].shape, expected_latent_shape) self.assertEqual(result[2].shape, expected_latent_shape) - @parameterized.expand(CASES) + @parameterized.expand(DTYPE_CASES) @SkipIfBeforePyTorchVersion((1, 11)) def test_shape_with_convtranspose_and_checkpointing( - self, input_param, input_shape, expected_shape, expected_latent_shape + self, input_param, input_shape, expected_shape, expected_latent_shape, dtype ): input_param = input_param.copy() input_param.update({"use_checkpointing": True, "use_convtranspose": True}) - net = AutoencoderKlMaisi(**input_param).to(device) + net = AutoencoderKlMaisi(**input_param).to(device=device, dtype=dtype) with eval_mode(net): - result = net.forward(torch.randn(input_shape).to(device)) + result = net.forward(torch.randn(input_shape).to(device=device, dtype=dtype)) self.assertEqual(result[0].shape, expected_shape) self.assertEqual(result[1].shape, expected_latent_shape) self.assertEqual(result[2].shape, expected_latent_shape)