Home / Class/ batch_norm_cpu_backward_channels_last_internal Class — pytorch Architecture

batch_norm_cpu_backward_channels_last_internal Class — pytorch Architecture

Architecture documentation for the batch_norm_cpu_backward_channels_last_internal class in batch_norm_kernel.cpp from the pytorch codebase.

Entity Profile

Source Code

aten/src/ATen/native/cpu/batch_norm_kernel.cpp lines 1120–1310

template <typename scalar_t, typename param_t>
void batch_norm_cpu_backward_channels_last_internal(Tensor& grad_input, Tensor& grad_weight, Tensor& grad_bias,
    const Tensor& grad_output, const Tensor& input, const Tensor& weight,
    const Tensor& running_mean, const Tensor& running_var, const Tensor& save_mean, const Tensor& save_invstd,
    bool train, double eps) {
  using opmath_t = at::opmath_type<scalar_t>;
  using bVec = Vectorized<scalar_t>;
  using fVec = Vectorized<opmath_t>;
  int64_t n_channel = input.size(1);
  int64_t N = input.numel() / n_channel;

  const scalar_t* grad_output_data = grad_output.const_data_ptr<scalar_t>();
  const scalar_t* input_data = input.const_data_ptr<scalar_t>();

  scalar_t* grad_input_data = grad_input.defined() ? grad_input.mutable_data_ptr<scalar_t>() : nullptr;
  param_t* grad_weight_data = grad_weight.defined() ? grad_weight.data_ptr<param_t>() : nullptr;
  param_t* grad_bias_data = grad_bias.defined() ? grad_bias.data_ptr<param_t>() : nullptr;

  auto weight_a = conditional_accessor_1d<const param_t>(weight);
  auto save_mean_a = conditional_accessor_1d<const param_t>(save_mean);
  auto save_invstd_a = conditional_accessor_1d<const param_t>(save_invstd);
  auto running_mean_a = conditional_accessor_1d<const param_t>(running_mean);
  auto running_var_a = conditional_accessor_1d<const param_t>(running_var);

  // use float as acc type
  bool weight_defined = weight.defined();
  Tensor weight_f = at::empty({n_channel}, input.options().dtype(kFloat));
  Tensor mean = at::empty({n_channel}, input.options().dtype(kFloat));
  Tensor invstd = at::empty({n_channel}, input.options().dtype(kFloat));
  opmath_t* weight_data = weight_f.data_ptr<opmath_t>();
  opmath_t* mean_data = mean.data_ptr<opmath_t>();
  opmath_t* invstd_data = invstd.data_ptr<opmath_t>();

  for (const auto c : c10::irange(n_channel)) {
    weight_data[c] = weight_defined ? opmath_t(weight_a[c]) : 1;

    if (train) {
      mean_data[c] = save_mean_a[c];
      invstd_data[c] = save_invstd_a[c];
    } else {
      mean_data[c] = running_mean_a[c];
      invstd_data[c] = 1 / std::sqrt(running_var_a[c] + eps);
    }
  }

  int num_threads = at::get_num_threads();
  Tensor buffer = at::zeros({2, num_threads, n_channel}, input.options().dtype(kFloat));
  opmath_t* sum_data = buffer.data_ptr<opmath_t>();
  opmath_t* dotp_data = sum_data + num_threads * n_channel;

  at::parallel_for(0, N, 1, [&](int64_t begin, int64_t end) {
    int tid = at::get_thread_num();
    TORCH_CHECK(tid < num_threads, "expect thread id smaller than ", num_threads, ", got thread id ", tid);
    opmath_t* sum_ptr = sum_data + tid * n_channel;
    opmath_t* dotp_ptr = dotp_data + tid * n_channel;
    for (const auto i : c10::irange(begin, end)) {
      const scalar_t* x_ptr = input_data + i * n_channel;
      const scalar_t* dy_ptr = grad_output_data + i * n_channel;

      int64_t d = 0;
      for(; d < n_channel - (n_channel % bVec::size()); d += bVec::size()) {
        bVec dy_bvec = bVec::loadu(dy_ptr + d);
        auto [dy_fvec0, dy_fvec1] = convert_to_float<scalar_t>(dy_bvec);
        fVec sum_fvec0 = dy_fvec0 + fVec::loadu(sum_ptr + d);
        fVec sum_fvec1 = dy_fvec1 + fVec::loadu(sum_ptr + d + fVec::size());
        sum_fvec0.store(sum_ptr + d);
        sum_fvec1.store(sum_ptr + d + fVec::size());

        bVec x_bvec = bVec::loadu(x_ptr + d);
        auto [x_fvec0, x_fvec1] = convert_to_float<scalar_t>(x_bvec);
        fVec mean_fvec0 = fVec::loadu(mean_data + d);
        fVec mean_fvec1 = fVec::loadu(mean_data + d + fVec::size());
        fVec dotp_fvec0 = fVec::loadu(dotp_ptr + d);
        fVec dotp_fvec1 = fVec::loadu(dotp_ptr + d + fVec::size());
        dotp_fvec0 += (x_fvec0 - mean_fvec0) * dy_fvec0;
        dotp_fvec1 += (x_fvec1 - mean_fvec1) * dy_fvec1;
        dotp_fvec0.store(dotp_ptr + d);
        dotp_fvec1.store(dotp_ptr + d + fVec::size());
      }
      for (; d < n_channel; d++) {
        opmath_t dy_val = dy_ptr[d];
        opmath_t x_val = x_ptr[d];
        opmath_t mean_val = mean_data[d];
        sum_ptr[d] += dy_val;
        dotp_ptr[d] += (x_val - mean_val) * dy_val;
      }
    }
  });

  at::parallel_for(0, n_channel, 1, [&](int64_t begin, int64_t end) {
    for (const auto c : c10::irange(begin, end)) {
      // store the final result of sum and dotp in the 1st lane of immediate buffer,
      // so that we won't need to allocate anther buffer to store the temp values.
      opmath_t _sum = 0;
      for (const auto t : c10::irange(num_threads)) {
        _sum += sum_data[t * n_channel + c];
      }
      sum_data[/* 0 * n_channel + */c] = _sum;

      opmath_t _dotp = 0;
      for (const auto t : c10::irange(num_threads)) {
        _dotp += dotp_data[t * n_channel + c];
      }
      dotp_data[/* 0 * n_channel + */c] = _dotp;
    }
  });

  // compute grad_input
  if (grad_input.defined()) {
    at::parallel_for(0, N, 1, [&](int64_t begin, int64_t end) {
      for (const auto i : c10::irange(begin, end)) {
        scalar_t* dx_ptr = grad_input_data + i * n_channel;
        const scalar_t* x_ptr = input_data + i * n_channel;
        const scalar_t* dy_ptr = grad_output_data + i * n_channel;
        if (train) {
          int64_t d = 0;
          for (; d < n_channel - (n_channel % bVec::size()); d += bVec::size()) {
            bVec x_bvec = bVec::loadu(x_ptr + d);
            auto [x_fvec0, x_fvec1] = convert_to_float<scalar_t>(x_bvec);
            fVec mean_fvec0 = fVec::loadu(mean_data + d);
            fVec mean_fvec1 = fVec::loadu(mean_data + d + fVec::size());
            fVec dotp_fvec0 = fVec::loadu(dotp_data + d);
            fVec dotp_fvec1 = fVec::loadu(dotp_data + d + fVec::size());
            fVec invstd_fvec0 = fVec::loadu(invstd_data + d);
            fVec invstd_fvec1 = fVec::loadu(invstd_data + d + fVec::size());
            fVec k_fvec0 = dotp_fvec0 * invstd_fvec0 * invstd_fvec0 / fVec(N);
            fVec k_fvec1 = dotp_fvec1 * invstd_fvec1 * invstd_fvec1 / fVec(N);
            fVec dx_fvec0 = (x_fvec0 - mean_fvec0) * k_fvec0;
            fVec dx_fvec1 = (x_fvec1 - mean_fvec1) * k_fvec1;
            bVec dy_bvec = bVec::loadu(dy_ptr + d);
            auto [dy_fvec0, dy_fvec1] = convert_to_float<scalar_t>(dy_bvec);
            fVec grad_mean_fvec0 = fVec::loadu(sum_data + d) / fVec(N);
            fVec grad_mean_fvec1 = fVec::loadu(sum_data + d + fVec::size()) / fVec(N);
            fVec w_fvec0 = fVec::loadu(weight_data + d);
            fVec w_fvec1 = fVec::loadu(weight_data + d + fVec::size());
            dx_fvec0 = (dy_fvec0 - grad_mean_fvec0 - dx_fvec0) * invstd_fvec0 * w_fvec0;
            dx_fvec1 = (dy_fvec1 - grad_mean_fvec1 - dx_fvec1) * invstd_fvec1 * w_fvec1;
            bVec dx_bvec = convert_from_float<scalar_t>(dx_fvec0, dx_fvec1);
            dx_bvec.store(dx_ptr + d);
          }
          for (; d < n_channel; d++) {
            opmath_t x_val = x_ptr[d];
            opmath_t mean_val = mean_data[d];
            opmath_t dotp_val = dotp_data[d];
            opmath_t invstd_val = invstd_data[d];
            opmath_t k_val = dotp_val * invstd_val * invstd_val / N;
            opmath_t dx_val = (x_val - mean_val) * k_val;
            opmath_t dy_val = dy_ptr[d];
            opmath_t grad_mean_val = sum_data[d] / N;
            opmath_t w_val = weight_data[d];
            dx_val = (dy_val - grad_mean_val - dx_val) * invstd_val * w_val;
            dx_ptr[d] = scalar_t(dx_val);
          }
        } else { // evaluation mode
          int64_t d = 0;
          for (; d < n_channel - (n_channel % bVec::size()); d += bVec::size()) {
            bVec dy_bvec = bVec::loadu(dy_ptr + d);
            auto [dy_fvec0, dy_fvec1] = convert_to_float<scalar_t>(dy_bvec);
            fVec invstd_fvec0 = fVec::loadu(invstd_data + d);
            fVec invstd_fvec1 = fVec::loadu(invstd_data + d + fVec::size());
            fVec w_fvec0 = fVec::loadu(weight_data + d);
            fVec w_fvec1 = fVec::loadu(weight_data + d + fVec::size());
            fVec dx_fvec0 = dy_fvec0 * invstd_fvec0 * w_fvec0;
            fVec dx_fvec1 = dy_fvec1 * invstd_fvec1 * w_fvec1;
            bVec dx_bvec = convert_from_float<scalar_t>(dx_fvec0, dx_fvec1);
            dx_bvec.store(dx_ptr + d);
          }
          for (; d < n_channel; d++) {
            opmath_t dy_val = dy_ptr[d];
            opmath_t invstd_val = invstd_data[d];
            opmath_t w_val = weight_data[d];
            opmath_t dx_val = dy_val * invstd_val * w_val;
            dx_ptr[d] = scalar_t(dx_val);
          }
        }
      }
    });
  }

  if (grad_weight.defined()) {
    for (const auto c : c10::irange(n_channel)) {
      grad_weight_data[c] = param_t(dotp_data[c] * invstd_data[c]);
    }
  }

  if (grad_bias.defined()) {
    for (const auto c : c10::irange(n_channel)) {
      grad_bias_data[c] = param_t(sum_data[c]);
    }
  }
}

Analyze Your Own Codebase

Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.

Try Supermodel Free