Home / Class/ weight_norm_last_dim_kernel Class — pytorch Architecture

weight_norm_last_dim_kernel Class — pytorch Architecture

Architecture documentation for the weight_norm_last_dim_kernel class in WeightNormKernel.cpp from the pytorch codebase.

Entity Profile

Source Code

aten/src/ATen/native/cpu/WeightNormKernel.cpp lines 129–179

template <typename scalar_t, typename accscalar_t>
void weight_norm_last_dim_kernel(
    TensorBase& w,
    TensorBase& norm,
    const TensorBase& v,
    const TensorBase& g,
    int64_t M, int64_t N) {
  const auto v_data = v.data_ptr<scalar_t>();
  const auto g_data = g.data_ptr<scalar_t>();
  auto w_data = w.data_ptr<scalar_t>();
  auto norm_data = norm.data_ptr<accscalar_t>();

  int num_threads = at::get_num_threads();
  TensorBase buffer = at::detail::empty_cpu({num_threads, N}, norm.options()).zero_();
  auto buffer_data = buffer.data_ptr<accscalar_t>();

  // vertical parallel reduction
  at::parallel_for(0, M, 1, [&](int64_t begin, int64_t end) {
    int tid = at::get_thread_num();
    TORCH_CHECK(tid < num_threads, "expect thread id smaller than ", num_threads, ", got thread id ", tid);
    auto buffer_ptr = buffer_data + tid * N;
    for (const auto i : c10::irange(begin, end)) {
      sum_norm_per_row(buffer_ptr, v_data + i * N, N);
    }
  });

  for (const auto j : c10::irange(N)) {
    accscalar_t sum = 0;
    for (const auto t : c10::irange(num_threads)) {
      sum += buffer_data[t * N + j];
    }
    norm_data[j] = std::sqrt(sum);
  }

  // reuse the first row of buffer to store g / norm
  vec::convert(g_data, buffer_data, N);
  using Vec = vec::Vectorized<accscalar_t>;
  vec::map2(
      [](Vec g, Vec norm) { return g / norm; },
      buffer_data,
      buffer_data,
      norm_data,
      N);

  // apply w = v * (g/norm)
  at::parallel_for(0, M, 1, [&](int64_t begin, int64_t end) {
    for (const auto i : c10::irange(begin, end)) {
      apply_norm_per_row(w_data + i * N, v_data + i * N, buffer_data, N);
    }
  });
}

Analyze Your Own Codebase

Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.

Try Supermodel Free