Home / Class/ apply_lu_factor Class — pytorch Architecture

apply_lu_factor Class — pytorch Architecture

Architecture documentation for the apply_lu_factor class in BatchLinearAlgebraKernel.cpp from the pytorch codebase.

Entity Profile

Source Code

aten/src/ATen/native/BatchLinearAlgebraKernel.cpp lines 970–1013

template <typename scalar_t>
void apply_lu_factor(const Tensor& input, const Tensor& pivots, const Tensor& infos, bool compute_pivots) {
#if !AT_BUILD_WITH_LAPACK()
  TORCH_CHECK(
      false,
      "Calling torch.linalg.lu_factor on a CPU tensor requires compiling ",
      "PyTorch with LAPACK. Please use PyTorch built with LAPACK support.");
#else
  TORCH_CHECK(compute_pivots, "linalg.lu_factor: LU without pivoting is not implemented on the CPU");

  auto input_data = input.data_ptr<scalar_t>();
  auto pivots_data = pivots.data_ptr<int>();
  auto infos_data = infos.data_ptr<int>();
  auto input_matrix_stride = matrixStride(input);
  auto pivots_stride = pivots.size(-1);
  auto batch_size = batchCount(input);
  auto m = input.size(-2);
  auto n = input.size(-1);
  auto leading_dimension = std::max<int64_t>(1, m);

  const auto loop = [&](int64_t start, int64_t end) {
    for (const auto i : c10::irange(start, end)) {
      scalar_t* input_working_ptr = &input_data[i * input_matrix_stride];
      int* pivots_working_ptr = &pivots_data[i * pivots_stride];
      int* infos_working_ptr = &infos_data[i];
      lapackLu<scalar_t>(
          m,
          n,
          input_working_ptr,
          leading_dimension,
          pivots_working_ptr,
          infos_working_ptr);
    }
  };
  // avoid overflow
  auto matrix_rank = std::min(m, n);
  // A heuristic tested on a 32 core/socket ICX system
  // https://github.com/pytorch/pytorch/pull/93037#discussion_r1090112948
  int64_t chunk_size_per_thread = static_cast<int64_t>(
      std::min(1.0, 3200.0 / (matrix_rank * matrix_rank * matrix_rank)));
  int64_t grain_size = chunk_size_per_thread * at::get_num_threads();
  at::parallel_for(0, batch_size, grain_size, loop);
#endif
}

Analyze Your Own Codebase

Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.

Try Supermodel Free