apply_ldl_factor_cusolver Class — pytorch Architecture
Architecture documentation for the apply_ldl_factor_cusolver class in BatchLinearAlgebraLib.cpp from the pytorch codebase.
Entity Profile
Source Code
aten/src/ATen/native/cuda/linalg/BatchLinearAlgebraLib.cpp lines 75–115
template <typename scalar_t>
void apply_ldl_factor_cusolver(
const Tensor& A,
const Tensor& pivots,
const Tensor& info,
bool upper) {
auto batch_size = batchCount(A);
auto n = cuda_int_cast(A.size(-2), "A.size(-2)");
auto lda = cuda_int_cast(A.stride(-1), "A.stride(-1)");
auto uplo = upper ? CUBLAS_FILL_MODE_UPPER : CUBLAS_FILL_MODE_LOWER;
auto a_stride = A.dim() > 2 ? A.stride(-3) : 0;
auto pivots_stride = pivots.dim() > 1 ? pivots.stride(-2) : 0;
auto a_data = A.data_ptr<scalar_t>();
auto pivots_data = pivots.data_ptr<int>();
auto info_data = info.data_ptr<int>();
auto handle = at::cuda::getCurrentCUDASolverDnHandle();
int lwork = 0;
at::cuda::solver::sytrf_bufferSize(handle, n, a_data, lda, &lwork);
auto& allocator = *::c10::cuda::CUDACachingAllocator::get();
auto work = allocator.allocate(sizeof(scalar_t) * lwork);
for (const auto i : c10::irange(batch_size)) {
auto* a_working_ptr = &a_data[i * a_stride];
auto* pivots_working_ptr = &pivots_data[i * pivots_stride];
auto* info_working_ptr = &info_data[i];
at::cuda::solver::sytrf(
handle,
uplo,
n,
a_working_ptr,
lda,
pivots_working_ptr,
reinterpret_cast<scalar_t*>(work.get()),
lwork,
info_working_ptr);
}
}
Source
Analyze Your Own Codebase
Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.
Try Supermodel Free