batch_norm_cpu_collect_stats_channels_last_internal Class — pytorch Architecture
Architecture documentation for the batch_norm_cpu_collect_stats_channels_last_internal class in batch_norm_kernel.cpp from the pytorch codebase.
Entity Profile
Source Code
aten/src/ATen/native/cpu/batch_norm_kernel.cpp lines 902–982
template <typename scalar_t, typename param_t>
inline void batch_norm_cpu_collect_stats_channels_last_internal(
Tensor& mean, Tensor& var_sum, const Tensor& input) {
using opmath_t = at::opmath_type<scalar_t>;
using bVec = Vectorized<scalar_t>;
using fVec = Vectorized<opmath_t>;
int64_t n_channel = input.size(1);
int64_t N = input.numel() / n_channel;
const scalar_t* input_data = input.const_data_ptr<scalar_t>();
param_t* mean_data = mean.data_ptr<param_t>();
param_t* var_sum_data = var_sum.data_ptr<param_t>();
int num_threads = at::get_num_threads();
Tensor buffer = at::zeros({num_threads, n_channel}, input.options().dtype(kFloat));
opmath_t* buffer_data = buffer.data_ptr<opmath_t>();
at::parallel_for(0, N, 1, [&](int64_t begin, int64_t end) {
int tid = at::get_thread_num();
TORCH_CHECK(tid < num_threads, "expect thread id smaller than ", num_threads, ", got thread id ", tid);
opmath_t* buffer_ptr = buffer_data + tid * n_channel;
for (const auto i : c10::irange(begin, end)) {
const scalar_t* input_ptr = input_data + i * n_channel;
int64_t d = 0;
for (; d < n_channel - (n_channel % bVec::size()); d += bVec::size()) {
bVec data_bvec = bVec::loadu(input_ptr + d);
auto [data_fvec0, data_fvec1] = convert_to_float<scalar_t>(data_bvec);
fVec sum_fvec0 = fVec::loadu(buffer_ptr + d) + data_fvec0;
fVec sum_fvec1 = fVec::loadu(buffer_ptr + d + fVec::size()) + data_fvec1;
sum_fvec0.store(buffer_ptr + d);
sum_fvec1.store(buffer_ptr + d + fVec::size());
}
for (; d < n_channel; d++) {
buffer_ptr[d] += input_ptr[d];
}
}
});
for (const auto c : c10::irange(n_channel)) {
opmath_t sum = 0;
for (const auto t : c10::irange(num_threads)) {
sum += buffer_data[t * n_channel + c];
}
mean_data[c] = param_t(sum / N);
}
buffer.zero_();
at::parallel_for(0, N, 1, [&](int64_t begin, int64_t end) {
int tid = at::get_thread_num();
TORCH_CHECK(tid < num_threads, "expect thread id smaller than ", num_threads, ", got thread id ", tid);
opmath_t* buffer_ptr = buffer_data + tid * n_channel;
for (const auto i : c10::irange(begin, end)) {
const scalar_t* input_ptr = input_data + i * n_channel;
int64_t d = 0;
for (; d < n_channel - (n_channel % bVec::size()); d += bVec::size()) {
bVec data_bvec = bVec::loadu(input_ptr + d);
auto [data_fvec0, data_fvec1] = convert_to_float<scalar_t>(data_bvec);
auto [mean_fvec0, mean_fvec1] = load2f(mean_data + d);
fVec var_fvec0 = fVec::loadu(buffer_ptr + d);
fVec var_fvec1 = fVec::loadu(buffer_ptr + d + fVec::size());
var_fvec0 += (data_fvec0 - mean_fvec0) * (data_fvec0 - mean_fvec0);
var_fvec1 += (data_fvec1 - mean_fvec1) * (data_fvec1 - mean_fvec1);
var_fvec0.store(buffer_ptr + d);
var_fvec1.store(buffer_ptr + d + fVec::size());
}
for (; d < n_channel; d++) {
opmath_t data_val = opmath_t(input_ptr[d]);
opmath_t mean_val = opmath_t(mean_data[d]);
buffer_ptr[d] += (data_val - mean_val) * (data_val - mean_val);
}
}
});
for (const auto c : c10::irange(n_channel)) {
opmath_t _var_sum = 0;
for (const auto t : c10::irange(num_threads)) {
_var_sum += buffer_data[t * n_channel + c];
}
var_sum_data[c] = param_t(_var_sum);
}
}
Source
Analyze Your Own Codebase
Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.
Try Supermodel Free