A Refresher on Batch (Re-)Normalization
January 24, 2018
If you’re like me, you enjoy throwing CNNs at every pictorial problem that comes your way. If you’re like me, then you have heard of BatchNorm.
This post assumes you have a CS231n-ish level of understanding of neural networks (aka you have taken a university level introduction course to deep learning). If you are completely new to neural nets, I highly recommend exactly said course as a perfect resource to quickly get up to speed. Honestly, it was when following those lectures that I developed most of my intuitions about CNNs that (the intuitions) I still rely on every day — Thanks Andrej! Thanks Justin!)
If you’re like me, you enjoy throwing CNNs at every pictorial problem that comes your way. You feel confident explaining to your MBA friends how¹ neural nets work and you like to complain with your CS² friends about the price tag of Nvidia’s new GPUs. If you’re like me, then you have heard of BatchNorm. You have probably used it³, googled for an explanation of what internal covariate shift means and somewhat satisfied returned to your daily business of tuning some more hyperparameters.
Now, chances are you haven’t heard of batch renormalization⁴, Ioffe’s follow-up paper.
The tl;dr for most people reads something like this:
BatchRenorm is superior to BatchNorm, implemented in TensorFlow, but comes at the cost of a few extra hyperparameters to tune. If your batch size is very small (say 2 or 4 — most likely due to constrained GPU memory) you should probably use it.
Below, I’ll try to give you a refresher on how BatchNorm is related to transfer learning, why you should be just a little paranoid when using BatchNorm (i.e. how it can break down in unexpected ways) and how BatchRenorm will help you go back to sleeping like a baby, but only — to stretch the simile a bit — if you are willing to do a few extra push-ups before going to bed.
Transfer Learning
The transfer learning scenario you are most familiar with, is probably this: you have a CNN pre-trained on ImageNet that you now want to use to distinguish between your left foot socks and your right foot socks (or something of the sort). Depending on how much data you have, you unfreeze the last few layers or re-train the whole CNN on your very own sock dataset — so far, so good. But there are other scenarios of transfer learning, that are, if not as omnipresent, just as important⁵.
When — instead of generalizing from one task to the next (like classifying socks instead of dogs and cats) — we want our model to generalize from a source domain to a target domain with a different data distribution, this is called domain adaptation. To get an intuitive understanding for why this might be a difficult problem, consider two datasets with only cats and dogs. Both include your normal variety of cats, but the first only has dogs that are brown. If we are unlucky, our model trained on the first dataset, will not learn differences such as pointy vs floppy ears, but only that if it sees a brown thing it must be a dog. If so, we shouldn’t be surprised if our classification accuracy suffers a bolt from the blue when testing on the second dataset⁶.
The takeaway is a universal truth of machine learning: if your data distribution changes under your nose, you are probably in trouble.
BatchNorm is not having any of it
Back to BatchNorm and googling what “reducing internal covariate shift” means. “Internal covariate shift” is just a fancy term for the fact that the input (“data”) distributions of intermediate layers of neural networks change during training. This is not surprising since the input of an intermediate layer is simply the output of the layer before it and as the parameters of this “pre”-layer get updated over time, its output will change too.
Instead of trying to find a clever “internal domain adaptation” technique, the ingenious solution of Ioffe-Szegedy to this problem of changing input distributions is to simply sidestep it. They use BatchNorm to force every layer input to be normalized and voilà: no more mess of shape-shifting distributions.
For quick reference here is the algorithm.
It turns out that using BatchNorm also makes your model more robust to less careful weight initialization and larger learning rates⁷. And another goodie:
I-S report that the noise introduced by computing the mean and variance over each mini-batch instead of over the entire training set⁸ isn’t just bad news, but acts as regularization and can remove the need to add extra dropout layers.
Why BatchNorm should make you paranoid
You know what I hate: if my code compiles, my model trains, but for some well-hidden reason the model performance is much worse than expected. Unfortunately, under certain circumstances BatchNorm can be that well-hidden reason. To understand when this happens I highly recommend reading Alex Irpan’s post on the perils of BatchNorm. In any case, here is my executive summary:
When the mini-batch mean (µB) and mini-batch standard deviation (σB) diverge from the mean and standard deviation over the entire training set too often, BatchNorm breaks.Remember that at inference time we use the moving averages of µB and σB (as an estimate of the statistics of the entire training set) to do the normalization step. Naturally, if your means and standard deviations during training and testing are different, so are your activations and you can’t be surprised if your results are different (read worse), too. This can happen when your mini-batch samples are non-i.i.d. (or in plain language: when your sampling procedure is biased — think first sampling only brown dogs and then sampling only black dogs) or, more commonly, when you have a very small batch size⁹. In both cases: Welcome back to “shape-shifting distributions”-land.
Enter BatchRenorm
BatchRenorm tackles this issue of differing statistics at train and inference time head-on. The key insight to bridge the difference is this:
The normalization step at inference time (using estimates of the training set statistics µ and σ) can actually be rewritten as an affine transformation of the normalization step at training time (using mini-batch statistics µB and σB)! And that’s basically all there is to it. Using mini-batch + affine transformation at train time and moving averages at inference time ensures that the output of BatchRenorm is the same during both phases, even when σB != σ and µB != µ.
Here is the algorithm in its non-japanese entirety.
An interesting thing to note is that batch renormalization is really just a generalization of BatchNorm and reverts to its predecessor when σB == σ and µB == µ (or r = 1 and d = 0). This leads us to the question of when to use BatchRenorm and when BatchNorm is enough?
BatchRenorm ≥ BatchNorm?
The good news is that in terms of model performance you can count on BatchRenorm to always be better or equal to BatchNorm.
However, using BatchRenorm comes at the added cost of two hyperparameters (discussed in the caption beneath the BatchRenorm algorithm) for which you have to find the right schedule to get the best performance. So there you have it. As most things in life it is a trade-off between your time and your model’s performance. If you have outsourced your hyperparameter tuning to things like Bayesian optimization¹¹, it’s at least still a trade-off between computing resources and performance.
Personally, I will be using BatchRenorm with the fixed schedule mentioned in the paper¹² from now on. If I have very small batch sizes (or some weird mini-batch sampling curiosity as in Irpan’s post) I might bring myself to do some hyperparameter tuning myself.
Let me know how it goes for you.
[1] ^ Notice that I said how, not why. For a lucid, completely non-technical account of how neural networks work, thisNYTimes article is a must-read. It also features a hilarious quote by Quoc V. Le, which — if you ask me — wonderfully captures this obscure feeling, that even though you understood most principles and math underlying neural networks, if you dare to dig too deep it still feels a lot like wizardry when your Inception-Resnet-v2 once again successfully classifies your dog as a Soft Coated Wheaten Terrier and not a, say, West Highland White Terrier [1a].
Pestered by the author for an intuitive mental picture of word embeddings (like word2vec or GloVe), Le answers bluntly: “Gideon, […] I do not generally like trying to visualize thousand-dimensional vectors in three-dimensional space.”
[1a] Full disclosure: I do not own a Terrier.
[2] ^ Freely search and replace with EE, Stats, Physics, Math. etc.
[3] ^ In fact, the Ioffe-Szegedy paper introducing the idea of batch normalization has been cited 3397 times (as of Jan. 20th, 2018). This doesn’t necessarily mean that 3397 people have tried to build upon the ideas introduced by I-S or that they have tried to find new ways of tackling the problems addressed by BatchNorm (more on those later), but mainly that BatchNorm has become SOP in most CNN architectures used since 2016. For good measure, let me cite it right here and now:
Ioffe, S., Szegedy, C., (2015). Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift. ICML.
[4] ^ Batch Renormalization has no connection to Renormalization, a collection of techniques in quantum field theory. If you know nothing about quantum field theory (like me) I still encourage you to follow the link to the Wikipedia article and enjoy yourself while reading the first sentence there. Anyways, here is the renormalization we are concerned with:
Ioffe, S., (2017). Batch Renormalization: Towards Reducing Minibatch Dependence in Batch-Normalized Models. NIPS.
[5] ^ Sebastian Ruder wrote a great blog post about the different transfer learning settings, with much more precise definitions than I am using here. There he also quotes Andrew Ng who, during his NIPS 2016 workshop, predicted that transfer learning will be “the next driver of ML success”.
[6] ^ Similarly, you could imagine that if we train on a normal dataset of cats and dogs, but are unlucky in our random sampling and only get brown dogs in the beginning, it will probably take a longer time until our model learns apt differences like pointy vs floppy ears. We’ll come back to this idea, when discussing how BatchNorm can break in unexpected ways.
[7] ^ An easy intuition for why this might be true, is quickly found in Ioffe’s Batch Renormalization paper [4]:
When x is itself a result of applying a linear transform W to the previous layer, batchnorm makes the model invariant to the scale of W (ignoring the small epsilon).
But if the scale of W does not matter, then proper weight initialization is not as important and larger learning rates won’t break your training as easily!
[8] ^ We can’t compute the population mean and variance of intermediate layer activations, since passing over the entire training set after every single parameter update would defeat the whole purpose of stochastic gradient descent. That seems straightforward enough. But why can’t we just use the moving averages of mean and variance instead of the mini-batch versions when doing the normalization? The reason, I find, is quite subtle. Disclaimer: If you have never been haunted by this question about BatchNorm I encourage you to skip the rest of this footnote, since it is both verbose and a little technical.
If you ever had to derive the backward pass for BatchNorm yourself, you might remember that the reason the exercise isn’t completely effortless, is that the mini-batch mean and variance are (obviously) functions of the current activations in the mini-batch.
Of course, this is also true for the moving averages. Now, if you ignore this dependence of the normalization step (the calculation of σ and μ) on your mini-batch inputs while calculating your gradients, your model will blow up. But why is this? I-S naturally have the answer:
The last thing to add is that if you factor in the dependence of E[x] on b, your gradient ∆b = 0 (go ahead and check if this is true!) and your model won’t explode. Yay! Back to the business of moving averages: in TensorFlow they are calculated using momentum α (in words: mean_moving_average = α * mean_moving_average + (1 − α) * mini_batch_mean). As you can probably already tell, even though it’s doable to factor in the dependence of the mini_batch_mean on your current activations, there is no easy way of calculating the gradients of mean_moving_average with respect to the activations x (alas, the downside of recursion), thus ∆b won’t be 0, thus b will explode, thus we use the mini-batch mean and variance and not the moving averages. So that’s that.
[9] ^ Small batch sizes lead to a high variance in µB and σB. Going one better (or worse): in the extreme case of a batch size of one, it’s not just a problem of fickle means and variances, but the output of BatchNorm will be all zeros!
[10] ^ If we wouldn’t do this we would be back to the exact same problem discussed in [8]. If you treat r and d as constants however, you can go through the motions discussed in [8] and you’ll see that ∆b is still 0 => and your model shouldn’t explode.
[11] ^ Is there an open-source version of Google Vizier yet?
[12] ^ The hyperparameter schedule used in the paper:
For Batch Renorm, we used rmax = 1, dmax = 0 (i.e. simply batchnorm) for the first 5000 training steps, after which these were gradually relaxed to reach rmax = 3 at 40k steps, and dmax = 5 at 25k steps.