diff --git a/fid.py b/fid.py index 880a12a5..ffce06da 100755 --- a/fid.py +++ b/fid.py @@ -17,12 +17,15 @@ def extract_feature_from_samples( ): n_batch = n_sample // batch_size resid = n_sample - (n_batch * batch_size) - batch_sizes = [batch_size] * n_batch + [resid] + if resid == 0: + batch_sizes = [batch_size] * n_batch + else: + batch_sizes = [batch_size] * n_batch + [resid] features = [] for batch in tqdm(batch_sizes): latent = torch.randn(batch, 512, device=device) - img, _ = g([latent], truncation=truncation, truncation_latent=truncation_latent) + img, _ = generator([latent], truncation=truncation, truncation_latent=truncation_latent) feat = inception(img)[0].view(img.shape[0], -1) features.append(feat.to("cpu"))