Source code for teras._src.models.gans.pcgain.pcgain
import keras
from teras._src import backend
from teras._src.api_export import teras_export
[docs]
@teras_export("teras.models.PCGAIN")
class PCGAIN(backend.models.PCGAIN):
"""
PCGAIN is a missing data imputation model based on the GAIN architecture.
It is proposed by Yufeng Wang et al. in the paper,
"PC-GAIN: Pseudo-label Conditional Generative Adversarial Imputation
Networks for Incomplete Data"
Reference(s):
https://arxiv.org/abs/2011.07770
Args:
generator: keras.Model, An instance of `GAINGenerator` model or any
customized model that can work in its place.
discriminator: keras.Model, An instance of `GAINDiscriminator` model
or any customized model that can work in its place.
classifier: keras.Model, An instance of the `PCGAINClassifier`
trained on the imputed pretraining dataset coupled with the
pseudo-labels generated by the clustering algorithm.
hint_rate: float, Hint rate will be used to sample binary vectors for
`hint vectors` generation. Must be between 0. and 1.
Hint vectors ensure that generated samples follow the underlying
data distribution.
Defaults to 0.9
alpha: float, Hyper parameter for the generator loss computation that
controls how much weight should be given to the MSE loss.
Precisely,
`generator_loss` = `cross_entropy_loss` + `alpha` * `mse_loss`
The higher the `alpha`, the more the mse_loss will affect the
overall generator loss.
Defaults to 200.
beta: float, Hyper parameter for generator loss computation that
controls the contribution of the classifier's loss to the
overall generator loss. Defaults to 100.
seed: int, seed to make results of random ops deterministic.
"""
[docs]
def __init__(self,
generator: keras.Model,
discriminator: keras.Model,
classifier: keras.Model,
hint_rate: float = 0.9,
alpha: float = 200.,
beta: float = 100.,
seed: int = 1337,
**kwargs):
super().__init__(generator=generator,
discriminator=discriminator,
classifier=classifier,
hint_rate=hint_rate,
alpha=alpha,
beta=beta,
seed=seed,
**kwargs)