Adafactor
Author: Aolei Cao (ac3237), Ziyang Li (zl986), Junjia Liang (jl4439) (ChemE 6800 Fall 2024)
Stewards: Nathan Preuss, Wei-Han Chen, Tianqi Xiao, Guoqing Hu
Introduction
Problem formulation
===1. Objective===
Minimize the loss function f(x), where x ∈ ℝⁿ and x is the weight vector to be optimized.
===2. Parameters===
- Gradient:
Gt = ∇f(xt-1)
- Second moment estimate:
ĤVt = Ĥβ2t ĤVt-1 + (1 - Ĥβ2t)(Gt² + ε₁ 1ₙ)
- ĤVt is the running average of the squared gradient.
- Ĥβ2t is the corrected decay parameter.
- ε₁ is a regularization constant.
- Step size:
αt = max(ε₂, RMS(xt-1)) ρt
- ρt is the relative step size.
- ε₂ is a regularization constant.
- RMS is the root mean square, defined as:
uxt = -gxt / √ĤvxtRMS(Ut) = RMSx ∈ X(uxt) = √Meanx ∈ X(gxt² / Ĥvxt)
3. Problem Formulation
Adafactor for Weighted Vectors
Inputs:
- Initial point: X₀ ∈ ℝⁿ
- Relative step sizes: ρt for t = 1 to T
- Second moment decay: Ĥβ2t for t = 1 to T, with Ĥβ21 = 0
- Regularization constants: ε₁, ε₂
- Clipping threshold: d
Algorithm:
- For t = 1 to T:
- Compute adaptive step size:
αt = max(ε₂, RMS(Xt-1)) ρt
- Compute gradient:
Gt = ∇ft(Xt-1)
- Update second moment estimate:
ĤVt = Ĥβ2t ĤVt-1 + (1 - Ĥβ2t)(Gt² + ε₁ 1ₙ)
- Compute normalized gradient:
Ut = Gt / √ĤVt
- Apply clipping:
ĤUt = Ut / max(1, RMS(Ut) / d)
- Update parameter:
Xt = Xt-1 - αt ĤUt
- Compute adaptive step size:
Adafactor for Weighted Matrices
Inputs:
- Initial point: X₀ ∈ ℝⁿ × ℝm
- Relative step sizes: ρt for t = 1 to T
- Second moment decay: Ĥβ2t for t = 1 to T, with Ĥβ21 = 0
- Regularization constants: ε₁, ε₂
- Clipping threshold: d
Algorithm:
- For t = 1 to T:
- Compute adaptive step size:
αt = max(ε₂, RMS(Xt-1)) ρt
- Compute gradient:
Gt = ∇ft(Xt-1)
- Update row-wise second moment:
Rt = Ĥβ2t Rt-1 + (1 - Ĥβ2t)(Gt² + ε₁ 1ₙ 1ₘᵀ) 1ₘ
- Update column-wise second moment:
Ct = Ĥβ2t Ct-1 + (1 - Ĥβ2t) 1ₙᵀ (Gt² + ε₁ 1ₙ 1ₘᵀ)
- Update overall second moment estimate:
ĤVt = Rt Ct / (1ₙᵀ Rt)
- Compute normalized gradient:
Ut = Gt / √ĤVt
- Apply clipping:
ĤUt = Ut / max(1, RMS(Ut) / d)
- Update parameter:
Xt = Xt-1 - αt ĤUt
- Compute adaptive step size:
4. Proposed Hyperparameters for Adafactor
- Regularization constant 1: ε₁ = 10⁻³⁰
- Regularization constant 2: ε₂ = 10⁻³
- Clipping threshold: d = 1
- Relative step size: ρt = min(10⁻², 1/√t)
- Second moment decay: Ĥβ2t = 1 - t⁻⁰.⁸
</body> </html>