Menü Kapat

Deep Learning Super Sampling Using Generative Adversarial Networks

DLSS Example 1
DLSS Example 2
DLSS Example 3

This article will provide you with a general framework of a DLSS Generative Adversarial Network written using the Keras library. You will be able to use this generic DLSS GAN template to train super resolution models using your own image datasets.

The complete version of the code is available at this GitHub repository.

Dataset

This template is written in a way where you should be able to use any RGB images as your dataset. An example of a good dataset that can be used can be downloaded from Kaggle here. This dataset provides two sets of images: one set of low resolution images with 96×96 pixels and one set of high resolution images with 384×384 pixels.

To use this dataset, you only need to define your settings within the user specified parameters. Set your input_path to LR folder within the downloaded dataset, and set your output_path to the HR folder. You also need to specify the dimensions of your images, so set input_dimensions to (96,96,3) and output_dimensions to (384,384,3). Note that the dimensions of the input and output are compatible in this case. The two image dimensions are compatible when the ratio between the high resolution and the low resolution is a multiple of 2. In this case, the 384×384 images are 4 times larger than the 96×96 images. This compatibility is required because of the way the Upsampling2D layers works in Keras. The layer doubles the resolution of the input and the super_sampling_ratio parameter determines how many times we apply the layer in our model to get from our input dimensions to our output dimensions.

It’s also important to note that you don’t need to use datasets that come with two sets of images. You can use this dataset for example, that only comes with one set of 256×256 images. You can just set the input_path and output_path parameters to the same folder and the images will be processed correctly. In this example, set your output_dimensions to (256,256,3) and set your input_dimesions to (64,64,3) or (128,128,3) and the images will be resized accordingly.

Imports

In [1]:
from keras.layers import Input, Dense, Reshape, Flatten
from keras.layers.advanced_activations import LeakyReLU
from keras.layers.convolutional import UpSampling2D, Conv2D, MaxPooling2D
from keras.models import Sequential, Model
from keras.optimizers import Adam
from sklearn.utils import shuffle
import matplotlib.pyplot as plt
import numpy as np
import os
from PIL import Image
import skimage.transform as st

This project requires the following libraries

  • Keras (I use 2.3.1)

  • Tensorflow (I use 1.14.0)

  • Sklearn

  • Skimage

  • Numpy

  • Matplotlib

  • PIL

Set Parameters

To customize this template to your specific needs, you need to define the parameters of the template. The current parameters are designed to be used on a GPU with 8GBs of VRAM. If you are working with a less powerful GPU, then reduce the number of convolutional filters and kernel size accordingly. The batch size can also be reduced for memory constraints. The rest of the parameters are explained below:

User Specified Parameters:

  • input_path: File path pointing to the folder containing the low resolution dataset.

  • output_path: File path pointing to the folder containing the high resolution dataset.

  • input_dimensions: Dimensions of the images inside the low resolution dataset. The image sizes must be compatible meaning output_dimensions / input_dimensions is a multiple of 2.

  • output_dimensions: Dimensions of the images inside the high resolution dataset. The image sizes must be compatible meaning output_dimensions / input_dimensions is a multiple of 2.

  • super_sampling_ratio: Integer representing the ratio of the difference in size between the two image resolutions. This integer specifies how many times the Upsampling2D and MaxPooling2D layers are used in the models.

  • model_path: File path pointing to the folder where you want to save to model as well as generated samples.

  • interval: Integer representing how many epochs between saving your model.

  • epochs: Integer representing how many epochs to train the model.

  • batch: Integer representing how many images to train at one time.

  • conv_filters: Integer representing how many convolutional filters are used in each convolutional layer of the Generator and the Discriminator.

  • kernel: Tuple representing the size of the kernels used in the convolutional layers.

  • png: Boolean flag, set to True if the data has PNGs to remove alpha layer from images.

In [2]:
# Folder containing input (low resolution) dataset
input_path = r'D:\Downloads\selfie2anime\trainB'

# Folder containing output (high resolution) dataset
output_path = r'D:\Downloads\selfie2anime\trainB'

# Dimensions of the images inside the dataset.
# NOTE: The image sizes must be compatible meaning output_dimensions / input_dimensions is a multiple of 2
input_dimensions = (128,128,3)

# Dimensions of the images inside the dataset.
# NOTE: The image sizes must be compatible meaning output_dimensions / input_dimensions is a multiple of 2
output_dimensions = (256,256,3)

# How many times to increase the resolution by 2 (by appling the UpSampling2D layer)
super_sampling_ratio = int(output_dimensions[0] / input_dimensions[0] / 2)

# Folder where you want to save to model as well as generated samples
model_path = r"C:\Users\Vee\Desktop\python\GAN\DLSS\results"

# How many epochs between saving your model
interval = 5

# How many epochs to run the model
epoch = 100

# How many images to train at one time.
# Ideally this number would be a factor of the size of your dataset
batch = 25

# How many convolutional filters for each convolutional layer of the generator and the discrminator
conv_filters = 64

# Size of kernel used in the convolutional layers
kernel = (5,5)

# Boolean flag, set to True if the data has pngs to remove alpha layer from images
png = True

Create Deep Convolutional GAN Class

This class contains 6 methods.

  • __init__(self): The class is initialized by defining the dimensions of the input image as well as the output image. The Generator and Discriminator models get initialized using build_generator() and build_discriminator().

  • build_generator(self): Defines Generator model. The Convolutional and UpSampling2D layers increase the resolution of the image by a factor of super_sampling_ratio * 2. Gets called when the DCGAN class is initialized.

  • build_discriminator(self): Defines Discriminator model. The Convolutional and MaxPooling2D layers downsample from output_dimensions to 1 scalar prediction. Gets called when the DCGAN class is initialized.

  • load_data(self): Loads data from user specified file path, data_path. Reshapes images from input_path to have input_dimensions. Reshapes images from output_path to have output_dimensions. Gets called in the train() method.

  • train(self, epochs, batch_size, save_interval): Trains the Generative Adversarial Network. Each epoch trains the model using the entire dataset split up into chunks defined by batch_size. If epoch is at save_interval, then the method calls save_imgs() to generate samples and saves the model at the current epoch.

  • save_imgs(self, epoch, gen_imgs, interpolated): Saves the model and generates prediction samples for a given epoch at the user specified path, model_path. Each sample contains 8 interpolated images and Deep Learned Super Sampled images for comparison.

Initialization

In [3]:
class DCGAN():
    
    # Initialize parameters, generator, and discriminator models
    def __init__(self):
        
        # Set dimensions of the output image
        self.img_rows = output_dimensions[0]
        self.img_cols = output_dimensions[1]
        self.channels = output_dimensions[2]
        self.img_shape = (self.img_rows, self.img_cols, self.channels)
        
        # Shape of low resolution input image
        self.latent_dim = input_dimensions
        
        # Chose optimizer for the models
        optimizer = Adam(0.0002, 0.5)

        # Build and compile the discriminator
        self.discriminator = self.build_discriminator()
        self.discriminator.compile(loss='binary_crossentropy',
            optimizer=optimizer,
            metrics=['accuracy'])

        # Build the generator
        self.generator = self.build_generator()
        generator = self.generator

        # The generator takes low resolution images as input and generates high resolution images
        z = Input(shape = self.latent_dim)
        img = self.generator(z)

        # For the combined model we will only train the generator
        self.discriminator.trainable = False

        # The discriminator takes generated images as input and determines validity
        valid = self.discriminator(img)

        # The combined model  (stacked generator and discriminator)
        # Trains the generator to fool the discriminator
        self.combined = Model(z, valid)
        self.combined.compile(loss='binary_crossentropy', optimizer=optimizer)

When the DCGAN class is initialized, we define the size of the images the Neural Network should expect from the dataset. The output dimensions gets specified by the Tuple img_shape. The input dimensions also gets specified by the Tuple latent_dim.

The optimizer we are using for both models is the Adam optimizer. Feel free to experiment with the learning rate and beta values of the optimizer and see what kind of results you get.

The architecture of the Generative Adversarial Network is defined here, with both models using Binary Cross Entropy loss. The choice of Binary Cross Entropy as the loss function is explained here. Feel free to experiment with other loss functions but just keep in mind that both models must use the same loss function.

Load Data

    # load data from specified file path 
    def load_data(self):
        
        # Initializing arrays for data and image file paths
        data = []
        small = []
        paths = []
        
        # Get the file paths of all image files in this folder
        for r, d, f in os.walk(output_path):
            for file in f:
                if '.jpg' in file or 'png' in file:
                    paths.append(os.path.join(r, file))
                    
        # For each file add high resolution image to array
        for path in paths:
            img = Image.open(path)
            
            # Resize Image
            y = np.array(img.resize((self.img_rows,self.img_cols)))
            
            # Remove alpha layer if imgaes are PNG
            if(png):
                y = y[...,:3]
                
            data.append(y)
          
        paths = []
        
        # Get the file paths of all image files in this folder
        for r, d, f in os.walk(input_path):
            for file in f:
                if '.jpg' in file or 'png' in file:
                    paths.append(os.path.join(r, file))
                    
        # For each file add low resolution image to array
        for path in paths:
            img = Image.open(path)
            
            # Resize Image
            x = np.array(img.resize((self.latent_dim[0],self.latent_dim[1])))
            
            # Remove alpha layer if imgaes are PNG
            if(png):
                x = x[...,:3]
                
            small.append(x)
        
            
        # Return x_train and y_train reshaped to 4 dimensions
        y_train = np.array(data)
        y_train = y_train.reshape(len(data),self.img_rows,self.img_cols,self.channels)
        x_train = np.array(small)
        x_train = x_train.reshape(len(small),self.latent_dim[0],self.latent_dim[0],self.latent_dim[2])
        
        del data
        del small
        del paths
        
        # Shuffle indexes of data
        X_shuffle, Y_shuffle = shuffle(x_train, y_train)
        
        return X_shuffle, Y_shuffle

The first method we are adding to the DCGAN class is load_data(). This will preprocess all images within the user specified paths, input_path and output_paths. The images in each folder will get resized to input_dimensions and output_dimensions accordingly. This method gets called inside the train() method to load the data before training.

Before we return the datasets, we shuffle the x_train and y_train datasets before returning the two arrays. I wrote the train() method to train the models on the dataset sequentially, incrementing by the batch size each iteration. So it is important to shuffle the dataset to not introduce weird biases relating to the way the dataset is sequentially ordered.

Build Generator

 # Define Generator model
    def build_generator(self):

        model = Sequential()
        
        # 1st Convolutional Layer / Input Layer
        model.add(Conv2D(conv_filters, kernel_size=kernel, padding="same", input_shape=self.latent_dim))
        model.add(LeakyReLU(alpha=0.2))
        
        # Upsample the data as many times as needed to reach output resolution
        for i in range(super_sampling_ratio):
        
            # Super Sampling Convolutional Layer
            model.add(Conv2D(conv_filters, kernel_size=kernel, padding="same"))
            model.add(LeakyReLU(alpha=0.2))

            # Upsample the data (Double the resolution)
            model.add(UpSampling2D())

        # Convolutional Layer
        model.add(Conv2D(conv_filters, kernel_size=kernel, padding="same"))
        model.add(LeakyReLU(alpha=0.2))

        # Convolutional Layer
        model.add(Conv2D(conv_filters, kernel_size=kernel, padding="same"))
        model.add(LeakyReLU(alpha=0.2))
        
        # Final Convolutional Layer (Output Layer)
        model.add(Conv2D(3, kernel_size=kernel, padding="same"))
        model.add(LeakyReLU(alpha=0.2))
        model.summary()

        noise = Input(shape=self.latent_dim)
        img = model(noise)

        return Model(noise, img)

The second method we are adding to the DCGAN class is build_generator(). This method is called when the class is first initialized. The architecture of the Generator model is designed here. The model summary will give you a clearer idea on what is actually happening inside this model.

DLSS Example 2

The input of the Generator model is a tensor representing an RGB image. In this case, a (128,128,3) image. This tensor is then upsampled to (256,256,3) as the output.

The output Convolutional layer contains 3 filters representing the Red, Green, and Blue channels of an RGB image respectively.

Build Discriminator

# Define Discriminator model
    def build_discriminator(self):

        model = Sequential()

        # Input Layer
        model.add(Conv2D(conv_filters, kernel_size=kernel, input_shape=self.img_shape,activation = "relu", padding="same"))
        
        # Downsample the image as many times as needed
        for i in range(super_sampling_ratio):
            
            # Convolutional Layer
            model.add(Conv2D(conv_filters, kernel_size=kernel))
            model.add(LeakyReLU(alpha=0.2))
        
            # Downsample the data (Half the resolution)
            model.add(MaxPooling2D(pool_size=(2, 2)))
        
        # Convolutional Layer
        model.add(Conv2D(conv_filters, kernel_size=kernel, strides = 2))
        model.add(LeakyReLU(alpha=0.2))

        # Convolutional Layer
        model.add(Conv2D(conv_filters, kernel_size=kernel, strides = 2))
        model.add(LeakyReLU(alpha=0.2))
        
        model.add(Flatten())
        
        # Output Layer
        model.add(Dense(1, activation='sigmoid'))

        model.summary()

        img = Input(shape=self.img_shape)
        validity = model(img)

        return Model(img, validity)

The third method we are adding to the DCGAN class is build_discriminator(). This method is called when the class is first initialized. The architecture of the Discriminator model is designed here. The model summary will give you a clearer idea on what is actually happening inside this model.

DLSS Example 2

The input of the Discriminator model is a tensor representing an RGB image. In this case, an (256,256,3) image. The tensor is then downsampled to 252×252, 126×126, 61×61, and 29×29. This 29×29 tensor is then flattened and passed to the output layer.

The final dense layer outputs a single scalar number, representing the prediction of the discriminator model. This prediction represents the confidence of the model in determining if the input image is “real”. A prediction of 1 means the model thinks that the image is from the original dataset. A prediction of 0 means that the model thinks that the image was generated by the Generator model.

Train

# Train the Generative Adversarial Network
    def train(self, epochs, batch_size, save_interval):
        
        # Prevent script from crashing from bad user input
        if(epochs <= 0):
            epochs = 1
        
        if(batch_size <= 0):
            batch_size = 1

        # Load the dataset
        X_train, Y_train = self.load_data()
        
        # Normalizing data to be between 0 and 1
        X_train = X_train / 255
        Y_train = Y_train / 255

        # Adversarial ground truths
        valid = np.ones((batch_size, 1))
        fake = np.zeros((batch_size, 1))
        
        # Placeholder arrays for Loss function values
        g_loss_epochs = np.zeros((epochs, 1))
        d_loss_epochs = np.zeros((epochs, 1))
        
        # Training the GAN
        for epoch in range(1, epochs + 1):
            
            # Initialize indexes for training data
            start = 0
            end = start + batch_size
            
            # Array to sum up all loss function values
            discriminator_loss_real = []
            discriminator_loss_fake = []
            generator_loss = []
            
            # Iterate through dataset training one batch at a time
            for i in range(int(len(X_train)/batch_size)):
                
                # Get batch of images
                imgs_output = Y_train[start:end]
                imgs_input = X_train[start:end]

                # Train Discriminator

                # Make predictions on current batch using generator
                gen_imgs = self.generator.predict(imgs_input)

                # Train the discriminator (real classified as ones and generated as zero)
                d_loss_real = self.discriminator.train_on_batch(imgs_output, valid)
                d_loss_fake = self.discriminator.train_on_batch(gen_imgs, fake)
                d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)

                #  Train Generator

                # Train the generator (wants discriminator to mistake images as real)
                g_loss = self.combined.train_on_batch(imgs_input, valid)
                
                # Add loss for current batch to sum over entire epoch
                discriminator_loss_real.append(d_loss[0])
                discriminator_loss_fake.append(d_loss[1])
                generator_loss.append(g_loss)
                
                # Increment image indexes
                start = start + batch_size
                end = end + batch_size

The fourth method we are adding to the DCGAN class is train(). This method will train the network for a specified number of epochs in increments specified by the batch size. When the training completes, the method will return two arrays representing the loss values of both models across every epoch. The loss values can be plotted using Matplotlib.

You should track the loss values and stop training the network if it starts collapsing. The network collapses if one of the models gets close to 0 loss.

DLSS Example 2

If the Generator gets close to 0 loss, then that means the Generator has figured out how to make an image that will fool the discriminator everytime. This will usually result in the Generator only being able to produce one type of image, also known as mode collapse.

If the Discriminator gets close to 0 loss, then that means that the Discriminator has figured out how to distinguish between the training data and generated images very accurately. This will cause the Generator to be unable to continue to learn from the discriminator, also known as the vanishing gradient problem.

To avoid losing our progress when our network collapses, we will save the model every few epochs. The user defined parameter, interval, will determine how often the model gets saved. Every time the current epoch lands on the defined interval, save_imgs() gets called. The method will save an image of some predicted samples to get a snapshot of how good the model was during that epoch.

Save Images

  # Save the model and generate prediction samples for a given epoch
    def save_imgs(self, epoch, gen_imgs, interpolated):
        
        # Define number of columns and rows
        r, c = 4, 4
        
        # Placeholder array for MatPlotLib Figure Subplots
        subplots = []
        
        # Create figure with title
        fig = plt.figure(figsize= (40, 40))
        fig.suptitle("Epoch: " + str(epoch), fontsize=65)
        
        # Initialize counters needed to track indexes across multiple arrays
        img_count = 0;
        index_count = 0;
        x_count = 0;
        
        # Loop through columns and rows of the figure
        for i in range(1, c+1):
            for j in range(1, r+1):
                # If row is even, plot the predictions
                if(j % 2 == 0):
                    img = gen_imgs[index_count]
                    index_count = index_count + 1
                # If row is odd, plot the interpolated images
                else:
                    img = interpolated[x_count]
                    x_count = x_count + 1
                # Add image to figure, add subplot to array
                subplots.append(fig.add_subplot(r, c, img_count + 1))
                plt.imshow(img)
                img_count = img_count + 1
        
        # Add title to columns of figure
        subplots[0].set_title("Interpolated", fontsize=45)
        subplots[1].set_title("Predicted", fontsize=45)
        subplots[2].set_title("Interpolated", fontsize=45)
        subplots[3].set_title("Predicted", fontsize=45)
                
        # Save figure to .png image in specified folder
        fig.savefig(model_path + "\\epoch_%d.png" % epoch)
        plt.close()
        
        # save model to .h5 file in specified folder
        self.generator.save(model_path + "\\generator" + str(epoch) + ".h5")

The fifth and last method we are adding to the DCGAN class is save_imgs(). This method will save the model at the current epoch and plot 8 super sampled images compared to their interpolated counterparts. The generated sample will allow you to compare the quality of the DLSS model vs. Nearest Neighbor Interpolation.

DLSS Example 2

This method is currently configured to save every 5 epochs. This can be adjusted with the parameter, interval. Frequently saving your model is a good way to track the progress your network is making during the training process.

Initializing The DCGAN Class

We are now done with creating the DCGAN class and ready to train our Generative Adversarial Network. First, we need to create an instance of the class and assign it to a variable.

In [4]

dcgan = DCGAN()

This will initialize the Generator and Discriminator models and print their summaries.

Training The Generative Adversarial Network

Now that we have our DCGAN class object, we just need to call the train() method to start training. With this script, you should generally pick a high number of epochs for training and track the loss values throughout the process. If the network starts collapsing, then stop the training early and check the generated samples to figure out which model was the best performing one.

The train() method returns two arrays containing the loss values of the two models throughout training. We will assign these values to g_loss, and d_loss and plot them.

In [5]

g_loss, d_loss = dcgan.train(epochs=epoch, batch_size=batch, save_interval=interval)
 
1 [D loss: 0.640689, acc.: 59.37%] [G loss: 0.967596]
2 [D loss: 0.575859, acc.: 73.34%] [G loss: 1.787223]
3 [D loss: 0.656025, acc.: 61.31%] [G loss: 1.042790]
4 [D loss: 0.656616, acc.: 60.19%] [G loss: 0.998186]
5 [D loss: 0.674997, acc.: 56.04%] [G loss: 0.893507]

Plot Loss

In [6]

plt.plot(g_loss)
plt.plot(d_loss)
plt.title('GAN Loss')
plt.ylabel('Loss')
plt.xlabel('Epoch')
plt.legend(['Generator', 'Discriminator'], loc='upper left')
plt.show()
DLSS Example 2

Analyze Results

Once you have trained a model you are satisfied with, it is time to test your model on various images. You can do this using this provided script. To use this, you must first define a few parameters. Set input_dimensions and ouput_dimesions to the same values you did in the training script. Set model_path to the path of the H5 model you want to use. The H5 models get saved in the folder specified by model_path in the training script during the save_imgs() method. Set the dataset_path parameter to the folder containing the images you want to test your model on. If the images contain any PNGs, set the png boolean flag to true to remove the alpha layers from the images. To make animated GIFs of your results, place the frames of a video inside the dataset_path folder. Lastly, set the save_path parameter to the folder you want the results of the model inference to be saved to.

In [19]

model = load_model(model_path)

Load Images and Super Sample

In [29]

paths = []
count = 0

for r, d, f in os.walk(dataset_path):
    for file in f:
        if '.png' in file or 'jpg' in file:
            paths.append(os.path.join(r, file))

for path in paths:
    
    # Select image
    img = Image.open(path)

    #create plot
    f, axarr = plt.subplots(1,3,figsize=(15,15),gridspec_kw={'width_ratios': [1,super_sampling_ratio,super_sampling_ratio]})
    axarr[0].set_xlabel('Original Image (' + str(input_dimensions[0]) + 'x' + str(input_dimensions[1]) + ')', fontsize=10)
    axarr[1].set_xlabel('Interpolated Image (' + str(output_dimensions[0]) + 'x' + str(output_dimensions[1]) + ')', fontsize=10)
    axarr[2].set_xlabel('Super Sampled Image (' + str(output_dimensions[0]) + 'x' + str(output_dimensions[1]) + ')', fontsize=10)

    #original image
    x = img.resize((input_dimensions[0],input_dimensions[1]))
    
    #interpolated (resized) image
    y = x.resize((output_dimensions[0],output_dimensions[1]))
    
    
    x = np.array(x)
    y = np.array(y)
    
    # Remove alpha layer if imgaes are PNG
    if(png):
        x = x[...,:3]
        y = y[...,:3]
    
    #plotting first two images
    axarr[0].imshow(x)
    axarr[1].imshow(y)
    
    #plotting super sampled image
    x = x.reshape(1,input_dimensions[0],input_dimensions[1],input_dimensions[2])/255
    result = np.array(model.predict_on_batch(x))*255
    result = result.reshape(output_dimensions[0],output_dimensions[1],output_dimensions[2])
    np.clip(result, 0, 255, out=result)
    result = result.astype('uint8')
                
    axarr[2].imshow(result)
    
    # Save image
    f.savefig(save_path + '\\frame_%d.png' % count)
    
    # Increment file name counter
    count = count + 1

Once you have set these parameters, the script will reshape each image to the size specified by input_dimesions and feed it as input to your model. Then the script will plot the image outputted by your model compared to what the input image would look like interpolated using Nearest Neighbor Interpolation. These generated samples from your model will be saved in the folder specified by the save_path parameter. These samples are a good way to analyze the quality of your model. You can compare the way an image looks before it is inputted to your model, the way it looks after it passes through your model, and the way it looks super sampled using another technique. The results of some of my models are displayed below.

Results

DLSS Example 1
DLSS Example 1
DLSS Example 1
DLSS Example 1
DLSS Example 1
DLSS Example 1

Conclusion

This article provides you with a general framework of training a DLSS Generative Adversarial Network using Keras. You will be able to create your own DLSS models for various resolutions with your own datasets using this script. The full version of this code can be found here.

Once you have trained a model you are satisfied with, you can use this script to generate outputs and analyze your results. The script will also provide you with code to create GIFs of super sampled video frames.

Github

Veri kümesi

Jupyter Notebook

tr_TRTurkish