Skip to content

Commit 3be148b

Browse files
adding more samples
1 parent af7a354 commit 3be148b

16 files changed

+30
-17
lines changed

modules.py

+17-5
Original file line numberDiff line numberDiff line change
@@ -15,29 +15,39 @@ def to_scalar(arr):
1515
def weights_init(m):
1616
classname = m.__class__.__name__
1717
if classname.find('Conv') != -1:
18-
nn.init.xavier_uniform_(m.weight.data)
19-
m.bias.data.fill_(0)
18+
try:
19+
nn.init.xavier_uniform_(m.weight.data)
20+
m.bias.data.fill_(0)
21+
except AttributeError:
22+
print("Skipping initialization of ", classname)
2023

2124

2225
class VAE(nn.Module):
2326
def __init__(self, input_dim, dim, z_dim):
2427
super().__init__()
2528
self.encoder = nn.Sequential(
2629
nn.Conv2d(input_dim, dim, 4, 2, 1),
30+
nn.BatchNorm2d(dim),
2731
nn.ReLU(True),
2832
nn.Conv2d(dim, dim, 4, 2, 1),
33+
nn.BatchNorm2d(dim),
2934
nn.ReLU(True),
3035
nn.Conv2d(dim, dim, 5, 1, 0),
36+
nn.BatchNorm2d(dim),
3137
nn.ReLU(True),
3238
nn.Conv2d(dim, z_dim * 2, 3, 1, 0),
39+
nn.BatchNorm2d(z_dim * 2)
3340
)
3441

3542
self.decoder = nn.Sequential(
3643
nn.ConvTranspose2d(z_dim, dim, 3, 1, 0),
44+
nn.BatchNorm2d(dim),
3745
nn.ReLU(True),
3846
nn.ConvTranspose2d(dim, dim, 5, 1, 0),
47+
nn.BatchNorm2d(dim),
3948
nn.ReLU(True),
4049
nn.ConvTranspose2d(dim, dim, 4, 2, 1),
50+
nn.BatchNorm2d(dim),
4151
nn.ReLU(True),
4252
nn.ConvTranspose2d(dim, input_dim, 4, 2, 1),
4353
nn.Tanh()
@@ -204,7 +214,7 @@ def forward(self, x_v, x_h, h):
204214
class GatedPixelCNN(nn.Module):
205215
def __init__(self, input_dim=256, dim=64, n_layers=15):
206216
super().__init__()
207-
self.dim = 64
217+
self.dim = dim
208218

209219
# Create embedding layer to embed input
210220
self.embedding = nn.Embedding(input_dim, dim)
@@ -225,11 +235,13 @@ def __init__(self, input_dim=256, dim=64, n_layers=15):
225235

226236
# Add the output layer
227237
self.output_conv = nn.Sequential(
228-
nn.Conv2d(dim, dim, 1),
238+
nn.Conv2d(dim, 512, 1),
229239
nn.ReLU(True),
230-
nn.Conv2d(dim, input_dim, 1)
240+
nn.Conv2d(512, input_dim, 1)
231241
)
232242

243+
self.apply(weights_init)
244+
233245
def forward(self, x, label):
234246
shp = x.size() + (-1, )
235247
x = self.embedding(x.view(-1)).view(shp) # (B, H, W, C)

pixelcnn_prior.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -11,16 +11,16 @@
1111
N_EPOCHS = 100
1212
PRINT_INTERVAL = 100
1313
ALWAYS_SAVE = True
14-
DATASET = 'MNIST' # CIFAR10 | MNIST | FashionMNIST
14+
DATASET = 'CIFAR10' # CIFAR10 | MNIST | FashionMNIST
1515
NUM_WORKERS = 4
1616

17-
LATENT_SHAPE = (7, 7) # (8, 8) -> 32x32 images, (7, 7) -> 28x28 images
18-
INPUT_DIM = 1 # 3 (RGB) | 1 (Grayscale)
19-
DIM = 64
17+
LATENT_SHAPE = (8, 8) # (8, 8) -> 32x32 images, (7, 7) -> 28x28 images
18+
INPUT_DIM = 3 # 3 (RGB) | 1 (Grayscale)
19+
DIM = 256
2020
VAE_DIM = 256
21-
N_LAYERS = 15
21+
N_LAYERS = 12
2222
K = 512
23-
LR = 1e-3
23+
LR = 3e-4
2424

2525
DEVICE = torch.device('cuda') # torch.device('cpu')
2626

samples/reconstructions_CIFAR10.png

-137 KB
Binary file not shown.

samples/samples_CIFAR10.png

3.18 KB
Loading

samples/samples_FashionMNIST.png

-1.45 KB
Loading

samples/samples_MNIST.png

-768 Bytes
Loading
134 KB
Loading
49.1 KB
Loading

samples/vae_reconstructions_MNIST.png

30.3 KB
Loading

samples/vae_samples_CIFAR10.png

-3.16 KB
Loading

samples/vae_samples_FashionMNIST.png

51.1 KB
Loading

samples/vae_samples_MNIST.png

2.2 KB
Loading
4.12 KB
Loading
158 Bytes
Loading

vae.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,13 @@
1414
BATCH_SIZE = 32
1515
N_EPOCHS = 100
1616
PRINT_INTERVAL = 500
17-
DATASET = 'CIFAR10' # CIFAR10 | MNIST | FashionMNIST
17+
DATASET = 'FashionMNIST' # CIFAR10 | MNIST | FashionMNIST
1818
NUM_WORKERS = 4
1919

20-
INPUT_DIM = 3
20+
INPUT_DIM = 1
2121
DIM = 256
2222
Z_DIM = 128
23-
LR = 3e-4
23+
LR = 1e-3
2424

2525

2626
preproc_transform = transforms.Compose([

vqvae.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -81,12 +81,13 @@ def train():
8181
loss_commit.backward()
8282
opt.step()
8383

84-
nll = -Normal(x_tilde, torch.ones_like(x_tilde)).log_prob(x)
85-
log_px = nll.mean().item() - np.log(128) + np.log(K)
84+
N = x.numel()
85+
nll = Normal(x_tilde, torch.ones_like(x_tilde)).log_prob(x)
86+
log_px = nll.sum() / N + np.log(128) - np.log(K * 2)
8687
log_px /= np.log(2)
8788

8889
train_loss.append(
89-
[log_px] + to_scalar([loss_recons, loss_vq])
90+
[log_px.item()] + to_scalar([loss_recons, loss_vq])
9091
)
9192

9293
if (batch_idx + 1) % PRINT_INTERVAL == 0:

0 commit comments

Comments
 (0)