Source code for teras._src.layers.saint.embedding

import keras
from keras import ops
from teras._src.api_export import teras_export
from teras._src.layers.layer_list import LayerList


[docs] @teras_export("teras.layers.SAINTEmbedding") class SAINTEmbedding(keras.layers.Layer): """ SAINTEmbedding layer as proposed in the paper, "SAINT: Improved Neural Networks for Tabular Data". Reference(s): https://arxiv.org/abs/2106.01342 Args: embedding_dim: int, dimensionality of the embeddings cardinalities: list, a list cardinalities of all the features in the dataset in the same order as the features' occurrence. For numerical features, use the value `0` as indicator at the corresponding index. You can use the `compute_cardinalities` function from `teras.utils` package for this purpose. Shapes: Input Shape: (batch_size, input_dim) Output Shape: (batch_size, input_dim, embedding_dim) """
[docs] def __init__(self, embedding_dim: int, cardinalities: list, **kwargs): super().__init__(**kwargs) self.cardinalities = cardinalities self.embedding_dim = embedding_dim self.embedding_layers = [] for card in self.cardinalities: if card == 0: # it's continuous self.embedding_layers.append( keras.layers.Dense(self.embedding_dim, activation="relu") ) else: # it's categorical self.embedding_layers.append( keras.layers.Embedding( input_dim=card + 1, output_dim=self.embedding_dim)) self.embedding_layers = LayerList(self.embedding_layers, sequential=False)
def build(self, input_shape): # since each embedding layer only operates on a single feature self.embedding_layers.build((input_shape[0], 1)) def compute_output_shape(self, input_shape): return input_shape + (self.embedding_dim,) def call(self, inputs): feature = ops.take(inputs, indices=[0], axis=1) # As much as I'd like to use the empty tensor, i can't because # every framework uses different methods for assigning values in # an **efficient** way. What's efficient in one framework isn't # efficient in the other. I ain't making a spaghetti. I'd rather do # it this way. embeddings = self.embedding_layers[0](feature) if len(ops.shape(embeddings)) == 2: embeddings = ops.expand_dims(embeddings, axis=1) for idx, embedding_layer in enumerate(self.embedding_layers[1:], start=1): feature = ops.take(inputs, indices=[idx], axis=1) feature_embeddings = embedding_layer(feature) if len(ops.shape(feature_embeddings)) == 2: feature_embeddings = ops.expand_dims( feature_embeddings, axis=1) embeddings = ops.concatenate([embeddings, feature_embeddings], axis=1) return embeddings def get_config(self): config = super().get_config() config.update({ "embedding_dim": self.embedding_dim, "cardinalities": self.cardinalities }) return config @classmethod def from_config(cls, config): return cls(**config)