apply_gels Class — pytorch Architecture
Architecture documentation for the apply_gels class in BatchLinearAlgebra.cpp from the pytorch codebase.
Entity Profile
Source Code
aten/src/ATen/native/cuda/linalg/BatchLinearAlgebra.cpp lines 2092–2128
template <typename scalar_t>
static void apply_gels(const Tensor& a, Tensor& b, Tensor& infos) {
#if !AT_MAGMA_ENABLED()
TORCH_CHECK(false, "torch.linalg.lstsq: MAGMA library not found in "
"compilation. Please rebuild with MAGMA.");
#else
auto trans = MagmaNoTrans;
auto m = magma_int_cast(a.size(-2), "m");
auto n = magma_int_cast(a.size(-1), "n");
TORCH_CHECK(
m >= n,
"torch.linalg.lstsq: only overdetermined systems (input.size(-2) >= input.size(-1)) are allowed on CUDA");
auto nrhs = magma_int_cast(b.size(-1), "nrhs");
auto ldda = std::max<magma_int_t>(1, m);
auto lddb = std::max<magma_int_t>(1, std::max(m, n));
auto nb = magmaGeqrfOptimalBlocksize<scalar_t>(m, n);
auto lwork = (m - n + nb) * (nrhs + nb) + nrhs * nb;
Tensor hwork = at::empty({static_cast<int64_t>(lwork)}, a.scalar_type());
auto* hwork_ptr = hwork.mutable_data_ptr<scalar_t>();
// MAGMA requires infos tensor to live on CPU
infos = infos.to(at::kCPU);
auto infos_data = infos.data_ptr<magma_int_t>();
batch_iterator_with_broadcasting<scalar_t>(a, b,
[&](scalar_t* a_working_ptr, scalar_t* b_working_ptr,
int64_t a_linear_batch_idx) {
magma_int_t* infos_working_ptr = &infos_data[a_linear_batch_idx];
magmaGels<scalar_t>(trans, m, n, nrhs,
a_working_ptr, ldda, b_working_ptr, lddb,
hwork_ptr, lwork, infos_working_ptr);
}
);
#endif
}
Source
Analyze Your Own Codebase
Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.
Try Supermodel Free