Home / Class/ apply_gels Class — pytorch Architecture

apply_gels Class — pytorch Architecture

Architecture documentation for the apply_gels class in BatchLinearAlgebra.cpp from the pytorch codebase.

Entity Profile

Source Code

aten/src/ATen/native/cuda/linalg/BatchLinearAlgebra.cpp lines 2092–2128

template <typename scalar_t>
static void apply_gels(const Tensor& a, Tensor& b, Tensor& infos) {
#if !AT_MAGMA_ENABLED()
  TORCH_CHECK(false, "torch.linalg.lstsq: MAGMA library not found in "
    "compilation. Please rebuild with MAGMA.");
#else
  auto trans = MagmaNoTrans;
  auto m = magma_int_cast(a.size(-2), "m");
  auto n = magma_int_cast(a.size(-1), "n");

  TORCH_CHECK(
    m >= n,
    "torch.linalg.lstsq: only overdetermined systems (input.size(-2) >= input.size(-1)) are allowed on CUDA");

  auto nrhs = magma_int_cast(b.size(-1), "nrhs");
  auto ldda = std::max<magma_int_t>(1, m);
  auto lddb = std::max<magma_int_t>(1, std::max(m, n));
  auto nb = magmaGeqrfOptimalBlocksize<scalar_t>(m, n);
  auto lwork = (m - n + nb) * (nrhs + nb) + nrhs * nb;
  Tensor hwork = at::empty({static_cast<int64_t>(lwork)}, a.scalar_type());
  auto* hwork_ptr = hwork.mutable_data_ptr<scalar_t>();

  // MAGMA requires infos tensor to live on CPU
  infos = infos.to(at::kCPU);
  auto infos_data = infos.data_ptr<magma_int_t>();

  batch_iterator_with_broadcasting<scalar_t>(a, b,
    [&](scalar_t* a_working_ptr, scalar_t* b_working_ptr,
      int64_t a_linear_batch_idx) {
      magma_int_t* infos_working_ptr = &infos_data[a_linear_batch_idx];
      magmaGels<scalar_t>(trans, m, n, nrhs,
        a_working_ptr, ldda, b_working_ptr, lddb,
        hwork_ptr, lwork, infos_working_ptr);
    }
  );
#endif
}

Analyze Your Own Codebase

Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.

Try Supermodel Free