Skip to content

Commit ff95416

Browse files
authored
[Features]Support dump segment predition (#2712)
## Motivation 1. It is used to save the segmentation predictions as files and upload these files to a test server ## Modification 1. Add output_file and format only in `IoUMetric` ## BC-breaking (Optional) No ## Use cases (Optional) If this PR introduces a new feature, it is better to list some use cases here, and update the documentation. ## Checklist 1. Pre-commit or other linting tools are used to fix the potential lint issues. 3. The modification is covered by complete unit tests. If not, please add more unit test to ensure the correctness. 4. If the modification has potential influence on downstream projects, this PR should be tested with downstream projects, like MMDet or MMDet3D. 5. The documentation has been modified accordingly, like docstring or example tutorials.
1 parent f6de1aa commit ff95416

File tree

10 files changed

+347
-33
lines changed

10 files changed

+347
-33
lines changed

docs/en/migration/interface.md

+29
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,35 @@ Compared with MMSeg0.x, MMSeg1.x provides fewer command line arguments in `tools
6565
<td>--cfg-options randomness.deterministic=True</td>
6666
</table>
6767

68+
## Test launch
69+
70+
Similar to training launch, there are only common arguments in tools/test.py of MMSegmentation 1.x.
71+
Below is the difference in test scripts,
72+
please refer to [this documentation](../user_guides/4_train_test.md) for more details about test launch.
73+
74+
<table class="docutils">
75+
<tr>
76+
<td>Function</td>
77+
<td>0.x</td>
78+
<td>1.x</td>
79+
</tr>
80+
<tr>
81+
<td>Evaluation metrics</td>
82+
<td>--eval mIoU</td>
83+
<td>--cfg-options test_evaluator.type=IoUMetric</td>
84+
</tr>
85+
<tr>
86+
<td>Whether to use test time augmentation</td>
87+
<td>--aug-test</td>
88+
<td>--tta</td>
89+
</tr>
90+
<tr>
91+
<td>Whether save the output results without perform evaluation</td>
92+
<td>--format-only</td>
93+
<td>--cfg-options test_evaluator.format_only=True</td>
94+
</tr>
95+
</table>
96+
6897
## Configuration file
6998

7099
### Model settings

docs/en/user_guides/4_train_test.md

+96-1
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ This tool accepts several optional arguments, including:
7070
export CUDA_VISIBLE_DEVICES=-1
7171
```
7272

73-
And then run the script [above](#testing-on-a-single-gpu).
73+
then run the script [above](#testing-on-a-single-gpu).
7474

7575
## Training and testing on multiple GPUs and multiple machines
7676

@@ -218,3 +218,98 @@ You can check [the source code](../../../tools/slurm_test.sh) to review full arg
218218
CUDA_VISIBLE_DEVICES=0,1,2,3 GPUS=4 MASTER_PORT=29500 sh tools/slurm_train.sh ${PARTITION} ${JOB_NAME} config1.py ${WORK_DIR}
219219
CUDA_VISIBLE_DEVICES=4,5,6,7 GPUS=4 MASTER_PORT=29501 sh tools/slurm_train.sh ${PARTITION} ${JOB_NAME} config2.py ${WORK_DIR}
220220
```
221+
222+
## Testing and saving segment files
223+
224+
### Basic Usage
225+
226+
When you want to save the results, you can use `--out` to specify the output directory.
227+
228+
```shell
229+
python tools/test.py ${CONFIG_FILE} ${CHECKPOINT_FILE} --out ${OUTPUT_DIR}
230+
```
231+
232+
Here is an example to save the predicted results from model `fcn_r50-d8_4xb4-80k_ade20k-512x512` on ADE20k validatation dataset.
233+
234+
```shell
235+
python tools/test.py configs/fcn/fcn_r50-d8_4xb4-80k_ade20k-512x512.py ckpt/fcn_r50-d8_512x512_80k_ade20k_20200614_144016-f8ac5082.pth --out work_dirs/format_results
236+
```
237+
238+
You also can modify the config file to define `output_dir`. We also take
239+
`fcn_r50-d8_4xb4-80k_ade20k-512x512` as example just add
240+
`test_evaluator` in `configs/fcn/fcn_r50-d8_4xb4-80k_ade20k-512x512.py`
241+
242+
```python
243+
test_evaluator = dict(type='IoUMetric', iou_metrics=['mIoU'], output_dir='work_dirs/format_results')
244+
```
245+
246+
then run command without `--out`:
247+
248+
```shell
249+
python tools/test.py configs/fcn/fcn_r50-d8_4xb4-80k_ade20k-512x512.py ckpt/fcn_r50-d8_512x512_80k_ade20k_20200614_144016-f8ac5082.pth
250+
```
251+
252+
If you would like to only save the predicted results without evaluation as annotation is not released by the official dataset, you can set `format_only=True` and modify `test_dataloader`.
253+
As there is no annotation in dataset, we remove `dict(type='LoadAnnotations')` from `test_dataloader` Here is the example configuration:
254+
255+
```python
256+
test_evaluator = dict(
257+
type='IoUMetric',
258+
iou_metrics=['mIoU'],
259+
format_only=True,
260+
output_dir='work_dirs/format_results')
261+
test_dataloader = dict(
262+
batch_size=1,
263+
num_workers=4,
264+
persistent_workers=True,
265+
sampler=dict(type='DefaultSampler', shuffle=False),
266+
dataset=dict(
267+
type = 'ADE20KDataset'
268+
data_root='data/ade/release_test',
269+
data_prefix=dict(img_path='testing'),
270+
# we don't load annotation in test transform pipeline.
271+
pipeline=[
272+
dict(type='LoadImageFromFile'),
273+
dict(type='Resize', scale=(2048, 512), keep_ratio=True),
274+
dict(type='PackSegInputs')
275+
]))
276+
```
277+
278+
then run test command:
279+
280+
```shell
281+
python tools/test.py configs/fcn/fcn_r50-d8_4xb4-80k_ade20k-512x512.py ckpt/fcn_r50-d8_512x512_80k_ade20k_20200614_144016-f8ac5082.pth
282+
```
283+
284+
### Testing Cityscape dataset and save predicted segment files
285+
286+
We recommend `CityscapesMetric` which is the wrapper of Cityscapes'sdk, when you want to
287+
save the predicted results of Cityscape test dataset to submit them in [Cityscape test server](https://www.cityscapes-dataset.com/submit/). Here is the example configuration:
288+
289+
```python
290+
test_evaluator = dict(
291+
type='CityscapesMetric',
292+
format_only=True,
293+
keep_results=True,
294+
output_dir='work_dirs/format_results')
295+
test_dataloader = dict(
296+
batch_size=1,
297+
num_workers=4,
298+
persistent_workers=True,
299+
sampler=dict(type='DefaultSampler', shuffle=False),
300+
dataset=dict(
301+
type='CityscapesDataset',
302+
data_root='data/cityscapes/',
303+
data_prefix=dict(img_path='leftImg8bit/test'),
304+
pipeline=[
305+
dict(type='LoadImageFromFile'),
306+
dict(type='Resize', scale=(2048, 1024), keep_ratio=True),
307+
dict(type='PackSegInputs')
308+
]))
309+
```
310+
311+
then run test command, for example:
312+
313+
```shell
314+
python tools/test.py configs/fcn/fcn_r18-d8_4xb2-80k_cityscapes-512x1024.py ckpt/fcn_r18-d8_512x1024_80k_cityscapes_20201225_021327-6c50f8b4.pth
315+
```

docs/zh_cn/migration/interface.md

+29-2
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,33 @@ OpenMMLab 2.0 的主要改进是发布了 MMEngine,它为启动训练任务的
6565
<td>--cfg-options randomness.deterministic=True</td>
6666
</table>
6767

68+
## 测试启动
69+
70+
与训练启动类似,MMSegmentation 1.x 的测试启动脚本在 tools/test.py 中仅提供关键命令行参数,以下是测试启动脚本的区别,更多关于测试启动的细节请参考[这里](../user_guides/4_train_test.md)
71+
72+
<table class="docutils">
73+
<tr>
74+
<td>功能</td>
75+
<td>0.x</td>
76+
<td>1.x</td>
77+
</tr>
78+
<tr>
79+
<td>指定评测指标</td>
80+
<td>--eval mIoU</td>
81+
<td>--cfg-options test_evaluator.type=IoUMetric</td>
82+
</tr>
83+
<tr>
84+
<td>测试时数据增强</td>
85+
<td>--aug-test</td>
86+
<td>--tta</td>
87+
</tr>
88+
<tr>
89+
<td>测试时是否只保存预测结果不计算评测指标</td>
90+
<td>--format-only</td>
91+
<td>--cfg-options test_evaluator.format_only=True</td>
92+
</tr>
93+
</table>
94+
6895
## 配置文件
6996

7097
### 模型设置
@@ -98,7 +125,7 @@ OpenMMLab 2.0 的主要改进是发布了 MMEngine,它为启动训练任务的
98125

99126
**data** 的更改:
100127

101-
原版 `data` 字段被拆分为 `train_dataloader``val_dataloader``test_dataloader`。这允许我们以细粒度配置它们。例如,您可以在训练和测试期间指定不同的采样器和批次大小。
128+
原版 `data` 字段被拆分为 `train_dataloader``val_dataloader``test_dataloader`,允许我们以细粒度配置它们。例如,您可以在训练和测试期间指定不同的采样器和批次大小。
102129
`samples_per_gpu` 重命名为 `batch_size`
103130
`workers_per_gpu` 重命名为 `num_workers`
104131

@@ -144,7 +171,7 @@ test_dataloader = val_dataloader
144171
</tr>
145172
</table>
146173

147-
**流程**变更
174+
**数据增强变换流程**变更
148175

149176
- 原始格式转换 **`ToTensor`****`ImageToTensor`****`Collect`** 组合为 [`PackSegInputs`](mmseg.datasets.transforms.PackSegInputs)
150177
- 我们不建议在数据集流程中执行 **`Normalize`****Pad**。请将其从流程中删除,并将其设置在 `data_preprocessor` 字段中。

docs/zh_cn/user_guides/4_train_test.md

+92
Original file line numberDiff line numberDiff line change
@@ -223,3 +223,95 @@ GPUS=4 sh tools/slurm_train.sh dev pspnet configs/pspnet/pspnet_r50-d8_512x1024_
223223
CUDA_VISIBLE_DEVICES=0,1,2,3 GPUS=4 MASTER_PORT=29500 sh tools/slurm_train.sh ${分区} ${任务名} config1.py ${工作路径}
224224
CUDA_VISIBLE_DEVICES=4,5,6,7 GPUS=4 MASTER_PORT=29501 sh tools/slurm_train.sh ${分区} ${任务名} config2.py ${工作路径}
225225
```
226+
227+
## 测试并保存分割结果
228+
229+
### 基础使用
230+
231+
当需要保存测试输出的分割结果,用 `--out` 指定分割结果输出路径
232+
233+
```shell
234+
python tools/test.py ${CONFIG_FILE} ${CHECKPOINT_FILE} --out ${OUTPUT_DIR}
235+
```
236+
237+
以保存模型 `fcn_r50-d8_4xb4-80k_ade20k-512x512` 在 ADE20K 验证数据集上的结果为例:
238+
239+
```shell
240+
python tools/test.py configs/fcn/fcn_r50-d8_4xb4-80k_ade20k-512x512.py ckpt/fcn_r50-d8_512x512_80k_ade20k_20200614_144016-f8ac5082.pth --out work_dirs/format_results
241+
```
242+
243+
或者通过配置文件定义 `output_dir`。例如在 `configs/fcn/fcn_r50-d8_4xb4-80k_ade20k-512x512.py` 添加 `test_evaluator` 定义:
244+
245+
```python
246+
test_evaluator = dict(type='IoUMetric', iou_metrics=['mIoU'], output_dir='work_dirs/format_results')
247+
```
248+
249+
然后执行相同功能的命令不需要再使用 `--out`
250+
251+
```shell
252+
python tools/test.py configs/fcn/fcn_r50-d8_4xb4-80k_ade20k-512x512.py ckpt/fcn_r50-d8_512x512_80k_ade20k_20200614_144016-f8ac5082.pth
253+
```
254+
255+
当测试的数据集没有提供标注,评测时没有真值可以参与计算,因此需要设置 `format_only=True`
256+
同时需要修改 `test_dataloader`,由于没有标注,我们需要在数据增强变换中删掉 `dict(type='LoadAnnotations')`,以下是一个配置示例:
257+
258+
```python
259+
test_evaluator = dict(
260+
type='IoUMetric',
261+
iou_metrics=['mIoU'],
262+
format_only=True,
263+
output_dir='work_dirs/format_results')
264+
test_dataloader = dict(
265+
batch_size=1,
266+
num_workers=4,
267+
persistent_workers=True,
268+
sampler=dict(type='DefaultSampler', shuffle=False),
269+
dataset=dict(
270+
type = 'ADE20KDataset'
271+
data_root='data/ade/release_test',
272+
data_prefix=dict(img_path='testing'),
273+
# 测试数据变换中没有加载标注
274+
pipeline=[
275+
dict(type='LoadImageFromFile'),
276+
dict(type='Resize', scale=(2048, 512), keep_ratio=True),
277+
dict(type='PackSegInputs')
278+
]))
279+
```
280+
281+
然后执行测试命令:
282+
283+
```shell
284+
python tools/test.py configs/fcn/fcn_r50-d8_4xb4-80k_ade20k-512x512.py ckpt/fcn_r50-d8_512x512_80k_ade20k_20200614_144016-f8ac5082.pth
285+
```
286+
287+
### 测试 Cityscapes 数据集并保存输出分割结果
288+
289+
推荐使用 `CityscapesMetric` 来保存模型在 Cityscapes 数据集上的测试结果,以下是一个配置示例:
290+
291+
```python
292+
test_evaluator = dict(
293+
type='CityscapesMetric',
294+
format_only=True,
295+
keep_results=True,
296+
output_dir='work_dirs/format_results')
297+
test_dataloader = dict(
298+
batch_size=1,
299+
num_workers=4,
300+
persistent_workers=True,
301+
sampler=dict(type='DefaultSampler', shuffle=False),
302+
dataset=dict(
303+
type='CityscapesDataset',
304+
data_root='data/cityscapes/',
305+
data_prefix=dict(img_path='leftImg8bit/test'),
306+
pipeline=[
307+
dict(type='LoadImageFromFile'),
308+
dict(type='Resize', scale=(2048, 1024), keep_ratio=True),
309+
dict(type='PackSegInputs')
310+
]))
311+
```
312+
313+
然后执行相同的命令,例如:
314+
315+
```shell
316+
python tools/test.py configs/fcn/fcn_r18-d8_4xb2-80k_cityscapes-512x1024.py ckpt/fcn_r18-d8_512x1024_80k_cityscapes_20201225_021327-6c50f8b4.pth
317+
```

mmseg/datasets/transforms/formatting.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ class PackSegInputs(BaseTransform):
4444
def __init__(self,
4545
meta_keys=('img_path', 'seg_map_path', 'ori_shape',
4646
'img_shape', 'pad_shape', 'scale_factor', 'flip',
47-
'flip_direction')):
47+
'flip_direction', 'reduce_zero_label')):
4848
self.meta_keys = meta_keys
4949

5050
def transform(self, results: dict) -> dict:

mmseg/evaluation/metrics/citys_metric.py

+10-5
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,8 @@ def __init__(self,
5151
format_only: bool = False,
5252
keep_results: bool = False,
5353
collect_device: str = 'cpu',
54-
prefix: Optional[str] = None) -> None:
54+
prefix: Optional[str] = None,
55+
**kwargs) -> None:
5556
super().__init__(collect_device=collect_device, prefix=prefix)
5657
if CSEval is None:
5758
raise ImportError('Please run "pip install cityscapesscripts" to '
@@ -97,10 +98,14 @@ def process(self, data_batch: dict, data_samples: Sequence[dict]) -> None:
9798
osp.join(self.output_dir, f'{basename}.png'))
9899
output = Image.fromarray(pred_label.astype(np.uint8)).convert('P')
99100
output.save(png_filename)
100-
# when evaluating with official cityscapesscripts,
101-
# **_gtFine_labelIds.png is used
102-
gt_filename = data_sample['seg_map_path'].replace(
103-
'labelTrainIds.png', 'labelIds.png')
101+
if self.format_only:
102+
# format_only always for test dataset without ground truth
103+
gt_filename = ''
104+
else:
105+
# when evaluating with official cityscapesscripts,
106+
# **_gtFine_labelIds.png is used
107+
gt_filename = data_sample['seg_map_path'].replace(
108+
'labelTrainIds.png', 'labelIds.png')
104109
self.results.append((png_filename, gt_filename))
105110

106111
def compute_metrics(self, results: list) -> Dict[str, float]:

0 commit comments

Comments
 (0)