Home / Class/ run_backward_parallel_cdist Class — pytorch Architecture

run_backward_parallel_cdist Class — pytorch Architecture

Architecture documentation for the run_backward_parallel_cdist class in DistanceOpsKernel.cpp from the pytorch codebase.

Entity Profile

Source Code

aten/src/ATen/native/cpu/DistanceOpsKernel.cpp lines 356–390

  template <typename F>
  static void run_backward_parallel_cdist(Tensor& result, const Tensor & grad, const Tensor & t1, const Tensor & t2, const scalar_t p, const Tensor& dist) {
    const int64_t r1 = t1.size(-2);
    const int64_t r2 = t2.size(-2);
    const int64_t m = t1.size(-1);
    const int64_t d = result.size(0);
    const int64_t l1_size = r1 * m;
    const int64_t l2_size = r2 * m;
    //current implementation supports only tensor that can be collapsed to 1D. However, to avoid checking if grad satisfies this assumption,
    //we call .contiguous() on grad before backward, thus stride is guaranteed to be 1
    //don't use grad.stride(-1), because if last dimension is 1, stride can be bogus.
    const int64_t gs = 1;

    const scalar_t * const grad_start = grad.const_data_ptr<scalar_t>();
    const scalar_t * const dist_start = dist.const_data_ptr<scalar_t>();
    const scalar_t * const t1_start = t1.const_data_ptr<scalar_t>();
    const scalar_t * const t2_start = t2.const_data_ptr<scalar_t>();
    scalar_t * const res_start = result.data_ptr<scalar_t>();

    at::parallel_for(0, m / Vec::size(), internal::GRAIN_SIZE / (16 * r1), [=](int64_t l, int64_t end) {
      const Vec pvec(p);

      const scalar_t * i = t1_start + l * Vec::size();
      const scalar_t * j = t2_start + l * Vec::size();
      scalar_t * res_l = res_start + l * Vec::size();

      for (const scalar_t * const res_end = res_start + end * Vec::size(); res_l != res_end; i += Vec::size(), j += Vec::size(), res_l += Vec::size()) {
        backward_down_column_cdist<F>(i, j, res_l, grad_start, dist_start, pvec, r1, r2, m, d, gs, l1_size, l2_size);
      }
    });
    const int64_t remainder = m % Vec::size();
    if (remainder) {
      backward_down_column_cdist<F>(t1_start + (m - remainder), t2_start + (m - remainder), res_start + (m - remainder), grad_start, dist_start, Vec(p), r1, r2, m, d, gs, l1_size, l2_size, remainder);
    }
  }

Analyze Your Own Codebase

Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.

Try Supermodel Free