Skip to content

Commit ec201e8

Browse files
committed
Changed filestructure. Added references. Update readme
1 parent b45b9f8 commit ec201e8

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

165 files changed

+639
-32
lines changed

Adversarial Variational Bayes/avb.py

+168
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,168 @@
1+
import matplotlib
2+
matplotlib.use('Agg')
3+
import tensorflow as tf
4+
import numpy as np
5+
import matplotlib.pyplot as plt
6+
import matplotlib.gridspec as gridspec
7+
import os
8+
from tensorflow.examples.tutorials.mnist import input_data
9+
slim = tf.contrib.slim
10+
ds = tf.contrib.distributions
11+
st = tf.contrib.bayesflow.stochastic_tensor
12+
13+
# Reference: https://gist.github.com/poolio/b71eb943d6537d01f46e7b20e9225149
14+
15+
mnist = input_data.read_data_sets('./mnist/', one_hot=True)
16+
mb_size = 100
17+
z_dim = 256
18+
eps_dim = mnist.train.images.shape[1]
19+
X_dim = mnist.train.images.shape[1]
20+
y_dim = mnist.train.labels.shape[1]
21+
h_dim = 200
22+
c = 0
23+
lr = 1e-3
24+
25+
def plot(samples):
26+
fig = plt.figure(figsize=(4, 4))
27+
gs = gridspec.GridSpec(4, 4)
28+
gs.update(wspace=0.05, hspace=0.05)
29+
30+
for i, sample in enumerate(samples):
31+
ax = plt.subplot(gs[i])
32+
plt.axis('off')
33+
ax.set_xticklabels([])
34+
ax.set_yticklabels([])
35+
ax.set_aspect('equal')
36+
plt.imshow(sample.reshape(28, 28), cmap='Greys_r')
37+
38+
return fig
39+
40+
41+
def xavier_init(size):
42+
in_dim = size[0]
43+
xavier_stddev = 1. / tf.sqrt(in_dim / 2.)
44+
return tf.random_normal(shape=size, stddev=xavier_stddev)
45+
46+
47+
""" Q(z|X,eps) """
48+
X = tf.placeholder(tf.float32, shape=[None, X_dim], name='X')
49+
eps = tf.placeholder(tf.float32, shape=[None, eps_dim], name='eps')
50+
X_eps = tf.placeholder(tf.float32, shape=[None, X_dim], name='X_eps')
51+
52+
Q_W1 = tf.Variable(xavier_init([X_dim + eps_dim, h_dim]))
53+
Q_b1 = tf.Variable(tf.zeros(shape=[h_dim]))
54+
55+
Q_W2 = tf.Variable(xavier_init([h_dim, z_dim]))
56+
Q_b2 = tf.Variable(tf.zeros(shape=[z_dim]))
57+
58+
theta_Q = [Q_W1, Q_W2, Q_b1, Q_b2]
59+
60+
def Q(X, eps):
61+
inputs = tf.concat([X, eps], 1)
62+
h = tf.nn.elu(tf.matmul(inputs, Q_W1) + Q_b1)
63+
z = tf.matmul(h, Q_W2) + Q_b2
64+
return z
65+
66+
67+
""" P(X|z) """
68+
P_W1 = tf.Variable(xavier_init([z_dim, h_dim]))
69+
P_b1 = tf.Variable(tf.zeros(shape=[h_dim]))
70+
71+
P_W2 = tf.Variable(xavier_init([h_dim, X_dim]))
72+
P_b2 = tf.Variable(tf.zeros(shape=[X_dim]))
73+
74+
theta_P = [P_W1, P_W2, P_b1, P_b2]
75+
76+
def P(z):
77+
h = tf.nn.elu(tf.matmul(z, P_W1) + P_b1)
78+
logits = tf.matmul(h, P_W2) + P_b2
79+
prob = tf.nn.sigmoid(logits)
80+
return prob
81+
82+
83+
""" D(z) """
84+
D_W1 = tf.Variable(xavier_init([z_dim + eps_dim, h_dim]))
85+
D_b1 = tf.Variable(tf.zeros(shape=[h_dim]))
86+
87+
D_W2 = tf.Variable(xavier_init([h_dim, 1]))
88+
D_b2 = tf.Variable(tf.zeros(shape=[1]))
89+
90+
theta_D = [D_W1, D_W2, D_b1, D_b2]
91+
92+
# Assumed to be good
93+
def D(X, z):
94+
h = tf.nn.elu(tf.matmul(tf.concat([X, z], 1), D_W1) + D_b1)
95+
out = tf.matmul(h, D_W2) + D_b2
96+
return out
97+
98+
99+
""" Training """
100+
z_sample = Q(X, eps)
101+
z_sample_fake = Q(X_eps, eps)
102+
p_prob = P(z_sample)
103+
104+
X_samples = P(z_sample_fake)
105+
106+
# Adversarial loss to approx. Q(z|X)
107+
D_real = D(X, z_sample)
108+
D_fake = D(X, z_sample_fake)
109+
110+
G_loss = -(-tf.reduce_mean(D_real) + tf.reduce_mean(tf.log(p_prob)))
111+
g_psi = tf.reduce_mean(
112+
tf.nn.sigmoid_cross_entropy_with_logits(labels=D_real, logits=tf.ones_like(D_real)) +
113+
tf.nn.sigmoid_cross_entropy_with_logits(labels=D_fake, logits=tf.zeros_like(D_fake)))
114+
115+
opt = tf.train.AdamOptimizer(1e-3, beta1=0.5)
116+
G_solver = opt.minimize(G_loss, var_list=theta_P + theta_Q)
117+
D_solver = opt.minimize(g_psi, var_list=theta_D)
118+
119+
sess = tf.Session()
120+
sess.run(tf.global_variables_initializer())
121+
122+
if not os.path.exists('out/'):
123+
os.makedirs('out/')
124+
if not os.path.exists('out2/'):
125+
os.makedirs('out2/')
126+
127+
#points_per_class = 300
128+
#labels = np.concatenate([[i] * points_per_class for i in xrange(params['input_dim'])])
129+
#np_data = np.eye(params['input_dim'], dtype=np.float32)[labels]
130+
131+
132+
i = 0
133+
for it in range(1000000):
134+
X_mb, _ = mnist.train.next_batch(mb_size)
135+
eps_mb = np.random.randn(mb_size, eps_dim)
136+
z_mb = np.random.randn(mb_size, eps_dim)
137+
138+
_, G_loss_curr = sess.run([G_solver, G_loss],
139+
feed_dict={X: X_mb, eps: eps_mb})
140+
141+
_, D_loss_curr = sess.run([D_solver, g_psi],
142+
feed_dict={X: X_mb, eps: eps_mb, X_eps: z_mb})
143+
144+
if it % 100 == 0:
145+
print('Iter: {}; G_loss: {:.4}; D_loss: {:.4}'
146+
.format(it, G_loss_curr, D_loss_curr))
147+
eps_mb = np.random.randn(4, eps_dim)
148+
X_mb, _ = mnist.train.next_batch(4)
149+
150+
samples = sess.run(X_samples, feed_dict={X_eps: np.random.randn(16, eps_dim), eps: np.random.randn(16, eps_dim)})
151+
152+
fig = plot(samples)
153+
plt.savefig('out/{}.png'.format(str(i).zfill(3)), bbox_inches='tight')
154+
plt.close(fig)
155+
156+
reconstructed, latent_rep = sess.run([p_prob, z_sample], feed_dict={X: X_mb, eps: eps_mb})
157+
n_examples = 4
158+
fig, axs = plt.subplots(3, n_examples, figsize=(4, 3))
159+
for example_i in range(n_examples):
160+
axs[0][example_i].imshow(
161+
np.reshape(X_mb[example_i, :], (28, 28)))
162+
axs[1][example_i].imshow(
163+
np.reshape(latent_rep[example_i, :], (16,16)))
164+
axs[2][example_i].imshow(
165+
np.reshape([reconstructed[example_i, :]], (28, 28)))
166+
plt.savefig('out2/{}.png'.format(str(i).zfill(3)), bbox_inches='tight')
167+
plt.close(fig)
168+
i += 1
31.3 KB
31.4 KB
31.4 KB
31.1 KB
30.9 KB
30.8 KB
31 KB
28.9 KB
29.6 KB
29.8 KB
28.2 KB
29.8 KB
28.5 KB
27.8 KB
28.3 KB
28.6 KB
27.4 KB
29.1 KB
27.7 KB
28.7 KB
29.1 KB
28.9 KB
29.2 KB
18 KB
31 KB
30.1 KB
30.1 KB
30.2 KB
29.8 KB
30.4 KB
29 KB
27.6 KB
30.6 KB
29.4 KB
29.8 KB
30.3 KB
30.1 KB
29.9 KB
29.8 KB
30.4 KB
29 KB
27.4 KB
29 KB
28.5 KB
29.5 KB
29.2 KB
29.5 KB
28.8 KB
28.1 KB
27.8 KB
29.9 KB
28.7 KB
27.4 KB
27.4 KB
27.1 KB
28.6 KB
29.4 KB
28.9 KB
29.3 KB
26.4 KB
29 KB
26.9 KB
29.7 KB
28.3 KB
28.4 KB
27.7 KB
28.1 KB
27.4 KB
26.6 KB
28.4 KB
12.5 KB
13.4 KB
12.8 KB
12.9 KB
12.8 KB
13.4 KB
14.2 KB
13.1 KB
12.6 KB
12 KB
12.9 KB
12.6 KB
13.1 KB
13.1 KB
12.7 KB
13.5 KB
13.3 KB
12.2 KB
13.3 KB
31.9 KB
24.2 KB
16.9 KB
16.6 KB
16.4 KB
16.6 KB
16.3 KB
17.1 KB
15.6 KB
15.6 KB
15.7 KB
15.4 KB
15.7 KB
16.1 KB
15.6 KB
15.6 KB
14.9 KB
16.5 KB
15.7 KB
15.9 KB
15.8 KB
15.7 KB
16.2 KB
18.6 KB
18.9 KB
17.8 KB
18.5 KB
21 KB
19.2 KB
17.2 KB
17.1 KB
18.2 KB
17.8 KB
19.6 KB
18.8 KB
18 KB
19.8 KB
19.3 KB
20 KB
16.9 KB
18.3 KB
17.1 KB
18.3 KB
19.6 KB
18.8 KB
20.2 KB
18.7 KB
19.4 KB
18.7 KB
17.7 KB
18.2 KB
18.4 KB
18.9 KB
17.4 KB
17.9 KB
19.7 KB
17.4 KB
18.8 KB
18.7 KB
18 KB
18.1 KB
16.4 KB
19.5 KB
19.5 KB
18.6 KB
18 KB
17.3 KB
18.3 KB
17.7 KB
19.2 KB

Autoencoder/autoencoder_class.py

+147
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,147 @@
1+
import tensorflow as tf
2+
import numpy as np
3+
import tensorflow.examples.tutorials.mnist.input_data as input_data
4+
import matplotlib.pyplot as plt
5+
import functools
6+
7+
# References
8+
# https://danijar.com/structuring-your-tensorflow-models/
9+
# https://jmetzen.github.io/2015-11-27/vae.html
10+
11+
def xavier_init(fan_in, fan_out, constant = 1):
12+
with tf.name_scope('xavier'):
13+
low = -constant * np.sqrt(6.0 / (fan_in + fan_out))
14+
high = constant * np.sqrt(6.0 / (fan_in + fan_out))
15+
return tf.random_uniform((fan_in, fan_out),
16+
minval = low, maxval = high,
17+
dtype = tf.float32)
18+
19+
def doublewrap(function):
20+
"""
21+
A decorator decorator, allowing to use the decorator to be used without
22+
parentheses if not arguments are provided. All arguments must be optional.
23+
"""
24+
@functools.wraps(function)
25+
def decorator(*args, **kwargs):
26+
if len(args) == 1 and len(kwargs) == 0 and callable(args[0]):
27+
return function(args[0])
28+
else:
29+
return lambda wrapee: function(wrapee, *args, **kwargs)
30+
return decorator
31+
32+
33+
@doublewrap
34+
def define_scope(function, scope=None, *args, **kwargs):
35+
"""
36+
A decorator for functions that define TensorFlow operations. The wrapped
37+
function will only be executed once. Subsequent calls to it will directly
38+
return the result so that operations are added to the graph only once.
39+
The operations added by the function live within a tf.variable_scope(). If
40+
this decorator is used with arguments, they will be forwarded to the
41+
variable scope. The scope name defaults to the name of the wrapped
42+
function.
43+
"""
44+
attribute = '_cache_' + function.__name__
45+
name = scope or function.__name__
46+
@property
47+
@functools.wraps(function)
48+
def decorator(self):
49+
if not hasattr(self, attribute):
50+
with tf.variable_scope(name, *args, **kwargs):
51+
setattr(self, attribute, function(self))
52+
return getattr(self, attribute)
53+
return decorator
54+
55+
56+
57+
class Autoencoder:
58+
def __init__(self, image, enc_dimensions = [784, 500, 200, 64], dec_dimensions = [64, 200, 500, 784]):
59+
self.image = image
60+
self.enc_dimensions = enc_dimensions
61+
self.dec_dimensions = dec_dimensions
62+
self.prediction
63+
self.optimize
64+
self.error
65+
66+
@define_scope
67+
def prediction(self, input = 4):
68+
current_input = self.image
69+
print('made it')
70+
# ENCODER
71+
encoder = []
72+
with tf.name_scope('Encoder'):
73+
for layer_i, n_output in enumerate(self.enc_dimensions[1:]):
74+
with tf.name_scope('enc_layer' + str(layer_i)):
75+
n_input = int(current_input.get_shape()[1])
76+
W = tf.Variable(xavier_init(n_input, n_output), name = 'weight'+str(layer_i))
77+
b = tf.Variable(tf.zeros(shape=(1, n_output)), name = 'bias'+str(layer_i))
78+
encoder.append(W)
79+
current_input = tf.nn.elu(tf.add(tf.matmul(current_input, W), b),
80+
name='enclayer' + str(layer_i))
81+
82+
# DECODER
83+
with tf.name_scope('Decoder'):
84+
for layer_i, n_output in enumerate(self.dec_dimensions[1:]):
85+
with tf.name_scope('dec_layer' + str(layer_i)):
86+
n_input = int(current_input.get_shape()[1])
87+
W = tf.Variable(xavier_init(n_input, n_output), name = 'weight'+str(layer_i))
88+
b = tf.Variable(tf.zeros(shape=(1, n_output)), name = 'bias'+str(layer_i))
89+
encoder.append(W)
90+
current_input = tf.nn.elu(tf.add(tf.matmul(current_input, W), b),
91+
name='declayer' + str(layer_i))
92+
93+
return current_input
94+
95+
@define_scope
96+
def optimize(self):
97+
optimizer = tf.train.AdamOptimizer(learning_rate=0.001)
98+
return optimizer.minimize(self.error)
99+
100+
@define_scope
101+
def error(self):
102+
error = tf.reduce_sum(tf.pow(tf.sub(self.prediction, self.image), 2))
103+
tf.summary.scalar('error', error)
104+
return error
105+
106+
def main():
107+
mnist = input_data.read_data_sets('./mnist/', one_hot=True)
108+
mean_img = np.mean(mnist.train.images, axis=0)
109+
image = tf.placeholder(tf.float32, [None, 784])
110+
autoencoder = Autoencoder(image)
111+
112+
merged_summary = tf.summary.merge_all()
113+
sess = tf.Session()
114+
logpath = '/tmp/tensorflow_logs/example/1'
115+
test_writer = tf.summary.FileWriter(logpath, graph=tf.get_default_graph())
116+
#train_writer = tf.summary.FileWriter('/train')
117+
sess.run(tf.global_variables_initializer())
118+
119+
for epoch_i in range(1):
120+
test_images = mnist.test.images
121+
test = np.array([img - mean_img for img in test_images])
122+
error, summary = sess.run(fetches=[autoencoder.error, merged_summary], feed_dict={image: test})
123+
test_writer.add_summary(summary, epoch_i)
124+
print('Test error {:6.2f}'.format(error))
125+
for batch_i in range(60):
126+
batch_xs, _ = mnist.train.next_batch(100)
127+
train = np.array([img-mean_img for img in batch_xs])
128+
_, summary = sess.run(fetches=[autoencoder.optimize, merged_summary], feed_dict={image: train})
129+
#train_writer.add_summary(summary, epoch_i)
130+
131+
# Plot example reconstructions
132+
n_examples = 15
133+
test_xs, _ = mnist.test.next_batch(n_examples)
134+
test_xs_norm = np.array([img - mean_img for img in test_xs])
135+
recon = sess.run(autoencoder.prediction, feed_dict={image: test_xs_norm})
136+
fig, axs = plt.subplots(2, n_examples, figsize=(10, 2))
137+
for example_i in range(n_examples):
138+
axs[0][example_i].imshow(
139+
np.reshape(test_xs[example_i, :], (28, 28)))
140+
axs[1][example_i].imshow(
141+
np.reshape([recon[example_i, :]], (28, 28)))
142+
fig.show()
143+
plt.draw()
144+
plt.waitforbuttonpress()
145+
146+
if __name__ == '__main__':
147+
main()

0 commit comments

Comments
 (0)