gemm_internal_cublas_half_helper Class — pytorch Architecture
Architecture documentation for the gemm_internal_cublas_half_helper class in CUDABlas.cpp from the pytorch codebase.
Entity Profile
Source Code
aten/src/ATen/cuda/CUDABlas.cpp lines 1103–1216
template <typename C_Dtype>
inline void gemm_internal_cublas_half_helper(CUDABLAS_GEMM_ARGTYPES_AND_C_DTYPE(at::Half, C_Dtype)) {
cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
cublasOperation_t opa = _cublasOpFromChar(transa);
cublasOperation_t opb = _cublasOpFromChar(transb);
float falpha = alpha;
float fbeta = beta;
#ifndef USE_ROCM
at::Half halpha;
at::Half hbeta;
auto compute_type = CUDA_R_32F;
#endif
void * alpha_ptr = &falpha;
void * beta_ptr = &fbeta;
_cublasAdjustLdLevel3(transa, transb, m, n, k, &lda, &ldb, &ldc);
GEMM_CHECK_ARGVALUES(at::Half);
#ifdef USE_ROCM
int flag = 0;
rocblas_datatype c_type = std::is_same<C_Dtype, float>::value ? rocblas_datatype_f32_r : rocblas_datatype_f16_r;
rocblas_datatype d_type = c_type;
#if USE_GEMM_FLAGS_FP16_ALT_IMPL
flag = at::ROCmBackwardPassGuard::is_backward_pass() ? rocblas_gemm_flags_fp16_alt_impl : 0;
#endif
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex(
(rocblas_handle)handle,
hipOperationToRocOperation(opa),
hipOperationToRocOperation(opb),
m,
n,
k,
alpha_ptr,
a,
rocblas_datatype_f16_r,
lda,
b,
rocblas_datatype_f16_r,
ldb,
beta_ptr,
c,
c_type,
ldc,
c,
d_type,
ldc,
rocblas_datatype_f32_r,
rocblas_gemm_algo_standard,
0,
flag)));
#else
cudaDeviceProp* prop = at::cuda::getCurrentDeviceProperties();
if (prop->major >= 7 && at::globalContext().allowFP16AccumulationCuBLAS()) {
compute_type = CUDA_R_16F;
halpha = alpha;
hbeta = beta;
alpha_ptr = &halpha;
beta_ptr = &hbeta;
}
if (prop->major >= 5) {
cublasMath_t cublas_flags = CUBLAS_DEFAULT_MATH;
auto fp16_reduction = at::globalContext().allowFP16ReductionCuBLAS();
TORCH_CHECK(fp16_reduction !=
at::CuBLASReductionOption::DisallowReducedPrecisionDisallowSplitK,
"torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction("
"..., allow_splitk=False) requires the cuBLASLt backend");
if (fp16_reduction !=
at::CuBLASReductionOption::AllowReducedPrecisionWithSplitK) {
cublas_flags = static_cast<cublasMath_t>(
cublas_flags | CUBLAS_MATH_DISALLOW_REDUCED_PRECISION_REDUCTION);
}
// Disallow fp16 reductions that could lead to unexpected overflow issues.
TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, cublas_flags));
TORCH_CUDABLAS_CHECK(cublasGemmEx(
handle,
opa,
opb,
m,
n,
k,
alpha_ptr,
a,
CUDA_R_16F,
lda,
b,
CUDA_R_16F,
ldb,
beta_ptr,
c,
std::is_same_v<C_Dtype, float> ? CUDA_R_32F : CUDA_R_16F,
ldc,
compute_type,
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
} else {
TORCH_CUDABLAS_CHECK(cublasSgemmEx(
handle,
opa,
opb,
m,
n,
k,
&falpha,
a,
CUDA_R_16F,
lda,
b,
CUDA_R_16F,
ldb,
&fbeta,
c,
std::is_same_v<C_Dtype, float> ? CUDA_R_32F : CUDA_R_16F,
ldc));
}
#endif
}
Source
Analyze Your Own Codebase
Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.
Try Supermodel Free