apply_cholesky_cusolver_potrs Class — pytorch Architecture
Architecture documentation for the apply_cholesky_cusolver_potrs class in BatchLinearAlgebraLib.cpp from the pytorch codebase.
Entity Profile
Source Code
aten/src/ATen/native/cuda/linalg/BatchLinearAlgebraLib.cpp lines 817–868
template<typename scalar_t>
static void apply_cholesky_cusolver_potrs(Tensor& self_working_copy, const Tensor& A_column_major_copy, bool upper, Tensor& infos) {
auto handle = at::cuda::getCurrentCUDASolverDnHandle();
const auto uplo = upper ? CUBLAS_FILL_MODE_UPPER : CUBLAS_FILL_MODE_LOWER;
const int64_t n = self_working_copy.size(-2);
const int64_t nrhs = self_working_copy.size(-1);
const int64_t lda = std::max<int64_t>(1, n);
const int64_t batch_size = batchCount(self_working_copy);
const int64_t self_matrix_stride = matrixStride(self_working_copy);
scalar_t* self_working_copy_ptr = self_working_copy.data_ptr<scalar_t>();
scalar_t* A_ptr = A_column_major_copy.data_ptr<scalar_t>();
const int64_t A_matrix_stride = matrixStride(A_column_major_copy);
const int64_t ldb = std::max<int64_t>(1, A_column_major_copy.size(-1));
int* infos_ptr = infos.data_ptr<int>();
#ifdef USE_CUSOLVER_64_BIT
cusolverDnParams_t params;
cudaDataType datatype = at::cuda::solver::get_cusolver_datatype<scalar_t>();
TORCH_CUSOLVER_CHECK(cusolverDnCreateParams(¶ms));
for (int64_t i = 0; i < batch_size; i++) {
at::cuda::solver::xpotrs(
handle, params, uplo, n, nrhs, datatype,
A_ptr + i * A_matrix_stride,
lda, datatype,
self_working_copy_ptr + i * self_matrix_stride,
ldb,
infos_ptr
);
}
TORCH_CUSOLVER_CHECK(cusolverDnDestroyParams(params));
#else // USE_CUSOLVER_64_BIT
int n_32 = cuda_int_cast(n, "n");
int nrhs_32 = cuda_int_cast(nrhs, "nrhs");
int lda_32 = cuda_int_cast(lda, "lda");
int ldb_32 = cuda_int_cast(ldb, "ldb");
for (int64_t i = 0; i < batch_size; i++) {
at::cuda::solver::potrs<scalar_t>(
handle, uplo, n_32, nrhs_32,
A_ptr + i * A_matrix_stride,
lda_32,
self_working_copy_ptr + i * self_matrix_stride,
ldb_32,
infos_ptr
);
}
#endif // USE_CUSOLVER_64_BIT
}
Source
Analyze Your Own Codebase
Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.
Try Supermodel Free