|
| 1 | +# Copyright (c) OpenMMLab. All rights reserved. |
| 2 | +import os.path as osp |
| 3 | +from typing import Dict, List, Optional, Sequence |
| 4 | + |
| 5 | +import mmcv |
| 6 | +import numpy as np |
| 7 | +from mmengine.evaluator import BaseMetric |
| 8 | +from mmengine.logging import MMLogger, print_log |
| 9 | +from PIL import Image |
| 10 | + |
| 11 | +from mmseg.registry import METRICS |
| 12 | + |
| 13 | + |
| 14 | +@METRICS.register_module() |
| 15 | +class CitysMetric(BaseMetric): |
| 16 | + """Cityscapes evaluation metric. |
| 17 | +
|
| 18 | + Args: |
| 19 | + ignore_index (int): Index that will be ignored in evaluation. |
| 20 | + Default: 255. |
| 21 | + citys_metrics (list[str] | str): Metrics to be evaluated, |
| 22 | + Default: ['cityscapes']. |
| 23 | + to_label_id (bool): whether convert output to label_id for |
| 24 | + submission. Default: True. |
| 25 | + suffix (str): The filename prefix of the png files. |
| 26 | + If the prefix is "somepath/xxx", the png files will be |
| 27 | + named "somepath/xxx.png". Default: '.format_cityscapes'. |
| 28 | + collect_device (str): Device name used for collecting results from |
| 29 | + different ranks during distributed training. Must be 'cpu' or |
| 30 | + 'gpu'. Defaults to 'cpu'. |
| 31 | + prefix (str, optional): The prefix that will be added in the metric |
| 32 | + names to disambiguate homonymous metrics of different evaluators. |
| 33 | + If prefix is not provided in the argument, self.default_prefix |
| 34 | + will be used instead. Defaults to None. |
| 35 | + """ |
| 36 | + |
| 37 | + def __init__(self, |
| 38 | + ignore_index: int = 255, |
| 39 | + citys_metrics: List[str] = ['cityscapes'], |
| 40 | + to_label_id: bool = True, |
| 41 | + suffix: str = '.format_cityscapes', |
| 42 | + collect_device: str = 'cpu', |
| 43 | + prefix: Optional[str] = None) -> None: |
| 44 | + super().__init__(collect_device=collect_device, prefix=prefix) |
| 45 | + |
| 46 | + self.ignore_index = ignore_index |
| 47 | + self.metrics = citys_metrics |
| 48 | + assert self.metrics[0] == 'cityscapes' |
| 49 | + self.to_label_id = to_label_id |
| 50 | + self.suffix = suffix |
| 51 | + |
| 52 | + def process(self, data_batch: Sequence[dict], |
| 53 | + predictions: Sequence[dict]) -> None: |
| 54 | + """Process one batch of data and predictions. |
| 55 | +
|
| 56 | + The processed results should be stored in ``self.results``, which will |
| 57 | + be used to computed the metrics when all batches have been processed. |
| 58 | +
|
| 59 | + Args: |
| 60 | + data_batch (Sequence[dict]): A batch of data from the dataloader. |
| 61 | + predictions (Sequence[dict]): A batch of outputs from the model. |
| 62 | + """ |
| 63 | + mmcv.mkdir_or_exist(self.suffix) |
| 64 | + |
| 65 | + for pred in predictions: |
| 66 | + pred_label = pred['pred_sem_seg']['data'][0].cpu().numpy() |
| 67 | + # results2img |
| 68 | + if self.to_label_id: |
| 69 | + pred_label = self._convert_to_label_id(pred_label) |
| 70 | + basename = osp.splitext(osp.basename(pred['img_path']))[0] |
| 71 | + png_filename = osp.join(self.suffix, f'{basename}.png') |
| 72 | + output = Image.fromarray(pred_label.astype(np.uint8)).convert('P') |
| 73 | + import cityscapesscripts.helpers.labels as CSLabels |
| 74 | + palette = np.zeros((len(CSLabels.id2label), 3), dtype=np.uint8) |
| 75 | + for label_id, label in CSLabels.id2label.items(): |
| 76 | + palette[label_id] = label.color |
| 77 | + output.putpalette(palette) |
| 78 | + output.save(png_filename) |
| 79 | + |
| 80 | + ann_dir = osp.join( |
| 81 | + data_batch[0]['data_sample']['seg_map_path'].split('val')[0], |
| 82 | + 'val') |
| 83 | + self.results.append(ann_dir) |
| 84 | + |
| 85 | + def compute_metrics(self, results: list) -> Dict[str, float]: |
| 86 | + """Compute the metrics from processed results. |
| 87 | +
|
| 88 | + Args: |
| 89 | + results (list): Testing results of the dataset. |
| 90 | + logger (logging.Logger | str | None): Logger used for printing |
| 91 | + related information during evaluation. Default: None. |
| 92 | + imgfile_prefix (str | None): The prefix of output image file |
| 93 | +
|
| 94 | + Returns: |
| 95 | + dict[str: float]: Cityscapes evaluation results. |
| 96 | + """ |
| 97 | + logger: MMLogger = MMLogger.get_current_instance() |
| 98 | + try: |
| 99 | + import cityscapesscripts.evaluation.evalPixelLevelSemanticLabeling as CSEval # noqa |
| 100 | + except ImportError: |
| 101 | + raise ImportError('Please run "pip install cityscapesscripts" to ' |
| 102 | + 'install cityscapesscripts first.') |
| 103 | + msg = 'Evaluating in Cityscapes style' |
| 104 | + |
| 105 | + if logger is None: |
| 106 | + msg = '\n' + msg |
| 107 | + print_log(msg, logger=logger) |
| 108 | + |
| 109 | + result_dir = self.suffix |
| 110 | + |
| 111 | + eval_results = dict() |
| 112 | + print_log(f'Evaluating results under {result_dir} ...', logger=logger) |
| 113 | + |
| 114 | + CSEval.args.evalInstLevelScore = True |
| 115 | + CSEval.args.predictionPath = osp.abspath(result_dir) |
| 116 | + CSEval.args.evalPixelAccuracy = True |
| 117 | + CSEval.args.JSONOutput = False |
| 118 | + |
| 119 | + seg_map_list = [] |
| 120 | + pred_list = [] |
| 121 | + ann_dir = results[0] |
| 122 | + # when evaluating with official cityscapesscripts, |
| 123 | + # **_gtFine_labelIds.png is used |
| 124 | + for seg_map in mmcv.scandir( |
| 125 | + ann_dir, 'gtFine_labelIds.png', recursive=True): |
| 126 | + seg_map_list.append(osp.join(ann_dir, seg_map)) |
| 127 | + pred_list.append(CSEval.getPrediction(CSEval.args, seg_map)) |
| 128 | + metric = dict() |
| 129 | + eval_results.update( |
| 130 | + CSEval.evaluateImgLists(pred_list, seg_map_list, CSEval.args)) |
| 131 | + metric['averageScoreCategories'] = eval_results[ |
| 132 | + 'averageScoreCategories'] |
| 133 | + metric['averageScoreInstCategories'] = eval_results[ |
| 134 | + 'averageScoreInstCategories'] |
| 135 | + return metric |
| 136 | + |
| 137 | + @staticmethod |
| 138 | + def _convert_to_label_id(result): |
| 139 | + """Convert trainId to id for cityscapes.""" |
| 140 | + if isinstance(result, str): |
| 141 | + result = np.load(result) |
| 142 | + import cityscapesscripts.helpers.labels as CSLabels |
| 143 | + result_copy = result.copy() |
| 144 | + for trainId, label in CSLabels.trainId2label.items(): |
| 145 | + result_copy[result == trainId] = label.id |
| 146 | + |
| 147 | + return result_copy |
0 commit comments