Source code for teras._src.models.autoencoders.tvae.decoder
import keras
from teras._src.layers.layer_list import LayerList
from teras._src.typing import IntegerSequence
from teras._src.api_export import teras_export
[docs]
@teras_export("teras.models.TVAEDecoder")
class TVAEDecoder(keras.Model):
"""
Encoder for the TVAE model as proposed by Lei Xu et al. in the paper,
"Modeling Tabular data using Conditional GAN".
Reference(s):
https://arxiv.org/abs/1907.00503
Args:
data_dim: int, Dimensionality of the input dataset.
decompression_dims: A sequence of integers. For each value in
the sequence, a dense layer of that dimensionality is added to
construct a decompression block.
"""
[docs]
def __init__(self,
data_dim: int,
decompression_dims: IntegerSequence = (128, 128),
**kwargs):
super().__init__(**kwargs)
if not isinstance(decompression_dims, (list, tuple)):
raise ValueError(
f"`decompression_dims` must be a sequence of integers. "
f"Received: {decompression_dims}")
self.data_dim = data_dim
self.decompression_dims = decompression_dims
self.compression_block = []
for i, units in enumerate(self.decompression_dims, start=1):
self.compression_block.append(
keras.layers.Dense(units=units,
activation="relu",
name=f"decompression_layer_{i}"))
self.decompression_block = LayerList(
self.compression_block,
sequential=True,
name="tvae_encoder_compression_block"
)
self.projection_layer = keras.layers.Dense(self.data_dim,
name="projection_layer")
self.sigmas = self.add_weight(shape=(self.data_dim,),
initializer="ones", trainable=True,
name="sigmas") * 0.1
def build(self, input_shape):
self.decompression_block.build(input_shape)
input_shape = self.decompression_block.compute_output_shape(input_shape)
self.projection_layer.build(input_shape)
def call(self, inputs):
x_generated = self.projection_layer(self.decompression_block(inputs))
return x_generated, self.sigmas
def predict_step(self, z):
generated_samples, _ = self(z)
return generated_samples
def compute_output_shape(self, input_shape):
batch_size, input_dim = input_shape
return ((batch_size, self.data_dim),
(self.data_dim,))
def get_config(self):
config = super().get_config()
config.update({
'data_dim': self.data_dim,
'decompression_dims': self.decompression_dims
})
return config