Home / Class/ batch_norm_cpu_collect_linear_and_constant_terms Class — pytorch Architecture

batch_norm_cpu_collect_linear_and_constant_terms Class — pytorch Architecture

Architecture documentation for the batch_norm_cpu_collect_linear_and_constant_terms class in batch_norm_kernel.cpp from the pytorch codebase.

Entity Profile

Source Code

aten/src/ATen/native/cpu/batch_norm_kernel.cpp lines 30–70

template<typename param_t, typename opmath_t>
void batch_norm_cpu_collect_linear_and_constant_terms(
    opmath_t* alpha, opmath_t* beta, int64_t n_channel,
    const Tensor& weight /* optional */, const Tensor& bias /* optional */,
    const Tensor& save_mean, const Tensor& save_invstd,
    const Tensor& running_mean, const Tensor& running_var, bool train, double eps) {

  const param_t* weight_data = weight.defined() ? weight.const_data_ptr<param_t>() : nullptr;
  const param_t* bias_data = bias.defined() ? bias.const_data_ptr<param_t>() : nullptr;

  auto save_mean_a = conditional_accessor_1d<const param_t>(save_mean);
  auto save_invstd_a = conditional_accessor_1d<const param_t>(save_invstd);
  auto running_mean_a = conditional_accessor_1d<const param_t>(running_mean);
  auto running_var_a = conditional_accessor_1d<const param_t>(running_var);

  /// Collect the linear and constant terms regarding the input.
  /// output(n, c, h, w)
  ///     = (input(n, c, h, w) - mean(c)) / sqrt(var(c) + eps) * weight(c)
  ///         + bias(c)
  ///     = input(n, c, h, w) * inv_var(c) * weight(c)
  ///         - mean(c) * inv_var(c) * weight(c) + bias(c),
  /// where inv_var(c) = 1 / sqrt(var(c) + eps).
  /// So the linear term, alpha(c) = inv_var(c) * weight(c),
  ///   the constant term beta(c) = bias(c) - mean(c) * inv_var(c) * weight(c)
  /// Note that this is only a good idea if (input_size >> c), in degenerate
  /// cases where image_size == 1 && batch_size == 1, it is slow.
  for (const auto c : c10::irange(n_channel)) {
    opmath_t mean, invstd;
    if (train) {
      mean = save_mean_a[c];
      invstd = save_invstd_a[c];
    } else {
      mean = running_mean_a[c];
      invstd = 1 / std::sqrt(running_var_a[c] + static_cast<opmath_t>(eps));
    }
    param_t weight_v = weight_data ? weight_data[c] : param_t(1);
    param_t bias_v = bias_data ? bias_data[c] : param_t(0);
    alpha[c] = invstd * weight_v;
    beta[c] = bias_v - mean * alpha[c];
  }
}

Analyze Your Own Codebase

Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.

Try Supermodel Free