apply_lstsq Class — pytorch Architecture
Architecture documentation for the apply_lstsq class in BatchLinearAlgebraKernel.cpp from the pytorch codebase.
Entity Profile
Source Code
aten/src/ATen/native/BatchLinearAlgebraKernel.cpp lines 543–680
template <typename scalar_t>
void apply_lstsq(const Tensor& A, Tensor& B, Tensor& rank, Tensor& singular_values, Tensor& infos, double rcond, LapackLstsqDriverType driver_type) {
#if !AT_BUILD_WITH_LAPACK()
TORCH_CHECK(
false,
"Calling torch.linalg.lstsq on a CPU tensor requires compiling ",
"PyTorch with LAPACK. Please use PyTorch built with LAPACK support.");
#else
using value_t = typename c10::scalar_value_type<scalar_t>::type;
using driver_t = at::native::LapackLstsqDriverType;
auto lapack_func = lapackLstsq<driver_t::Gelsd, scalar_t, value_t>;
static auto driver_type_to_func
= std::unordered_map<driver_t, decltype(lapack_func)>({
{driver_t::Gels, lapackLstsq<driver_t::Gels, scalar_t, value_t>},
{driver_t::Gelsy, lapackLstsq<driver_t::Gelsy, scalar_t, value_t>},
{driver_t::Gelsd, lapackLstsq<driver_t::Gelsd, scalar_t, value_t>},
{driver_t::Gelss, lapackLstsq<driver_t::Gelss, scalar_t, value_t>}
});
lapack_func = driver_type_to_func[driver_type];
char trans = 'N';
auto A_data = A.data_ptr<scalar_t>();
auto B_data = B.data_ptr<scalar_t>();
auto m = A.size(-2);
auto n = A.size(-1);
auto nrhs = B.size(-1);
auto lda = std::max<int64_t>(1, m);
auto ldb = std::max<int64_t>({static_cast<int64_t>(1), m, n});
auto infos_data = infos.data_ptr<int>();
// only 'gels' driver does not compute the rank
int rank_32 = 0;
int64_t* rank_data = nullptr;
int64_t* rank_working_ptr = nullptr;
if (driver_t::Gels != driver_type) {
rank_data = rank.data_ptr<int64_t>();
rank_working_ptr = rank_data;
}
// 'gelsd' and 'gelss' are SVD-based algorithms
// so we can get singular values
value_t* s_data = nullptr;
value_t* s_working_ptr = nullptr;
int64_t s_stride = 0;
if (driver_t::Gelsd == driver_type || driver_t::Gelss == driver_type) {
s_data = singular_values.data_ptr<value_t>();
s_working_ptr = s_data;
s_stride = singular_values.size(-1);
}
// 'jpvt' workspace array is used only for 'gelsy' which uses QR factorization with column pivoting
Tensor jpvt;
int* jpvt_data = nullptr;
if (driver_t::Gelsy == driver_type) {
jpvt = at::empty({std::max<int64_t>(1, n)}, A.options().dtype(at::kInt));
jpvt_data = jpvt.mutable_data_ptr<int>();
}
// Run once the driver, first to get the optimal workspace size
int lwork = -1; // default value to decide the opt size for workspace arrays
scalar_t work_opt;
value_t rwork_opt;
int iwork_opt = 0;
lapack_func(trans, m, n, nrhs,
A_data, lda,
B_data, ldb,
&work_opt, lwork,
infos_data,
jpvt_data,
static_cast<value_t>(rcond),
&rank_32,
&rwork_opt,
s_working_ptr,
&iwork_opt);
lwork = lapack_work_to_int(work_opt);
Tensor work = at::empty({lwork}, A.options());
scalar_t* work_data = work.mutable_data_ptr<scalar_t>();
// 'rwork' only used for complex inputs and 'gelsy', 'gelsd' and 'gelss' drivers
Tensor rwork;
value_t* rwork_data = nullptr;
if (A.is_complex() && driver_t::Gels != driver_type) {
int64_t rwork_len = 0;
switch (driver_type) {
case driver_t::Gelsy:
rwork_len = std::max<int64_t>(1, 2 * n);
break;
case driver_t::Gelss:
rwork_len = std::max<int64_t>(1, 5 * std::min(m, n));
break;
// case driver_t::Gelsd:
default:
rwork_len = std::max<int64_t>(1, rwork_opt);
}
rwork = at::empty({rwork_len}, A.options().dtype(c10::toRealValueType(A.scalar_type())));
rwork_data = rwork.mutable_data_ptr<value_t>();
}
// 'iwork' workspace array is relevant only for 'gelsd'
Tensor iwork;
int* iwork_data = nullptr;
if (driver_t::Gelsd == driver_type) {
iwork = at::empty({std::max<int>(1, iwork_opt)}, A.options().dtype(at::kInt));
iwork_data = iwork.mutable_data_ptr<int>();
}
at::native::batch_iterator_with_broadcasting<scalar_t>(A, B,
[&](scalar_t* A_working_ptr, scalar_t* B_working_ptr, int64_t A_linear_batch_idx) {
rank_working_ptr = rank_working_ptr ? &rank_data[A_linear_batch_idx] : nullptr;
s_working_ptr = s_working_ptr ? &s_data[A_linear_batch_idx * s_stride] : nullptr;
int* infos_working_ptr = &infos_data[A_linear_batch_idx];
lapack_func(trans, m, n, nrhs,
A_working_ptr, lda,
B_working_ptr, ldb,
work_data, lwork,
infos_working_ptr,
jpvt_data,
static_cast<value_t>(rcond),
&rank_32,
rwork_data,
s_working_ptr,
iwork_data);
// we want the output `rank` Tensor to be of type int64_t,
// however LAPACK accepts int. That is why we use an integer
// variable that then gets promoted and written into `rank`.
// We use this approach over a tensor cast for better performance.
if (rank_working_ptr) {
*rank_working_ptr = static_cast<int64_t>(rank_32);
}
}
);
#endif
}
Source
Analyze Your Own Codebase
Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.
Try Supermodel Free