@@ -38,12 +38,19 @@ def extract_image(filename, num_channels):
38
38
image = tf .image .convert_image_dtype (image , tf .float32 )
39
39
return image
40
40
41
- def per_pixel_mean_stddev (dataset ):
41
+ def per_pixel_mean_stddev (dataset , image_size ):
42
42
"""
43
43
Compute the mean of each pixel over the entire dataset.
44
44
45
45
"""
46
- return
46
+ maximum = image_size * image_size * 3
47
+ initial_state = tf .constant ([0. ]* maximum )
48
+ count = dataset .reduce (0 , lambda x , _ : x + 1 )
49
+ dataset_resized = dataset .map (lambda x : resize ([x ], image_size ))
50
+ dataset_per_pixel = dataset_resized .map (lambda x : tf .reshape (x , [- 1 ]))
51
+ pixel_sum = dataset_per_pixel .reduce (initial_state , lambda x , y : x + y )
52
+ pixel_mean = tf .divide (pixel_sum , tf .to_float (count ))
53
+ return pixel_mean
47
54
48
55
def per_channel_mean_stddev (dataset ):
49
56
"""
@@ -53,7 +60,8 @@ def channel_mean_stddev(decoded_image):
53
60
means = tf .reduce_mean (decoded_image , axis = [0 ,1 ])
54
61
stddev = tf .sqrt (tf .reduce_mean (tf .square (decoded_image - means ), axis = [0 ,1 ]))
55
62
return tf .stack ([means , stddev ])
56
- return dataset .map (lambda x : channel_mean_stddev (x ))
63
+ per_channel_mean_stddev_dataset = dataset .map (lambda x : channel_mean_stddev (x ))
64
+ return per_channel_mean_stddev_dataset
57
65
58
66
def per_mean_stddev (dataset ):
59
67
"""
@@ -72,12 +80,17 @@ def encode_stats(alpha):
72
80
"""
73
81
pass
74
82
83
+ def resize (image , image_size ):
84
+ rank_assertion = tf .Assert (
85
+ tf .equal (tf .rank (image ), 4 ),
86
+ ['Rank of image must be equal to 4.' ])
87
+ with tf .control_dependencies ([rank_assertion ]):
88
+ image = tf .image .resize_bilinear (image , [image_size , image_size ])[0 ]
89
+ return image
90
+
91
+
75
92
a , b = load_images ("D:/MURA-v1.1/train/*/*/*/*.png" , 3 ,image_extension = 'png' )
76
- a = per_channel_mean_stddev (a )
77
- a_iter = a .make_one_shot_iterator ()
78
- image = a_iter .get_next ()
79
- b_iter = b .make_one_shot_iterator ()
80
- label = b_iter .get_next ()
93
+ a = per_pixel_mean_stddev (a , 300 )
81
94
with tf .Session () as sess :
82
- for i in range ( 10 ):
83
- print (sess . run ([ image , label ]) )
95
+ sess . run ( a )
96
+ print (a )
0 commit comments