Skip to content

PyTorch implementation of BasisVAE: Translation-invariant feature-level clustering with Variational Autoencoders

License

Notifications You must be signed in to change notification settings

ethanweinberger/BasisVAE

 
 

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

14 Commits
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

BasisVAE

This is the PyTorch implementation of our AISTATS 2020 paper BasisVAE: Translation-invariant feature-level clustering with Variational Autoencoders

Summary

It would be desirable to construct a joint modelling framework for simultaneous dimensionality reduction and clustering of features. Here, we focus on embedding such capabilities within the Variational Autoencoder (VAE) framework. Specifically, we propose the BasisVAE: a combination of the VAE and a probabilistic clustering prior, which lets us learn a one-hot basis function representation as part of the decoder network. This has been illustrated below. Furthermore, for scenarios where not all features are aligned, we develop an extension to handle translation-invariant basis functions.

Illustrations on synthetic data

We illustrate BasisVAE model fitting on the following toy data set that has four groups of features:

Here is the inferred scale-invariant BasisVAE:

and here we additionally allow translation invariance

Demo notebook

See this Colab notebook for a toy example demo.

Implementation details

The core component of BasisVAE is its specialised decoder. See decoder.py for details about its usage and implementation. Our current implementation supports the following likelihoods Gaussian, Bernoulli, NB, ZINB.

Installation

pip install git+https://github.com/kasparmartens/BasisVAE.git

About

PyTorch implementation of BasisVAE: Translation-invariant feature-level clustering with Variational Autoencoders

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages

  • Python 91.7%
  • R 8.3%