Source code for teras._src.models.pretrainers.saint

import keras

from teras._src import backend
from teras._src.api_export import teras_export


[docs] @teras_export("teras.models.SAINTPretrainer") class SAINTPretrainer(backend.models.SAINTPretrainer): """ SAINTPretrainer as proposed in the paper, "SAINT: Improved Neural Networks for Tabular Data". Reference(s): https://arxiv.org/abs/2106.01342 Args: model: keras.Model, instance of `SAINTBackbone` model to pretrain cardinalities: list, a list cardinalities of all the features in the dataset in the same order as the features' occurrence. For numerical features, use the value `0` as indicator at the corresponding index. You can use the `compute_cardinalities` function from `teras.utils` package for this purpose. embedding_dim: int, dimensionality of the embeddings being used in the model. cutmix_probability: float, used by the `CutMix` layer in generation of mask that is used to mix samples together. Defaults to 0.3 mixup_alpha: float, used by the `MixUp` layer in sampling from the `Beta` distribution which is then used to interpolate samples. Defaults to 1. temperature: float, used in the computation of the `contrastive_loss` to scale logits. Defaults to 0.7 lambda_: float, acts as a weight when adding the contrastive loss and the denoising loss together. `loss = constrastive_loss + lambda_ * denoising_loss` Defaults to 10. lambda_c: float, used in the computation of the contrastive loss. Similar to `lambda_` is helps combined two sub-losses within the contrastive loss. Defaults to 0.5 seed: int, seed used in random sampling and shuffling etc. It helps make the model behavior more deterministic. Defaults to (you guessed it) 1337. """
[docs] def __init__(self, model: keras.Model, cardinalities: list, embedding_dim: int, cutmix_probability: float = 0.3, mixup_alpha: float = 1., temperature: float = 0.7, lambda_: float = 10., lambda_c: float = 0.5, seed: int = 1337, **kwargs): super().__init__(model=model, cardinalities=cardinalities, embedding_dim=embedding_dim, cutmix_probability=cutmix_probability, mixup_alpha=mixup_alpha, temperature=temperature, lambda_=lambda_, lambda_c=lambda_c, seed=seed, **kwargs)