apply_svd_cusolver_gesvdj Class — pytorch Architecture
Architecture documentation for the apply_svd_cusolver_gesvdj class in BatchLinearAlgebraLib.cpp from the pytorch codebase.
Entity Profile
Source Code
aten/src/ATen/native/cuda/linalg/BatchLinearAlgebraLib.cpp lines 353–423
template<typename scalar_t>
static void apply_svd_cusolver_gesvdj(const Tensor& A, const Tensor& U, const Tensor& S, const Tensor& V,
const Tensor& infos, bool full_matrices, bool compute_uv) {
using value_t = typename c10::scalar_value_type<scalar_t>::type;
int m = cuda_int_cast(A.size(-2), "m");
int n = cuda_int_cast(A.size(-1), "n");
int k = std::min(m, n);
// Need to pass allocated memory to the function, otherwise it fails
auto& allocator = *::c10::cuda::CUDACachingAllocator::get();
auto dataPtr_U = !compute_uv ? allocator.allocate(sizeof(scalar_t)* m * k) : c10::DataPtr{};
auto dataPtr_V = !compute_uv ? allocator.allocate(sizeof(scalar_t)* n * k) : c10::DataPtr{};
auto A_data = A.data_ptr<scalar_t>();
auto U_data = compute_uv ? U.data_ptr<scalar_t>() : reinterpret_cast<scalar_t*>(dataPtr_U.get());
auto S_data = S.data_ptr<value_t>();
auto V_data = compute_uv ? V.data_ptr<scalar_t>() : reinterpret_cast<scalar_t*>(dataPtr_V.get());
auto A_stride = matrixStride(A);
auto U_stride = compute_uv ? matrixStride(U) : 0;
auto S_stride = S.size(-1);
auto V_stride = compute_uv ? matrixStride(V) : 0;
int batchsize = cuda_int_cast(batchCount(A), "batch size");
int lda = A.stride(-1);
int ldu = compute_uv ? U.stride(-1) : m;
int ldv = compute_uv ? V.stride(-1) : n;
auto handle = at::cuda::getCurrentCUDASolverDnHandle();
auto jobz = compute_uv ? CUSOLVER_EIG_MODE_VECTOR : CUSOLVER_EIG_MODE_NOVECTOR;
int econ = full_matrices ? 0 : 1;
// gesvdj_params controls the numerical accuracy of cusolver gesvdj iterations on GPU
gesvdjInfo_t gesvdj_params;
TORCH_CUSOLVER_CHECK(cusolverDnCreateGesvdjInfo(&gesvdj_params));
// Todo: expose the following two parameters to users
TORCH_CUSOLVER_CHECK(cusolverDnXgesvdjSetTolerance(gesvdj_params, std::numeric_limits<scalar_t>::epsilon()));
TORCH_CUSOLVER_CHECK(cusolverDnXgesvdjSetMaxSweeps(gesvdj_params, cusolver_gesvdj_max_sweeps));
int lwork = -1;
at::cuda::solver::gesvdj_buffersize<scalar_t>(
handle, jobz, econ, m, n, A_data, lda, S_data, U_data, ldu, V_data, ldv, &lwork, gesvdj_params);
TORCH_INTERNAL_ASSERT(lwork >= 0, "gesvdj_buffersize failed to get needed buffer size, got lwork = ", lwork);
auto dataPtr = allocator.allocate(sizeof(scalar_t)*lwork);
for(int i = 0; i < batchsize; i++){
at::cuda::solver::gesvdj<scalar_t>(
handle, jobz, econ, m, n,
A_data + i * A_stride,
lda,
S_data + i * S_stride,
U_data + i * U_stride,
ldu,
V_data + i * V_stride,
ldv,
reinterpret_cast<scalar_t*>(dataPtr.get()),
lwork,
infos.data_ptr<int>() + i,
gesvdj_params
);
// // The following code can be used to check or report the gesvdj residual.
// // Note: this will introduce a device-host sync and may negatively affect the performance
// double residual = 0;
// TORCH_CUSOLVER_CHECK(cusolverDnXgesvdjGetResidual(handle, gesvdj_params, &residual));
// printf("gesvdj residual = %.6e\n", residual);
}
TORCH_CUSOLVER_CHECK(cusolverDnDestroyGesvdjInfo(gesvdj_params));
}
Source
Analyze Your Own Codebase
Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.
Try Supermodel Free