Source code for teras._src.layers.categorical_extraction

import keras
from keras import ops
from teras._src.api_export import teras_export


[docs] @teras_export("teras.layers.CategoricalExtraction") class CategoricalExtraction(keras.layers.Layer): """ CategoricalExtraction layer extracts categorical features from inputs as is. It helps us build functional model where inputs to the model contain both categorical and continuous features, but they must diverge into two different branches for separate processing. Args: categorical_idx: list, list of indices of categorical features in the given dataset. """
[docs] def __init__(self, categorical_idx: list, **kwargs): super().__init__(trainable=False, **kwargs) self.categorical_idx = categorical_idx
def compute_output_shape(self, input_shape): output_shape = list(input_shape) output_shape[1] = len(self.categorical_idx) return tuple(output_shape) def call(self, inputs): categorical_features = ops.take(inputs, indices=self.categorical_idx, axis=1) return categorical_features def get_config(self): config = super().get_config() config.update({ "categorical_idx": self.categorical_idx })