Source code for teras._src.models.gans.pcgain.pcgain

import keras

from teras._src import backend
from teras._src.api_export import teras_export


[docs] @teras_export("teras.models.PCGAIN") class PCGAIN(backend.models.PCGAIN): """ PCGAIN is a missing data imputation model based on the GAIN architecture. It is proposed by Yufeng Wang et al. in the paper, "PC-GAIN: Pseudo-label Conditional Generative Adversarial Imputation Networks for Incomplete Data" Reference(s): https://arxiv.org/abs/2011.07770 Args: generator: keras.Model, An instance of `GAINGenerator` model or any customized model that can work in its place. discriminator: keras.Model, An instance of `GAINDiscriminator` model or any customized model that can work in its place. classifier: keras.Model, An instance of the `PCGAINClassifier` trained on the imputed pretraining dataset coupled with the pseudo-labels generated by the clustering algorithm. hint_rate: float, Hint rate will be used to sample binary vectors for `hint vectors` generation. Must be between 0. and 1. Hint vectors ensure that generated samples follow the underlying data distribution. Defaults to 0.9 alpha: float, Hyper parameter for the generator loss computation that controls how much weight should be given to the MSE loss. Precisely, `generator_loss` = `cross_entropy_loss` + `alpha` * `mse_loss` The higher the `alpha`, the more the mse_loss will affect the overall generator loss. Defaults to 200. beta: float, Hyper parameter for generator loss computation that controls the contribution of the classifier's loss to the overall generator loss. Defaults to 100. seed: int, seed to make results of random ops deterministic. """
[docs] def __init__(self, generator: keras.Model, discriminator: keras.Model, classifier: keras.Model, hint_rate: float = 0.9, alpha: float = 200., beta: float = 100., seed: int = 1337, **kwargs): super().__init__(generator=generator, discriminator=discriminator, classifier=classifier, hint_rate=hint_rate, alpha=alpha, beta=beta, seed=seed, **kwargs)