apply_svd Class — pytorch Architecture
Architecture documentation for the apply_svd class in BatchLinearAlgebraKernel.cpp from the pytorch codebase.
Entity Profile
Source Code
aten/src/ATen/native/BatchLinearAlgebraKernel.cpp lines 1097–1158
template <typename scalar_t>
void apply_svd(const Tensor& A,
const bool full_matrices,
const bool compute_uv,
const Tensor& U,
const Tensor& S,
const Tensor& Vh,
const Tensor& info) {
#if !AT_BUILD_WITH_LAPACK()
TORCH_CHECK(false, "svd: LAPACK library not found in compilation");
#else
using value_t = typename c10::scalar_value_type<scalar_t>::type;
const auto A_data = A.data_ptr<scalar_t>();
const auto U_data = compute_uv ? U.data_ptr<scalar_t>() : nullptr;
const auto S_data = S.data_ptr<value_t>();
const auto info_data = info.data_ptr<int>();
const auto Vh_data = compute_uv ? Vh.data_ptr<scalar_t>() : nullptr;
const auto A_stride = matrixStride(A);
const auto S_stride = S.size(-1);
const auto U_stride = compute_uv ? matrixStride(U) : 1;
const auto Vh_stride = compute_uv ? matrixStride(Vh) : 1;
const auto batchsize = batchCount(A);
const char jobz = compute_uv ? (full_matrices ? 'A' : 'S') : 'N';
const auto m = A.size(-2);
const auto n = A.size(-1);
const auto lda = A.stride(-1);
const auto ldu= compute_uv ? U.stride(-1) : 1;
const auto ldvh = compute_uv ? Vh.stride(-1) : 1;
auto iwork = std::vector<int>(8 * std::min(m, n));
auto* const iwork_data = iwork.data();
// rwork is just used for the complex decomposition
auto rwork = std::vector<value_t>{};
if (A.is_complex()) {
rwork.resize(std::max(computeLRWorkDim(jobz, m, n), int64_t{1}));
}
auto* const rwork_data = rwork.data();
// Query svd for the optimal lwork size
int lwork = -1;
{
scalar_t wkopt;
lapackSvd<scalar_t, value_t>(jobz, m, n, A_data, lda, S_data, U_data, ldu, Vh_data, ldvh, &wkopt, lwork, rwork_data, iwork_data, info_data);
lwork = lapack_work_to_int(wkopt);
}
auto work = std::vector<scalar_t>(lwork);
auto* const work_data = work.data();
for (const auto i : c10::irange(batchsize)) {
auto* const A_working_ptr = &A_data[i * A_stride];
auto* const S_working_ptr = &S_data[i * S_stride];
auto* const U_working_ptr = compute_uv ? &U_data[i * U_stride] : nullptr;
auto* const Vh_working_ptr = compute_uv ? &Vh_data[i * Vh_stride] : nullptr;
// Compute S, U (optionally) and Vh (optionally)
lapackSvd<scalar_t, value_t>(jobz, m, n, A_working_ptr, lda,
S_working_ptr, U_working_ptr, ldu, Vh_working_ptr, ldvh, work_data, lwork, rwork_data, iwork_data, info_data + i);
}
#endif
}
Source
Analyze Your Own Codebase
Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.
Try Supermodel Free