Home / Class/ apply_svd Class — pytorch Architecture

apply_svd Class — pytorch Architecture

Architecture documentation for the apply_svd class in BatchLinearAlgebraKernel.cpp from the pytorch codebase.

Entity Profile

Source Code

aten/src/ATen/native/BatchLinearAlgebraKernel.cpp lines 1097–1158

template <typename scalar_t>
void apply_svd(const Tensor& A,
                      const bool full_matrices,
                      const bool compute_uv,
                      const Tensor& U,
                      const Tensor& S,
                      const Tensor& Vh,
                      const Tensor& info) {
#if !AT_BUILD_WITH_LAPACK()
  TORCH_CHECK(false, "svd: LAPACK library not found in compilation");
#else
  using value_t = typename c10::scalar_value_type<scalar_t>::type;
  const auto A_data = A.data_ptr<scalar_t>();
  const auto U_data = compute_uv ? U.data_ptr<scalar_t>() : nullptr;
  const auto S_data = S.data_ptr<value_t>();
  const auto info_data = info.data_ptr<int>();
  const auto Vh_data = compute_uv ? Vh.data_ptr<scalar_t>() : nullptr;
  const auto A_stride = matrixStride(A);
  const auto S_stride = S.size(-1);
  const auto U_stride = compute_uv ? matrixStride(U) : 1;
  const auto Vh_stride = compute_uv ? matrixStride(Vh) : 1;
  const auto batchsize = batchCount(A);
  const char jobz = compute_uv ? (full_matrices ? 'A' : 'S') : 'N';

  const auto m = A.size(-2);
  const auto n = A.size(-1);
  const auto lda = A.stride(-1);
  const auto ldu= compute_uv ? U.stride(-1) : 1;
  const auto ldvh = compute_uv ? Vh.stride(-1) : 1;

  auto iwork = std::vector<int>(8 * std::min(m, n));
  auto* const iwork_data = iwork.data();

  // rwork is just used for the complex decomposition
  auto rwork = std::vector<value_t>{};
  if (A.is_complex()) {
    rwork.resize(std::max(computeLRWorkDim(jobz, m, n), int64_t{1}));
  }
  auto* const rwork_data = rwork.data();

  // Query svd for the optimal lwork size
  int lwork = -1;
  {
    scalar_t wkopt;
    lapackSvd<scalar_t, value_t>(jobz, m, n, A_data, lda, S_data, U_data, ldu, Vh_data, ldvh, &wkopt, lwork, rwork_data, iwork_data, info_data);
    lwork = lapack_work_to_int(wkopt);
  }
  auto work = std::vector<scalar_t>(lwork);
  auto* const work_data = work.data();

  for (const auto i : c10::irange(batchsize)) {
    auto* const A_working_ptr = &A_data[i * A_stride];
    auto* const S_working_ptr = &S_data[i * S_stride];
    auto* const U_working_ptr = compute_uv ? &U_data[i * U_stride] : nullptr;
    auto* const Vh_working_ptr = compute_uv ? &Vh_data[i * Vh_stride] : nullptr;

    // Compute S, U (optionally) and Vh (optionally)
    lapackSvd<scalar_t, value_t>(jobz, m, n, A_working_ptr, lda,
                        S_working_ptr, U_working_ptr, ldu, Vh_working_ptr, ldvh, work_data, lwork, rwork_data, iwork_data, info_data + i);
  }
#endif
}

Analyze Your Own Codebase

Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.

Try Supermodel Free