Home / Class/ apply_lu_solve Class — pytorch Architecture

apply_lu_solve Class — pytorch Architecture

Architecture documentation for the apply_lu_solve class in BatchLinearAlgebraKernel.cpp from the pytorch codebase.

Entity Profile

Source Code

aten/src/ATen/native/BatchLinearAlgebraKernel.cpp lines 1035–1079

template <typename scalar_t>
void apply_lu_solve(const Tensor& LU, const Tensor& pivots, const Tensor& B, TransposeType transpose) {
#if !AT_BUILD_WITH_LAPACK()
  TORCH_CHECK(
      false,
      "Calling linalg.lu_solve on a CPU tensor requires compiling ",
      "PyTorch with LAPACK. Please use PyTorch built with LAPACK support.");
#else
  auto b_data = B.data_ptr<scalar_t>();
  auto lu_data = LU.const_data_ptr<scalar_t>();
  const auto trans = to_blas(transpose);
  auto pivots_data = pivots.const_data_ptr<int>();
  auto b_stride = matrixStride(B);
  auto lu_stride = LU.dim() > 2 ? LU.stride(-3) : 0;
  auto pivots_stride = pivots.dim() > 1 ? pivots.stride(-2) : 0;
  auto batch_size = batchCount(B);

  auto n = LU.size(-2);
  auto nrhs = B.size(-1);
  auto leading_dimension = std::max<int64_t>(1, n);

  int info = 0;

  // lu and pivots tensors can be broadcast to B
  // here we construct a helper indexing tensor to linearly index into LU and pivots
  IntArrayRef lu_batch_shape(LU.sizes().data(), LU.dim() - 2);
  IntArrayRef b_batch_shape(B.sizes().data(), B.dim() - 2);
  BroadcastLinearIndices lu_index(
      batchCount(LU), lu_batch_shape, b_batch_shape);

  for (const auto i : c10::irange(batch_size)) {
    int64_t lu_index_i = lu_index(i);
    scalar_t* b_working_ptr = &b_data[i * b_stride];
    const scalar_t* lu_working_ptr = &lu_data[lu_index_i * lu_stride];
    const int* pivots_working_ptr = &pivots_data[lu_index_i * pivots_stride];

    lapackLuSolve<scalar_t>(trans, n, nrhs, const_cast<scalar_t*>(lu_working_ptr), leading_dimension, const_cast<int*>(pivots_working_ptr),
                            b_working_ptr, leading_dimension, &info);

    // info from lapackLuSolve only reports if the i-th parameter is wrong
    // so we don't need to check it all the time
    TORCH_INTERNAL_ASSERT_DEBUG_ONLY(info == 0);
  }
#endif
}

Analyze Your Own Codebase

Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.

Try Supermodel Free