Adafactor

From Cornell University Computational Optimization Open Textbook - Optimization Wiki
Jump to navigation Jump to search

Author: Aolei Cao (ac3237), Ziyang Li (zl986), Junjia Liang (jl4439) (ChemE 6800 Fall 2024)

Stewards: Nathan Preuss, Wei-Han Chen, Tianqi Xiao, Guoqing Hu

Introduction

Problem formulation

1. Objective

Minimize the loss function , where and is the weight vector to be optimized.

2. Parameters

  • Gradient:

  • Second moment estimate:

  • Where:
    • is the running average of the squared gradient.
    • is the corrected decay parameter.
    • is a regularization constant.
  • Step size:

  • Where:
    • is the relative step size.
    • is a regularization constant.
    • is the root mean square, defined as:

3. Algorithms

Adafactor for Weighted Vectors

Inputs:

  • Initial point:
  • Relative step sizes: for to
  • Second moment decay: for to , with
  • Regularization constants:
  • Clipping threshold:

Algorithm:

  • For to :
    • Compute adaptive step size:
    • Compute gradient:
    • Update second moment estimate:
    • Compute normalized gradient:
    • Apply clipping:
    • Update parameter:
  • End for

Adafactor for Weighted Matrices

Inputs:

  • Initial point:
  • Relative step sizes: for to
  • Second moment decay: for to , with
  • Regularization constants:
  • Clipping threshold:

Algorithm:

  • For to :
    • Compute adaptive step size:
    • Compute gradient:
    • Update row-wise second moment:
    • Update column-wise second moment:
    • Update overall second moment estimate:
    • Compute normalized gradient:
    • Apply clipping:
    • Update parameter:
  • End for

Why Adafactor is more memory efficient, compared to Adam

Row-wise and Column-wise Second Moment Updates

Instead of storing the full , Adafactor computes the row and column respectively, which reduces the memory requirements from to

Factored Representation of the Second Moment

This updates the second momentum based on the outer product .

  • However, this is not since
    • The operation is performed element-wise, so it actually never materializes as a matrix
    • It also only storing and instead of storage the full second-moment matrix

4. Proposed Hyperparameters for Adafactor

  • Regularization constant 1:
  • Ensures numerical stability by preventing division by zero in the calculation of second-moment estimates, so the numerical value should be very close to zero
  • Regularization constant 2:
  • Help to stabilize parameter updates by controlling the effect of second-moment scaling in low-magnitude scenarios. Compared to , a relatively larger value ensures the stability of noise and low-magnitude scenarios.
  • Clipping threshold:
  • A threshold of 1 balances stability and learning efficiency. It avoids excessive suppression of large gradients, which could hinder learning, while still protecting against extreme updates that could destabilize the model.
  • Relative step size:
    • can caps the learning rate at 10^-2, which is a empirical found for upper bound
    • This step size promote convergence of the model. This rate ensures a balance between sufficient exploration in early iteration and stability in later iterations
  • Second moment decay:
    • 1-...: ensures the decay factor remains close to 1
    • the power 0.8 ensures a balance between rapid adaptation in early training and later iterations

5.Discussion

Why Clipping

Adafactor employs clipping to maintain numerical stability, especially since it is designed for use with very large models and often works with unscaled learning rates.

  • Clipping prevents the update step from becoming very large, which would destabilize training
  • Clipping mitigates the effects of very large gradients preventing numerical instability

Therefore, implementing clipping helps ensure stability and efficient training without requiring per-parameter scaling like Adam.

Why Adafactor is more memory efficient, compared to Adam

Row-wise and Column-wise Second Moment Updates

Instead of storing the full , Adafactor computes the row and column respectively, which reduces the memory requirements from to

Factored Representation of the Second Moment

This updates the second momentum based on the outer product .

  • However, this is not since
    • The operation is performed element-wise, so it actually never materializes as a matrix
    • It also only storing and instead of storage the full second-moment matrix

Numerical Examples

Applications

Conclusion

Reference