|
| 1 | +import tensorflow as tf |
| 2 | +import numpy as np |
| 3 | +import tensorflow.examples.tutorials.mnist.input_data as input_data |
| 4 | +import matplotlib.pyplot as plt |
| 5 | +import functools |
| 6 | + |
| 7 | +# References |
| 8 | +# https://danijar.com/structuring-your-tensorflow-models/ |
| 9 | +# https://jmetzen.github.io/2015-11-27/vae.html |
| 10 | + |
| 11 | +def xavier_init(fan_in, fan_out, constant = 1): |
| 12 | + with tf.name_scope('xavier'): |
| 13 | + low = -constant * np.sqrt(6.0 / (fan_in + fan_out)) |
| 14 | + high = constant * np.sqrt(6.0 / (fan_in + fan_out)) |
| 15 | + return tf.random_uniform((fan_in, fan_out), |
| 16 | + minval = low, maxval = high, |
| 17 | + dtype = tf.float32) |
| 18 | + |
| 19 | +def doublewrap(function): |
| 20 | + """ |
| 21 | + A decorator decorator, allowing to use the decorator to be used without |
| 22 | + parentheses if not arguments are provided. All arguments must be optional. |
| 23 | + """ |
| 24 | + @functools.wraps(function) |
| 25 | + def decorator(*args, **kwargs): |
| 26 | + if len(args) == 1 and len(kwargs) == 0 and callable(args[0]): |
| 27 | + return function(args[0]) |
| 28 | + else: |
| 29 | + return lambda wrapee: function(wrapee, *args, **kwargs) |
| 30 | + return decorator |
| 31 | + |
| 32 | + |
| 33 | +@doublewrap |
| 34 | +def define_scope(function, scope=None, *args, **kwargs): |
| 35 | + """ |
| 36 | + A decorator for functions that define TensorFlow operations. The wrapped |
| 37 | + function will only be executed once. Subsequent calls to it will directly |
| 38 | + return the result so that operations are added to the graph only once. |
| 39 | + The operations added by the function live within a tf.variable_scope(). If |
| 40 | + this decorator is used with arguments, they will be forwarded to the |
| 41 | + variable scope. The scope name defaults to the name of the wrapped |
| 42 | + function. |
| 43 | + """ |
| 44 | + attribute = '_cache_' + function.__name__ |
| 45 | + name = scope or function.__name__ |
| 46 | + @property |
| 47 | + @functools.wraps(function) |
| 48 | + def decorator(self): |
| 49 | + if not hasattr(self, attribute): |
| 50 | + with tf.variable_scope(name, *args, **kwargs): |
| 51 | + setattr(self, attribute, function(self)) |
| 52 | + return getattr(self, attribute) |
| 53 | + return decorator |
| 54 | + |
| 55 | + |
| 56 | + |
| 57 | +class Autoencoder: |
| 58 | + def __init__(self, image, enc_dimensions = [784, 500, 200, 64], dec_dimensions = [64, 200, 500, 784]): |
| 59 | + self.image = image |
| 60 | + self.enc_dimensions = enc_dimensions |
| 61 | + self.dec_dimensions = dec_dimensions |
| 62 | + self.prediction |
| 63 | + self.optimize |
| 64 | + self.error |
| 65 | + |
| 66 | + @define_scope |
| 67 | + def prediction(self, input = 4): |
| 68 | + current_input = self.image |
| 69 | + print('made it') |
| 70 | + # ENCODER |
| 71 | + encoder = [] |
| 72 | + with tf.name_scope('Encoder'): |
| 73 | + for layer_i, n_output in enumerate(self.enc_dimensions[1:]): |
| 74 | + with tf.name_scope('enc_layer' + str(layer_i)): |
| 75 | + n_input = int(current_input.get_shape()[1]) |
| 76 | + W = tf.Variable(xavier_init(n_input, n_output), name = 'weight'+str(layer_i)) |
| 77 | + b = tf.Variable(tf.zeros(shape=(1, n_output)), name = 'bias'+str(layer_i)) |
| 78 | + encoder.append(W) |
| 79 | + current_input = tf.nn.elu(tf.add(tf.matmul(current_input, W), b), |
| 80 | + name='enclayer' + str(layer_i)) |
| 81 | + |
| 82 | + # DECODER |
| 83 | + with tf.name_scope('Decoder'): |
| 84 | + for layer_i, n_output in enumerate(self.dec_dimensions[1:]): |
| 85 | + with tf.name_scope('dec_layer' + str(layer_i)): |
| 86 | + n_input = int(current_input.get_shape()[1]) |
| 87 | + W = tf.Variable(xavier_init(n_input, n_output), name = 'weight'+str(layer_i)) |
| 88 | + b = tf.Variable(tf.zeros(shape=(1, n_output)), name = 'bias'+str(layer_i)) |
| 89 | + encoder.append(W) |
| 90 | + current_input = tf.nn.elu(tf.add(tf.matmul(current_input, W), b), |
| 91 | + name='declayer' + str(layer_i)) |
| 92 | + |
| 93 | + return current_input |
| 94 | + |
| 95 | + @define_scope |
| 96 | + def optimize(self): |
| 97 | + optimizer = tf.train.AdamOptimizer(learning_rate=0.001) |
| 98 | + return optimizer.minimize(self.error) |
| 99 | + |
| 100 | + @define_scope |
| 101 | + def error(self): |
| 102 | + error = tf.reduce_sum(tf.pow(tf.sub(self.prediction, self.image), 2)) |
| 103 | + tf.summary.scalar('error', error) |
| 104 | + return error |
| 105 | + |
| 106 | +def main(): |
| 107 | + mnist = input_data.read_data_sets('./mnist/', one_hot=True) |
| 108 | + mean_img = np.mean(mnist.train.images, axis=0) |
| 109 | + image = tf.placeholder(tf.float32, [None, 784]) |
| 110 | + autoencoder = Autoencoder(image) |
| 111 | + |
| 112 | + merged_summary = tf.summary.merge_all() |
| 113 | + sess = tf.Session() |
| 114 | + logpath = '/tmp/tensorflow_logs/example/1' |
| 115 | + test_writer = tf.summary.FileWriter(logpath, graph=tf.get_default_graph()) |
| 116 | + #train_writer = tf.summary.FileWriter('/train') |
| 117 | + sess.run(tf.global_variables_initializer()) |
| 118 | + |
| 119 | + for epoch_i in range(1): |
| 120 | + test_images = mnist.test.images |
| 121 | + test = np.array([img - mean_img for img in test_images]) |
| 122 | + error, summary = sess.run(fetches=[autoencoder.error, merged_summary], feed_dict={image: test}) |
| 123 | + test_writer.add_summary(summary, epoch_i) |
| 124 | + print('Test error {:6.2f}'.format(error)) |
| 125 | + for batch_i in range(60): |
| 126 | + batch_xs, _ = mnist.train.next_batch(100) |
| 127 | + train = np.array([img-mean_img for img in batch_xs]) |
| 128 | + _, summary = sess.run(fetches=[autoencoder.optimize, merged_summary], feed_dict={image: train}) |
| 129 | + #train_writer.add_summary(summary, epoch_i) |
| 130 | + |
| 131 | + # Plot example reconstructions |
| 132 | + n_examples = 15 |
| 133 | + test_xs, _ = mnist.test.next_batch(n_examples) |
| 134 | + test_xs_norm = np.array([img - mean_img for img in test_xs]) |
| 135 | + recon = sess.run(autoencoder.prediction, feed_dict={image: test_xs_norm}) |
| 136 | + fig, axs = plt.subplots(2, n_examples, figsize=(10, 2)) |
| 137 | + for example_i in range(n_examples): |
| 138 | + axs[0][example_i].imshow( |
| 139 | + np.reshape(test_xs[example_i, :], (28, 28))) |
| 140 | + axs[1][example_i].imshow( |
| 141 | + np.reshape([recon[example_i, :]], (28, 28))) |
| 142 | + fig.show() |
| 143 | + plt.draw() |
| 144 | + plt.waitforbuttonpress() |
| 145 | + |
| 146 | +if __name__ == '__main__': |
| 147 | + main() |
0 commit comments