batch_norm_cpu_collect_stats_contiguous_internal Class — pytorch Architecture
Architecture documentation for the batch_norm_cpu_collect_stats_contiguous_internal class in batch_norm_kernel.cpp from the pytorch codebase.
Entity Profile
Source Code
aten/src/ATen/native/cpu/batch_norm_kernel.cpp lines 829–888
template <typename scalar_t, typename param_t>
inline void batch_norm_cpu_collect_stats_contiguous_internal(
Tensor& mean, Tensor& var_sum, const Tensor& input) {
using opmath_t = at::opmath_type<scalar_t>;
using bVec = Vectorized<scalar_t>;
using fVec = Vectorized<opmath_t>;
int64_t n_batch = input.size(0);
int64_t n_channel = input.size(1);
int64_t image_size = input.numel() / n_batch / n_channel;
int64_t N = input.numel() / n_channel;
const scalar_t* input_data = input.const_data_ptr<scalar_t>();
param_t* mean_data = mean.data_ptr<param_t>();
param_t* var_sum_data = var_sum.data_ptr<param_t>();
at::parallel_for(0, n_channel, 1, [&](int64_t begin, int64_t end) {
for (const auto c : c10::irange(begin, end)) {
opmath_t sum_val = opmath_t(0);
fVec sum_fvec = fVec(opmath_t(0));
for (int64_t n = 0; n < n_batch; n++) {
const scalar_t* input_ptr = input_data + n * n_channel * image_size + c * image_size;
int64_t d = 0;
for (; d < image_size - (image_size % bVec::size()); d += bVec::size()) {
bVec data_bvec = bVec::loadu(input_ptr + d);
auto [data_fvec0, data_fvec1] = convert_to_float<scalar_t>(data_bvec);
sum_fvec += data_fvec0;
sum_fvec += data_fvec1;
}
for (; d < image_size; d++) {
sum_val += opmath_t(input_ptr[d]);
}
}
// TODO: use fast version
sum_val += vec_reduce_all([](fVec& x, fVec& y) { return x + y; }, sum_fvec, fVec::size());
opmath_t mean_val = sum_val / N;
mean_data[c] = param_t(mean_val);
opmath_t var_val = opmath_t(0);
fVec var_fvec = fVec(opmath_t(0));
fVec mean_fvec = fVec(mean_val);
for (int64_t n = 0; n < n_batch; n++) {
const scalar_t* input_ptr = input_data + n * n_channel * image_size + c * image_size;
int64_t d = 0;
for (; d < image_size - (image_size % bVec::size()); d += bVec::size()) {
bVec data_bvec = bVec::loadu(input_ptr + d);
auto [data_fvec0, data_fvec1] = convert_to_float<scalar_t>(data_bvec);
var_fvec += (data_fvec0 - mean_fvec) * (data_fvec0 - mean_fvec);
var_fvec += (data_fvec1 - mean_fvec) * (data_fvec1 - mean_fvec);
}
for (; d < image_size; d++) {
opmath_t data_val = input_ptr[d];
var_val += (data_val - mean_val) * (data_val - mean_val);
}
}
// TODO: use fast version
var_val += vec_reduce_all([](fVec& x, fVec& y) { return x + y; }, var_fvec, fVec::size());
var_sum_data[c] = param_t(var_val);
}
});
}
Source
Analyze Your Own Codebase
Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.
Try Supermodel Free