Home / Class/ is_same_v Class — pytorch Architecture

is_same_v Class — pytorch Architecture

Architecture documentation for the is_same_v class in FusedSGDKernel.cpp from the pytorch codebase.

Entity Profile

Source Code

aten/src/ATen/native/cpu/FusedSGDKernel.cpp lines 14–103

template <typename scalar_t, typename opmath_t>
std::enable_if_t<
    std::is_same_v<scalar_t, Half> || std::is_same_v<scalar_t, BFloat16>,
    void>
    inline sgd_math(
  scalar_t* param_ptr,
  scalar_t* grad_ptr,
  scalar_t* momentum_buf_ptr,
  const double weight_decay,
  const double momentum,
  const double lr,
  const double dampening,
  const bool nesterov,
  const bool maximize,
  const bool is_first_step,
  const float* grad_scale_ptr,
  int64_t size
){
  using lpVec = at::vec::Vectorized<scalar_t>;
  using fVec = at::vec::Vectorized<opmath_t>;
  int64_t d = 0;
  for (; d < size - (size % lpVec::size()); d += lpVec::size()) {
    lpVec param_lpvec = lpVec::loadu(param_ptr + d);
    auto [param_vec1, param_vec2] = vec::convert_to_float<scalar_t>(param_lpvec);
    lpVec grad_lpvec = lpVec::loadu(grad_ptr + d);
    auto [grad_vec1, grad_vec2] = vec::convert_to_float<scalar_t>(grad_lpvec);
    if (grad_scale_ptr) {
      grad_vec1 = grad_vec1 / fVec(float(*grad_scale_ptr));
      grad_vec2 = grad_vec2 / fVec(float(*grad_scale_ptr));
      lpVec grad_vec_to_store = vec::convert_from_float<scalar_t>(grad_vec1, grad_vec2);
      grad_vec_to_store.store(grad_ptr + d);
    }
    if (maximize){
      grad_vec1 = grad_vec1 * fVec(opmath_t(-1.0));
      grad_vec2 = grad_vec2 * fVec(opmath_t(-1.0));
    }
    if (weight_decay != 0.0){
      grad_vec1 = vec::fmadd(param_vec1, fVec(scalar_t(weight_decay)), grad_vec1);
      grad_vec2 = vec::fmadd(param_vec2, fVec(scalar_t(weight_decay)), grad_vec2);
    }
    if (momentum != 0.0) {
      fVec momentum_vec1, momentum_vec2;
      if (is_first_step) {
        momentum_vec1 = grad_vec1;
        momentum_vec2 = grad_vec2;
      } else {
        momentum_vec1 = fVec::loadu(momentum_buf_ptr + d) * fVec(scalar_t(momentum));
        momentum_vec2 = fVec::loadu(momentum_buf_ptr + d + fVec::size()) * fVec(scalar_t(momentum));
        momentum_vec1 = vec::fmadd(fVec(scalar_t(1 - dampening)), grad_vec1, momentum_vec1);
        momentum_vec2 = vec::fmadd(fVec(scalar_t(1 - dampening)), grad_vec2, momentum_vec2);
      }
      vec::convert_from_float<scalar_t>(momentum_vec1, momentum_vec2).store(momentum_buf_ptr + d);;
      if (nesterov) {
        grad_vec1 = vec::fmadd(momentum_vec1, fVec(scalar_t(momentum)), grad_vec1);
        grad_vec2 = vec::fmadd(momentum_vec2, fVec(scalar_t(momentum)), grad_vec2);
      } else {
        grad_vec1 = momentum_vec1;
        grad_vec2 = momentum_vec2;
      }
    }
  }
  for (; d < size; d++) {
    opmath_t grad_val = grad_ptr[d];
    opmath_t param_val = param_ptr[d];
    if (grad_scale_ptr) {
      grad_val = grad_ptr[d] / opmath_t(*grad_scale_ptr);
      grad_ptr[d] = grad_val;
    }
    if (maximize) grad_val = -grad_val;
    if (weight_decay != 0.0){
      grad_val += param_val * opmath_t(weight_decay);
    }
    if (momentum != 0.0) {
      opmath_t momentum_buf_var = momentum_buf_ptr[d];
      if (is_first_step) {
        momentum_buf_var = grad_val;
      } else {
        momentum_buf_var = momentum_buf_var * opmath_t(momentum) +
            grad_val * opmath_t(1 - dampening);
      }
      momentum_buf_ptr[d] = momentum_buf_var;
      if (nesterov) {
        grad_val += momentum_buf_var * opmath_t(momentum);
      } else {
        grad_val = momentum_buf_var;
      }
    }
    param_ptr[d] = param_val - grad_val * opmath_t(lr);
  }
}

Analyze Your Own Codebase

Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.

Try Supermodel Free