Home / Class/ apply_triangular_solve Class — pytorch Architecture

apply_triangular_solve Class — pytorch Architecture

Architecture documentation for the apply_triangular_solve class in BatchLinearAlgebraKernel.cpp from the pytorch codebase.

Entity Profile

Source Code

aten/src/ATen/native/BatchLinearAlgebraKernel.cpp lines 788–818

template<typename scalar_t>
void apply_triangular_solve(const Tensor& A, const Tensor& B, bool left, bool upper, TransposeType transpose, bool unitriangular) {
#if !AT_BUILD_WITH_BLAS()
  TORCH_CHECK(
      false,
      "Calling torch.triangular_solve on a CPU tensor requires compiling ",
      "PyTorch with BLAS. Please use PyTorch built with BLAS support.");
#else
  char uplo = upper ? 'U' : 'L';
  char diag = unitriangular ? 'U' : 'N';
  char side = left ? 'L' : 'R';
  const char trans = to_blas(transpose);

  auto A_data = A.const_data_ptr<scalar_t>();
  auto B_data = B.data_ptr<scalar_t>();
  auto A_mat_stride = matrixStride(A);
  auto B_mat_stride = matrixStride(B);
  auto batch_size = batchCount(A);
  // This allows to pass rectangular A and B when left = True
  auto m = left ? A.size(-1) : B.size(-2);
  auto n = B.size(-1);
  auto lda = std::max<int64_t>(1, A.size(-2));
  auto ldb = std::max<int64_t>(1, B.size(-2));

  for (const auto i : c10::irange(batch_size)) {
    const scalar_t* A_working_ptr = &A_data[i * A_mat_stride];
    scalar_t* B_working_ptr = &B_data[i * B_mat_stride];
    blasTriangularSolve<scalar_t>(side, uplo, trans, diag, m, n, const_cast<scalar_t*>(A_working_ptr), lda, B_working_ptr, ldb);
  }
#endif
}

Analyze Your Own Codebase

Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.

Try Supermodel Free