Image Reconstruction with pre-trained GAN using perceptual and contextual losses

In this post I detail my implementation and some initial results for image reconstruction using a pre-trained Generative Adversarial Network (GAN). I will be using the approach recommended in Yeh et al. by using a perceptual and contextual loss in the reconstruction stage. This post is related to my Deep Learning (IFT6266) course project class.

As usual, more information with regards to my broader plan and summary of the project can be found here, more details on the project here and all the code used can be found in this GitHub repo.

The pre-trained GAN

Following my previous post about the course project, the model was trained for a longer period of time for enhanced performance. For more details regarding the implementation of the GAN, you can refer to it.

The model’s parameters were saved at each 5 epochs in order to be able to use the discriminator and generator at a later time. In addition, this allows us to compare performance of multiple models and choose which one to use for image reconstruction/inpainting.

In order to select which saved model to use, we start by looking at the evolution of the losses over training time. This figure is similar to my previous post, only differs for longer training time and longer average period.

plot_dcgan_loss_30d_10g_60epoch

The first thing that strikes the eye is the increase in the loss of the generator around the 30,000 training step mark. I interpret this as the following,

  • The discriminator is now more confident that the image generated is fake. A generator’s loss of 0.5 indicates a score of 0.6, while a loss of 1.0 indicates a score of 0.36. Given its better performance, this could mean that the discriminator has figured out a characteristics of the generator’s output that it can recognize as being a fake image.

One issue I have with the high variance that can be noticed towards the end of training is it could also be explained by the optimization problem. Since the discriminator and generator are trained for 30 and 10 steps each, respectively, high losses occur when a model starts a training phase because the other one is already performing well.

Selecting the pre-trained model

Now comes the question to select the appropriate model to move to the second stage that is image reconstruction. In order to do so, we will examine different random samples generated at different epochs.

The decision rule that I opted for is to select the model generating images that looks more realistic to my eye. I feel this is an appropriate way to do so since it will be the way the images are evaluated, it is not by measuring the discriminator’s score or the loss.

Starting at 30 epochs on the first row, each other row represents 5 more training epochs, the last one being at 55 epochs.

One thing to remember is that these are generated purely out of noise. We can notice some fine detail in the last two rows in addition to a range a colour, while still having some structure. This leads me to believe that the longer trained model will be appropriate for the reconstruction stage.

Reconstruction stage

The authors of Yeh et al. argue for using a pre-trained GAN to then find the appropriate image reconstruction, by minimizing another loss function on the reconstructed image. The noise input that corresponds to that best fitted image would then be used to generate the image by using the network’s generator.

They split the total loss as two components,

  1. Contextual loss — This component is responsible to ensure the uncorrupted portion of the image is well generated by the generator. This is done my minimizing the distance between the uncorrupted portion of the generated image and the uncorrupted portion of the image being reconstructed. This can be seen as ||\textbf{M} \odot G(\textbf{z}) - \textbf{M} \odot \textbf{y}||_p^p, where \textbf{M} denotes the uncorrupted mask. Based on empirical results, the authors recommend using the l_1-norm as they were obtaining sharper images than with l_2-norm.
  2. Perceptual loss — This component is responsible to ensure the image generated is recognized to be a real image by the network’s discriminator. They recommend minimizing \log(1 - D(G(\textbf{z}))), denoting the \log of the probability of the generated image being fake.

Conceptually, we are therefore trying to find the noise \textbf{z} that minimizes the difference between the uncorrupted part of the reconstructed image and the uncorrupted part of the generated image, while ensuring the full image would be considered as a real one by the discriminator.

By applying back-propagation to \textbf{z} w.r.t. to the total loss and given \lambda controlling the importance of the perceptual loss, we solve the following,

\hat{\textbf{z}} =  \text{argmin}_\textbf{z}( L_{contextual}(\textbf{z}) + \lambda L_{perceptual} (\textbf{z}))

The image can then be reconstructed by,

\textbf{x}_{reconstructed} = \textbf{M} \odot \textbf{y} + (1 - \textbf{M}) \odot G(\hat{\textbf{z}})

To put it in words, we are only taking the corrupted portion from the generated image, while keeping the uncorrupted part of the original image.

To solve for \hat{\textbf{z}}, I run gradient descent until the l_1-norm of the gradient is < 0.0001. This is another hyper-parameter to choose, but I noticed it had good balance in terms of speed of convergence and loss minimization.

To further enhance the reconstruction smoothness, the authors recommend applying Poisson blending. However I haven’t had the chance to implement it yet.

An important point that can be made for this method is that it doesn’t require the corrupted portion of the image to be the center, i.e. the model isn’t trained to generate a 32×32 image, the corruption mask is versatile. I think this is a major advantage over other techniques that focus only on the cropped out center.

Preliminary results

In this section I will detail preliminary results obtained with the above detailed procedure/method. The authors mention the \lambda is very important to the image reconstruction, a more in depth analysis of the hyper-parameters will need to be done.

Below are sets of some randomly selected images for which I show the corrupted, reconstructed and the true image, all from the validation set. Results are showed for different values of hyper-parameter of the \hat{\textbf{z}} optimization. Unlike the authors, the l_2-norm was used.

  • \lambda = 10^{-5} and learning rate = 0.9
  • \lambda = 10^{-4} and learning rate = 0.9
  • \lambda = 10^{-6} and learning rate = 0.9

Analysis

In the model’s defence, in the above pictures there are some where it simply cannot know what the interior should be. Looking for example in the first set at the top-left picture, without the caption, the model cannot know the suit should be red.

It also seems like the model is not necessarily looking to fill it by the true image, it is simply trying to make it a real one that fits well with the contour. This is where the captions would come into pay, where they would provide much more information for the reconstruction.

However, I find these are quite surprisingly good results and I am pleased with them! If we look at the reconstructed images, it looks to me like images from a generative model. The reconstructed images look like they are appropriate, especially in the blending of the colors, but in a sens, if we take too much time to see if it does look good, the image starts to look weird.

Also, there are some images that simply do not work, but let’s not talk about those, at least I’m not just showing the ones that are beautiful :).

Next step

One thing I noticed with the optimization in order to solve for \hat{\textbf{z}}, is maybe pure gradient descent isn’t appropriate, as it may be finding a local minimum based on the initial values of the weights when the optimization starts. An alternative would be to set a mini-batch of initial values and then apply gradient descent on  them, but select the set of noise that has lower loss.

In other words, I’ll have to think about this one…

Also, implementing Poisson blending as suggested in the paper might be a good idea!

Advertisements

Leave a Reply

Fill in your details below or click an icon to log in:

WordPress.com Logo

You are commenting using your WordPress.com account. Log Out / Change )

Twitter picture

You are commenting using your Twitter account. Log Out / Change )

Facebook photo

You are commenting using your Facebook account. Log Out / Change )

Google+ photo

You are commenting using your Google+ account. Log Out / Change )

Connecting to %s