Imputing missing data in teras

Imputing missing data in teras#

Using state of the art deep learning data imputation models for tabular data can be quite a challenge, not just because of how complex the model architecture might get, but also because of the data preprocessing and transformation steps involved. But teras makes it as easy as doing a classification or regression task.

As of teras v0.3, it offers two GAN-based architectures for data imputation, namely GAIN and PCGAIN.

For the sake of this tutorial, we’ll use the GAIN architecture.

So without further ado, let’s get to coding!

As always, the first step is to configure your backend. I’ll be using JAX because it’s almost always is the fastest of the three.

To configure your backend for teras, you need to set the KERAS_BACKEND environment variable.

NOTE: You need to configure you backend before importing teras/keras

import os
os.environ["KERAS_BACKEND"] = "jax"

For this tutorial, we’ll be using the Boston Housing dataset made available by keras.

from keras.datasets import boston_housing

(X_train, y_train), (X_test, y_test) = boston_housing.load_data()
---------------------------------------------------------------------------
ModuleNotFoundError                       Traceback (most recent call last)
File ~/checkouts/readthedocs.org/user_builds/teras/envs/stable/lib/python3.11/site-packages/jax/_src/lib/__init__.py:26
     25 try:
---> 26   import jaxlib as jaxlib
     27 except ModuleNotFoundError as err:

ModuleNotFoundError: No module named 'jaxlib'

The above exception was the direct cause of the following exception:

ModuleNotFoundError                       Traceback (most recent call last)
Cell In[2], line 1
----> 1 from keras.datasets import boston_housing
      3 (X_train, y_train), (X_test, y_test) = boston_housing.load_data()

File ~/checkouts/readthedocs.org/user_builds/teras/envs/stable/lib/python3.11/site-packages/keras/__init__.py:8
      1 """DO NOT EDIT.
      2 
      3 This file was autogenerated. Do not edit it by hand,
      4 since your modifications would be overwritten.
      5 """
----> 8 from keras import _tf_keras
      9 from keras import activations
     10 from keras import applications

File ~/checkouts/readthedocs.org/user_builds/teras/envs/stable/lib/python3.11/site-packages/keras/_tf_keras/__init__.py:1
----> 1 from keras._tf_keras import keras

File ~/checkouts/readthedocs.org/user_builds/teras/envs/stable/lib/python3.11/site-packages/keras/_tf_keras/keras/__init__.py:8
      1 """DO NOT EDIT.
      2 
      3 This file was autogenerated. Do not edit it by hand,
      4 since your modifications would be overwritten.
      5 """
----> 8 from keras import activations
      9 from keras import applications
     10 from keras import callbacks

File ~/checkouts/readthedocs.org/user_builds/teras/envs/stable/lib/python3.11/site-packages/keras/activations/__init__.py:8
      1 """DO NOT EDIT.
      2 
      3 This file was autogenerated. Do not edit it by hand,
      4 since your modifications would be overwritten.
      5 """
----> 8 from keras.src.activations import deserialize
      9 from keras.src.activations import get
     10 from keras.src.activations import serialize

File ~/checkouts/readthedocs.org/user_builds/teras/envs/stable/lib/python3.11/site-packages/keras/src/__init__.py:1
----> 1 from keras.src import activations
      2 from keras.src import applications
      3 from keras.src import backend

File ~/checkouts/readthedocs.org/user_builds/teras/envs/stable/lib/python3.11/site-packages/keras/src/activations/__init__.py:3
      1 import types
----> 3 from keras.src.activations.activations import elu
      4 from keras.src.activations.activations import exponential
      5 from keras.src.activations.activations import gelu

File ~/checkouts/readthedocs.org/user_builds/teras/envs/stable/lib/python3.11/site-packages/keras/src/activations/activations.py:1
----> 1 from keras.src import backend
      2 from keras.src import ops
      3 from keras.src.api_export import keras_export

File ~/checkouts/readthedocs.org/user_builds/teras/envs/stable/lib/python3.11/site-packages/keras/src/backend/__init__.py:35
     33     from keras.src.backend.tensorflow import *  # noqa: F403
     34 elif backend() == "jax":
---> 35     from keras.src.backend.jax import *  # noqa: F403
     36 elif backend() == "torch":
     37     from keras.src.backend.torch import *  # noqa: F403

File ~/checkouts/readthedocs.org/user_builds/teras/envs/stable/lib/python3.11/site-packages/keras/src/backend/jax/__init__.py:1
----> 1 from keras.src.backend.jax import core
      2 from keras.src.backend.jax import distribution_lib
      3 from keras.src.backend.jax import image

File ~/checkouts/readthedocs.org/user_builds/teras/envs/stable/lib/python3.11/site-packages/keras/src/backend/jax/core.py:1
----> 1 import jax
      2 import jax.experimental.sparse as jax_sparse
      3 import jax.numpy as jnp

File ~/checkouts/readthedocs.org/user_builds/teras/envs/stable/lib/python3.11/site-packages/jax/__init__.py:37
     34 del _cloud_tpu_init
     36 # Force early import, allowing use of `jax.core` after importing `jax`.
---> 37 import jax.core as _core
     38 del _core
     40 # Note: import <name> as <name> is required for names to be exported.
     41 # See PEP 484 & https://github.com/google/jax/issues/7570

File ~/checkouts/readthedocs.org/user_builds/teras/envs/stable/lib/python3.11/site-packages/jax/core.py:18
      1 # Copyright 2022 The JAX Authors.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
   (...)
     15 # Note: import <name> as <name> is required for names to be exported.
     16 # See PEP 484 & https://github.com/google/jax/issues/7570
---> 18 from jax._src.core import (
     19   AbstractToken as AbstractToken,
     20   AbstractValue as AbstractValue,
     21   Atom as Atom,
     22   AxisSize as AxisSize,
     23   CallPrimitive as CallPrimitive,
     24   ClosedJaxpr as ClosedJaxpr,
     25   ConcreteArray as ConcreteArray,
     26   ConcretizationTypeError as ConcretizationTypeError,
     27   DShapedArray as DShapedArray,
     28   DropVar as DropVar,
     29   Effect as Effect,
     30   Effects as Effects,
     31   EvalTrace as EvalTrace,
     32   InDBIdx as InDBIdx,
     33   InconclusiveDimensionOperation as InconclusiveDimensionOperation,
     34   InputType as InputType,
     35   Jaxpr as Jaxpr,
     36   JaxprDebugInfo as JaxprDebugInfo,
     37   JaxprEqn as JaxprEqn,
     38   JaxprPpContext as JaxprPpContext,
     39   JaxprPpSettings as JaxprPpSettings,
     40   JaxprTypeError as JaxprTypeError,
     41   Literal as Literal,
     42   MainTrace as MainTrace,
     43   MapPrimitive as MapPrimitive,
     44   NameGatheringSubst as NameGatheringSubst,
     45   NamedShape as NamedShape,
     46   OutDBIdx as OutDBIdx,
     47   OutputType as OutputType,
     48   ParamDict as ParamDict,
     49   Primitive as Primitive,
     50   ShapedArray as ShapedArray,
     51   Sublevel as Sublevel,
     52   TRACER_LEAK_DEBUGGER_WARNING as TRACER_LEAK_DEBUGGER_WARNING,
     53   ThreadLocalState as ThreadLocalState,
     54   Token as Token,
     55   Trace as Trace,
     56   TraceStack as TraceStack,
     57   TraceState as TraceState,
     58   Tracer as Tracer,
     59   UnshapedArray as UnshapedArray,
     60   Value as Value,
     61   Var as Var,
     62   abstract_token as abstract_token,
     63   apply_todos as apply_todos,
     64   as_named_shape as as_named_shape,
     65   aval_mapping_handlers as aval_mapping_handlers,
     66   axis_frame as axis_frame,
     67   call as call,
     68   call_bind_with_continuation as call_bind_with_continuation,
     69   call_impl as call_impl,
     70   call_p as call_p,
     71   canonicalize_shape as _deprecated_canonicalize_shape,
     72   check_eqn as check_eqn,
     73   check_jaxpr as check_jaxpr,
     74   check_type as check_type,
     75   check_valid_jaxtype as check_valid_jaxtype,
     76   closed_call_p as closed_call_p,
     77   concrete_aval as concrete_aval,
     78   concrete_or_error as concrete_or_error,
     79   concretization_function_error as concretization_function_error,
     80   cur_sublevel as cur_sublevel,
     81   custom_typechecks as custom_typechecks,
     82   dedup_referents as dedup_referents,
     83   definitely_equal as _deprecated_definitely_equal,
     84   dimension_as_value as _deprecated_dimension_as_value,
     85   do_subst_axis_names_jaxpr as do_subst_axis_names_jaxpr,
     86   ensure_compile_time_eval as ensure_compile_time_eval,
     87   escaped_tracer_error as escaped_tracer_error,
     88   eval_context as eval_context,
     89   eval_jaxpr as eval_jaxpr,
     90   extend_axis_env as extend_axis_env,
     91   extend_axis_env_nd as extend_axis_env_nd,
     92   find_top_trace as find_top_trace,
     93   full_lower as full_lower,
     94   gensym as gensym,
     95   get_aval as get_aval,
     96   get_referent as get_referent,
     97   is_constant_dim as is_constant_dim,
     98   is_constant_shape as is_constant_shape,
     99   jaxpr_as_fun as jaxpr_as_fun,
    100   jaxpr_uses_outfeed as jaxpr_uses_outfeed,
    101   jaxprs_in_params as jaxprs_in_params,
    102   join_effects as join_effects,
    103   join_named_shapes as join_named_shapes,
    104   lattice_join as lattice_join,
    105   leaked_tracer_error as leaked_tracer_error,
    106   literalable_types as literalable_types,
    107   map_bind as map_bind,
    108   map_bind_with_continuation as map_bind_with_continuation,
    109   mapped_aval as mapped_aval,
    110   maybe_find_leaked_tracers as maybe_find_leaked_tracers,
    111   max_dim as max_dim,
    112   min_dim as min_dim,
    113   new_base_main as new_base_main,
    114   new_jaxpr_eqn as new_jaxpr_eqn,
    115   new_main as new_main,
    116   new_sublevel as new_sublevel,
    117   no_axis_name as no_axis_name,
    118   no_effects as no_effects,
    119   non_negative_dim as _deprecated_non_negative_dim,
    120   outfeed_primitives as outfeed_primitives,
    121   pp_aval as pp_aval,
    122   pp_eqn as pp_eqn,
    123   pp_eqn_rules as pp_eqn_rules,
    124   pp_eqns as pp_eqns,
    125   pp_jaxpr as pp_jaxpr,
    126   pp_jaxpr_eqn_range as pp_jaxpr_eqn_range,
    127   pp_jaxpr_skeleton as pp_jaxpr_skeleton,
    128   pp_jaxprs as pp_jaxprs,
    129   pp_kv_pair as pp_kv_pair,
    130   pp_kv_pairs as pp_kv_pairs,
    131   pp_var as pp_var,
    132   pp_vars as pp_vars,
    133   primal_dtype_to_tangent_dtype as primal_dtype_to_tangent_dtype,
    134   primitive_uses_outfeed as primitive_uses_outfeed,
    135   process_env_traces_call as process_env_traces_call,
    136   process_env_traces_map as process_env_traces_map,
    137   pytype_aval_mappings as pytype_aval_mappings,
    138   raise_as_much_as_possible as raise_as_much_as_possible,
    139   raise_to_shaped as raise_to_shaped,
    140   raise_to_shaped_mappings as raise_to_shaped_mappings,
    141   reset_trace_state as reset_trace_state,
    142   stash_axis_env as stash_axis_env,
    143   str_eqn_compact as str_eqn_compact,
    144   subjaxprs as subjaxprs,
    145   subst_axis_names as subst_axis_names,
    146   subst_axis_names_eqn as subst_axis_names_eqn,
    147   subst_axis_names_jaxpr as subst_axis_names_jaxpr,
    148   subst_axis_names_var as subst_axis_names_var,
    149   substitute_vars_in_output_ty as substitute_vars_in_output_ty,
    150   thread_local_state as thread_local_state,
    151   token as token,
    152   trace_state_clean as trace_state_clean,
    153   traverse_jaxpr_params as traverse_jaxpr_params,
    154   typecheck as typecheck,
    155   typecompat as typecompat,
    156   typematch as typematch,
    157   unmapped_aval as unmapped_aval,
    158   used_axis_names as used_axis_names,
    159   used_axis_names_jaxpr as used_axis_names_jaxpr,
    160   valid_jaxtype as valid_jaxtype,
    161 )
    164 from jax._src import core as _src_core
    165 _deprecations = {
    166     # Added Oct 11, 2023:
    167     "DimSize": (
   (...)
    191     ),
    192 }

File ~/checkouts/readthedocs.org/user_builds/teras/envs/stable/lib/python3.11/site-packages/jax/_src/core.py:39
     35 from weakref import ref
     37 import numpy as np
---> 39 from jax._src import dtypes
     40 from jax._src import config
     41 from jax._src import effects

File ~/checkouts/readthedocs.org/user_builds/teras/envs/stable/lib/python3.11/site-packages/jax/_src/dtypes.py:33
     30 import ml_dtypes
     31 import numpy as np
---> 33 from jax._src import config
     34 from jax._src.typing import DType, DTypeLike
     35 from jax._src.util import set_module, StrictABC

File ~/checkouts/readthedocs.org/user_builds/teras/envs/stable/lib/python3.11/site-packages/jax/_src/config.py:27
     24 import threading
     25 from typing import Any, Callable, Generic, NamedTuple, NoReturn, TypeVar, cast
---> 27 from jax._src import lib
     28 from jax._src.lib import jax_jit
     29 from jax._src.lib import transfer_guard_lib

File ~/checkouts/readthedocs.org/user_builds/teras/envs/stable/lib/python3.11/site-packages/jax/_src/lib/__init__.py:28
     26   import jaxlib as jaxlib
     27 except ModuleNotFoundError as err:
---> 28   raise ModuleNotFoundError(
     29     'jax requires jaxlib to be installed. See '
     30     'https://github.com/google/jax#installation for installation instructions.'
     31     ) from err
     33 import jax.version
     34 from jax.version import _minimum_jaxlib_version as _minimum_jaxlib_version_str

ModuleNotFoundError: jax requires jaxlib to be installed. See https://github.com/google/jax#installation for installation instructions.

Let’s combine all the data since our task here is self-supervised so we don’t need labels or test data to compute any metrics

import numpy as np

dataset = np.concatenate([np.concatenate([X_train, y_train[:, np.newaxis]], axis=1),
                          np.concatenate([X_test, y_test[:, np.newaxis]], axis=1)],
                         axis=0)
dataset.shape
(506, 14)

Always a good idea to normalize our dataset

from sklearn.preprocessing import Normalizer

normalizer = Normalizer()
dataset = normalizer.fit_transform(dataset)

Now, this dataset in itself doesn’t contain any missing value, so we’ll inject missing values ourselves to simulate a real world scenario.

And for that, teras offers a handy utility that can be quite helpful for quickly simulating such situations. It conveniently named inject_missing_values

from teras.utils import inject_missing_values

print("# of missing values: ")
print("Before injecting: ", np.isnan(dataset).sum())
dataset = inject_missing_values(dataset, 0.2)
print("After injecting: ", np.isnan(dataset).sum())
# of missing values: 
Before injecting:  0
After injecting:  1426

The GAIN architecture that we’ll be using requires dataset in the form (x_generator, x_discriminator).

There’s a handy data utility function in teras for this purpose named create_gain_dataset.

NOTE: As of teras v0.3.0, you need to have TensorFlow installed to use this function since it makes use of tf.data to create a TensorFlow dataset that is then handled by Keras 3 to be used with any backend. It is also true for any data sampling classes available in teras. You may not like TensorFlow but you cannot not like tf.data.

from teras.data_utils import create_gain_dataset

gain_dataset = create_gain_dataset(dataset)

# Remember to batch your tensorflow dataset
BATCH_SIZE = 64
gain_dataset = gain_dataset.batch(BATCH_SIZE)
2024-04-10 13:43:50.949841: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-04-10 13:43:50.949885: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-04-10 13:43:50.951338: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2024-04-10 13:43:52.014177: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT

Now let’s import GAIN

Since GAIN is a generative adversarial network, so it requires a instaces of a generator and a discriminator, which we’ll also import.

from teras.models import GAIN
from teras.models import GAINGenerator
from teras.models import GAINDiscriminator

If you look at the documentation, to instantiate either the GAINGenerator or GAINDiscriminator you need a positional argument namely data_dim. Now it’s usually the same as the input dimensionality of the dataset, but is named so for cases when the input dataset has different dimensionality from the original dataset due to data transformations and such other preprocessing craft.

Anyway, here data_dim refers to the dimensionality of the original dataset.

dataset.shape[1]
14
generator = GAINGenerator(data_dim=dataset.shape[1])

discriminator = GAINDiscriminator(data_dim=dataset.shape[1])

gain = GAIN(generator,
            discriminator)

NOTE: You can customize these models futher by specifying various keyword arguments. Look up docs! I’ll just stick with default for the sake of this tutorial.

Now let’s compile our model. Note that we’re not passing any loss function to the compile method of GAIN instance, the reason being these specialized architectures contain loss computing methods within.

import keras

gain.compile(generator_optimizer=keras.optimizers.Adam(),
             discriminator_optimizer=keras.optimizers.Adam())

The rule of thumb for GAN-based models in teras is to ALWAYS build them yourself because the dataset that we pass to such architectures is usually deviates from normal (X, y) paired dataset, so Keras fails to build such models automatically due to failure to infer expected input shape.

So let’s build the model ourself!

gain.build((BATCH_SIZE, dataset.shape[1]))

Now, if and only if you’re using the JAX backend, you’ll have to call the build_optimizers method when using any GAN based model or any model that makes use of more than one optimizer. It is not needed for other backends like TensorFlow or PyTorch, neither it is needed for any architecture that only uses a single optimizer, which is usually how it is in 99.99% of the cases.

Anyway, since we ARE using the JAX backend, so we’ll call this method.

gain.build_optimizers()

WARNING: Calling build_optimizers method on a backend other than JAX will result in error!

history = gain.fit(gain_dataset, epochs=2)
Epoch 1/2
8/8 ━━━━━━━━━━━━━━━━━━━━ 5s 384ms/step - discriminator_loss: 0.7368 - generator_loss: 48.5096
Epoch 2/2
8/8 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - discriminator_loss: 0.7002 - generator_loss: 47.1353

Now the model is trained. Cool. But if we can’t put it to use, it’s useless. So let’s put it into use.

To impute data with missing values, you can either use the predict method of the trained GAIN instance or use a the Imputer class available in teras.tasks module. The Imputer class may not be that useful here, but it can be very useful in cases where you transform your data using a data transformer class.

So, assuming you already know how to use predict, we’ll use the Imputer class here. It offers an impute method that takes in dataset with missing values and returns imputed data. If a data transformer instance is passed in during the instantiation, it will return the imputed data in its original format.

Since we’re not using any data transformer class so we’ll set the reverse_transform parameter in impute method to False otherwise it’ll result in error.

from teras.tasks import Imputer

gain_imputer = Imputer(gain)

imputed_dataset = gain_imputer.impute(dataset, reverse_transform=False)
16/16 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step
print("Missing values in the original dataset: ", np.isnan(dataset).sum())
print("Missing values in the imputed dataset: ", np.isnan(imputed_dataset).sum())
Missing values in the original dataset:  1426
Missing values in the imputed dataset:  0

And that wraps it up! As you saw, it’s super easy and intuitive to use state of the art complex architectures for data imputation, thanks to teras!

If you have any questions or run into an issue, reach us at twitter @TerasML or file an issue at teras github repository.