Skip to content

Commit f2dad93

Browse files
authored
Merge pull request #26 from HiLab-git/dev
Dev
2 parents ee51fa9 + 42feb23 commit f2dad93

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

55 files changed

+1843
-679
lines changed

README.md

+18-14
Original file line numberDiff line numberDiff line change
@@ -1,56 +1,60 @@
11
# PyMIC: A Pytorch-Based Toolkit for Medical Image Computing
22

3-
PyMIC is a pytorch-based toolkit for medical image computing with deep learning. Despite that pytorch is a fantastic platform for deep learning, using it for medical image computing is not straightforward as medical images are often with higher dimension, multiple modalities and low contrast. The toolkit is developed to facilitate medical image computing researchers so that training and testing deep learning models become easier. It is very friendly to researchers who are new to this area. Even without writing any code, you can use PyMIC commands to train and test a model by simply editing configure files.
3+
PyMIC is a pytorch-based toolkit for medical image computing with annotation-efficient deep learning. Despite that pytorch is a fantastic platform for deep learning, using it for medical image computing is not straightforward as medical images are often with high dimension and large volume, multiple modalities and difficulies in annotating. This toolkit is developed to facilitate medical image computing researchers so that training and testing deep learning models become easier. It is very friendly to researchers who are new to this area. Even without writing any code, you can use PyMIC commands to train and test a model by simply editing configuration files. PyMIC is developed to support learning with imperfect labels, including semi-supervised and weakly supervised learning, and learning with noisy annotations.
44

55
Currently PyMIC supports 2D/3D medical image classification and segmentation, and it is still under development. It was originally developed for COVID-19 pneumonia lesion segmentation from CT images. If you use this toolkit, please cite the following paper:
66

7-
87
* G. Wang, X. Liu, C. Li, Z. Xu, J. Ruan, H. Zhu, T. Meng, K. Li, N. Huang, S. Zhang.
98
[A Noise-robust Framework for Automatic Segmentation of COVID-19 Pneumonia Lesions from CT Images.][tmi2020] IEEE Transactions on Medical Imaging. 39(8):2653-2663, 2020. DOI: [10.1109/TMI.2020.3000314][tmi2020]
109

1110
[tmi2020]:https://ieeexplore.ieee.org/document/9109297
1211

1312

14-
# Advantages
15-
PyMIC provides some basic modules for medical image computing that can be share by different applications. We currently provide the following functions:
13+
# Features
14+
PyMIC provides flixible modules for medical image computing tasks including classification and segmentation. It currently provides the following functions:
15+
* Support for annotation-efficient image segmentation, especially for semi-supervised, weakly-supervised and noisy-label learning.
16+
* User friendly: For beginners, you only need to edit the configuration files for model training and inference, without writing code. For advanced users, you can customize different modules (networks, loss functions, training pipeline, etc) and easily integrate them into PyMIC.
1617
* Easy-to-use I/O interface to read and write different 2D and 3D images.
18+
* Various data pre-processing/transformation methods before sending a tensor into a network.
19+
* Implementation of typical neural networks for medical image segmentation.
1720
* Re-useable training and testing pipeline that can be transferred to different tasks.
18-
* Various data pre-processing methods before sending a tensor into a network.
19-
* Implementation of loss functions, especially for image segmentation.
20-
* Implementation of evaluation metrics to get quantitative evaluation of your methods (for segmentation).
21+
* Evaluation metrics for quantitative evaluation of your methods.
2122

2223
# Usage
2324
## Requirement
2425
* [Pytorch][torch_link] version >=1.0.1
2526
* [TensorboardX][tbx_link] to visualize training performance
2627
* Some common python packages such as Numpy, Pandas, SimpleITK
28+
* See `requirements.txt` for details.
2729

2830
[torch_link]:https://pytorch.org/
2931
[tbx_link]:https://github.com/lanpa/tensorboardX
3032

3133
## Installation
32-
Run the following command to install the current released version of PyMIC:
34+
Run the following command to install the latest released version of PyMIC:
3335

3436
```bash
3537
pip install PYMIC
3638
```
37-
To install a specific version of PYMIC such as 0.2.4, run:
39+
To install a specific version of PYMIC such as 0.3.0, run:
3840

3941
```bash
40-
pip install PYMIC==0.2.4
42+
pip install PYMIC==0.3.0
4143
```
4244
Alternatively, you can download the source code for the latest version. Run the following command to compile and install:
4345

4446
```bash
4547
python setup.py install
4648
```
4749

48-
## Examples
49-
[PyMIC_examples][examples] provides some examples of starting to use PyMIC. For beginners, you only need to simply change the configuration files to select different datasets, networks and training methods for running the code. For advanced users, you can develop your own modules based on this package. You can find both types of examples
50+
## How to start
51+
* [PyMIC_examples][exp_link] shows some examples of starting to use PyMIC.
52+
* [PyMIC_doc][docs_link] provides documentation of this project.
5053

51-
[examples]: https://github.com/HiLab-git/PyMIC_examples
54+
[docs_link]:https://pymic.readthedocs.io/en/latest/
55+
[exp_link]:https://github.com/HiLab-git/PyMIC_examples
5256

53-
# Projects based on PyMIC
57+
## Projects based on PyMIC
5458
Using PyMIC, it becomes easy to develop deep learning models for different projects, such as the following:
5559

5660
1, [MyoPS][myops] Winner of the MICCAI 2020 myocardial pathology segmentation (MyoPS) Challenge.

pymic/loss/loss_dict_seg.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,14 @@
11
# -*- coding: utf-8 -*-
22
from __future__ import print_function, division
33
import torch.nn as nn
4-
from pymic.loss.seg.ce import CrossEntropyLoss, GeneralizedCrossEntropyLoss
4+
from pymic.loss.seg.ce import CrossEntropyLoss, GeneralizedCELoss
55
from pymic.loss.seg.dice import DiceLoss, FocalDiceLoss, NoiseRobustDiceLoss
66
from pymic.loss.seg.slsr import SLSRLoss
77
from pymic.loss.seg.exp_log import ExpLogLoss
88
from pymic.loss.seg.mse import MSELoss, MAELoss
99

1010
SegLossDict = {'CrossEntropyLoss': CrossEntropyLoss,
11-
'GeneralizedCrossEntropyLoss': GeneralizedCrossEntropyLoss,
11+
'GeneralizedCELoss': GeneralizedCELoss,
1212
'SLSRLoss': SLSRLoss,
1313
'DiceLoss': DiceLoss,
1414
'FocalDiceLoss': FocalDiceLoss,

pymic/loss/seg/ce.py

+14-11
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from pymic.loss.seg.util import reshape_tensor_to_2D
77

88
class CrossEntropyLoss(nn.Module):
9-
def __init__(self, params):
9+
def __init__(self, params = None):
1010
super(CrossEntropyLoss, self).__init__()
1111
if(params is None):
1212
self.softmax = True
@@ -59,41 +59,44 @@ def forward(self, loss_input_dict):
5959
ce = torch.mean(ce)
6060
return ce
6161

62-
class GeneralizedCrossEntropyLoss(nn.Module):
62+
class GeneralizedCELoss(nn.Module):
6363
"""
6464
Generalized cross entropy loss to deal with noisy labels.
6565
Z. Zhang et al. Generalized Cross Entropy Loss for Training Deep Neural Networks
6666
with Noisy Labels, NeurIPS 2018.
6767
"""
6868
def __init__(self, params):
69-
super(GeneralizedCrossEntropyLoss, self).__init__()
70-
self.enable_pix_weight = params['GeneralizedCrossEntropyLoss_Enable_Pixel_Weight'.lower()]
71-
self.enable_cls_weight = params['GeneralizedCrossEntropyLoss_Enable_Class_Weight'.lower()]
72-
self.q = params['GeneralizedCrossEntropyLoss_q'.lower()]
69+
"""
70+
q: in (0, 1), becmomes MAE when q = 1
71+
"""
72+
super(GeneralizedCELoss, self).__init__()
73+
self.enable_pix_weight = params.get('GeneralizedCELoss_Enable_Pixel_Weight', False)
74+
self.enable_cls_weight = params.get('GeneralizedCELoss_Enable_Class_Weight', False)
75+
self.q = params.get('GeneralizedCELoss_q', 0.5)
76+
self.softmax = params.get('loss_softmax', True)
7377

7478
def forward(self, loss_input_dict):
7579
predict = loss_input_dict['prediction']
76-
soft_y = loss_input_dict['ground_truth']
77-
pix_w = loss_input_dict['pixel_weight']
78-
cls_w = loss_input_dict['class_weight']
79-
softmax = loss_input_dict['softmax']
80+
soft_y = loss_input_dict['ground_truth']
8081

8182
if(isinstance(predict, (list, tuple))):
8283
predict = predict[0]
83-
if(softmax):
84+
if(self.softmax):
8485
predict = nn.Softmax(dim = 1)(predict)
8586
predict = reshape_tensor_to_2D(predict)
8687
soft_y = reshape_tensor_to_2D(soft_y)
8788
gce = (1.0 - torch.pow(predict, self.q)) / self.q * soft_y
8889

8990
if(self.enable_cls_weight):
91+
cls_w = loss_input_dict.get('class_weight', None)
9092
if(cls_w is None):
9193
raise ValueError("Class weight is enabled but not defined")
9294
gce = torch.sum(gce * cls_w, dim = 1)
9395
else:
9496
gce = torch.sum(gce, dim = 1)
9597

9698
if(self.enable_pix_weight):
99+
pix_w = loss_input_dict.get('pixel_weight', None)
97100
if(pix_w is None):
98101
raise ValueError("Pixel weight is enabled but not defined")
99102
pix_w = reshape_tensor_to_2D(pix_w)

pymic/loss/seg/mumford_shah.py

+2-25
Original file line numberDiff line numberDiff line change
@@ -4,37 +4,14 @@
44
import torch
55
import torch.nn as nn
66

7-
class DiceLoss(nn.Module):
8-
def __init__(self, params = None):
9-
super(DiceLoss, self).__init__()
10-
if(params is None):
11-
self.softmax = True
12-
else:
13-
self.softmax = params.get('loss_softmax', True)
14-
15-
def forward(self, loss_input_dict):
16-
predict = loss_input_dict['prediction']
17-
soft_y = loss_input_dict['ground_truth']
18-
19-
if(isinstance(predict, (list, tuple))):
20-
predict = predict[0]
21-
if(self.softmax):
22-
predict = nn.Softmax(dim = 1)(predict)
23-
predict = reshape_tensor_to_2D(predict)
24-
soft_y = reshape_tensor_to_2D(soft_y)
25-
dice_score = get_classwise_dice(predict, soft_y)
26-
dice_loss = 1.0 - dice_score.mean()
27-
return dice_loss
28-
297
class MumfordShahLoss(nn.Module):
308
"""
319
Implementation of Mumford Shah Loss in this paper:
32-
Boah Kim and Jong Chul Ye, Mumford–Shah Loss Functional
10+
Boah Kim and Jong Chul Ye: Mumford–Shah Loss Functional
3311
for Image Segmentation With Deep Learning. IEEE TIP, 2019.
3412
The oringial implementation is availabel at:
3513
https://github.com/jongcye/CNN_MumfordShah_Loss
36-
37-
currently only 2D version is supported.
14+
Currently only 2D version is supported.
3815
"""
3916
def __init__(self, params = None):
4017
super(MumfordShahLoss, self).__init__()

pymic/loss/seg/slsr.py

+5-4
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,10 @@
22
"""
33
Spatial Label Smoothing Regularization (SLSR) loss for learning from
44
noisy annotatins according to the following paper:
5-
Minqing Zhang, Jiantao Gao et al., Characterizing Label Errors:
6-
Confident Learning for Noisy-Labeled Image Segmentation, MICCAI 2020.
5+
Minqing Zhang, Jiantao Gao et al.:
6+
Characterizing Label Errors: Confident Learning for Noisy-Labeled Image
7+
Segmentation, MICCAI 2020.
8+
https://link.springer.com/chapter/10.1007/978-3-030-59710-8_70
79
"""
810
from __future__ import print_function, division
911

@@ -17,7 +19,7 @@ def __init__(self, params):
1719
if(params is None):
1820
params = {}
1921
self.softmax = params.get('loss_softmax', True)
20-
self.epsilon = params.get('slsrloss_softmax', 0.25)
22+
self.epsilon = params.get('slsrloss_epsilon', 0.25)
2123

2224
def forward(self, loss_input_dict):
2325
predict = loss_input_dict['prediction']
@@ -35,7 +37,6 @@ def forward(self, loss_input_dict):
3537
soft_y = reshape_tensor_to_2D(soft_y)
3638
if(pix_w is not None):
3739
pix_w = reshape_tensor_to_2D(pix_w > 0).float()
38-
3940
# smooth labels for pixels in the unconfident mask
4041
smooth_y = (soft_y - 0.5) * (0.5 - self.epsilon) / 0.5 + 0.5
4142
smooth_y = pix_w * smooth_y + (1 - pix_w) * soft_y

pymic/net/net2d/unet2d.py

+65-4
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,67 @@ def forward(self, x1, x2):
7272
x = torch.cat([x2, x1], dim=1)
7373
return self.conv(x)
7474

75+
class Encoder(nn.Module):
76+
def __init__(self, params):
77+
super(Encoder, self).__init__()
78+
self.params = params
79+
self.in_chns = self.params['in_chns']
80+
self.ft_chns = self.params['feature_chns']
81+
self.dropout = self.params['dropout']
82+
assert(len(self.ft_chns) == 5 or len(self.ft_chns) == 4)
83+
84+
self.in_conv= ConvBlock(self.in_chns, self.ft_chns[0], self.dropout[0])
85+
self.down1 = DownBlock(self.ft_chns[0], self.ft_chns[1], self.dropout[1])
86+
self.down2 = DownBlock(self.ft_chns[1], self.ft_chns[2], self.dropout[2])
87+
self.down3 = DownBlock(self.ft_chns[2], self.ft_chns[3], self.dropout[3])
88+
if(len(self.ft_chns) == 5):
89+
self.down4 = DownBlock(self.ft_chns[3], self.ft_chns[4], self.dropout[4])
90+
91+
def forward(self, x):
92+
x0 = self.in_conv(x)
93+
x1 = self.down1(x0)
94+
x2 = self.down2(x1)
95+
x3 = self.down3(x2)
96+
output = [x0, x1, x2, x3]
97+
if(len(self.ft_chns) == 5):
98+
x4 = self.down4(x3)
99+
output.append(x4)
100+
return output
101+
102+
class Decoder(nn.Module):
103+
def __init__(self, params):
104+
super(Decoder, self).__init__()
105+
self.params = params
106+
self.in_chns = self.params['in_chns']
107+
self.ft_chns = self.params['feature_chns']
108+
self.dropout = self.params['dropout']
109+
self.n_class = self.params['class_num']
110+
self.bilinear = self.params['bilinear']
111+
112+
assert(len(self.ft_chns) == 5 or len(self.ft_chns) == 4)
113+
114+
if(len(self.ft_chns) == 5):
115+
self.up1 = UpBlock(self.ft_chns[4], self.ft_chns[3], self.ft_chns[3], self.dropout[3], self.bilinear)
116+
self.up2 = UpBlock(self.ft_chns[3], self.ft_chns[2], self.ft_chns[2], self.dropout[2], self.bilinear)
117+
self.up3 = UpBlock(self.ft_chns[2], self.ft_chns[1], self.ft_chns[1], self.dropout[1], self.bilinear)
118+
self.up4 = UpBlock(self.ft_chns[1], self.ft_chns[0], self.ft_chns[0], self.dropout[0], self.bilinear)
119+
self.out_conv = nn.Conv2d(self.ft_chns[0], self.n_class, kernel_size = 1)
120+
121+
def forward(self, x):
122+
if(len(self.ft_chns) == 5):
123+
assert(len(x) == 5)
124+
x0, x1, x2, x3, x4 = x
125+
x_d3 = self.up1(x4, x3)
126+
else:
127+
assert(len(x) == 4)
128+
x0, x1, x2, x3 = x
129+
x_d3 = x3
130+
x_d2 = self.up2(x_d3, x2)
131+
x_d1 = self.up3(x_d2, x1)
132+
x_d0 = self.up4(x_d1, x0)
133+
output = self.out_conv(x_d0)
134+
return output
135+
75136
class UNet2D(nn.Module):
76137
def __init__(self, params):
77138
super(UNet2D, self).__init__()
@@ -91,10 +152,10 @@ def __init__(self, params):
91152
self.down3 = DownBlock(self.ft_chns[2], self.ft_chns[3], self.dropout[3])
92153
if(len(self.ft_chns) == 5):
93154
self.down4 = DownBlock(self.ft_chns[3], self.ft_chns[4], self.dropout[4])
94-
self.up1 = UpBlock(self.ft_chns[4], self.ft_chns[3], self.ft_chns[3], 0.0, self.bilinear)
95-
self.up2 = UpBlock(self.ft_chns[3], self.ft_chns[2], self.ft_chns[2], 0.0, self.bilinear)
96-
self.up3 = UpBlock(self.ft_chns[2], self.ft_chns[1], self.ft_chns[1], 0.0, self.bilinear)
97-
self.up4 = UpBlock(self.ft_chns[1], self.ft_chns[0], self.ft_chns[0], 0.0, self.bilinear)
155+
self.up1 = UpBlock(self.ft_chns[4], self.ft_chns[3], self.ft_chns[3], self.dropout[3], self.bilinear)
156+
self.up2 = UpBlock(self.ft_chns[3], self.ft_chns[2], self.ft_chns[2], self.dropout[2], self.bilinear)
157+
self.up3 = UpBlock(self.ft_chns[2], self.ft_chns[1], self.ft_chns[1], self.dropout[1], self.bilinear)
158+
self.up4 = UpBlock(self.ft_chns[1], self.ft_chns[0], self.ft_chns[0], self.dropout[0], self.bilinear)
98159

99160
self.out_conv = nn.Conv2d(self.ft_chns[0], self.n_class, kernel_size = 1)
100161
if(self.deep_sup):

0 commit comments

Comments
 (0)