BLOCK_SIZE Class — pytorch Architecture
Architecture documentation for the BLOCK_SIZE class in ck_gemm_template.h from the pytorch codebase.
Entity Profile
Source Code
aten/src/ATen/native/hip/ck_gemm_template.h lines 78–232
template <
typename Dtype,
int BLOCK_SIZE,
int MBLOCK,
int NBLOCK,
int KBLOCK,
int AK1,
int BK1,
int MPER_XDL,
int NPER_XDL,
int MPER_WAVE,
int NPER_WAVE,
typename ABLOCK_CLUSTER_LENS,
typename ABLOCK_CLUSTER_ORDER,
typename ABLOCK_SRC_ORDER,
int ABLOCK_VECTOR_DIM,
int ABLOCK_SCALAR_VEC,
int ABLOCK_SCALAR_VEC_AK1,
bool ABLOCK_LDS_EXTRAM,
typename BBLOCK_CLUSTER_LENS,
typename BBLOCK_CLUSTER_ORDER,
typename BBLOCK_SRC_ORDER,
int BBLOCK_VECTOR_DIM,
int BBLOCK_SCALAR_VEC,
int BBLOCK_SCALAR_VEC_AK1,
bool BBLOCK_LDS_EXTRAN,
int CMPER_WAVE,
int CNPER_WAVE,
typename BLOCK_CLUSTER_LENS,
typename CDE_SCALAR_VEC,
bool PADDING = false,
bool TRANSA = false,
bool TRANSB = false>
void gemm_impl(CUDABLAS_GEMM_ARGTYPES(Dtype)) {
// Get input information.
int M = m;
int N = n;
int K = k;
int StrideA = lda;
int StrideB = ldb;
int StrideC = ldc;
int KBatch = 1;
float falpha = alpha;
float fbeta = beta;
using ADataType = typename CkMathType<Dtype>::dtype;
using BDataType = typename CkMathType<Dtype>::dtype;
using CDataType = typename CkMathType<Dtype>::dtype;
using DDataType = typename CkMathType<Dtype>::dtype;
using AccDataType = float;
using CShuffleDataType = typename CkMathType<Dtype>::dtype;
using ALayout = typename CkTensorLayout<TRANSA, TRANSB>::a_layout;
using BLayout = typename CkTensorLayout<TRANSA, TRANSB>::b_layout;
using DLayout = Row;
using CLayout = Row;
using AElementOp = PassThrough;
using BElementOp = PassThrough;
using CElementOp = AlphaBetaAdd;
static constexpr auto GemmDefault =
ck::tensor_operation::device::GemmSpecialization::Default;
static constexpr auto GemmMNKPadding =
ck::tensor_operation::device::GemmSpecialization::MNKPadding;
static constexpr auto GemmSpec = PADDING ? GemmMNKPadding : GemmDefault;
using DeviceGemmInstance =
ck::tensor_operation::device::DeviceGemmMultiD_Xdl_CShuffle_V3<ALayout,
BLayout,
ck::Tuple<>,
CLayout,
ADataType,
BDataType,
ck::Tuple<>,
CDataType,
AccDataType,
CShuffleDataType,
AElementOp,
BElementOp,
CElementOp,
GemmSpec,
BLOCK_SIZE,
MBLOCK,
NBLOCK,
KBLOCK,
AK1,
BK1,
MPER_XDL,
NPER_XDL,
MPER_WAVE,
NPER_WAVE,
ABLOCK_CLUSTER_LENS,
ABLOCK_CLUSTER_ORDER,
ABLOCK_SRC_ORDER,
ABLOCK_VECTOR_DIM,
ABLOCK_SCALAR_VEC,
ABLOCK_SCALAR_VEC_AK1,
ABLOCK_LDS_EXTRAM,
BBLOCK_CLUSTER_LENS,
BBLOCK_CLUSTER_ORDER,
BBLOCK_SRC_ORDER,
BBLOCK_VECTOR_DIM,
BBLOCK_SCALAR_VEC,
BBLOCK_SCALAR_VEC_AK1,
BBLOCK_LDS_EXTRAN,
CMPER_WAVE,
CNPER_WAVE,
BLOCK_CLUSTER_LENS,
CDE_SCALAR_VEC>;
auto gemm = DeviceGemmInstance{};
auto invoker = gemm.MakeInvoker();
auto a_element_op = AElementOp{};
auto b_element_op = BElementOp{};
auto c_element_op = CElementOp{alpha, beta};
using DDataArrayType = std::array<const void*, 0>;
DDataArrayType DDataArray;
// We swap A and B inputs here as a temporary workaround
auto argument = gemm.MakeArgument(
reinterpret_cast<const void*>(b),
reinterpret_cast<const void*>(a),
DDataArray,
reinterpret_cast<void*>(c),
N,
M,
K,
StrideB,
StrideA,
std::array<ck::index_t, 0>{},
StrideC,
KBatch,
a_element_op,
b_element_op,
c_element_op);
TORCH_CHECK(gemm.IsSupportedArgument(argument), "wrong! device_gemm with the specified compilation parameters does not support this GEMM problem");
auto stream = at::cuda::getCurrentCUDAStream().stream();
invoker.Run(argument, StreamConfig{stream, false});
}
Source
Analyze Your Own Codebase
Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.
Try Supermodel Free