Source code for teras._src.preprocessing.data_transformers.tvae

from teras._src.preprocessing.data_transformers.ctgan import CTGANDataTransformer as _BaseDataTransformer
from teras._src.typing import FeaturesNamesType


[docs] class TVAEDataTransformer(_BaseDataTransformer): """ TVAEDataTransformer class that is exactly similar to the CTGANDataTransformer, it just acts as a wrapper for convenience. Reference(s): https://arxiv.org/abs/1907.00503 https://github.com/sdv-dev/CTGAN/ Args: categorical_features: list, List of categorical features names in the dataset. continuous_features: list, List of continuous features names in the dataset. max_clusters: int, Maximum Number of clusters to use in `ModeSpecificNormalization`. Defaults to 10. std_multiplier: int, Multiplies the standard deviation in the normalization. Defaults to 4. weight_threshold: float, The minimum value a component weight can take to be considered a valid component. `weights_` under this value will be ignored. (Taken from the official implementation.) Defaults to 0.005. covariance_type: str, Parameter for the `GaussianMixtureModel` class of sklearn. Defaults to "full". weight_concentration_prior_type: str, Parameter for the `GaussianMixtureModel` class of sklearn. Defaults to "dirichlet_process" weight_concentration_prior: float, Parameter for the `GaussianMixtureModel` class of sklearn. Defaults to 0.001. """
[docs] def __init__(self, continuous_features: FeaturesNamesType = None, categorical_features: FeaturesNamesType = None, max_clusters: int = 10, std_multiplier: int = 4, weight_threshold: float = 0.005, covariance_type: str = "full", weight_concentration_prior_type: str = "dirichlet_process", weight_concentration_prior: float = 0.001 ): super().__init__( continuous_features=continuous_features, categorical_features=categorical_features, max_clusters=max_clusters, std_multiplier=std_multiplier, weight_threshold=weight_threshold, covariance_type=covariance_type, weight_concentration_prior_type=weight_concentration_prior_type, weight_concentration_prior=weight_concentration_prior)