Home / Class/ BLOCK_SIZE Class — pytorch Architecture

BLOCK_SIZE Class — pytorch Architecture

Architecture documentation for the BLOCK_SIZE class in ck_gemm_template.h from the pytorch codebase.

Entity Profile

Source Code

aten/src/ATen/native/hip/ck_gemm_template.h lines 78–232

template <
    typename Dtype,
    int BLOCK_SIZE,
    int MBLOCK,
    int NBLOCK,
    int KBLOCK,
    int AK1,
    int BK1,
    int MPER_XDL,
    int NPER_XDL,
    int MPER_WAVE,
    int NPER_WAVE,
    typename ABLOCK_CLUSTER_LENS,
    typename ABLOCK_CLUSTER_ORDER,
    typename ABLOCK_SRC_ORDER,
    int ABLOCK_VECTOR_DIM,
    int ABLOCK_SCALAR_VEC,
    int ABLOCK_SCALAR_VEC_AK1,
    bool ABLOCK_LDS_EXTRAM,
    typename BBLOCK_CLUSTER_LENS,
    typename BBLOCK_CLUSTER_ORDER,
    typename BBLOCK_SRC_ORDER,
    int BBLOCK_VECTOR_DIM,
    int BBLOCK_SCALAR_VEC,
    int BBLOCK_SCALAR_VEC_AK1,
    bool BBLOCK_LDS_EXTRAN,
    int CMPER_WAVE,
    int CNPER_WAVE,
    typename BLOCK_CLUSTER_LENS,
    typename CDE_SCALAR_VEC,
    bool PADDING = false,
    bool TRANSA = false,
    bool TRANSB = false>
void gemm_impl(CUDABLAS_GEMM_ARGTYPES(Dtype)) {
  // Get input information.
  int M = m;
  int N = n;
  int K = k;

  int StrideA = lda;
  int StrideB = ldb;
  int StrideC = ldc;

  int KBatch = 1;

  float falpha = alpha;
  float fbeta = beta;

  using ADataType = typename CkMathType<Dtype>::dtype;
  using BDataType = typename CkMathType<Dtype>::dtype;
  using CDataType = typename CkMathType<Dtype>::dtype;
  using DDataType = typename CkMathType<Dtype>::dtype;

  using AccDataType = float;
  using CShuffleDataType = typename CkMathType<Dtype>::dtype;

  using ALayout = typename CkTensorLayout<TRANSA, TRANSB>::a_layout;
  using BLayout = typename CkTensorLayout<TRANSA, TRANSB>::b_layout;

  using DLayout = Row;
  using CLayout = Row;

  using AElementOp = PassThrough;
  using BElementOp = PassThrough;
  using CElementOp = AlphaBetaAdd;


  static constexpr auto GemmDefault =
      ck::tensor_operation::device::GemmSpecialization::Default;
  static constexpr auto GemmMNKPadding =
      ck::tensor_operation::device::GemmSpecialization::MNKPadding;
  static constexpr auto GemmSpec = PADDING ? GemmMNKPadding : GemmDefault;


  using DeviceGemmInstance =
    ck::tensor_operation::device::DeviceGemmMultiD_Xdl_CShuffle_V3<ALayout,
                                                                   BLayout,
                                                                   ck::Tuple<>,
                                                                   CLayout,
                                                                   ADataType,
                                                                   BDataType,
                                                                   ck::Tuple<>,
                                                                   CDataType,
                                                                   AccDataType,
                                                                   CShuffleDataType,
                                                                   AElementOp,
                                                                   BElementOp,
                                                                   CElementOp,
                                                                   GemmSpec,
                                                                   BLOCK_SIZE,
                                                                   MBLOCK,
                                                                   NBLOCK,
                                                                   KBLOCK,
                                                                   AK1,
                                                                   BK1,
                                                                   MPER_XDL,
                                                                   NPER_XDL,
                                                                   MPER_WAVE,
                                                                   NPER_WAVE,
                                                                   ABLOCK_CLUSTER_LENS,
                                                                   ABLOCK_CLUSTER_ORDER,
                                                                   ABLOCK_SRC_ORDER,
                                                                   ABLOCK_VECTOR_DIM,
                                                                   ABLOCK_SCALAR_VEC,
                                                                   ABLOCK_SCALAR_VEC_AK1,
                                                                   ABLOCK_LDS_EXTRAM,
                                                                   BBLOCK_CLUSTER_LENS,
                                                                   BBLOCK_CLUSTER_ORDER,
                                                                   BBLOCK_SRC_ORDER,
                                                                   BBLOCK_VECTOR_DIM,
                                                                   BBLOCK_SCALAR_VEC,
                                                                   BBLOCK_SCALAR_VEC_AK1,
                                                                   BBLOCK_LDS_EXTRAN,
                                                                   CMPER_WAVE,
                                                                   CNPER_WAVE,
                                                                   BLOCK_CLUSTER_LENS,
                                                                   CDE_SCALAR_VEC>;


  auto gemm = DeviceGemmInstance{};
  auto invoker = gemm.MakeInvoker();

  auto a_element_op = AElementOp{};
  auto b_element_op = BElementOp{};
  auto c_element_op = CElementOp{alpha, beta};


  using DDataArrayType = std::array<const void*, 0>;
  DDataArrayType DDataArray;

  // We swap A and B inputs here as a temporary workaround
  auto argument = gemm.MakeArgument(
     reinterpret_cast<const void*>(b),
     reinterpret_cast<const void*>(a),
     DDataArray,
     reinterpret_cast<void*>(c),
     N,
     M,
     K,
     StrideB,
     StrideA,
     std::array<ck::index_t, 0>{},
     StrideC,
     KBatch,
     a_element_op,
     b_element_op,
     c_element_op);


 TORCH_CHECK(gemm.IsSupportedArgument(argument), "wrong! device_gemm with the specified compilation parameters does not support this GEMM problem");


 auto stream = at::cuda::getCurrentCUDAStream().stream();
 invoker.Run(argument, StreamConfig{stream, false});
}

Analyze Your Own Codebase

Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.

Try Supermodel Free