Home / Class/ batch_norm_cpu_backward_contiguous_internal Class — pytorch Architecture

batch_norm_cpu_backward_contiguous_internal Class — pytorch Architecture

Architecture documentation for the batch_norm_cpu_backward_contiguous_internal class in batch_norm_kernel.cpp from the pytorch codebase.

Entity Profile

Source Code

aten/src/ATen/native/cpu/batch_norm_kernel.cpp lines 996–1102

template <typename scalar_t, typename param_t>
void batch_norm_cpu_backward_contiguous_internal(Tensor& grad_input, Tensor& grad_weight, Tensor& grad_bias,
    const Tensor& grad_output, const Tensor& input, const Tensor& weight,
    const Tensor& running_mean, const Tensor& running_var, const Tensor& save_mean, const Tensor& save_invstd,
    bool train, double eps) {
  using opmath_t = at::opmath_type<scalar_t>;
  using bVec = Vectorized<scalar_t>;
  using fVec = Vectorized<opmath_t>;
  int64_t n_batch = input.size(0);
  int64_t n_channel = input.size(1);
  int64_t image_size = input.numel() / n_batch / n_channel;
  int64_t N = input.numel() / n_channel;

  const scalar_t* grad_output_data = grad_output.const_data_ptr<scalar_t>();
  const scalar_t* input_data = input.const_data_ptr<scalar_t>();

  scalar_t* grad_input_data = grad_input.defined() ? grad_input.mutable_data_ptr<scalar_t>() : nullptr;
  param_t* grad_weight_data = grad_weight.defined() ? grad_weight.data_ptr<param_t>() : nullptr;
  param_t* grad_bias_data = grad_bias.defined() ? grad_bias.data_ptr<param_t>() : nullptr;
  const bool grad_input_null = grad_input_data == nullptr;
  const bool grad_weight_null = grad_weight_data == nullptr;
  const bool grad_bias_null = grad_bias_data == nullptr;

  auto weight_a = conditional_accessor_1d<const param_t>(weight);
  auto save_mean_a = conditional_accessor_1d<const param_t>(save_mean);
  auto save_invstd_a = conditional_accessor_1d<const param_t>(save_invstd);
  auto running_mean_a = conditional_accessor_1d<const param_t>(running_mean);
  auto running_var_a = conditional_accessor_1d<const param_t>(running_var);

  // parallel dim reduce on 'channel'
  at::parallel_for(0, n_channel, 1, [&](int64_t begin, int64_t end) {
    for (const auto c : c10::irange(begin, end)) {
      opmath_t w = weight.defined() ? opmath_t(weight_a[c]) : 1;

      opmath_t mean, invstd;
      if (train) {
        mean = save_mean_a[c];
        invstd = save_invstd_a[c];
      } else {
        mean = running_mean_a[c];
        invstd = 1 / std::sqrt(running_var_a[c] + eps);
      }

      // compute 1) sum; 2) dot product of Q(X) and dY.
      opmath_t sum{0}, dotp{0};
      fVec sum_fvec{0}, dotp_fvec{0};
      for (const auto n : c10::irange(n_batch)) {
        const scalar_t* x_ptr = input_data + n * n_channel * image_size + c * image_size;
        const scalar_t* dy_ptr = grad_output_data + n * n_channel * image_size + c * image_size;

        int64_t d = 0;
        for (; d < image_size - (image_size % bVec::size()); d += bVec::size()) {
          bVec dy_bvec = bVec::loadu(dy_ptr + d);
          auto [dy_fvec0, dy_fvec1] = convert_to_float<scalar_t>(dy_bvec);
          sum_fvec += dy_fvec0;
          sum_fvec += dy_fvec1;

          bVec x_bvec = bVec::loadu(x_ptr + d);
          auto [x_fvec0, x_fvec1] = convert_to_float<scalar_t>(x_bvec);
          dotp_fvec += (x_fvec0 - fVec(mean)) * dy_fvec0;
          dotp_fvec += (x_fvec1 - fVec(mean)) * dy_fvec1;
        }
        for (; d < image_size; d++) {
          sum += opmath_t(dy_ptr[d]);
          dotp += (opmath_t(x_ptr[d]) - mean) * opmath_t(dy_ptr[d]);
        }
      }
      // TODO: use fast version
      sum += vec_reduce_all([](fVec& x, fVec& y) { return x + y; }, sum_fvec, fVec::size());
      dotp += vec_reduce_all([](fVec& x, fVec& y) { return x + y; }, dotp_fvec, fVec::size());

      if (!grad_input_null) {
        if (train) {
          opmath_t k = dotp * invstd * invstd / N;
          opmath_t grad_mean = sum / N;
          for (const auto n : c10::irange(n_batch)) {
            const scalar_t* x_ptr = input_data + n * n_channel * image_size + c * image_size;
            scalar_t* dx_ptr = grad_input_data + n * n_channel * image_size + c * image_size;
            const scalar_t* dy_ptr = grad_output_data + n * n_channel * image_size + c * image_size;
            vec::map2(
                [=](fVec x, fVec dy) {
                  fVec dx = (x - fVec(mean)) * fVec(k);
                  return (dy - fVec(grad_mean) - dx) * fVec(invstd) * fVec(w);
                },
                dx_ptr, x_ptr, dy_ptr, image_size);
          }
        } else { // evaluation mode
          for (const auto n : c10::irange(n_batch)) {
            scalar_t* dx_ptr = grad_input_data + n * n_channel * image_size + c * image_size;
            const scalar_t* dy_ptr = grad_output_data + n * n_channel * image_size + c * image_size;
            vec::map(
                [=](fVec dy) { return dy * fVec(invstd) * fVec(w); },
                dx_ptr, dy_ptr, image_size);
          }
        }
      }

      if (!grad_weight_null) {
        grad_weight_data[c] = param_t(dotp * invstd);
      }

      if (!grad_bias_null) {
        grad_bias_data[c] = param_t(sum);
      }
    }
  });
}

Analyze Your Own Codebase

Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.

Try Supermodel Free