batch_norm_cpu_backward_channels_last_internal Class — pytorch Architecture
Architecture documentation for the batch_norm_cpu_backward_channels_last_internal class in batch_norm_kernel.cpp from the pytorch codebase.
Entity Profile
Source Code
aten/src/ATen/native/cpu/batch_norm_kernel.cpp lines 1120–1310
template <typename scalar_t, typename param_t>
void batch_norm_cpu_backward_channels_last_internal(Tensor& grad_input, Tensor& grad_weight, Tensor& grad_bias,
const Tensor& grad_output, const Tensor& input, const Tensor& weight,
const Tensor& running_mean, const Tensor& running_var, const Tensor& save_mean, const Tensor& save_invstd,
bool train, double eps) {
using opmath_t = at::opmath_type<scalar_t>;
using bVec = Vectorized<scalar_t>;
using fVec = Vectorized<opmath_t>;
int64_t n_channel = input.size(1);
int64_t N = input.numel() / n_channel;
const scalar_t* grad_output_data = grad_output.const_data_ptr<scalar_t>();
const scalar_t* input_data = input.const_data_ptr<scalar_t>();
scalar_t* grad_input_data = grad_input.defined() ? grad_input.mutable_data_ptr<scalar_t>() : nullptr;
param_t* grad_weight_data = grad_weight.defined() ? grad_weight.data_ptr<param_t>() : nullptr;
param_t* grad_bias_data = grad_bias.defined() ? grad_bias.data_ptr<param_t>() : nullptr;
auto weight_a = conditional_accessor_1d<const param_t>(weight);
auto save_mean_a = conditional_accessor_1d<const param_t>(save_mean);
auto save_invstd_a = conditional_accessor_1d<const param_t>(save_invstd);
auto running_mean_a = conditional_accessor_1d<const param_t>(running_mean);
auto running_var_a = conditional_accessor_1d<const param_t>(running_var);
// use float as acc type
bool weight_defined = weight.defined();
Tensor weight_f = at::empty({n_channel}, input.options().dtype(kFloat));
Tensor mean = at::empty({n_channel}, input.options().dtype(kFloat));
Tensor invstd = at::empty({n_channel}, input.options().dtype(kFloat));
opmath_t* weight_data = weight_f.data_ptr<opmath_t>();
opmath_t* mean_data = mean.data_ptr<opmath_t>();
opmath_t* invstd_data = invstd.data_ptr<opmath_t>();
for (const auto c : c10::irange(n_channel)) {
weight_data[c] = weight_defined ? opmath_t(weight_a[c]) : 1;
if (train) {
mean_data[c] = save_mean_a[c];
invstd_data[c] = save_invstd_a[c];
} else {
mean_data[c] = running_mean_a[c];
invstd_data[c] = 1 / std::sqrt(running_var_a[c] + eps);
}
}
int num_threads = at::get_num_threads();
Tensor buffer = at::zeros({2, num_threads, n_channel}, input.options().dtype(kFloat));
opmath_t* sum_data = buffer.data_ptr<opmath_t>();
opmath_t* dotp_data = sum_data + num_threads * n_channel;
at::parallel_for(0, N, 1, [&](int64_t begin, int64_t end) {
int tid = at::get_thread_num();
TORCH_CHECK(tid < num_threads, "expect thread id smaller than ", num_threads, ", got thread id ", tid);
opmath_t* sum_ptr = sum_data + tid * n_channel;
opmath_t* dotp_ptr = dotp_data + tid * n_channel;
for (const auto i : c10::irange(begin, end)) {
const scalar_t* x_ptr = input_data + i * n_channel;
const scalar_t* dy_ptr = grad_output_data + i * n_channel;
int64_t d = 0;
for(; d < n_channel - (n_channel % bVec::size()); d += bVec::size()) {
bVec dy_bvec = bVec::loadu(dy_ptr + d);
auto [dy_fvec0, dy_fvec1] = convert_to_float<scalar_t>(dy_bvec);
fVec sum_fvec0 = dy_fvec0 + fVec::loadu(sum_ptr + d);
fVec sum_fvec1 = dy_fvec1 + fVec::loadu(sum_ptr + d + fVec::size());
sum_fvec0.store(sum_ptr + d);
sum_fvec1.store(sum_ptr + d + fVec::size());
bVec x_bvec = bVec::loadu(x_ptr + d);
auto [x_fvec0, x_fvec1] = convert_to_float<scalar_t>(x_bvec);
fVec mean_fvec0 = fVec::loadu(mean_data + d);
fVec mean_fvec1 = fVec::loadu(mean_data + d + fVec::size());
fVec dotp_fvec0 = fVec::loadu(dotp_ptr + d);
fVec dotp_fvec1 = fVec::loadu(dotp_ptr + d + fVec::size());
dotp_fvec0 += (x_fvec0 - mean_fvec0) * dy_fvec0;
dotp_fvec1 += (x_fvec1 - mean_fvec1) * dy_fvec1;
dotp_fvec0.store(dotp_ptr + d);
dotp_fvec1.store(dotp_ptr + d + fVec::size());
}
for (; d < n_channel; d++) {
opmath_t dy_val = dy_ptr[d];
opmath_t x_val = x_ptr[d];
opmath_t mean_val = mean_data[d];
sum_ptr[d] += dy_val;
dotp_ptr[d] += (x_val - mean_val) * dy_val;
}
}
});
at::parallel_for(0, n_channel, 1, [&](int64_t begin, int64_t end) {
for (const auto c : c10::irange(begin, end)) {
// store the final result of sum and dotp in the 1st lane of immediate buffer,
// so that we won't need to allocate anther buffer to store the temp values.
opmath_t _sum = 0;
for (const auto t : c10::irange(num_threads)) {
_sum += sum_data[t * n_channel + c];
}
sum_data[/* 0 * n_channel + */c] = _sum;
opmath_t _dotp = 0;
for (const auto t : c10::irange(num_threads)) {
_dotp += dotp_data[t * n_channel + c];
}
dotp_data[/* 0 * n_channel + */c] = _dotp;
}
});
// compute grad_input
if (grad_input.defined()) {
at::parallel_for(0, N, 1, [&](int64_t begin, int64_t end) {
for (const auto i : c10::irange(begin, end)) {
scalar_t* dx_ptr = grad_input_data + i * n_channel;
const scalar_t* x_ptr = input_data + i * n_channel;
const scalar_t* dy_ptr = grad_output_data + i * n_channel;
if (train) {
int64_t d = 0;
for (; d < n_channel - (n_channel % bVec::size()); d += bVec::size()) {
bVec x_bvec = bVec::loadu(x_ptr + d);
auto [x_fvec0, x_fvec1] = convert_to_float<scalar_t>(x_bvec);
fVec mean_fvec0 = fVec::loadu(mean_data + d);
fVec mean_fvec1 = fVec::loadu(mean_data + d + fVec::size());
fVec dotp_fvec0 = fVec::loadu(dotp_data + d);
fVec dotp_fvec1 = fVec::loadu(dotp_data + d + fVec::size());
fVec invstd_fvec0 = fVec::loadu(invstd_data + d);
fVec invstd_fvec1 = fVec::loadu(invstd_data + d + fVec::size());
fVec k_fvec0 = dotp_fvec0 * invstd_fvec0 * invstd_fvec0 / fVec(N);
fVec k_fvec1 = dotp_fvec1 * invstd_fvec1 * invstd_fvec1 / fVec(N);
fVec dx_fvec0 = (x_fvec0 - mean_fvec0) * k_fvec0;
fVec dx_fvec1 = (x_fvec1 - mean_fvec1) * k_fvec1;
bVec dy_bvec = bVec::loadu(dy_ptr + d);
auto [dy_fvec0, dy_fvec1] = convert_to_float<scalar_t>(dy_bvec);
fVec grad_mean_fvec0 = fVec::loadu(sum_data + d) / fVec(N);
fVec grad_mean_fvec1 = fVec::loadu(sum_data + d + fVec::size()) / fVec(N);
fVec w_fvec0 = fVec::loadu(weight_data + d);
fVec w_fvec1 = fVec::loadu(weight_data + d + fVec::size());
dx_fvec0 = (dy_fvec0 - grad_mean_fvec0 - dx_fvec0) * invstd_fvec0 * w_fvec0;
dx_fvec1 = (dy_fvec1 - grad_mean_fvec1 - dx_fvec1) * invstd_fvec1 * w_fvec1;
bVec dx_bvec = convert_from_float<scalar_t>(dx_fvec0, dx_fvec1);
dx_bvec.store(dx_ptr + d);
}
for (; d < n_channel; d++) {
opmath_t x_val = x_ptr[d];
opmath_t mean_val = mean_data[d];
opmath_t dotp_val = dotp_data[d];
opmath_t invstd_val = invstd_data[d];
opmath_t k_val = dotp_val * invstd_val * invstd_val / N;
opmath_t dx_val = (x_val - mean_val) * k_val;
opmath_t dy_val = dy_ptr[d];
opmath_t grad_mean_val = sum_data[d] / N;
opmath_t w_val = weight_data[d];
dx_val = (dy_val - grad_mean_val - dx_val) * invstd_val * w_val;
dx_ptr[d] = scalar_t(dx_val);
}
} else { // evaluation mode
int64_t d = 0;
for (; d < n_channel - (n_channel % bVec::size()); d += bVec::size()) {
bVec dy_bvec = bVec::loadu(dy_ptr + d);
auto [dy_fvec0, dy_fvec1] = convert_to_float<scalar_t>(dy_bvec);
fVec invstd_fvec0 = fVec::loadu(invstd_data + d);
fVec invstd_fvec1 = fVec::loadu(invstd_data + d + fVec::size());
fVec w_fvec0 = fVec::loadu(weight_data + d);
fVec w_fvec1 = fVec::loadu(weight_data + d + fVec::size());
fVec dx_fvec0 = dy_fvec0 * invstd_fvec0 * w_fvec0;
fVec dx_fvec1 = dy_fvec1 * invstd_fvec1 * w_fvec1;
bVec dx_bvec = convert_from_float<scalar_t>(dx_fvec0, dx_fvec1);
dx_bvec.store(dx_ptr + d);
}
for (; d < n_channel; d++) {
opmath_t dy_val = dy_ptr[d];
opmath_t invstd_val = invstd_data[d];
opmath_t w_val = weight_data[d];
opmath_t dx_val = dy_val * invstd_val * w_val;
dx_ptr[d] = scalar_t(dx_val);
}
}
}
});
}
if (grad_weight.defined()) {
for (const auto c : c10::irange(n_channel)) {
grad_weight_data[c] = param_t(dotp_data[c] * invstd_data[c]);
}
}
if (grad_bias.defined()) {
for (const auto c : c10::irange(n_channel)) {
grad_bias_data[c] = param_t(sum_data[c]);
}
}
}
Source
Analyze Your Own Codebase
Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.
Try Supermodel Free