apply_lu_factor_looped_magma Class — pytorch Architecture
Architecture documentation for the apply_lu_factor_looped_magma class in BatchLinearAlgebra.cpp from the pytorch codebase.
Entity Profile
Source Code
aten/src/ATen/native/cuda/linalg/BatchLinearAlgebra.cpp lines 1148–1186
template <typename scalar_t>
static void apply_lu_factor_looped_magma(const Tensor& input, const Tensor& pivots, const Tensor& infos, bool compute_pivots) {
#if !AT_MAGMA_ENABLED()
// This should never be thrown if the calling functions are correct.
TORCH_CHECK(false, "linalg.lu_factor: PyTorch was not compiled with MAGMA support.");
#else
// magmaLu and magmaLuNoPiv require infos and pivots tensor to be on CPU
// the data is later copied back to the appropriate output tensor
Tensor infos_cpu = at::empty_like(infos, infos.options().device(kCPU).pinned_memory(true));
auto input_data = input.data_ptr<scalar_t>();
auto infos_data = infos_cpu.mutable_data_ptr<magma_int_t>();
auto input_matrix_stride = matrixStride(input);
auto pivots_stride = pivots.size(-1);
auto batch_size = batchCount(input);
magma_int_t m = magma_int_cast(input.size(-2), "m");
magma_int_t n = magma_int_cast(input.size(-1), "n");
auto leading_dimension = std::max<magma_int_t>(1, m);
if (compute_pivots) {
Tensor pivots_cpu = at::empty_like(pivots, pivots.options().device(kCPU).pinned_memory(true));
auto pivots_data = pivots_cpu.mutable_data_ptr<magma_int_t>();
for (decltype(batch_size) i = 0; i < batch_size; i++) {
scalar_t* input_working_ptr = &input_data[i * input_matrix_stride];
int* pivots_working_ptr = &pivots_data[i * pivots_stride];
int* infos_working_ptr = &infos_data[i];
magmaLu<scalar_t>(m, n, input_working_ptr, leading_dimension, pivots_working_ptr, infos_working_ptr);
}
pivots.copy_(pivots_cpu);
} else {
for (decltype(batch_size) i = 0; i < batch_size; i++) {
scalar_t* input_working_ptr = &input_data[i * input_matrix_stride];
int* infos_working_ptr = &infos_data[i];
magmaLuNoPiv<scalar_t>(m, n, input_working_ptr, leading_dimension, infos_working_ptr);
}
}
infos.copy_(infos_cpu);
#endif
}
Source
Analyze Your Own Codebase
Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.
Try Supermodel Free