apply_lu_solve_looped_magma Class — pytorch Architecture
Architecture documentation for the apply_lu_solve_looped_magma class in BatchLinearAlgebra.cpp from the pytorch codebase.
Entity Profile
Source Code
aten/src/ATen/native/cuda/linalg/BatchLinearAlgebra.cpp lines 1760–1806
template <typename scalar_t>
static void apply_lu_solve_looped_magma(const Tensor& LU, const Tensor& pivots, const Tensor& B, TransposeType transpose) {
#if !AT_MAGMA_ENABLED()
TORCH_CHECK(
false,
"Calling linalg.lu_solve on a CUDA tensor requires compiling ",
"PyTorch with MAGMA. Please rebuild with MAGMA.");
#else
auto trans = to_magma(transpose);
auto b_data = B.data_ptr<scalar_t>();
auto lu_data = LU.data_ptr<scalar_t>();
// MAGMA requires pivots to be a CPU tensor
Tensor pivots_cpu = pivots.cpu();
auto pivots_data = pivots_cpu.data_ptr<magma_int_t>();
auto b_stride = matrixStride(B);
auto lu_stride = LU.dim() > 2 ? LU.stride(-3) : 0;
auto pivots_stride = pivots_cpu.dim() > 1 ? pivots_cpu.stride(-2) : 0;
auto batch_size = batchCount(B);
magma_int_t n = magma_int_cast(LU.size(-2), "n");
magma_int_t nrhs = magma_int_cast(B.size(-1), "nrhs");
auto leading_dimension = std::max<magma_int_t>(1, n);
// LU and pivots tensors can be broadcasted to B
// here we construct a helper indexing tensor to linearly index into lu and pivots
IntArrayRef lu_batch_shape(LU.sizes().data(), LU.dim() - 2);
IntArrayRef b_batch_shape(B.sizes().data(), B.dim() - 2);
BroadcastLinearIndices lu_index(
batchCount(LU), lu_batch_shape, b_batch_shape);
int info = 0;
for (decltype(batch_size) i = 0; i < batch_size; i++) {
int64_t lu_index_i = lu_index(i);
scalar_t* b_working_ptr = &b_data[i * b_stride];
scalar_t* lu_working_ptr = &lu_data[lu_index_i * lu_stride];
int* pivots_working_ptr = &pivots_data[lu_index_i * pivots_stride];
magmaLuSolve<scalar_t>(n, nrhs, lu_working_ptr, leading_dimension, pivots_working_ptr, b_working_ptr, leading_dimension, &info, trans);
// info from magmaLuSolve only reports if the i-th parameter is wrong
// so we don't need to check it all the time
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(info == 0);
}
#endif
}
Source
Analyze Your Own Codebase
Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.
Try Supermodel Free