AdamW: Difference between revisions
No edit summary |
No edit summary |
||
| Line 26: | Line 26: | ||
#**<math>\epsilon</math>: small constant to avoid division by zero | #**<math>\epsilon</math>: small constant to avoid division by zero | ||
#For each time step <math>t</math>: | #For each time step <math>t</math>: | ||
## | ##Calculate the gradient of the objective function:<math>g_t = \nabla_{\theta_t} f(\theta_t).</math> | ||
Revision as of 15:03, 12 December 2024
Author: Yufeng Hao (yh2295), Zhengdao Tang (zt278), Yixiao Tian (yt669), Yijie Zhang (yz3384), Zheng Zhou (zz875) (ChemE 6800 Fall 2024)
Stewards: Nathan Preuss, Wei-Han Chen, Tianqi Xiao, Guoqing Hu
Introduction
AdamW is an influential optimization algorithm in deep learning, developed as a modification to the Adam optimizer to decouple weight decay from gradient-based updates (Loshchilov & Hutter, 2017). This decoupling was introduced to address overfitting issues that often arise when using standard Adam, especially for large-scale neural network models.
By applying weight decay separately from the adaptive updates of parameters, AdamW achieves more effective regularization while retaining Adam’s strengths, such as adaptive learning rates and computational efficiency. This characteristic enables AdamW to achieve superior convergence and generalization compared to its predecessor, making it particularly advantageous for complex tasks involving large transformer-based architectures like BERT and GPT (Devlin et al., 2019; Brown et al., 2020).
As deep learning models grow in scale and complexity, AdamW has become a preferred optimizer due to its robust and stable convergence properties. Research has shown that AdamW can yield improved validation accuracy, faster convergence, and better generalization compared to both standard Adam and stochastic gradient descent (SGD) with momentum, especially in large-scale applications (Loshchilov & Hutter, 2017; Devlin et al., 2019; Dosovitskiy et al., 2021).
Algorithm Discussion
The standard Adam optimizer integrates weight decay by adding a term proportional to the parameters directly to the gradient, effectively acting as an L2 regularization term. This approach can interfere with Adam’s adaptive learning rates, leading to suboptimal convergence characteristics (Loshchilov & Hutter, 2017).
AdamW addresses this shortcoming by decoupling the weight decay step from the gradient-based parameter updates. Weight decay is applied after the parameter update is performed, preserving the integrity of the adaptive learning rate mechanism while maintaining effective regularization. This decoupling leads to more stable and predictable training dynamics, which is critical for large-scale models prone to overfitting (Loshchilov & Hutter, 2017).
Algorithm Steps
Given the parameters , a learning rate , and weight decay Failed to parse (SVG (MathML can be enabled via browser plugin): Invalid response ("Math extension cannot connect to Restbase.") from server "https://wikimedia.org/api/rest_v1/":): {\displaystyle \lambda} , AdamW follows these steps:
- Initialize:
- Initialize parameters Failed to parse (SVG (MathML can be enabled via browser plugin): Invalid response ("Math extension cannot connect to Restbase.") from server "https://wikimedia.org/api/rest_v1/":): {\displaystyle \theta_0} , the first-moment estimate Failed to parse (SVG (MathML can be enabled via browser plugin): Invalid response ("Math extension cannot connect to Restbase.") from server "https://wikimedia.org/api/rest_v1/":): {\displaystyle m_0 = 0} , and the second-moment estimate Failed to parse (SVG (MathML can be enabled via browser plugin): Invalid response ("Math extension cannot connect to Restbase.") from server "https://wikimedia.org/api/rest_v1/":): {\displaystyle v_0 = 0} .
- Set hyperparameters:
- Failed to parse (SVG (MathML can be enabled via browser plugin): Invalid response ("Math extension cannot connect to Restbase.") from server "https://wikimedia.org/api/rest_v1/":): {\displaystyle \alpha} : learning rate
- Failed to parse (SVG (MathML can be enabled via browser plugin): Invalid response ("Math extension cannot connect to Restbase.") from server "https://wikimedia.org/api/rest_v1/":): {\displaystyle \beta_1} : exponential decay rate for the first moment
- Failed to parse (SVG (MathML can be enabled via browser plugin): Invalid response ("Math extension cannot connect to Restbase.") from server "https://wikimedia.org/api/rest_v1/":): {\displaystyle \beta_2} : exponential decay rate for the second moment
- Failed to parse (SVG (MathML can be enabled via browser plugin): Invalid response ("Math extension cannot connect to Restbase.") from server "https://wikimedia.org/api/rest_v1/":): {\displaystyle \epsilon} : small constant to avoid division by zero
- For each time step Failed to parse (SVG (MathML can be enabled via browser plugin): Invalid response ("Math extension cannot connect to Restbase.") from server "https://wikimedia.org/api/rest_v1/":): {\displaystyle t}
:
- Calculate the gradient of the objective function:Failed to parse (SVG (MathML can be enabled via browser plugin): Invalid response ("Math extension cannot connect to Restbase.") from server "https://wikimedia.org/api/rest_v1/":): {\displaystyle g_t = \nabla_{\theta_t} f(\theta_t).}