AdamW

From Cornell University Computational Optimization Open Textbook - Optimization Wiki
Revision as of 17:20, 12 December 2024 by Fall2024 Wiki Team8 (talk | contribs)
Jump to navigation Jump to search

Author: Yufeng Hao (yh2295), Zhengdao Tang (zt278), Yixiao Tian (yt669), Yijie Zhang (yz3384), Zheng Zhou (zz875) (ChemE 6800 Fall 2024)

Stewards: Nathan Preuss, Wei-Han Chen, Tianqi Xiao, Guoqing Hu

Introduction

AdamW is an influential optimization algorithm in deep learning, developed as a modification to the Adam optimizer to decouple weight decay from gradient-based updates (Loshchilov & Hutter, 2017). This decoupling was introduced to address overfitting issues that often arise when using standard Adam, especially for large-scale neural network models.

By applying weight decay separately from the adaptive updates of parameters, AdamW achieves more effective regularization while retaining Adam’s strengths, such as adaptive learning rates and computational efficiency. This characteristic enables AdamW to achieve superior convergence and generalization compared to its predecessor, making it particularly advantageous for complex tasks involving large transformer-based architectures like BERT and GPT (Devlin et al., 2019; Brown et al., 2020).

As deep learning models grow in scale and complexity, AdamW has become a preferred optimizer due to its robust and stable convergence properties. Research has shown that AdamW can yield improved validation accuracy, faster convergence, and better generalization compared to both standard Adam and stochastic gradient descent (SGD) with momentum, especially in large-scale applications (Loshchilov & Hutter, 2017; Devlin et al., 2019; Dosovitskiy et al., 2021).

Algorithm Discussion

The standard Adam optimizer integrates weight decay by adding a term proportional to the parameters directly to the gradient, effectively acting as an L2 regularization term. This approach can interfere with Adam’s adaptive learning rates, leading to suboptimal convergence characteristics (Loshchilov & Hutter, 2017).

AdamW addresses this shortcoming by decoupling the weight decay step from the gradient-based parameter updates. Weight decay is applied after the parameter update is performed, preserving the integrity of the adaptive learning rate mechanism while maintaining effective regularization. This decoupling leads to more stable and predictable training dynamics, which is critical for large-scale models prone to overfitting (Loshchilov & Hutter, 2017).

Algorithm Steps

Given the parameters , a learning rate , and weight decay , AdamW follows these steps:

  • Initialize:
    • Initialize parameters , the first-moment estimate , and the second-moment estimate .
    • Set hyperparameters:
      • : learning rate
      • : exponential decay rate for the first moment
      • : exponential decay rate for the second moment
      • : small constant to avoid division by zero
  • For each time step :
    • Compute Gradient:
      • Calculate the gradient of the objective function:
     
    • Update First Moment Estimate:
      • Update the exponentially decaying average of past gradients:
     
    • Update Second Moment Estimate:
      • Update the exponentially decaying average of squared gradients (element-wise square):
      
     # where  denotes element-wise multiplication of  with itself.
    • Bias Correction:
      • Compute bias-corrected first and second moment estimates:
     
     
    • Parameter Update with Weight Decay:
      • Update parameters with weight decay applied separately from the gradient step:
     
      • This form highlights that weight decay is applied as a separate additive term to the parameter update, reinforcing the decoupling concept.

Pseudocode for AdamW

Initialize , ,

Set hyperparameters: , , , ,

For to :

  # Compute gradient




  # Update parameters without weight decay
  # Apply decoupled weight decay

Numerical Examples

To demonstrate the functionality of the AdamW algorithm, a straightforward numerical example is presented. This example utilizes small dimensions and simplified values to clearly illustrate the key calculations and steps involved in the algorithm.

Example Setup

Consider the following:

  • Initial parameter:
  • Learning rate:
  • Weight decay:
  • First-moment decay rate:
  • Second-moment decay rate:
  • Small constant:
  • Objective function gradient:

For this example, assume we have a simple quadratic function:


The gradient of this function is:


Step-by-Step Calculation

Initialization

  • First moment estimate:
  • Second moment estimate:
  • Initial parameter:

Iteration 1

  • Step 1: Compute Gradient:
  
  • Step 2: Update First Moment Estimate:
  
  • Step 3: Update Second Moment Estimate:
  
  • Step 4: Bias Correction for First Moment:
  
  • Step 5: Bias Correction for Second Moment:
  
  • Step 6: Parameter Update with Weight Decay:
    • Gradient Update:
    
    • Simplify the denominator:
    
    • Compute the update:
    
    • Weight Decay:
    
    • Updated Parameter:
    

Iteration 2

  • Step 1: Compute Gradient:
  
  • Step 2: Update First Moment Estimate:
  
  • Step 3: Update Second Moment Estimate:
  
  • Step 4: Bias Correction for First Moment:
  
  • Step 5: Bias Correction for Second Moment:
  
  • Step 6: Parameter Update with Weight Decay:
    • Gradient Update:
    
    • Simplify the denominator:
    
    • Compute the update:
    
    • Weight Decay:
    
    • Updated Parameter: