Home / Class/ gemm_internal_cublas_half_helper Class — pytorch Architecture

gemm_internal_cublas_half_helper Class — pytorch Architecture

Architecture documentation for the gemm_internal_cublas_half_helper class in CUDABlas.cpp from the pytorch codebase.

Entity Profile

Source Code

aten/src/ATen/cuda/CUDABlas.cpp lines 1103–1216

template <typename C_Dtype>
inline void gemm_internal_cublas_half_helper(CUDABLAS_GEMM_ARGTYPES_AND_C_DTYPE(at::Half, C_Dtype)) {
  cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
  cublasOperation_t opa = _cublasOpFromChar(transa);
  cublasOperation_t opb = _cublasOpFromChar(transb);
  float falpha = alpha;
  float fbeta = beta;
#ifndef USE_ROCM
  at::Half halpha;
  at::Half hbeta;
  auto compute_type = CUDA_R_32F;
#endif
  void * alpha_ptr = &falpha;
  void * beta_ptr = &fbeta;
  _cublasAdjustLdLevel3(transa, transb, m, n, k, &lda, &ldb, &ldc);
  GEMM_CHECK_ARGVALUES(at::Half);
#ifdef USE_ROCM
  int flag = 0;
  rocblas_datatype c_type = std::is_same<C_Dtype, float>::value ? rocblas_datatype_f32_r : rocblas_datatype_f16_r;
  rocblas_datatype d_type = c_type;
#if USE_GEMM_FLAGS_FP16_ALT_IMPL
  flag = at::ROCmBackwardPassGuard::is_backward_pass() ? rocblas_gemm_flags_fp16_alt_impl : 0;
#endif
  TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex(
      (rocblas_handle)handle,
      hipOperationToRocOperation(opa),
      hipOperationToRocOperation(opb),
      m,
      n,
      k,
      alpha_ptr,
      a,
      rocblas_datatype_f16_r,
      lda,
      b,
      rocblas_datatype_f16_r,
      ldb,
      beta_ptr,
      c,
      c_type,
      ldc,
      c,
      d_type,
      ldc,
      rocblas_datatype_f32_r,
      rocblas_gemm_algo_standard,
      0,
      flag)));
#else
  cudaDeviceProp* prop = at::cuda::getCurrentDeviceProperties();
  if (prop->major >= 7 && at::globalContext().allowFP16AccumulationCuBLAS()) {
    compute_type = CUDA_R_16F;
    halpha = alpha;
    hbeta = beta;
    alpha_ptr = &halpha;
    beta_ptr = &hbeta;
  }
  if (prop->major >= 5) {
    cublasMath_t cublas_flags = CUBLAS_DEFAULT_MATH;
    auto fp16_reduction = at::globalContext().allowFP16ReductionCuBLAS();
    TORCH_CHECK(fp16_reduction !=
        at::CuBLASReductionOption::DisallowReducedPrecisionDisallowSplitK,
          "torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction("
          "..., allow_splitk=False) requires the cuBLASLt backend");
    if (fp16_reduction !=
        at::CuBLASReductionOption::AllowReducedPrecisionWithSplitK) {
      cublas_flags = static_cast<cublasMath_t>(
          cublas_flags | CUBLAS_MATH_DISALLOW_REDUCED_PRECISION_REDUCTION);
    }
    // Disallow fp16 reductions that could lead to unexpected overflow issues.
    TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, cublas_flags));
    TORCH_CUDABLAS_CHECK(cublasGemmEx(
        handle,
        opa,
        opb,
        m,
        n,
        k,
        alpha_ptr,
        a,
        CUDA_R_16F,
        lda,
        b,
        CUDA_R_16F,
        ldb,
        beta_ptr,
        c,
        std::is_same_v<C_Dtype, float> ? CUDA_R_32F : CUDA_R_16F,
        ldc,
        compute_type,
        CUBLAS_GEMM_DEFAULT_TENSOR_OP));
    TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
  } else {
    TORCH_CUDABLAS_CHECK(cublasSgemmEx(
        handle,
        opa,
        opb,
        m,
        n,
        k,
        &falpha,
        a,
        CUDA_R_16F,
        lda,
        b,
        CUDA_R_16F,
        ldb,
        &fbeta,
        c,
        std::is_same_v<C_Dtype, float> ? CUDA_R_32F : CUDA_R_16F,
        ldc));
  }
#endif
}

Analyze Your Own Codebase

Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.

Try Supermodel Free