Source code for teras._src.tasks.generation

import keras
from keras import random, ops
from teras._src.models.gans.ctgan.ctgan import CTGAN
from teras._src.models.gans.ctgan.generator import CTGANGenerator
from teras._src.api_export import teras_export


[docs] @teras_export("teras.tasks.Generator") class Generator: """ Generator class that provides methods related to data generation. Args: model: keras.Model, instance of the trained model that will be used to generate data data_transformer: Instance of data transformer used to transform data for training. data_sampler: Instance of the data sampler used to sample data for training. """
[docs] def __init__(self, model: keras.Model, data_transformer, data_sampler=None, ): self.model = model self.data_transformer = data_transformer self.data_sampler = data_sampler
def generate(self, num_samples, latent_dim, batch_size=None, verbose="auto", steps=None, callbacks=None, seed=None): """ Generates new data samples. It exposes all the arguments taken by the `predict` method. Args: num_samples: int, number of samples to generate. latent_dim: int, latent dimensions for sampling noise. It should be the same as used in the model during training. """ z = random.normal((num_samples, latent_dim), seed=seed) if isinstance(self.model, (CTGAN, CTGANGenerator)): if self.data_sampler is None: raise ValueError( "For `CTGAN` architecture `data_sampler` cannot be `None`." "you must pass the data sampler instance that was used to " "train the architecture. " f"Received: {self.data_sampler}" ) cond_vectors = self.data_sampler.sample_cond_vectors_for_generation( batch_size=num_samples ) z = ops.concatenate([z, cond_vectors], axis=1) x_generated = self.model.predict(z, batch_size=batch_size, verbose=verbose, steps=steps, callbacks=callbacks ) x_generated = self.data_transformer.reverse_transform(x_generated) return x_generated