adam_mode Class — pytorch Architecture
Architecture documentation for the adam_mode class in FusedAdamKernel.cpp from the pytorch codebase.
Entity Profile
Source Code
aten/src/ATen/native/cpu/FusedAdamKernel.cpp lines 14–154
template <typename scalar_t, typename opmath_t, ADAM_MODE adam_mode>
std::enable_if_t<
std::is_same_v<scalar_t, Half> || std::is_same_v<scalar_t, BFloat16>,
void>
inline adam_math(
scalar_t* param_ptr,
scalar_t* exp_avg_ptr,
scalar_t* exp_avg_sq_ptr,
scalar_t* grad_ptr,
scalar_t* max_exp_avg_sq_ptr,
double lr,
double bias_correction1,
double bias_correction2,
double exp_avg_grad_coefficient,
double exp_avg_sq_grad_coefficient,
double bias_correction2_sqrt,
double eps,
double weight_decay,
double beta2,
bool amsgrad,
bool maximize,
const float* grad_scale_ptr,
int64_t size
){
double step_size = lr / bias_correction1;
using lpVec = at::vec::Vectorized<scalar_t>;
using fVec = at::vec::Vectorized<opmath_t>;
int64_t d = 0;
for (; d < size - (size % lpVec::size()); d += lpVec::size()) {
lpVec param_lpvec = lpVec::loadu(param_ptr + d);
auto [param_vec1, param_vec2] = vec::convert_to_float<scalar_t>(param_lpvec);
lpVec grad_lpvec = lpVec::loadu(grad_ptr + d);
auto [grad_vec1, grad_vec2] = vec::convert_to_float<scalar_t>(grad_lpvec);
if (grad_scale_ptr) {
grad_vec1 = grad_vec1 / fVec(float(*grad_scale_ptr));
grad_vec2 = grad_vec2 / fVec(float(*grad_scale_ptr));
lpVec grad_vec_to_store = vec::convert_from_float<scalar_t>(grad_vec1, grad_vec2);
grad_vec_to_store.store(grad_ptr + d);
}
if (maximize){
grad_vec1 = grad_vec1 * fVec(opmath_t(-1.0));
grad_vec2 = grad_vec2 * fVec(opmath_t(-1.0));
}
if (weight_decay != 0.f){
if constexpr (adam_mode == ADAM_MODE::ORIGINAL) {
grad_vec1 += param_vec1 * fVec(opmath_t(weight_decay));
grad_vec2 += param_vec2 * fVec(opmath_t(weight_decay));
} else if constexpr (adam_mode == ADAM_MODE::ADAMW) {
param_vec1 = param_vec1 * fVec(opmath_t(1 - lr * weight_decay));
param_vec2 = param_vec2 * fVec(opmath_t(1 - lr * weight_decay));
}
}
lpVec exp_avg_lpvec = lpVec::loadu(exp_avg_ptr + d);
auto [exp_avg_vec1, exp_avg_vec2] = vec::convert_to_float<scalar_t>(exp_avg_lpvec);
// exp_avg.lerp_(grad, 1 - beta1)
const fVec lerp_weight = fVec(opmath_t(exp_avg_grad_coefficient));
auto mask = lerp_weight.abs() < fVec(0.5);
auto coeff = fVec::blendv(lerp_weight - fVec(1), lerp_weight, mask);
auto base1 = fVec::blendv(grad_vec1, exp_avg_vec1, mask);
exp_avg_vec1 = vec::fmadd(coeff, grad_vec1 - exp_avg_vec1, base1);
auto base2 = fVec::blendv(grad_vec2, exp_avg_vec2, mask);
exp_avg_vec2 = vec::fmadd(coeff, grad_vec2 - exp_avg_vec2, base2);
lpVec exp_avg_sq_lpvec = lpVec::loadu(exp_avg_sq_ptr + d);
auto [exp_avg_sq_vec1, exp_avg_sq_vec2] = vec::convert_to_float<scalar_t>(exp_avg_sq_lpvec);
exp_avg_sq_vec1 = exp_avg_sq_vec1 * fVec(opmath_t(beta2)) +
fVec(opmath_t(exp_avg_sq_grad_coefficient)) * grad_vec1 * grad_vec1;
exp_avg_sq_vec2 = exp_avg_sq_vec2 * fVec(opmath_t(beta2)) +
fVec(opmath_t(exp_avg_sq_grad_coefficient)) * grad_vec2 * grad_vec2;
vec::convert_from_float<scalar_t>(exp_avg_vec1, exp_avg_vec2).store(exp_avg_ptr + d);
vec::convert_from_float<scalar_t>(exp_avg_sq_vec1, exp_avg_sq_vec2).store(exp_avg_sq_ptr + d);
fVec denom_vec1, denom_vec2;
if (amsgrad) {
lpVec max_exp_avg_sq_lpvec = lpVec::loadu(max_exp_avg_sq_ptr + d);
auto [max_exp_avg_sq_vec1, max_exp_avg_sq_vec2] = vec::convert_to_float<scalar_t>(max_exp_avg_sq_lpvec);
max_exp_avg_sq_vec1 = maximum(max_exp_avg_sq_vec1, exp_avg_sq_vec1);
max_exp_avg_sq_vec2 = maximum(max_exp_avg_sq_vec2, exp_avg_sq_vec2);
vec::convert_from_float<scalar_t>(max_exp_avg_sq_vec1, max_exp_avg_sq_vec2).store(max_exp_avg_sq_ptr + d);
denom_vec1 =
(max_exp_avg_sq_vec1.sqrt() / fVec(opmath_t(bias_correction2_sqrt))) + fVec(opmath_t(eps));
denom_vec2 =
(max_exp_avg_sq_vec2.sqrt() / fVec(opmath_t(bias_correction2_sqrt))) + fVec(opmath_t(eps));
} else {
denom_vec1 =
(exp_avg_sq_vec1.sqrt() / fVec(opmath_t(bias_correction2_sqrt))) + fVec(opmath_t(eps));
denom_vec2 =
(exp_avg_sq_vec2.sqrt() / fVec(opmath_t(bias_correction2_sqrt))) + fVec(opmath_t(eps));
}
param_vec1 = param_vec1 + fVec(opmath_t(-step_size)) * exp_avg_vec1 / denom_vec1;
param_vec2 = param_vec2 + fVec(opmath_t(-step_size)) * exp_avg_vec2 / denom_vec2;
vec::convert_from_float<scalar_t>(param_vec1, param_vec2).store(param_ptr + d);
}
for (; d < size; d++) {
opmath_t grad_val = grad_ptr[d];
opmath_t param_val = param_ptr[d];
if (grad_scale_ptr) {
grad_val = grad_ptr[d] / float(*grad_scale_ptr);
grad_ptr[d] = grad_val;
}
if (maximize) grad_val = -grad_val;
if (weight_decay != 0.f){
if constexpr (adam_mode == ADAM_MODE::ORIGINAL) {
grad_val += param_val * opmath_t(weight_decay);
} else if constexpr (adam_mode == ADAM_MODE::ADAMW) {
param_val = param_val * opmath_t(1 - lr * weight_decay);
}
}
// exp_avg.lerp_(grad, 1 - beta1)
opmath_t exp_avg_var = exp_avg_ptr[d];
auto is_lerp_weight_small = std::abs(opmath_t(exp_avg_grad_coefficient)) < opmath_t(0.5);
if (is_lerp_weight_small) {
exp_avg_var = exp_avg_var + opmath_t(exp_avg_grad_coefficient) * (grad_val - exp_avg_var);
} else {
exp_avg_var = grad_val - (grad_val - exp_avg_var) * (opmath_t(1) - opmath_t(exp_avg_grad_coefficient));
}
exp_avg_ptr[d] = scalar_t(exp_avg_var);
opmath_t exp_avg_sq_var = exp_avg_sq_ptr[d];
exp_avg_sq_var = exp_avg_sq_var * opmath_t(beta2);
exp_avg_sq_var = exp_avg_sq_var +
opmath_t(exp_avg_sq_grad_coefficient) * grad_val * grad_val;
exp_avg_sq_ptr[d] = scalar_t(exp_avg_sq_var);
opmath_t demon_val;
if (amsgrad) {
opmath_t max_exp_avg_sq_var = max_exp_avg_sq_ptr[d];
max_exp_avg_sq_var = std::max(max_exp_avg_sq_var, exp_avg_sq_var);
max_exp_avg_sq_ptr[d] =
scalar_t(max_exp_avg_sq_var);
demon_val =
std::sqrt(max_exp_avg_sq_var) / opmath_t(bias_correction2_sqrt) + opmath_t(eps);
} else {
demon_val = std::sqrt(exp_avg_sq_var) / opmath_t(bias_correction2_sqrt) + opmath_t(eps);
}
param_ptr[d] = param_val - opmath_t(step_size) * exp_avg_var / demon_val;
}
}
Source
Analyze Your Own Codebase
Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.
Try Supermodel Free