Source code for teras._src.layers.saint.multi_head_inter_sample_attention

import keras
from keras import ops
from teras._src.api_export import teras_export


[docs] @teras_export("teras.layers.SAINTMultiHeadInterSampleAttention") class SAINTMultiHeadInterSampleAttention(keras.layers.Layer): """ Multi Head Inter Sample Attention layer based on the SAINT architecture proposed in the "SAINT: Improved Neural Networks for Tabular Data" paper. Reference(s): https://arxiv.org/abs/2106.01342 Args: num_heads: int, number of attention heads to use. key_dim: int, the paper proposes to use embedding_dim/num_heads dimensions for your key dimensionality value_dim: int, same value as key_dim is used by the paper. dropout: float, dropout value to use. Defaults to 0. Shapes: Input Shape: (batch_size, num_features, embedding_dim) Output Shape: (batch_size, num_features, embedding_dim) """
[docs] def __init__(self, num_heads: int, key_dim: int, value_dim: int = None, dropout: float = 0.0, **kwargs): super().__init__(**kwargs) self.num_heads = num_heads self.key_dim = key_dim self.value_dim = value_dim self.dropout = dropout self.multi_head_attention = keras.layers.MultiHeadAttention( num_heads=self.num_heads, key_dim=self.key_dim, value_dim=self.value_dim, dropout=self.dropout, )
def build(self, input_shape): if len(input_shape) != 3: raise ValueError( "Inputs must have the shape of rank 3" "(`batch_size`, `num_features`, embedding_dim`) but " f"received shape {input_shape} with rank " f"{len(input_shape)}." ) input_shape = (1, input_shape[0], input_shape[1] * input_shape[2]) self.multi_head_attention.build(input_shape, input_shape) def call(self, inputs): batch_size, num_features, embedding_dim = ops.shape(inputs) x = ops.reshape(inputs, (1, batch_size, num_features * embedding_dim), ) x = self.multi_head_attention(x, x) x = ops.reshape(x, (batch_size, num_features, embedding_dim)) return x def compute_output_shape(self, input_shape): return input_shape # easy! def get_config(self): config = super().get_config() config.update({ "num_heads": self.num_heads, "key_dim": self.key_dim, "value_dim": self.value_dim, "dropout": self.dropout })