Under review as a conference paper at ICLR 2017 ON LARGE-BATCH TRAINING FOR DEEP LEARNING: GENERALIZATION GAP AND SHARP MINIMA Nitish Shirish Keskar Northwestern University Evanston, IL 60208

[email protected] Dheevatsa Mudigere Intel Corporation Bangalore, India

[email protected] Jorge Nocedal Northwestern University Evanston, IL 60208

[email protected] Mikhail Smelyanskiy see (Bengio et al., 2016) and the references therein. The problem of training these networks is one of non-convex optimization. Mathematically, this can be represented as: min x2Rn f(x) := 1M MX i=1 fi(x); (1) where fi is a loss function for data point i 2f1;2; ;Mg which captures the deviation of the model prediction from the data, and x is the vector of weights being optimized. The process of optimizing this function is also called training of the network. Stochastic Gradient Descent (SGD) (Bottou, 1998; Sutskever et al., 2013) and its variants are often used for training deep networks. These methods minimize the objective function f by iteratively taking steps of the form: xk+1 = xk k 1 jBkj X i2Bk rfi(xk) ! ; (2) whereBk f1;2; ;Mgis the batch sampled from the data set and k is the step size at iteration k. These methods can be interpreted as gradient descent using noisy gradients, which and are often referred to as mini-batch gradients with batch size jBkj. SGD and its variants are employed in a Work was performed when author was an intern at Intel Corporation 1 Under review as a conference paper at ICLR 2017 small-batch regime, wherejBkj M and typicallyjBkj2f32;64; ;512g. These configura- tions have been successfully used in practice for a large number of applications; see e.g. (Simonyan Graves et al., 2013; Mnih et al., 2013). Many theoretical properties of these methods are known. These include guarantees of: (a) convergence to minimizers of strongly-convex functions and to stationary points for non-convex functions (Bottou et al., 2016), (b) saddle-point avoidance (Ge et al., 2015; Lee et al., 2016), and (c) robustness to input data (Hardt et al., 2015). Stochastic gradient methods have, however, a major drawback: owing to the sequential nature of the iteration and small batch sizes, there is limited avenue for parallelization. While some efforts have been made to parallelize SGD for Deep Learning (Dean et al., 2012; Das et al., 2016; Zhang et al., 2015), the speed-ups and scalability obtained are often limited by the small batch sizes. One natu- ral avenue for improving parallelism is to increase the batch sizejBkj. This increases the amount of computation per iteration, which can be effectively distributed. However, practitioners have ob- served that this leads to a loss in generalization performance; see e.g. (LeCun et al., 2012). In other words, the performance of the model on testing data sets is often worse when trained with large- batch methods as compared to small-batch methods. In our experiments, we have found the drop in generalization (also called generalization gap) to be as high as 5% even for smaller networks. In this paper, we present numerical results that shed light into this drawback of large-batch methods. We observe that the generalization gap is correlated with a marked sharpness of the minimizers obtained by large-batch methods. This motivates efforts at remedying the generalization problem, as a training algorithm that employs large batches without sacrificing generalization performance would have the ability to scale to a much larger number of nodes than is possible today. This could potentially reduce the training time by orders-of-magnitude; we present an idealized performance model in the Appendix C to support this claim. The paper is organized as follows. In the remainder of this section, we define the notation used in this paper, and in Section 2 we present our main findings and their supporting numerical evidence. In Section 3 we explore the performance of small-batch methods, and in Section 4 we briefly discuss the relationship between our results and recent theoretical work. We conclude with open questions concerning the generalization gap, sharp minima, and possible modifications to make large-batch training viable. In Appendix E, we present some attempts to overcome the problems of large-batch training. 1.1 NOTATION We use the notation fi to denote the composition of loss function and a prediction function corre- sponding to the ith data point. The vector of weights is denoted by x and is subscripted by k to denote an iteration. We use the term small-batch (SB) method to denote SGD, or one of its variants like ADAM (Kingma (ii) LB methods are attracted to saddle points; (iii) LB methods lack the explorative properties of SB methods and tend to zoom-in on the minimizer closest to the initial point; (iv) SB and LB methods converge to qualitatively different minimizers with differing generalization properties. The data presented in this paper supports the last two conjectures. The main observation of this paper is as follows: 2 Under review as a conference paper at ICLR 2017 The lack of generalization ability is due to the fact that large-batch methods tend to converge to sharp minimizers of the training function. These minimizers are characterized by a signif- icant number of large positive eigenvalues in r2f(x), and tend to generalize less well. In contrast, small-batch methods converge to flat minimizers characterized by having numerous small eigenvalues ofr2f(x). We have observed that the loss function landscape of deep neural networks is such that large-batch methods are attracted to regions with sharp minimizers and that, unlike small-batch methods, are unable to escape basins of attraction of these minimizers. The concept of sharp and flat minimizers have been discussed in the statistics and machine learning literature. (Hochreiter see Figure 1 for a hypothetical illustration. This can be explained through the lens of the minimum description length (MDL) theory, which states that statistical models that require fewer bits to describe (i.e., are of low complexity) generalize better (Rissanen, 1983). Since flat minimizers can be specified with lower precision than to sharp minimizers, they tend to have bet- ter generalization performance. Alternative explanations are proffered through the Bayesian view of learning (MacKay, 1992), and through the lens of free Gibbs energy; see e.g. Chaudhari et al. (2016). Flat Minimum Sharp Minimum Training Function Testing Function f(x) Figure 1: A Conceptual Sketch of Flat and Sharp Minima. The Y-axis indicates value of the loss function and the X-axis the variables (parameters) 2.2 NUMERICAL EXPERIMENTS In this section, we present numerical results to support the observations made above. To this end, we make use of the visualization technique employed by (Goodfellow et al., 2014b) and a proposed heuristic metric of sharpness (Equation (4)). We consider 6 multi-class classification network con- figurations for our experiments; they are described in Table 1. The details about the data sets and network configurations are presented in Appendices A and B respectively. As is common for such problems, we use the mean cross entropy loss as the objective function f. The networks were chosen to exemplify popular configurations used in practice like AlexNet (Krizhevsky et al., 2012) and VGGNet (Simonyan see Figure 2 for the training–testing curve of the F2 and C1 networks, which are representative of the rest. As such, early-stopping heuristics aimed at preventing models from over-fitting would not help reduce the generalization gap. The difference between the training and testing accuracies for the networks is due to the specific choice of the network (e.g. AlexNet, VGGNet etc.) and is not the focus of this study. Rather, our goal is to study the source of the testing performance disparity of the two regimes, SB and LB, on a given network model. 0 20 40 60 80 100Epoch10 20 30 40 50 60 70 80 90 100 Accuracy SB - Training SB - TestingLB - Training LB - Testing (a) Network F2 0 20 40 60 80 100Epoch20 30 40 50 60 70 80 90 100 Accuracy SB - Training SB - TestingLB - Training LB - Testing (b) Network C1 Figure 2: Training and testing accuracy for SB and LB methods as a function of epochs. 4 Under review as a conference paper at ICLR 2017 2.2.1 PARAMETRIC PLOTS We first present parametric 1-D plots of the function as described in (Goodfellow et al., 2014b). Let x?s and x?‘ indicate the solutions obtained by running ADAM using small and large batch sizes respectively. We plot the loss function, on both training and testing data sets, along a line-segment containing the two points. Specifically, for 2 [ 1;2], we plot the function f( x?‘ + (1 )x?s) and also superimpose the classification accuracy at the intermediate points; see Figure 31. For this experiment, we randomly chose a pair of SB and LB minimizers from the 5 trials used to generate the data in Table 2. The plots show that the LB minima are strikingly sharper than the SB minima in this one-dimensional manifold. The plots in Figure 3 only explore a linear slice of the function, but in Figure 7 in Appendix D, we plot f(sin( 2 )x?‘ + cos( 2 )x?s) to monitor the function along a curved path between the two minimizers . There too, the relative sharpness of the minima is evident. 2.2.2 SHARPNESS OF MINIMA So far, we have used the term sharp minimizer loosely, but we noted that this concept has received attention in the literature (Hochreiter 2; ;pgg; (3) where A+ denotes the pseudo-inverse of A. Thus controls the size of the box. We can now define our measure of sharpness (or sensitivity). Metric 2.1. Given x2Rn, 0 and A2Rn p, we define the (C ;A)-sharpness of f at x as: x;f( ;A) := (maxy2C f(x+Ay)) f(x)1 +f(x) 100: (4) Unless specified otherwise, we use this metric for sharpness for the rest of the paper; ifAis not spec- ified, it is assumed to be the identity matrix, In. (We note in passing that, in the convex optimization literature, the term sharp minimum has a different definition (Ferris, 1988), but that concept is not useful for our purposes.) In Tables 3 and 4, we present the values of the sharpness metric (4) for the minimizers of the various problems. Table 3 explores the full-space (i.e., A = In) whereas Table 4 uses a randomly sampled n 100 dimensional matrix A. We report results with two values of , (10 3;5 10 4). In all experiments, we solve the maximization problem in Equation (4) inexactly by applying 10 iterations of L-BFGS-B (Byrd et al., 1995). This limit on the number of iterations was necessitated by the large cost of evaluating the true objective f. Both tables show a 1–2 order-of-magnitude difference between the values of our metric for the SB and LB regimes. These results reinforce the view that the solutions obtained by a large-batch method defines points of larger sensitivity of the training function. In Appedix E, we describe approaches to attempt to remedy this generalization problem of LB methods. These approaches include data augmentation, conservative training and adversarial training. Our preliminary findings show that these approaches help reduce the generalization gap but still lead to relatively sharp minimizers and as such, do not completely remedy the problem. 1The code to reproduce the parametric plot on exemplary networks can be found in our GitHub repository: https://github.com/keskarnitish/large-batch-training. 5 Under review as a conference paper at ICLR 2017 1.0 0.5 0.0 0.5 1.0 1.5 2.0alpha0 2 4 6 8 10 12 Cross Entropy TrainTest 0 20 40 60 80 100 Accuracy (a) F1 1.0 0.5 0.0 0.5 1.0 1.5 2.0alpha0 2 4 6 8 10 12 14 Cross Entropy TrainTest 0 20 40 60 80 100 Accuracy (b) F2 1.0 0.5 0.0 0.5 1.0 1.5 2.0alpha0 1 2 3 4 5 6 7 8 Cross Entropy TrainTest 10 20 30 40 50 60 70 80 90 100 Accuracy (c) C1 1.0 0.5 0.0 0.5 1.0 1.5 2.0alpha0 1 2 3 4 5 6 7 Cross Entropy TrainTest 0 20 40 60 80 100 Accuracy (d) C2 1.0 0.5 0.0 0.5 1.0 1.5 2.0alpha0 2 4 6 8 10 Cross Entropy TrainTest 0 20 40 60 80 100 Accuracy (e) C3 1.0 0.5 0.0 0.5 1.0 1.5 2.0alpha0 2 4 6 8 10 12 Cross Entropy TrainTest 0 20 40 60 80 100 Accuracy (f) C4 Figure 3: Parametric Plots – Linear (Left vertical axis corresponds to cross-entropy loss, f, and right vertical axis corresponds to classification accuracy; solid line indicates training data set and dashed line indicated testing data set); = 0 corresponds to the SB minimizer and = 1 to the LB minimizer. Note that Metric 2.1 is closely related to the spectrum ofr2f(x). Assuming to be small enough, when A = In, the value (4) relates to the largest eigenvalue of r2f(x) and when A is randomly sampled it approximates the Ritz value ofr2f(x) projected onto the column-space of A. We conclude this section by noting that the sharp minimizers identified in our experiments do not resemble a cone, i.e., the function does not increase rapidly along all (or even most) directions. By sampling the loss function in a neighborhood of LB solutions, we observe that it rises steeply only along a small dimensional subspace (e.g. 5% of the whole space); on most other directions, the function is relatively flat. 6 Under review as a conference paper at ICLR 2017 Table 3: Sharpness of Minima in Full Space; is defined in (3). = 10 3 = 5 10 4 SB LB SB LB F1 1:23 0:83 205:14 69:52 0:61 0:27 42:90 17:14 F2 1:39 0:02 310:64 38:46 0:90 0:05 93:15 6:81 C1 28:58 3:13 707:23 43:04 7:08 0:88 227:31 23:23 C2 8:68 1:32 925:32 38:29 2:07 0:86 175:31 18:28 C3 29:85 5:98 258:75 8:96 8:56 0:99 105:11 13:22 C4 12:83 3:84 421:84 36:97 4:07 0:87 109:35 16:57 Table 4: Sharpness of Minima in Random Subspaces of Dimension 100 = 10 3 = 5 10 4 SB LB SB LB F1 0:11 0:00 9:22 0:56 0:05 0:00 9:17 0:14 F2 0:29 0:02 23:63 0:54 0:05 0:00 6:28 0:19 C1 2:18 0:23 137:25 21:60 0:71 0:15 29:50 7:48 C2 0:95 0:34 25:09 2:61 0:31 0:08 5:82 0:52 C3 17:02 2:20 236:03 31:26 4:03 1:45 86:96 27:39 C4 6:05 1:13 72:99 10:96 1:89 0:33 19:85 4:12 3 SUCCESS OF SMALL-BATCH METHODS It is often reported that when increasing the batch size for a problem, there exists a threshold after which there is a deterioration in the quality of the model. This behavior can be observed for the F2 andC1 networks in Figure 4. In both of these experiments, there is a batch size ( 15000 forF2 and 500 for C1) after which there is a large drop in testing accuracy. Notice also that the upward drift in value of the sharpness is considerably reduced around this threshold. Similar thresholds exist for the other networks in Table 1. Let us now consider the behavior of SB methods, which use noisy gradients in the step computation. From the results reported in the previous section, it appears that noise in the gradient pushes the iterates out of the basin of attraction of sharp minimizers and encourages movement towards a flatter minimizer where noise will not cause exit from that basin. When the batch size is greater than the threshold mentioned above, th