Source code for teras._src.models.gans.gain.discriminator

import keras

from teras._src.layers.layer_list import LayerList
from teras._src.typing import IntegerSequence, ActivationType
from teras._src.api_export import teras_export


[docs] @teras_export("teras.models.GAINDiscriminator") class GAINDiscriminator(keras.Model): """ Discriminator model for the GAIN architecture proposed by Jinsung Yoon et al. in the paper GAIN: Missing Data Imputation using Generative Adversarial Nets. Note that the Generator and Discriminator share the exact same architecture by default. They differ in the inputs they receive and their loss functions. Reference(s): https://arxiv.org/abs/1806.02920 Args: data_dim: int, The dimensionality of the input dataset. Note the dimensionality must be equal to the dimensionality of the transformed dataset that is passed to the fit method and not that of original dataset as the dimensionality of raw input dataset may change during transformation. One way to access the dimensionality of the transformed dataset is through the `.data_dim` attribute of the `GAINDataSampler` instance used in sampling the dataset. hidden_dims: list, A list of hidden dimensionalities for constructing hidden block. For each value, a `Dense` layer of that dimensionality is added to the hidden block. By default, `units_values` = [`data_dim`, `data_dim`]. activation_hidden: Activation function to use for the hidden layers in the hidden block. Defaults to "relu". activation_out: Activation function to use for the output layer of the Discriminator. Defaults to "sigmoid" """
[docs] def __init__(self, data_dim: int, hidden_dims: IntegerSequence = None, activation_hidden: ActivationType = "relu", activation_out: ActivationType = "sigmoid", **kwargs): super().__init__(**kwargs) if hidden_dims is not None and not isinstance(hidden_dims, (list, tuple)): raise ValueError( "`units_values` must be a list or tuple of units which " "determines the number of Discriminator blocks and the " f"dimensionality of those blocks. But {hidden_dims} was " "passed.") self.data_dim = data_dim self.hidden_dims = hidden_dims self.activation_hidden = activation_hidden self.activation_out = activation_out if self.hidden_dims is None: self.hidden_dims = [self.data_dim] * 2 self.hidden_block = [] for dim in self.hidden_dims: self.hidden_block.append( keras.layers.Dense( units=dim, activation=self.activation_hidden, kernel_initializer="glorot_normal", ) ) self.hidden_block = LayerList( self.hidden_block, name="discriminator_hidden_block" ) self.output_layer = keras.layers.Dense( self.data_dim, activation=self.activation_out, name="discriminator_output_layer")
def build(self, input_shape): self.hidden_block.build(input_shape) input_shape = self.hidden_block.compute_output_shape(input_shape) self.output_layer.build(input_shape) def call(self, inputs, **kwargs): # inputs is the concatenation of `hint` and manipulated # Generator output (i.e. generated samples). # `hint` has the same dimensions as data # so the inputs received are 2x the dimensions of original data x = self.hidden_block(inputs) x = self.output_layer(x) return x def compute_output_shape(self, input_shape): return input_shape[:-1] + (self.data_dim,) def get_config(self): config = super().get_config() config.update({ 'data_dim': self.data_dim, 'hidden_dims': self.hidden_dims, 'activation_hidden': self.activation_hidden, 'activation_out': self.activation_out, }) return config @classmethod def from_config(cls, config): data_dim = config.pop("data_dim") return cls(data_dim=data_dim, **config)