Home / Class/ weight_norm_backward_first_dim_kernel Class — pytorch Architecture

weight_norm_backward_first_dim_kernel Class — pytorch Architecture

Architecture documentation for the weight_norm_backward_first_dim_kernel class in WeightNormKernel.cpp from the pytorch codebase.

Entity Profile

Source Code

aten/src/ATen/native/cpu/WeightNormKernel.cpp lines 181–228

template <typename scalar_t, typename accscalar_t>
void weight_norm_backward_first_dim_kernel(
    TensorBase& grad_v,
    TensorBase& grad_g,
    const TensorBase& grad_w,
    const TensorBase& saved_v,
    const TensorBase& saved_g,
    const TensorBase& saved_norm,
    int64_t M, int64_t N) {
  const auto grad_w_data = grad_w.data_ptr<scalar_t>();
  const auto saved_v_data = saved_v.data_ptr<scalar_t>();
  const auto saved_g_data = saved_g.data_ptr<scalar_t>();
  const auto saved_norm_data = saved_norm.data_ptr<accscalar_t>();
  auto grad_v_data = grad_v.data_ptr<scalar_t>();
  auto grad_g_data = grad_g.data_ptr<scalar_t>();

  using Vec = vec::Vectorized<accscalar_t>;
  at::parallel_for(0, M, 1, [&](int64_t begin, int64_t end) {
    for (const auto i : c10::irange(begin, end)) {
      accscalar_t per_dim_sum_val = vec::map2_reduce_all<scalar_t>(
          [](Vec grad_w, Vec saved_v) { return grad_w * saved_v; },
          [](Vec x, Vec y) { return x + y; },
          grad_w_data + i * N,
          saved_v_data + i * N,
          N);

      accscalar_t saved_norm_val = saved_norm_data[i];
      accscalar_t saved_g_val = accscalar_t(saved_g_data[i]);
      accscalar_t grad_g_val = per_dim_sum_val / saved_norm_val;

      // grad_g = sum / norm
      // grad_v = (g / norm) * (grad_w - v * (sum / norm^2))
      //  let a = g /norm
      //      b = a * grad_g / norm
      // grad_v = a * grad_w - b * v
      grad_g_data[i] = scalar_t(grad_g_val);
      accscalar_t a = saved_g_val / saved_norm_val;
      accscalar_t b = a * grad_g_val / saved_norm_val;

      vec::map2(
          [a, b](Vec grad_w, Vec v) { return Vec(a) * grad_w - Vec(b) * v; },
          grad_v_data + i * N,
          grad_w_data + i * N,
          saved_v_data + i * N,
          N);
    }
  });
}

Analyze Your Own Codebase

Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.

Try Supermodel Free