diff --git a/apply_factor.py b/apply_factor.py index 1a66365f..7dd22a2b 100755 --- a/apply_factor.py +++ b/apply_factor.py @@ -51,6 +51,12 @@ type=str, help="name of the closed form factorization result factor file", ) + parser.add_argument( + "--torch_seed", + type=int, + default=0, + help="seed for generating random numbers", + ) args = parser.parse_args() @@ -61,6 +67,8 @@ trunc = g.mean_latent(4096) + if args.torch_seed > 0: + torch.manual_seed(args.torch_seed) latent = torch.randn(args.n_sample, 512, device=args.device) latent = g.get_latent(latent)