Home / Class/ apply_lstsq Class — pytorch Architecture

apply_lstsq Class — pytorch Architecture

Architecture documentation for the apply_lstsq class in BatchLinearAlgebraKernel.cpp from the pytorch codebase.

Entity Profile

Source Code

aten/src/ATen/native/BatchLinearAlgebraKernel.cpp lines 543–680

template <typename scalar_t>
void apply_lstsq(const Tensor& A, Tensor& B, Tensor& rank, Tensor& singular_values, Tensor& infos, double rcond, LapackLstsqDriverType driver_type) {
#if !AT_BUILD_WITH_LAPACK()
  TORCH_CHECK(
      false,
      "Calling torch.linalg.lstsq on a CPU tensor requires compiling ",
      "PyTorch with LAPACK. Please use PyTorch built with LAPACK support.");
#else
  using value_t = typename c10::scalar_value_type<scalar_t>::type;
  using driver_t = at::native::LapackLstsqDriverType;

  auto lapack_func = lapackLstsq<driver_t::Gelsd, scalar_t, value_t>;
  static auto driver_type_to_func
    = std::unordered_map<driver_t, decltype(lapack_func)>({
    {driver_t::Gels, lapackLstsq<driver_t::Gels, scalar_t, value_t>},
    {driver_t::Gelsy, lapackLstsq<driver_t::Gelsy, scalar_t, value_t>},
    {driver_t::Gelsd, lapackLstsq<driver_t::Gelsd, scalar_t, value_t>},
    {driver_t::Gelss, lapackLstsq<driver_t::Gelss, scalar_t, value_t>}
  });
  lapack_func = driver_type_to_func[driver_type];

  char trans = 'N';

  auto A_data = A.data_ptr<scalar_t>();
  auto B_data = B.data_ptr<scalar_t>();
  auto m = A.size(-2);
  auto n = A.size(-1);
  auto nrhs = B.size(-1);
  auto lda = std::max<int64_t>(1, m);
  auto ldb = std::max<int64_t>({static_cast<int64_t>(1), m, n});
  auto infos_data = infos.data_ptr<int>();

  // only 'gels' driver does not compute the rank
  int rank_32 = 0;
  int64_t* rank_data = nullptr;
  int64_t* rank_working_ptr = nullptr;
  if (driver_t::Gels != driver_type) {
    rank_data = rank.data_ptr<int64_t>();
    rank_working_ptr = rank_data;
  }

  // 'gelsd' and 'gelss' are SVD-based algorithms
  // so we can get singular values
  value_t* s_data = nullptr;
  value_t* s_working_ptr = nullptr;
  int64_t s_stride = 0;
  if (driver_t::Gelsd == driver_type || driver_t::Gelss == driver_type) {
    s_data = singular_values.data_ptr<value_t>();
    s_working_ptr = s_data;
    s_stride = singular_values.size(-1);
  }

  // 'jpvt' workspace array is used only for 'gelsy' which uses QR factorization with column pivoting
  Tensor jpvt;
  int* jpvt_data = nullptr;
  if (driver_t::Gelsy == driver_type) {
    jpvt = at::empty({std::max<int64_t>(1, n)}, A.options().dtype(at::kInt));
    jpvt_data = jpvt.mutable_data_ptr<int>();
  }

  // Run once the driver, first to get the optimal workspace size
  int lwork = -1; // default value to decide the opt size for workspace arrays
  scalar_t work_opt;
  value_t rwork_opt;
  int iwork_opt = 0;
  lapack_func(trans, m, n, nrhs,
    A_data, lda,
    B_data, ldb,
    &work_opt, lwork,
    infos_data,
    jpvt_data,
    static_cast<value_t>(rcond),
    &rank_32,
    &rwork_opt,
    s_working_ptr,
    &iwork_opt);

  lwork = lapack_work_to_int(work_opt);
  Tensor work = at::empty({lwork}, A.options());
  scalar_t* work_data = work.mutable_data_ptr<scalar_t>();

  // 'rwork' only used for complex inputs and 'gelsy', 'gelsd' and 'gelss' drivers
  Tensor rwork;
  value_t* rwork_data = nullptr;
  if (A.is_complex() && driver_t::Gels != driver_type) {
    int64_t rwork_len = 0;
    switch (driver_type) {
      case driver_t::Gelsy:
        rwork_len = std::max<int64_t>(1, 2 * n);
        break;
      case driver_t::Gelss:
        rwork_len = std::max<int64_t>(1, 5 * std::min(m, n));
        break;
      // case driver_t::Gelsd:
      default:
        rwork_len = std::max<int64_t>(1, rwork_opt);
    }
    rwork = at::empty({rwork_len}, A.options().dtype(c10::toRealValueType(A.scalar_type())));
    rwork_data = rwork.mutable_data_ptr<value_t>();
  }

  // 'iwork' workspace array is relevant only for 'gelsd'
  Tensor iwork;
  int* iwork_data = nullptr;
  if (driver_t::Gelsd == driver_type) {
    iwork = at::empty({std::max<int>(1, iwork_opt)}, A.options().dtype(at::kInt));
    iwork_data = iwork.mutable_data_ptr<int>();
  }

  at::native::batch_iterator_with_broadcasting<scalar_t>(A, B,
    [&](scalar_t* A_working_ptr, scalar_t* B_working_ptr, int64_t A_linear_batch_idx) {
      rank_working_ptr = rank_working_ptr ? &rank_data[A_linear_batch_idx] : nullptr;
      s_working_ptr = s_working_ptr ? &s_data[A_linear_batch_idx * s_stride] : nullptr;
      int* infos_working_ptr = &infos_data[A_linear_batch_idx];

      lapack_func(trans, m, n, nrhs,
        A_working_ptr, lda,
        B_working_ptr, ldb,
        work_data, lwork,
        infos_working_ptr,
        jpvt_data,
        static_cast<value_t>(rcond),
        &rank_32,
        rwork_data,
        s_working_ptr,
        iwork_data);

      // we want the output `rank` Tensor to be of type int64_t,
      // however LAPACK accepts int. That is why we use an integer
      // variable that then gets promoted and written into `rank`.
      // We use this approach over a tensor cast for better performance.
      if (rank_working_ptr) {
        *rank_working_ptr = static_cast<int64_t>(rank_32);
      }
    }
  );
#endif
}

Analyze Your Own Codebase

Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.

Try Supermodel Free