Source code for teras._src.layers.continuous_extraction
import keras
from keras import ops
from teras._src.api_export import teras_export
[docs]
@teras_export("teras.layers.ContinuousExtraction")
class ContinuousExtraction(keras.layers.Layer):
"""
ContinuousExtraction layer extracts continuous features from inputs
as is. It helps us build functional model where inputs to the
model contain both categorical and continuous features, but they
must diverge into two different branches for separate processing.
Args:
continuous_idx: list, a list of indices of continuous
features in the given dataset.
"""
[docs]
def __init__(self,
continuous_idx: list,
**kwargs):
super().__init__(trainable=False, **kwargs)
self.continuous_idx = continuous_idx
def compute_output_shape(self, input_shape):
output_shape = list(input_shape)
output_shape[1] = len(self.continuous_idx)
return tuple(output_shape)
def call(self, inputs):
continuous_features = ops.take(inputs,
indices=self.continuous_idx,
axis=1)
return continuous_features
def get_config(self):
config = super().get_config()
config.update({
"continuous_idx": self.continuous_idx
})