Source code for teras._src.models.backbones.saint.saint

import keras

from teras._src.layers.cls_token import CLSToken
from teras._src.layers.cls_token_extraction import CLSTokenExtraction
from teras._src.layers.saint.embedding import SAINTEmbedding
from teras._src.layers.saint.encoder_layer import SAINTEncoderLayer
from teras._src.models.backbones.backbone import Backbone
from teras._src.api_export import teras_export


[docs] @teras_export("teras.model.SAINTBackbone") class SAINTBackbone(Backbone): """ SAINT Backbone based on the SAINT architecture proposed in the paper, "SAINT: Improved Neural Networks for Tabular Data". Reference(s): https://arxiv.org/abs/2106.01342 Args: input_dim: int, dimensionality of the input dataset. i.e. the number of features in the dataset. cardinalities: list, a list cardinalities of all the features in the dataset in the same order as the features' occurrence. For numerical features, use any value <=0 as indicator at the corresponding index. You can use the `compute_cardinalities` function from `teras.utils` package for this purpose. embedding_dim: int, dimensionality of the embeddings used by the model. It is also referred to as the `d_model` or model dimensionality. num_layers: int, number of `TransformerEncoderLayer`s to use in the encoder. num_heads: int, number of attention heads to use in the `MultiHeadAttention` layer. feedforward_dim: int, hidden dimensionality to use in the `TransformerFeedForward` layer. attention_dropout: float, dropout value to use in the `MultiHeadAttention` layer. Defaults to 0. feedforward_dropout: float, dropout value to use in the `TransformerFeedForward` layer. Defaults to 0. layer_norm_epsilon: float, epsilon value to use in the `LayerNormalization` layer. Defaults to 1e-5. embedd_inputs: bool, whether to use `SAINTEmbedding` layer to create emebddings of inputs. Defaults to `True` as this is what we want. but when pretraining we want to use the `SAINTBackbone` as encoder only which expects embeddings as inputs, so that's when we set this parameter to `False`. return_cls_token_only: bool, whether to return only embeddings for the `CLS` token. Defaults to `True`. """
[docs] def __init__(self, input_dim: int, cardinalities: list, embedding_dim: int, num_layers: int = 6, num_heads: int = 8, feedforward_dim: int = None, attention_dropout: float = 0., feedforward_dropout: float = 0., layer_norm_epsilon: float = 1e-5, embedd_inputs: bool = True, return_cls_token_only: bool = True, **kwargs): if embedd_inputs: inputs = keras.layers.Input((input_dim,), name="inputs") else: # Embeddings will be computed by an external embedding layer, # so in that case, the model should expect inputs with dims # (input_dim, embedding_dim) inputs = keras.layers.Input((input_dim, embedding_dim), name="inputs") x = inputs if embedd_inputs: x = SAINTEmbedding( embedding_dim=embedding_dim, cardinalities=cardinalities)(x) x = CLSToken(embedding_dim=embedding_dim)(x) for i in range(num_layers): x = SAINTEncoderLayer( embedding_dim=embedding_dim, num_heads=num_heads, feedforward_dim=feedforward_dim, attention_dropout=attention_dropout, feedforward_dropout=feedforward_dropout, layer_norm_epsilon=layer_norm_epsilon, name=f"saint_encoder_layer_{i}")(x) if return_cls_token_only: outputs = CLSTokenExtraction()(x) else: outputs = x super().__init__(inputs=inputs, outputs=outputs, **kwargs) self.input_dim = input_dim self.cardinalities = cardinalities self.num_layers = num_layers self.embedding_dim = embedding_dim self.num_heads = num_heads self.feedforward_dim = feedforward_dim self.attention_dropout = attention_dropout self.feedforward_dropout = feedforward_dropout self.layer_norm_epsilon = layer_norm_epsilon self.embedd_inputs = embedd_inputs self.return_cls_token_only = return_cls_token_only
def compute_output_shape(self, input_shape): if self.return_cls_token_only: return input_shape[:1] + (1, self.embedding_dim) else: return input_shape + (self.embedding_dim,) def get_config(self): config = super().get_config() config.update({ "input_dim": self.input_dim, "cardinalities": self.cardinalities, "embedding_dim": self.embedding_dim, "num_layers": self.num_layers, "num_heads": self.num_heads, "feedforward_dim": self.feedforward_dim, "attention_dropout": self.attention_dropout, "feedforward_dropout": self.feedforward_dropout, "layer_norm_epsilon": self.layer_norm_epsilon, "embedd_inputs": self.embedd_inputs, "return_cls_token_only": self.return_cls_token_only, }) return config