Skip to content

Commit abf0dc6

Browse files
committed
add dual-branch 3D unet
1 parent 452ccef commit abf0dc6

File tree

5 files changed

+157
-21
lines changed

5 files changed

+157
-21
lines changed

pymic/net/net3d/unet3d.py

+87
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,93 @@ def forward(self, x1, x2):
7777
x = torch.cat([x2, x1], dim=1)
7878
return self.conv(x)
7979

80+
class Encoder(nn.Module):
81+
"""
82+
Encoder of 3D UNet.
83+
84+
Parameters are given in the `params` dictionary, and should include the
85+
following fields:
86+
87+
:param in_chns: (int) Input channel number.
88+
:param feature_chns: (list) Feature channel for each resolution level.
89+
The length should be 4 or 5, such as [16, 32, 64, 128, 256].
90+
:param dropout: (list) The dropout ratio for each resolution level.
91+
The length should be the same as that of `feature_chns`.
92+
"""
93+
def __init__(self, params):
94+
super(Encoder, self).__init__()
95+
self.params = params
96+
self.in_chns = self.params['in_chns']
97+
self.ft_chns = self.params['feature_chns']
98+
self.dropout = self.params['dropout']
99+
assert(len(self.ft_chns) == 5 or len(self.ft_chns) == 4)
100+
101+
self.in_conv= ConvBlock(self.in_chns, self.ft_chns[0], self.dropout[0])
102+
self.down1 = DownBlock(self.ft_chns[0], self.ft_chns[1], self.dropout[1])
103+
self.down2 = DownBlock(self.ft_chns[1], self.ft_chns[2], self.dropout[2])
104+
self.down3 = DownBlock(self.ft_chns[2], self.ft_chns[3], self.dropout[3])
105+
if(len(self.ft_chns) == 5):
106+
self.down4 = DownBlock(self.ft_chns[3], self.ft_chns[4], self.dropout[4])
107+
108+
def forward(self, x):
109+
x0 = self.in_conv(x)
110+
x1 = self.down1(x0)
111+
x2 = self.down2(x1)
112+
x3 = self.down3(x2)
113+
output = [x0, x1, x2, x3]
114+
if(len(self.ft_chns) == 5):
115+
x4 = self.down4(x3)
116+
output.append(x4)
117+
return output
118+
119+
class Decoder(nn.Module):
120+
"""
121+
Decoder of 3D UNet.
122+
123+
Parameters are given in the `params` dictionary, and should include the
124+
following fields:
125+
126+
:param in_chns: (int) Input channel number.
127+
:param feature_chns: (list) Feature channel for each resolution level.
128+
The length should be 4 or 5, such as [16, 32, 64, 128, 256].
129+
:param dropout: (list) The dropout ratio for each resolution level.
130+
The length should be the same as that of `feature_chns`.
131+
:param class_num: (int) The class number for segmentation task.
132+
:param trilinear: (bool) Using bilinear for up-sampling or not.
133+
If False, deconvolution will be used for up-sampling.
134+
"""
135+
def __init__(self, params):
136+
super(Decoder, self).__init__()
137+
self.params = params
138+
self.in_chns = self.params['in_chns']
139+
self.ft_chns = self.params['feature_chns']
140+
self.dropout = self.params['dropout']
141+
self.n_class = self.params['class_num']
142+
self.trilinear = self.params['trilinear']
143+
144+
assert(len(self.ft_chns) == 5 or len(self.ft_chns) == 4)
145+
146+
if(len(self.ft_chns) == 5):
147+
self.up1 = UpBlock(self.ft_chns[4], self.ft_chns[3], self.ft_chns[3], self.dropout[3], self.bilinear)
148+
self.up2 = UpBlock(self.ft_chns[3], self.ft_chns[2], self.ft_chns[2], self.dropout[2], self.bilinear)
149+
self.up3 = UpBlock(self.ft_chns[2], self.ft_chns[1], self.ft_chns[1], self.dropout[1], self.bilinear)
150+
self.up4 = UpBlock(self.ft_chns[1], self.ft_chns[0], self.ft_chns[0], self.dropout[0], self.bilinear)
151+
self.out_conv = nn.Conv3d(self.ft_chns[0], self.n_class, kernel_size = 1)
152+
153+
def forward(self, x):
154+
if(len(self.ft_chns) == 5):
155+
assert(len(x) == 5)
156+
x0, x1, x2, x3, x4 = x
157+
x_d3 = self.up1(x4, x3)
158+
else:
159+
assert(len(x) == 4)
160+
x0, x1, x2, x3 = x
161+
x_d3 = x3
162+
x_d2 = self.up2(x_d3, x2)
163+
x_d1 = self.up3(x_d2, x1)
164+
x_d0 = self.up4(x_d1, x0)
165+
output = self.out_conv(x_d0)
166+
return output
80167

81168
class UNet3D(nn.Module):
82169
"""

pymic/net/net3d/unet3d_dual_branch.py

+46
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
# -*- coding: utf-8 -*-
2+
from __future__ import print_function, division
3+
4+
import torch
5+
import torch.nn as nn
6+
from pymic.net.net3d.unet3d import *
7+
8+
class UNet3D_DualBranch(nn.Module):
9+
"""
10+
A dual branch network using UNet3D as backbone.
11+
12+
* Reference: Xiangde Luo, Minhao Hu, Wenjun Liao, Shuwei Zhai, Tao Song, Guotai Wang,
13+
Shaoting Zhang. ScribblScribble-Supervised Medical Image Segmentation via
14+
Dual-Branch Network and Dynamically Mixed Pseudo Labels Supervision.
15+
`MICCAI 2022. <https://arxiv.org/abs/2203.02106>`_
16+
17+
The parameters for the backbone should be given in the `params` dictionary.
18+
See :mod:`pymic.net.net3d.unet3d.UNet3D` for details.
19+
In addition, the following field should be included:
20+
21+
:param output_mode: (str) How to obtain the result during the inference.
22+
`average`: taking average of the two branches.
23+
`first`: takeing the result in the first branch.
24+
`second`: taking the result in the second branch.
25+
"""
26+
def __init__(self, params):
27+
super(UNet3D_DualBranch, self).__init__()
28+
self.output_mode = params.get("output_mode", "average")
29+
self.encoder = Encoder(params)
30+
self.decoder1 = Decoder(params)
31+
self.decoder2 = Decoder(params)
32+
33+
def forward(self, x):
34+
f = self.encoder(x)
35+
output1 = self.decoder1(f)
36+
output2 = self.decoder2(f)
37+
38+
if(self.training):
39+
return output1, output2
40+
else:
41+
if(self.output_mode == "average"):
42+
return (output1 + output2)/2
43+
elif(self.output_mode == "first"):
44+
return output1
45+
else:
46+
return output2

pymic/net/net_dict_seg.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from pymic.net.net3d.unet2d5 import UNet2D5
2525
from pymic.net.net3d.unet3d import UNet3D
2626
from pymic.net.net3d.unet3d_scse import UNet3D_ScSE
27+
from pymic.net.net3d.unet3d_dual_branch import UNet3D_DualBranch
2728

2829
SegNetDict = {
2930
'UNet2D': UNet2D,
@@ -35,5 +36,7 @@
3536
'UNet2D_ScSE': UNet2D_ScSE,
3637
'UNet2D5': UNet2D5,
3738
'UNet3D': UNet3D,
38-
'UNet3D_ScSE': UNet3D_ScSE
39+
'UNet3D_ScSE': UNet3D_ScSE,
40+
'UNet3D_DualBranch': UNet3D_DualBranch
41+
3942
}

pymic/transform/rescale.py

+12-12
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,7 @@
11
# -*- coding: utf-8 -*-
22
from __future__ import print_function, division
33

4-
import torch
54
import json
6-
import math
75
import random
86
import numpy as np
97
from scipy import ndimage
@@ -104,11 +102,13 @@ def __init__(self, params):
104102
assert isinstance(self.ratio1, (float, list, tuple))
105103

106104
def __call__(self, sample):
107-
if(np.random.uniform() > self.prob):
108-
sample['RandomRescale_triggered'] = False
109-
return sample
110-
else:
111-
sample['RandomRescale_triggered'] = True
105+
# if(random.random() > self.prob):
106+
# print("rescale not started")
107+
# sample['RandomRescale_triggered'] = False
108+
# return sample
109+
# else:
110+
# print("rescale started")
111+
# sample['RandomRescale_triggered'] = True
112112
image = sample['image']
113113
input_shape = image.shape
114114
input_dim = len(input_shape) - 1
@@ -125,7 +125,7 @@ def __call__(self, sample):
125125
image_t = ndimage.interpolation.zoom(image, scale, order = 1)
126126

127127
sample['image'] = image_t
128-
sample['RandomRescale_origin_shape'] = json.dumps(input_shape)
128+
sample['RandomRescale_Param'] = json.dumps(input_shape)
129129
if('label' in sample and self.task == 'segmentation'):
130130
label = sample['label']
131131
label = ndimage.interpolation.zoom(label, scale, order = 0)
@@ -140,11 +140,11 @@ def __call__(self, sample):
140140
def inverse_transform_for_prediction(self, sample):
141141
if(not sample['RandomRescale_triggered']):
142142
return sample
143-
if(isinstance(sample['RandomRescale_origin_shape'], list) or \
144-
isinstance(sample['RandomRescale_origin_shape'], tuple)):
145-
origin_shape = json.loads(sample['RandomRescale_origin_shape'][0])
143+
if(isinstance(sample['RandomRescale_Param'], list) or \
144+
isinstance(sample['RandomRescale_Param'], tuple)):
145+
origin_shape = json.loads(sample['RandomRescale_Param'][0])
146146
else:
147-
origin_shape = json.loads(sample['RandomRescale_origin_shape'])
147+
origin_shape = json.loads(sample['RandomRescale_Param'])
148148
origin_dim = len(origin_shape) - 1
149149
predict = sample['predict']
150150
input_shape = predict.shape

pymic/transform/rotate.py

+8-8
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,7 @@
11
# -*- coding: utf-8 -*-
22
from __future__ import print_function, division
33

4-
import torch
54
import json
6-
import math
75
import random
86
import numpy as np
97
from scipy import ndimage
@@ -53,11 +51,11 @@ def __apply_transformation(self, image, transform_param_list, order = 1):
5351
return image
5452

5553
def __call__(self, sample):
56-
if(np.random.uniform() > self.prob):
57-
sample['RandomRotate_triggered'] = False
58-
return sample
59-
else:
60-
sample['RandomRotate_triggered'] = True
54+
# if(random.random() > self.prob):
55+
# sample['RandomRotate_triggered'] = False
56+
# return sample
57+
# else:
58+
# sample['RandomRotate_triggered'] = True
6159
image = sample['image']
6260
input_shape = image.shape
6361
input_dim = len(input_shape) - 1
@@ -74,7 +72,9 @@ def __call__(self, sample):
7472
angle_w = np.random.uniform(self.angle_range_w[0], self.angle_range_w[1])
7573
transform_param_list.append([angle_w, (-2, -3)])
7674
assert(len(transform_param_list) > 0)
77-
75+
# select a random transform from the possible list rather than
76+
# use a combination for higher efficiency
77+
transform_param_list = [random.choice(transform_param_list)]
7878
sample['RandomRotate_Param'] = json.dumps(transform_param_list)
7979
image_t = self.__apply_transformation(image, transform_param_list, 1)
8080
sample['image'] = image_t

0 commit comments

Comments
 (0)