Source code for teras._src.models.pretrainers.tabnet.tabnet

import keras

from teras._src import backend
from teras._src.api_export import teras_export


[docs] @teras_export("teras.models.TabNetPretrainer") class TabNetPretrainer(backend.models.TabNetPretrainer): """ TabNetPretrainer for pretraining `TabNetEncoder` as proposed by Arik et al. in the paper, "TabNet: Attentive Interpretable Tabular Learning" Reference(s): https://arxiv.org/abs/1908.07442 Args: encoder: keras.Model, instance of `TabNetEncoder` to pretrain decoder: keras.Model, instance of `TabNetDecoder` missing_feature_probability: float, probability of missing features seed: int, seed for generating mask. Defaults to 1337 """
[docs] def __init__(self, encoder: keras.Model, decoder: keras.Model, missing_feature_probability: float = 0.3, seed: int = 1337, **kwargs): super().__init__( encoder=encoder, decoder=decoder, missing_feature_probability=missing_feature_probability, seed=seed, **kwargs)