apply_lu_solve Class — pytorch Architecture
Architecture documentation for the apply_lu_solve class in BatchLinearAlgebraKernel.cpp from the pytorch codebase.
Entity Profile
Source Code
aten/src/ATen/native/BatchLinearAlgebraKernel.cpp lines 1035–1079
template <typename scalar_t>
void apply_lu_solve(const Tensor& LU, const Tensor& pivots, const Tensor& B, TransposeType transpose) {
#if !AT_BUILD_WITH_LAPACK()
TORCH_CHECK(
false,
"Calling linalg.lu_solve on a CPU tensor requires compiling ",
"PyTorch with LAPACK. Please use PyTorch built with LAPACK support.");
#else
auto b_data = B.data_ptr<scalar_t>();
auto lu_data = LU.const_data_ptr<scalar_t>();
const auto trans = to_blas(transpose);
auto pivots_data = pivots.const_data_ptr<int>();
auto b_stride = matrixStride(B);
auto lu_stride = LU.dim() > 2 ? LU.stride(-3) : 0;
auto pivots_stride = pivots.dim() > 1 ? pivots.stride(-2) : 0;
auto batch_size = batchCount(B);
auto n = LU.size(-2);
auto nrhs = B.size(-1);
auto leading_dimension = std::max<int64_t>(1, n);
int info = 0;
// lu and pivots tensors can be broadcast to B
// here we construct a helper indexing tensor to linearly index into LU and pivots
IntArrayRef lu_batch_shape(LU.sizes().data(), LU.dim() - 2);
IntArrayRef b_batch_shape(B.sizes().data(), B.dim() - 2);
BroadcastLinearIndices lu_index(
batchCount(LU), lu_batch_shape, b_batch_shape);
for (const auto i : c10::irange(batch_size)) {
int64_t lu_index_i = lu_index(i);
scalar_t* b_working_ptr = &b_data[i * b_stride];
const scalar_t* lu_working_ptr = &lu_data[lu_index_i * lu_stride];
const int* pivots_working_ptr = &pivots_data[lu_index_i * pivots_stride];
lapackLuSolve<scalar_t>(trans, n, nrhs, const_cast<scalar_t*>(lu_working_ptr), leading_dimension, const_cast<int*>(pivots_working_ptr),
b_working_ptr, leading_dimension, &info);
// info from lapackLuSolve only reports if the i-th parameter is wrong
// so we don't need to check it all the time
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(info == 0);
}
#endif
}
Source
Analyze Your Own Codebase
Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.
Try Supermodel Free