Skip to content

Commit 78c1460

Browse files
authored
Merge pull request #33 from HiLab-git/dev
Dev
2 parents 1ecadee + 603df81 commit 78c1460

32 files changed

+432
-188
lines changed

README.md

+6-6
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@ PyMIC is a pytorch-based toolkit for medical image computing with annotation-eff
44

55
Currently PyMIC supports 2D/3D medical image classification and segmentation, and it is still under development. If you use this toolkit, please cite the following paper:
66

7-
* G. Wang, X. Luo, R. Gu, S. Yang, Y. Qu, S. Zhai, Q. Zhao, K. Li, S. Zhang. (2022).
8-
[PyMIC: A deep learning toolkit for annotation-efficient medical image segmentation.][arxiv2022] arXiv, 2208.09350.
7+
* G. Wang, X. Luo, R. Gu, S. Yang, Y. Qu, S. Zhai, Q. Zhao, K. Li, S. Zhang. (2023).
8+
[PyMIC: A deep learning toolkit for annotation-efficient medical image segmentation.][arxiv2022] Computer Methods and Programs in Biomedicine (CMPB). February 2023, 107398.
99

1010
[arxiv2022]:http://arxiv.org/abs/2208.09350
1111

@@ -14,11 +14,11 @@ BibTeX entry:
1414
@article{Wang2022pymic,
1515
author = {Guotai Wang and Xiangde Luo and Ran Gu and Shuojue Yang and Yijie Qu and Shuwei Zhai and Qianfei Zhao and Kang Li and Shaoting Zhang},
1616
title = {{PyMIC: A deep learning toolkit for annotation-efficient medical image segmentation}},
17-
year = {2022},
17+
year = {2023},
1818
url = {http://arxiv.org/abs/2208.09350},
19-
journal = {arXiv},
20-
volume = {2208.09350},
21-
pages = {1-10},
19+
journal = {Computer Methods and Programs in Biomedicine},
20+
volume = {February},
21+
pages = {107398},
2222
}
2323

2424
# Features

pymic/loss/seg/deep_sup.py

+53-10
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,43 @@
22
from __future__ import print_function, division
33

44
import torch.nn as nn
5+
from torch.nn.functional import interpolate
56
from pymic.loss.seg.abstract import AbstractSegLoss
67

8+
def match_prediction_and_gt_shape(pred, gt, mode = 0):
9+
pred_shape = list(pred.shape)
10+
gt_shape = list(gt.shape)
11+
dim = len(pred_shape) - 2
12+
shape_match = False
13+
if(dim == 2):
14+
if(pred_shape[-1] == gt_shape[-1] and pred_shape[-2] == gt_shape[-2]):
15+
shape_match = True
16+
else:
17+
if(pred_shape[-1] == gt_shape[-1] and pred_shape[-2] == gt_shape[-2]
18+
and pred_shape[-3] == gt_shape[-3]):
19+
shape_match = True
20+
if(shape_match):
21+
return pred, gt
22+
23+
interp_mode = 'bilinear' if dim == 2 else 'trilinear'
24+
if(mode == 0):
25+
pred_new = interpolate(pred, gt_shape[2:], mode = interp_mode)
26+
gt_new = gt
27+
elif(mode == 1):
28+
pred_new = pred
29+
gt_new = interpolate(gt, pred_shape[2:], mode = interp_mode)
30+
elif(mode == 2):
31+
pred_new = pred
32+
if(dim == 2):
33+
avg_pool = nn.AdaptiveAvgPool2d(pred_shape[-2:])
34+
else:
35+
avg_pool = nn.AdaptiveAvgPool3d(pred_shape[-3:])
36+
gt_new = avg_pool(gt)
37+
else:
38+
raise ValueError("mode shoud be 0, 1 or 2, but {0:} was given".format(mode))
39+
return pred_new, gt_new
40+
41+
742
class DeepSuperviseLoss(AbstractSegLoss):
843
'''
944
Combine deep supervision with a basic loss function.
@@ -12,28 +47,36 @@ class DeepSuperviseLoss(AbstractSegLoss):
1247
1348
:param `loss_softmax`: (optional, bool)
1449
Apply softmax to the prediction of network or not. Default is True.
15-
:param `deep_suervise_weight`: (list) A list of weight for each deep supervision scale. \n
1650
:param `base_loss`: (nn.Module) The basic function used for each scale.
51+
:param `deep_supervise_weight`: (list) A list of weight for each deep supervision scale.
52+
:param `deep_supervise_model`: (int) Mode for deep supervision when the prediction
53+
has a smaller shape than the ground truth. 0: upsample the prediction to the size
54+
of the ground truth. 1: downsample the ground truth to the size of the prediction
55+
via interpolation. 2: downsample the ground truth via adaptive average pooling.
1756
1857
'''
1958
def __init__(self, params):
2059
super(DeepSuperviseLoss, self).__init__(params)
21-
self.deep_sup_weight = params.get('deep_suervise_weight', None)
22-
self.base_loss = params['base_loss']
60+
self.base_loss = params['base_loss']
61+
self.deep_sup_weight = params.get('deep_supervise_weight', None)
62+
self.deep_sup_mode = params.get('deep_supervise_mode', 0)
2363

2464
def forward(self, loss_input_dict):
25-
predict = loss_input_dict['prediction']
26-
if(not isinstance(predict, (list,tuple))):
65+
pred = loss_input_dict['prediction']
66+
gt = loss_input_dict['ground_truth']
67+
if(not isinstance(pred, (list,tuple))):
2768
raise ValueError("""For deep supervision, the prediction should
2869
be a list or a tuple""")
29-
predict_num = len(predict)
70+
pred_num = len(pred)
3071
if(self.deep_sup_weight is None):
31-
self.deep_sup_weight = [1.0] * predict_num
72+
self.deep_sup_weight = [1.0] * pred_num
3273
else:
33-
assert(predict_num == len(self.deep_sup_weight))
74+
assert(pred_num == len(self.deep_sup_weight))
3475
loss_sum, weight_sum = 0.0, 0.0
35-
for i in range(predict_num):
36-
loss_input_dict['prediction'] = predict[i]
76+
for i in range(pred_num):
77+
pred_i, gt_i = match_prediction_and_gt_shape(pred[i], gt, self.deep_sup_mode)
78+
loss_input_dict['prediction'] = pred_i
79+
loss_input_dict['ground_truth'] = gt_i
3780
temp_loss = self.base_loss(loss_input_dict)
3881
loss_sum += temp_loss * self.deep_sup_weight[i]
3982
weight_sum += self.deep_sup_weight[i]

pymic/net/net2d/unet2d.py

+1-5
Original file line numberDiff line numberDiff line change
@@ -240,18 +240,14 @@ def forward(self, x):
240240
x_d0 = self.up4(x_d1, x0)
241241
output = self.out_conv(x_d0)
242242
if(self.deep_sup):
243-
out_shape = list(output.shape)[2:]
244243
output1 = self.out_conv1(x_d1)
245-
output1 = interpolate(output1, out_shape, mode = 'bilinear')
246244
output2 = self.out_conv2(x_d2)
247-
output2 = interpolate(output2, out_shape, mode = 'bilinear')
248245
output3 = self.out_conv3(x_d3)
249-
output3 = interpolate(output3, out_shape, mode = 'bilinear')
250246
output = [output, output1, output2, output3]
251247

252248
if(len(x_shape) == 5):
253-
new_shape = [N, D] + list(output[0].shape)[1:]
254249
for i in range(len(output)):
250+
new_shape = [N, D] + list(output[i].shape)[1:]
255251
output[i] = torch.transpose(torch.reshape(output[i], new_shape), 1, 2)
256252
elif(len(x_shape) == 5):
257253
new_shape = [N, D] + list(output.shape)[1:]

pymic/net/net3d/unet3d.py

+87
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,93 @@ def forward(self, x1, x2):
7777
x = torch.cat([x2, x1], dim=1)
7878
return self.conv(x)
7979

80+
class Encoder(nn.Module):
81+
"""
82+
Encoder of 3D UNet.
83+
84+
Parameters are given in the `params` dictionary, and should include the
85+
following fields:
86+
87+
:param in_chns: (int) Input channel number.
88+
:param feature_chns: (list) Feature channel for each resolution level.
89+
The length should be 4 or 5, such as [16, 32, 64, 128, 256].
90+
:param dropout: (list) The dropout ratio for each resolution level.
91+
The length should be the same as that of `feature_chns`.
92+
"""
93+
def __init__(self, params):
94+
super(Encoder, self).__init__()
95+
self.params = params
96+
self.in_chns = self.params['in_chns']
97+
self.ft_chns = self.params['feature_chns']
98+
self.dropout = self.params['dropout']
99+
assert(len(self.ft_chns) == 5 or len(self.ft_chns) == 4)
100+
101+
self.in_conv= ConvBlock(self.in_chns, self.ft_chns[0], self.dropout[0])
102+
self.down1 = DownBlock(self.ft_chns[0], self.ft_chns[1], self.dropout[1])
103+
self.down2 = DownBlock(self.ft_chns[1], self.ft_chns[2], self.dropout[2])
104+
self.down3 = DownBlock(self.ft_chns[2], self.ft_chns[3], self.dropout[3])
105+
if(len(self.ft_chns) == 5):
106+
self.down4 = DownBlock(self.ft_chns[3], self.ft_chns[4], self.dropout[4])
107+
108+
def forward(self, x):
109+
x0 = self.in_conv(x)
110+
x1 = self.down1(x0)
111+
x2 = self.down2(x1)
112+
x3 = self.down3(x2)
113+
output = [x0, x1, x2, x3]
114+
if(len(self.ft_chns) == 5):
115+
x4 = self.down4(x3)
116+
output.append(x4)
117+
return output
118+
119+
class Decoder(nn.Module):
120+
"""
121+
Decoder of 3D UNet.
122+
123+
Parameters are given in the `params` dictionary, and should include the
124+
following fields:
125+
126+
:param in_chns: (int) Input channel number.
127+
:param feature_chns: (list) Feature channel for each resolution level.
128+
The length should be 4 or 5, such as [16, 32, 64, 128, 256].
129+
:param dropout: (list) The dropout ratio for each resolution level.
130+
The length should be the same as that of `feature_chns`.
131+
:param class_num: (int) The class number for segmentation task.
132+
:param trilinear: (bool) Using bilinear for up-sampling or not.
133+
If False, deconvolution will be used for up-sampling.
134+
"""
135+
def __init__(self, params):
136+
super(Decoder, self).__init__()
137+
self.params = params
138+
self.in_chns = self.params['in_chns']
139+
self.ft_chns = self.params['feature_chns']
140+
self.dropout = self.params['dropout']
141+
self.n_class = self.params['class_num']
142+
self.trilinear = self.params['trilinear']
143+
144+
assert(len(self.ft_chns) == 5 or len(self.ft_chns) == 4)
145+
146+
if(len(self.ft_chns) == 5):
147+
self.up1 = UpBlock(self.ft_chns[4], self.ft_chns[3], self.ft_chns[3], self.dropout[3], self.bilinear)
148+
self.up2 = UpBlock(self.ft_chns[3], self.ft_chns[2], self.ft_chns[2], self.dropout[2], self.bilinear)
149+
self.up3 = UpBlock(self.ft_chns[2], self.ft_chns[1], self.ft_chns[1], self.dropout[1], self.bilinear)
150+
self.up4 = UpBlock(self.ft_chns[1], self.ft_chns[0], self.ft_chns[0], self.dropout[0], self.bilinear)
151+
self.out_conv = nn.Conv3d(self.ft_chns[0], self.n_class, kernel_size = 1)
152+
153+
def forward(self, x):
154+
if(len(self.ft_chns) == 5):
155+
assert(len(x) == 5)
156+
x0, x1, x2, x3, x4 = x
157+
x_d3 = self.up1(x4, x3)
158+
else:
159+
assert(len(x) == 4)
160+
x0, x1, x2, x3 = x
161+
x_d3 = x3
162+
x_d2 = self.up2(x_d3, x2)
163+
x_d1 = self.up3(x_d2, x1)
164+
x_d0 = self.up4(x_d1, x0)
165+
output = self.out_conv(x_d0)
166+
return output
80167

81168
class UNet3D(nn.Module):
82169
"""

pymic/net/net3d/unet3d_dual_branch.py

+46
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
# -*- coding: utf-8 -*-
2+
from __future__ import print_function, division
3+
4+
import torch
5+
import torch.nn as nn
6+
from pymic.net.net3d.unet3d import *
7+
8+
class UNet3D_DualBranch(nn.Module):
9+
"""
10+
A dual branch network using UNet3D as backbone.
11+
12+
* Reference: Xiangde Luo, Minhao Hu, Wenjun Liao, Shuwei Zhai, Tao Song, Guotai Wang,
13+
Shaoting Zhang. ScribblScribble-Supervised Medical Image Segmentation via
14+
Dual-Branch Network and Dynamically Mixed Pseudo Labels Supervision.
15+
`MICCAI 2022. <https://arxiv.org/abs/2203.02106>`_
16+
17+
The parameters for the backbone should be given in the `params` dictionary.
18+
See :mod:`pymic.net.net3d.unet3d.UNet3D` for details.
19+
In addition, the following field should be included:
20+
21+
:param output_mode: (str) How to obtain the result during the inference.
22+
`average`: taking average of the two branches.
23+
`first`: takeing the result in the first branch.
24+
`second`: taking the result in the second branch.
25+
"""
26+
def __init__(self, params):
27+
super(UNet3D_DualBranch, self).__init__()
28+
self.output_mode = params.get("output_mode", "average")
29+
self.encoder = Encoder(params)
30+
self.decoder1 = Decoder(params)
31+
self.decoder2 = Decoder(params)
32+
33+
def forward(self, x):
34+
f = self.encoder(x)
35+
output1 = self.decoder1(f)
36+
output2 = self.decoder2(f)
37+
38+
if(self.training):
39+
return output1, output2
40+
else:
41+
if(self.output_mode == "average"):
42+
return (output1 + output2)/2
43+
elif(self.output_mode == "first"):
44+
return output1
45+
else:
46+
return output2

pymic/net/net_dict_seg.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from pymic.net.net3d.unet2d5 import UNet2D5
2525
from pymic.net.net3d.unet3d import UNet3D
2626
from pymic.net.net3d.unet3d_scse import UNet3D_ScSE
27+
from pymic.net.net3d.unet3d_dual_branch import UNet3D_DualBranch
2728

2829
SegNetDict = {
2930
'UNet2D': UNet2D,
@@ -35,5 +36,7 @@
3536
'UNet2D_ScSE': UNet2D_ScSE,
3637
'UNet2D5': UNet2D5,
3738
'UNet3D': UNet3D,
38-
'UNet3D_ScSE': UNet3D_ScSE
39+
'UNet3D_ScSE': UNet3D_ScSE,
40+
'UNet3D_DualBranch': UNet3D_DualBranch
41+
3942
}

pymic/net_run/agent_abstract.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -259,7 +259,7 @@ def worker_init_fn(worker_id):
259259

260260
bn_train = self.config['dataset']['train_batch_size']
261261
bn_valid = self.config['dataset'].get('valid_batch_size', 1)
262-
num_worker = self.config['dataset'].get('num_workder', 16)
262+
num_worker = self.config['dataset'].get('num_worker', 16)
263263
g_train, g_valid = torch.Generator(), torch.Generator()
264264
g_train.manual_seed(self.random_seed)
265265
g_valid.manual_seed(self.random_seed)

pymic/net_run/agent_seg.py

+17-21
Original file line numberDiff line numberDiff line change
@@ -162,10 +162,6 @@ def training(self):
162162
loss = self.get_loss_value(data, outputs, labels_prob)
163163
loss.backward()
164164
self.optimizer.step()
165-
if(self.scheduler is not None and \
166-
not isinstance(self.scheduler, lr_scheduler.ReduceLROnPlateau)):
167-
self.scheduler.step()
168-
169165
train_loss = train_loss + loss.item()
170166
# get dice evaluation for each class
171167
if(isinstance(outputs, tuple) or isinstance(outputs, list)):
@@ -177,9 +173,9 @@ def training(self):
177173
train_dice_list.append(dice_list.cpu().numpy())
178174
train_avg_loss = train_loss / iter_valid
179175
train_cls_dice = np.asarray(train_dice_list).mean(axis = 0)
180-
train_avg_dice = train_cls_dice.mean()
176+
train_avg_dice = train_cls_dice[1:].mean()
181177

182-
train_scalers = {'loss': train_avg_loss, 'avg_dice':train_avg_dice,\
178+
train_scalers = {'loss': train_avg_loss, 'avg_fg_dice':train_avg_dice,\
183179
'class_dice': train_cls_dice}
184180
return train_scalers
185181

@@ -218,18 +214,14 @@ def validation(self):
218214

219215
valid_avg_loss = np.asarray(valid_loss_list).mean()
220216
valid_cls_dice = np.asarray(valid_dice_list).mean(axis = 0)
221-
valid_avg_dice = valid_cls_dice.mean()
222-
223-
if(isinstance(self.scheduler, lr_scheduler.ReduceLROnPlateau)):
224-
self.scheduler.step(valid_avg_dice)
225-
226-
valid_scalers = {'loss': valid_avg_loss, 'avg_dice': valid_avg_dice,\
217+
valid_avg_dice = valid_cls_dice[1:].mean()
218+
valid_scalers = {'loss': valid_avg_loss, 'avg_fg_dice': valid_avg_dice,\
227219
'class_dice': valid_cls_dice}
228220
return valid_scalers
229221

230222
def write_scalars(self, train_scalars, valid_scalars, lr_value, glob_it):
231223
loss_scalar ={'train':train_scalars['loss'], 'valid':valid_scalars['loss']}
232-
dice_scalar ={'train':train_scalars['avg_dice'], 'valid':valid_scalars['avg_dice']}
224+
dice_scalar ={'train':train_scalars['avg_fg_dice'], 'valid':valid_scalars['avg_fg_dice']}
233225
self.summ_writer.add_scalars('loss', loss_scalar, glob_it)
234226
self.summ_writer.add_scalars('dice', dice_scalar, glob_it)
235227
self.summ_writer.add_scalars('lr', {"lr": lr_value}, glob_it)
@@ -239,11 +231,11 @@ def write_scalars(self, train_scalars, valid_scalars, lr_value, glob_it):
239231
'valid':valid_scalars['class_dice'][c]}
240232
self.summ_writer.add_scalars('class_{0:}_dice'.format(c), cls_dice_scalar, glob_it)
241233

242-
logging.info('train loss {0:.4f}, avg dice {1:.4f} '.format(
243-
train_scalars['loss'], train_scalars['avg_dice']) + "[" + \
234+
logging.info('train loss {0:.4f}, avg foreground dice {1:.4f} '.format(
235+
train_scalars['loss'], train_scalars['avg_fg_dice']) + "[" + \
244236
' '.join("{0:.4f}".format(x) for x in train_scalars['class_dice']) + "]")
245-
logging.info('valid loss {0:.4f}, avg dice {1:.4f} '.format(
246-
valid_scalars['loss'], valid_scalars['avg_dice']) + "[" + \
237+
logging.info('valid loss {0:.4f}, avg foreground dice {1:.4f} '.format(
238+
valid_scalars['loss'], valid_scalars['avg_fg_dice']) + "[" + \
247239
' '.join("{0:.4f}".format(x) for x in valid_scalars['class_dice']) + "]")
248240

249241
def train_valid(self):
@@ -300,16 +292,20 @@ def train_valid(self):
300292
t0 = time.time()
301293
train_scalars = self.training()
302294
t1 = time.time()
303-
304295
valid_scalars = self.validation()
305296
t2 = time.time()
297+
if(isinstance(self.scheduler, lr_scheduler.ReduceLROnPlateau)):
298+
self.scheduler.step(valid_scalars['avg_fg_dice'])
299+
else:
300+
self.scheduler.step()
301+
306302
self.glob_it = it + iter_valid
307303
logging.info("\n{0:} it {1:}".format(str(datetime.now())[:-7], self.glob_it))
308304
logging.info('learning rate {0:}'.format(lr_value))
309305
logging.info("training/validation time: {0:.2f}s/{1:.2f}s".format(t1-t0, t2-t1))
310306
self.write_scalars(train_scalars, valid_scalars, lr_value, self.glob_it)
311-
if(valid_scalars['avg_dice'] > self.max_val_dice):
312-
self.max_val_dice = valid_scalars['avg_dice']
307+
if(valid_scalars['avg_fg_dice'] > self.max_val_dice):
308+
self.max_val_dice = valid_scalars['avg_fg_dice']
313309
self.max_val_it = self.glob_it
314310
if(len(device_ids) > 1):
315311
self.best_model_wts = copy.deepcopy(self.net.module.state_dict())
@@ -320,7 +316,7 @@ def train_valid(self):
320316
self.glob_it - self.max_val_it > early_stop_it) else False
321317
if ((self.glob_it in iter_save_list) or stop_now):
322318
save_dict = {'iteration': self.glob_it,
323-
'valid_pred': valid_scalars['avg_dice'],
319+
'valid_pred': valid_scalars['avg_fg_dice'],
324320
'model_state_dict': self.net.module.state_dict() \
325321
if len(device_ids) > 1 else self.net.state_dict(),
326322
'optimizer_state_dict': self.optimizer.state_dict()}

0 commit comments

Comments
 (0)