Skip to content

Commit b58e883

Browse files
johnzielkeericspodKumoLiu
authored
selfattention block: Remove the fc linear layer if it is not used (#8325)
### Description when include_fc = False, the nn.Linear layer is unused. This leads to errors and warning when training with the pytorch Distributed Data Parallel infrastructure, since the parameters for the nn.Linear layer will not have gradients attached. ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: John Zielke <[email protected]> Co-authored-by: Eric Kerfoot <[email protected]> Co-authored-by: YunLiu <[email protected]>
1 parent f27517b commit b58e883

File tree

3 files changed

+29
-4
lines changed

3 files changed

+29
-4
lines changed

monai/networks/blocks/selfattention.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,11 @@ def __init__(
101101

102102
self.num_heads = num_heads
103103
self.hidden_input_size = hidden_input_size if hidden_input_size else hidden_size
104-
self.out_proj = nn.Linear(self.inner_dim, self.hidden_input_size)
104+
self.out_proj: Union[nn.Linear, nn.Identity]
105+
if include_fc:
106+
self.out_proj = nn.Linear(self.inner_dim, self.hidden_input_size)
107+
else:
108+
self.out_proj = nn.Identity()
105109

106110
self.qkv: Union[nn.Linear, nn.Identity]
107111
self.to_q: Union[nn.Linear, nn.Identity]

monai/networks/nets/diffusion_model_unet.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -1847,9 +1847,9 @@ def load_old_state_dict(self, old_state_dict: dict, verbose=False) -> None:
18471847
new_state_dict[f"{block}.attn.to_v.bias"] = old_state_dict.pop(f"{block}.to_v.bias")
18481848

18491849
# projection
1850-
new_state_dict[f"{block}.attn.out_proj.weight"] = old_state_dict.pop(f"{block}.proj_attn.weight")
1851-
new_state_dict[f"{block}.attn.out_proj.bias"] = old_state_dict.pop(f"{block}.proj_attn.bias")
1852-
1850+
if f"{block}.attn.out_proj.weight" in new_state_dict and f"{block}.attn.out_proj.bias" in new_state_dict:
1851+
new_state_dict[f"{block}.attn.out_proj.weight"] = old_state_dict.pop(f"{block}.proj_attn.weight")
1852+
new_state_dict[f"{block}.attn.out_proj.bias"] = old_state_dict.pop(f"{block}.proj_attn.bias")
18531853
# fix the cross attention blocks
18541854
cross_attention_blocks = [
18551855
k.replace(".out_proj.weight", "")

tests/networks/blocks/test_selfattention.py

+21
Original file line numberDiff line numberDiff line change
@@ -227,6 +227,27 @@ def test_flash_attention(self):
227227
out_2 = block_wo_flash_attention(test_data)
228228
assert_allclose(out_1, out_2, atol=1e-4)
229229

230+
@parameterized.expand([[True], [False]])
231+
def test_no_extra_weights_if_no_fc(self, include_fc):
232+
input_param = {
233+
"hidden_size": 360,
234+
"num_heads": 4,
235+
"dropout_rate": 0.0,
236+
"rel_pos_embedding": None,
237+
"input_size": (16, 32),
238+
"include_fc": include_fc,
239+
"use_combined_linear": use_combined_linear,
240+
}
241+
net = SABlock(**input_param)
242+
if not include_fc:
243+
self.assertNotIn("out_proj.weight", net.state_dict())
244+
self.assertNotIn("out_proj.bias", net.state_dict())
245+
self.assertIsInstance(net.out_proj, torch.nn.Identity)
246+
else:
247+
self.assertIn("out_proj.weight", net.state_dict())
248+
self.assertIn("out_proj.bias", net.state_dict())
249+
self.assertIsInstance(net.out_proj, torch.nn.Linear)
250+
230251

231252
if __name__ == "__main__":
232253
unittest.main()

0 commit comments

Comments
 (0)