Source code for teras._src.models.gans.gain.gain
import keras
from teras._src import backend
from teras._src.api_export import teras_export
[docs]
@teras_export("teras.models.GAIN")
class GAIN(backend.models.GAIN):
"""
GAIN is a missing data imputation model based on GANs. This is an
implementation of the GAIN architecture proposed by Jinsung Yoon et al.
in the paper,
"GAIN: Missing Data Imputation using Generative Adversarial Nets"
In GAIN, the generator observes some components of a real data vector,
imputes the missing components conditioned on what is actually observed, and
outputs a completed vector.
The discriminator then takes a completed vector and attempts to determine
which components were actually observed and which were imputed. It also
utilizes a novel hint mechanism, which ensures that generator does in
fact learn to generate samples according to the true data distribution.
Reference(s):
https://arxiv.org/abs/1806.02920
Args:
generator: keras.Model, An instance of `GAINGenerator` model or any
customized model that can work in its place.
discriminator: keras.Model, An instance of `GAINDiscriminator` model
or any customized model that can work in its place.
hint_rate: float, Hint rate will be used to sample binary vectors for
`hint vectors` generation. Must be between 0. and 1.
Hint vectors ensure that generated samples follow the underlying
data distribution.
Defaults to 0.9
alpha: float, Hyper parameter for the generator loss computation that
controls how much weight should be given to the MSE loss.
Precisely,
`generator_loss` = `cross_entropy_loss` + `alpha` * `mse_loss`
The higher the `alpha`, the more the mse_loss will affect the
overall generator loss.
Defaults to 100.
"""
[docs]
def __init__(self,
generator: keras.Model,
discriminator: keras.Model,
hint_rate: float = 0.9,
alpha: float = 100.,
**kwargs):
super().__init__(generator=generator,
discriminator=discriminator,
hint_rate=hint_rate,
alpha=alpha,
**kwargs)