Home / Class/ apply_svd_cusolver_gesvdj Class — pytorch Architecture

apply_svd_cusolver_gesvdj Class — pytorch Architecture

Architecture documentation for the apply_svd_cusolver_gesvdj class in BatchLinearAlgebraLib.cpp from the pytorch codebase.

Entity Profile

Source Code

aten/src/ATen/native/cuda/linalg/BatchLinearAlgebraLib.cpp lines 353–423

template<typename scalar_t>
static void apply_svd_cusolver_gesvdj(const Tensor& A, const Tensor& U, const Tensor& S, const Tensor& V,
  const Tensor& infos, bool full_matrices, bool compute_uv) {
  using value_t = typename c10::scalar_value_type<scalar_t>::type;
  int m = cuda_int_cast(A.size(-2), "m");
  int n = cuda_int_cast(A.size(-1), "n");
  int k = std::min(m, n);

  // Need to pass allocated memory to the function, otherwise it fails
  auto& allocator = *::c10::cuda::CUDACachingAllocator::get();
  auto dataPtr_U = !compute_uv ? allocator.allocate(sizeof(scalar_t)* m * k) : c10::DataPtr{};
  auto dataPtr_V = !compute_uv ? allocator.allocate(sizeof(scalar_t)* n * k) : c10::DataPtr{};

  auto A_data = A.data_ptr<scalar_t>();
  auto U_data = compute_uv ? U.data_ptr<scalar_t>() : reinterpret_cast<scalar_t*>(dataPtr_U.get());
  auto S_data = S.data_ptr<value_t>();
  auto V_data = compute_uv ? V.data_ptr<scalar_t>() : reinterpret_cast<scalar_t*>(dataPtr_V.get());
  auto A_stride = matrixStride(A);
  auto U_stride = compute_uv ? matrixStride(U) : 0;
  auto S_stride = S.size(-1);
  auto V_stride = compute_uv ? matrixStride(V) : 0;

  int batchsize = cuda_int_cast(batchCount(A), "batch size");
  int lda = A.stride(-1);
  int ldu = compute_uv ? U.stride(-1) : m;
  int ldv = compute_uv ? V.stride(-1) : n;

  auto handle = at::cuda::getCurrentCUDASolverDnHandle();
  auto jobz = compute_uv ? CUSOLVER_EIG_MODE_VECTOR : CUSOLVER_EIG_MODE_NOVECTOR;
  int econ = full_matrices ? 0 : 1;

  // gesvdj_params controls the numerical accuracy of cusolver gesvdj iterations on GPU
  gesvdjInfo_t gesvdj_params;
  TORCH_CUSOLVER_CHECK(cusolverDnCreateGesvdjInfo(&gesvdj_params));

  // Todo: expose the following two parameters to users
  TORCH_CUSOLVER_CHECK(cusolverDnXgesvdjSetTolerance(gesvdj_params, std::numeric_limits<scalar_t>::epsilon()));
  TORCH_CUSOLVER_CHECK(cusolverDnXgesvdjSetMaxSweeps(gesvdj_params, cusolver_gesvdj_max_sweeps));

  int lwork = -1;
  at::cuda::solver::gesvdj_buffersize<scalar_t>(
    handle, jobz, econ, m, n, A_data, lda, S_data, U_data, ldu, V_data, ldv, &lwork, gesvdj_params);
  TORCH_INTERNAL_ASSERT(lwork >= 0, "gesvdj_buffersize failed to get needed buffer size, got lwork = ", lwork);

  auto dataPtr = allocator.allocate(sizeof(scalar_t)*lwork);

  for(int i = 0; i < batchsize; i++){
    at::cuda::solver::gesvdj<scalar_t>(
      handle, jobz, econ, m, n,
      A_data + i * A_stride,
      lda,
      S_data + i * S_stride,
      U_data + i * U_stride,
      ldu,
      V_data + i * V_stride,
      ldv,
      reinterpret_cast<scalar_t*>(dataPtr.get()),
      lwork,
      infos.data_ptr<int>() + i,
      gesvdj_params
    );

    // // The following code can be used to check or report the gesvdj residual.
    // // Note: this will introduce a device-host sync and may negatively affect the performance
    // double residual = 0;
    // TORCH_CUSOLVER_CHECK(cusolverDnXgesvdjGetResidual(handle, gesvdj_params, &residual));
    // printf("gesvdj residual = %.6e\n", residual);
  }

  TORCH_CUSOLVER_CHECK(cusolverDnDestroyGesvdjInfo(gesvdj_params));
}

Analyze Your Own Codebase

Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.

Try Supermodel Free