Skip to content

Commit e6ac8bf

Browse files
authored
363-update gaussian 1d (#1020)
* update gaussian 1d Signed-off-by: Wenqi Li <[email protected]> * update gaussian related test cases Signed-off-by: Wenqi Li <[email protected]> * [MONAI] python code formatting Signed-off-by: monai-bot <[email protected]>
1 parent 197f501 commit e6ac8bf

13 files changed

+154
-89
lines changed

monai/networks/layers/convutils.py

+11-11
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from typing import Sequence, Tuple, Union
1313

1414
import numpy as np
15+
import torch
1516

1617
__all__ = ["same_padding", "stride_minus_kernel_padding", "calculate_out_shape", "gaussian_1d"]
1718

@@ -77,7 +78,7 @@ def calculate_out_shape(
7778
return out_shape if len(out_shape) > 1 else out_shape[0]
7879

7980

80-
def gaussian_1d(sigma: float, truncated: float = 4.0) -> np.ndarray:
81+
def gaussian_1d(sigma: Union[float, torch.Tensor], truncated: float = 4.0) -> np.ndarray:
8182
"""
8283
one dimensional gaussian kernel.
8384
@@ -86,18 +87,17 @@ def gaussian_1d(sigma: float, truncated: float = 4.0) -> np.ndarray:
8687
truncated: tail length
8788
8889
Raises:
89-
ValueError: When ``sigma`` is nonpositive.
90+
ValueError: When ``sigma`` is non-positive.
9091
9192
Returns:
92-
1D numpy array
93+
1D torch tensor
9394
9495
"""
95-
if sigma <= 0:
96-
raise ValueError(f"sigma must be positive, got {sigma}.")
97-
96+
sigma = torch.as_tensor(sigma).float()
97+
if sigma <= 0 or truncated <= 0:
98+
raise ValueError(f"sigma and truncated must be positive, got {sigma} and {truncated}.")
9899
tail = int(sigma * truncated + 0.5)
99-
sigma2 = sigma * sigma
100-
x = np.arange(-tail, tail + 1)
101-
out = np.exp(-0.5 / sigma2 * x ** 2)
102-
out /= out.sum()
103-
return out
100+
x = torch.arange(-tail, tail + 1).float()
101+
t = 1 / (torch.tensor(2.0).sqrt() * sigma)
102+
out = 0.5 * ((t * (x + 0.5)).erf() - (t * (x - 0.5)).erf())
103+
return out.clamp(min=0)

tests/test_gaussian.py

+54
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
# Copyright 2020 MONAI Consortium
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
# http://www.apache.org/licenses/LICENSE-2.0
6+
# Unless required by applicable law or agreed to in writing, software
7+
# distributed under the License is distributed on an "AS IS" BASIS,
8+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
# See the License for the specific language governing permissions and
10+
# limitations under the License.
11+
12+
import unittest
13+
14+
import numpy as np
15+
import torch
16+
17+
from monai.networks.layers.convutils import gaussian_1d
18+
19+
20+
class TestGaussian1d(unittest.TestCase):
21+
def test_gaussian(self):
22+
np.testing.assert_allclose(
23+
gaussian_1d(0.5, 8),
24+
torch.tensor(
25+
[
26+
0.0000e00,
27+
2.9802e-07,
28+
1.3496e-03,
29+
1.5731e-01,
30+
6.8269e-01,
31+
1.5731e-01,
32+
1.3496e-03,
33+
2.9802e-07,
34+
0.0000e00,
35+
]
36+
),
37+
rtol=1e-4,
38+
)
39+
40+
np.testing.assert_allclose(
41+
gaussian_1d(1, 1),
42+
torch.tensor([0.24173, 0.382925, 0.24173]),
43+
rtol=1e-4,
44+
)
45+
46+
def test_wrong_sigma(self):
47+
with self.assertRaises(ValueError):
48+
gaussian_1d(-1, 10)
49+
with self.assertRaises(ValueError):
50+
gaussian_1d(1, -10)
51+
52+
53+
if __name__ == "__main__":
54+
unittest.main()

tests/test_gaussian_filter.py

+31-29
Original file line numberDiff line numberDiff line change
@@ -25,16 +25,16 @@ def test_1d(self):
2525
[
2626
[
2727
[
28-
0.56658804,
29-
0.69108766,
30-
0.79392236,
31-
0.86594427,
32-
0.90267116,
33-
0.9026711,
34-
0.8659443,
35-
0.7939224,
36-
0.6910876,
37-
0.56658804,
28+
0.5654129,
29+
0.68915915,
30+
0.79146194,
31+
0.8631974,
32+
0.8998163,
33+
0.8998163,
34+
0.8631973,
35+
0.79146194,
36+
0.6891592,
37+
0.5654129,
3838
]
3939
]
4040
]
@@ -49,9 +49,9 @@ def test_2d(self):
4949
[
5050
[
5151
[
52-
[0.13380532, 0.14087981, 0.13380532],
53-
[0.14087981, 0.14832835, 0.14087981],
54-
[0.13380532, 0.14087981, 0.13380532],
52+
[0.13239081, 0.13932934, 0.13239081],
53+
[0.13932936, 0.14663152, 0.13932936],
54+
[0.13239081, 0.13932934, 0.13239081],
5555
]
5656
]
5757
]
@@ -65,29 +65,30 @@ def test_2d(self):
6565
def test_3d(self):
6666
a = torch.ones(1, 1, 4, 3, 4)
6767
g = GaussianFilter(3, 3, 3).to(torch.device("cpu:0"))
68+
6869
expected = np.array(
6970
[
7071
[
7172
[
7273
[
73-
[0.07294822, 0.08033235, 0.08033235, 0.07294822],
74-
[0.07680509, 0.08457965, 0.08457965, 0.07680509],
75-
[0.07294822, 0.08033235, 0.08033235, 0.07294822],
74+
[0.07189433, 0.07911152, 0.07911152, 0.07189433],
75+
[0.07566228, 0.08325771, 0.08325771, 0.07566228],
76+
[0.07189433, 0.07911152, 0.07911152, 0.07189433],
7677
],
7778
[
78-
[0.08033235, 0.08846395, 0.08846395, 0.08033235],
79-
[0.08457965, 0.09314119, 0.09314119, 0.08457966],
80-
[0.08033235, 0.08846396, 0.08846396, 0.08033236],
79+
[0.07911152, 0.08705322, 0.08705322, 0.07911152],
80+
[0.08325771, 0.09161563, 0.09161563, 0.08325771],
81+
[0.07911152, 0.08705322, 0.08705322, 0.07911152],
8182
],
8283
[
83-
[0.08033235, 0.08846395, 0.08846395, 0.08033235],
84-
[0.08457965, 0.09314119, 0.09314119, 0.08457966],
85-
[0.08033235, 0.08846396, 0.08846396, 0.08033236],
84+
[0.07911152, 0.08705322, 0.08705322, 0.07911152],
85+
[0.08325771, 0.09161563, 0.09161563, 0.08325771],
86+
[0.07911152, 0.08705322, 0.08705322, 0.07911152],
8687
],
8788
[
88-
[0.07294822, 0.08033235, 0.08033235, 0.07294822],
89-
[0.07680509, 0.08457965, 0.08457965, 0.07680509],
90-
[0.07294822, 0.08033235, 0.08033235, 0.07294822],
89+
[0.07189433, 0.07911152, 0.07911152, 0.07189433],
90+
[0.07566228, 0.08325771, 0.08325771, 0.07566228],
91+
[0.07189433, 0.07911152, 0.07911152, 0.07189433],
9192
],
9293
]
9394
]
@@ -98,14 +99,15 @@ def test_3d(self):
9899
def test_3d_sigmas(self):
99100
a = torch.ones(1, 1, 4, 3, 2)
100101
g = GaussianFilter(3, [3, 2, 1], 3).to(torch.device("cpu:0"))
102+
101103
expected = np.array(
102104
[
103105
[
104106
[
105-
[[0.1422854, 0.1422854], [0.15806103, 0.15806103], [0.1422854, 0.1422854]],
106-
[[0.15668818, 0.15668817], [0.17406069, 0.17406069], [0.15668818, 0.15668817]],
107-
[[0.15668818, 0.15668817], [0.17406069, 0.17406069], [0.15668818, 0.15668817]],
108-
[[0.1422854, 0.1422854], [0.15806103, 0.15806103], [0.1422854, 0.1422854]],
107+
[[0.13690521, 0.13690521], [0.15181276, 0.15181276], [0.13690521, 0.13690521]],
108+
[[0.1506486, 0.15064861], [0.16705267, 0.16705267], [0.1506486, 0.15064861]],
109+
[[0.1506486, 0.15064861], [0.16705267, 0.16705267], [0.1506486, 0.15064861]],
110+
[[0.13690521, 0.13690521], [0.15181276, 0.15181276], [0.13690521, 0.13690521]],
109111
]
110112
]
111113
]

tests/test_gaussian_sharpen.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,8 @@
2121
np.array([[[1, 1, 1], [2, 2, 2], [3, 3, 3]], [[4, 4, 4], [5, 5, 5], [6, 6, 6]]]),
2222
np.array(
2323
[
24-
[[4.0335875, 3.362756, 4.0335875], [3.588128, 2.628216, 3.588128], [4.491922, 3.8134987, 4.491922]],
25-
[[10.427719, 8.744948, 10.427719], [8.97032, 6.5705404, 8.970321], [10.886056, 9.195692, 10.886056]],
24+
[[4.1081963, 3.4950666, 4.1081963], [3.7239995, 2.8491793, 3.7239995], [4.569839, 3.9529324, 4.569839]],
25+
[[10.616725, 9.081067, 10.616725], [9.309998, 7.12295, 9.309998], [11.078365, 9.538931, 11.078365]],
2626
]
2727
),
2828
]
@@ -32,8 +32,8 @@
3232
np.array([[[1, 1, 1], [2, 2, 2], [3, 3, 3]], [[4, 4, 4], [5, 5, 5], [6, 6, 6]]]),
3333
np.array(
3434
[
35-
[[4.146659, 4.392873, 4.146659], [8.031006, 8.804623, 8.031005], [10.127394, 11.669131, 10.127394]],
36-
[[14.852196, 16.439377, 14.852201], [20.077503, 22.011555, 20.077507], [20.832941, 23.715641, 20.832935]],
35+
[[4.513644, 4.869134, 4.513644], [8.467242, 9.4004135, 8.467242], [10.416813, 12.0653515, 10.416813]],
36+
[[15.711488, 17.569994, 15.711488], [21.16811, 23.501041, 21.16811], [21.614658, 24.766209, 21.614658]],
3737
]
3838
),
3939
]
@@ -43,8 +43,8 @@
4343
np.array([[[1, 1, 1], [2, 2, 2], [3, 3, 3]], [[4, 4, 4], [5, 5, 5], [6, 6, 6]]]),
4444
np.array(
4545
[
46-
[[3.129089, 3.0711129, 3.129089], [6.783306, 6.8526435, 6.7833037], [11.901203, 13.098082, 11.901203]],
47-
[[14.401806, 15.198004, 14.401809], [16.958261, 17.131605, 16.958261], [23.17392, 25.224974, 23.17392]],
46+
[[3.3324685, 3.335536, 3.3324673], [7.7666636, 8.16056, 7.7666636], [12.662973, 14.317837, 12.6629715]],
47+
[[15.329051, 16.57557, 15.329051], [19.41665, 20.40139, 19.416655], [24.659554, 27.557873, 24.659554]],
4848
]
4949
),
5050
]

tests/test_gaussian_sharpend.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,8 @@
2121
{"img": np.array([[[1, 1, 1], [2, 2, 2], [3, 3, 3]], [[4, 4, 4], [5, 5, 5], [6, 6, 6]]])},
2222
np.array(
2323
[
24-
[[4.0335875, 3.362756, 4.0335875], [3.588128, 2.628216, 3.588128], [4.491922, 3.8134987, 4.491922]],
25-
[[10.427719, 8.744948, 10.427719], [8.97032, 6.5705404, 8.970321], [10.886056, 9.195692, 10.886056]],
24+
[[4.1081963, 3.4950666, 4.1081963], [3.7239995, 2.8491793, 3.7239995], [4.569839, 3.9529324, 4.569839]],
25+
[[10.616725, 9.081067, 10.616725], [9.309998, 7.12295, 9.309998], [11.078365, 9.538931, 11.078365]],
2626
]
2727
),
2828
]
@@ -32,8 +32,8 @@
3232
{"img": np.array([[[1, 1, 1], [2, 2, 2], [3, 3, 3]], [[4, 4, 4], [5, 5, 5], [6, 6, 6]]])},
3333
np.array(
3434
[
35-
[[4.146659, 4.392873, 4.146659], [8.031006, 8.804623, 8.031005], [10.127394, 11.669131, 10.127394]],
36-
[[14.852196, 16.439377, 14.852201], [20.077503, 22.011555, 20.077507], [20.832941, 23.715641, 20.832935]],
35+
[[4.513644, 4.869134, 4.513644], [8.467242, 9.4004135, 8.467242], [10.416813, 12.0653515, 10.416813]],
36+
[[15.711488, 17.569994, 15.711488], [21.16811, 23.501041, 21.16811], [21.614658, 24.766209, 21.614658]],
3737
]
3838
),
3939
]
@@ -43,8 +43,8 @@
4343
{"img": np.array([[[1, 1, 1], [2, 2, 2], [3, 3, 3]], [[4, 4, 4], [5, 5, 5], [6, 6, 6]]])},
4444
np.array(
4545
[
46-
[[3.129089, 3.0711129, 3.129089], [6.783306, 6.8526435, 6.7833037], [11.901203, 13.098082, 11.901203]],
47-
[[14.401806, 15.198004, 14.401809], [16.958261, 17.131605, 16.958261], [23.17392, 25.224974, 23.17392]],
46+
[[3.3324685, 3.335536, 3.3324673], [7.7666636, 8.16056, 7.7666636], [12.662973, 14.317837, 12.6629715]],
47+
[[15.329051, 16.57557, 15.329051], [19.41665, 20.40139, 19.416655], [24.659554, 27.557873, 24.659554]],
4848
]
4949
),
5050
]

tests/test_gaussian_smooth.py

+10-6
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,12 @@
2121
np.array([[[1, 1, 1], [2, 2, 2], [3, 3, 3]], [[4, 4, 4], [5, 5, 5], [6, 6, 6]]]),
2222
np.array(
2323
[
24-
[[0.5999930, 0.7056839, 0.5999930], [0.8140513, 0.9574494, 0.8140513], [0.7842673, 0.9224188, 0.7842673]],
25-
[[1.6381884, 1.926761, 1.6381884], [2.0351284, 2.3936234, 2.0351284], [1.8224627, 2.143496, 1.8224627]],
24+
[
25+
[0.59167546, 0.69312394, 0.59167546],
26+
[0.7956997, 0.93213004, 0.7956997],
27+
[0.7668002, 0.8982755, 0.7668002],
28+
],
29+
[[1.6105323, 1.8866735, 1.6105323], [1.9892492, 2.3303251, 1.9892492], [1.7856569, 2.091825, 1.7856569]],
2630
]
2731
),
2832
]
@@ -32,8 +36,8 @@
3236
np.array([[[1, 1, 1], [2, 2, 2], [3, 3, 3]], [[4, 4, 4], [5, 5, 5], [6, 6, 6]]]),
3337
np.array(
3438
[
35-
[[0.893521, 0.99973595, 0.893521], [1.785628, 1.9978896, 1.7856278], [2.2983139, 2.5715199, 2.2983139]],
36-
[[3.2873974, 3.6781778, 3.2873974], [4.46407, 4.9947243, 4.46407], [4.69219, 5.2499614, 4.69219]],
39+
[[0.8424794, 0.99864554, 0.8424794], [1.678146, 1.9892154, 1.678146], [1.9889624, 2.3576462, 1.9889624]],
40+
[[2.966061, 3.5158648, 2.966061], [4.1953645, 4.973038, 4.1953645], [4.112544, 4.8748655, 4.1125436]],
3741
]
3842
),
3943
]
@@ -43,8 +47,8 @@
4347
np.array([[[1, 1, 1], [2, 2, 2], [3, 3, 3]], [[4, 4, 4], [5, 5, 5], [6, 6, 6]]]),
4448
np.array(
4549
[
46-
[[0.91108215, 1.0193846, 0.91108215], [1.236127, 1.3830683, 1.236127], [1.1909003, 1.3324654, 1.1909003]],
47-
[[2.4875693, 2.7832723, 2.487569], [3.0903177, 3.457671, 3.0903175], [2.7673876, 3.0963533, 2.7673874]],
50+
[[0.8542037, 1.0125432, 0.8542037], [1.1487541, 1.3616928, 1.1487541], [1.1070318, 1.3122368, 1.1070318]],
51+
[[2.3251305, 2.756128, 2.3251305], [2.8718853, 3.4042323, 2.8718853], [2.5779586, 3.0558217, 2.5779586]],
4852
]
4953
),
5054
]

tests/test_gaussian_smoothd.py

+10-6
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,12 @@
2121
{"img": np.array([[[1, 1, 1], [2, 2, 2], [3, 3, 3]], [[4, 4, 4], [5, 5, 5], [6, 6, 6]]])},
2222
np.array(
2323
[
24-
[[0.5999930, 0.7056839, 0.5999930], [0.8140513, 0.9574494, 0.8140513], [0.7842673, 0.9224188, 0.7842673]],
25-
[[1.6381884, 1.926761, 1.6381884], [2.0351284, 2.3936234, 2.0351284], [1.8224627, 2.143496, 1.8224627]],
24+
[
25+
[0.59167546, 0.69312394, 0.59167546],
26+
[0.7956997, 0.93213004, 0.7956997],
27+
[0.7668002, 0.8982755, 0.7668002],
28+
],
29+
[[1.6105323, 1.8866735, 1.6105323], [1.9892492, 2.3303251, 1.9892492], [1.7856569, 2.091825, 1.7856569]],
2630
]
2731
),
2832
]
@@ -32,8 +36,8 @@
3236
{"img": np.array([[[1, 1, 1], [2, 2, 2], [3, 3, 3]], [[4, 4, 4], [5, 5, 5], [6, 6, 6]]])},
3337
np.array(
3438
[
35-
[[0.893521, 0.99973595, 0.893521], [1.785628, 1.9978896, 1.7856278], [2.2983139, 2.5715199, 2.2983139]],
36-
[[3.2873974, 3.6781778, 3.2873974], [4.46407, 4.9947243, 4.46407], [4.69219, 5.2499614, 4.69219]],
39+
[[0.8424794, 0.99864554, 0.8424794], [1.678146, 1.9892154, 1.678146], [1.9889624, 2.3576462, 1.9889624]],
40+
[[2.966061, 3.5158648, 2.966061], [4.1953645, 4.973038, 4.1953645], [4.112544, 4.8748655, 4.1125436]],
3741
]
3842
),
3943
]
@@ -43,8 +47,8 @@
4347
{"img": np.array([[[1, 1, 1], [2, 2, 2], [3, 3, 3]], [[4, 4, 4], [5, 5, 5], [6, 6, 6]]])},
4448
np.array(
4549
[
46-
[[0.91108215, 1.0193846, 0.91108215], [1.236127, 1.3830683, 1.236127], [1.1909003, 1.3324654, 1.1909003]],
47-
[[2.4875693, 2.7832723, 2.487569], [3.0903177, 3.457671, 3.0903175], [2.7673876, 3.0963533, 2.7673874]],
50+
[[0.8542037, 1.0125432, 0.8542037], [1.1487541, 1.3616928, 1.1487541], [1.1070318, 1.3122368, 1.1070318]],
51+
[[2.3251305, 2.756128, 2.3251305], [2.8718853, 3.4042323, 2.8718853], [2.5779586, 3.0558217, 2.5779586]],
4852
]
4953
),
5054
]

tests/test_rand_elastic_3d.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@
5050
"device": None,
5151
},
5252
{"img": torch.arange(27).reshape((1, 3, 3, 3)), "spatial_size": (2, 2, 2)},
53-
np.array([[[[6.492354, 7.5022864], [9.519528, 10.524366]], [[15.51277, 16.525297], [18.533852, 19.539217]]]]),
53+
np.array([[[[6.4939356, 7.50289], [9.518351, 10.522849]], [[15.512375, 16.523542], [18.531467, 19.53646]]]]),
5454
],
5555
[
5656
{
@@ -63,7 +63,7 @@
6363
"spatial_size": (2, 2, 2),
6464
},
6565
{"img": torch.arange(27).reshape((1, 3, 3, 3)), "mode": "bilinear"},
66-
np.array([[[[5.005563, 9.463698], [9.289501, 13.741863]], [[12.320587, 16.779654], [16.597677, 21.049414]]]]),
66+
np.array([[[[5.0069294, 9.463932], [9.287769, 13.739735]], [[12.319424, 16.777205], [16.594296, 21.045748]]]]),
6767
],
6868
]
6969

tests/test_rand_elasticd_3d.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@
6868
"spatial_size": (2, 2, 2),
6969
},
7070
{"img": torch.arange(27).reshape((1, 3, 3, 3)), "seg": torch.arange(27).reshape((1, 3, 3, 3))},
71-
np.array([[[[6.492354, 7.5022864], [9.519528, 10.524366]], [[15.51277, 16.525297], [18.533852, 19.539217]]]]),
71+
np.array([[[[6.4939356, 7.50289], [9.518351, 10.522849]], [[15.512375, 16.523542], [18.531467, 19.53646]]]]),
7272
],
7373
[
7474
{
@@ -83,7 +83,7 @@
8383
"mode": "bilinear",
8484
},
8585
{"img": torch.arange(27).reshape((1, 3, 3, 3)), "seg": torch.arange(27).reshape((1, 3, 3, 3))},
86-
np.array([[[[5.005563, 9.463698], [9.289501, 13.741863]], [[12.320587, 16.779654], [16.597677, 21.049414]]]]),
86+
np.array([[[[5.0069294, 9.463932], [9.287769, 13.739735]], [[12.319424, 16.777205], [16.594296, 21.045748]]]]),
8787
],
8888
[
8989
{
@@ -99,7 +99,7 @@
9999
},
100100
{"img": torch.arange(27).reshape((1, 3, 3, 3)), "seg": torch.arange(27).reshape((1, 3, 3, 3))},
101101
{
102-
"img": torch.tensor([[[[5.0056, 9.4637], [9.2895, 13.7419]], [[12.3206, 16.7797], [16.5977, 21.0494]]]]),
102+
"img": torch.tensor([[[[5.0069, 9.4639], [9.2878, 13.7397]], [[12.3194, 16.7772], [16.5943, 21.0457]]]]),
103103
"seg": torch.tensor([[[[4.0, 14.0], [7.0, 14.0]], [[9.0, 19.0], [12.0, 22.0]]]]),
104104
},
105105
],

0 commit comments

Comments
 (0)