API¶
DataWrapper¶
- class wgan.DataWrapper(df, continuous_vars=[], categorical_vars=[], context_vars=[], continuous_lower_bounds=dict(), continuous_upper_bounds=dict())¶
Class for processing raw training data for training Wasserstein GAN
- Parameters
df (pandas.DataFrame) – Training data frame, includes both variables to be generated, and variables to be conditioned on
continuous_vars (list) – List of str of continuous variables to be generated
categorical_vars (list) – List of str of categorical variables to be generated
context_vars (list) – List of str of variables that are conditioned on for cWGAN
continuous_lower_bounds (dict) – Key is element of continuous_vars, value is lower limit on that variable.
continuous_upper_bounds (dict) – Key is element of continuous_vars, value is upper limit on that variable.
- apply_generator(generator, df)¶
Replaces or inserts columns in DataFrame that are generated by the generator, of size equal to the number of rows in the DataFrame that is passed
- Parameters
df (pandas.DataFrame) – Must contain columns listed in self.variables[“context”], which the generator will be conditioned on. Even without context vars, len(df) is used to infer the desired sample size, so you need to supply at least pd.DataFrame(index=range(n))
generator (wgan_model.Generator) – Trained generator for simulating data
- Returns
Original DataFrame with columns replaced by generated data where possible.
- Return type
pandas.DataFrame
- deprocess(x, context)¶
Unscale tensors from WGAN output to original scale
- Parameters
x (torch.tensor) – Generated data
context (torch.tensor) – Data conditioned on
- Returns
df – DataFrame with data converted back to original scale
- Return type
pandas.DataFrame
- preprocess(df)¶
Scale training data for training in WGANs
- Parameters
df (pandas.DataFrame) – raw training data
- Returns
x (torch.tensor) – training data to be generated by WGAN
context (torch.tensor) – training data to be conditioned on by WGAN
Specifications¶
- class wgan.Specifications(data_wrapper, optimizer=torch.optim.Adam, critic_d_hidden=[128, 128, 128], critic_dropout=0, critic_steps=15, critic_lr=0.0001, critic_gp_factor=5, generator_d_hidden=[128, 128, 128], generator_dropout=0.1, generator_lr=0.0001, generator_d_noise='generator_d_output', generator_optimizer='optimizer', max_epochs=1000, batch_size=32, test_set_size=16, load_checkpoint=None, save_checkpoint=None, save_every=100, print_every=200, device='cuda' if torch.cuda.is_available() else 'cpu')¶
Class used to set up WGAN training specifications before training Generator and Critic.
- Parameters
data_wrapper (wgan_model.DataWrapper) – Object containing details on data frame to be trained
optimizer (torch.optim.Optimizer) – The torch.optim.Optimizer object used for training the networks, per default torch.optim.Adam
critic_d_hidden (list) – List of int, length equal to the number of hidden layers in the critic, giving the size of each hidden layer.
critic_dropout (float) – Dropout parameter for critic (see Srivastava et al 2014)
critic_steps (int) – Number of critic training steps taken for each generator training step
critic_lr (float) – Initial learning rate for critic
critic_gp_factor (float) – Weight on gradient penalty for critic loss function
generator_d_hidden (list) – List of int, length equal to the number of hidden layers in generator, giving the size of each hidden layer.
generator_dropout (float) – Dropout parameter for generator (See Srivastava et al 2014)
generator_lr (float) – Initial learning rate for generator
generator_d_noise (int) – The dimension of the noise input to the generator. Default sets to the output dimension of the generator.
generator_optimizer (torch.optim.Optimizer) – The torch.optim.Optimizer object used for training the generator network if different from “optimizer”, per default the same
max_epochs (int) – The number of times to train the network on the whole dataset.
batch_size (int) – The batch size for each training iteration.
test_set_size (int) – Holdout test set for calculating out of sample wasserstein distance.
load_checkpoint (str) – Filepath to existing model weights to start training from.
save_checkpoint (str) – Filepath of folder to save model weights every save_every iterations
save_every (int) – If save_checkpoint is not None, then how often to save checkpoint of model weights during training.
print_every (int) – How often to print training status during training.
device (str) – Either “cuda” if GPU is available or “cpu” if not
Generator¶
- class wgan.Generator(specifications)¶
torch.nn.Module class for generator network in WGAN
- Parameters
specifications (wgan_model.Specifications) – parameters for training WGAN
Critic¶
- class wgan.Critic(specifications)¶
torch.nn.Module for critic in WGAN framework
- Parameters
specifications (wgan_model.Specifications) –
train¶
- wgan.train(generator, critic, x, context, specifications, penalty=None)¶
Function for training generator and critic in conditional WGAN-GP If context is empty, trains a regular WGAN-GP. See Gulrajani et al 2017 for details on training procedure.
- Parameters
generator (wgan_model.Generator) – Generator network to be trained
critic (wgan_model.Critic) – Critic network to be trained
x (torch.tensor) – Training data for generated data
context (torch.tensor) – Data conditioned on for generating data
specifications (wgan_model.Specifications) – Includes all the tuning parameters for training
compare_dfs¶
- wgan.compare_dfs(df_real, df_fake, scatterplot=dict(x=[], y=[], samples=400, smooth=0), table_groupby=[], histogram=dict(variables=[], nrow=1, ncol=1), figsize=3, save=False, path='')¶
Diagnostic function for comparing real and generated data from WGAN models. Prints out comparison of means, comparisons of standard deviations, and histograms and scatterplots.
- Parameters
df_real (pandas.DataFrame) – real data
df_fake (pandas.DataFrame) – data produced by generator
scatterplot (dict) – Contains specifications for plotting scatterplots of variables in real and fake data
table_groupby (list) – List of variables to group mean and standard deviation table by
histogram (dict) – Contains specifications for plotting histograms comparing marginal densities of real and fake data
save (bool) – Indicate whether to save results to file or print them
path (string) – Path to save diagnostics for model