From c59cc0eaa3eca42cc41c561b6b306ee1a55876b4 Mon Sep 17 00:00:00 2001 From: Shion Honda <26x.orc.ed5.1hs@gmail.com> Date: Fri, 8 Apr 2022 23:21:27 +0900 Subject: [PATCH] fix fid calculation --- fid.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/fid.py b/fid.py index 880a12a5..ffce06da 100755 --- a/fid.py +++ b/fid.py @@ -17,12 +17,15 @@ def extract_feature_from_samples( ): n_batch = n_sample // batch_size resid = n_sample - (n_batch * batch_size) - batch_sizes = [batch_size] * n_batch + [resid] + if resid == 0: + batch_sizes = [batch_size] * n_batch + else: + batch_sizes = [batch_size] * n_batch + [resid] features = [] for batch in tqdm(batch_sizes): latent = torch.randn(batch, 512, device=device) - img, _ = g([latent], truncation=truncation, truncation_latent=truncation_latent) + img, _ = generator([latent], truncation=truncation, truncation_latent=truncation_latent) feat = inception(img)[0].view(img.shape[0], -1) features.append(feat.to("cpu"))