apply_lu_solve_batched_cublas Class — pytorch Architecture
Architecture documentation for the apply_lu_solve_batched_cublas class in BatchLinearAlgebraLibBlas.cpp from the pytorch codebase.
Entity Profile
Source Code
aten/src/ATen/native/cuda/linalg/BatchLinearAlgebraLibBlas.cpp lines 130–152
template <typename scalar_t>
static void apply_lu_solve_batched_cublas(const Tensor& LU, const Tensor& pivots, const Tensor& B, TransposeType transpose) {
TORCH_INTERNAL_ASSERT(batchCount(LU) == batchCount(B), "batch_size of LU and B must be the same");
TORCH_INTERNAL_ASSERT(batchCount(LU) == batchCount(pivots.unsqueeze(-1)), "batch_size of LU and pivots must be the same");
const auto trans = to_cublas(transpose);
auto pivots_data = pivots.const_data_ptr<int>();
auto batch_size = cuda_int_cast(batchCount(LU), "batch_size");;
auto m = cuda_int_cast(LU.size(-2), "m");
auto nrhs = cuda_int_cast(B.size(-1), "nrhs");
auto lda = cuda_int_cast(std::max<int>(1, m), "lda");
int info = 0;
Tensor lu_ptr_array = get_device_pointers<scalar_t>(LU);
Tensor b_ptr_array = get_device_pointers<scalar_t>(B);
auto lu_ptr_array_data = reinterpret_cast<const scalar_t* const*>(lu_ptr_array.const_data_ptr());
auto b_ptr_array_data = reinterpret_cast<scalar_t**>(b_ptr_array.data_ptr());
auto handle = at::cuda::getCurrentCUDABlasHandle();
at::cuda::blas::getrsBatched(handle, trans, m, nrhs, const_cast<scalar_t**>(lu_ptr_array_data),
lda, const_cast<int*>(pivots_data), b_ptr_array_data, lda, &info, batch_size);
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(info == 0);
}
Source
Analyze Your Own Codebase
Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.
Try Supermodel Free