bgemm_internal_cublaslt Class — pytorch Architecture
Architecture documentation for the bgemm_internal_cublaslt class in CUDABlas.cpp from the pytorch codebase.
Entity Profile
Source Code
aten/src/ATen/cuda/CUDABlas.cpp lines 368–605
template <typename Dtype, typename C_Dtype = Dtype>
static inline bool bgemm_internal_cublaslt(CUDABLAS_BGEMM_ARGTYPES_AND_C_DTYPE(Dtype, C_Dtype)) {
#if defined(USE_ROCM) && ROCM_VERSION == 60400
// regression in ROCm 6.4, planned fixed in 6.4.1, hipblaslt TT fp32 calculation errors
// best to disallow hipblaslt for this specific case
if constexpr (std::is_same_v<Dtype, float>) {
if (_cublasOpFromChar(transa) == CUBLAS_OP_T && _cublasOpFromChar(transb) == CUBLAS_OP_T) {
return false;
}
}
#endif
cudaDataType_t abType = CUDA_R_32F;
cudaDataType_t cType = CUDA_R_32F;
cublasComputeType_t computeType = CUBLAS_COMPUTE_32F;
cudaDataType_t scaleType = CUDA_R_32F;
CuBlasLtMatmulPreference preference;
#ifndef USE_ROCM
at::Half halpha;
at::Half hbeta;
uint32_t mask = -1;
#endif
void * alpha_ptr = α
void * beta_ptr = β
if constexpr (std::is_same_v<Dtype, double>) {
abType = CUDA_R_64F;
cType = CUDA_R_64F;
computeType = CUBLAS_COMPUTE_64F;
scaleType = CUDA_R_64F;
} else if constexpr (std::is_same_v<Dtype, float>) {
if (at::globalContext().float32Precision(at::Float32Backend::CUDA, at::Float32Op::MATMUL) == at::Float32Precision::TF32) {
computeType = CUBLAS_COMPUTE_32F_FAST_TF32;
}
} else if constexpr (std::is_same_v<Dtype, c10::complex<double>>) {
abType = CUDA_C_64F;
cType = CUDA_C_64F;
computeType = CUBLAS_COMPUTE_64F;
scaleType = CUDA_C_64F;
} else if constexpr (std::is_same_v<Dtype, c10::complex<float>>) {
abType = CUDA_C_32F;
cType = CUDA_C_32F;
scaleType = CUDA_C_32F;
} else if constexpr (std::is_same_v<Dtype, at::Half>) {
#ifndef USE_ROCM
cudaDeviceProp* prop = at::cuda::getCurrentDeviceProperties();
if (prop->major >= 7 && at::globalContext().allowFP16AccumulationCuBLAS()) {
computeType = CUBLAS_COMPUTE_16F;
scaleType = CUDA_R_16F;
halpha = alpha;
hbeta = beta;
alpha_ptr = &halpha;
beta_ptr = &hbeta;
}
#endif
abType = CUDA_R_16F;
cType = std::is_same_v<C_Dtype, float> ? CUDA_R_32F : CUDA_R_16F;
#ifndef USE_ROCM
auto fp16_reduction = at::globalContext().allowFP16ReductionCuBLAS();
if (fp16_reduction !=
at::CuBLASReductionOption::AllowReducedPrecisionWithSplitK) {
mask =
fp16_reduction ==
at::CuBLASReductionOption::DisallowReducedPrecisionAllowSplitK
? (CUBLASLT_REDUCTION_SCHEME_COMPUTE_TYPE |
CUBLASLT_REDUCTION_SCHEME_NONE)
: CUBLASLT_REDUCTION_SCHEME_NONE;
preference.setAttribute(
CUBLASLT_MATMUL_PREF_REDUCTION_SCHEME_MASK, mask);
}
#endif
} else if constexpr (std::is_same_v<Dtype, at::BFloat16>) {
abType = CUDA_R_16BF;
cType = std::is_same_v<C_Dtype, float> ? CUDA_R_32F : CUDA_R_16BF;
#ifndef USE_ROCM
auto bf16_reduction = at::globalContext().allowBF16ReductionCuBLAS();
if (bf16_reduction !=
at::CuBLASReductionOption::AllowReducedPrecisionWithSplitK) {
mask =
bf16_reduction ==
at::CuBLASReductionOption::DisallowReducedPrecisionAllowSplitK
? (CUBLASLT_REDUCTION_SCHEME_COMPUTE_TYPE |
CUBLASLT_REDUCTION_SCHEME_NONE)
: CUBLASLT_REDUCTION_SCHEME_NONE;
preference.setAttribute(
CUBLASLT_MATMUL_PREF_REDUCTION_SCHEME_MASK, mask);
}
#endif
} else {
static_assert(false && sizeof(Dtype), "at::cuda::blas::bgemm_internal_cublaslt: not implemented");
}
cublasLtHandle_t ltHandle = at::cuda::getCurrentCUDABlasLtHandle();
cublasOperation_t opa = _cublasOpFromChar(transa);
cublasOperation_t opb = _cublasOpFromChar(transb);
_cublasAdjustLdLevel3(transa, transb, m, n, k, &lda, &ldb, &ldc);
CuBlasLtMatmulDescriptor computeDesc(computeType, scaleType);
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_TRANSA, opa);
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_TRANSB, opb);
auto stream = at::cuda::getCurrentCUDAStream();
#ifndef USE_ROCM
if (at::globalContext()._SMCarveout_EXPERIMENTAL().has_value()) {
computeDesc.setAttribute<int32_t>(
CUBLASLT_MATMUL_DESC_SM_COUNT_TARGET,
at::cuda::getCurrentDeviceProperties()->multiProcessorCount -
at::globalContext()._SMCarveout_EXPERIMENTAL().value());
}
#else
if (at::globalContext()._SMCarveout_EXPERIMENTAL().has_value()) {
stream = _getCarveoutStream(
at::globalContext()._SMCarveout_EXPERIMENTAL().value());
_syncCurrentWithCarveoutStream(stream, true);
}
#endif
CuBlasLtMatrixLayout Adesc(abType, m, k, lda, opa == CUBLAS_OP_T);
CuBlasLtMatrixLayout Bdesc(abType, k, n, ldb, opb == CUBLAS_OP_T);
CuBlasLtMatrixLayout Cdesc(cType, m, n, ldc);
if (num_batches > 1) {
int num_batches_as_int = static_cast<int>(num_batches);
Adesc.setAttribute(CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, num_batches_as_int);
Bdesc.setAttribute(CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, num_batches_as_int);
Cdesc.setAttribute(CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, num_batches_as_int);
Adesc.setAttribute(CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, stridea);
Bdesc.setAttribute(CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, strideb);
Cdesc.setAttribute(CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, stridec);
}
#ifndef USE_ROCM
uint32_t a_alignment = _getAlignment(reinterpret_cast<uintptr_t>(a));
uint32_t b_alignment = _getAlignment(reinterpret_cast<uintptr_t>(b));
uint32_t c_alignment = _getAlignment(reinterpret_cast<uintptr_t>(c));
preference.setAttribute(CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_A_BYTES, a_alignment);
preference.setAttribute(CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_B_BYTES, b_alignment);
preference.setAttribute(CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_C_BYTES, c_alignment);
#endif
auto ltworkspace = CublasLtWorkspace();
TORCH_CHECK(ltworkspace.ptr != nullptr, "OOM trying to allocate workspace for cublaslt");
preference.setAttribute(CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, ltworkspace.size);
cublasStatus_t cublasStatus = CUBLAS_STATUS_SUCCESS;
cublasLtMatmulHeuristicResult_t heuristicResult = {};
int returnedResult = 0;
// on Blackwell+, we fake a n > 1 matmul when querying heuristics
// to prevent cuBLASLt from dispatching to a GEMV kernel for batch-invariance
#ifndef USE_ROCM
const bool lie_to_cublaslt = mask == CUBLASLT_REDUCTION_SCHEME_NONE && n == 1 && at::cuda::getCurrentDeviceProperties()->major >= 10;
#else
const bool lie_to_cublaslt = false;
#endif
if (lie_to_cublaslt) {
CuBlasLtMatrixLayout FakeBdesc(abType, k, 2, ldb, opb == CUBLAS_OP_T);
CuBlasLtMatrixLayout FakeCdesc(cType, m, 2, ldc);
TORCH_CUDABLAS_CHECK(cublasLtMatmulAlgoGetHeuristic(
ltHandle,
computeDesc.descriptor(),
Adesc.descriptor(),
FakeBdesc.descriptor(),
FakeCdesc.descriptor(),
FakeCdesc.descriptor(),
preference.descriptor(),
1,
&heuristicResult,
&returnedResult));
} else {
TORCH_CUDABLAS_CHECK(cublasLtMatmulAlgoGetHeuristic(
ltHandle,
computeDesc.descriptor(),
Adesc.descriptor(),
Bdesc.descriptor(),
Cdesc.descriptor(),
Cdesc.descriptor(),
preference.descriptor(),
1,
&heuristicResult,
&returnedResult));
}
if (returnedResult == 0) {
cublasStatus = CUBLAS_STATUS_NOT_SUPPORTED;
}
else {
cublasStatus = cublasLtMatmul(
ltHandle,
computeDesc.descriptor(),
alpha_ptr,
a,
Adesc.descriptor(),
b,
Bdesc.descriptor(),
beta_ptr,
c,
Cdesc.descriptor(),
c,
Cdesc.descriptor(),
&heuristicResult.algo,
ltworkspace.ptr,
ltworkspace.size,
stream);
#ifdef USE_ROCM
if (at::globalContext()._SMCarveout_EXPERIMENTAL().has_value()) {
_syncCurrentWithCarveoutStream(stream, false);
}
#endif
}
if (cublasStatus != CUBLAS_STATUS_SUCCESS) {
TORCH_WARN(
"bgemm_internal_cublaslt error: ",
at::cuda::blas::_cublasGetErrorEnum(cublasStatus),
" when calling cublasLtMatmul with transpose_mat1 ",
(opa == CUBLAS_OP_T),
" transpose_mat2 ",
(opb == CUBLAS_OP_T),
" m ",
m,
" n ",
n,
" k ",
k,
" lda ",
lda,
" ldb ",
ldb,
" ldc ",
ldc,
" abType ",
abType,
" cType ",
cType,
" computeType ",
computeType,
" scaleType ",
scaleType,
". Will attempt to recover by calling cublas instead.");
return false;
}
return true;
}
Source
Analyze Your Own Codebase
Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.
Try Supermodel Free