gemm_notrans_ Class — pytorch Architecture
Architecture documentation for the gemm_notrans_ class in BlasKernel.cpp from the pytorch codebase.
Entity Profile
Source Code
aten/src/ATen/native/cpu/BlasKernel.cpp lines 102–137
template <typename scalar_t, typename opmath_t, typename out_t>
__ubsan_ignore_signed_int_overflow__
std::enable_if_t<std::is_same_v<scalar_t, opmath_t>, void>
gemm_notrans_(
int64_t m,
int64_t n,
int64_t k,
opmath_t alpha,
const scalar_t* a,
int64_t lda,
const scalar_t* b,
int64_t ldb,
opmath_t beta,
out_t* c,
int64_t ldc) {
// c *= beta
scale_(m, n, beta, c, ldc);
// c += alpha * (a @ b)
const uint64_t unsigned_m = m;
const uint64_t i_m = unsigned_m / 4;
for (const uint64_t l : c10::irange(k)) {
for (const uint64_t j : c10::irange(n)) {
opmath_t val = b[l + j * ldb] * alpha;
for (const auto i_i : c10::irange(i_m)) {
c[j * ldc + i_i * 4 + 0] += a[i_i * 4 + 0 + l * lda] * val;
c[j * ldc + i_i * 4 + 1] += a[i_i * 4 + 1 + l * lda] * val;
c[j * ldc + i_i * 4 + 2] += a[i_i * 4 + 2 + l * lda] * val;
c[j * ldc + i_i * 4 + 3] += a[i_i * 4 + 3 + l * lda] * val;
}
uint64_t i = i_m * 4;
for (; i < unsigned_m; i++)
c[j * ldc + i] += a[i + l * lda] * val;
}
}
}
Source
Analyze Your Own Codebase
Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.
Try Supermodel Free