Skip to content

Commit d38bc56

Browse files
authored
Merge pull request #18 from MHGL/Test_Add_IoU_Eval
Test add io u eval
2 parents 8db4c23 + 0244cd3 commit d38bc56

File tree

3 files changed

+29
-6
lines changed

3 files changed

+29
-6
lines changed

test.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
# -*- coding:utf-8 -*-
12
import sys
23
import os
34
import time
@@ -7,12 +8,15 @@
78
import torch.nn.functional as F
89
import deepvac
910
from deepvac import LOG, Deepvac
11+
from utils.utils_IOU_eval import IOUEval
1012
from data.dataloader import OsWalkDataset2
1113

1214
class ESPNetTest(Deepvac):
1315
def __init__(self, deepvac_config):
1416
super(ESPNetTest, self).__init__(deepvac_config)
1517
os.makedirs(self.config.show_output_dir, exist_ok=True)
18+
if self.config.test_label_path is not None:
19+
self.config.iou_eval = IOUEval(self.config.cls_num)
1620

1721
def preIter(self):
1822
assert len(self.config.target) == 1, 'config.core.test_batch_size must be set to 1 in current test mode.'
@@ -35,6 +39,11 @@ def postIter(self):
3539
savepath = os.path.join(self.config.show_output_dir, filename)
3640
mask_savepath = os.path.join(self.config.show_output_dir, mask_filename)
3741

42+
if self.config.test_label_path:
43+
label_file = os.path.join(self.config.test_label_path, filename.replace(".jpg", ".png"))
44+
self.config.label = cv2.imread(label_file, 0)
45+
self.config.iou_eval.addBatch(self.config.mask, self.config.label)
46+
3847
classMap_numpy_color = np.zeros((h, w, 3), dtype=np.uint8)
3948
for idx in np.unique(self.config.mask):
4049
[r, g, b] = self.config.pallete[idx]
@@ -44,6 +53,17 @@ def postIter(self):
4453
cv2.imwrite(mask_savepath, classMap_numpy_color)
4554
LOG.logI('{}: [out cv image save to {}] [{}/{}]\n'.format(self.config.phase, savepath, self.config.test_step + 1, len(self.config.test_loader)))
4655

56+
def testFly(self):
57+
if self.config.test_loader:
58+
self.test()
59+
if self.config.test_label_path is None:
60+
return
61+
*_, self.config.mIOU = self.config.iou_eval.getMetric()
62+
LOG.logI(">>> {}: [dataset: {}, mIOU: {:.3f}]".format(self.config.phase, self.config.filepath.split('/')[-2], self.config.mIOU))
63+
return
64+
65+
LOG.logE("You have to reimplement testFly() in subclass {} if you didn't set any valid input, e.g. config.core.test_loader.".format(self.name()), exit=True)
66+
4767

4868
if __name__ == "__main__":
4969
from config import config
@@ -55,13 +75,16 @@ def check_args(idx, argv):
5575
config.core.model_path = sys.argv[1]
5676
if check_args(2, sys.argv):
5777
config.test_sample_path = sys.argv[2]
78+
if check_args(3, sys.argv):
79+
config.core.test_label_path = sys.argv[3]
5880

5981
if (config.core.model_path is None) or (config.test_sample_path is None):
6082
helper = '''model_path or test_sample_path not found, please check:
6183
config.core.model_path or sys.argv[1] to init model path
6284
config.test_sample_path or sys.argv[2] to init test sample path
85+
config.test_label_path or sys.argv[3] to init test sample path (not required)
6386
for example:
64-
python3 test.py <trained-model-path> <test sample path>'''
87+
python3 test.py <trained-model-path> <test sample path> [test label path(not required)]'''
6588
print(helper)
6689
sys.exit(1)
6790

train.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ def train(self):
2121
for i, loader in enumerate(self.config.train_loader_list):
2222
self.config.train_loader = loader
2323
super(ESPNetTrain, self).train()
24-
24+
2525
#only save model for last loader
2626
def doSave(self):
2727
if not self.config.train_loader.is_last_loader:
@@ -34,9 +34,9 @@ def postIter(self):
3434

3535
self.config.epoch_loss.append(self.config.loss.item())
3636
if self.config.phase == 'TRAIN':
37-
self.iou_eval_train.addBatch(self.config.output[0].max(1)[1].data, self.config.target.data)
37+
self.iou_eval_train.addBatch(self.config.output[0].max(1)[1].data.cpu().numpy(), self.config.target.data.cpu().numpy())
3838
else:
39-
self.iou_eval_val.addBatch(self.config.output[0].max(1)[1].data, self.config.target.data)
39+
self.iou_eval_val.addBatch(self.config.output[0].max(1)[1].data.cpu().numpy(), self.config.target.data.cpu().numpy())
4040

4141
def preEpoch(self):
4242
self.config.epoch_loss = []

utils/utils_IOU_eval.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,8 @@ def compute_hist(self, predict, gth):
2121
return hist
2222

2323
def addBatch(self, predict, gth):
24-
predict = predict.cpu().numpy().flatten()
25-
gth = gth.cpu().numpy().flatten()
24+
predict = predict.flatten()
25+
gth = gth.flatten()
2626

2727
epsilon = 0.00000001
2828
hist = self.compute_hist(predict, gth)

0 commit comments

Comments
 (0)