Source code for teras._src.losses.tabnet

from keras import ops
from keras.backend import floatx
from teras._src.api_export import teras_export


[docs] @teras_export("teras.losses.tabnet_reconstruction_loss") def tabnet_reconstruction_loss(real=None, reconstructed=None, mask=None): """ Reconstruction loss for TabNet Pretrainer mode as proposed by Sercan et al. in the paper, "TabNet: Attentive Interpretable Tabular Learning" Reference(s): https://arxiv.org/abs/1908.07442 Args: real: Samples drawn from the input dataset reconstructed: Samples reconstructed by the decoder mask: Mask that indicates the missing-ness of features in a sample Returns: Reconstruction loss for TabNet Pretraining. """ nominator_part = (reconstructed - real) * mask real_samples_population_std = ops.std(ops.cast(real, dtype=floatx())) # divide x = nominator_part / real_samples_population_std # Calculate L2 norm loss = ops.sqrt(ops.sum(ops.square(x))) return loss