gemm_and_bias Class — pytorch Architecture
Architecture documentation for the gemm_and_bias class in CUDABlas.cpp from the pytorch codebase.
Entity Profile
Source Code
aten/src/ATen/cuda/CUDABlas.cpp lines 1566–1796
template <typename Dtype, typename C_Dtype>
bool gemm_and_bias(
bool transpose_mat1,
bool transpose_mat2,
int64_t m,
int64_t n,
int64_t k,
at::opmath_type<Dtype> alpha_val,
const Dtype* mat1_ptr,
int64_t mat1_ld,
const Dtype* mat2_ptr,
int64_t mat2_ld,
const Dtype* bias,
C_Dtype* result_ptr,
int64_t result_ld,
GEMMAndBiasActivationEpilogue activation) {
if (std::is_same_v<C_Dtype, float> && std::is_same_v<Dtype, at::BFloat16>) {
#ifdef USE_ROCM
TORCH_CHECK(false, "gemm input type at::BFloat16 and output type float is not supported for ROCm");
#endif
} else if (std::is_same_v<C_Dtype, float> && std::is_same_v<Dtype, at::Half>) {
#ifdef USE_ROCM
TORCH_CHECK(false, "gemm input type at::Half and output type float is not supported for ROCm");
#endif
if (at::globalContext().allowFP16AccumulationCuBLAS())
TORCH_CHECK(false, "gemm input type at::Half and output type float is not supported with allowFP16AccumulationCuBLAS");
}
using opmath_t = at::opmath_type<Dtype>;
opmath_t beta_val = bias ? 0 : 1; // bias is added in epilogue unless nullptr
cudaDataType_t abType = CUDA_R_32F;
cudaDataType_t cType = CUDA_R_32F;
cublasComputeType_t computeType = CUBLAS_COMPUTE_32F;
cudaDataType_t scaleType = CUDA_R_32F;
CuBlasLtMatmulPreference preference;
void * alpha_ptr = &alpha_val;
void * beta_ptr = &beta_val;
#ifndef USE_ROCM
at::Half halpha_val;
at::Half hbeta_val;
#endif
if constexpr (std::is_same_v<Dtype, double>) {
abType = CUDA_R_64F;
cType = CUDA_R_64F;
computeType = CUBLAS_COMPUTE_64F;
scaleType = CUDA_R_64F;
} else if constexpr (std::is_same_v<Dtype, float>) {
if (at::globalContext().float32Precision(at::Float32Backend::CUDA, at::Float32Op::MATMUL) == at::Float32Precision::TF32) {
computeType = CUBLAS_COMPUTE_32F_FAST_TF32;
}
} else if constexpr (std::is_same_v<Dtype, at::Half>) {
#ifndef USE_ROCM
cudaDeviceProp* prop = at::cuda::getCurrentDeviceProperties();
if (prop->major >= 7 && at::globalContext().allowFP16AccumulationCuBLAS()) {
computeType = CUBLAS_COMPUTE_16F;
scaleType = CUDA_R_16F;
halpha_val = alpha_val;
hbeta_val = beta_val;
alpha_ptr = &halpha_val;
beta_ptr = &hbeta_val;
}
#endif
abType = CUDA_R_16F;
cType = std::is_same_v<C_Dtype, float> ? CUDA_R_32F : CUDA_R_16F;
#ifndef USE_ROCM
auto fp16_reduction = at::globalContext().allowFP16ReductionCuBLAS();
if (fp16_reduction !=
at::CuBLASReductionOption::AllowReducedPrecisionWithSplitK) {
uint32_t mask =
fp16_reduction ==
at::CuBLASReductionOption::DisallowReducedPrecisionAllowSplitK
? (CUBLASLT_REDUCTION_SCHEME_COMPUTE_TYPE |
CUBLASLT_REDUCTION_SCHEME_NONE)
: CUBLASLT_REDUCTION_SCHEME_NONE;
preference.setAttribute(
CUBLASLT_MATMUL_PREF_REDUCTION_SCHEME_MASK, mask);
}
#endif
} else if constexpr (std::is_same_v<Dtype, at::BFloat16>) {
abType = CUDA_R_16BF;
cType = std::is_same_v<C_Dtype, float> ? CUDA_R_32F : CUDA_R_16BF;
#ifndef USE_ROCM
auto bf16_reduction = at::globalContext().allowBF16ReductionCuBLAS();
if (bf16_reduction !=
at::CuBLASReductionOption::AllowReducedPrecisionWithSplitK) {
uint32_t mask =
bf16_reduction ==
at::CuBLASReductionOption::DisallowReducedPrecisionAllowSplitK
? (CUBLASLT_REDUCTION_SCHEME_COMPUTE_TYPE |
CUBLASLT_REDUCTION_SCHEME_NONE)
: CUBLASLT_REDUCTION_SCHEME_NONE;
preference.setAttribute(
CUBLASLT_MATMUL_PREF_REDUCTION_SCHEME_MASK, mask);
}
#endif
}
CuBlasLtMatmulDescriptor computeDesc(computeType, scaleType);
cublasOperation_t transa = transpose_mat1 ? CUBLAS_OP_T : CUBLAS_OP_N;
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_TRANSA, transa);
cublasOperation_t transb = transpose_mat2 ? CUBLAS_OP_T : CUBLAS_OP_N;
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_TRANSB, transb);
auto stream = at::cuda::getCurrentCUDAStream();
#ifndef USE_ROCM
if (at::globalContext()._SMCarveout_EXPERIMENTAL().has_value()) {
computeDesc.setAttribute<int32_t>(
CUBLASLT_MATMUL_DESC_SM_COUNT_TARGET,
at::cuda::getCurrentDeviceProperties()->multiProcessorCount -
at::globalContext()._SMCarveout_EXPERIMENTAL().value());
}
#else
if (at::globalContext()._SMCarveout_EXPERIMENTAL().has_value()) {
stream = _getCarveoutStream(
at::globalContext()._SMCarveout_EXPERIMENTAL().value());
_syncCurrentWithCarveoutStream(stream, true);
}
#endif
const auto epilogue = [&]() -> cublasLtEpilogue_t {
// The cuBLAS documentation indicates that
// *_<ACTIVATION>_BIAS = *_<ACTIVATION>,
// but we keep it verbose here for clarity.
switch (activation) {
case GEMMAndBiasActivationEpilogue::RELU:
return bias ? CUBLASLT_EPILOGUE_RELU_BIAS : CUBLASLT_EPILOGUE_RELU;
case GEMMAndBiasActivationEpilogue::GELU:
return bias ? CUBLASLT_EPILOGUE_GELU_BIAS : CUBLASLT_EPILOGUE_GELU;
default:
return bias ? CUBLASLT_EPILOGUE_BIAS : CUBLASLT_EPILOGUE_DEFAULT;
}
}();
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_EPILOGUE, epilogue);
if (bias) {
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_BIAS_POINTER, bias);
}
CuBlasLtMatrixLayout Adesc(abType, m, k, mat1_ld, transpose_mat1);
CuBlasLtMatrixLayout Bdesc(abType, k, n, mat2_ld, transpose_mat2);
CuBlasLtMatrixLayout Cdesc(cType, m, n, result_ld);
auto ltworkspace = CublasLtWorkspace();
preference.setAttribute(CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, ltworkspace.size);
#ifndef USE_ROCM
uint32_t a_alignment = _getAlignment(reinterpret_cast<uintptr_t>(mat1_ptr));
uint32_t b_alignment = _getAlignment(reinterpret_cast<uintptr_t>(mat2_ptr));
uint32_t c_alignment = _getAlignment(reinterpret_cast<uintptr_t>(result_ptr));
uint32_t d_alignment = _getAlignment(reinterpret_cast<uintptr_t>(bias));
preference.setAttribute(CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_A_BYTES, a_alignment);
preference.setAttribute(CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_B_BYTES, b_alignment);
preference.setAttribute(CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_C_BYTES, c_alignment);
preference.setAttribute(CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_D_BYTES, d_alignment);
#endif
cublasLtMatmulHeuristicResult_t heuristicResult = {};
int returnedResult = 0;
cublasLtHandle_t ltHandle = at::cuda::getCurrentCUDABlasLtHandle();
TORCH_CUDABLAS_CHECK(cublasLtMatmulAlgoGetHeuristic(
ltHandle,
computeDesc.descriptor(),
Adesc.descriptor(),
Bdesc.descriptor(),
Cdesc.descriptor(),
Cdesc.descriptor(),
preference.descriptor(),
1,
&heuristicResult,
&returnedResult));
cublasStatus_t cublasStatus = CUBLAS_STATUS_SUCCESS;
if (returnedResult == 0) {
cublasStatus = CUBLAS_STATUS_NOT_SUPPORTED;
}
else {
cublasStatus = cublasLtMatmul(
ltHandle,
computeDesc.descriptor(),
alpha_ptr,
mat1_ptr,
Adesc.descriptor(),
mat2_ptr,
Bdesc.descriptor(),
beta_ptr,
result_ptr,
Cdesc.descriptor(),
result_ptr,
Cdesc.descriptor(),
&heuristicResult.algo,
ltworkspace.ptr,
ltworkspace.size,
stream);
#ifdef USE_ROCM
if (at::globalContext()._SMCarveout_EXPERIMENTAL().has_value()) {
_syncCurrentWithCarveoutStream(stream, false);
}
#endif
}
if (cublasStatus != CUBLAS_STATUS_SUCCESS) {
TORCH_WARN(
"gemm_and_bias error: ",
at::cuda::blas::_cublasGetErrorEnum(cublasStatus),
" when calling cublasLtMatmul with transpose_mat1 ",
transpose_mat1,
" transpose_mat2 ",
transpose_mat2,
" m ",
m,
" n ",
n,
" k ",
k,
" mat1_ld ",
mat1_ld,
" mat2_ld ",
mat2_ld,
" result_ld ",
result_ld,
" abType ",
abType,
" cType ",
cType,
" computeType ",
computeType,
" scaleType ",
scaleType,
". Will attempt to recover by calling unfused cublas path.");
return false;
}
return true;
}
Source
Analyze Your Own Codebase
Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.
Try Supermodel Free