@@ -77,6 +77,93 @@ def forward(self, x1, x2):
77
77
x = torch .cat ([x2 , x1 ], dim = 1 )
78
78
return self .conv (x )
79
79
80
+ class Encoder (nn .Module ):
81
+ """
82
+ Encoder of 3D UNet.
83
+
84
+ Parameters are given in the `params` dictionary, and should include the
85
+ following fields:
86
+
87
+ :param in_chns: (int) Input channel number.
88
+ :param feature_chns: (list) Feature channel for each resolution level.
89
+ The length should be 4 or 5, such as [16, 32, 64, 128, 256].
90
+ :param dropout: (list) The dropout ratio for each resolution level.
91
+ The length should be the same as that of `feature_chns`.
92
+ """
93
+ def __init__ (self , params ):
94
+ super (Encoder , self ).__init__ ()
95
+ self .params = params
96
+ self .in_chns = self .params ['in_chns' ]
97
+ self .ft_chns = self .params ['feature_chns' ]
98
+ self .dropout = self .params ['dropout' ]
99
+ assert (len (self .ft_chns ) == 5 or len (self .ft_chns ) == 4 )
100
+
101
+ self .in_conv = ConvBlock (self .in_chns , self .ft_chns [0 ], self .dropout [0 ])
102
+ self .down1 = DownBlock (self .ft_chns [0 ], self .ft_chns [1 ], self .dropout [1 ])
103
+ self .down2 = DownBlock (self .ft_chns [1 ], self .ft_chns [2 ], self .dropout [2 ])
104
+ self .down3 = DownBlock (self .ft_chns [2 ], self .ft_chns [3 ], self .dropout [3 ])
105
+ if (len (self .ft_chns ) == 5 ):
106
+ self .down4 = DownBlock (self .ft_chns [3 ], self .ft_chns [4 ], self .dropout [4 ])
107
+
108
+ def forward (self , x ):
109
+ x0 = self .in_conv (x )
110
+ x1 = self .down1 (x0 )
111
+ x2 = self .down2 (x1 )
112
+ x3 = self .down3 (x2 )
113
+ output = [x0 , x1 , x2 , x3 ]
114
+ if (len (self .ft_chns ) == 5 ):
115
+ x4 = self .down4 (x3 )
116
+ output .append (x4 )
117
+ return output
118
+
119
+ class Decoder (nn .Module ):
120
+ """
121
+ Decoder of 3D UNet.
122
+
123
+ Parameters are given in the `params` dictionary, and should include the
124
+ following fields:
125
+
126
+ :param in_chns: (int) Input channel number.
127
+ :param feature_chns: (list) Feature channel for each resolution level.
128
+ The length should be 4 or 5, such as [16, 32, 64, 128, 256].
129
+ :param dropout: (list) The dropout ratio for each resolution level.
130
+ The length should be the same as that of `feature_chns`.
131
+ :param class_num: (int) The class number for segmentation task.
132
+ :param trilinear: (bool) Using bilinear for up-sampling or not.
133
+ If False, deconvolution will be used for up-sampling.
134
+ """
135
+ def __init__ (self , params ):
136
+ super (Decoder , self ).__init__ ()
137
+ self .params = params
138
+ self .in_chns = self .params ['in_chns' ]
139
+ self .ft_chns = self .params ['feature_chns' ]
140
+ self .dropout = self .params ['dropout' ]
141
+ self .n_class = self .params ['class_num' ]
142
+ self .trilinear = self .params ['trilinear' ]
143
+
144
+ assert (len (self .ft_chns ) == 5 or len (self .ft_chns ) == 4 )
145
+
146
+ if (len (self .ft_chns ) == 5 ):
147
+ self .up1 = UpBlock (self .ft_chns [4 ], self .ft_chns [3 ], self .ft_chns [3 ], self .dropout [3 ], self .bilinear )
148
+ self .up2 = UpBlock (self .ft_chns [3 ], self .ft_chns [2 ], self .ft_chns [2 ], self .dropout [2 ], self .bilinear )
149
+ self .up3 = UpBlock (self .ft_chns [2 ], self .ft_chns [1 ], self .ft_chns [1 ], self .dropout [1 ], self .bilinear )
150
+ self .up4 = UpBlock (self .ft_chns [1 ], self .ft_chns [0 ], self .ft_chns [0 ], self .dropout [0 ], self .bilinear )
151
+ self .out_conv = nn .Conv3d (self .ft_chns [0 ], self .n_class , kernel_size = 1 )
152
+
153
+ def forward (self , x ):
154
+ if (len (self .ft_chns ) == 5 ):
155
+ assert (len (x ) == 5 )
156
+ x0 , x1 , x2 , x3 , x4 = x
157
+ x_d3 = self .up1 (x4 , x3 )
158
+ else :
159
+ assert (len (x ) == 4 )
160
+ x0 , x1 , x2 , x3 = x
161
+ x_d3 = x3
162
+ x_d2 = self .up2 (x_d3 , x2 )
163
+ x_d1 = self .up3 (x_d2 , x1 )
164
+ x_d0 = self .up4 (x_d1 , x0 )
165
+ output = self .out_conv (x_d0 )
166
+ return output
80
167
81
168
class UNet3D (nn .Module ):
82
169
"""
0 commit comments