Home / Class/ gemm_and_bias Class — pytorch Architecture

gemm_and_bias Class — pytorch Architecture

Architecture documentation for the gemm_and_bias class in CUDABlas.cpp from the pytorch codebase.

Entity Profile

Source Code

aten/src/ATen/cuda/CUDABlas.cpp lines 1566–1796

template <typename Dtype, typename C_Dtype>
bool gemm_and_bias(
    bool transpose_mat1,
    bool transpose_mat2,
    int64_t m,
    int64_t n,
    int64_t k,
    at::opmath_type<Dtype> alpha_val,
    const Dtype* mat1_ptr,
    int64_t mat1_ld,
    const Dtype* mat2_ptr,
    int64_t mat2_ld,
    const Dtype* bias,
    C_Dtype* result_ptr,
    int64_t result_ld,
    GEMMAndBiasActivationEpilogue activation) {

  if (std::is_same_v<C_Dtype, float> && std::is_same_v<Dtype, at::BFloat16>) {
    #ifdef USE_ROCM
    TORCH_CHECK(false, "gemm input type at::BFloat16 and output type float is not supported for ROCm");
    #endif
  } else if (std::is_same_v<C_Dtype, float> && std::is_same_v<Dtype, at::Half>) {
    #ifdef USE_ROCM
    TORCH_CHECK(false, "gemm input type at::Half and output type float is not supported for ROCm");
    #endif
    if (at::globalContext().allowFP16AccumulationCuBLAS())
      TORCH_CHECK(false, "gemm input type at::Half and output type float is not supported with allowFP16AccumulationCuBLAS");
  }

  using opmath_t = at::opmath_type<Dtype>;
  opmath_t beta_val = bias ? 0 : 1; // bias is added in epilogue unless nullptr

  cudaDataType_t abType = CUDA_R_32F;
  cudaDataType_t cType = CUDA_R_32F;
  cublasComputeType_t computeType = CUBLAS_COMPUTE_32F;
  cudaDataType_t scaleType = CUDA_R_32F;
  CuBlasLtMatmulPreference preference;
  void * alpha_ptr = &alpha_val;
  void * beta_ptr = &beta_val;
#ifndef USE_ROCM
  at::Half halpha_val;
  at::Half hbeta_val;
#endif
  if constexpr (std::is_same_v<Dtype, double>) {
    abType = CUDA_R_64F;
    cType = CUDA_R_64F;
    computeType = CUBLAS_COMPUTE_64F;
    scaleType = CUDA_R_64F;
  } else if constexpr (std::is_same_v<Dtype, float>) {
    if (at::globalContext().float32Precision(at::Float32Backend::CUDA, at::Float32Op::MATMUL) == at::Float32Precision::TF32) {
      computeType = CUBLAS_COMPUTE_32F_FAST_TF32;
    }
  } else if constexpr (std::is_same_v<Dtype, at::Half>) {
#ifndef USE_ROCM
    cudaDeviceProp* prop = at::cuda::getCurrentDeviceProperties();
    if (prop->major >= 7 && at::globalContext().allowFP16AccumulationCuBLAS()) {
      computeType = CUBLAS_COMPUTE_16F;
      scaleType = CUDA_R_16F;
      halpha_val = alpha_val;
      hbeta_val = beta_val;
      alpha_ptr = &halpha_val;
      beta_ptr = &hbeta_val;
    }
#endif
    abType = CUDA_R_16F;
    cType = std::is_same_v<C_Dtype, float> ? CUDA_R_32F : CUDA_R_16F;
#ifndef USE_ROCM
    auto fp16_reduction = at::globalContext().allowFP16ReductionCuBLAS();
    if (fp16_reduction !=
        at::CuBLASReductionOption::AllowReducedPrecisionWithSplitK) {
      uint32_t mask =
          fp16_reduction ==
                  at::CuBLASReductionOption::DisallowReducedPrecisionAllowSplitK
              ? (CUBLASLT_REDUCTION_SCHEME_COMPUTE_TYPE |
                 CUBLASLT_REDUCTION_SCHEME_NONE)
              : CUBLASLT_REDUCTION_SCHEME_NONE;
      preference.setAttribute(
          CUBLASLT_MATMUL_PREF_REDUCTION_SCHEME_MASK, mask);
    }
#endif
  } else if constexpr (std::is_same_v<Dtype, at::BFloat16>) {
    abType = CUDA_R_16BF;
    cType = std::is_same_v<C_Dtype, float> ? CUDA_R_32F : CUDA_R_16BF;
#ifndef USE_ROCM
    auto bf16_reduction = at::globalContext().allowBF16ReductionCuBLAS();
    if (bf16_reduction !=
        at::CuBLASReductionOption::AllowReducedPrecisionWithSplitK) {
      uint32_t mask =
          bf16_reduction ==
                  at::CuBLASReductionOption::DisallowReducedPrecisionAllowSplitK
              ? (CUBLASLT_REDUCTION_SCHEME_COMPUTE_TYPE |
                 CUBLASLT_REDUCTION_SCHEME_NONE)
              : CUBLASLT_REDUCTION_SCHEME_NONE;
      preference.setAttribute(
          CUBLASLT_MATMUL_PREF_REDUCTION_SCHEME_MASK, mask);
    }
#endif
  }

  CuBlasLtMatmulDescriptor computeDesc(computeType, scaleType);
  cublasOperation_t transa = transpose_mat1 ? CUBLAS_OP_T : CUBLAS_OP_N;
  computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_TRANSA, transa);
  cublasOperation_t transb = transpose_mat2 ? CUBLAS_OP_T : CUBLAS_OP_N;
  computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_TRANSB, transb);
  auto stream = at::cuda::getCurrentCUDAStream();
#ifndef USE_ROCM
  if (at::globalContext()._SMCarveout_EXPERIMENTAL().has_value()) {
    computeDesc.setAttribute<int32_t>(
        CUBLASLT_MATMUL_DESC_SM_COUNT_TARGET,
        at::cuda::getCurrentDeviceProperties()->multiProcessorCount -
            at::globalContext()._SMCarveout_EXPERIMENTAL().value());
  }
#else
  if (at::globalContext()._SMCarveout_EXPERIMENTAL().has_value()) {
    stream = _getCarveoutStream(
        at::globalContext()._SMCarveout_EXPERIMENTAL().value());
    _syncCurrentWithCarveoutStream(stream, true);
  }
#endif
  const auto epilogue = [&]() -> cublasLtEpilogue_t {
    // The cuBLAS documentation indicates that
    // *_<ACTIVATION>_BIAS = *_<ACTIVATION>,
    // but we keep it verbose here for clarity.
    switch (activation) {
      case GEMMAndBiasActivationEpilogue::RELU:
        return bias ? CUBLASLT_EPILOGUE_RELU_BIAS : CUBLASLT_EPILOGUE_RELU;
      case GEMMAndBiasActivationEpilogue::GELU:
        return bias ? CUBLASLT_EPILOGUE_GELU_BIAS : CUBLASLT_EPILOGUE_GELU;
      default:
        return bias ? CUBLASLT_EPILOGUE_BIAS : CUBLASLT_EPILOGUE_DEFAULT;
    }
  }();
  computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_EPILOGUE, epilogue);

  if (bias) {
    computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_BIAS_POINTER, bias);
  }

  CuBlasLtMatrixLayout Adesc(abType, m, k, mat1_ld, transpose_mat1);
  CuBlasLtMatrixLayout Bdesc(abType, k, n, mat2_ld, transpose_mat2);
  CuBlasLtMatrixLayout Cdesc(cType, m, n, result_ld);

  auto ltworkspace = CublasLtWorkspace();
  preference.setAttribute(CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, ltworkspace.size);

#ifndef USE_ROCM
  uint32_t a_alignment = _getAlignment(reinterpret_cast<uintptr_t>(mat1_ptr));
  uint32_t b_alignment = _getAlignment(reinterpret_cast<uintptr_t>(mat2_ptr));
  uint32_t c_alignment = _getAlignment(reinterpret_cast<uintptr_t>(result_ptr));
  uint32_t d_alignment = _getAlignment(reinterpret_cast<uintptr_t>(bias));
  preference.setAttribute(CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_A_BYTES, a_alignment);
  preference.setAttribute(CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_B_BYTES, b_alignment);
  preference.setAttribute(CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_C_BYTES, c_alignment);
  preference.setAttribute(CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_D_BYTES, d_alignment);
#endif

  cublasLtMatmulHeuristicResult_t heuristicResult = {};
  int returnedResult = 0;
  cublasLtHandle_t ltHandle = at::cuda::getCurrentCUDABlasLtHandle();
  TORCH_CUDABLAS_CHECK(cublasLtMatmulAlgoGetHeuristic(
      ltHandle,
      computeDesc.descriptor(),
      Adesc.descriptor(),
      Bdesc.descriptor(),
      Cdesc.descriptor(),
      Cdesc.descriptor(),
      preference.descriptor(),
      1,
      &heuristicResult,
      &returnedResult));
  cublasStatus_t cublasStatus = CUBLAS_STATUS_SUCCESS;
  if (returnedResult == 0) {
    cublasStatus = CUBLAS_STATUS_NOT_SUPPORTED;
  }
  else {
    cublasStatus = cublasLtMatmul(
      ltHandle,
      computeDesc.descriptor(),
      alpha_ptr,
      mat1_ptr,
      Adesc.descriptor(),
      mat2_ptr,
      Bdesc.descriptor(),
      beta_ptr,
      result_ptr,
      Cdesc.descriptor(),
      result_ptr,
      Cdesc.descriptor(),
      &heuristicResult.algo,
      ltworkspace.ptr,
      ltworkspace.size,
      stream);
#ifdef USE_ROCM
    if (at::globalContext()._SMCarveout_EXPERIMENTAL().has_value()) {
      _syncCurrentWithCarveoutStream(stream, false);
    }
#endif
  }
  if (cublasStatus != CUBLAS_STATUS_SUCCESS) {
    TORCH_WARN(
      "gemm_and_bias error: ",
      at::cuda::blas::_cublasGetErrorEnum(cublasStatus),
      " when calling cublasLtMatmul with transpose_mat1 ",
      transpose_mat1,
      " transpose_mat2 ",
      transpose_mat2,
      " m ",
      m,
      " n ",
      n,
      " k ",
      k,
      " mat1_ld ",
      mat1_ld,
      " mat2_ld ",
      mat2_ld,
      " result_ld ",
      result_ld,
      " abType ",
      abType,
      " cType ",
      cType,
      " computeType ",
      computeType,
      " scaleType ",
      scaleType,
      ". Will attempt to recover by calling unfused cublas path.");
    return false;
  }
  return true;
}

Analyze Your Own Codebase

Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.

Try Supermodel Free