Adafactor: Difference between revisions
Jump to navigation
Jump to search
Line 5: | Line 5: | ||
== Introduction == | == Introduction == | ||
== Problem Formulation == | == Problem Formulation == | ||
=== | <p>Minimize the loss function <b>f(x)</b>, where <b>x ∈ ℝⁿ</b> and <b>x</b> is the weight vector to be optimized.</p> | ||
<h1> | |||
<h1>2. Parameters</h1> | |||
<ul> | |||
<li><b>Gradient:</b> | |||
<div><i>G<sub>t</sub> = ∇f(x<sub>t-1</sub>)</i></div> | |||
</li> | |||
<li><b>Second moment estimate:</b> | |||
<div><i>Ĥ<sub>Vt</sub> = Ĥ<sub>β2t</sub> Ĥ<sub>Vt-1</sub> + (1 - Ĥ<sub>β2t</sub>)(G<sub>t</sub>² + ε₁ 1ₙ)</i></div> | |||
<ul> | |||
<li><i>Ĥ<sub>Vt</sub></i> is the running average of the squared gradient.</li> | |||
<li><i>Ĥ<sub>β2t</sub></i> is the corrected decay parameter.</li> | |||
<li><i>ε₁</i> is a regularization constant.</li> | |||
</ul> | |||
</li> | |||
<li><b>Step size:</b> | |||
<div><i>α<sub>t</sub> = max(ε₂, RMS(x<sub>t-1</sub>)) ρ<sub>t</sub></i></div> | |||
<ul> | |||
<li><i>ρ<sub>t</sub></i> is the relative step size.</li> | |||
<li><i>ε₂</i> is a regularization constant.</li> | |||
<li><i>RMS</i> is the root mean square, defined as: | |||
<div><i>u<sub>xt</sub> = -g<sub>xt</sub> / √Ĥ<sub>vxt</sub></i></div> | |||
<div><i>RMS(U<sub>t</sub>) = RMS<sub>x ∈ X</sub>(u<sub>xt</sub>) = √Mean<sub>x ∈ X</sub>(g<sub>xt</sub>² / Ĥ<sub>vxt</sub>)</i></div> | |||
</li> | |||
</ul> | |||
</li> | |||
</ul> | |||
<h1>3. Problem Formulation</h1> | |||
<h2>Adafactor for Weighted Vectors</h2> | |||
<h3>Inputs:</h3> | |||
<ul> | |||
<li>Initial point: <i>X₀ ∈ ℝⁿ</i></li> | |||
<li>Relative step sizes: <i>ρ<sub>t</sub></i> for <i>t = 1</i> to <i>T</i></li> | |||
<li>Second moment decay: <i>Ĥ<sub>β2t</sub></i> for <i>t = 1</i> to <i>T</i>, with <i>Ĥ<sub>β21</sub> = 0</i></li> | |||
<li>Regularization constants: <i>ε₁, ε₂</i></li> | |||
<li>Clipping threshold: <i>d</i></li> | |||
</ul> | |||
<h3>Algorithm:</h3> | |||
<ul> | |||
<li>For <i>t = 1</i> to <i>T</i>: | |||
<ul> | |||
<li>Compute adaptive step size: | |||
<div><i>α<sub>t</sub> = max(ε₂, RMS(X<sub>t-1</sub>)) ρ<sub>t</sub></i></div> | |||
</li> | |||
<li>Compute gradient: | |||
<div><i>G<sub>t</sub> = ∇f<sub>t</sub>(X<sub>t-1</sub>)</i></div> | |||
</li> | |||
<li>Update second moment estimate: | |||
<div><i>Ĥ<sub>Vt</sub> = Ĥ<sub>β2t</sub> Ĥ<sub>Vt-1</sub> + (1 - Ĥ<sub>β2t</sub>)(G<sub>t</sub>² + ε₁ 1ₙ)</i></div> | |||
</li> | |||
<li>Compute normalized gradient: | |||
<div><i>U<sub>t</sub> = G<sub>t</sub> / √Ĥ<sub>Vt</sub></i></div> | |||
</li> | |||
<li>Apply clipping: | |||
<div><i>Ĥ<sub>U<sub>t</sub></i> = U<sub>t</sub> / max(1, RMS(U<sub>t</sub>) / d)</i></div> | |||
</li> | |||
<li>Update parameter: | |||
<div><i>X<sub>t</sub> = X<sub>t-1</sub> - α<sub>t</sub> Ĥ<sub>U<sub>t</sub></i></div> | |||
</li> | |||
</ul> | |||
</li> | |||
</ul> | |||
<h2>Adafactor for Weighted Matrices</h2> | |||
<h3>Inputs:</h3> | |||
<ul> | |||
<li>Initial point: <i>X₀ ∈ ℝⁿ × ℝ<sup>m</sup></i></li> | |||
<li>Relative step sizes: <i>ρ<sub>t</sub></i> for <i>t = 1</i> to <i>T</i></li> | |||
<li>Second moment decay: <i>Ĥ<sub>β2t</sub></i> for <i>t = 1</i> to <i>T</i>, with <i>Ĥ<sub>β21</sub> = 0</i></li> | |||
<li>Regularization constants: <i>ε₁, ε₂</i></li> | |||
<li>Clipping threshold: <i>d</i></li> | |||
</ul> | |||
<h3>Algorithm:</h3> | |||
<ul> | |||
<li>For <i>t = 1</i> to <i>T</i>: | |||
<ul> | |||
<li>Compute adaptive step size: | |||
<div><i>α<sub>t</sub> = max(ε₂, RMS(X<sub>t-1</sub>)) ρ<sub>t</sub></i></div> | |||
</li> | |||
<li>Compute gradient: | |||
<div><i>G<sub>t</sub> = ∇f<sub>t</sub>(X<sub>t-1</sub>)</i></div> | |||
</li> | |||
<li>Update row-wise second moment: | |||
<div><i>R<sub>t</sub> = Ĥ<sub>β2t</sub> R<sub>t-1</sub> + (1 - Ĥ<sub>β2t</sub>)(G<sub>t</sub>² + ε₁ 1ₙ 1ₘᵀ) 1ₘ</i></div> | |||
</li> | |||
<li>Update column-wise second moment: | |||
<div><i>C<sub>t</sub> = Ĥ<sub>β2t</sub> C<sub>t-1</sub> + (1 - Ĥ<sub>β2t</sub>) 1ₙᵀ (G<sub>t</sub>² + ε₁ 1ₙ 1ₘᵀ)</i></div> | |||
</li> | |||
<li>Update overall second moment estimate: | |||
<div><i>Ĥ<sub>Vt</sub> = R<sub>t</sub> C<sub>t</sub> / (1ₙᵀ R<sub>t</sub>)</i></div> | |||
</li> | |||
<li>Compute normalized gradient: | |||
<div><i>U<sub>t</sub> = G<sub>t</sub> / √Ĥ<sub>Vt</sub></i></div> | |||
</li> | |||
<li>Apply clipping: | |||
<div><i>Ĥ<sub>U<sub>t</sub></i> = U<sub>t</sub> / max(1, RMS(U<sub>t</sub>) / d)</i></div> | |||
</li> | |||
<li>Update parameter: | |||
<div><i>X<sub>t</sub> = X<sub>t-1</sub> - α<sub>t</sub> Ĥ<sub>U<sub>t</sub></i></div> | |||
</li> | |||
</ul> | |||
</li> | |||
</ul> | |||
== Numerical Examples == | == Numerical Examples == |
Revision as of 16:18, 10 December 2024
Author: Aolei Cao (ac3237), Ziyang Li (zl986), Junjia Liang (jl4439) (ChemE 6800 Fall 2024)
Stewards: Nathan Preuss, Wei-Han Chen, Tianqi Xiao, Guoqing Hu
Introduction
Problem Formulation
Minimize the loss function f(x), where x ∈ ℝⁿ and x is the weight vector to be optimized.
2. Parameters
- Gradient:
Gt = ∇f(xt-1)
- Second moment estimate:
ĤVt = Ĥβ2t ĤVt-1 + (1 - Ĥβ2t)(Gt² + ε₁ 1ₙ)
- ĤVt is the running average of the squared gradient.
- Ĥβ2t is the corrected decay parameter.
- ε₁ is a regularization constant.
- Step size:
αt = max(ε₂, RMS(xt-1)) ρt
- ρt is the relative step size.
- ε₂ is a regularization constant.
- RMS is the root mean square, defined as:
uxt = -gxt / √ĤvxtRMS(Ut) = RMSx ∈ X(uxt) = √Meanx ∈ X(gxt² / Ĥvxt)
3. Problem Formulation
Adafactor for Weighted Vectors
Inputs:
- Initial point: X₀ ∈ ℝⁿ
- Relative step sizes: ρt for t = 1 to T
- Second moment decay: Ĥβ2t for t = 1 to T, with Ĥβ21 = 0
- Regularization constants: ε₁, ε₂
- Clipping threshold: d
Algorithm:
- For t = 1 to T:
- Compute adaptive step size:
αt = max(ε₂, RMS(Xt-1)) ρt
- Compute gradient:
Gt = ∇ft(Xt-1)
- Update second moment estimate:
ĤVt = Ĥβ2t ĤVt-1 + (1 - Ĥβ2t)(Gt² + ε₁ 1ₙ)
- Compute normalized gradient:
Ut = Gt / √ĤVt
- Apply clipping:
ĤUt = Ut / max(1, RMS(Ut) / d)
- Update parameter:
Xt = Xt-1 - αt ĤUt
- Compute adaptive step size:
Adafactor for Weighted Matrices
Inputs:
- Initial point: X₀ ∈ ℝⁿ × ℝm
- Relative step sizes: ρt for t = 1 to T
- Second moment decay: Ĥβ2t for t = 1 to T, with Ĥβ21 = 0
- Regularization constants: ε₁, ε₂
- Clipping threshold: d
Algorithm:
- For t = 1 to T:
- Compute adaptive step size:
αt = max(ε₂, RMS(Xt-1)) ρt
- Compute gradient:
Gt = ∇ft(Xt-1)
- Update row-wise second moment:
Rt = Ĥβ2t Rt-1 + (1 - Ĥβ2t)(Gt² + ε₁ 1ₙ 1ₘᵀ) 1ₘ
- Update column-wise second moment:
Ct = Ĥβ2t Ct-1 + (1 - Ĥβ2t) 1ₙᵀ (Gt² + ε₁ 1ₙ 1ₘᵀ)
- Update overall second moment estimate:
ĤVt = Rt Ct / (1ₙᵀ Rt)
- Compute normalized gradient:
Ut = Gt / √ĤVt
- Apply clipping:
ĤUt = Ut / max(1, RMS(Ut) / d)
- Update parameter:
Xt = Xt-1 - αt ĤUt
- Compute adaptive step size: