Source code for teras._src.layers.cls_token

import keras
from keras import ops
from teras._src.api_export import teras_export


[docs] @teras_export("teras.layers.CLSToken") class CLSToken(keras.layers.Layer): """ CLS Token layer that makes it possible to append CLS token embedding to the input embeddings in the sequential or functional models. The idea of CLS token was introduced in the "BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding" paper. Reference(s): https://arxiv.org/abs/1810.04805 Args: embedding_dim: int, dimensionality of the input embeddings Shapes: Input Shape: `(batch_size, num_features, embedding_dim)` Output Shape: `(batch_size, num_features + 1, embedding_dim)` """
[docs] def __init__(self, embedding_dim: int, **kwargs): super().__init__(**kwargs) self.embedding_dim = embedding_dim
def build(self, input_shape=None): self.cls_token = self.add_weight( shape=(1, self.embedding_dim), initializer="random_normal", ) def call(self, inputs): # TODO Remove the call to `convert_to_tensor` as soon as Keras # fixes `broadcast_to` method for JAX backend token_broadcasted = ops.broadcast_to( ops.convert_to_tensor(self.cls_token), shape=(ops.shape(inputs)[0], *ops.shape(self.cls_token))) return ops.concatenate([token_broadcasted, inputs], axis=1) def compute_output_shape(self, input_shape): return (input_shape[0], input_shape[1] + 1, input_shape[2]) def get_config(self): config = super().get_config() config.update({ "embedding_dim": self.embedding_dim }) return config