Source code for teras._src.models.gans.ctgan.ctgan
import keras
from teras._src import backend
from teras._src.api_export import teras_export
[docs]
@teras_export("teras.models.CTGAN")
class CTGAN(backend.models.CTGAN):
"""
CTGAN is a state-of-the-art tabular data generation architecture
proposed by Lei Xu et al. in the paper,
"Modeling Tabular data using Conditional GAN".
Reference(s):
https://arxiv.org/abs/1907.00503
Args:
generator: keras.Model, An instance of :py:class:`CTGANGenerator`.
discriminator: keras.Model, An instance of
:py:class:`CTGANDiscriminator`.
metadata: dict, A dictionary containing features metadata computed
during the data transformation step.
It can be accessed through the `.metadata` property attribute of
the :py:class:`CTGANDataTransformer` instance which was used to
transform the raw input data.
Note that, this is NOT the same metadata as `features_metadata`,
which is computed using the `get_metadata_for_embedding` utility
function from :py:mod:`teras.utils`.
latent_dim: int, Dimensionality of noise or `z` that serves as
input to :py:class:`CTGANGenerator` to generate samples.
Defaults to 128.
seed: int, Seed for random sampling. Defaults to 1337.
"""
[docs]
def __init__(self,
generator: keras.Model,
discriminator: keras.Model,
metadata: dict,
latent_dim: int = 128,
seed: int = 1337,
**kwargs):
super().__init__(generator=generator,
discriminator=discriminator,
metadata=metadata,
latent_dim=latent_dim,
seed=seed,
**kwargs)