1
1
# -*- coding: utf-8 -*-
2
2
from __future__ import print_function , division
3
+ import sys
4
+ import numpy as np
5
+ from pymic .util .parse_config import *
6
+ from pymic .net_run .agent_cls import ClassificationAgent
7
+ from pymic .net_run .agent_seg import SegmentationAgent
8
+ import SimpleITK as sitk
3
9
4
- import os
5
- import torch
6
- import pandas as pd
7
- import numpy as np
8
- from skimage import io , transform
9
- from torch .utils .data import Dataset , DataLoader
10
- from torchvision import transforms , utils
11
- from pymic .io .image_read_write import *
12
- from pymic .io .nifty_dataset import NiftyDataset
13
- from pymic .io .transform3d import *
10
+ def save_array_as_nifty_volume (data , image_name , reference_name = None ):
11
+ """
12
+ Save a numpy array as nifty image
14
13
15
- if __name__ == "__main__" :
16
- root_dir = '/home/guotai/data/brats/BraTS2018_Training'
17
- csv_file = '/home/guotai/projects/torch_brats/brats/config/brats18_train_train.csv'
18
-
19
- crop1 = CropWithBoundingBox (start = None , output_size = [4 , 144 , 180 , 144 ])
20
- norm = ChannelWiseNormalize (mean = None , std = None , zero_to_random = True )
21
- labconv = LabelConvert ([0 , 1 , 2 , 4 ], [0 , 1 , 2 , 3 ])
22
- crop2 = RandomCrop ([128 , 128 , 128 ])
23
- rescale = Rescale ([64 , 64 , 64 ])
24
- transform_list = [crop1 , norm , labconv , crop2 ,rescale , ToTensor ()]
25
- transformed_dataset = NiftyDataset (root_dir = root_dir ,
26
- csv_file = csv_file ,
27
- modal_num = 4 ,
28
- transform = transforms .Compose (transform_list ))
29
- dataloader = DataLoader (transformed_dataset , batch_size = 4 ,
30
- shuffle = True , num_workers = 4 )
31
- # Helper function to show a batch
14
+ :param data: (numpy.ndarray) A numpy array with shape [Depth, Height, Width].
15
+ :param image_name: (str) The ouput file name.
16
+ :param reference_name: (str) File name of the reference image of which
17
+ meta information is used.
18
+ """
19
+ img = sitk .GetImageFromArray (data )
20
+ if (reference_name is not None ):
21
+ img_ref = sitk .ReadImage (reference_name )
22
+ #img.CopyInformation(img_ref)
23
+ img .SetSpacing (img_ref .GetSpacing ())
24
+ img .SetOrigin (img_ref .GetOrigin ())
25
+ img .SetDirection (img_ref .GetDirection ())
26
+ sitk .WriteImage (img , image_name )
32
27
28
+ def main ():
29
+ """
30
+ The main function for running a network for training or inference.
31
+ """
32
+ if (len (sys .argv ) < 3 ):
33
+ print ('Number of arguments should be 3. e.g.' )
34
+ print ('python test_nifty_dataset.py train config.cfg' )
35
+ exit ()
36
+ stage = str (sys .argv [1 ])
37
+ cfg_file = str (sys .argv [2 ])
38
+ config = parse_config (cfg_file )
39
+ config = synchronize_config (config )
40
+ # task = config['dataset']['task_type']
41
+ # assert task in ['cls', 'cls_nexcl', 'seg']
42
+ # if(task == 'cls' or task == 'cls_nexcl'):
43
+ # agent = ClassificationAgent(config, stage)
44
+ # else:
45
+ # agent = SegmentationAgent(config, stage)
46
+ agent = SegmentationAgent (config , stage )
47
+ agent .create_dataset ()
48
+ data_loader = agent .train_loader if stage == "train" else agent .test_loader
49
+ it = 0
50
+ for data in data_loader :
51
+ inputs = agent .convert_tensor_type (data ['image' ])
52
+ labels_prob = agent .convert_tensor_type (data ['label_prob' ])
53
+ for i in range (inputs .shape [0 ]):
54
+ image_i = inputs [i ][0 ]
55
+ label_i = np .argmax (labels_prob [i ], axis = 0 )
56
+ print (image_i .shape , label_i .shape )
57
+ image_name = "temp/image_{0:}_{1:}.nii.gz" .format (it , i )
58
+ label_name = "temp/label_{0:}_{1:}.nii.gz" .format (it , i )
59
+ save_array_as_nifty_volume (image_i , image_name , reference_name = None )
60
+ save_array_as_nifty_volume (label_i , label_name , reference_name = None )
61
+ it = it + 1
62
+ if (it == 10 ):
63
+ break
33
64
34
- for i_batch , sample_batched in enumerate ( dataloader ) :
35
- print ( i_batch , sample_batched [ 'image' ]. size (),
36
- sample_batched [ 'label' ]. size ())
65
+ if __name__ == "__main__" :
66
+ main ()
67
+
37
68
38
- # # observe 4th batch and stop.
39
- modals = ['flair' , 't1ce' , 't1' , 't2' ]
40
- if i_batch == 0 :
41
- image = sample_batched ['image' ].numpy ()
42
- label = sample_batched ['label' ].numpy ()
43
- for i in range (image .shape [0 ]):
44
- for mod in range (4 ):
45
- image_i = image [i ][mod ]
46
- label_i = label [i ][0 ]
47
- image_name = "temp/image_{0:}_{1:}.nii.gz" .format (i , modals [mod ])
48
- label_name = "temp/label_{0:}.nii.gz" .format (i )
49
- save_array_as_nifty_volume (image_i , image_name , reference_name = None )
50
- save_array_as_nifty_volume (label_i , label_name , reference_name = None )
0 commit comments