Home / Class/ batch_norm_cpu_collect_stats_channels_last_internal Class — pytorch Architecture

batch_norm_cpu_collect_stats_channels_last_internal Class — pytorch Architecture

Architecture documentation for the batch_norm_cpu_collect_stats_channels_last_internal class in batch_norm_kernel.cpp from the pytorch codebase.

Entity Profile

Source Code

aten/src/ATen/native/cpu/batch_norm_kernel.cpp lines 902–982

template <typename scalar_t, typename param_t>
inline void batch_norm_cpu_collect_stats_channels_last_internal(
    Tensor& mean, Tensor& var_sum, const Tensor& input) {
  using opmath_t = at::opmath_type<scalar_t>;
  using bVec = Vectorized<scalar_t>;
  using fVec = Vectorized<opmath_t>;
  int64_t n_channel = input.size(1);
  int64_t N = input.numel() / n_channel;

  const scalar_t* input_data = input.const_data_ptr<scalar_t>();
  param_t* mean_data = mean.data_ptr<param_t>();
  param_t* var_sum_data = var_sum.data_ptr<param_t>();

  int num_threads = at::get_num_threads();
  Tensor buffer = at::zeros({num_threads, n_channel}, input.options().dtype(kFloat));
  opmath_t* buffer_data = buffer.data_ptr<opmath_t>();

  at::parallel_for(0, N, 1, [&](int64_t begin, int64_t end) {
    int tid = at::get_thread_num();
    TORCH_CHECK(tid < num_threads, "expect thread id smaller than ", num_threads, ", got thread id ", tid);
    opmath_t* buffer_ptr = buffer_data + tid * n_channel;
    for (const auto i : c10::irange(begin, end)) {
      const scalar_t* input_ptr = input_data + i * n_channel;
      int64_t d = 0;
      for (; d < n_channel - (n_channel % bVec::size()); d += bVec::size()) {
        bVec data_bvec = bVec::loadu(input_ptr + d);
        auto [data_fvec0, data_fvec1] = convert_to_float<scalar_t>(data_bvec);
        fVec sum_fvec0 = fVec::loadu(buffer_ptr + d) + data_fvec0;
        fVec sum_fvec1 = fVec::loadu(buffer_ptr + d + fVec::size()) + data_fvec1;
        sum_fvec0.store(buffer_ptr + d);
        sum_fvec1.store(buffer_ptr + d + fVec::size());
      }
      for (; d < n_channel; d++) {
        buffer_ptr[d] += input_ptr[d];
      }
    }
  });

  for (const auto c : c10::irange(n_channel)) {
    opmath_t sum = 0;
    for (const auto t : c10::irange(num_threads)) {
      sum += buffer_data[t * n_channel + c];
    }
    mean_data[c] = param_t(sum / N);
  }

  buffer.zero_();
  at::parallel_for(0, N, 1, [&](int64_t begin, int64_t end) {
    int tid = at::get_thread_num();
    TORCH_CHECK(tid < num_threads, "expect thread id smaller than ", num_threads, ", got thread id ", tid);
    opmath_t* buffer_ptr = buffer_data + tid * n_channel;
    for (const auto i : c10::irange(begin, end)) {
      const scalar_t* input_ptr = input_data + i * n_channel;
      int64_t d = 0;
      for (; d < n_channel - (n_channel % bVec::size()); d += bVec::size()) {
        bVec data_bvec = bVec::loadu(input_ptr + d);
        auto [data_fvec0, data_fvec1] = convert_to_float<scalar_t>(data_bvec);
        auto [mean_fvec0, mean_fvec1] = load2f(mean_data + d);
        fVec var_fvec0 = fVec::loadu(buffer_ptr + d);
        fVec var_fvec1 = fVec::loadu(buffer_ptr + d + fVec::size());
        var_fvec0 += (data_fvec0 - mean_fvec0) * (data_fvec0 - mean_fvec0);
        var_fvec1 += (data_fvec1 - mean_fvec1) * (data_fvec1 - mean_fvec1);
        var_fvec0.store(buffer_ptr + d);
        var_fvec1.store(buffer_ptr + d + fVec::size());
      }
      for (; d < n_channel; d++) {
        opmath_t data_val = opmath_t(input_ptr[d]);
        opmath_t mean_val = opmath_t(mean_data[d]);
        buffer_ptr[d] += (data_val - mean_val) * (data_val - mean_val);
      }
    }
  });

  for (const auto c : c10::irange(n_channel)) {
    opmath_t _var_sum = 0;
    for (const auto t : c10::irange(num_threads)) {
      _var_sum += buffer_data[t * n_channel + c];
    }
    var_sum_data[c] = param_t(_var_sum);
  }
}

Analyze Your Own Codebase

Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.

Try Supermodel Free