Home / Class/ apply_lu_solve_batched_magma Class — pytorch Architecture

apply_lu_solve_batched_magma Class — pytorch Architecture

Architecture documentation for the apply_lu_solve_batched_magma class in BatchLinearAlgebra.cpp from the pytorch codebase.

Entity Profile

Source Code

aten/src/ATen/native/cuda/linalg/BatchLinearAlgebra.cpp lines 1822–1884

template <typename scalar_t>
static void apply_lu_solve_batched_magma(const Tensor& LU, const Tensor& pivots, const Tensor& B, TransposeType transpose) {
#if !AT_MAGMA_ENABLED()
  TORCH_CHECK(
      false,
      "Calling linalg.lu_solve on a CUDA tensor requires compiling ",
      "PyTorch with MAGMA. Please rebuild with MAGMA.");
#else
  TORCH_INTERNAL_ASSERT(batchCount(B) == batchCount(LU), "batch_size of LU and B must be the same");
  TORCH_INTERNAL_ASSERT(batchCount(LU) == batchCount(pivots.unsqueeze(-1)), "batch_size of LU and pivots must be the same");
  auto trans = to_magma(transpose);
  auto b_data = B.data_ptr<scalar_t>();
  auto lu_data = LU.data_ptr<scalar_t>();

  magma_int_t n = magma_int_cast(LU.size(-2), "n");
  magma_int_t nrhs = magma_int_cast(B.size(-1), "nrhs");
  auto leading_dimension = std::max<magma_int_t>(1, n);

  auto pivots_data = pivots.data_ptr<magma_int_t>();

  auto b_stride = matrixStride(B);
  auto lu_stride = matrixStride(LU);
  auto pivots_stride = pivots.size(-1);
  magma_int_t batch_size = magma_int_cast(batchCount(B), "batchCount");

  magma_int_t** pivots_array;
  scalar_t** lu_array;
  scalar_t** b_array;

  ALLOCATE_ARRAY(pivots_array, magma_int_t*, batch_size);
  ALLOCATE_ARRAY(lu_array, scalar_t*, batch_size);
  ALLOCATE_ARRAY(b_array, scalar_t*, batch_size);

  for (int64_t i = 0; i < batch_size; i++) {
    pivots_array[i] = &pivots_data[i * pivots_stride];
    b_array[i] = &b_data[i * b_stride];
    lu_array[i] = &lu_data[i * lu_stride];
  }

  MAGMAQueue magma_queue(B.get_device());

  // Compute the result in batches of 65535
  // that is the maximum allowed number for batch_size in MAGMA
  constexpr int64_t batch_limit = 65535;

  for (int64_t mini_idx = 0; mini_idx < batch_size; mini_idx += batch_limit) {
    int64_t nbatches = std::min(batch_limit, batch_size - mini_idx);
    scalar_t** lu_array_cur = &lu_array[mini_idx];
    scalar_t** b_array_cur = &b_array[mini_idx];
    magma_int_t** pivots_array_cur = &pivots_array[mini_idx];

    int info;
    magmaLuSolveBatched<scalar_t>(
        n, nrhs, lu_array_cur, leading_dimension,
        pivots_array_cur, b_array_cur, leading_dimension,
        info, nbatches, magma_queue, trans);

    // info from magmaLuSolveBatched only reports if the i-th parameter is wrong
    // so we don't need to check it all the time
    TORCH_INTERNAL_ASSERT_DEBUG_ONLY(info == 0);
  }
#endif
}

Analyze Your Own Codebase

Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.

Try Supermodel Free