Home / Class/ apply_svd_cusolver_gesvdjBatched Class — pytorch Architecture

apply_svd_cusolver_gesvdjBatched Class — pytorch Architecture

Architecture documentation for the apply_svd_cusolver_gesvdjBatched class in BatchLinearAlgebraLib.cpp from the pytorch codebase.

Entity Profile

Source Code

aten/src/ATen/native/cuda/linalg/BatchLinearAlgebraLib.cpp lines 435–477

template<typename scalar_t>
static void apply_svd_cusolver_gesvdjBatched(const Tensor& A, const Tensor& U, const Tensor& S, const Tensor& V,
  const Tensor& infos, bool compute_uv
) {
  using value_t = typename c10::scalar_value_type<scalar_t>::type;
  int m = cuda_int_cast(A.size(-2), "m");
  int n = cuda_int_cast(A.size(-1), "n");
  int batchsize = cuda_int_cast(batchCount(A), "batch size");
  auto lda = std::max<int>(1, m);
  auto ldu = std::max<int>(1, m);
  auto ldv = std::max<int>(1, n);

  // Need to pass allocated memory to the function, otherwise it fails
  auto& allocator = *::c10::cuda::CUDACachingAllocator::get();
  auto dataPtr_U = !compute_uv ? allocator.allocate(sizeof(scalar_t) * batchsize * m * ldu) : c10::DataPtr{};
  auto dataPtr_V = !compute_uv ? allocator.allocate(sizeof(scalar_t) * batchsize * n * ldv) : c10::DataPtr{};

  auto A_data = A.data_ptr<scalar_t>();
  auto U_data = compute_uv ? U.data_ptr<scalar_t>() : reinterpret_cast<scalar_t*>(dataPtr_U.get());
  auto S_data = S.data_ptr<value_t>();
  auto V_data = compute_uv ? V.data_ptr<scalar_t>() : reinterpret_cast<scalar_t*>(dataPtr_V.get());

  TORCH_INTERNAL_ASSERT(m <= 32 && n <= 32, "gesvdjBatched requires both matrix dimensions not greater than 32, but got "
                        "m = ", m, " n = ", n);

  // gesvdj_params controls the numerical accuracy of cusolver gesvdj iterations on GPU
  gesvdjInfo_t gesvdj_params;
  TORCH_CUSOLVER_CHECK(cusolverDnCreateGesvdjInfo(&gesvdj_params));

  // Todo: expose the following two parameters to users
  TORCH_CUSOLVER_CHECK(cusolverDnXgesvdjSetTolerance(gesvdj_params, std::numeric_limits<scalar_t>::epsilon()));
  TORCH_CUSOLVER_CHECK(cusolverDnXgesvdjSetMaxSweeps(gesvdj_params, cusolver_gesvdj_max_sweeps));
  TORCH_CUSOLVER_CHECK(cusolverDnXgesvdjSetSortEig(gesvdj_params, 1));

  auto handle = at::cuda::getCurrentCUDASolverDnHandle();
  auto jobz = compute_uv ? CUSOLVER_EIG_MODE_VECTOR : CUSOLVER_EIG_MODE_NOVECTOR;
  at::cuda::solver::gesvdjBatched<scalar_t>(
    handle, jobz, m, n, A_data, lda, S_data, U_data, ldu, V_data, ldv,
    infos.data_ptr<int>(), gesvdj_params, batchsize
  );

  TORCH_CUSOLVER_CHECK(cusolverDnDestroyGesvdjInfo(gesvdj_params));
}

Analyze Your Own Codebase

Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.

Try Supermodel Free