Source code for teras._src.losses.ctgan

import keras
from keras import ops
from teras._src.api_export import teras_export


[docs] @teras_export("teras.losses.ctgan_generator_loss") def ctgan_generator_loss(x_generated, y_pred_generated, cond_vectors, mask, metadata): """ Loss for the Generator model in the CTGAN architecture. CTGAN is a state-of-the-art tabular data generation architecture proposed by Lei Xu et al. in the paper, "Modeling Tabular data using Conditional GAN". Reference(s): https://arxiv.org/abs/1907.00503 Args: x_generated: Samples drawn from the input dataset y_pred_generated: Discriminator's output for the generated samples cond_vectors: Conditional vectors that are used for and with generated samples mask: Mask created during the conditional vectors generation step metadata: dict, metadata computed during the data transformation step. Returns: Generator's loss. """ loss = [] cross_entropy_loss = keras.losses.SparseCategoricalCrossentropy( from_logits=True, reduction=None) numerical_features_relative_indices = metadata["numerical"]["relative_indices_all"] features_relative_indices_all = metadata["relative_indices_all"] num_categories_all = metadata["categorical"]["num_categories_all"] # the first k features in the data are numerical which we'll ignore as # we're only concerned with the categorical features here offset = len(numerical_features_relative_indices) for i, index in enumerate(features_relative_indices_all[offset:]): logits = x_generated[:, index: index + num_categories_all[i]] temp_cond_vector = cond_vectors[:, i: i + num_categories_all[i]] labels = ops.argmax(temp_cond_vector, axis=1) ce_loss = cross_entropy_loss(y_pred=logits, y_true=labels ) loss.append(ce_loss) loss = ops.stack(loss, axis=1) loss = ops.sum(loss * ops.cast(mask, dtype="float32") ) / ops.cast(ops.shape(y_pred_generated)[0], dtype="float32") loss = -ops.mean(y_pred_generated) * loss return loss
[docs] @teras_export("teras.losses.ctgan_discriminator_loss") def ctgan_discriminator_loss(y_pred_real, y_pred_generated): """ Loss for the Discriminator model in the CTGAN architecture. CTGAN is a state-of-the-art tabular data generation architecture proposed by Lei Xu et al. in the paper, "Modeling Tabular data using Conditional GAN". Reference(s): https://arxiv.org/abs/1907.00503 Args: y_pred_real: Discriminator's output for real samples y_pred_generated: Discriminator's output for generated samples Returns: Discriminator's loss. """ return -(ops.mean(y_pred_real) - ops.mean(y_pred_generated))