Home / Class/ weight_norm_backward_last_dim_kernel Class — pytorch Architecture

weight_norm_backward_last_dim_kernel Class — pytorch Architecture

Architecture documentation for the weight_norm_backward_last_dim_kernel class in WeightNormKernel.cpp from the pytorch codebase.

Entity Profile

Source Code

aten/src/ATen/native/cpu/WeightNormKernel.cpp lines 324–400

template <typename scalar_t, typename accscalar_t>
void weight_norm_backward_last_dim_kernel(
    TensorBase& grad_v,
    TensorBase& grad_g,
    const TensorBase& grad_w,
    const TensorBase& saved_v,
    const TensorBase& saved_g,
    const TensorBase& saved_norm,
    int64_t M, int64_t N) {
  const auto grad_w_data = grad_w.data_ptr<scalar_t>();
  const auto saved_v_data = saved_v.data_ptr<scalar_t>();
  const auto saved_g_data = saved_g.data_ptr<scalar_t>();
  const auto saved_norm_data = saved_norm.data_ptr<accscalar_t>();
  auto grad_v_data = grad_v.data_ptr<scalar_t>();
  auto grad_g_data = grad_g.data_ptr<scalar_t>();

  // the temp buffer will be used twice:
  // 1. vertical reduction from [M, N] to [T, N]
  // 2. store the intermediate data of `sum`, `a` and `b`,
  //    so need to make sure it has at least 3 rows
  //
  int num_threads = at::get_num_threads();
  int K = std::max(3, num_threads);
  TensorBase buffer = at::detail::empty_cpu({K, N}, saved_norm.options()).zero_();
  auto buffer_data = buffer.data_ptr<accscalar_t>();

  // vertical parallel reduction
  at::parallel_for(0, M, 1, [&](int64_t begin, int64_t end) {
    int tid = at::get_thread_num();
    TORCH_CHECK(tid < num_threads, "expect thread id smaller than ", num_threads, ", got thread id ", tid);
    auto buffer_ptr = buffer_data + tid * N;
    for (const auto i : c10::irange(begin, end)) {
      sum_product_per_row(buffer_ptr, grad_w_data + i * N, saved_v_data + i * N, N);
    }
  });

  // store result on the first row of buffer
  for (const auto j : c10::irange(N)) {
    accscalar_t sum = 0;
    for (const auto t : c10::irange(num_threads)) {
      sum += buffer_data[t * N + j];
    }
    buffer_data[j] = sum;
  }

  // reuse the 1st row of buffer to store the sum
  // 2nd row to store coefficient a
  // 3rd row to store coefficient b
  accscalar_t* per_dim_sum = buffer_data;
  accscalar_t* a = buffer_data + N;
  accscalar_t* b = buffer_data + 2 * N;

  // a = g /norm
  // b = a * grad_g / norm
  for (const auto j : c10::irange(N)) {
    accscalar_t saved_norm_val = saved_norm_data[j];
    accscalar_t saved_g_val = accscalar_t(saved_g_data[j]);
    accscalar_t grad_g_val = per_dim_sum[j] / saved_norm_val;
    grad_g_data[j] = scalar_t(grad_g_val);

    a[j] = saved_g_val / saved_norm_val;
    b[j] = a[j] * grad_g_val / saved_norm_val;
  }

  // apply grad_v = a * grad_w - b * v
  at::parallel_for(0, M, 1, [&](int64_t begin, int64_t end) {
    for (const auto i : c10::irange(begin, end)) {
      apply_per_row_backward(
          grad_v_data + i * N,
          grad_w_data + i * N,
          saved_v_data + i * N,
          a,
          b,
          N);
    }
  });
}

Analyze Your Own Codebase

Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.

Try Supermodel Free