Home / Class/ apply_gels_batched Class — pytorch Architecture

apply_gels_batched Class — pytorch Architecture

Architecture documentation for the apply_gels_batched class in BatchLinearAlgebraLibBlas.cpp from the pytorch codebase.

Entity Profile

Source Code

aten/src/ATen/native/cuda/linalg/BatchLinearAlgebraLibBlas.cpp lines 240–292

template <typename scalar_t>
inline void apply_gels_batched(const Tensor& A, Tensor& B, Tensor& infos) {
  auto trans = CUBLAS_OP_N;
  auto m = cuda_int_cast(A.size(-2), "m");
  auto n = cuda_int_cast(A.size(-1), "n");

  auto nrhs = cuda_int_cast(B.size(-1), "nrhs");
  // cuBLAS from cuda10 and older doesn't work with nrhs == 0 (cuda11 works)
  // so we need to put this early return
  if (nrhs == 0) {
    return;
  }

  auto batch_size = cuda_int_cast(batchCount(B), "batch_size");
  auto lda = std::max<int>(1, m);
  auto ldb = std::max<int>(1, m);

  // cuBLAS's requirement
  TORCH_CHECK(
    m >= n,
    "torch.linalg.lstsq: only overdetermined systems (input.size(-2) >= input.size(-1)) are allowed on CUDA with cuBLAS backend.");

  // cuBLAS documentation says:
  // Matrices Aarray[i] should not overlap; otherwise, undefined behavior is expected.
  // explicitly broadcast the batch dimensions of A
  IntArrayRef A_batch_sizes(A.sizes().data(), A.dim() - 2);
  IntArrayRef B_batch_sizes(B.sizes().data(), B.dim() - 2);
  std::vector<int64_t> expand_batch_portion = at::infer_size(A_batch_sizes, B_batch_sizes);
  expand_batch_portion.insert(expand_batch_portion.end(), {A.size(-2), A.size(-1)});
  Tensor A_expanded = A.expand({expand_batch_portion});
  Tensor A_broadcasted = cloneBatchedColumnMajor(A_expanded);

  // cuBLAS batched gels requires input to be the device array of pointers to device single matrices
  Tensor A_ptr_array = get_device_pointers<scalar_t>(A_broadcasted);
  Tensor B_ptr_array = get_device_pointers<scalar_t>(B);
  auto A_ptr_array_data = reinterpret_cast<scalar_t**>(A_ptr_array.data_ptr());
  auto B_ptr_array_data = reinterpret_cast<scalar_t**>(B_ptr_array.data_ptr());

  auto infos_data = infos.data_ptr<int>();
  auto handle = at::cuda::getCurrentCUDABlasHandle();
  int info;

  at::cuda::blas::gelsBatched<scalar_t>(
    handle, trans, m, n, nrhs,
    A_ptr_array_data, lda,
    B_ptr_array_data, ldb,
    &info,
    infos_data,
    batch_size);

  // negative info indicates that an argument to gelsBatched call is invalid
  TORCH_INTERNAL_ASSERT(info == 0);
}

Analyze Your Own Codebase

Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.

Try Supermodel Free