Source code for teras._src.models.autoencoders.tvae.encoder

import keras

from teras._src.layers.layer_list import LayerList
from teras._src.typing import IntegerSequence
from teras._src.api_export import teras_export


[docs] @teras_export("teras.models.TVAEEncoder") class TVAEEncoder(keras.Model): """ Encoder for the TVAE model as proposed by Lei Xu et al. in the paper, "Modeling Tabular data using Conditional GAN". Reference(s): https://arxiv.org/abs/1907.00503 Args: latent_dim: int, Dimensionality of the learned latent space. Defaults to 128. compression_dims: Sequence, A sequence of integers. For each value in the sequence, a dense layer of that dimensions is added to construct a compression block. Defaults to (128, 128). """
[docs] def __init__(self, latent_dim: int = 128, compression_dims: IntegerSequence = (128, 128), **kwargs): super().__init__(**kwargs) if not isinstance(compression_dims, (list, tuple)): raise ValueError( f"`compression_dims` must be a sequence of integers." f"Received: {compression_dims}") self.latent_dim = latent_dim self.compression_dims = compression_dims self.compression_block = [] for i, units in enumerate(self.compression_dims, start=1): self.compression_block.append( keras.layers.Dense(units=units, activation="relu", name=f"compression_layer_{i}")) self.compression_block = LayerList( self.compression_block, sequential=True, name="tvae_encoder_compression_block" ) self.dense_mean = keras.layers.Dense(self.latent_dim, name="mean") self.dense_log_var = keras.layers.Dense(self.latent_dim, name="log_var")
def build(self, input_shape): self.compression_block.build(input_shape) input_shape = self.compression_block.compute_output_shape(input_shape) self.dense_mean.build(input_shape) self.dense_log_var.build(input_shape) def call(self, inputs): h = self.compression_block(inputs) mean = self.dense_mean(h) log_var = self.dense_log_var(h) return mean, log_var def compute_output_shape(self, input_shape): batch_size, dims = input_shape return ((batch_size, self.latent_dim), (batch_size, self.latent_dim)) def get_config(self): config = super().get_config() config.update({ 'latent_dim': self.latent_dim, 'compression_dims': self.compression_dims }) return config