Source code for teras._src.models.backbones.transformer.encoder
import keras
from teras._src.layers.transformer.encoder_layer import TransformerEncoderLayer
from teras._src.models.backbones.backbone import Backbone
from teras._src.api_export import teras_export
[docs]
@teras_export("teras.models.TransformerEncoderBackbone")
class TransformerEncoderBackbone(Backbone):
"""
Transformer Encoder model as proposed in the
"Attention is all you need" paper.
Reference(s):
https://arxiv.org/abs/1706.03762
Args:
input_dim: int, dimensionality of the input data.
embedding_dim: int, dimensionality of the embeddings used
by the model. It is also referred to as the `d_model` or
model dimensionality.
num_layers: int, number of `TransformerEncoderLayer`s to use in
the encoder.
num_heads: int, number of attention heads to use in the
`MultiHeadAttention` layer.
feedforward_dim: int, hidden dimensionality to use in the
`TransformerFeedForward` layer.
attention_dropout: float, dropout value to use in the
`MultiHeadAttention` layer. Defaults to 0.
feedforward_dropout: float, dropout value to use in the
`TransformerFeedForward` layer. Defaults to 0.
layer_norm_epsilon: float, epsilon value to use in the
`LayerNormalization` layer. Defaults to 1e-5.
unnormalized_layers: list, list of indices corresponding to
the layers in which `LayerNormalization` won't be used.
For instance, if you don't want to use the normalization in
the first `TransformerEncoderLayer` layer (like FT-Transformer)
you can pass [0]. If you don't want to normalize first and
second layer, you can similarly pass [0, 1] and so on.
Defaults, to `[]` (empty list), because the original
Transformer architecture and most others use normalization
in all of their layers.
pre_normalization: bool, whether to use Pre-Normalization technique
whereby `LayerNormalization` is applied to inputs of the
`MultiHeadAttention` or `FeedForward` and then outputs of
those layers are elementwise added to the original inputs.
Defaults to `False`, as the original Transformers architecture
doesn't use pre-normalization.
"""
[docs]
def __init__(self,
input_dim: int,
embedding_dim: int,
num_layers: int = 6,
num_heads: int = 8,
feedforward_dim: int = None,
attention_dropout: float = 0.,
feedforward_dropout: float = 0.,
layer_norm_epsilon: float = 1e-5,
unnormalized_layers: list = [],
pre_normalization: bool = False,
**kwargs):
if num_layers < 1:
raise ValueError(
f"`num_layers` must be 1 or greater. Received {num_layers}")
if (len(unnormalized_layers) > 0 and
(max(unnormalized_layers) > (num_layers - 1) or
min(unnormalized_layers) < 0)):
raise ValueError(
f"Layer indices must be in the interval [0, num_layers). "
f"Received {unnormalized_layers}"
)
inputs = keras.layers.Input(shape=(input_dim, embedding_dim))
x = inputs
for i in range(num_layers):
use_normalization = i not in unnormalized_layers
x = TransformerEncoderLayer(
embedding_dim=embedding_dim,
num_heads=num_heads,
feedforward_dim=feedforward_dim,
attention_dropout=attention_dropout,
feedforward_dropout=feedforward_dropout,
layer_norm_epsilon=layer_norm_epsilon,
use_normalization=use_normalization,
pre_normalization=pre_normalization,
name=f"transformer_encoder_layer_{i+1}")(x)
outputs = x
super().__init__(inputs=inputs, outputs=outputs, **kwargs)
self.input_dim = input_dim
self.num_layers = num_layers
self.embedding_dim = embedding_dim
self.num_heads = num_heads
self.feedforward_dim = feedforward_dim
self.attention_dropout = attention_dropout
self.feedforward_dropout = feedforward_dropout
self.layer_norm_epsilon = layer_norm_epsilon
self.unnormalized_layers = unnormalized_layers
self.pre_normalization = pre_normalization
def get_config(self):
config = super().get_config()
config.update({
"input_dim": self.input_dim,
"embedding_dim": self.embedding_dim,
"num_layers": self.num_layers,
"num_heads": self.num_heads,
"feedforward_dim": self.feedforward_dim,
"attention_dropout": self.attention_dropout,
"feedforward_dropout": self.feedforward_dropout,
"layer_norm_epsilon": self.layer_norm_epsilon,
"unnormalized_layers": self.unnormalized_layers,
"pre_normalization": self.pre_normalization
})
return config