batch_norm_cpu_collect_linear_and_constant_terms Class — pytorch Architecture
Architecture documentation for the batch_norm_cpu_collect_linear_and_constant_terms class in batch_norm_kernel.cpp from the pytorch codebase.
Entity Profile
Source Code
aten/src/ATen/native/cpu/batch_norm_kernel.cpp lines 30–70
template<typename param_t, typename opmath_t>
void batch_norm_cpu_collect_linear_and_constant_terms(
opmath_t* alpha, opmath_t* beta, int64_t n_channel,
const Tensor& weight /* optional */, const Tensor& bias /* optional */,
const Tensor& save_mean, const Tensor& save_invstd,
const Tensor& running_mean, const Tensor& running_var, bool train, double eps) {
const param_t* weight_data = weight.defined() ? weight.const_data_ptr<param_t>() : nullptr;
const param_t* bias_data = bias.defined() ? bias.const_data_ptr<param_t>() : nullptr;
auto save_mean_a = conditional_accessor_1d<const param_t>(save_mean);
auto save_invstd_a = conditional_accessor_1d<const param_t>(save_invstd);
auto running_mean_a = conditional_accessor_1d<const param_t>(running_mean);
auto running_var_a = conditional_accessor_1d<const param_t>(running_var);
/// Collect the linear and constant terms regarding the input.
/// output(n, c, h, w)
/// = (input(n, c, h, w) - mean(c)) / sqrt(var(c) + eps) * weight(c)
/// + bias(c)
/// = input(n, c, h, w) * inv_var(c) * weight(c)
/// - mean(c) * inv_var(c) * weight(c) + bias(c),
/// where inv_var(c) = 1 / sqrt(var(c) + eps).
/// So the linear term, alpha(c) = inv_var(c) * weight(c),
/// the constant term beta(c) = bias(c) - mean(c) * inv_var(c) * weight(c)
/// Note that this is only a good idea if (input_size >> c), in degenerate
/// cases where image_size == 1 && batch_size == 1, it is slow.
for (const auto c : c10::irange(n_channel)) {
opmath_t mean, invstd;
if (train) {
mean = save_mean_a[c];
invstd = save_invstd_a[c];
} else {
mean = running_mean_a[c];
invstd = 1 / std::sqrt(running_var_a[c] + static_cast<opmath_t>(eps));
}
param_t weight_v = weight_data ? weight_data[c] : param_t(1);
param_t bias_v = bias_data ? bias_data[c] : param_t(0);
alpha[c] = invstd * weight_v;
beta[c] = bias_v - mean * alpha[c];
}
}
Source
Analyze Your Own Codebase
Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.
Try Supermodel Free