# Recurrent Batch Normalization.pdf

Under review as a conference paper at ICLR 2017 RECURRENT BATCH NORMALIZATION Tim Cooijmans, Nicolas Ballas, César Laurent, ?a?glar Gül?ehre Yao et al., 2015). Top-performing models, however, are based on very high-capacity networks that are computation- ally intensive and costly to train. Effective optimization of recurrent neural networks is thus an active area of study (Pascanu et al., 2012; Martens Ollivier, 2013). It is well-known that for deep feed-forward neural networks, covariate shift (Shimodaira, 2000; Ioffe Amodei et al., 2015). It has found limited use in stacked RNNs, where the nor- malization is applied “vertically”, i.e. to the input of each RNN, but not “horizontally” between timesteps. RNNs are deeper in the time direction, and as such batch normalization would be most beneficial when applied horizontally. However, Laurent et al. (2016) hypothesized that applying batch normalization in this way hurts training because of exploding gradients due to repeated rescal- ing. Our findings run counter to this hypothesis. We show that it is both possible and highly beneficial to apply batch normalization in the hidden-to-hidden transition of recurrent models. In particular, we describe a reparameterization of LSTM (Section 3) that involves batch normalization and demon- strate that it is easier to optimize and generalizes better. In addition, we empirically analyze the 1 Under review as a conference paper at ICLR 2017 gradient backpropagation and show that proper initialization of the batch normalization parameters is crucial to avoiding vanishing gradient (Section 4). We evaluate our proposal on several sequen- tial problems and show (Section 5) that our LSTM reparameterization consistently outperforms the LSTM baseline across tasks, in terms of both time to convergence and performance. Liao x2;:::;xT), an RNN defines a sequence of hidden states ht according to ht = (Whht 1 +Wxxt +b); (1) where Wh 2Rdh dh;Wx 2Rdx dh;b2Rdh and the initial state h0 2Rdh are model parame- ters. A popular choice for the activation function ( ) is tanh. RNNs are popular in sequence modeling thanks to their natural ability to process variable-length sequences. However, training RNNs using first-order stochastic gradient descent (SGD) is notori- ously difficult due to the well-known problem of exploding/vanishing gradients (Bengio et al., 1994; Hochreiter, 1991; Pascanu et al., 2012). Gradient vanishing occurs when states ht are not influenced by small changes in much earlier states h ,t , preventing learning of long-term dependencies in the input data. Although learning long-term dependencies is fundamentally difficult (Bengio et al., 1994), its effects can be mitigated through architectural variations such as LSTM (Hochreiter Arjovsky et al., 2015). In what follows, we focus on the LSTM architecture (Hochreiter (4) where Wh 2 Rdh 4dh;WxRdx 4dh;b 2 R4dh and the initial states h0 2 Rdh;c0 2 Rdh are model parameters. is the logistic sigmoid function, and the operator denotes the Hadamard product. The LSTM differs from simple RNNs in that it has an additional memory cell ct whose update is nearly linear which allows the gradient to flow back through time more easily. In addition, unlike the RNN which overwrites its content at each timestep, the update of the LSTM cell is regulated by a set of gates. The forget gate ft determines the extent to which information is carried over from the previous timestep, and the input gate it controls the flow of information from the current input xt. The output gate ot allows the model to read from the cell. This carefully controlled interaction with the cell is what allows the LSTM to robustly retain information for long periods of time. 2.2 BATCH NORMALIZATION Covariate shift (Shimodaira, 2000) is a phenomenon in machine learning where the features pre- sented to a model change in distribution. In order for learning to succeed in the presence of covari- ate shift, the model’s parameters must be adjusted not just to learn the concept at hand but also to adapt to the changing distribution of the inputs. In deep neural networks, this problem manifests as 2 Under review as a conference paper at ICLR 2017 internal covariate shift (Ioffe ; ) = + h bE[h] q dVar[h] + (5) where h 2 Rd is the vector of (pre)activations to be normalized, 2 Rd; 2 Rd are model parameters that determine the mean and standard deviation of the normalized activation, and 2R is a regularization hyperparameter. The division should be understood to proceed elementwise. At training time, the statistics E[h] and Var[h] are estimated by the sample mean and sample vari- ance of the current minibatch. This allows for backpropagation through the statistics, preserving the convergence properties of stochastic gradient descent. During inference, the statistics are typically estimated based on the entire training set, so as to produce a deterministic prediction. 3 BATCH-NORMALIZED LSTM This section introduces a reparameterization of LSTM that takes advantage of batch normalization. Contrary to Laurent et al. (2016); Amodei et al. (2015), we leverage batch normalization in both the input-to-hidden and the hidden-to-hidden transformations. We introduce the batch-normalizing transform BN( ; ; ) into the LSTM as follows: 0 BB @ ~ft ~it ~ot ~gt 1 CC A = BN(Whht 1; h; h) + BN(Wxxt; x; x) +b (6) ct = (~ft) ct 1 + (~it) tanh(~gt) (7) ht = (~ot) tanh(BN(ct; c; c)) (8) In our formulation, we normalize the recurrent term Whht 1 and the input term Wxxt separately. Normalizing these terms individually gives the model better control over the relative contribution of the terms using the h and x parameters. We set h = x = 0 to avoid unnecessary redun- dancy, instead relying on the pre-existing parameter vector b to account for both biases. In order to leave the LSTM dynamics intact and preserve the gradient flow through ct, we do not apply batch normalization in the cell update. The batch normalization transform relies on batch statistics to standardize the LSTM activations. It would seem natural to share the statistics that are used for normalization across time, just as recurrent neural networks share their parameters over time. However, we find that simply averaging statistics over time severely degrades performance. Although LSTM activations do converge to a stationary distribution, we observe that their statistics during the initial transient differ significantly (see Fig- ure 5 in Appendix A). Consequently, we recommend using separate statistics for each timestep to preserve information of the initial transient phase in the activations.1 Generalizing the model to sequences longer than those seen during training is straightforward thanks to the rapid convergence of the activations to their steady-state distributions (cf. Figure 5). For our experiments we estimate the population statistics separately for each timestep 1;:::;Tmax where 1 Note that we separate only the statistics over time and not the and parameters. 3 Under review as a conference paper at ICLR 2017 Tmax is the length of the longest training sequence. When at test time we need to generalize beyond Tmax, we use the population statistic of time Tmax for all time steps beyond it. During training we estimate the statistics across the minibatch, independently for each timestep. At test time we use estimates obtained by averaging the minibatch estimates over the training set. 4 INITIALIZING FOR GRADIENT FLOW Although batch normalization allows for easy control of the pre-activation variance through the parameters, common practice is to normalize to unit variance. We suspect that the previous difficul- ties with recurrent batch normalization reported in Laurent et al. (2016); Amodei et al. (2015) are largely due to improper initialization of the batch normalization parameters, and in particular. In this section we demonstrate the impact of on gradient flow. 0 100200300400500600700800t10-2610-24 10-210-2010 -1810-16 10-1410-1210 -1010-8 10-610-410 -2100 ||dloss/dh_t||_2 RNN gradient propagation gamma=0.10gamma=0.20gamma=0.30 gamma=0.40gamma=0.50gamma=0.60 gamma=0.70gamma=0.80gamma=0.90 gamma=1.00 (a) We visualize the gradient flow through a batch- normalized tanh RNN as a function of . High variance causes vanishing gradient. 0.00.20.40.60.81.0input standard deviation0.0 0.2 0.4 0.6 0.8 1.0 expected derivative (and IQR range) derivative through tanh (b) We show the empirical expected derivative and interquartile range of tanh nonlinearity as a func- tion of input variance. High variance causes satura- tion, which decreases the expected derivative. Figure 1: Influence of pre-activation variance on gradient propagation. In Figure 1(a), we show how the pre-activation variance impacts gradient propagation in a simple RNN on the sequential MNIST task described in Section 5.1. Since backpropagation operates in reverse, the plot is best read from right to left. The quantity plotted is the norm of the gradient of the loss with respect to the hidden state at different time steps. For large values of , the norm quickly goes to zero as gradient is propagated back in time. For small values of the norm is nearly constant. To demonstrate what we think is the cause of this vanishing, we drew samples x from a set of centered Gaussian distributions with standard deviation ranging from 0 to 1, and computed the derivative tanh0(x) = 1 tanh2(x)2[0;1] for each. Figure 1(b) shows the empirical distribution of the derivative as a function of standard deviation. When the input standard deviation is low, the input tends to be close to the origin where the derivative is close to 1. As the standard deviation increases, the expected derivative decreases as the input is more likely to be in the saturation regime. At unit standard deviation, the expected derivative is much smaller than 1. We conjecture that this is what causes the gradient to vanish, and recommend initializing to a small value. In our trials we found that values of 0.01 or lower caused instabilities during training. Our choice of 0.1 seems to work well across different tasks. 5 EXPERIMENTS This section presents an empirical evaluation of the proposed batch-normalized LSTM on four dif- ferent tasks. Note that for all the experiments, we initialize the batch normalization scale and shift parameters and to 0:1 and 0 respectively. 4 Under review as a conference paper at ICLR 2017 0 20000400006000080000100000Training Iteration0.0 0.2 0.4 0.6 0.8 1.0 Accuracy Pixel-by-Pixel MNIST (Validation Set) lstmbn_lstm 0 20000400006000080000100000Training Iteration0.1 0.2 0.3 0.4 0.5 0.6 0.7 0.8 0.9 1.0 Accuracy Pixel-by-Pixel Permuted-MNIST (Validation Set) lstmbn_lstm Figure 2: Accuracy on the validation set for the pixel by pixel MNIST classification tasks. The batch-normalized LSTM is able to converge faster relatively to a baseline LSTM. Batch-normalized LSTM also shows some improve generalization on the permuted sequential MNIST that require to preserve long-term memory information. 5.1 SEQUENTIAL MNIST We evaluate our batch-normalized LSTM on a sequential version of the MNIST classification task (Le et al., 2015). The model processes each image one pixel at a time and finally predicts the label. We consider both sequential MNIST tasks, MNIST and permuted MNIST (pMNIST). In MNIST, the pixels are processed in scanline order. In pMNIST the pixels are processed in a fixed random order. Our baseline consists of an LSTM with 100 hidden units, with a softmax classifier to produce a prediction from the final hidden state. We use orthogonal initialization for all weight matrices, except for the hidden-to-hidden weight matrix which we initialize to be the identity matrix, as this yields better generalization performance on this task for both models. The model is trained using RMSProp (Tieleman Chung et al., 2016; Ha et al., 2016). 5.3 TEXT8 We evaluate our model on a second character-level language modeling task on the much larger text8 dataset (Mahoney, 2009). This dataset is derived from Wikipedia and consists of a sequence of 100M characters including only alphabetical characters and spaces. We follow Mikolov et al. (2012); Zhang et al. (2016) and use the first 90M characters for training, the next 5M for validation and the final 5M characters for testing. We train on nonoverlapping sequences of length 180. Both our baseline and batch-normalized models are LSTMs with 2000 units, trained to predict the next character using a softmax classifier on the hidden state ht. We use stochastic gradient descent on minibatches of size 128, with gradient clipping at 1.0 and step rule determined by Adam (Kingma Amodei et al. (2015). That is, we share statistics over time for normalization 0 200040006000800010000120001400016000training steps1.4 1.6 1.8 2.0 2.2 2.4 bits per character LSTMBN-LSTM (a) Performance in bits-per-character on length- 100 subsequences of the Penn Treebank validation sequence during training. 1002003004005006007008009001000sequence length1.32 1.34 1.36 1.38 1.40 1.42 1.44 1.46 mean bits per character LSTMBN-LSTM, population statisticsBN-LSTM, batch statistics (b) Generalization to longer subsequences of Penn Treebank using population statistics. The subse- quences are taken from the test sequence. Figure 3: Penn Treebank evaluation 7 Under review as a conference paper at ICLR 2017 0 100200300400500600700800training steps (thousands)0.0 0.2 0.4 0.6 0.8 1.0 error rate LSTM trainBN-LSTM trainBN-everywhere train BN-e* trainBN-e** trainLSTM valid BN-LSTM validBN-everywhere validBN-e* valid BN-e** valid (a) Error rate on the validation set for the Atten- tive Reader models on a variant of the CNN QA task (Hermann et al., 2015). As detailed in Ap- pendix C, the theoretical lower bound on the error rate on this task is 43%. 0 50100150200250300350400training steps (thousands)0.10.2 0.30.4 0.50.6 0.70.8 0.91.0 error rate LSTM trainBN-e** trainLSTM valid BN-e** valid (b) Error rate on the validation set on the full CNN QA task from Hermann et al. (2015). Figure 4: Training curves on the CNN question-answering tasks. of the input terms Wxxt, but not for the recurrent terms Whht or the cell output ct. Doing so avoids many issues involving degenerate statistics due to input sequence padding. Our fourth and final variant BN-e** is like BN-e* but bidirectional. The main difficulty in adapting to bidirectional models also involves padding. Padding poses no problem as long as it is properly ignored (by not updating the hidden states based on padded regions of the input). However to perform the reverse application of a bidirectional model, it is common to simply reverse the padded sequences, thus moving the padding to the front. This causes similar problems as were observed on the sequential MNIST task (Section 5.1): the hidden states will not diverge during the initial timesteps and hence their variance will be severely underestimated. To get around this, we reverse only the unpa