Adamax

From Cornell University Computational Optimization Open Textbook - Optimization Wiki
Revision as of 02:04, 15 December 2024 by Fall2024 Team13 (talk | contribs)
Jump to navigation Jump to search

Author: Chengcong Xu (cx253), Jessica Liu (hl2482), Xiaolin Bu (xb58), Qiaoyue Ye (qy252), Haoru Feng (hf352) (ChemE 6800 Fall 2024)

Stewards: Nathan Preuss, Wei-Han Chen, Tianqi Xiao, Guoqing Hu

Introduction

Adamax is an optimization algorithm introduced by Kingma and Ba in their Adam optimizer paper (2014). It improves upon the Adam algorithm by replacing the second moment's root mean square (RMS) norm with the infinity norm (). This change makes Adamax more robust and numerically stable, especially when handling sparse gradients, noisy updates, or optimization problems with significant gradient variations.

Adamax dynamically adjusts learning rates for individual parameters, making it well-suited for training deep neural networks, large-scale machine learning models, and tasks involving high-dimensional parameter spaces.

Algorithm Discussion

The Adamax optimizer, a variant of the Adam optimizer, adapts the learning rate for each parameter based on the first moment estimate and the infinity norm of past gradients. This approach makes it particularly robust for handling sparse gradients and stable under certain training conditions. By replacing the second moment estimate with the infinity norm, Adamax simplifies the parameter update while retaining the core benefits of adaptive learning rates.

Given the parameters Failed to parse (SVG (MathML can be enabled via browser plugin): Invalid response ("Math extension cannot connect to Restbase.") from server "https://wikimedia.org/api/rest_v1/":): {\displaystyle \theta} , a learning rate Failed to parse (SVG (MathML can be enabled via browser plugin): Invalid response ("Math extension cannot connect to Restbase.") from server "https://wikimedia.org/api/rest_v1/":): {\displaystyle \alpha} , and decay rates Failed to parse (SVG (MathML can be enabled via browser plugin): Invalid response ("Math extension cannot connect to Restbase.") from server "https://wikimedia.org/api/rest_v1/":): {\displaystyle \beta_1} and Failed to parse (SVG (MathML can be enabled via browser plugin): Invalid response ("Math extension cannot connect to Restbase.") from server "https://wikimedia.org/api/rest_v1/":): {\displaystyle \beta_2} , Adamax follows these steps:

Initialize

  • Initialize parameters Failed to parse (SVG (MathML can be enabled via browser plugin): Invalid response ("Math extension cannot connect to Restbase.") from server "https://wikimedia.org/api/rest_v1/":): {\displaystyle \theta_0} , the first-moment estimate Failed to parse (SVG (MathML can be enabled via browser plugin): Invalid response ("Math extension cannot connect to Restbase.") from server "https://wikimedia.org/api/rest_v1/":): {\displaystyle m_0 = 0} , and the exponentially weighted infinity norm Failed to parse (SVG (MathML can be enabled via browser plugin): Invalid response ("Math extension cannot connect to Restbase.") from server "https://wikimedia.org/api/rest_v1/":): {\displaystyle u_0 = 0} .
  • Set hyperparameters:
  Failed to parse (SVG (MathML can be enabled via browser plugin): Invalid response ("Math extension cannot connect to Restbase.") from server "https://wikimedia.org/api/rest_v1/":): {\displaystyle \alpha}
: Learning rate
  Failed to parse (SVG (MathML can be enabled via browser plugin): Invalid response ("Math extension cannot connect to Restbase.") from server "https://wikimedia.org/api/rest_v1/":): {\displaystyle \beta_1}
: Exponential decay rate for the first moment
  Failed to parse (SVG (MathML can be enabled via browser plugin): Invalid response ("Math extension cannot connect to Restbase.") from server "https://wikimedia.org/api/rest_v1/":): {\displaystyle \beta_2}
: Exponential decay rate for the infinity norm
  Failed to parse (SVG (MathML can be enabled via browser plugin): Invalid response ("Math extension cannot connect to Restbase.") from server "https://wikimedia.org/api/rest_v1/":): {\displaystyle \epsilon}
: Small constant to avoid division by zero

For each time step

  • Compute Gradient: Failed to parse (SVG (MathML can be enabled via browser plugin): Invalid response ("Math extension cannot connect to Restbase.") from server "https://wikimedia.org/api/rest_v1/":): {\displaystyle g_t = \nabla_{\theta} J(\theta_{t-1})}
  • Update First Moment Estimate: Failed to parse (SVG (MathML can be enabled via browser plugin): Invalid response ("Math extension cannot connect to Restbase.") from server "https://wikimedia.org/api/rest_v1/":): {\displaystyle m_t = \beta_1 \cdot m_{t-1} + (1 - \beta_1) \cdot g_t}
  • Update Infinity Norm: Failed to parse (SVG (MathML can be enabled via browser plugin): Invalid response ("Math extension cannot connect to Restbase.") from server "https://wikimedia.org/api/rest_v1/":): {\displaystyle u_t = \max(\beta_2 \cdot u_{t-1}, |g_t|)}
  • Bias Correction for the First Moment: Failed to parse (SVG (MathML can be enabled via browser plugin): Invalid response ("Math extension cannot connect to Restbase.") from server "https://wikimedia.org/api/rest_v1/":): {\displaystyle \hat{m}_t = \frac{m_t}{1 - \beta_1^t}}
  • Parameter Update: Failed to parse (SVG (MathML can be enabled via browser plugin): Invalid response ("Math extension cannot connect to Restbase.") from server "https://wikimedia.org/api/rest_v1/":): {\displaystyle \theta_t = \theta_{t-1} - \alpha \cdot \frac{\hat{m}_t}{u_t + \epsilon}}

Pseudocode for Adamax

For Failed to parse (SVG (MathML can be enabled via browser plugin): Invalid response ("Math extension cannot connect to Restbase.") from server "https://wikimedia.org/api/rest_v1/":): {\displaystyle t = 1} to Failed to parse (SVG (MathML can be enabled via browser plugin): Invalid response ("Math extension cannot connect to Restbase.") from server "https://wikimedia.org/api/rest_v1/":): {\displaystyle T} :

 Failed to parse (SVG (MathML can be enabled via browser plugin): Invalid response ("Math extension cannot connect to Restbase.") from server "https://wikimedia.org/api/rest_v1/":): {\displaystyle g_t = \nabla_{\theta} J(\theta_{t-1})}

 Failed to parse (SVG (MathML can be enabled via browser plugin): Invalid response ("Math extension cannot connect to Restbase.") from server "https://wikimedia.org/api/rest_v1/":): {\displaystyle m_t = \beta_1 \cdot m_{t-1} + (1 - \beta_1) \cdot g_t}

 Failed to parse (SVG (MathML can be enabled via browser plugin): Invalid response ("Math extension cannot connect to Restbase.") from server "https://wikimedia.org/api/rest_v1/":): {\displaystyle u_t = \max(\beta_2 \cdot u_{t-1}, |g_t|)}

 Failed to parse (SVG (MathML can be enabled via browser plugin): Invalid response ("Math extension cannot connect to Restbase.") from server "https://wikimedia.org/api/rest_v1/":): {\displaystyle \hat{m}_t = \frac{m_t}{1 - \beta_1^t}}

 Failed to parse (SVG (MathML can be enabled via browser plugin): Invalid response ("Math extension cannot connect to Restbase.") from server "https://wikimedia.org/api/rest_v1/":): {\displaystyle \theta_t = \theta_{t-1} - \alpha \cdot \frac{\hat{m}_t}{u_t + \epsilon}}

Numerical Examples

To illustrate the Adamax optimization algorithm, we will minimize the quadratic function Failed to parse (SVG (MathML can be enabled via browser plugin): Invalid response ("Math extension cannot connect to Restbase.") from server "https://wikimedia.org/api/rest_v1/":): {\displaystyle f(x) = x^2} with step-by-step calculations.

Problem Setup

  • Optimization Objective: Minimize Failed to parse (SVG (MathML can be enabled via browser plugin): Invalid response ("Math extension cannot connect to Restbase.") from server "https://wikimedia.org/api/rest_v1/":): {\displaystyle f(x) = x^2} , which reaches its minimum at Failed to parse (SVG (MathML can be enabled via browser plugin): Invalid response ("Math extension cannot connect to Restbase.") from server "https://wikimedia.org/api/rest_v1/":): {\displaystyle x = 0} with Failed to parse (SVG (MathML can be enabled via browser plugin): Invalid response ("Math extension cannot connect to Restbase.") from server "https://wikimedia.org/api/rest_v1/":): {\displaystyle f(x) = 0} .
  • Initial Parameter: Start with Failed to parse (SVG (MathML can be enabled via browser plugin): Invalid response ("Math extension cannot connect to Restbase.") from server "https://wikimedia.org/api/rest_v1/":): {\displaystyle x_0 = 2.0} .
  • Gradient Formula: Failed to parse (SVG (MathML can be enabled via browser plugin): Invalid response ("Math extension cannot connect to Restbase.") from server "https://wikimedia.org/api/rest_v1/":): {\displaystyle g_t = \frac{\partial f}{\partial x} = 2x_t} , which determines the direction and rate of parameter change.
  • Hyperparameters:

Learning Rate: Failed to parse (SVG (MathML can be enabled via browser plugin): Invalid response ("Math extension cannot connect to Restbase.") from server "https://wikimedia.org/api/rest_v1/":): {\displaystyle \alpha = 0.1} controls the step size.

First Moment Decay Rate: Failed to parse (SVG (MathML can be enabled via browser plugin): Invalid response ("Math extension cannot connect to Restbase.") from server "https://wikimedia.org/api/rest_v1/":): {\displaystyle \beta_1 = 0.9} , determines how past gradients influence the current gradient estimate.

Infinity Norm Decay Rate: Failed to parse (SVG (MathML can be enabled via browser plugin): Invalid response ("Math extension cannot connect to Restbase.") from server "https://wikimedia.org/api/rest_v1/":): {\displaystyle \beta_2 = 0.999} , governs the decay of the infinity norm used for scaling updates.

Numerical Stability Constant: Failed to parse (SVG (MathML can be enabled via browser plugin): Invalid response ("Math extension cannot connect to Restbase.") from server "https://wikimedia.org/api/rest_v1/":): {\displaystyle \epsilon = 10^{-8}} , prevents division by zero.

  • Initialization: Failed to parse (SVG (MathML can be enabled via browser plugin): Invalid response ("Math extension cannot connect to Restbase.") from server "https://wikimedia.org/api/rest_v1/":): {\displaystyle m_0 = 0, u_0 = 0, t = 0}

Step-by-Step Calculations

Iteration 1

Failed to parse (SVG (MathML can be enabled via browser plugin): Invalid response ("Math extension cannot connect to Restbase.") from server "https://wikimedia.org/api/rest_v1/":): {\displaystyle t = 1}

  • Gradient Calculation: Failed to parse (SVG (MathML can be enabled via browser plugin): Invalid response ("Math extension cannot connect to Restbase.") from server "https://wikimedia.org/api/rest_v1/":): {\displaystyle g_1 = 2x_0 = 2 \cdot 2.0 = 4.0}

The gradient indicates the steepest direction and magnitude for reducing Failed to parse (SVG (MathML can be enabled via browser plugin): Invalid response ("Math extension cannot connect to Restbase.") from server "https://wikimedia.org/api/rest_v1/":): {\displaystyle f(x)} . A positive gradient shows Failed to parse (SVG (MathML can be enabled via browser plugin): Invalid response ("Math extension cannot connect to Restbase.") from server "https://wikimedia.org/api/rest_v1/":): {\displaystyle x_0} must decrease to minimize the function.

  • First Moment Update: Failed to parse (SVG (MathML can be enabled via browser plugin): Invalid response ("Math extension cannot connect to Restbase.") from server "https://wikimedia.org/api/rest_v1/":): {\displaystyle m_1 = \beta_1 m_0 + (1 - \beta_1) g_1 = 0.9 \cdot 0 + 0.1 \cdot 4.0 = 0.4}

The first moment Failed to parse (SVG (MathML can be enabled via browser plugin): Invalid response ("Math extension cannot connect to Restbase.") from server "https://wikimedia.org/api/rest_v1/":): {\displaystyle m_1} is a running average of past gradients, smoothing out fluctuations.

  • Infinity Norm Update: Failed to parse (SVG (MathML can be enabled via browser plugin): Invalid response ("Math extension cannot connect to Restbase.") from server "https://wikimedia.org/api/rest_v1/":): {\displaystyle u_1 = \max(\beta_2 u_0, |g_1|) = \max(0.999 \cdot 0, 4.0) = 4.0}

The infinity norm Failed to parse (SVG (MathML can be enabled via browser plugin): Invalid response ("Math extension cannot connect to Restbase.") from server "https://wikimedia.org/api/rest_v1/":): {\displaystyle u_1} ensures updates are scaled by the largest observed gradient, stabilizing step sizes.

  • Bias-Corrected Learning Rate: Failed to parse (SVG (MathML can be enabled via browser plugin): Invalid response ("Math extension cannot connect to Restbase.") from server "https://wikimedia.org/api/rest_v1/":): {\displaystyle \hat{\alpha} = \frac{\alpha}{1 - \beta_1^t} = \frac{0.1}{1 - 0.9^1} = 1.0}

The learning rate is corrected for bias introduced by initialization, ensuring effective parameter updates.

  • Parameter Update: Failed to parse (SVG (MathML can be enabled via browser plugin): Invalid response ("Math extension cannot connect to Restbase.") from server "https://wikimedia.org/api/rest_v1/":): {\displaystyle x_1 = x_0 - \frac{\hat{\alpha} \cdot m_1}{u_1 + \epsilon} = 2.0 - \frac{1.0 \cdot 0.4}{4.0 + 10^{-8}} = 1.9}

The parameter moves closer to the function's minimum at Failed to parse (SVG (MathML can be enabled via browser plugin): Invalid response ("Math extension cannot connect to Restbase.") from server "https://wikimedia.org/api/rest_v1/":): {\displaystyle x = 0} .

Iteration 2

Failed to parse (SVG (MathML can be enabled via browser plugin): Invalid response ("Math extension cannot connect to Restbase.") from server "https://wikimedia.org/api/rest_v1/":): {\displaystyle t = 2}

  • Gradient Calculation :Failed to parse (SVG (MathML can be enabled via browser plugin): Invalid response ("Math extension cannot connect to Restbase.") from server "https://wikimedia.org/api/rest_v1/":): {\displaystyle g_2 = 2x_1 = 2 \cdot 1.9 = 3.8}
  • First Moment Update: Failed to parse (SVG (MathML can be enabled via browser plugin): Invalid response ("Math extension cannot connect to Restbase.") from server "https://wikimedia.org/api/rest_v1/":): {\displaystyle m_2 = \beta_1 m_1 + (1 - \beta_1) g_2 = 0.9 \cdot 0.4 + 0.1 \cdot 3.8 = 0.758}
  • Infinity Norm Update: Failed to parse (SVG (MathML can be enabled via browser plugin): Invalid response ("Math extension cannot connect to Restbase.") from server "https://wikimedia.org/api/rest_v1/":): {\displaystyle u_2 = \max(\beta_2 u_1, |g_2|) = \max(0.999 \cdot 4.0, 3.8) = 4.0}
  • Bias-Corrected Learning Rate: Failed to parse (SVG (MathML can be enabled via browser plugin): Invalid response ("Math extension cannot connect to Restbase.") from server "https://wikimedia.org/api/rest_v1/":): {\displaystyle \hat{\alpha} = \frac{\alpha}{1 - \beta_1^t} = \frac{0.1}{1 - 0.9^2} = 0.526}
  • Parameter Update: Failed to parse (SVG (MathML can be enabled via browser plugin): Invalid response ("Math extension cannot connect to Restbase.") from server "https://wikimedia.org/api/rest_v1/":): {\displaystyle x_2 = x_1 - \frac{\hat{\alpha} \cdot m_2}{u_2 + \epsilon} = 1.9 - \frac{0.526 \cdot 0.758}{4.0 + 10^{-8}} = 1.802}

The parameter continues to approach the minimum at Failed to parse (SVG (MathML can be enabled via browser plugin): Invalid response ("Math extension cannot connect to Restbase.") from server "https://wikimedia.org/api/rest_v1/":): {\displaystyle x = 0} .

Summary

Through these two iterations, Adamax effectively adjusts the parameter Failed to parse (SVG (MathML can be enabled via browser plugin): Invalid response ("Math extension cannot connect to Restbase.") from server "https://wikimedia.org/api/rest_v1/":): {\displaystyle x} based on the computed gradients, moving it closer to the minimum. The use of the infinity norm stabilizes the updates, ensuring smooth convergence.

Applications

Natural Language Processing

Adamax is particularly effective in training transformer-based models like BERT and GPT. Its stability with sparse gradients makes it ideal for tasks such as text classification, machine translation, and named entity recognition.

Computer Vision

In computer vision, Adamax optimizes deep CNNs for tasks like image classification and object detection. Its smooth convergence behavior has been observed to enhance performance in models like ResNet and DenseNet.

Reinforcement Learning

Adamax has been applied in training reinforcement learning agents, particularly in environments where gradient updates are inconsistent or noisy, such as robotic control and policy optimization.

Generative Models

For training generative models, including GANs and VAEs, Adamax provides robust optimization, improving stability and output quality during adversarial training.

Time-Series Forecasting

Adamax is used in financial and economic forecasting, where it handles noisy gradients effectively, resulting in stable and accurate time-series predictions.

Advantages over Other Approaches

  • Stability: The use of the infinity norm ensures Adamax handles gradient variations smoothly.
  • Sparse Gradient Handling: Adamax is robust in scenarios with zero or near-zero gradients, common in NLP tasks.
  • Efficiency: Adamax is computationally efficient for high-dimensional optimization problems.

Conclusion

Adamax is a robust and efficient variant of the Adam optimizer that replaces the RMS norm with the infinity norm. Its ability to handle sparse gradients, noisy updates, and large parameter spaces makes it a widely used optimization method in natural language processing, computer vision, reinforcement learning, and generative modeling.

Future advancements may involve integrating Adamax with learning rate schedules and regularization techniques to further enhance its performance.

References