Home / Class/ bgemm_internal_cublaslt Class — pytorch Architecture

bgemm_internal_cublaslt Class — pytorch Architecture

Architecture documentation for the bgemm_internal_cublaslt class in CUDABlas.cpp from the pytorch codebase.

Entity Profile

Source Code

aten/src/ATen/cuda/CUDABlas.cpp lines 368–605

template <typename Dtype, typename C_Dtype = Dtype>
static inline bool bgemm_internal_cublaslt(CUDABLAS_BGEMM_ARGTYPES_AND_C_DTYPE(Dtype, C_Dtype)) {
#if defined(USE_ROCM) && ROCM_VERSION == 60400
  // regression in ROCm 6.4, planned fixed in 6.4.1, hipblaslt TT fp32 calculation errors
  // best to disallow hipblaslt for this specific case
  if constexpr (std::is_same_v<Dtype, float>) {
    if (_cublasOpFromChar(transa) == CUBLAS_OP_T && _cublasOpFromChar(transb) == CUBLAS_OP_T) {
        return false;
    }
  }
#endif
  cudaDataType_t abType = CUDA_R_32F;
  cudaDataType_t cType = CUDA_R_32F;
  cublasComputeType_t computeType = CUBLAS_COMPUTE_32F;
  cudaDataType_t scaleType = CUDA_R_32F;
  CuBlasLtMatmulPreference preference;
#ifndef USE_ROCM
  at::Half halpha;
  at::Half hbeta;
  uint32_t mask = -1;
#endif
  void * alpha_ptr = &alpha;
  void * beta_ptr = &beta;
  if constexpr (std::is_same_v<Dtype, double>) {
    abType = CUDA_R_64F;
    cType = CUDA_R_64F;
    computeType = CUBLAS_COMPUTE_64F;
    scaleType = CUDA_R_64F;
  } else if constexpr (std::is_same_v<Dtype, float>) {
    if (at::globalContext().float32Precision(at::Float32Backend::CUDA, at::Float32Op::MATMUL) == at::Float32Precision::TF32) {
      computeType = CUBLAS_COMPUTE_32F_FAST_TF32;
    }
  } else if constexpr (std::is_same_v<Dtype, c10::complex<double>>) {
    abType = CUDA_C_64F;
    cType = CUDA_C_64F;
    computeType = CUBLAS_COMPUTE_64F;
    scaleType = CUDA_C_64F;
  } else if constexpr (std::is_same_v<Dtype, c10::complex<float>>) {
    abType = CUDA_C_32F;
    cType = CUDA_C_32F;
    scaleType = CUDA_C_32F;
  } else if constexpr (std::is_same_v<Dtype, at::Half>) {
#ifndef USE_ROCM
    cudaDeviceProp* prop = at::cuda::getCurrentDeviceProperties();
    if (prop->major >= 7 && at::globalContext().allowFP16AccumulationCuBLAS()) {
      computeType = CUBLAS_COMPUTE_16F;
      scaleType = CUDA_R_16F;
      halpha = alpha;
      hbeta = beta;
      alpha_ptr = &halpha;
      beta_ptr = &hbeta;
    }
#endif
    abType = CUDA_R_16F;
    cType = std::is_same_v<C_Dtype, float> ? CUDA_R_32F : CUDA_R_16F;
#ifndef USE_ROCM
    auto fp16_reduction = at::globalContext().allowFP16ReductionCuBLAS();
    if (fp16_reduction !=
        at::CuBLASReductionOption::AllowReducedPrecisionWithSplitK) {
      mask =
          fp16_reduction ==
                  at::CuBLASReductionOption::DisallowReducedPrecisionAllowSplitK
              ? (CUBLASLT_REDUCTION_SCHEME_COMPUTE_TYPE |
                 CUBLASLT_REDUCTION_SCHEME_NONE)
              : CUBLASLT_REDUCTION_SCHEME_NONE;
      preference.setAttribute(
          CUBLASLT_MATMUL_PREF_REDUCTION_SCHEME_MASK, mask);
    }
#endif
  } else if constexpr (std::is_same_v<Dtype, at::BFloat16>) {
    abType = CUDA_R_16BF;
    cType = std::is_same_v<C_Dtype, float> ? CUDA_R_32F : CUDA_R_16BF;
#ifndef USE_ROCM
    auto bf16_reduction = at::globalContext().allowBF16ReductionCuBLAS();
    if (bf16_reduction !=
        at::CuBLASReductionOption::AllowReducedPrecisionWithSplitK) {
      mask =
          bf16_reduction ==
                  at::CuBLASReductionOption::DisallowReducedPrecisionAllowSplitK
              ? (CUBLASLT_REDUCTION_SCHEME_COMPUTE_TYPE |
                 CUBLASLT_REDUCTION_SCHEME_NONE)
              : CUBLASLT_REDUCTION_SCHEME_NONE;
      preference.setAttribute(
          CUBLASLT_MATMUL_PREF_REDUCTION_SCHEME_MASK, mask);
    }
#endif
  } else {
    static_assert(false && sizeof(Dtype), "at::cuda::blas::bgemm_internal_cublaslt: not implemented");
  }

  cublasLtHandle_t ltHandle = at::cuda::getCurrentCUDABlasLtHandle();
  cublasOperation_t opa = _cublasOpFromChar(transa);
  cublasOperation_t opb = _cublasOpFromChar(transb);
  _cublasAdjustLdLevel3(transa, transb, m, n, k, &lda, &ldb, &ldc);

  CuBlasLtMatmulDescriptor computeDesc(computeType, scaleType);
  computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_TRANSA, opa);
  computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_TRANSB, opb);
  auto stream = at::cuda::getCurrentCUDAStream();
#ifndef USE_ROCM
  if (at::globalContext()._SMCarveout_EXPERIMENTAL().has_value()) {
    computeDesc.setAttribute<int32_t>(
        CUBLASLT_MATMUL_DESC_SM_COUNT_TARGET,
        at::cuda::getCurrentDeviceProperties()->multiProcessorCount -
            at::globalContext()._SMCarveout_EXPERIMENTAL().value());
  }
#else
  if (at::globalContext()._SMCarveout_EXPERIMENTAL().has_value()) {
    stream = _getCarveoutStream(
        at::globalContext()._SMCarveout_EXPERIMENTAL().value());
    _syncCurrentWithCarveoutStream(stream, true);
  }
#endif
  CuBlasLtMatrixLayout Adesc(abType, m, k, lda, opa == CUBLAS_OP_T);
  CuBlasLtMatrixLayout Bdesc(abType, k, n, ldb, opb == CUBLAS_OP_T);
  CuBlasLtMatrixLayout Cdesc(cType, m, n, ldc);

  if (num_batches > 1) {
    int num_batches_as_int = static_cast<int>(num_batches);
    Adesc.setAttribute(CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, num_batches_as_int);
    Bdesc.setAttribute(CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, num_batches_as_int);
    Cdesc.setAttribute(CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, num_batches_as_int);
    Adesc.setAttribute(CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, stridea);
    Bdesc.setAttribute(CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, strideb);
    Cdesc.setAttribute(CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, stridec);
  }

#ifndef USE_ROCM
  uint32_t a_alignment = _getAlignment(reinterpret_cast<uintptr_t>(a));
  uint32_t b_alignment = _getAlignment(reinterpret_cast<uintptr_t>(b));
  uint32_t c_alignment = _getAlignment(reinterpret_cast<uintptr_t>(c));
  preference.setAttribute(CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_A_BYTES, a_alignment);
  preference.setAttribute(CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_B_BYTES, b_alignment);
  preference.setAttribute(CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_C_BYTES, c_alignment);
#endif

  auto ltworkspace = CublasLtWorkspace();
  TORCH_CHECK(ltworkspace.ptr != nullptr, "OOM trying to allocate workspace for cublaslt");
  preference.setAttribute(CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, ltworkspace.size);

  cublasStatus_t cublasStatus = CUBLAS_STATUS_SUCCESS;
  cublasLtMatmulHeuristicResult_t heuristicResult = {};
  int returnedResult = 0;
  // on Blackwell+, we fake a n > 1 matmul when querying heuristics
  // to prevent cuBLASLt from dispatching to a GEMV kernel for batch-invariance
#ifndef USE_ROCM
  const bool lie_to_cublaslt = mask == CUBLASLT_REDUCTION_SCHEME_NONE && n == 1 && at::cuda::getCurrentDeviceProperties()->major >= 10;
#else
  const bool lie_to_cublaslt = false;
#endif
  if (lie_to_cublaslt) {
     CuBlasLtMatrixLayout FakeBdesc(abType, k, 2, ldb, opb == CUBLAS_OP_T);
     CuBlasLtMatrixLayout FakeCdesc(cType, m, 2, ldc);

     TORCH_CUDABLAS_CHECK(cublasLtMatmulAlgoGetHeuristic(
        ltHandle,
        computeDesc.descriptor(),
        Adesc.descriptor(),
        FakeBdesc.descriptor(),
        FakeCdesc.descriptor(),
        FakeCdesc.descriptor(),
        preference.descriptor(),
        1,
        &heuristicResult,
        &returnedResult));
  } else {
    TORCH_CUDABLAS_CHECK(cublasLtMatmulAlgoGetHeuristic(
        ltHandle,
        computeDesc.descriptor(),
        Adesc.descriptor(),
        Bdesc.descriptor(),
        Cdesc.descriptor(),
        Cdesc.descriptor(),
        preference.descriptor(),
        1,
        &heuristicResult,
        &returnedResult));
  }
  if (returnedResult == 0) {
    cublasStatus = CUBLAS_STATUS_NOT_SUPPORTED;
  }
  else {
    cublasStatus = cublasLtMatmul(
      ltHandle,
      computeDesc.descriptor(),
      alpha_ptr,
      a,
      Adesc.descriptor(),
      b,
      Bdesc.descriptor(),
      beta_ptr,
      c,
      Cdesc.descriptor(),
      c,
      Cdesc.descriptor(),
      &heuristicResult.algo,
      ltworkspace.ptr,
      ltworkspace.size,
      stream);
#ifdef USE_ROCM
    if (at::globalContext()._SMCarveout_EXPERIMENTAL().has_value()) {
      _syncCurrentWithCarveoutStream(stream, false);
    }
#endif
  }
  if (cublasStatus != CUBLAS_STATUS_SUCCESS) {
    TORCH_WARN(
      "bgemm_internal_cublaslt error: ",
      at::cuda::blas::_cublasGetErrorEnum(cublasStatus),
      " when calling cublasLtMatmul with transpose_mat1 ",
      (opa == CUBLAS_OP_T),
      " transpose_mat2 ",
      (opb == CUBLAS_OP_T),
      " m ",
      m,
      " n ",
      n,
      " k ",
      k,
      " lda ",
      lda,
      " ldb ",
      ldb,
      " ldc ",
      ldc,
      " abType ",
      abType,
      " cType ",
      cType,
      " computeType ",
      computeType,
      " scaleType ",
      scaleType,
      ". Will attempt to recover by calling cublas instead.");
    return false;
  }
  return true;
}

Analyze Your Own Codebase

Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.

Try Supermodel Free