Home / Class/ apply_ldl_factor_cusolver Class — pytorch Architecture

apply_ldl_factor_cusolver Class — pytorch Architecture

Architecture documentation for the apply_ldl_factor_cusolver class in BatchLinearAlgebraLib.cpp from the pytorch codebase.

Entity Profile

Source Code

aten/src/ATen/native/cuda/linalg/BatchLinearAlgebraLib.cpp lines 75–115

template <typename scalar_t>
void apply_ldl_factor_cusolver(
    const Tensor& A,
    const Tensor& pivots,
    const Tensor& info,
    bool upper) {
  auto batch_size = batchCount(A);
  auto n = cuda_int_cast(A.size(-2), "A.size(-2)");
  auto lda = cuda_int_cast(A.stride(-1), "A.stride(-1)");
  auto uplo = upper ? CUBLAS_FILL_MODE_UPPER : CUBLAS_FILL_MODE_LOWER;

  auto a_stride = A.dim() > 2 ? A.stride(-3) : 0;
  auto pivots_stride = pivots.dim() > 1 ? pivots.stride(-2) : 0;

  auto a_data = A.data_ptr<scalar_t>();
  auto pivots_data = pivots.data_ptr<int>();
  auto info_data = info.data_ptr<int>();

  auto handle = at::cuda::getCurrentCUDASolverDnHandle();

  int lwork = 0;
  at::cuda::solver::sytrf_bufferSize(handle, n, a_data, lda, &lwork);
  auto& allocator = *::c10::cuda::CUDACachingAllocator::get();
  auto work = allocator.allocate(sizeof(scalar_t) * lwork);

  for (const auto i : c10::irange(batch_size)) {
    auto* a_working_ptr = &a_data[i * a_stride];
    auto* pivots_working_ptr = &pivots_data[i * pivots_stride];
    auto* info_working_ptr = &info_data[i];
    at::cuda::solver::sytrf(
        handle,
        uplo,
        n,
        a_working_ptr,
        lda,
        pivots_working_ptr,
        reinterpret_cast<scalar_t*>(work.get()),
        lwork,
        info_working_ptr);
  }
}

Analyze Your Own Codebase

Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.

Try Supermodel Free