Source code for teras._src.models.gans.ctgan.generator

from teras._src import backend
from teras._src.typing import IntegerSequence
from teras._src.api_export import teras_export


[docs] @teras_export("teras.models.CTGANGenerator") class CTGANGenerator(backend.models.CTGANGenerator): """ CTGANGenerator for CTGAN architecture as proposed by Lei Xu et al. in the paper, "Modeling Tabular data using Conditional GAN". Reference(s): https://arxiv.org/abs/1907.00503 Args: data_dim: int, The dimensionality of the dataset. It will also be the dimensionality of the output produced by the generator. Note the dimensionality must be equal to the dimensionality of dataset that is passed to the fit method and not necessarily the dimensionality of the raw input dataset as sometimes data transformation alters the dimensionality of the dataset. metadata: dict, `CTGANGenerator` applies different activation functions to its outputs depending on the type of features (categorical or continuous). And to determine the feature types and for other computations during the activation step, the ``metadata`` computed during the data transformation step, is required. It can be accessed through the `.metadata` property attribute of the `CTGANDataTransformer` instance which was used to transform the raw input data. Note that, this is NOT the same metadata as `features_metadata`, which is computed using the `get_metadata_for_embedding` utility function from `teras.utils`. You must access it through the `.metadata` property attribute of the `CTGANDataTransformer`. hidden_dims: Sequence, A sequence of integers that is used to construct the hidden block. For each value, a `CTGANGeneratorLayer` of that dimensionality is added. Defaults to [256, 256] """
[docs] def __init__(self, data_dim: int, metadata: dict, hidden_dims: IntegerSequence = (256, 256), seed: int = 1337, **kwargs): super().__init__(data_dim=data_dim, metadata=metadata, hidden_dims=hidden_dims, seed=seed, **kwargs)