AdamW: Difference between revisions
No edit summary |
No edit summary |
||
Line 28: | Line 28: | ||
**Compute Gradient: | **Compute Gradient: | ||
***Calculate the gradient of the objective function: | ***Calculate the gradient of the objective function: | ||
<math>g_t = \nabla_{\theta_t} f(\theta_t) | <math>g_t = \nabla_{\theta_t} f(\theta_t)</math> | ||
**Update First Moment Estimate: | **Update First Moment Estimate: | ||
***Update the exponentially decaying average of past gradients: | ***Update the exponentially decaying average of past gradients: | ||
<math>m_t = \beta_1 m_{t-1} + (1 - \beta_1) g_t | <math>m_t = \beta_1 m_{t-1} + (1 - \beta_1) g_t</math> | ||
**Update Second Moment Estimate: | **Update Second Moment Estimate: | ||
***Update the exponentially decaying average of squared gradients (element-wise square): | ***Update the exponentially decaying average of squared gradients (element-wise square): | ||
Line 39: | Line 39: | ||
***Compute bias-corrected first and second moment estimates: | ***Compute bias-corrected first and second moment estimates: | ||
<math>\hat{m}_t = \frac{m_t}{1 - \beta_1^t},</math> | <math>\hat{m}_t = \frac{m_t}{1 - \beta_1^t},</math> | ||
<math>\hat{v}_t = \frac{v_t}{1 - \beta_2^t} | <math>\hat{v}_t = \frac{v_t}{1 - \beta_2^t}</math> | ||
**Parameter Update with Weight Decay: | **Parameter Update with Weight Decay: | ||
***Update parameters <math>\theta_t</math> with weight decay applied separately from the gradient step: | ***Update parameters <math>\theta_t</math> with weight decay applied separately from the gradient step: | ||
<math>\theta_{t+1} = \theta_t - \alpha \left( \frac{\hat{m}_t}{\sqrt{\hat{v}_t} + \epsilon} + \lambda \theta_t \right) | <math>\theta_{t+1} = \theta_t - \alpha \left( \frac{\hat{m}_t}{\sqrt{\hat{v}_t} + \epsilon} + \lambda \theta_t \right)</math> | ||
***This form highlights that weight decay <math>\lambda \theta_t</math> is applied as a separate additive term to the parameter update, reinforcing the decoupling concept. | ***This form highlights that weight decay <math>\lambda \theta_t</math> is applied as a separate additive term to the parameter update, reinforcing the decoupling concept. | ||
Line 86: | Line 86: | ||
*Step 1: Compute Gradient: | *Step 1: Compute Gradient: | ||
<math>g_1 = 2 \times \theta_0 = 2 \times 10 = 20 | <math>g_1 = 2 \times \theta_0 = 2 \times 10 = 20</math> | ||
*Step 2: Update First Moment Estimate: | *Step 2: Update First Moment Estimate: | ||
<math>m_1 = \beta_1 m_0 + (1 - \beta_1) g_1 = 0.9 \times 0 + 0.1 \times 20 = 2 | <math>m_1 = \beta_1 m_0 + (1 - \beta_1) g_1 = 0.9 \times 0 + 0.1 \times 20 = 2</math> | ||
*Step 3: Update Second Moment Estimate: | *Step 3: Update Second Moment Estimate: | ||
<math>v_1 = \beta_2 v_0 + (1 - \beta_2) g_1^2 = 0.999 \times 0 + 0.001 \times 20^2 = 0 + 0.4 = 0.4 | <math>v_1 = \beta_2 v_0 + (1 - \beta_2) g_1^2 = 0.999 \times 0 + 0.001 \times 20^2 = 0 + 0.4 = 0.4</math> | ||
*Step 4: Bias Correction for First Moment: | *Step 4: Bias Correction for First Moment: | ||
<math>\hat{m}_1 = \frac{m_1}{1 - \beta_1^1} = \frac{2}{1 - 0.9} = \frac{2}{0.1} = 20 | <math>\hat{m}_1 = \frac{m_1}{1 - \beta_1^1} = \frac{2}{1 - 0.9} = \frac{2}{0.1} = 20</math> | ||
*Step 5: Bias Correction for Second Moment: | *Step 5: Bias Correction for Second Moment: | ||
<math>\hat{v}_1 = \frac{v_1}{1 - \beta_2^1} = \frac{0.4}{1 - 0.999} = \frac{0.4}{0.001} = 400 | <math>\hat{v}_1 = \frac{v_1}{1 - \beta_2^1} = \frac{0.4}{1 - 0.999} = \frac{0.4}{0.001} = 400</math> | ||
*Step 6: Parameter Update with Weight Decay: | *Step 6: Parameter Update with Weight Decay: | ||
**Gradient Update: | **Gradient Update: | ||
<math>\theta_{1} = \theta_{0} - \alpha \times \frac{\hat{m}_1}{\sqrt{\hat{v}_1} + \epsilon} = 10 - 0.1 \times \frac{20}{\sqrt{400} + 10^{-8}} | <math>\theta_{1} = \theta_{0} - \alpha \times \frac{\hat{m}_1}{\sqrt{\hat{v}_1} + \epsilon} = 10 - 0.1 \times \frac{20}{\sqrt{400} + 10^{-8}}</math> | ||
**Simplify the denominator: | **Simplify the denominator: | ||
<math>\sqrt{\hat{v}_1} + \epsilon = \sqrt{400} + 10^{-8} = 20 + 10^{-8} | <math>\sqrt{\hat{v}_1} + \epsilon = \sqrt{400} + 10^{-8} = 20 + 10^{-8}</math> | ||
**Compute the update: | **Compute the update: | ||
<math>\theta_{1} = 10 - 0.1 \times \frac{20}{20 + 10^{-8}} = 10 - 0.1 \times 1 = 9.9 | <math>\theta_{1} = 10 - 0.1 \times \frac{20}{20 + 10^{-8}} = 10 - 0.1 \times 1 = 9.9</math> | ||
**Weight Decay: | **Weight Decay: | ||
<math>\theta_{1} = \theta_{1} - \alpha \times \lambda \times \theta_{0} = 9.9 - 0.1 \times 0.01 \times 10 = 9.9 - 0.01 = 9.89 | <math>\theta_{1} = \theta_{1} - \alpha \times \lambda \times \theta_{0} = 9.9 - 0.1 \times 0.01 \times 10 = 9.9 - 0.01 = 9.89</math> | ||
**Updated Parameter: | **Updated Parameter: | ||
<math>\theta_{1} = 9.89 | <math>\theta_{1} = 9.89</math> | ||
==== Iteration 2 ==== | ==== Iteration 2 ==== | ||
*Step 1: Compute Gradient: | *Step 1: Compute Gradient: | ||
<math>g_1 = 2 \times \theta_0 = 2 \times 10 = 20 | <math>g_1 = 2 \times \theta_0 = 2 \times 10 = 20</math> | ||
*Step 2: Update First Moment Estimate: | *Step 2: Update First Moment Estimate: | ||
<math>m_1 = \beta_1 m_0 + (1 - \beta_1) g_1 = 0.9 \times 0 + 0.1 \times 20 = 2 | <math>m_1 = \beta_1 m_0 + (1 - \beta_1) g_1 = 0.9 \times 0 + 0.1 \times 20 = 2</math> | ||
*Step 3: Update Second Moment Estimate: | *Step 3: Update Second Moment Estimate: | ||
<math>v_1 = \beta_2 v_0 + (1 - \beta_2) g_1^2 = 0.999 \times 0 + 0.001 \times 20^2 = 0 + 0.4 = 0.4 | <math>v_1 = \beta_2 v_0 + (1 - \beta_2) g_1^2 = 0.999 \times 0 + 0.001 \times 20^2 = 0 + 0.4 = 0.4</math> | ||
*Step 4: Bias Correction for First Moment: | *Step 4: Bias Correction for First Moment: | ||
<math>\hat{m}_1 = \frac{m_1}{1 - \beta_1^1} = \frac{2}{1 - 0.9} = \frac{2}{0.1} = 20 | <math>\hat{m}_1 = \frac{m_1}{1 - \beta_1^1} = \frac{2}{1 - 0.9} = \frac{2}{0.1} = 20</math> | ||
*Step 5: Bias Correction for Second Moment: | *Step 5: Bias Correction for Second Moment: | ||
<math>\hat{v}_1 = \frac{v_1}{1 - \beta_2^1} = \frac{0.4}{1 - 0.999} = \frac{0.4}{0.001} = 400 | <math>\hat{v}_1 = \frac{v_1}{1 - \beta_2^1} = \frac{0.4}{1 - 0.999} = \frac{0.4}{0.001} = 400</math> | ||
*Step 6: Parameter Update with Weight Decay: | *Step 6: Parameter Update with Weight Decay: | ||
**Gradient Update: | **Gradient Update: | ||
<math>\theta_{1} = \theta_{0} - \alpha \times \frac{\hat{m}_1}{\sqrt{\hat{v}_1} + \epsilon} = 10 - 0.1 \times \frac{20}{\sqrt{400} + 10^{-8}} | <math>\theta_{1} = \theta_{0} - \alpha \times \frac{\hat{m}_1}{\sqrt{\hat{v}_1} + \epsilon} = 10 - 0.1 \times \frac{20}{\sqrt{400} + 10^{-8}}</math> | ||
**Simplify the denominator: | **Simplify the denominator: | ||
<math>\sqrt{\hat{v}_1} + \epsilon = \sqrt{400} + 10^{-8} = 20 + 10^{-8} | <math>\sqrt{\hat{v}_1} + \epsilon = \sqrt{400} + 10^{-8} = 20 + 10^{-8}</math> | ||
**Compute the update: | **Compute the update: | ||
<math>\theta_{1} = 10 - 0.1 \times \frac{20}{20 + 10^{-8}} = 10 - 0.1 \times 1 = 9.9 | <math>\theta_{1} = 10 - 0.1 \times \frac{20}{20 + 10^{-8}} = 10 - 0.1 \times 1 = 9.9</math> | ||
**Weight Decay: | **Weight Decay: | ||
<math>\theta_{1} = \theta_{1} - \alpha \times \lambda \times \theta_{0} = 9.9 - 0.1 \times 0.01 \times 10 = 9.9 - 0.01 = 9.89 | <math>\theta_{1} = \theta_{1} - \alpha \times \lambda \times \theta_{0} = 9.9 - 0.1 \times 0.01 \times 10 = 9.9 - 0.01 = 9.89</math> | ||
**Updated Parameter: | **Updated Parameter: | ||
<math>\theta_{1} = 9.89.</math> | <math>\theta_{1} = 9.89</math> | ||
=== Explanations for Each Step === | |||
Step 1: The gradient is calculated based on the current parameter value. For the function <math>f(\theta) = \theta^2</math>, the gradient <math>g_t = 2 \theta_t</math> represents the slope of the function at <math>\theta_t</math>. | |||
Steps 2 and 3: The first and second moment estimates (<math>m_t</math> and <math>v_t</math>) are updated using exponentially decaying averages of past gradients and squared gradients, respectively. These updates help the optimizer adjust the learning rate dynamically for each parameter, improving efficiency. | |||
Steps 4 and 5: Bias correction is applied to the moment estimates to address their initial bias toward zero. This correction is particularly important during the early stages of optimization, ensuring more accurate estimates. | |||
Step 6: The parameter is updated in two key parts: | |||
*Gradient Update: The parameter is adjusted in the opposite direction of the gradient. This adjustment is scaled by the learning rate and adapted using the corrected moment estimates. | |||
*Weight Decay: A regularization term is applied by reducing the parameter's value slightly. This encourages smaller parameter values, which helps to prevent overfitting. | |||
By repeatedly performing these steps, the AdamW optimizer effectively moves the parameters closer to the function's minimum while controlling overfitting through the use of decoupled weight decay. |
Revision as of 17:30, 12 December 2024
Author: Yufeng Hao (yh2295), Zhengdao Tang (zt278), Yixiao Tian (yt669), Yijie Zhang (yz3384), Zheng Zhou (zz875) (ChemE 6800 Fall 2024)
Stewards: Nathan Preuss, Wei-Han Chen, Tianqi Xiao, Guoqing Hu
Introduction
AdamW is an influential optimization algorithm in deep learning, developed as a modification to the Adam optimizer to decouple weight decay from gradient-based updates (Loshchilov & Hutter, 2017). This decoupling was introduced to address overfitting issues that often arise when using standard Adam, especially for large-scale neural network models.
By applying weight decay separately from the adaptive updates of parameters, AdamW achieves more effective regularization while retaining Adam’s strengths, such as adaptive learning rates and computational efficiency. This characteristic enables AdamW to achieve superior convergence and generalization compared to its predecessor, making it particularly advantageous for complex tasks involving large transformer-based architectures like BERT and GPT (Devlin et al., 2019; Brown et al., 2020).
As deep learning models grow in scale and complexity, AdamW has become a preferred optimizer due to its robust and stable convergence properties. Research has shown that AdamW can yield improved validation accuracy, faster convergence, and better generalization compared to both standard Adam and stochastic gradient descent (SGD) with momentum, especially in large-scale applications (Loshchilov & Hutter, 2017; Devlin et al., 2019; Dosovitskiy et al., 2021).
Algorithm Discussion
The standard Adam optimizer integrates weight decay by adding a term proportional to the parameters directly to the gradient, effectively acting as an L2 regularization term. This approach can interfere with Adam’s adaptive learning rates, leading to suboptimal convergence characteristics (Loshchilov & Hutter, 2017).
AdamW addresses this shortcoming by decoupling the weight decay step from the gradient-based parameter updates. Weight decay is applied after the parameter update is performed, preserving the integrity of the adaptive learning rate mechanism while maintaining effective regularization. This decoupling leads to more stable and predictable training dynamics, which is critical for large-scale models prone to overfitting (Loshchilov & Hutter, 2017).
Algorithm Steps
Given the parameters , a learning rate , and weight decay , AdamW follows these steps:
- Initialize:
- Initialize parameters , the first-moment estimate , and the second-moment estimate .
- Set hyperparameters:
- : learning rate
- : exponential decay rate for the first moment
- : exponential decay rate for the second moment
- : small constant to avoid division by zero
- For each time step :
- Compute Gradient:
- Calculate the gradient of the objective function:
- Compute Gradient:
- Update First Moment Estimate:
- Update the exponentially decaying average of past gradients:
- Update First Moment Estimate:
- Update Second Moment Estimate:
- Update the exponentially decaying average of squared gradients (element-wise square):
- Update Second Moment Estimate:
# where denotes element-wise multiplication of with itself.
- Bias Correction:
- Compute bias-corrected first and second moment estimates:
- Bias Correction:
- Parameter Update with Weight Decay:
- Update parameters with weight decay applied separately from the gradient step:
- Parameter Update with Weight Decay:
- This form highlights that weight decay is applied as a separate additive term to the parameter update, reinforcing the decoupling concept.
Pseudocode for AdamW
Initialize , ,
Set hyperparameters: , , , ,
For to :
# Compute gradient # Update parameters without weight decay # Apply decoupled weight decay
Numerical Examples
To demonstrate the functionality of the AdamW algorithm, a straightforward numerical example is presented. This example utilizes small dimensions and simplified values to clearly illustrate the key calculations and steps involved in the algorithm.
Example Setup
Consider the following:
- Initial parameter:
- Learning rate:
- Weight decay:
- First-moment decay rate:
- Second-moment decay rate:
- Small constant:
- Objective function gradient:
For this example, assume we have a simple quadratic function:
The gradient of this function is:
Step-by-Step Calculation
Initialization
- First moment estimate:
- Second moment estimate:
- Initial parameter:
Iteration 1
- Step 1: Compute Gradient:
- Step 2: Update First Moment Estimate:
- Step 3: Update Second Moment Estimate:
- Step 4: Bias Correction for First Moment:
- Step 5: Bias Correction for Second Moment:
- Step 6: Parameter Update with Weight Decay:
- Gradient Update:
- Simplify the denominator:
- Compute the update:
- Weight Decay:
- Updated Parameter:
Iteration 2
- Step 1: Compute Gradient:
- Step 2: Update First Moment Estimate:
- Step 3: Update Second Moment Estimate:
- Step 4: Bias Correction for First Moment:
- Step 5: Bias Correction for Second Moment:
- Step 6: Parameter Update with Weight Decay:
- Gradient Update:
- Simplify the denominator:
- Compute the update:
- Weight Decay:
- Updated Parameter:
Explanations for Each Step
Step 1: The gradient is calculated based on the current parameter value. For the function , the gradient represents the slope of the function at .
Steps 2 and 3: The first and second moment estimates ( and ) are updated using exponentially decaying averages of past gradients and squared gradients, respectively. These updates help the optimizer adjust the learning rate dynamically for each parameter, improving efficiency.
Steps 4 and 5: Bias correction is applied to the moment estimates to address their initial bias toward zero. This correction is particularly important during the early stages of optimization, ensuring more accurate estimates.
Step 6: The parameter is updated in two key parts:
- Gradient Update: The parameter is adjusted in the opposite direction of the gradient. This adjustment is scaled by the learning rate and adapted using the corrected moment estimates.
- Weight Decay: A regularization term is applied by reducing the parameter's value slightly. This encourages smaller parameter values, which helps to prevent overfitting.
By repeatedly performing these steps, the AdamW optimizer effectively moves the parameters closer to the function's minimum while controlling overfitting through the use of decoupled weight decay.