Skip to content

Commit eef12a0

Browse files
author
zhengmiao
committed
Merge branch 'limengzhang/fix_kwargs' into 'refactor_dev'
[Fix] Delete all **kwargs in Segmentor Forward function See merge request openmmlab-enterprise/openmmlab-ce/mmsegmentation!52
2 parents 9ed2e1c + c5ad7fb commit eef12a0

File tree

5 files changed

+70
-79
lines changed

5 files changed

+70
-79
lines changed

mmseg/models/decode_heads/cascade_decode_head.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,7 @@ def forward(self, inputs, prev_output):
2121
pass
2222

2323
def loss(self, inputs: List[Tensor], prev_output: Tensor,
24-
batch_data_samples: List[dict], train_cfg: ConfigType,
25-
**kwargs) -> Tensor:
24+
batch_data_samples: List[dict], train_cfg: ConfigType) -> Tensor:
2625
"""Forward function for training.
2726
2827
Args:
@@ -37,12 +36,12 @@ def loss(self, inputs: List[Tensor], prev_output: Tensor,
3736
dict[str, Tensor]: a dictionary of loss components
3837
"""
3938
seg_logits = self.forward(inputs, prev_output)
40-
losses = self.loss_by_feat(seg_logits, batch_data_samples, **kwargs)
39+
losses = self.loss_by_feat(seg_logits, batch_data_samples)
4140

4241
return losses
4342

4443
def predict(self, inputs: List[Tensor], prev_output: Tensor,
45-
batch_img_metas: List[dict], tese_cfg: ConfigType, **kwargs):
44+
batch_img_metas: List[dict], tese_cfg: ConfigType):
4645
"""Forward function for testing.
4746
4847
Args:
@@ -60,4 +59,4 @@ def predict(self, inputs: List[Tensor], prev_output: Tensor,
6059
"""
6160
seg_logits = self.forward(inputs, prev_output)
6261

63-
return self.predict_by_feat(seg_logits, batch_img_metas, **kwargs)
62+
return self.predict_by_feat(seg_logits, batch_img_metas)

mmseg/models/decode_heads/decode_head.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -205,7 +205,7 @@ def _transform_inputs(self, inputs):
205205
return inputs
206206

207207
@abstractmethod
208-
def forward(self, inputs, **kwargs):
208+
def forward(self, inputs):
209209
"""Placeholder of forward function."""
210210
pass
211211

@@ -217,7 +217,7 @@ def cls_seg(self, feat):
217217
return output
218218

219219
def loss(self, inputs: Tuple[Tensor], batch_data_samples: SampleList,
220-
train_cfg: ConfigType, **kwargs) -> dict:
220+
train_cfg: ConfigType) -> dict:
221221
"""Forward function for training.
222222
223223
Args:
@@ -230,12 +230,12 @@ def loss(self, inputs: Tuple[Tensor], batch_data_samples: SampleList,
230230
Returns:
231231
dict[str, Tensor]: a dictionary of loss components
232232
"""
233-
seg_logits = self.forward(inputs, **kwargs)
234-
losses = self.loss_by_feat(seg_logits, batch_data_samples, **kwargs)
233+
seg_logits = self.forward(inputs)
234+
losses = self.loss_by_feat(seg_logits, batch_data_samples)
235235
return losses
236236

237237
def predict(self, inputs: Tuple[Tensor], batch_img_metas: List[dict],
238-
test_cfg: ConfigType, **kwargs) -> List[Tensor]:
238+
test_cfg: ConfigType) -> List[Tensor]:
239239
"""Forward function for prediction.
240240
241241
Args:
@@ -250,18 +250,18 @@ def predict(self, inputs: Tuple[Tensor], batch_img_metas: List[dict],
250250
Returns:
251251
List[Tensor]: Outputs segmentation logits map.
252252
"""
253-
seg_logits = self.forward(inputs, **kwargs)
253+
seg_logits = self.forward(inputs)
254254

255-
return self.predict_by_feat(seg_logits, batch_img_metas, **kwargs)
255+
return self.predict_by_feat(seg_logits, batch_img_metas)
256256

257257
def _stack_batch_gt(self, batch_data_samples: SampleList) -> Tensor:
258258
gt_semantic_segs = [
259259
data_sample.gt_sem_seg.data for data_sample in batch_data_samples
260260
]
261261
return torch.stack(gt_semantic_segs, dim=0)
262262

263-
def loss_by_feat(self, seg_logits: Tensor, batch_data_samples: SampleList,
264-
**kwargs) -> dict:
263+
def loss_by_feat(self, seg_logits: Tensor,
264+
batch_data_samples: SampleList) -> dict:
265265
"""Compute segmentation loss.
266266
267267
Args:
@@ -309,8 +309,8 @@ def loss_by_feat(self, seg_logits: Tensor, batch_data_samples: SampleList,
309309
seg_logits, seg_label, ignore_index=self.ignore_index)
310310
return loss
311311

312-
def predict_by_feat(self, seg_logits: Tensor, batch_img_metas: List[dict],
313-
**kwargs) -> List[Tensor]:
312+
def predict_by_feat(self, seg_logits: Tensor,
313+
batch_img_metas: List[dict]) -> List[Tensor]:
314314
"""Transform a batch of output seg_logits to the input shape.
315315
316316
Args:

mmseg/models/segmentors/base.py

Lines changed: 14 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -53,16 +53,15 @@ def extract_feat(self, batch_inputs: Tensor) -> bool:
5353

5454
@abstractmethod
5555
def encode_decode(self, batch_inputs: Tensor,
56-
batch_data_samples: SampleList, **kwargs):
56+
batch_data_samples: SampleList):
5757
"""Placeholder for encode images with backbone and decode into a
5858
semantic segmentation map of the same size as input."""
5959
pass
6060

6161
def forward(self,
6262
batch_inputs: Tensor,
6363
batch_data_samples: OptSampleList = None,
64-
mode: str = 'tensor',
65-
**kwargs) -> ForwardResults:
64+
mode: str = 'tensor') -> ForwardResults:
6665
"""The unified entry for a forward process in both training and test.
6766
6867
The method should accept three modes: "tensor", "predict" and "loss":
@@ -92,33 +91,33 @@ def forward(self,
9291
- If ``mode="loss"``, return a dict of tensor.
9392
"""
9493
if mode == 'loss':
95-
return self.loss(batch_inputs, batch_data_samples, **kwargs)
94+
return self.loss(batch_inputs, batch_data_samples)
9695
elif mode == 'predict':
97-
return self.predict(batch_inputs, batch_data_samples, **kwargs)
96+
return self.predict(batch_inputs, batch_data_samples)
9897
elif mode == 'tensor':
99-
return self._forward(batch_inputs, batch_data_samples, **kwargs)
98+
return self._forward(batch_inputs, batch_data_samples)
10099
else:
101100
raise RuntimeError(f'Invalid mode "{mode}". '
102101
'Only supports loss, predict and tensor mode')
103102

104103
@abstractmethod
105-
def loss(self, batch_inputs: Tensor, batch_data_samples: SampleList,
106-
**kwargs) -> dict:
104+
def loss(self, batch_inputs: Tensor,
105+
batch_data_samples: SampleList) -> dict:
107106
"""Calculate losses from a batch of inputs and data samples."""
108107
pass
109108

110109
@abstractmethod
111-
def predict(self, batch_inputs: Tensor, batch_data_samples: SampleList,
112-
**kwargs) -> SampleList:
110+
def predict(self, batch_inputs: Tensor,
111+
batch_data_samples: SampleList) -> SampleList:
113112
"""Predict results from a batch of inputs and data samples with post-
114113
processing."""
115114
pass
116115

117116
@abstractmethod
118-
def _forward(self,
119-
batch_inputs: Tensor,
120-
batch_data_samples: OptSampleList = None,
121-
**kwargs) -> Tuple[List[Tensor]]:
117+
def _forward(
118+
self,
119+
batch_inputs: Tensor,
120+
batch_data_samples: OptSampleList = None) -> Tuple[List[Tensor]]:
122121
"""Network forward process.
123122
124123
Usually includes backbone, neck and head forward without any post-
@@ -127,7 +126,7 @@ def _forward(self,
127126
pass
128127

129128
@abstractmethod
130-
def aug_test(self, batch_inputs, batch_img_metas, **kwargs):
129+
def aug_test(self, batch_inputs, batch_img_metas):
131130
"""Placeholder for augmentation test."""
132131
pass
133132

mmseg/models/segmentors/cascade_encoder_decoder.py

Lines changed: 12 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -70,29 +70,28 @@ def _init_decode_head(self, decode_head: ConfigType) -> None:
7070
self.align_corners = self.decode_head[-1].align_corners
7171
self.num_classes = self.decode_head[-1].num_classes
7272

73-
def encode_decode(self, batch_inputs: Tensor, batch_img_metas: List[dict],
74-
**kwargs) -> List[Tensor]:
73+
def encode_decode(self, batch_inputs: Tensor,
74+
batch_img_metas: List[dict]) -> List[Tensor]:
7575
"""Encode images with backbone and decode into a semantic segmentation
7676
map of the same size as input."""
7777
x = self.extract_feat(batch_inputs)
78-
out = self.decode_head[0].forward(x, **kwargs)
78+
out = self.decode_head[0].forward(x)
7979
for i in range(1, self.num_stages - 1):
80-
out = self.decode_head[i].forward(x, out, **kwargs)
80+
out = self.decode_head[i].forward(x, out)
8181
seg_logits_list = self.decode_head[-1].predict(x, out, batch_img_metas,
82-
self.test_cfg, **kwargs)
82+
self.test_cfg)
8383

8484
return seg_logits_list
8585

8686
def _decode_head_forward_train(self, batch_inputs: Tensor,
87-
batch_data_samples: SampleList,
88-
**kwargs) -> dict:
87+
batch_data_samples: SampleList) -> dict:
8988
"""Run forward function and calculate loss for decode head in
9089
training."""
9190
losses = dict()
9291

9392
loss_decode = self.decode_head[0].loss(batch_inputs,
9493
batch_data_samples,
95-
self.train_cfg, **kwargs)
94+
self.train_cfg)
9695

9796
losses.update(add_prefix(loss_decode, 'decode_0'))
9897
# get batch_img_metas
@@ -105,22 +104,20 @@ def _decode_head_forward_train(self, batch_inputs: Tensor,
105104
for i in range(1, self.num_stages):
106105
# forward test again, maybe unnecessary for most methods.
107106
if i == 1:
108-
prev_outputs = self.decode_head[0].forward(
109-
batch_inputs, **kwargs)
107+
prev_outputs = self.decode_head[0].forward(batch_inputs)
110108
else:
111109
prev_outputs = self.decode_head[i - 1].forward(
112-
batch_inputs, prev_outputs, **kwargs)
110+
batch_inputs, prev_outputs)
113111
loss_decode = self.decode_head[i].loss(batch_inputs, prev_outputs,
114112
batch_data_samples,
115-
self.train_cfg, **kwargs)
113+
self.train_cfg)
116114
losses.update(add_prefix(loss_decode, f'decode_{i}'))
117115

118116
return losses
119117

120118
def _forward(self,
121119
batch_inputs: Tensor,
122-
data_samples: OptSampleList = None,
123-
**kwargs) -> Tensor:
120+
data_samples: OptSampleList = None) -> Tensor:
124121
"""Network forward process.
125122
126123
Args:
@@ -137,6 +134,6 @@ def _forward(self,
137134
out = self.decode_head[0].forward(x)
138135
for i in range(1, self.num_stages):
139136
# TODO support PointRend tensor mode
140-
out = self.decode_head[i].forward(x, out, **kwargs)
137+
out = self.decode_head[i].forward(x, out)
141138

142139
return out

0 commit comments

Comments
 (0)