gemm_tunable Class — pytorch Architecture
Architecture documentation for the gemm_tunable class in CUDABlas.cpp from the pytorch codebase.
Entity Profile
Source Code
aten/src/ATen/cuda/CUDABlas.cpp lines 1439–1478
template <typename DType, typename C_Dtype>
inline void gemm_tunable(CUDABLAS_GEMM_ARGTYPES_AND_C_DTYPE(DType, C_Dtype)) {
tunable::GemmParams<DType> params;
params.transa = transa;
params.transb = transb;
params.m = m;
params.n = n;
params.k = k;
params.alpha = alpha;
params.a = a;
params.lda = lda;
params.b = b;
params.ldb = ldb;
params.beta = beta;
params.c = c;
params.ldc = ldc;
bool transa_ = ((transa != 'n') && (transa != 'N'));
bool transb_ = ((transb != 'n') && (transb != 'N'));
if (transa_ && transb_) {
static tunable::GemmTunableOp<DType, tunable::BlasOp::T, tunable::BlasOp::T> gemm{};
gemm(¶ms);
}
else if (transa_ && !transb_) {
static tunable::GemmTunableOp<DType, tunable::BlasOp::T, tunable::BlasOp::N> gemm{};
gemm(¶ms);
}
else if (!transa_ && transb_) {
static tunable::GemmTunableOp<DType, tunable::BlasOp::N, tunable::BlasOp::T> gemm{};
gemm(¶ms);
}
else if (!transa_ && !transb_) {
static tunable::GemmTunableOp<DType, tunable::BlasOp::N, tunable::BlasOp::N> gemm{};
gemm(¶ms);
}
else {
TORCH_CHECK(false, "unreachable");
}
}
Source
Analyze Your Own Codebase
Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.
Try Supermodel Free