Home / Class/ BLOCK_SIZE Class — pytorch Architecture

BLOCK_SIZE Class — pytorch Architecture

Architecture documentation for the BLOCK_SIZE class in bgemm_kernel_template.h from the pytorch codebase.

Entity Profile

Source Code

aten/src/ATen/native/hip/bgemm_kernels/bgemm_kernel_template.h lines 45–157

template <
    typename A_DATA_TYPE,
    typename B_DATA_TYPE,
    int BLOCK_SIZE,
    int MBLOCK,
    int NBLOCK,
    int KBLOCK,
    int AK1,
    int BK1,
    int WAVE_TILE_M,
    int WAVE_TILE_N,
    int WAVE_MAP_M,
    int WAVE_MAP_N,
    typename ABLOCK_TRANSFER,
    int ABLOCK_TRANSFER_SSPV,
    int ABLOCK_TRANSFER_DSPV_K1,
    typename BBLOCK_TRANSFER,
    int BBLOCK_TRANSFER_SSPV,
    int BBLOCK_TRANSFER_SSPV_K1,
    int CSHUFFLE_MXDL_PWPS,
    int CSHUFFLE_NXDL_PWPS,
    typename CSHUFFLEBLOCK_TRANSFER,
    typename CDESHUFFLEBLOCK_TRANSFER,
    ck::BlockGemmPipelineScheduler LOOP_SCHED,
    ck::BlockGemmPipelineVersion PIPELINE_VERSION,
    ck::tensor_operation::device::GemmSpecialization GEMM_SPEC =
        ck::tensor_operation::device::GemmSpecialization::MNPadding,
    bool TRANSA = false,
    bool TRANSB = false>
void bgemm_kernel_impl(CUDABLAS_BGEMM_ARGTYPES(at::BFloat16)) {

  using ADataType = typename CkMathType<A_DATA_TYPE>::dtype;
  using BDataType = typename CkMathType<B_DATA_TYPE>::dtype;

  using ALayout = typename CkTensorLayout<TRANSA, TRANSB>::a_layout;
  using BLayout = typename CkTensorLayout<TRANSA, TRANSB>::b_layout;

  auto a_element_op = AElementOp{};
  auto b_element_op = BElementOp{};
  auto cde_element_op = CDEElementOp{};

  auto gemm = ck::tensor_operation::device::DeviceBatchedGemmMultiD_Xdl_CShuffle_V3<
      ALayout,                  // ALayout
      BLayout,                  // BLayout
      DsLayout,                 // DsLayout
      CLayout,                  // CLayout
      ADataType,                // ADataType
      BDataType,                // BDataType
      DsDataType,               // DsDataType
      CDataType,                // CDataType
      AccDataType,              // AccDataType
      CShuffleDataType,         // CshuffleType
      AElementOp,               // AElementwiseOperation
      BElementOp,               // BElementwiseOperation
      CDEElementOp,             // CElementwiseOperation
      GEMM_SPEC,                // GEMMSpecialization
      BLOCK_SIZE,               // BlockSize
      MBLOCK,                   // MPerBlock
      NBLOCK,                   // NPerBlock
      KBLOCK,                   // KPerBlock
      AK1,                      // AK1
      BK1,                      // BK1
      WAVE_TILE_M,              // MPerXDL
      WAVE_TILE_N,              // NPerXDL
      WAVE_MAP_M,               // MXdlPerWave
      WAVE_MAP_N,               // NXdlPerWave
      ABLOCK_TRANSFER,          // ABlockTransferThreadClusterLengths_AK0_M_AK1
      S<1, 0, 2>,               // ABlockTransferThreadClusterArrangeOrder
      S<1, 0, 2>,               // ABlockTransferSrcAccessOrder
      2,                        // ABlockTransferSrcVectorDim
      ABLOCK_TRANSFER_SSPV,     // ABlockTransferSrcScalarPerVector
      ABLOCK_TRANSFER_DSPV_K1,  // ABlockTransferDstScalarPerVector_AK1
      0,                        // ABlockLdsExtraM
      BBLOCK_TRANSFER,          // BBlockTransferThreadClusterLengths_BK0_N_BK1
      S<1, 0, 2>,               // BBlockTransferThreadClusterArrangeOrder
      S<1, 0, 2>,               // BBlockTransferSrcAccessOrder
      2,                        // BBlockTransferSrcVectorDim
      BBLOCK_TRANSFER_SSPV,     // BBlockTransferSrcScalarPerVector
      BBLOCK_TRANSFER_SSPV_K1,  // BBlockTransferDstScalarPerVector_BK1
      0,                        // BBlockLdsAddExtraN
      CSHUFFLE_MXDL_PWPS,       // CShuffleMXdlPerWavePerShuffle
      CSHUFFLE_NXDL_PWPS,       // CShuffleNXdlPerWavePerShuffle
      CSHUFFLEBLOCK_TRANSFER,   // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
      CDESHUFFLEBLOCK_TRANSFER, // CDEShuffleBlockTransferScalarPerVectors
      LOOP_SCHED,               // BlockGemmPipelineScheduler
      PIPELINE_VERSION          // BlockGemmPipelineVersion
      >{};
  auto invoker = gemm.MakeInvoker();
  auto argument = gemm.MakeArgument(
    b, // A and B are swapped for CK
    a,
    {},
    c,
    n,
    m,
    k,
    num_batches,
    ldb,
    lda,
    {},
    ldc,
    n * k,  // batch_stride_a
    m * k,  // batch_stride_b
    {},
    m * n,  // batch_stride_c
    a_element_op,
    b_element_op,
    cde_element_op
  );
  TORCH_CHECK(gemm.IsSupportedArgument(argument), "wrong! device_gemm with the specified compilation parameters does not support this GEMM problem");
  auto stream = at::cuda::getCurrentCUDAStream().stream();
  invoker.Run(argument, StreamConfig{stream, false});
}

Analyze Your Own Codebase

Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.

Try Supermodel Free