Source code for teras._src.layers.ctgan.generator_layer
import keras
from keras import ops
from teras._src.api_export import teras_export
[docs]
@teras_export("teras.layers.CTGANGeneratorLayer")
class CTGANGeneratorLayer(keras.layers.Layer):
"""
Residual Block for Generator as used by the authors of CTGAN
proposed in the paper Modeling Tabular data using Conditional GAN.
`outputs = Concat([ReLU(BatchNorm(Dense(inputs))), inputs])`
Reference(s):
https://arxiv.org/abs/1907.00503
Args:
dim: int, Dimensionality of the hidden layer.
Defaults to 256.
"""
[docs]
def __init__(self,
dim: int = 256,
**kwargs):
super().__init__(**kwargs)
self.dim = dim
self.dense = keras.layers.Dense(self.dim)
self.batch_norm = keras.layers.BatchNormalization()
self.relu = keras.layers.ReLU()
def build(self, input_shape):
self.dense.build(input_shape)
input_shape = self.dense.compute_output_shape(input_shape)
self.batch_norm.build(input_shape)
self.relu.build(input_shape)
def call(self, inputs):
x = self.dense(inputs)
x = self.batch_norm(x)
x = self.relu(x)
out = ops.concatenate([x, inputs], axis=1)
return out
def compute_output_shape(self, input_shape):
return input_shape[:-1] + (input_shape[-1] + self.dim,)
def get_config(self):
config = super().get_config()
config.update({"dim": self.dim})
return config