pad_remain_row_col_zero Class — pytorch Architecture
Architecture documentation for the pad_remain_row_col_zero class in FlashAttentionKernel.cpp from the pytorch codebase.
Entity Profile
Source Code
aten/src/ATen/native/cpu/FlashAttentionKernel.cpp lines 268–304
template <typename scalar_t>
inline void pad_remain_row_col_zero(
scalar_t* value_ptr,
int rows,
int cols,
int prows,
int pcols,
int ldi) {
auto psize = pcols - cols;
if (psize == 0 && prows == rows) {
return;
}
auto vec_size = at::vec::Vectorized<scalar_t>::size();
auto zero = at::vec::Vectorized<scalar_t>(0);
if (psize > 0) {
for (int i = 0; i < rows; i++) {
int j = 0;
for (; j < psize - (psize % vec_size); j += vec_size) {
zero.store(value_ptr + i * ldi + cols + j);
}
if (j < psize) {
zero.store(value_ptr + i * ldi + cols + j, psize - j);
}
}
}
for (int i = rows; i < prows; i++) {
int j = 0;
for (; j < pcols - (pcols % vec_size); j += vec_size) {
zero.store(value_ptr + i * ldi + j);
}
if (j < pcols) {
zero.store(value_ptr + i * ldi + j, pcols - j);
}
}
}
Source
Analyze Your Own Codebase
Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.
Try Supermodel Free