apply_orgqr Class — pytorch Architecture
Architecture documentation for the apply_orgqr class in BatchLinearAlgebraLib.cpp from the pytorch codebase.
Entity Profile
Source Code
aten/src/ATen/native/cuda/linalg/BatchLinearAlgebraLib.cpp lines 1144–1199
template <typename scalar_t>
static void apply_orgqr(Tensor& self, const Tensor& tau) {
auto self_data = self.data_ptr<scalar_t>();
auto tau_data = tau.const_data_ptr<scalar_t>();
auto self_matrix_stride = matrixStride(self);
auto batchsize = cuda_int_cast(batchCount(self), "batch size");
auto m = cuda_int_cast(self.size(-2), "m");
auto n = cuda_int_cast(self.size(-1), "n");
auto k = cuda_int_cast(tau.size(-1), "k");
auto tau_stride = std::max<int>(1, k);
auto lda = std::max<int>(1, m);
// LAPACK's requirement
TORCH_INTERNAL_ASSERT(m >= n);
TORCH_INTERNAL_ASSERT(n >= k);
// cuSOLVER doesn't compute anything for this case, which is wrong
// the result should be a matrix with 1 on the diagonal
if (k == 0) {
self.fill_(0);
self.diagonal(/*offset=*/0, /*dim1=*/-2, /*dim2=*/-1).fill_(1);
return;
}
// get the optimal work size and allocate workspace tensor
int lwork;
at::cuda::solver::orgqr_buffersize<scalar_t>(
at::cuda::getCurrentCUDASolverDnHandle(), m, n, k, self_data, lda, tau_data, &lwork);
auto info = at::zeros({1}, self.options().dtype(at::kInt));
auto info_data = info.data_ptr<int>();
for (auto i = decltype(batchsize){0}; i < batchsize; i++) {
scalar_t* self_working_ptr = &self_data[i * self_matrix_stride];
const scalar_t* tau_working_ptr = &tau_data[i * tau_stride];
auto handle = at::cuda::getCurrentCUDASolverDnHandle();
// allocate workspace storage
auto& allocator = *at::cuda::getCUDADeviceAllocator();
auto work_data = allocator.allocate(sizeof(scalar_t)*lwork);
at::cuda::solver::orgqr<scalar_t>(
handle, m, n, k,
self_working_ptr,
lda,
tau_working_ptr,
static_cast<scalar_t*>(work_data.get()),
lwork,
info_data
);
// info from orgqr only reports if the i-th parameter is wrong
// so we don't need to check it all the time
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(info.item().toInt() == 0);
}
}
Source
Analyze Your Own Codebase
Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.
Try Supermodel Free