Adafactor: Difference between revisions

From Cornell University Computational Optimization Open Textbook - Optimization Wiki
Jump to navigation Jump to search
Line 4: Line 4:


== Introduction ==
== Introduction ==
Adafactor is an efficient, adaptive learning rate optimization algorithm proposed by Noam Shazeer and Mitchell Stern from Google Research in 2018. <sup>1</sup>
Unlike traditional Adam optimizers, Adafactor does not store complete second-order moment matrices. Instead, it employs a factorization approach that only maintains gradient statistics for the rows and columns of parameter matrices, significantly reducing memory usage. Moreover, Adafactor uses an adaptive learning rate, allowing it to dynamically adjust step sizes without the need for manually setting a global learning rate or relying heavily on hyperparameter tuning. Its design also defaults to not performing bias correction, yet it remains stable in scenarios involving large-batch training data.[1] This efficiency makes it an ideal choice for training ultra-large-scale models such as T5.<sup>2</sup>
Adafactor’s efficient memory usage and outstanding performance make it widely applicable in scenarios such as Natural Language Processing (NLP). Compared to the Adam optimizer, Adafactor significantly reduces memory and computational resource requirements while maintaining comparable performance when training large-scale language models and vision models. <sup>3,6</sup>
== Problem formulation ==
== Problem formulation ==
=== 1. Objective ===
=== 1. Objective ===

Revision as of 20:49, 11 December 2024

Author: Aolei Cao (ac3237), Ziyang Li (zl986), Junjia Liang (jl4439) (ChemE 6800 Fall 2024)

Stewards: Nathan Preuss, Wei-Han Chen, Tianqi Xiao, Guoqing Hu

Introduction

Adafactor is an efficient, adaptive learning rate optimization algorithm proposed by Noam Shazeer and Mitchell Stern from Google Research in 2018. 1

Unlike traditional Adam optimizers, Adafactor does not store complete second-order moment matrices. Instead, it employs a factorization approach that only maintains gradient statistics for the rows and columns of parameter matrices, significantly reducing memory usage. Moreover, Adafactor uses an adaptive learning rate, allowing it to dynamically adjust step sizes without the need for manually setting a global learning rate or relying heavily on hyperparameter tuning. Its design also defaults to not performing bias correction, yet it remains stable in scenarios involving large-batch training data.[1] This efficiency makes it an ideal choice for training ultra-large-scale models such as T5.2

Adafactor’s efficient memory usage and outstanding performance make it widely applicable in scenarios such as Natural Language Processing (NLP). Compared to the Adam optimizer, Adafactor significantly reduces memory and computational resource requirements while maintaining comparable performance when training large-scale language models and vision models. 3,6

Problem formulation

1. Objective

Minimize the loss function $ f(x) $, where $ x \in R^n $ and $ x $ is the weight vector to be optimized.

2. Parameters

  • Gradient:

$ G_t = \nabla f(x_{t-1}) $

  • Second moment estimate:

$ \hat{V}_t = \hat{\beta}_{2t} \hat{V}_{t-1} + (1 - \hat{\beta}_{2t})(G_t^2 + \epsilon_1 1_n) $

  • Where:
    • $ \hat{V}_t $ is the running average of the squared gradient.
    • $ \hat{\beta}_{2t} $ is the corrected decay parameter.
    • $ \epsilon_1 $ is a regularization constant.
  • Step size:

$ \alpha_t = \max(\epsilon_2, \text{RMS}(x_{t-1})) \rho_t $

  • Where:
    • $ \rho_t $ is the relative step size.
    • $ \epsilon_2 $ is a regularization constant.
    • $ \text{RMS} $ is the root mean square, defined as:
      • $ u_{xt} = \frac{-g_{xt}}{\sqrt{\hat{v}_{xt}}} $
      • $ \text{RMS}(U_t) = \text{RMS}_{x \in X}(u_{xt}) = \sqrt{\text{Mean}_{x \in X}\left(\frac{(g_{xt})^2}{\hat{v}_{xt}}\right)} $

3. Algorithms

Adafactor for Weighted Vectors

Inputs:

  • Initial point: $ X_0 \in \mathbb{R}^n $
  • Relative step sizes: $ \rho_t $ for $ t = 1 $ to $ T $
  • Second moment decay: $ \hat{\beta}_{2t} $ for $ t = 1 $ to $ T $, with $ \hat{\beta}_{21} = 0 $
  • Regularization constants: $ \epsilon_1, \epsilon_2 $
  • Clipping threshold: $ d $

Algorithm:

  • For $ t = 1 $ to $ T $:
    • Compute adaptive step size: $ \alpha_t = \max(\epsilon_2, \text{RMS}(X_{t-1})) \rho_t $
    • Compute gradient: $ G_t = \nabla f_t(X_{t-1}) $
    • Update second moment estimate: $ \hat{V}_t = \hat{\beta}_{2t} \hat{V}_{t-1} + (1 - \hat{\beta}_{2t})(G_t^2 + \epsilon_1 1_n) $
    • Compute normalized gradient: $ U_t = \frac{G_t}{\sqrt{\hat{V}_t}} $
    • Apply clipping: $ \hat{U}_t = \frac{U_t}{\max(1, \text{RMS}(U_t) / d)} $
    • Update parameter: $ X_t = X_{t-1} - \alpha_t \hat{U}_t $
  • End for

Adafactor for Weighted Matrices

Inputs:

  • Initial point: $ X_0 \in \mathbb{R}^{n \times m} $
  • Relative step sizes: $ \rho_t $ for $ t = 1 $ to $ T $
  • Second moment decay: $ \hat{\beta}_{2t} $ for $ t = 1 $ to $ T $, with $ \hat{\beta}_{21} = 0 $
  • Regularization constants: $ \epsilon_1, \epsilon_2 $
  • Clipping threshold: $ d $

Algorithm:

  • For $ t = 1 $ to $ T $:
    • Compute adaptive step size: $ \alpha_t = \max(\epsilon_2, \text{RMS}(X_{t-1})) \rho_t $
    • Compute gradient: $ G_t = \nabla f_t(X_{t-1}) $
    • Update row-wise second moment: $ R_t = \hat{\beta}_{2t} R_{t-1} + (1 - \hat{\beta}_{2t})(G_t^2 + \epsilon_1 1_n 1_m^T) 1_m $
    • Update column-wise second moment: $ C_t = \hat{\beta}_{2t} C_{t-1} + (1 - \hat{\beta}_{2t}) 1_n^T (G_t^2 + \epsilon_1 1_n 1_m^T) $
    • Update overall second moment estimate: $ \hat{V}_t = \frac{R_t C_t}{1_n^T R_t} $
    • Compute normalized gradient: $ U_t = \frac{G_t}{\sqrt{\hat{V}_t}} $
    • Apply clipping: $ \hat{U}_t = \frac{U_t}{\max(1, \text{RMS}(U_t) / d)} $
    • Update parameter: $ X_t = X_{t-1} - \alpha_t \hat{U}_t $
  • End for

4. Proposed Hyperparameters for Adafactor

  • Regularization constant 1: $ \epsilon_1 = 10^{-30} $
  • Regularization constant 2: $ \epsilon_2 = 10^{-3} $
  • Clipping threshold: $ d = 1 $
  • Relative step size: $ \rho_t = \min(10^{-2}, 1/\sqrt{t}) $
  • Second moment decay: $ \hat{\beta}_{2t} = 1 - t^{-0.8} $

Numerical Examples

Step-by-step instructions for determining the result of the first iteration.

Problem setup

Initial weights ($ X_0 $):

$ X_0 = \begin{bmatrix} 0.7 &-0.5& 0.9\\ -1.1 & 0.8& -1.6\\1.2&-0.7& 0.4 \end{bmatrix} $

Gradient for first iteration (​$ G_1 $):

Gradient of the loss function with respect to X

$ G_1 = \begin{bmatrix} 0.3&-0.2&0.4\\ -0.5&0.6&-0.1\\0.2&-0.4 &0.3 \end{bmatrix} $

Hyperparameters setup

$ \epsilon_1 = 10^{-30} $ (Minimum learning rate scaling factor))

$ \epsilon_2 = 10^{-3} $ (Regularization constant)

$ d = 1 $ (Clipping threshold)

$ \rho_t = \min(10^{-2}, 1/\sqrt{t}) $ (Relative step size)

$ \hat{\beta}_{2t} = 1 - t^{-0.8} $ (Second moment decay)

Step 1: Learning Rate Scaling

Define the relative step size

$ \rho_1 = \min(10^{-2}, 1/\sqrt{1})= 10^{-2} $

Step 1.1: Root Mean Square(RMS) calculation for $ X_0 $

Root Mean Square(RMS) calculation for $ X_0 $

RMS formula

$ RMS(X_0) = \sqrt{\tfrac{1}{n}\sum_{i=1}^n X_0[i]^2} $

Substitute the initial weights

$ RMS(X_0) = \sqrt{\tfrac{1}{9}(0.72^2+(-0.5)^2+0.9^2+(-1.1)^2+0.8^2+(-0.6)^2+1.2^2+(-0.7)^2+0.4^2)} $

$ RMS(X_0) = \sqrt{\frac{6.85}{9}}\approx 0.806 $

Step 1.2: Find the Learning Rate Scaling ($ \alpha_t $):

Learning rate formula

$ \alpha_1 = max(\epsilon_2,RMS(X_0))\cdot p_1 $

Substitute the RMS

$ \alpha_1 = max(0.001,0.806)\cdot 0.01=0.00806 $


Step 2: Compute $ G^{2}_t $​ (Element-wise Square of Gradient)

Compute the squared value of each element in the gradient matrix $ G_t $.

$ G^{2}_1 = \begin{bmatrix} 0.3^2&(-0.2)^2&0.4^2\\ (-0.5)^2&0.6^2&(-0.1)^2\\0.2^2&(-0.4)^2 &0.3^2 \end{bmatrix} $



$ G^{2}_1 = \begin{bmatrix} 0.09& 0.04&0.16\\ 0.25&0.36&0.01\\0.04&0.16&0.09\end{bmatrix} $

Step 3: Find the moment estimate

Compute the exponential moving average of squared gradients to capture the variance or scale of gradients.

Step 3.1: Compute row moments ($ R_t $)

This equation computes the row-wise second moments ($ R_t $ ​) as an exponential moving average of past moments ($ R_{t-1} $) and the current row-wise mean of squared gradients ( $ G^{2}_t $​ ), with a balance controlled by ($ \hat{\beta}_{2t} $).

For $ G^{2}_t=\mathbb{R}^{m\times n} $

$ R_t = \hat{\beta_{2t}} \cdot R_{t-1} + (1-\hat{\beta})\cdot (\tfrac{1}{m}\sum_{j=1}^m G^{2}_t[i,j]+\epsilon_1) $

Since $ \hat{\beta}_{2t} = 1 - t^{-0.8} $, for first iteration: $ \hat{\beta}_{21} = 0 $. And because $ \epsilon_1 $ is too small, we can ignore it. The update of $ R_t $ is:

$ R_{1} = \tfrac{1}{m}\textstyle \sum_{j=1}^m \displaystyle G^{2}_1[i,j] $

Row-wise mean ($ R_t $):

$ R_1 = \begin{bmatrix} \tfrac{0.09+0.04+0.16}{3} \\ \tfrac{0.25+0.36+0.01}{3}\\\tfrac{0.04+0.16+0.09}{3} \end{bmatrix} = \begin{bmatrix} 0.0967\\ 0.2067\\0.0967\end{bmatrix} $

Step 3.2: Compute column moments ($ C_t $)

The process is same as row moments.

$ C_t = \hat{\beta}\cdot C_{{t-1}} + (1-\hat{\beta})\cdot (\tfrac{1}{n}\sum_{j=1}^n G^{2}_t[i,j]+\epsilon_1) $

Column-wise mean ($ C_t $):

$ C_1 = \begin{bmatrix} \tfrac{0.09+025+0.04}{3} \\ \tfrac{0.04+0.36+0.16}{3}\\\tfrac{0.16+0.01+0.09}{3} \end{bmatrix} = \begin{bmatrix} 0.1267\\ 0.1867\\0.0867\end{bmatrix} $

Step 3.3: Second Moment Estimate ($ \hat{V_t} $)

The Second Moment Estimate is calculated as the outer product of the row moments ($ R_t $​) and column moments ($ C_t $​).

$ \hat{V}_t = R_t \otimes C_t $

$ \hat{V}_1 = \begin{bmatrix} 0.0967\\0.2067\\0.0967 \end{bmatrix} \otimes \begin{bmatrix} 0.1267&0.1867&0.0867\\ \end{bmatrix} $


$ \hat{V}_1 = \begin{bmatrix} 0.0122&0.0180&0.0084\\ 0.0262&0.0386&0.0179\\ 0.0122&0.0180&0.0084\end{bmatrix} $

Step 4: Update the vector ($ U_t $)

Computed by scaling the gradient matrix $ G_t $​ element-wise with the inverse square root of the second moment estimate ($ \hat{V_t} $​)

step 4.1: Find the vector value of $ U_t $

Formula of $ U_t $

$ U_t = \frac{G_t}{\sqrt{\hat{V_t}+\epsilon_1}} $

Substitute $ C_t $ and $ V_t $

$ U_1 = \frac{\begin{bmatrix}0.3&-0.2&0.4 \\ -0.5&0.6&-0.1\\0.2&-0.4&0.3 \end{bmatrix}}{\sqrt{\begin{bmatrix} 0.0122&0.0180&0.0084\\ 0.0262&0.0386&0.0179\\0.0122&0.0180&0.0084 \end{bmatrix}}} $


$ U_1 = \begin{bmatrix} 2.711&-1.489&4.370\\-3.090&3.055&-0.747\\1.807&-2.978&3.278 \end{bmatrix} $

step 4.2: Clipped Update Vector $ \hat{U_t} $

Scale the update vector ( $ U_t $​ ) to ensure its RMS value does not exceed a predefined clipping threshold ($ d $), maintaining stability in updates.

Formula of $ \hat{U_t} $

$ \hat{U_t} = \frac{U_t}{max(1,\tfrac{RMS(U_t)}{d}) } $

Compute RMS of $ U_t $

$ RMS(U_1) = \sqrt{\tfrac{1}{9} \sum_{i=1}^9 U_t[i]^2} \approx 3.303 $

Since RMS($ U_t $​)>d, scale $ U_t $​ by $ \tfrac{1}{3.303} $

$ \hat{U_1} = \begin{bmatrix} 0.965&-0.53&1.556 \\-1.1&1.088&-0.266\\0.664&-1.06&1.167 \end{bmatrix} $


Step 5: Weight Update ($ X_1 $)

Adjust the weights ($ X_t $) by subtracting the product of the learning rate ($ \alpha_t $) and the clipped update vector ($ \hat{U_t} $ ).

$ X_1 = X_0 - \alpha \cdot \hat{U_t} $

The result for first iteration.

$ X_1 = \begin{bmatrix} 0.7 &-0.5& 0.9\\ -1.1 & 0.8& -1.6\\1.2&-0.7& 0.4 \end{bmatrix} - 0.00806 \cdot \begin{bmatrix} 0.965&-0.53&1.556 \\-1.1&1.088&-0.266\\0.664&-1.06&1.167 \end{bmatrix} $

$ X_1 = \begin{bmatrix} 0.692&-0.496&0.887 \\-1.091&0.791&-0.596\\ 1.195&-0.691&0.391\end{bmatrix} $




Applications

Conclusion

Reference