Source code for teras._src.layers.transformer.encoder_layer

import keras
from teras._src.layers.transformer.feedforward import TransformerFeedForward
from teras._src.api_export import teras_export


[docs] @teras_export("teras.layers.TransformerEncoderLayer") class TransformerEncoderLayer(keras.layers.Layer): """ Transformer Encoder Layer as proposed in the original Transformer architecture in the "Attention is all you need" paper. This is the layer that makes up the encoder in the architecture. This is made up of `MultiHeadAttention` and `TransformerFeedForward` layers. Reference(s): https://arxiv.org/abs/1706.03762 Args: embedding_dim: int, dimensionality of the embeddings used by the model. It is also referred to as the `d_model` or model dimensionality. num_heads: int, number of attention heads to use in the `MultiHeadAttention` layer. feedforward_dim: int, hidden dimensionality to use in the `TransformerFeedForward` layer. attention_dropout: float, dropout value to use in the `MultiHeadAttention` layer. Defaults to 0. feedforward_dropout: float, dropout value to use in the `TransformerFeedForward` layer. Defaults to 0. layer_norm_epsilon: float, epsilon value to use in the `LayerNormalization` layer. Defaults to 1e-5. use_normalization: bool, whether to use `LayerNormalization`. In some architecture, normalization isn't applied to the very first layer, so to accomodate such architectures, we introduced this parameter. Defaults to `True`. pre_normalization: bool, whether to use Pre-Normalization technique whereby `LayerNormalization` is applied to inputs of the `MultiHeadAttention` or `FeedForward` and then outputs of those layers are elementwise added to the original inputs. Defaults to `False`, as the original Transformers architecture doesn't use pre-normalization. Shapes: Input Shape: `(batch_size, num_features, embedding_dim)` Output Shape: `(batch_size, num_features, embedding_dim)` """
[docs] def __init__(self, embedding_dim: int, num_heads: int = 8, feedforward_dim: int = None, attention_dropout: float = 0., feedforward_dropout: float = 0., layer_norm_epsilon: float = 1e-5, use_normalization: bool = True, pre_normalization: bool = False, **kwargs): super().__init__(**kwargs) self.embedding_dim = embedding_dim self.num_heads = num_heads self.feedforward_dim = feedforward_dim self.attention_dropout = attention_dropout self.feedforward_dropout = feedforward_dropout self.layer_norm_epsilon = layer_norm_epsilon self.use_normalization = use_normalization self.pre_normalization = pre_normalization self.attention = keras.layers.MultiHeadAttention( num_heads=self.num_heads, key_dim=self.embedding_dim, dropout=attention_dropout ) self.feedforward = TransformerFeedForward( embedding_dim=self.embedding_dim, hidden_dim=self.feedforward_dim, dropout=self.feedforward_dropout ) self.add_1 = keras.layers.Add() self.add_2 = keras.layers.Add() if self.use_normalization: self.layer_norm_1 = keras.layers.LayerNormalization( epsilon=self.layer_norm_epsilon ) self.layer_norm_2 = keras.layers.LayerNormalization( epsilon=self.layer_norm_epsilon )
def build(self, input_shape): self.feedforward.build(input_shape) if self.use_normalization: self.layer_norm_1.build(input_shape) self.layer_norm_2.build(input_shape) def call(self, inputs): residue = inputs if self.use_normalization and self.pre_normalization: x = self.layer_norm_1(inputs) x = self.attention(x, x) x = self.add_1([x, residue]) residue = x x = self.layer_norm_2(x) x = self.feedforward(x) x = self.add_2([x, residue]) else: x = self.attention(inputs, inputs) x = self.add_1([x, residue]) if self.use_normalization: x = self.layer_norm_1(x) residue = x x = self.feedforward(x) x = self.add_2([x, residue]) if self.use_normalization: x = self.layer_norm_2(x) return x def get_config(self): config = super().get_config() config.update({ "embedding_dim": self.embedding_dim, "num_heads": self.num_heads, "feedforward_dim": self.feedforward_dim, "attention_dropout": self.attention_dropout, "feedforward_dropout": self.feedforward_dropout, "layer_norm_epsilon": self.layer_norm_epsilon, "use_normalization": self.use_normalization, "pre_normalization": self.pre_normalization }) return config