Home / Class/ bgemm_internal_cublas_bfloat16_helper Class — pytorch Architecture

bgemm_internal_cublas_bfloat16_helper Class — pytorch Architecture

Architecture documentation for the bgemm_internal_cublas_bfloat16_helper class in CUDABlas.cpp from the pytorch codebase.

Entity Profile

Source Code

aten/src/ATen/cuda/CUDABlas.cpp lines 732–755

template <typename C_Dtype>
inline void bgemm_internal_cublas_bfloat16_helper(CUDABLAS_BGEMM_ARGTYPES_AND_C_DTYPE(at::BFloat16, C_Dtype)) {
  BGEMM_CHECK_ARGVALUES(at::BFloat16);
  cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
  cublasOperation_t opa = _cublasOpFromChar(transa);
  cublasOperation_t opb = _cublasOpFromChar(transb);
  const float falpha = alpha;
  const float fbeta = beta;
  _cublasAdjustLdLevel3(transa, transb, m, n, k, &lda, &ldb, &ldc);

#if defined(USE_ROCM)
  auto compute_type = CUBLAS_COMPUTE_32F;
#else
  auto compute_type = CUDA_R_32F;
#endif
  TORCH_CUDABLAS_CHECK(cublasGemmStridedBatchedEx(handle,
                              opa, opb, (int)m, (int)n, (int)k,
                              (void*)&falpha, a, CUDA_R_16BF, (int)lda, stridea,
                              b, CUDA_R_16BF, (int)ldb, strideb,
                              (void*)&fbeta, c, std::is_same_v<C_Dtype, float> ? CUDA_R_32F : CUDA_R_16BF,
                              (int)ldc, stridec, (int)num_batches,
                              compute_type,
                              CUBLAS_GEMM_DEFAULT_TENSOR_OP));
}

Analyze Your Own Codebase

Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.

Try Supermodel Free