Home / Class/ pad_remain_row_col_zero Class — pytorch Architecture

pad_remain_row_col_zero Class — pytorch Architecture

Architecture documentation for the pad_remain_row_col_zero class in FlashAttentionKernel.cpp from the pytorch codebase.

Entity Profile

Source Code

aten/src/ATen/native/cpu/FlashAttentionKernel.cpp lines 268–304

template <typename scalar_t>
inline void pad_remain_row_col_zero(
    scalar_t* value_ptr,
    int rows,
    int cols,
    int prows,
    int pcols,
    int ldi) {
  auto psize = pcols - cols;
  if (psize == 0 && prows == rows) {
    return;
  }
  auto vec_size = at::vec::Vectorized<scalar_t>::size();
  auto zero = at::vec::Vectorized<scalar_t>(0);
  if (psize > 0) {
    for (int i = 0; i < rows; i++) {
      int j = 0;
      for (; j < psize - (psize % vec_size); j += vec_size) {
        zero.store(value_ptr + i * ldi + cols + j);
      }
      if (j < psize) {
        zero.store(value_ptr + i * ldi + cols + j, psize - j);
      }
    }
  }

  for (int i = rows; i < prows; i++) {
    int j = 0;
    for (; j < pcols - (pcols % vec_size); j += vec_size) {
      zero.store(value_ptr + i * ldi + j);
    }
    if (j < pcols) {
      zero.store(value_ptr + i * ldi + j, pcols - j);
    }
  }

}

Analyze Your Own Codebase

Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.

Try Supermodel Free