Skip to content

Commit 6053345

Browse files
linfangjian.vendorzhengmiao
linfangjian.vendor
authored and
zhengmiao
committed
[Refactor] Refactor cityscapes metrics
1 parent eef12a0 commit 6053345

File tree

9 files changed

+266
-5
lines changed

9 files changed

+266
-5
lines changed

mmseg/datasets/pipelines/formatting.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -41,8 +41,9 @@ class PackSegInputs(BaseTransform):
4141
"""
4242

4343
def __init__(self,
44-
meta_keys=('img_path', 'ori_shape', 'img_shape', 'pad_shape',
45-
'scale_factor', 'flip', 'flip_direction')):
44+
meta_keys=('img_path', 'seg_map_path', 'ori_shape',
45+
'img_shape', 'pad_shape', 'scale_factor', 'flip',
46+
'flip_direction')):
4647
self.meta_keys = meta_keys
4748

4849
def transform(self, results: dict) -> dict:

mmseg/metrics/__init__.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
# Copyright (c) OpenMMLab. All rights reserved.
2+
from .citys_metric import CitysMetric
23
from .iou_metric import IoUMetric
34

4-
__all__ = ['IoUMetric']
5+
__all__ = ['IoUMetric', 'CitysMetric']

mmseg/metrics/citys_metric.py

+147
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,147 @@
1+
# Copyright (c) OpenMMLab. All rights reserved.
2+
import os.path as osp
3+
from typing import Dict, List, Optional, Sequence
4+
5+
import mmcv
6+
import numpy as np
7+
from mmengine.evaluator import BaseMetric
8+
from mmengine.logging import MMLogger, print_log
9+
from PIL import Image
10+
11+
from mmseg.registry import METRICS
12+
13+
14+
@METRICS.register_module()
15+
class CitysMetric(BaseMetric):
16+
"""Cityscapes evaluation metric.
17+
18+
Args:
19+
ignore_index (int): Index that will be ignored in evaluation.
20+
Default: 255.
21+
citys_metrics (list[str] | str): Metrics to be evaluated,
22+
Default: ['cityscapes'].
23+
to_label_id (bool): whether convert output to label_id for
24+
submission. Default: True.
25+
suffix (str): The filename prefix of the png files.
26+
If the prefix is "somepath/xxx", the png files will be
27+
named "somepath/xxx.png". Default: '.format_cityscapes'.
28+
collect_device (str): Device name used for collecting results from
29+
different ranks during distributed training. Must be 'cpu' or
30+
'gpu'. Defaults to 'cpu'.
31+
prefix (str, optional): The prefix that will be added in the metric
32+
names to disambiguate homonymous metrics of different evaluators.
33+
If prefix is not provided in the argument, self.default_prefix
34+
will be used instead. Defaults to None.
35+
"""
36+
37+
def __init__(self,
38+
ignore_index: int = 255,
39+
citys_metrics: List[str] = ['cityscapes'],
40+
to_label_id: bool = True,
41+
suffix: str = '.format_cityscapes',
42+
collect_device: str = 'cpu',
43+
prefix: Optional[str] = None) -> None:
44+
super().__init__(collect_device=collect_device, prefix=prefix)
45+
46+
self.ignore_index = ignore_index
47+
self.metrics = citys_metrics
48+
assert self.metrics[0] == 'cityscapes'
49+
self.to_label_id = to_label_id
50+
self.suffix = suffix
51+
52+
def process(self, data_batch: Sequence[dict],
53+
predictions: Sequence[dict]) -> None:
54+
"""Process one batch of data and predictions.
55+
56+
The processed results should be stored in ``self.results``, which will
57+
be used to computed the metrics when all batches have been processed.
58+
59+
Args:
60+
data_batch (Sequence[dict]): A batch of data from the dataloader.
61+
predictions (Sequence[dict]): A batch of outputs from the model.
62+
"""
63+
mmcv.mkdir_or_exist(self.suffix)
64+
65+
for pred in predictions:
66+
pred_label = pred['pred_sem_seg']['data'][0].cpu().numpy()
67+
# results2img
68+
if self.to_label_id:
69+
pred_label = self._convert_to_label_id(pred_label)
70+
basename = osp.splitext(osp.basename(pred['img_path']))[0]
71+
png_filename = osp.join(self.suffix, f'{basename}.png')
72+
output = Image.fromarray(pred_label.astype(np.uint8)).convert('P')
73+
import cityscapesscripts.helpers.labels as CSLabels
74+
palette = np.zeros((len(CSLabels.id2label), 3), dtype=np.uint8)
75+
for label_id, label in CSLabels.id2label.items():
76+
palette[label_id] = label.color
77+
output.putpalette(palette)
78+
output.save(png_filename)
79+
80+
ann_dir = osp.join(
81+
data_batch[0]['data_sample']['seg_map_path'].split('val')[0],
82+
'val')
83+
self.results.append(ann_dir)
84+
85+
def compute_metrics(self, results: list) -> Dict[str, float]:
86+
"""Compute the metrics from processed results.
87+
88+
Args:
89+
results (list): Testing results of the dataset.
90+
logger (logging.Logger | str | None): Logger used for printing
91+
related information during evaluation. Default: None.
92+
imgfile_prefix (str | None): The prefix of output image file
93+
94+
Returns:
95+
dict[str: float]: Cityscapes evaluation results.
96+
"""
97+
logger: MMLogger = MMLogger.get_current_instance()
98+
try:
99+
import cityscapesscripts.evaluation.evalPixelLevelSemanticLabeling as CSEval # noqa
100+
except ImportError:
101+
raise ImportError('Please run "pip install cityscapesscripts" to '
102+
'install cityscapesscripts first.')
103+
msg = 'Evaluating in Cityscapes style'
104+
105+
if logger is None:
106+
msg = '\n' + msg
107+
print_log(msg, logger=logger)
108+
109+
result_dir = self.suffix
110+
111+
eval_results = dict()
112+
print_log(f'Evaluating results under {result_dir} ...', logger=logger)
113+
114+
CSEval.args.evalInstLevelScore = True
115+
CSEval.args.predictionPath = osp.abspath(result_dir)
116+
CSEval.args.evalPixelAccuracy = True
117+
CSEval.args.JSONOutput = False
118+
119+
seg_map_list = []
120+
pred_list = []
121+
ann_dir = results[0]
122+
# when evaluating with official cityscapesscripts,
123+
# **_gtFine_labelIds.png is used
124+
for seg_map in mmcv.scandir(
125+
ann_dir, 'gtFine_labelIds.png', recursive=True):
126+
seg_map_list.append(osp.join(ann_dir, seg_map))
127+
pred_list.append(CSEval.getPrediction(CSEval.args, seg_map))
128+
metric = dict()
129+
eval_results.update(
130+
CSEval.evaluateImgLists(pred_list, seg_map_list, CSEval.args))
131+
metric['averageScoreCategories'] = eval_results[
132+
'averageScoreCategories']
133+
metric['averageScoreInstCategories'] = eval_results[
134+
'averageScoreInstCategories']
135+
return metric
136+
137+
@staticmethod
138+
def _convert_to_label_id(result):
139+
"""Convert trainId to id for cityscapes."""
140+
if isinstance(result, str):
141+
result = np.load(result)
142+
import cityscapesscripts.helpers.labels as CSLabels
143+
result_copy = result.copy()
144+
for trainId, label in CSLabels.trainId2label.items():
145+
result_copy[result == trainId] = label.id
146+
147+
return result_copy

tests/test_datasets/test_dataset.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -173,10 +173,10 @@ def test_cityscapes():
173173
data_prefix=dict(
174174
img_path=osp.join(
175175
osp.dirname(__file__),
176-
'../data/pseudo_cityscapes_dataset/leftImg8bit'),
176+
'../data/pseudo_cityscapes_dataset/leftImg8bit/val'),
177177
seg_map_path=osp.join(
178178
osp.dirname(__file__),
179-
'../data/pseudo_cityscapes_dataset/gtFine')))
179+
'../data/pseudo_cityscapes_dataset/gtFine/val')))
180180
assert len(test_dataset) == 1
181181

182182

+112
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
# Copyright (c) OpenMMLab. All rights reserved.
2+
from unittest import TestCase
3+
4+
import numpy as np
5+
import torch
6+
from mmengine.data import BaseDataElement, PixelData
7+
8+
from mmseg.core import SegDataSample
9+
from mmseg.metrics import CitysMetric
10+
11+
12+
class TestCitysMetric(TestCase):
13+
14+
def _demo_mm_inputs(self,
15+
batch_size=1,
16+
image_shapes=(3, 128, 256),
17+
num_classes=5):
18+
"""Create a superset of inputs needed to run test or train batches.
19+
20+
Args:
21+
batch_size (int): batch size. Default to 2.
22+
image_shapes (List[tuple], Optional): image shape.
23+
Default to (3, 64, 64)
24+
num_classes (int): number of different classes.
25+
Default to 5.
26+
"""
27+
if isinstance(image_shapes, list):
28+
assert len(image_shapes) == batch_size
29+
else:
30+
image_shapes = [image_shapes] * batch_size
31+
32+
packed_inputs = []
33+
for idx in range(batch_size):
34+
image_shape = image_shapes[idx]
35+
_, h, w = image_shape
36+
37+
mm_inputs = dict()
38+
data_sample = SegDataSample()
39+
gt_semantic_seg = np.random.randint(
40+
0, num_classes, (1, h, w), dtype=np.uint8)
41+
gt_semantic_seg = torch.LongTensor(gt_semantic_seg)
42+
gt_sem_seg_data = dict(data=gt_semantic_seg)
43+
data_sample.gt_sem_seg = PixelData(**gt_sem_seg_data)
44+
mm_inputs['data_sample'] = data_sample.to_dict()
45+
mm_inputs['data_sample']['seg_map_path'] = \
46+
'tests/data/pseudo_cityscapes_dataset/gtFine/val/\
47+
frankfurt/frankfurt_000000_000294_gtFine_labelTrainIds.png'
48+
49+
packed_inputs.append(mm_inputs)
50+
51+
return packed_inputs
52+
53+
def _demo_mm_model_output(self,
54+
batch_size=1,
55+
image_shapes=(3, 128, 256),
56+
num_classes=5):
57+
"""Create a superset of inputs needed to run test or train batches.
58+
59+
Args:
60+
batch_size (int): batch size. Default to 2.
61+
image_shapes (List[tuple], Optional): image shape.
62+
Default to (3, 64, 64)
63+
num_classes (int): number of different classes.
64+
Default to 5.
65+
"""
66+
results_dict = dict()
67+
_, h, w = image_shapes
68+
seg_logit = torch.randn(batch_size, num_classes, h, w)
69+
results_dict['seg_logits'] = seg_logit
70+
seg_pred = np.random.randint(
71+
0, num_classes, (batch_size, h, w), dtype=np.uint8)
72+
seg_pred = torch.LongTensor(seg_pred)
73+
results_dict['pred_sem_seg'] = seg_pred
74+
75+
batch_datasampes = [
76+
SegDataSample()
77+
for _ in range(results_dict['pred_sem_seg'].shape[0])
78+
]
79+
for key, value in results_dict.items():
80+
for i in range(value.shape[0]):
81+
setattr(batch_datasampes[i], key, PixelData(data=value[i]))
82+
83+
_predictions = []
84+
for pred in batch_datasampes:
85+
if isinstance(pred, BaseDataElement):
86+
test_data = pred.to_dict()
87+
test_data['img_path'] = \
88+
'tests/data/pseudo_cityscapes_dataset/leftImg8bit/val/\
89+
frankfurt/frankfurt_000000_000294_leftImg8bit.png'
90+
91+
_predictions.append(test_data)
92+
else:
93+
_predictions.append(pred)
94+
return _predictions
95+
96+
def test_evaluate(self):
97+
"""Test using the metric in the same way as Evalutor."""
98+
99+
data_batch = self._demo_mm_inputs()
100+
predictions = self._demo_mm_model_output()
101+
iou_metric = CitysMetric(citys_metrics=['cityscapes'])
102+
iou_metric.process(data_batch, predictions)
103+
res = iou_metric.evaluate(6)
104+
self.assertIsInstance(res, dict)
105+
# test to_label_id = True
106+
iou_metric = CitysMetric(
107+
citys_metrics=['cityscapes'], to_label_id=True)
108+
iou_metric.process(data_batch, predictions)
109+
res = iou_metric.evaluate(6)
110+
self.assertIsInstance(res, dict)
111+
import shutil
112+
shutil.rmtree('.format_cityscapes')

0 commit comments

Comments
 (0)