Source code for teras._src.layers.saint.projection_head

import keras
from teras._src.api_export import teras_export


[docs] @teras_export("teras.layers.SAINTProjectionHead") class SAINTProjectionHead(keras.layers.Layer): """ Projection Head layer that is used in the contrastive learning phase of the `SAINTPretrainer` to project embeddings to a lower dimension. According to the SAINT paper, "The use of a projection head to reduce dimensionality before computing contrastive loss is common in vision and indeed also improves results on tabular data." Reference(s): https://arxiv.org/abs/2106.01342 Args: hidden_dim: int, Dimensionality of the hidden layer. In the official implementation, it is computed as follows, `hidden_dim = 6 * embedding_dim * number_of_features // 5` output_dim: int, Dimensionality of the output layer. In the official implementation, it is computed as follows, `output_dim = embedding_dim * number_of_features // 2` hidden_activation: Activation function to use in the hidden layer. Defaults to "relu". """
[docs] def __init__(self, hidden_dim: int, output_dim: int, hidden_activation="relu", **kwargs): super().__init__(**kwargs) self.hidden_dim = hidden_dim self.hidden_activation = hidden_activation self.output_dim = output_dim self.hidden_block = keras.layers.Dense( units=hidden_dim, activation=hidden_activation, name="projection_head_hidden" ) self.output_layer = keras.layers.Dense( units=output_dim, name="projection_head_output" )
def build(self, input_shape): self.hidden_block.build(input_shape) input_shape = self.hidden_block.compute_output_shape(input_shape) self.output_layer.build(input_shape) def call(self, inputs): x = self.hidden_block(inputs) return self.output_layer(x) def compute_output_shape(self, input_shape): return input_shape[:-1] + (self.output_dim,) def get_config(self): config = { "name": self.name, "trainable": self.trainable, "hidden_dim": self.hidden_dim, "output_dim": self.output_dim, "hidden_activation": self.hidden_activation, } return config