Source code for teras._src.layers.tabnet.feature_transformer_layer
import keras
from teras._src.activations import glu
from teras._src.api_export import teras_export
[docs]
@teras_export("teras.layers.TabNetFeatureTransformerLayer")
class TabNetFeatureTransformerLayer(keras.layers.Layer):
"""
TabNetFeatureTransformerLayer layer that serves as the building block
for the `TabNetFeatureTransformer` layer which is proposed by Arik
et al. in the "TabNet: Attentive Interpretable Tabular Learning" paper.
Reference(s):
https://arxiv.org/abs/1908.07442
Args:
dim: int, the dense layer first maps the inputs to dim * 2
dimension hidden representations and later the glu activation
maps the hidden representations to `dim`-dimensions.
batch_momentum: float, batch momentum
"""
[docs]
def __init__(self,
dim: int,
batch_momentum: float,
**kwargs):
super().__init__(**kwargs)
self.dim = dim
self.batch_momentum = batch_momentum
self.dense = keras.layers.Dense(self.dim * 2,
use_bias=False)
self.batch_norm = keras.layers.BatchNormalization(
momentum=self.batch_momentum)
def build(self, input_shape):
self.dense.build(input_shape)
input_shape = self.dense.compute_output_shape(input_shape)
self.batch_norm.build(input_shape)
def call(self, inputs):
x = self.dense(inputs)
x = self.batch_norm(x)
return glu(x)
def compute_output_shape(self, input_shape):
return input_shape[:-1] + (self.dim,)
def get_config(self):
config = super().get_config()
config.update({
"dim": self.dim,
"batch_momentum": self.batch_momentum,
})
return config