Sure! I'll try to give a more intuitive explanation. Basically, it's been known for awhile that intelligence is very related to compression. Why? Suppose you're trying to describe a phenomena you see (such as an IQ test). To precisely describe what's going on, you'd have to write it in some kind of formal language, like a programming language. Maybe your description looks like:
This wavey piece can be described by algorithm #1, while this spikey piece can be described by algorithm #2, while this...
More precisely, you try to write your phenomena as a weighted sum of these algorithms:
phenomena = sum weight * algorithm
There are exponentialy more algorithms that have more bits, so if you want this sum to ever converge, you need to have exponentially smaller weights for longer algorithms. Thus, most of the weight is concentrated in shorter algorithms, so a simple explanation is going to be a really good one!
What the authors are trying to do is find a simple (small number of bits) algorithm that reconstructs the puzzles and the example solutions they're given. As a byproduct, the algorithm will construct a solution to the final problem that's part of the puzzle. If the algorithm is simple enough, it won't be able to just memorize the given examples—it has to actually learn the trick to solving the puzzle.
Now, they could have just started enumerating out programs, beginning with `0`, `1`, `00`, `01`, ..., and seeing what their computer did with the bits. Eventually, they might hit on a simple bit sequence that the computer interprets as an actual program, and in fact one that solves the puzzle. But, that's very slow, and in fact the halting problem says you can't rule out some of your programs running forever (and your search getting stuck). So the authors turned to a specialized kind of computer, one that they know will stop in a finite number of steps...
...and that "computer" is a fixed-size neural network! The bit sequence they feed in goes to determine (1) the inputs to the neural network, and (2) the weights in the neural network. Now, they cheat a little, and actually just specify the inputs/weights, and then figure out what bits would have given them those inputs/weights. That's because it's easier to search in the input/weight space—people do this all the time with neural networks.
They initialize the space of inputs/weights as random normal distributions, but they want to change these distributions to be concentrated in areas that correctly solve the puzzle. This means they need additional bits to specify how to change the distributions. How many extra bits does it take to specify a distribution q, if you started with a distribution p? Well, it's
- sum q(x) log p(x) + sum p(x) log p(x)
(expected # bits for random q) (expected # bits for random p)
This is known as the KL-divergence, which we write as KL(q||p). They want to minimize the length of their program, which means they want to minimize the expected number of additional bits they have to use, i.e. KL(q(inputs)||p(inputs)) + KL(q(weights)||p(weights)).
There's a final piece of the puzzle: they want their computer to exactly give the correct answer for the example solutions they know. So, if the neural network outputs an incorrect value, they need extra bits to say it was incorrect, and actually here's the correct value. Again, the expected number of bits is just going to be a KL-divergence, this time between the neural network's output, and the correct answers.
Putting this altogether, they have a simple computer (neural network + corrector), and a way to measure the bitlength for various "programs" they can feed into the computer (inputs/weights). Every program will give the correct answers for the known information, but the very simplest programs are much more likely to give the correct answer for the unknown puzzles too! So, they just have to train their distributions q(inputs), q(weights) to concentrate on programs that have short bitlengths, by minimizing the loss function
They specify p(inputs) as the usual normal distribtuion, p(weights) as a normal distribution with variance around 1/(dimension of inputs) (so the values in the neural network don't explode), and finally have trainable parameters for the mean and variance of q(inputs) and q(weights).
This wavey piece can be described by algorithm #1, while this spikey piece can be described by algorithm #2, while this...
More precisely, you try to write your phenomena as a weighted sum of these algorithms:
There are exponentialy more algorithms that have more bits, so if you want this sum to ever converge, you need to have exponentially smaller weights for longer algorithms. Thus, most of the weight is concentrated in shorter algorithms, so a simple explanation is going to be a really good one!What the authors are trying to do is find a simple (small number of bits) algorithm that reconstructs the puzzles and the example solutions they're given. As a byproduct, the algorithm will construct a solution to the final problem that's part of the puzzle. If the algorithm is simple enough, it won't be able to just memorize the given examples—it has to actually learn the trick to solving the puzzle.
Now, they could have just started enumerating out programs, beginning with `0`, `1`, `00`, `01`, ..., and seeing what their computer did with the bits. Eventually, they might hit on a simple bit sequence that the computer interprets as an actual program, and in fact one that solves the puzzle. But, that's very slow, and in fact the halting problem says you can't rule out some of your programs running forever (and your search getting stuck). So the authors turned to a specialized kind of computer, one that they know will stop in a finite number of steps...
...and that "computer" is a fixed-size neural network! The bit sequence they feed in goes to determine (1) the inputs to the neural network, and (2) the weights in the neural network. Now, they cheat a little, and actually just specify the inputs/weights, and then figure out what bits would have given them those inputs/weights. That's because it's easier to search in the input/weight space—people do this all the time with neural networks.
They initialize the space of inputs/weights as random normal distributions, but they want to change these distributions to be concentrated in areas that correctly solve the puzzle. This means they need additional bits to specify how to change the distributions. How many extra bits does it take to specify a distribution q, if you started with a distribution p? Well, it's
This is known as the KL-divergence, which we write as KL(q||p). They want to minimize the length of their program, which means they want to minimize the expected number of additional bits they have to use, i.e. KL(q(inputs)||p(inputs)) + KL(q(weights)||p(weights)).There's a final piece of the puzzle: they want their computer to exactly give the correct answer for the example solutions they know. So, if the neural network outputs an incorrect value, they need extra bits to say it was incorrect, and actually here's the correct value. Again, the expected number of bits is just going to be a KL-divergence, this time between the neural network's output, and the correct answers.
Putting this altogether, they have a simple computer (neural network + corrector), and a way to measure the bitlength for various "programs" they can feed into the computer (inputs/weights). Every program will give the correct answers for the known information, but the very simplest programs are much more likely to give the correct answer for the unknown puzzles too! So, they just have to train their distributions q(inputs), q(weights) to concentrate on programs that have short bitlengths, by minimizing the loss function
They specify p(inputs) as the usual normal distribtuion, p(weights) as a normal distribution with variance around 1/(dimension of inputs) (so the values in the neural network don't explode), and finally have trainable parameters for the mean and variance of q(inputs) and q(weights).