apply_ormqr Class — pytorch Architecture
Architecture documentation for the apply_ormqr class in BatchLinearAlgebraKernel.cpp from the pytorch codebase.
Entity Profile
Source Code
aten/src/ATen/native/BatchLinearAlgebraKernel.cpp lines 714–769
template <typename scalar_t>
void apply_ormqr(const Tensor& input, const Tensor& tau, const Tensor& other, bool left, bool transpose) {
#if !AT_BUILD_WITH_LAPACK()
TORCH_CHECK(false, "Calling torch.ormqr on a CPU tensor requires compiling ",
"PyTorch with LAPACK. Please use PyTorch built with LAPACK support.");
#else
using value_t = typename c10::scalar_value_type<scalar_t>::type;
char side = left ? 'L' : 'R';
char trans = transpose ? (input.is_complex() ? 'C' : 'T') : 'N';
auto input_data = input.const_data_ptr<scalar_t>();
auto tau_data = tau.const_data_ptr<scalar_t>();
auto other_data = other.data_ptr<scalar_t>();
auto input_matrix_stride = matrixStride(input);
auto other_matrix_stride = matrixStride(other);
auto tau_stride = tau.size(-1);
auto batch_size = batchCount(input);
auto m = other.size(-2);
auto n = other.size(-1);
auto k = tau.size(-1);
auto lda = std::max<int64_t>(1, left ? m : n);
auto ldc = std::max<int64_t>(1, m);
int info = 0;
// LAPACK's requirement
TORCH_INTERNAL_ASSERT_DEBUG_ONLY((left ? m : n) >= k);
// Query for the optimal size of the workspace tensor
int lwork = -1;
scalar_t wkopt;
lapackOrmqr<scalar_t>(side, trans, m, n, k, const_cast<scalar_t*>(input_data), lda, const_cast<scalar_t*>(tau_data), other_data, ldc, &wkopt, lwork, &info);
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(info == 0);
lwork = std::max<int>(1, real_impl<scalar_t, value_t>(wkopt));
Tensor work = at::empty({lwork}, input.options());
for (const auto i : c10::irange(batch_size)) {
const scalar_t* input_working_ptr = &input_data[i * input_matrix_stride];
scalar_t* other_working_ptr = &other_data[i * other_matrix_stride];
const scalar_t* tau_working_ptr = &tau_data[i * tau_stride];
// now compute the actual result
lapackOrmqr<scalar_t>(
side, trans, m, n, k,
const_cast<scalar_t*>(input_working_ptr), lda,
const_cast<scalar_t*>(tau_working_ptr),
other_working_ptr, ldc,
work.data_ptr<scalar_t>(), lwork, &info);
// info from lapackOrmqr only reports if the i-th parameter is wrong
// so we don't need to check it all the time
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(info == 0);
}
#endif
}
Source
Analyze Your Own Codebase
Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.
Try Supermodel Free