batch_norm_cpu_backward_contiguous_internal Class — pytorch Architecture
Architecture documentation for the batch_norm_cpu_backward_contiguous_internal class in batch_norm_kernel.cpp from the pytorch codebase.
Entity Profile
Source Code
aten/src/ATen/native/cpu/batch_norm_kernel.cpp lines 996–1102
template <typename scalar_t, typename param_t>
void batch_norm_cpu_backward_contiguous_internal(Tensor& grad_input, Tensor& grad_weight, Tensor& grad_bias,
const Tensor& grad_output, const Tensor& input, const Tensor& weight,
const Tensor& running_mean, const Tensor& running_var, const Tensor& save_mean, const Tensor& save_invstd,
bool train, double eps) {
using opmath_t = at::opmath_type<scalar_t>;
using bVec = Vectorized<scalar_t>;
using fVec = Vectorized<opmath_t>;
int64_t n_batch = input.size(0);
int64_t n_channel = input.size(1);
int64_t image_size = input.numel() / n_batch / n_channel;
int64_t N = input.numel() / n_channel;
const scalar_t* grad_output_data = grad_output.const_data_ptr<scalar_t>();
const scalar_t* input_data = input.const_data_ptr<scalar_t>();
scalar_t* grad_input_data = grad_input.defined() ? grad_input.mutable_data_ptr<scalar_t>() : nullptr;
param_t* grad_weight_data = grad_weight.defined() ? grad_weight.data_ptr<param_t>() : nullptr;
param_t* grad_bias_data = grad_bias.defined() ? grad_bias.data_ptr<param_t>() : nullptr;
const bool grad_input_null = grad_input_data == nullptr;
const bool grad_weight_null = grad_weight_data == nullptr;
const bool grad_bias_null = grad_bias_data == nullptr;
auto weight_a = conditional_accessor_1d<const param_t>(weight);
auto save_mean_a = conditional_accessor_1d<const param_t>(save_mean);
auto save_invstd_a = conditional_accessor_1d<const param_t>(save_invstd);
auto running_mean_a = conditional_accessor_1d<const param_t>(running_mean);
auto running_var_a = conditional_accessor_1d<const param_t>(running_var);
// parallel dim reduce on 'channel'
at::parallel_for(0, n_channel, 1, [&](int64_t begin, int64_t end) {
for (const auto c : c10::irange(begin, end)) {
opmath_t w = weight.defined() ? opmath_t(weight_a[c]) : 1;
opmath_t mean, invstd;
if (train) {
mean = save_mean_a[c];
invstd = save_invstd_a[c];
} else {
mean = running_mean_a[c];
invstd = 1 / std::sqrt(running_var_a[c] + eps);
}
// compute 1) sum; 2) dot product of Q(X) and dY.
opmath_t sum{0}, dotp{0};
fVec sum_fvec{0}, dotp_fvec{0};
for (const auto n : c10::irange(n_batch)) {
const scalar_t* x_ptr = input_data + n * n_channel * image_size + c * image_size;
const scalar_t* dy_ptr = grad_output_data + n * n_channel * image_size + c * image_size;
int64_t d = 0;
for (; d < image_size - (image_size % bVec::size()); d += bVec::size()) {
bVec dy_bvec = bVec::loadu(dy_ptr + d);
auto [dy_fvec0, dy_fvec1] = convert_to_float<scalar_t>(dy_bvec);
sum_fvec += dy_fvec0;
sum_fvec += dy_fvec1;
bVec x_bvec = bVec::loadu(x_ptr + d);
auto [x_fvec0, x_fvec1] = convert_to_float<scalar_t>(x_bvec);
dotp_fvec += (x_fvec0 - fVec(mean)) * dy_fvec0;
dotp_fvec += (x_fvec1 - fVec(mean)) * dy_fvec1;
}
for (; d < image_size; d++) {
sum += opmath_t(dy_ptr[d]);
dotp += (opmath_t(x_ptr[d]) - mean) * opmath_t(dy_ptr[d]);
}
}
// TODO: use fast version
sum += vec_reduce_all([](fVec& x, fVec& y) { return x + y; }, sum_fvec, fVec::size());
dotp += vec_reduce_all([](fVec& x, fVec& y) { return x + y; }, dotp_fvec, fVec::size());
if (!grad_input_null) {
if (train) {
opmath_t k = dotp * invstd * invstd / N;
opmath_t grad_mean = sum / N;
for (const auto n : c10::irange(n_batch)) {
const scalar_t* x_ptr = input_data + n * n_channel * image_size + c * image_size;
scalar_t* dx_ptr = grad_input_data + n * n_channel * image_size + c * image_size;
const scalar_t* dy_ptr = grad_output_data + n * n_channel * image_size + c * image_size;
vec::map2(
[=](fVec x, fVec dy) {
fVec dx = (x - fVec(mean)) * fVec(k);
return (dy - fVec(grad_mean) - dx) * fVec(invstd) * fVec(w);
},
dx_ptr, x_ptr, dy_ptr, image_size);
}
} else { // evaluation mode
for (const auto n : c10::irange(n_batch)) {
scalar_t* dx_ptr = grad_input_data + n * n_channel * image_size + c * image_size;
const scalar_t* dy_ptr = grad_output_data + n * n_channel * image_size + c * image_size;
vec::map(
[=](fVec dy) { return dy * fVec(invstd) * fVec(w); },
dx_ptr, dy_ptr, image_size);
}
}
}
if (!grad_weight_null) {
grad_weight_data[c] = param_t(dotp * invstd);
}
if (!grad_bias_null) {
grad_bias_data[c] = param_t(sum);
}
}
});
}
Source
Analyze Your Own Codebase
Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.
Try Supermodel Free