Source code for teras._src.layers.saint.reconstruction_head
import keras
from keras import ops
from teras._src.layers.layer_list import LayerList
from teras._src.api_export import teras_export
[docs]
@teras_export("teras.layers.SAINTReconstructionHead")
class SAINTReconstructionHead(keras.layers.Layer):
"""
SAINT Reconstruction Head layer for `SAINTPretrainer`.
For each feature in the dataset, it creates an MLP with a hidden
layer and an output layer with dimensions equal to the cardinality of
the feature in the case of a categorical features and 1 in the case
of a continuous feature.
Args:
cardinalities: list or ndarray, a list or 1d-array of
cardinalities of all the features in the dataset in the
same order as the features' occurrence.
For numerical features, use 0 as indicator at
the corresponding index of the array.
You can use the `compute_cardinalities` function from
`teras.utils` package for this purpose.
embedding_dim: int, Dimensionality of embeddings being used in
the model,
"""
[docs]
def __init__(self,
cardinalities: list,
embedding_dim: int,
**kwargs):
super().__init__(**kwargs)
self.cardinalities = cardinalities
self.embedding_dim = embedding_dim
self.reconstruction_blocks = LayerList([
LayerList([
keras.layers.Dense(embedding_dim * 5,
activation="relu",
name=f"hidden_rb_{i}"),
keras.layers.Dense((card + 1) if card == 0 else card)],
sequential=True,
name=f"reconstruction_block_{i}"
)
for i, card in enumerate(self.cardinalities)
],
sequential=False,
name="reconstruction_blocks"
)
def build(self, input_shape):
self.reconstruction_blocks.build(input_shape)
def call(self, inputs):
reconstructed_features = self.reconstruction_blocks[0](
inputs[:, 0, :]
)
for idx, layer in enumerate(self.reconstruction_blocks[1:],
start=1):
r_f = layer(inputs[:, idx, :])
reconstructed_features = ops.concatenate(
[reconstructed_features, r_f],
axis=1
)
return reconstructed_features
def compute_output_shape(self, input_shape):
return input_shape[:1] + (len(self.cardinalities),)
def get_config(self):
config = super().get_config()
config.update({
"cardinalities": self.cardinalities,
"embedding_dim": self.embedding_dim,
})
return config