Home / Class/ scalar_t Class — pytorch Architecture

scalar_t Class — pytorch Architecture

Architecture documentation for the scalar_t class in batch_norm_kernel.cpp from the pytorch codebase.

Entity Profile

Source Code

aten/src/ATen/native/cpu/batch_norm_kernel.cpp lines 73–123

template<typename scalar_t>
typename std::enable_if_t<std::is_same_v<scalar_t, at::opmath_type<scalar_t>>, void>
batch_norm_cpu_contiguous_impl(Tensor& output, const Tensor& input,
    const Tensor& weight, const Tensor& bias, const Tensor& save_mean, const Tensor& save_invstd,
    const Tensor& running_mean, const Tensor& running_var, bool train, double eps) {

  using Vec = Vectorized<scalar_t>;
  int64_t n_batch = input.size(0);
  int64_t n_channel = input.size(1);
  int64_t image_size = input.numel() / n_batch / n_channel;

  Tensor alpha = at::empty({n_channel}, input.options());
  Tensor beta = at::empty({n_channel}, input.options());
  scalar_t* alpha_data = alpha.mutable_data_ptr<scalar_t>();
  scalar_t* beta_data = beta.data_ptr<scalar_t>();

  batch_norm_cpu_collect_linear_and_constant_terms<scalar_t, scalar_t>(
     alpha_data, beta_data, n_channel, weight, bias,
     save_mean, save_invstd, running_mean, running_var, train, eps);

  scalar_t* output_data = output.data_ptr<scalar_t>();
  const scalar_t* input_data = input.const_data_ptr<scalar_t>();

  // Apply the linear terms to the input,
  // output(n, c, h, w) = input(n, c, h, w) * alpha(c) + beta(c)
  const int64_t loop_size = image_size - (image_size % Vec::size());
  at::parallel_for(0, n_batch * n_channel, 1, [&](int64_t begin, int64_t end) {
    int64_t n = 0;
    int64_t c = 0;
    data_index_init(begin, n, n_batch, c, n_channel);

    for (const auto i : c10::irange(begin, end)) {
      const Vec alpha_vec(alpha_data[c]);
      const Vec beta_vec(beta_data[c]);
      int64_t offset = i * image_size;
      int64_t d = 0;
      for (; d < loop_size; d += Vec::size()) {
        Vec data_vec = Vec::loadu(input_data + offset + d);
        Vec output_vec = data_vec * alpha_vec + beta_vec;
        output_vec.store(output_data + offset + d);
      }
      if (image_size - d > 0) {
        Vec data_vec = Vec::loadu(input_data + offset + d, image_size - d);
        Vec output_vec = data_vec * alpha_vec + beta_vec;
        output_vec.store(output_data + offset + d, image_size - d);
      }
      // move on to next index
      data_index_step(n, n_batch, c, n_channel);
    }
  });
}

Analyze Your Own Codebase

Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.

Try Supermodel Free