Skip to content

Commit d561f07

Browse files
authored
Made label data optional for inference and adopted other required changes (#2183)
* Made label data optional for inference and adopted other required changes * address comments * Fix format issues * update version numbers for newly generated checkpoints
1 parent 193ea36 commit d561f07

File tree

4 files changed

+48
-48
lines changed

4 files changed

+48
-48
lines changed
Lines changed: 0 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
import keras
2-
31
from keras_hub.src.api_export import keras_hub_export
42
from keras_hub.src.models.image_segmenter_preprocessor import (
53
ImageSegmenterPreprocessor,
@@ -8,25 +6,9 @@
86
from keras_hub.src.models.segformer.segformer_image_converter import (
97
SegFormerImageConverter,
108
)
11-
from keras_hub.src.utils.tensor_utils import preprocessing_function
12-
13-
IMAGENET_DEFAULT_MEAN = [0.485, 0.456, 0.406]
14-
IMAGENET_DEFAULT_STD = [0.229, 0.224, 0.225]
159

1610

1711
@keras_hub_export("keras_hub.models.SegFormerImageSegmenterPreprocessor")
1812
class SegFormerImageSegmenterPreprocessor(ImageSegmenterPreprocessor):
1913
backbone_cls = SegFormerBackbone
2014
image_converter_cls = SegFormerImageConverter
21-
22-
@preprocessing_function
23-
def call(self, x, y=None, sample_weight=None):
24-
if self.image_converter:
25-
x = self.image_converter(x)
26-
if y is not None:
27-
y = self.image_converter(y)
28-
29-
x = x / 255
30-
x = (x - IMAGENET_DEFAULT_MEAN) / IMAGENET_DEFAULT_STD
31-
32-
return keras.utils.pack_x_y_sample_weight(x, y, sample_weight)

keras_hub/src/models/segformer/segformer_image_segmenter_tests.py

Lines changed: 23 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,37 +1,45 @@
11
import numpy as np
22
import pytest
3-
from keras import ops
43

54
from keras_hub.src.models.mit.mit_backbone import MiTBackbone
65
from keras_hub.src.models.segformer.segformer_backbone import SegFormerBackbone
76
from keras_hub.src.models.segformer.segformer_image_segmenter import (
87
SegFormerImageSegmenter,
98
)
9+
from keras_hub.src.models.segformer.segformer_image_segmenter_preprocessor import ( # noqa: E501
10+
SegFormerImageSegmenterPreprocessor,
11+
)
1012
from keras_hub.src.tests.test_case import TestCase
1113

1214

1315
class SegFormerTest(TestCase):
1416
def setUp(self):
1517
image_encoder = MiTBackbone(
16-
depths=[2, 2],
17-
image_shape=(224, 224, 3),
18+
layerwise_depths=[2, 2],
19+
image_shape=(32, 32, 3),
1820
hidden_dims=[32, 64],
1921
num_layers=2,
20-
blockwise_num_heads=[1, 2],
21-
blockwise_sr_ratios=[8, 4],
22+
layerwise_num_heads=[1, 2],
23+
layerwise_sr_ratios=[8, 4],
2224
max_drop_path_rate=0.1,
23-
patch_sizes=[7, 3],
24-
strides=[4, 2],
25+
layerwise_patch_sizes=[7, 3],
26+
layerwise_strides=[4, 2],
2527
)
2628
projection_filters = 256
29+
self.preprocessor = SegFormerImageSegmenterPreprocessor()
2730
self.backbone = SegFormerBackbone(
2831
image_encoder=image_encoder, projection_filters=projection_filters
2932
)
3033

31-
self.input_size = 224
32-
self.input_data = ops.ones((2, self.input_size, self.input_size, 3))
34+
self.input_size = 32
35+
self.input_data = np.ones((2, self.input_size, self.input_size, 3))
36+
self.label_data = np.ones((2, self.input_size, self.input_size, 4))
3337

34-
self.init_kwargs = {"backbone": self.backbone, "num_classes": 4}
38+
self.init_kwargs = {
39+
"backbone": self.backbone,
40+
"num_classes": 4,
41+
"preprocessor": self.preprocessor,
42+
}
3543

3644
def test_segformer_segmenter_construction(self):
3745
SegFormerImageSegmenter(backbone=self.backbone, num_classes=4)
@@ -42,19 +50,19 @@ def test_segformer_call(self):
4250
backbone=self.backbone, num_classes=4
4351
)
4452

45-
images = np.random.uniform(size=(2, 224, 224, 4))
53+
images = np.random.uniform(size=(2, 32, 32, 3))
4654
segformer_output = segformer(images)
4755
segformer_predict = segformer.predict(images)
4856

49-
assert segformer_output.shape == images.shape
50-
assert segformer_predict.shape == images.shape
57+
self.assertAllEqual(segformer_output.shape, (2, 32, 32, 4))
58+
self.assertAllEqual(segformer_predict.shape, (2, 32, 32, 4))
5159

5260
def test_task(self):
5361
self.run_task_test(
5462
cls=SegFormerImageSegmenter,
5563
init_kwargs={**self.init_kwargs},
56-
train_data=self.input_data,
57-
expected_output_shape=(2, 224, 224),
64+
train_data=(self.input_data, self.label_data),
65+
expected_output_shape=(2, 32, 32, 4),
5866
)
5967

6068
@pytest.mark.large

keras_hub/src/models/segformer/segformer_presets.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
"params": 3719027,
1111
"path": "segformer_b0",
1212
},
13-
"kaggle_handle": "kaggle://keras/segformer/keras/segformer_b0_ade20k_512/2",
13+
"kaggle_handle": "kaggle://keras/segformer/keras/segformer_b0_ade20k_512/3",
1414
},
1515
"segformer_b1_ade20k_512": {
1616
"metadata": {
@@ -21,7 +21,7 @@
2121
"params": 13682643,
2222
"path": "segformer_b1",
2323
},
24-
"kaggle_handle": "kaggle://keras/segformer/keras/segformer_b1_ade20k_512/2",
24+
"kaggle_handle": "kaggle://keras/segformer/keras/segformer_b1_ade20k_512/5",
2525
},
2626
"segformer_b2_ade20k_512": {
2727
"metadata": {
@@ -32,7 +32,7 @@
3232
"params": 24727507,
3333
"path": "segformer_b2",
3434
},
35-
"kaggle_handle": "kaggle://keras/segformer/keras/segformer_b2_ade20k_512/2",
35+
"kaggle_handle": "kaggle://keras/segformer/keras/segformer_b2_ade20k_512/3",
3636
},
3737
"segformer_b3_ade20k_512": {
3838
"metadata": {
@@ -43,7 +43,7 @@
4343
"params": 44603347,
4444
"path": "segformer_b3",
4545
},
46-
"kaggle_handle": "kaggle://keras/segformer/keras/segformer_b3_ade20k_512/2",
46+
"kaggle_handle": "kaggle://keras/segformer/keras/segformer_b3_ade20k_512/3",
4747
},
4848
"segformer_b4_ade20k_512": {
4949
"metadata": {
@@ -54,7 +54,7 @@
5454
"params": 61373907,
5555
"path": "segformer_b4",
5656
},
57-
"kaggle_handle": "kaggle://keras/segformer/keras/segformer_b4_ade20k_512/2",
57+
"kaggle_handle": "kaggle://keras/segformer/keras/segformer_b4_ade20k_512/3",
5858
},
5959
"segformer_b5_ade20k_640": {
6060
"metadata": {
@@ -65,7 +65,7 @@
6565
"params": 81974227,
6666
"path": "segformer_b5",
6767
},
68-
"kaggle_handle": "kaggle://keras/segformer/keras/segformer_b5_ade20k_640/2",
68+
"kaggle_handle": "kaggle://keras/segformer/keras/segformer_b5_ade20k_640/3",
6969
},
7070
"segformer_b0_cityscapes_1024": {
7171
"metadata": {
@@ -76,7 +76,7 @@
7676
"params": 3719027,
7777
"path": "segformer_b0",
7878
},
79-
"kaggle_handle": "kaggle://keras/segformer/keras/segformer_b0_cityscapes_1024/2",
79+
"kaggle_handle": "kaggle://keras/segformer/keras/segformer_b0_cityscapes_1024/3",
8080
},
8181
"segformer_b1_cityscapes_1024": {
8282
"metadata": {
@@ -87,7 +87,7 @@
8787
"params": 13682643,
8888
"path": "segformer_b1",
8989
},
90-
"kaggle_handle": "kaggle://keras/segformer/keras/segformer_b1_ade20k_512/2",
90+
"kaggle_handle": "kaggle://keras/segformer/keras/segformer_b1_ade20k_512/1",
9191
},
9292
"segformer_b2_cityscapes_1024": {
9393
"metadata": {
@@ -98,7 +98,7 @@
9898
"params": 24727507,
9999
"path": "segformer_b2",
100100
},
101-
"kaggle_handle": "kaggle://keras/segformer/keras/segformer_b2_cityscapes_1024/2",
101+
"kaggle_handle": "kaggle://keras/segformer/keras/segformer_b2_cityscapes_1024/3",
102102
},
103103
"segformer_b3_cityscapes_1024": {
104104
"metadata": {
@@ -109,7 +109,7 @@
109109
"params": 44603347,
110110
"path": "segformer_b3",
111111
},
112-
"kaggle_handle": "kaggle://keras/segformer/keras/segformer_b3_cityscapes_1024/2",
112+
"kaggle_handle": "kaggle://keras/segformer/keras/segformer_b3_cityscapes_1024/3",
113113
},
114114
"segformer_b4_cityscapes_1024": {
115115
"metadata": {
@@ -120,7 +120,7 @@
120120
"params": 61373907,
121121
"path": "segformer_b4",
122122
},
123-
"kaggle_handle": "kaggle://keras/segformer/keras/segformer_b4_cityscapes_1024/2",
123+
"kaggle_handle": "kaggle://keras/segformer/keras/segformer_b4_cityscapes_1024/3",
124124
},
125125
"segformer_b5_cityscapes_1024": {
126126
"metadata": {
@@ -131,6 +131,6 @@
131131
"params": 81974227,
132132
"path": "segformer_b5",
133133
},
134-
"kaggle_handle": "kaggle://keras/segformer/keras/segformer_b5_cityscapes_1024/2",
134+
"kaggle_handle": "kaggle://keras/segformer/keras/segformer_b5_cityscapes_1024/3",
135135
},
136136
}

tools/checkpoint_conversion/convert_segformer_checkpoints.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# Usage example
2-
# python tools/checkpoint_conversion/convert_mix_transformer.py \
3-
# --preset "B0_ade_512"
2+
# python tools/checkpoint_conversion/convert_segformer_checkpoints.py \
3+
# --preset "b0_ade20k_512"
44

55
import numpy as np
66
from absl import app
@@ -94,7 +94,17 @@ def main(_):
9494
)
9595
num_classes = 150 if "ade20k" in FLAGS.preset else 19
9696

97-
preprocessor = SegFormerImageSegmenterPreprocessor()
97+
image_converter = keras_hub.layers.SegFormerImageConverter(
98+
height=resolution,
99+
width=resolution,
100+
scale=[
101+
1.0 / (0.229 * 255.0),
102+
1.0 / (0.224 * 255.0),
103+
1.0 / (0.225 * 255.0),
104+
],
105+
offset=[-0.485 / 0.229, -0.456 / 0.224, -0.406 / 0.225],
106+
)
107+
preprocessor = SegFormerImageSegmenterPreprocessor(image_converter)
98108
segformer_segmenter = keras_hub.models.SegFormerImageSegmenter(
99109
backbone=segformer_backbone,
100110
num_classes=num_classes,

0 commit comments

Comments
 (0)