Skip to content

Commit 20372f0

Browse files
Lucas-rbntKumoLiu
andauthored
Implementation of a Masked Autoencoder for representation learning (#8152)
This follows a previous PR (#7598). In the previous PR, the official implementation was under a non-compatible license. This is a clean-sheet implementation I developed. The code is fairly straightforward, involving a transformer, encoder, and decoder. The primary changes are in how masks are selected and how patches are organized as they pass through the model. In the official masked autoencoder implementation, noise is first generated and then sorted twice using `torch.argsort`. This rearranges the tokens and identifies which ones are retained, ultimately selecting only a subset of the shuffled indices. In our implementation, we use `torch.multinomial` to generate mask indices, followed by simple boolean indexing to manage the sub-selection of patches for encoding and the reordering with mask tokens in the decoder. **Let me know if you need a detailed, line-by-line explanation of the new code, including how it works and how it differs from the previous version.** ### Description Implementation of the Masked Autoencoder as described in the paper: [Masked Autoencoders Are Scalable Vision Learners](https://arxiv.org/pdf/2111.06377.pdf) from Kaiming et al. Its effectiveness has already been demonstrated in the literature for medical tasks in the paper [Self Pre-training with Masked Autoencoders for Medical Image Classification and Segmentation](https://arxiv.org/abs/2203.05573). The PR contains the architecture and associated unit tests. **Note:** The output includes the prediction, which is a tensor of size: ($BS$, $N_{tokens}$, $D$), and the associated mask ($BS$, $N_{tokens}$). The mask is used to apply loss only to masked patches, but I'm not sure it's the “best” output format, what do you think? ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [x] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [x] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [x] In-line docstrings updated. - [x] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: Lucas Robinet <[email protected]> Signed-off-by: Lucas Robinet <[email protected]> Co-authored-by: YunLiu <[email protected]>
1 parent e73257c commit 20372f0

File tree

4 files changed

+377
-0
lines changed

4 files changed

+377
-0
lines changed

docs/source/networks.rst

+5
Original file line numberDiff line numberDiff line change
@@ -630,6 +630,11 @@ Nets
630630
.. autoclass:: ViTAutoEnc
631631
:members:
632632

633+
`MaskedAutoEncoderViT`
634+
~~~~~~~~~~~~~~~~~~~~~~
635+
.. autoclass:: MaskedAutoEncoderViT
636+
:members:
637+
633638
`FullyConnectedNet`
634639
~~~~~~~~~~~~~~~~~~~
635640
.. autoclass:: FullyConnectedNet

monai/networks/nets/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@
5353
from .generator import Generator
5454
from .highresnet import HighResBlock, HighResNet
5555
from .hovernet import Hovernet, HoVernet, HoVerNet, HoverNet
56+
from .masked_autoencoder_vit import MaskedAutoEncoderViT
5657
from .mednext import (
5758
MedNeXt,
5859
MedNext,
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,211 @@
1+
# Copyright (c) MONAI Consortium
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
# http://www.apache.org/licenses/LICENSE-2.0
6+
# Unless required by applicable law or agreed to in writing, software
7+
# distributed under the License is distributed on an "AS IS" BASIS,
8+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
# See the License for the specific language governing permissions and
10+
# limitations under the License.
11+
12+
from __future__ import annotations
13+
14+
from collections.abc import Sequence
15+
16+
import numpy as np
17+
import torch
18+
import torch.nn as nn
19+
20+
from monai.networks.blocks.patchembedding import PatchEmbeddingBlock
21+
from monai.networks.blocks.pos_embed_utils import build_sincos_position_embedding
22+
from monai.networks.blocks.transformerblock import TransformerBlock
23+
from monai.networks.layers import trunc_normal_
24+
from monai.utils import ensure_tuple_rep
25+
from monai.utils.module import look_up_option
26+
27+
SUPPORTED_POS_EMBEDDING_TYPES = {"none", "learnable", "sincos"}
28+
29+
__all__ = ["MaskedAutoEncoderViT"]
30+
31+
32+
class MaskedAutoEncoderViT(nn.Module):
33+
"""
34+
Masked Autoencoder (ViT), based on: "Kaiming et al.,
35+
Masked Autoencoders Are Scalable Vision Learners <https://arxiv.org/abs/2111.06377>"
36+
Only a subset of the patches passes through the encoder. The decoder tries to reconstruct
37+
the masked patches, resulting in improved training speed.
38+
"""
39+
40+
def __init__(
41+
self,
42+
in_channels: int,
43+
img_size: Sequence[int] | int,
44+
patch_size: Sequence[int] | int,
45+
hidden_size: int = 768,
46+
mlp_dim: int = 512,
47+
num_layers: int = 12,
48+
num_heads: int = 12,
49+
masking_ratio: float = 0.75,
50+
decoder_hidden_size: int = 384,
51+
decoder_mlp_dim: int = 512,
52+
decoder_num_layers: int = 4,
53+
decoder_num_heads: int = 12,
54+
proj_type: str = "conv",
55+
pos_embed_type: str = "sincos",
56+
decoder_pos_embed_type: str = "sincos",
57+
dropout_rate: float = 0.0,
58+
spatial_dims: int = 3,
59+
qkv_bias: bool = False,
60+
save_attn: bool = False,
61+
) -> None:
62+
"""
63+
Args:
64+
in_channels: dimension of input channels or the number of channels for input.
65+
img_size: dimension of input image.
66+
patch_size: dimension of patch size
67+
hidden_size: dimension of hidden layer. Defaults to 768.
68+
mlp_dim: dimension of feedforward layer. Defaults to 512.
69+
num_layers: number of transformer blocks. Defaults to 12.
70+
num_heads: number of attention heads. Defaults to 12.
71+
masking_ratio: ratio of patches to be masked. Defaults to 0.75.
72+
decoder_hidden_size: dimension of hidden layer for decoder. Defaults to 384.
73+
decoder_mlp_dim: dimension of feedforward layer for decoder. Defaults to 512.
74+
decoder_num_layers: number of transformer blocks for decoder. Defaults to 4.
75+
decoder_num_heads: number of attention heads for decoder. Defaults to 12.
76+
proj_type: position embedding layer type. Defaults to "conv".
77+
pos_embed_type: position embedding layer type. Defaults to "sincos".
78+
decoder_pos_embed_type: position embedding layer type for decoder. Defaults to "sincos".
79+
dropout_rate: fraction of the input units to drop. Defaults to 0.0.
80+
spatial_dims: number of spatial dimensions. Defaults to 3.
81+
qkv_bias: apply bias to the qkv linear layer in self attention block. Defaults to False.
82+
save_attn: to make accessible the attention in self attention block. Defaults to False.
83+
Examples::
84+
# for single channel input with image size of (96,96,96), and sin-cos positional encoding
85+
>>> net = MaskedAutoEncoderViT(in_channels=1, img_size=(96,96,96), patch_size=(16,16,16),
86+
pos_embed_type='sincos')
87+
# for 3-channel with image size of (128,128,128) and a learnable positional encoding
88+
>>> net = MaskedAutoEncoderViT(in_channels=3, img_size=128, patch_size=16, pos_embed_type='learnable')
89+
# for 3-channel with image size of (224,224) and a masking ratio of 0.25
90+
>>> net = MaskedAutoEncoderViT(in_channels=3, img_size=(224,224), patch_size=(16,16), masking_ratio=0.25,
91+
spatial_dims=2)
92+
"""
93+
94+
super().__init__()
95+
96+
if not (0 <= dropout_rate <= 1):
97+
raise ValueError(f"dropout_rate should be between 0 and 1, got {dropout_rate}.")
98+
99+
if hidden_size % num_heads != 0:
100+
raise ValueError("hidden_size should be divisible by num_heads.")
101+
102+
if decoder_hidden_size % decoder_num_heads != 0:
103+
raise ValueError("decoder_hidden_size should be divisible by decoder_num_heads.")
104+
105+
self.patch_size = ensure_tuple_rep(patch_size, spatial_dims)
106+
self.img_size = ensure_tuple_rep(img_size, spatial_dims)
107+
self.spatial_dims = spatial_dims
108+
for m, p in zip(self.img_size, self.patch_size):
109+
if m % p != 0:
110+
raise ValueError(f"patch_size={patch_size} should be divisible by img_size={img_size}.")
111+
112+
self.decoder_hidden_size = decoder_hidden_size
113+
114+
if masking_ratio <= 0 or masking_ratio >= 1:
115+
raise ValueError(f"masking_ratio should be in the range (0, 1), got {masking_ratio}.")
116+
117+
self.masking_ratio = masking_ratio
118+
self.cls_token = nn.Parameter(torch.zeros(1, 1, hidden_size))
119+
120+
self.patch_embedding = PatchEmbeddingBlock(
121+
in_channels=in_channels,
122+
img_size=img_size,
123+
patch_size=patch_size,
124+
hidden_size=hidden_size,
125+
num_heads=num_heads,
126+
proj_type=proj_type,
127+
pos_embed_type=pos_embed_type,
128+
dropout_rate=dropout_rate,
129+
spatial_dims=self.spatial_dims,
130+
)
131+
blocks = [
132+
TransformerBlock(hidden_size, mlp_dim, num_heads, dropout_rate, qkv_bias, save_attn)
133+
for _ in range(num_layers)
134+
]
135+
self.blocks = nn.Sequential(*blocks, nn.LayerNorm(hidden_size))
136+
137+
# decoder
138+
self.decoder_embed = nn.Linear(hidden_size, decoder_hidden_size)
139+
140+
self.mask_tokens = nn.Parameter(torch.zeros(1, 1, decoder_hidden_size))
141+
142+
self.decoder_pos_embed_type = look_up_option(decoder_pos_embed_type, SUPPORTED_POS_EMBEDDING_TYPES)
143+
self.decoder_pos_embedding = nn.Parameter(torch.zeros(1, self.patch_embedding.n_patches, decoder_hidden_size))
144+
145+
decoder_blocks = [
146+
TransformerBlock(decoder_hidden_size, decoder_mlp_dim, decoder_num_heads, dropout_rate, qkv_bias, save_attn)
147+
for _ in range(decoder_num_layers)
148+
]
149+
self.decoder_blocks = nn.Sequential(*decoder_blocks, nn.LayerNorm(decoder_hidden_size))
150+
self.decoder_pred = nn.Linear(decoder_hidden_size, int(np.prod(self.patch_size)) * in_channels)
151+
152+
self._init_weights()
153+
154+
def _init_weights(self):
155+
"""
156+
similar to monai/networks/blocks/patchembedding.py for the decoder positional encoding and for mask and
157+
classification tokens
158+
"""
159+
if self.decoder_pos_embed_type == "none":
160+
pass
161+
elif self.decoder_pos_embed_type == "learnable":
162+
trunc_normal_(self.decoder_pos_embedding, mean=0.0, std=0.02, a=-2.0, b=2.0)
163+
elif self.decoder_pos_embed_type == "sincos":
164+
grid_size = []
165+
for in_size, pa_size in zip(self.img_size, self.patch_size):
166+
grid_size.append(in_size // pa_size)
167+
168+
self.decoder_pos_embedding = build_sincos_position_embedding(
169+
grid_size, self.decoder_hidden_size, self.spatial_dims
170+
)
171+
172+
else:
173+
raise ValueError(f"decoder_pos_embed_type {self.decoder_pos_embed_type} not supported.")
174+
175+
# initialize patch_embedding like nn.Linear (instead of nn.Conv2d)
176+
trunc_normal_(self.mask_tokens, mean=0.0, std=0.02, a=-2.0, b=2.0)
177+
trunc_normal_(self.cls_token, mean=0.0, std=0.02, a=-2.0, b=2.0)
178+
179+
def _masking(self, x, masking_ratio: float | None = None):
180+
batch_size, num_tokens, _ = x.shape
181+
percentage_to_keep = 1 - masking_ratio if masking_ratio is not None else 1 - self.masking_ratio
182+
selected_indices = torch.multinomial(
183+
torch.ones(batch_size, num_tokens), int(percentage_to_keep * num_tokens), replacement=False
184+
)
185+
x_masked = x[torch.arange(batch_size).unsqueeze(1), selected_indices] # gather the selected tokens
186+
mask = torch.ones(batch_size, num_tokens, dtype=torch.int).to(x.device)
187+
mask[torch.arange(batch_size).unsqueeze(-1), selected_indices] = 0
188+
189+
return x_masked, selected_indices, mask
190+
191+
def forward(self, x, masking_ratio: float | None = None):
192+
x = self.patch_embedding(x)
193+
x, selected_indices, mask = self._masking(x, masking_ratio=masking_ratio)
194+
195+
cls_tokens = self.cls_token.expand(x.shape[0], -1, -1)
196+
x = torch.cat((cls_tokens, x), dim=1)
197+
198+
x = self.blocks(x)
199+
200+
# decoder
201+
x = self.decoder_embed(x)
202+
203+
x_ = self.mask_tokens.repeat(x.shape[0], mask.shape[1], 1)
204+
x_[torch.arange(x.shape[0]).unsqueeze(-1), selected_indices] = x[:, 1:, :] # no cls token
205+
x_ = x_ + self.decoder_pos_embedding
206+
x = torch.cat([x[:, :1, :], x_], dim=1)
207+
x = self.decoder_blocks(x)
208+
x = self.decoder_pred(x)
209+
210+
x = x[:, 1:, :]
211+
return x, mask

tests/test_masked_autoencoder_vit.py

+160
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,160 @@
1+
# Copyright (c) MONAI Consortium
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
# http://www.apache.org/licenses/LICENSE-2.0
6+
# Unless required by applicable law or agreed to in writing, software
7+
# distributed under the License is distributed on an "AS IS" BASIS,
8+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
# See the License for the specific language governing permissions and
10+
# limitations under the License.
11+
12+
from __future__ import annotations
13+
14+
import unittest
15+
16+
import torch
17+
from parameterized import parameterized
18+
19+
from monai.networks import eval_mode
20+
from monai.networks.nets.masked_autoencoder_vit import MaskedAutoEncoderViT
21+
from tests.utils import skip_if_quick
22+
23+
TEST_CASE_MaskedAutoEncoderViT = []
24+
for masking_ratio in [0.5]:
25+
for dropout_rate in [0.6]:
26+
for in_channels in [4]:
27+
for hidden_size in [768]:
28+
for img_size in [96, 128]:
29+
for patch_size in [16]:
30+
for num_heads in [12]:
31+
for mlp_dim in [3072]:
32+
for num_layers in [4]:
33+
for decoder_hidden_size in [384]:
34+
for decoder_mlp_dim in [512]:
35+
for decoder_num_layers in [4]:
36+
for decoder_num_heads in [16]:
37+
for pos_embed_type in ["sincos", "learnable"]:
38+
for proj_type in ["conv", "perceptron"]:
39+
for nd in (2, 3):
40+
test_case = [
41+
{
42+
"in_channels": in_channels,
43+
"img_size": (img_size,) * nd,
44+
"patch_size": (patch_size,) * nd,
45+
"hidden_size": hidden_size,
46+
"mlp_dim": mlp_dim,
47+
"num_layers": num_layers,
48+
"decoder_hidden_size": decoder_hidden_size,
49+
"decoder_mlp_dim": decoder_mlp_dim,
50+
"decoder_num_layers": decoder_num_layers,
51+
"decoder_num_heads": decoder_num_heads,
52+
"pos_embed_type": pos_embed_type,
53+
"masking_ratio": masking_ratio,
54+
"decoder_pos_embed_type": pos_embed_type,
55+
"num_heads": num_heads,
56+
"proj_type": proj_type,
57+
"dropout_rate": dropout_rate,
58+
},
59+
(2, in_channels, *([img_size] * nd)),
60+
(
61+
2,
62+
(img_size // patch_size) ** nd,
63+
in_channels * (patch_size**nd),
64+
),
65+
]
66+
if nd == 2:
67+
test_case[0]["spatial_dims"] = 2 # type: ignore
68+
TEST_CASE_MaskedAutoEncoderViT.append(test_case)
69+
70+
TEST_CASE_ill_args = [
71+
[{"in_channels": 1, "img_size": (128, 128, 128), "patch_size": (16, 16, 16), "dropout_rate": 5.0}],
72+
[{"in_channels": 1, "img_size": (128, 128, 128), "patch_size": (64, 64, 64), "pos_embed_type": "sin"}],
73+
[{"in_channels": 1, "img_size": (128, 128, 128), "patch_size": (64, 64, 64), "decoder_pos_embed_type": "sin"}],
74+
[{"in_channels": 1, "img_size": (32, 32, 32), "patch_size": (64, 64, 64)}],
75+
[{"in_channels": 1, "img_size": (128, 128, 128), "patch_size": (64, 64, 64), "num_layers": 12, "num_heads": 14}],
76+
[{"in_channels": 1, "img_size": (97, 97, 97), "patch_size": (16, 16, 16)}],
77+
[{"in_channels": 1, "img_size": (128, 128, 128), "patch_size": (64, 64, 64), "masking_ratio": 1.1}],
78+
[{"in_channels": 1, "img_size": (128, 128, 128), "patch_size": (64, 64, 64), "masking_ratio": -0.1}],
79+
]
80+
81+
82+
@skip_if_quick
83+
class TestMaskedAutoencoderViT(unittest.TestCase):
84+
85+
@parameterized.expand(TEST_CASE_MaskedAutoEncoderViT)
86+
def test_shape(self, input_param, input_shape, expected_shape):
87+
net = MaskedAutoEncoderViT(**input_param)
88+
with eval_mode(net):
89+
result, _ = net(torch.randn(input_shape))
90+
self.assertEqual(result.shape, expected_shape)
91+
92+
def test_frozen_pos_embedding(self):
93+
net = MaskedAutoEncoderViT(in_channels=1, img_size=(96, 96, 96), patch_size=(16, 16, 16))
94+
95+
self.assertEqual(net.decoder_pos_embedding.requires_grad, False)
96+
97+
@parameterized.expand(TEST_CASE_ill_args)
98+
def test_ill_arg(self, input_param):
99+
with self.assertRaises(ValueError):
100+
MaskedAutoEncoderViT(**input_param)
101+
102+
def test_access_attn_matrix(self):
103+
# input format
104+
in_channels = 1
105+
img_size = (96, 96, 96)
106+
patch_size = (16, 16, 16)
107+
in_shape = (1, in_channels, img_size[0], img_size[1], img_size[2])
108+
109+
# no data in the matrix
110+
no_matrix_acess_blk = MaskedAutoEncoderViT(in_channels=in_channels, img_size=img_size, patch_size=patch_size)
111+
no_matrix_acess_blk(torch.randn(in_shape))
112+
assert isinstance(no_matrix_acess_blk.blocks[0].attn.att_mat, torch.Tensor)
113+
# no of elements is zero
114+
assert no_matrix_acess_blk.blocks[0].attn.att_mat.nelement() == 0
115+
116+
# be able to acess the attention matrix
117+
matrix_acess_blk = MaskedAutoEncoderViT(
118+
in_channels=in_channels, img_size=img_size, patch_size=patch_size, save_attn=True
119+
)
120+
matrix_acess_blk(torch.randn(in_shape))
121+
122+
assert matrix_acess_blk.blocks[0].attn.att_mat.shape == (in_shape[0], 12, 55, 55)
123+
124+
def test_masking_ratio(self):
125+
# input format
126+
in_channels = 1
127+
img_size = (96, 96, 96)
128+
patch_size = (16, 16, 16)
129+
in_shape = (1, in_channels, img_size[0], img_size[1], img_size[2])
130+
131+
# masking ratio 0.25
132+
masking_ratio_blk = MaskedAutoEncoderViT(
133+
in_channels=in_channels, img_size=img_size, patch_size=patch_size, masking_ratio=0.25, save_attn=True
134+
)
135+
masking_ratio_blk(torch.randn(in_shape))
136+
desired_num_tokens = int(
137+
(img_size[0] // patch_size[0])
138+
* (img_size[1] // patch_size[1])
139+
* (img_size[2] // patch_size[2])
140+
* (1 - 0.25)
141+
)
142+
assert masking_ratio_blk.blocks[0].attn.att_mat.shape[-1] - 1 == desired_num_tokens
143+
144+
# masking ratio 0.33
145+
masking_ratio_blk = MaskedAutoEncoderViT(
146+
in_channels=in_channels, img_size=img_size, patch_size=patch_size, masking_ratio=0.33, save_attn=True
147+
)
148+
masking_ratio_blk(torch.randn(in_shape))
149+
desired_num_tokens = int(
150+
(img_size[0] // patch_size[0])
151+
* (img_size[1] // patch_size[1])
152+
* (img_size[2] // patch_size[2])
153+
* (1 - 0.33)
154+
)
155+
156+
assert masking_ratio_blk.blocks[0].attn.att_mat.shape[-1] - 1 == desired_num_tokens
157+
158+
159+
if __name__ == "__main__":
160+
unittest.main()

0 commit comments

Comments
 (0)