Skip to content

intsystems/implicit-reparameterization-trick

Folders and files

NameName
Last commit message
Last commit date

Latest commit

b4a0fc2 Β· Dec 2, 2024
Nov 25, 2024
Nov 5, 2024
Dec 2, 2024
Nov 5, 2024
Nov 26, 2024
Dec 2, 2024
Nov 25, 2024
Nov 23, 2024
Nov 5, 2024
Nov 26, 2024
Nov 5, 2024
Sep 30, 2024
Sep 30, 2024
Nov 5, 2024

Repository files navigation

Implicit Reparametrization Trick

Coverage_2 Docs

Title Implicit Reparametrization Trick for BMM
Authors Matvei Kreinin, Maria Nikitina, Petr Babkin, Iryna Zabarianska
Consultant Oleg Bakhteev, PhD

πŸ’‘ Description

This repository implements an educational project for the Bayesian Multimodeling course. It implements algorithms for sampling from various distributions, using the implicit reparameterization trick.

πŸ—ƒ Scope

We plan to implement the following distributions in our library:

  • Gaussian normal distribution (*)
  • Dirichlet distribution (Beta distributions)(*)
  • Mixture of the same family distributions (**)
  • Student's t-distribution (**) (*)
  • VonMises distribution (***)
  • Sampling from an arbitrary factorized distribution (***)

(*) - this distribution is already implemented in torch using the explicit reparameterization trick, we will implement it for comparison

(**) - this distribution is added as a backup, their inclusion is questionable

(***) - this distribution is not very clear in implementation, its inclusion is questionable

πŸ“š Stack

We plan to inherit from the torch.distribution.Distribution class, so we need to implement all the methods that are present in that class.

πŸ‘¨β€πŸ’» Usage

In this example, we demonstrate the application of our library using a Variational Autoencoder (VAE) model, where the latent layer is modified by a normal distribution.

>>> import torch.distributions.implicit as irt
>>> params = Encoder(inputs)
>>> gauss = irt.Normal(*params)
>>> deviated = gauss.rsample()
>>> outputs = Decoder(deviated)

In this example, we demonstrate the use of a mixture of distributions using our library.

>>> import irt
>>> params = Encoder(inputs)
>>> mix = irt.Mixture([irt.Normal(*params), irt.Dirichlet(*params)])
>>> deviated = mix.rsample()
>>> outputs = Decoder(deviated)

πŸ“¬ Links

About

Implementation of Implicit Reparameterization Trick

Topics

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published