A gander at GANs
I recently gave myself a project to bring together two recent interests of mine: learning PyTorch, and GANs. PyTorch is an open-source machine learning library for Python that has been gaining traction over competitors like TensorFlow, and so I thought it was high time I learned it. Independently of that, I've become fascinated by Generative Adversarial Networks, GANs for short, and so I thought I would turn my first PyTorch model into a GAN for generating expressions of language. GANs are one of the most fascinating machine learning developments in the last couple years, and after attending a couple talks about their potential applications in NLP, I wanted to better understand what the fuss was about.
Here is the basic idea behind a GAN. You train two neural nets, one that generates fake data, and another that discriminates between fake data and real data. These networks play off each other, with the discriminator being used to improve the generator and vice versa. GANs were invented for image generation, and this remains the most useful illustration of the idea. The image below is from deeplearning4j's introduction to GANs:
You can train a GAN to generate fake photos of birds that look like real photos of birds. The generator starts by generating images at random, which will just look like white noise. The discriminator's job is easy at this point---look at a bunch of real images of birds, and look at a bunch of noise generated by the generator, and learn to tell the difference.
But the trick is that the generator can be trained to minimize the discriminator's success. That is, the weights of the neural net used to create forgeries of images of birds are updated so as to minimize the likelihood that the discriminator will successfully classify it as a fake. And thus the generator will learn how to trick the discriminator. But as the generator learns to trick the discriminator, so too does the discriminator get better at its job, because the discriminator is repeatedly trained on the increasingly realistic bird pictures forged by the generator.
This feedback loop ideally results in pictures that look like actual pictures of birds, as well as a discriminator that is really good at detecting forgeries. The image below, found here, shows some actual forgeries produced by a GAN:
Pretty impressive. Can GANs be used to "forge" language? There has been some work in this area, such as this paper. The results are modest, but nonetheless, I am intrigued by the idea of adversarial language generation.
As I was coding my first PyTorch model, I thought it would be fun to turn it into a GAN. This GAN is fed a bunch of questions from the SQuAD data set, and tasked with producing forged "questions" of English. I use a simple feed-forward network to generate strings of words from a "seed" of 1000 random digits, and for the discriminator I use a convolutional neural net with inputs represented as sequences of one-hot encoded words. Each round of training, the discriminator is given a few thousand examples of real questions, then a few thousand fake ones generated by the generator, and then the generator is trained to minimize the discriminator's success rate.
I had figured that my first, rather simple attempt at this would generate nothing but noise, but actually, the results were pretty good. This is very much a work in progress, with more features, a more complex model architecture, and multiple objectives likely required to produce truly convincing forgeries, but we can see in the results below that the model rather quickly picks up on some basic patterns of question formation. Below are some "questions" produced by the model after a few rounds of training:
# what old america legend states ?
# when is original defence eisenhowers ?
# is nyc highest education ?
# what is worlds capital ?
# what killing did napoleon ?
# is welsh british born ?
# what parents of most species ?
# what is indigenous of human stems ?
Yes, even the best ones are admittedly...
But let's keep in mind that this first pass uses discriminator success (or lack thereof) as its only objective, and thus the model has no idea what any of these words mean or what their grammatical properties are. So this is an interesting first step.
The code for this can be found on GitHub here.
In any case, what practical use could this have? Beyond the obvious applications relating to natural language generation, training a really good discriminator, capable of reliably separating out almost-well-formed sentences from actually-well-formed sentences could be very useful. Anyone who has ever received an e-mail from a Nigerian prince should realize this.