Skip to content

Commit 972011a

Browse files
committed
update transform and lr scheduler
add PolynomialLR and Transpose
1 parent abf0dc6 commit 972011a

File tree

9 files changed

+223
-78
lines changed

9 files changed

+223
-78
lines changed

pymic/loss/seg/deep_sup.py

+53-10
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,43 @@
22
from __future__ import print_function, division
33

44
import torch.nn as nn
5+
from torch.nn.functional import interpolate
56
from pymic.loss.seg.abstract import AbstractSegLoss
67

8+
def match_prediction_and_gt_shape(pred, gt, mode = 0):
9+
pred_shape = list(pred.shape)
10+
gt_shape = list(gt.shape)
11+
dim = len(pred_shape) - 2
12+
shape_match = False
13+
if(dim == 2):
14+
if(pred_shape[-1] == gt_shape[-1] and pred_shape[-2] == gt_shape[-2]):
15+
shape_match = True
16+
else:
17+
if(pred_shape[-1] == gt_shape[-1] and pred_shape[-2] == gt_shape[-2]
18+
and pred_shape[-3] == gt_shape[-3]):
19+
shape_match = True
20+
if(shape_match):
21+
return pred, gt
22+
23+
interp_mode = 'bilinear' if dim == 2 else 'trilinear'
24+
if(mode == 0):
25+
pred_new = interpolate(pred, gt_shape[2:], mode = interp_mode)
26+
gt_new = gt
27+
elif(mode == 1):
28+
pred_new = pred
29+
gt_new = interpolate(gt, pred_shape[2:], mode = interp_mode)
30+
elif(mode == 2):
31+
pred_new = pred
32+
if(dim == 2):
33+
avg_pool = nn.AdaptiveAvgPool2d(pred_shape[-2:])
34+
else:
35+
avg_pool = nn.AdaptiveAvgPool3d(pred_shape[-3:])
36+
gt_new = avg_pool(gt)
37+
else:
38+
raise ValueError("mode shoud be 0, 1 or 2, but {0:} was given".format(mode))
39+
return pred_new, gt_new
40+
41+
742
class DeepSuperviseLoss(AbstractSegLoss):
843
'''
944
Combine deep supervision with a basic loss function.
@@ -12,28 +47,36 @@ class DeepSuperviseLoss(AbstractSegLoss):
1247
1348
:param `loss_softmax`: (optional, bool)
1449
Apply softmax to the prediction of network or not. Default is True.
15-
:param `deep_suervise_weight`: (list) A list of weight for each deep supervision scale. \n
1650
:param `base_loss`: (nn.Module) The basic function used for each scale.
51+
:param `deep_supervise_weight`: (list) A list of weight for each deep supervision scale.
52+
:param `deep_supervise_model`: (int) Mode for deep supervision when the prediction
53+
has a smaller shape than the ground truth. 0: upsample the prediction to the size
54+
of the ground truth. 1: downsample the ground truth to the size of the prediction
55+
via interpolation. 2: downsample the ground truth via adaptive average pooling.
1756
1857
'''
1958
def __init__(self, params):
2059
super(DeepSuperviseLoss, self).__init__(params)
21-
self.deep_sup_weight = params.get('deep_suervise_weight', None)
22-
self.base_loss = params['base_loss']
60+
self.base_loss = params['base_loss']
61+
self.deep_sup_weight = params.get('deep_supervise_weight', None)
62+
self.deep_sup_mode = params.get('deep_supervise_mode', 0)
2363

2464
def forward(self, loss_input_dict):
25-
predict = loss_input_dict['prediction']
26-
if(not isinstance(predict, (list,tuple))):
65+
pred = loss_input_dict['prediction']
66+
gt = loss_input_dict['ground_truth']
67+
if(not isinstance(pred, (list,tuple))):
2768
raise ValueError("""For deep supervision, the prediction should
2869
be a list or a tuple""")
29-
predict_num = len(predict)
70+
pred_num = len(pred)
3071
if(self.deep_sup_weight is None):
31-
self.deep_sup_weight = [1.0] * predict_num
72+
self.deep_sup_weight = [1.0] * pred_num
3273
else:
33-
assert(predict_num == len(self.deep_sup_weight))
74+
assert(pred_num == len(self.deep_sup_weight))
3475
loss_sum, weight_sum = 0.0, 0.0
35-
for i in range(predict_num):
36-
loss_input_dict['prediction'] = predict[i]
76+
for i in range(pred_num):
77+
pred_i, gt_i = match_prediction_and_gt_shape(pred[i], gt, self.deep_sup_mode)
78+
loss_input_dict['prediction'] = pred_i
79+
loss_input_dict['ground_truth'] = gt_i
3780
temp_loss = self.base_loss(loss_input_dict)
3881
loss_sum += temp_loss * self.deep_sup_weight[i]
3982
weight_sum += self.deep_sup_weight[i]

pymic/net/net2d/unet2d.py

+1-5
Original file line numberDiff line numberDiff line change
@@ -240,18 +240,14 @@ def forward(self, x):
240240
x_d0 = self.up4(x_d1, x0)
241241
output = self.out_conv(x_d0)
242242
if(self.deep_sup):
243-
out_shape = list(output.shape)[2:]
244243
output1 = self.out_conv1(x_d1)
245-
output1 = interpolate(output1, out_shape, mode = 'bilinear')
246244
output2 = self.out_conv2(x_d2)
247-
output2 = interpolate(output2, out_shape, mode = 'bilinear')
248245
output3 = self.out_conv3(x_d3)
249-
output3 = interpolate(output3, out_shape, mode = 'bilinear')
250246
output = [output, output1, output2, output3]
251247

252248
if(len(x_shape) == 5):
253-
new_shape = [N, D] + list(output[0].shape)[1:]
254249
for i in range(len(output)):
250+
new_shape = [N, D] + list(output[i].shape)[1:]
255251
output[i] = torch.transpose(torch.reshape(output[i], new_shape), 1, 2)
256252
elif(len(x_shape) == 5):
257253
new_shape = [N, D] + list(output.shape)[1:]

pymic/net_run/agent_seg.py

+13-13
Original file line numberDiff line numberDiff line change
@@ -173,9 +173,9 @@ def training(self):
173173
train_dice_list.append(dice_list.cpu().numpy())
174174
train_avg_loss = train_loss / iter_valid
175175
train_cls_dice = np.asarray(train_dice_list).mean(axis = 0)
176-
train_avg_dice = train_cls_dice.mean()
176+
train_avg_dice = train_cls_dice[1:].mean()
177177

178-
train_scalers = {'loss': train_avg_loss, 'avg_dice':train_avg_dice,\
178+
train_scalers = {'loss': train_avg_loss, 'avg_fg_dice':train_avg_dice,\
179179
'class_dice': train_cls_dice}
180180
return train_scalers
181181

@@ -214,14 +214,14 @@ def validation(self):
214214

215215
valid_avg_loss = np.asarray(valid_loss_list).mean()
216216
valid_cls_dice = np.asarray(valid_dice_list).mean(axis = 0)
217-
valid_avg_dice = valid_cls_dice.mean()
218-
valid_scalers = {'loss': valid_avg_loss, 'avg_dice': valid_avg_dice,\
217+
valid_avg_dice = valid_cls_dice[1:].mean()
218+
valid_scalers = {'loss': valid_avg_loss, 'avg_fg_dice': valid_avg_dice,\
219219
'class_dice': valid_cls_dice}
220220
return valid_scalers
221221

222222
def write_scalars(self, train_scalars, valid_scalars, lr_value, glob_it):
223223
loss_scalar ={'train':train_scalars['loss'], 'valid':valid_scalars['loss']}
224-
dice_scalar ={'train':train_scalars['avg_dice'], 'valid':valid_scalars['avg_dice']}
224+
dice_scalar ={'train':train_scalars['avg_fg_dice'], 'valid':valid_scalars['avg_fg_dice']}
225225
self.summ_writer.add_scalars('loss', loss_scalar, glob_it)
226226
self.summ_writer.add_scalars('dice', dice_scalar, glob_it)
227227
self.summ_writer.add_scalars('lr', {"lr": lr_value}, glob_it)
@@ -231,11 +231,11 @@ def write_scalars(self, train_scalars, valid_scalars, lr_value, glob_it):
231231
'valid':valid_scalars['class_dice'][c]}
232232
self.summ_writer.add_scalars('class_{0:}_dice'.format(c), cls_dice_scalar, glob_it)
233233

234-
logging.info('train loss {0:.4f}, avg dice {1:.4f} '.format(
235-
train_scalars['loss'], train_scalars['avg_dice']) + "[" + \
234+
logging.info('train loss {0:.4f}, avg foreground dice {1:.4f} '.format(
235+
train_scalars['loss'], train_scalars['avg_fg_dice']) + "[" + \
236236
' '.join("{0:.4f}".format(x) for x in train_scalars['class_dice']) + "]")
237-
logging.info('valid loss {0:.4f}, avg dice {1:.4f} '.format(
238-
valid_scalars['loss'], valid_scalars['avg_dice']) + "[" + \
237+
logging.info('valid loss {0:.4f}, avg foreground dice {1:.4f} '.format(
238+
valid_scalars['loss'], valid_scalars['avg_fg_dice']) + "[" + \
239239
' '.join("{0:.4f}".format(x) for x in valid_scalars['class_dice']) + "]")
240240

241241
def train_valid(self):
@@ -295,7 +295,7 @@ def train_valid(self):
295295
valid_scalars = self.validation()
296296
t2 = time.time()
297297
if(isinstance(self.scheduler, lr_scheduler.ReduceLROnPlateau)):
298-
self.scheduler.step(valid_scalars['avg_dice'])
298+
self.scheduler.step(valid_scalars['avg_fg_dice'])
299299
else:
300300
self.scheduler.step()
301301

@@ -304,8 +304,8 @@ def train_valid(self):
304304
logging.info('learning rate {0:}'.format(lr_value))
305305
logging.info("training/validation time: {0:.2f}s/{1:.2f}s".format(t1-t0, t2-t1))
306306
self.write_scalars(train_scalars, valid_scalars, lr_value, self.glob_it)
307-
if(valid_scalars['avg_dice'] > self.max_val_dice):
308-
self.max_val_dice = valid_scalars['avg_dice']
307+
if(valid_scalars['avg_fg_dice'] > self.max_val_dice):
308+
self.max_val_dice = valid_scalars['avg_fg_dice']
309309
self.max_val_it = self.glob_it
310310
if(len(device_ids) > 1):
311311
self.best_model_wts = copy.deepcopy(self.net.module.state_dict())
@@ -316,7 +316,7 @@ def train_valid(self):
316316
self.glob_it - self.max_val_it > early_stop_it) else False
317317
if ((self.glob_it in iter_save_list) or stop_now):
318318
save_dict = {'iteration': self.glob_it,
319-
'valid_pred': valid_scalars['avg_dice'],
319+
'valid_pred': valid_scalars['avg_fg_dice'],
320320
'model_state_dict': self.net.module.state_dict() \
321321
if len(device_ids) > 1 else self.net.state_dict(),
322322
'optimizer_state_dict': self.optimizer.state_dict()}

pymic/net_run/get_optimizer.py

+8-4
Original file line numberDiff line numberDiff line change
@@ -39,8 +39,9 @@ def get_optimizer(name, net_params, optim_params):
3939

4040

4141
def get_lr_scheduler(optimizer, sched_params):
42-
name = sched_params["lr_scheduler"]
43-
val_it = sched_params["iter_valid"]
42+
name = sched_params["lr_scheduler"]
43+
val_it = sched_params["iter_valid"]
44+
epoch_last = sched_params["last_iter"] / val_it
4445
if(name is None):
4546
return None
4647
if(keyword_match(name, "ReduceLROnPlateau")):
@@ -52,16 +53,19 @@ def get_lr_scheduler(optimizer, sched_params):
5253
elif(keyword_match(name, "MultiStepLR")):
5354
lr_milestones = sched_params["lr_milestones"]
5455
lr_milestones = [int(item / val_it) for item in lr_milestones]
55-
epoch_last = sched_params["last_iter"] / val_it
5656
lr_gamma = sched_params["lr_gamma"]
5757
scheduler = lr_scheduler.MultiStepLR(optimizer,
5858
lr_milestones, lr_gamma, epoch_last)
5959
elif(keyword_match(name, "CosineAnnealingLR")):
6060
epoch_max = sched_params["iter_max"] / val_it
61-
epoch_last = sched_params["last_iter"] / val_it
6261
lr_min = sched_params.get("lr_min", 0)
6362
scheduler = lr_scheduler.CosineAnnealingLR(optimizer,
6463
epoch_max, lr_min, epoch_last)
64+
elif(keyword_match(name, "PolynomialLR")):
65+
epoch_max = sched_params["iter_max"] / val_it
66+
power = sched_params["lr_power"]
67+
scheduler = lr_scheduler.PolynomialLR(optimizer,
68+
epoch_max, power, epoch_last)
6569
else:
6670
raise ValueError("unsupported lr scheduler {0:}".format(name))
6771
return scheduler

pymic/net_run/infer_func.py

+18-1
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,8 @@ class Inferer(object):
1515
:param `sliding_window_stride`: (optional, list) The sliding window stride.
1616
:param `tta_mode`: (optional, int) The test time augmentation mode. Default
1717
is 0 (no test time augmentation). The other option is 1 (augmentation
18-
with horinzontal and vertical flipping).
18+
with horinzontal and vertical flipping) and 2 (ensemble of inference
19+
in axial, sagittal and coronal views for 2D networks applied to 3D volumes)
1920
"""
2021
def __init__(self, config):
2122
self.config = config
@@ -170,6 +171,22 @@ def run(self, model, image):
170171
outputs3 = torch.flip(outputs3, [-1])
171172
outputs4 = torch.flip(outputs4, [-2, -1])
172173
outputs = (outputs1 + outputs2 + outputs3 + outputs4) / 4
174+
elif(tta_mode == 2):
175+
outputs1 = self.__infer(image)
176+
outputs2 = self.__infer(torch.transpose(image, -1, -3))
177+
outputs3 = self.__infer(torch.transpose(image, -2, -3))
178+
if(isinstance(outputs1, (tuple, list))):
179+
outputs = []
180+
for i in range(len(outputs1)):
181+
temp_out1 = outputs1[i]
182+
temp_out2 = torch.transpose(outputs2[i], -1, -3)
183+
temp_out3 = torch.transpose(outputs3[i], -2, -3)
184+
temp_mean = (temp_out1 + temp_out2 + temp_out3) / 3
185+
outputs.append(temp_mean)
186+
else:
187+
outputs2 = torch.transpose(outputs2, -1, -3)
188+
outputs3 = torch.transpose(outputs3, -2, -3)
189+
outputs = (outputs1 + outputs2 + outputs3) / 3
173190
else:
174191
raise ValueError("Undefined tta_mode {0:}".format(tta_mode))
175192
return outputs

pymic/net_run/net_run.py

+10-2
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import logging
44
import os
55
import sys
6+
import shutil
67
from pymic.util.parse_config import *
78
from pymic.net_run.agent_cls import ClassificationAgent
89
from pymic.net_run.agent_seg import SegmentationAgent
@@ -22,8 +23,15 @@ def main():
2223
log_dir = config['training']['ckpt_save_dir']
2324
if(not os.path.exists(log_dir)):
2425
os.makedirs(log_dir, exist_ok=True)
25-
logging.basicConfig(filename=log_dir+"/log_{0:}.txt".format(stage), level=logging.INFO,
26-
format='%(message)s')
26+
if(stage == "train"):
27+
dst_cfg = cfg_file if "/" not in cfg_file else cfg_file.split("/")[-1]
28+
shutil.copy(cfg_file, log_dir + "/" + dst_cfg)
29+
if sys.version.startswith("3.9"):
30+
logging.basicConfig(filename=log_dir+"/log_{0:}.txt".format(stage), level=logging.INFO,
31+
format='%(message)s', force=True) # for python 3.9
32+
else:
33+
logging.basicConfig(filename=log_dir+"/log_{0:}.txt".format(stage), level=logging.INFO,
34+
format='%(message)s') # for python 3.6
2735
logging.getLogger().addHandler(logging.StreamHandler(sys.stdout))
2836
logging_config(config)
2937
task = config['dataset']['task_type']

pymic/test/test_nifty_dataset.py

+61-43
Original file line numberDiff line numberDiff line change
@@ -1,50 +1,68 @@
11
# -*- coding: utf-8 -*-
22
from __future__ import print_function, division
3+
import sys
4+
import numpy as np
5+
from pymic.util.parse_config import *
6+
from pymic.net_run.agent_cls import ClassificationAgent
7+
from pymic.net_run.agent_seg import SegmentationAgent
8+
import SimpleITK as sitk
39

4-
import os
5-
import torch
6-
import pandas as pd
7-
import numpy as np
8-
from skimage import io, transform
9-
from torch.utils.data import Dataset, DataLoader
10-
from torchvision import transforms, utils
11-
from pymic.io.image_read_write import *
12-
from pymic.io.nifty_dataset import NiftyDataset
13-
from pymic.io.transform3d import *
10+
def save_array_as_nifty_volume(data, image_name, reference_name = None):
11+
"""
12+
Save a numpy array as nifty image
1413
15-
if __name__ == "__main__":
16-
root_dir = '/home/guotai/data/brats/BraTS2018_Training'
17-
csv_file = '/home/guotai/projects/torch_brats/brats/config/brats18_train_train.csv'
18-
19-
crop1 = CropWithBoundingBox(start = None, output_size = [4, 144, 180, 144])
20-
norm = ChannelWiseNormalize(mean = None, std = None, zero_to_random = True)
21-
labconv = LabelConvert([0, 1, 2, 4], [0, 1, 2, 3])
22-
crop2 = RandomCrop([128, 128, 128])
23-
rescale =Rescale([64, 64, 64])
24-
transform_list = [crop1, norm, labconv, crop2,rescale, ToTensor()]
25-
transformed_dataset = NiftyDataset(root_dir=root_dir,
26-
csv_file = csv_file,
27-
modal_num = 4,
28-
transform = transforms.Compose(transform_list))
29-
dataloader = DataLoader(transformed_dataset, batch_size=4,
30-
shuffle=True, num_workers=4)
31-
# Helper function to show a batch
14+
:param data: (numpy.ndarray) A numpy array with shape [Depth, Height, Width].
15+
:param image_name: (str) The ouput file name.
16+
:param reference_name: (str) File name of the reference image of which
17+
meta information is used.
18+
"""
19+
img = sitk.GetImageFromArray(data)
20+
if(reference_name is not None):
21+
img_ref = sitk.ReadImage(reference_name)
22+
#img.CopyInformation(img_ref)
23+
img.SetSpacing(img_ref.GetSpacing())
24+
img.SetOrigin(img_ref.GetOrigin())
25+
img.SetDirection(img_ref.GetDirection())
26+
sitk.WriteImage(img, image_name)
3227

28+
def main():
29+
"""
30+
The main function for running a network for training or inference.
31+
"""
32+
if(len(sys.argv) < 3):
33+
print('Number of arguments should be 3. e.g.')
34+
print('python test_nifty_dataset.py train config.cfg')
35+
exit()
36+
stage = str(sys.argv[1])
37+
cfg_file = str(sys.argv[2])
38+
config = parse_config(cfg_file)
39+
config = synchronize_config(config)
40+
# task = config['dataset']['task_type']
41+
# assert task in ['cls', 'cls_nexcl', 'seg']
42+
# if(task == 'cls' or task == 'cls_nexcl'):
43+
# agent = ClassificationAgent(config, stage)
44+
# else:
45+
# agent = SegmentationAgent(config, stage)
46+
agent = SegmentationAgent(config, stage)
47+
agent.create_dataset()
48+
data_loader = agent.train_loader if stage == "train" else agent.test_loader
49+
it = 0
50+
for data in data_loader:
51+
inputs = agent.convert_tensor_type(data['image'])
52+
labels_prob = agent.convert_tensor_type(data['label_prob'])
53+
for i in range(inputs.shape[0]):
54+
image_i = inputs[i][0]
55+
label_i = np.argmax(labels_prob[i], axis = 0)
56+
print(image_i.shape, label_i.shape)
57+
image_name = "temp/image_{0:}_{1:}.nii.gz".format(it, i)
58+
label_name = "temp/label_{0:}_{1:}.nii.gz".format(it, i)
59+
save_array_as_nifty_volume(image_i, image_name, reference_name = None)
60+
save_array_as_nifty_volume(label_i, label_name, reference_name = None)
61+
it = it + 1
62+
if(it == 10):
63+
break
3364

34-
for i_batch, sample_batched in enumerate(dataloader):
35-
print(i_batch, sample_batched['image'].size(),
36-
sample_batched['label'].size())
65+
if __name__ == "__main__":
66+
main()
67+
3768

38-
# # observe 4th batch and stop.
39-
modals = ['flair', 't1ce', 't1', 't2']
40-
if i_batch == 0:
41-
image = sample_batched['image'].numpy()
42-
label = sample_batched['label'].numpy()
43-
for i in range(image.shape[0]):
44-
for mod in range(4):
45-
image_i = image[i][mod]
46-
label_i = label[i][0]
47-
image_name = "temp/image_{0:}_{1:}.nii.gz".format(i, modals[mod])
48-
label_name = "temp/label_{0:}.nii.gz".format(i)
49-
save_array_as_nifty_volume(image_i, image_name, reference_name = None)
50-
save_array_as_nifty_volume(label_i, label_name, reference_name = None)

0 commit comments

Comments
 (0)