Source code for teras._src.models.pretrainers.tab_transformer

import keras

from teras._src import backend
from teras._src.api_export import teras_export


[docs] @teras_export("teras.models.TabTransformerMLMPretrainer") class TabTransformerMLMPretrainer(backend.models.TabTransformerMLMPretrainer): """ Masked Language Modelling (MLM) based Pretrainer for pretraining `TabTransformerBackbone` as proposed by Huang et al. in the paper, "TabTransformer: Tabular Data Modeling Using Contextual Embeddings". Reference(s): https://arxiv.org/abs/2012.06678 Args: model: keras.Model, instance of `TabTransformerBackbone` to pretrain data_dim: int, dimensionality of the input dataset missing_rate: float, fraction of original features to make missing. Must be in the range [0, 1). Defaults to 0.3 (or 30%) mask_seed: int, seed for generating mask. Defaults to 1337 """
[docs] def __init__(self, model: keras.Model, data_dim: int, missing_rate: float = 0.3, mask_seed: int = 1337, **kwargs): super().__init__( model=model, data_dim=data_dim, missing_rate=missing_rate, mask_seed=mask_seed, **kwargs)
[docs] @teras_export("teras.models.TabTransformerRTDPretrainer") class TabTransformerRTDPretrainer(backend.models.TabTransformerRTDPretrainer): """ Replaced Token Detection (RTD) based Pretrainer for pretraining `TabTransformerBackbone` as proposed by Huang et al. in the paper, "TabTransformer: Tabular Data Modeling Using Contextual Embeddings". Reference(s): https://arxiv.org/abs/2012.06678 Args: model: keras.Model, instance of `TabTransformerBackbone` to pretrain data_dim: int, dimensionality of the input dataset replace_rate: float, fraction of original features to replace. Must be in the range [0, 1). Defaults to 0.3 (or 30%) mask_seed: int, seed for generating mask. Defaults to 1337 shuffle_seed: int, seed for shuffling inputs. Defaults to 1999 """
[docs] def __init__(self, model: keras.Model, data_dim: int, replace_rate: float = 0.3, mask_seed: int = 1337, shuffle_seed: int = 1999, **kwargs): super().__init__( model=model, data_dim=data_dim, replace_rate=replace_rate, mask_seed=mask_seed, shuffle_seed=shuffle_seed, **kwargs)