Home / Class/ apply_lu_solve_batched_cublas Class — pytorch Architecture

apply_lu_solve_batched_cublas Class — pytorch Architecture

Architecture documentation for the apply_lu_solve_batched_cublas class in BatchLinearAlgebraLibBlas.cpp from the pytorch codebase.

Entity Profile

Source Code

aten/src/ATen/native/cuda/linalg/BatchLinearAlgebraLibBlas.cpp lines 130–152

template <typename scalar_t>
static void apply_lu_solve_batched_cublas(const Tensor& LU, const Tensor& pivots, const Tensor& B, TransposeType transpose) {
  TORCH_INTERNAL_ASSERT(batchCount(LU) == batchCount(B), "batch_size of LU and B must be the same");
  TORCH_INTERNAL_ASSERT(batchCount(LU) == batchCount(pivots.unsqueeze(-1)), "batch_size of LU and pivots must be the same");
  const auto trans = to_cublas(transpose);

  auto pivots_data = pivots.const_data_ptr<int>();
  auto batch_size = cuda_int_cast(batchCount(LU), "batch_size");;
  auto m = cuda_int_cast(LU.size(-2), "m");
  auto nrhs = cuda_int_cast(B.size(-1), "nrhs");
  auto lda = cuda_int_cast(std::max<int>(1, m), "lda");
  int info = 0;

  Tensor lu_ptr_array = get_device_pointers<scalar_t>(LU);
  Tensor b_ptr_array = get_device_pointers<scalar_t>(B);
  auto lu_ptr_array_data = reinterpret_cast<const scalar_t* const*>(lu_ptr_array.const_data_ptr());
  auto b_ptr_array_data = reinterpret_cast<scalar_t**>(b_ptr_array.data_ptr());

  auto handle = at::cuda::getCurrentCUDABlasHandle();
  at::cuda::blas::getrsBatched(handle, trans, m, nrhs, const_cast<scalar_t**>(lu_ptr_array_data),
    lda, const_cast<int*>(pivots_data), b_ptr_array_data, lda, &info, batch_size);
  TORCH_INTERNAL_ASSERT_DEBUG_ONLY(info == 0);
}

Analyze Your Own Codebase

Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.

Try Supermodel Free