Why Gradient Descent Minimizes Training Loss
\[ \newcommand{\R}{\mathbb{R}} \newcommand{\der}{\partial} \newcommand{\dldt}{\frac{\der l}{\der t}} \newcommand{\len}{\text{length}} \newcommand{\en}{\text{energy}} \newcommand{\relu}{ {\small\text{ReLU}} } \newcommand{\dim}{\text{dim}} \newcommand{\tr}{\text{trace}} \newcommand{\lin}{\text{Lin}} \newcommand{\rnk}{\text{rank}} \newcommand{\ht}{\widehat} \newcommand{\dwdt}{\frac{\der w}{\der t}} \newcommand{\l}{\mathscr{l}} \newcommand{\E}{\mathbb E} \]
Every student of deep learning has wondered at some point in their journey: “How can gradient descent on a neural network possibly succeed? Surely, on such a large and non-convex loss landscape, we will get stuck in some local minimum, far from the optimal solution.” And yet, deep learning practitioners quickly learn to trust that a large enough neural network with properly tuned hyper-parameters will successfully fit any training data we throw at it [1]. In this article, we explore a mathematical theory to explain this surprising phenomenon.
One common approach to showing that gradient descent will converge to a global optimum is to show that the loss landscape is convex, which means that a line drawn between any pair of points will never intersect its surface. This property can be used to prove that no local minima exist [2]. But unfortunately, this technique cannot be applied to neural networks, as the loss function of a neural network is not a convex function of its parameters. This can be demonstrated by a simple counterexample.
Neural network convexity counterexample
Consider a 2-layer MLP with ReLU nonlinearity, trained with the L2 loss on a single data point. It has two parameter matrices \(M\in \R^{k\times m}\) and \(N\in \R^{n\times k}\). Given an input \(x \in \R^m\) the MLP produces the output \[y = N \: \relu(M x)\] And the loss is \(\l = \frac 1 2 | y - q|^2\) where \(q \in \R^n\) is the target corresponding to the input \(x\).
Let \(m=n=1\) and \(k=2\). Our input-output pair is just \(x=1\) and \(q=0\). To show that the function is not convex, we will interpolate the loss between two parameter vectors \(w = (N, M)\) and \(w' = (N', M')\) where \[ N = \begin{bmatrix} 0 & 1 \end{bmatrix} \quad M = \begin{bmatrix} 1 \\ 0 \end{bmatrix} \quad\quad N' = \begin{bmatrix} 1 & 0 \end{bmatrix} \quad M' = \begin{bmatrix} 0 \\ 1 \end{bmatrix} \]
Since the target \(q=0\) it’s clear that the loss at \(w\) and \(w'\) is \(0\). But if we define \(u = s w + (1-s) w'\) you can check that the loss at \(u\) is \(\frac 1 2 s^2 (1-s)^2\) which, for \(s \in (0, 1)\) is greater than \(0\). Thus, the surface of the loss landscape must intersect the line drawn between these two points, and so it is not convex.However, while convexity is sufficient to prove convergence, it is not required for convergence to occur. For example, look at the plot below. On the left, we see the loss landscape for a quadratic loss on a two-dimensional parameter space. On the right, we see a loss landscape that is similar, except that it has a bulge that clearly makes it non-convex. And yet it is visually evident that neither plot has local minima. Gradient descent would optimize successfully in both settings.
In this article, we investigate an alternative property to convexity which we call learning health. Informally, it says that if the loss is large, the gradient must also be large. (In the plot above, both loss landscapes are healthy.)
The central result of this article is to use learning health to prove that gradient descent on neural networks will decrease the loss to a small value. The argument goes like this:
- The neural network is healthy near the initialization.
- When the loss landscape is healthy, the loss will decrease exponentially quickly.
- Under the dynamics of gradient descent, the parameters will stay close to initialization for a some time.
These three results allow us to conclude that the loss decays exponentially for a meaningful amount of time, and therefore is guaranteed to reach some small value. In general, we observe that the bounds we derive are improved by the size of the neural network, so when using a large enough network, the loss can be guaranteed to fall arbitrarily close to zero.
In this article, we apply this theory only to a very simple neural network: a 2-layer MLP trained with the L2 loss on a dataset consisting of a single datapoint. But the general theory we develop here applies more broadly. In follow-up posts, we will show that the results can also be applied to MLPs with multiple layers, trained on multiple data points, and using the cross-entropy loss instead of the L2. There is also existing literature exploring similar ideas (see e.g. [1], [3], [4]).
The plot below visualizes the bound we derive, juxtaposed with the actual training curve of a neural network. Every time you click the train button a new set of initial parameters, inputs, and targets are randomly initialized, and 200 steps of gradient descent training are performed. You can see for yourself that the loss always approaches 0, and that our bound is non-vacuous.
0. Notation and Definitions
Much of the theory in this article treats the space of parameters as a vector space equipped with an inner product. The parameters of a neural network tend to either be elements of \(\R^{n}\) or matrices in \(\R^{n\times m}\), so we need to define suitable inner products for these two types of vectors.
Definition 0.1: Inner Product. For two vectors \(x,y \in \R^n\) we take the inner product to be \(<x, y> = x^T y\). For two matrices \(M,N\in \R^{n\times m}\) we use \(<M, N> = \tr(M^T N)\). Any inner product induces a norm \(|\cdot|\) via \(|v| = \sqrt{<v,v>}\), and you can verify that, in the \(\R^n\) and \(\R^{n\times m}\) cases, this induced norm takes the following forms: \[\begin{align} |x| = \sqrt{\sum_i x_i^2} \quad\quad |M| = \sqrt{\sum_{ij} M_{ij}^2} \end{align}\] which are also known as the L2 norms of the respective vector spaces. \(\blacksquare\)
Definition 0.2: Derivative. Let \(f: V \to W\) be a function between two vector spaces. The derivative at a point \(v\in V\) is a linear function \(\der f(v): V\to W\). It tells us how changing the input \(v \to v + u\) will affect the output, if the change \(u\in V\) is small. Concretely, \[ \der f(v)(u) \simeq f(v + u) - f(v) \] Some people would call \(\der f(v)(u)\) the directional derivative of \(f\) along \(u\) at a point \(v\). \(\blacksquare\)
Definition 0.3: Path Derivative. An especially important case is when we are taking the derivative of a function \(h:\R \to V\), also known as a path through \(V\). Here, using the notation \(\der h(t): \R \to V\) is a little cumbersome. Instead, we can use Newton’s notation, which defines \(h'(t) = \der h(t)(1)\in V\). This allows us to think of the derivative of \(h\) as a vector in \(V\), as opposed to a map \(\R \to V\). This is possible because any linear map of the form \(M: \R \to V\) can be turned into scalar-vector multiplication \(M(r) = r \: v\) by defining \(v = M(1)\). \(\blacksquare\)
One important use of the derivative of a path is to define the length and energy of the path.
Definition 0.4: length and energy of a path. Given a normed vector space \(V\) and a differentiable path \(h: [0, t] \to V\), the length and energy of the path are defined as \[\begin{align} \len(h) &= \int_0^t |h'(s)| ds \\ \en(h) &= \int_0^t |h'(s)|^2 ds \end{align}\] \(\blacksquare\)
Definition 0.5: Gradient. Given a vector space \(V\) equipped with an inner product, the gradient of a differentiable function \(f: V \to \R\), denoted \(\nabla f\), is a map \(\nabla f: V\to V\) defined so that \(\nabla f(v)\in V\) is the unique vector satisfying \[ \der f(v)(u) = <\nabla f(v), u> \quad \forall v\in V \] Whenever it is completely clear from the context which function \(f\) we are taking the gradient of we can use the shorthand \(\hat v = \nabla f(v)\). \(\blacksquare\)
You might be used to a different definition of the gradient. The one you know turns out to be equivalent to the one we introduce here. Take a look at this great article for an in-depth explanation.
Definition 0.6: Gradient Flow. Let \(V\) be a vector space with an inner product. Given a differentiable and lower-bounded function \(f: V\to \R\), the gradient flow starting at \(v_0\in V\) is defined as the path \(\gamma: [0, \infty) \to V\) satisfying \[ \gamma(0) = v_0 \quad\text{and}\quad \gamma'(t) = - \nabla f\circ \gamma(t) \] The existence of \(\gamma\) follows from the existence and uniqueness of partial differential equations. The fact that \(f\) is lower-bounded guarantees that the solution never diverges, and so \(\gamma\) is a map with domain \([0, \infty)\). \(\blacksquare\)
1. Learning with Gradient Flows
Definition 1.1: Learning Problem. The problem setup consists of a tuple \((W, w_0, f)\) where:
\(W\) is the parameter space, which controls the behavior of the model. Mathematically, it is a vector space equipped with an inner product \(<\cdot, \cdot>\).
\(w_0\in W\) is the initial point from which the learning will proceed.
\(f: W \to \R^+\) is the loss function, which tells us how good the parameters are (presumably at fitting some dataset). The lower the loss, the better the parameters. A loss function must be lower bounded so it’s nice to assume wlog that \(\inf_{w} f(w) = 0\).
Out of \(f, W\) and \(w_0\), the following objects are defined:
A path \(\gamma: [0, \infty) \to W\) is a gradient flow of \(f\) starting at \(w_0\). It describes the evolution of the parameters through time.
A loss curve \(\l:\R \to \R\), defined as \(\l(t) = f\circ \gamma(t)\), tells us the amount of loss at any moment in time.
\(\blacksquare\)
This setting is deeply connected to the learning algorithms used in practice to train neural networks. But it is worth pointing out some important differences:
- Learning algorithms used in practice have discrete aspects. For example, gradient descent starts with the initial parameters \(w_0\in W\) and repeatedly applies the update \(w_{t+1} = w_t - \delta \nabla f(w_t)\) for some learning rate hyperparameter \(\delta \in \R^+\). Clearly, in the limit \(\delta \to 0\) this discrete path converges to the gradient flow \(\gamma\). But it is less clear whether the learning rates used in practice are small enough for the gradient flow to be a good approximation.
- Modern deep learning always uses stochastic gradient descent (SGD). Instead of computing the gradient on the entire dataset, we estimate it by computing the average gradient on a small random subset of the data. These stochastic gradients result in slightly worse updates, but they are much cheaper to compute. This tradeoff is worth making because, for the same amount of compute, it allows us apply many more updates. Again, as the learning rate approaches \(0\), SGD converges to the gradient flow. But it is worth asking whether or not this approximation applies in practice.
- It’s common for deep learning optimizers to include details like a momentum term, a correction for the variance, etc… each of these details is yet another reason why practical algorithms may look quite different from gradient flow.
Ultimately, we will want to prove guarantees for the algorithms we run in practice. But studying the behavior of gradient flows is a useful intermediate step. The rest of this section will lay out some of the key mathematical properties of gradient flows that make them so powerful when deriving learning guarantees.
Result 1.2. The derivative of the loss curve satisfies: \[\l'(t) = - | \nabla f\circ \gamma(t)| ^2\]
Proof
For simplicity let \(w=\gamma(t)\). Then, \[\begin{align} \l'(t) &= \der f(w)( \gamma'(t)) &\quad &\text{(chain rule)} \\ &= < \nabla f(w), \gamma'(t)> &\quad &\text{(gradient def.)} \\ &= -< \nabla f(w), \nabla f(w)> &\quad &\text{(gradient flow def.)} \\ &= - | \nabla f (w)| ^2 \end{align}\]
So, under gradient flow, the magnitude of the gradient tells us how fast the loss is decreasing at any moment in time. A large gradient means fast learning! A consequence of this result is that \(\en(\gamma)\) measures how much the path \(\gamma\) has managed to reduce the loss.
Result 1.3. Let \(f:W \to \R\) be a differentiable function and \(\gamma:[0, t] \to W\) be a gradient flow of \(f\) as in definition 1.1. Then:
\[ \en(\gamma) = \l(0) - \l(t) \]
Proof
\[\begin{align} \en(\gamma) &= \int_0^t |\gamma'(s)|^2 ds \\ &= \int_0^t |\nabla f \circ \gamma(s)|^2 ds &\quad &\text{(gradient flow def.)} \\ &= - \int_0^t \l'(s) ds &\quad &\text{(result 1.2)} \\ &= \l(0) - \l(t) \end{align}\]
Thus, we can reformulate questions about \(\l(t)\) into questions about \(\en(\gamma)\); an interesting shift in perspective. For our purposes, the most important implication is that \(\en(\gamma) \le \l(0)\), i.e., the energy is bounded by the initial loss. This follows from Result 1.3 and from the assumption in Definition 1.1 that \(f(w) \ge 0\).
The last step before we are ready to prove the main result of this section is to establish an inequality between the length and energy of a generic path (not necessarily a gradient flow).
Result 1.4. If \(h: [0, t] \to V\) is a differentiable path, then
\[ \len(h)^2 \le t\; \en(h) \]
Proof
Just note that \[\begin{align} \len(h) &= \int_0^t |h'(s)| ds \\ &\le \sqrt{\int_0^t 1^2 ds } \sqrt{\int_0^t | h'(s)|^2 ds } \;\;\;\;\;\; \text{(by Cauchy Schwarz)} \\ &\le\sqrt{ t \; \en(h)} \\ \end{align}\] From which the result follows.We are finally ready to show that, in a short amount of time \(t\), the parameters \(\gamma(t)\) cannot move very far away from the initialization.
Result 1.5. Let \(\gamma\) be the gradient flow of \(f\) starting at \(w_0\) as in Definition 1.1. Then: \[ |\gamma(t) - w_0| \le \sqrt{t\;\l(0)} \]
Proof
\[\begin{align} |\gamma(t) - \gamma(0)| &= \left |\int_0^t \gamma'(s) ds \right| \\ &\le \int_0^t \left | \gamma'(s) \right| ds &\quad &\text{(triangle inequality for integrals)} \\ &= \len(\gamma) \\ &\le \sqrt{t \; \en(\gamma)} &\quad &\text{(result 1.4)} \end{align}\]2. A Replacement for Convexity
As described in the introduction, our proof utilizes the concept of learning health instead of convexity, which we define the following way: a loss function \(f : W \to \mathbb{R}^+\) has learning health if there exists \(\alpha \in \R^+\) such that for all \(w \in W\), \[ \frac{|\nabla f(w)|^2}{f(w)} \ge \alpha \] One implication of this definition is that no local minima can exist in a healthy loss landscape. After all, the definition of a local minimum is a point \(w\) with no gradient but high loss, the existence of which is ruled out by the condition.
But the following result tells us that learning health also implies something much stronger. On a healthy loss landscape, the loss of a gradient flow is guaranteed to decay to \(0\) exponentially quickly.
Result 2.1. If \(f: W\to \R^+\) satisfies \(\frac{|\nabla f(w)|^2}{f(w)} \ge \alpha\) for some \(\alpha \in \R^+\), then \[ \l(t) \le \l(0) \; e^{-\alpha t} \]
Proof
\[\begin{align} \frac{\der \ln \l(t)}{\der t} &= \frac{\l'(t)}{\l(t)} = -\frac{|\nabla f \circ \gamma(t)|^2}{f\circ \gamma(t)} \le -\alpha \\ \end{align}\] so \[ \ln \l(t) - \ln \l(0) = \int_0^t \frac{\der \ln \l(s)}{\der s} ds \le - \int_0^t \alpha ds = - \alpha t \] And then \(\ln \l(t) \le \ln \l(0) - \alpha t\). To conclude, use the fact that exponential is a monotone function and \[ \l(t) = e^{\ln \l(t)} \le e^{\ln \l(0) - \alpha t} = \l(0) e^{ - \alpha t} \]Unfortunately, neural networks don’t satisfy the learning health property for all \(w\in W\). (To see that, just consider an MLP with “degenerate” parameters, where all the weight matrices are \(0\). That MLP will have 0 gradient, even when it has high loss.) But that does not prevent us from using learning health to derive guarantees. In Section 3 we will see that a properly-initialized 2-layer MLP does indeed satisfy a relaxed version of this property: there exists \(\alpha,\beta \in \R^+\) such that for all \(w \in W\), \[ \frac{|\nabla f(w)|^2}{f(w)} \ge \alpha - \beta \; |w-w_0| \]
The main result of this section is a guarantee that can be applied to any learning problem satisfying this relaxed property.
Result 2.2. If \(f\) and \(w_0\) satisfy \(\frac{|\nabla f(w)|^2}{f(w)} \ge \alpha - \beta \; |w-w_0|\) for all \(w\in W\), then \[ l(\infty) \le \l(0) \; \exp({-\frac {\alpha^3}{3 \beta^2 \l(0)} }) \]
Proof
The proof follows a very similar argument to 2.1,
\[\begin{align} \frac{\der \ln \l(t)}{\der t} &= -\frac{|\nabla f(w)|^2}{f(w)} \quad &\text{(result 1.3)} \\ &\le - \alpha + \beta \; |w - w_0| \\ &\le - \alpha + \beta \sqrt {\l(0) t} \quad &\text{(result 1.5)} \\ \end{align}\]
Integrating both sides from \(0\) to \(t\) we get \[ \ln \l(t) - \ln \l(0) \le - \alpha t + \frac{2}{3} \beta \sqrt {\l(0)} \; t^{3/2} \]
which implies \[ \l(t) \le \l(0) \exp(- \alpha t + \frac{2}{3} \beta \sqrt \l(0) \; t^{3/2}) \]
Now, we want to find the value of \(t\) that minimizes the term in the exponential. We set the derivative to \(0\) by solving \(- \alpha t + \frac{2}{3} \beta \sqrt {\l(0)} \; t^{3/2}=0\) which you can easily see is achieved at \(t^* = \frac {\alpha^2} {\beta^2 \l(0)}\). This implies that, \[ \l(t^*) \le \l(0) \exp({-\frac{\alpha^3}{3\beta^2 \l(0)}}) \] And, since \(\l'(t) = -|\nabla f \circ \gamma(t)|^2 \le 0\), the loss function is monotonically decreasing, which implies that \(\l(\infty) \le \l(t^*) \le \l(0)\exp({-\frac{\alpha^3}{3\beta^2 \l(0)}})\), proving the result.
3. The Simplest Neural Network
The objective of this section is straightforward. To take the absolute simplest neural network and show that the conditions of Result 2.2 are satisfied. We’ll look at a 2-layer MLP with ReLU nonlinearity. It has two parameter matrices \(M\in \R^{k\times m}\) and \(N\in \R^{n\times k}\). The ReLU is denoted by \(\sigma\). Given an input \(x \in \R^m\), the MLP produces the output \[ y = N \sigma(M x) \] It is convenient to define intermediate variables by breaking down the computation into steps. \[ a = Mx, \quad b=\sigma(a) \quad y=Nb \] Parameter space. The weight vectors are pairs of matrices \(w = (N, M)\) and so the parameter space is \[W= \R^{n\times k} \oplus \R^{k\times m}\] The loss function. In keeping with the philosophy of studying the simplest example, we’ll take the loss function to be the L2 error on a single input-target pair \((x, q)\in \R^n \oplus \R^m\). \[ f(w) = \frac 1 2 |y - q|^2 \] We will need \(x\) to be normalized, so we just assume that \(|x|^2 = m\). It’s also useful to define the loss as a variable \(L = f(w)\).
Initialization. The standard way to initialize a neural network is to independently sample all the entries in the weight matrices from some distribution. In our case, we use \[ M_{ij} \sim \mathcal N (0, {\small \frac 2 m} ) \quad\quad N_{ij} \sim \mathcal N (0, {\small \frac 2 k}) \] This is the commonly-used He initialization, which works well for networks with ReLU nonlinearities. This initialization provides us with the guarantee that, with high probability, \[ |b|^2 \simeq k \]
See a derivation of this property
Remember that we are sampling the initial weight matrices via \[ M_{ij} \sim \mathcal N (0, {\small \frac 2 m} ) \quad\quad N_{ij} \sim \mathcal N (0, {\small \frac 2 k}) \] Since \(M,N\) are random variables, so are \(a\) and \(b\). He initialization uses carefully-chosen variances to ensure that the entries of the initial activations \(a = M x\) and \(b = \sigma(a)\) are within some reasonable range. In particular, we want the entries of \(b_i\) to have \(0\) mean and variance \(1\). And since all the entries are independent, that will mean that \(|b|^2 \simeq k\) and so \(b\) will be close to the sphere of radius \(\sqrt k\) with high probability. (But this is only the case if the input vector is also normalized; that is why we needed the assumption that \(|x|^2 = m\).) We will now prove these statements.
First we want to understand \(\E[a_i^2]\) \[\begin{align} \E[a_i^2] &= \E[ ({\small \sum_j M_{ij} x_j})^2] \\ &= \E[ {\small \sum_j M_{ij}^2 x_j^2 + \sum_{k\neq j} M_{ij} x_j M_{ik} x_k } ] \\ &= \sum_j \E[M_{ij}^2 x_j^2] \\ &= \frac 2 m |x|^2= 2 \end{align}\] where we used the independence of different entries of \(M\) and the fact that \(\E[M_{ij}]=0\). Then \[ \E[|a|^2] = \sum_i \E[a_i^2] = 2k \] Let \(p_M\) denote the probability distribution functions (PDFs) of the entries of \(M_{ij}\) (all entries have the same PDF because they are independent and identically distributed). The PDF of \(a_i\) will be the nested integral of \(\delta(a_i - M_i^T x)\) as \(M_{i1},\cdots,M_{im}\) range from \(-\infty \to \infty\), where \(\delta\) denote the Dirac delta function. \[ p_{a}(z) = \int_{-\infty}^\infty \cdots \int_{-\infty}^\infty p_M(M_{i1}) \cdots p_M(M_{im}) \; \delta (z - M_i^T x)\; d M_{i1} \cdots d M_{im} \] Recall that a distribution is symmetric if \(p(z) = p(-z)\). Since \(p_M\) is a Gaussian, it is symmetric. The next thing we need to prove is that \(p_a\) is symmetric too. \[\begin{align} p_{a}(z) &= \int_{-\infty}^\infty \cdots \int_{-\infty}^\infty p_M(M_{i1}) \cdots p_M(M_{im}) \; \delta (z - M_i^T x)\; d M_{i1} \cdots d M_{im} \\ &= \int_{-\infty}^\infty \cdots \int_{-\infty}^\infty p_M(-M_{i1}) \cdots p_M(-M_{im}) \; \delta (z + M_i^T x)\; d M_{i1} \cdots d M_{im} \\ &= \int_{-\infty}^\infty \cdots \int_{-\infty}^\infty p_M(M_{i1}) \cdots p_M(M_{im}) \; \delta (-z - M_i^T x)\; d M_{i1} \cdots d M_{im} \\ &= p_a(-z) \end{align}\] The first step uses the change of variable \(M_{ij}\to -M_{ij}\), and the second one exploits the symmetry of \(p_M\) and \(\delta\). Finally, we compute the term we care about: \[\begin{align} \E[b_i^2] &= \int_{-\infty}^\infty p_a(a_i) \sigma(a_i)^2 da_i \\ &= \int_{-\infty}^0 p_a(a_i) \:0\: da_i + \int_0^\infty p_a(a_i) a_i^2 da_i \\ &= \frac 1 2 \int_{-\infty}^\infty p_a(a_i) a_i^2 da_i \quad \text{(using symmetry)}\\ &= \frac 1 2 \E[a_i^2] \\ &= k \end{align}\] We have only been working out \(\E[|b|^2] = k\), but we wanted to make claims about the particular samples themselves being approximately \(|b|^2\simeq k\). The rigorous way to go about doing so would be to prove statements like \(\Pr( k - \epsilon \le |b|^2 \le k+\epsilon) \le \alpha\). But this type of argument is very tedious and adds little insight about the ideas this document is exploring, so instead, in the rest of this article we just assume that \(|b|^2 = k\). When the dimension \(k\) is large, this approximation will be very accurate.Gradients. The gradients of the loss \(\l\) wrt all the intermediate variables and weights are: \[ \ht y = y -q, \quad \ht b = N^T \ht y, \quad \ht a =\der \sigma(a)^T \ht b, \quad \ht M = \ht a x^T, \quad \ht N = \ht y b^T, \quad \]
In this notation, a hat on top of a variable denotes the gradient of the loss function wrt that variable.
Expand for full derivation of the gradients.
Note that \(|\hat y |^2 = 2 \l\) so, when the loss is large, the gradient of \(\l\) wrt \(y\) is large too. Ultimately we are trying to show something similar, but about gradients of \(\l\) wrt \(M\) (in order to then apply Result 2.2.)
Below, we’ve written a short derivation of all these gradient formulas, but using some unconventional notation and techniques. If you find them confusing, just work out the gradients in your own way and confirm you get the same answers.
Starting with \(\ht y\), the gradient of \(\l\) wrt \(y\). Let \(\dot y \in \R^n\) denote an arbitrary change to \(y\). Then \[\begin{align} <\ht y, \dot y> =\frac {\der \l} {\der y} (\dot y) &\simeq \frac 1 2 |y +\dot y - q|^2 - \frac 1 2 |y - q|^2 \\ &= \frac 1 2 ( <\dot y, y-q> + <\dot y, \dot y> + <y-q, \dot y> ) \\ &\simeq <y-q, \dot y> \quad \text{(dropping the lower order term)} \\ \end{align}\] Now, let’s see how a change in \(b\) affects \(y\). Like before, let \(\dot b\) denote a change to \(b\). The derivative \(\frac {\der y}{\der b}(\dot b) \simeq N(b+ \dot b) - M b = \dot M b\). So, the gradient of \(\l\) wrt \(b\) satisfies \[ <\ht b, \dot b> = {\frac {\der \l}{\der b}(\dot b)} = {\frac {\der \l}{\der y}} \left( {\frac {\der y}{\der b}}(\dot b) \right) = <\ht y, {\frac {\der y}{\der b}} (\dot b)> = <\hat y, M \dot b> = <M^T \ht y, \dot b> \] Now, let’s look at \(N\). The derivative \(\frac {\der y}{\der N}(\dot N) \simeq (N+\dot N) b - N b = \dot N b\). And the gradient \[ <\ht N, \dot N> = <\ht y, \dot N b> = <\ht y b^T, \dot N> \] Recall that, since we are taking the inner product of two matrices, we are using the trace under the hood. To see why the last step is true, just use the cyclic property of the trace. Finally, gradients of \(a\) and \(M\): \[\begin{gather} <\hat a, \dot a> = <\hat b, \der \sigma(a)(\dot a)> = <\der \sigma(a)^T \ht b, \dot a> \\ <\ht M, \dot M> = <\ht a, \dot M x> = <\ht a x^T, \dot M> \end{gather}\]Note that \(|\hat y|^2 = 2 f(w) = L\), the gradient of the loss wrt the output, is proportional to the loss itself.
Gradient health. The gradient of the loss wrt the matrix \(N\) satisfies \[ |\ht{N} |^2 \ge 2 L ( k - 2 \sqrt {k} |w_0 - w| ) \\ \]
Expand for activation and gradient bounds.
First, let’s review a basic fact. Given a matrix \(A\in \R^{n\times m}\) and \(x\in \R^m\) we have that \(|Ax| \le |A| \; |x|\). This follows form the SVD decomposition \(A = R \Lambda S^T\), where \(R,S\) are orthogonal matrices and \(\Lambda\) is a diagonal matrix with the singular values \(\lambda_1,\cdots, \lambda_n \ge 0\) in it’s diagonal. First, note that \[ |A^T A |^2 = | S \Lambda^2 S^T|^2 = | \Lambda^2 |^2 = \sum_i \lambda_i^4 \le \left( \sum_i \lambda_i^2 \right)^2 = | \Lambda |^4 = | R \Lambda S^T|^4 = | A |^4 \] So \[\begin{align} |A x|^2 &= \tr(x^T A^T A x) = \tr(A^T A x x^T) = <A^T A, x x^T> \\ &\le |A^T A| |x x^T| = |A^T A| |x|^2 \\ &\le |A|^2 |x|^2 \end{align}\]
Now, recall that \(w_0 = (M_0, N_0)\) and let \(a_0=M_0 x, b_0=\sigma(a_0)\) and \(y_0=N_0 b_0\) be the activations at initialization. We want to derive upper bounds on \(|a - a_0|\) and \(|b - b_0|\) based on \(|w-w_0|\). First, \[\begin{align*} |a_0 - a| &\le |M_0 x - Mx| \\ & = |M_0 - M| \; |x| \quad \text{(using the fact we just proved)} \\ & \le |w-w_0| \; |x| \\ & = \sqrt m \; |w-w_0| \end{align*}\] where the last step used the assumption that \(|x| = \sqrt m\) (the inputs are normalized). To bound \(|b_0 - b |\) we need to use the fact that the ReLU \(\sigma\) is 1-Lipschitz. So \[ |b_0 - b| = |\sigma(a) - \sigma(a_0) | \le |a_0 - a | \le \sqrt m \; |w-w_0| \]
Recall that our weight initialization guarantees that \(|b_0| \simeq \sqrt k\) (for simplicity we assume exact equality). Now we can conclude with: \[\begin{align} | \ht{N} |^2 &= |\ht y b^T|^2 = \text{trace}(b y^T y b^T) = |\ht{y}|^2 |b|^2 = 2 L |b|^2 \\ &\ge L (|b_0| - |b_0 - b|)^2 = 2 L ( \sqrt k - |b_0 - b| )^2 \\ &= 2 L ( k - 2 \sqrt k |b_0 - b| + |b_0 - b| ^2 ) \\ &\ge 2 L ( k - 2 \sqrt k |b_0 - b| ) \\ &\ge 2 L ( k - 2 \sqrt {km} |w_0 - w| ) \\ \end{align}\]We could also attempt to derive a lower bound for \(|\ht M|^2\), but it is not really necessary to do so. We already have enough to apply 2.2.
Learning Guarantees. Since \(|\nabla f(w)|^2 = |\ht N|^2 + |\ht N|^2\), the previous result implies that \[ \frac{|\nabla f(w)|^2}{f(w)} \ge 2k - 4 \sqrt {km} |w_0 - w|\\ \] So by setting \(\alpha = 2k\) and \(\beta = 4 \sqrt {km}\), the application of Result 2.2 gives \[ \l(\infty) \le \l(0) \; \text{exp}({-\frac {k^2}{6 m \l(0)} }) \] From looking at the above equation, it is apparent that the scale of the MLP helps it learn. By growing \(k\) we can very quickly guarantee that the loss is decreased to any desired value.
Below is an interactive visualization that can be played with to see how this simple 2-layer MLP learns. The red line shows the bound we’ve just derived.
Acknowledgments
Huge thanks to Nahr for constructing the live training visualizations of the post and to Sean Zhang for bringing relevant literature to our attention.
References
Citation
@misc{gelada2024,
author = {Gelada, Carles and Buckman, Jacob},
publisher = {Manifest AI},
title = {Why {Gradient} {Descent} {Minimizes} {Training} {Loss}},
date = {2024-09-23},
langid = {en}
}