Home / Class/ apply_lu_solve_looped_magma Class — pytorch Architecture

apply_lu_solve_looped_magma Class — pytorch Architecture

Architecture documentation for the apply_lu_solve_looped_magma class in BatchLinearAlgebra.cpp from the pytorch codebase.

Entity Profile

Source Code

aten/src/ATen/native/cuda/linalg/BatchLinearAlgebra.cpp lines 1760–1806

template <typename scalar_t>
static void apply_lu_solve_looped_magma(const Tensor& LU, const Tensor& pivots, const Tensor& B, TransposeType transpose) {
#if !AT_MAGMA_ENABLED()
  TORCH_CHECK(
      false,
      "Calling linalg.lu_solve on a CUDA tensor requires compiling ",
      "PyTorch with MAGMA. Please rebuild with MAGMA.");
#else
  auto trans = to_magma(transpose);
  auto b_data = B.data_ptr<scalar_t>();
  auto lu_data = LU.data_ptr<scalar_t>();

  // MAGMA requires pivots to be a CPU tensor
  Tensor pivots_cpu = pivots.cpu();
  auto pivots_data = pivots_cpu.data_ptr<magma_int_t>();

  auto b_stride = matrixStride(B);
  auto lu_stride = LU.dim() > 2 ? LU.stride(-3) : 0;
  auto pivots_stride = pivots_cpu.dim() > 1 ? pivots_cpu.stride(-2) : 0;
  auto batch_size = batchCount(B);

  magma_int_t n = magma_int_cast(LU.size(-2), "n");
  magma_int_t nrhs = magma_int_cast(B.size(-1), "nrhs");
  auto leading_dimension = std::max<magma_int_t>(1, n);

  // LU and pivots tensors can be broadcasted to B
  // here we construct a helper indexing tensor to linearly index into lu and pivots
  IntArrayRef lu_batch_shape(LU.sizes().data(), LU.dim() - 2);
  IntArrayRef b_batch_shape(B.sizes().data(), B.dim() - 2);
  BroadcastLinearIndices lu_index(
      batchCount(LU), lu_batch_shape, b_batch_shape);

  int info = 0;
  for (decltype(batch_size) i = 0; i < batch_size; i++) {
    int64_t lu_index_i = lu_index(i);
    scalar_t* b_working_ptr = &b_data[i * b_stride];
    scalar_t* lu_working_ptr = &lu_data[lu_index_i * lu_stride];
    int* pivots_working_ptr = &pivots_data[lu_index_i * pivots_stride];

    magmaLuSolve<scalar_t>(n, nrhs, lu_working_ptr, leading_dimension, pivots_working_ptr, b_working_ptr, leading_dimension, &info, trans);

    // info from magmaLuSolve only reports if the i-th parameter is wrong
    // so we don't need to check it all the time
    TORCH_INTERNAL_ASSERT_DEBUG_ONLY(info == 0);
  }
#endif
}

Analyze Your Own Codebase

Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.

Try Supermodel Free