Home / Class/ apply_lu_factor_looped_magma Class — pytorch Architecture

apply_lu_factor_looped_magma Class — pytorch Architecture

Architecture documentation for the apply_lu_factor_looped_magma class in BatchLinearAlgebra.cpp from the pytorch codebase.

Entity Profile

Source Code

aten/src/ATen/native/cuda/linalg/BatchLinearAlgebra.cpp lines 1148–1186

template <typename scalar_t>
static void apply_lu_factor_looped_magma(const Tensor& input, const Tensor& pivots, const Tensor& infos, bool compute_pivots) {
#if !AT_MAGMA_ENABLED()
  // This should never be thrown if the calling functions are correct.
  TORCH_CHECK(false, "linalg.lu_factor: PyTorch was not compiled with MAGMA support.");
#else
  // magmaLu and magmaLuNoPiv require infos and pivots tensor to be on CPU
  // the data is later copied back to the appropriate output tensor
  Tensor infos_cpu = at::empty_like(infos, infos.options().device(kCPU).pinned_memory(true));

  auto input_data = input.data_ptr<scalar_t>();
  auto infos_data = infos_cpu.mutable_data_ptr<magma_int_t>();
  auto input_matrix_stride = matrixStride(input);
  auto pivots_stride = pivots.size(-1);
  auto batch_size = batchCount(input);
  magma_int_t m = magma_int_cast(input.size(-2), "m");
  magma_int_t n = magma_int_cast(input.size(-1), "n");
  auto leading_dimension = std::max<magma_int_t>(1, m);

  if (compute_pivots) {
    Tensor pivots_cpu = at::empty_like(pivots, pivots.options().device(kCPU).pinned_memory(true));
    auto pivots_data = pivots_cpu.mutable_data_ptr<magma_int_t>();
    for (decltype(batch_size) i = 0; i < batch_size; i++) {
      scalar_t* input_working_ptr = &input_data[i * input_matrix_stride];
      int* pivots_working_ptr = &pivots_data[i * pivots_stride];
      int* infos_working_ptr = &infos_data[i];
      magmaLu<scalar_t>(m, n, input_working_ptr, leading_dimension, pivots_working_ptr, infos_working_ptr);
    }
    pivots.copy_(pivots_cpu);
  } else {
    for (decltype(batch_size) i = 0; i < batch_size; i++) {
      scalar_t* input_working_ptr = &input_data[i * input_matrix_stride];
      int* infos_working_ptr = &infos_data[i];
      magmaLuNoPiv<scalar_t>(m, n, input_working_ptr, leading_dimension, infos_working_ptr);
    }
  }
  infos.copy_(infos_cpu);
#endif
}

Analyze Your Own Codebase

Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.

Try Supermodel Free