Source code for teras._src.layers.tabnet.attentive_transformer

import keras
from teras._src.activations import sparsemax
from teras._src.api_export import teras_export


[docs] @teras_export("teras.layers.TabNetAttentiveTransformer") class TabNetAttentiveTransformer(keras.layers.Layer): """ TabNetAttentiveTransformer layer proposed by Arik et al. in the "TabNet: Attentive Interpretable Tabular Learning" paper. Reference(s): https://arxiv.org/abs/1908.07442 Args: data_dim: int, dimensionality of the dataset batch_momentum: float, batch momentum """
[docs] def __init__(self, data_dim: int, batch_momentum: float, **kwargs): super().__init__(**kwargs) self.data_dim = data_dim self.batch_momentum = batch_momentum self.dense = keras.layers.Dense(data_dim, use_bias=False) self.batch_norm = keras.layers.BatchNormalization( momentum=self.batch_momentum)
def build(self, input_shape): self.dense.build(input_shape) input_shape = self.dense.compute_output_shape(input_shape) self.batch_norm.build(input_shape) self.built = True def call(self, inputs, prior_scales): x = self.dense(inputs) x = self.batch_norm(x) x = x * prior_scales x = sparsemax(x) return x def compute_output_shape(self, input_shape): return input_shape[:-1] + (self.data_dim,) def get_config(self): config = super().get_config() config.update({ "data_dim": self.data_dim, "batch_momentum": self.batch_momentum })