API

DataWrapper

class wgan.DataWrapper(df, continuous_vars=[], categorical_vars=[], context_vars=[], continuous_lower_bounds=dict(), continuous_upper_bounds=dict())

Class for processing raw training data for training Wasserstein GAN

Parameters
  • df (pandas.DataFrame) – Training data frame, includes both variables to be generated, and variables to be conditioned on

  • continuous_vars (list) – List of str of continuous variables to be generated

  • categorical_vars (list) – List of str of categorical variables to be generated

  • context_vars (list) – List of str of variables that are conditioned on for cWGAN

  • continuous_lower_bounds (dict) – Key is element of continuous_vars, value is lower limit on that variable.

  • continuous_upper_bounds (dict) – Key is element of continuous_vars, value is upper limit on that variable.

apply_generator(generator, df)

Replaces or inserts columns in DataFrame that are generated by the generator, of size equal to the number of rows in the DataFrame that is passed

Parameters
  • df (pandas.DataFrame) – Must contain columns listed in self.variables[“context”], which the generator will be conditioned on. Even without context vars, len(df) is used to infer the desired sample size, so you need to supply at least pd.DataFrame(index=range(n))

  • generator (wgan_model.Generator) – Trained generator for simulating data

Returns

Original DataFrame with columns replaced by generated data where possible.

Return type

pandas.DataFrame

deprocess(x, context)

Unscale tensors from WGAN output to original scale

Parameters
  • x (torch.tensor) – Generated data

  • context (torch.tensor) – Data conditioned on

Returns

df – DataFrame with data converted back to original scale

Return type

pandas.DataFrame

preprocess(df)

Scale training data for training in WGANs

Parameters

df (pandas.DataFrame) – raw training data

Returns

  • x (torch.tensor) – training data to be generated by WGAN

  • context (torch.tensor) – training data to be conditioned on by WGAN

Specifications

class wgan.Specifications(data_wrapper, optimizer=torch.optim.Adam, critic_d_hidden=[128, 128, 128], critic_dropout=0, critic_steps=15, critic_lr=0.0001, critic_gp_factor=5, generator_d_hidden=[128, 128, 128], generator_dropout=0.1, generator_lr=0.0001, generator_d_noise='generator_d_output', generator_optimizer='optimizer', max_epochs=1000, batch_size=32, test_set_size=16, load_checkpoint=None, save_checkpoint=None, save_every=100, print_every=200, device='cuda' if torch.cuda.is_available() else 'cpu')

Class used to set up WGAN training specifications before training Generator and Critic.

Parameters
  • data_wrapper (wgan_model.DataWrapper) – Object containing details on data frame to be trained

  • optimizer (torch.optim.Optimizer) – The torch.optim.Optimizer object used for training the networks, per default torch.optim.Adam

  • critic_d_hidden (list) – List of int, length equal to the number of hidden layers in the critic, giving the size of each hidden layer.

  • critic_dropout (float) – Dropout parameter for critic (see Srivastava et al 2014)

  • critic_steps (int) – Number of critic training steps taken for each generator training step

  • critic_lr (float) – Initial learning rate for critic

  • critic_gp_factor (float) – Weight on gradient penalty for critic loss function

  • generator_d_hidden (list) – List of int, length equal to the number of hidden layers in generator, giving the size of each hidden layer.

  • generator_dropout (float) – Dropout parameter for generator (See Srivastava et al 2014)

  • generator_lr (float) – Initial learning rate for generator

  • generator_d_noise (int) – The dimension of the noise input to the generator. Default sets to the output dimension of the generator.

  • generator_optimizer (torch.optim.Optimizer) – The torch.optim.Optimizer object used for training the generator network if different from “optimizer”, per default the same

  • max_epochs (int) – The number of times to train the network on the whole dataset.

  • batch_size (int) – The batch size for each training iteration.

  • test_set_size (int) – Holdout test set for calculating out of sample wasserstein distance.

  • load_checkpoint (str) – Filepath to existing model weights to start training from.

  • save_checkpoint (str) – Filepath of folder to save model weights every save_every iterations

  • save_every (int) – If save_checkpoint is not None, then how often to save checkpoint of model weights during training.

  • print_every (int) – How often to print training status during training.

  • device (str) – Either “cuda” if GPU is available or “cpu” if not

Generator

class wgan.Generator(specifications)

torch.nn.Module class for generator network in WGAN

Parameters

specifications (wgan_model.Specifications) – parameters for training WGAN

Critic

class wgan.Critic(specifications)

torch.nn.Module for critic in WGAN framework

Parameters

specifications (wgan_model.Specifications) –

train

wgan.train(generator, critic, x, context, specifications, penalty=None)

Function for training generator and critic in conditional WGAN-GP If context is empty, trains a regular WGAN-GP. See Gulrajani et al 2017 for details on training procedure.

Parameters
  • generator (wgan_model.Generator) – Generator network to be trained

  • critic (wgan_model.Critic) – Critic network to be trained

  • x (torch.tensor) – Training data for generated data

  • context (torch.tensor) – Data conditioned on for generating data

  • specifications (wgan_model.Specifications) – Includes all the tuning parameters for training

compare_dfs

wgan.compare_dfs(df_real, df_fake, scatterplot=dict(x=[], y=[], samples=400, smooth=0), table_groupby=[], histogram=dict(variables=[], nrow=1, ncol=1), figsize=3, save=False, path='')

Diagnostic function for comparing real and generated data from WGAN models. Prints out comparison of means, comparisons of standard deviations, and histograms and scatterplots.

Parameters
  • df_real (pandas.DataFrame) – real data

  • df_fake (pandas.DataFrame) – data produced by generator

  • scatterplot (dict) – Contains specifications for plotting scatterplots of variables in real and fake data

  • table_groupby (list) – List of variables to group mean and standard deviation table by

  • histogram (dict) – Contains specifications for plotting histograms comparing marginal densities of real and fake data

  • save (bool) – Indicate whether to save results to file or print them

  • path (string) – Path to save diagnostics for model