Home / Class/ batch_norm_cpu_collect_stats_contiguous_internal Class — pytorch Architecture

batch_norm_cpu_collect_stats_contiguous_internal Class — pytorch Architecture

Architecture documentation for the batch_norm_cpu_collect_stats_contiguous_internal class in batch_norm_kernel.cpp from the pytorch codebase.

Entity Profile

Source Code

aten/src/ATen/native/cpu/batch_norm_kernel.cpp lines 829–888

template <typename scalar_t, typename param_t>
inline void batch_norm_cpu_collect_stats_contiguous_internal(
    Tensor& mean, Tensor& var_sum, const Tensor& input) {
  using opmath_t = at::opmath_type<scalar_t>;
  using bVec = Vectorized<scalar_t>;
  using fVec = Vectorized<opmath_t>;
  int64_t n_batch = input.size(0);
  int64_t n_channel = input.size(1);
  int64_t image_size = input.numel() / n_batch / n_channel;
  int64_t N = input.numel() / n_channel;

  const scalar_t* input_data = input.const_data_ptr<scalar_t>();
  param_t* mean_data = mean.data_ptr<param_t>();
  param_t* var_sum_data = var_sum.data_ptr<param_t>();

  at::parallel_for(0, n_channel, 1, [&](int64_t begin, int64_t end) {
    for (const auto c : c10::irange(begin, end)) {
      opmath_t sum_val = opmath_t(0);
      fVec sum_fvec = fVec(opmath_t(0));
      for (int64_t n = 0; n < n_batch; n++) {
        const scalar_t* input_ptr = input_data + n * n_channel * image_size + c * image_size;
        int64_t d = 0;
        for (; d < image_size - (image_size % bVec::size()); d += bVec::size()) {
          bVec data_bvec = bVec::loadu(input_ptr + d);
          auto [data_fvec0, data_fvec1] = convert_to_float<scalar_t>(data_bvec);
          sum_fvec += data_fvec0;
          sum_fvec += data_fvec1;
        }
        for (; d < image_size; d++) {
          sum_val += opmath_t(input_ptr[d]);
        }
      }
      // TODO: use fast version
      sum_val += vec_reduce_all([](fVec& x, fVec& y) { return x + y; }, sum_fvec, fVec::size());
      opmath_t mean_val = sum_val / N;
      mean_data[c] = param_t(mean_val);

      opmath_t var_val = opmath_t(0);
      fVec var_fvec = fVec(opmath_t(0));
      fVec mean_fvec = fVec(mean_val);
      for (int64_t n = 0; n < n_batch; n++) {
        const scalar_t* input_ptr = input_data + n * n_channel * image_size + c * image_size;
        int64_t d = 0;
        for (; d < image_size - (image_size % bVec::size()); d += bVec::size()) {
          bVec data_bvec = bVec::loadu(input_ptr + d);
          auto [data_fvec0, data_fvec1] = convert_to_float<scalar_t>(data_bvec);
          var_fvec += (data_fvec0 - mean_fvec) * (data_fvec0 - mean_fvec);
          var_fvec += (data_fvec1 - mean_fvec) * (data_fvec1 - mean_fvec);
        }
        for (; d < image_size; d++) {
          opmath_t data_val = input_ptr[d];
          var_val += (data_val - mean_val) * (data_val - mean_val);
        }
      }
      // TODO: use fast version
      var_val += vec_reduce_all([](fVec& x, fVec& y) { return x + y; }, var_fvec, fVec::size());
      var_sum_data[c] = param_t(var_val);
    }
  });
}

Analyze Your Own Codebase

Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.

Try Supermodel Free