apply_svd_cusolver_gesvdjBatched Class — pytorch Architecture
Architecture documentation for the apply_svd_cusolver_gesvdjBatched class in BatchLinearAlgebraLib.cpp from the pytorch codebase.
Entity Profile
Source Code
aten/src/ATen/native/cuda/linalg/BatchLinearAlgebraLib.cpp lines 435–477
template<typename scalar_t>
static void apply_svd_cusolver_gesvdjBatched(const Tensor& A, const Tensor& U, const Tensor& S, const Tensor& V,
const Tensor& infos, bool compute_uv
) {
using value_t = typename c10::scalar_value_type<scalar_t>::type;
int m = cuda_int_cast(A.size(-2), "m");
int n = cuda_int_cast(A.size(-1), "n");
int batchsize = cuda_int_cast(batchCount(A), "batch size");
auto lda = std::max<int>(1, m);
auto ldu = std::max<int>(1, m);
auto ldv = std::max<int>(1, n);
// Need to pass allocated memory to the function, otherwise it fails
auto& allocator = *::c10::cuda::CUDACachingAllocator::get();
auto dataPtr_U = !compute_uv ? allocator.allocate(sizeof(scalar_t) * batchsize * m * ldu) : c10::DataPtr{};
auto dataPtr_V = !compute_uv ? allocator.allocate(sizeof(scalar_t) * batchsize * n * ldv) : c10::DataPtr{};
auto A_data = A.data_ptr<scalar_t>();
auto U_data = compute_uv ? U.data_ptr<scalar_t>() : reinterpret_cast<scalar_t*>(dataPtr_U.get());
auto S_data = S.data_ptr<value_t>();
auto V_data = compute_uv ? V.data_ptr<scalar_t>() : reinterpret_cast<scalar_t*>(dataPtr_V.get());
TORCH_INTERNAL_ASSERT(m <= 32 && n <= 32, "gesvdjBatched requires both matrix dimensions not greater than 32, but got "
"m = ", m, " n = ", n);
// gesvdj_params controls the numerical accuracy of cusolver gesvdj iterations on GPU
gesvdjInfo_t gesvdj_params;
TORCH_CUSOLVER_CHECK(cusolverDnCreateGesvdjInfo(&gesvdj_params));
// Todo: expose the following two parameters to users
TORCH_CUSOLVER_CHECK(cusolverDnXgesvdjSetTolerance(gesvdj_params, std::numeric_limits<scalar_t>::epsilon()));
TORCH_CUSOLVER_CHECK(cusolverDnXgesvdjSetMaxSweeps(gesvdj_params, cusolver_gesvdj_max_sweeps));
TORCH_CUSOLVER_CHECK(cusolverDnXgesvdjSetSortEig(gesvdj_params, 1));
auto handle = at::cuda::getCurrentCUDASolverDnHandle();
auto jobz = compute_uv ? CUSOLVER_EIG_MODE_VECTOR : CUSOLVER_EIG_MODE_NOVECTOR;
at::cuda::solver::gesvdjBatched<scalar_t>(
handle, jobz, m, n, A_data, lda, S_data, U_data, ldu, V_data, ldv,
infos.data_ptr<int>(), gesvdj_params, batchsize
);
TORCH_CUSOLVER_CHECK(cusolverDnDestroyGesvdjInfo(gesvdj_params));
}
Source
Analyze Your Own Codebase
Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.
Try Supermodel Free