apply_triangular_solve_batched_magma Class — pytorch Architecture
Architecture documentation for the apply_triangular_solve_batched_magma class in BatchLinearAlgebra.cpp from the pytorch codebase.
Entity Profile
Source Code
aten/src/ATen/native/cuda/linalg/BatchLinearAlgebra.cpp lines 1341–1405
template <typename scalar_t>
static void apply_triangular_solve_batched_magma(const Tensor& A, const Tensor& b, bool left, bool upper, TransposeType transpose, bool unitriangular) {
#if !AT_MAGMA_ENABLED()
TORCH_CHECK(false, "triangular_solve: MAGMA library not found in "
"compilation. Please rebuild with MAGMA.");
#else
magma_uplo_t uplo = upper ? MagmaUpper : MagmaLower;
magma_trans_t trans = to_magma(transpose);
magma_diag_t diag = unitriangular ? MagmaUnit : MagmaNonUnit;
magma_side_t side = left ? MagmaLeft : MagmaRight;
auto A_data = A.data_ptr<scalar_t>();
auto b_data = b.data_ptr<scalar_t>();
// This allows to pass rectangular A and b when left = True
magma_int_t m = magma_int_cast(left ? A.size(-1) : b.size(-2), "m");
magma_int_t n = magma_int_cast(b.size(-1), "n");
// magma returns early if m <= 0 || n <= 0 for magmaTriangularSolveBatched
// magmaTriangularSolve is calling cuBLAS and it prints
// ** On entry to DTRSM parameter number 9 had an illegal value
// so let's use proper lda parameter here
magma_int_t lda = std::max<magma_int_t>(1, A.size(-2));
magma_int_t ldb = std::max<magma_int_t>(1, b.size(-2));
magma_int_t batch_size = magma_int_cast(batchCount(A), "batch_size");
auto A_mat_stride = matrixStride(A);
auto b_mat_stride = matrixStride(b);
scalar_t** A_array;
scalar_t** b_array;
ALLOCATE_ARRAY(A_array, scalar_t*, batch_size);
ALLOCATE_ARRAY(b_array, scalar_t*, batch_size);
// Set up the created arrays
for (int64_t i = 0; i < batch_size; i++) {
A_array[i] = &A_data[i * A_mat_stride];
b_array[i] = &b_data[i * b_mat_stride];
}
MAGMAQueue magma_queue(b.get_device());
constexpr int64_t batch_limit = 65535;
// Compute as many batches of 65535 as possible
// The number of "mini"-batches are floor(batch_size / batch_limit)
// and these cover floor(batch_size / batch_limit) * batch_limit matrix solves
int64_t mini_batches = batch_size / batch_limit;
int64_t mini_idx; // this is outside the loop because it is used for the case batch_size % batch_limit != 0
for (mini_idx = 0; mini_idx < mini_batches * batch_limit; mini_idx += batch_limit) {
scalar_t** A_array_cur = &A_array[mini_idx];
scalar_t** b_array_cur = &b_array[mini_idx];
magmaTriangularSolveBatched<scalar_t>(
side, uplo, trans, diag, m, n, A_array_cur,
lda, b_array_cur, ldb, batch_limit, magma_queue);
}
// Compute whatever is left = batch_size - floor(batch_size / batch_limit) * batch_limit
// which concisely is equal to batch_size % batch_limit
if (batch_size % batch_limit != 0) {
magmaTriangularSolveBatched<scalar_t>(
side, uplo, trans, diag, m, n, &A_array[mini_idx],
lda, &b_array[mini_idx], ldb, batch_size % batch_limit, magma_queue);
}
#endif
}
Source
Analyze Your Own Codebase
Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.
Try Supermodel Free