|
4 | 4 | "cell_type": "markdown",
|
5 | 5 | "metadata": {},
|
6 | 6 | "source": [
|
7 |
| - "## Learning in Neural Networks\n", |
| 7 | + "## Backpropagation for humans\n", |
8 | 8 | "\n",
|
9 | 9 | "\n",
|
10 | 10 | "This is probably the least understood algorithm in Machine Learning but is extremely intuitive. In this post we'll explore how to mathematically derive backpropagation and get an intuition how it works."
|
|
18 | 18 | "The learning process is simply adjusting the weights and biases that's it! The Neural Netowork does this by a process called Backpropagation. The steps are as follows:\n",
|
19 | 19 | "1. Randomly initialise weights\n",
|
20 | 20 | "2. __Forward Pass__: Predict a value using an activation function. \n",
|
21 |
| - "2. See how bad you're performing using loss function. \n", |
22 |
| - "3. __Backward Pass__: Backpropagate the error. That is, tell your network that it's wrong, and also tell what direction it's supposed to go in order to reduce the error. This step updates the weights (here's where the network learns!)\n", |
23 |
| - "4. Repeat steps 2 & 3 until the error is reasonably small or for a specified number of iterations. \n", |
| 21 | + "3 See how bad you're performing using loss function. \n", |
| 22 | + "4. __Backward Pass__: Backpropagate the error. That is, tell your network that it's wrong, and also tell what direction it's supposed to go in order to reduce the error. This step updates the weights (here's where the network learns!)\n", |
| 23 | + "5. Repeat steps 2 & 3 until the error is reasonably small or for a specified number of iterations. \n", |
24 | 24 | "\n",
|
25 | 25 | "Step 3 is the most important step. We'll mathematically derive the equation for updating the values. \n",
|
26 | 26 | "\n",
|
|
73 | 73 | "\\end{bmatrix}\n",
|
74 | 74 | "$$\n",
|
75 | 75 | "\n",
|
76 |
| - "And second level weights as:\n", |
| 76 | + "And second layer weights as:\n", |
77 | 77 | "$$\n",
|
78 | 78 | "\\theta_2 = \n",
|
79 | 79 | "\\begin{bmatrix}\n",
|
|
86 | 86 | "$$ z_1^{\\left(2\\right)}=\\theta_{10}^{\\left(1\\right)}+\\theta_{11}^{\\left(1\\right)}x_1+\\theta_{12}^{\\left(1\\right)}x_2 + ....\\text{for all the $z$s}$$\n",
|
87 | 87 | "\n",
|
88 | 88 | "All we do is:\n",
|
89 |
| - "$$ \\tag 3 z^{\\left(2\\right)}=\\theta^{\\left(2\\right)}\\cdot X $$\n", |
| 89 | + "$$ \\tag 3 z^{\\left(2\\right)}=\\theta^{\\left(2\\right)}\\cdot X^T $$\n", |
90 | 90 | "\n",
|
91 | 91 | "And the activity at the second layer is thus\n",
|
92 | 92 | "$$ \\tag 4 a^{\\left(2\\right)}=\\sigma\\left(z^{\\left(2\\right)}\\right) $$\n",
|
93 | 93 | "Which is the same as:\n",
|
94 |
| - "$$ \\tag 5 a^{\\left(2\\right)}=\\sigma\\left(\\theta^{\\left(2\\right)}\\cdot X\\right) $$\n", |
| 94 | + "$$ \\tag 5 a^{\\left(2\\right)}=\\sigma\\left(\\theta^{\\left(2\\right)}\\cdot X^T\\right) $$\n", |
95 | 95 | "\n",
|
96 | 96 | "Repeating the same step for the third layer will give us the output. \n",
|
97 | 97 | "$$ \\tag 6 z^{\\left(3\\right)}=\\theta^{\\left(2\\right)}\\cdot a^{\\left(2\\right)} $$\n",
|
|
105 | 105 | "source": [
|
106 | 106 | "## Forward Pass\n",
|
107 | 107 | "\n",
|
108 |
| - "Let's take an example of a Neural Network to solve the MNIST character recognition problem. Every image is 20x20 pixel in dimension, hence the a single input will (20x20) 400 features. Remember, that the input is the first layer, so the number of neurons in the first layer will be 400. The second layer will be the hidden layer, let's say that the number of neurons in the hidden layers is 25. And since we're predicting whether the image is a number from 0-9 there are 10 discrete outputs, hence the output layer will have 10 neurons. Each of the neuron in output layer will predict a value between 0 and 1. Since these values as probabilities, the value that has the highest probability will be the winner. \n", |
| 108 | + "Let's take an example of a Neural Network to solve the MNIST character recognition problem. Every image is 20x20 pixel in dimension, hence the a single input will (20x20) 400 features. Remember, that the input is the first layer, so the number of neurons in the first layer will be 400. The second layer will be the hidden layer, let's say that the number of neurons in the hidden layers is 25. And since we're predicting whether the image is a number from 0-9 there are 10 discrete outputs, hence the output layer will have 10 neurons. Each of the neuron in output layer will predict a value between 0 and 1. Since these values are probabilities, the value that has the highest probability will be the winner. \n", |
109 | 109 | "\n",
|
110 | 110 | "#### Dimension of (input) X = (5000, 400) \n",
|
111 | 111 | "\n",
|
|
192 | 192 | "cell_type": "markdown",
|
193 | 193 | "metadata": {},
|
194 | 194 | "source": [
|
195 |
| - "### Easier part $\\frac{\\partial J}{\\partial \\theta^{\\left(2\\right)}}$\n", |
| 195 | + "## Easier part $\\frac{\\partial J}{\\partial \\theta^{\\left(2\\right)}}$\n", |
196 | 196 | "\n",
|
197 | 197 | "Calculating $\\frac{\\partial J}{\\partial \\theta^{\\left(2\\right)}}$ is easier than calculating $\\frac{\\partial J}{\\partial \\theta^{\\left(1\\right)}}$ so we'll start by that first. We'll go step by step and try to understand what each step is accomplishing. \n",
|
198 | 198 | "\n"
|
|
219 | 219 | "\\frac{\\partial J}{\\partial W^{\\left(2\\right)}} &= \\frac{\\partial\\frac{1}{2}\\left(y-\\hat y\\right)^2}{\\partial W^{\\left(2\\right)}} \\\\\n",
|
220 | 220 | "\\notag\n",
|
221 | 221 | "&= (y-\\hat y)\\cdot\\left(-\\frac{\\partial \\hat y}{\\partial W^{\\left(2\\right)}}\\right)\n",
|
222 |
| - "\\end{align} $$\n", |
| 222 | + "\\end{align} \n", |
| 223 | + "$$\n", |
| 224 | + "\n", |
223 | 225 | "We have to differentiate $\\hat y$ to respect the [Chain Rule](https://www.youtube.com/watch?v=6kScLENCXLg). This minus sign in the second term comes from differentiating $-\\hat y$\n",
|
224 | 226 | "\n",
|
225 | 227 | "Using Equation (7) and (8) we have, \n",
|
|
241 | 243 | "In the last part of the equation we'll be differentiating $W^{\\left(2\\right)} \\cdot a^{\\left(2\\right)}$ by $W^{\\left(2\\right)}$. We know that the derivative of $4x$ with respect to $x$ is $4$ so the derivative of $W^{\\left(2\\right)} \\cdot a^{\\left(2\\right)}$ with respect $W^{\\left(2\\right)}$ will be $a^{\\left(2\\right)}$\n",
|
242 | 244 | "\n",
|
243 | 245 | "$$ \n",
|
244 |
| - "\\tag 9\n", |
245 |
| - "\\frac{\\partial J}{\\partial W^{\\left(2\\right)}} = \\left(z-y\\right)\\cdot\\sigma'\\left(z^{\\left(3\\right)}\\right)\\cdot\\left(a^{\\left(2\\right)}\\right)$$\n", |
| 246 | + "\\frac{\\partial J}{\\partial W^{\\left(2\\right)}} = \\left(z-y\\right)\\cdot\\sigma'\\left(z^{\\left(3\\right)}\\right)\\cdot\\left(a^{\\left(2\\right)}\\right)\n", |
| 247 | + "$$\n", |
| 248 | + "\n", |
| 249 | + "We'll denote the error term in the final layer by $\\delta^{(3)}$\n", |
| 250 | + "\n", |
| 251 | + "$$ \n", |
| 252 | + "\\tag{9}\n", |
| 253 | + "\\frac{\\partial J}{\\partial W^{\\left(2\\right)}} = \\delta^{\\left(3\\right)}\\cdot a^{\\left(2\\right)}\n", |
| 254 | + "$$\n", |
246 | 255 | "\n",
|
247 | 256 | "Now, coming back to the summation we ignored at the top of the derivation, we're going to fix that in the implementation using an accumulator matrix which will store the errors for every row and sum it up. "
|
248 | 257 | ]
|
249 | 258 | },
|
250 | 259 | {
|
251 | 260 | "cell_type": "markdown",
|
252 | 261 | "metadata": {},
|
253 |
| - "source": [] |
| 262 | + "source": [ |
| 263 | + "## Sucky part $\\frac{\\partial J}{\\partial \\theta^{\\left(1\\right)}}$\n", |
| 264 | + "\n", |
| 265 | + "It's nearly the same as the previous step, but involves one additional step using chain rule. We'll start in the same way. " |
| 266 | + ] |
| 267 | + }, |
| 268 | + { |
| 269 | + "cell_type": "markdown", |
| 270 | + "metadata": {}, |
| 271 | + "source": [ |
| 272 | + "$\n", |
| 273 | + "\\begin{align}\n", |
| 274 | + "\\tag {from (1)}\n", |
| 275 | + "\\frac{\\partial J}{\\partial W^{\\left(1\\right)}} &= \\frac{\\partial\\frac{1}{2}\\sum_{i=0}^m\\left(y-\\hat y\\right)^2}{\\partial W^{\\left(1\\right)}} \\\\\n", |
| 276 | + "\\notag\n", |
| 277 | + "&=\\frac{\\sum_{i=0}^m\\partial\\frac{1}{2}\\left(y-z\\right)^2}{\\partial W^{\\left(1\\right)}} \\\\\n", |
| 278 | + "&= (y-\\hat y)\\cdot\\left(-\\frac{\\partial \\hat y}{\\partial W^{\\left(1\\right)}}\\right)\n", |
| 279 | + "\\end{align}\n", |
| 280 | + "$\n", |
| 281 | + "Not that we've skipped the summation sign as before. Mathematicians might be cursing me at this point. \n", |
| 282 | + "\n", |
| 283 | + "$$\n", |
| 284 | + "\\begin{align}\n", |
| 285 | + "\\notag\n", |
| 286 | + "\\frac{\\partial J}{\\partial W^{\\left(1\\right)}} &= \\left(z-y\\right)\\cdot\\sigma'\\left(z^{\\left(3\\right)}\\right)\\cdot\\left(\\frac{\\partial z^{\\left(3\\right)}}{\\partial W^{\\left(1\\right)}}\\right) \\\\\n", |
| 287 | + "\\end{align}\n", |
| 288 | + "$$\n", |
| 289 | + "\n", |
| 290 | + "Things start to get a little different here. We cannot directly differentiate $z^{(3)}$ with respect to $W^{(1)}$ because $z^{(3)}$ does not directly depend on $W^{(1)}$. So we will use our good ol' chain rule again and divide it further.\n", |
| 291 | + "\n", |
| 292 | + "$$\n", |
| 293 | + "\\frac{\\partial J}{\\partial W^{\\left(1\\right)}} = \\left(z-y\\right)\\cdot\\sigma'\\left(z^{\\left(3\\right)}\\right)\\cdot \\frac{\\partial z^{\\left(3\\right)}}{\\partial a^{\\left(2\\right)}}\\cdot\\frac{\\partial a^{\\left(2\\right)}}{\\partial W^{\\left(1\\right)}}\n", |
| 294 | + "$$\n", |
| 295 | + "\n", |
| 296 | + "Replacing the value of $\\delta^{(3)}$ from equation (9)\n", |
| 297 | + "\n", |
| 298 | + "$$\n", |
| 299 | + "\\frac{\\partial J}{\\partial W^{\\left(1\\right)}} = \\delta^{(3)} \\cdot \\frac{\\partial z^{\\left(3\\right)}}{\\partial a^{\\left(2\\right)}}\\cdot\\frac{\\partial a^{\\left(2\\right)}}{\\partial W^{\\left(1\\right)}}\n", |
| 300 | + "$$\n", |
| 301 | + "\n", |
| 302 | + "Substituting the value of $z^{(3)}$ from equation (6)\n", |
| 303 | + "\n", |
| 304 | + "$$\n", |
| 305 | + "\\begin{align}\n", |
| 306 | + "\\notag\n", |
| 307 | + "\\frac{\\partial J}{\\partial W^{\\left(1\\right)}} &= \\delta^{(3)} \\cdot \\frac{\\partial z^{\\left(3\\right)}}{\\partial a^{\\left(2\\right)}}\\cdot\\frac{\\partial a^{\\left(2\\right)}}{\\partial W^{\\left(1\\right)}} \\\\\n", |
| 308 | + "&= \\delta^{(3)} \\cdot \\frac{\\partial\\left(W^{\\left(2\\right)}\\cdot a^{\\left(2\\right)}\\right)}{\\partial a^{\\left(2\\right)}} \\cdot\\frac{\\partial a^{\\left(2\\right)}}{\\partial W^{\\left(1\\right)}} \\\\\n", |
| 309 | + "&= \\delta^{(3)} \\cdot W^{(2)} \\cdot \\frac{\\partial a^{\\left(2\\right)}}{\\partial W^{\\left(1\\right)}} \\\\\n", |
| 310 | + "\\tag{Using (4)}\n", |
| 311 | + "&= \\delta^{(3)} \\cdot W^{(2)} \\cdot \\frac{\\partial\\sigma\\left(z^{\\left(2\\right)}\\right)}{\\partial W^{\\left(1\\right)}} \\\\\n", |
| 312 | + "\\tag{We've done this before}\n", |
| 313 | + "&= \\delta^{(3)} \\cdot W^{(2)} \\cdot \\sigma'\\left(z^{\\left(2\\right)}\\right) \\cdot \\frac{\\partial z^{\\left(2\\right)}}{\\partial W^{\\left(1\\right)}}\n", |
| 314 | + "\\end{align}\n", |
| 315 | + "$$\n" |
| 316 | + ] |
254 | 317 | },
|
255 | 318 | {
|
256 | 319 | "cell_type": "code",
|
|
0 commit comments