Home / Class/ ApplyInputGradientsChannelsLastColMov Class — pytorch Architecture

ApplyInputGradientsChannelsLastColMov Class — pytorch Architecture

Architecture documentation for the ApplyInputGradientsChannelsLastColMov class in group_norm_kernel.cpp from the pytorch codebase.

Entity Profile

Source Code

aten/src/ATen/native/cpu/group_norm_kernel.cpp lines 1049–1095

template <typename T, typename PT, typename opmath_t>
inline std::enable_if_t<std::is_same_v<T, opmath_t>, void>
ApplyInputGradientsChannelsLastColMov(
  const T* dY_data,
  const T* X_data,
  T* dX_data,
  const PT* rstd,
  const PT* gamma,
  opmath_t c2,
  opmath_t c3,
  int64_t HxW,
  int64_t C,
  int64_t D) {
  const bool gamma_null = (gamma == nullptr);
  int64_t d = 0;
  auto K = vec::Vectorized<T>::size();
  for (; d < D / K * K; d += K) {
    auto c1 = vec::Vectorized<T>(*rstd) *
        (gamma_null ? vec::Vectorized<T>(1)
                    : vec::Vectorized<T>::loadu(gamma + d));
    for (const auto m : c10::irange(HxW)) {
      const T* X_ptr = X_data + m * C;
      const T* dY_ptr = dY_data + m * C;
      T* dX_ptr = dX_data + m * C;
      auto dy_vec = vec::Vectorized<T>::loadu(dY_ptr + d);
      auto x_vec = vec::Vectorized<T>::loadu(X_ptr + d);
      auto dx_vec = c1 * dy_vec +
        vec::Vectorized<T>(c2) * x_vec + vec::Vectorized<T>(c3);
      dx_vec.store(dX_ptr + d);
    }
  }
  if (D - d > 0) {
    auto c1 = vec::Vectorized<T>(*rstd) *
        (gamma_null ? vec::Vectorized<T>(1)
                    : vec::Vectorized<T>::loadu(gamma + d, D - d));
    for (const auto m : c10::irange(HxW)) {
      const T* X_ptr = X_data + m * C;
      const T* dY_ptr = dY_data + m * C;
      T* dX_ptr = dX_data + m * C;
    auto dy_vec = vec::Vectorized<T>::loadu(dY_ptr + d, D - d);
    auto x_vec = vec::Vectorized<T>::loadu(X_ptr + d, D - d);
    auto dx_vec = c1 * dy_vec +
      vec::Vectorized<T>(c2) * x_vec + vec::Vectorized<T>(c3);
    dx_vec.store(dX_ptr + d, D - d);
    }
  }
}

Analyze Your Own Codebase

Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.

Try Supermodel Free