Source code for teras._src.tasks.imputation
import keras
from teras._src.api_export import teras_export
[docs]
@teras_export("teras.tasks.Imputer")
class Imputer:
"""
Imputer task class used to impute missing data using the trained `model`
instance.
Args:
model: keras.Model, trained instance of a keras model that will be
used to impute missing data.
Currently, `teras` offers `GAIN` and `PCGAIN` architectures for
imputation.
data_transformer: Instance of a relevant data transformer that is
used during the data transformation step before training the
respective architecture.
"""
[docs]
def __init__(self,
model: keras.Model,
data_transformer=None):
self.model = model
self.data_transformer = data_transformer
def impute(self, x, reverse_transform=True, batch_size=None,
verbose="auto", steps=None, callbacks=None):
"""
Imputes the missing data.
It exposes all the arguments taken by the `predict` method.
Args:
x: pd.DataFrame, dataset with missing values.
reverse_transform: bool, default True, whether to reverse
transformed the raw imputed data to original format.
Returns:
Imputed data in the original format.
"""
if reverse_transform:
if self.data_transformer is None:
raise ValueError(
"To reverse transform the raw imputed data, "
"`data_transformer` must not be None. "
"Please pass the instance of `DataTransformer` class used "
"to transform the input data as argument to this `impute` "
"method. \n"
"Or alternatively, you can set `reverse_transform` "
"parameter to False, and manually reverse transform the "
"generated raw data to original format using the "
"`reverse_transform` method of `DataTransformer` instance.")
x_imputed = self.model.predict(x,
batch_size=batch_size,
verbose=verbose,
steps=steps,
callbacks=callbacks)
if reverse_transform:
x_imputed = self.data_transformer.reverse_transform(x_imputed)
return x_imputed