AdaGrad

From Cornell University Computational Optimization Open Textbook - Optimization Wiki
Jump to navigation Jump to search

Author: Daniel Villarraga (SYSEN 6800 Fall 2021)

Introduction

AdaGrad is a family of sub-gradient algorithms for stochastic optimization. The algorithms belonging to that family are similar to second-order stochastic gradient descend with an approximation for the Hessian of the optimized function. AdaGrad's name comes from Adaptative Gradient. Intuitively, it adapts the learning rate for each feature depending on the estimated geometry of the problem; particularly, it tends to assign higher learning rates to infrequent features, which ensures that the parameter updates rely less on frequency and more on relevance.

AdaGrad was introduced by Duchi et al.[1] in a highly cited paper published in the Journal of machine learning research in 2011. It is arguably one of the most popular algorithms for machine learning (particularly for training deep neural networks) and it influenced the development of the Adam algorithm[2].

Theory

The objective of AdaGrad is to minimize the expected value of a stochastic objective function, with respect to a set of parameters, given a sequence of realizations of the function. As with other sub-gradient-based methods, it achieves so by updating the parameters in the opposite direction of the sub-gradients. While standard sub-gradient methods use update rules with step-sizes that ignore the information from the past observations, AdaGrad adapts the learning rate for each parameter individually using the sequence of gradient estimates.

Definitions

Failed to parse (SVG (MathML can be enabled via browser plugin): Invalid response ("Math extension cannot connect to Restbase.") from server "https://wikimedia.org/api/rest_v1/":): {\displaystyle f(x)} : Stochastic objective function with parameters Failed to parse (SVG (MathML can be enabled via browser plugin): Invalid response ("Math extension cannot connect to Restbase.") from server "https://wikimedia.org/api/rest_v1/":): {\displaystyle x} .

Failed to parse (SVG (MathML can be enabled via browser plugin): Invalid response ("Math extension cannot connect to Restbase.") from server "https://wikimedia.org/api/rest_v1/":): {\displaystyle f_t(x)} : Realization of stochastic objective at time step Failed to parse (SVG (MathML can be enabled via browser plugin): Invalid response ("Math extension cannot connect to Restbase.") from server "https://wikimedia.org/api/rest_v1/":): {\displaystyle t} . For simplicity Failed to parse (SVG (MathML can be enabled via browser plugin): Invalid response ("Math extension cannot connect to Restbase.") from server "https://wikimedia.org/api/rest_v1/":): {\displaystyle f_t } .

Failed to parse (SVG (MathML can be enabled via browser plugin): Invalid response ("Math extension cannot connect to Restbase.") from server "https://wikimedia.org/api/rest_v1/":): {\displaystyle g_t(x) } : The gradient of Failed to parse (SVG (MathML can be enabled via browser plugin): Invalid response ("Math extension cannot connect to Restbase.") from server "https://wikimedia.org/api/rest_v1/":): {\displaystyle f_t(x)} with respect to Failed to parse (SVG (MathML can be enabled via browser plugin): Invalid response ("Math extension cannot connect to Restbase.") from server "https://wikimedia.org/api/rest_v1/":): {\displaystyle x} , formally Failed to parse (SVG (MathML can be enabled via browser plugin): Invalid response ("Math extension cannot connect to Restbase.") from server "https://wikimedia.org/api/rest_v1/":): {\displaystyle \nabla_x f_t(x)} . For simplicity, Failed to parse (SVG (MathML can be enabled via browser plugin): Invalid response ("Math extension cannot connect to Restbase.") from server "https://wikimedia.org/api/rest_v1/":): {\displaystyle g_t } .

Failed to parse (SVG (MathML can be enabled via browser plugin): Invalid response ("Math extension cannot connect to Restbase.") from server "https://wikimedia.org/api/rest_v1/":): {\displaystyle x_t} : Parameters at time step Failed to parse (SVG (MathML can be enabled via browser plugin): Invalid response ("Math extension cannot connect to Restbase.") from server "https://wikimedia.org/api/rest_v1/":): {\displaystyle t} .

Failed to parse (SVG (MathML can be enabled via browser plugin): Invalid response ("Math extension cannot connect to Restbase.") from server "https://wikimedia.org/api/rest_v1/":): {\displaystyle G_t} : The outer product of all previous subgradients, given by Failed to parse (SVG (MathML can be enabled via browser plugin): Invalid response ("Math extension cannot connect to Restbase.") from server "https://wikimedia.org/api/rest_v1/":): {\textstyle \sum_{\tau=1}^t g_{\tau}g_{\tau}^{\top} }

Standard Sub-gradient Update

Standard sub-gradient algorithms update parameters Failed to parse (SVG (MathML can be enabled via browser plugin): Invalid response ("Math extension cannot connect to Restbase.") from server "https://wikimedia.org/api/rest_v1/":): {\displaystyle x} according to the following rule:

Failed to parse (SVG (MathML can be enabled via browser plugin): Invalid response ("Math extension cannot connect to Restbase.") from server "https://wikimedia.org/api/rest_v1/":): {\displaystyle x_{t+1} = x_t - \eta g_t} where Failed to parse (SVG (MathML can be enabled via browser plugin): Invalid response ("Math extension cannot connect to Restbase.") from server "https://wikimedia.org/api/rest_v1/":): {\displaystyle \eta} denotes the step-size often refered as learning rate or step-size. Expanding each term on the previous equation, the vector of parameters is updated as follows:

Failed to parse (SVG (MathML can be enabled via browser plugin): Invalid response ("Math extension cannot connect to Restbase.") from server "https://wikimedia.org/api/rest_v1/":): {\displaystyle \begin{bmatrix} x_{t+1}^{(2)} \\ x_{t+1}^{(2)} \\ \vdots \\ x_{t+1}^{(m)} \end{bmatrix} = \begin{bmatrix} x_{t}^{(2)} \\ x_{t}^{(2)} \\ \vdots \\ x_{t}^{(m)} \end{bmatrix} - \eta \begin{bmatrix} g_{t}^{(2)} \\ g_{t}^{(2)} \\ \vdots \\ g_{t}^{(m)} \end{bmatrix} }

AdaGrad Update

The general AdaGrad update rule is given by:

Failed to parse (SVG (MathML can be enabled via browser plugin): Invalid response ("Math extension cannot connect to Restbase.") from server "https://wikimedia.org/api/rest_v1/":): {\displaystyle x_{t+1} = x_t - \eta G_t^{-1/2} g_t} where Failed to parse (SVG (MathML can be enabled via browser plugin): Invalid response ("Math extension cannot connect to Restbase.") from server "https://wikimedia.org/api/rest_v1/":): {\displaystyle G_t^{-1/2}} is the inverse of the square root of Failed to parse (SVG (MathML can be enabled via browser plugin): Invalid response ("Math extension cannot connect to Restbase.") from server "https://wikimedia.org/api/rest_v1/":): {\displaystyle G_t} . A simplified version of the update rule takes the diagonal elements of Failed to parse (SVG (MathML can be enabled via browser plugin): Invalid response ("Math extension cannot connect to Restbase.") from server "https://wikimedia.org/api/rest_v1/":): {\displaystyle G_t} instead of the whole matrix:

Failed to parse (SVG (MathML can be enabled via browser plugin): Invalid response ("Math extension cannot connect to Restbase.") from server "https://wikimedia.org/api/rest_v1/":): {\displaystyle x_{t+1} = x_t - \eta \text{diag}(G_t)^{-1/2} g_t} which can be computed in linear time. In practice, a small quantity Failed to parse (SVG (MathML can be enabled via browser plugin): Invalid response ("Math extension cannot connect to Restbase.") from server "https://wikimedia.org/api/rest_v1/":): {\displaystyle \epsilon} is added to each diagonal element in Failed to parse (SVG (MathML can be enabled via browser plugin): Invalid response ("Math extension cannot connect to Restbase.") from server "https://wikimedia.org/api/rest_v1/":): {\displaystyle G_t} to avoid singularity problems, the resulting update rule is given by:

Failed to parse (SVG (MathML can be enabled via browser plugin): Invalid response ("Math extension cannot connect to Restbase.") from server "https://wikimedia.org/api/rest_v1/":): {\displaystyle x_{t+1} = x_t - \eta \text{diag}(\epsilon I + G_t)^{-1/2} g_t} where Failed to parse (SVG (MathML can be enabled via browser plugin): Invalid response ("Math extension cannot connect to Restbase.") from server "https://wikimedia.org/api/rest_v1/":): {\displaystyle I } denotes the identity matrix. An expanded form of the previous update is presented below,

Failed to parse (SVG (MathML can be enabled via browser plugin): Invalid response ("Math extension cannot connect to Restbase.") from server "https://wikimedia.org/api/rest_v1/":): {\displaystyle \begin{bmatrix} x_{t+1}^{(2)} \\ x_{t+1}^{(2)} \\ \vdots \\ x_{t+1}^{(m)} \end{bmatrix} = \begin{bmatrix} x_{t}^{(2)} \\ x_{t}^{(2)} \\ \vdots \\ x_{t}^{(m)} \end{bmatrix} - \begin{bmatrix} \eta \frac{1}{\sqrt{\epsilon + G_t^{(1,1)}}} \\ \eta \frac{1}{\sqrt{\epsilon + G_t^{(2,2)}}} \\ \vdots \\ \eta \frac{1}{\sqrt{\epsilon + G_t^{(m,m)}}} \end{bmatrix} \odot \begin{bmatrix} g_{t}^{(2)} \\ g_{t}^{(2)} \\ \vdots \\ g_{t}^{(m)} \end{bmatrix} } where the operator Failed to parse (SVG (MathML can be enabled via browser plugin): Invalid response ("Math extension cannot connect to Restbase.") from server "https://wikimedia.org/api/rest_v1/":): {\displaystyle \odot} denotes the Hadamard product between matrices of the same dimension, and Failed to parse (SVG (MathML can be enabled via browser plugin): Invalid response ("Math extension cannot connect to Restbase.") from server "https://wikimedia.org/api/rest_v1/":): {\displaystyle G_t^{(j,j)}} is the Failed to parse (SVG (MathML can be enabled via browser plugin): Invalid response ("Math extension cannot connect to Restbase.") from server "https://wikimedia.org/api/rest_v1/":): {\displaystyle j} element in the Failed to parse (SVG (MathML can be enabled via browser plugin): Invalid response ("Math extension cannot connect to Restbase.") from server "https://wikimedia.org/api/rest_v1/":): {\displaystyle G_t} diagonal. From the last expression, it is clear that the update rule for AdaGrad adapts the step-size for each parameter Failed to parse (SVG (MathML can be enabled via browser plugin): Invalid response ("Math extension cannot connect to Restbase.") from server "https://wikimedia.org/api/rest_v1/":): {\displaystyle j} accoding to Failed to parse (SVG (MathML can be enabled via browser plugin): Invalid response ("Math extension cannot connect to Restbase.") from server "https://wikimedia.org/api/rest_v1/":): {\textstyle \eta (\epsilon + G_t^{(j,j)})^{-1/2}} , while standard sub-gradient methods have fixed step-size Failed to parse (SVG (MathML can be enabled via browser plugin): Invalid response ("Math extension cannot connect to Restbase.") from server "https://wikimedia.org/api/rest_v1/":): {\displaystyle \eta} for every parameter.

Algorithm

Regret Bounds

Empirical Performance

Numerical Example

Applications

Summary and Discussion

References

  1. Duchi, J., Hazan, E., & Singer, Y. (2011). Adaptive subgradient methods for online learning and stochastic optimization. Journal of machine learning research, 12(7).
  2. Kingma, D. P., & Ba, J. (2014). Adam: A method for stochastic optimization. arXiv preprint arXiv:1412.6980.