Source code for teras._src.layers.activation.gumbel_softmax

import keras
from teras._src.activations import gumbel_softmax
from teras._src.api_export import teras_export


[docs] @teras_export("teras.layers.GumbelSoftmax") class GumbelSoftmax(keras.layers.Layer): """ Implementation of the Gumbel Softmax activation proposed by Eric Jang et al. in the paper, "Categorical Reparameterization with Gumbel-Softmax" Reference(s): https://arxiv.org/abs/1611.01144 Args: temperature: float, Controls the sharpness or smoothness of the resulting probability distribution. A higher temperature value leads to a smoother and more uniform probability distribution. Conversely, a lower temperature value makes the distribution concentrated around the category with the highest probability. hard: bool, Whether to return soft probabilities or hard one hot vectors. Defaults to False. seed: int, Seed for random sampling. """
[docs] def __init__(self, temperature: float = 0.2, hard: bool = False, seed: int = 1337, **kwargs): super().__init__(**kwargs) self.temperature = temperature self.hard = hard self.seed = seed
def build(self, input_shape): # nothing to build self.built = True def call(self, logits): return gumbel_softmax(logits, temperature=self.temperature, hard=self.hard, seed=self.seed) def get_config(self): config = super().get_config() new_config = {'temperature': self.temperature, 'hard': self.hard} config.update(new_config) return config