Home / Class/ gemm_internal_cublas_bfloat16_helper Class — pytorch Architecture

gemm_internal_cublas_bfloat16_helper Class — pytorch Architecture

Architecture documentation for the gemm_internal_cublas_bfloat16_helper class in CUDABlas.cpp from the pytorch codebase.

Entity Profile

Source Code

aten/src/ATen/cuda/CUDABlas.cpp lines 1218–1267

template <typename C_Dtype>
inline void gemm_internal_cublas_bfloat16_helper(CUDABLAS_GEMM_ARGTYPES_AND_C_DTYPE(at::BFloat16, C_Dtype)) {
  cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
  cublasOperation_t opa = _cublasOpFromChar(transa);
  cublasOperation_t opb = _cublasOpFromChar(transb);
  float falpha = alpha;
  float fbeta = beta;
  _cublasAdjustLdLevel3(transa, transb, m, n, k, &lda, &ldb, &ldc);
  GEMM_CHECK_ARGVALUES(at::BFloat16);
#ifndef USE_ROCM
  cublasMath_t cublas_flags = CUBLAS_DEFAULT_MATH;
  auto bf16_reduction = at::globalContext().allowBF16ReductionCuBLAS();
  TORCH_CHECK(bf16_reduction !=
      at::CuBLASReductionOption::DisallowReducedPrecisionDisallowSplitK,
        "torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction("
        "..., allow_splitk=False) requires the cuBLASLt backend");
  if (bf16_reduction !=
      at::CuBLASReductionOption::AllowReducedPrecisionWithSplitK) {
    cublas_flags = static_cast<cublasMath_t>(
        cublas_flags | CUBLAS_MATH_DISALLOW_REDUCED_PRECISION_REDUCTION);
  }
#endif
#if defined(USE_ROCM)
  auto compute_type = CUBLAS_COMPUTE_32F;
#else
  auto compute_type = CUDA_R_32F;
#endif
  TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, cublas_flags));
  TORCH_CUDABLAS_CHECK(cublasGemmEx(
      handle,
      opa,
      opb,
      m,
      n,
      k,
      &falpha,
      a,
      CUDA_R_16BF,
      lda,
      b,
      CUDA_R_16BF,
      ldb,
      &fbeta,
      c,
      std::is_same_v<C_Dtype, float> ? CUDA_R_32F : CUDA_R_16BF,
      ldc,
      compute_type,
      CUBLAS_GEMM_DEFAULT_TENSOR_OP));
  TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
}

Analyze Your Own Codebase

Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.

Try Supermodel Free