Home / Class/ apply_lu_factor_batched_magma Class — pytorch Architecture

apply_lu_factor_batched_magma Class — pytorch Architecture

Architecture documentation for the apply_lu_factor_batched_magma class in BatchLinearAlgebra.cpp from the pytorch codebase.

Entity Profile

Source Code

aten/src/ATen/native/cuda/linalg/BatchLinearAlgebra.cpp lines 1202–1260

template <typename scalar_t>
static void apply_lu_factor_batched_magma(const Tensor& input, const Tensor& pivots, const Tensor& infos, bool compute_pivots) {
#if !AT_MAGMA_ENABLED()
  TORCH_CHECK(
      false,
      "Calling linalg.lu_factor on a CUDA tensor requires compiling ",
      "PyTorch with MAGMA. Please rebuild with MAGMA.");
#else
  // There is a bug in lu_factor_batched_magma in MAGMA < 2.5.2, see
  // https://bitbucket.org/icl/magma/issues/13/getrf_batched-kernel-produces-nans-on
  std::tuple<magma_int_t, magma_int_t, magma_int_t> version;
  magma_version(&std::get<0>(version), &std::get<1>(version), &std::get<2>(version));
  const bool magma_batched_buggy = version < std::make_tuple<magma_int_t, magma_int_t, magma_int_t>(2, 5, 2);
  TORCH_CHECK(!magma_batched_buggy, "linalg.lu_factor has buggs on MAGMA < 2.5.2. Please update your MAGMA version to a newer one.");

  auto input_data = input.data_ptr<scalar_t>();
  auto infos_data = infos.data_ptr<magma_int_t>();
  auto input_matrix_stride = matrixStride(input);
  magma_int_t batch_size = magma_int_cast(batchCount(input), "batchCount");

  magma_int_t m = magma_int_cast(input.size(-2), "m");
  magma_int_t n = magma_int_cast(input.size(-1), "n");
  auto leading_dimension = std::max<magma_int_t>(1, m);

  scalar_t** input_array;
  ALLOCATE_ARRAY(input_array, scalar_t*, batch_size);

  // Set up array of pointers to matrices
  for (int64_t i = 0; i < batch_size; i++) {
    input_array[i] = &input_data[i * input_matrix_stride];
  }

  // needed to run lu tests in parallel, see https://github.com/pytorch/pytorch/issues/82894 for examples
  // of failures
  c10::cuda::device_synchronize();
  MAGMAQueue magma_queue(input.get_device());

  if (compute_pivots) {
    auto pivots_data = pivots.data_ptr<magma_int_t>();
    auto pivots_stride = pivots.size(-1);
    // fill pivots with ones to avoid memory access violations inside magma kernels
    // magmaLuBatched might not set the values for it
    // see https://github.com/pytorch/pytorch/pull/53064
    pivots.fill_(1);
    magma_int_t** pivots_array;
    ALLOCATE_ARRAY(pivots_array, magma_int_t*, batch_size);
    for (int64_t i = 0; i < batch_size; i++) {
      pivots_array[i] = &pivots_data[i * pivots_stride];
    }
    magmaLuBatched<scalar_t>(m, n, input_array, leading_dimension, pivots_array, infos_data, batch_size, magma_queue);
  } else {
    magmaLuNoPivBatched<scalar_t>(m, n, input_array, leading_dimension, infos_data, batch_size, magma_queue);
  }

  // block CPU until all operations on the queue are finished
  // this explicit sync prevents garbage results from the subsequent magmaLuSolveBatched call from a different queue
  magma_queue_sync(magma_queue.get_queue());
#endif
}

Analyze Your Own Codebase

Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.

Try Supermodel Free