apply_cholesky_solve Class — pytorch Architecture
Architecture documentation for the apply_cholesky_solve class in BatchLinearAlgebra.cpp from the pytorch codebase.
Entity Profile
Source Code
aten/src/ATen/native/cuda/linalg/BatchLinearAlgebra.cpp lines 940–1007
template <typename scalar_t>
static void apply_cholesky_solve(Tensor& b, Tensor& A, bool upper, int64_t& info) {
#if !AT_MAGMA_ENABLED()
TORCH_CHECK(false, "cholesky_solve: MAGMA library not found in "
"compilation. Please rebuild with MAGMA.");
#else
magma_uplo_t uplo = upper ? MagmaUpper : MagmaLower;
auto A_data = A.data_ptr<scalar_t>();
auto b_data = b.data_ptr<scalar_t>();
magma_int_t n = magma_int_cast(A.size(-2), "A.size(-2)");
magma_int_t lda = std::max<magma_int_t>(1, n);
magma_int_t nrhs = magma_int_cast(b.size(-1), "b.size(-1)");
int info_tmp = 0;
if (b.dim() == 2) {
magmaCholeskySolve<scalar_t>(uplo, n, nrhs, A_data, lda,
b_data, lda, &info_tmp);
info = info_tmp;
} else {
auto A_mat_stride = matrixStride(A);
auto b_mat_stride = matrixStride(b);
magma_int_t batch_size = magma_int_cast(batchCount(A), "batchCount");
scalar_t** A_array;
scalar_t** b_array;
ALLOCATE_ARRAY(A_array, scalar_t*, batch_size);
ALLOCATE_ARRAY(b_array, scalar_t*, batch_size);
// Set up the created arrays
for (int64_t i = 0; i < batch_size; i++) {
A_array[i] = &A_data[i * A_mat_stride];
b_array[i] = &b_data[i * b_mat_stride];
}
MAGMAQueue magma_queue(b.get_device());
constexpr int64_t batch_limit = 65535;
// Compute as many batches of 65535 possible
// The number of "mini"-batches are floor(batch_size / batch_limit)
// and these cover floor(batch_size / batch_limit) * batch_limit matrix solves
int64_t mini_batches = batch_size / batch_limit, mini_idx;
for (mini_idx = 0; mini_idx < mini_batches * batch_limit; mini_idx += batch_limit) {
scalar_t** A_array_cur = &A_array[mini_idx];
scalar_t** b_array_cur = &b_array[mini_idx];
magmaCholeskySolveBatched<scalar_t>(
uplo, n, nrhs, A_array_cur, lda, b_array_cur, lda,
info_tmp, batch_limit, magma_queue);
if (info_tmp != 0) {
break;
}
}
// Compute whatever is left = batch_size - floor(batch_size / batch_limit) * batch_limit
// which concisely is equal to batch_size % batch_limit
if (batch_size % batch_limit != 0 && info_tmp == 0) {
magmaCholeskySolveBatched<scalar_t>(
uplo, n, nrhs, &A_array[mini_idx], lda, &b_array[mini_idx], lda,
info_tmp, batch_size % batch_limit, magma_queue);
}
info = info_tmp;
}
#endif
}
Source
Analyze Your Own Codebase
Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.
Try Supermodel Free