apply_ormqr Class — pytorch Architecture
Architecture documentation for the apply_ormqr class in BatchLinearAlgebraLib.cpp from the pytorch codebase.
Entity Profile
Source Code
aten/src/ATen/native/cuda/linalg/BatchLinearAlgebraLib.cpp lines 1071–1124
template <typename scalar_t>
static void apply_ormqr(const Tensor& input, const Tensor& tau, const Tensor& other, bool left, bool transpose) {
auto side = left ? CUBLAS_SIDE_LEFT : CUBLAS_SIDE_RIGHT;
auto trans = transpose ? (input.is_complex() ? CUBLAS_OP_C : CUBLAS_OP_T) : CUBLAS_OP_N;
auto input_data = input.const_data_ptr<scalar_t>();
auto tau_data = tau.const_data_ptr<scalar_t>();
auto other_data = other.data_ptr<scalar_t>();
auto input_matrix_stride = matrixStride(input);
auto other_matrix_stride = matrixStride(other);
auto tau_stride = tau.size(-1);
auto batch_size = batchCount(input);
auto m = cuda_int_cast(other.size(-2), "m");
auto n = cuda_int_cast(other.size(-1), "n");
auto k = cuda_int_cast(tau.size(-1), "k");
auto lda = std::max<int>(1, left ? m : n);
auto ldc = std::max<int>(1, m);
// get the optimal work size and allocate workspace tensor
int lwork;
at::cuda::solver::ormqr_bufferSize<scalar_t>(
at::cuda::getCurrentCUDASolverDnHandle(), side, trans, m, n, k, input_data, lda, tau_data, other_data, ldc, &lwork);
auto info = at::zeros({1}, input.options().dtype(at::kInt));
auto info_data = info.data_ptr<int>();
for (auto i = decltype(batch_size){0}; i < batch_size; i++) {
const scalar_t* input_working_ptr = &input_data[i * input_matrix_stride];
scalar_t* other_working_ptr = &other_data[i * other_matrix_stride];
const scalar_t* tau_working_ptr = &tau_data[i * tau_stride];
auto handle = at::cuda::getCurrentCUDASolverDnHandle();
// allocate workspace storage
auto& allocator = *at::cuda::getCUDADeviceAllocator();
auto work_data = allocator.allocate(sizeof(scalar_t)*lwork);
at::cuda::solver::ormqr<scalar_t>(
handle, side, trans, m, n, k,
input_working_ptr,
lda,
tau_working_ptr,
other_working_ptr,
ldc,
static_cast<scalar_t*>(work_data.get()),
lwork,
info_data
);
// info from ormqr only reports if the i-th parameter is wrong
// so we don't need to check it all the time
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(info.item().toInt() == 0);
}
}
Source
Analyze Your Own Codebase
Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.
Try Supermodel Free