gemv Class — pytorch Architecture
Architecture documentation for the gemv class in BlasKernel.cpp from the pytorch codebase.
Entity Profile
Source Code
aten/src/ATen/native/BlasKernel.cpp lines 541–610
template<typename scalar_t>
void gemv(char trans, int64_t m, int64_t n, scalar_t alpha, const scalar_t *a, int64_t lda, const scalar_t *x, int64_t incx, scalar_t beta, scalar_t *y, int64_t incy) {
if(n == 1) lda = m;
#if AT_BUILD_WITH_BLAS()
if (blas_impl::gemv_use_fast_path<scalar_t>(trans, m, n, alpha, lda, incx, beta, incy)) {
TORCH_CHECK(lda >= std::max<int64_t>(1L, m), "lda should be at least max(1,", m, "), but have ", lda);
int i_m = static_cast<int>(m);
int i_n = static_cast<int>(n);
int i_lda = static_cast<int>(lda);
int i_incx = static_cast<int>(incx);
int i_incy = static_cast<int>(incy);
blas_impl::gemv_fast_path<scalar_t>(&trans, &i_m, &i_n, &alpha, a, &i_lda, x, &i_incx, &beta, y, &i_incy);
return;
}
#endif
using opmath_t = at::opmath_type<scalar_t>;
if ((trans == 'T') || (trans == 't')) {
for (const auto i : c10::irange(n)) {
opmath_t sum = 0;
const scalar_t *row_ = a + lda * i;
for (const auto j : c10::irange(m)) {
sum += static_cast<opmath_t>(x[j * incx]) * static_cast<opmath_t>(row_[j]);
}
if (beta == scalar_t(0)) {
y[i * incy] = alpha * sum;
} else {
y[i * incy] = beta * y[i * incy] + alpha * sum;
}
}
} else {
if (beta != scalar_t(1) && beta != scalar_t(0)) scal<scalar_t>(m, beta, y, incy);
constexpr bool is_low_precision = !std::is_same_v<opmath_t, scalar_t>;
std::vector<opmath_t> sum;
if constexpr (is_low_precision) {
sum.resize(m);
}
for (const auto j : c10::irange(n)) {
const scalar_t *column_ = a + lda * j;
opmath_t z = alpha * static_cast<opmath_t>(x[j * incx]);
for (const auto i : c10::irange(m)) {
//output values are ignored if beta is 0, and set to 0, nans and infs are not propagated
if (j==0 && beta==scalar_t(0)) {
if constexpr (!is_low_precision) {
y[i * incy] = 0;
}
}
if constexpr (is_low_precision) {
sum[i] += z * column_[i];
} else {
y[i * incy] += z * column_[i];
}
}
}
if constexpr (is_low_precision) {
if (beta == scalar_t(0)) {
for (const auto i : c10::irange(m)) {
y[i * incy] = sum[i];
}
} else {
for (const auto i : c10::irange(m)) {
y[i * incy] += sum[i];
}
}
}
}
return;
}
Source
Analyze Your Own Codebase
Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.
Try Supermodel Free