apply_lu_factor_batched_magma Class — pytorch Architecture
Architecture documentation for the apply_lu_factor_batched_magma class in BatchLinearAlgebra.cpp from the pytorch codebase.
Entity Profile
Source Code
aten/src/ATen/native/cuda/linalg/BatchLinearAlgebra.cpp lines 1202–1260
template <typename scalar_t>
static void apply_lu_factor_batched_magma(const Tensor& input, const Tensor& pivots, const Tensor& infos, bool compute_pivots) {
#if !AT_MAGMA_ENABLED()
TORCH_CHECK(
false,
"Calling linalg.lu_factor on a CUDA tensor requires compiling ",
"PyTorch with MAGMA. Please rebuild with MAGMA.");
#else
// There is a bug in lu_factor_batched_magma in MAGMA < 2.5.2, see
// https://bitbucket.org/icl/magma/issues/13/getrf_batched-kernel-produces-nans-on
std::tuple<magma_int_t, magma_int_t, magma_int_t> version;
magma_version(&std::get<0>(version), &std::get<1>(version), &std::get<2>(version));
const bool magma_batched_buggy = version < std::make_tuple<magma_int_t, magma_int_t, magma_int_t>(2, 5, 2);
TORCH_CHECK(!magma_batched_buggy, "linalg.lu_factor has buggs on MAGMA < 2.5.2. Please update your MAGMA version to a newer one.");
auto input_data = input.data_ptr<scalar_t>();
auto infos_data = infos.data_ptr<magma_int_t>();
auto input_matrix_stride = matrixStride(input);
magma_int_t batch_size = magma_int_cast(batchCount(input), "batchCount");
magma_int_t m = magma_int_cast(input.size(-2), "m");
magma_int_t n = magma_int_cast(input.size(-1), "n");
auto leading_dimension = std::max<magma_int_t>(1, m);
scalar_t** input_array;
ALLOCATE_ARRAY(input_array, scalar_t*, batch_size);
// Set up array of pointers to matrices
for (int64_t i = 0; i < batch_size; i++) {
input_array[i] = &input_data[i * input_matrix_stride];
}
// needed to run lu tests in parallel, see https://github.com/pytorch/pytorch/issues/82894 for examples
// of failures
c10::cuda::device_synchronize();
MAGMAQueue magma_queue(input.get_device());
if (compute_pivots) {
auto pivots_data = pivots.data_ptr<magma_int_t>();
auto pivots_stride = pivots.size(-1);
// fill pivots with ones to avoid memory access violations inside magma kernels
// magmaLuBatched might not set the values for it
// see https://github.com/pytorch/pytorch/pull/53064
pivots.fill_(1);
magma_int_t** pivots_array;
ALLOCATE_ARRAY(pivots_array, magma_int_t*, batch_size);
for (int64_t i = 0; i < batch_size; i++) {
pivots_array[i] = &pivots_data[i * pivots_stride];
}
magmaLuBatched<scalar_t>(m, n, input_array, leading_dimension, pivots_array, infos_data, batch_size, magma_queue);
} else {
magmaLuNoPivBatched<scalar_t>(m, n, input_array, leading_dimension, infos_data, batch_size, magma_queue);
}
// block CPU until all operations on the queue are finished
// this explicit sync prevents garbage results from the subsequent magmaLuSolveBatched call from a different queue
magma_queue_sync(magma_queue.get_queue());
#endif
}
Source
Analyze Your Own Codebase
Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.
Try Supermodel Free