IntroductionΒΆ
wgan is a python module built on top of PyTorch for using Wasserstein Generative Adversarial Network with Gradient Penalty (WGAN-GP) to simulate data with a known ground truth from real datasets, in order to test the properties of different estimators, as described in Athey et al. [2019]. The module contains functionality to simulate from either joint or conditional distributions. This documentation will explain how to set up the data, train the models, generate the artificial data and evaluate the models.
Generative Adversarial Networks (GANs) consist of two parts, the generator and a discriminator. The generator generates new observations that look similar to training data by maximizing the probability that the discriminator makes a mistake; the discriminator minimizes the probability of misclassifying generated data as real data. In the wgan module both the generator and the discriminator are neural networks.
The workflow for fitting a distribution and generating data from it using the module is as follows:
-
Load data into memory
Initialize a DataWrapper object and specify the data type for each variable
Initialize Specifications object given the DataWrapper, which specifies hyperparameters for training
Initialize Generator (generator) & Critic (discriminator) given the Specifications
Normalize the data with the DataWrapper object
-
-
Replace columns in df with simulated data from Generator using DataWrapper.apply_generator
-
Check the generated data via compare_dfs
Save generated data
For bug reports and feature requests, please submit an issue in the Github repository. The repository also contains a Google Colab tutorial that can be accessed here