Source code for teras._src.layers.saint.encoder_layer
import keras
from teras._src.layers.transformer.encoder_layer import TransformerEncoderLayer
from teras._src.layers.transformer.feedforward import TransformerFeedForward
from teras._src.layers.saint.multi_head_inter_sample_attention import SAINTMultiHeadInterSampleAttention
from teras._src.api_export import teras_export
[docs]
@teras_export("teras.layers.SAINTEncoderLayer")
class SAINTEncoderLayer(keras.layers.Layer):
"""
SAINTEncoderLayer layer as proposed in the paper,
"SAINT: Improved Neural Networks for Tabular Data".
Reference(s):
https://arxiv.org/abs/2106.01342
Args:
embedding_dim: int, dimensionality of the embeddings
num_heads: int, number of attention heads to use in the
`MultiHeadAttention` layer.
feedforward_dim: int, hidden dimensionality to use in the
`TransformerFeedForward` layer.
attention_dropout: float, dropout value to use in the
`MultiHeadAttention` layer. Defaults to 0.
feedforward_dropout: float, dropout value to use in the
`TransformerFeedForward` layer. Defaults to 0.
layer_norm_epsilon: float, epsilon value to use in the
`LayerNormalization` layer. Defaults to 1e-5.
"""
[docs]
def __init__(self,
embedding_dim: int,
num_heads: int = 8,
feedforward_dim: int = None,
attention_dropout: float = 0.,
feedforward_dropout: float = 0.,
layer_norm_epsilon: float = 1e-5,
**kwargs):
super().__init__(**kwargs)
self.embedding_dim = embedding_dim
self.num_heads = num_heads
self.feedforward_dim = feedforward_dim
self.attention_dropout = attention_dropout
self.feedforward_dropout = feedforward_dropout
self.layer_norm_epsilon = layer_norm_epsilon
# ====== Self Attention Block ========
self.self_attention_block = TransformerEncoderLayer(
embedding_dim=self.embedding_dim,
num_heads=self.num_heads,
feedforward_dim=self.feedforward_dim,
attention_dropout=self.attention_dropout,
feedforward_dropout=self.feedforward_dropout,
layer_norm_epsilon=self.layer_norm_epsilon,
name="Self_Attention_Block",
)
# ====== Inter Sample Attention Block ========
self.inter_sample_attention = SAINTMultiHeadInterSampleAttention(
num_heads=self.num_heads,
key_dim=self.embedding_dim // self.num_heads, # ref: paper
value_dim=None,
dropout=self.attention_dropout,
)
self.isab_feed_forward = TransformerFeedForward(
embedding_dim=self.embedding_dim,
hidden_dim=self.feedforward_dim,
activation="gelu",
dropout=feedforward_dropout
)
self.isab_norm_1 = keras.layers.LayerNormalization(
epsilon=self.layer_norm_epsilon,
)
self.isab_norm_2 = keras.layers.LayerNormalization(
epsilon=self.layer_norm_epsilon,
)
def build(self, input_shape):
self.self_attention_block.build(input_shape)
self.inter_sample_attention.build(input_shape)
self.isab_feed_forward.build(input_shape)
self.isab_norm_1.build(input_shape)
self.isab_norm_2.build(input_shape)
def call(self, inputs):
# ====== Self Attention Block ========
x = self.self_attention_block(inputs)
# ====== Inter Sample Attention Block ========
residue = x
x = self.inter_sample_attention(x)
x = keras.layers.add([x, residue])
x = self.isab_norm_1(x)
residue = x
x = self.isab_feed_forward(x)
x = keras.layers.add([x, residue])
x = self.isab_norm_2(x)
return x
def compute_output_shape(self, input_shape):
return input_shape
def get_config(self):
config = super().get_config()
config.update({
"embedding_dim": self.embedding_dim,
"num_heads": self.num_heads,
"feedforward_dim": self.feedforward_dim,
"attention_dropout": self.attention_dropout,
"feedforward_dropout": self.feedforward_dropout,
"layer_norm_epsilon": self.layer_norm_epsilon,
})