Home / Class/ weight_norm_first_dim_kernel Class — pytorch Architecture

weight_norm_first_dim_kernel Class — pytorch Architecture

Architecture documentation for the weight_norm_first_dim_kernel class in WeightNormKernel.cpp from the pytorch codebase.

Entity Profile

Source Code

aten/src/ATen/native/cpu/WeightNormKernel.cpp lines 17–48

template <typename scalar_t, typename accscalar_t>
void weight_norm_first_dim_kernel(
    TensorBase& w,
    TensorBase& norm,
    const TensorBase& v,
    const TensorBase& g,
    int64_t M, int64_t N) {
  const auto v_data = v.data_ptr<scalar_t>();
  const auto g_data = g.data_ptr<scalar_t>();
  auto w_data = w.data_ptr<scalar_t>();
  auto norm_data = norm.data_ptr<accscalar_t>();

  using Vec = vec::Vectorized<accscalar_t>;
  at::parallel_for(0, M, 1, [&](int64_t begin, int64_t end) {
    for (const auto i : c10::irange(begin, end)) {
      accscalar_t norm_val = vec::map_reduce_all<scalar_t>(
          [](Vec x) { return x * x; },
          [](Vec x, Vec y) { return x + y; },
          v_data + i * N,
          N);
      norm_val = std::sqrt(norm_val);
      norm_data[i] = norm_val;

      accscalar_t a = g_data[i] / norm_val;
      vec::map(
          [a](Vec x) { return x * Vec(a); },
          w_data + i * N,
          v_data + i * N,
          N);
    }
  });
}

Analyze Your Own Codebase

Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.

Try Supermodel Free